From aa02a85e4d10294e9f6b4678cbd9d85480f8e975 Mon Sep 17 00:00:00 2001 From: Chen Chen Date: Mon, 15 Dec 2025 14:18:23 +0800 Subject: [PATCH] [bugfix] Fix dummy-run and multi-node issues in MoE routing and MTP (#4947) ### What this PR does / why we need it? - Fix a premature `return` in `moe_init_routing_quant_v2.cpp` so the routing kernel completes correctly instead of exiting early in certain paths. - Switch `FusedAlltoAllCommImpl` to use the MC2-based token dispatcher and prepare/finalize routines, aligning MoE communication with the MC2 algorithm optimized for Ascend devices. - Add a temporary override in `MtpProposer` to map `FUSED_ALLTOALL` back to `ALLTOALL` until the MoE communication type selection logic is fully finalized, avoiding incorrect behavior in dummy-run flows. - Simplify the MoE communication selection for Ascend 910-93 in `NPUModelRunner` by removing the EP-size guard on `FUSED_ALLTOALL`, which fixes failures in multi-node / larger-EP configurations while keeping MC2 routing under the configured token capacity. - vLLM version: v0.12.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 Signed-off-by: mojave2 --- .../moe_init_routing_quant_v2.cpp | 1 - vllm_ascend/spec_decode/mtp_proposer.py | 3 +++ vllm_ascend/utils.py | 7 ++++--- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_init_routing_quant_v2.cpp b/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_init_routing_quant_v2.cpp index 811b2ce9..9180b06d 100644 --- a/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_init_routing_quant_v2.cpp +++ b/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_init_routing_quant_v2.cpp @@ -114,7 +114,6 @@ __aicore__ inline void moe_init_routing_quant_v2( srcToDstAndGatherOp.Init(x, scale, expandedRowIdx, expandedX, dynamicQuantScale, workspace, tilingData, &srcToDstGatherPipe); srcToDstAndGatherOp.Process(); srcToDstGatherPipe.Destroy(); - return; } } diff --git a/vllm_ascend/spec_decode/mtp_proposer.py b/vllm_ascend/spec_decode/mtp_proposer.py index 14db8976..28882924 100644 --- a/vllm_ascend/spec_decode/mtp_proposer.py +++ b/vllm_ascend/spec_decode/mtp_proposer.py @@ -734,6 +734,9 @@ class MtpProposer(Proposer): num_input_tokens, self.runner.with_prefill) moe_comm_type = self.runner._select_moe_comm_method(num_input_tokens) + # TODO: remove this after moe_comm_type selection logic is finalized + moe_comm_type = (MoECommType.ALLTOALL if moe_comm_type + == MoECommType.FUSED_ALLTOALL else moe_comm_type) # Enable shared_expert_dp and MTP FULL graph may cause accuracy issues. if scheduler_output and not self.enable_shared_expert_dp: diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index ef74c354..1dad4a28 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -920,16 +920,17 @@ def calculate_ep_buffer_size() -> int: try: from vllm.config import get_current_vllm_config vllm_config = get_current_vllm_config() + tp_size = vllm_config.parallel_config.tensor_parallel_size hf_config = vllm_config.model_config.hf_config hidden_size = hf_config.hidden_size - topk = getattr(hf_config, "num_experts_per_token", 1) - batch_size = vllm_config.scheduler_config.max_num_batched_tokens + topk = getattr(hf_config, "num_experts_per_tok", 1) + batch_size = vllm_config.scheduler_config.max_num_batched_tokens // tp_size int8_size = torch.iinfo(torch.int8).bits // 8 bf16_size = torch.finfo(torch.bfloat16).bits // 8 ep_buffer_size = math.ceil( (batch_size * hidden_size * topk * - (int8_size * 2 + bf16_size)) / (1024 * 1024)) + (int8_size + bf16_size) * 3) / (1024 * 1024)) except Exception: pass return max(ep_buffer_size, _DEFAULT_BUFFER_SIZE)