[Main][Ops] Make triton rope support index_selecting from cos_sin_cache (#5450)
### What this PR does / why we need it?
This PR extends original `rope_triton_forward` and
`split_qkv_rmsnorm_rope` to support `cos_sin_cache` && `positions` as
inputs. This fully aligns to vLLM RoPE api interface. Compared with
earlier implementation for RoPE, the benefits are:
1. avoiding pre-computation of `cos` `sin` before model execution, which
helps to remove redundant codes.
2. allowing eagle3 draft model to have different rope parameters with
main model (see #6612 ). This help to recover accept rate && accuracy in
that case.
In addition, this kernel change only introduces very small performance
degradation. Those `index_select` or `chunk` operations are now changed
into simple memory access in triton kernel (For example,
https://github.com/vllm-project/vllm-ascend/pull/5450/changes#diff-a4c2d3071530df193b98f9bf38553874bc4d47571336711f116c26d019cfbb6aR77-R81).
**Highlights**
- **RoPE Cache Unification**: Replaced separate _sin and _cos global
tensors with a unified cos_sin_cache and explicit positions tensor for
Rotary Positional Embeddings (RoPE), streamlining data handling.
- **Triton Kernel Integration**: Updated Triton kernels
(split_qkv_rmsnorm_rope_kernel, _triton_rope) to directly consume the
cos_sin_cache and positions for more efficient and integrated RoPE
calculations.
- **Custom Operation Registration**: Registered `rope_forward_oot` as a
new custom operation, allowing its use in fused compilation passes and
providing a dedicated entry point for the new RoPE implementation.
- **Refactored RoPE Forward Pass**: Modified the rope_forward_oot
function to accept the new cos_sin_cache and positions arguments,
enabling a more flexible and integrated RoPE application within the
system.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
- vLLM version: v0.13.0
- vLLM main:
5326c89803
Additional test on Qwen3-235b accuracy:
| Aime2024 | GSM8K | Livecodebench |
| -------- | -------- | -------- |
| 83.33 | 96.26 | 70.23 |
---------
Signed-off-by: Angazenn <supperccell@163.com>
This commit is contained in:
@@ -60,7 +60,7 @@ class ModelQKNormRopeWithoutBias(nn.Module):
|
||||
self.q_weight = nn.Parameter(torch.randn(head_dim, dtype=dtype, device=device))
|
||||
self.k_weight = nn.Parameter(torch.randn(head_dim, dtype=dtype, device=device))
|
||||
|
||||
def forward(self, qkv, cos, sin):
|
||||
def forward(self, qkv, cos_sin_cache, positions):
|
||||
"""
|
||||
Args:
|
||||
qkv: [T, q_size + 2*kv_size]
|
||||
@@ -82,13 +82,12 @@ class ModelQKNormRopeWithoutBias(nn.Module):
|
||||
|
||||
# Reshape for RoPE: [T, num_heads, head_dim] -> [1, T, num_heads, head_dim]
|
||||
q_flat = q_norm_out.view(q.shape)
|
||||
q_reshape = q_flat.contiguous().view(1, q_flat.shape[0], -1, self.head_dim)
|
||||
|
||||
k_flat = k_norm_out.view(k.shape)
|
||||
k_reshape = k_flat.contiguous().view(1, k_flat.shape[0], -1, self.head_dim)
|
||||
|
||||
# Apply RoPE
|
||||
q_rope, k_rope = torch.ops.npu.npu_apply_rotary_pos_emb(q_reshape, k_reshape, cos, sin)
|
||||
q_rope, k_rope = torch.ops.vllm.npu_rotary_embedding(
|
||||
positions, q_flat, k_flat, cos_sin_cache, self.head_dim, self.head_dim, True
|
||||
)
|
||||
|
||||
return q_rope, k_rope, v
|
||||
|
||||
@@ -116,7 +115,7 @@ class ModelQKNormRopeWithBias(nn.Module):
|
||||
self.q_bias = nn.Parameter(torch.randn(head_dim, dtype=dtype, device=device))
|
||||
self.k_bias = nn.Parameter(torch.randn(head_dim, dtype=dtype, device=device))
|
||||
|
||||
def forward(self, qkv, cos, sin):
|
||||
def forward(self, qkv, cos_sin_cache, positions):
|
||||
# Split QKV
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
|
||||
@@ -132,13 +131,12 @@ class ModelQKNormRopeWithBias(nn.Module):
|
||||
|
||||
# Reshape for RoPE
|
||||
q_flat = q_normed.view(q.shape)
|
||||
q_reshape = q_flat.contiguous().view(1, q_flat.shape[0], -1, self.head_dim)
|
||||
|
||||
k_flat = k_normed.view(k.shape)
|
||||
k_reshape = k_flat.contiguous().view(1, k_flat.shape[0], -1, self.head_dim)
|
||||
|
||||
# Apply RoPE
|
||||
q_rope, k_rope = torch.ops.npu.npu_apply_rotary_pos_emb(q_reshape, k_reshape, cos, sin)
|
||||
q_rope, k_rope = torch.ops.vllm.npu_rotary_embedding(
|
||||
positions, q_flat, k_flat, cos_sin_cache, self.head_dim, self.head_dim, True
|
||||
)
|
||||
|
||||
return q_rope, k_rope, v
|
||||
|
||||
@@ -147,7 +145,7 @@ def assert_qknorm_rope_fusion(after_gm, expect_fused=True, use_bias=False):
|
||||
check_rules = [
|
||||
(torch.ops.vllm.qkv_rmsnorm_rope.default, expect_fused),
|
||||
(torch.ops.npu.npu_rms_norm.default, not expect_fused),
|
||||
(torch.ops.npu.npu_apply_rotary_pos_emb.default, not expect_fused),
|
||||
(torch.ops.vllm.npu_rotary_embedding.default, not expect_fused),
|
||||
]
|
||||
if use_bias:
|
||||
check_rules.append((torch.ops.aten.add.Tensor, not expect_fused))
|
||||
|
||||
Reference in New Issue
Block a user