diff --git a/vllm_ascend/ops/triton/fla/fused_qkvzba_split_reshape.py b/vllm_ascend/ops/triton/fla/fused_qkvzba_split_reshape.py new file mode 100644 index 00000000..4129fdd0 --- /dev/null +++ b/vllm_ascend/ops/triton/fla/fused_qkvzba_split_reshape.py @@ -0,0 +1,118 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +# ruff: noqa: E501 +# mypy: ignore-errors +import torch +from vllm.triton_utils import HAS_TRITON, tl, triton + +if HAS_TRITON: + import torch_npu._inductor # noqa: F401 + + +@triton.jit +def fused_qkvzba_split_reshape_cat_kernel( + mixed_qkv, + z, + b, + a, + mixed_qkvz, + mixed_ba, + NUM_HEADS_QK: tl.constexpr, + NUM_HEADS_V: tl.constexpr, + HEAD_QK: tl.constexpr, + HEAD_V: tl.constexpr, +): + i_bs, i_qk = tl.program_id(0), tl.program_id(1) + QKVZ_DIM_T: tl.constexpr = HEAD_QK * 2 + NUM_HEADS_V // NUM_HEADS_QK * HEAD_V * 2 + BA_DIM_T: tl.constexpr = NUM_HEADS_V // NUM_HEADS_QK * 2 + QKV_DIM_T: tl.constexpr = HEAD_QK * 2 + NUM_HEADS_V // NUM_HEADS_QK * HEAD_V + q_end: tl.constexpr = HEAD_QK + blk_q_ptr = (mixed_qkvz + i_bs * NUM_HEADS_QK * QKVZ_DIM_T + + i_qk * QKVZ_DIM_T + tl.arange(0, q_end)) + k_end: tl.constexpr = q_end + HEAD_QK + blk_k_ptr = (mixed_qkvz + i_bs * NUM_HEADS_QK * QKVZ_DIM_T + + i_qk * QKVZ_DIM_T + tl.arange(q_end, k_end)) + v_end: tl.constexpr = k_end + NUM_HEADS_V // NUM_HEADS_QK * HEAD_V + blk_v_ptr = (mixed_qkvz + i_bs * NUM_HEADS_QK * QKVZ_DIM_T + + i_qk * QKVZ_DIM_T + tl.arange(k_end, v_end)) + z_end: tl.constexpr = v_end + NUM_HEADS_V // NUM_HEADS_QK * HEAD_V + blk_z_ptr = (mixed_qkvz + i_bs * NUM_HEADS_QK * QKVZ_DIM_T + + i_qk * QKVZ_DIM_T + tl.arange(v_end, z_end)) + blk_q_st_ptr = (mixed_qkv + i_bs * NUM_HEADS_QK * QKV_DIM_T + + i_qk * HEAD_QK + tl.arange(0, HEAD_QK)) + blk_k_st_ptr = (mixed_qkv + i_bs * NUM_HEADS_QK * QKV_DIM_T + + NUM_HEADS_QK * HEAD_QK + i_qk * HEAD_QK + + tl.arange(0, HEAD_QK)) + blk_v_st_ptr = (mixed_qkv + i_bs * NUM_HEADS_QK * QKV_DIM_T + + NUM_HEADS_QK * HEAD_QK * 2 + + i_qk * HEAD_V * NUM_HEADS_V // NUM_HEADS_QK + + tl.arange(0, HEAD_V * NUM_HEADS_V // NUM_HEADS_QK)) + blk_z_st_ptr = (z + i_bs * NUM_HEADS_V * HEAD_V + + i_qk * HEAD_V * NUM_HEADS_V // NUM_HEADS_QK + + tl.arange(0, HEAD_V * NUM_HEADS_V // NUM_HEADS_QK)) + tl.store(blk_q_st_ptr, tl.load(blk_q_ptr)) + tl.store(blk_k_st_ptr, tl.load(blk_k_ptr)) + tl.store(blk_v_st_ptr, tl.load(blk_v_ptr)) + tl.store(blk_z_st_ptr, tl.load(blk_z_ptr)) + b_end: tl.constexpr = NUM_HEADS_V // NUM_HEADS_QK + a_end: tl.constexpr = b_end + NUM_HEADS_V // NUM_HEADS_QK + for i in tl.static_range(b_end): + blk_b_ptr = mixed_ba + i_bs * NUM_HEADS_QK * BA_DIM_T + i_qk * BA_DIM_T + i + blk_b_st_ptr = b + i_bs * NUM_HEADS_V + i_qk * NUM_HEADS_V // NUM_HEADS_QK + i + tl.store(blk_b_st_ptr, tl.load(blk_b_ptr)) + for i in tl.static_range(b_end, a_end): + blk_a_ptr = mixed_ba + i_bs * NUM_HEADS_QK * BA_DIM_T + i_qk * BA_DIM_T + i + blk_a_st_ptr = (a + i_bs * NUM_HEADS_V + + i_qk * NUM_HEADS_V // NUM_HEADS_QK + (i - b_end)) + tl.store(blk_a_st_ptr, tl.load(blk_a_ptr)) + + +def fused_qkvzba_split_reshape_cat( + mixed_qkvz, + mixed_ba, + num_heads_qk, + num_heads_v, + head_qk, + head_v, +): + batch, seq_len = mixed_qkvz.shape[0], 1 + qkv_dim_t = num_heads_qk * head_qk * 2 + num_heads_v * head_v + mixed_qkv = torch.empty( + [batch * seq_len, qkv_dim_t], + dtype=mixed_qkvz.dtype, + device=mixed_qkvz.device, + ) + z = torch.empty( + [batch * seq_len, num_heads_v, head_v], + dtype=mixed_qkvz.dtype, + device=mixed_qkvz.device, + ) + b = torch.empty( + [batch * seq_len, num_heads_v], + dtype=mixed_ba.dtype, + device=mixed_ba.device, + ) + a = torch.empty_like(b) + grid = (batch * seq_len, num_heads_qk) + fused_qkvzba_split_reshape_cat_kernel[grid]( + mixed_qkv, + z, + b, + a, + mixed_qkvz, + mixed_ba, + num_heads_qk, + num_heads_v, + head_qk, + head_v, + num_warps=1, + num_stages=3, + ) + return mixed_qkv, z, b, a diff --git a/vllm_ascend/patch/__init__.py b/vllm_ascend/patch/__init__.py index 33da508f..4d3a9daf 100644 --- a/vllm_ascend/patch/__init__.py +++ b/vllm_ascend/patch/__init__.py @@ -272,4 +272,16 @@ # 1. make these functions as class func of RejectionSampler, create AscendRejectionSampler # to override them, then delete the patch file `worker/patch_rejection_sampler.py`. # 2. make these functions as costom op, then remove AscendRejectionSampler -# \ No newline at end of file +# +# ** 14.File: worker/patch_qwen3_next.py** +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# 1. `vllm.model_executor.models.qwen3_next.Qwen3NextGatedDeltaNet.forward` +# Why: +# The Qwen3Next GatedDeltaNet forward cannot directly add custom operators. +# How: +# Add a branch in Qwen3NextGatedDeltaNet.forward to adapt to fused_qkvzba_split_reshape_cat. +# Related PR (if no, explain why): +# https://github.com/vllm-project/vllm/pull/30863 +# Future Plan: +# Remove this patch when vLLM support these operators. +# diff --git a/vllm_ascend/patch/worker/__init__.py b/vllm_ascend/patch/worker/__init__.py index 07a77f7d..23faa67a 100644 --- a/vllm_ascend/patch/worker/__init__.py +++ b/vllm_ascend/patch/worker/__init__.py @@ -32,5 +32,6 @@ import vllm_ascend.patch.worker.patch_qwen2_5_vl # noqa import vllm_ascend.patch.worker.patch_qwen2_5_omni # noqa import vllm_ascend.patch.worker.patch_qwen3_vl # noqa import vllm_ascend.patch.worker.patch_rope # noqa +import vllm_ascend.patch.worker.patch_qwen3_next # noqa import vllm_ascend.patch.worker.patch_qwen3_next_mtp # noqa import vllm_ascend.patch.worker.patch_rejection_sampler # noqa diff --git a/vllm_ascend/patch/worker/patch_qwen3_next.py b/vllm_ascend/patch/worker/patch_qwen3_next.py new file mode 100644 index 00000000..20bf4b01 --- /dev/null +++ b/vllm_ascend/patch/worker/patch_qwen3_next.py @@ -0,0 +1,105 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#from collections.abc import Iterable + +import torch +from einops import rearrange +from torch import nn +from vllm.config import CUDAGraphMode +from vllm.forward_context import get_forward_context +from vllm.model_executor.layers.mamba.abstract import MambaBase +from vllm.model_executor.models.qwen3_next import Qwen3NextGatedDeltaNet +from vllm.triton_utils import triton + +from vllm_ascend.ops.triton.fla.fused_qkvzba_split_reshape import \ + fused_qkvzba_split_reshape_cat + + +class AscendQwen3Next_GatedDeltaNet(nn.Module, MambaBase): + + def forward( + self, + hidden_states: torch.Tensor, + output: torch.Tensor, + ): + """ + Forward pass with three parts: + 1. Input projection + 2. Core attention (custom op) + 3. Output projection + """ + num_tokens = hidden_states.size(0) + + # ============================================================ + # Part 1: Input Projection + # ============================================================ + projected_states_qkvz, _ = self.in_proj_qkvz(hidden_states) + projected_states_ba, _ = self.in_proj_ba(hidden_states) + forward_context = get_forward_context() + is_cuda_graph = forward_context.cudagraph_runtime_mode != CUDAGraphMode.NONE + # triton grid should be less than 66536 + divide_grid = projected_states_qkvz.shape[0] * triton.cdiv( + self.num_k_heads, self.tp_size) + if self.num_v_heads // self.num_k_heads in [1, 2, 4] and \ + is_cuda_graph and divide_grid < 65536: + mixed_qkv, z, b, a = fused_qkvzba_split_reshape_cat( + projected_states_qkvz, + projected_states_ba, + triton.cdiv(self.num_k_heads, self.tp_size), + triton.cdiv(self.num_v_heads, self.tp_size), + self.head_k_dim, + self.head_v_dim, + ) + else: + query, key, value, z, b, a = self.fix_query_key_value_ordering( + projected_states_qkvz, projected_states_ba) + query, key, value = map(lambda x: rearrange(x, 'l p d -> l (p d)'), + (query, key, value)) + mixed_qkv = torch.cat((query, key, value), dim=-1) + + # ============================================================ + # Part 2: Core Attention (Custom Op) + # ============================================================ + # Note: we should not use torch.empty here like other attention backends, + # see discussions in https://github.com/vllm-project/vllm/pull/28182 + core_attn_out = torch.zeros( + (num_tokens, self.num_v_heads // self.tp_size, self.head_v_dim), + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + + torch.ops.vllm.gdn_attention_core( + mixed_qkv, + b, + a, + core_attn_out, + self.prefix, + ) + + # ============================================================ + # Part 3: Output Projection + # ============================================================ + z_shape_og = z.shape + # Reshape input data into 2D tensor + core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1]) + z = z.reshape(-1, z.shape[-1]) + core_attn_out = self.norm(core_attn_out, z) + core_attn_out = core_attn_out.reshape(z_shape_og) + core_attn_out = rearrange(core_attn_out, "... h d -> ... (h d)") + output[:num_tokens], _ = self.out_proj(core_attn_out) + + +Qwen3NextGatedDeltaNet.forward = AscendQwen3Next_GatedDeltaNet.forward