[DeepEP] Reduce routed scaling overhead (#5277)

Co-authored-by: Cheng Wan <54331508+ch-wan@users.noreply.github.com>
This commit is contained in:
yulei
2025-04-14 07:03:09 +08:00
committed by GitHub
parent 39d90449f3
commit adca585bfb

View File

@@ -337,16 +337,13 @@ class DeepseekV2MoE(nn.Module):
topk_weights,
forward_mode=forward_mode,
)
final_hidden_states = (
self.experts(
hidden_states=hidden_states,
reorder_topk_ids=reorder_topk_ids,
seg_indptr=seg_indptr,
masked_m=masked_m,
expected_m=expected_m,
forward_mode=forward_mode,
)
* self.routed_scaling_factor
final_hidden_states = self.experts(
hidden_states=hidden_states,
reorder_topk_ids=reorder_topk_ids,
seg_indptr=seg_indptr,
masked_m=masked_m,
expected_m=expected_m,
forward_mode=forward_mode,
)
if self.ep_size > 1:
final_hidden_states = self.deepep_dispatcher.combine(
@@ -355,6 +352,8 @@ class DeepseekV2MoE(nn.Module):
topk_weights,
forward_mode,
)
final_hidden_states *= self.routed_scaling_factor
if shared_output is not None:
final_hidden_states = final_hidden_states + shared_output