[perf]Support MOE Multi-stream in Deepseek (#947)

### What this PR does / why we need it?
Support MOE inner Multi-stream for Deepseek. 
This feature requires graph mode with mc2 enabled.

---------

Signed-off-by: David9857 <985700846@qq.com>
This commit is contained in:
David9857
2025-06-05 23:39:38 +08:00
committed by GitHub
parent 908a851a77
commit 78431b3469
6 changed files with 133 additions and 45 deletions

View File

@@ -216,6 +216,8 @@ class CustomDeepseekV2MoE(nn.Module):
ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
self.enable_multistream_shared_expert = \
ascend_config.torchair_graph_config.enable_multistream_shared_expert
def forward(
self,
@@ -238,6 +240,8 @@ class CustomDeepseekV2MoE(nn.Module):
num_tokens, hidden_size = hidden_states.shape
multistream = self.enable_multistream_shared_expert and not is_prefill
old_hidden_states = hidden_states.clone()
if self.tp_size > 1:
@@ -259,13 +263,25 @@ class CustomDeepseekV2MoE(nn.Module):
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
kwargs = {}
if multistream:
kwargs.update({
"shared_experts": self.shared_experts,
"shared_hidden_states": old_hidden_states
})
hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits,
is_prefill=is_prefill,
top_k=CustomDeepseekV2MoE.top_k,
enable_force_load_balance=enable_force_load_balance,
) * self.routed_scaling_factor
**kwargs)
if multistream:
hidden_states, shared_output = hidden_states
hidden_states = hidden_states * self.routed_scaling_factor
if self.tp_size > 1:
if self.torchair_graph_enabled:
@@ -288,7 +304,8 @@ class CustomDeepseekV2MoE(nn.Module):
hidden_states = hidden_states[:-num_padding_tokens]
if self.n_shared_experts is not None:
shared_output = self.shared_experts(old_hidden_states)
if not multistream:
shared_output = self.shared_experts(old_hidden_states)
if shared_output is not None:
hidden_states = hidden_states + shared_output