From d840f153f42f291ca2078d3fe4d027efdcb3d2db Mon Sep 17 00:00:00 2001 From: wangqiankun13 Date: Thu, 15 Jan 2026 09:21:18 +0800 Subject: [PATCH] [Bugfix] Fix acc bug when enbale dispatch_gmm_combine_decode and eplb (#5806) ### What this PR does / why we need it? Fix acc bug when enbale dispatch_gmm_combine_decode and eplb. After eplb, expert table may change, so mapping is needed, while fused_mc2 miss the mapping. More info about this operator, please refer to RFC: issue https://github.com/vllm-project/vllm-ascend/issues/5476 ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? without this pr, qwen3-235b eplb with dispatch_gmm_combine_decode get acc 3.33% on aime2024. with this pr, test qwen3-235b eplb on a single A3 node(ep16) without dispatch_gmm_combine_decode | dataset | version | metric | mode | vllm-api-stream-chat | |----- | ----- | ----- | ----- | -----| | aime2024 | 604a78 | accuracy | gen | 86.67 | with dispatch_gmm_combine_decode | dataset | version | metric | mode | vllm-api-stream-chat | |----- | ----- | ----- | ----- | -----| | aime2024 | 604a78 | accuracy | gen | 86.67 | - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2f4e6548efec402b913ffddc8726230d9311948d Signed-off-by: wangqiankun --- vllm_ascend/ops/fused_moe/moe_comm_method.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/vllm_ascend/ops/fused_moe/moe_comm_method.py b/vllm_ascend/ops/fused_moe/moe_comm_method.py index 631830bb..1692f145 100644 --- a/vllm_ascend/ops/fused_moe/moe_comm_method.py +++ b/vllm_ascend/ops/fused_moe/moe_comm_method.py @@ -300,6 +300,11 @@ class FusedMC2CommImpl(MoECommMethod): assert isinstance(self.token_dispatcher, TokenDispatcherWithMC2), \ "token_dispatcher must be an instance of TokenDispatcherWithMC2." + + # Apply log2phy if needed + if log2phy is not None: + topk_ids = log2phy[topk_ids] + group_list_type = None expert_tokens = None if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1: @@ -331,7 +336,7 @@ class FusedMC2CommImpl(MoECommMethod): group_ep=self.token_dispatcher.moe_all_to_all_group_name, ep_rank_size=self.token_dispatcher.ep_world_size, ep_rank_id=self.token_dispatcher.ep_rank_id, - moe_expert_num=len(expert_map), + moe_expert_num=self.moe_config.num_experts, global_bs=self.token_dispatcher.fused_global_bs) else: raise ValueError(