[task] Add fused gdn gating triton kernel (#4304)
### What this PR does / why we need it?
This commit introduces a Triton-based fused GDN gating kernel for Ascend
NPU, aimed at improving performance in the Gated Delta Net workflow.
### Does this PR introduce _any_ user-facing change?
It only adds and refactors internal Triton kernels and wrappers for
Ascend. These are backend implementation details. There are no new APIs,
flags, CLI options, or behavior changes visible to end users.
### How was this patch tested?
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: Ascendyh <hw7osiris@outlook.com>
This commit is contained in:
118
vllm_ascend/ops/triton/fused_gdn_gating.py
Normal file
118
vllm_ascend/ops/triton/fused_gdn_gating.py
Normal file
@@ -0,0 +1,118 @@
|
||||
# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/qwen3_next.py
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num
|
||||
|
||||
UNIFIED_BUFFER_SIZE = 1572864
|
||||
|
||||
|
||||
@triton.jit
|
||||
def fused_gdn_gating_kernel(
|
||||
g,
|
||||
beta_output,
|
||||
A_log,
|
||||
a,
|
||||
b,
|
||||
dt_bias,
|
||||
seq_len,
|
||||
NUM_HEADS: tl.constexpr,
|
||||
NUM_BATCHES: tl.constexpr,
|
||||
beta: tl.constexpr,
|
||||
threshold: tl.constexpr,
|
||||
BLK_HEADS: tl.constexpr,
|
||||
COL_ITER: tl.constexpr,
|
||||
BLK_BATCHES: tl.constexpr,
|
||||
ROW_ITER: tl.constexpr,
|
||||
):
|
||||
i_b, i_s = tl.program_id(0), tl.program_id(1)
|
||||
for row_idx in range(0, ROW_ITER):
|
||||
batch_off = i_b * ROW_ITER * BLK_BATCHES + row_idx * BLK_BATCHES + tl.arange(
|
||||
0, BLK_BATCHES)
|
||||
|
||||
for col_idx in range(0, COL_ITER):
|
||||
head_off = col_idx * BLK_HEADS + tl.arange(0, BLK_HEADS)
|
||||
|
||||
off = batch_off[:,
|
||||
None] * seq_len * NUM_HEADS + i_s * NUM_HEADS + head_off[
|
||||
None, :]
|
||||
head_mask = head_off < NUM_HEADS
|
||||
mask = head_mask[None, :] & (batch_off[:, None] < NUM_BATCHES)
|
||||
|
||||
blk_A_log = tl.load(A_log + head_off, mask=head_mask)
|
||||
blk_a = tl.load(a + off, mask=mask)
|
||||
blk_b = tl.load(b + off, mask=mask)
|
||||
blk_bias = tl.load(dt_bias + head_off, mask=head_mask)
|
||||
|
||||
x = blk_a.to(tl.float32) + blk_bias.to(tl.float32)[None, :]
|
||||
softplus_x = tl.where(beta * x <= threshold,
|
||||
(1 / beta) * tl.log(1 + tl.exp(beta * x)), x)
|
||||
|
||||
blk_g = -tl.exp(blk_A_log.to(tl.float32)) * softplus_x
|
||||
tl.store(g + off, blk_g.to(g.dtype.element_ty), mask=mask)
|
||||
|
||||
# compute beta_output = sigmoid(b)
|
||||
blk_beta_output = tl.sigmoid(blk_b.to(tl.float32))
|
||||
tl.store(beta_output + off,
|
||||
blk_beta_output.to(beta_output.dtype.element_ty),
|
||||
mask=mask)
|
||||
|
||||
|
||||
def fused_gdn_gating_patch(
|
||||
A_log: torch.Tensor,
|
||||
a: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
dt_bias: torch.Tensor,
|
||||
beta: float = 1.0,
|
||||
threshold: float = 20.0,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
batch, num_heads = a.shape
|
||||
seq_len = 1
|
||||
|
||||
num_cores = get_vectorcore_num()
|
||||
|
||||
BLK_HEADS = 8
|
||||
COL_ITER = triton.cdiv(num_heads, BLK_HEADS)
|
||||
|
||||
if batch <= num_cores:
|
||||
progs = batch
|
||||
BLK_BATCHES = 1
|
||||
ROW_ITER = 1
|
||||
else:
|
||||
progs = num_cores
|
||||
FACTOR = 8 * num_heads
|
||||
row_per_core = triton.cdiv(batch, num_cores)
|
||||
BLK_BATCHES = triton.next_power_of_2(
|
||||
triton.cdiv(UNIFIED_BUFFER_SIZE, FACTOR * BLK_HEADS) //
|
||||
a.element_size()) // 2
|
||||
ROW_ITER = triton.cdiv(row_per_core, BLK_BATCHES)
|
||||
|
||||
g = torch.empty(1, batch, num_heads, dtype=torch.float32, device=a.device)
|
||||
beta_output = torch.empty(1,
|
||||
batch,
|
||||
num_heads,
|
||||
dtype=b.dtype,
|
||||
device=b.device)
|
||||
|
||||
grid = (progs, seq_len)
|
||||
fused_gdn_gating_kernel[grid](
|
||||
g,
|
||||
beta_output,
|
||||
A_log,
|
||||
a,
|
||||
b,
|
||||
dt_bias,
|
||||
seq_len,
|
||||
num_heads,
|
||||
batch,
|
||||
beta,
|
||||
threshold,
|
||||
BLK_HEADS=BLK_HEADS,
|
||||
COL_ITER=COL_ITER,
|
||||
BLK_BATCHES=BLK_BATCHES,
|
||||
ROW_ITER=ROW_ITER,
|
||||
)
|
||||
return g, beta_output
|
||||
@@ -297,3 +297,13 @@
|
||||
# Future Plan:
|
||||
# Remove this patch when vLLM support these operators.
|
||||
#
|
||||
# 2. `vllm.model_executor.models.qwen3_next.Qwen3NextGatedDeltaNet._forward_core`
|
||||
# Why:
|
||||
# The Qwen3Next GatedDeltaNet _forward_core cannot directly add custom operators.
|
||||
# How:
|
||||
# Add a branch in Qwen3NextGatedDeltaNet._forward_core to adapt to fused_gdn_gating_patch.
|
||||
# Related PR (if no, explain why):
|
||||
# https://github.com/vllm-project/vllm/pull/31002
|
||||
# Future Plan:
|
||||
# Remove this patch when vLLM support these operators.
|
||||
#
|
||||
|
||||
@@ -35,6 +35,7 @@ from vllm_ascend.ops.triton.fla.fused_qkvzba_split_reshape import \
|
||||
fused_qkvzba_split_reshape_cat
|
||||
from vllm_ascend.ops.triton.fla.sigmoid_gating import \
|
||||
fused_sigmoid_gating_delta_rule_update
|
||||
from vllm_ascend.ops.triton.fused_gdn_gating import fused_gdn_gating_patch
|
||||
|
||||
|
||||
class AscendQwen3Next_GatedDeltaNet(nn.Module, MambaBase):
|
||||
@@ -151,7 +152,6 @@ class AscendQwen3Next_GatedDeltaNet(nn.Module, MambaBase):
|
||||
# 1. Convolution sequence transformation
|
||||
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
|
||||
self.conv1d.weight.size(2))
|
||||
|
||||
if spec_sequence_masks is not None:
|
||||
if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0:
|
||||
mixed_qkv_spec = mixed_qkv
|
||||
@@ -211,14 +211,18 @@ class AscendQwen3Next_GatedDeltaNet(nn.Module, MambaBase):
|
||||
)
|
||||
else:
|
||||
mixed_qkv_non_spec = None
|
||||
|
||||
query_spec, key_spec, value_spec = self.rearrange_mixed_qkv(
|
||||
mixed_qkv_spec)
|
||||
query_non_spec, key_non_spec, value_non_spec = self.rearrange_mixed_qkv(
|
||||
mixed_qkv_non_spec)
|
||||
|
||||
if attn_metadata.num_prefills > 0 or spec_sequence_masks is not None:
|
||||
g, beta = fused_gdn_gating(self.A_log, a, b, self.dt_bias)
|
||||
is_cuda_graph = forward_context.cudagraph_runtime_mode != CUDAGraphMode.NONE
|
||||
if (is_cuda_graph):
|
||||
g, beta = fused_gdn_gating_patch(self.A_log, a, b,
|
||||
self.dt_bias)
|
||||
else:
|
||||
g, beta = fused_gdn_gating(self.A_log, a, b, self.dt_bias)
|
||||
|
||||
if spec_sequence_masks is not None:
|
||||
if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0:
|
||||
|
||||
Reference in New Issue
Block a user