diff --git a/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_split_qkv_rmsnorm_rope.py b/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_split_qkv_rmsnorm_rope.py index 7336e515..0e51cac2 100644 --- a/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_split_qkv_rmsnorm_rope.py +++ b/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_split_qkv_rmsnorm_rope.py @@ -11,6 +11,7 @@ MAX_POSITION_EMBEDDINGS = [262144] NUM_TOKENS = [1, 4, 8, 16, 1024] NUM_QKV_HEADS = [(12, 1), (16, 1), (32, 4), (64, 4)] HEAD_SIZES = [128] +ROPE_DIMS = [64, 128] EPS = [1e-6] DTYPES = [torch.bfloat16] SEEDS = [0] @@ -23,19 +24,26 @@ def custom_rope(q, k, sin, cos): rotary_dim = sin.shape[-1] sin = sin.to(torch.float32) cos = cos.to(torch.float32) - x1 = q[..., :rotary_dim // 2] - x2 = q[..., rotary_dim // 2:] - cat_x = torch.cat([-x2, x1], axis=-1) - mul1 = cat_x * sin - mul2 = q * cos - res1 = mul1 + mul2 + q_rot = q[..., :rotary_dim] + k_rot = k[..., :rotary_dim] + q_pass = q[..., rotary_dim:] + k_pass = k[..., rotary_dim:] - x1 = k[..., :rotary_dim // 2] - x2 = k[..., rotary_dim // 2:] + x1 = q_rot[..., :rotary_dim // 2] + x2 = q_rot[..., rotary_dim // 2:] cat_x = torch.cat([-x2, x1], axis=-1) mul1 = cat_x * sin - mul2 = k * cos - res2 = mul1 + mul2 + mul2 = q_rot * cos + q_rot = mul1 + mul2 + res1 = torch.cat([q_rot, q_pass], dim=-1) + + x1 = k_rot[..., :rotary_dim // 2] + x2 = k_rot[..., rotary_dim // 2:] + cat_x = torch.cat([-x2, x1], axis=-1) + mul1 = cat_x * sin + mul2 = k_rot * cos + k_rot = mul1 + mul2 + res2 = torch.cat([k_rot, k_pass], dim=-1) return res1, res2 @@ -64,9 +72,10 @@ def rms_norm( @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", DEVICES) +@pytest.mark.parametrize("rope_dim", ROPE_DIMS) @torch.inference_mode() def test_split_qkv_rmsnorm_rope(max_position_embeddings, num_tokens, num_q_heads, num_kv_heads, - head_size, eps, dtype, seed, device): + head_size, eps, dtype, seed, device, rope_dim): torch.manual_seed(seed) torch.set_default_device(device) init_device_properties_triton() @@ -81,7 +90,7 @@ def test_split_qkv_rmsnorm_rope(max_position_embeddings, num_tokens, num_q_heads k_weight = torch.randn(head_size, dtype=dtype, device=device) cos_sin_cache = torch.from_numpy( np.random.uniform(0, 1, - [max_position_embeddings, head_size])).to(dtype).npu() + [max_position_embeddings, rope_dim])).to(dtype).npu() positions = torch.randint(low=0, high=max_position_embeddings, size=(num_tokens,), dtype=torch.int64, device=device) # fused kernel q, k, v = torch.ops.vllm.qkv_rmsnorm_rope(input=qkv, @@ -141,10 +150,11 @@ def test_split_qkv_rmsnorm_rope(max_position_embeddings, num_tokens, num_q_heads @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", DEVICES) +@pytest.mark.parametrize("rope_dim", ROPE_DIMS) @torch.inference_mode() def test_split_qkv_rmsnorm_rope_with_bias(max_position_embeddings, num_tokens, num_q_heads, num_kv_heads, head_size, eps, dtype, - seed, device): + seed, device, rope_dim): torch.manual_seed(seed) torch.set_default_device(device) init_device_properties_triton() @@ -161,7 +171,7 @@ def test_split_qkv_rmsnorm_rope_with_bias(max_position_embeddings, num_tokens, n k_bias = torch.randn(head_size, dtype=dtype, device=device) cos_sin_cache = torch.from_numpy( np.random.uniform(0, 1, - [max_position_embeddings, head_size])).to(dtype).npu() + [max_position_embeddings, rope_dim])).to(dtype).npu() positions = torch.randint(low=0, high=max_position_embeddings, size=(num_tokens,), dtype=torch.int64, device=device) # fused kernel q, k, v = torch.ops.vllm.qkv_rmsnorm_rope(input=qkv, diff --git a/tests/e2e/singlecard/compile/test_graphex_qknorm_rope_fusion.py b/tests/e2e/singlecard/compile/test_graphex_qknorm_rope_fusion.py index 7298ecb9..0c35b70e 100644 --- a/tests/e2e/singlecard/compile/test_graphex_qknorm_rope_fusion.py +++ b/tests/e2e/singlecard/compile/test_graphex_qknorm_rope_fusion.py @@ -1,5 +1,6 @@ import copy +import numpy as np import pytest import torch import torch.nn as nn @@ -16,6 +17,8 @@ from vllm_ascend.compilation.passes.qknorm_rope_fusion_pass import ( ) from vllm_ascend.ops.triton.triton_utils import init_device_properties_triton +MAX_POSITION_EMBEDDING = 262144 + def find_op(gm, op_default): return any(node.op == "call_function" and node.target == op_default for node in gm.graph.nodes) @@ -207,8 +210,10 @@ def test_rmsnorm_quant_fusion( model = model.to("npu") seq_len = 5 qkv = torch.randn(seq_len, qkv_size, device="npu", dtype=dtype) - cos = torch.randn(1, seq_len, 1, head_dim, device="npu", dtype=dtype) - sin = torch.randn(1, seq_len, 1, head_dim, device="npu", dtype=dtype) + cos_sin_cache = torch.from_numpy(np.random.uniform(0, 1, [MAX_POSITION_EMBEDDING, head_dim])).to(dtype).npu() + positions = torch.randint( + low=0, high=MAX_POSITION_EMBEDDING, size=(num_tokens,), dtype=torch.int64, device="npu" + ) with torch.no_grad(): original_optimize = torchair.npu_fx_compiler._optimize_fx @@ -218,6 +223,6 @@ def test_rmsnorm_quant_fusion( compiled_model = torch.compile(model, backend="npugraph_ex", fullgraph=True, dynamic=True) - compiled_model(qkv, cos, sin) + compiled_model(qkv, cos_sin_cache, positions) torchair.npu_fx_compiler._optimize_fx = original_optimize diff --git a/tests/e2e/singlecard/test_aclgraph_accuracy.py b/tests/e2e/singlecard/test_aclgraph_accuracy.py index 6835b194..32b891b6 100644 --- a/tests/e2e/singlecard/test_aclgraph_accuracy.py +++ b/tests/e2e/singlecard/test_aclgraph_accuracy.py @@ -74,7 +74,7 @@ CASE_QWEN_FULL_DECODE_ONLY = LLMTestCase( prompts=PROMPTS_LONG, golden_answers=[ " \n\nTo solve this problem, we need to use the Law of Sines and Law of Cosines. Let me start by drawing triangle $ABC$ with the", - " \n\nTo solve this problem, we can use the fact that the expected value of the area of a triangle with vertices on a square can be calculated by integrating over", + " \n\nTo solve this problem, we can use the following approach: Let $P$ be the perimeter of the square. Then, the expected value of the area", " \n\nTo solve this problem, we can use the following approach: Let $ \\alpha $ be the common real root of the two equations. Then, we can", ], ) @@ -95,7 +95,7 @@ CASE_QWEN_EX = LLMTestCase( prompts=PROMPTS_LONG, golden_answers=[ " \n\nTo solve this problem, we need to use the Law of Sines and Law of Cosines. Let me start by drawing triangle $ABC$ with the", - " \n\nTo solve this problem, we can use the fact that the expected value of the area of a triangle with vertices on a square can be calculated by integrating over", + " \n\nTo solve this problem, we can use the following approach: Let $P$ be the perimeter of the square. Then, the expected value of the area", " \n\nTo solve this problem, we can use the following approach: Let $ \\alpha $ be the common real root of the two equations. Then, we can", ], ) diff --git a/vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py b/vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py index 2fc06a10..f7e59a28 100644 --- a/vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py +++ b/vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py @@ -20,17 +20,15 @@ import triton # type: ignore import triton.language as tl # type: ignore from vllm.utils.torch_utils import direct_register_custom_op -from vllm_ascend.ops.triton.triton_utils import extract_slice, get_vectorcore_num, insert_slice +from vllm_ascend.ops.triton.triton_utils import extract_slice, get_element, get_vectorcore_num, insert_slice @triton.jit def split_qkv_rmsnorm_rope_kernel( - input_ptr, - cos_sin_ptr, - pos_ptr, - q_ptr, - k_ptr, - v_ptr, + input_gm_ptr, + q_gm_ptr, + k_gm_ptr, + v_gm_ptr, q_weight_ptr, q_bias_ptr, k_weight_ptr, @@ -40,158 +38,228 @@ def split_qkv_rmsnorm_rope_kernel( kv_hidden_size: tl.constexpr, total_hidden_size: tl.constexpr, eps: tl.constexpr, - Q_BLOCK_SIZE: tl.constexpr, - KV_BLOCK_SIZE: tl.constexpr, BIAS: tl.constexpr, HEAD_DIM: tl.constexpr, - HALF_HEAD_DIM: tl.constexpr, + ROPE_DIM: tl.constexpr, + HALF_ROPE_DIM: tl.constexpr, + IS_PARTIAL_ROPE: tl.constexpr, + num_vectorcore: tl.constexpr, + batch_size_per_iter_per_vec: tl.constexpr, + qk_head_nums_per_iter_per_vec: tl.constexpr, + q_head_num: tl.constexpr, + kv_head_num: tl.constexpr, + qk_head_num_sum: tl.constexpr, + v_batch_size_per_iter_per_vec: tl.constexpr, + positions_gm_ptr, + cos_sin_cache_gm_ptr, ): row_pid = tl.program_id(0) - col_pid = tl.program_id(1) - row_step = tl.num_programs(0) - # q - weight_values = tl.load(q_weight_ptr + tl.arange(0, HEAD_DIM)) - if BIAS: - bias_values = tl.load(q_bias_ptr + tl.arange(0, HEAD_DIM)) - input_offset = row_pid * total_hidden_size - output_offset = row_pid * q_hidden_size - input_offset_step = row_step * total_hidden_size - output_offset_step = row_step * q_hidden_size - for row_idx in tl.range(row_pid, batch_size, row_step): - col_indices = col_pid * Q_BLOCK_SIZE + tl.arange(0, Q_BLOCK_SIZE) - valid_mask = col_indices < q_hidden_size - input_values = ( - tl.load(input_ptr + input_offset + col_indices, mask=valid_mask, other=0.0) - .to(tl.float32) - .reshape(Q_BLOCK_SIZE // HEAD_DIM, HEAD_DIM) - ) - squares = input_values * input_values - variances = tl.sum(squares, axis=1) / HEAD_DIM - reciprocal_std = (1 / tl.sqrt(variances + eps)).reshape(Q_BLOCK_SIZE // HEAD_DIM, 1) - normalized_values = input_values * reciprocal_std # (Q_BLOCK_SIZE//HEAD_DIM, HEAD_DIM) - if BIAS: - normalized_values = (normalized_values * weight_values + bias_values).to(tl.bfloat16) - else: - normalized_values = (normalized_values * weight_values).to(tl.bfloat16) - pos_idx = tl.load(pos_ptr + row_idx).to(tl.int64) - cos_offsets = pos_idx * HEAD_DIM + tl.arange(0, HALF_HEAD_DIM) - sin_offsets = pos_idx * HEAD_DIM + tl.arange(HALF_HEAD_DIM, HEAD_DIM) - cos = (tl.load(cos_sin_ptr + cos_offsets)).reshape(1, HALF_HEAD_DIM) - sin = (tl.load(cos_sin_ptr + sin_offsets)).reshape(1, HALF_HEAD_DIM) + q_weight_values = tl.load(q_weight_ptr + tl.arange(0, HEAD_DIM)) + k_weight_values = tl.load(k_weight_ptr + tl.arange(0, HEAD_DIM)) + + batch_size_per_vec = tl.cdiv(batch_size, num_vectorcore) + iter_num_per_vec = tl.cdiv(batch_size_per_vec, batch_size_per_iter_per_vec) + v_iter_num_per_vec = tl.cdiv(batch_size_per_vec, v_batch_size_per_iter_per_vec) + input_batch_offset = row_pid * batch_size_per_vec + mblk_idx = tl.arange(0, batch_size_per_iter_per_vec) + input_batch_offset + nblk_idx = tl.arange(0, q_hidden_size + kv_hidden_size) + nmask = nblk_idx < total_hidden_size + + input_batch_offset_end = min(input_batch_offset + batch_size_per_vec, batch_size) + + pos_indices = input_batch_offset + tl.arange(0, batch_size_per_iter_per_vec) + output_q_nblk_idx = tl.arange(0, q_hidden_size) + output_q_nmask = output_q_nblk_idx < q_hidden_size + output_kv_nblk_idx = tl.arange(0, kv_hidden_size) + output_kv_nmask = output_kv_nblk_idx < kv_hidden_size + sin_cos_range = tl.arange(0, ROPE_DIM) + cos_sin_cache_offset = cos_sin_cache_gm_ptr + sin_cos_range + + for iter in tl.range(iter_num_per_vec): + pos_offset = iter * batch_size_per_iter_per_vec + x = tl.load( + positions_gm_ptr + pos_indices + pos_offset, mask=(pos_indices + pos_offset) < input_batch_offset_end + ) + mmask = (mblk_idx + pos_offset) < input_batch_offset_end + mask = (mmask[:, None]) & (nmask[None, :]) + idx = (mblk_idx + pos_offset)[:, None] * total_hidden_size + nblk_idx[None, :] + values_tmp1 = tl.load(input_gm_ptr + idx, mask=mask).reshape(qk_head_nums_per_iter_per_vec, HEAD_DIM) + if BIAS: + q_bias_values = tl.load(q_bias_ptr + tl.arange(0, HEAD_DIM)) + k_bias_values = tl.load(k_bias_ptr + tl.arange(0, HEAD_DIM)) + + values_tmp3 = tl.zeros((batch_size_per_iter_per_vec, ROPE_DIM), dtype=tl.bfloat16) + for i in tl.range(batch_size_per_iter_per_vec): + pos = get_element(x, (i,)) + values_tmp3 = insert_slice( + values_tmp3.reshape(batch_size_per_iter_per_vec, ROPE_DIM), + tl.load(pos * ROPE_DIM + cos_sin_cache_offset[:, None]).reshape(1, ROPE_DIM), + offsets=(i, 0), + sizes=(1, ROPE_DIM), + strides=(1, 1), + ) + values_tmp3 = values_tmp3.reshape(batch_size_per_iter_per_vec, 1, ROPE_DIM) + cos = extract_slice( + values_tmp3, + offsets=(0, 0, 0), + sizes=(batch_size_per_iter_per_vec, 1, HALF_ROPE_DIM), + strides=(1, 1, 1), + ) + sin = extract_slice( + values_tmp3, + offsets=(0, 0, HALF_ROPE_DIM), + sizes=(batch_size_per_iter_per_vec, 1, HALF_ROPE_DIM), + strides=(1, 1, 1), + ) + + normalized_values = values_tmp1.to(tl.float32) + normalized_values = normalized_values * normalized_values + normalized_values = tl.sum(normalized_values, axis=1) / HEAD_DIM + normalized_values = 1 / tl.sqrt(normalized_values + eps).reshape(qk_head_nums_per_iter_per_vec, 1) + normalized_values = values_tmp1 * normalized_values + + normalized_values_tmp = extract_slice( + normalized_values.reshape(batch_size_per_iter_per_vec, qk_head_num_sum, HEAD_DIM), + offsets=(0, 0, 0), + sizes=(batch_size_per_iter_per_vec, q_head_num, HEAD_DIM), + strides=(1, 1, 1), + ) + + if BIAS: + normalized_values_tmp = (normalized_values_tmp * q_weight_values + q_bias_values).to(tl.bfloat16) + else: + normalized_values_tmp = (normalized_values_tmp * q_weight_values).to(tl.bfloat16) + + # q rope + values_tmp = tl.zeros((batch_size_per_iter_per_vec, q_head_num, ROPE_DIM), dtype=tl.bfloat16) x1 = extract_slice( - normalized_values, - offsets=(0, 0), - sizes=(Q_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM), - strides=(1, 1), + normalized_values_tmp, + offsets=(0, 0, 0), + sizes=(batch_size_per_iter_per_vec, q_head_num, HALF_ROPE_DIM), + strides=(1, 1, 1), ) x2 = extract_slice( - normalized_values, - offsets=(0, HALF_HEAD_DIM), - sizes=(Q_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM), - strides=(1, 1), + normalized_values_tmp, + offsets=(0, 0, HALF_ROPE_DIM), + sizes=(batch_size_per_iter_per_vec, q_head_num, HALF_ROPE_DIM), + strides=(1, 1, 1), ) - roped_q1 = x1 * cos - x2 * sin - roped_q2 = x2 * cos + x1 * sin - - roped_q = tl.zeros((Q_BLOCK_SIZE // HEAD_DIM, HEAD_DIM), dtype=tl.bfloat16) - roped_q = insert_slice( - roped_q, - roped_q1, - offsets=(0, 0), - sizes=(Q_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM), - strides=(1, 1), + values_tmp = insert_slice( + values_tmp, + x1 * cos - x2 * sin, + offsets=(0, 0, 0), + sizes=(batch_size_per_iter_per_vec, q_head_num, HALF_ROPE_DIM), + strides=(1, 1, 1), ) - roped_q = insert_slice( - roped_q, - roped_q2, - offsets=(0, HALF_HEAD_DIM), - sizes=(Q_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM), - strides=(1, 1), + values_tmp = insert_slice( + values_tmp, + x2 * cos + x1 * sin, + offsets=(0, 0, HALF_ROPE_DIM), + sizes=(batch_size_per_iter_per_vec, q_head_num, HALF_ROPE_DIM), + strides=(1, 1, 1), ) - tl.store( - q_ptr + output_offset + col_indices, - roped_q.reshape(Q_BLOCK_SIZE).to(q_ptr.dtype.element_ty), - mask=valid_mask, - ) - input_offset += input_offset_step - output_offset += output_offset_step - - weight_values = tl.load(k_weight_ptr + tl.arange(0, HEAD_DIM)) - if BIAS: - bias_values = tl.load(k_bias_ptr + tl.arange(0, HEAD_DIM)) - input_offset = row_pid * total_hidden_size + q_hidden_size - output_offset = row_pid * kv_hidden_size - output_offset_step = row_step * kv_hidden_size - for row_idx in tl.range(row_pid, batch_size, row_step): - col_indices = col_pid * KV_BLOCK_SIZE + tl.arange(0, KV_BLOCK_SIZE) - valid_mask = col_indices < kv_hidden_size - input_values = ( - tl.load(input_ptr + input_offset + col_indices, mask=valid_mask, other=0.0) - .to(tl.float32) - .reshape(KV_BLOCK_SIZE // HEAD_DIM, HEAD_DIM) - ) - squares = input_values * input_values - variances = tl.sum(squares, axis=1) / HEAD_DIM - reciprocal_std = (1 / tl.sqrt(variances + eps)).reshape(KV_BLOCK_SIZE // HEAD_DIM, 1) - normalized_values = input_values * reciprocal_std # (KV_BLOCK_SIZE/HEAD_DIM, HEAD_DIM) - if BIAS: - normalized_values = (normalized_values * weight_values + bias_values).to(tl.bfloat16) + q_output_idx = output_q_nblk_idx[None, :] + (mblk_idx + pos_offset)[:, None] * q_hidden_size + mask = (mmask[:, None]) & (output_q_nmask[None, :]) + if IS_PARTIAL_ROPE: + normalized_values_tmp = insert_slice( + normalized_values_tmp, + values_tmp, + offsets=(0, 0, 0), + sizes=(batch_size_per_iter_per_vec, q_head_num, ROPE_DIM), + strides=(1, 1, 1), + ) + tl.store( + q_gm_ptr + q_output_idx, + normalized_values_tmp.reshape(batch_size_per_iter_per_vec, q_hidden_size), + mask=mask, + ) else: - normalized_values = (normalized_values * weight_values).to(tl.bfloat16) + tl.store( + q_gm_ptr + q_output_idx, + values_tmp.reshape(batch_size_per_iter_per_vec, q_hidden_size), + mask=mask, + ) + + # k rope + normalized_values_tmp1 = extract_slice( + normalized_values.reshape(batch_size_per_iter_per_vec, qk_head_num_sum, HEAD_DIM), + offsets=(0, q_head_num, 0), + sizes=(batch_size_per_iter_per_vec, kv_head_num, HEAD_DIM), + strides=(1, 1, 1), + ) + + if BIAS: + normalized_values_tmp1 = (normalized_values_tmp1 * k_weight_values + k_bias_values).to(tl.bfloat16) + else: + normalized_values_tmp1 = (normalized_values_tmp1 * k_weight_values).to(tl.bfloat16) + + values_tmp2 = tl.zeros((batch_size_per_iter_per_vec, kv_head_num, ROPE_DIM), dtype=tl.bfloat16) - pos_idx = tl.load(pos_ptr + row_idx).to(tl.int64) - cos_offsets = pos_idx * HEAD_DIM + tl.arange(0, HALF_HEAD_DIM) - sin_offsets = pos_idx * HEAD_DIM + tl.arange(HALF_HEAD_DIM, HEAD_DIM) - cos = (tl.load(cos_sin_ptr + cos_offsets)).reshape(1, HALF_HEAD_DIM) - sin = (tl.load(cos_sin_ptr + sin_offsets)).reshape(1, HALF_HEAD_DIM) x1 = extract_slice( - normalized_values, - offsets=(0, 0), - sizes=(KV_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM), - strides=(1, 1), + normalized_values_tmp1, + offsets=(0, 0, 0), + sizes=(batch_size_per_iter_per_vec, kv_head_num, HALF_ROPE_DIM), + strides=(1, 1, 1), ) x2 = extract_slice( - normalized_values, - offsets=(0, HALF_HEAD_DIM), - sizes=(KV_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM), - strides=(1, 1), + normalized_values_tmp1, + offsets=(0, 0, HALF_ROPE_DIM), + sizes=(batch_size_per_iter_per_vec, kv_head_num, HALF_ROPE_DIM), + strides=(1, 1, 1), + ) + values_tmp2 = insert_slice( + values_tmp2, + x1 * cos - x2 * sin, + offsets=(0, 0, 0), + sizes=(batch_size_per_iter_per_vec, kv_head_num, HALF_ROPE_DIM), + strides=(1, 1, 1), + ) + values_tmp2 = insert_slice( + values_tmp2, + x2 * cos + x1 * sin, + offsets=(0, 0, HALF_ROPE_DIM), + sizes=(batch_size_per_iter_per_vec, kv_head_num, HALF_ROPE_DIM), + strides=(1, 1, 1), ) - roped_k1 = x1 * cos - x2 * sin - roped_k2 = x2 * cos + x1 * sin - roped_k = tl.zeros((KV_BLOCK_SIZE // HEAD_DIM, HEAD_DIM), dtype=tl.bfloat16) - roped_k = insert_slice( - roped_k, - roped_k1, - offsets=(0, 0), - sizes=(KV_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM), - strides=(1, 1), - ) - roped_k = insert_slice( - roped_k, - roped_k2, - offsets=(0, HALF_HEAD_DIM), - sizes=(KV_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM), - strides=(1, 1), - ) - tl.store( - k_ptr + output_offset + col_indices, - roped_k.to(tl.bfloat16).reshape(KV_BLOCK_SIZE), - mask=valid_mask, - ) - input_offset += input_offset_step - output_offset += output_offset_step + kv_output_idx = output_kv_nblk_idx[None, :] + (mblk_idx + pos_offset)[:, None] * kv_hidden_size + mask = (mmask[:, None]) & (output_kv_nmask[None, :]) + if IS_PARTIAL_ROPE: + normalized_values_tmp1 = insert_slice( + normalized_values_tmp1, + values_tmp2, + offsets=(0, 0, 0), + sizes=(batch_size_per_iter_per_vec, kv_head_num, ROPE_DIM), + strides=(1, 1, 1), + ) + tl.store( + k_gm_ptr + kv_output_idx, + normalized_values_tmp1.reshape(batch_size_per_iter_per_vec, kv_hidden_size), + mask=mask, + ) + else: + tl.store( + k_gm_ptr + kv_output_idx, + values_tmp2.reshape(batch_size_per_iter_per_vec, kv_hidden_size), + mask=mask, + ) - input_offset = row_pid * total_hidden_size + q_hidden_size + kv_hidden_size - output_offset = row_pid * kv_hidden_size - for _ in tl.range(row_pid, batch_size, row_step): - col_indices = col_pid * KV_BLOCK_SIZE + tl.arange(0, KV_BLOCK_SIZE) - valid_mask = col_indices < kv_hidden_size - input_values = tl.load(input_ptr + input_offset + col_indices, mask=valid_mask, other=0.0) - tl.store(v_ptr + output_offset + col_indices, input_values, mask=valid_mask) - input_offset += input_offset_step - output_offset += output_offset_step + mblk_idx = tl.arange(0, v_batch_size_per_iter_per_vec) + input_batch_offset + nblk_idx = tl.arange(q_hidden_size + kv_hidden_size, total_hidden_size) + nmask = nblk_idx < total_hidden_size + out_nblk_idx = tl.arange(0, kv_hidden_size) + out_nmask = out_nblk_idx < kv_hidden_size + + for _ in tl.range(v_iter_num_per_vec): + mmask = mblk_idx < input_batch_offset_end + mask = (mmask[:, None]) & (nmask[None, :]) + idx = mblk_idx[:, None] * total_hidden_size + nblk_idx[None, :] + values = tl.load(input_gm_ptr + idx, mask=mask) + out_idx = mblk_idx[:, None] * kv_hidden_size + out_nblk_idx[None, :] + out_mask = (mmask[:, None]) & (out_nmask[None, :]) + tl.store(v_gm_ptr + out_idx, values, mask=out_mask) + mblk_idx += v_batch_size_per_iter_per_vec def split_qkv_rmsnorm_rope_impl( @@ -207,25 +275,46 @@ def split_qkv_rmsnorm_rope_impl( q_bias: torch.Tensor | None = None, k_bias: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - KV_BLOCK_SIZE = triton.next_power_of_2(head_dim) - assert head_dim == KV_BLOCK_SIZE - assert q_hidden_size % kv_hidden_size == 0 - Q_BLOCK_SIZE = q_hidden_size // kv_hidden_size * head_dim + # get available vector core + num_vectorcore = get_vectorcore_num() + rope_dim = cos_sin_cache.shape[-1] batch_size = input.shape[0] + BIAS = q_bias is not None + IS_PARTIAL_ROPE = rope_dim != head_dim + # Q + K + V total_hidden_size = q_hidden_size + kv_hidden_size * 2 + q_output = torch.empty(batch_size, q_hidden_size, device=input.device, dtype=input.dtype) k_output = torch.empty(batch_size, kv_hidden_size, device=input.device, dtype=input.dtype) v_output = torch.empty(batch_size, kv_hidden_size, device=input.device, dtype=input.dtype) - n_cols = kv_hidden_size // KV_BLOCK_SIZE - num_vectorcore = get_vectorcore_num() - assert num_vectorcore % n_cols == 0 - n_rows = num_vectorcore // n_cols - BIAS = q_bias is not None - split_qkv_rmsnorm_rope_kernel[(n_rows, n_cols, 1)]( + q_head_num = q_hidden_size // head_dim + kv_head_num = kv_hidden_size // head_dim + + # set number of line loading from GM data is x + # x*(q_head_num + kv_head_num)*HEAD_DIM: values_tmp + # 2x*(q_head_num + kv_head_num)*HEAD_DIM: normalized_values(float32) + # x*ROPE_DIM*2 : cos/sin + # x*q_head_num*HEAD_DIM*2: normalized_values_tmp + # x*q_head_num*ROPE_DIM*(0.5) (not IS_PARTIAL_ROPE) x*q_head_num*ROPE_DIM*(0.5): y + UB_SIZE = 87040 # 85K = 85 * 1024 + # the factor is the sum of elements number + if IS_PARTIAL_ROPE: + factor = 5 * q_hidden_size + 3 * kv_hidden_size + rope_dim * 4 + q_head_num * rope_dim + batch_size_per_iter_per_vec = int(UB_SIZE / input.element_size()) // factor + else: + factor = 5 * q_hidden_size + 3 * kv_hidden_size + rope_dim * 2 + q_head_num * rope_dim // 2 + batch_size_per_iter_per_vec = int(UB_SIZE / input.element_size()) // factor + batch_size_per_iter_per_vec = max(1, batch_size_per_iter_per_vec) + qk_head_num_sum = int(q_head_num + kv_head_num) + qk_head_nums_per_iter_per_vec = batch_size_per_iter_per_vec * qk_head_num_sum + + grid = (num_vectorcore, 1, 1) + # v tiling + v_batch_size_per_iter_per_vec = UB_SIZE / torch.bfloat16.itemsize // (kv_hidden_size + 1) + + split_qkv_rmsnorm_rope_kernel[grid]( input, - cos_sin_cache, - positions, q_output, k_output, v_output, @@ -238,11 +327,20 @@ def split_qkv_rmsnorm_rope_impl( kv_hidden_size, total_hidden_size, eps, - Q_BLOCK_SIZE, - KV_BLOCK_SIZE, BIAS, head_dim, - head_dim // 2, + rope_dim, + rope_dim // 2, + IS_PARTIAL_ROPE, + num_vectorcore, + int(batch_size_per_iter_per_vec), + int(qk_head_nums_per_iter_per_vec), + q_head_num, + kv_head_num, + qk_head_num_sum, + int(v_batch_size_per_iter_per_vec), + positions, + cos_sin_cache, ) return q_output, k_output, v_output @@ -265,19 +363,19 @@ def split_qkv_rmsnorm_rope_impl_fake( batch_size = input.shape[0] q_output = torch.empty( batch_size, - q_hidden_size, + int(q_hidden_size), device=input.device, dtype=input.dtype, ) k_output = torch.empty( batch_size, - kv_hidden_size, + int(kv_hidden_size), device=input.device, dtype=input.dtype, ) v_output = torch.empty( batch_size, - kv_hidden_size, + int(kv_hidden_size), device=input.device, dtype=input.dtype, )