support triton of mrope (#5664)
### What this PR does / why we need it?
this pr support use triton mrope like cuda_forward, which performance is
equal to ascendc ops
this triton ops should use cann 8.5.0
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
test in qwen3-vl-235b acc textvqa
native 81.82
npu triton 81.58
cuda triton 81.52
- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef
Signed-off-by: shiyuan680 <917935075@qq.com>
This commit is contained in:
@@ -25,6 +25,10 @@ from vllm.model_executor.layers.rotary_embedding import (
|
||||
DeepseekScalingRotaryEmbedding, MRotaryEmbedding, RotaryEmbedding,
|
||||
YaRNScalingRotaryEmbedding)
|
||||
from vllm.model_executor.layers.rotary_embedding.common import ApplyRotaryEmb
|
||||
from vllm.triton_utils import HAS_TRITON
|
||||
|
||||
if HAS_TRITON:
|
||||
from vllm.model_executor.layers.rotary_embedding.mrope import triton_mrope
|
||||
|
||||
from vllm_ascend.platform import NPUPlatform
|
||||
from vllm_ascend.utils import (AscendDeviceType, enable_custom_op,
|
||||
@@ -527,12 +531,50 @@ class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding):
|
||||
|
||||
class AscendMRotaryEmbedding(MRotaryEmbedding):
|
||||
|
||||
def forward_triton(self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor | None = None,
|
||||
offsets: torch.Tensor | None = None):
|
||||
assert positions.ndim == 2
|
||||
assert key is not None
|
||||
|
||||
self._match_cos_sin_cache_dtype(query)
|
||||
self.cos = None
|
||||
self.sin = None
|
||||
if self.cos is None and self.sin is None:
|
||||
cos_sin = self.cos_sin_cache[positions] # type: ignore
|
||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||
self.cos = cos.contiguous()
|
||||
self.sin = sin.contiguous()
|
||||
query_shape = query.shape
|
||||
key_shape = key.shape
|
||||
|
||||
assert self.mrope_section
|
||||
|
||||
q, k = triton_mrope(
|
||||
query,
|
||||
key,
|
||||
self.cos,
|
||||
self.sin,
|
||||
self.mrope_section,
|
||||
self.head_size,
|
||||
self.rotary_dim,
|
||||
self.mrope_interleaved,
|
||||
)
|
||||
|
||||
return q.reshape(query_shape), k.reshape(key_shape)
|
||||
|
||||
def forward_oot(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
):
|
||||
if HAS_TRITON and positions.ndim == 2:
|
||||
# todo: need cann update in 8.5.0
|
||||
return self.forward_triton(positions, query, key)
|
||||
|
||||
if self.mrope_section != [16, 24, 24] or \
|
||||
get_ascend_device_type() == AscendDeviceType.A5:
|
||||
return super().forward_oot(positions, query, key)
|
||||
|
||||
Reference in New Issue
Block a user