Skip to content

Commit c982ac5

Browse files
authored
[Bugfix] Fix FP16 overflow for DeepSeek V2 (vllm-project#13232)
Signed-off-by: Yida Wu <[email protected]>
1 parent 4290b70 commit c982ac5

File tree

1 file changed

+24
-4
lines changed

1 file changed

+24
-4
lines changed

vllm/model_executor/models/deepseek_v2.py

+24-4
Original file line numberDiff line numberDiff line change
@@ -155,11 +155,21 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
155155
shared_output = self.shared_experts(hidden_states)
156156
# router_logits: (num_tokens, n_experts)
157157
router_logits, _ = self.gate(hidden_states)
158-
final_hidden_states = self.experts(
159-
hidden_states=hidden_states,
160-
router_logits=router_logits) * self.routed_scaling_factor
158+
if hidden_states.dtype != torch.float16:
159+
final_hidden_states = self.experts(
160+
hidden_states=hidden_states,
161+
router_logits=router_logits) * self.routed_scaling_factor
162+
else:
163+
# This is a special case to avoid FP16 overflow
164+
final_hidden_states = self.experts(hidden_states=hidden_states,
165+
router_logits=router_logits)
161166
if shared_output is not None:
162-
final_hidden_states = final_hidden_states + shared_output
167+
if hidden_states.dtype != torch.float16:
168+
final_hidden_states = final_hidden_states + shared_output
169+
else:
170+
# This is a special case to avoid FP16 overflow
171+
final_hidden_states = final_hidden_states + shared_output \
172+
* (1. / self.routed_scaling_factor)
163173
if self.tp_size > 1:
164174
final_hidden_states = tensor_model_parallel_all_reduce(
165175
final_hidden_states)
@@ -531,6 +541,7 @@ def __init__(
531541
eps=config.rms_norm_eps)
532542
self.post_attention_layernorm = RMSNorm(config.hidden_size,
533543
eps=config.rms_norm_eps)
544+
self.routed_scaling_factor = config.routed_scaling_factor
534545

535546
def forward(
536547
self,
@@ -551,9 +562,18 @@ def forward(
551562
)
552563

553564
# Fully Connected
565+
if isinstance(self.mlp, DeepseekV2MoE) and \
566+
hidden_states.dtype == torch.float16:
567+
# This is a special case to avoid FP16 overflow
568+
hidden_states *= 1. / self.routed_scaling_factor
554569
hidden_states, residual = self.post_attention_layernorm(
555570
hidden_states, residual)
556571
hidden_states = self.mlp(hidden_states)
572+
if isinstance(self.mlp, DeepseekV2MLP) and \
573+
hidden_states.dtype == torch.float16:
574+
# This is a special case to avoid FP16 overflow
575+
hidden_states *= 1. / self.routed_scaling_factor
576+
residual *= 1. / self.routed_scaling_factor
557577
return hidden_states, residual
558578

559579

0 commit comments

Comments
 (0)