From b2c121637fd8b8045e66e24ea0f63cb17ffb3b69 Mon Sep 17 00:00:00 2001 From: Ascendyh Date: Mon, 22 Dec 2025 14:09:19 +0800 Subject: [PATCH] [task] Add fused gdn gating triton kernel (#4304) ### What this PR does / why we need it? This commit introduces a Triton-based fused GDN gating kernel for Ascend NPU, aimed at improving performance in the Gated Delta Net workflow. ### Does this PR introduce _any_ user-facing change? It only adds and refactors internal Triton kernels and wrappers for Ascend. These are backend implementation details. There are no new APIs, flags, CLI options, or behavior changes visible to end users. ### How was this patch tested? - vLLM version: v0.12.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 --------- Signed-off-by: Ascendyh --- vllm_ascend/ops/triton/fused_gdn_gating.py | 118 +++++++++++++++++++ vllm_ascend/patch/__init__.py | 10 ++ vllm_ascend/patch/worker/patch_qwen3_next.py | 10 +- 3 files changed, 135 insertions(+), 3 deletions(-) create mode 100644 vllm_ascend/ops/triton/fused_gdn_gating.py diff --git a/vllm_ascend/ops/triton/fused_gdn_gating.py b/vllm_ascend/ops/triton/fused_gdn_gating.py new file mode 100644 index 00000000..dfd5dde2 --- /dev/null +++ b/vllm_ascend/ops/triton/fused_gdn_gating.py @@ -0,0 +1,118 @@ +# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/qwen3_next.py +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch +from vllm.triton_utils import tl, triton + +from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num + +UNIFIED_BUFFER_SIZE = 1572864 + + +@triton.jit +def fused_gdn_gating_kernel( + g, + beta_output, + A_log, + a, + b, + dt_bias, + seq_len, + NUM_HEADS: tl.constexpr, + NUM_BATCHES: tl.constexpr, + beta: tl.constexpr, + threshold: tl.constexpr, + BLK_HEADS: tl.constexpr, + COL_ITER: tl.constexpr, + BLK_BATCHES: tl.constexpr, + ROW_ITER: tl.constexpr, +): + i_b, i_s = tl.program_id(0), tl.program_id(1) + for row_idx in range(0, ROW_ITER): + batch_off = i_b * ROW_ITER * BLK_BATCHES + row_idx * BLK_BATCHES + tl.arange( + 0, BLK_BATCHES) + + for col_idx in range(0, COL_ITER): + head_off = col_idx * BLK_HEADS + tl.arange(0, BLK_HEADS) + + off = batch_off[:, + None] * seq_len * NUM_HEADS + i_s * NUM_HEADS + head_off[ + None, :] + head_mask = head_off < NUM_HEADS + mask = head_mask[None, :] & (batch_off[:, None] < NUM_BATCHES) + + blk_A_log = tl.load(A_log + head_off, mask=head_mask) + blk_a = tl.load(a + off, mask=mask) + blk_b = tl.load(b + off, mask=mask) + blk_bias = tl.load(dt_bias + head_off, mask=head_mask) + + x = blk_a.to(tl.float32) + blk_bias.to(tl.float32)[None, :] + softplus_x = tl.where(beta * x <= threshold, + (1 / beta) * tl.log(1 + tl.exp(beta * x)), x) + + blk_g = -tl.exp(blk_A_log.to(tl.float32)) * softplus_x + tl.store(g + off, blk_g.to(g.dtype.element_ty), mask=mask) + + # compute beta_output = sigmoid(b) + blk_beta_output = tl.sigmoid(blk_b.to(tl.float32)) + tl.store(beta_output + off, + blk_beta_output.to(beta_output.dtype.element_ty), + mask=mask) + + +def fused_gdn_gating_patch( + A_log: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + dt_bias: torch.Tensor, + beta: float = 1.0, + threshold: float = 20.0, +) -> tuple[torch.Tensor, torch.Tensor]: + batch, num_heads = a.shape + seq_len = 1 + + num_cores = get_vectorcore_num() + + BLK_HEADS = 8 + COL_ITER = triton.cdiv(num_heads, BLK_HEADS) + + if batch <= num_cores: + progs = batch + BLK_BATCHES = 1 + ROW_ITER = 1 + else: + progs = num_cores + FACTOR = 8 * num_heads + row_per_core = triton.cdiv(batch, num_cores) + BLK_BATCHES = triton.next_power_of_2( + triton.cdiv(UNIFIED_BUFFER_SIZE, FACTOR * BLK_HEADS) // + a.element_size()) // 2 + ROW_ITER = triton.cdiv(row_per_core, BLK_BATCHES) + + g = torch.empty(1, batch, num_heads, dtype=torch.float32, device=a.device) + beta_output = torch.empty(1, + batch, + num_heads, + dtype=b.dtype, + device=b.device) + + grid = (progs, seq_len) + fused_gdn_gating_kernel[grid]( + g, + beta_output, + A_log, + a, + b, + dt_bias, + seq_len, + num_heads, + batch, + beta, + threshold, + BLK_HEADS=BLK_HEADS, + COL_ITER=COL_ITER, + BLK_BATCHES=BLK_BATCHES, + ROW_ITER=ROW_ITER, + ) + return g, beta_output diff --git a/vllm_ascend/patch/__init__.py b/vllm_ascend/patch/__init__.py index 89d7c957..c3486d61 100644 --- a/vllm_ascend/patch/__init__.py +++ b/vllm_ascend/patch/__init__.py @@ -297,3 +297,13 @@ # Future Plan: # Remove this patch when vLLM support these operators. # +# 2. `vllm.model_executor.models.qwen3_next.Qwen3NextGatedDeltaNet._forward_core` +# Why: +# The Qwen3Next GatedDeltaNet _forward_core cannot directly add custom operators. +# How: +# Add a branch in Qwen3NextGatedDeltaNet._forward_core to adapt to fused_gdn_gating_patch. +# Related PR (if no, explain why): +# https://github.com/vllm-project/vllm/pull/31002 +# Future Plan: +# Remove this patch when vLLM support these operators. +# diff --git a/vllm_ascend/patch/worker/patch_qwen3_next.py b/vllm_ascend/patch/worker/patch_qwen3_next.py index 172aab8f..e7604aef 100644 --- a/vllm_ascend/patch/worker/patch_qwen3_next.py +++ b/vllm_ascend/patch/worker/patch_qwen3_next.py @@ -35,6 +35,7 @@ from vllm_ascend.ops.triton.fla.fused_qkvzba_split_reshape import \ fused_qkvzba_split_reshape_cat from vllm_ascend.ops.triton.fla.sigmoid_gating import \ fused_sigmoid_gating_delta_rule_update +from vllm_ascend.ops.triton.fused_gdn_gating import fused_gdn_gating_patch class AscendQwen3Next_GatedDeltaNet(nn.Module, MambaBase): @@ -151,7 +152,6 @@ class AscendQwen3Next_GatedDeltaNet(nn.Module, MambaBase): # 1. Convolution sequence transformation conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)) - if spec_sequence_masks is not None: if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0: mixed_qkv_spec = mixed_qkv @@ -211,14 +211,18 @@ class AscendQwen3Next_GatedDeltaNet(nn.Module, MambaBase): ) else: mixed_qkv_non_spec = None - query_spec, key_spec, value_spec = self.rearrange_mixed_qkv( mixed_qkv_spec) query_non_spec, key_non_spec, value_non_spec = self.rearrange_mixed_qkv( mixed_qkv_non_spec) if attn_metadata.num_prefills > 0 or spec_sequence_masks is not None: - g, beta = fused_gdn_gating(self.A_log, a, b, self.dt_bias) + is_cuda_graph = forward_context.cudagraph_runtime_mode != CUDAGraphMode.NONE + if (is_cuda_graph): + g, beta = fused_gdn_gating_patch(self.A_log, a, b, + self.dt_bias) + else: + g, beta = fused_gdn_gating(self.A_log, a, b, self.dt_bias) if spec_sequence_masks is not None: if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0: