[Main][Ops] Make triton rope support index_selecting from cos_sin_cache (#5450)

### What this PR does / why we need it?

This PR extends original `rope_triton_forward` and
`split_qkv_rmsnorm_rope` to support `cos_sin_cache` && `positions` as
inputs. This fully aligns to vLLM RoPE api interface. Compared with
earlier implementation for RoPE, the benefits are:

1. avoiding pre-computation of `cos` `sin` before model execution, which
helps to remove redundant codes.
2. allowing eagle3 draft model to have different rope parameters with
main model (see #6612 ). This help to recover accept rate && accuracy in
that case.

In addition, this kernel change only introduces very small performance
degradation. Those `index_select` or `chunk` operations are now changed
into simple memory access in triton kernel (For example,
https://github.com/vllm-project/vllm-ascend/pull/5450/changes#diff-a4c2d3071530df193b98f9bf38553874bc4d47571336711f116c26d019cfbb6aR77-R81).

**Highlights**

- **RoPE Cache Unification**: Replaced separate _sin and _cos global
tensors with a unified cos_sin_cache and explicit positions tensor for
Rotary Positional Embeddings (RoPE), streamlining data handling.
- **Triton Kernel Integration**: Updated Triton kernels
(split_qkv_rmsnorm_rope_kernel, _triton_rope) to directly consume the
cos_sin_cache and positions for more efficient and integrated RoPE
calculations.
- **Custom Operation Registration**: Registered `rope_forward_oot` as a
new custom operation, allowing its use in fused compilation passes and
providing a dedicated entry point for the new RoPE implementation.
- **Refactored RoPE Forward Pass**: Modified the rope_forward_oot
function to accept the new cos_sin_cache and positions arguments,
enabling a more flexible and integrated RoPE application within the
system.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
5326c89803

Additional test on Qwen3-235b accuracy:

| Aime2024 | GSM8K | Livecodebench |
| -------- | -------- | -------- |
| 83.33 | 96.26 | 70.23 |

---------

Signed-off-by: Angazenn <supperccell@163.com>
This commit is contained in:
Angazenn
2026-02-11 21:20:53 +08:00
committed by GitHub
parent 6bc44bf49b
commit c0c2eb614e
13 changed files with 378 additions and 243 deletions

View File

@@ -29,11 +29,13 @@ from vllm.model_executor.layers.rotary_embedding import (
from vllm.model_executor.layers.rotary_embedding.common import ApplyRotaryEmb
from vllm.triton_utils import HAS_TRITON
from vllm_ascend.platform import NPUPlatform
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type, has_rope, is_vl_model
if HAS_TRITON:
from vllm.model_executor.layers.rotary_embedding.mrope import triton_mrope
from vllm_ascend.platform import NPUPlatform
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type, has_rope, is_vl_model
from vllm_ascend.ops.triton.rope import rope_forward_triton
# Currently, rope ops used on npu requires detached cos && sin as inputs.
# However, RotaryEmbedding in vllm use cos_sin_cache as a whole variable.
@@ -146,54 +148,38 @@ def get_cos_and_sin_slice():
return _cos_slice, _sin_slice
def _rope_forward_oot(
self,
def rope_forward_oot(
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
cos_sin_cache: torch.Tensor,
head_size: int,
rotary_dim: int,
is_neox_style: bool,
offsets: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
query_shape, key_shape = query.shape, key.shape
if self.cos_sin_cache.device != query.device:
self.cos_sin_cache = self.cos_sin_cache.to(query.device)
if self.cos_sin_cache.dtype != query.dtype:
self.cos_sin_cache = self.cos_sin_cache.to(query.dtype)
cos, sin = get_cos_and_sin_slice()
if offsets is not None:
raise NotImplementedError("Batched rotary embedding is currently not supported on NPU.")
if (
is_neox_style
and self.head_size == 128
and self.cos_sin_cache.shape[-1] == 128
and cos is not None
and sin is not None
):
# If cos and sin are generated outside, use npu_apply_rotary_pos_emb to avoid redundant calculation.
# This method requires head_size and rotary_dim equal 128 and neox_style is True
query = query.contiguous().view(1, query.shape[0], -1, self.head_size)
key = key.contiguous().view(1, key.shape[0], -1, self.head_size)
# Although this function modifies in-place, please retain the function's return value.
# Otherwise, the graph fusion operation may fail.
query, key = torch_npu.npu_apply_rotary_pos_emb(query, key, cos, sin)
elif self.rotary_dim < self.head_size:
if HAS_TRITON:
cos = cos.view(-1, self.rotary_dim)
sin = sin.view(-1, self.rotary_dim)
q = query.contiguous().view(query.shape[0], -1, self.head_size)
k = key.contiguous().view(key.shape[0], -1, self.head_size)
query, key = torch.ops.vllm.rope_forward_triton(
q, k, cos, sin, rope_dim=self.rotary_dim, is_neox_style=True
)
return query.view(query_shape), key.view(key_shape)
else:
if HAS_TRITON:
num_tokens = query.shape[0]
query, key = rope_forward_triton(
query.view(num_tokens, -1, head_size),
key.view(num_tokens, -1, head_size),
cos_sin_cache=cos_sin_cache,
positions=positions,
rope_dim=rotary_dim,
is_neox_style=is_neox_style,
)
else:
if rotary_dim < head_size:
num_tokens = query.shape[0]
query = query.view(num_tokens, -1, self.head_size)
key = key.view(num_tokens, -1, self.head_size)
q_rot = query[..., : self.rotary_dim]
q_pass = query[..., self.rotary_dim :]
k_rot = key[..., : self.rotary_dim]
k_pass = key[..., self.rotary_dim :]
query = query.view(num_tokens, -1, head_size)
key = key.view(num_tokens, -1, head_size)
q_rot = query[..., :rotary_dim]
q_pass = query[..., rotary_dim:]
k_rot = key[..., :rotary_dim]
k_pass = key[..., rotary_dim:]
q_rot = q_rot.contiguous().view(num_tokens, -1)
k_rot = k_rot.contiguous().view(num_tokens, -1)
# only the rotary part is processed here,
@@ -202,27 +188,26 @@ def _rope_forward_oot(
positions,
q_rot,
k_rot,
self.rotary_dim,
self.cos_sin_cache,
rotary_dim,
cos_sin_cache,
is_neox_style,
)
q_rot = q_rot.view(num_tokens, -1, rotary_dim)
k_rot = k_rot.view(num_tokens, -1, rotary_dim)
query = torch.cat((q_rot, q_pass), dim=-1).reshape(query_shape)
key = torch.cat((k_rot, k_pass), dim=-1).reshape(key_shape)
else:
# TODO: Remove the contiguous in the future.
query = query.contiguous().view(query.shape[0], -1)
key = key.contiguous().view(key.shape[0], -1)
torch_npu._npu_rotary_embedding(
positions,
query,
key,
head_size,
cos_sin_cache,
is_neox_style,
)
q_rot = q_rot.view(num_tokens, -1, self.rotary_dim)
k_rot = k_rot.view(num_tokens, -1, self.rotary_dim)
q = torch.cat((q_rot, q_pass), dim=-1).reshape(query_shape)
k = torch.cat((k_rot, k_pass), dim=-1).reshape(key_shape)
return q, k
else:
# TODO: Remove the contiguous in the future.
query = query.contiguous().view(query.shape[0], -1)
key = key.contiguous().view(key.shape[0], -1)
torch_npu._npu_rotary_embedding(
positions,
query,
key,
self.head_size,
self.cos_sin_cache,
is_neox_style,
)
return query.view(query_shape), key.view(key_shape)
@@ -251,7 +236,9 @@ class AscendRotaryEmbedding(RotaryEmbedding):
is_neox_style = self.is_neox_style
if is_neox_style_override is not None:
is_neox_style = is_neox_style_override
return _rope_forward_oot(self, positions, query, key, is_neox_style, offsets)
return torch.ops.vllm.npu_rotary_embedding(
positions, query, key, self.cos_sin_cache, self.head_size, self.rotary_dim, is_neox_style
)
class AscendYaRNRotaryEmbedding(YaRNScalingRotaryEmbedding):
@@ -460,7 +447,9 @@ class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding):
query = query.view(b, h_q, d // 2, 2).transpose(3, 2).reshape(b, h_q, d)
b, h_k, d = key.shape
key = key.view(b, h_k, d // 2, 2).transpose(3, 2).reshape(b, h_k, d)
q_pe, k_pe = _rope_forward_oot(self, positions, query, key, is_neox_style, offsets)
q_pe, k_pe = torch.ops.vllm.npu_rotary_embedding(
positions, query, key, self.cos_sin_cache, self.head_size, self.rotary_dim, is_neox_style
)
return q_pe, k_pe