From c5dfa8d645958729e96edde318becde848e81a14 Mon Sep 17 00:00:00 2001 From: Fager10086 <77871921+Fager10086@users.noreply.github.com> Date: Fri, 6 Mar 2026 16:18:37 +0800 Subject: [PATCH] [OPS]add split_qkv_rmsnorm_mrope ops (#6730) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What this PR does / why we need it? This PR adds split_qkv_rmsnorm_mrope kernel with interleaved for qwen3.5 and qwen3-vl to improve performance. ### Does this PR introduce _any_ user-facing change? Does not. ### How to use? ```python real_q, real_k, real_v, real_gate = torch.ops.vllm.triton_split_qkv_rmsnorm_mrope( qkv=qkv, q_weight=q_weight, k_weight=k_weight, cos_sin=cos_sin, num_q_heads=num_q_heads, num_kv_heads=num_kv_heads, head_size=head_size, eps=eps, mrope_section=mrope_section, is_interleaved=is_interleaved, rope_dim=rope_dim, has_gate=has_gate, ) ``` ### How was this patch tested? - vLLM version: v0.16.0 - Accuracy test script: ```shell pytest tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_split_qkv_rmsnorm_mrope.py ``` --------- Signed-off-by: Fager <865071616@qq.com> Signed-off-by: Fager10086 <77871921+Fager10086@users.noreply.github.com> Signed-off-by: fager <865071616@qq.com> --- .../triton/test_split_qkv_rmsnorm_mrope.py | 343 ++++++++++++++ vllm_ascend/ops/__init__.py | 1 + .../linearnorm/split_qkv_rmsnorm_mrope.py | 428 ++++++++++++++++++ 3 files changed, 772 insertions(+) create mode 100644 tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_split_qkv_rmsnorm_mrope.py create mode 100644 vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_mrope.py diff --git a/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_split_qkv_rmsnorm_mrope.py b/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_split_qkv_rmsnorm_mrope.py new file mode 100644 index 00000000..957de91d --- /dev/null +++ b/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_split_qkv_rmsnorm_mrope.py @@ -0,0 +1,343 @@ +import gc + +import pytest +import torch + +from vllm_ascend.ops.triton.triton_utils import init_device_properties_triton + +NUM_TOKENS = [1, 4, 8, 16, 1024, 4096] +NUM_QKV_HEADS = [(8, 2), (2, 1), (16, 2)] +HEAD_SIZES = [128, 256] +EPS = [1e-6] +MROPE_SECTION = [[11, 11, 10], [24, 20, 20]] +IS_INTERLEAVED = [True, False] +HAS_GATE = [True, False] +DTYPES = [torch.bfloat16, torch.float16] +DEVICES = [f"npu:{0}"] +DEFAULT_ATOL = 1e-2 +DEFAULT_RTOL = 1e-2 + + +def apply_interleaved_rope(x: torch.Tensor, mrope_section: list[int]) -> torch.Tensor: + """Apply interleaved MRoPE to 3D rotary embeddings. + Reorganizes frequency layout from chunked [TTT...HHH...WWW] to + interleaved [THTHWHTHW...TT], preserving frequency continuity. + """ + x_t = x[0].clone() + x_t[..., 1 : mrope_section[1] * 3 : 3] = x[1, ..., 1 : mrope_section[1] * 3 : 3] + x_t[..., 2 : mrope_section[2] * 3 : 3] = x[2, ..., 2 : mrope_section[2] * 3 : 3] + return x_t + + +def rms_norm(x: torch.Tensor, + norm_weight: torch.Tensor, + eps, + norm_bias=None, +): + x = x.cpu() + norm_weight = norm_weight.cpu() + + x = x.to(torch.float32) + norm_weight = norm_weight.to(torch.float32).cpu() + reciprocal_std = 1 / torch.sqrt( + torch.mean(x ** 2, axis=-1, keepdims=True) + eps) + out = x * reciprocal_std * norm_weight + + if norm_bias is not None: + norm_bias = norm_bias.cpu().to(torch.float32) + out = out + norm_bias + + return out + + +def naive_split_qkv_rmsnorm_mrope( + qkv: torch.Tensor, + q_weight: torch.Tensor, + q_bias: torch.Tensor, + k_weight: torch.Tensor, + k_bias: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + num_q_heads: int, + num_kv_heads: int, + head_size: int, + eps: float, + mrope_section: list[int], + rope_dim: int, +): + q_size = num_q_heads * head_size + kv_size = num_kv_heads * head_size + + # split + qkv = qkv.cpu() + q, k, v = qkv.split([q_size, kv_size, kv_size], dim=-1) + + # norm + q = rms_norm(q.reshape(-1, head_size), q_weight, eps, norm_bias=q_bias) + k = rms_norm(k.reshape(-1, head_size), k_weight, eps, norm_bias=k_bias) + + # mrope + rotary_dim = rope_dim + num_tokens = qkv.shape[0] + n_q_head = num_q_heads + n_kv_head = num_kv_heads + q_reshaped = q.view(num_tokens, n_q_head, head_size) + k_reshaped = k.view(num_tokens, n_kv_head, head_size) + cos_reshaped = cos.permute(1, 2, 0) + sin_reshaped = sin.permute(1, 2, 0) + half_rd = rotary_dim // 2 + + for token_idx in range(num_tokens): + token_cos = cos_reshaped[token_idx] + token_sin = sin_reshaped[token_idx] + + cos_row = torch.zeros(half_rd, device=q.device, dtype=q.dtype) + sin_row = torch.zeros(half_rd, device=q.device, dtype=q.dtype) + + t_end = mrope_section[0] + h_end = t_end + mrope_section[1] + + if t_end > 0: + cos_row[:t_end] = token_cos[:t_end, 0] + sin_row[:t_end] = token_sin[:t_end, 0] + + if mrope_section[1] > 0: + cos_row[t_end:h_end] = token_cos[t_end:h_end, 1] + sin_row[t_end:h_end] = token_sin[t_end:h_end, 1] + + if mrope_section[2] > 0: + w_start = h_end + cos_row[w_start:half_rd] = token_cos[w_start:half_rd, 2] + sin_row[w_start:half_rd] = token_sin[w_start:half_rd, 2] + + q_token = q_reshaped[token_idx] + k_token = k_reshaped[token_idx] + + q1 = q_token[:, :half_rd] + q2 = q_token[:, half_rd:rotary_dim] + k1 = k_token[:, :half_rd] + k2 = k_token[:, half_rd:rotary_dim] + + cos_half = cos_row.unsqueeze(0) + sin_half = sin_row.unsqueeze(0) + + new_q1 = q1 * cos_half - q2 * sin_half + new_q2 = q2 * cos_half + q1 * sin_half + + new_k1 = k1 * cos_half - k2 * sin_half + new_k2 = k2 * cos_half + k1 * sin_half + + q_reshaped[token_idx, :, :rotary_dim] = torch.cat([new_q1, new_q2], dim=1) + k_reshaped[token_idx, :, :rotary_dim] = torch.cat([new_k1, new_k2], dim=1) + + q_result = q_reshaped.view(num_tokens, -1) + k_result = k_reshaped.view(num_tokens, -1) + + q = q_result.to(qkv.dtype) + k = k_result.to(qkv.dtype) + v = v.to(qkv.dtype) + + return q, k, v + + +def naive_split_qkv_rmsnorm_mrope_interleaved( + qkv: torch.Tensor, + q_weight: torch.Tensor, + q_bias: torch.Tensor, + k_weight: torch.Tensor, + k_bias: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + num_q_heads: int, + num_kv_heads: int, + head_size: int, + eps: float, + mrope_section: list[int], + rope_dim: int, +): + q_size = num_q_heads * head_size + kv_size = num_kv_heads * head_size + + # split + qkv = qkv.cpu() + q, k, v = qkv.split([q_size, kv_size, kv_size], dim=-1) + + # norm + q = rms_norm(q.reshape(-1, head_size), q_weight, eps, norm_bias=q_bias) + k = rms_norm(k.reshape(-1, head_size), k_weight, eps, norm_bias=k_bias) + + # mrope + rotary_dim = rope_dim + num_tokens = qkv.shape[0] + n_q_head = num_q_heads + n_kv_head = num_kv_heads + q_reshaped = q.view(num_tokens, n_q_head, head_size) + k_reshaped = k.view(num_tokens, n_kv_head, head_size) + cos_reshaped = apply_interleaved_rope(cos, mrope_section) + sin_reshaped = apply_interleaved_rope(sin, mrope_section) + half_rd = rotary_dim // 2 + + for token_idx in range(num_tokens): + cos_row = cos_reshaped[token_idx] + sin_row = sin_reshaped[token_idx] + + q_token = q_reshaped[token_idx] + k_token = k_reshaped[token_idx] + + q1 = q_token[:, :half_rd] + q2 = q_token[:, half_rd:rotary_dim] + k1 = k_token[:, :half_rd] + k2 = k_token[:, half_rd:rotary_dim] + + cos_half = cos_row.unsqueeze(0) + sin_half = sin_row.unsqueeze(0) + + new_q1 = q1 * cos_half - q2 * sin_half + new_q2 = q2 * cos_half + q1 * sin_half + + new_k1 = k1 * cos_half - k2 * sin_half + new_k2 = k2 * cos_half + k1 * sin_half + + q_reshaped[token_idx, :, :rotary_dim] = torch.cat([new_q1, new_q2], dim=1) + k_reshaped[token_idx, :, :rotary_dim] = torch.cat([new_k1, new_k2], dim=1) + + q_result = q_reshaped.view(num_tokens, -1) + k_result = k_reshaped.view(num_tokens, -1) + + q = q_result.to(qkv.dtype) + k = k_result.to(qkv.dtype) + v = v.to(qkv.dtype) + + return q, k, v + + +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("num_q_heads, num_kv_heads", NUM_QKV_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("eps", EPS) +@pytest.mark.parametrize("mrope_section", MROPE_SECTION) +@pytest.mark.parametrize("is_interleaved", IS_INTERLEAVED) +@pytest.mark.parametrize("has_gate", HAS_GATE) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("device", DEVICES) +@torch.inference_mode() +def test_split_qkv_rmsnorm_mrope( + num_tokens: int, + num_q_heads: int, + num_kv_heads: int, + head_size: int, + mrope_section: list[int], + eps: float, + dtype: torch.dtype, + device: str, + is_interleaved: bool, + has_gate: bool, +): + + torch.set_default_device(device) + init_device_properties_triton() + rope_dim = 2 * sum(mrope_section) + q_size = num_q_heads * head_size + kv_size = num_kv_heads * head_size + + # input tensor + if has_gate: + qkv = torch.randn(num_tokens, + 2 * q_size + kv_size * 2, + dtype=dtype, + device=device) + else: + qkv = torch.randn(num_tokens, + q_size + kv_size * 2, + dtype=dtype, + device=device) + q_weight = torch.randn(head_size, dtype=dtype, device=device) + k_weight = torch.randn(head_size, dtype=dtype, device=device) + q_bias = None + k_bias = None + + cos_sin = torch.randn(3, num_tokens, rope_dim, dtype=dtype, + device=device) + cos, sin = cos_sin.chunk(2, dim=-1) + + cos = cos.contiguous() + sin = sin.contiguous() + + if has_gate: + q_gate_data = qkv[:, :q_size * 2].view(-1, num_q_heads, head_size * 2) + q_data, golden_gate = torch.chunk(q_gate_data, 2, dim=-1) + golden_gate = golden_gate.reshape(-1, q_size) + q_data = q_data.reshape(-1, q_size) + k_data = qkv[:, 2 * q_size:2 * q_size + kv_size] + v_data = qkv[:, 2 * q_size + kv_size:] + qkv_for_ref = torch.cat([q_data, k_data, v_data], dim=-1) + else: + qkv_for_ref = qkv + + if is_interleaved: + golden_q, golden_k, golden_v = naive_split_qkv_rmsnorm_mrope_interleaved(qkv_for_ref.cpu(), + q_weight.cpu(), + q_bias, + k_weight.cpu(), + k_bias, + cos.cpu(), + sin.cpu(), + num_q_heads, + num_kv_heads, + head_size, + eps, + mrope_section, + rope_dim) + else: + golden_q, golden_k, golden_v = naive_split_qkv_rmsnorm_mrope(qkv_for_ref.cpu(), + q_weight.cpu(), + q_bias, + k_weight.cpu(), + k_bias, + cos.cpu(), + sin.cpu(), + num_q_heads, + num_kv_heads, + head_size, + eps, + mrope_section, + rope_dim) + + real_q, real_k, real_v, real_gate = torch.ops.vllm.triton_split_qkv_rmsnorm_mrope( + qkv=qkv, + q_weight=q_weight, + k_weight=k_weight, + cos_sin=cos_sin, + num_q_heads=num_q_heads, + num_kv_heads=num_kv_heads, + head_size=head_size, + eps=eps, + mrope_section=mrope_section, + is_interleaved=is_interleaved, + rope_dim=rope_dim, + has_gate=has_gate, + ) + + torch.testing.assert_close(real_q.cpu(), + golden_q.cpu(), + atol=DEFAULT_ATOL, + rtol=DEFAULT_RTOL) + + torch.testing.assert_close(real_k.cpu(), + golden_k.cpu(), + atol=DEFAULT_ATOL, + rtol=DEFAULT_RTOL) + + torch.testing.assert_close(real_v.cpu(), + golden_v.cpu(), + atol=DEFAULT_ATOL, + rtol=DEFAULT_RTOL) + if has_gate: + torch.testing.assert_close(real_gate.cpu(), + golden_gate.cpu(), + atol=DEFAULT_ATOL, + rtol=DEFAULT_RTOL) + + gc.collect() + torch.npu.empty_cache() + torch.npu.reset_peak_memory_stats() \ No newline at end of file diff --git a/vllm_ascend/ops/__init__.py b/vllm_ascend/ops/__init__.py index 46c0ceff..95b498a5 100644 --- a/vllm_ascend/ops/__init__.py +++ b/vllm_ascend/ops/__init__.py @@ -24,6 +24,7 @@ import vllm_ascend.ops.register_custom_ops # noqa if HAS_TRITON: import vllm_ascend.ops.triton.linearnorm.split_qkv_rmsnorm_rope # noqa + import vllm_ascend.ops.triton.linearnorm.split_qkv_rmsnorm_mrope import vllm_ascend.ops.vocab_parallel_embedding # noqa from vllm_ascend.ops.activation import AscendQuickGELU, AscendSiluAndMul diff --git a/vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_mrope.py b/vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_mrope.py new file mode 100644 index 00000000..6fd5b879 --- /dev/null +++ b/vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_mrope.py @@ -0,0 +1,428 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# + + +import torch +import triton # type: ignore +import triton.language as tl # type: ignore +from vllm.utils.torch_utils import direct_register_custom_op + +from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num + + +@triton.jit( + do_not_specialize=["num_tokens", "front_core_num", "num_tokens_each_front_core", "num_tokens_each_tail_core"] +) +def split_qkv_rmsnorm_mrope_kernel( + in_qkv_ptr: torch.Tensor, + q_weight_ptr: torch.Tensor, + q_bias_ptr: torch.Tensor, + k_weight_ptr: torch.Tensor, + k_bias_ptr: torch.Tensor, + cos_sin_ptr: torch.Tensor, + out_q_ptr: torch.Tensor, + out_k_ptr: torch.Tensor, + out_v_ptr: torch.Tensor, + out_gate_ptr: torch.Tensor, + num_tokens, + front_core_num, + num_tokens_each_front_core, + num_tokens_each_tail_core, + num_q_heads: tl.constexpr, + num_kv_heads: tl.constexpr, + head_size: tl.constexpr, + q_size: tl.constexpr, + kv_size: tl.constexpr, + eps: tl.constexpr, + mrope_section_t: tl.constexpr, + mrope_section_h: tl.constexpr, + mrope_section_w: tl.constexpr, + has_bias: tl.constexpr, + is_interleaved: tl.constexpr, + rope_dim: tl.constexpr, + half_rope_dim: tl.constexpr, + IS_PARTIAL_ROPE: tl.constexpr, + gate_size: tl.constexpr, +): + block_idx = tl.program_id(0) + + loop_num = num_tokens_each_front_core + if block_idx >= front_core_num: + loop_num = num_tokens_each_tail_core + + block_offset = num_tokens_each_front_core * block_idx + if block_idx >= front_core_num: + block_offset = ( + num_tokens_each_front_core * front_core_num + (block_idx - front_core_num) * num_tokens_each_tail_core + ) + + q_rmsnorm_weight = tl.load(q_weight_ptr + tl.arange(0, head_size)) + k_rmsnorm_weight = tl.load(k_weight_ptr + tl.arange(0, head_size)) + + if has_bias: + q_bias = tl.load(q_bias_ptr + tl.arange(0, head_size)) + k_bias = tl.load(k_bias_ptr + tl.arange(0, head_size)) + + for index in range(loop_num): + ## load ## + # q + in_q_offset = in_qkv_ptr + (block_offset + index) * (q_size + gate_size + 2 * kv_size) + if gate_size > 0: + in_q_gate_tensor = ( + tl.load(in_q_offset + tl.arange(0, q_size + gate_size)) + .to(tl.float32) + .reshape(num_q_heads, head_size * 2) + ) + in_q_tensor = tl.extract_slice( + in_q_gate_tensor, + offsets=(0, 0), + sizes=(num_q_heads, head_size), + strides=(1, 1), + ) + in_gate_tensor = tl.extract_slice( + in_q_gate_tensor, + offsets=(0, head_size), + sizes=(num_q_heads, head_size), + strides=(1, 1), + ).reshape(q_size) + else: + in_q_tensor = tl.load(in_q_offset + tl.arange(0, q_size)).to(tl.float32).reshape(num_q_heads, head_size) + + # k + in_k_offset = in_q_offset + q_size + gate_size + in_k_tensor = tl.load(in_k_offset + tl.arange(0, kv_size)).to(tl.float32).reshape(num_kv_heads, head_size) + # v + in_v_offset = in_k_offset + kv_size + in_v_tensor = tl.load(in_v_offset + tl.arange(0, kv_size)) + + # cos, sin + cos_offsets = tl.arange(0, half_rope_dim) + if is_interleaved: + h_mask = ((cos_offsets % 3) == 1) & (cos_offsets <= 3 * mrope_section_h) + w_mask = ((cos_offsets % 3) == 2) & (cos_offsets <= 3 * mrope_section_w) + t_mask = ~(h_mask | w_mask) + else: + t_mask = cos_offsets < mrope_section_t + h_mask = (mrope_section_t - 1 < cos_offsets) & (cos_offsets < mrope_section_t + mrope_section_h) + w_mask = (mrope_section_t + mrope_section_h - 1 < cos_offsets) & ( + cos_offsets < mrope_section_t + mrope_section_h + mrope_section_w + ) + + t_cos_offset = cos_sin_ptr + (block_offset + index) * rope_dim + h_cos_offset = t_cos_offset + num_tokens * rope_dim + w_cos_offset = h_cos_offset + num_tokens * rope_dim + + t_sin_offset = cos_sin_ptr + (block_offset + index) * rope_dim + half_rope_dim + h_sin_offset = t_sin_offset + num_tokens * rope_dim + w_sin_offset = h_sin_offset + num_tokens * rope_dim + + t_cos_tensor = tl.load(t_cos_offset + cos_offsets, mask=t_mask, other=0) + h_cos_tensor = tl.load(h_cos_offset + cos_offsets, mask=h_mask, other=0) + w_cos_tensor = tl.load(w_cos_offset + cos_offsets, mask=w_mask, other=0) + t_sin_tensor = tl.load(t_sin_offset + cos_offsets, mask=t_mask, other=0) + h_sin_tensor = tl.load(h_sin_offset + cos_offsets, mask=h_mask, other=0) + w_sin_tensor = tl.load(w_sin_offset + cos_offsets, mask=w_mask, other=0) + + cos_tensor = (t_cos_tensor + h_cos_tensor + w_cos_tensor).to(tl.float32).reshape(1, half_rope_dim) + cos_tensor = tl.broadcast_to(cos_tensor, (2, half_rope_dim)).reshape(1, rope_dim) + + sin_tensor = (t_sin_tensor + h_sin_tensor + w_sin_tensor).to(tl.float32).reshape(1, half_rope_dim) + sin_tensor = tl.broadcast_to(sin_tensor, (2, half_rope_dim)).reshape(1, rope_dim) + + ## compute ## + # q-rmsnorm + squares = in_q_tensor * in_q_tensor + variances = tl.sum(squares, axis=1) / head_size + reciprocal_std = (1 / tl.sqrt(variances + eps)).reshape(num_q_heads, 1) + q_normalized = in_q_tensor * reciprocal_std + q_normalized = q_normalized * q_rmsnorm_weight + if has_bias: + q_normalized = q_normalized + q_bias + + # k-rmsnorm + squares = in_k_tensor * in_k_tensor + variances = tl.sum(squares, axis=1) / head_size + reciprocal_std = (1 / tl.sqrt(variances + eps)).reshape(num_kv_heads, 1) + k_normalized = in_k_tensor * reciprocal_std + k_normalized = k_normalized * k_rmsnorm_weight + if has_bias: + k_normalized = k_normalized + k_bias + + # q-mrope + x1 = tl.extract_slice( + q_normalized, + offsets=(0, 0), + sizes=(num_q_heads, half_rope_dim), + strides=(1, 1), + ) + x2 = tl.extract_slice( + q_normalized, + offsets=(0, half_rope_dim), + sizes=(num_q_heads, half_rope_dim), + strides=(1, 1), + ) + cat_x = tl.zeros((num_q_heads, rope_dim), dtype=tl.float32) + cat_x = tl.insert_slice( + cat_x, + -x2, + offsets=(0, 0), + sizes=(num_q_heads, half_rope_dim), + strides=(1, 1), + ) + cat_x = tl.insert_slice( + cat_x, + x1, + offsets=(0, half_rope_dim), + sizes=(num_q_heads, half_rope_dim), + strides=(1, 1), + ) + if IS_PARTIAL_ROPE: + orig_qk = tl.extract_slice( + q_normalized, + offsets=(0, 0), + sizes=(num_q_heads, rope_dim), + strides=(1, 1), + ) + else: + orig_qk = q_normalized + roped_q = cat_x * sin_tensor + orig_qk * cos_tensor + + # k-mrope + y1 = tl.extract_slice( + k_normalized, + offsets=(0, 0), + sizes=(num_kv_heads, half_rope_dim), + strides=(1, 1), + ) + y2 = tl.extract_slice( + k_normalized, + offsets=(0, half_rope_dim), + sizes=(num_kv_heads, half_rope_dim), + strides=(1, 1), + ) + cat_y = tl.zeros((num_kv_heads, rope_dim), dtype=tl.float32) + cat_y = tl.insert_slice( + cat_y, + -y2, + offsets=(0, 0), + sizes=(num_kv_heads, half_rope_dim), + strides=(1, 1), + ) + cat_y = tl.insert_slice( + cat_y, + y1, + offsets=(0, half_rope_dim), + sizes=(num_kv_heads, half_rope_dim), + strides=(1, 1), + ) + if IS_PARTIAL_ROPE: + orig_qk = tl.extract_slice( + k_normalized, + offsets=(0, 0), + sizes=(num_kv_heads, rope_dim), + strides=(1, 1), + ) + else: + orig_qk = k_normalized + roped_k = cat_y * sin_tensor + orig_qk * cos_tensor + + if IS_PARTIAL_ROPE: + q_normalized = tl.insert_slice( + q_normalized, + roped_q, + offsets=(0, 0), + sizes=(num_q_heads, rope_dim), + strides=(1, 1), + ) + k_normalized = tl.insert_slice( + k_normalized, + roped_k, + offsets=(0, 0), + sizes=(num_kv_heads, rope_dim), + strides=(1, 1), + ) + else: + q_normalized = roped_q + k_normalized = roped_k + + ## store ## + # out_q + out_q_offset = out_q_ptr + (block_offset + index) * q_size + out_q_indices = tl.arange(0, q_size) + tl.store(out_q_offset + out_q_indices, q_normalized.reshape(q_size)) + + # out_k + out_k_offset = out_k_ptr + (block_offset + index) * kv_size + out_k_indices = tl.arange(0, kv_size) + tl.store(out_k_offset + out_k_indices, k_normalized.reshape(kv_size)) + + # out_v + out_v_offset = out_v_ptr + (block_offset + index) * kv_size + tl.store(out_v_offset + tl.arange(0, kv_size), in_v_tensor) + + # out_gate + if gate_size > 0: + out_gate_offset = out_gate_ptr + (block_offset + index) * gate_size + tl.store(out_gate_offset + tl.arange(0, gate_size), in_gate_tensor) + + +def triton_split_qkv_rmsnorm_mrope( + qkv: torch.Tensor, + q_weight: torch.Tensor, + k_weight: torch.Tensor, + cos_sin: torch.Tensor, + num_q_heads: int, + num_kv_heads: int, + head_size: int, + eps: float, + mrope_section: list[int], + is_interleaved: bool, + rope_dim: int | None = None, + q_bias: torch.Tensor | None = None, + k_bias: torch.Tensor | None = None, + has_gate: bool = False, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + core_num = get_vectorcore_num() + + q_size = num_q_heads * head_size + kv_size = num_kv_heads * head_size + num_tokens = qkv.shape[0] + + gate_size = q_size if has_gate else 0 + + if rope_dim is None: + rope_dim = head_size + IS_PARTIAL_ROPE = rope_dim != head_size + + front_core_num = core_num + if num_tokens % core_num != 0: + front_core_num = num_tokens % core_num + + num_tokens_each_front_core = (num_tokens + core_num - 1) // core_num + + tail_core_num = 0 + if num_tokens > core_num: + tail_core_num = core_num - front_core_num + + num_tokens_each_tail_core = num_tokens // core_num + + q_output = torch.empty(num_tokens, q_size, device=qkv.device, dtype=qkv.dtype) + k_output = torch.empty(num_tokens, kv_size, device=qkv.device, dtype=qkv.dtype) + v_output = torch.empty(num_tokens, kv_size, device=qkv.device, dtype=qkv.dtype) + gate_output = torch.empty(num_tokens, gate_size, device=qkv.device, dtype=qkv.dtype) + + total_core = front_core_num + tail_core_num + block_dim = core_num + if total_core < core_num: + block_dim = total_core + + has_bias = q_bias is not None + + split_qkv_rmsnorm_mrope_kernel[(block_dim,)]( + qkv, + q_weight, + q_bias, + k_weight, + k_bias, + cos_sin, + q_output, + k_output, + v_output, + gate_output, + num_tokens, + front_core_num, + num_tokens_each_front_core, + num_tokens_each_tail_core, + num_q_heads, + num_kv_heads, + head_size, + q_size, + kv_size, + eps, + mrope_section[0], + mrope_section[1], + mrope_section[2], + has_bias, + is_interleaved, + rope_dim, + rope_dim // 2, + IS_PARTIAL_ROPE, + gate_size, + ) + + return q_output, k_output, v_output, gate_output + + +def triton_split_qkv_rmsnorm_mrope_fake( + qkv: torch.Tensor, + q_weight: torch.Tensor, + k_weight: torch.Tensor, + cos_sin: torch.Tensor, + num_q_heads: int, + num_kv_heads: int, + head_size: int, + eps: float, + mrope_section: list[int], + is_interleaved: bool, + rope_dim: int | None = None, + q_bias: torch.Tensor | None = None, + k_bias: torch.Tensor | None = None, + has_gate: bool = False, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + num_tokens = qkv.shape[0] + q_size = num_q_heads * head_size + kv_size = num_kv_heads * head_size + gate_size = q_size if has_gate else 0 + + q_output = torch.empty( + num_tokens, + q_size, + device=qkv.device, + dtype=qkv.dtype, + ) + + k_output = torch.empty( + num_tokens, + kv_size, + device=qkv.device, + dtype=qkv.dtype, + ) + + v_output = torch.empty( + num_tokens, + kv_size, + device=qkv.device, + dtype=qkv.dtype, + ) + + gate_output = torch.empty( + num_tokens, + gate_size, + device=qkv.device, + dtype=qkv.dtype, + ) + + return q_output, k_output, v_output, gate_output + + +direct_register_custom_op( + op_name="triton_split_qkv_rmsnorm_mrope", + op_func=triton_split_qkv_rmsnorm_mrope, + fake_impl=triton_split_qkv_rmsnorm_mrope_fake, + mutates_args=[], + dispatch_key="PrivateUse1", +)