[Main][Ops] Make triton rope support index_selecting from cos_sin_cache (#5450)

### What this PR does / why we need it?

This PR extends original `rope_triton_forward` and
`split_qkv_rmsnorm_rope` to support `cos_sin_cache` && `positions` as
inputs. This fully aligns to vLLM RoPE api interface. Compared with
earlier implementation for RoPE, the benefits are:

1. avoiding pre-computation of `cos` `sin` before model execution, which
helps to remove redundant codes.
2. allowing eagle3 draft model to have different rope parameters with
main model (see #6612 ). This help to recover accept rate && accuracy in
that case.

In addition, this kernel change only introduces very small performance
degradation. Those `index_select` or `chunk` operations are now changed
into simple memory access in triton kernel (For example,
https://github.com/vllm-project/vllm-ascend/pull/5450/changes#diff-a4c2d3071530df193b98f9bf38553874bc4d47571336711f116c26d019cfbb6aR77-R81).

**Highlights**

- **RoPE Cache Unification**: Replaced separate _sin and _cos global
tensors with a unified cos_sin_cache and explicit positions tensor for
Rotary Positional Embeddings (RoPE), streamlining data handling.
- **Triton Kernel Integration**: Updated Triton kernels
(split_qkv_rmsnorm_rope_kernel, _triton_rope) to directly consume the
cos_sin_cache and positions for more efficient and integrated RoPE
calculations.
- **Custom Operation Registration**: Registered `rope_forward_oot` as a
new custom operation, allowing its use in fused compilation passes and
providing a dedicated entry point for the new RoPE implementation.
- **Refactored RoPE Forward Pass**: Modified the rope_forward_oot
function to accept the new cos_sin_cache and positions arguments,
enabling a more flexible and integrated RoPE application within the
system.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
5326c89803

Additional test on Qwen3-235b accuracy:

| Aime2024 | GSM8K | Livecodebench |
| -------- | -------- | -------- |
| 83.33 | 96.26 | 70.23 |

---------

Signed-off-by: Angazenn <supperccell@163.com>
This commit is contained in:
Angazenn
2026-02-11 21:20:53 +08:00
committed by GitHub
parent 6bc44bf49b
commit c0c2eb614e
13 changed files with 378 additions and 243 deletions

View File

@@ -45,16 +45,21 @@ class QKNormRopeFusionPattern:
def get_inputs(self):
T = 5
max_position_embeddings = 16384
qkv = torch.empty(T, self.q_size + 2 * self.kv_size, dtype=torch.bfloat16, device="npu")
q_weight = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu")
k_weight = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu")
cos = torch.empty(1, T, 1, self.head_dim, dtype=torch.bfloat16, device="npu")
sin = torch.empty(1, T, 1, self.head_dim, dtype=torch.bfloat16, device="npu")
return [qkv, q_weight, k_weight, cos, sin]
cos_sin_cache = torch.empty(max_position_embeddings, self.head_dim, dtype=torch.bfloat16, device="npu")
positions = torch.ones(T, dtype=torch.int64, device="npu")
return [qkv, q_weight, k_weight, cos_sin_cache, positions]
def register(self, pm_pass: PatternMatcherPass):
def pattern(
qkv: torch.Tensor, q_weight: torch.Tensor, k_weight: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
qkv: torch.Tensor,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
cos_sin_cache: torch.Tensor,
positions: torch.Tensor,
):
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
@@ -65,17 +70,19 @@ class QKNormRopeFusionPattern:
k_norm_out, _ = torch.ops.npu.npu_rms_norm(k_by_head, k_weight, self.eps)
q_flat = q_norm_out.view(q.shape)
q_reshape = q_flat.contiguous().view(1, q_flat.shape[0], -1, self.head_dim)
k_flat = k_norm_out.view(k.shape)
k_reshape = k_flat.contiguous().view(1, k_flat.shape[0], -1, self.head_dim)
q_rope, k_rope = torch.ops.npu.npu_apply_rotary_pos_emb(q_reshape, k_reshape, cos, sin)
q_rope, k_rope = torch.ops.vllm.npu_rotary_embedding(
positions, q_flat, k_flat, cos_sin_cache, self.head_dim, self.head_dim, True
)
return q_rope, k_rope, v
def replacement(
qkv: torch.Tensor, q_weight: torch.Tensor, k_weight: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
qkv: torch.Tensor,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
cos_sin_cache: torch.Tensor,
positions: torch.Tensor,
):
results = torch.ops.vllm.qkv_rmsnorm_rope(
input=qkv,
@@ -87,8 +94,8 @@ class QKNormRopeFusionPattern:
eps=self.eps,
q_bias=None,
k_bias=None,
sin=sin,
cos=cos,
cos_sin_cache=cos_sin_cache,
positions=positions,
)
return results
@@ -109,15 +116,16 @@ class QKNormRopeFusionPatternWithBias:
def get_inputs(self):
T = 5
max_position_embeddings = 16384
qkv = torch.empty(T, self.q_size + 2 * self.kv_size, dtype=torch.bfloat16, device="npu")
q_weight = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu")
k_weight = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu")
q_bias = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu")
k_bias = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu")
cos = torch.empty(1, T, 1, self.head_dim, dtype=torch.bfloat16, device="npu")
sin = torch.empty(1, T, 1, self.head_dim, dtype=torch.bfloat16, device="npu")
cos_sin_cache = torch.empty(max_position_embeddings, self.head_dim, dtype=torch.bfloat16, device="npu")
positions = torch.ones(T, dtype=torch.int64, device="npu")
return [qkv, q_weight, k_weight, q_bias, k_bias, cos, sin]
return [qkv, q_weight, k_weight, q_bias, k_bias, cos_sin_cache, positions]
def register(self, pm_pass: PatternMatcherPass):
def pattern(
@@ -126,8 +134,8 @@ class QKNormRopeFusionPatternWithBias:
k_weight: torch.Tensor,
q_bias: torch.Tensor,
k_bias: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
cos_sin_cache: torch.Tensor,
positions: torch.Tensor,
):
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
@@ -140,12 +148,10 @@ class QKNormRopeFusionPatternWithBias:
k_normed = k_norm_out + k_bias
q_flat = q_normed.view(q.shape)
q_reshape = q_flat.contiguous().view(1, q_flat.shape[0], -1, self.head_dim)
k_flat = k_normed.view(k.shape)
k_reshape = k_flat.contiguous().view(1, k_flat.shape[0], -1, self.head_dim)
q_rope, k_rope = torch.ops.npu.npu_apply_rotary_pos_emb(q_reshape, k_reshape, cos, sin)
q_rope, k_rope = torch.ops.vllm.npu_rotary_embedding(
positions, q_flat, k_flat, cos_sin_cache, self.head_dim, self.head_dim, True
)
return q_rope, k_rope, v
@@ -155,8 +161,8 @@ class QKNormRopeFusionPatternWithBias:
k_weight: torch.Tensor,
q_bias: torch.Tensor,
k_bias: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
cos_sin_cache: torch.Tensor,
positions: torch.Tensor,
):
results = torch.ops.vllm.qkv_rmsnorm_rope(
input=qkv,
@@ -168,8 +174,8 @@ class QKNormRopeFusionPatternWithBias:
eps=self.eps,
q_bias=q_bias,
k_bias=k_bias,
cos=cos,
sin=sin,
cos_sin_cache=cos_sin_cache,
positions=positions,
)
return results
@@ -186,7 +192,7 @@ class QKNormRopeFusionPass(VllmInductorPass):
self.pattern_match_passes: PatternMatcherPass = PatternMatcherPass(pass_name="qknorm_rope_fusion_pass")
dtype = vllm_config.model_config.dtype
if dtype not in (torch.bfloat16, torch.float16):
if dtype not in (torch.bfloat16,):
logger.debug("QKNorm and Rope fusion not enabled: unsupported dtype %s", dtype)
return