diff --git a/vllm_ascend/ops/rotary_embedding.py b/vllm_ascend/ops/rotary_embedding.py index 042bc57e..d97eed6c 100644 --- a/vllm_ascend/ops/rotary_embedding.py +++ b/vllm_ascend/ops/rotary_embedding.py @@ -20,6 +20,8 @@ import os import torch import torch_npu +from vllm.config import get_current_vllm_config +from vllm.forward_context import get_forward_context from vllm.model_executor.layers.rotary_embedding import ( DeepseekScalingRotaryEmbedding, MRotaryEmbedding, @@ -222,6 +224,8 @@ class AscendRotaryEmbedding(RotaryEmbedding): dtype: torch.dtype, ) -> None: super().__init__(head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype) + vllm_config = get_current_vllm_config() + self.use_mtp = vllm_config.speculative_config and vllm_config.speculative_config.method == "mtp" _record_cos_sin_cache(self.cos_sin_cache) _record_cos_and_sin_cache_interleaved(self.cos_sin_cache) @@ -236,6 +240,10 @@ class AscendRotaryEmbedding(RotaryEmbedding): is_neox_style = self.is_neox_style if is_neox_style_override is not None: is_neox_style = is_neox_style_override + is_draft_model = get_forward_context().is_draft_model + flash_comm_v1_enabled = get_forward_context().flash_comm_v1_enabled + if is_draft_model and self.use_mtp and flash_comm_v1_enabled: + positions = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(positions.contiguous(), True) return torch.ops.vllm.npu_rotary_embedding( positions, query, key, self.cos_sin_cache, self.head_size, self.rotary_dim, is_neox_style ) diff --git a/vllm_ascend/spec_decode/eagle_proposer.py b/vllm_ascend/spec_decode/eagle_proposer.py index cc0dcd90..89ecb2e5 100644 --- a/vllm_ascend/spec_decode/eagle_proposer.py +++ b/vllm_ascend/spec_decode/eagle_proposer.py @@ -1331,14 +1331,14 @@ class EagleProposer(VllmEagleProposer): hidden_states: torch.Tensor, positions: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: + forward_context = get_forward_context() if self.method == "mtp": - if self.enable_shared_expert_dp: + if forward_context.flash_comm_v1_enabled: hidden_states = torch.ops.vllm.maybe_pad_and_reduce(hidden_states) positions = positions.unsqueeze(-1) positions = torch.ops.vllm.maybe_pad_and_reduce(positions) positions = positions.squeeze(-1) else: - forward_context = get_forward_context() if forward_context.flash_comm_v1_enabled: hidden_states = split_inputs_tp_to_sp(hidden_states, hidden_states) return hidden_states, positions