[pref] qwen3_next add triton ops : fused_sigmoid_gating_delta_rule_update (#4818)

### What this PR does / why we need it?
qwen3_next add fused_sigmoid_gating_delta_rule_update op which fused
fused_gdn_gating+fused_recurrent_gated_delta_rule

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

Signed-off-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com>
Co-authored-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com>
This commit is contained in:
XiaoxinWang
2025-12-19 16:34:11 +08:00
committed by GitHub
parent 118b0ed346
commit 0cc3fc357f
5 changed files with 539 additions and 1 deletions

View File

@@ -0,0 +1,65 @@
import torch
from vllm.model_executor.layers.fla.ops import fused_recurrent_gated_delta_rule
from vllm.model_executor.models.qwen3_next import fused_gdn_gating
from vllm_ascend.ops.triton.fla.sigmoid_gating import \
fused_sigmoid_gating_delta_rule_update
def test_triton_fusion_ops():
q = torch.randn(1, 1, 4, 128, dtype=torch.bfloat16).npu()
k = torch.randn(1, 1, 4, 128, dtype=torch.bfloat16).npu()
v = torch.randn(1, 1, 8, 128, dtype=torch.bfloat16).npu()
a = torch.tensor([[
-2.6094, -0.2617, -0.3848, 2.2656, 3.6250, -0.7383, -1.0938, -0.0505
]]).bfloat16().npu()
b = torch.tensor(
[[0.4277, 0.8906, 1.6875, 2.3750, 4.1562, 0.3809, 1.0625,
3.6719]]).bfloat16().npu()
ssm_state = torch.randn(1, 8, 128, 128, dtype=torch.bfloat16).npu()
non_spec_state_indices_tensor = torch.tensor([2]).int().npu()
non_spec_query_start_loc = torch.tensor([0, 1]).int().npu()
a_log = torch.tensor([
-2.6875, -3.2031, -3.3438, -2.7812, -3.0625, -4.0312, -5.3750, 5.7188
]).bfloat16().npu()
dt_bias = torch.tensor(
[-4.7812, -5.0938, -5.5000, 9.4375, 7.6250, -4.3750, -3.0938,
0.9688]).bfloat16().npu()
core_attn_out_non_spec_fused = fused_sigmoid_gating_delta_rule_update(
A_log=a_log.contiguous(),
dt_bias=dt_bias.contiguous(),
q=q.contiguous(),
k=k.contiguous(),
v=v.contiguous(),
a=a.contiguous(),
b=b.contiguous(),
initial_state_source=ssm_state,
initial_state_indices=non_spec_state_indices_tensor,
cu_seqlens=non_spec_query_start_loc,
use_qk_l2norm_in_kernel=True,
softplus_beta=1.0,
softplus_threshold=20.0,
)
g, beta = fused_gdn_gating(a_log, a, b, dt_bias)
g_non_spec = g
beta_non_spec = beta
core_attn_out_non_spec_split, last_recurrent_state = (
fused_recurrent_gated_delta_rule(
q=q,
k=k,
v=v,
g=g_non_spec,
beta=beta_non_spec,
initial_state=ssm_state,
inplace_final_state=True,
cu_seqlens=non_spec_query_start_loc,
ssm_state_indices=non_spec_state_indices_tensor,
use_qk_l2norm_in_kernel=True,
))
torch.testing.assert_close(core_attn_out_non_spec_fused,
core_attn_out_non_spec_split,
rtol=1e-02,
atol=1e-02,
equal_nan=True)

View File

@@ -11,6 +11,7 @@
import os
import torch
from vllm.triton_utils import tl, tldevice, triton
if os.environ.get('FLA_USE_FAST_OPS', '0') == '1':
@@ -169,3 +170,228 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
p_ht = ht + (bos + i_t) * stride_final_state_token
p_ht = p_ht + i_hv * K * V + o_k[:, None] * V + o_v[None, :]
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
@triton.heuristics({
"USE_INITIAL_STATE": lambda args: args["h0_source"] is not None,
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
})
@triton.jit(do_not_specialize=["T"])
def fused_sigmoid_gating_delta_rule_update_kernel(
A_log,
a,
dt_bias,
softplus_beta,
softplus_threshold,
q,
k,
v,
b,
o,
h0_source,
h0_indices,
cu_seqlens,
scale,
T,
B: tl.constexpr,
H: tl.constexpr,
HV: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
USE_INITIAL_STATE: tl.constexpr,
USE_QK_L2NORM_IN_KERNEL: tl.constexpr,
IS_VARLEN: tl.constexpr,
):
"""
Fused kernel that combines sigmoid gating computation with recurrent delta rule update.
"""
i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_n, i_hv = i_nh // HV, i_nh % HV
i_h = i_hv // (HV // H)
if IS_VARLEN:
bos, eos = (
tl.load(cu_seqlens + i_n).to(tl.int64),
tl.load(cu_seqlens + i_n + 1).to(tl.int64),
)
all = T
T = eos - bos
else:
bos, eos = i_n * T, i_n * T + T
all = B * T
o_k = i_k * BK + tl.arange(0, BK)
o_v = i_v * BV + tl.arange(0, BV)
p_q = q + (bos * H + i_h) * K + o_k
p_k = k + (bos * H + i_h) * K + o_k
p_v = v + (bos * HV + i_hv) * V + o_v
p_b = b + bos * HV + i_hv
p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v
# Gating computation pointers
p_A_log = A_log + i_hv
p_a = a + bos * HV + i_hv
p_dt_bias = dt_bias + i_hv
mask_k = o_k < K
mask_v = o_v < V
mask_h = mask_k[:, None] & mask_v[None, :]
b_h = tl.zeros([BK, BV], dtype=tl.float32)
if USE_INITIAL_STATE:
idx = tl.load(h0_indices + i_n)
# if idx >= 0:
tmp0 = tl.where(idx < 0, 0, idx)
p_h0 = (h0_source + tmp0 * HV * K * V + i_hv * K * V +
o_k[:, None] * V + o_v[None, :])
temp1 = tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
temp2 = tl.zeros_like(temp1)
value0 = tl.where(idx < 0, temp2, temp1)
b_h += value0 # tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
for i in range(0, T):
# Load inputs
b_q = tl.load(p_q + i * H * K, mask=mask_k, other=0).to(tl.float32)
b_k = tl.load(p_k + i * H * K, mask=mask_k, other=0).to(tl.float32)
b_v = tl.load(p_v + i * HV * V, mask=mask_v, other=0).to(tl.float32)
b_b = tl.load(p_b + i * HV).to(tl.float32)
# Compute sigmoid gating
# Load gating parameters
b_A_log = tl.load(p_A_log).to(tl.float32)
b_a = tl.load(p_a + i * HV).to(tl.float32)
b_dt_bias = tl.load(p_dt_bias).to(tl.float32)
# Compute g = -exp(A_log) * softplus(a + dt_bias)
x = b_a + b_dt_bias
beta_x = softplus_beta * x
# Apply softplus with numerical stability
softplus_x = tl.where(
beta_x <= softplus_threshold,
(1.0 / softplus_beta) * tl.log(1.0 + tl.exp(beta_x)),
x,
)
b_g = -tl.exp(b_A_log) * softplus_x
# Compute beta = sigmoid(b)
b_beta = 1.0 / (1.0 + tl.exp(-b_b))
# Apply L2 normalization if enabled
if USE_QK_L2NORM_IN_KERNEL:
b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q)) + 1e-6)
b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k)) + 1e-6)
b_q = b_q * scale
# Apply gating to hidden state: h *= exp(g)
b_h *= tl.exp(b_g)
# Delta rule: v -= sum(h * k, dim=0)
b_v -= tl.sum(b_h * b_k[:, None], 0)
# Apply beta gating: v *= beta
b_v *= b_beta
# Update hidden state: h += k[:, None] * v[None, :]
b_h += b_k[:, None] * b_v[None, :]
# Compute output: o = sum(h * q, dim=0)
b_o = tl.sum(b_h * b_q[:, None], 0)
tl.store(p_o + i * HV * V, b_o.to(p_o.dtype.element_ty), mask=mask_v)
# # Update pointers for next timestep
# p_q += H * K
# p_k += H * K
# p_o += HV * V
# p_v += HV * V
# p_b += HV
# p_a += HV
# Store final state back to h0_source with bounds checking
if USE_INITIAL_STATE:
idx = tl.load(h0_indices + i_n)
if idx >= 0:
p_h0 = (h0_source + idx * HV * K * V + i_hv * K * V +
o_k[:, None] * V + o_v[None, :])
tl.store(p_h0, b_h.to(p_h0.dtype.element_ty), mask=mask_h)
def fused_sigmoid_gating_delta_rule_update(
A_log: torch.Tensor,
a: torch.Tensor,
dt_bias: torch.Tensor,
softplus_beta: float,
softplus_threshold: float,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
b: torch.Tensor,
initial_state_source: torch.Tensor,
initial_state_indices: torch.Tensor,
scale: float = None,
use_qk_l2norm_in_kernel: bool = False,
cu_seqlens: torch.Tensor = None,
):
"""
Fused triton implementation of sigmoid gating delta rule update.
This function uses a single fused kernel that combines both sigmoid gating computation
and the recurrent delta rule update for better performance.
"""
B, T, H, K, V = *k.shape, v.shape[-1]
HV = v.shape[2]
N = B if cu_seqlens is None else len(cu_seqlens) - 1
BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 64)
NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
assert NK == 1, "NK > 1 is not supported yet"
num_stages = 3
num_warps = 1
if scale is None:
scale = k.shape[-1]**-0.5
else:
assert scale > 0, "scale must be positive"
o = q.new_empty(NK, *v.shape)
grid = (NK, NV, N * HV)
if not initial_state_indices.is_contiguous():
initial_state_indices = initial_state_indices.contiguous()
if not initial_state_source.is_contiguous():
initial_state_source_contiguous = initial_state_source.contiguous()
if not cu_seqlens.is_contiguous():
cu_seqlens = cu_seqlens.contiguous()
fused_sigmoid_gating_delta_rule_update_kernel[grid](
A_log=A_log,
a=a,
dt_bias=dt_bias,
softplus_beta=softplus_beta,
softplus_threshold=softplus_threshold,
q=q,
k=k,
v=v,
b=b,
o=o,
h0_source=initial_state_source_contiguous,
h0_indices=initial_state_indices,
cu_seqlens=cu_seqlens,
scale=scale,
T=T,
B=B,
H=H,
HV=HV,
K=K,
V=V,
BK=BK,
BV=BV,
USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel,
num_warps=num_warps,
num_stages=num_stages,
)
initial_state_source.copy_(
initial_state_source_contiguous.view_as(initial_state_source))
o = o.squeeze(0)
return o

