Support shared experts overlap in cutlass moe (#11611)

This commit is contained in:
fzyzcjy
2025-10-18 07:59:40 +08:00
committed by GitHub
parent 8a382fd399
commit 505329cab0
3 changed files with 41 additions and 5 deletions

View File

@@ -655,6 +655,7 @@ class DeepseekV2MoE(nn.Module):
self._enable_a2a_moe = (
get_moe_a2a_backend().is_deepep() or get_moe_a2a_backend().is_mooncake()
)
self._fuse_shared_experts_inside_sbo = SboFlags.fuse_shared_experts_inside_sbo()
def get_moe_weights(self):
return [
@@ -746,9 +747,10 @@ class DeepseekV2MoE(nn.Module):
return self.forward_cpu(hidden_states, should_allreduce_fusion)
if hidden_states.shape[0] > 0:
shared_output = self._forward_shared_experts(
hidden_states, gemm_output_zero_allocator
)
if not self._fuse_shared_experts_inside_sbo:
shared_output = self._forward_shared_experts(
hidden_states, gemm_output_zero_allocator
)
# router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states, gemm_output_zero_allocator)
topk_output = self.topk(hidden_states, router_logits)
@@ -756,7 +758,27 @@ class DeepseekV2MoE(nn.Module):
shared_output = None
topk_output = self.topk.empty_topk_output(hidden_states.device)
final_hidden_states = self.experts(hidden_states, topk_output)
if self._fuse_shared_experts_inside_sbo:
shared_output = None
def _forward_shared_experts_and_put_results():
nonlocal shared_output
shared_output = self._forward_shared_experts(
hidden_states, gemm_output_zero_allocator
)
final_hidden_states = self.experts(
hidden_states,
topk_output,
**(
dict(
forward_shared_experts=_forward_shared_experts_and_put_results,
alt_stream=self.alt_stream,
)
if self._fuse_shared_experts_inside_sbo
else {}
),
)
if not _is_cuda and not _use_aiter:
# fused in biased_grouped_topk so we can skip here
final_hidden_states *= self.routed_scaling_factor