[OPS]add split_qkv_rmsnorm_mrope ops (#6730)

### What this PR does / why we need it?
This PR adds split_qkv_rmsnorm_mrope kernel with interleaved for qwen3.5
and qwen3-vl to improve performance.

### Does this PR introduce _any_ user-facing change?
Does not.

### How to use?
```python
real_q, real_k, real_v, real_gate = torch.ops.vllm.triton_split_qkv_rmsnorm_mrope(
            qkv=qkv,
            q_weight=q_weight,
            k_weight=k_weight,
            cos_sin=cos_sin,
            num_q_heads=num_q_heads,
            num_kv_heads=num_kv_heads,
            head_size=head_size,
            eps=eps,
            mrope_section=mrope_section,
            is_interleaved=is_interleaved,
            rope_dim=rope_dim,
            has_gate=has_gate,
    )
```
### How was this patch tested?
- vLLM version: v0.16.0
- Accuracy test script:
```shell
pytest tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_split_qkv_rmsnorm_mrope.py
```

---------

Signed-off-by: Fager <865071616@qq.com>
Signed-off-by: Fager10086 <77871921+Fager10086@users.noreply.github.com>
Signed-off-by: fager <865071616@qq.com>
This commit is contained in:
Fager10086
2026-03-06 16:18:37 +08:00
committed by GitHub
parent bc0fd7ca72
commit c5dfa8d645
3 changed files with 772 additions and 0 deletions

View File

@@ -0,0 +1,343 @@
import gc
import pytest
import torch
from vllm_ascend.ops.triton.triton_utils import init_device_properties_triton
NUM_TOKENS = [1, 4, 8, 16, 1024, 4096]
NUM_QKV_HEADS = [(8, 2), (2, 1), (16, 2)]
HEAD_SIZES = [128, 256]
EPS = [1e-6]
MROPE_SECTION = [[11, 11, 10], [24, 20, 20]]
IS_INTERLEAVED = [True, False]
HAS_GATE = [True, False]
DTYPES = [torch.bfloat16, torch.float16]
DEVICES = [f"npu:{0}"]
DEFAULT_ATOL = 1e-2
DEFAULT_RTOL = 1e-2
def apply_interleaved_rope(x: torch.Tensor, mrope_section: list[int]) -> torch.Tensor:
"""Apply interleaved MRoPE to 3D rotary embeddings.
Reorganizes frequency layout from chunked [TTT...HHH...WWW] to
interleaved [THTHWHTHW...TT], preserving frequency continuity.
"""
x_t = x[0].clone()
x_t[..., 1 : mrope_section[1] * 3 : 3] = x[1, ..., 1 : mrope_section[1] * 3 : 3]
x_t[..., 2 : mrope_section[2] * 3 : 3] = x[2, ..., 2 : mrope_section[2] * 3 : 3]
return x_t
def rms_norm(x: torch.Tensor,
norm_weight: torch.Tensor,
eps,
norm_bias=None,
):
x = x.cpu()
norm_weight = norm_weight.cpu()
x = x.to(torch.float32)
norm_weight = norm_weight.to(torch.float32).cpu()
reciprocal_std = 1 / torch.sqrt(
torch.mean(x ** 2, axis=-1, keepdims=True) + eps)
out = x * reciprocal_std * norm_weight
if norm_bias is not None:
norm_bias = norm_bias.cpu().to(torch.float32)
out = out + norm_bias
return out
def naive_split_qkv_rmsnorm_mrope(
qkv: torch.Tensor,
q_weight: torch.Tensor,
q_bias: torch.Tensor,
k_weight: torch.Tensor,
k_bias: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
num_q_heads: int,
num_kv_heads: int,
head_size: int,
eps: float,
mrope_section: list[int],
rope_dim: int,
):
q_size = num_q_heads * head_size
kv_size = num_kv_heads * head_size
# split
qkv = qkv.cpu()
q, k, v = qkv.split([q_size, kv_size, kv_size], dim=-1)
# norm
q = rms_norm(q.reshape(-1, head_size), q_weight, eps, norm_bias=q_bias)
k = rms_norm(k.reshape(-1, head_size), k_weight, eps, norm_bias=k_bias)
# mrope
rotary_dim = rope_dim
num_tokens = qkv.shape[0]
n_q_head = num_q_heads
n_kv_head = num_kv_heads
q_reshaped = q.view(num_tokens, n_q_head, head_size)
k_reshaped = k.view(num_tokens, n_kv_head, head_size)
cos_reshaped = cos.permute(1, 2, 0)
sin_reshaped = sin.permute(1, 2, 0)
half_rd = rotary_dim // 2
for token_idx in range(num_tokens):
token_cos = cos_reshaped[token_idx]
token_sin = sin_reshaped[token_idx]
cos_row = torch.zeros(half_rd, device=q.device, dtype=q.dtype)
sin_row = torch.zeros(half_rd, device=q.device, dtype=q.dtype)
t_end = mrope_section[0]
h_end = t_end + mrope_section[1]
if t_end > 0:
cos_row[:t_end] = token_cos[:t_end, 0]
sin_row[:t_end] = token_sin[:t_end, 0]
if mrope_section[1] > 0:
cos_row[t_end:h_end] = token_cos[t_end:h_end, 1]
sin_row[t_end:h_end] = token_sin[t_end:h_end, 1]
if mrope_section[2] > 0:
w_start = h_end
cos_row[w_start:half_rd] = token_cos[w_start:half_rd, 2]
sin_row[w_start:half_rd] = token_sin[w_start:half_rd, 2]
q_token = q_reshaped[token_idx]
k_token = k_reshaped[token_idx]
q1 = q_token[:, :half_rd]
q2 = q_token[:, half_rd:rotary_dim]
k1 = k_token[:, :half_rd]
k2 = k_token[:, half_rd:rotary_dim]
cos_half = cos_row.unsqueeze(0)
sin_half = sin_row.unsqueeze(0)
new_q1 = q1 * cos_half - q2 * sin_half
new_q2 = q2 * cos_half + q1 * sin_half
new_k1 = k1 * cos_half - k2 * sin_half
new_k2 = k2 * cos_half + k1 * sin_half
q_reshaped[token_idx, :, :rotary_dim] = torch.cat([new_q1, new_q2], dim=1)
k_reshaped[token_idx, :, :rotary_dim] = torch.cat([new_k1, new_k2], dim=1)
q_result = q_reshaped.view(num_tokens, -1)
k_result = k_reshaped.view(num_tokens, -1)
q = q_result.to(qkv.dtype)
k = k_result.to(qkv.dtype)
v = v.to(qkv.dtype)
return q, k, v
def naive_split_qkv_rmsnorm_mrope_interleaved(
qkv: torch.Tensor,
q_weight: torch.Tensor,
q_bias: torch.Tensor,
k_weight: torch.Tensor,
k_bias: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
num_q_heads: int,
num_kv_heads: int,
head_size: int,
eps: float,
mrope_section: list[int],
rope_dim: int,
):
q_size = num_q_heads * head_size
kv_size = num_kv_heads * head_size
# split
qkv = qkv.cpu()
q, k, v = qkv.split([q_size, kv_size, kv_size], dim=-1)
# norm
q = rms_norm(q.reshape(-1, head_size), q_weight, eps, norm_bias=q_bias)
k = rms_norm(k.reshape(-1, head_size), k_weight, eps, norm_bias=k_bias)
# mrope
rotary_dim = rope_dim
num_tokens = qkv.shape[0]
n_q_head = num_q_heads
n_kv_head = num_kv_heads
q_reshaped = q.view(num_tokens, n_q_head, head_size)
k_reshaped = k.view(num_tokens, n_kv_head, head_size)
cos_reshaped = apply_interleaved_rope(cos, mrope_section)
sin_reshaped = apply_interleaved_rope(sin, mrope_section)
half_rd = rotary_dim // 2
for token_idx in range(num_tokens):
cos_row = cos_reshaped[token_idx]
sin_row = sin_reshaped[token_idx]
q_token = q_reshaped[token_idx]
k_token = k_reshaped[token_idx]
q1 = q_token[:, :half_rd]
q2 = q_token[:, half_rd:rotary_dim]
k1 = k_token[:, :half_rd]
k2 = k_token[:, half_rd:rotary_dim]
cos_half = cos_row.unsqueeze(0)
sin_half = sin_row.unsqueeze(0)
new_q1 = q1 * cos_half - q2 * sin_half
new_q2 = q2 * cos_half + q1 * sin_half
new_k1 = k1 * cos_half - k2 * sin_half
new_k2 = k2 * cos_half + k1 * sin_half
q_reshaped[token_idx, :, :rotary_dim] = torch.cat([new_q1, new_q2], dim=1)
k_reshaped[token_idx, :, :rotary_dim] = torch.cat([new_k1, new_k2], dim=1)
q_result = q_reshaped.view(num_tokens, -1)
k_result = k_reshaped.view(num_tokens, -1)
q = q_result.to(qkv.dtype)
k = k_result.to(qkv.dtype)
v = v.to(qkv.dtype)
return q, k, v
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("num_q_heads, num_kv_heads", NUM_QKV_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("eps", EPS)
@pytest.mark.parametrize("mrope_section", MROPE_SECTION)
@pytest.mark.parametrize("is_interleaved", IS_INTERLEAVED)
@pytest.mark.parametrize("has_gate", HAS_GATE)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("device", DEVICES)
@torch.inference_mode()
def test_split_qkv_rmsnorm_mrope(
num_tokens: int,
num_q_heads: int,
num_kv_heads: int,
head_size: int,
mrope_section: list[int],
eps: float,
dtype: torch.dtype,
device: str,
is_interleaved: bool,
has_gate: bool,
):
torch.set_default_device(device)
init_device_properties_triton()
rope_dim = 2 * sum(mrope_section)
q_size = num_q_heads * head_size
kv_size = num_kv_heads * head_size
# input tensor
if has_gate:
qkv = torch.randn(num_tokens,
2 * q_size + kv_size * 2,
dtype=dtype,
device=device)
else:
qkv = torch.randn(num_tokens,
q_size + kv_size * 2,
dtype=dtype,
device=device)
q_weight = torch.randn(head_size, dtype=dtype, device=device)
k_weight = torch.randn(head_size, dtype=dtype, device=device)
q_bias = None
k_bias = None
cos_sin = torch.randn(3, num_tokens, rope_dim, dtype=dtype,
device=device)
cos, sin = cos_sin.chunk(2, dim=-1)
cos = cos.contiguous()
sin = sin.contiguous()
if has_gate:
q_gate_data = qkv[:, :q_size * 2].view(-1, num_q_heads, head_size * 2)
q_data, golden_gate = torch.chunk(q_gate_data, 2, dim=-1)
golden_gate = golden_gate.reshape(-1, q_size)
q_data = q_data.reshape(-1, q_size)
k_data = qkv[:, 2 * q_size:2 * q_size + kv_size]
v_data = qkv[:, 2 * q_size + kv_size:]
qkv_for_ref = torch.cat([q_data, k_data, v_data], dim=-1)
else:
qkv_for_ref = qkv
if is_interleaved:
golden_q, golden_k, golden_v = naive_split_qkv_rmsnorm_mrope_interleaved(qkv_for_ref.cpu(),
q_weight.cpu(),
q_bias,
k_weight.cpu(),
k_bias,
cos.cpu(),
sin.cpu(),
num_q_heads,
num_kv_heads,
head_size,
eps,
mrope_section,
rope_dim)
else:
golden_q, golden_k, golden_v = naive_split_qkv_rmsnorm_mrope(qkv_for_ref.cpu(),
q_weight.cpu(),
q_bias,
k_weight.cpu(),
k_bias,
cos.cpu(),
sin.cpu(),
num_q_heads,
num_kv_heads,
head_size,
eps,
mrope_section,
rope_dim)
real_q, real_k, real_v, real_gate = torch.ops.vllm.triton_split_qkv_rmsnorm_mrope(
qkv=qkv,
q_weight=q_weight,
k_weight=k_weight,
cos_sin=cos_sin,
num_q_heads=num_q_heads,
num_kv_heads=num_kv_heads,
head_size=head_size,
eps=eps,
mrope_section=mrope_section,
is_interleaved=is_interleaved,
rope_dim=rope_dim,
has_gate=has_gate,
)
torch.testing.assert_close(real_q.cpu(),
golden_q.cpu(),
atol=DEFAULT_ATOL,
rtol=DEFAULT_RTOL)
torch.testing.assert_close(real_k.cpu(),
golden_k.cpu(),
atol=DEFAULT_ATOL,
rtol=DEFAULT_RTOL)
torch.testing.assert_close(real_v.cpu(),
golden_v.cpu(),
atol=DEFAULT_ATOL,
rtol=DEFAULT_RTOL)
if has_gate:
torch.testing.assert_close(real_gate.cpu(),
golden_gate.cpu(),
atol=DEFAULT_ATOL,
rtol=DEFAULT_RTOL)
gc.collect()
torch.npu.empty_cache()
torch.npu.reset_peak_memory_stats()

View File

@@ -24,6 +24,7 @@ import vllm_ascend.ops.register_custom_ops # noqa
if HAS_TRITON: if HAS_TRITON:
import vllm_ascend.ops.triton.linearnorm.split_qkv_rmsnorm_rope # noqa import vllm_ascend.ops.triton.linearnorm.split_qkv_rmsnorm_rope # noqa
import vllm_ascend.ops.triton.linearnorm.split_qkv_rmsnorm_mrope
import vllm_ascend.ops.vocab_parallel_embedding # noqa import vllm_ascend.ops.vocab_parallel_embedding # noqa
from vllm_ascend.ops.activation import AscendQuickGELU, AscendSiluAndMul from vllm_ascend.ops.activation import AscendQuickGELU, AscendSiluAndMul

View File

@@ -0,0 +1,428 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
import torch
import triton # type: ignore
import triton.language as tl # type: ignore
from vllm.utils.torch_utils import direct_register_custom_op
from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num
@triton.jit(
do_not_specialize=["num_tokens", "front_core_num", "num_tokens_each_front_core", "num_tokens_each_tail_core"]
)
def split_qkv_rmsnorm_mrope_kernel(
in_qkv_ptr: torch.Tensor,
q_weight_ptr: torch.Tensor,
q_bias_ptr: torch.Tensor,
k_weight_ptr: torch.Tensor,
k_bias_ptr: torch.Tensor,
cos_sin_ptr: torch.Tensor,
out_q_ptr: torch.Tensor,
out_k_ptr: torch.Tensor,
out_v_ptr: torch.Tensor,
out_gate_ptr: torch.Tensor,
num_tokens,
front_core_num,
num_tokens_each_front_core,
num_tokens_each_tail_core,
num_q_heads: tl.constexpr,
num_kv_heads: tl.constexpr,
head_size: tl.constexpr,
q_size: tl.constexpr,
kv_size: tl.constexpr,
eps: tl.constexpr,
mrope_section_t: tl.constexpr,
mrope_section_h: tl.constexpr,
mrope_section_w: tl.constexpr,
has_bias: tl.constexpr,
is_interleaved: tl.constexpr,
rope_dim: tl.constexpr,
half_rope_dim: tl.constexpr,
IS_PARTIAL_ROPE: tl.constexpr,
gate_size: tl.constexpr,
):
block_idx = tl.program_id(0)
loop_num = num_tokens_each_front_core
if block_idx >= front_core_num:
loop_num = num_tokens_each_tail_core
block_offset = num_tokens_each_front_core * block_idx
if block_idx >= front_core_num:
block_offset = (
num_tokens_each_front_core * front_core_num + (block_idx - front_core_num) * num_tokens_each_tail_core
)
q_rmsnorm_weight = tl.load(q_weight_ptr + tl.arange(0, head_size))
k_rmsnorm_weight = tl.load(k_weight_ptr + tl.arange(0, head_size))
if has_bias:
q_bias = tl.load(q_bias_ptr + tl.arange(0, head_size))
k_bias = tl.load(k_bias_ptr + tl.arange(0, head_size))
for index in range(loop_num):
## load ##
# q
in_q_offset = in_qkv_ptr + (block_offset + index) * (q_size + gate_size + 2 * kv_size)
if gate_size > 0:
in_q_gate_tensor = (
tl.load(in_q_offset + tl.arange(0, q_size + gate_size))
.to(tl.float32)
.reshape(num_q_heads, head_size * 2)
)
in_q_tensor = tl.extract_slice(
in_q_gate_tensor,
offsets=(0, 0),
sizes=(num_q_heads, head_size),
strides=(1, 1),
)
in_gate_tensor = tl.extract_slice(
in_q_gate_tensor,
offsets=(0, head_size),
sizes=(num_q_heads, head_size),
strides=(1, 1),
).reshape(q_size)
else:
in_q_tensor = tl.load(in_q_offset + tl.arange(0, q_size)).to(tl.float32).reshape(num_q_heads, head_size)
# k
in_k_offset = in_q_offset + q_size + gate_size
in_k_tensor = tl.load(in_k_offset + tl.arange(0, kv_size)).to(tl.float32).reshape(num_kv_heads, head_size)
# v
in_v_offset = in_k_offset + kv_size
in_v_tensor = tl.load(in_v_offset + tl.arange(0, kv_size))
# cos, sin
cos_offsets = tl.arange(0, half_rope_dim)
if is_interleaved:
h_mask = ((cos_offsets % 3) == 1) & (cos_offsets <= 3 * mrope_section_h)
w_mask = ((cos_offsets % 3) == 2) & (cos_offsets <= 3 * mrope_section_w)
t_mask = ~(h_mask | w_mask)
else:
t_mask = cos_offsets < mrope_section_t
h_mask = (mrope_section_t - 1 < cos_offsets) & (cos_offsets < mrope_section_t + mrope_section_h)
w_mask = (mrope_section_t + mrope_section_h - 1 < cos_offsets) & (
cos_offsets < mrope_section_t + mrope_section_h + mrope_section_w
)
t_cos_offset = cos_sin_ptr + (block_offset + index) * rope_dim
h_cos_offset = t_cos_offset + num_tokens * rope_dim
w_cos_offset = h_cos_offset + num_tokens * rope_dim
t_sin_offset = cos_sin_ptr + (block_offset + index) * rope_dim + half_rope_dim
h_sin_offset = t_sin_offset + num_tokens * rope_dim
w_sin_offset = h_sin_offset + num_tokens * rope_dim
t_cos_tensor = tl.load(t_cos_offset + cos_offsets, mask=t_mask, other=0)
h_cos_tensor = tl.load(h_cos_offset + cos_offsets, mask=h_mask, other=0)
w_cos_tensor = tl.load(w_cos_offset + cos_offsets, mask=w_mask, other=0)
t_sin_tensor = tl.load(t_sin_offset + cos_offsets, mask=t_mask, other=0)
h_sin_tensor = tl.load(h_sin_offset + cos_offsets, mask=h_mask, other=0)
w_sin_tensor = tl.load(w_sin_offset + cos_offsets, mask=w_mask, other=0)
cos_tensor = (t_cos_tensor + h_cos_tensor + w_cos_tensor).to(tl.float32).reshape(1, half_rope_dim)
cos_tensor = tl.broadcast_to(cos_tensor, (2, half_rope_dim)).reshape(1, rope_dim)
sin_tensor = (t_sin_tensor + h_sin_tensor + w_sin_tensor).to(tl.float32).reshape(1, half_rope_dim)
sin_tensor = tl.broadcast_to(sin_tensor, (2, half_rope_dim)).reshape(1, rope_dim)
## compute ##
# q-rmsnorm
squares = in_q_tensor * in_q_tensor
variances = tl.sum(squares, axis=1) / head_size
reciprocal_std = (1 / tl.sqrt(variances + eps)).reshape(num_q_heads, 1)
q_normalized = in_q_tensor * reciprocal_std
q_normalized = q_normalized * q_rmsnorm_weight
if has_bias:
q_normalized = q_normalized + q_bias
# k-rmsnorm
squares = in_k_tensor * in_k_tensor
variances = tl.sum(squares, axis=1) / head_size
reciprocal_std = (1 / tl.sqrt(variances + eps)).reshape(num_kv_heads, 1)
k_normalized = in_k_tensor * reciprocal_std
k_normalized = k_normalized * k_rmsnorm_weight
if has_bias:
k_normalized = k_normalized + k_bias
# q-mrope
x1 = tl.extract_slice(
q_normalized,
offsets=(0, 0),
sizes=(num_q_heads, half_rope_dim),
strides=(1, 1),
)
x2 = tl.extract_slice(
q_normalized,
offsets=(0, half_rope_dim),
sizes=(num_q_heads, half_rope_dim),
strides=(1, 1),
)
cat_x = tl.zeros((num_q_heads, rope_dim), dtype=tl.float32)
cat_x = tl.insert_slice(
cat_x,
-x2,
offsets=(0, 0),
sizes=(num_q_heads, half_rope_dim),
strides=(1, 1),
)
cat_x = tl.insert_slice(
cat_x,
x1,
offsets=(0, half_rope_dim),
sizes=(num_q_heads, half_rope_dim),
strides=(1, 1),
)
if IS_PARTIAL_ROPE:
orig_qk = tl.extract_slice(
q_normalized,
offsets=(0, 0),
sizes=(num_q_heads, rope_dim),
strides=(1, 1),
)
else:
orig_qk = q_normalized
roped_q = cat_x * sin_tensor + orig_qk * cos_tensor
# k-mrope
y1 = tl.extract_slice(
k_normalized,
offsets=(0, 0),
sizes=(num_kv_heads, half_rope_dim),
strides=(1, 1),
)
y2 = tl.extract_slice(
k_normalized,
offsets=(0, half_rope_dim),
sizes=(num_kv_heads, half_rope_dim),
strides=(1, 1),
)
cat_y = tl.zeros((num_kv_heads, rope_dim), dtype=tl.float32)
cat_y = tl.insert_slice(
cat_y,
-y2,
offsets=(0, 0),
sizes=(num_kv_heads, half_rope_dim),
strides=(1, 1),
)
cat_y = tl.insert_slice(
cat_y,
y1,
offsets=(0, half_rope_dim),
sizes=(num_kv_heads, half_rope_dim),
strides=(1, 1),
)
if IS_PARTIAL_ROPE:
orig_qk = tl.extract_slice(
k_normalized,
offsets=(0, 0),
sizes=(num_kv_heads, rope_dim),
strides=(1, 1),
)
else:
orig_qk = k_normalized
roped_k = cat_y * sin_tensor + orig_qk * cos_tensor
if IS_PARTIAL_ROPE:
q_normalized = tl.insert_slice(
q_normalized,
roped_q,
offsets=(0, 0),
sizes=(num_q_heads, rope_dim),
strides=(1, 1),
)
k_normalized = tl.insert_slice(
k_normalized,
roped_k,
offsets=(0, 0),
sizes=(num_kv_heads, rope_dim),
strides=(1, 1),
)
else:
q_normalized = roped_q
k_normalized = roped_k
## store ##
# out_q
out_q_offset = out_q_ptr + (block_offset + index) * q_size
out_q_indices = tl.arange(0, q_size)
tl.store(out_q_offset + out_q_indices, q_normalized.reshape(q_size))
# out_k
out_k_offset = out_k_ptr + (block_offset + index) * kv_size
out_k_indices = tl.arange(0, kv_size)
tl.store(out_k_offset + out_k_indices, k_normalized.reshape(kv_size))
# out_v
out_v_offset = out_v_ptr + (block_offset + index) * kv_size
tl.store(out_v_offset + tl.arange(0, kv_size), in_v_tensor)
# out_gate
if gate_size > 0:
out_gate_offset = out_gate_ptr + (block_offset + index) * gate_size
tl.store(out_gate_offset + tl.arange(0, gate_size), in_gate_tensor)
def triton_split_qkv_rmsnorm_mrope(
qkv: torch.Tensor,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
cos_sin: torch.Tensor,
num_q_heads: int,
num_kv_heads: int,
head_size: int,
eps: float,
mrope_section: list[int],
is_interleaved: bool,
rope_dim: int | None = None,
q_bias: torch.Tensor | None = None,
k_bias: torch.Tensor | None = None,
has_gate: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
core_num = get_vectorcore_num()
q_size = num_q_heads * head_size
kv_size = num_kv_heads * head_size
num_tokens = qkv.shape[0]
gate_size = q_size if has_gate else 0
if rope_dim is None:
rope_dim = head_size
IS_PARTIAL_ROPE = rope_dim != head_size
front_core_num = core_num
if num_tokens % core_num != 0:
front_core_num = num_tokens % core_num
num_tokens_each_front_core = (num_tokens + core_num - 1) // core_num
tail_core_num = 0
if num_tokens > core_num:
tail_core_num = core_num - front_core_num
num_tokens_each_tail_core = num_tokens // core_num
q_output = torch.empty(num_tokens, q_size, device=qkv.device, dtype=qkv.dtype)
k_output = torch.empty(num_tokens, kv_size, device=qkv.device, dtype=qkv.dtype)
v_output = torch.empty(num_tokens, kv_size, device=qkv.device, dtype=qkv.dtype)
gate_output = torch.empty(num_tokens, gate_size, device=qkv.device, dtype=qkv.dtype)
total_core = front_core_num + tail_core_num
block_dim = core_num
if total_core < core_num:
block_dim = total_core
has_bias = q_bias is not None
split_qkv_rmsnorm_mrope_kernel[(block_dim,)](
qkv,
q_weight,
q_bias,
k_weight,
k_bias,
cos_sin,
q_output,
k_output,
v_output,
gate_output,
num_tokens,
front_core_num,
num_tokens_each_front_core,
num_tokens_each_tail_core,
num_q_heads,
num_kv_heads,
head_size,
q_size,
kv_size,
eps,
mrope_section[0],
mrope_section[1],
mrope_section[2],
has_bias,
is_interleaved,
rope_dim,
rope_dim // 2,
IS_PARTIAL_ROPE,
gate_size,
)
return q_output, k_output, v_output, gate_output
def triton_split_qkv_rmsnorm_mrope_fake(
qkv: torch.Tensor,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
cos_sin: torch.Tensor,
num_q_heads: int,
num_kv_heads: int,
head_size: int,
eps: float,
mrope_section: list[int],
is_interleaved: bool,
rope_dim: int | None = None,
q_bias: torch.Tensor | None = None,
k_bias: torch.Tensor | None = None,
has_gate: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
num_tokens = qkv.shape[0]
q_size = num_q_heads * head_size
kv_size = num_kv_heads * head_size
gate_size = q_size if has_gate else 0
q_output = torch.empty(
num_tokens,
q_size,
device=qkv.device,
dtype=qkv.dtype,
)
k_output = torch.empty(
num_tokens,
kv_size,
device=qkv.device,
dtype=qkv.dtype,
)
v_output = torch.empty(
num_tokens,
kv_size,
device=qkv.device,
dtype=qkv.dtype,
)
gate_output = torch.empty(
num_tokens,
gate_size,
device=qkv.device,
dtype=qkv.dtype,
)
return q_output, k_output, v_output, gate_output
direct_register_custom_op(
op_name="triton_split_qkv_rmsnorm_mrope",
op_func=triton_split_qkv_rmsnorm_mrope,
fake_impl=triton_split_qkv_rmsnorm_mrope_fake,
mutates_args=[],
dispatch_key="PrivateUse1",
)