[OPS]add split_qkv_rmsnorm_mrope ops (#6730)
### What this PR does / why we need it?
This PR adds split_qkv_rmsnorm_mrope kernel with interleaved for qwen3.5
and qwen3-vl to improve performance.
### Does this PR introduce _any_ user-facing change?
Does not.
### How to use?
```python
real_q, real_k, real_v, real_gate = torch.ops.vllm.triton_split_qkv_rmsnorm_mrope(
qkv=qkv,
q_weight=q_weight,
k_weight=k_weight,
cos_sin=cos_sin,
num_q_heads=num_q_heads,
num_kv_heads=num_kv_heads,
head_size=head_size,
eps=eps,
mrope_section=mrope_section,
is_interleaved=is_interleaved,
rope_dim=rope_dim,
has_gate=has_gate,
)
```
### How was this patch tested?
- vLLM version: v0.16.0
- Accuracy test script:
```shell
pytest tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_split_qkv_rmsnorm_mrope.py
```
---------
Signed-off-by: Fager <865071616@qq.com>
Signed-off-by: Fager10086 <77871921+Fager10086@users.noreply.github.com>
Signed-off-by: fager <865071616@qq.com>
This commit is contained in:
@@ -0,0 +1,343 @@
|
||||
import gc
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm_ascend.ops.triton.triton_utils import init_device_properties_triton
|
||||
|
||||
NUM_TOKENS = [1, 4, 8, 16, 1024, 4096]
|
||||
NUM_QKV_HEADS = [(8, 2), (2, 1), (16, 2)]
|
||||
HEAD_SIZES = [128, 256]
|
||||
EPS = [1e-6]
|
||||
MROPE_SECTION = [[11, 11, 10], [24, 20, 20]]
|
||||
IS_INTERLEAVED = [True, False]
|
||||
HAS_GATE = [True, False]
|
||||
DTYPES = [torch.bfloat16, torch.float16]
|
||||
DEVICES = [f"npu:{0}"]
|
||||
DEFAULT_ATOL = 1e-2
|
||||
DEFAULT_RTOL = 1e-2
|
||||
|
||||
|
||||
def apply_interleaved_rope(x: torch.Tensor, mrope_section: list[int]) -> torch.Tensor:
|
||||
"""Apply interleaved MRoPE to 3D rotary embeddings.
|
||||
Reorganizes frequency layout from chunked [TTT...HHH...WWW] to
|
||||
interleaved [THTHWHTHW...TT], preserving frequency continuity.
|
||||
"""
|
||||
x_t = x[0].clone()
|
||||
x_t[..., 1 : mrope_section[1] * 3 : 3] = x[1, ..., 1 : mrope_section[1] * 3 : 3]
|
||||
x_t[..., 2 : mrope_section[2] * 3 : 3] = x[2, ..., 2 : mrope_section[2] * 3 : 3]
|
||||
return x_t
|
||||
|
||||
|
||||
def rms_norm(x: torch.Tensor,
|
||||
norm_weight: torch.Tensor,
|
||||
eps,
|
||||
norm_bias=None,
|
||||
):
|
||||
x = x.cpu()
|
||||
norm_weight = norm_weight.cpu()
|
||||
|
||||
x = x.to(torch.float32)
|
||||
norm_weight = norm_weight.to(torch.float32).cpu()
|
||||
reciprocal_std = 1 / torch.sqrt(
|
||||
torch.mean(x ** 2, axis=-1, keepdims=True) + eps)
|
||||
out = x * reciprocal_std * norm_weight
|
||||
|
||||
if norm_bias is not None:
|
||||
norm_bias = norm_bias.cpu().to(torch.float32)
|
||||
out = out + norm_bias
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def naive_split_qkv_rmsnorm_mrope(
|
||||
qkv: torch.Tensor,
|
||||
q_weight: torch.Tensor,
|
||||
q_bias: torch.Tensor,
|
||||
k_weight: torch.Tensor,
|
||||
k_bias: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
num_q_heads: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
eps: float,
|
||||
mrope_section: list[int],
|
||||
rope_dim: int,
|
||||
):
|
||||
q_size = num_q_heads * head_size
|
||||
kv_size = num_kv_heads * head_size
|
||||
|
||||
# split
|
||||
qkv = qkv.cpu()
|
||||
q, k, v = qkv.split([q_size, kv_size, kv_size], dim=-1)
|
||||
|
||||
# norm
|
||||
q = rms_norm(q.reshape(-1, head_size), q_weight, eps, norm_bias=q_bias)
|
||||
k = rms_norm(k.reshape(-1, head_size), k_weight, eps, norm_bias=k_bias)
|
||||
|
||||
# mrope
|
||||
rotary_dim = rope_dim
|
||||
num_tokens = qkv.shape[0]
|
||||
n_q_head = num_q_heads
|
||||
n_kv_head = num_kv_heads
|
||||
q_reshaped = q.view(num_tokens, n_q_head, head_size)
|
||||
k_reshaped = k.view(num_tokens, n_kv_head, head_size)
|
||||
cos_reshaped = cos.permute(1, 2, 0)
|
||||
sin_reshaped = sin.permute(1, 2, 0)
|
||||
half_rd = rotary_dim // 2
|
||||
|
||||
for token_idx in range(num_tokens):
|
||||
token_cos = cos_reshaped[token_idx]
|
||||
token_sin = sin_reshaped[token_idx]
|
||||
|
||||
cos_row = torch.zeros(half_rd, device=q.device, dtype=q.dtype)
|
||||
sin_row = torch.zeros(half_rd, device=q.device, dtype=q.dtype)
|
||||
|
||||
t_end = mrope_section[0]
|
||||
h_end = t_end + mrope_section[1]
|
||||
|
||||
if t_end > 0:
|
||||
cos_row[:t_end] = token_cos[:t_end, 0]
|
||||
sin_row[:t_end] = token_sin[:t_end, 0]
|
||||
|
||||
if mrope_section[1] > 0:
|
||||
cos_row[t_end:h_end] = token_cos[t_end:h_end, 1]
|
||||
sin_row[t_end:h_end] = token_sin[t_end:h_end, 1]
|
||||
|
||||
if mrope_section[2] > 0:
|
||||
w_start = h_end
|
||||
cos_row[w_start:half_rd] = token_cos[w_start:half_rd, 2]
|
||||
sin_row[w_start:half_rd] = token_sin[w_start:half_rd, 2]
|
||||
|
||||
q_token = q_reshaped[token_idx]
|
||||
k_token = k_reshaped[token_idx]
|
||||
|
||||
q1 = q_token[:, :half_rd]
|
||||
q2 = q_token[:, half_rd:rotary_dim]
|
||||
k1 = k_token[:, :half_rd]
|
||||
k2 = k_token[:, half_rd:rotary_dim]
|
||||
|
||||
cos_half = cos_row.unsqueeze(0)
|
||||
sin_half = sin_row.unsqueeze(0)
|
||||
|
||||
new_q1 = q1 * cos_half - q2 * sin_half
|
||||
new_q2 = q2 * cos_half + q1 * sin_half
|
||||
|
||||
new_k1 = k1 * cos_half - k2 * sin_half
|
||||
new_k2 = k2 * cos_half + k1 * sin_half
|
||||
|
||||
q_reshaped[token_idx, :, :rotary_dim] = torch.cat([new_q1, new_q2], dim=1)
|
||||
k_reshaped[token_idx, :, :rotary_dim] = torch.cat([new_k1, new_k2], dim=1)
|
||||
|
||||
q_result = q_reshaped.view(num_tokens, -1)
|
||||
k_result = k_reshaped.view(num_tokens, -1)
|
||||
|
||||
q = q_result.to(qkv.dtype)
|
||||
k = k_result.to(qkv.dtype)
|
||||
v = v.to(qkv.dtype)
|
||||
|
||||
return q, k, v
|
||||
|
||||
|
||||
def naive_split_qkv_rmsnorm_mrope_interleaved(
|
||||
qkv: torch.Tensor,
|
||||
q_weight: torch.Tensor,
|
||||
q_bias: torch.Tensor,
|
||||
k_weight: torch.Tensor,
|
||||
k_bias: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
num_q_heads: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
eps: float,
|
||||
mrope_section: list[int],
|
||||
rope_dim: int,
|
||||
):
|
||||
q_size = num_q_heads * head_size
|
||||
kv_size = num_kv_heads * head_size
|
||||
|
||||
# split
|
||||
qkv = qkv.cpu()
|
||||
q, k, v = qkv.split([q_size, kv_size, kv_size], dim=-1)
|
||||
|
||||
# norm
|
||||
q = rms_norm(q.reshape(-1, head_size), q_weight, eps, norm_bias=q_bias)
|
||||
k = rms_norm(k.reshape(-1, head_size), k_weight, eps, norm_bias=k_bias)
|
||||
|
||||
# mrope
|
||||
rotary_dim = rope_dim
|
||||
num_tokens = qkv.shape[0]
|
||||
n_q_head = num_q_heads
|
||||
n_kv_head = num_kv_heads
|
||||
q_reshaped = q.view(num_tokens, n_q_head, head_size)
|
||||
k_reshaped = k.view(num_tokens, n_kv_head, head_size)
|
||||
cos_reshaped = apply_interleaved_rope(cos, mrope_section)
|
||||
sin_reshaped = apply_interleaved_rope(sin, mrope_section)
|
||||
half_rd = rotary_dim // 2
|
||||
|
||||
for token_idx in range(num_tokens):
|
||||
cos_row = cos_reshaped[token_idx]
|
||||
sin_row = sin_reshaped[token_idx]
|
||||
|
||||
q_token = q_reshaped[token_idx]
|
||||
k_token = k_reshaped[token_idx]
|
||||
|
||||
q1 = q_token[:, :half_rd]
|
||||
q2 = q_token[:, half_rd:rotary_dim]
|
||||
k1 = k_token[:, :half_rd]
|
||||
k2 = k_token[:, half_rd:rotary_dim]
|
||||
|
||||
cos_half = cos_row.unsqueeze(0)
|
||||
sin_half = sin_row.unsqueeze(0)
|
||||
|
||||
new_q1 = q1 * cos_half - q2 * sin_half
|
||||
new_q2 = q2 * cos_half + q1 * sin_half
|
||||
|
||||
new_k1 = k1 * cos_half - k2 * sin_half
|
||||
new_k2 = k2 * cos_half + k1 * sin_half
|
||||
|
||||
q_reshaped[token_idx, :, :rotary_dim] = torch.cat([new_q1, new_q2], dim=1)
|
||||
k_reshaped[token_idx, :, :rotary_dim] = torch.cat([new_k1, new_k2], dim=1)
|
||||
|
||||
q_result = q_reshaped.view(num_tokens, -1)
|
||||
k_result = k_reshaped.view(num_tokens, -1)
|
||||
|
||||
q = q_result.to(qkv.dtype)
|
||||
k = k_result.to(qkv.dtype)
|
||||
v = v.to(qkv.dtype)
|
||||
|
||||
return q, k, v
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||
@pytest.mark.parametrize("num_q_heads, num_kv_heads", NUM_QKV_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("eps", EPS)
|
||||
@pytest.mark.parametrize("mrope_section", MROPE_SECTION)
|
||||
@pytest.mark.parametrize("is_interleaved", IS_INTERLEAVED)
|
||||
@pytest.mark.parametrize("has_gate", HAS_GATE)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_split_qkv_rmsnorm_mrope(
|
||||
num_tokens: int,
|
||||
num_q_heads: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
mrope_section: list[int],
|
||||
eps: float,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
is_interleaved: bool,
|
||||
has_gate: bool,
|
||||
):
|
||||
|
||||
torch.set_default_device(device)
|
||||
init_device_properties_triton()
|
||||
rope_dim = 2 * sum(mrope_section)
|
||||
q_size = num_q_heads * head_size
|
||||
kv_size = num_kv_heads * head_size
|
||||
|
||||
# input tensor
|
||||
if has_gate:
|
||||
qkv = torch.randn(num_tokens,
|
||||
2 * q_size + kv_size * 2,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
else:
|
||||
qkv = torch.randn(num_tokens,
|
||||
q_size + kv_size * 2,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
q_weight = torch.randn(head_size, dtype=dtype, device=device)
|
||||
k_weight = torch.randn(head_size, dtype=dtype, device=device)
|
||||
q_bias = None
|
||||
k_bias = None
|
||||
|
||||
cos_sin = torch.randn(3, num_tokens, rope_dim, dtype=dtype,
|
||||
device=device)
|
||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||
|
||||
cos = cos.contiguous()
|
||||
sin = sin.contiguous()
|
||||
|
||||
if has_gate:
|
||||
q_gate_data = qkv[:, :q_size * 2].view(-1, num_q_heads, head_size * 2)
|
||||
q_data, golden_gate = torch.chunk(q_gate_data, 2, dim=-1)
|
||||
golden_gate = golden_gate.reshape(-1, q_size)
|
||||
q_data = q_data.reshape(-1, q_size)
|
||||
k_data = qkv[:, 2 * q_size:2 * q_size + kv_size]
|
||||
v_data = qkv[:, 2 * q_size + kv_size:]
|
||||
qkv_for_ref = torch.cat([q_data, k_data, v_data], dim=-1)
|
||||
else:
|
||||
qkv_for_ref = qkv
|
||||
|
||||
if is_interleaved:
|
||||
golden_q, golden_k, golden_v = naive_split_qkv_rmsnorm_mrope_interleaved(qkv_for_ref.cpu(),
|
||||
q_weight.cpu(),
|
||||
q_bias,
|
||||
k_weight.cpu(),
|
||||
k_bias,
|
||||
cos.cpu(),
|
||||
sin.cpu(),
|
||||
num_q_heads,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
eps,
|
||||
mrope_section,
|
||||
rope_dim)
|
||||
else:
|
||||
golden_q, golden_k, golden_v = naive_split_qkv_rmsnorm_mrope(qkv_for_ref.cpu(),
|
||||
q_weight.cpu(),
|
||||
q_bias,
|
||||
k_weight.cpu(),
|
||||
k_bias,
|
||||
cos.cpu(),
|
||||
sin.cpu(),
|
||||
num_q_heads,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
eps,
|
||||
mrope_section,
|
||||
rope_dim)
|
||||
|
||||
real_q, real_k, real_v, real_gate = torch.ops.vllm.triton_split_qkv_rmsnorm_mrope(
|
||||
qkv=qkv,
|
||||
q_weight=q_weight,
|
||||
k_weight=k_weight,
|
||||
cos_sin=cos_sin,
|
||||
num_q_heads=num_q_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_size=head_size,
|
||||
eps=eps,
|
||||
mrope_section=mrope_section,
|
||||
is_interleaved=is_interleaved,
|
||||
rope_dim=rope_dim,
|
||||
has_gate=has_gate,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(real_q.cpu(),
|
||||
golden_q.cpu(),
|
||||
atol=DEFAULT_ATOL,
|
||||
rtol=DEFAULT_RTOL)
|
||||
|
||||
torch.testing.assert_close(real_k.cpu(),
|
||||
golden_k.cpu(),
|
||||
atol=DEFAULT_ATOL,
|
||||
rtol=DEFAULT_RTOL)
|
||||
|
||||
torch.testing.assert_close(real_v.cpu(),
|
||||
golden_v.cpu(),
|
||||
atol=DEFAULT_ATOL,
|
||||
rtol=DEFAULT_RTOL)
|
||||
if has_gate:
|
||||
torch.testing.assert_close(real_gate.cpu(),
|
||||
golden_gate.cpu(),
|
||||
atol=DEFAULT_ATOL,
|
||||
rtol=DEFAULT_RTOL)
|
||||
|
||||
gc.collect()
|
||||
torch.npu.empty_cache()
|
||||
torch.npu.reset_peak_memory_stats()
|
||||
@@ -24,6 +24,7 @@ import vllm_ascend.ops.register_custom_ops # noqa
|
||||
|
||||
if HAS_TRITON:
|
||||
import vllm_ascend.ops.triton.linearnorm.split_qkv_rmsnorm_rope # noqa
|
||||
import vllm_ascend.ops.triton.linearnorm.split_qkv_rmsnorm_mrope
|
||||
|
||||
import vllm_ascend.ops.vocab_parallel_embedding # noqa
|
||||
from vllm_ascend.ops.activation import AscendQuickGELU, AscendSiluAndMul
|
||||
|
||||
428
vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_mrope.py
Normal file
428
vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_mrope.py
Normal file
@@ -0,0 +1,428 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
|
||||
import torch
|
||||
import triton # type: ignore
|
||||
import triton.language as tl # type: ignore
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num
|
||||
|
||||
|
||||
@triton.jit(
|
||||
do_not_specialize=["num_tokens", "front_core_num", "num_tokens_each_front_core", "num_tokens_each_tail_core"]
|
||||
)
|
||||
def split_qkv_rmsnorm_mrope_kernel(
|
||||
in_qkv_ptr: torch.Tensor,
|
||||
q_weight_ptr: torch.Tensor,
|
||||
q_bias_ptr: torch.Tensor,
|
||||
k_weight_ptr: torch.Tensor,
|
||||
k_bias_ptr: torch.Tensor,
|
||||
cos_sin_ptr: torch.Tensor,
|
||||
out_q_ptr: torch.Tensor,
|
||||
out_k_ptr: torch.Tensor,
|
||||
out_v_ptr: torch.Tensor,
|
||||
out_gate_ptr: torch.Tensor,
|
||||
num_tokens,
|
||||
front_core_num,
|
||||
num_tokens_each_front_core,
|
||||
num_tokens_each_tail_core,
|
||||
num_q_heads: tl.constexpr,
|
||||
num_kv_heads: tl.constexpr,
|
||||
head_size: tl.constexpr,
|
||||
q_size: tl.constexpr,
|
||||
kv_size: tl.constexpr,
|
||||
eps: tl.constexpr,
|
||||
mrope_section_t: tl.constexpr,
|
||||
mrope_section_h: tl.constexpr,
|
||||
mrope_section_w: tl.constexpr,
|
||||
has_bias: tl.constexpr,
|
||||
is_interleaved: tl.constexpr,
|
||||
rope_dim: tl.constexpr,
|
||||
half_rope_dim: tl.constexpr,
|
||||
IS_PARTIAL_ROPE: tl.constexpr,
|
||||
gate_size: tl.constexpr,
|
||||
):
|
||||
block_idx = tl.program_id(0)
|
||||
|
||||
loop_num = num_tokens_each_front_core
|
||||
if block_idx >= front_core_num:
|
||||
loop_num = num_tokens_each_tail_core
|
||||
|
||||
block_offset = num_tokens_each_front_core * block_idx
|
||||
if block_idx >= front_core_num:
|
||||
block_offset = (
|
||||
num_tokens_each_front_core * front_core_num + (block_idx - front_core_num) * num_tokens_each_tail_core
|
||||
)
|
||||
|
||||
q_rmsnorm_weight = tl.load(q_weight_ptr + tl.arange(0, head_size))
|
||||
k_rmsnorm_weight = tl.load(k_weight_ptr + tl.arange(0, head_size))
|
||||
|
||||
if has_bias:
|
||||
q_bias = tl.load(q_bias_ptr + tl.arange(0, head_size))
|
||||
k_bias = tl.load(k_bias_ptr + tl.arange(0, head_size))
|
||||
|
||||
for index in range(loop_num):
|
||||
## load ##
|
||||
# q
|
||||
in_q_offset = in_qkv_ptr + (block_offset + index) * (q_size + gate_size + 2 * kv_size)
|
||||
if gate_size > 0:
|
||||
in_q_gate_tensor = (
|
||||
tl.load(in_q_offset + tl.arange(0, q_size + gate_size))
|
||||
.to(tl.float32)
|
||||
.reshape(num_q_heads, head_size * 2)
|
||||
)
|
||||
in_q_tensor = tl.extract_slice(
|
||||
in_q_gate_tensor,
|
||||
offsets=(0, 0),
|
||||
sizes=(num_q_heads, head_size),
|
||||
strides=(1, 1),
|
||||
)
|
||||
in_gate_tensor = tl.extract_slice(
|
||||
in_q_gate_tensor,
|
||||
offsets=(0, head_size),
|
||||
sizes=(num_q_heads, head_size),
|
||||
strides=(1, 1),
|
||||
).reshape(q_size)
|
||||
else:
|
||||
in_q_tensor = tl.load(in_q_offset + tl.arange(0, q_size)).to(tl.float32).reshape(num_q_heads, head_size)
|
||||
|
||||
# k
|
||||
in_k_offset = in_q_offset + q_size + gate_size
|
||||
in_k_tensor = tl.load(in_k_offset + tl.arange(0, kv_size)).to(tl.float32).reshape(num_kv_heads, head_size)
|
||||
# v
|
||||
in_v_offset = in_k_offset + kv_size
|
||||
in_v_tensor = tl.load(in_v_offset + tl.arange(0, kv_size))
|
||||
|
||||
# cos, sin
|
||||
cos_offsets = tl.arange(0, half_rope_dim)
|
||||
if is_interleaved:
|
||||
h_mask = ((cos_offsets % 3) == 1) & (cos_offsets <= 3 * mrope_section_h)
|
||||
w_mask = ((cos_offsets % 3) == 2) & (cos_offsets <= 3 * mrope_section_w)
|
||||
t_mask = ~(h_mask | w_mask)
|
||||
else:
|
||||
t_mask = cos_offsets < mrope_section_t
|
||||
h_mask = (mrope_section_t - 1 < cos_offsets) & (cos_offsets < mrope_section_t + mrope_section_h)
|
||||
w_mask = (mrope_section_t + mrope_section_h - 1 < cos_offsets) & (
|
||||
cos_offsets < mrope_section_t + mrope_section_h + mrope_section_w
|
||||
)
|
||||
|
||||
t_cos_offset = cos_sin_ptr + (block_offset + index) * rope_dim
|
||||
h_cos_offset = t_cos_offset + num_tokens * rope_dim
|
||||
w_cos_offset = h_cos_offset + num_tokens * rope_dim
|
||||
|
||||
t_sin_offset = cos_sin_ptr + (block_offset + index) * rope_dim + half_rope_dim
|
||||
h_sin_offset = t_sin_offset + num_tokens * rope_dim
|
||||
w_sin_offset = h_sin_offset + num_tokens * rope_dim
|
||||
|
||||
t_cos_tensor = tl.load(t_cos_offset + cos_offsets, mask=t_mask, other=0)
|
||||
h_cos_tensor = tl.load(h_cos_offset + cos_offsets, mask=h_mask, other=0)
|
||||
w_cos_tensor = tl.load(w_cos_offset + cos_offsets, mask=w_mask, other=0)
|
||||
t_sin_tensor = tl.load(t_sin_offset + cos_offsets, mask=t_mask, other=0)
|
||||
h_sin_tensor = tl.load(h_sin_offset + cos_offsets, mask=h_mask, other=0)
|
||||
w_sin_tensor = tl.load(w_sin_offset + cos_offsets, mask=w_mask, other=0)
|
||||
|
||||
cos_tensor = (t_cos_tensor + h_cos_tensor + w_cos_tensor).to(tl.float32).reshape(1, half_rope_dim)
|
||||
cos_tensor = tl.broadcast_to(cos_tensor, (2, half_rope_dim)).reshape(1, rope_dim)
|
||||
|
||||
sin_tensor = (t_sin_tensor + h_sin_tensor + w_sin_tensor).to(tl.float32).reshape(1, half_rope_dim)
|
||||
sin_tensor = tl.broadcast_to(sin_tensor, (2, half_rope_dim)).reshape(1, rope_dim)
|
||||
|
||||
## compute ##
|
||||
# q-rmsnorm
|
||||
squares = in_q_tensor * in_q_tensor
|
||||
variances = tl.sum(squares, axis=1) / head_size
|
||||
reciprocal_std = (1 / tl.sqrt(variances + eps)).reshape(num_q_heads, 1)
|
||||
q_normalized = in_q_tensor * reciprocal_std
|
||||
q_normalized = q_normalized * q_rmsnorm_weight
|
||||
if has_bias:
|
||||
q_normalized = q_normalized + q_bias
|
||||
|
||||
# k-rmsnorm
|
||||
squares = in_k_tensor * in_k_tensor
|
||||
variances = tl.sum(squares, axis=1) / head_size
|
||||
reciprocal_std = (1 / tl.sqrt(variances + eps)).reshape(num_kv_heads, 1)
|
||||
k_normalized = in_k_tensor * reciprocal_std
|
||||
k_normalized = k_normalized * k_rmsnorm_weight
|
||||
if has_bias:
|
||||
k_normalized = k_normalized + k_bias
|
||||
|
||||
# q-mrope
|
||||
x1 = tl.extract_slice(
|
||||
q_normalized,
|
||||
offsets=(0, 0),
|
||||
sizes=(num_q_heads, half_rope_dim),
|
||||
strides=(1, 1),
|
||||
)
|
||||
x2 = tl.extract_slice(
|
||||
q_normalized,
|
||||
offsets=(0, half_rope_dim),
|
||||
sizes=(num_q_heads, half_rope_dim),
|
||||
strides=(1, 1),
|
||||
)
|
||||
cat_x = tl.zeros((num_q_heads, rope_dim), dtype=tl.float32)
|
||||
cat_x = tl.insert_slice(
|
||||
cat_x,
|
||||
-x2,
|
||||
offsets=(0, 0),
|
||||
sizes=(num_q_heads, half_rope_dim),
|
||||
strides=(1, 1),
|
||||
)
|
||||
cat_x = tl.insert_slice(
|
||||
cat_x,
|
||||
x1,
|
||||
offsets=(0, half_rope_dim),
|
||||
sizes=(num_q_heads, half_rope_dim),
|
||||
strides=(1, 1),
|
||||
)
|
||||
if IS_PARTIAL_ROPE:
|
||||
orig_qk = tl.extract_slice(
|
||||
q_normalized,
|
||||
offsets=(0, 0),
|
||||
sizes=(num_q_heads, rope_dim),
|
||||
strides=(1, 1),
|
||||
)
|
||||
else:
|
||||
orig_qk = q_normalized
|
||||
roped_q = cat_x * sin_tensor + orig_qk * cos_tensor
|
||||
|
||||
# k-mrope
|
||||
y1 = tl.extract_slice(
|
||||
k_normalized,
|
||||
offsets=(0, 0),
|
||||
sizes=(num_kv_heads, half_rope_dim),
|
||||
strides=(1, 1),
|
||||
)
|
||||
y2 = tl.extract_slice(
|
||||
k_normalized,
|
||||
offsets=(0, half_rope_dim),
|
||||
sizes=(num_kv_heads, half_rope_dim),
|
||||
strides=(1, 1),
|
||||
)
|
||||
cat_y = tl.zeros((num_kv_heads, rope_dim), dtype=tl.float32)
|
||||
cat_y = tl.insert_slice(
|
||||
cat_y,
|
||||
-y2,
|
||||
offsets=(0, 0),
|
||||
sizes=(num_kv_heads, half_rope_dim),
|
||||
strides=(1, 1),
|
||||
)
|
||||
cat_y = tl.insert_slice(
|
||||
cat_y,
|
||||
y1,
|
||||
offsets=(0, half_rope_dim),
|
||||
sizes=(num_kv_heads, half_rope_dim),
|
||||
strides=(1, 1),
|
||||
)
|
||||
if IS_PARTIAL_ROPE:
|
||||
orig_qk = tl.extract_slice(
|
||||
k_normalized,
|
||||
offsets=(0, 0),
|
||||
sizes=(num_kv_heads, rope_dim),
|
||||
strides=(1, 1),
|
||||
)
|
||||
else:
|
||||
orig_qk = k_normalized
|
||||
roped_k = cat_y * sin_tensor + orig_qk * cos_tensor
|
||||
|
||||
if IS_PARTIAL_ROPE:
|
||||
q_normalized = tl.insert_slice(
|
||||
q_normalized,
|
||||
roped_q,
|
||||
offsets=(0, 0),
|
||||
sizes=(num_q_heads, rope_dim),
|
||||
strides=(1, 1),
|
||||
)
|
||||
k_normalized = tl.insert_slice(
|
||||
k_normalized,
|
||||
roped_k,
|
||||
offsets=(0, 0),
|
||||
sizes=(num_kv_heads, rope_dim),
|
||||
strides=(1, 1),
|
||||
)
|
||||
else:
|
||||
q_normalized = roped_q
|
||||
k_normalized = roped_k
|
||||
|
||||
## store ##
|
||||
# out_q
|
||||
out_q_offset = out_q_ptr + (block_offset + index) * q_size
|
||||
out_q_indices = tl.arange(0, q_size)
|
||||
tl.store(out_q_offset + out_q_indices, q_normalized.reshape(q_size))
|
||||
|
||||
# out_k
|
||||
out_k_offset = out_k_ptr + (block_offset + index) * kv_size
|
||||
out_k_indices = tl.arange(0, kv_size)
|
||||
tl.store(out_k_offset + out_k_indices, k_normalized.reshape(kv_size))
|
||||
|
||||
# out_v
|
||||
out_v_offset = out_v_ptr + (block_offset + index) * kv_size
|
||||
tl.store(out_v_offset + tl.arange(0, kv_size), in_v_tensor)
|
||||
|
||||
# out_gate
|
||||
if gate_size > 0:
|
||||
out_gate_offset = out_gate_ptr + (block_offset + index) * gate_size
|
||||
tl.store(out_gate_offset + tl.arange(0, gate_size), in_gate_tensor)
|
||||
|
||||
|
||||
def triton_split_qkv_rmsnorm_mrope(
|
||||
qkv: torch.Tensor,
|
||||
q_weight: torch.Tensor,
|
||||
k_weight: torch.Tensor,
|
||||
cos_sin: torch.Tensor,
|
||||
num_q_heads: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
eps: float,
|
||||
mrope_section: list[int],
|
||||
is_interleaved: bool,
|
||||
rope_dim: int | None = None,
|
||||
q_bias: torch.Tensor | None = None,
|
||||
k_bias: torch.Tensor | None = None,
|
||||
has_gate: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
core_num = get_vectorcore_num()
|
||||
|
||||
q_size = num_q_heads * head_size
|
||||
kv_size = num_kv_heads * head_size
|
||||
num_tokens = qkv.shape[0]
|
||||
|
||||
gate_size = q_size if has_gate else 0
|
||||
|
||||
if rope_dim is None:
|
||||
rope_dim = head_size
|
||||
IS_PARTIAL_ROPE = rope_dim != head_size
|
||||
|
||||
front_core_num = core_num
|
||||
if num_tokens % core_num != 0:
|
||||
front_core_num = num_tokens % core_num
|
||||
|
||||
num_tokens_each_front_core = (num_tokens + core_num - 1) // core_num
|
||||
|
||||
tail_core_num = 0
|
||||
if num_tokens > core_num:
|
||||
tail_core_num = core_num - front_core_num
|
||||
|
||||
num_tokens_each_tail_core = num_tokens // core_num
|
||||
|
||||
q_output = torch.empty(num_tokens, q_size, device=qkv.device, dtype=qkv.dtype)
|
||||
k_output = torch.empty(num_tokens, kv_size, device=qkv.device, dtype=qkv.dtype)
|
||||
v_output = torch.empty(num_tokens, kv_size, device=qkv.device, dtype=qkv.dtype)
|
||||
gate_output = torch.empty(num_tokens, gate_size, device=qkv.device, dtype=qkv.dtype)
|
||||
|
||||
total_core = front_core_num + tail_core_num
|
||||
block_dim = core_num
|
||||
if total_core < core_num:
|
||||
block_dim = total_core
|
||||
|
||||
has_bias = q_bias is not None
|
||||
|
||||
split_qkv_rmsnorm_mrope_kernel[(block_dim,)](
|
||||
qkv,
|
||||
q_weight,
|
||||
q_bias,
|
||||
k_weight,
|
||||
k_bias,
|
||||
cos_sin,
|
||||
q_output,
|
||||
k_output,
|
||||
v_output,
|
||||
gate_output,
|
||||
num_tokens,
|
||||
front_core_num,
|
||||
num_tokens_each_front_core,
|
||||
num_tokens_each_tail_core,
|
||||
num_q_heads,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
q_size,
|
||||
kv_size,
|
||||
eps,
|
||||
mrope_section[0],
|
||||
mrope_section[1],
|
||||
mrope_section[2],
|
||||
has_bias,
|
||||
is_interleaved,
|
||||
rope_dim,
|
||||
rope_dim // 2,
|
||||
IS_PARTIAL_ROPE,
|
||||
gate_size,
|
||||
)
|
||||
|
||||
return q_output, k_output, v_output, gate_output
|
||||
|
||||
|
||||
def triton_split_qkv_rmsnorm_mrope_fake(
|
||||
qkv: torch.Tensor,
|
||||
q_weight: torch.Tensor,
|
||||
k_weight: torch.Tensor,
|
||||
cos_sin: torch.Tensor,
|
||||
num_q_heads: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
eps: float,
|
||||
mrope_section: list[int],
|
||||
is_interleaved: bool,
|
||||
rope_dim: int | None = None,
|
||||
q_bias: torch.Tensor | None = None,
|
||||
k_bias: torch.Tensor | None = None,
|
||||
has_gate: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
num_tokens = qkv.shape[0]
|
||||
q_size = num_q_heads * head_size
|
||||
kv_size = num_kv_heads * head_size
|
||||
gate_size = q_size if has_gate else 0
|
||||
|
||||
q_output = torch.empty(
|
||||
num_tokens,
|
||||
q_size,
|
||||
device=qkv.device,
|
||||
dtype=qkv.dtype,
|
||||
)
|
||||
|
||||
k_output = torch.empty(
|
||||
num_tokens,
|
||||
kv_size,
|
||||
device=qkv.device,
|
||||
dtype=qkv.dtype,
|
||||
)
|
||||
|
||||
v_output = torch.empty(
|
||||
num_tokens,
|
||||
kv_size,
|
||||
device=qkv.device,
|
||||
dtype=qkv.dtype,
|
||||
)
|
||||
|
||||
gate_output = torch.empty(
|
||||
num_tokens,
|
||||
gate_size,
|
||||
device=qkv.device,
|
||||
dtype=qkv.dtype,
|
||||
)
|
||||
|
||||
return q_output, k_output, v_output, gate_output
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="triton_split_qkv_rmsnorm_mrope",
|
||||
op_func=triton_split_qkv_rmsnorm_mrope,
|
||||
fake_impl=triton_split_qkv_rmsnorm_mrope_fake,
|
||||
mutates_args=[],
|
||||
dispatch_key="PrivateUse1",
|
||||
)
|
||||
Reference in New Issue
Block a user