[Ops][Misc] Optimize split_qkv_rmsnorm_rope op (#6827)
### What this PR does / why we need it?
This PR optimizes the `split_qkv_rmsnorm_rope` operator by introducing a
new Triton kernel, `split_qkv_rmsnorm_rope_prefill_kernel`, for the
prefill stage (i.e., large batch sizes). The implementation now
dynamically selects between the existing decode kernel and the new
prefill kernel based on the batch size, which improves performance for
large batch scenarios.
Additionally, the RoPE implementation is updated to support partial
rotation dimensions (`rope_dim`), making the operator more flexible.
### Does this PR introduce _any_ user-facing change?
No. This is a performance optimization and is not expected to introduce
any user-facing changes.
### How was this patch tested?
CI should pass with existing tests. The new prefill path is triggered
when the batch size is larger than the number of available vector cores.
The partial RoPE feature can be tested by passing the `rope_dim`
argument.
- vLLM version: v0.15.0
- vLLM main:
83b47f67b1
---------
Signed-off-by: guzhiyong <guzhiyong5@h-partners.com>
Signed-off-by: frank <2547457096@qq.com>
Co-authored-by: guzhiyong <guzhiyong5@h-partners.com>
This commit is contained in:
@@ -20,17 +20,15 @@ import triton # type: ignore
|
||||
import triton.language as tl # type: ignore
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
from vllm_ascend.ops.triton.triton_utils import extract_slice, get_vectorcore_num, insert_slice
|
||||
from vllm_ascend.ops.triton.triton_utils import extract_slice, get_element, get_vectorcore_num, insert_slice
|
||||
|
||||
|
||||
@triton.jit
|
||||
def split_qkv_rmsnorm_rope_kernel(
|
||||
input_ptr,
|
||||
cos_sin_ptr,
|
||||
pos_ptr,
|
||||
q_ptr,
|
||||
k_ptr,
|
||||
v_ptr,
|
||||
input_gm_ptr,
|
||||
q_gm_ptr,
|
||||
k_gm_ptr,
|
||||
v_gm_ptr,
|
||||
q_weight_ptr,
|
||||
q_bias_ptr,
|
||||
k_weight_ptr,
|
||||
@@ -40,158 +38,228 @@ def split_qkv_rmsnorm_rope_kernel(
|
||||
kv_hidden_size: tl.constexpr,
|
||||
total_hidden_size: tl.constexpr,
|
||||
eps: tl.constexpr,
|
||||
Q_BLOCK_SIZE: tl.constexpr,
|
||||
KV_BLOCK_SIZE: tl.constexpr,
|
||||
BIAS: tl.constexpr,
|
||||
HEAD_DIM: tl.constexpr,
|
||||
HALF_HEAD_DIM: tl.constexpr,
|
||||
ROPE_DIM: tl.constexpr,
|
||||
HALF_ROPE_DIM: tl.constexpr,
|
||||
IS_PARTIAL_ROPE: tl.constexpr,
|
||||
num_vectorcore: tl.constexpr,
|
||||
batch_size_per_iter_per_vec: tl.constexpr,
|
||||
qk_head_nums_per_iter_per_vec: tl.constexpr,
|
||||
q_head_num: tl.constexpr,
|
||||
kv_head_num: tl.constexpr,
|
||||
qk_head_num_sum: tl.constexpr,
|
||||
v_batch_size_per_iter_per_vec: tl.constexpr,
|
||||
positions_gm_ptr,
|
||||
cos_sin_cache_gm_ptr,
|
||||
):
|
||||
row_pid = tl.program_id(0)
|
||||
col_pid = tl.program_id(1)
|
||||
row_step = tl.num_programs(0)
|
||||
# q
|
||||
weight_values = tl.load(q_weight_ptr + tl.arange(0, HEAD_DIM))
|
||||
if BIAS:
|
||||
bias_values = tl.load(q_bias_ptr + tl.arange(0, HEAD_DIM))
|
||||
input_offset = row_pid * total_hidden_size
|
||||
output_offset = row_pid * q_hidden_size
|
||||
input_offset_step = row_step * total_hidden_size
|
||||
output_offset_step = row_step * q_hidden_size
|
||||
for row_idx in tl.range(row_pid, batch_size, row_step):
|
||||
col_indices = col_pid * Q_BLOCK_SIZE + tl.arange(0, Q_BLOCK_SIZE)
|
||||
valid_mask = col_indices < q_hidden_size
|
||||
input_values = (
|
||||
tl.load(input_ptr + input_offset + col_indices, mask=valid_mask, other=0.0)
|
||||
.to(tl.float32)
|
||||
.reshape(Q_BLOCK_SIZE // HEAD_DIM, HEAD_DIM)
|
||||
)
|
||||
squares = input_values * input_values
|
||||
variances = tl.sum(squares, axis=1) / HEAD_DIM
|
||||
reciprocal_std = (1 / tl.sqrt(variances + eps)).reshape(Q_BLOCK_SIZE // HEAD_DIM, 1)
|
||||
normalized_values = input_values * reciprocal_std # (Q_BLOCK_SIZE//HEAD_DIM, HEAD_DIM)
|
||||
if BIAS:
|
||||
normalized_values = (normalized_values * weight_values + bias_values).to(tl.bfloat16)
|
||||
else:
|
||||
normalized_values = (normalized_values * weight_values).to(tl.bfloat16)
|
||||
|
||||
pos_idx = tl.load(pos_ptr + row_idx).to(tl.int64)
|
||||
cos_offsets = pos_idx * HEAD_DIM + tl.arange(0, HALF_HEAD_DIM)
|
||||
sin_offsets = pos_idx * HEAD_DIM + tl.arange(HALF_HEAD_DIM, HEAD_DIM)
|
||||
cos = (tl.load(cos_sin_ptr + cos_offsets)).reshape(1, HALF_HEAD_DIM)
|
||||
sin = (tl.load(cos_sin_ptr + sin_offsets)).reshape(1, HALF_HEAD_DIM)
|
||||
q_weight_values = tl.load(q_weight_ptr + tl.arange(0, HEAD_DIM))
|
||||
k_weight_values = tl.load(k_weight_ptr + tl.arange(0, HEAD_DIM))
|
||||
|
||||
batch_size_per_vec = tl.cdiv(batch_size, num_vectorcore)
|
||||
iter_num_per_vec = tl.cdiv(batch_size_per_vec, batch_size_per_iter_per_vec)
|
||||
v_iter_num_per_vec = tl.cdiv(batch_size_per_vec, v_batch_size_per_iter_per_vec)
|
||||
input_batch_offset = row_pid * batch_size_per_vec
|
||||
mblk_idx = tl.arange(0, batch_size_per_iter_per_vec) + input_batch_offset
|
||||
nblk_idx = tl.arange(0, q_hidden_size + kv_hidden_size)
|
||||
nmask = nblk_idx < total_hidden_size
|
||||
|
||||
input_batch_offset_end = min(input_batch_offset + batch_size_per_vec, batch_size)
|
||||
|
||||
pos_indices = input_batch_offset + tl.arange(0, batch_size_per_iter_per_vec)
|
||||
output_q_nblk_idx = tl.arange(0, q_hidden_size)
|
||||
output_q_nmask = output_q_nblk_idx < q_hidden_size
|
||||
output_kv_nblk_idx = tl.arange(0, kv_hidden_size)
|
||||
output_kv_nmask = output_kv_nblk_idx < kv_hidden_size
|
||||
sin_cos_range = tl.arange(0, ROPE_DIM)
|
||||
cos_sin_cache_offset = cos_sin_cache_gm_ptr + sin_cos_range
|
||||
|
||||
for iter in tl.range(iter_num_per_vec):
|
||||
pos_offset = iter * batch_size_per_iter_per_vec
|
||||
x = tl.load(
|
||||
positions_gm_ptr + pos_indices + pos_offset, mask=(pos_indices + pos_offset) < input_batch_offset_end
|
||||
)
|
||||
mmask = (mblk_idx + pos_offset) < input_batch_offset_end
|
||||
mask = (mmask[:, None]) & (nmask[None, :])
|
||||
idx = (mblk_idx + pos_offset)[:, None] * total_hidden_size + nblk_idx[None, :]
|
||||
values_tmp1 = tl.load(input_gm_ptr + idx, mask=mask).reshape(qk_head_nums_per_iter_per_vec, HEAD_DIM)
|
||||
if BIAS:
|
||||
q_bias_values = tl.load(q_bias_ptr + tl.arange(0, HEAD_DIM))
|
||||
k_bias_values = tl.load(k_bias_ptr + tl.arange(0, HEAD_DIM))
|
||||
|
||||
values_tmp3 = tl.zeros((batch_size_per_iter_per_vec, ROPE_DIM), dtype=tl.bfloat16)
|
||||
for i in tl.range(batch_size_per_iter_per_vec):
|
||||
pos = get_element(x, (i,))
|
||||
values_tmp3 = insert_slice(
|
||||
values_tmp3.reshape(batch_size_per_iter_per_vec, ROPE_DIM),
|
||||
tl.load(pos * ROPE_DIM + cos_sin_cache_offset[:, None]).reshape(1, ROPE_DIM),
|
||||
offsets=(i, 0),
|
||||
sizes=(1, ROPE_DIM),
|
||||
strides=(1, 1),
|
||||
)
|
||||
values_tmp3 = values_tmp3.reshape(batch_size_per_iter_per_vec, 1, ROPE_DIM)
|
||||
cos = extract_slice(
|
||||
values_tmp3,
|
||||
offsets=(0, 0, 0),
|
||||
sizes=(batch_size_per_iter_per_vec, 1, HALF_ROPE_DIM),
|
||||
strides=(1, 1, 1),
|
||||
)
|
||||
sin = extract_slice(
|
||||
values_tmp3,
|
||||
offsets=(0, 0, HALF_ROPE_DIM),
|
||||
sizes=(batch_size_per_iter_per_vec, 1, HALF_ROPE_DIM),
|
||||
strides=(1, 1, 1),
|
||||
)
|
||||
|
||||
normalized_values = values_tmp1.to(tl.float32)
|
||||
normalized_values = normalized_values * normalized_values
|
||||
normalized_values = tl.sum(normalized_values, axis=1) / HEAD_DIM
|
||||
normalized_values = 1 / tl.sqrt(normalized_values + eps).reshape(qk_head_nums_per_iter_per_vec, 1)
|
||||
normalized_values = values_tmp1 * normalized_values
|
||||
|
||||
normalized_values_tmp = extract_slice(
|
||||
normalized_values.reshape(batch_size_per_iter_per_vec, qk_head_num_sum, HEAD_DIM),
|
||||
offsets=(0, 0, 0),
|
||||
sizes=(batch_size_per_iter_per_vec, q_head_num, HEAD_DIM),
|
||||
strides=(1, 1, 1),
|
||||
)
|
||||
|
||||
if BIAS:
|
||||
normalized_values_tmp = (normalized_values_tmp * q_weight_values + q_bias_values).to(tl.bfloat16)
|
||||
else:
|
||||
normalized_values_tmp = (normalized_values_tmp * q_weight_values).to(tl.bfloat16)
|
||||
|
||||
# q rope
|
||||
values_tmp = tl.zeros((batch_size_per_iter_per_vec, q_head_num, ROPE_DIM), dtype=tl.bfloat16)
|
||||
x1 = extract_slice(
|
||||
normalized_values,
|
||||
offsets=(0, 0),
|
||||
sizes=(Q_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
|
||||
strides=(1, 1),
|
||||
normalized_values_tmp,
|
||||
offsets=(0, 0, 0),
|
||||
sizes=(batch_size_per_iter_per_vec, q_head_num, HALF_ROPE_DIM),
|
||||
strides=(1, 1, 1),
|
||||
)
|
||||
x2 = extract_slice(
|
||||
normalized_values,
|
||||
offsets=(0, HALF_HEAD_DIM),
|
||||
sizes=(Q_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
|
||||
strides=(1, 1),
|
||||
normalized_values_tmp,
|
||||
offsets=(0, 0, HALF_ROPE_DIM),
|
||||
sizes=(batch_size_per_iter_per_vec, q_head_num, HALF_ROPE_DIM),
|
||||
strides=(1, 1, 1),
|
||||
)
|
||||
roped_q1 = x1 * cos - x2 * sin
|
||||
roped_q2 = x2 * cos + x1 * sin
|
||||
|
||||
roped_q = tl.zeros((Q_BLOCK_SIZE // HEAD_DIM, HEAD_DIM), dtype=tl.bfloat16)
|
||||
roped_q = insert_slice(
|
||||
roped_q,
|
||||
roped_q1,
|
||||
offsets=(0, 0),
|
||||
sizes=(Q_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
|
||||
strides=(1, 1),
|
||||
values_tmp = insert_slice(
|
||||
values_tmp,
|
||||
x1 * cos - x2 * sin,
|
||||
offsets=(0, 0, 0),
|
||||
sizes=(batch_size_per_iter_per_vec, q_head_num, HALF_ROPE_DIM),
|
||||
strides=(1, 1, 1),
|
||||
)
|
||||
roped_q = insert_slice(
|
||||
roped_q,
|
||||
roped_q2,
|
||||
offsets=(0, HALF_HEAD_DIM),
|
||||
sizes=(Q_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
|
||||
strides=(1, 1),
|
||||
values_tmp = insert_slice(
|
||||
values_tmp,
|
||||
x2 * cos + x1 * sin,
|
||||
offsets=(0, 0, HALF_ROPE_DIM),
|
||||
sizes=(batch_size_per_iter_per_vec, q_head_num, HALF_ROPE_DIM),
|
||||
strides=(1, 1, 1),
|
||||
)
|
||||
tl.store(
|
||||
q_ptr + output_offset + col_indices,
|
||||
roped_q.reshape(Q_BLOCK_SIZE).to(q_ptr.dtype.element_ty),
|
||||
mask=valid_mask,
|
||||
)
|
||||
input_offset += input_offset_step
|
||||
output_offset += output_offset_step
|
||||
|
||||
weight_values = tl.load(k_weight_ptr + tl.arange(0, HEAD_DIM))
|
||||
if BIAS:
|
||||
bias_values = tl.load(k_bias_ptr + tl.arange(0, HEAD_DIM))
|
||||
input_offset = row_pid * total_hidden_size + q_hidden_size
|
||||
output_offset = row_pid * kv_hidden_size
|
||||
output_offset_step = row_step * kv_hidden_size
|
||||
for row_idx in tl.range(row_pid, batch_size, row_step):
|
||||
col_indices = col_pid * KV_BLOCK_SIZE + tl.arange(0, KV_BLOCK_SIZE)
|
||||
valid_mask = col_indices < kv_hidden_size
|
||||
input_values = (
|
||||
tl.load(input_ptr + input_offset + col_indices, mask=valid_mask, other=0.0)
|
||||
.to(tl.float32)
|
||||
.reshape(KV_BLOCK_SIZE // HEAD_DIM, HEAD_DIM)
|
||||
)
|
||||
squares = input_values * input_values
|
||||
variances = tl.sum(squares, axis=1) / HEAD_DIM
|
||||
reciprocal_std = (1 / tl.sqrt(variances + eps)).reshape(KV_BLOCK_SIZE // HEAD_DIM, 1)
|
||||
normalized_values = input_values * reciprocal_std # (KV_BLOCK_SIZE/HEAD_DIM, HEAD_DIM)
|
||||
if BIAS:
|
||||
normalized_values = (normalized_values * weight_values + bias_values).to(tl.bfloat16)
|
||||
q_output_idx = output_q_nblk_idx[None, :] + (mblk_idx + pos_offset)[:, None] * q_hidden_size
|
||||
mask = (mmask[:, None]) & (output_q_nmask[None, :])
|
||||
if IS_PARTIAL_ROPE:
|
||||
normalized_values_tmp = insert_slice(
|
||||
normalized_values_tmp,
|
||||
values_tmp,
|
||||
offsets=(0, 0, 0),
|
||||
sizes=(batch_size_per_iter_per_vec, q_head_num, ROPE_DIM),
|
||||
strides=(1, 1, 1),
|
||||
)
|
||||
tl.store(
|
||||
q_gm_ptr + q_output_idx,
|
||||
normalized_values_tmp.reshape(batch_size_per_iter_per_vec, q_hidden_size),
|
||||
mask=mask,
|
||||
)
|
||||
else:
|
||||
normalized_values = (normalized_values * weight_values).to(tl.bfloat16)
|
||||
tl.store(
|
||||
q_gm_ptr + q_output_idx,
|
||||
values_tmp.reshape(batch_size_per_iter_per_vec, q_hidden_size),
|
||||
mask=mask,
|
||||
)
|
||||
|
||||
# k rope
|
||||
normalized_values_tmp1 = extract_slice(
|
||||
normalized_values.reshape(batch_size_per_iter_per_vec, qk_head_num_sum, HEAD_DIM),
|
||||
offsets=(0, q_head_num, 0),
|
||||
sizes=(batch_size_per_iter_per_vec, kv_head_num, HEAD_DIM),
|
||||
strides=(1, 1, 1),
|
||||
)
|
||||
|
||||
if BIAS:
|
||||
normalized_values_tmp1 = (normalized_values_tmp1 * k_weight_values + k_bias_values).to(tl.bfloat16)
|
||||
else:
|
||||
normalized_values_tmp1 = (normalized_values_tmp1 * k_weight_values).to(tl.bfloat16)
|
||||
|
||||
values_tmp2 = tl.zeros((batch_size_per_iter_per_vec, kv_head_num, ROPE_DIM), dtype=tl.bfloat16)
|
||||
|
||||
pos_idx = tl.load(pos_ptr + row_idx).to(tl.int64)
|
||||
cos_offsets = pos_idx * HEAD_DIM + tl.arange(0, HALF_HEAD_DIM)
|
||||
sin_offsets = pos_idx * HEAD_DIM + tl.arange(HALF_HEAD_DIM, HEAD_DIM)
|
||||
cos = (tl.load(cos_sin_ptr + cos_offsets)).reshape(1, HALF_HEAD_DIM)
|
||||
sin = (tl.load(cos_sin_ptr + sin_offsets)).reshape(1, HALF_HEAD_DIM)
|
||||
x1 = extract_slice(
|
||||
normalized_values,
|
||||
offsets=(0, 0),
|
||||
sizes=(KV_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
|
||||
strides=(1, 1),
|
||||
normalized_values_tmp1,
|
||||
offsets=(0, 0, 0),
|
||||
sizes=(batch_size_per_iter_per_vec, kv_head_num, HALF_ROPE_DIM),
|
||||
strides=(1, 1, 1),
|
||||
)
|
||||
x2 = extract_slice(
|
||||
normalized_values,
|
||||
offsets=(0, HALF_HEAD_DIM),
|
||||
sizes=(KV_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
|
||||
strides=(1, 1),
|
||||
normalized_values_tmp1,
|
||||
offsets=(0, 0, HALF_ROPE_DIM),
|
||||
sizes=(batch_size_per_iter_per_vec, kv_head_num, HALF_ROPE_DIM),
|
||||
strides=(1, 1, 1),
|
||||
)
|
||||
values_tmp2 = insert_slice(
|
||||
values_tmp2,
|
||||
x1 * cos - x2 * sin,
|
||||
offsets=(0, 0, 0),
|
||||
sizes=(batch_size_per_iter_per_vec, kv_head_num, HALF_ROPE_DIM),
|
||||
strides=(1, 1, 1),
|
||||
)
|
||||
values_tmp2 = insert_slice(
|
||||
values_tmp2,
|
||||
x2 * cos + x1 * sin,
|
||||
offsets=(0, 0, HALF_ROPE_DIM),
|
||||
sizes=(batch_size_per_iter_per_vec, kv_head_num, HALF_ROPE_DIM),
|
||||
strides=(1, 1, 1),
|
||||
)
|
||||
roped_k1 = x1 * cos - x2 * sin
|
||||
roped_k2 = x2 * cos + x1 * sin
|
||||
|
||||
roped_k = tl.zeros((KV_BLOCK_SIZE // HEAD_DIM, HEAD_DIM), dtype=tl.bfloat16)
|
||||
roped_k = insert_slice(
|
||||
roped_k,
|
||||
roped_k1,
|
||||
offsets=(0, 0),
|
||||
sizes=(KV_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
|
||||
strides=(1, 1),
|
||||
)
|
||||
roped_k = insert_slice(
|
||||
roped_k,
|
||||
roped_k2,
|
||||
offsets=(0, HALF_HEAD_DIM),
|
||||
sizes=(KV_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
|
||||
strides=(1, 1),
|
||||
)
|
||||
tl.store(
|
||||
k_ptr + output_offset + col_indices,
|
||||
roped_k.to(tl.bfloat16).reshape(KV_BLOCK_SIZE),
|
||||
mask=valid_mask,
|
||||
)
|
||||
input_offset += input_offset_step
|
||||
output_offset += output_offset_step
|
||||
kv_output_idx = output_kv_nblk_idx[None, :] + (mblk_idx + pos_offset)[:, None] * kv_hidden_size
|
||||
mask = (mmask[:, None]) & (output_kv_nmask[None, :])
|
||||
if IS_PARTIAL_ROPE:
|
||||
normalized_values_tmp1 = insert_slice(
|
||||
normalized_values_tmp1,
|
||||
values_tmp2,
|
||||
offsets=(0, 0, 0),
|
||||
sizes=(batch_size_per_iter_per_vec, kv_head_num, ROPE_DIM),
|
||||
strides=(1, 1, 1),
|
||||
)
|
||||
tl.store(
|
||||
k_gm_ptr + kv_output_idx,
|
||||
normalized_values_tmp1.reshape(batch_size_per_iter_per_vec, kv_hidden_size),
|
||||
mask=mask,
|
||||
)
|
||||
else:
|
||||
tl.store(
|
||||
k_gm_ptr + kv_output_idx,
|
||||
values_tmp2.reshape(batch_size_per_iter_per_vec, kv_hidden_size),
|
||||
mask=mask,
|
||||
)
|
||||
|
||||
input_offset = row_pid * total_hidden_size + q_hidden_size + kv_hidden_size
|
||||
output_offset = row_pid * kv_hidden_size
|
||||
for _ in tl.range(row_pid, batch_size, row_step):
|
||||
col_indices = col_pid * KV_BLOCK_SIZE + tl.arange(0, KV_BLOCK_SIZE)
|
||||
valid_mask = col_indices < kv_hidden_size
|
||||
input_values = tl.load(input_ptr + input_offset + col_indices, mask=valid_mask, other=0.0)
|
||||
tl.store(v_ptr + output_offset + col_indices, input_values, mask=valid_mask)
|
||||
input_offset += input_offset_step
|
||||
output_offset += output_offset_step
|
||||
mblk_idx = tl.arange(0, v_batch_size_per_iter_per_vec) + input_batch_offset
|
||||
nblk_idx = tl.arange(q_hidden_size + kv_hidden_size, total_hidden_size)
|
||||
nmask = nblk_idx < total_hidden_size
|
||||
out_nblk_idx = tl.arange(0, kv_hidden_size)
|
||||
out_nmask = out_nblk_idx < kv_hidden_size
|
||||
|
||||
for _ in tl.range(v_iter_num_per_vec):
|
||||
mmask = mblk_idx < input_batch_offset_end
|
||||
mask = (mmask[:, None]) & (nmask[None, :])
|
||||
idx = mblk_idx[:, None] * total_hidden_size + nblk_idx[None, :]
|
||||
values = tl.load(input_gm_ptr + idx, mask=mask)
|
||||
out_idx = mblk_idx[:, None] * kv_hidden_size + out_nblk_idx[None, :]
|
||||
out_mask = (mmask[:, None]) & (out_nmask[None, :])
|
||||
tl.store(v_gm_ptr + out_idx, values, mask=out_mask)
|
||||
mblk_idx += v_batch_size_per_iter_per_vec
|
||||
|
||||
|
||||
def split_qkv_rmsnorm_rope_impl(
|
||||
@@ -207,25 +275,46 @@ def split_qkv_rmsnorm_rope_impl(
|
||||
q_bias: torch.Tensor | None = None,
|
||||
k_bias: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
KV_BLOCK_SIZE = triton.next_power_of_2(head_dim)
|
||||
assert head_dim == KV_BLOCK_SIZE
|
||||
assert q_hidden_size % kv_hidden_size == 0
|
||||
Q_BLOCK_SIZE = q_hidden_size // kv_hidden_size * head_dim
|
||||
# get available vector core
|
||||
num_vectorcore = get_vectorcore_num()
|
||||
rope_dim = cos_sin_cache.shape[-1]
|
||||
batch_size = input.shape[0]
|
||||
BIAS = q_bias is not None
|
||||
IS_PARTIAL_ROPE = rope_dim != head_dim
|
||||
# Q + K + V
|
||||
total_hidden_size = q_hidden_size + kv_hidden_size * 2
|
||||
|
||||
q_output = torch.empty(batch_size, q_hidden_size, device=input.device, dtype=input.dtype)
|
||||
k_output = torch.empty(batch_size, kv_hidden_size, device=input.device, dtype=input.dtype)
|
||||
v_output = torch.empty(batch_size, kv_hidden_size, device=input.device, dtype=input.dtype)
|
||||
n_cols = kv_hidden_size // KV_BLOCK_SIZE
|
||||
num_vectorcore = get_vectorcore_num()
|
||||
assert num_vectorcore % n_cols == 0
|
||||
n_rows = num_vectorcore // n_cols
|
||||
BIAS = q_bias is not None
|
||||
|
||||
split_qkv_rmsnorm_rope_kernel[(n_rows, n_cols, 1)](
|
||||
q_head_num = q_hidden_size // head_dim
|
||||
kv_head_num = kv_hidden_size // head_dim
|
||||
|
||||
# set number of line loading from GM data is x
|
||||
# x*(q_head_num + kv_head_num)*HEAD_DIM: values_tmp
|
||||
# 2x*(q_head_num + kv_head_num)*HEAD_DIM: normalized_values(float32)
|
||||
# x*ROPE_DIM*2 : cos/sin
|
||||
# x*q_head_num*HEAD_DIM*2: normalized_values_tmp
|
||||
# x*q_head_num*ROPE_DIM*(0.5) (not IS_PARTIAL_ROPE) x*q_head_num*ROPE_DIM*(0.5): y
|
||||
UB_SIZE = 87040 # 85K = 85 * 1024
|
||||
# the factor is the sum of elements number
|
||||
if IS_PARTIAL_ROPE:
|
||||
factor = 5 * q_hidden_size + 3 * kv_hidden_size + rope_dim * 4 + q_head_num * rope_dim
|
||||
batch_size_per_iter_per_vec = int(UB_SIZE / input.element_size()) // factor
|
||||
else:
|
||||
factor = 5 * q_hidden_size + 3 * kv_hidden_size + rope_dim * 2 + q_head_num * rope_dim // 2
|
||||
batch_size_per_iter_per_vec = int(UB_SIZE / input.element_size()) // factor
|
||||
batch_size_per_iter_per_vec = max(1, batch_size_per_iter_per_vec)
|
||||
qk_head_num_sum = int(q_head_num + kv_head_num)
|
||||
qk_head_nums_per_iter_per_vec = batch_size_per_iter_per_vec * qk_head_num_sum
|
||||
|
||||
grid = (num_vectorcore, 1, 1)
|
||||
# v tiling
|
||||
v_batch_size_per_iter_per_vec = UB_SIZE / torch.bfloat16.itemsize // (kv_hidden_size + 1)
|
||||
|
||||
split_qkv_rmsnorm_rope_kernel[grid](
|
||||
input,
|
||||
cos_sin_cache,
|
||||
positions,
|
||||
q_output,
|
||||
k_output,
|
||||
v_output,
|
||||
@@ -238,11 +327,20 @@ def split_qkv_rmsnorm_rope_impl(
|
||||
kv_hidden_size,
|
||||
total_hidden_size,
|
||||
eps,
|
||||
Q_BLOCK_SIZE,
|
||||
KV_BLOCK_SIZE,
|
||||
BIAS,
|
||||
head_dim,
|
||||
head_dim // 2,
|
||||
rope_dim,
|
||||
rope_dim // 2,
|
||||
IS_PARTIAL_ROPE,
|
||||
num_vectorcore,
|
||||
int(batch_size_per_iter_per_vec),
|
||||
int(qk_head_nums_per_iter_per_vec),
|
||||
q_head_num,
|
||||
kv_head_num,
|
||||
qk_head_num_sum,
|
||||
int(v_batch_size_per_iter_per_vec),
|
||||
positions,
|
||||
cos_sin_cache,
|
||||
)
|
||||
return q_output, k_output, v_output
|
||||
|
||||
@@ -265,19 +363,19 @@ def split_qkv_rmsnorm_rope_impl_fake(
|
||||
batch_size = input.shape[0]
|
||||
q_output = torch.empty(
|
||||
batch_size,
|
||||
q_hidden_size,
|
||||
int(q_hidden_size),
|
||||
device=input.device,
|
||||
dtype=input.dtype,
|
||||
)
|
||||
k_output = torch.empty(
|
||||
batch_size,
|
||||
kv_hidden_size,
|
||||
int(kv_hidden_size),
|
||||
device=input.device,
|
||||
dtype=input.dtype,
|
||||
)
|
||||
v_output = torch.empty(
|
||||
batch_size,
|
||||
kv_hidden_size,
|
||||
int(kv_hidden_size),
|
||||
device=input.device,
|
||||
dtype=input.dtype,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user