From 18b52afe2bb2e3ac80e988a2788ac1e3bbdb898a Mon Sep 17 00:00:00 2001 From: frank <2547457096@qq.com> Date: Fri, 6 Mar 2026 09:30:31 +0800 Subject: [PATCH] [Ops][Misc] Optimize split_qkv_rmsnorm_rope op (#6827) ### What this PR does / why we need it? This PR optimizes the `split_qkv_rmsnorm_rope` operator by introducing a new Triton kernel, `split_qkv_rmsnorm_rope_prefill_kernel`, for the prefill stage (i.e., large batch sizes). The implementation now dynamically selects between the existing decode kernel and the new prefill kernel based on the batch size, which improves performance for large batch scenarios. Additionally, the RoPE implementation is updated to support partial rotation dimensions (`rope_dim`), making the operator more flexible. ### Does this PR introduce _any_ user-facing change? No. This is a performance optimization and is not expected to introduce any user-facing changes. ### How was this patch tested? CI should pass with existing tests. The new prefill path is triggered when the batch size is larger than the number of available vector cores. The partial RoPE feature can be tested by passing the `rope_dim` argument. - vLLM version: v0.15.0 - vLLM main: https://github.com/vllm-project/vllm/commit/83b47f67b1dfad505606070ae4d9f83e50ad4ebd --------- Signed-off-by: guzhiyong Signed-off-by: frank <2547457096@qq.com> Co-authored-by: guzhiyong --- .../triton/test_split_qkv_rmsnorm_rope.py | 38 +- .../test_graphex_qknorm_rope_fusion.py | 11 +- .../e2e/singlecard/test_aclgraph_accuracy.py | 4 +- .../linearnorm/split_qkv_rmsnorm_rope.py | 414 +++++++++++------- 4 files changed, 290 insertions(+), 177 deletions(-) diff --git a/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_split_qkv_rmsnorm_rope.py b/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_split_qkv_rmsnorm_rope.py index 7336e515..0e51cac2 100644 --- a/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_split_qkv_rmsnorm_rope.py +++ b/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_split_qkv_rmsnorm_rope.py @@ -11,6 +11,7 @@ MAX_POSITION_EMBEDDINGS = [262144] NUM_TOKENS = [1, 4, 8, 16, 1024] NUM_QKV_HEADS = [(12, 1), (16, 1), (32, 4), (64, 4)] HEAD_SIZES = [128] +ROPE_DIMS = [64, 128] EPS = [1e-6] DTYPES = [torch.bfloat16] SEEDS = [0] @@ -23,19 +24,26 @@ def custom_rope(q, k, sin, cos): rotary_dim = sin.shape[-1] sin = sin.to(torch.float32) cos = cos.to(torch.float32) - x1 = q[..., :rotary_dim // 2] - x2 = q[..., rotary_dim // 2:] - cat_x = torch.cat([-x2, x1], axis=-1) - mul1 = cat_x * sin - mul2 = q * cos - res1 = mul1 + mul2 + q_rot = q[..., :rotary_dim] + k_rot = k[..., :rotary_dim] + q_pass = q[..., rotary_dim:] + k_pass = k[..., rotary_dim:] - x1 = k[..., :rotary_dim // 2] - x2 = k[..., rotary_dim // 2:] + x1 = q_rot[..., :rotary_dim // 2] + x2 = q_rot[..., rotary_dim // 2:] cat_x = torch.cat([-x2, x1], axis=-1) mul1 = cat_x * sin - mul2 = k * cos - res2 = mul1 + mul2 + mul2 = q_rot * cos + q_rot = mul1 + mul2 + res1 = torch.cat([q_rot, q_pass], dim=-1) + + x1 = k_rot[..., :rotary_dim // 2] + x2 = k_rot[..., rotary_dim // 2:] + cat_x = torch.cat([-x2, x1], axis=-1) + mul1 = cat_x * sin + mul2 = k_rot * cos + k_rot = mul1 + mul2 + res2 = torch.cat([k_rot, k_pass], dim=-1) return res1, res2 @@ -64,9 +72,10 @@ def rms_norm( @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", DEVICES) +@pytest.mark.parametrize("rope_dim", ROPE_DIMS) @torch.inference_mode() def test_split_qkv_rmsnorm_rope(max_position_embeddings, num_tokens, num_q_heads, num_kv_heads, - head_size, eps, dtype, seed, device): + head_size, eps, dtype, seed, device, rope_dim): torch.manual_seed(seed) torch.set_default_device(device) init_device_properties_triton() @@ -81,7 +90,7 @@ def test_split_qkv_rmsnorm_rope(max_position_embeddings, num_tokens, num_q_heads k_weight = torch.randn(head_size, dtype=dtype, device=device) cos_sin_cache = torch.from_numpy( np.random.uniform(0, 1, - [max_position_embeddings, head_size])).to(dtype).npu() + [max_position_embeddings, rope_dim])).to(dtype).npu() positions = torch.randint(low=0, high=max_position_embeddings, size=(num_tokens,), dtype=torch.int64, device=device) # fused kernel q, k, v = torch.ops.vllm.qkv_rmsnorm_rope(input=qkv, @@ -141,10 +150,11 @@ def test_split_qkv_rmsnorm_rope(max_position_embeddings, num_tokens, num_q_heads @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", DEVICES) +@pytest.mark.parametrize("rope_dim", ROPE_DIMS) @torch.inference_mode() def test_split_qkv_rmsnorm_rope_with_bias(max_position_embeddings, num_tokens, num_q_heads, num_kv_heads, head_size, eps, dtype, - seed, device): + seed, device, rope_dim): torch.manual_seed(seed) torch.set_default_device(device) init_device_properties_triton() @@ -161,7 +171,7 @@ def test_split_qkv_rmsnorm_rope_with_bias(max_position_embeddings, num_tokens, n k_bias = torch.randn(head_size, dtype=dtype, device=device) cos_sin_cache = torch.from_numpy( np.random.uniform(0, 1, - [max_position_embeddings, head_size])).to(dtype).npu() + [max_position_embeddings, rope_dim])).to(dtype).npu() positions = torch.randint(low=0, high=max_position_embeddings, size=(num_tokens,), dtype=torch.int64, device=device) # fused kernel q, k, v = torch.ops.vllm.qkv_rmsnorm_rope(input=qkv, diff --git a/tests/e2e/singlecard/compile/test_graphex_qknorm_rope_fusion.py b/tests/e2e/singlecard/compile/test_graphex_qknorm_rope_fusion.py index 7298ecb9..0c35b70e 100644 --- a/tests/e2e/singlecard/compile/test_graphex_qknorm_rope_fusion.py +++ b/tests/e2e/singlecard/compile/test_graphex_qknorm_rope_fusion.py @@ -1,5 +1,6 @@ import copy +import numpy as np import pytest import torch import torch.nn as nn @@ -16,6 +17,8 @@ from vllm_ascend.compilation.passes.qknorm_rope_fusion_pass import ( ) from vllm_ascend.ops.triton.triton_utils import init_device_properties_triton +MAX_POSITION_EMBEDDING = 262144 + def find_op(gm, op_default): return any(node.op == "call_function" and node.target == op_default for node in gm.graph.nodes) @@ -207,8 +210,10 @@ def test_rmsnorm_quant_fusion( model = model.to("npu") seq_len = 5 qkv = torch.randn(seq_len, qkv_size, device="npu", dtype=dtype) - cos = torch.randn(1, seq_len, 1, head_dim, device="npu", dtype=dtype) - sin = torch.randn(1, seq_len, 1, head_dim, device="npu", dtype=dtype) + cos_sin_cache = torch.from_numpy(np.random.uniform(0, 1, [MAX_POSITION_EMBEDDING, head_dim])).to(dtype).npu() + positions = torch.randint( + low=0, high=MAX_POSITION_EMBEDDING, size=(num_tokens,), dtype=torch.int64, device="npu" + ) with torch.no_grad(): original_optimize = torchair.npu_fx_compiler._optimize_fx @@ -218,6 +223,6 @@ def test_rmsnorm_quant_fusion( compiled_model = torch.compile(model, backend="npugraph_ex", fullgraph=True, dynamic=True) - compiled_model(qkv, cos, sin) + compiled_model(qkv, cos_sin_cache, positions) torchair.npu_fx_compiler._optimize_fx = original_optimize diff --git a/tests/e2e/singlecard/test_aclgraph_accuracy.py b/tests/e2e/singlecard/test_aclgraph_accuracy.py index 6835b194..32b891b6 100644 --- a/tests/e2e/singlecard/test_aclgraph_accuracy.py +++ b/tests/e2e/singlecard/test_aclgraph_accuracy.py @@ -74,7 +74,7 @@ CASE_QWEN_FULL_DECODE_ONLY = LLMTestCase( prompts=PROMPTS_LONG, golden_answers=[ " \n\nTo solve this problem, we need to use the Law of Sines and Law of Cosines. Let me start by drawing triangle $ABC$ with the", - " \n\nTo solve this problem, we can use the fact that the expected value of the area of a triangle with vertices on a square can be calculated by integrating over", + " \n\nTo solve this problem, we can use the following approach: Let $P$ be the perimeter of the square. Then, the expected value of the area", " \n\nTo solve this problem, we can use the following approach: Let $ \\alpha $ be the common real root of the two equations. Then, we can", ], ) @@ -95,7 +95,7 @@ CASE_QWEN_EX = LLMTestCase( prompts=PROMPTS_LONG, golden_answers=[ " \n\nTo solve this problem, we need to use the Law of Sines and Law of Cosines. Let me start by drawing triangle $ABC$ with the", - " \n\nTo solve this problem, we can use the fact that the expected value of the area of a triangle with vertices on a square can be calculated by integrating over", + " \n\nTo solve this problem, we can use the following approach: Let $P$ be the perimeter of the square. Then, the expected value of the area", " \n\nTo solve this problem, we can use the following approach: Let $ \\alpha $ be the common real root of the two equations. Then, we can", ], ) diff --git a/vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py b/vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py index 2fc06a10..f7e59a28 100644 --- a/vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py +++ b/vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py @@ -20,17 +20,15 @@ import triton # type: ignore import triton.language as tl # type: ignore from vllm.utils.torch_utils import direct_register_custom_op -from vllm_ascend.ops.triton.triton_utils import extract_slice, get_vectorcore_num, insert_slice +from vllm_ascend.ops.triton.triton_utils import extract_slice, get_element, get_vectorcore_num, insert_slice @triton.jit def split_qkv_rmsnorm_rope_kernel( - input_ptr, - cos_sin_ptr, - pos_ptr, - q_ptr, - k_ptr, - v_ptr, + input_gm_ptr, + q_gm_ptr, + k_gm_ptr, + v_gm_ptr, q_weight_ptr, q_bias_ptr, k_weight_ptr, @@ -40,158 +38,228 @@ def split_qkv_rmsnorm_rope_kernel( kv_hidden_size: tl.constexpr, total_hidden_size: tl.constexpr, eps: tl.constexpr, - Q_BLOCK_SIZE: tl.constexpr, - KV_BLOCK_SIZE: tl.constexpr, BIAS: tl.constexpr, HEAD_DIM: tl.constexpr, - HALF_HEAD_DIM: tl.constexpr, + ROPE_DIM: tl.constexpr, + HALF_ROPE_DIM: tl.constexpr, + IS_PARTIAL_ROPE: tl.constexpr, + num_vectorcore: tl.constexpr, + batch_size_per_iter_per_vec: tl.constexpr, + qk_head_nums_per_iter_per_vec: tl.constexpr, + q_head_num: tl.constexpr, + kv_head_num: tl.constexpr, + qk_head_num_sum: tl.constexpr, + v_batch_size_per_iter_per_vec: tl.constexpr, + positions_gm_ptr, + cos_sin_cache_gm_ptr, ): row_pid = tl.program_id(0) - col_pid = tl.program_id(1) - row_step = tl.num_programs(0) - # q - weight_values = tl.load(q_weight_ptr + tl.arange(0, HEAD_DIM)) - if BIAS: - bias_values = tl.load(q_bias_ptr + tl.arange(0, HEAD_DIM)) - input_offset = row_pid * total_hidden_size - output_offset = row_pid * q_hidden_size - input_offset_step = row_step * total_hidden_size - output_offset_step = row_step * q_hidden_size - for row_idx in tl.range(row_pid, batch_size, row_step): - col_indices = col_pid * Q_BLOCK_SIZE + tl.arange(0, Q_BLOCK_SIZE) - valid_mask = col_indices < q_hidden_size - input_values = ( - tl.load(input_ptr + input_offset + col_indices, mask=valid_mask, other=0.0) - .to(tl.float32) - .reshape(Q_BLOCK_SIZE // HEAD_DIM, HEAD_DIM) - ) - squares = input_values * input_values - variances = tl.sum(squares, axis=1) / HEAD_DIM - reciprocal_std = (1 / tl.sqrt(variances + eps)).reshape(Q_BLOCK_SIZE // HEAD_DIM, 1) - normalized_values = input_values * reciprocal_std # (Q_BLOCK_SIZE//HEAD_DIM, HEAD_DIM) - if BIAS: - normalized_values = (normalized_values * weight_values + bias_values).to(tl.bfloat16) - else: - normalized_values = (normalized_values * weight_values).to(tl.bfloat16) - pos_idx = tl.load(pos_ptr + row_idx).to(tl.int64) - cos_offsets = pos_idx * HEAD_DIM + tl.arange(0, HALF_HEAD_DIM) - sin_offsets = pos_idx * HEAD_DIM + tl.arange(HALF_HEAD_DIM, HEAD_DIM) - cos = (tl.load(cos_sin_ptr + cos_offsets)).reshape(1, HALF_HEAD_DIM) - sin = (tl.load(cos_sin_ptr + sin_offsets)).reshape(1, HALF_HEAD_DIM) + q_weight_values = tl.load(q_weight_ptr + tl.arange(0, HEAD_DIM)) + k_weight_values = tl.load(k_weight_ptr + tl.arange(0, HEAD_DIM)) + + batch_size_per_vec = tl.cdiv(batch_size, num_vectorcore) + iter_num_per_vec = tl.cdiv(batch_size_per_vec, batch_size_per_iter_per_vec) + v_iter_num_per_vec = tl.cdiv(batch_size_per_vec, v_batch_size_per_iter_per_vec) + input_batch_offset = row_pid * batch_size_per_vec + mblk_idx = tl.arange(0, batch_size_per_iter_per_vec) + input_batch_offset + nblk_idx = tl.arange(0, q_hidden_size + kv_hidden_size) + nmask = nblk_idx < total_hidden_size + + input_batch_offset_end = min(input_batch_offset + batch_size_per_vec, batch_size) + + pos_indices = input_batch_offset + tl.arange(0, batch_size_per_iter_per_vec) + output_q_nblk_idx = tl.arange(0, q_hidden_size) + output_q_nmask = output_q_nblk_idx < q_hidden_size + output_kv_nblk_idx = tl.arange(0, kv_hidden_size) + output_kv_nmask = output_kv_nblk_idx < kv_hidden_size + sin_cos_range = tl.arange(0, ROPE_DIM) + cos_sin_cache_offset = cos_sin_cache_gm_ptr + sin_cos_range + + for iter in tl.range(iter_num_per_vec): + pos_offset = iter * batch_size_per_iter_per_vec + x = tl.load( + positions_gm_ptr + pos_indices + pos_offset, mask=(pos_indices + pos_offset) < input_batch_offset_end + ) + mmask = (mblk_idx + pos_offset) < input_batch_offset_end + mask = (mmask[:, None]) & (nmask[None, :]) + idx = (mblk_idx + pos_offset)[:, None] * total_hidden_size + nblk_idx[None, :] + values_tmp1 = tl.load(input_gm_ptr + idx, mask=mask).reshape(qk_head_nums_per_iter_per_vec, HEAD_DIM) + if BIAS: + q_bias_values = tl.load(q_bias_ptr + tl.arange(0, HEAD_DIM)) + k_bias_values = tl.load(k_bias_ptr + tl.arange(0, HEAD_DIM)) + + values_tmp3 = tl.zeros((batch_size_per_iter_per_vec, ROPE_DIM), dtype=tl.bfloat16) + for i in tl.range(batch_size_per_iter_per_vec): + pos = get_element(x, (i,)) + values_tmp3 = insert_slice( + values_tmp3.reshape(batch_size_per_iter_per_vec, ROPE_DIM), + tl.load(pos * ROPE_DIM + cos_sin_cache_offset[:, None]).reshape(1, ROPE_DIM), + offsets=(i, 0), + sizes=(1, ROPE_DIM), + strides=(1, 1), + ) + values_tmp3 = values_tmp3.reshape(batch_size_per_iter_per_vec, 1, ROPE_DIM) + cos = extract_slice( + values_tmp3, + offsets=(0, 0, 0), + sizes=(batch_size_per_iter_per_vec, 1, HALF_ROPE_DIM), + strides=(1, 1, 1), + ) + sin = extract_slice( + values_tmp3, + offsets=(0, 0, HALF_ROPE_DIM), + sizes=(batch_size_per_iter_per_vec, 1, HALF_ROPE_DIM), + strides=(1, 1, 1), + ) + + normalized_values = values_tmp1.to(tl.float32) + normalized_values = normalized_values * normalized_values + normalized_values = tl.sum(normalized_values, axis=1) / HEAD_DIM + normalized_values = 1 / tl.sqrt(normalized_values + eps).reshape(qk_head_nums_per_iter_per_vec, 1) + normalized_values = values_tmp1 * normalized_values + + normalized_values_tmp = extract_slice( + normalized_values.reshape(batch_size_per_iter_per_vec, qk_head_num_sum, HEAD_DIM), + offsets=(0, 0, 0), + sizes=(batch_size_per_iter_per_vec, q_head_num, HEAD_DIM), + strides=(1, 1, 1), + ) + + if BIAS: + normalized_values_tmp = (normalized_values_tmp * q_weight_values + q_bias_values).to(tl.bfloat16) + else: + normalized_values_tmp = (normalized_values_tmp * q_weight_values).to(tl.bfloat16) + + # q rope + values_tmp = tl.zeros((batch_size_per_iter_per_vec, q_head_num, ROPE_DIM), dtype=tl.bfloat16) x1 = extract_slice( - normalized_values, - offsets=(0, 0), - sizes=(Q_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM), - strides=(1, 1), + normalized_values_tmp, + offsets=(0, 0, 0), + sizes=(batch_size_per_iter_per_vec, q_head_num, HALF_ROPE_DIM), + strides=(1, 1, 1), ) x2 = extract_slice( - normalized_values, - offsets=(0, HALF_HEAD_DIM), - sizes=(Q_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM), - strides=(1, 1), + normalized_values_tmp, + offsets=(0, 0, HALF_ROPE_DIM), + sizes=(batch_size_per_iter_per_vec, q_head_num, HALF_ROPE_DIM), + strides=(1, 1, 1), ) - roped_q1 = x1 * cos - x2 * sin - roped_q2 = x2 * cos + x1 * sin - - roped_q = tl.zeros((Q_BLOCK_SIZE // HEAD_DIM, HEAD_DIM), dtype=tl.bfloat16) - roped_q = insert_slice( - roped_q, - roped_q1, - offsets=(0, 0), - sizes=(Q_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM), - strides=(1, 1), + values_tmp = insert_slice( + values_tmp, + x1 * cos - x2 * sin, + offsets=(0, 0, 0), + sizes=(batch_size_per_iter_per_vec, q_head_num, HALF_ROPE_DIM), + strides=(1, 1, 1), ) - roped_q = insert_slice( - roped_q, - roped_q2, - offsets=(0, HALF_HEAD_DIM), - sizes=(Q_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM), - strides=(1, 1), + values_tmp = insert_slice( + values_tmp, + x2 * cos + x1 * sin, + offsets=(0, 0, HALF_ROPE_DIM), + sizes=(batch_size_per_iter_per_vec, q_head_num, HALF_ROPE_DIM), + strides=(1, 1, 1), ) - tl.store( - q_ptr + output_offset + col_indices, - roped_q.reshape(Q_BLOCK_SIZE).to(q_ptr.dtype.element_ty), - mask=valid_mask, - ) - input_offset += input_offset_step - output_offset += output_offset_step - - weight_values = tl.load(k_weight_ptr + tl.arange(0, HEAD_DIM)) - if BIAS: - bias_values = tl.load(k_bias_ptr + tl.arange(0, HEAD_DIM)) - input_offset = row_pid * total_hidden_size + q_hidden_size - output_offset = row_pid * kv_hidden_size - output_offset_step = row_step * kv_hidden_size - for row_idx in tl.range(row_pid, batch_size, row_step): - col_indices = col_pid * KV_BLOCK_SIZE + tl.arange(0, KV_BLOCK_SIZE) - valid_mask = col_indices < kv_hidden_size - input_values = ( - tl.load(input_ptr + input_offset + col_indices, mask=valid_mask, other=0.0) - .to(tl.float32) - .reshape(KV_BLOCK_SIZE // HEAD_DIM, HEAD_DIM) - ) - squares = input_values * input_values - variances = tl.sum(squares, axis=1) / HEAD_DIM - reciprocal_std = (1 / tl.sqrt(variances + eps)).reshape(KV_BLOCK_SIZE // HEAD_DIM, 1) - normalized_values = input_values * reciprocal_std # (KV_BLOCK_SIZE/HEAD_DIM, HEAD_DIM) - if BIAS: - normalized_values = (normalized_values * weight_values + bias_values).to(tl.bfloat16) + q_output_idx = output_q_nblk_idx[None, :] + (mblk_idx + pos_offset)[:, None] * q_hidden_size + mask = (mmask[:, None]) & (output_q_nmask[None, :]) + if IS_PARTIAL_ROPE: + normalized_values_tmp = insert_slice( + normalized_values_tmp, + values_tmp, + offsets=(0, 0, 0), + sizes=(batch_size_per_iter_per_vec, q_head_num, ROPE_DIM), + strides=(1, 1, 1), + ) + tl.store( + q_gm_ptr + q_output_idx, + normalized_values_tmp.reshape(batch_size_per_iter_per_vec, q_hidden_size), + mask=mask, + ) else: - normalized_values = (normalized_values * weight_values).to(tl.bfloat16) + tl.store( + q_gm_ptr + q_output_idx, + values_tmp.reshape(batch_size_per_iter_per_vec, q_hidden_size), + mask=mask, + ) + + # k rope + normalized_values_tmp1 = extract_slice( + normalized_values.reshape(batch_size_per_iter_per_vec, qk_head_num_sum, HEAD_DIM), + offsets=(0, q_head_num, 0), + sizes=(batch_size_per_iter_per_vec, kv_head_num, HEAD_DIM), + strides=(1, 1, 1), + ) + + if BIAS: + normalized_values_tmp1 = (normalized_values_tmp1 * k_weight_values + k_bias_values).to(tl.bfloat16) + else: + normalized_values_tmp1 = (normalized_values_tmp1 * k_weight_values).to(tl.bfloat16) + + values_tmp2 = tl.zeros((batch_size_per_iter_per_vec, kv_head_num, ROPE_DIM), dtype=tl.bfloat16) - pos_idx = tl.load(pos_ptr + row_idx).to(tl.int64) - cos_offsets = pos_idx * HEAD_DIM + tl.arange(0, HALF_HEAD_DIM) - sin_offsets = pos_idx * HEAD_DIM + tl.arange(HALF_HEAD_DIM, HEAD_DIM) - cos = (tl.load(cos_sin_ptr + cos_offsets)).reshape(1, HALF_HEAD_DIM) - sin = (tl.load(cos_sin_ptr + sin_offsets)).reshape(1, HALF_HEAD_DIM) x1 = extract_slice( - normalized_values, - offsets=(0, 0), - sizes=(KV_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM), - strides=(1, 1), + normalized_values_tmp1, + offsets=(0, 0, 0), + sizes=(batch_size_per_iter_per_vec, kv_head_num, HALF_ROPE_DIM), + strides=(1, 1, 1), ) x2 = extract_slice( - normalized_values, - offsets=(0, HALF_HEAD_DIM), - sizes=(KV_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM), - strides=(1, 1), + normalized_values_tmp1, + offsets=(0, 0, HALF_ROPE_DIM), + sizes=(batch_size_per_iter_per_vec, kv_head_num, HALF_ROPE_DIM), + strides=(1, 1, 1), + ) + values_tmp2 = insert_slice( + values_tmp2, + x1 * cos - x2 * sin, + offsets=(0, 0, 0), + sizes=(batch_size_per_iter_per_vec, kv_head_num, HALF_ROPE_DIM), + strides=(1, 1, 1), + ) + values_tmp2 = insert_slice( + values_tmp2, + x2 * cos + x1 * sin, + offsets=(0, 0, HALF_ROPE_DIM), + sizes=(batch_size_per_iter_per_vec, kv_head_num, HALF_ROPE_DIM), + strides=(1, 1, 1), ) - roped_k1 = x1 * cos - x2 * sin - roped_k2 = x2 * cos + x1 * sin - roped_k = tl.zeros((KV_BLOCK_SIZE // HEAD_DIM, HEAD_DIM), dtype=tl.bfloat16) - roped_k = insert_slice( - roped_k, - roped_k1, - offsets=(0, 0), - sizes=(KV_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM), - strides=(1, 1), - ) - roped_k = insert_slice( - roped_k, - roped_k2, - offsets=(0, HALF_HEAD_DIM), - sizes=(KV_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM), - strides=(1, 1), - ) - tl.store( - k_ptr + output_offset + col_indices, - roped_k.to(tl.bfloat16).reshape(KV_BLOCK_SIZE), - mask=valid_mask, - ) - input_offset += input_offset_step - output_offset += output_offset_step + kv_output_idx = output_kv_nblk_idx[None, :] + (mblk_idx + pos_offset)[:, None] * kv_hidden_size + mask = (mmask[:, None]) & (output_kv_nmask[None, :]) + if IS_PARTIAL_ROPE: + normalized_values_tmp1 = insert_slice( + normalized_values_tmp1, + values_tmp2, + offsets=(0, 0, 0), + sizes=(batch_size_per_iter_per_vec, kv_head_num, ROPE_DIM), + strides=(1, 1, 1), + ) + tl.store( + k_gm_ptr + kv_output_idx, + normalized_values_tmp1.reshape(batch_size_per_iter_per_vec, kv_hidden_size), + mask=mask, + ) + else: + tl.store( + k_gm_ptr + kv_output_idx, + values_tmp2.reshape(batch_size_per_iter_per_vec, kv_hidden_size), + mask=mask, + ) - input_offset = row_pid * total_hidden_size + q_hidden_size + kv_hidden_size - output_offset = row_pid * kv_hidden_size - for _ in tl.range(row_pid, batch_size, row_step): - col_indices = col_pid * KV_BLOCK_SIZE + tl.arange(0, KV_BLOCK_SIZE) - valid_mask = col_indices < kv_hidden_size - input_values = tl.load(input_ptr + input_offset + col_indices, mask=valid_mask, other=0.0) - tl.store(v_ptr + output_offset + col_indices, input_values, mask=valid_mask) - input_offset += input_offset_step - output_offset += output_offset_step + mblk_idx = tl.arange(0, v_batch_size_per_iter_per_vec) + input_batch_offset + nblk_idx = tl.arange(q_hidden_size + kv_hidden_size, total_hidden_size) + nmask = nblk_idx < total_hidden_size + out_nblk_idx = tl.arange(0, kv_hidden_size) + out_nmask = out_nblk_idx < kv_hidden_size + + for _ in tl.range(v_iter_num_per_vec): + mmask = mblk_idx < input_batch_offset_end + mask = (mmask[:, None]) & (nmask[None, :]) + idx = mblk_idx[:, None] * total_hidden_size + nblk_idx[None, :] + values = tl.load(input_gm_ptr + idx, mask=mask) + out_idx = mblk_idx[:, None] * kv_hidden_size + out_nblk_idx[None, :] + out_mask = (mmask[:, None]) & (out_nmask[None, :]) + tl.store(v_gm_ptr + out_idx, values, mask=out_mask) + mblk_idx += v_batch_size_per_iter_per_vec def split_qkv_rmsnorm_rope_impl( @@ -207,25 +275,46 @@ def split_qkv_rmsnorm_rope_impl( q_bias: torch.Tensor | None = None, k_bias: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - KV_BLOCK_SIZE = triton.next_power_of_2(head_dim) - assert head_dim == KV_BLOCK_SIZE - assert q_hidden_size % kv_hidden_size == 0 - Q_BLOCK_SIZE = q_hidden_size // kv_hidden_size * head_dim + # get available vector core + num_vectorcore = get_vectorcore_num() + rope_dim = cos_sin_cache.shape[-1] batch_size = input.shape[0] + BIAS = q_bias is not None + IS_PARTIAL_ROPE = rope_dim != head_dim + # Q + K + V total_hidden_size = q_hidden_size + kv_hidden_size * 2 + q_output = torch.empty(batch_size, q_hidden_size, device=input.device, dtype=input.dtype) k_output = torch.empty(batch_size, kv_hidden_size, device=input.device, dtype=input.dtype) v_output = torch.empty(batch_size, kv_hidden_size, device=input.device, dtype=input.dtype) - n_cols = kv_hidden_size // KV_BLOCK_SIZE - num_vectorcore = get_vectorcore_num() - assert num_vectorcore % n_cols == 0 - n_rows = num_vectorcore // n_cols - BIAS = q_bias is not None - split_qkv_rmsnorm_rope_kernel[(n_rows, n_cols, 1)]( + q_head_num = q_hidden_size // head_dim + kv_head_num = kv_hidden_size // head_dim + + # set number of line loading from GM data is x + # x*(q_head_num + kv_head_num)*HEAD_DIM: values_tmp + # 2x*(q_head_num + kv_head_num)*HEAD_DIM: normalized_values(float32) + # x*ROPE_DIM*2 : cos/sin + # x*q_head_num*HEAD_DIM*2: normalized_values_tmp + # x*q_head_num*ROPE_DIM*(0.5) (not IS_PARTIAL_ROPE) x*q_head_num*ROPE_DIM*(0.5): y + UB_SIZE = 87040 # 85K = 85 * 1024 + # the factor is the sum of elements number + if IS_PARTIAL_ROPE: + factor = 5 * q_hidden_size + 3 * kv_hidden_size + rope_dim * 4 + q_head_num * rope_dim + batch_size_per_iter_per_vec = int(UB_SIZE / input.element_size()) // factor + else: + factor = 5 * q_hidden_size + 3 * kv_hidden_size + rope_dim * 2 + q_head_num * rope_dim // 2 + batch_size_per_iter_per_vec = int(UB_SIZE / input.element_size()) // factor + batch_size_per_iter_per_vec = max(1, batch_size_per_iter_per_vec) + qk_head_num_sum = int(q_head_num + kv_head_num) + qk_head_nums_per_iter_per_vec = batch_size_per_iter_per_vec * qk_head_num_sum + + grid = (num_vectorcore, 1, 1) + # v tiling + v_batch_size_per_iter_per_vec = UB_SIZE / torch.bfloat16.itemsize // (kv_hidden_size + 1) + + split_qkv_rmsnorm_rope_kernel[grid]( input, - cos_sin_cache, - positions, q_output, k_output, v_output, @@ -238,11 +327,20 @@ def split_qkv_rmsnorm_rope_impl( kv_hidden_size, total_hidden_size, eps, - Q_BLOCK_SIZE, - KV_BLOCK_SIZE, BIAS, head_dim, - head_dim // 2, + rope_dim, + rope_dim // 2, + IS_PARTIAL_ROPE, + num_vectorcore, + int(batch_size_per_iter_per_vec), + int(qk_head_nums_per_iter_per_vec), + q_head_num, + kv_head_num, + qk_head_num_sum, + int(v_batch_size_per_iter_per_vec), + positions, + cos_sin_cache, ) return q_output, k_output, v_output @@ -265,19 +363,19 @@ def split_qkv_rmsnorm_rope_impl_fake( batch_size = input.shape[0] q_output = torch.empty( batch_size, - q_hidden_size, + int(q_hidden_size), device=input.device, dtype=input.dtype, ) k_output = torch.empty( batch_size, - kv_hidden_size, + int(kv_hidden_size), device=input.device, dtype=input.dtype, ) v_output = torch.empty( batch_size, - kv_hidden_size, + int(kv_hidden_size), device=input.device, dtype=input.dtype, )