[task] Add fused gdn gating triton kernel (#4304)

### What this PR does / why we need it?
This commit introduces a Triton-based fused GDN gating kernel for Ascend
NPU, aimed at improving performance in the Gated Delta Net workflow.
### Does this PR introduce _any_ user-facing change?
It only adds and refactors internal Triton kernels and wrappers for
Ascend. These are backend implementation details. There are no new APIs,
flags, CLI options, or behavior changes visible to end users.
### How was this patch tested?

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: Ascendyh <hw7osiris@outlook.com>
This commit is contained in:
Ascendyh
2025-12-22 14:09:19 +08:00
committed by GitHub
parent ea6206bb18
commit b2c121637f
3 changed files with 135 additions and 3 deletions

View File

@@ -0,0 +1,118 @@
# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/qwen3_next.py
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.triton_utils import tl, triton
from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num
UNIFIED_BUFFER_SIZE = 1572864
@triton.jit
def fused_gdn_gating_kernel(
g,
beta_output,
A_log,
a,
b,
dt_bias,
seq_len,
NUM_HEADS: tl.constexpr,
NUM_BATCHES: tl.constexpr,
beta: tl.constexpr,
threshold: tl.constexpr,
BLK_HEADS: tl.constexpr,
COL_ITER: tl.constexpr,
BLK_BATCHES: tl.constexpr,
ROW_ITER: tl.constexpr,
):
i_b, i_s = tl.program_id(0), tl.program_id(1)
for row_idx in range(0, ROW_ITER):
batch_off = i_b * ROW_ITER * BLK_BATCHES + row_idx * BLK_BATCHES + tl.arange(
0, BLK_BATCHES)
for col_idx in range(0, COL_ITER):
head_off = col_idx * BLK_HEADS + tl.arange(0, BLK_HEADS)
off = batch_off[:,
None] * seq_len * NUM_HEADS + i_s * NUM_HEADS + head_off[
None, :]
head_mask = head_off < NUM_HEADS
mask = head_mask[None, :] & (batch_off[:, None] < NUM_BATCHES)
blk_A_log = tl.load(A_log + head_off, mask=head_mask)
blk_a = tl.load(a + off, mask=mask)
blk_b = tl.load(b + off, mask=mask)
blk_bias = tl.load(dt_bias + head_off, mask=head_mask)
x = blk_a.to(tl.float32) + blk_bias.to(tl.float32)[None, :]
softplus_x = tl.where(beta * x <= threshold,
(1 / beta) * tl.log(1 + tl.exp(beta * x)), x)
blk_g = -tl.exp(blk_A_log.to(tl.float32)) * softplus_x
tl.store(g + off, blk_g.to(g.dtype.element_ty), mask=mask)
# compute beta_output = sigmoid(b)
blk_beta_output = tl.sigmoid(blk_b.to(tl.float32))
tl.store(beta_output + off,
blk_beta_output.to(beta_output.dtype.element_ty),
mask=mask)
def fused_gdn_gating_patch(
A_log: torch.Tensor,
a: torch.Tensor,
b: torch.Tensor,
dt_bias: torch.Tensor,
beta: float = 1.0,
threshold: float = 20.0,
) -> tuple[torch.Tensor, torch.Tensor]:
batch, num_heads = a.shape
seq_len = 1
num_cores = get_vectorcore_num()
BLK_HEADS = 8
COL_ITER = triton.cdiv(num_heads, BLK_HEADS)
if batch <= num_cores:
progs = batch
BLK_BATCHES = 1
ROW_ITER = 1
else:
progs = num_cores
FACTOR = 8 * num_heads
row_per_core = triton.cdiv(batch, num_cores)
BLK_BATCHES = triton.next_power_of_2(
triton.cdiv(UNIFIED_BUFFER_SIZE, FACTOR * BLK_HEADS) //
a.element_size()) // 2
ROW_ITER = triton.cdiv(row_per_core, BLK_BATCHES)
g = torch.empty(1, batch, num_heads, dtype=torch.float32, device=a.device)
beta_output = torch.empty(1,
batch,
num_heads,
dtype=b.dtype,
device=b.device)
grid = (progs, seq_len)
fused_gdn_gating_kernel[grid](
g,
beta_output,
A_log,
a,
b,
dt_bias,
seq_len,
num_heads,
batch,
beta,
threshold,
BLK_HEADS=BLK_HEADS,
COL_ITER=COL_ITER,
BLK_BATCHES=BLK_BATCHES,
ROW_ITER=ROW_ITER,
)
return g, beta_output

View File

@@ -297,3 +297,13 @@
# Future Plan:
# Remove this patch when vLLM support these operators.
#
# 2. `vllm.model_executor.models.qwen3_next.Qwen3NextGatedDeltaNet._forward_core`
# Why:
# The Qwen3Next GatedDeltaNet _forward_core cannot directly add custom operators.
# How
# Add a branch in Qwen3NextGatedDeltaNet._forward_core to adapt to fused_gdn_gating_patch.
# Related PR (if no, explain why):
# https://github.com/vllm-project/vllm/pull/31002
# Future Plan:
# Remove this patch when vLLM support these operators.
#

View File

@@ -35,6 +35,7 @@ from vllm_ascend.ops.triton.fla.fused_qkvzba_split_reshape import \
fused_qkvzba_split_reshape_cat
from vllm_ascend.ops.triton.fla.sigmoid_gating import \
fused_sigmoid_gating_delta_rule_update
from vllm_ascend.ops.triton.fused_gdn_gating import fused_gdn_gating_patch
class AscendQwen3Next_GatedDeltaNet(nn.Module, MambaBase):
@@ -151,7 +152,6 @@ class AscendQwen3Next_GatedDeltaNet(nn.Module, MambaBase):
# 1. Convolution sequence transformation
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
self.conv1d.weight.size(2))
if spec_sequence_masks is not None:
if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0:
mixed_qkv_spec = mixed_qkv
@@ -211,14 +211,18 @@ class AscendQwen3Next_GatedDeltaNet(nn.Module, MambaBase):
)
else:
mixed_qkv_non_spec = None
query_spec, key_spec, value_spec = self.rearrange_mixed_qkv(
mixed_qkv_spec)
query_non_spec, key_non_spec, value_non_spec = self.rearrange_mixed_qkv(
mixed_qkv_non_spec)
if attn_metadata.num_prefills > 0 or spec_sequence_masks is not None:
g, beta = fused_gdn_gating(self.A_log, a, b, self.dt_bias)
is_cuda_graph = forward_context.cudagraph_runtime_mode != CUDAGraphMode.NONE
if (is_cuda_graph):
g, beta = fused_gdn_gating_patch(self.A_log, a, b,
self.dt_bias)
else:
g, beta = fused_gdn_gating(self.A_log, a, b, self.dt_bias)
if spec_sequence_masks is not None:
if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0: