[Ops][BugFix] Fix RoPE shape mismatch for mtp models with flashcomm v1 enabled (#6939)

What this PR does / why we need it?
When using a draft model (e.g., in MTP speculative decoding) with shared
expert data parallelism (enabled via flashcomm), a shape mismatch error
occurs in the rotary embedding calculation for models like GLM-4.7. This
is because the positions tensor has an incorrect shape for this specific
configuration.

This PR fixes the issue by adding a check in
AscendRotaryEmbedding.forward_oot. If the model is a draft model and
shared expert DP is enabled, it processes the positions tensor using
torch.ops.vllm.maybe_all_gather_and_maybe_unpad to ensure its shape is
correct before applying the rotary embedding. This resolves the shape
mismatch error.
- vLLM version: v0.16.0
- vLLM main:
15d76f74e2

---------

Signed-off-by: Zhu Jiyang <zhujiyang2@huawei.com>
This commit is contained in:
Zhujiyang2
2026-03-04 16:02:08 +08:00
committed by GitHub
parent 95b44d7b73
commit c3c265648f
2 changed files with 10 additions and 2 deletions

View File

@@ -20,6 +20,8 @@ import os
import torch
import torch_npu
from vllm.config import get_current_vllm_config
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.rotary_embedding import (
DeepseekScalingRotaryEmbedding,
MRotaryEmbedding,
@@ -222,6 +224,8 @@ class AscendRotaryEmbedding(RotaryEmbedding):
dtype: torch.dtype,
) -> None:
super().__init__(head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype)
vllm_config = get_current_vllm_config()
self.use_mtp = vllm_config.speculative_config and vllm_config.speculative_config.method == "mtp"
_record_cos_sin_cache(self.cos_sin_cache)
_record_cos_and_sin_cache_interleaved(self.cos_sin_cache)
@@ -236,6 +240,10 @@ class AscendRotaryEmbedding(RotaryEmbedding):
is_neox_style = self.is_neox_style
if is_neox_style_override is not None:
is_neox_style = is_neox_style_override
is_draft_model = get_forward_context().is_draft_model
flash_comm_v1_enabled = get_forward_context().flash_comm_v1_enabled
if is_draft_model and self.use_mtp and flash_comm_v1_enabled:
positions = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(positions.contiguous(), True)
return torch.ops.vllm.npu_rotary_embedding(
positions, query, key, self.cos_sin_cache, self.head_size, self.rotary_dim, is_neox_style
)

View File

@@ -1331,14 +1331,14 @@ class EagleProposer(VllmEagleProposer):
hidden_states: torch.Tensor,
positions: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
forward_context = get_forward_context()
if self.method == "mtp":
if self.enable_shared_expert_dp:
if forward_context.flash_comm_v1_enabled:
hidden_states = torch.ops.vllm.maybe_pad_and_reduce(hidden_states)
positions = positions.unsqueeze(-1)
positions = torch.ops.vllm.maybe_pad_and_reduce(positions)
positions = positions.squeeze(-1)
else:
forward_context = get_forward_context()
if forward_context.flash_comm_v1_enabled:
hidden_states = split_inputs_tp_to_sp(hidden_states, hidden_states)
return hidden_states, positions