[Ops][Misc] Optimize split_qkv_rmsnorm_rope op (#6827)
### What this PR does / why we need it?
This PR optimizes the `split_qkv_rmsnorm_rope` operator by introducing a
new Triton kernel, `split_qkv_rmsnorm_rope_prefill_kernel`, for the
prefill stage (i.e., large batch sizes). The implementation now
dynamically selects between the existing decode kernel and the new
prefill kernel based on the batch size, which improves performance for
large batch scenarios.
Additionally, the RoPE implementation is updated to support partial
rotation dimensions (`rope_dim`), making the operator more flexible.
### Does this PR introduce _any_ user-facing change?
No. This is a performance optimization and is not expected to introduce
any user-facing changes.
### How was this patch tested?
CI should pass with existing tests. The new prefill path is triggered
when the batch size is larger than the number of available vector cores.
The partial RoPE feature can be tested by passing the `rope_dim`
argument.
- vLLM version: v0.15.0
- vLLM main:
83b47f67b1
---------
Signed-off-by: guzhiyong <guzhiyong5@h-partners.com>
Signed-off-by: frank <2547457096@qq.com>
Co-authored-by: guzhiyong <guzhiyong5@h-partners.com>
This commit is contained in:
@@ -11,6 +11,7 @@ MAX_POSITION_EMBEDDINGS = [262144]
|
||||
NUM_TOKENS = [1, 4, 8, 16, 1024]
|
||||
NUM_QKV_HEADS = [(12, 1), (16, 1), (32, 4), (64, 4)]
|
||||
HEAD_SIZES = [128]
|
||||
ROPE_DIMS = [64, 128]
|
||||
EPS = [1e-6]
|
||||
DTYPES = [torch.bfloat16]
|
||||
SEEDS = [0]
|
||||
@@ -23,19 +24,26 @@ def custom_rope(q, k, sin, cos):
|
||||
rotary_dim = sin.shape[-1]
|
||||
sin = sin.to(torch.float32)
|
||||
cos = cos.to(torch.float32)
|
||||
x1 = q[..., :rotary_dim // 2]
|
||||
x2 = q[..., rotary_dim // 2:]
|
||||
cat_x = torch.cat([-x2, x1], axis=-1)
|
||||
mul1 = cat_x * sin
|
||||
mul2 = q * cos
|
||||
res1 = mul1 + mul2
|
||||
q_rot = q[..., :rotary_dim]
|
||||
k_rot = k[..., :rotary_dim]
|
||||
q_pass = q[..., rotary_dim:]
|
||||
k_pass = k[..., rotary_dim:]
|
||||
|
||||
x1 = k[..., :rotary_dim // 2]
|
||||
x2 = k[..., rotary_dim // 2:]
|
||||
x1 = q_rot[..., :rotary_dim // 2]
|
||||
x2 = q_rot[..., rotary_dim // 2:]
|
||||
cat_x = torch.cat([-x2, x1], axis=-1)
|
||||
mul1 = cat_x * sin
|
||||
mul2 = k * cos
|
||||
res2 = mul1 + mul2
|
||||
mul2 = q_rot * cos
|
||||
q_rot = mul1 + mul2
|
||||
res1 = torch.cat([q_rot, q_pass], dim=-1)
|
||||
|
||||
x1 = k_rot[..., :rotary_dim // 2]
|
||||
x2 = k_rot[..., rotary_dim // 2:]
|
||||
cat_x = torch.cat([-x2, x1], axis=-1)
|
||||
mul1 = cat_x * sin
|
||||
mul2 = k_rot * cos
|
||||
k_rot = mul1 + mul2
|
||||
res2 = torch.cat([k_rot, k_pass], dim=-1)
|
||||
return res1, res2
|
||||
|
||||
|
||||
@@ -64,9 +72,10 @@ def rms_norm(
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
@pytest.mark.parametrize("rope_dim", ROPE_DIMS)
|
||||
@torch.inference_mode()
|
||||
def test_split_qkv_rmsnorm_rope(max_position_embeddings, num_tokens, num_q_heads, num_kv_heads,
|
||||
head_size, eps, dtype, seed, device):
|
||||
head_size, eps, dtype, seed, device, rope_dim):
|
||||
torch.manual_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
init_device_properties_triton()
|
||||
@@ -81,7 +90,7 @@ def test_split_qkv_rmsnorm_rope(max_position_embeddings, num_tokens, num_q_heads
|
||||
k_weight = torch.randn(head_size, dtype=dtype, device=device)
|
||||
cos_sin_cache = torch.from_numpy(
|
||||
np.random.uniform(0, 1,
|
||||
[max_position_embeddings, head_size])).to(dtype).npu()
|
||||
[max_position_embeddings, rope_dim])).to(dtype).npu()
|
||||
positions = torch.randint(low=0, high=max_position_embeddings, size=(num_tokens,), dtype=torch.int64, device=device)
|
||||
# fused kernel
|
||||
q, k, v = torch.ops.vllm.qkv_rmsnorm_rope(input=qkv,
|
||||
@@ -141,10 +150,11 @@ def test_split_qkv_rmsnorm_rope(max_position_embeddings, num_tokens, num_q_heads
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
@pytest.mark.parametrize("rope_dim", ROPE_DIMS)
|
||||
@torch.inference_mode()
|
||||
def test_split_qkv_rmsnorm_rope_with_bias(max_position_embeddings, num_tokens, num_q_heads,
|
||||
num_kv_heads, head_size, eps, dtype,
|
||||
seed, device):
|
||||
seed, device, rope_dim):
|
||||
torch.manual_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
init_device_properties_triton()
|
||||
@@ -161,7 +171,7 @@ def test_split_qkv_rmsnorm_rope_with_bias(max_position_embeddings, num_tokens, n
|
||||
k_bias = torch.randn(head_size, dtype=dtype, device=device)
|
||||
cos_sin_cache = torch.from_numpy(
|
||||
np.random.uniform(0, 1,
|
||||
[max_position_embeddings, head_size])).to(dtype).npu()
|
||||
[max_position_embeddings, rope_dim])).to(dtype).npu()
|
||||
positions = torch.randint(low=0, high=max_position_embeddings, size=(num_tokens,), dtype=torch.int64, device=device)
|
||||
# fused kernel
|
||||
q, k, v = torch.ops.vllm.qkv_rmsnorm_rope(input=qkv,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import copy
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -16,6 +17,8 @@ from vllm_ascend.compilation.passes.qknorm_rope_fusion_pass import (
|
||||
)
|
||||
from vllm_ascend.ops.triton.triton_utils import init_device_properties_triton
|
||||
|
||||
MAX_POSITION_EMBEDDING = 262144
|
||||
|
||||
|
||||
def find_op(gm, op_default):
|
||||
return any(node.op == "call_function" and node.target == op_default for node in gm.graph.nodes)
|
||||
@@ -207,8 +210,10 @@ def test_rmsnorm_quant_fusion(
|
||||
model = model.to("npu")
|
||||
seq_len = 5
|
||||
qkv = torch.randn(seq_len, qkv_size, device="npu", dtype=dtype)
|
||||
cos = torch.randn(1, seq_len, 1, head_dim, device="npu", dtype=dtype)
|
||||
sin = torch.randn(1, seq_len, 1, head_dim, device="npu", dtype=dtype)
|
||||
cos_sin_cache = torch.from_numpy(np.random.uniform(0, 1, [MAX_POSITION_EMBEDDING, head_dim])).to(dtype).npu()
|
||||
positions = torch.randint(
|
||||
low=0, high=MAX_POSITION_EMBEDDING, size=(num_tokens,), dtype=torch.int64, device="npu"
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
original_optimize = torchair.npu_fx_compiler._optimize_fx
|
||||
@@ -218,6 +223,6 @@ def test_rmsnorm_quant_fusion(
|
||||
|
||||
compiled_model = torch.compile(model, backend="npugraph_ex", fullgraph=True, dynamic=True)
|
||||
|
||||
compiled_model(qkv, cos, sin)
|
||||
compiled_model(qkv, cos_sin_cache, positions)
|
||||
|
||||
torchair.npu_fx_compiler._optimize_fx = original_optimize
|
||||
|
||||
@@ -74,7 +74,7 @@ CASE_QWEN_FULL_DECODE_ONLY = LLMTestCase(
|
||||
prompts=PROMPTS_LONG,
|
||||
golden_answers=[
|
||||
" \n\nTo solve this problem, we need to use the Law of Sines and Law of Cosines. Let me start by drawing triangle $ABC$ with the",
|
||||
" \n\nTo solve this problem, we can use the fact that the expected value of the area of a triangle with vertices on a square can be calculated by integrating over",
|
||||
" \n\nTo solve this problem, we can use the following approach: Let $P$ be the perimeter of the square. Then, the expected value of the area",
|
||||
" \n\nTo solve this problem, we can use the following approach: Let $ \\alpha $ be the common real root of the two equations. Then, we can",
|
||||
],
|
||||
)
|
||||
@@ -95,7 +95,7 @@ CASE_QWEN_EX = LLMTestCase(
|
||||
prompts=PROMPTS_LONG,
|
||||
golden_answers=[
|
||||
" \n\nTo solve this problem, we need to use the Law of Sines and Law of Cosines. Let me start by drawing triangle $ABC$ with the",
|
||||
" \n\nTo solve this problem, we can use the fact that the expected value of the area of a triangle with vertices on a square can be calculated by integrating over",
|
||||
" \n\nTo solve this problem, we can use the following approach: Let $P$ be the perimeter of the square. Then, the expected value of the area",
|
||||
" \n\nTo solve this problem, we can use the following approach: Let $ \\alpha $ be the common real root of the two equations. Then, we can",
|
||||
],
|
||||
)
|
||||
|
||||
@@ -20,17 +20,15 @@ import triton # type: ignore
|
||||
import triton.language as tl # type: ignore
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
from vllm_ascend.ops.triton.triton_utils import extract_slice, get_vectorcore_num, insert_slice
|
||||
from vllm_ascend.ops.triton.triton_utils import extract_slice, get_element, get_vectorcore_num, insert_slice
|
||||
|
||||
|
||||
@triton.jit
|
||||
def split_qkv_rmsnorm_rope_kernel(
|
||||
input_ptr,
|
||||
cos_sin_ptr,
|
||||
pos_ptr,
|
||||
q_ptr,
|
||||
k_ptr,
|
||||
v_ptr,
|
||||
input_gm_ptr,
|
||||
q_gm_ptr,
|
||||
k_gm_ptr,
|
||||
v_gm_ptr,
|
||||
q_weight_ptr,
|
||||
q_bias_ptr,
|
||||
k_weight_ptr,
|
||||
@@ -40,158 +38,228 @@ def split_qkv_rmsnorm_rope_kernel(
|
||||
kv_hidden_size: tl.constexpr,
|
||||
total_hidden_size: tl.constexpr,
|
||||
eps: tl.constexpr,
|
||||
Q_BLOCK_SIZE: tl.constexpr,
|
||||
KV_BLOCK_SIZE: tl.constexpr,
|
||||
BIAS: tl.constexpr,
|
||||
HEAD_DIM: tl.constexpr,
|
||||
HALF_HEAD_DIM: tl.constexpr,
|
||||
ROPE_DIM: tl.constexpr,
|
||||
HALF_ROPE_DIM: tl.constexpr,
|
||||
IS_PARTIAL_ROPE: tl.constexpr,
|
||||
num_vectorcore: tl.constexpr,
|
||||
batch_size_per_iter_per_vec: tl.constexpr,
|
||||
qk_head_nums_per_iter_per_vec: tl.constexpr,
|
||||
q_head_num: tl.constexpr,
|
||||
kv_head_num: tl.constexpr,
|
||||
qk_head_num_sum: tl.constexpr,
|
||||
v_batch_size_per_iter_per_vec: tl.constexpr,
|
||||
positions_gm_ptr,
|
||||
cos_sin_cache_gm_ptr,
|
||||
):
|
||||
row_pid = tl.program_id(0)
|
||||
col_pid = tl.program_id(1)
|
||||
row_step = tl.num_programs(0)
|
||||
# q
|
||||
weight_values = tl.load(q_weight_ptr + tl.arange(0, HEAD_DIM))
|
||||
if BIAS:
|
||||
bias_values = tl.load(q_bias_ptr + tl.arange(0, HEAD_DIM))
|
||||
input_offset = row_pid * total_hidden_size
|
||||
output_offset = row_pid * q_hidden_size
|
||||
input_offset_step = row_step * total_hidden_size
|
||||
output_offset_step = row_step * q_hidden_size
|
||||
for row_idx in tl.range(row_pid, batch_size, row_step):
|
||||
col_indices = col_pid * Q_BLOCK_SIZE + tl.arange(0, Q_BLOCK_SIZE)
|
||||
valid_mask = col_indices < q_hidden_size
|
||||
input_values = (
|
||||
tl.load(input_ptr + input_offset + col_indices, mask=valid_mask, other=0.0)
|
||||
.to(tl.float32)
|
||||
.reshape(Q_BLOCK_SIZE // HEAD_DIM, HEAD_DIM)
|
||||
)
|
||||
squares = input_values * input_values
|
||||
variances = tl.sum(squares, axis=1) / HEAD_DIM
|
||||
reciprocal_std = (1 / tl.sqrt(variances + eps)).reshape(Q_BLOCK_SIZE // HEAD_DIM, 1)
|
||||
normalized_values = input_values * reciprocal_std # (Q_BLOCK_SIZE//HEAD_DIM, HEAD_DIM)
|
||||
if BIAS:
|
||||
normalized_values = (normalized_values * weight_values + bias_values).to(tl.bfloat16)
|
||||
else:
|
||||
normalized_values = (normalized_values * weight_values).to(tl.bfloat16)
|
||||
|
||||
pos_idx = tl.load(pos_ptr + row_idx).to(tl.int64)
|
||||
cos_offsets = pos_idx * HEAD_DIM + tl.arange(0, HALF_HEAD_DIM)
|
||||
sin_offsets = pos_idx * HEAD_DIM + tl.arange(HALF_HEAD_DIM, HEAD_DIM)
|
||||
cos = (tl.load(cos_sin_ptr + cos_offsets)).reshape(1, HALF_HEAD_DIM)
|
||||
sin = (tl.load(cos_sin_ptr + sin_offsets)).reshape(1, HALF_HEAD_DIM)
|
||||
q_weight_values = tl.load(q_weight_ptr + tl.arange(0, HEAD_DIM))
|
||||
k_weight_values = tl.load(k_weight_ptr + tl.arange(0, HEAD_DIM))
|
||||
|
||||
batch_size_per_vec = tl.cdiv(batch_size, num_vectorcore)
|
||||
iter_num_per_vec = tl.cdiv(batch_size_per_vec, batch_size_per_iter_per_vec)
|
||||
v_iter_num_per_vec = tl.cdiv(batch_size_per_vec, v_batch_size_per_iter_per_vec)
|
||||
input_batch_offset = row_pid * batch_size_per_vec
|
||||
mblk_idx = tl.arange(0, batch_size_per_iter_per_vec) + input_batch_offset
|
||||
nblk_idx = tl.arange(0, q_hidden_size + kv_hidden_size)
|
||||
nmask = nblk_idx < total_hidden_size
|
||||
|
||||
input_batch_offset_end = min(input_batch_offset + batch_size_per_vec, batch_size)
|
||||
|
||||
pos_indices = input_batch_offset + tl.arange(0, batch_size_per_iter_per_vec)
|
||||
output_q_nblk_idx = tl.arange(0, q_hidden_size)
|
||||
output_q_nmask = output_q_nblk_idx < q_hidden_size
|
||||
output_kv_nblk_idx = tl.arange(0, kv_hidden_size)
|
||||
output_kv_nmask = output_kv_nblk_idx < kv_hidden_size
|
||||
sin_cos_range = tl.arange(0, ROPE_DIM)
|
||||
cos_sin_cache_offset = cos_sin_cache_gm_ptr + sin_cos_range
|
||||
|
||||
for iter in tl.range(iter_num_per_vec):
|
||||
pos_offset = iter * batch_size_per_iter_per_vec
|
||||
x = tl.load(
|
||||
positions_gm_ptr + pos_indices + pos_offset, mask=(pos_indices + pos_offset) < input_batch_offset_end
|
||||
)
|
||||
mmask = (mblk_idx + pos_offset) < input_batch_offset_end
|
||||
mask = (mmask[:, None]) & (nmask[None, :])
|
||||
idx = (mblk_idx + pos_offset)[:, None] * total_hidden_size + nblk_idx[None, :]
|
||||
values_tmp1 = tl.load(input_gm_ptr + idx, mask=mask).reshape(qk_head_nums_per_iter_per_vec, HEAD_DIM)
|
||||
if BIAS:
|
||||
q_bias_values = tl.load(q_bias_ptr + tl.arange(0, HEAD_DIM))
|
||||
k_bias_values = tl.load(k_bias_ptr + tl.arange(0, HEAD_DIM))
|
||||
|
||||
values_tmp3 = tl.zeros((batch_size_per_iter_per_vec, ROPE_DIM), dtype=tl.bfloat16)
|
||||
for i in tl.range(batch_size_per_iter_per_vec):
|
||||
pos = get_element(x, (i,))
|
||||
values_tmp3 = insert_slice(
|
||||
values_tmp3.reshape(batch_size_per_iter_per_vec, ROPE_DIM),
|
||||
tl.load(pos * ROPE_DIM + cos_sin_cache_offset[:, None]).reshape(1, ROPE_DIM),
|
||||
offsets=(i, 0),
|
||||
sizes=(1, ROPE_DIM),
|
||||
strides=(1, 1),
|
||||
)
|
||||
values_tmp3 = values_tmp3.reshape(batch_size_per_iter_per_vec, 1, ROPE_DIM)
|
||||
cos = extract_slice(
|
||||
values_tmp3,
|
||||
offsets=(0, 0, 0),
|
||||
sizes=(batch_size_per_iter_per_vec, 1, HALF_ROPE_DIM),
|
||||
strides=(1, 1, 1),
|
||||
)
|
||||
sin = extract_slice(
|
||||
values_tmp3,
|
||||
offsets=(0, 0, HALF_ROPE_DIM),
|
||||
sizes=(batch_size_per_iter_per_vec, 1, HALF_ROPE_DIM),
|
||||
strides=(1, 1, 1),
|
||||
)
|
||||
|
||||
normalized_values = values_tmp1.to(tl.float32)
|
||||
normalized_values = normalized_values * normalized_values
|
||||
normalized_values = tl.sum(normalized_values, axis=1) / HEAD_DIM
|
||||
normalized_values = 1 / tl.sqrt(normalized_values + eps).reshape(qk_head_nums_per_iter_per_vec, 1)
|
||||
normalized_values = values_tmp1 * normalized_values
|
||||
|
||||
normalized_values_tmp = extract_slice(
|
||||
normalized_values.reshape(batch_size_per_iter_per_vec, qk_head_num_sum, HEAD_DIM),
|
||||
offsets=(0, 0, 0),
|
||||
sizes=(batch_size_per_iter_per_vec, q_head_num, HEAD_DIM),
|
||||
strides=(1, 1, 1),
|
||||
)
|
||||
|
||||
if BIAS:
|
||||
normalized_values_tmp = (normalized_values_tmp * q_weight_values + q_bias_values).to(tl.bfloat16)
|
||||
else:
|
||||
normalized_values_tmp = (normalized_values_tmp * q_weight_values).to(tl.bfloat16)
|
||||
|
||||
# q rope
|
||||
values_tmp = tl.zeros((batch_size_per_iter_per_vec, q_head_num, ROPE_DIM), dtype=tl.bfloat16)
|
||||
x1 = extract_slice(
|
||||
normalized_values,
|
||||
offsets=(0, 0),
|
||||
sizes=(Q_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
|
||||
strides=(1, 1),
|
||||
normalized_values_tmp,
|
||||
offsets=(0, 0, 0),
|
||||
sizes=(batch_size_per_iter_per_vec, q_head_num, HALF_ROPE_DIM),
|
||||
strides=(1, 1, 1),
|
||||
)
|
||||
x2 = extract_slice(
|
||||
normalized_values,
|
||||
offsets=(0, HALF_HEAD_DIM),
|
||||
sizes=(Q_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
|
||||
strides=(1, 1),
|
||||
normalized_values_tmp,
|
||||
offsets=(0, 0, HALF_ROPE_DIM),
|
||||
sizes=(batch_size_per_iter_per_vec, q_head_num, HALF_ROPE_DIM),
|
||||
strides=(1, 1, 1),
|
||||
)
|
||||
roped_q1 = x1 * cos - x2 * sin
|
||||
roped_q2 = x2 * cos + x1 * sin
|
||||
|
||||
roped_q = tl.zeros((Q_BLOCK_SIZE // HEAD_DIM, HEAD_DIM), dtype=tl.bfloat16)
|
||||
roped_q = insert_slice(
|
||||
roped_q,
|
||||
roped_q1,
|
||||
offsets=(0, 0),
|
||||
sizes=(Q_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
|
||||
strides=(1, 1),
|
||||
values_tmp = insert_slice(
|
||||
values_tmp,
|
||||
x1 * cos - x2 * sin,
|
||||
offsets=(0, 0, 0),
|
||||
sizes=(batch_size_per_iter_per_vec, q_head_num, HALF_ROPE_DIM),
|
||||
strides=(1, 1, 1),
|
||||
)
|
||||
roped_q = insert_slice(
|
||||
roped_q,
|
||||
roped_q2,
|
||||
offsets=(0, HALF_HEAD_DIM),
|
||||
sizes=(Q_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
|
||||
strides=(1, 1),
|
||||
values_tmp = insert_slice(
|
||||
values_tmp,
|
||||
x2 * cos + x1 * sin,
|
||||
offsets=(0, 0, HALF_ROPE_DIM),
|
||||
sizes=(batch_size_per_iter_per_vec, q_head_num, HALF_ROPE_DIM),
|
||||
strides=(1, 1, 1),
|
||||
)
|
||||
tl.store(
|
||||
q_ptr + output_offset + col_indices,
|
||||
roped_q.reshape(Q_BLOCK_SIZE).to(q_ptr.dtype.element_ty),
|
||||
mask=valid_mask,
|
||||
)
|
||||
input_offset += input_offset_step
|
||||
output_offset += output_offset_step
|
||||
|
||||
weight_values = tl.load(k_weight_ptr + tl.arange(0, HEAD_DIM))
|
||||
if BIAS:
|
||||
bias_values = tl.load(k_bias_ptr + tl.arange(0, HEAD_DIM))
|
||||
input_offset = row_pid * total_hidden_size + q_hidden_size
|
||||
output_offset = row_pid * kv_hidden_size
|
||||
output_offset_step = row_step * kv_hidden_size
|
||||
for row_idx in tl.range(row_pid, batch_size, row_step):
|
||||
col_indices = col_pid * KV_BLOCK_SIZE + tl.arange(0, KV_BLOCK_SIZE)
|
||||
valid_mask = col_indices < kv_hidden_size
|
||||
input_values = (
|
||||
tl.load(input_ptr + input_offset + col_indices, mask=valid_mask, other=0.0)
|
||||
.to(tl.float32)
|
||||
.reshape(KV_BLOCK_SIZE // HEAD_DIM, HEAD_DIM)
|
||||
)
|
||||
squares = input_values * input_values
|
||||
variances = tl.sum(squares, axis=1) / HEAD_DIM
|
||||
reciprocal_std = (1 / tl.sqrt(variances + eps)).reshape(KV_BLOCK_SIZE // HEAD_DIM, 1)
|
||||
normalized_values = input_values * reciprocal_std # (KV_BLOCK_SIZE/HEAD_DIM, HEAD_DIM)
|
||||
if BIAS:
|
||||
normalized_values = (normalized_values * weight_values + bias_values).to(tl.bfloat16)
|
||||
q_output_idx = output_q_nblk_idx[None, :] + (mblk_idx + pos_offset)[:, None] * q_hidden_size
|
||||
mask = (mmask[:, None]) & (output_q_nmask[None, :])
|
||||
if IS_PARTIAL_ROPE:
|
||||
normalized_values_tmp = insert_slice(
|
||||
normalized_values_tmp,
|
||||
values_tmp,
|
||||
offsets=(0, 0, 0),
|
||||
sizes=(batch_size_per_iter_per_vec, q_head_num, ROPE_DIM),
|
||||
strides=(1, 1, 1),
|
||||
)
|
||||
tl.store(
|
||||
q_gm_ptr + q_output_idx,
|
||||
normalized_values_tmp.reshape(batch_size_per_iter_per_vec, q_hidden_size),
|
||||
mask=mask,
|
||||
)
|
||||
else:
|
||||
normalized_values = (normalized_values * weight_values).to(tl.bfloat16)
|
||||
tl.store(
|
||||
q_gm_ptr + q_output_idx,
|
||||
values_tmp.reshape(batch_size_per_iter_per_vec, q_hidden_size),
|
||||
mask=mask,
|
||||
)
|
||||
|
||||
# k rope
|
||||
normalized_values_tmp1 = extract_slice(
|
||||
normalized_values.reshape(batch_size_per_iter_per_vec, qk_head_num_sum, HEAD_DIM),
|
||||
offsets=(0, q_head_num, 0),
|
||||
sizes=(batch_size_per_iter_per_vec, kv_head_num, HEAD_DIM),
|
||||
strides=(1, 1, 1),
|
||||
)
|
||||
|
||||
if BIAS:
|
||||
normalized_values_tmp1 = (normalized_values_tmp1 * k_weight_values + k_bias_values).to(tl.bfloat16)
|
||||
else:
|
||||
normalized_values_tmp1 = (normalized_values_tmp1 * k_weight_values).to(tl.bfloat16)
|
||||
|
||||
values_tmp2 = tl.zeros((batch_size_per_iter_per_vec, kv_head_num, ROPE_DIM), dtype=tl.bfloat16)
|
||||
|
||||
pos_idx = tl.load(pos_ptr + row_idx).to(tl.int64)
|
||||
cos_offsets = pos_idx * HEAD_DIM + tl.arange(0, HALF_HEAD_DIM)
|
||||
sin_offsets = pos_idx * HEAD_DIM + tl.arange(HALF_HEAD_DIM, HEAD_DIM)
|
||||
cos = (tl.load(cos_sin_ptr + cos_offsets)).reshape(1, HALF_HEAD_DIM)
|
||||
sin = (tl.load(cos_sin_ptr + sin_offsets)).reshape(1, HALF_HEAD_DIM)
|
||||
x1 = extract_slice(
|
||||
normalized_values,
|
||||
offsets=(0, 0),
|
||||
sizes=(KV_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
|
||||
strides=(1, 1),
|
||||
normalized_values_tmp1,
|
||||
offsets=(0, 0, 0),
|
||||
sizes=(batch_size_per_iter_per_vec, kv_head_num, HALF_ROPE_DIM),
|
||||
strides=(1, 1, 1),
|
||||
)
|
||||
x2 = extract_slice(
|
||||
normalized_values,
|
||||
offsets=(0, HALF_HEAD_DIM),
|
||||
sizes=(KV_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
|
||||
strides=(1, 1),
|
||||
normalized_values_tmp1,
|
||||
offsets=(0, 0, HALF_ROPE_DIM),
|
||||
sizes=(batch_size_per_iter_per_vec, kv_head_num, HALF_ROPE_DIM),
|
||||
strides=(1, 1, 1),
|
||||
)
|
||||
values_tmp2 = insert_slice(
|
||||
values_tmp2,
|
||||
x1 * cos - x2 * sin,
|
||||
offsets=(0, 0, 0),
|
||||
sizes=(batch_size_per_iter_per_vec, kv_head_num, HALF_ROPE_DIM),
|
||||
strides=(1, 1, 1),
|
||||
)
|
||||
values_tmp2 = insert_slice(
|
||||
values_tmp2,
|
||||
x2 * cos + x1 * sin,
|
||||
offsets=(0, 0, HALF_ROPE_DIM),
|
||||
sizes=(batch_size_per_iter_per_vec, kv_head_num, HALF_ROPE_DIM),
|
||||
strides=(1, 1, 1),
|
||||
)
|
||||
roped_k1 = x1 * cos - x2 * sin
|
||||
roped_k2 = x2 * cos + x1 * sin
|
||||
|
||||
roped_k = tl.zeros((KV_BLOCK_SIZE // HEAD_DIM, HEAD_DIM), dtype=tl.bfloat16)
|
||||
roped_k = insert_slice(
|
||||
roped_k,
|
||||
roped_k1,
|
||||
offsets=(0, 0),
|
||||
sizes=(KV_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
|
||||
strides=(1, 1),
|
||||
)
|
||||
roped_k = insert_slice(
|
||||
roped_k,
|
||||
roped_k2,
|
||||
offsets=(0, HALF_HEAD_DIM),
|
||||
sizes=(KV_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
|
||||
strides=(1, 1),
|
||||
)
|
||||
tl.store(
|
||||
k_ptr + output_offset + col_indices,
|
||||
roped_k.to(tl.bfloat16).reshape(KV_BLOCK_SIZE),
|
||||
mask=valid_mask,
|
||||
)
|
||||
input_offset += input_offset_step
|
||||
output_offset += output_offset_step
|
||||
kv_output_idx = output_kv_nblk_idx[None, :] + (mblk_idx + pos_offset)[:, None] * kv_hidden_size
|
||||
mask = (mmask[:, None]) & (output_kv_nmask[None, :])
|
||||
if IS_PARTIAL_ROPE:
|
||||
normalized_values_tmp1 = insert_slice(
|
||||
normalized_values_tmp1,
|
||||
values_tmp2,
|
||||
offsets=(0, 0, 0),
|
||||
sizes=(batch_size_per_iter_per_vec, kv_head_num, ROPE_DIM),
|
||||
strides=(1, 1, 1),
|
||||
)
|
||||
tl.store(
|
||||
k_gm_ptr + kv_output_idx,
|
||||
normalized_values_tmp1.reshape(batch_size_per_iter_per_vec, kv_hidden_size),
|
||||
mask=mask,
|
||||
)
|
||||
else:
|
||||
tl.store(
|
||||
k_gm_ptr + kv_output_idx,
|
||||
values_tmp2.reshape(batch_size_per_iter_per_vec, kv_hidden_size),
|
||||
mask=mask,
|
||||
)
|
||||
|
||||
input_offset = row_pid * total_hidden_size + q_hidden_size + kv_hidden_size
|
||||
output_offset = row_pid * kv_hidden_size
|
||||
for _ in tl.range(row_pid, batch_size, row_step):
|
||||
col_indices = col_pid * KV_BLOCK_SIZE + tl.arange(0, KV_BLOCK_SIZE)
|
||||
valid_mask = col_indices < kv_hidden_size
|
||||
input_values = tl.load(input_ptr + input_offset + col_indices, mask=valid_mask, other=0.0)
|
||||
tl.store(v_ptr + output_offset + col_indices, input_values, mask=valid_mask)
|
||||
input_offset += input_offset_step
|
||||
output_offset += output_offset_step
|
||||
mblk_idx = tl.arange(0, v_batch_size_per_iter_per_vec) + input_batch_offset
|
||||
nblk_idx = tl.arange(q_hidden_size + kv_hidden_size, total_hidden_size)
|
||||
nmask = nblk_idx < total_hidden_size
|
||||
out_nblk_idx = tl.arange(0, kv_hidden_size)
|
||||
out_nmask = out_nblk_idx < kv_hidden_size
|
||||
|
||||
for _ in tl.range(v_iter_num_per_vec):
|
||||
mmask = mblk_idx < input_batch_offset_end
|
||||
mask = (mmask[:, None]) & (nmask[None, :])
|
||||
idx = mblk_idx[:, None] * total_hidden_size + nblk_idx[None, :]
|
||||
values = tl.load(input_gm_ptr + idx, mask=mask)
|
||||
out_idx = mblk_idx[:, None] * kv_hidden_size + out_nblk_idx[None, :]
|
||||
out_mask = (mmask[:, None]) & (out_nmask[None, :])
|
||||
tl.store(v_gm_ptr + out_idx, values, mask=out_mask)
|
||||
mblk_idx += v_batch_size_per_iter_per_vec
|
||||
|
||||
|
||||
def split_qkv_rmsnorm_rope_impl(
|
||||
@@ -207,25 +275,46 @@ def split_qkv_rmsnorm_rope_impl(
|
||||
q_bias: torch.Tensor | None = None,
|
||||
k_bias: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
KV_BLOCK_SIZE = triton.next_power_of_2(head_dim)
|
||||
assert head_dim == KV_BLOCK_SIZE
|
||||
assert q_hidden_size % kv_hidden_size == 0
|
||||
Q_BLOCK_SIZE = q_hidden_size // kv_hidden_size * head_dim
|
||||
# get available vector core
|
||||
num_vectorcore = get_vectorcore_num()
|
||||
rope_dim = cos_sin_cache.shape[-1]
|
||||
batch_size = input.shape[0]
|
||||
BIAS = q_bias is not None
|
||||
IS_PARTIAL_ROPE = rope_dim != head_dim
|
||||
# Q + K + V
|
||||
total_hidden_size = q_hidden_size + kv_hidden_size * 2
|
||||
|
||||
q_output = torch.empty(batch_size, q_hidden_size, device=input.device, dtype=input.dtype)
|
||||
k_output = torch.empty(batch_size, kv_hidden_size, device=input.device, dtype=input.dtype)
|
||||
v_output = torch.empty(batch_size, kv_hidden_size, device=input.device, dtype=input.dtype)
|
||||
n_cols = kv_hidden_size // KV_BLOCK_SIZE
|
||||
num_vectorcore = get_vectorcore_num()
|
||||
assert num_vectorcore % n_cols == 0
|
||||
n_rows = num_vectorcore // n_cols
|
||||
BIAS = q_bias is not None
|
||||
|
||||
split_qkv_rmsnorm_rope_kernel[(n_rows, n_cols, 1)](
|
||||
q_head_num = q_hidden_size // head_dim
|
||||
kv_head_num = kv_hidden_size // head_dim
|
||||
|
||||
# set number of line loading from GM data is x
|
||||
# x*(q_head_num + kv_head_num)*HEAD_DIM: values_tmp
|
||||
# 2x*(q_head_num + kv_head_num)*HEAD_DIM: normalized_values(float32)
|
||||
# x*ROPE_DIM*2 : cos/sin
|
||||
# x*q_head_num*HEAD_DIM*2: normalized_values_tmp
|
||||
# x*q_head_num*ROPE_DIM*(0.5) (not IS_PARTIAL_ROPE) x*q_head_num*ROPE_DIM*(0.5): y
|
||||
UB_SIZE = 87040 # 85K = 85 * 1024
|
||||
# the factor is the sum of elements number
|
||||
if IS_PARTIAL_ROPE:
|
||||
factor = 5 * q_hidden_size + 3 * kv_hidden_size + rope_dim * 4 + q_head_num * rope_dim
|
||||
batch_size_per_iter_per_vec = int(UB_SIZE / input.element_size()) // factor
|
||||
else:
|
||||
factor = 5 * q_hidden_size + 3 * kv_hidden_size + rope_dim * 2 + q_head_num * rope_dim // 2
|
||||
batch_size_per_iter_per_vec = int(UB_SIZE / input.element_size()) // factor
|
||||
batch_size_per_iter_per_vec = max(1, batch_size_per_iter_per_vec)
|
||||
qk_head_num_sum = int(q_head_num + kv_head_num)
|
||||
qk_head_nums_per_iter_per_vec = batch_size_per_iter_per_vec * qk_head_num_sum
|
||||
|
||||
grid = (num_vectorcore, 1, 1)
|
||||
# v tiling
|
||||
v_batch_size_per_iter_per_vec = UB_SIZE / torch.bfloat16.itemsize // (kv_hidden_size + 1)
|
||||
|
||||
split_qkv_rmsnorm_rope_kernel[grid](
|
||||
input,
|
||||
cos_sin_cache,
|
||||
positions,
|
||||
q_output,
|
||||
k_output,
|
||||
v_output,
|
||||
@@ -238,11 +327,20 @@ def split_qkv_rmsnorm_rope_impl(
|
||||
kv_hidden_size,
|
||||
total_hidden_size,
|
||||
eps,
|
||||
Q_BLOCK_SIZE,
|
||||
KV_BLOCK_SIZE,
|
||||
BIAS,
|
||||
head_dim,
|
||||
head_dim // 2,
|
||||
rope_dim,
|
||||
rope_dim // 2,
|
||||
IS_PARTIAL_ROPE,
|
||||
num_vectorcore,
|
||||
int(batch_size_per_iter_per_vec),
|
||||
int(qk_head_nums_per_iter_per_vec),
|
||||
q_head_num,
|
||||
kv_head_num,
|
||||
qk_head_num_sum,
|
||||
int(v_batch_size_per_iter_per_vec),
|
||||
positions,
|
||||
cos_sin_cache,
|
||||
)
|
||||
return q_output, k_output, v_output
|
||||
|
||||
@@ -265,19 +363,19 @@ def split_qkv_rmsnorm_rope_impl_fake(
|
||||
batch_size = input.shape[0]
|
||||
q_output = torch.empty(
|
||||
batch_size,
|
||||
q_hidden_size,
|
||||
int(q_hidden_size),
|
||||
device=input.device,
|
||||
dtype=input.dtype,
|
||||
)
|
||||
k_output = torch.empty(
|
||||
batch_size,
|
||||
kv_hidden_size,
|
||||
int(kv_hidden_size),
|
||||
device=input.device,
|
||||
dtype=input.dtype,
|
||||
)
|
||||
v_output = torch.empty(
|
||||
batch_size,
|
||||
kv_hidden_size,
|
||||
int(kv_hidden_size),
|
||||
device=input.device,
|
||||
dtype=input.dtype,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user