Optimize Permute Kernel in DeepEP (#4643)

Co-authored-by: Cheng Wan <54331508+ch-wan@users.noreply.github.com>
This commit is contained in:
xutizhou
2025-03-23 05:30:34 +08:00
committed by GitHub
parent f8f9244a61
commit c2bd094d6e
4 changed files with 101 additions and 230 deletions

View File

@@ -831,19 +831,23 @@ class DeepEPMoE(EPMoE):
def forward(
self,
hidden_states: torch.Tensor,
tokens_per_expert: torch.Tensor,
reorder_topk_ids: torch.Tensor,
seg_indptr: torch.Tensor,
forward_mode: ForwardMode,
):
# Todo: use m_grouped_gemm_fp8_fp8_bf16_nt_masked after low_latency dispatch (decode)
if True: # not forward_mode.is_decode():
return self.forward_normal(hidden_states, tokens_per_expert)
return self.forward_normal(hidden_states, reorder_topk_ids, seg_indptr)
else:
return self.forward_deepgemm_masked(hidden_states, tokens_per_expert)
return self.forward_deepgemm_masked(
hidden_states, reorder_topk_ids, seg_indptr
)
def forward_normal(
self,
hidden_states: torch.Tensor,
tokens_per_expert: torch.Tensor,
reorder_topk_ids: torch.Tensor,
seg_indptr: torch.Tensor,
):
assert self.quant_method is not None
assert self.activation == "silu"
@@ -851,15 +855,7 @@ class DeepEPMoE(EPMoE):
self.grouped_gemm_runner = GroupedGemmRunner(
hidden_states.device, use_flashinfer=False # TODO: use flashinfer
)
seg_indptr_cur_rank = torch.cat(
[
torch.zeros(
1, device=tokens_per_expert.device, dtype=tokens_per_expert.dtype
),
torch.cumsum(tokens_per_expert, dim=0),
]
)
reorder_topk_ids = torch.repeat_interleave(tokens_per_expert)
if self.activation_scheme == "dynamic" and not self.use_block_quant:
max_value = (
torch.max(hidden_states)
@@ -881,6 +877,7 @@ class DeepEPMoE(EPMoE):
device=hidden_states.device,
dtype=hidden_states.dtype,
)
if hidden_states.shape[0] > 0:
gateup_output = self.grouped_gemm_runner(
a=hidden_states,
@@ -888,7 +885,7 @@ class DeepEPMoE(EPMoE):
c=gateup_output,
batch_size=self.num_experts_per_partition,
weight_column_major=True,
seg_indptr=seg_indptr_cur_rank,
seg_indptr=seg_indptr,
weight_indices=weight_indices_cur_rank,
use_fp8_w8a8=self.use_fp8_w8a8,
scale_a=self.w13_input_scale,
@@ -946,7 +943,7 @@ class DeepEPMoE(EPMoE):
c=down_output,
batch_size=self.num_experts_per_partition,
weight_column_major=True,
seg_indptr=seg_indptr_cur_rank,
seg_indptr=seg_indptr,
weight_indices=weight_indices_cur_rank,
use_fp8_w8a8=self.use_fp8_w8a8,
scale_a=self.w2_input_scale,