[BugFix]Fix moe load problems in torchair when using dynamic eplb (#3381)
### What this PR does / why we need it? When using dynamic eplb, moe load is not imported. We fix this problem by modifying the return value of hidden states in torchair. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? DeepseekV3 in A3. - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: daishixun <dsxsteven@sina.com>
This commit is contained in:
@@ -1279,13 +1279,15 @@ class TorchairAscendFusedMoE(FusedMoE):
|
||||
)
|
||||
|
||||
if shared_experts:
|
||||
if isinstance(e_hidden_states, tuple):
|
||||
if isinstance(e_hidden_states,
|
||||
tuple) and len(e_hidden_states) == 2:
|
||||
e_hidden_states, shared_hidden_states = e_hidden_states
|
||||
|
||||
if self.dynamic_eplb and isinstance(
|
||||
e_hidden_states, tuple) and len(e_hidden_states) == 3:
|
||||
self.moe_load += e_hidden_states[2] if e_hidden_states[1] == 0 else \
|
||||
torch.cat(e_hidden_states[2][:1], e_hidden_states[2][1:] - e_hidden_states[2][:-1])
|
||||
e_hidden_states, group_list_type, expert_tokens = e_hidden_states
|
||||
self.moe_load += expert_tokens if group_list_type else \
|
||||
torch.cat([expert_tokens[:1], expert_tokens[1:] - expert_tokens[:-1]])
|
||||
|
||||
if (fused_moe_state not in [
|
||||
FusedMoEState.AllGather, FusedMoEState.AllGatherEP,
|
||||
|
||||
@@ -220,6 +220,7 @@ def torchair_fused_experts_with_mc2(
|
||||
shared_dequant_scale: Optional[Any] = None,
|
||||
w1_scale_bias: torch.Tensor = None,
|
||||
w2_scale_bias: torch.Tensor = None,
|
||||
dynamic_eplb: bool = False,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
assert mc2_mask is not None
|
||||
if log2phy is not None:
|
||||
@@ -354,6 +355,9 @@ def torchair_fused_experts_with_mc2(
|
||||
) if enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine(
|
||||
**kwargs_mc2)
|
||||
|
||||
if dynamic_eplb:
|
||||
return (hidden_states, 1, expert_token_nums)
|
||||
|
||||
if shared_experts is None:
|
||||
return hidden_states
|
||||
else:
|
||||
@@ -832,6 +836,7 @@ class TorchairAscendW8A8DynamicFusedMoEMethod:
|
||||
self.ep_group = get_ep_group()
|
||||
|
||||
ascend_config = get_ascend_config()
|
||||
self.dynamic_eplb = ascend_config.dynamic_eplb
|
||||
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
||||
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
|
||||
|
||||
@@ -994,7 +999,8 @@ class TorchairAscendW8A8DynamicFusedMoEMethod:
|
||||
is_torchair=self.torchair_graph_enabled,
|
||||
mc2_mask=kwargs.get("mc2_mask", None),
|
||||
shared_gate_up=shared_gate_up,
|
||||
shared_dequant_scale=shared_dequant_scale)
|
||||
shared_dequant_scale=shared_dequant_scale,
|
||||
dynamic_eplb=self.dynamic_eplb)
|
||||
elif fused_moe_state in [
|
||||
FusedMoEState.AllGather, FusedMoEState.NaiveMulticast
|
||||
]:
|
||||
|
||||
Reference in New Issue
Block a user