Optimize Permute Kernel in DeepEP (#4643)

Co-authored-by: Cheng Wan <54331508+ch-wan@users.noreply.github.com>
This commit is contained in:
xutizhou
2025-03-23 05:30:34 +08:00
committed by GitHub
parent f8f9244a61
commit c2bd094d6e
4 changed files with 101 additions and 230 deletions

View File

@@ -294,7 +294,7 @@ class DeepseekV2MoE(nn.Module):
correction_bias=self.correction_bias,
)
if self.tp_size > 1:
recv_hidden_states, topk_idx, topk_weights, tokens_per_expert = (
recv_hidden_states, reorder_topk_ids, seg_indptr = (
self.deepep_dispatcher.dispatch(
hidden_states,
topk_idx,
@@ -306,7 +306,8 @@ class DeepseekV2MoE(nn.Module):
final_hidden_states = (
self.experts(
hidden_states=recv_hidden_states,
tokens_per_expert=tokens_per_expert,
reorder_topk_ids=reorder_topk_ids,
seg_indptr=seg_indptr,
forward_mode=forward_mode,
)
* self.routed_scaling_factor