[Feat]fused_qkvzba_split_reshape supports token number greater than 65536 (#6740)
### What this PR does / why we need it?
This pull request optimizes the fused_qkvzba_split_reshape_cat Triton
kernel for Qwen3-Next GatedDeltaNet model and removes the previous
conditional restrictions in the forward pass.
Key changes:
1. Refactored Triton kernel implementation: The
fused_qkvzba_split_reshape_cat_kernel has been optimized with a new
loop-based approach that supports arbitrary num_v_heads / num_k_heads
ratios and batch sizes. The kernel now uses configurable ROWS_PER_ITER
for better memory utilization .
2. The optimized kernel now handles all scenarios directly without
requiring a fallback path using fix_query_key_value_ordering and
torch.cat.
### Does this PR introduce _any_ user-facing change?
No. This is an internal optimization of the Triton kernel implementation
and does not introduce any user-facing changes.
### How was this patch tested?
CI is expected to pass with existing tests.
- vLLM version: v0.15.0
- vLLM main:
9562912cea
---------
Signed-off-by: songjianquan <songjianquan1@huawei.com>
Co-authored-by: songjianquan <songjianquan1@huawei.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -18,7 +18,6 @@
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from torch import nn
|
||||
from vllm.config import CUDAGraphMode
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.fla.ops import chunk_gated_delta_rule, fused_recurrent_gated_delta_rule
|
||||
from vllm.model_executor.layers.mamba.abstract import MambaBase
|
||||
@@ -52,23 +51,15 @@ class AscendQwen3Next_GatedDeltaNet(nn.Module, MambaBase):
|
||||
# ============================================================
|
||||
projected_states_qkvz, _ = self.in_proj_qkvz(hidden_states)
|
||||
projected_states_ba, _ = self.in_proj_ba(hidden_states)
|
||||
forward_context = get_forward_context()
|
||||
is_cuda_graph = forward_context.cudagraph_runtime_mode != CUDAGraphMode.NONE
|
||||
# triton grid should be less than 66536
|
||||
divide_grid = projected_states_qkvz.shape[0] * triton.cdiv(self.num_k_heads, self.tp_size)
|
||||
if self.num_v_heads // self.num_k_heads in [1, 2, 4] and is_cuda_graph and divide_grid < 65536:
|
||||
mixed_qkv, z, b, a = fused_qkvzba_split_reshape_cat(
|
||||
projected_states_qkvz,
|
||||
projected_states_ba,
|
||||
triton.cdiv(self.num_k_heads, self.tp_size),
|
||||
triton.cdiv(self.num_v_heads, self.tp_size),
|
||||
self.head_k_dim,
|
||||
self.head_v_dim,
|
||||
)
|
||||
else:
|
||||
query, key, value, z, b, a = self.fix_query_key_value_ordering(projected_states_qkvz, projected_states_ba)
|
||||
query, key, value = map(lambda x: rearrange(x, "l p d -> l (p d)"), (query, key, value))
|
||||
mixed_qkv = torch.cat((query, key, value), dim=-1)
|
||||
|
||||
mixed_qkv, z, b, a = fused_qkvzba_split_reshape_cat(
|
||||
projected_states_qkvz,
|
||||
projected_states_ba,
|
||||
triton.cdiv(self.num_k_heads, self.tp_size),
|
||||
triton.cdiv(self.num_v_heads, self.tp_size),
|
||||
self.head_k_dim,
|
||||
self.head_v_dim,
|
||||
)
|
||||
|
||||
# ============================================================
|
||||
# Part 2: Core Attention (Custom Op)
|
||||
|
||||
Reference in New Issue
Block a user