From 848419d1ba89eddea2bcd721b1d59469fb4e68e4 Mon Sep 17 00:00:00 2001 From: Chen Chen Date: Tue, 9 Dec 2025 22:14:05 +0800 Subject: [PATCH] [Bugfix] Disable the dispatch_ffn_combine kernel in MTP path (#4751) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What this PR does / why we need it? This PR is to fix a smoking test failure. Adjust mtp_proposer and model_runner_v1 to route MTP decoding through the non‑fused MoE implementation while keeping the overall inference flow unchanged. - vLLM version: v0.12.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 Signed-off-by: mojave2 Co-authored-by: Mengqing Cao --- vllm_ascend/spec_decode/mtp_proposer.py | 6 +++++- vllm_ascend/worker/model_runner_v1.py | 13 +++++++------ 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/vllm_ascend/spec_decode/mtp_proposer.py b/vllm_ascend/spec_decode/mtp_proposer.py index 08930190..eb71bfb2 100644 --- a/vllm_ascend/spec_decode/mtp_proposer.py +++ b/vllm_ascend/spec_decode/mtp_proposer.py @@ -27,7 +27,8 @@ from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.utils import CpuGpuBuffer from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch -from vllm_ascend.ascend_forward_context import set_ascend_forward_context +from vllm_ascend.ascend_forward_context import (MoECommType, + set_ascend_forward_context) from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.utils import AscendCommonAttentionMetadata from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper, @@ -237,6 +238,9 @@ class MtpProposer(Proposer): ) = self.runner._sync_metadata_across_dp(num_tokens, with_prefill) moe_comm_type = self.runner._select_moe_comm_method(num_tokens) + # TODO: remove this after moe_comm_type selection logic is finalized + moe_comm_type = (MoECommType.ALLTOALL if moe_comm_type + == MoECommType.FUSED_ALLTOALL else moe_comm_type) if skip_attn: attn_metadata = None diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index f6fded7d..24501d3d 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -52,8 +52,8 @@ from vllm.distributed.kv_transfer import (get_kv_transfer_group, has_kv_transfer_group) from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1 from vllm.distributed.parallel_state import (get_dcp_group, get_dp_group, - get_pcp_group, get_pp_group, - get_tp_group, + get_ep_group, get_pcp_group, + get_pp_group, get_tp_group, is_global_first_rank) from vllm.forward_context import get_forward_context from vllm.logger import logger @@ -2332,10 +2332,11 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin): moe_comm_type = MoECommType.ALLGATHER elif soc_version in {AscendDeviceType._910_93}: - moe_comm_type = (MoECommType.MC2 - if num_tokens <= self.mc2_tokens_capacity else - MoECommType.FUSED_ALLTOALL if quant_type - == "w8a8_dynamic" else MoECommType.ALLTOALL) + # TODO: drop the EP-size guard when dispatch_ffn_combine supports larger EP sizes + moe_comm_type = ( + MoECommType.MC2 if num_tokens <= self.mc2_tokens_capacity else + MoECommType.FUSED_ALLTOALL if quant_type == "w8a8_dynamic" + and get_ep_group().world_size <= 16 else MoECommType.ALLTOALL) else: raise ValueError(f"Unsupported soc_version: {soc_version}")