[Main][Ops] Make triton rope support index_selecting from cos_sin_cache (#5450)

### What this PR does / why we need it?

This PR extends original `rope_triton_forward` and
`split_qkv_rmsnorm_rope` to support `cos_sin_cache` && `positions` as
inputs. This fully aligns to vLLM RoPE api interface. Compared with
earlier implementation for RoPE, the benefits are:

1. avoiding pre-computation of `cos` `sin` before model execution, which
helps to remove redundant codes.
2. allowing eagle3 draft model to have different rope parameters with
main model (see #6612 ). This help to recover accept rate && accuracy in
that case.

In addition, this kernel change only introduces very small performance
degradation. Those `index_select` or `chunk` operations are now changed
into simple memory access in triton kernel (For example,
https://github.com/vllm-project/vllm-ascend/pull/5450/changes#diff-a4c2d3071530df193b98f9bf38553874bc4d47571336711f116c26d019cfbb6aR77-R81).

**Highlights**

- **RoPE Cache Unification**: Replaced separate _sin and _cos global
tensors with a unified cos_sin_cache and explicit positions tensor for
Rotary Positional Embeddings (RoPE), streamlining data handling.
- **Triton Kernel Integration**: Updated Triton kernels
(split_qkv_rmsnorm_rope_kernel, _triton_rope) to directly consume the
cos_sin_cache and positions for more efficient and integrated RoPE
calculations.
- **Custom Operation Registration**: Registered `rope_forward_oot` as a
new custom operation, allowing its use in fused compilation passes and
providing a dedicated entry point for the new RoPE implementation.
- **Refactored RoPE Forward Pass**: Modified the rope_forward_oot
function to accept the new cos_sin_cache and positions arguments,
enabling a more flexible and integrated RoPE application within the
system.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
5326c89803

Additional test on Qwen3-235b accuracy:

| Aime2024 | GSM8K | Livecodebench |
| -------- | -------- | -------- |
| 83.33 | 96.26 | 70.23 |

---------

Signed-off-by: Angazenn <supperccell@163.com>
This commit is contained in:
Angazenn
2026-02-11 21:20:53 +08:00
committed by GitHub
parent 6bc44bf49b
commit c0c2eb614e
13 changed files with 378 additions and 243 deletions

View File

@@ -26,8 +26,8 @@ from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num
@triton.jit
def split_qkv_rmsnorm_rope_kernel(
input_ptr,
sin_ptr,
cos_ptr,
cos_sin_ptr,
pos_ptr,
q_ptr,
k_ptr,
v_ptr,
@@ -74,9 +74,11 @@ def split_qkv_rmsnorm_rope_kernel(
else:
normalized_values = (normalized_values * weight_values).to(tl.bfloat16)
sc_offsets = row_idx * HEAD_DIM + tl.arange(0, HEAD_DIM)
sin = (tl.load(sin_ptr + sc_offsets)).reshape(1, HEAD_DIM)
cos = (tl.load(cos_ptr + sc_offsets)).reshape(1, HEAD_DIM)
pos_idx = tl.load(pos_ptr + row_idx).to(tl.int64)
cos_offsets = pos_idx * HEAD_DIM + tl.arange(0, HALF_HEAD_DIM)
sin_offsets = pos_idx * HEAD_DIM + tl.arange(HALF_HEAD_DIM, HEAD_DIM)
cos = (tl.load(cos_sin_ptr + cos_offsets)).reshape(1, HALF_HEAD_DIM)
sin = (tl.load(cos_sin_ptr + sin_offsets)).reshape(1, HALF_HEAD_DIM)
x1 = tl.extract_slice(
normalized_values,
offsets=(0, 0),
@@ -89,22 +91,24 @@ def split_qkv_rmsnorm_rope_kernel(
sizes=(Q_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
strides=(1, 1),
)
cat_x = tl.zeros((Q_BLOCK_SIZE // HEAD_DIM, HEAD_DIM), dtype=tl.bfloat16)
cat_x = tl.insert_slice(
cat_x,
-x2,
roped_q1 = x1 * cos - x2 * sin
roped_q2 = x2 * cos + x1 * sin
roped_q = tl.zeros((Q_BLOCK_SIZE // HEAD_DIM, HEAD_DIM), dtype=tl.bfloat16)
roped_q = tl.insert_slice(
roped_q,
roped_q1,
offsets=(0, 0),
sizes=(Q_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
strides=(1, 1),
)
cat_x = tl.insert_slice(
cat_x,
x1,
roped_q = tl.insert_slice(
roped_q,
roped_q2,
offsets=(0, HALF_HEAD_DIM),
sizes=(Q_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
strides=(1, 1),
)
roped_q = cat_x * sin + normalized_values * cos
tl.store(
q_ptr + output_offset + col_indices,
roped_q.reshape(Q_BLOCK_SIZE).to(q_ptr.dtype.element_ty),
@@ -135,9 +139,12 @@ def split_qkv_rmsnorm_rope_kernel(
normalized_values = (normalized_values * weight_values + bias_values).to(tl.bfloat16)
else:
normalized_values = (normalized_values * weight_values).to(tl.bfloat16)
sc_offsets = row_idx * HEAD_DIM + tl.arange(0, HEAD_DIM)
sin = (tl.load(sin_ptr + sc_offsets)).reshape(1, HEAD_DIM)
cos = (tl.load(cos_ptr + sc_offsets)).reshape(1, HEAD_DIM)
pos_idx = tl.load(pos_ptr + row_idx).to(tl.int64)
cos_offsets = pos_idx * HEAD_DIM + tl.arange(0, HALF_HEAD_DIM)
sin_offsets = pos_idx * HEAD_DIM + tl.arange(HALF_HEAD_DIM, HEAD_DIM)
cos = (tl.load(cos_sin_ptr + cos_offsets)).reshape(1, HALF_HEAD_DIM)
sin = (tl.load(cos_sin_ptr + sin_offsets)).reshape(1, HALF_HEAD_DIM)
x1 = tl.extract_slice(
normalized_values,
offsets=(0, 0),
@@ -150,23 +157,24 @@ def split_qkv_rmsnorm_rope_kernel(
sizes=(KV_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
strides=(1, 1),
)
cat_x = tl.zeros((KV_BLOCK_SIZE // HEAD_DIM, HEAD_DIM), dtype=tl.bfloat16)
cat_x = tl.insert_slice(
cat_x,
-x2,
roped_k1 = x1 * cos - x2 * sin
roped_k2 = x2 * cos + x1 * sin
roped_k = tl.zeros((KV_BLOCK_SIZE // HEAD_DIM, HEAD_DIM), dtype=tl.bfloat16)
roped_k = tl.insert_slice(
roped_k,
roped_k1,
offsets=(0, 0),
sizes=(KV_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
strides=(1, 1),
)
cat_x = tl.insert_slice(
cat_x,
x1,
roped_k = tl.insert_slice(
roped_k,
roped_k2,
offsets=(0, HALF_HEAD_DIM),
sizes=(KV_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
strides=(1, 1),
)
roped_k = cat_x * sin + normalized_values * cos
tl.store(
k_ptr + output_offset + col_indices,
roped_k.to(tl.bfloat16).reshape(KV_BLOCK_SIZE),
@@ -188,8 +196,8 @@ def split_qkv_rmsnorm_rope_kernel(
def split_qkv_rmsnorm_rope_impl(
input: torch.Tensor,
sin: torch.Tensor,
cos: torch.Tensor,
cos_sin_cache: torch.Tensor,
positions: torch.Tensor,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
q_hidden_size: int,
@@ -216,8 +224,8 @@ def split_qkv_rmsnorm_rope_impl(
split_qkv_rmsnorm_rope_kernel[(n_rows, n_cols, 1)](
input,
sin,
cos,
cos_sin_cache,
positions,
q_output,
k_output,
v_output,
@@ -241,8 +249,8 @@ def split_qkv_rmsnorm_rope_impl(
def split_qkv_rmsnorm_rope_impl_fake(
input: torch.Tensor,
sin: torch.Tensor,
cos: torch.Tensor,
cos_sin_cache: torch.Tensor,
positions: torch.Tensor,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
q_hidden_size: int,

View File

@@ -14,7 +14,6 @@
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
import torch
from vllm.triton_utils import tl, triton
@@ -30,10 +29,13 @@ def _triton_rope(
q_row_stride,
k_ptr,
k_row_stride,
cos,
cos_ptr,
cos_row_stride,
sin,
sin_ptr,
sin_row_stride,
cos_sin_ptr,
cos_sin_row_stride,
pos_ptr,
num_tokens,
n_qh: tl.constexpr,
n_kh: tl.constexpr,
@@ -44,6 +46,7 @@ def _triton_rope(
pad_rope_dim: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
IS_NEOX_STYLE: tl.constexpr,
USE_COS_SIN: tl.constexpr,
):
"""
This triton kernel applies rotary embedding on q and k.
@@ -84,13 +87,19 @@ def _triton_rope(
# get the cos(mθ_{i...d/2}) and sin(mθ_{i...d/2}) for token position
# m of this program instance
# ####################################################################
cos_start_ptr = cos + row_idx * cos_row_stride
sin_start_ptr = sin + row_idx * sin_row_stride
cos_offsets = tl.arange(0, pad_rope_dim // 2)
sin_offsets = tl.arange(pad_rope_dim // 2, pad_rope_dim)
cos_mask = cos_offsets < (rope_dim // 2)
cos_row = tl.load(cos_start_ptr + cos_offsets, mask=cos_mask, other=0).to(tl.float32)
sin_row = tl.load(sin_start_ptr + cos_offsets, mask=cos_mask, other=0).to(tl.float32)
if USE_COS_SIN:
pos_idx = tl.load(pos_ptr + row_idx).to(tl.int64)
cos_start_ptr = cos_sin_ptr + pos_idx * cos_sin_row_stride
cos_row = tl.load(cos_start_ptr + cos_offsets, mask=cos_mask, other=0).to(tl.float32)
sin_row = tl.load(cos_start_ptr + sin_offsets, mask=cos_mask, other=0).to(tl.float32)
else:
cos_start_ptr = cos_ptr + row_idx * cos_row_stride
sin_start_ptr = sin_ptr + row_idx * sin_row_stride
cos_row = tl.load(cos_start_ptr + cos_offsets, mask=cos_mask, other=0).to(tl.float32)
sin_row = tl.load(sin_start_ptr + cos_offsets, mask=cos_mask, other=0).to(tl.float32)
# ####################################################################
# Load the left and right half of q and k for the current
@@ -140,8 +149,10 @@ def _triton_rope(
def rope_forward_triton(
q: torch.Tensor,
k: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
cos: torch.Tensor = None,
sin: torch.Tensor = None,
cos_sin_cache: torch.Tensor = None,
positions: torch.Tensor = None,
rope_dim: int = -1,
is_neox_style: bool = True,
) -> tuple[torch.Tensor, torch.Tensor]:
@@ -152,12 +163,6 @@ def rope_forward_triton(
num_tokens, n_q_head, head_dim = q.shape
n_kv_head = k.shape[1]
cos = cos.view(num_tokens, -1)
sin = sin.view(num_tokens, -1)
if rope_dim == -1:
# If rope_dim is not specified, we assume that input cos/sin is not
# duplicated to rope_dim, which means rope_dim == cos.shape[-1] * 2
rope_dim = cos.shape[-1] * 2
assert rope_dim <= head_dim
pad_rope_dim = triton.next_power_of_2(rope_dim)
pad_n_q_head = triton.next_power_of_2(n_q_head)
@@ -166,24 +171,69 @@ def rope_forward_triton(
num_vectorcore = get_vectorcore_num()
n_row = min(num_tokens, num_vectorcore)
_triton_rope[(n_row,)](
q,
q.stride(0),
k,
k.stride(0),
cos,
cos.stride(0),
sin,
sin.stride(0),
num_tokens,
n_q_head,
n_kv_head,
head_dim,
rope_dim,
pad_n_q_head,
pad_n_kv_head,
pad_rope_dim,
BLOCK_SIZE=BLOCK_SIZE,
IS_NEOX_STYLE=is_neox_style,
)
if cos_sin_cache is not None and positions is not None:
assert positions.shape[0] == num_tokens
_triton_rope[(n_row,)](
q,
q.stride(0),
k,
k.stride(0),
None,
None,
None,
None,
cos_sin_cache,
cos_sin_cache.stride(0),
positions,
num_tokens,
n_q_head,
n_kv_head,
head_dim,
rope_dim,
pad_n_q_head,
pad_n_kv_head,
pad_rope_dim,
BLOCK_SIZE=BLOCK_SIZE,
IS_NEOX_STYLE=is_neox_style,
USE_COS_SIN=True,
)
elif cos is not None and sin is not None:
assert cos.shape[0] == num_tokens and sin.shape[0] == num_tokens
cos = cos.view(num_tokens, -1)
sin = sin.view(num_tokens, -1)
if rope_dim == -1:
# If rope_dim is not specified, we assume that input cos/sin is not
# duplicated to rope_dim, which means rope_dim == cos.shape[-1] * 2
rope_dim = cos.shape[-1] * 2
_triton_rope[(n_row,)](
q,
q.stride(0),
k,
k.stride(0),
cos,
cos.stride(0),
sin,
sin.stride(0),
None,
None,
None,
num_tokens,
n_q_head,
n_kv_head,
head_dim,
rope_dim,
pad_n_q_head,
pad_n_kv_head,
pad_rope_dim,
BLOCK_SIZE=BLOCK_SIZE,
IS_NEOX_STYLE=is_neox_style,
USE_COS_SIN=False,
)
else:
raise ValueError(
"Currently, rope_forward_triton supports passing:\n"
"1. positions and original cos_sin_cache.\n"
"2. cos and sin which are already selected by positions\n"
"Please check whether you call rope_forward_triton correctly."
)
return q, k