From c3c265648f6fb3bf9ea2f6c0e43a4a2e67973d40 Mon Sep 17 00:00:00 2001 From: Zhujiyang2 Date: Wed, 4 Mar 2026 16:02:08 +0800 Subject: [PATCH] [Ops][BugFix] Fix RoPE shape mismatch for mtp models with flashcomm v1 enabled (#6939) What this PR does / why we need it? When using a draft model (e.g., in MTP speculative decoding) with shared expert data parallelism (enabled via flashcomm), a shape mismatch error occurs in the rotary embedding calculation for models like GLM-4.7. This is because the positions tensor has an incorrect shape for this specific configuration. This PR fixes the issue by adding a check in AscendRotaryEmbedding.forward_oot. If the model is a draft model and shared expert DP is enabled, it processes the positions tensor using torch.ops.vllm.maybe_all_gather_and_maybe_unpad to ensure its shape is correct before applying the rotary embedding. This resolves the shape mismatch error. - vLLM version: v0.16.0 - vLLM main: https://github.com/vllm-project/vllm/commit/15d76f74e2fdb12a95ea00f0ca283acf6219a2b7 --------- Signed-off-by: Zhu Jiyang --- vllm_ascend/ops/rotary_embedding.py | 8 ++++++++ vllm_ascend/spec_decode/eagle_proposer.py | 4 ++-- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/vllm_ascend/ops/rotary_embedding.py b/vllm_ascend/ops/rotary_embedding.py index 042bc57e..d97eed6c 100644 --- a/vllm_ascend/ops/rotary_embedding.py +++ b/vllm_ascend/ops/rotary_embedding.py @@ -20,6 +20,8 @@ import os import torch import torch_npu +from vllm.config import get_current_vllm_config +from vllm.forward_context import get_forward_context from vllm.model_executor.layers.rotary_embedding import ( DeepseekScalingRotaryEmbedding, MRotaryEmbedding, @@ -222,6 +224,8 @@ class AscendRotaryEmbedding(RotaryEmbedding): dtype: torch.dtype, ) -> None: super().__init__(head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype) + vllm_config = get_current_vllm_config() + self.use_mtp = vllm_config.speculative_config and vllm_config.speculative_config.method == "mtp" _record_cos_sin_cache(self.cos_sin_cache) _record_cos_and_sin_cache_interleaved(self.cos_sin_cache) @@ -236,6 +240,10 @@ class AscendRotaryEmbedding(RotaryEmbedding): is_neox_style = self.is_neox_style if is_neox_style_override is not None: is_neox_style = is_neox_style_override + is_draft_model = get_forward_context().is_draft_model + flash_comm_v1_enabled = get_forward_context().flash_comm_v1_enabled + if is_draft_model and self.use_mtp and flash_comm_v1_enabled: + positions = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(positions.contiguous(), True) return torch.ops.vllm.npu_rotary_embedding( positions, query, key, self.cos_sin_cache, self.head_size, self.rotary_dim, is_neox_style ) diff --git a/vllm_ascend/spec_decode/eagle_proposer.py b/vllm_ascend/spec_decode/eagle_proposer.py index cc0dcd90..89ecb2e5 100644 --- a/vllm_ascend/spec_decode/eagle_proposer.py +++ b/vllm_ascend/spec_decode/eagle_proposer.py @@ -1331,14 +1331,14 @@ class EagleProposer(VllmEagleProposer): hidden_states: torch.Tensor, positions: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: + forward_context = get_forward_context() if self.method == "mtp": - if self.enable_shared_expert_dp: + if forward_context.flash_comm_v1_enabled: hidden_states = torch.ops.vllm.maybe_pad_and_reduce(hidden_states) positions = positions.unsqueeze(-1) positions = torch.ops.vllm.maybe_pad_and_reduce(positions) positions = positions.squeeze(-1) else: - forward_context = get_forward_context() if forward_context.flash_comm_v1_enabled: hidden_states = split_inputs_tp_to_sp(hidden_states, hidden_states) return hidden_states, positions