From 0cc3fc357ff7fe1fac675fc8898ad904158ff12a Mon Sep 17 00:00:00 2001 From: XiaoxinWang <963372609@qq.com> Date: Fri, 19 Dec 2025 16:34:11 +0800 Subject: [PATCH] [pref] qwen3_next add triton ops : fused_sigmoid_gating_delta_rule_update (#4818) ### What this PR does / why we need it? qwen3_next add fused_sigmoid_gating_delta_rule_update op which fused fused_gdn_gating+fused_recurrent_gated_delta_rule - vLLM version: v0.12.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 Signed-off-by: wangxiaoxin-sherie Co-authored-by: wangxiaoxin-sherie --- .../test_fused_sigmoid_gating_delta_rule.py | 65 +++++ vllm_ascend/ops/triton/fla/sigmoid_gating.py | 226 +++++++++++++++++ vllm_ascend/patch/__init__.py | 12 + vllm_ascend/patch/worker/__init__.py | 1 + vllm_ascend/patch/worker/patch_qwen3_next.py | 236 +++++++++++++++++- 5 files changed, 539 insertions(+), 1 deletion(-) create mode 100644 tests/e2e/singlecard/test_fused_sigmoid_gating_delta_rule.py diff --git a/tests/e2e/singlecard/test_fused_sigmoid_gating_delta_rule.py b/tests/e2e/singlecard/test_fused_sigmoid_gating_delta_rule.py new file mode 100644 index 00000000..fd469ef3 --- /dev/null +++ b/tests/e2e/singlecard/test_fused_sigmoid_gating_delta_rule.py @@ -0,0 +1,65 @@ +import torch +from vllm.model_executor.layers.fla.ops import fused_recurrent_gated_delta_rule +from vllm.model_executor.models.qwen3_next import fused_gdn_gating + +from vllm_ascend.ops.triton.fla.sigmoid_gating import \ + fused_sigmoid_gating_delta_rule_update + + +def test_triton_fusion_ops(): + q = torch.randn(1, 1, 4, 128, dtype=torch.bfloat16).npu() + k = torch.randn(1, 1, 4, 128, dtype=torch.bfloat16).npu() + v = torch.randn(1, 1, 8, 128, dtype=torch.bfloat16).npu() + a = torch.tensor([[ + -2.6094, -0.2617, -0.3848, 2.2656, 3.6250, -0.7383, -1.0938, -0.0505 + ]]).bfloat16().npu() + b = torch.tensor( + [[0.4277, 0.8906, 1.6875, 2.3750, 4.1562, 0.3809, 1.0625, + 3.6719]]).bfloat16().npu() + ssm_state = torch.randn(1, 8, 128, 128, dtype=torch.bfloat16).npu() + non_spec_state_indices_tensor = torch.tensor([2]).int().npu() + non_spec_query_start_loc = torch.tensor([0, 1]).int().npu() + a_log = torch.tensor([ + -2.6875, -3.2031, -3.3438, -2.7812, -3.0625, -4.0312, -5.3750, 5.7188 + ]).bfloat16().npu() + dt_bias = torch.tensor( + [-4.7812, -5.0938, -5.5000, 9.4375, 7.6250, -4.3750, -3.0938, + 0.9688]).bfloat16().npu() + + core_attn_out_non_spec_fused = fused_sigmoid_gating_delta_rule_update( + A_log=a_log.contiguous(), + dt_bias=dt_bias.contiguous(), + q=q.contiguous(), + k=k.contiguous(), + v=v.contiguous(), + a=a.contiguous(), + b=b.contiguous(), + initial_state_source=ssm_state, + initial_state_indices=non_spec_state_indices_tensor, + cu_seqlens=non_spec_query_start_loc, + use_qk_l2norm_in_kernel=True, + softplus_beta=1.0, + softplus_threshold=20.0, + ) + + g, beta = fused_gdn_gating(a_log, a, b, dt_bias) + g_non_spec = g + beta_non_spec = beta + core_attn_out_non_spec_split, last_recurrent_state = ( + fused_recurrent_gated_delta_rule( + q=q, + k=k, + v=v, + g=g_non_spec, + beta=beta_non_spec, + initial_state=ssm_state, + inplace_final_state=True, + cu_seqlens=non_spec_query_start_loc, + ssm_state_indices=non_spec_state_indices_tensor, + use_qk_l2norm_in_kernel=True, + )) + torch.testing.assert_close(core_attn_out_non_spec_fused, + core_attn_out_non_spec_split, + rtol=1e-02, + atol=1e-02, + equal_nan=True) diff --git a/vllm_ascend/ops/triton/fla/sigmoid_gating.py b/vllm_ascend/ops/triton/fla/sigmoid_gating.py index eff4c99d..dd481ec4 100644 --- a/vllm_ascend/ops/triton/fla/sigmoid_gating.py +++ b/vllm_ascend/ops/triton/fla/sigmoid_gating.py @@ -11,6 +11,7 @@ import os +import torch from vllm.triton_utils import tl, tldevice, triton if os.environ.get('FLA_USE_FAST_OPS', '0') == '1': @@ -169,3 +170,228 @@ def fused_recurrent_gated_delta_rule_fwd_kernel( p_ht = ht + (bos + i_t) * stride_final_state_token p_ht = p_ht + i_hv * K * V + o_k[:, None] * V + o_v[None, :] tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h) + + +@triton.heuristics({ + "USE_INITIAL_STATE": lambda args: args["h0_source"] is not None, + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, +}) +@triton.jit(do_not_specialize=["T"]) +def fused_sigmoid_gating_delta_rule_update_kernel( + A_log, + a, + dt_bias, + softplus_beta, + softplus_threshold, + q, + k, + v, + b, + o, + h0_source, + h0_indices, + cu_seqlens, + scale, + T, + B: tl.constexpr, + H: tl.constexpr, + HV: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + USE_QK_L2NORM_IN_KERNEL: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + """ + Fused kernel that combines sigmoid gating computation with recurrent delta rule update. + """ + i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_n, i_hv = i_nh // HV, i_nh % HV + i_h = i_hv // (HV // H) + + if IS_VARLEN: + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int64), + tl.load(cu_seqlens + i_n + 1).to(tl.int64), + ) + all = T + T = eos - bos + else: + bos, eos = i_n * T, i_n * T + T + all = B * T + + o_k = i_k * BK + tl.arange(0, BK) + o_v = i_v * BV + tl.arange(0, BV) + + p_q = q + (bos * H + i_h) * K + o_k + p_k = k + (bos * H + i_h) * K + o_k + p_v = v + (bos * HV + i_hv) * V + o_v + p_b = b + bos * HV + i_hv + p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v + + # Gating computation pointers + p_A_log = A_log + i_hv + p_a = a + bos * HV + i_hv + p_dt_bias = dt_bias + i_hv + + mask_k = o_k < K + mask_v = o_v < V + mask_h = mask_k[:, None] & mask_v[None, :] + + b_h = tl.zeros([BK, BV], dtype=tl.float32) + if USE_INITIAL_STATE: + idx = tl.load(h0_indices + i_n) + # if idx >= 0: + tmp0 = tl.where(idx < 0, 0, idx) + p_h0 = (h0_source + tmp0 * HV * K * V + i_hv * K * V + + o_k[:, None] * V + o_v[None, :]) + temp1 = tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) + temp2 = tl.zeros_like(temp1) + value0 = tl.where(idx < 0, temp2, temp1) + b_h += value0 # tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) + + for i in range(0, T): + # Load inputs + b_q = tl.load(p_q + i * H * K, mask=mask_k, other=0).to(tl.float32) + b_k = tl.load(p_k + i * H * K, mask=mask_k, other=0).to(tl.float32) + b_v = tl.load(p_v + i * HV * V, mask=mask_v, other=0).to(tl.float32) + b_b = tl.load(p_b + i * HV).to(tl.float32) + + # Compute sigmoid gating + # Load gating parameters + b_A_log = tl.load(p_A_log).to(tl.float32) + b_a = tl.load(p_a + i * HV).to(tl.float32) + b_dt_bias = tl.load(p_dt_bias).to(tl.float32) + + # Compute g = -exp(A_log) * softplus(a + dt_bias) + x = b_a + b_dt_bias + beta_x = softplus_beta * x + # Apply softplus with numerical stability + softplus_x = tl.where( + beta_x <= softplus_threshold, + (1.0 / softplus_beta) * tl.log(1.0 + tl.exp(beta_x)), + x, + ) + b_g = -tl.exp(b_A_log) * softplus_x + + # Compute beta = sigmoid(b) + b_beta = 1.0 / (1.0 + tl.exp(-b_b)) + + # Apply L2 normalization if enabled + if USE_QK_L2NORM_IN_KERNEL: + b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q)) + 1e-6) + b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k)) + 1e-6) + + b_q = b_q * scale + + # Apply gating to hidden state: h *= exp(g) + b_h *= tl.exp(b_g) + + # Delta rule: v -= sum(h * k, dim=0) + b_v -= tl.sum(b_h * b_k[:, None], 0) + + # Apply beta gating: v *= beta + b_v *= b_beta + + # Update hidden state: h += k[:, None] * v[None, :] + b_h += b_k[:, None] * b_v[None, :] + + # Compute output: o = sum(h * q, dim=0) + b_o = tl.sum(b_h * b_q[:, None], 0) + tl.store(p_o + i * HV * V, b_o.to(p_o.dtype.element_ty), mask=mask_v) + + # # Update pointers for next timestep + # p_q += H * K + # p_k += H * K + # p_o += HV * V + # p_v += HV * V + # p_b += HV + # p_a += HV + + # Store final state back to h0_source with bounds checking + if USE_INITIAL_STATE: + idx = tl.load(h0_indices + i_n) + if idx >= 0: + p_h0 = (h0_source + idx * HV * K * V + i_hv * K * V + + o_k[:, None] * V + o_v[None, :]) + tl.store(p_h0, b_h.to(p_h0.dtype.element_ty), mask=mask_h) + + +def fused_sigmoid_gating_delta_rule_update( + A_log: torch.Tensor, + a: torch.Tensor, + dt_bias: torch.Tensor, + softplus_beta: float, + softplus_threshold: float, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + b: torch.Tensor, + initial_state_source: torch.Tensor, + initial_state_indices: torch.Tensor, + scale: float = None, + use_qk_l2norm_in_kernel: bool = False, + cu_seqlens: torch.Tensor = None, +): + """ + Fused triton implementation of sigmoid gating delta rule update. + This function uses a single fused kernel that combines both sigmoid gating computation + and the recurrent delta rule update for better performance. + """ + B, T, H, K, V = *k.shape, v.shape[-1] + HV = v.shape[2] + N = B if cu_seqlens is None else len(cu_seqlens) - 1 + BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 64) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + assert NK == 1, "NK > 1 is not supported yet" + num_stages = 3 + num_warps = 1 + + if scale is None: + scale = k.shape[-1]**-0.5 + else: + assert scale > 0, "scale must be positive" + + o = q.new_empty(NK, *v.shape) + grid = (NK, NV, N * HV) + + if not initial_state_indices.is_contiguous(): + initial_state_indices = initial_state_indices.contiguous() + if not initial_state_source.is_contiguous(): + initial_state_source_contiguous = initial_state_source.contiguous() + if not cu_seqlens.is_contiguous(): + cu_seqlens = cu_seqlens.contiguous() + + fused_sigmoid_gating_delta_rule_update_kernel[grid]( + A_log=A_log, + a=a, + dt_bias=dt_bias, + softplus_beta=softplus_beta, + softplus_threshold=softplus_threshold, + q=q, + k=k, + v=v, + b=b, + o=o, + h0_source=initial_state_source_contiguous, + h0_indices=initial_state_indices, + cu_seqlens=cu_seqlens, + scale=scale, + T=T, + B=B, + H=H, + HV=HV, + K=K, + V=V, + BK=BK, + BV=BV, + USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel, + num_warps=num_warps, + num_stages=num_stages, + ) + initial_state_source.copy_( + initial_state_source_contiguous.view_as(initial_state_source)) + o = o.squeeze(0) + return o diff --git a/vllm_ascend/patch/__init__.py b/vllm_ascend/patch/__init__.py index 4d3a9daf..89d7c957 100644 --- a/vllm_ascend/patch/__init__.py +++ b/vllm_ascend/patch/__init__.py @@ -285,3 +285,15 @@ # Future Plan: # Remove this patch when vLLM support these operators. # +# ** 15. File: worker/patch_qwen3_next.py** +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# 1. `vllm.model_executor.models.qwen3_next.Qwen3NextGatedDeltaNet._forward_core` +# Why: +# triton ops fused_recurrent_gated_delta_rule and fused_gdn_gating in vLLM perform not good on NPU. +# How: +# add a new fused triton ops in vLLM with ascend implementation. +# Related PR (if no, explain why): +# https://github.com/vllm-project/vllm/pull/30860 +# Future Plan: +# Remove this patch when vLLM support these operators. +# diff --git a/vllm_ascend/patch/worker/__init__.py b/vllm_ascend/patch/worker/__init__.py index 23faa67a..d6e9c049 100644 --- a/vllm_ascend/patch/worker/__init__.py +++ b/vllm_ascend/patch/worker/__init__.py @@ -35,3 +35,4 @@ import vllm_ascend.patch.worker.patch_rope # noqa import vllm_ascend.patch.worker.patch_qwen3_next # noqa import vllm_ascend.patch.worker.patch_qwen3_next_mtp # noqa import vllm_ascend.patch.worker.patch_rejection_sampler # noqa +import vllm_ascend.patch.worker.patch_qwen3_next # noqa diff --git a/vllm_ascend/patch/worker/patch_qwen3_next.py b/vllm_ascend/patch/worker/patch_qwen3_next.py index 20bf4b01..172aab8f 100644 --- a/vllm_ascend/patch/worker/patch_qwen3_next.py +++ b/vllm_ascend/patch/worker/patch_qwen3_next.py @@ -18,14 +18,23 @@ import torch from einops import rearrange from torch import nn +from vllm.attention.backends.abstract import AttentionMetadata from vllm.config import CUDAGraphMode from vllm.forward_context import get_forward_context +from vllm.model_executor.layers.fla.ops import ( + chunk_gated_delta_rule, fused_recurrent_gated_delta_rule) from vllm.model_executor.layers.mamba.abstract import MambaBase -from vllm.model_executor.models.qwen3_next import Qwen3NextGatedDeltaNet +from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( + causal_conv1d_fn, causal_conv1d_update) +from vllm.model_executor.models.qwen3_next import (Qwen3NextGatedDeltaNet, + fused_gdn_gating) from vllm.triton_utils import triton +from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata from vllm_ascend.ops.triton.fla.fused_qkvzba_split_reshape import \ fused_qkvzba_split_reshape_cat +from vllm_ascend.ops.triton.fla.sigmoid_gating import \ + fused_sigmoid_gating_delta_rule_update class AscendQwen3Next_GatedDeltaNet(nn.Module, MambaBase): @@ -101,5 +110,230 @@ class AscendQwen3Next_GatedDeltaNet(nn.Module, MambaBase): core_attn_out = rearrange(core_attn_out, "... h d -> ... (h d)") output[:num_tokens], _ = self.out_proj(core_attn_out) + def _forward_core( + self, + mixed_qkv: torch.Tensor, + b: torch.Tensor, + a: torch.Tensor, + core_attn_out: torch.Tensor, + ): + """ + Core attention computation (called by custom op). + """ + forward_context = get_forward_context() + attn_metadata: AttentionMetadata = forward_context.attn_metadata + + if attn_metadata is None: + # V1 profile run + return + + assert isinstance(attn_metadata, dict) + attn_metadata = attn_metadata[self.prefix] + assert isinstance(attn_metadata, GDNAttentionMetadata) + has_initial_state = attn_metadata.has_initial_state + spec_query_start_loc = attn_metadata.spec_query_start_loc + non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc + spec_sequence_masks = attn_metadata.spec_sequence_masks + spec_token_indx = attn_metadata.spec_token_indx + non_spec_token_indx = attn_metadata.non_spec_token_indx + spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor # noqa: E501 + non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501 + self_kv_cache = self.kv_cache[forward_context.virtual_engine] + conv_state = self_kv_cache[0].transpose(-1, -2) + ssm_state = self_kv_cache[1] + num_actual_tokens = attn_metadata.num_actual_tokens + num_accepted_tokens = attn_metadata.num_accepted_tokens + + mixed_qkv = mixed_qkv[:num_actual_tokens] + b = b[:num_actual_tokens] + a = a[:num_actual_tokens] + + # 1. Convolution sequence transformation + conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), + self.conv1d.weight.size(2)) + + if spec_sequence_masks is not None: + if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0: + mixed_qkv_spec = mixed_qkv + mixed_qkv_non_spec = None + else: + mixed_qkv_spec = mixed_qkv.index_select(0, spec_token_indx) + mixed_qkv_non_spec = mixed_qkv.index_select( + 0, non_spec_token_indx) + else: + mixed_qkv_spec = None + mixed_qkv_non_spec = mixed_qkv + + # 1.1: Process the multi-query part + if spec_sequence_masks is not None: + mixed_qkv_spec = causal_conv1d_update( + mixed_qkv_spec, + conv_state, + conv_weights, + self.conv1d.bias, + self.activation, + conv_state_indices=spec_state_indices_tensor[:, 0] + [:attn_metadata.num_spec_decodes], + num_accepted_tokens=num_accepted_tokens, + query_start_loc=spec_query_start_loc, + max_query_len=spec_state_indices_tensor.size(-1), + validate_data=False, + ) + + # 1.2: Process the remaining part + if attn_metadata.num_prefills > 0: + if mixed_qkv_non_spec is not None: + mixed_qkv_non_spec_T = mixed_qkv_non_spec.transpose(0, 1) + # - "cache_indices" updates the conv_state cache in positions + # pointed to by "state_indices_tensor" + mixed_qkv_non_spec = causal_conv1d_fn( + mixed_qkv_non_spec_T, + conv_weights, + self.conv1d.bias, + activation=self.activation, + conv_states=conv_state, + has_initial_state=has_initial_state, + cache_indices=non_spec_state_indices_tensor, + query_start_loc=non_spec_query_start_loc, + metadata=attn_metadata, + ).transpose(0, 1) + elif attn_metadata.num_decodes > 0: + mixed_qkv_non_spec = causal_conv1d_update( + mixed_qkv_non_spec, + conv_state, + conv_weights, + self.conv1d.bias, + self.activation, + conv_state_indices= + non_spec_state_indices_tensor[:attn_metadata. + num_actual_tokens], + validate_data=True, + ) + else: + mixed_qkv_non_spec = None + + query_spec, key_spec, value_spec = self.rearrange_mixed_qkv( + mixed_qkv_spec) + query_non_spec, key_non_spec, value_non_spec = self.rearrange_mixed_qkv( + mixed_qkv_non_spec) + + if attn_metadata.num_prefills > 0 or spec_sequence_masks is not None: + g, beta = fused_gdn_gating(self.A_log, a, b, self.dt_bias) + + if spec_sequence_masks is not None: + if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0: + g_spec = g + beta_spec = beta + g_non_spec = None + beta_non_spec = None + else: + g_spec = g.index_select(1, spec_token_indx) + beta_spec = beta.index_select(1, spec_token_indx) + g_non_spec = g.index_select(1, non_spec_token_indx) + beta_non_spec = beta.index_select(1, non_spec_token_indx) + else: + g_spec = None + beta_spec = None + g_non_spec = g + beta_non_spec = beta + + # 2. Recurrent attention + + # 2.1: Process the multi-query part + if spec_sequence_masks is not None: + core_attn_out_spec, last_recurrent_state = fused_recurrent_gated_delta_rule( + q=query_spec, + k=key_spec, + v=value_spec, + g=g_spec, + beta=beta_spec, + initial_state=ssm_state, + inplace_final_state=True, + cu_seqlens=spec_query_start_loc[:attn_metadata. + num_spec_decodes + 1], + ssm_state_indices=spec_state_indices_tensor, + num_accepted_tokens=num_accepted_tokens, + use_qk_l2norm_in_kernel=True, + ) + else: + core_attn_out_spec, last_recurrent_state = None, None + + # 2.2: Process the remaining part + if attn_metadata.num_prefills > 0: + initial_state = ssm_state[ + non_spec_state_indices_tensor].contiguous() + initial_state[~has_initial_state, ...] = 0 + ( + core_attn_out_non_spec, + last_recurrent_state, + ) = chunk_gated_delta_rule( + q=query_non_spec, + k=key_non_spec, + v=value_non_spec, + g=g_non_spec, + beta=beta_non_spec, + initial_state=initial_state, + output_final_state=True, + cu_seqlens=non_spec_query_start_loc, + head_first=False, + use_qk_l2norm_in_kernel=True, + ) + # Init cache + ssm_state[ + non_spec_state_indices_tensor] = last_recurrent_state.to( + ssm_state.dtype) + elif attn_metadata.num_decodes > 0: + core_attn_out_non_spec, last_recurrent_state = ( + fused_recurrent_gated_delta_rule( + q=query_non_spec, + k=key_non_spec, + v=value_non_spec, + g=g_non_spec, + beta=beta_non_spec, + initial_state=ssm_state, + inplace_final_state=True, + cu_seqlens=non_spec_query_start_loc[:attn_metadata. + num_decodes + 1], + ssm_state_indices=non_spec_state_indices_tensor, + use_qk_l2norm_in_kernel=True, + )) + else: + core_attn_out_non_spec, last_recurrent_state = None, None + + elif attn_metadata.num_decodes > 0: + core_attn_out_non_spec = fused_sigmoid_gating_delta_rule_update( + A_log=self.A_log.contiguous(), + dt_bias=self.dt_bias.contiguous(), + q=query_non_spec.contiguous(), + k=key_non_spec.contiguous(), + v=value_non_spec.contiguous(), + a=a.contiguous(), + b=b.contiguous(), + initial_state_source=ssm_state, + initial_state_indices=non_spec_state_indices_tensor, + cu_seqlens=non_spec_query_start_loc, + use_qk_l2norm_in_kernel=True, + softplus_beta=1.0, + softplus_threshold=20.0, + ) + + # 3. Merge core attention output + if spec_sequence_masks is not None and core_attn_out_non_spec is not None: + merged_out = torch.empty( + (1, num_actual_tokens, *core_attn_out_spec.shape[2:]), + dtype=core_attn_out_non_spec.dtype, + device=core_attn_out_non_spec.device, + ) + merged_out.index_copy_(1, spec_token_indx, core_attn_out_spec) + merged_out.index_copy_(1, non_spec_token_indx, + core_attn_out_non_spec) + core_attn_out[:num_actual_tokens] = merged_out.squeeze(0) + elif spec_sequence_masks is not None: + core_attn_out[:num_actual_tokens] = core_attn_out_spec.squeeze(0) + else: + core_attn_out[:num_actual_tokens] = core_attn_out_non_spec.squeeze( + 0) + Qwen3NextGatedDeltaNet.forward = AscendQwen3Next_GatedDeltaNet.forward +Qwen3NextGatedDeltaNet._forward_core = AscendQwen3Next_GatedDeltaNet._forward_core