qwen3_next add triton ops : fused_qkvzba_split_reshape (#4788)

### What this PR does / why we need it?
add triton ops fused_qkvzba_split_reshape_cat for qwen3_next
GatedDeltaNet
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
UT 
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: ZT-AIA <1028681969@qq.com>
Signed-off-by: ZT-AIA <63220130+ZT-AIA@users.noreply.github.com>
This commit is contained in:
ZT-AIA
2025-12-18 11:31:04 +08:00
committed by GitHub
parent 07014e2101
commit 39fb9e7c83
4 changed files with 237 additions and 1 deletions

View File

@@ -0,0 +1,118 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
#
# This file contains code copied from the flash-linear-attention project.
# The original source code was licensed under the MIT license and included
# the following copyright notice:
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# ruff: noqa: E501
# mypy: ignore-errors
import torch
from vllm.triton_utils import HAS_TRITON, tl, triton
if HAS_TRITON:
import torch_npu._inductor # noqa: F401
@triton.jit
def fused_qkvzba_split_reshape_cat_kernel(
mixed_qkv,
z,
b,
a,
mixed_qkvz,
mixed_ba,
NUM_HEADS_QK: tl.constexpr,
NUM_HEADS_V: tl.constexpr,
HEAD_QK: tl.constexpr,
HEAD_V: tl.constexpr,
):
i_bs, i_qk = tl.program_id(0), tl.program_id(1)
QKVZ_DIM_T: tl.constexpr = HEAD_QK * 2 + NUM_HEADS_V // NUM_HEADS_QK * HEAD_V * 2
BA_DIM_T: tl.constexpr = NUM_HEADS_V // NUM_HEADS_QK * 2
QKV_DIM_T: tl.constexpr = HEAD_QK * 2 + NUM_HEADS_V // NUM_HEADS_QK * HEAD_V
q_end: tl.constexpr = HEAD_QK
blk_q_ptr = (mixed_qkvz + i_bs * NUM_HEADS_QK * QKVZ_DIM_T +
i_qk * QKVZ_DIM_T + tl.arange(0, q_end))
k_end: tl.constexpr = q_end + HEAD_QK
blk_k_ptr = (mixed_qkvz + i_bs * NUM_HEADS_QK * QKVZ_DIM_T +
i_qk * QKVZ_DIM_T + tl.arange(q_end, k_end))
v_end: tl.constexpr = k_end + NUM_HEADS_V // NUM_HEADS_QK * HEAD_V
blk_v_ptr = (mixed_qkvz + i_bs * NUM_HEADS_QK * QKVZ_DIM_T +
i_qk * QKVZ_DIM_T + tl.arange(k_end, v_end))
z_end: tl.constexpr = v_end + NUM_HEADS_V // NUM_HEADS_QK * HEAD_V
blk_z_ptr = (mixed_qkvz + i_bs * NUM_HEADS_QK * QKVZ_DIM_T +
i_qk * QKVZ_DIM_T + tl.arange(v_end, z_end))
blk_q_st_ptr = (mixed_qkv + i_bs * NUM_HEADS_QK * QKV_DIM_T +
i_qk * HEAD_QK + tl.arange(0, HEAD_QK))
blk_k_st_ptr = (mixed_qkv + i_bs * NUM_HEADS_QK * QKV_DIM_T +
NUM_HEADS_QK * HEAD_QK + i_qk * HEAD_QK +
tl.arange(0, HEAD_QK))
blk_v_st_ptr = (mixed_qkv + i_bs * NUM_HEADS_QK * QKV_DIM_T +
NUM_HEADS_QK * HEAD_QK * 2 +
i_qk * HEAD_V * NUM_HEADS_V // NUM_HEADS_QK +
tl.arange(0, HEAD_V * NUM_HEADS_V // NUM_HEADS_QK))
blk_z_st_ptr = (z + i_bs * NUM_HEADS_V * HEAD_V +
i_qk * HEAD_V * NUM_HEADS_V // NUM_HEADS_QK +
tl.arange(0, HEAD_V * NUM_HEADS_V // NUM_HEADS_QK))
tl.store(blk_q_st_ptr, tl.load(blk_q_ptr))
tl.store(blk_k_st_ptr, tl.load(blk_k_ptr))
tl.store(blk_v_st_ptr, tl.load(blk_v_ptr))
tl.store(blk_z_st_ptr, tl.load(blk_z_ptr))
b_end: tl.constexpr = NUM_HEADS_V // NUM_HEADS_QK
a_end: tl.constexpr = b_end + NUM_HEADS_V // NUM_HEADS_QK
for i in tl.static_range(b_end):
blk_b_ptr = mixed_ba + i_bs * NUM_HEADS_QK * BA_DIM_T + i_qk * BA_DIM_T + i
blk_b_st_ptr = b + i_bs * NUM_HEADS_V + i_qk * NUM_HEADS_V // NUM_HEADS_QK + i
tl.store(blk_b_st_ptr, tl.load(blk_b_ptr))
for i in tl.static_range(b_end, a_end):
blk_a_ptr = mixed_ba + i_bs * NUM_HEADS_QK * BA_DIM_T + i_qk * BA_DIM_T + i
blk_a_st_ptr = (a + i_bs * NUM_HEADS_V +
i_qk * NUM_HEADS_V // NUM_HEADS_QK + (i - b_end))
tl.store(blk_a_st_ptr, tl.load(blk_a_ptr))
def fused_qkvzba_split_reshape_cat(
mixed_qkvz,
mixed_ba,
num_heads_qk,
num_heads_v,
head_qk,
head_v,
):
batch, seq_len = mixed_qkvz.shape[0], 1
qkv_dim_t = num_heads_qk * head_qk * 2 + num_heads_v * head_v
mixed_qkv = torch.empty(
[batch * seq_len, qkv_dim_t],
dtype=mixed_qkvz.dtype,
device=mixed_qkvz.device,
)
z = torch.empty(
[batch * seq_len, num_heads_v, head_v],
dtype=mixed_qkvz.dtype,
device=mixed_qkvz.device,
)
b = torch.empty(
[batch * seq_len, num_heads_v],
dtype=mixed_ba.dtype,
device=mixed_ba.device,
)
a = torch.empty_like(b)
grid = (batch * seq_len, num_heads_qk)
fused_qkvzba_split_reshape_cat_kernel[grid](
mixed_qkv,
z,
b,
a,
mixed_qkvz,
mixed_ba,
num_heads_qk,
num_heads_v,
head_qk,
head_v,
num_warps=1,
num_stages=3,
)
return mixed_qkv, z, b, a

View File

@@ -272,4 +272,16 @@
# 1. make these functions as class func of RejectionSampler, create AscendRejectionSampler
# to override them, then delete the patch file `worker/patch_rejection_sampler.py`.
# 2. make these functions as costom op, then remove AscendRejectionSampler
#
#
# ** 14.File: worker/patch_qwen3_next.py**
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# 1. `vllm.model_executor.models.qwen3_next.Qwen3NextGatedDeltaNet.forward`
# Why:
# The Qwen3Next GatedDeltaNet forward cannot directly add custom operators.
# How
# Add a branch in Qwen3NextGatedDeltaNet.forward to adapt to fused_qkvzba_split_reshape_cat.
# Related PR (if no, explain why):
# https://github.com/vllm-project/vllm/pull/30863
# Future Plan:
# Remove this patch when vLLM support these operators.
#

View File

@@ -32,5 +32,6 @@ import vllm_ascend.patch.worker.patch_qwen2_5_vl # noqa
import vllm_ascend.patch.worker.patch_qwen2_5_omni # noqa
import vllm_ascend.patch.worker.patch_qwen3_vl # noqa
import vllm_ascend.patch.worker.patch_rope # noqa
import vllm_ascend.patch.worker.patch_qwen3_next # noqa
import vllm_ascend.patch.worker.patch_qwen3_next_mtp # noqa
import vllm_ascend.patch.worker.patch_rejection_sampler # noqa

View File

@@ -0,0 +1,105 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#from collections.abc import Iterable
import torch
from einops import rearrange
from torch import nn
from vllm.config import CUDAGraphMode
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.model_executor.models.qwen3_next import Qwen3NextGatedDeltaNet
from vllm.triton_utils import triton
from vllm_ascend.ops.triton.fla.fused_qkvzba_split_reshape import \
fused_qkvzba_split_reshape_cat
class AscendQwen3Next_GatedDeltaNet(nn.Module, MambaBase):
def forward(
self,
hidden_states: torch.Tensor,
output: torch.Tensor,
):
"""
Forward pass with three parts:
1. Input projection
2. Core attention (custom op)
3. Output projection
"""
num_tokens = hidden_states.size(0)
# ============================================================
# Part 1: Input Projection
# ============================================================
projected_states_qkvz, _ = self.in_proj_qkvz(hidden_states)
projected_states_ba, _ = self.in_proj_ba(hidden_states)
forward_context = get_forward_context()
is_cuda_graph = forward_context.cudagraph_runtime_mode != CUDAGraphMode.NONE
# triton grid should be less than 66536
divide_grid = projected_states_qkvz.shape[0] * triton.cdiv(
self.num_k_heads, self.tp_size)
if self.num_v_heads // self.num_k_heads in [1, 2, 4] and \
is_cuda_graph and divide_grid < 65536:
mixed_qkv, z, b, a = fused_qkvzba_split_reshape_cat(
projected_states_qkvz,
projected_states_ba,
triton.cdiv(self.num_k_heads, self.tp_size),
triton.cdiv(self.num_v_heads, self.tp_size),
self.head_k_dim,
self.head_v_dim,
)
else:
query, key, value, z, b, a = self.fix_query_key_value_ordering(
projected_states_qkvz, projected_states_ba)
query, key, value = map(lambda x: rearrange(x, 'l p d -> l (p d)'),
(query, key, value))
mixed_qkv = torch.cat((query, key, value), dim=-1)
# ============================================================
# Part 2: Core Attention (Custom Op)
# ============================================================
# Note: we should not use torch.empty here like other attention backends,
# see discussions in https://github.com/vllm-project/vllm/pull/28182
core_attn_out = torch.zeros(
(num_tokens, self.num_v_heads // self.tp_size, self.head_v_dim),
dtype=hidden_states.dtype,
device=hidden_states.device,
)
torch.ops.vllm.gdn_attention_core(
mixed_qkv,
b,
a,
core_attn_out,
self.prefix,
)
# ============================================================
# Part 3: Output Projection
# ============================================================
z_shape_og = z.shape
# Reshape input data into 2D tensor
core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1])
z = z.reshape(-1, z.shape[-1])
core_attn_out = self.norm(core_attn_out, z)
core_attn_out = core_attn_out.reshape(z_shape_og)
core_attn_out = rearrange(core_attn_out, "... h d -> ... (h d)")
output[:num_tokens], _ = self.out_proj(core_attn_out)
Qwen3NextGatedDeltaNet.forward = AscendQwen3Next_GatedDeltaNet.forward