[cherry-pick][Feat] Add mrope fusion op#3708 (#3735)

### What this PR does / why we need it?
Add mrope fusion op for qwen2.5-vl. This mrope operator dosen't
support Qwen3-VL currently. Thus could only take affect in qwen2.5-vl
cherry pick from 39b994a987

CI passed with existing test

Signed-off-by: shaopeng666 <shaopeng666@noreply.gitcode.com>
Co-authored-by: shaopeng666 <shaopeng666@noreply.gitcode.com>
This commit is contained in:
shaopeng-666
2025-10-25 11:41:23 +08:00
committed by GitHub
parent 0644113c35
commit fed8145aea
3 changed files with 123 additions and 4 deletions

View File

@@ -22,7 +22,7 @@ import torch
import torch_npu
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.rotary_embedding import (
DeepseekScalingRotaryEmbedding, RotaryEmbedding,
DeepseekScalingRotaryEmbedding, MRotaryEmbedding, RotaryEmbedding,
YaRNScalingRotaryEmbedding)
from vllm_ascend.platform import NPUPlatform
@@ -395,3 +395,37 @@ class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding):
q_pe, k_pe = _rope_forward_oot(self, positions, query, key,
is_neox_style, offsets)
return q_pe, k_pe
class AscendMRotaryEmbedding(MRotaryEmbedding):
def forward_oot(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
):
if self.mrope_section != [16, 24, 24]:
return super().forward_oot(positions, query, key)
import torch_npu
mrope_section = [0, 0, 0
] if positions.ndim == 1 else self.mrope_section
if self.cos_sin_cache.device != query.device: # type: ignore
self.cos_sin_cache = self.cos_sin_cache.to( # type: ignore
query.device) # type: ignore
if self.cos_sin_cache.dtype != query.dtype: # type: ignore
self.cos_sin_cache = self.cos_sin_cache.to( # type: ignore
query.dtype) # type: ignore
query, key = torch_npu.npu_mrope(positions,
query.contiguous(),
key.contiguous(),
self.cos_sin_cache.contiguous(),
self.head_size,
mrope_section=mrope_section,
rotary_mode='half')
return query, key