[Ops][BugFix] Fix RoPE shape mismatch for mtp models with flashcomm v1 enabled (#6939)
What this PR does / why we need it?
When using a draft model (e.g., in MTP speculative decoding) with shared
expert data parallelism (enabled via flashcomm), a shape mismatch error
occurs in the rotary embedding calculation for models like GLM-4.7. This
is because the positions tensor has an incorrect shape for this specific
configuration.
This PR fixes the issue by adding a check in
AscendRotaryEmbedding.forward_oot. If the model is a draft model and
shared expert DP is enabled, it processes the positions tensor using
torch.ops.vllm.maybe_all_gather_and_maybe_unpad to ensure its shape is
correct before applying the rotary embedding. This resolves the shape
mismatch error.
- vLLM version: v0.16.0
- vLLM main:
15d76f74e2
---------
Signed-off-by: Zhu Jiyang <zhujiyang2@huawei.com>
This commit is contained in:
@@ -20,6 +20,8 @@ import os
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch_npu
|
import torch_npu
|
||||||
|
from vllm.config import get_current_vllm_config
|
||||||
|
from vllm.forward_context import get_forward_context
|
||||||
from vllm.model_executor.layers.rotary_embedding import (
|
from vllm.model_executor.layers.rotary_embedding import (
|
||||||
DeepseekScalingRotaryEmbedding,
|
DeepseekScalingRotaryEmbedding,
|
||||||
MRotaryEmbedding,
|
MRotaryEmbedding,
|
||||||
@@ -222,6 +224,8 @@ class AscendRotaryEmbedding(RotaryEmbedding):
|
|||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype)
|
super().__init__(head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype)
|
||||||
|
vllm_config = get_current_vllm_config()
|
||||||
|
self.use_mtp = vllm_config.speculative_config and vllm_config.speculative_config.method == "mtp"
|
||||||
_record_cos_sin_cache(self.cos_sin_cache)
|
_record_cos_sin_cache(self.cos_sin_cache)
|
||||||
_record_cos_and_sin_cache_interleaved(self.cos_sin_cache)
|
_record_cos_and_sin_cache_interleaved(self.cos_sin_cache)
|
||||||
|
|
||||||
@@ -236,6 +240,10 @@ class AscendRotaryEmbedding(RotaryEmbedding):
|
|||||||
is_neox_style = self.is_neox_style
|
is_neox_style = self.is_neox_style
|
||||||
if is_neox_style_override is not None:
|
if is_neox_style_override is not None:
|
||||||
is_neox_style = is_neox_style_override
|
is_neox_style = is_neox_style_override
|
||||||
|
is_draft_model = get_forward_context().is_draft_model
|
||||||
|
flash_comm_v1_enabled = get_forward_context().flash_comm_v1_enabled
|
||||||
|
if is_draft_model and self.use_mtp and flash_comm_v1_enabled:
|
||||||
|
positions = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(positions.contiguous(), True)
|
||||||
return torch.ops.vllm.npu_rotary_embedding(
|
return torch.ops.vllm.npu_rotary_embedding(
|
||||||
positions, query, key, self.cos_sin_cache, self.head_size, self.rotary_dim, is_neox_style
|
positions, query, key, self.cos_sin_cache, self.head_size, self.rotary_dim, is_neox_style
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1331,14 +1331,14 @@ class EagleProposer(VllmEagleProposer):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
forward_context = get_forward_context()
|
||||||
if self.method == "mtp":
|
if self.method == "mtp":
|
||||||
if self.enable_shared_expert_dp:
|
if forward_context.flash_comm_v1_enabled:
|
||||||
hidden_states = torch.ops.vllm.maybe_pad_and_reduce(hidden_states)
|
hidden_states = torch.ops.vllm.maybe_pad_and_reduce(hidden_states)
|
||||||
positions = positions.unsqueeze(-1)
|
positions = positions.unsqueeze(-1)
|
||||||
positions = torch.ops.vllm.maybe_pad_and_reduce(positions)
|
positions = torch.ops.vllm.maybe_pad_and_reduce(positions)
|
||||||
positions = positions.squeeze(-1)
|
positions = positions.squeeze(-1)
|
||||||
else:
|
else:
|
||||||
forward_context = get_forward_context()
|
|
||||||
if forward_context.flash_comm_v1_enabled:
|
if forward_context.flash_comm_v1_enabled:
|
||||||
hidden_states = split_inputs_tp_to_sp(hidden_states, hidden_states)
|
hidden_states = split_inputs_tp_to_sp(hidden_states, hidden_states)
|
||||||
return hidden_states, positions
|
return hidden_states, positions
|
||||||
|
|||||||
Reference in New Issue
Block a user