[DeepEP] Reduce routed scaling overhead (#5277)
Co-authored-by: Cheng Wan <54331508+ch-wan@users.noreply.github.com>
This commit is contained in:
@@ -337,16 +337,13 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
topk_weights,
|
topk_weights,
|
||||||
forward_mode=forward_mode,
|
forward_mode=forward_mode,
|
||||||
)
|
)
|
||||||
final_hidden_states = (
|
final_hidden_states = self.experts(
|
||||||
self.experts(
|
hidden_states=hidden_states,
|
||||||
hidden_states=hidden_states,
|
reorder_topk_ids=reorder_topk_ids,
|
||||||
reorder_topk_ids=reorder_topk_ids,
|
seg_indptr=seg_indptr,
|
||||||
seg_indptr=seg_indptr,
|
masked_m=masked_m,
|
||||||
masked_m=masked_m,
|
expected_m=expected_m,
|
||||||
expected_m=expected_m,
|
forward_mode=forward_mode,
|
||||||
forward_mode=forward_mode,
|
|
||||||
)
|
|
||||||
* self.routed_scaling_factor
|
|
||||||
)
|
)
|
||||||
if self.ep_size > 1:
|
if self.ep_size > 1:
|
||||||
final_hidden_states = self.deepep_dispatcher.combine(
|
final_hidden_states = self.deepep_dispatcher.combine(
|
||||||
@@ -355,6 +352,8 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
topk_weights,
|
topk_weights,
|
||||||
forward_mode,
|
forward_mode,
|
||||||
)
|
)
|
||||||
|
final_hidden_states *= self.routed_scaling_factor
|
||||||
|
|
||||||
if shared_output is not None:
|
if shared_output is not None:
|
||||||
final_hidden_states = final_hidden_states + shared_output
|
final_hidden_states = final_hidden_states + shared_output
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user