[Main][Ops] Make triton rope support index_selecting from cos_sin_cache (#5450)
### What this PR does / why we need it?
This PR extends original `rope_triton_forward` and
`split_qkv_rmsnorm_rope` to support `cos_sin_cache` && `positions` as
inputs. This fully aligns to vLLM RoPE api interface. Compared with
earlier implementation for RoPE, the benefits are:
1. avoiding pre-computation of `cos` `sin` before model execution, which
helps to remove redundant codes.
2. allowing eagle3 draft model to have different rope parameters with
main model (see #6612 ). This help to recover accept rate && accuracy in
that case.
In addition, this kernel change only introduces very small performance
degradation. Those `index_select` or `chunk` operations are now changed
into simple memory access in triton kernel (For example,
https://github.com/vllm-project/vllm-ascend/pull/5450/changes#diff-a4c2d3071530df193b98f9bf38553874bc4d47571336711f116c26d019cfbb6aR77-R81).
**Highlights**
- **RoPE Cache Unification**: Replaced separate _sin and _cos global
tensors with a unified cos_sin_cache and explicit positions tensor for
Rotary Positional Embeddings (RoPE), streamlining data handling.
- **Triton Kernel Integration**: Updated Triton kernels
(split_qkv_rmsnorm_rope_kernel, _triton_rope) to directly consume the
cos_sin_cache and positions for more efficient and integrated RoPE
calculations.
- **Custom Operation Registration**: Registered `rope_forward_oot` as a
new custom operation, allowing its use in fused compilation passes and
providing a dedicated entry point for the new RoPE implementation.
- **Refactored RoPE Forward Pass**: Modified the rope_forward_oot
function to accept the new cos_sin_cache and positions arguments,
enabling a more flexible and integrated RoPE application within the
system.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
- vLLM version: v0.13.0
- vLLM main:
5326c89803
Additional test on Qwen3-235b accuracy:
| Aime2024 | GSM8K | Livecodebench |
| -------- | -------- | -------- |
| 83.33 | 96.26 | 70.23 |
---------
Signed-off-by: Angazenn <supperccell@163.com>
This commit is contained in:
@@ -14,7 +14,7 @@ from vllm.forward_context import get_forward_context
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
from vllm_ascend.ascend_forward_context import MoECommType
|
||||
from vllm_ascend.ops.triton.rope import rope_forward_triton
|
||||
from vllm_ascend.ops.rotary_embedding import rope_forward_oot
|
||||
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
|
||||
from vllm_ascend.utils import npu_stream_switch, prefetch_stream
|
||||
|
||||
@@ -188,15 +188,16 @@ def _quantize_impl_fake(
|
||||
return torch_npu.npu_quantize(in_tensor, input_scale_reciprocal, input_offset, torch.qint8, -1, False)
|
||||
|
||||
|
||||
def _rope_forward_triton_fake(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
rope_dim: int = -1,
|
||||
def _rope_forward_oot_impl_fake(
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
head_dim: int,
|
||||
rotary_dim: int,
|
||||
is_neox_style: bool = True,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
return torch.empty_like(q), torch.empty_like(k)
|
||||
return query, key
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
@@ -262,10 +263,11 @@ direct_register_custom_op(
|
||||
mutates_args=[],
|
||||
dispatch_key="PrivateUse1",
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="rope_forward_triton",
|
||||
op_func=rope_forward_triton,
|
||||
fake_impl=_rope_forward_triton_fake,
|
||||
op_name="npu_rotary_embedding",
|
||||
op_func=rope_forward_oot,
|
||||
fake_impl=_rope_forward_oot_impl_fake,
|
||||
mutates_args=[],
|
||||
dispatch_key="PrivateUse1",
|
||||
)
|
||||
|
||||
@@ -29,11 +29,13 @@ from vllm.model_executor.layers.rotary_embedding import (
|
||||
from vllm.model_executor.layers.rotary_embedding.common import ApplyRotaryEmb
|
||||
from vllm.triton_utils import HAS_TRITON
|
||||
|
||||
from vllm_ascend.platform import NPUPlatform
|
||||
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type, has_rope, is_vl_model
|
||||
|
||||
if HAS_TRITON:
|
||||
from vllm.model_executor.layers.rotary_embedding.mrope import triton_mrope
|
||||
|
||||
from vllm_ascend.platform import NPUPlatform
|
||||
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type, has_rope, is_vl_model
|
||||
from vllm_ascend.ops.triton.rope import rope_forward_triton
|
||||
|
||||
# Currently, rope ops used on npu requires detached cos && sin as inputs.
|
||||
# However, RotaryEmbedding in vllm use cos_sin_cache as a whole variable.
|
||||
@@ -146,54 +148,38 @@ def get_cos_and_sin_slice():
|
||||
return _cos_slice, _sin_slice
|
||||
|
||||
|
||||
def _rope_forward_oot(
|
||||
self,
|
||||
def rope_forward_oot(
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
is_neox_style: bool,
|
||||
offsets: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
query_shape, key_shape = query.shape, key.shape
|
||||
if self.cos_sin_cache.device != query.device:
|
||||
self.cos_sin_cache = self.cos_sin_cache.to(query.device)
|
||||
if self.cos_sin_cache.dtype != query.dtype:
|
||||
self.cos_sin_cache = self.cos_sin_cache.to(query.dtype)
|
||||
cos, sin = get_cos_and_sin_slice()
|
||||
if offsets is not None:
|
||||
raise NotImplementedError("Batched rotary embedding is currently not supported on NPU.")
|
||||
if (
|
||||
is_neox_style
|
||||
and self.head_size == 128
|
||||
and self.cos_sin_cache.shape[-1] == 128
|
||||
and cos is not None
|
||||
and sin is not None
|
||||
):
|
||||
# If cos and sin are generated outside, use npu_apply_rotary_pos_emb to avoid redundant calculation.
|
||||
# This method requires head_size and rotary_dim equal 128 and neox_style is True
|
||||
query = query.contiguous().view(1, query.shape[0], -1, self.head_size)
|
||||
key = key.contiguous().view(1, key.shape[0], -1, self.head_size)
|
||||
# Although this function modifies in-place, please retain the function's return value.
|
||||
# Otherwise, the graph fusion operation may fail.
|
||||
query, key = torch_npu.npu_apply_rotary_pos_emb(query, key, cos, sin)
|
||||
elif self.rotary_dim < self.head_size:
|
||||
if HAS_TRITON:
|
||||
cos = cos.view(-1, self.rotary_dim)
|
||||
sin = sin.view(-1, self.rotary_dim)
|
||||
q = query.contiguous().view(query.shape[0], -1, self.head_size)
|
||||
k = key.contiguous().view(key.shape[0], -1, self.head_size)
|
||||
query, key = torch.ops.vllm.rope_forward_triton(
|
||||
q, k, cos, sin, rope_dim=self.rotary_dim, is_neox_style=True
|
||||
)
|
||||
return query.view(query_shape), key.view(key_shape)
|
||||
else:
|
||||
if HAS_TRITON:
|
||||
num_tokens = query.shape[0]
|
||||
query, key = rope_forward_triton(
|
||||
query.view(num_tokens, -1, head_size),
|
||||
key.view(num_tokens, -1, head_size),
|
||||
cos_sin_cache=cos_sin_cache,
|
||||
positions=positions,
|
||||
rope_dim=rotary_dim,
|
||||
is_neox_style=is_neox_style,
|
||||
)
|
||||
else:
|
||||
if rotary_dim < head_size:
|
||||
num_tokens = query.shape[0]
|
||||
query = query.view(num_tokens, -1, self.head_size)
|
||||
key = key.view(num_tokens, -1, self.head_size)
|
||||
q_rot = query[..., : self.rotary_dim]
|
||||
q_pass = query[..., self.rotary_dim :]
|
||||
k_rot = key[..., : self.rotary_dim]
|
||||
k_pass = key[..., self.rotary_dim :]
|
||||
query = query.view(num_tokens, -1, head_size)
|
||||
key = key.view(num_tokens, -1, head_size)
|
||||
q_rot = query[..., :rotary_dim]
|
||||
q_pass = query[..., rotary_dim:]
|
||||
k_rot = key[..., :rotary_dim]
|
||||
k_pass = key[..., rotary_dim:]
|
||||
q_rot = q_rot.contiguous().view(num_tokens, -1)
|
||||
k_rot = k_rot.contiguous().view(num_tokens, -1)
|
||||
# only the rotary part is processed here,
|
||||
@@ -202,27 +188,26 @@ def _rope_forward_oot(
|
||||
positions,
|
||||
q_rot,
|
||||
k_rot,
|
||||
self.rotary_dim,
|
||||
self.cos_sin_cache,
|
||||
rotary_dim,
|
||||
cos_sin_cache,
|
||||
is_neox_style,
|
||||
)
|
||||
q_rot = q_rot.view(num_tokens, -1, rotary_dim)
|
||||
k_rot = k_rot.view(num_tokens, -1, rotary_dim)
|
||||
query = torch.cat((q_rot, q_pass), dim=-1).reshape(query_shape)
|
||||
key = torch.cat((k_rot, k_pass), dim=-1).reshape(key_shape)
|
||||
else:
|
||||
# TODO: Remove the contiguous in the future.
|
||||
query = query.contiguous().view(query.shape[0], -1)
|
||||
key = key.contiguous().view(key.shape[0], -1)
|
||||
torch_npu._npu_rotary_embedding(
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
head_size,
|
||||
cos_sin_cache,
|
||||
is_neox_style,
|
||||
)
|
||||
q_rot = q_rot.view(num_tokens, -1, self.rotary_dim)
|
||||
k_rot = k_rot.view(num_tokens, -1, self.rotary_dim)
|
||||
q = torch.cat((q_rot, q_pass), dim=-1).reshape(query_shape)
|
||||
k = torch.cat((k_rot, k_pass), dim=-1).reshape(key_shape)
|
||||
return q, k
|
||||
else:
|
||||
# TODO: Remove the contiguous in the future.
|
||||
query = query.contiguous().view(query.shape[0], -1)
|
||||
key = key.contiguous().view(key.shape[0], -1)
|
||||
torch_npu._npu_rotary_embedding(
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
self.head_size,
|
||||
self.cos_sin_cache,
|
||||
is_neox_style,
|
||||
)
|
||||
return query.view(query_shape), key.view(key_shape)
|
||||
|
||||
|
||||
@@ -251,7 +236,9 @@ class AscendRotaryEmbedding(RotaryEmbedding):
|
||||
is_neox_style = self.is_neox_style
|
||||
if is_neox_style_override is not None:
|
||||
is_neox_style = is_neox_style_override
|
||||
return _rope_forward_oot(self, positions, query, key, is_neox_style, offsets)
|
||||
return torch.ops.vllm.npu_rotary_embedding(
|
||||
positions, query, key, self.cos_sin_cache, self.head_size, self.rotary_dim, is_neox_style
|
||||
)
|
||||
|
||||
|
||||
class AscendYaRNRotaryEmbedding(YaRNScalingRotaryEmbedding):
|
||||
@@ -460,7 +447,9 @@ class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding):
|
||||
query = query.view(b, h_q, d // 2, 2).transpose(3, 2).reshape(b, h_q, d)
|
||||
b, h_k, d = key.shape
|
||||
key = key.view(b, h_k, d // 2, 2).transpose(3, 2).reshape(b, h_k, d)
|
||||
q_pe, k_pe = _rope_forward_oot(self, positions, query, key, is_neox_style, offsets)
|
||||
q_pe, k_pe = torch.ops.vllm.npu_rotary_embedding(
|
||||
positions, query, key, self.cos_sin_cache, self.head_size, self.rotary_dim, is_neox_style
|
||||
)
|
||||
return q_pe, k_pe
|
||||
|
||||
|
||||
|
||||
@@ -26,8 +26,8 @@ from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num
|
||||
@triton.jit
|
||||
def split_qkv_rmsnorm_rope_kernel(
|
||||
input_ptr,
|
||||
sin_ptr,
|
||||
cos_ptr,
|
||||
cos_sin_ptr,
|
||||
pos_ptr,
|
||||
q_ptr,
|
||||
k_ptr,
|
||||
v_ptr,
|
||||
@@ -74,9 +74,11 @@ def split_qkv_rmsnorm_rope_kernel(
|
||||
else:
|
||||
normalized_values = (normalized_values * weight_values).to(tl.bfloat16)
|
||||
|
||||
sc_offsets = row_idx * HEAD_DIM + tl.arange(0, HEAD_DIM)
|
||||
sin = (tl.load(sin_ptr + sc_offsets)).reshape(1, HEAD_DIM)
|
||||
cos = (tl.load(cos_ptr + sc_offsets)).reshape(1, HEAD_DIM)
|
||||
pos_idx = tl.load(pos_ptr + row_idx).to(tl.int64)
|
||||
cos_offsets = pos_idx * HEAD_DIM + tl.arange(0, HALF_HEAD_DIM)
|
||||
sin_offsets = pos_idx * HEAD_DIM + tl.arange(HALF_HEAD_DIM, HEAD_DIM)
|
||||
cos = (tl.load(cos_sin_ptr + cos_offsets)).reshape(1, HALF_HEAD_DIM)
|
||||
sin = (tl.load(cos_sin_ptr + sin_offsets)).reshape(1, HALF_HEAD_DIM)
|
||||
x1 = tl.extract_slice(
|
||||
normalized_values,
|
||||
offsets=(0, 0),
|
||||
@@ -89,22 +91,24 @@ def split_qkv_rmsnorm_rope_kernel(
|
||||
sizes=(Q_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
|
||||
strides=(1, 1),
|
||||
)
|
||||
cat_x = tl.zeros((Q_BLOCK_SIZE // HEAD_DIM, HEAD_DIM), dtype=tl.bfloat16)
|
||||
cat_x = tl.insert_slice(
|
||||
cat_x,
|
||||
-x2,
|
||||
roped_q1 = x1 * cos - x2 * sin
|
||||
roped_q2 = x2 * cos + x1 * sin
|
||||
|
||||
roped_q = tl.zeros((Q_BLOCK_SIZE // HEAD_DIM, HEAD_DIM), dtype=tl.bfloat16)
|
||||
roped_q = tl.insert_slice(
|
||||
roped_q,
|
||||
roped_q1,
|
||||
offsets=(0, 0),
|
||||
sizes=(Q_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
|
||||
strides=(1, 1),
|
||||
)
|
||||
cat_x = tl.insert_slice(
|
||||
cat_x,
|
||||
x1,
|
||||
roped_q = tl.insert_slice(
|
||||
roped_q,
|
||||
roped_q2,
|
||||
offsets=(0, HALF_HEAD_DIM),
|
||||
sizes=(Q_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
|
||||
strides=(1, 1),
|
||||
)
|
||||
roped_q = cat_x * sin + normalized_values * cos
|
||||
tl.store(
|
||||
q_ptr + output_offset + col_indices,
|
||||
roped_q.reshape(Q_BLOCK_SIZE).to(q_ptr.dtype.element_ty),
|
||||
@@ -135,9 +139,12 @@ def split_qkv_rmsnorm_rope_kernel(
|
||||
normalized_values = (normalized_values * weight_values + bias_values).to(tl.bfloat16)
|
||||
else:
|
||||
normalized_values = (normalized_values * weight_values).to(tl.bfloat16)
|
||||
sc_offsets = row_idx * HEAD_DIM + tl.arange(0, HEAD_DIM)
|
||||
sin = (tl.load(sin_ptr + sc_offsets)).reshape(1, HEAD_DIM)
|
||||
cos = (tl.load(cos_ptr + sc_offsets)).reshape(1, HEAD_DIM)
|
||||
|
||||
pos_idx = tl.load(pos_ptr + row_idx).to(tl.int64)
|
||||
cos_offsets = pos_idx * HEAD_DIM + tl.arange(0, HALF_HEAD_DIM)
|
||||
sin_offsets = pos_idx * HEAD_DIM + tl.arange(HALF_HEAD_DIM, HEAD_DIM)
|
||||
cos = (tl.load(cos_sin_ptr + cos_offsets)).reshape(1, HALF_HEAD_DIM)
|
||||
sin = (tl.load(cos_sin_ptr + sin_offsets)).reshape(1, HALF_HEAD_DIM)
|
||||
x1 = tl.extract_slice(
|
||||
normalized_values,
|
||||
offsets=(0, 0),
|
||||
@@ -150,23 +157,24 @@ def split_qkv_rmsnorm_rope_kernel(
|
||||
sizes=(KV_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
|
||||
strides=(1, 1),
|
||||
)
|
||||
cat_x = tl.zeros((KV_BLOCK_SIZE // HEAD_DIM, HEAD_DIM), dtype=tl.bfloat16)
|
||||
cat_x = tl.insert_slice(
|
||||
cat_x,
|
||||
-x2,
|
||||
roped_k1 = x1 * cos - x2 * sin
|
||||
roped_k2 = x2 * cos + x1 * sin
|
||||
|
||||
roped_k = tl.zeros((KV_BLOCK_SIZE // HEAD_DIM, HEAD_DIM), dtype=tl.bfloat16)
|
||||
roped_k = tl.insert_slice(
|
||||
roped_k,
|
||||
roped_k1,
|
||||
offsets=(0, 0),
|
||||
sizes=(KV_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
|
||||
strides=(1, 1),
|
||||
)
|
||||
cat_x = tl.insert_slice(
|
||||
cat_x,
|
||||
x1,
|
||||
roped_k = tl.insert_slice(
|
||||
roped_k,
|
||||
roped_k2,
|
||||
offsets=(0, HALF_HEAD_DIM),
|
||||
sizes=(KV_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
|
||||
strides=(1, 1),
|
||||
)
|
||||
roped_k = cat_x * sin + normalized_values * cos
|
||||
|
||||
tl.store(
|
||||
k_ptr + output_offset + col_indices,
|
||||
roped_k.to(tl.bfloat16).reshape(KV_BLOCK_SIZE),
|
||||
@@ -188,8 +196,8 @@ def split_qkv_rmsnorm_rope_kernel(
|
||||
|
||||
def split_qkv_rmsnorm_rope_impl(
|
||||
input: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
q_weight: torch.Tensor,
|
||||
k_weight: torch.Tensor,
|
||||
q_hidden_size: int,
|
||||
@@ -216,8 +224,8 @@ def split_qkv_rmsnorm_rope_impl(
|
||||
|
||||
split_qkv_rmsnorm_rope_kernel[(n_rows, n_cols, 1)](
|
||||
input,
|
||||
sin,
|
||||
cos,
|
||||
cos_sin_cache,
|
||||
positions,
|
||||
q_output,
|
||||
k_output,
|
||||
v_output,
|
||||
@@ -241,8 +249,8 @@ def split_qkv_rmsnorm_rope_impl(
|
||||
|
||||
def split_qkv_rmsnorm_rope_impl_fake(
|
||||
input: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
q_weight: torch.Tensor,
|
||||
k_weight: torch.Tensor,
|
||||
q_hidden_size: int,
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
import torch
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
@@ -30,10 +29,13 @@ def _triton_rope(
|
||||
q_row_stride,
|
||||
k_ptr,
|
||||
k_row_stride,
|
||||
cos,
|
||||
cos_ptr,
|
||||
cos_row_stride,
|
||||
sin,
|
||||
sin_ptr,
|
||||
sin_row_stride,
|
||||
cos_sin_ptr,
|
||||
cos_sin_row_stride,
|
||||
pos_ptr,
|
||||
num_tokens,
|
||||
n_qh: tl.constexpr,
|
||||
n_kh: tl.constexpr,
|
||||
@@ -44,6 +46,7 @@ def _triton_rope(
|
||||
pad_rope_dim: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
IS_NEOX_STYLE: tl.constexpr,
|
||||
USE_COS_SIN: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
This triton kernel applies rotary embedding on q and k.
|
||||
@@ -84,13 +87,19 @@ def _triton_rope(
|
||||
# get the cos(mθ_{i...d/2}) and sin(mθ_{i...d/2}) for token position
|
||||
# m of this program instance
|
||||
# ####################################################################
|
||||
cos_start_ptr = cos + row_idx * cos_row_stride
|
||||
sin_start_ptr = sin + row_idx * sin_row_stride
|
||||
|
||||
cos_offsets = tl.arange(0, pad_rope_dim // 2)
|
||||
sin_offsets = tl.arange(pad_rope_dim // 2, pad_rope_dim)
|
||||
cos_mask = cos_offsets < (rope_dim // 2)
|
||||
cos_row = tl.load(cos_start_ptr + cos_offsets, mask=cos_mask, other=0).to(tl.float32)
|
||||
sin_row = tl.load(sin_start_ptr + cos_offsets, mask=cos_mask, other=0).to(tl.float32)
|
||||
if USE_COS_SIN:
|
||||
pos_idx = tl.load(pos_ptr + row_idx).to(tl.int64)
|
||||
cos_start_ptr = cos_sin_ptr + pos_idx * cos_sin_row_stride
|
||||
cos_row = tl.load(cos_start_ptr + cos_offsets, mask=cos_mask, other=0).to(tl.float32)
|
||||
sin_row = tl.load(cos_start_ptr + sin_offsets, mask=cos_mask, other=0).to(tl.float32)
|
||||
else:
|
||||
cos_start_ptr = cos_ptr + row_idx * cos_row_stride
|
||||
sin_start_ptr = sin_ptr + row_idx * sin_row_stride
|
||||
cos_row = tl.load(cos_start_ptr + cos_offsets, mask=cos_mask, other=0).to(tl.float32)
|
||||
sin_row = tl.load(sin_start_ptr + cos_offsets, mask=cos_mask, other=0).to(tl.float32)
|
||||
|
||||
# ####################################################################
|
||||
# Load the left and right half of q and k for the current
|
||||
@@ -140,8 +149,10 @@ def _triton_rope(
|
||||
def rope_forward_triton(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
cos: torch.Tensor = None,
|
||||
sin: torch.Tensor = None,
|
||||
cos_sin_cache: torch.Tensor = None,
|
||||
positions: torch.Tensor = None,
|
||||
rope_dim: int = -1,
|
||||
is_neox_style: bool = True,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
@@ -152,12 +163,6 @@ def rope_forward_triton(
|
||||
|
||||
num_tokens, n_q_head, head_dim = q.shape
|
||||
n_kv_head = k.shape[1]
|
||||
cos = cos.view(num_tokens, -1)
|
||||
sin = sin.view(num_tokens, -1)
|
||||
if rope_dim == -1:
|
||||
# If rope_dim is not specified, we assume that input cos/sin is not
|
||||
# duplicated to rope_dim, which means rope_dim == cos.shape[-1] * 2
|
||||
rope_dim = cos.shape[-1] * 2
|
||||
assert rope_dim <= head_dim
|
||||
pad_rope_dim = triton.next_power_of_2(rope_dim)
|
||||
pad_n_q_head = triton.next_power_of_2(n_q_head)
|
||||
@@ -166,24 +171,69 @@ def rope_forward_triton(
|
||||
num_vectorcore = get_vectorcore_num()
|
||||
n_row = min(num_tokens, num_vectorcore)
|
||||
|
||||
_triton_rope[(n_row,)](
|
||||
q,
|
||||
q.stride(0),
|
||||
k,
|
||||
k.stride(0),
|
||||
cos,
|
||||
cos.stride(0),
|
||||
sin,
|
||||
sin.stride(0),
|
||||
num_tokens,
|
||||
n_q_head,
|
||||
n_kv_head,
|
||||
head_dim,
|
||||
rope_dim,
|
||||
pad_n_q_head,
|
||||
pad_n_kv_head,
|
||||
pad_rope_dim,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
IS_NEOX_STYLE=is_neox_style,
|
||||
)
|
||||
if cos_sin_cache is not None and positions is not None:
|
||||
assert positions.shape[0] == num_tokens
|
||||
_triton_rope[(n_row,)](
|
||||
q,
|
||||
q.stride(0),
|
||||
k,
|
||||
k.stride(0),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
cos_sin_cache,
|
||||
cos_sin_cache.stride(0),
|
||||
positions,
|
||||
num_tokens,
|
||||
n_q_head,
|
||||
n_kv_head,
|
||||
head_dim,
|
||||
rope_dim,
|
||||
pad_n_q_head,
|
||||
pad_n_kv_head,
|
||||
pad_rope_dim,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
IS_NEOX_STYLE=is_neox_style,
|
||||
USE_COS_SIN=True,
|
||||
)
|
||||
elif cos is not None and sin is not None:
|
||||
assert cos.shape[0] == num_tokens and sin.shape[0] == num_tokens
|
||||
cos = cos.view(num_tokens, -1)
|
||||
sin = sin.view(num_tokens, -1)
|
||||
if rope_dim == -1:
|
||||
# If rope_dim is not specified, we assume that input cos/sin is not
|
||||
# duplicated to rope_dim, which means rope_dim == cos.shape[-1] * 2
|
||||
rope_dim = cos.shape[-1] * 2
|
||||
_triton_rope[(n_row,)](
|
||||
q,
|
||||
q.stride(0),
|
||||
k,
|
||||
k.stride(0),
|
||||
cos,
|
||||
cos.stride(0),
|
||||
sin,
|
||||
sin.stride(0),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
num_tokens,
|
||||
n_q_head,
|
||||
n_kv_head,
|
||||
head_dim,
|
||||
rope_dim,
|
||||
pad_n_q_head,
|
||||
pad_n_kv_head,
|
||||
pad_rope_dim,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
IS_NEOX_STYLE=is_neox_style,
|
||||
USE_COS_SIN=False,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Currently, rope_forward_triton supports passing:\n"
|
||||
"1. positions and original cos_sin_cache.\n"
|
||||
"2. cos and sin which are already selected by positions\n"
|
||||
"Please check whether you call rope_forward_triton correctly."
|
||||
)
|
||||
return q, k
|
||||
|
||||
Reference in New Issue
Block a user