From d94b1dc2d0c852cbd3fe8dd9ad6990232f32de1a Mon Sep 17 00:00:00 2001 From: wangbj127 <256472688+wangbj127@users.noreply.github.com> Date: Mon, 13 Apr 2026 17:36:09 +0800 Subject: [PATCH] [v0.18.0][BugFix] Fix Qwen3.5 MoE flash comm v1 shared expert shape error of mtp layer on A2 (#8004) ### What this PR does / why we need it? Fix Qwen3.5 MoE MTP layer shared expert shape error when flash comm v1 is enabled. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? - vLLM version: v0.18.0 - vLLM main: https://github.com/vllm-project/vllm/commit/35141a7eeda941a60ad5a4956670c60fd5a77029 Signed-off-by: Wangbingjie --- vllm_ascend/ops/register_custom_ops.py | 2 +- vllm_ascend/spec_decode/eagle_proposer.py | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/vllm_ascend/ops/register_custom_ops.py b/vllm_ascend/ops/register_custom_ops.py index dea1e0bf..bc287dd2 100644 --- a/vllm_ascend/ops/register_custom_ops.py +++ b/vllm_ascend/ops/register_custom_ops.py @@ -80,7 +80,7 @@ def _maybe_pad_and_reduce_impl(x: torch.Tensor, is_ep_comm: bool = False) -> tor enable_sp_by_pass() and is_ep_comm ) - if not flash_comm_v1_enabled or (forward_context.is_draft_model and is_vl_model()): + if not flash_comm_v1_enabled or (forward_context.is_draft_model and is_vl_model() and not is_ep_comm): return tensor_model_parallel_all_reduce(x) dp_metadata = forward_context.dp_metadata diff --git a/vllm_ascend/spec_decode/eagle_proposer.py b/vllm_ascend/spec_decode/eagle_proposer.py index e47f4590..7e193371 100644 --- a/vllm_ascend/spec_decode/eagle_proposer.py +++ b/vllm_ascend/spec_decode/eagle_proposer.py @@ -1595,10 +1595,8 @@ class SpecDecodeBaseProposer(EagleProposer): hidden_states: torch.Tensor, positions: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: - if self.is_multimodal_model and _EXTRA_CTX.flash_comm_v1_enabled: - return hidden_states, positions if self.method == "mtp": - if _EXTRA_CTX.flash_comm_v1_enabled: + if _EXTRA_CTX.flash_comm_v1_enabled and not self.is_multimodal_model: hidden_states = torch.ops.vllm.maybe_pad_and_reduce(hidden_states) positions = positions.unsqueeze(-1) positions = torch.ops.vllm.maybe_pad_and_reduce(positions)