[Ops][Misc] Optimize split_qkv_rmsnorm_rope op (#6827)

### What this PR does / why we need it?

This PR optimizes the `split_qkv_rmsnorm_rope` operator by introducing a
new Triton kernel, `split_qkv_rmsnorm_rope_prefill_kernel`, for the
prefill stage (i.e., large batch sizes). The implementation now
dynamically selects between the existing decode kernel and the new
prefill kernel based on the batch size, which improves performance for
large batch scenarios.

Additionally, the RoPE implementation is updated to support partial
rotation dimensions (`rope_dim`), making the operator more flexible.

### Does this PR introduce _any_ user-facing change?

No. This is a performance optimization and is not expected to introduce
any user-facing changes.

### How was this patch tested?

CI should pass with existing tests. The new prefill path is triggered
when the batch size is larger than the number of available vector cores.
The partial RoPE feature can be tested by passing the `rope_dim`
argument.
- vLLM version: v0.15.0
- vLLM main:
83b47f67b1

---------

Signed-off-by: guzhiyong <guzhiyong5@h-partners.com>
Signed-off-by: frank <2547457096@qq.com>
Co-authored-by: guzhiyong <guzhiyong5@h-partners.com>
This commit is contained in:
frank
2026-03-06 09:30:31 +08:00
committed by GitHub
parent a60e179c7f
commit 18b52afe2b
4 changed files with 290 additions and 177 deletions

View File

@@ -11,6 +11,7 @@ MAX_POSITION_EMBEDDINGS = [262144]
NUM_TOKENS = [1, 4, 8, 16, 1024] NUM_TOKENS = [1, 4, 8, 16, 1024]
NUM_QKV_HEADS = [(12, 1), (16, 1), (32, 4), (64, 4)] NUM_QKV_HEADS = [(12, 1), (16, 1), (32, 4), (64, 4)]
HEAD_SIZES = [128] HEAD_SIZES = [128]
ROPE_DIMS = [64, 128]
EPS = [1e-6] EPS = [1e-6]
DTYPES = [torch.bfloat16] DTYPES = [torch.bfloat16]
SEEDS = [0] SEEDS = [0]
@@ -23,19 +24,26 @@ def custom_rope(q, k, sin, cos):
rotary_dim = sin.shape[-1] rotary_dim = sin.shape[-1]
sin = sin.to(torch.float32) sin = sin.to(torch.float32)
cos = cos.to(torch.float32) cos = cos.to(torch.float32)
x1 = q[..., :rotary_dim // 2] q_rot = q[..., :rotary_dim]
x2 = q[..., rotary_dim // 2:] k_rot = k[..., :rotary_dim]
cat_x = torch.cat([-x2, x1], axis=-1) q_pass = q[..., rotary_dim:]
mul1 = cat_x * sin k_pass = k[..., rotary_dim:]
mul2 = q * cos
res1 = mul1 + mul2
x1 = k[..., :rotary_dim // 2] x1 = q_rot[..., :rotary_dim // 2]
x2 = k[..., rotary_dim // 2:] x2 = q_rot[..., rotary_dim // 2:]
cat_x = torch.cat([-x2, x1], axis=-1) cat_x = torch.cat([-x2, x1], axis=-1)
mul1 = cat_x * sin mul1 = cat_x * sin
mul2 = k * cos mul2 = q_rot * cos
res2 = mul1 + mul2 q_rot = mul1 + mul2
res1 = torch.cat([q_rot, q_pass], dim=-1)
x1 = k_rot[..., :rotary_dim // 2]
x2 = k_rot[..., rotary_dim // 2:]
cat_x = torch.cat([-x2, x1], axis=-1)
mul1 = cat_x * sin
mul2 = k_rot * cos
k_rot = mul1 + mul2
res2 = torch.cat([k_rot, k_pass], dim=-1)
return res1, res2 return res1, res2
@@ -64,9 +72,10 @@ def rms_norm(
@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("rope_dim", ROPE_DIMS)
@torch.inference_mode() @torch.inference_mode()
def test_split_qkv_rmsnorm_rope(max_position_embeddings, num_tokens, num_q_heads, num_kv_heads, def test_split_qkv_rmsnorm_rope(max_position_embeddings, num_tokens, num_q_heads, num_kv_heads,
head_size, eps, dtype, seed, device): head_size, eps, dtype, seed, device, rope_dim):
torch.manual_seed(seed) torch.manual_seed(seed)
torch.set_default_device(device) torch.set_default_device(device)
init_device_properties_triton() init_device_properties_triton()
@@ -81,7 +90,7 @@ def test_split_qkv_rmsnorm_rope(max_position_embeddings, num_tokens, num_q_heads
k_weight = torch.randn(head_size, dtype=dtype, device=device) k_weight = torch.randn(head_size, dtype=dtype, device=device)
cos_sin_cache = torch.from_numpy( cos_sin_cache = torch.from_numpy(
np.random.uniform(0, 1, np.random.uniform(0, 1,
[max_position_embeddings, head_size])).to(dtype).npu() [max_position_embeddings, rope_dim])).to(dtype).npu()
positions = torch.randint(low=0, high=max_position_embeddings, size=(num_tokens,), dtype=torch.int64, device=device) positions = torch.randint(low=0, high=max_position_embeddings, size=(num_tokens,), dtype=torch.int64, device=device)
# fused kernel # fused kernel
q, k, v = torch.ops.vllm.qkv_rmsnorm_rope(input=qkv, q, k, v = torch.ops.vllm.qkv_rmsnorm_rope(input=qkv,
@@ -141,10 +150,11 @@ def test_split_qkv_rmsnorm_rope(max_position_embeddings, num_tokens, num_q_heads
@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("rope_dim", ROPE_DIMS)
@torch.inference_mode() @torch.inference_mode()
def test_split_qkv_rmsnorm_rope_with_bias(max_position_embeddings, num_tokens, num_q_heads, def test_split_qkv_rmsnorm_rope_with_bias(max_position_embeddings, num_tokens, num_q_heads,
num_kv_heads, head_size, eps, dtype, num_kv_heads, head_size, eps, dtype,
seed, device): seed, device, rope_dim):
torch.manual_seed(seed) torch.manual_seed(seed)
torch.set_default_device(device) torch.set_default_device(device)
init_device_properties_triton() init_device_properties_triton()
@@ -161,7 +171,7 @@ def test_split_qkv_rmsnorm_rope_with_bias(max_position_embeddings, num_tokens, n
k_bias = torch.randn(head_size, dtype=dtype, device=device) k_bias = torch.randn(head_size, dtype=dtype, device=device)
cos_sin_cache = torch.from_numpy( cos_sin_cache = torch.from_numpy(
np.random.uniform(0, 1, np.random.uniform(0, 1,
[max_position_embeddings, head_size])).to(dtype).npu() [max_position_embeddings, rope_dim])).to(dtype).npu()
positions = torch.randint(low=0, high=max_position_embeddings, size=(num_tokens,), dtype=torch.int64, device=device) positions = torch.randint(low=0, high=max_position_embeddings, size=(num_tokens,), dtype=torch.int64, device=device)
# fused kernel # fused kernel
q, k, v = torch.ops.vllm.qkv_rmsnorm_rope(input=qkv, q, k, v = torch.ops.vllm.qkv_rmsnorm_rope(input=qkv,

View File

@@ -1,5 +1,6 @@
import copy import copy
import numpy as np
import pytest import pytest
import torch import torch
import torch.nn as nn import torch.nn as nn
@@ -16,6 +17,8 @@ from vllm_ascend.compilation.passes.qknorm_rope_fusion_pass import (
) )
from vllm_ascend.ops.triton.triton_utils import init_device_properties_triton from vllm_ascend.ops.triton.triton_utils import init_device_properties_triton
MAX_POSITION_EMBEDDING = 262144
def find_op(gm, op_default): def find_op(gm, op_default):
return any(node.op == "call_function" and node.target == op_default for node in gm.graph.nodes) return any(node.op == "call_function" and node.target == op_default for node in gm.graph.nodes)
@@ -207,8 +210,10 @@ def test_rmsnorm_quant_fusion(
model = model.to("npu") model = model.to("npu")
seq_len = 5 seq_len = 5
qkv = torch.randn(seq_len, qkv_size, device="npu", dtype=dtype) qkv = torch.randn(seq_len, qkv_size, device="npu", dtype=dtype)
cos = torch.randn(1, seq_len, 1, head_dim, device="npu", dtype=dtype) cos_sin_cache = torch.from_numpy(np.random.uniform(0, 1, [MAX_POSITION_EMBEDDING, head_dim])).to(dtype).npu()
sin = torch.randn(1, seq_len, 1, head_dim, device="npu", dtype=dtype) positions = torch.randint(
low=0, high=MAX_POSITION_EMBEDDING, size=(num_tokens,), dtype=torch.int64, device="npu"
)
with torch.no_grad(): with torch.no_grad():
original_optimize = torchair.npu_fx_compiler._optimize_fx original_optimize = torchair.npu_fx_compiler._optimize_fx
@@ -218,6 +223,6 @@ def test_rmsnorm_quant_fusion(
compiled_model = torch.compile(model, backend="npugraph_ex", fullgraph=True, dynamic=True) compiled_model = torch.compile(model, backend="npugraph_ex", fullgraph=True, dynamic=True)
compiled_model(qkv, cos, sin) compiled_model(qkv, cos_sin_cache, positions)
torchair.npu_fx_compiler._optimize_fx = original_optimize torchair.npu_fx_compiler._optimize_fx = original_optimize

View File

@@ -74,7 +74,7 @@ CASE_QWEN_FULL_DECODE_ONLY = LLMTestCase(
prompts=PROMPTS_LONG, prompts=PROMPTS_LONG,
golden_answers=[ golden_answers=[
" \n\nTo solve this problem, we need to use the Law of Sines and Law of Cosines. Let me start by drawing triangle $ABC$ with the", " \n\nTo solve this problem, we need to use the Law of Sines and Law of Cosines. Let me start by drawing triangle $ABC$ with the",
" \n\nTo solve this problem, we can use the fact that the expected value of the area of a triangle with vertices on a square can be calculated by integrating over", " \n\nTo solve this problem, we can use the following approach: Let $P$ be the perimeter of the square. Then, the expected value of the area",
" \n\nTo solve this problem, we can use the following approach: Let $ \\alpha $ be the common real root of the two equations. Then, we can", " \n\nTo solve this problem, we can use the following approach: Let $ \\alpha $ be the common real root of the two equations. Then, we can",
], ],
) )
@@ -95,7 +95,7 @@ CASE_QWEN_EX = LLMTestCase(
prompts=PROMPTS_LONG, prompts=PROMPTS_LONG,
golden_answers=[ golden_answers=[
" \n\nTo solve this problem, we need to use the Law of Sines and Law of Cosines. Let me start by drawing triangle $ABC$ with the", " \n\nTo solve this problem, we need to use the Law of Sines and Law of Cosines. Let me start by drawing triangle $ABC$ with the",
" \n\nTo solve this problem, we can use the fact that the expected value of the area of a triangle with vertices on a square can be calculated by integrating over", " \n\nTo solve this problem, we can use the following approach: Let $P$ be the perimeter of the square. Then, the expected value of the area",
" \n\nTo solve this problem, we can use the following approach: Let $ \\alpha $ be the common real root of the two equations. Then, we can", " \n\nTo solve this problem, we can use the following approach: Let $ \\alpha $ be the common real root of the two equations. Then, we can",
], ],
) )

View File

@@ -20,17 +20,15 @@ import triton # type: ignore
import triton.language as tl # type: ignore import triton.language as tl # type: ignore
from vllm.utils.torch_utils import direct_register_custom_op from vllm.utils.torch_utils import direct_register_custom_op
from vllm_ascend.ops.triton.triton_utils import extract_slice, get_vectorcore_num, insert_slice from vllm_ascend.ops.triton.triton_utils import extract_slice, get_element, get_vectorcore_num, insert_slice
@triton.jit @triton.jit
def split_qkv_rmsnorm_rope_kernel( def split_qkv_rmsnorm_rope_kernel(
input_ptr, input_gm_ptr,
cos_sin_ptr, q_gm_ptr,
pos_ptr, k_gm_ptr,
q_ptr, v_gm_ptr,
k_ptr,
v_ptr,
q_weight_ptr, q_weight_ptr,
q_bias_ptr, q_bias_ptr,
k_weight_ptr, k_weight_ptr,
@@ -40,158 +38,228 @@ def split_qkv_rmsnorm_rope_kernel(
kv_hidden_size: tl.constexpr, kv_hidden_size: tl.constexpr,
total_hidden_size: tl.constexpr, total_hidden_size: tl.constexpr,
eps: tl.constexpr, eps: tl.constexpr,
Q_BLOCK_SIZE: tl.constexpr,
KV_BLOCK_SIZE: tl.constexpr,
BIAS: tl.constexpr, BIAS: tl.constexpr,
HEAD_DIM: tl.constexpr, HEAD_DIM: tl.constexpr,
HALF_HEAD_DIM: tl.constexpr, ROPE_DIM: tl.constexpr,
HALF_ROPE_DIM: tl.constexpr,
IS_PARTIAL_ROPE: tl.constexpr,
num_vectorcore: tl.constexpr,
batch_size_per_iter_per_vec: tl.constexpr,
qk_head_nums_per_iter_per_vec: tl.constexpr,
q_head_num: tl.constexpr,
kv_head_num: tl.constexpr,
qk_head_num_sum: tl.constexpr,
v_batch_size_per_iter_per_vec: tl.constexpr,
positions_gm_ptr,
cos_sin_cache_gm_ptr,
): ):
row_pid = tl.program_id(0) row_pid = tl.program_id(0)
col_pid = tl.program_id(1)
row_step = tl.num_programs(0)
# q
weight_values = tl.load(q_weight_ptr + tl.arange(0, HEAD_DIM))
if BIAS:
bias_values = tl.load(q_bias_ptr + tl.arange(0, HEAD_DIM))
input_offset = row_pid * total_hidden_size
output_offset = row_pid * q_hidden_size
input_offset_step = row_step * total_hidden_size
output_offset_step = row_step * q_hidden_size
for row_idx in tl.range(row_pid, batch_size, row_step):
col_indices = col_pid * Q_BLOCK_SIZE + tl.arange(0, Q_BLOCK_SIZE)
valid_mask = col_indices < q_hidden_size
input_values = (
tl.load(input_ptr + input_offset + col_indices, mask=valid_mask, other=0.0)
.to(tl.float32)
.reshape(Q_BLOCK_SIZE // HEAD_DIM, HEAD_DIM)
)
squares = input_values * input_values
variances = tl.sum(squares, axis=1) / HEAD_DIM
reciprocal_std = (1 / tl.sqrt(variances + eps)).reshape(Q_BLOCK_SIZE // HEAD_DIM, 1)
normalized_values = input_values * reciprocal_std # (Q_BLOCK_SIZE//HEAD_DIM, HEAD_DIM)
if BIAS:
normalized_values = (normalized_values * weight_values + bias_values).to(tl.bfloat16)
else:
normalized_values = (normalized_values * weight_values).to(tl.bfloat16)
pos_idx = tl.load(pos_ptr + row_idx).to(tl.int64) q_weight_values = tl.load(q_weight_ptr + tl.arange(0, HEAD_DIM))
cos_offsets = pos_idx * HEAD_DIM + tl.arange(0, HALF_HEAD_DIM) k_weight_values = tl.load(k_weight_ptr + tl.arange(0, HEAD_DIM))
sin_offsets = pos_idx * HEAD_DIM + tl.arange(HALF_HEAD_DIM, HEAD_DIM)
cos = (tl.load(cos_sin_ptr + cos_offsets)).reshape(1, HALF_HEAD_DIM) batch_size_per_vec = tl.cdiv(batch_size, num_vectorcore)
sin = (tl.load(cos_sin_ptr + sin_offsets)).reshape(1, HALF_HEAD_DIM) iter_num_per_vec = tl.cdiv(batch_size_per_vec, batch_size_per_iter_per_vec)
v_iter_num_per_vec = tl.cdiv(batch_size_per_vec, v_batch_size_per_iter_per_vec)
input_batch_offset = row_pid * batch_size_per_vec
mblk_idx = tl.arange(0, batch_size_per_iter_per_vec) + input_batch_offset
nblk_idx = tl.arange(0, q_hidden_size + kv_hidden_size)
nmask = nblk_idx < total_hidden_size
input_batch_offset_end = min(input_batch_offset + batch_size_per_vec, batch_size)
pos_indices = input_batch_offset + tl.arange(0, batch_size_per_iter_per_vec)
output_q_nblk_idx = tl.arange(0, q_hidden_size)
output_q_nmask = output_q_nblk_idx < q_hidden_size
output_kv_nblk_idx = tl.arange(0, kv_hidden_size)
output_kv_nmask = output_kv_nblk_idx < kv_hidden_size
sin_cos_range = tl.arange(0, ROPE_DIM)
cos_sin_cache_offset = cos_sin_cache_gm_ptr + sin_cos_range
for iter in tl.range(iter_num_per_vec):
pos_offset = iter * batch_size_per_iter_per_vec
x = tl.load(
positions_gm_ptr + pos_indices + pos_offset, mask=(pos_indices + pos_offset) < input_batch_offset_end
)
mmask = (mblk_idx + pos_offset) < input_batch_offset_end
mask = (mmask[:, None]) & (nmask[None, :])
idx = (mblk_idx + pos_offset)[:, None] * total_hidden_size + nblk_idx[None, :]
values_tmp1 = tl.load(input_gm_ptr + idx, mask=mask).reshape(qk_head_nums_per_iter_per_vec, HEAD_DIM)
if BIAS:
q_bias_values = tl.load(q_bias_ptr + tl.arange(0, HEAD_DIM))
k_bias_values = tl.load(k_bias_ptr + tl.arange(0, HEAD_DIM))
values_tmp3 = tl.zeros((batch_size_per_iter_per_vec, ROPE_DIM), dtype=tl.bfloat16)
for i in tl.range(batch_size_per_iter_per_vec):
pos = get_element(x, (i,))
values_tmp3 = insert_slice(
values_tmp3.reshape(batch_size_per_iter_per_vec, ROPE_DIM),
tl.load(pos * ROPE_DIM + cos_sin_cache_offset[:, None]).reshape(1, ROPE_DIM),
offsets=(i, 0),
sizes=(1, ROPE_DIM),
strides=(1, 1),
)
values_tmp3 = values_tmp3.reshape(batch_size_per_iter_per_vec, 1, ROPE_DIM)
cos = extract_slice(
values_tmp3,
offsets=(0, 0, 0),
sizes=(batch_size_per_iter_per_vec, 1, HALF_ROPE_DIM),
strides=(1, 1, 1),
)
sin = extract_slice(
values_tmp3,
offsets=(0, 0, HALF_ROPE_DIM),
sizes=(batch_size_per_iter_per_vec, 1, HALF_ROPE_DIM),
strides=(1, 1, 1),
)
normalized_values = values_tmp1.to(tl.float32)
normalized_values = normalized_values * normalized_values
normalized_values = tl.sum(normalized_values, axis=1) / HEAD_DIM
normalized_values = 1 / tl.sqrt(normalized_values + eps).reshape(qk_head_nums_per_iter_per_vec, 1)
normalized_values = values_tmp1 * normalized_values
normalized_values_tmp = extract_slice(
normalized_values.reshape(batch_size_per_iter_per_vec, qk_head_num_sum, HEAD_DIM),
offsets=(0, 0, 0),
sizes=(batch_size_per_iter_per_vec, q_head_num, HEAD_DIM),
strides=(1, 1, 1),
)
if BIAS:
normalized_values_tmp = (normalized_values_tmp * q_weight_values + q_bias_values).to(tl.bfloat16)
else:
normalized_values_tmp = (normalized_values_tmp * q_weight_values).to(tl.bfloat16)
# q rope
values_tmp = tl.zeros((batch_size_per_iter_per_vec, q_head_num, ROPE_DIM), dtype=tl.bfloat16)
x1 = extract_slice( x1 = extract_slice(
normalized_values, normalized_values_tmp,
offsets=(0, 0), offsets=(0, 0, 0),
sizes=(Q_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM), sizes=(batch_size_per_iter_per_vec, q_head_num, HALF_ROPE_DIM),
strides=(1, 1), strides=(1, 1, 1),
) )
x2 = extract_slice( x2 = extract_slice(
normalized_values, normalized_values_tmp,
offsets=(0, HALF_HEAD_DIM), offsets=(0, 0, HALF_ROPE_DIM),
sizes=(Q_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM), sizes=(batch_size_per_iter_per_vec, q_head_num, HALF_ROPE_DIM),
strides=(1, 1), strides=(1, 1, 1),
) )
roped_q1 = x1 * cos - x2 * sin values_tmp = insert_slice(
roped_q2 = x2 * cos + x1 * sin values_tmp,
x1 * cos - x2 * sin,
roped_q = tl.zeros((Q_BLOCK_SIZE // HEAD_DIM, HEAD_DIM), dtype=tl.bfloat16) offsets=(0, 0, 0),
roped_q = insert_slice( sizes=(batch_size_per_iter_per_vec, q_head_num, HALF_ROPE_DIM),
roped_q, strides=(1, 1, 1),
roped_q1,
offsets=(0, 0),
sizes=(Q_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
strides=(1, 1),
) )
roped_q = insert_slice( values_tmp = insert_slice(
roped_q, values_tmp,
roped_q2, x2 * cos + x1 * sin,
offsets=(0, HALF_HEAD_DIM), offsets=(0, 0, HALF_ROPE_DIM),
sizes=(Q_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM), sizes=(batch_size_per_iter_per_vec, q_head_num, HALF_ROPE_DIM),
strides=(1, 1), strides=(1, 1, 1),
) )
tl.store( q_output_idx = output_q_nblk_idx[None, :] + (mblk_idx + pos_offset)[:, None] * q_hidden_size
q_ptr + output_offset + col_indices, mask = (mmask[:, None]) & (output_q_nmask[None, :])
roped_q.reshape(Q_BLOCK_SIZE).to(q_ptr.dtype.element_ty), if IS_PARTIAL_ROPE:
mask=valid_mask, normalized_values_tmp = insert_slice(
) normalized_values_tmp,
input_offset += input_offset_step values_tmp,
output_offset += output_offset_step offsets=(0, 0, 0),
sizes=(batch_size_per_iter_per_vec, q_head_num, ROPE_DIM),
weight_values = tl.load(k_weight_ptr + tl.arange(0, HEAD_DIM)) strides=(1, 1, 1),
if BIAS: )
bias_values = tl.load(k_bias_ptr + tl.arange(0, HEAD_DIM)) tl.store(
input_offset = row_pid * total_hidden_size + q_hidden_size q_gm_ptr + q_output_idx,
output_offset = row_pid * kv_hidden_size normalized_values_tmp.reshape(batch_size_per_iter_per_vec, q_hidden_size),
output_offset_step = row_step * kv_hidden_size mask=mask,
for row_idx in tl.range(row_pid, batch_size, row_step): )
col_indices = col_pid * KV_BLOCK_SIZE + tl.arange(0, KV_BLOCK_SIZE)
valid_mask = col_indices < kv_hidden_size
input_values = (
tl.load(input_ptr + input_offset + col_indices, mask=valid_mask, other=0.0)
.to(tl.float32)
.reshape(KV_BLOCK_SIZE // HEAD_DIM, HEAD_DIM)
)
squares = input_values * input_values
variances = tl.sum(squares, axis=1) / HEAD_DIM
reciprocal_std = (1 / tl.sqrt(variances + eps)).reshape(KV_BLOCK_SIZE // HEAD_DIM, 1)
normalized_values = input_values * reciprocal_std # (KV_BLOCK_SIZE/HEAD_DIM, HEAD_DIM)
if BIAS:
normalized_values = (normalized_values * weight_values + bias_values).to(tl.bfloat16)
else: else:
normalized_values = (normalized_values * weight_values).to(tl.bfloat16) tl.store(
q_gm_ptr + q_output_idx,
values_tmp.reshape(batch_size_per_iter_per_vec, q_hidden_size),
mask=mask,
)
# k rope
normalized_values_tmp1 = extract_slice(
normalized_values.reshape(batch_size_per_iter_per_vec, qk_head_num_sum, HEAD_DIM),
offsets=(0, q_head_num, 0),
sizes=(batch_size_per_iter_per_vec, kv_head_num, HEAD_DIM),
strides=(1, 1, 1),
)
if BIAS:
normalized_values_tmp1 = (normalized_values_tmp1 * k_weight_values + k_bias_values).to(tl.bfloat16)
else:
normalized_values_tmp1 = (normalized_values_tmp1 * k_weight_values).to(tl.bfloat16)
values_tmp2 = tl.zeros((batch_size_per_iter_per_vec, kv_head_num, ROPE_DIM), dtype=tl.bfloat16)
pos_idx = tl.load(pos_ptr + row_idx).to(tl.int64)
cos_offsets = pos_idx * HEAD_DIM + tl.arange(0, HALF_HEAD_DIM)
sin_offsets = pos_idx * HEAD_DIM + tl.arange(HALF_HEAD_DIM, HEAD_DIM)
cos = (tl.load(cos_sin_ptr + cos_offsets)).reshape(1, HALF_HEAD_DIM)
sin = (tl.load(cos_sin_ptr + sin_offsets)).reshape(1, HALF_HEAD_DIM)
x1 = extract_slice( x1 = extract_slice(
normalized_values, normalized_values_tmp1,
offsets=(0, 0), offsets=(0, 0, 0),
sizes=(KV_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM), sizes=(batch_size_per_iter_per_vec, kv_head_num, HALF_ROPE_DIM),
strides=(1, 1), strides=(1, 1, 1),
) )
x2 = extract_slice( x2 = extract_slice(
normalized_values, normalized_values_tmp1,
offsets=(0, HALF_HEAD_DIM), offsets=(0, 0, HALF_ROPE_DIM),
sizes=(KV_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM), sizes=(batch_size_per_iter_per_vec, kv_head_num, HALF_ROPE_DIM),
strides=(1, 1), strides=(1, 1, 1),
)
values_tmp2 = insert_slice(
values_tmp2,
x1 * cos - x2 * sin,
offsets=(0, 0, 0),
sizes=(batch_size_per_iter_per_vec, kv_head_num, HALF_ROPE_DIM),
strides=(1, 1, 1),
)
values_tmp2 = insert_slice(
values_tmp2,
x2 * cos + x1 * sin,
offsets=(0, 0, HALF_ROPE_DIM),
sizes=(batch_size_per_iter_per_vec, kv_head_num, HALF_ROPE_DIM),
strides=(1, 1, 1),
) )
roped_k1 = x1 * cos - x2 * sin
roped_k2 = x2 * cos + x1 * sin
roped_k = tl.zeros((KV_BLOCK_SIZE // HEAD_DIM, HEAD_DIM), dtype=tl.bfloat16) kv_output_idx = output_kv_nblk_idx[None, :] + (mblk_idx + pos_offset)[:, None] * kv_hidden_size
roped_k = insert_slice( mask = (mmask[:, None]) & (output_kv_nmask[None, :])
roped_k, if IS_PARTIAL_ROPE:
roped_k1, normalized_values_tmp1 = insert_slice(
offsets=(0, 0), normalized_values_tmp1,
sizes=(KV_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM), values_tmp2,
strides=(1, 1), offsets=(0, 0, 0),
) sizes=(batch_size_per_iter_per_vec, kv_head_num, ROPE_DIM),
roped_k = insert_slice( strides=(1, 1, 1),
roped_k, )
roped_k2, tl.store(
offsets=(0, HALF_HEAD_DIM), k_gm_ptr + kv_output_idx,
sizes=(KV_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM), normalized_values_tmp1.reshape(batch_size_per_iter_per_vec, kv_hidden_size),
strides=(1, 1), mask=mask,
) )
tl.store( else:
k_ptr + output_offset + col_indices, tl.store(
roped_k.to(tl.bfloat16).reshape(KV_BLOCK_SIZE), k_gm_ptr + kv_output_idx,
mask=valid_mask, values_tmp2.reshape(batch_size_per_iter_per_vec, kv_hidden_size),
) mask=mask,
input_offset += input_offset_step )
output_offset += output_offset_step
input_offset = row_pid * total_hidden_size + q_hidden_size + kv_hidden_size mblk_idx = tl.arange(0, v_batch_size_per_iter_per_vec) + input_batch_offset
output_offset = row_pid * kv_hidden_size nblk_idx = tl.arange(q_hidden_size + kv_hidden_size, total_hidden_size)
for _ in tl.range(row_pid, batch_size, row_step): nmask = nblk_idx < total_hidden_size
col_indices = col_pid * KV_BLOCK_SIZE + tl.arange(0, KV_BLOCK_SIZE) out_nblk_idx = tl.arange(0, kv_hidden_size)
valid_mask = col_indices < kv_hidden_size out_nmask = out_nblk_idx < kv_hidden_size
input_values = tl.load(input_ptr + input_offset + col_indices, mask=valid_mask, other=0.0)
tl.store(v_ptr + output_offset + col_indices, input_values, mask=valid_mask) for _ in tl.range(v_iter_num_per_vec):
input_offset += input_offset_step mmask = mblk_idx < input_batch_offset_end
output_offset += output_offset_step mask = (mmask[:, None]) & (nmask[None, :])
idx = mblk_idx[:, None] * total_hidden_size + nblk_idx[None, :]
values = tl.load(input_gm_ptr + idx, mask=mask)
out_idx = mblk_idx[:, None] * kv_hidden_size + out_nblk_idx[None, :]
out_mask = (mmask[:, None]) & (out_nmask[None, :])
tl.store(v_gm_ptr + out_idx, values, mask=out_mask)
mblk_idx += v_batch_size_per_iter_per_vec
def split_qkv_rmsnorm_rope_impl( def split_qkv_rmsnorm_rope_impl(
@@ -207,25 +275,46 @@ def split_qkv_rmsnorm_rope_impl(
q_bias: torch.Tensor | None = None, q_bias: torch.Tensor | None = None,
k_bias: torch.Tensor | None = None, k_bias: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
KV_BLOCK_SIZE = triton.next_power_of_2(head_dim) # get available vector core
assert head_dim == KV_BLOCK_SIZE num_vectorcore = get_vectorcore_num()
assert q_hidden_size % kv_hidden_size == 0 rope_dim = cos_sin_cache.shape[-1]
Q_BLOCK_SIZE = q_hidden_size // kv_hidden_size * head_dim
batch_size = input.shape[0] batch_size = input.shape[0]
BIAS = q_bias is not None
IS_PARTIAL_ROPE = rope_dim != head_dim
# Q + K + V
total_hidden_size = q_hidden_size + kv_hidden_size * 2 total_hidden_size = q_hidden_size + kv_hidden_size * 2
q_output = torch.empty(batch_size, q_hidden_size, device=input.device, dtype=input.dtype) q_output = torch.empty(batch_size, q_hidden_size, device=input.device, dtype=input.dtype)
k_output = torch.empty(batch_size, kv_hidden_size, device=input.device, dtype=input.dtype) k_output = torch.empty(batch_size, kv_hidden_size, device=input.device, dtype=input.dtype)
v_output = torch.empty(batch_size, kv_hidden_size, device=input.device, dtype=input.dtype) v_output = torch.empty(batch_size, kv_hidden_size, device=input.device, dtype=input.dtype)
n_cols = kv_hidden_size // KV_BLOCK_SIZE
num_vectorcore = get_vectorcore_num()
assert num_vectorcore % n_cols == 0
n_rows = num_vectorcore // n_cols
BIAS = q_bias is not None
split_qkv_rmsnorm_rope_kernel[(n_rows, n_cols, 1)]( q_head_num = q_hidden_size // head_dim
kv_head_num = kv_hidden_size // head_dim
# set number of line loading from GM data is x
# x*(q_head_num + kv_head_num)*HEAD_DIM: values_tmp
# 2x*(q_head_num + kv_head_num)*HEAD_DIM: normalized_values(float32)
# x*ROPE_DIM*2 : cos/sin
# x*q_head_num*HEAD_DIM*2 normalized_values_tmp
# x*q_head_num*ROPE_DIM*(0.5) (not IS_PARTIAL_ROPE) x*q_head_num*ROPE_DIM*(0.5): y
UB_SIZE = 87040 # 85K = 85 * 1024
# the factor is the sum of elements number
if IS_PARTIAL_ROPE:
factor = 5 * q_hidden_size + 3 * kv_hidden_size + rope_dim * 4 + q_head_num * rope_dim
batch_size_per_iter_per_vec = int(UB_SIZE / input.element_size()) // factor
else:
factor = 5 * q_hidden_size + 3 * kv_hidden_size + rope_dim * 2 + q_head_num * rope_dim // 2
batch_size_per_iter_per_vec = int(UB_SIZE / input.element_size()) // factor
batch_size_per_iter_per_vec = max(1, batch_size_per_iter_per_vec)
qk_head_num_sum = int(q_head_num + kv_head_num)
qk_head_nums_per_iter_per_vec = batch_size_per_iter_per_vec * qk_head_num_sum
grid = (num_vectorcore, 1, 1)
# v tiling
v_batch_size_per_iter_per_vec = UB_SIZE / torch.bfloat16.itemsize // (kv_hidden_size + 1)
split_qkv_rmsnorm_rope_kernel[grid](
input, input,
cos_sin_cache,
positions,
q_output, q_output,
k_output, k_output,
v_output, v_output,
@@ -238,11 +327,20 @@ def split_qkv_rmsnorm_rope_impl(
kv_hidden_size, kv_hidden_size,
total_hidden_size, total_hidden_size,
eps, eps,
Q_BLOCK_SIZE,
KV_BLOCK_SIZE,
BIAS, BIAS,
head_dim, head_dim,
head_dim // 2, rope_dim,
rope_dim // 2,
IS_PARTIAL_ROPE,
num_vectorcore,
int(batch_size_per_iter_per_vec),
int(qk_head_nums_per_iter_per_vec),
q_head_num,
kv_head_num,
qk_head_num_sum,
int(v_batch_size_per_iter_per_vec),
positions,
cos_sin_cache,
) )
return q_output, k_output, v_output return q_output, k_output, v_output
@@ -265,19 +363,19 @@ def split_qkv_rmsnorm_rope_impl_fake(
batch_size = input.shape[0] batch_size = input.shape[0]
q_output = torch.empty( q_output = torch.empty(
batch_size, batch_size,
q_hidden_size, int(q_hidden_size),
device=input.device, device=input.device,
dtype=input.dtype, dtype=input.dtype,
) )
k_output = torch.empty( k_output = torch.empty(
batch_size, batch_size,
kv_hidden_size, int(kv_hidden_size),
device=input.device, device=input.device,
dtype=input.dtype, dtype=input.dtype,
) )
v_output = torch.empty( v_output = torch.empty(
batch_size, batch_size,
kv_hidden_size, int(kv_hidden_size),
device=input.device, device=input.device,
dtype=input.dtype, dtype=input.dtype,
) )