[Feat]fused_qkvzba_split_reshape supports token number greater than 65536 (#6740)
### What this PR does / why we need it?
This pull request optimizes the fused_qkvzba_split_reshape_cat Triton
kernel for Qwen3-Next GatedDeltaNet model and removes the previous
conditional restrictions in the forward pass.
Key changes:
1. Refactored Triton kernel implementation: The
fused_qkvzba_split_reshape_cat_kernel has been optimized with a new
loop-based approach that supports arbitrary num_v_heads / num_k_heads
ratios and batch sizes. The kernel now uses configurable ROWS_PER_ITER
for better memory utilization .
2. The optimized kernel now handles all scenarios directly without
requiring a fallback path using fix_query_key_value_ordering and
torch.cat.
### Does this PR introduce _any_ user-facing change?
No. This is an internal optimization of the Triton kernel implementation
and does not introduce any user-facing changes.
### How was this patch tested?
CI is expected to pass with existing tests.
- vLLM version: v0.15.0
- vLLM main:
9562912cea
---------
Signed-off-by: songjianquan <songjianquan1@huawei.com>
Co-authored-by: songjianquan <songjianquan1@huawei.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -12,8 +12,12 @@
|
||||
import torch
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num
|
||||
|
||||
@triton.jit
|
||||
MAX_ROWS_PER_ITER = 64
|
||||
|
||||
|
||||
@triton.jit(do_not_specialize=["total_rows", "rows_per_vec"])
|
||||
def fused_qkvzba_split_reshape_cat_kernel(
|
||||
mixed_qkv,
|
||||
z,
|
||||
@@ -25,50 +29,116 @@ def fused_qkvzba_split_reshape_cat_kernel(
|
||||
NUM_HEADS_V: tl.constexpr,
|
||||
HEAD_QK: tl.constexpr,
|
||||
HEAD_V: tl.constexpr,
|
||||
total_rows,
|
||||
rows_per_vec,
|
||||
QKVZ_ROW_STRIDE: tl.constexpr,
|
||||
BA_ROW_STRIDE: tl.constexpr,
|
||||
QKV_ROW_STRIDE: tl.constexpr,
|
||||
Z_ROW_STRIDE: tl.constexpr,
|
||||
BA_OUT_ROW_STRIDE: tl.constexpr,
|
||||
ROWS_PER_ITER: tl.constexpr,
|
||||
):
|
||||
i_bs, i_qk = tl.program_id(0), tl.program_id(1)
|
||||
QKVZ_DIM_T: tl.constexpr = HEAD_QK * 2 + NUM_HEADS_V // NUM_HEADS_QK * HEAD_V * 2
|
||||
BA_DIM_T: tl.constexpr = NUM_HEADS_V // NUM_HEADS_QK * 2
|
||||
QKV_DIM_T: tl.constexpr = HEAD_QK * 2 + NUM_HEADS_V // NUM_HEADS_QK * HEAD_V
|
||||
q_end: tl.constexpr = HEAD_QK
|
||||
blk_q_ptr = mixed_qkvz + i_bs * NUM_HEADS_QK * QKVZ_DIM_T + i_qk * QKVZ_DIM_T + tl.arange(0, q_end)
|
||||
k_end: tl.constexpr = q_end + HEAD_QK
|
||||
blk_k_ptr = mixed_qkvz + i_bs * NUM_HEADS_QK * QKVZ_DIM_T + i_qk * QKVZ_DIM_T + tl.arange(q_end, k_end)
|
||||
v_end: tl.constexpr = k_end + NUM_HEADS_V // NUM_HEADS_QK * HEAD_V
|
||||
blk_v_ptr = mixed_qkvz + i_bs * NUM_HEADS_QK * QKVZ_DIM_T + i_qk * QKVZ_DIM_T + tl.arange(k_end, v_end)
|
||||
z_end: tl.constexpr = v_end + NUM_HEADS_V // NUM_HEADS_QK * HEAD_V
|
||||
blk_z_ptr = mixed_qkvz + i_bs * NUM_HEADS_QK * QKVZ_DIM_T + i_qk * QKVZ_DIM_T + tl.arange(v_end, z_end)
|
||||
blk_q_st_ptr = mixed_qkv + i_bs * NUM_HEADS_QK * QKV_DIM_T + i_qk * HEAD_QK + tl.arange(0, HEAD_QK)
|
||||
blk_k_st_ptr = (
|
||||
mixed_qkv + i_bs * NUM_HEADS_QK * QKV_DIM_T + NUM_HEADS_QK * HEAD_QK + i_qk * HEAD_QK + tl.arange(0, HEAD_QK)
|
||||
)
|
||||
blk_v_st_ptr = (
|
||||
mixed_qkv
|
||||
+ i_bs * NUM_HEADS_QK * QKV_DIM_T
|
||||
+ NUM_HEADS_QK * HEAD_QK * 2
|
||||
+ i_qk * HEAD_V * NUM_HEADS_V // NUM_HEADS_QK
|
||||
+ tl.arange(0, HEAD_V * NUM_HEADS_V // NUM_HEADS_QK)
|
||||
)
|
||||
blk_z_st_ptr = (
|
||||
z
|
||||
+ i_bs * NUM_HEADS_V * HEAD_V
|
||||
+ i_qk * HEAD_V * NUM_HEADS_V // NUM_HEADS_QK
|
||||
+ tl.arange(0, HEAD_V * NUM_HEADS_V // NUM_HEADS_QK)
|
||||
)
|
||||
tl.store(blk_q_st_ptr, tl.load(blk_q_ptr))
|
||||
tl.store(blk_k_st_ptr, tl.load(blk_k_ptr))
|
||||
tl.store(blk_v_st_ptr, tl.load(blk_v_ptr))
|
||||
tl.store(blk_z_st_ptr, tl.load(blk_z_ptr))
|
||||
b_end: tl.constexpr = NUM_HEADS_V // NUM_HEADS_QK
|
||||
a_end: tl.constexpr = b_end + NUM_HEADS_V // NUM_HEADS_QK
|
||||
for i in tl.static_range(b_end):
|
||||
blk_b_ptr = mixed_ba + i_bs * NUM_HEADS_QK * BA_DIM_T + i_qk * BA_DIM_T + i
|
||||
blk_b_st_ptr = b + i_bs * NUM_HEADS_V + i_qk * NUM_HEADS_V // NUM_HEADS_QK + i
|
||||
tl.store(blk_b_st_ptr, tl.load(blk_b_ptr))
|
||||
for i in tl.static_range(b_end, a_end):
|
||||
blk_a_ptr = mixed_ba + i_bs * NUM_HEADS_QK * BA_DIM_T + i_qk * BA_DIM_T + i
|
||||
blk_a_st_ptr = a + i_bs * NUM_HEADS_V + i_qk * NUM_HEADS_V // NUM_HEADS_QK + (i - b_end)
|
||||
tl.store(blk_a_st_ptr, tl.load(blk_a_ptr))
|
||||
"""
|
||||
Fused kernel to split and reshape mixed QKVZ and BA tensors.
|
||||
|
||||
This kernel performs the following transformations:
|
||||
- Input mixed_qkvz: [num_tokens, num_heads_qk * (Q + K + V + Z)] where each
|
||||
head block contains [Q(HEAD_QK), K(HEAD_QK), V(V_DIM_PER_QK), Z(V_DIM_PER_QK)]
|
||||
- Input mixed_ba: [num_tokens, num_heads_qk * (B + A)] where each head block
|
||||
contains [B(V_HEADS_PER_QK), A(V_HEADS_PER_QK)]
|
||||
- Output mixed_qkv: [num_tokens, Q_all | K_all | V_all] concatenated by type
|
||||
- Output z: [num_tokens, num_heads_v, head_v]
|
||||
- Output b, a: [num_tokens, num_heads_v]
|
||||
"""
|
||||
# Each vector core processes a contiguous chunk of rows
|
||||
vec_id = tl.program_id(0)
|
||||
|
||||
V_HEADS_PER_QK: tl.constexpr = NUM_HEADS_V // NUM_HEADS_QK
|
||||
V_DIM_PER_QK: tl.constexpr = V_HEADS_PER_QK * HEAD_V
|
||||
QKVZ_DIM_T: tl.constexpr = HEAD_QK * 2 + V_DIM_PER_QK * 2
|
||||
BA_DIM_T: tl.constexpr = V_HEADS_PER_QK * 2
|
||||
|
||||
Q_TOTAL: tl.constexpr = NUM_HEADS_QK * HEAD_QK
|
||||
K_TOTAL: tl.constexpr = NUM_HEADS_QK * HEAD_QK
|
||||
|
||||
row_start = vec_id * rows_per_vec
|
||||
row_end = min(row_start + rows_per_vec, total_rows)
|
||||
|
||||
row_offset = row_start
|
||||
|
||||
iter_count = (row_end - row_start + ROWS_PER_ITER - 1) // ROWS_PER_ITER
|
||||
|
||||
# ========== Main Iteration Loop ==========
|
||||
for _ in tl.range(iter_count):
|
||||
row_indices = tl.arange(0, ROWS_PER_ITER) + row_offset
|
||||
row_mask = row_indices < row_end
|
||||
|
||||
# ========== Head Iteration Loop ==========
|
||||
# Iterate over each Q/K head group to extract and rearrange data
|
||||
for head_id in tl.static_range(NUM_HEADS_QK):
|
||||
# Byte offset to the current head's data block in mixed_qkvz
|
||||
src_head_offset = head_id * QKVZ_DIM_T
|
||||
|
||||
# ----- Q (Query) Extraction -----
|
||||
# Source layout: mixed_qkvz[row, head_id * QKVZ_DIM_T + 0:HEAD_QK]
|
||||
# Dest layout: mixed_qkv[row, head_id * HEAD_QK : (head_id+1) * HEAD_QK]
|
||||
q_range = tl.arange(0, HEAD_QK)
|
||||
q_src = row_indices[:, None] * QKVZ_ROW_STRIDE + src_head_offset + q_range[None, :]
|
||||
q_dst = row_indices[:, None] * QKV_ROW_STRIDE + head_id * HEAD_QK + q_range[None, :]
|
||||
q_data = tl.load(mixed_qkvz + q_src, mask=row_mask[:, None])
|
||||
tl.store(mixed_qkv + q_dst, q_data, mask=row_mask[:, None])
|
||||
|
||||
# ----- K (Key) Extraction -----
|
||||
# Source layout: mixed_qkvz[row, head_id * QKVZ_DIM_T + HEAD_QK : +HEAD_QK]
|
||||
# Dest layout: mixed_qkv[row, Q_TOTAL + head_id * HEAD_QK : ...]
|
||||
# K is stored after Q in the source; in dest, K starts after all Q heads
|
||||
k_src = row_indices[:, None] * QKVZ_ROW_STRIDE + src_head_offset + HEAD_QK + q_range[None, :]
|
||||
k_dst = row_indices[:, None] * QKV_ROW_STRIDE + Q_TOTAL + head_id * HEAD_QK + q_range[None, :]
|
||||
k_data = tl.load(mixed_qkvz + k_src, mask=row_mask[:, None])
|
||||
tl.store(mixed_qkv + k_dst, k_data, mask=row_mask[:, None])
|
||||
|
||||
# ----- V (Value) Extraction -----
|
||||
# Source layout: mixed_qkvz[row, head_id * QKVZ_DIM_T + HEAD_QK*2 : +V_DIM_PER_QK]
|
||||
# Dest layout: mixed_qkv[row, Q_TOTAL + K_TOTAL + head_id * V_DIM_PER_QK : ...]
|
||||
# V follows Q and K in source; in dest, V starts after all Q and K heads
|
||||
v_range = tl.arange(0, V_DIM_PER_QK)
|
||||
v_src = row_indices[:, None] * QKVZ_ROW_STRIDE + src_head_offset + HEAD_QK * 2 + v_range[None, :]
|
||||
v_dst = (
|
||||
row_indices[:, None] * QKV_ROW_STRIDE + Q_TOTAL + K_TOTAL + head_id * V_DIM_PER_QK + v_range[None, :]
|
||||
)
|
||||
v_data = tl.load(mixed_qkvz + v_src, mask=row_mask[:, None])
|
||||
tl.store(mixed_qkv + v_dst, v_data, mask=row_mask[:, None])
|
||||
|
||||
# ----- Z Extraction -----
|
||||
# Source layout: mixed_qkvz[row, head_id * QKVZ_DIM_T + HEAD_QK*2 + V_DIM_PER_QK : ...]
|
||||
# Dest layout: z[row, head_id * V_DIM_PER_QK : (head_id+1) * V_DIM_PER_QK]
|
||||
# Z follows V in source; output z is reshaped to [batch, num_heads_v, head_v]
|
||||
z_src = (
|
||||
row_indices[:, None] * QKVZ_ROW_STRIDE + src_head_offset + HEAD_QK * 2 + V_DIM_PER_QK + v_range[None, :]
|
||||
)
|
||||
z_dst = row_indices[:, None] * Z_ROW_STRIDE + head_id * V_DIM_PER_QK + v_range[None, :]
|
||||
z_data = tl.load(mixed_qkvz + z_src, mask=row_mask[:, None])
|
||||
tl.store(z + z_dst, z_data, mask=row_mask[:, None])
|
||||
|
||||
# ----- B Extraction -----
|
||||
# Source layout: mixed_ba[row, head_id * BA_DIM_T : +V_HEADS_PER_QK]
|
||||
# Dest layout: b[row, head_id * V_HEADS_PER_QK : (head_id+1) * V_HEADS_PER_QK]
|
||||
b_range = tl.arange(0, V_HEADS_PER_QK)
|
||||
ba_head_offset = head_id * BA_DIM_T
|
||||
b_src = row_indices[:, None] * BA_ROW_STRIDE + ba_head_offset + b_range[None, :]
|
||||
b_dst = row_indices[:, None] * BA_OUT_ROW_STRIDE + head_id * V_HEADS_PER_QK + b_range[None, :]
|
||||
b_data = tl.load(mixed_ba + b_src, mask=row_mask[:, None])
|
||||
tl.store(b + b_dst, b_data, mask=row_mask[:, None])
|
||||
|
||||
# ----- A Extraction -----
|
||||
# Source layout: mixed_ba[row, head_id * BA_DIM_T + V_HEADS_PER_QK : ...]
|
||||
# Dest layout: a[row, head_id * V_HEADS_PER_QK : ...] (same as b_dst)
|
||||
# A follows B in source; output layout is same as B
|
||||
a_src = row_indices[:, None] * BA_ROW_STRIDE + ba_head_offset + V_HEADS_PER_QK + b_range[None, :]
|
||||
a_data = tl.load(mixed_ba + a_src, mask=row_mask[:, None])
|
||||
tl.store(a + b_dst, a_data, mask=row_mask[:, None])
|
||||
|
||||
row_offset += ROWS_PER_ITER
|
||||
|
||||
|
||||
def fused_qkvzba_split_reshape_cat(
|
||||
@@ -80,6 +150,20 @@ def fused_qkvzba_split_reshape_cat(
|
||||
head_v,
|
||||
):
|
||||
batch, seq_len = mixed_qkvz.shape[0], 1
|
||||
total_rows = batch * seq_len
|
||||
|
||||
v_heads_per_qk = num_heads_v // num_heads_qk
|
||||
v_dim_per_qk = v_heads_per_qk * head_v
|
||||
qkvz_dim_t = head_qk * 2 + v_dim_per_qk * 2
|
||||
ba_dim_t = v_heads_per_qk * 2
|
||||
|
||||
# row stride
|
||||
qkvz_row_stride = num_heads_qk * qkvz_dim_t
|
||||
ba_row_stride = num_heads_qk * ba_dim_t
|
||||
qkv_row_stride = num_heads_qk * head_qk * 2 + num_heads_v * head_v
|
||||
z_row_stride = num_heads_v * head_v
|
||||
ba_out_row_stride = num_heads_v
|
||||
|
||||
qkv_dim_t = num_heads_qk * head_qk * 2 + num_heads_v * head_v
|
||||
mixed_qkv = torch.empty(
|
||||
[batch * seq_len, qkv_dim_t],
|
||||
@@ -96,8 +180,28 @@ def fused_qkvzba_split_reshape_cat(
|
||||
dtype=mixed_ba.dtype,
|
||||
device=mixed_ba.device,
|
||||
)
|
||||
a = torch.empty_like(b)
|
||||
grid = (batch * seq_len, num_heads_qk)
|
||||
a = torch.empty(
|
||||
[batch * seq_len, num_heads_v],
|
||||
dtype=mixed_ba.dtype,
|
||||
device=mixed_ba.device,
|
||||
)
|
||||
|
||||
num_vectorcore = get_vectorcore_num()
|
||||
|
||||
grid_size = min(num_vectorcore, total_rows)
|
||||
grid_size = max(1, grid_size)
|
||||
|
||||
rows_per_vec = triton.cdiv(total_rows, grid_size)
|
||||
|
||||
ub_size = 85 * 1024 // mixed_qkvz.element_size()
|
||||
|
||||
elements_per_row = qkvz_row_stride + ba_row_stride + qkv_row_stride + z_row_stride + ba_out_row_stride * 2
|
||||
|
||||
rows_per_iter = max(1, ub_size // elements_per_row)
|
||||
rows_per_iter = triton.next_power_of_2(rows_per_iter)
|
||||
rows_per_iter = min(rows_per_iter, rows_per_vec, MAX_ROWS_PER_ITER)
|
||||
|
||||
grid = (grid_size, 1)
|
||||
fused_qkvzba_split_reshape_cat_kernel[grid](
|
||||
mixed_qkv,
|
||||
z,
|
||||
@@ -109,7 +213,13 @@ def fused_qkvzba_split_reshape_cat(
|
||||
num_heads_v,
|
||||
head_qk,
|
||||
head_v,
|
||||
num_warps=1,
|
||||
num_stages=3,
|
||||
total_rows,
|
||||
rows_per_vec,
|
||||
qkvz_row_stride,
|
||||
ba_row_stride,
|
||||
qkv_row_stride,
|
||||
z_row_stride,
|
||||
ba_out_row_stride,
|
||||
rows_per_iter,
|
||||
)
|
||||
return mixed_qkv, z, b, a
|
||||
|
||||
@@ -18,7 +18,6 @@
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from torch import nn
|
||||
from vllm.config import CUDAGraphMode
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.fla.ops import chunk_gated_delta_rule, fused_recurrent_gated_delta_rule
|
||||
from vllm.model_executor.layers.mamba.abstract import MambaBase
|
||||
@@ -52,23 +51,15 @@ class AscendQwen3Next_GatedDeltaNet(nn.Module, MambaBase):
|
||||
# ============================================================
|
||||
projected_states_qkvz, _ = self.in_proj_qkvz(hidden_states)
|
||||
projected_states_ba, _ = self.in_proj_ba(hidden_states)
|
||||
forward_context = get_forward_context()
|
||||
is_cuda_graph = forward_context.cudagraph_runtime_mode != CUDAGraphMode.NONE
|
||||
# triton grid should be less than 66536
|
||||
divide_grid = projected_states_qkvz.shape[0] * triton.cdiv(self.num_k_heads, self.tp_size)
|
||||
if self.num_v_heads // self.num_k_heads in [1, 2, 4] and is_cuda_graph and divide_grid < 65536:
|
||||
mixed_qkv, z, b, a = fused_qkvzba_split_reshape_cat(
|
||||
projected_states_qkvz,
|
||||
projected_states_ba,
|
||||
triton.cdiv(self.num_k_heads, self.tp_size),
|
||||
triton.cdiv(self.num_v_heads, self.tp_size),
|
||||
self.head_k_dim,
|
||||
self.head_v_dim,
|
||||
)
|
||||
else:
|
||||
query, key, value, z, b, a = self.fix_query_key_value_ordering(projected_states_qkvz, projected_states_ba)
|
||||
query, key, value = map(lambda x: rearrange(x, "l p d -> l (p d)"), (query, key, value))
|
||||
mixed_qkv = torch.cat((query, key, value), dim=-1)
|
||||
|
||||
mixed_qkv, z, b, a = fused_qkvzba_split_reshape_cat(
|
||||
projected_states_qkvz,
|
||||
projected_states_ba,
|
||||
triton.cdiv(self.num_k_heads, self.tp_size),
|
||||
triton.cdiv(self.num_v_heads, self.tp_size),
|
||||
self.head_k_dim,
|
||||
self.head_v_dim,
|
||||
)
|
||||
|
||||
# ============================================================
|
||||
# Part 2: Core Attention (Custom Op)
|
||||
|
||||
Reference in New Issue
Block a user