@@ -155,11 +155,21 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
155
155
shared_output = self .shared_experts (hidden_states )
156
156
# router_logits: (num_tokens, n_experts)
157
157
router_logits , _ = self .gate (hidden_states )
158
- final_hidden_states = self .experts (
159
- hidden_states = hidden_states ,
160
- router_logits = router_logits ) * self .routed_scaling_factor
158
+ if hidden_states .dtype != torch .float16 :
159
+ final_hidden_states = self .experts (
160
+ hidden_states = hidden_states ,
161
+ router_logits = router_logits ) * self .routed_scaling_factor
162
+ else :
163
+ # This is a special case to avoid FP16 overflow
164
+ final_hidden_states = self .experts (hidden_states = hidden_states ,
165
+ router_logits = router_logits )
161
166
if shared_output is not None :
162
- final_hidden_states = final_hidden_states + shared_output
167
+ if hidden_states .dtype != torch .float16 :
168
+ final_hidden_states = final_hidden_states + shared_output
169
+ else :
170
+ # This is a special case to avoid FP16 overflow
171
+ final_hidden_states = final_hidden_states + shared_output \
172
+ * (1. / self .routed_scaling_factor )
163
173
if self .tp_size > 1 :
164
174
final_hidden_states = tensor_model_parallel_all_reduce (
165
175
final_hidden_states )
@@ -531,6 +541,7 @@ def __init__(
531
541
eps = config .rms_norm_eps )
532
542
self .post_attention_layernorm = RMSNorm (config .hidden_size ,
533
543
eps = config .rms_norm_eps )
544
+ self .routed_scaling_factor = config .routed_scaling_factor
534
545
535
546
def forward (
536
547
self ,
@@ -551,9 +562,18 @@ def forward(
551
562
)
552
563
553
564
# Fully Connected
565
+ if isinstance (self .mlp , DeepseekV2MoE ) and \
566
+ hidden_states .dtype == torch .float16 :
567
+ # This is a special case to avoid FP16 overflow
568
+ hidden_states *= 1. / self .routed_scaling_factor
554
569
hidden_states , residual = self .post_attention_layernorm (
555
570
hidden_states , residual )
556
571
hidden_states = self .mlp (hidden_states )
572
+ if isinstance (self .mlp , DeepseekV2MLP ) and \
573
+ hidden_states .dtype == torch .float16 :
574
+ # This is a special case to avoid FP16 overflow
575
+ hidden_states *= 1. / self .routed_scaling_factor
576
+ residual *= 1. / self .routed_scaling_factor
557
577
return hidden_states , residual
558
578
559
579
0 commit comments