support triton of mrope (#5664)

### What this PR does / why we need it?
this pr support use triton mrope like cuda_forward, which performance is
equal to ascendc ops
this triton ops should use cann 8.5.0
### Does this PR introduce _any_ user-facing change?

### How was this patch tested?
test in qwen3-vl-235b acc textvqa
native 81.82
npu triton 81.58
cuda triton 81.52
- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef

Signed-off-by: shiyuan680 <917935075@qq.com>
This commit is contained in:
shiyuan680
2026-01-13 09:13:51 +08:00
committed by GitHub
parent db7cf9b0ca
commit 7af3b880c1
2 changed files with 235 additions and 0 deletions

View File

@@ -25,6 +25,10 @@ from vllm.model_executor.layers.rotary_embedding import (
DeepseekScalingRotaryEmbedding, MRotaryEmbedding, RotaryEmbedding,
YaRNScalingRotaryEmbedding)
from vllm.model_executor.layers.rotary_embedding.common import ApplyRotaryEmb
from vllm.triton_utils import HAS_TRITON
if HAS_TRITON:
from vllm.model_executor.layers.rotary_embedding.mrope import triton_mrope
from vllm_ascend.platform import NPUPlatform
from vllm_ascend.utils import (AscendDeviceType, enable_custom_op,
@@ -527,12 +531,50 @@ class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding):
class AscendMRotaryEmbedding(MRotaryEmbedding):
def forward_triton(self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor | None = None,
offsets: torch.Tensor | None = None):
assert positions.ndim == 2
assert key is not None
self._match_cos_sin_cache_dtype(query)
self.cos = None
self.sin = None
if self.cos is None and self.sin is None:
cos_sin = self.cos_sin_cache[positions] # type: ignore
cos, sin = cos_sin.chunk(2, dim=-1)
self.cos = cos.contiguous()
self.sin = sin.contiguous()
query_shape = query.shape
key_shape = key.shape
assert self.mrope_section
q, k = triton_mrope(
query,
key,
self.cos,
self.sin,
self.mrope_section,
self.head_size,
self.rotary_dim,
self.mrope_interleaved,
)
return q.reshape(query_shape), k.reshape(key_shape)
def forward_oot(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
):
if HAS_TRITON and positions.ndim == 2:
# todo: need cann update in 8.5.0
return self.forward_triton(positions, query, key)
if self.mrope_section != [16, 24, 24] or \
get_ascend_device_type() == AscendDeviceType.A5:
return super().forward_oot(positions, query, key)