[Main][Ops] Make triton rope support index_selecting from cos_sin_cache (#5450)
### What this PR does / why we need it?
This PR extends original `rope_triton_forward` and
`split_qkv_rmsnorm_rope` to support `cos_sin_cache` && `positions` as
inputs. This fully aligns to vLLM RoPE api interface. Compared with
earlier implementation for RoPE, the benefits are:
1. avoiding pre-computation of `cos` `sin` before model execution, which
helps to remove redundant codes.
2. allowing eagle3 draft model to have different rope parameters with
main model (see #6612 ). This help to recover accept rate && accuracy in
that case.
In addition, this kernel change only introduces very small performance
degradation. Those `index_select` or `chunk` operations are now changed
into simple memory access in triton kernel (For example,
https://github.com/vllm-project/vllm-ascend/pull/5450/changes#diff-a4c2d3071530df193b98f9bf38553874bc4d47571336711f116c26d019cfbb6aR77-R81).
**Highlights**
- **RoPE Cache Unification**: Replaced separate _sin and _cos global
tensors with a unified cos_sin_cache and explicit positions tensor for
Rotary Positional Embeddings (RoPE), streamlining data handling.
- **Triton Kernel Integration**: Updated Triton kernels
(split_qkv_rmsnorm_rope_kernel, _triton_rope) to directly consume the
cos_sin_cache and positions for more efficient and integrated RoPE
calculations.
- **Custom Operation Registration**: Registered `rope_forward_oot` as a
new custom operation, allowing its use in fused compilation passes and
providing a dedicated entry point for the new RoPE implementation.
- **Refactored RoPE Forward Pass**: Modified the rope_forward_oot
function to accept the new cos_sin_cache and positions arguments,
enabling a more flexible and integrated RoPE application within the
system.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
- vLLM version: v0.13.0
- vLLM main:
5326c89803
Additional test on Qwen3-235b accuracy:
| Aime2024 | GSM8K | Livecodebench |
| -------- | -------- | -------- |
| 83.33 | 96.26 | 70.23 |
---------
Signed-off-by: Angazenn <supperccell@163.com>
This commit is contained in:
@@ -164,8 +164,6 @@ def set_ascend_forward_context(
|
||||
|
||||
_mc2_tokens_capacity: int | None = None
|
||||
_reserved_mc2_mask: torch.Tensor | None = None
|
||||
_sin: torch.Tensor | None = None
|
||||
_cos: torch.Tensor | None = None
|
||||
|
||||
|
||||
def set_mc2_tokens_capacity(vllm_config, max_num_reqs, uniform_decode_query_len):
|
||||
|
||||
Reference in New Issue
Block a user