diff --git a/vllm_ascend/ops/triton/fused_gdn_gating.py b/vllm_ascend/ops/triton/fused_gdn_gating.py new file mode 100644 index 00000000..dfd5dde2 --- /dev/null +++ b/vllm_ascend/ops/triton/fused_gdn_gating.py @@ -0,0 +1,118 @@ +# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/qwen3_next.py +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch +from vllm.triton_utils import tl, triton + +from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num + +UNIFIED_BUFFER_SIZE = 1572864 + + +@triton.jit +def fused_gdn_gating_kernel( + g, + beta_output, + A_log, + a, + b, + dt_bias, + seq_len, + NUM_HEADS: tl.constexpr, + NUM_BATCHES: tl.constexpr, + beta: tl.constexpr, + threshold: tl.constexpr, + BLK_HEADS: tl.constexpr, + COL_ITER: tl.constexpr, + BLK_BATCHES: tl.constexpr, + ROW_ITER: tl.constexpr, +): + i_b, i_s = tl.program_id(0), tl.program_id(1) + for row_idx in range(0, ROW_ITER): + batch_off = i_b * ROW_ITER * BLK_BATCHES + row_idx * BLK_BATCHES + tl.arange( + 0, BLK_BATCHES) + + for col_idx in range(0, COL_ITER): + head_off = col_idx * BLK_HEADS + tl.arange(0, BLK_HEADS) + + off = batch_off[:, + None] * seq_len * NUM_HEADS + i_s * NUM_HEADS + head_off[ + None, :] + head_mask = head_off < NUM_HEADS + mask = head_mask[None, :] & (batch_off[:, None] < NUM_BATCHES) + + blk_A_log = tl.load(A_log + head_off, mask=head_mask) + blk_a = tl.load(a + off, mask=mask) + blk_b = tl.load(b + off, mask=mask) + blk_bias = tl.load(dt_bias + head_off, mask=head_mask) + + x = blk_a.to(tl.float32) + blk_bias.to(tl.float32)[None, :] + softplus_x = tl.where(beta * x <= threshold, + (1 / beta) * tl.log(1 + tl.exp(beta * x)), x) + + blk_g = -tl.exp(blk_A_log.to(tl.float32)) * softplus_x + tl.store(g + off, blk_g.to(g.dtype.element_ty), mask=mask) + + # compute beta_output = sigmoid(b) + blk_beta_output = tl.sigmoid(blk_b.to(tl.float32)) + tl.store(beta_output + off, + blk_beta_output.to(beta_output.dtype.element_ty), + mask=mask) + + +def fused_gdn_gating_patch( + A_log: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + dt_bias: torch.Tensor, + beta: float = 1.0, + threshold: float = 20.0, +) -> tuple[torch.Tensor, torch.Tensor]: + batch, num_heads = a.shape + seq_len = 1 + + num_cores = get_vectorcore_num() + + BLK_HEADS = 8 + COL_ITER = triton.cdiv(num_heads, BLK_HEADS) + + if batch <= num_cores: + progs = batch + BLK_BATCHES = 1 + ROW_ITER = 1 + else: + progs = num_cores + FACTOR = 8 * num_heads + row_per_core = triton.cdiv(batch, num_cores) + BLK_BATCHES = triton.next_power_of_2( + triton.cdiv(UNIFIED_BUFFER_SIZE, FACTOR * BLK_HEADS) // + a.element_size()) // 2 + ROW_ITER = triton.cdiv(row_per_core, BLK_BATCHES) + + g = torch.empty(1, batch, num_heads, dtype=torch.float32, device=a.device) + beta_output = torch.empty(1, + batch, + num_heads, + dtype=b.dtype, + device=b.device) + + grid = (progs, seq_len) + fused_gdn_gating_kernel[grid]( + g, + beta_output, + A_log, + a, + b, + dt_bias, + seq_len, + num_heads, + batch, + beta, + threshold, + BLK_HEADS=BLK_HEADS, + COL_ITER=COL_ITER, + BLK_BATCHES=BLK_BATCHES, + ROW_ITER=ROW_ITER, + ) + return g, beta_output diff --git a/vllm_ascend/patch/__init__.py b/vllm_ascend/patch/__init__.py index 89d7c957..c3486d61 100644 --- a/vllm_ascend/patch/__init__.py +++ b/vllm_ascend/patch/__init__.py @@ -297,3 +297,13 @@ # Future Plan: # Remove this patch when vLLM support these operators. # +# 2. `vllm.model_executor.models.qwen3_next.Qwen3NextGatedDeltaNet._forward_core` +# Why: +# The Qwen3Next GatedDeltaNet _forward_core cannot directly add custom operators. +# How: +# Add a branch in Qwen3NextGatedDeltaNet._forward_core to adapt to fused_gdn_gating_patch. +# Related PR (if no, explain why): +# https://github.com/vllm-project/vllm/pull/31002 +# Future Plan: +# Remove this patch when vLLM support these operators. +# diff --git a/vllm_ascend/patch/worker/patch_qwen3_next.py b/vllm_ascend/patch/worker/patch_qwen3_next.py index 172aab8f..e7604aef 100644 --- a/vllm_ascend/patch/worker/patch_qwen3_next.py +++ b/vllm_ascend/patch/worker/patch_qwen3_next.py @@ -35,6 +35,7 @@ from vllm_ascend.ops.triton.fla.fused_qkvzba_split_reshape import \ fused_qkvzba_split_reshape_cat from vllm_ascend.ops.triton.fla.sigmoid_gating import \ fused_sigmoid_gating_delta_rule_update +from vllm_ascend.ops.triton.fused_gdn_gating import fused_gdn_gating_patch class AscendQwen3Next_GatedDeltaNet(nn.Module, MambaBase): @@ -151,7 +152,6 @@ class AscendQwen3Next_GatedDeltaNet(nn.Module, MambaBase): # 1. Convolution sequence transformation conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)) - if spec_sequence_masks is not None: if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0: mixed_qkv_spec = mixed_qkv @@ -211,14 +211,18 @@ class AscendQwen3Next_GatedDeltaNet(nn.Module, MambaBase): ) else: mixed_qkv_non_spec = None - query_spec, key_spec, value_spec = self.rearrange_mixed_qkv( mixed_qkv_spec) query_non_spec, key_non_spec, value_non_spec = self.rearrange_mixed_qkv( mixed_qkv_non_spec) if attn_metadata.num_prefills > 0 or spec_sequence_masks is not None: - g, beta = fused_gdn_gating(self.A_log, a, b, self.dt_bias) + is_cuda_graph = forward_context.cudagraph_runtime_mode != CUDAGraphMode.NONE + if (is_cuda_graph): + g, beta = fused_gdn_gating_patch(self.A_log, a, b, + self.dt_bias) + else: + g, beta = fused_gdn_gating(self.A_log, a, b, self.dt_bias) if spec_sequence_masks is not None: if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0: