[Feat] Add mrope fusion op (#3708)
### What this PR does / why we need it? Add mrope fusion op for qwen2.5-vl. This mrope operator dosen't support Qwen3-VL currently. Thus could only take affect in qwen2.5-vl - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 Signed-off-by: shaopeng666 <shaopeng666@noreply.gitcode.com> Co-authored-by: shaopeng666 <shaopeng666@noreply.gitcode.com>
This commit is contained in:
@@ -22,7 +22,7 @@ import torch
|
||||
import torch_npu
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.rotary_embedding import (
|
||||
DeepseekScalingRotaryEmbedding, RotaryEmbedding,
|
||||
DeepseekScalingRotaryEmbedding, MRotaryEmbedding, RotaryEmbedding,
|
||||
YaRNScalingRotaryEmbedding)
|
||||
|
||||
from vllm_ascend.platform import NPUPlatform
|
||||
@@ -395,3 +395,37 @@ class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding):
|
||||
q_pe, k_pe = _rope_forward_oot(self, positions, query, key,
|
||||
is_neox_style, offsets)
|
||||
return q_pe, k_pe
|
||||
|
||||
|
||||
class AscendMRotaryEmbedding(MRotaryEmbedding):
|
||||
|
||||
def forward_oot(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
):
|
||||
if self.mrope_section != [16, 24, 24]:
|
||||
return super().forward_oot(positions, query, key)
|
||||
|
||||
import torch_npu
|
||||
mrope_section = [0, 0, 0
|
||||
] if positions.ndim == 1 else self.mrope_section
|
||||
|
||||
if self.cos_sin_cache.device != query.device: # type: ignore
|
||||
self.cos_sin_cache = self.cos_sin_cache.to( # type: ignore
|
||||
query.device) # type: ignore
|
||||
|
||||
if self.cos_sin_cache.dtype != query.dtype: # type: ignore
|
||||
self.cos_sin_cache = self.cos_sin_cache.to( # type: ignore
|
||||
query.dtype) # type: ignore
|
||||
|
||||
query, key = torch_npu.npu_mrope(positions,
|
||||
query.contiguous(),
|
||||
key.contiguous(),
|
||||
self.cos_sin_cache.contiguous(),
|
||||
self.head_size,
|
||||
mrope_section=mrope_section,
|
||||
rotary_mode='half')
|
||||
|
||||
return query, key
|
||||
Reference in New Issue
Block a user