From 43c8da3574c96b9aaeaf4ef360c9b4aaf6a3e305 Mon Sep 17 00:00:00 2001 From: songjianquan <61267299+songjianquan@users.noreply.github.com> Date: Thu, 5 Mar 2026 14:41:38 +0800 Subject: [PATCH] [Feat]fused_qkvzba_split_reshape supports token number greater than 65536 (#6740) ### What this PR does / why we need it? This pull request optimizes the fused_qkvzba_split_reshape_cat Triton kernel for Qwen3-Next GatedDeltaNet model and removes the previous conditional restrictions in the forward pass. Key changes: 1. Refactored Triton kernel implementation: The fused_qkvzba_split_reshape_cat_kernel has been optimized with a new loop-based approach that supports arbitrary num_v_heads / num_k_heads ratios and batch sizes. The kernel now uses configurable ROWS_PER_ITER for better memory utilization . 2. The optimized kernel now handles all scenarios directly without requiring a fallback path using fix_query_key_value_ordering and torch.cat. ### Does this PR introduce _any_ user-facing change? No. This is an internal optimization of the Triton kernel implementation and does not introduce any user-facing changes. ### How was this patch tested? CI is expected to pass with existing tests. - vLLM version: v0.15.0 - vLLM main: https://github.com/vllm-project/vllm/commit/9562912cead1f11e8540fb91306c5cbda66f0007 --------- Signed-off-by: songjianquan Co-authored-by: songjianquan Co-authored-by: wangxiyuan --- .../triton/fla/fused_qkvzba_split_reshape.py | 206 ++++++++++++++---- vllm_ascend/patch/worker/patch_qwen3_next.py | 27 +-- 2 files changed, 167 insertions(+), 66 deletions(-) diff --git a/vllm_ascend/ops/triton/fla/fused_qkvzba_split_reshape.py b/vllm_ascend/ops/triton/fla/fused_qkvzba_split_reshape.py index e58e79bb..37319601 100644 --- a/vllm_ascend/ops/triton/fla/fused_qkvzba_split_reshape.py +++ b/vllm_ascend/ops/triton/fla/fused_qkvzba_split_reshape.py @@ -12,8 +12,12 @@ import torch from vllm.triton_utils import tl, triton +from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num -@triton.jit +MAX_ROWS_PER_ITER = 64 + + +@triton.jit(do_not_specialize=["total_rows", "rows_per_vec"]) def fused_qkvzba_split_reshape_cat_kernel( mixed_qkv, z, @@ -25,50 +29,116 @@ def fused_qkvzba_split_reshape_cat_kernel( NUM_HEADS_V: tl.constexpr, HEAD_QK: tl.constexpr, HEAD_V: tl.constexpr, + total_rows, + rows_per_vec, + QKVZ_ROW_STRIDE: tl.constexpr, + BA_ROW_STRIDE: tl.constexpr, + QKV_ROW_STRIDE: tl.constexpr, + Z_ROW_STRIDE: tl.constexpr, + BA_OUT_ROW_STRIDE: tl.constexpr, + ROWS_PER_ITER: tl.constexpr, ): - i_bs, i_qk = tl.program_id(0), tl.program_id(1) - QKVZ_DIM_T: tl.constexpr = HEAD_QK * 2 + NUM_HEADS_V // NUM_HEADS_QK * HEAD_V * 2 - BA_DIM_T: tl.constexpr = NUM_HEADS_V // NUM_HEADS_QK * 2 - QKV_DIM_T: tl.constexpr = HEAD_QK * 2 + NUM_HEADS_V // NUM_HEADS_QK * HEAD_V - q_end: tl.constexpr = HEAD_QK - blk_q_ptr = mixed_qkvz + i_bs * NUM_HEADS_QK * QKVZ_DIM_T + i_qk * QKVZ_DIM_T + tl.arange(0, q_end) - k_end: tl.constexpr = q_end + HEAD_QK - blk_k_ptr = mixed_qkvz + i_bs * NUM_HEADS_QK * QKVZ_DIM_T + i_qk * QKVZ_DIM_T + tl.arange(q_end, k_end) - v_end: tl.constexpr = k_end + NUM_HEADS_V // NUM_HEADS_QK * HEAD_V - blk_v_ptr = mixed_qkvz + i_bs * NUM_HEADS_QK * QKVZ_DIM_T + i_qk * QKVZ_DIM_T + tl.arange(k_end, v_end) - z_end: tl.constexpr = v_end + NUM_HEADS_V // NUM_HEADS_QK * HEAD_V - blk_z_ptr = mixed_qkvz + i_bs * NUM_HEADS_QK * QKVZ_DIM_T + i_qk * QKVZ_DIM_T + tl.arange(v_end, z_end) - blk_q_st_ptr = mixed_qkv + i_bs * NUM_HEADS_QK * QKV_DIM_T + i_qk * HEAD_QK + tl.arange(0, HEAD_QK) - blk_k_st_ptr = ( - mixed_qkv + i_bs * NUM_HEADS_QK * QKV_DIM_T + NUM_HEADS_QK * HEAD_QK + i_qk * HEAD_QK + tl.arange(0, HEAD_QK) - ) - blk_v_st_ptr = ( - mixed_qkv - + i_bs * NUM_HEADS_QK * QKV_DIM_T - + NUM_HEADS_QK * HEAD_QK * 2 - + i_qk * HEAD_V * NUM_HEADS_V // NUM_HEADS_QK - + tl.arange(0, HEAD_V * NUM_HEADS_V // NUM_HEADS_QK) - ) - blk_z_st_ptr = ( - z - + i_bs * NUM_HEADS_V * HEAD_V - + i_qk * HEAD_V * NUM_HEADS_V // NUM_HEADS_QK - + tl.arange(0, HEAD_V * NUM_HEADS_V // NUM_HEADS_QK) - ) - tl.store(blk_q_st_ptr, tl.load(blk_q_ptr)) - tl.store(blk_k_st_ptr, tl.load(blk_k_ptr)) - tl.store(blk_v_st_ptr, tl.load(blk_v_ptr)) - tl.store(blk_z_st_ptr, tl.load(blk_z_ptr)) - b_end: tl.constexpr = NUM_HEADS_V // NUM_HEADS_QK - a_end: tl.constexpr = b_end + NUM_HEADS_V // NUM_HEADS_QK - for i in tl.static_range(b_end): - blk_b_ptr = mixed_ba + i_bs * NUM_HEADS_QK * BA_DIM_T + i_qk * BA_DIM_T + i - blk_b_st_ptr = b + i_bs * NUM_HEADS_V + i_qk * NUM_HEADS_V // NUM_HEADS_QK + i - tl.store(blk_b_st_ptr, tl.load(blk_b_ptr)) - for i in tl.static_range(b_end, a_end): - blk_a_ptr = mixed_ba + i_bs * NUM_HEADS_QK * BA_DIM_T + i_qk * BA_DIM_T + i - blk_a_st_ptr = a + i_bs * NUM_HEADS_V + i_qk * NUM_HEADS_V // NUM_HEADS_QK + (i - b_end) - tl.store(blk_a_st_ptr, tl.load(blk_a_ptr)) + """ + Fused kernel to split and reshape mixed QKVZ and BA tensors. + + This kernel performs the following transformations: + - Input mixed_qkvz: [num_tokens, num_heads_qk * (Q + K + V + Z)] where each + head block contains [Q(HEAD_QK), K(HEAD_QK), V(V_DIM_PER_QK), Z(V_DIM_PER_QK)] + - Input mixed_ba: [num_tokens, num_heads_qk * (B + A)] where each head block + contains [B(V_HEADS_PER_QK), A(V_HEADS_PER_QK)] + - Output mixed_qkv: [num_tokens, Q_all | K_all | V_all] concatenated by type + - Output z: [num_tokens, num_heads_v, head_v] + - Output b, a: [num_tokens, num_heads_v] + """ + # Each vector core processes a contiguous chunk of rows + vec_id = tl.program_id(0) + + V_HEADS_PER_QK: tl.constexpr = NUM_HEADS_V // NUM_HEADS_QK + V_DIM_PER_QK: tl.constexpr = V_HEADS_PER_QK * HEAD_V + QKVZ_DIM_T: tl.constexpr = HEAD_QK * 2 + V_DIM_PER_QK * 2 + BA_DIM_T: tl.constexpr = V_HEADS_PER_QK * 2 + + Q_TOTAL: tl.constexpr = NUM_HEADS_QK * HEAD_QK + K_TOTAL: tl.constexpr = NUM_HEADS_QK * HEAD_QK + + row_start = vec_id * rows_per_vec + row_end = min(row_start + rows_per_vec, total_rows) + + row_offset = row_start + + iter_count = (row_end - row_start + ROWS_PER_ITER - 1) // ROWS_PER_ITER + + # ========== Main Iteration Loop ========== + for _ in tl.range(iter_count): + row_indices = tl.arange(0, ROWS_PER_ITER) + row_offset + row_mask = row_indices < row_end + + # ========== Head Iteration Loop ========== + # Iterate over each Q/K head group to extract and rearrange data + for head_id in tl.static_range(NUM_HEADS_QK): + # Byte offset to the current head's data block in mixed_qkvz + src_head_offset = head_id * QKVZ_DIM_T + + # ----- Q (Query) Extraction ----- + # Source layout: mixed_qkvz[row, head_id * QKVZ_DIM_T + 0:HEAD_QK] + # Dest layout: mixed_qkv[row, head_id * HEAD_QK : (head_id+1) * HEAD_QK] + q_range = tl.arange(0, HEAD_QK) + q_src = row_indices[:, None] * QKVZ_ROW_STRIDE + src_head_offset + q_range[None, :] + q_dst = row_indices[:, None] * QKV_ROW_STRIDE + head_id * HEAD_QK + q_range[None, :] + q_data = tl.load(mixed_qkvz + q_src, mask=row_mask[:, None]) + tl.store(mixed_qkv + q_dst, q_data, mask=row_mask[:, None]) + + # ----- K (Key) Extraction ----- + # Source layout: mixed_qkvz[row, head_id * QKVZ_DIM_T + HEAD_QK : +HEAD_QK] + # Dest layout: mixed_qkv[row, Q_TOTAL + head_id * HEAD_QK : ...] + # K is stored after Q in the source; in dest, K starts after all Q heads + k_src = row_indices[:, None] * QKVZ_ROW_STRIDE + src_head_offset + HEAD_QK + q_range[None, :] + k_dst = row_indices[:, None] * QKV_ROW_STRIDE + Q_TOTAL + head_id * HEAD_QK + q_range[None, :] + k_data = tl.load(mixed_qkvz + k_src, mask=row_mask[:, None]) + tl.store(mixed_qkv + k_dst, k_data, mask=row_mask[:, None]) + + # ----- V (Value) Extraction ----- + # Source layout: mixed_qkvz[row, head_id * QKVZ_DIM_T + HEAD_QK*2 : +V_DIM_PER_QK] + # Dest layout: mixed_qkv[row, Q_TOTAL + K_TOTAL + head_id * V_DIM_PER_QK : ...] + # V follows Q and K in source; in dest, V starts after all Q and K heads + v_range = tl.arange(0, V_DIM_PER_QK) + v_src = row_indices[:, None] * QKVZ_ROW_STRIDE + src_head_offset + HEAD_QK * 2 + v_range[None, :] + v_dst = ( + row_indices[:, None] * QKV_ROW_STRIDE + Q_TOTAL + K_TOTAL + head_id * V_DIM_PER_QK + v_range[None, :] + ) + v_data = tl.load(mixed_qkvz + v_src, mask=row_mask[:, None]) + tl.store(mixed_qkv + v_dst, v_data, mask=row_mask[:, None]) + + # ----- Z Extraction ----- + # Source layout: mixed_qkvz[row, head_id * QKVZ_DIM_T + HEAD_QK*2 + V_DIM_PER_QK : ...] + # Dest layout: z[row, head_id * V_DIM_PER_QK : (head_id+1) * V_DIM_PER_QK] + # Z follows V in source; output z is reshaped to [batch, num_heads_v, head_v] + z_src = ( + row_indices[:, None] * QKVZ_ROW_STRIDE + src_head_offset + HEAD_QK * 2 + V_DIM_PER_QK + v_range[None, :] + ) + z_dst = row_indices[:, None] * Z_ROW_STRIDE + head_id * V_DIM_PER_QK + v_range[None, :] + z_data = tl.load(mixed_qkvz + z_src, mask=row_mask[:, None]) + tl.store(z + z_dst, z_data, mask=row_mask[:, None]) + + # ----- B Extraction ----- + # Source layout: mixed_ba[row, head_id * BA_DIM_T : +V_HEADS_PER_QK] + # Dest layout: b[row, head_id * V_HEADS_PER_QK : (head_id+1) * V_HEADS_PER_QK] + b_range = tl.arange(0, V_HEADS_PER_QK) + ba_head_offset = head_id * BA_DIM_T + b_src = row_indices[:, None] * BA_ROW_STRIDE + ba_head_offset + b_range[None, :] + b_dst = row_indices[:, None] * BA_OUT_ROW_STRIDE + head_id * V_HEADS_PER_QK + b_range[None, :] + b_data = tl.load(mixed_ba + b_src, mask=row_mask[:, None]) + tl.store(b + b_dst, b_data, mask=row_mask[:, None]) + + # ----- A Extraction ----- + # Source layout: mixed_ba[row, head_id * BA_DIM_T + V_HEADS_PER_QK : ...] + # Dest layout: a[row, head_id * V_HEADS_PER_QK : ...] (same as b_dst) + # A follows B in source; output layout is same as B + a_src = row_indices[:, None] * BA_ROW_STRIDE + ba_head_offset + V_HEADS_PER_QK + b_range[None, :] + a_data = tl.load(mixed_ba + a_src, mask=row_mask[:, None]) + tl.store(a + b_dst, a_data, mask=row_mask[:, None]) + + row_offset += ROWS_PER_ITER def fused_qkvzba_split_reshape_cat( @@ -80,6 +150,20 @@ def fused_qkvzba_split_reshape_cat( head_v, ): batch, seq_len = mixed_qkvz.shape[0], 1 + total_rows = batch * seq_len + + v_heads_per_qk = num_heads_v // num_heads_qk + v_dim_per_qk = v_heads_per_qk * head_v + qkvz_dim_t = head_qk * 2 + v_dim_per_qk * 2 + ba_dim_t = v_heads_per_qk * 2 + + # row stride + qkvz_row_stride = num_heads_qk * qkvz_dim_t + ba_row_stride = num_heads_qk * ba_dim_t + qkv_row_stride = num_heads_qk * head_qk * 2 + num_heads_v * head_v + z_row_stride = num_heads_v * head_v + ba_out_row_stride = num_heads_v + qkv_dim_t = num_heads_qk * head_qk * 2 + num_heads_v * head_v mixed_qkv = torch.empty( [batch * seq_len, qkv_dim_t], @@ -96,8 +180,28 @@ def fused_qkvzba_split_reshape_cat( dtype=mixed_ba.dtype, device=mixed_ba.device, ) - a = torch.empty_like(b) - grid = (batch * seq_len, num_heads_qk) + a = torch.empty( + [batch * seq_len, num_heads_v], + dtype=mixed_ba.dtype, + device=mixed_ba.device, + ) + + num_vectorcore = get_vectorcore_num() + + grid_size = min(num_vectorcore, total_rows) + grid_size = max(1, grid_size) + + rows_per_vec = triton.cdiv(total_rows, grid_size) + + ub_size = 85 * 1024 // mixed_qkvz.element_size() + + elements_per_row = qkvz_row_stride + ba_row_stride + qkv_row_stride + z_row_stride + ba_out_row_stride * 2 + + rows_per_iter = max(1, ub_size // elements_per_row) + rows_per_iter = triton.next_power_of_2(rows_per_iter) + rows_per_iter = min(rows_per_iter, rows_per_vec, MAX_ROWS_PER_ITER) + + grid = (grid_size, 1) fused_qkvzba_split_reshape_cat_kernel[grid]( mixed_qkv, z, @@ -109,7 +213,13 @@ def fused_qkvzba_split_reshape_cat( num_heads_v, head_qk, head_v, - num_warps=1, - num_stages=3, + total_rows, + rows_per_vec, + qkvz_row_stride, + ba_row_stride, + qkv_row_stride, + z_row_stride, + ba_out_row_stride, + rows_per_iter, ) return mixed_qkv, z, b, a diff --git a/vllm_ascend/patch/worker/patch_qwen3_next.py b/vllm_ascend/patch/worker/patch_qwen3_next.py index ebd736f2..8ce2f0bd 100644 --- a/vllm_ascend/patch/worker/patch_qwen3_next.py +++ b/vllm_ascend/patch/worker/patch_qwen3_next.py @@ -18,7 +18,6 @@ import torch from einops import rearrange from torch import nn -from vllm.config import CUDAGraphMode from vllm.forward_context import get_forward_context from vllm.model_executor.layers.fla.ops import chunk_gated_delta_rule, fused_recurrent_gated_delta_rule from vllm.model_executor.layers.mamba.abstract import MambaBase @@ -52,23 +51,15 @@ class AscendQwen3Next_GatedDeltaNet(nn.Module, MambaBase): # ============================================================ projected_states_qkvz, _ = self.in_proj_qkvz(hidden_states) projected_states_ba, _ = self.in_proj_ba(hidden_states) - forward_context = get_forward_context() - is_cuda_graph = forward_context.cudagraph_runtime_mode != CUDAGraphMode.NONE - # triton grid should be less than 66536 - divide_grid = projected_states_qkvz.shape[0] * triton.cdiv(self.num_k_heads, self.tp_size) - if self.num_v_heads // self.num_k_heads in [1, 2, 4] and is_cuda_graph and divide_grid < 65536: - mixed_qkv, z, b, a = fused_qkvzba_split_reshape_cat( - projected_states_qkvz, - projected_states_ba, - triton.cdiv(self.num_k_heads, self.tp_size), - triton.cdiv(self.num_v_heads, self.tp_size), - self.head_k_dim, - self.head_v_dim, - ) - else: - query, key, value, z, b, a = self.fix_query_key_value_ordering(projected_states_qkvz, projected_states_ba) - query, key, value = map(lambda x: rearrange(x, "l p d -> l (p d)"), (query, key, value)) - mixed_qkv = torch.cat((query, key, value), dim=-1) + + mixed_qkv, z, b, a = fused_qkvzba_split_reshape_cat( + projected_states_qkvz, + projected_states_ba, + triton.cdiv(self.num_k_heads, self.tp_size), + triton.cdiv(self.num_v_heads, self.tp_size), + self.head_k_dim, + self.head_v_dim, + ) # ============================================================ # Part 2: Core Attention (Custom Op)