[Feat]fused_qkvzba_split_reshape supports token number greater than 65536 (#6740)

### What this PR does / why we need it?

This pull request optimizes the fused_qkvzba_split_reshape_cat Triton
kernel for Qwen3-Next GatedDeltaNet model and removes the previous
conditional restrictions in the forward pass.
Key changes:
1. Refactored Triton kernel implementation: The
fused_qkvzba_split_reshape_cat_kernel has been optimized with a new
loop-based approach that supports arbitrary num_v_heads / num_k_heads
ratios and batch sizes. The kernel now uses configurable ROWS_PER_ITER
for better memory utilization .
2. The optimized kernel now handles all scenarios directly without
requiring a fallback path using fix_query_key_value_ordering and
torch.cat.

### Does this PR introduce _any_ user-facing change?
No. This is an internal optimization of the Triton kernel implementation
and does not introduce any user-facing changes.

### How was this patch tested?
CI is expected to pass with existing tests.

- vLLM version: v0.15.0
- vLLM main:
9562912cea

---------

Signed-off-by: songjianquan <songjianquan1@huawei.com>
Co-authored-by: songjianquan <songjianquan1@huawei.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
songjianquan
2026-03-05 14:41:38 +08:00
committed by GitHub
parent 13777bf3f0
commit 43c8da3574
2 changed files with 167 additions and 66 deletions

View File

@@ -18,7 +18,6 @@
import torch
from einops import rearrange
from torch import nn
from vllm.config import CUDAGraphMode
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.fla.ops import chunk_gated_delta_rule, fused_recurrent_gated_delta_rule
from vllm.model_executor.layers.mamba.abstract import MambaBase
@@ -52,23 +51,15 @@ class AscendQwen3Next_GatedDeltaNet(nn.Module, MambaBase):
# ============================================================
projected_states_qkvz, _ = self.in_proj_qkvz(hidden_states)
projected_states_ba, _ = self.in_proj_ba(hidden_states)
forward_context = get_forward_context()
is_cuda_graph = forward_context.cudagraph_runtime_mode != CUDAGraphMode.NONE
# triton grid should be less than 66536
divide_grid = projected_states_qkvz.shape[0] * triton.cdiv(self.num_k_heads, self.tp_size)
if self.num_v_heads // self.num_k_heads in [1, 2, 4] and is_cuda_graph and divide_grid < 65536:
mixed_qkv, z, b, a = fused_qkvzba_split_reshape_cat(
projected_states_qkvz,
projected_states_ba,
triton.cdiv(self.num_k_heads, self.tp_size),
triton.cdiv(self.num_v_heads, self.tp_size),
self.head_k_dim,
self.head_v_dim,
)
else:
query, key, value, z, b, a = self.fix_query_key_value_ordering(projected_states_qkvz, projected_states_ba)
query, key, value = map(lambda x: rearrange(x, "l p d -> l (p d)"), (query, key, value))
mixed_qkv = torch.cat((query, key, value), dim=-1)
mixed_qkv, z, b, a = fused_qkvzba_split_reshape_cat(
projected_states_qkvz,
projected_states_ba,
triton.cdiv(self.num_k_heads, self.tp_size),
triton.cdiv(self.num_v_heads, self.tp_size),
self.head_k_dim,
self.head_v_dim,
)
# ============================================================
# Part 2: Core Attention (Custom Op)