View File

@@ -285,3 +285,15 @@
# Future Plan:
# Remove this patch when vLLM support these operators.
#
# ** 15. File: worker/patch_qwen3_next.py**
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# 1. `vllm.model_executor.models.qwen3_next.Qwen3NextGatedDeltaNet._forward_core`
# Why:
# triton ops fused_recurrent_gated_delta_rule and fused_gdn_gating in vLLM perform not good on NPU.
# How
# add a new fused triton ops in vLLM with ascend implementation.
# Related PR (if no, explain why):
# https://github.com/vllm-project/vllm/pull/30860
# Future Plan:
# Remove this patch when vLLM support these operators.
#

View File

@@ -35,3 +35,4 @@ import vllm_ascend.patch.worker.patch_rope # noqa
import vllm_ascend.patch.worker.patch_qwen3_next # noqa
import vllm_ascend.patch.worker.patch_qwen3_next_mtp # noqa
import vllm_ascend.patch.worker.patch_rejection_sampler # noqa
import vllm_ascend.patch.worker.patch_qwen3_next # noqa

View File

@@ -18,14 +18,23 @@
import torch
from einops import rearrange
from torch import nn
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import CUDAGraphMode
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.fla.ops import (
chunk_gated_delta_rule, fused_recurrent_gated_delta_rule)
from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.model_executor.models.qwen3_next import Qwen3NextGatedDeltaNet
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_update)
from vllm.model_executor.models.qwen3_next import (Qwen3NextGatedDeltaNet,
fused_gdn_gating)
from vllm.triton_utils import triton
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata
from vllm_ascend.ops.triton.fla.fused_qkvzba_split_reshape import \
fused_qkvzba_split_reshape_cat
from vllm_ascend.ops.triton.fla.sigmoid_gating import \
fused_sigmoid_gating_delta_rule_update
class AscendQwen3Next_GatedDeltaNet(nn.Module, MambaBase):
@@ -101,5 +110,230 @@ class AscendQwen3Next_GatedDeltaNet(nn.Module, MambaBase):
core_attn_out = rearrange(core_attn_out, "... h d -> ... (h d)")
output[:num_tokens], _ = self.out_proj(core_attn_out)
def _forward_core(
self,
mixed_qkv: torch.Tensor,
b: torch.Tensor,
a: torch.Tensor,
core_attn_out: torch.Tensor,
):
"""
Core attention computation (called by custom op).
"""
forward_context = get_forward_context()
attn_metadata: AttentionMetadata = forward_context.attn_metadata
if attn_metadata is None:
# V1 profile run
return
assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix]
assert isinstance(attn_metadata, GDNAttentionMetadata)
has_initial_state = attn_metadata.has_initial_state
spec_query_start_loc = attn_metadata.spec_query_start_loc
non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc
spec_sequence_masks = attn_metadata.spec_sequence_masks
spec_token_indx = attn_metadata.spec_token_indx
non_spec_token_indx = attn_metadata.non_spec_token_indx
spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor # noqa: E501
non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1]
num_actual_tokens = attn_metadata.num_actual_tokens
num_accepted_tokens = attn_metadata.num_accepted_tokens
mixed_qkv = mixed_qkv[:num_actual_tokens]
b = b[:num_actual_tokens]
a = a[:num_actual_tokens]
# 1. Convolution sequence transformation
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
self.conv1d.weight.size(2))
if spec_sequence_masks is not None:
if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0:
mixed_qkv_spec = mixed_qkv
mixed_qkv_non_spec = None
else:
mixed_qkv_spec = mixed_qkv.index_select(0, spec_token_indx)
mixed_qkv_non_spec = mixed_qkv.index_select(
0, non_spec_token_indx)
else:
mixed_qkv_spec = None
mixed_qkv_non_spec = mixed_qkv
# 1.1: Process the multi-query part
if spec_sequence_masks is not None:
mixed_qkv_spec = causal_conv1d_update(
mixed_qkv_spec,
conv_state,
conv_weights,
self.conv1d.bias,
self.activation,
conv_state_indices=spec_state_indices_tensor[:, 0]
[:attn_metadata.num_spec_decodes],
num_accepted_tokens=num_accepted_tokens,
query_start_loc=spec_query_start_loc,
max_query_len=spec_state_indices_tensor.size(-1),
validate_data=False,
)
# 1.2: Process the remaining part
if attn_metadata.num_prefills > 0:
if mixed_qkv_non_spec is not None:
mixed_qkv_non_spec_T = mixed_qkv_non_spec.transpose(0, 1)
# - "cache_indices" updates the conv_state cache in positions
# pointed to by "state_indices_tensor"
mixed_qkv_non_spec = causal_conv1d_fn(
mixed_qkv_non_spec_T,
conv_weights,
self.conv1d.bias,
activation=self.activation,
conv_states=conv_state,
has_initial_state=has_initial_state,
cache_indices=non_spec_state_indices_tensor,
query_start_loc=non_spec_query_start_loc,
metadata=attn_metadata,
).transpose(0, 1)
elif attn_metadata.num_decodes > 0:
mixed_qkv_non_spec = causal_conv1d_update(
mixed_qkv_non_spec,
conv_state,
conv_weights,
self.conv1d.bias,
self.activation,
conv_state_indices=
non_spec_state_indices_tensor[:attn_metadata.
num_actual_tokens],
validate_data=True,
)
else:
mixed_qkv_non_spec = None
query_spec, key_spec, value_spec = self.rearrange_mixed_qkv(
mixed_qkv_spec)
query_non_spec, key_non_spec, value_non_spec = self.rearrange_mixed_qkv(
mixed_qkv_non_spec)
if attn_metadata.num_prefills > 0 or spec_sequence_masks is not None:
g, beta = fused_gdn_gating(self.A_log, a, b, self.dt_bias)
if spec_sequence_masks is not None:
if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0:
g_spec = g
beta_spec = beta
g_non_spec = None
beta_non_spec = None
else:
g_spec = g.index_select(1, spec_token_indx)
beta_spec = beta.index_select(1, spec_token_indx)
g_non_spec = g.index_select(1, non_spec_token_indx)
beta_non_spec = beta.index_select(1, non_spec_token_indx)
else:
g_spec = None
beta_spec = None
g_non_spec = g
beta_non_spec = beta
# 2. Recurrent attention
# 2.1: Process the multi-query part
if spec_sequence_masks is not None:
core_attn_out_spec, last_recurrent_state = fused_recurrent_gated_delta_rule(
q=query_spec,
k=key_spec,
v=value_spec,
g=g_spec,
beta=beta_spec,
initial_state=ssm_state,
inplace_final_state=True,
cu_seqlens=spec_query_start_loc[:attn_metadata.
num_spec_decodes + 1],
ssm_state_indices=spec_state_indices_tensor,
num_accepted_tokens=num_accepted_tokens,
use_qk_l2norm_in_kernel=True,
)
else:
core_attn_out_spec, last_recurrent_state = None, None
# 2.2: Process the remaining part
if attn_metadata.num_prefills > 0:
initial_state = ssm_state[
non_spec_state_indices_tensor].contiguous()
initial_state[~has_initial_state, ...] = 0
(
core_attn_out_non_spec,
last_recurrent_state,
) = chunk_gated_delta_rule(
q=query_non_spec,
k=key_non_spec,
v=value_non_spec,
g=g_non_spec,
beta=beta_non_spec,
initial_state=initial_state,
output_final_state=True,
cu_seqlens=non_spec_query_start_loc,
head_first=False,
use_qk_l2norm_in_kernel=True,
)
# Init cache
ssm_state[
non_spec_state_indices_tensor] = last_recurrent_state.to(
ssm_state.dtype)
elif attn_metadata.num_decodes > 0:
core_attn_out_non_spec, last_recurrent_state = (
fused_recurrent_gated_delta_rule(
q=query_non_spec,
k=key_non_spec,
v=value_non_spec,
g=g_non_spec,
beta=beta_non_spec,
initial_state=ssm_state,
inplace_final_state=True,
cu_seqlens=non_spec_query_start_loc[:attn_metadata.
num_decodes + 1],
ssm_state_indices=non_spec_state_indices_tensor,
use_qk_l2norm_in_kernel=True,
))
else:
core_attn_out_non_spec, last_recurrent_state = None, None
elif attn_metadata.num_decodes > 0:
core_attn_out_non_spec = fused_sigmoid_gating_delta_rule_update(
A_log=self.A_log.contiguous(),
dt_bias=self.dt_bias.contiguous(),
q=query_non_spec.contiguous(),
k=key_non_spec.contiguous(),
v=value_non_spec.contiguous(),
a=a.contiguous(),
b=b.contiguous(),
initial_state_source=ssm_state,
initial_state_indices=non_spec_state_indices_tensor,
cu_seqlens=non_spec_query_start_loc,
use_qk_l2norm_in_kernel=True,
softplus_beta=1.0,
softplus_threshold=20.0,
)
# 3. Merge core attention output
if spec_sequence_masks is not None and core_attn_out_non_spec is not None:
merged_out = torch.empty(
(1, num_actual_tokens, *core_attn_out_spec.shape[2:]),
dtype=core_attn_out_non_spec.dtype,
device=core_attn_out_non_spec.device,
)
merged_out.index_copy_(1, spec_token_indx, core_attn_out_spec)
merged_out.index_copy_(1, non_spec_token_indx,
core_attn_out_non_spec)
core_attn_out[:num_actual_tokens] = merged_out.squeeze(0)
elif spec_sequence_masks is not None:
core_attn_out[:num_actual_tokens] = core_attn_out_spec.squeeze(0)
else:
core_attn_out[:num_actual_tokens] = core_attn_out_non_spec.squeeze(
0)
Qwen3NextGatedDeltaNet.forward = AscendQwen3Next_GatedDeltaNet.forward
Qwen3NextGatedDeltaNet._forward_core = AscendQwen3Next_GatedDeltaNet._forward_core