[Main][Ops] Make triton rope support index_selecting from cos_sin_cache (#5450)

### What this PR does / why we need it?

This PR extends original `rope_triton_forward` and
`split_qkv_rmsnorm_rope` to support `cos_sin_cache` && `positions` as
inputs. This fully aligns to vLLM RoPE api interface. Compared with
earlier implementation for RoPE, the benefits are:

1. avoiding pre-computation of `cos` `sin` before model execution, which
helps to remove redundant codes.
2. allowing eagle3 draft model to have different rope parameters with
main model (see #6612 ). This help to recover accept rate && accuracy in
that case.

In addition, this kernel change only introduces very small performance
degradation. Those `index_select` or `chunk` operations are now changed
into simple memory access in triton kernel (For example,
https://github.com/vllm-project/vllm-ascend/pull/5450/changes#diff-a4c2d3071530df193b98f9bf38553874bc4d47571336711f116c26d019cfbb6aR77-R81).

**Highlights**

- **RoPE Cache Unification**: Replaced separate _sin and _cos global
tensors with a unified cos_sin_cache and explicit positions tensor for
Rotary Positional Embeddings (RoPE), streamlining data handling.
- **Triton Kernel Integration**: Updated Triton kernels
(split_qkv_rmsnorm_rope_kernel, _triton_rope) to directly consume the
cos_sin_cache and positions for more efficient and integrated RoPE
calculations.
- **Custom Operation Registration**: Registered `rope_forward_oot` as a
new custom operation, allowing its use in fused compilation passes and
providing a dedicated entry point for the new RoPE implementation.
- **Refactored RoPE Forward Pass**: Modified the rope_forward_oot
function to accept the new cos_sin_cache and positions arguments,
enabling a more flexible and integrated RoPE application within the
system.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
5326c89803

Additional test on Qwen3-235b accuracy:

| Aime2024 | GSM8K | Livecodebench |
| -------- | -------- | -------- |
| 83.33 | 96.26 | 70.23 |

---------

Signed-off-by: Angazenn <supperccell@163.com>
This commit is contained in:
Angazenn
2026-02-11 21:20:53 +08:00
committed by GitHub
parent 6bc44bf49b
commit c0c2eb614e
13 changed files with 378 additions and 243 deletions

View File

@@ -14,7 +14,7 @@ from vllm.forward_context import get_forward_context
from vllm.utils.torch_utils import direct_register_custom_op
from vllm_ascend.ascend_forward_context import MoECommType
from vllm_ascend.ops.triton.rope import rope_forward_triton
from vllm_ascend.ops.rotary_embedding import rope_forward_oot
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
from vllm_ascend.utils import npu_stream_switch, prefetch_stream
@@ -188,15 +188,16 @@ def _quantize_impl_fake(
return torch_npu.npu_quantize(in_tensor, input_scale_reciprocal, input_offset, torch.qint8, -1, False)
def _rope_forward_triton_fake(
q: torch.Tensor,
k: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
rope_dim: int = -1,
def _rope_forward_oot_impl_fake(
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
cos_sin_cache: torch.Tensor,
head_dim: int,
rotary_dim: int,
is_neox_style: bool = True,
) -> tuple[torch.Tensor, torch.Tensor]:
return torch.empty_like(q), torch.empty_like(k)
return query, key
direct_register_custom_op(
@@ -262,10 +263,11 @@ direct_register_custom_op(
mutates_args=[],
dispatch_key="PrivateUse1",
)
direct_register_custom_op(
op_name="rope_forward_triton",
op_func=rope_forward_triton,
fake_impl=_rope_forward_triton_fake,
op_name="npu_rotary_embedding",
op_func=rope_forward_oot,
fake_impl=_rope_forward_oot_impl_fake,
mutates_args=[],
dispatch_key="PrivateUse1",
)