From 847d12a389217e4cbcc5fff70abd72d9b15ad5c4 Mon Sep 17 00:00:00 2001 From: dsxsteven <36877507+dsxsteven@users.noreply.github.com> Date: Mon, 13 Oct 2025 11:38:57 +0800 Subject: [PATCH] [BugFix]Fix moe load problems in torchair when using dynamic eplb (#3381) ### What this PR does / why we need it? When using dynamic eplb, moe load is not imported. We fix this problem by modifying the return value of hidden states in torchair. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? DeepseekV3 in A3. - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: daishixun --- vllm_ascend/torchair/ops/torchair_fused_moe.py | 8 +++++--- .../torchair/quantization/torchair_w8a8_dynamic.py | 8 +++++++- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/vllm_ascend/torchair/ops/torchair_fused_moe.py b/vllm_ascend/torchair/ops/torchair_fused_moe.py index 8b1547d..56843fb 100644 --- a/vllm_ascend/torchair/ops/torchair_fused_moe.py +++ b/vllm_ascend/torchair/ops/torchair_fused_moe.py @@ -1279,13 +1279,15 @@ class TorchairAscendFusedMoE(FusedMoE): ) if shared_experts: - if isinstance(e_hidden_states, tuple): + if isinstance(e_hidden_states, + tuple) and len(e_hidden_states) == 2: e_hidden_states, shared_hidden_states = e_hidden_states if self.dynamic_eplb and isinstance( e_hidden_states, tuple) and len(e_hidden_states) == 3: - self.moe_load += e_hidden_states[2] if e_hidden_states[1] == 0 else \ - torch.cat(e_hidden_states[2][:1], e_hidden_states[2][1:] - e_hidden_states[2][:-1]) + e_hidden_states, group_list_type, expert_tokens = e_hidden_states + self.moe_load += expert_tokens if group_list_type else \ + torch.cat([expert_tokens[:1], expert_tokens[1:] - expert_tokens[:-1]]) if (fused_moe_state not in [ FusedMoEState.AllGather, FusedMoEState.AllGatherEP, diff --git a/vllm_ascend/torchair/quantization/torchair_w8a8_dynamic.py b/vllm_ascend/torchair/quantization/torchair_w8a8_dynamic.py index 23c4699..1825b2b 100644 --- a/vllm_ascend/torchair/quantization/torchair_w8a8_dynamic.py +++ b/vllm_ascend/torchair/quantization/torchair_w8a8_dynamic.py @@ -220,6 +220,7 @@ def torchair_fused_experts_with_mc2( shared_dequant_scale: Optional[Any] = None, w1_scale_bias: torch.Tensor = None, w2_scale_bias: torch.Tensor = None, + dynamic_eplb: bool = False, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: assert mc2_mask is not None if log2phy is not None: @@ -354,6 +355,9 @@ def torchair_fused_experts_with_mc2( ) if enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine( **kwargs_mc2) + if dynamic_eplb: + return (hidden_states, 1, expert_token_nums) + if shared_experts is None: return hidden_states else: @@ -832,6 +836,7 @@ class TorchairAscendW8A8DynamicFusedMoEMethod: self.ep_group = get_ep_group() ascend_config = get_ascend_config() + self.dynamic_eplb = ascend_config.dynamic_eplb self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp @@ -994,7 +999,8 @@ class TorchairAscendW8A8DynamicFusedMoEMethod: is_torchair=self.torchair_graph_enabled, mc2_mask=kwargs.get("mc2_mask", None), shared_gate_up=shared_gate_up, - shared_dequant_scale=shared_dequant_scale) + shared_dequant_scale=shared_dequant_scale, + dynamic_eplb=self.dynamic_eplb) elif fused_moe_state in [ FusedMoEState.AllGather, FusedMoEState.NaiveMulticast ]: