[perf]Support MOE Multi-stream in Deepseek (#947)
### What this PR does / why we need it? Support MOE inner Multi-stream for Deepseek. This feature requires graph mode with mc2 enabled. --------- Signed-off-by: David9857 <985700846@qq.com>
This commit is contained in:
@@ -216,6 +216,8 @@ class CustomDeepseekV2MoE(nn.Module):
|
||||
|
||||
ascend_config = get_ascend_config()
|
||||
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
||||
self.enable_multistream_shared_expert = \
|
||||
ascend_config.torchair_graph_config.enable_multistream_shared_expert
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -238,6 +240,8 @@ class CustomDeepseekV2MoE(nn.Module):
|
||||
|
||||
num_tokens, hidden_size = hidden_states.shape
|
||||
|
||||
multistream = self.enable_multistream_shared_expert and not is_prefill
|
||||
|
||||
old_hidden_states = hidden_states.clone()
|
||||
|
||||
if self.tp_size > 1:
|
||||
@@ -259,13 +263,25 @@ class CustomDeepseekV2MoE(nn.Module):
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
|
||||
kwargs = {}
|
||||
if multistream:
|
||||
kwargs.update({
|
||||
"shared_experts": self.shared_experts,
|
||||
"shared_hidden_states": old_hidden_states
|
||||
})
|
||||
|
||||
hidden_states = self.experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
is_prefill=is_prefill,
|
||||
top_k=CustomDeepseekV2MoE.top_k,
|
||||
enable_force_load_balance=enable_force_load_balance,
|
||||
) * self.routed_scaling_factor
|
||||
**kwargs)
|
||||
|
||||
if multistream:
|
||||
hidden_states, shared_output = hidden_states
|
||||
|
||||
hidden_states = hidden_states * self.routed_scaling_factor
|
||||
|
||||
if self.tp_size > 1:
|
||||
if self.torchair_graph_enabled:
|
||||
@@ -288,7 +304,8 @@ class CustomDeepseekV2MoE(nn.Module):
|
||||
hidden_states = hidden_states[:-num_padding_tokens]
|
||||
|
||||
if self.n_shared_experts is not None:
|
||||
shared_output = self.shared_experts(old_hidden_states)
|
||||
if not multistream:
|
||||
shared_output = self.shared_experts(old_hidden_states)
|
||||
|
||||
if shared_output is not None:
|
||||
hidden_states = hidden_states + shared_output
|
||||
|
||||
Reference in New Issue
Block a user