Files
xc-llm-ascend/vllm_ascend/ops/triton/fused_gdn_gating.py
xmpp777 9216e1b050 [fix] Add support for Qwen3.5 Dense and MoE on Ascend (#6933)
### What this PR does / why we need it?

This pull request introduces support for the Qwen3.5 MoE model on Ascend
devices. The key changes are:

* **Quantization Configuration for Qwen3.5 MoE**: Adds necessary prefix
mappings and packed module definitions for `qwen3_5_moe` in
`vllm_ascend/quantization/modelslim_config.py` to enable ModelSlim
quantization.
* **Triton Kernel Fix**: Corrects a bug in the `fused_gdn_gating` Triton
kernel. The calculation for `BLK_BATCHES` had an operator precedence
issue which is now resolved. The calculation has also been made more
robust with added clamping to prevent potential out-of-bounds memory
access in the unified buffer.

These changes enable the correct and efficient execution of Qwen3.5 MoE
models on Ascend hardware.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

CI should be used to verify the correctness of these changes. It is
recommended to run tests with the Qwen3.5 MoE model to ensure the new
configurations and the kernel fix work as expected.

Signed-off-by: xmpp777 <yangming2@huawei.com>
2026-03-10 09:09:31 +08:00

112 lines
3.4 KiB
Python

# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/qwen3_next.py
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.triton_utils import tl, triton
from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num
UNIFIED_BUFFER_SIZE = 1572864
@triton.jit
def fused_gdn_gating_kernel(
g,
beta_output,
A_log,
a,
b,
dt_bias,
seq_len,
NUM_HEADS: tl.constexpr,
NUM_BATCHES: tl.constexpr,
beta: tl.constexpr,
threshold: tl.constexpr,
BLK_HEADS: tl.constexpr,
COL_ITER: tl.constexpr,
BLK_BATCHES: tl.constexpr,
ROW_ITER: tl.constexpr,
):
i_b, i_s = tl.program_id(0), tl.program_id(1)
for row_idx in range(0, ROW_ITER):
batch_off = i_b * ROW_ITER * BLK_BATCHES + row_idx * BLK_BATCHES + tl.arange(0, BLK_BATCHES)
for col_idx in range(0, COL_ITER):
head_off = col_idx * BLK_HEADS + tl.arange(0, BLK_HEADS)
off = batch_off[:, None] * seq_len * NUM_HEADS + i_s * NUM_HEADS + head_off[None, :]
head_mask = head_off < NUM_HEADS
mask = head_mask[None, :] & (batch_off[:, None] < NUM_BATCHES)
blk_A_log = tl.load(A_log + head_off, mask=head_mask)
blk_a = tl.load(a + off, mask=mask)
blk_b = tl.load(b + off, mask=mask)
blk_bias = tl.load(dt_bias + head_off, mask=head_mask)
x = blk_a.to(tl.float32) + blk_bias.to(tl.float32)[None, :]
softplus_x = tl.where(beta * x <= threshold, (1 / beta) * tl.log(1 + tl.exp(beta * x)), x)
blk_g = -tl.exp(blk_A_log.to(tl.float32)) * softplus_x
tl.store(g + off, blk_g.to(g.dtype.element_ty), mask=mask)
# compute beta_output = sigmoid(b)
blk_beta_output = tl.sigmoid(blk_b.to(tl.float32))
tl.store(beta_output + off, blk_beta_output.to(beta_output.dtype.element_ty), mask=mask)
def fused_gdn_gating_patch(
A_log: torch.Tensor,
a: torch.Tensor,
b: torch.Tensor,
dt_bias: torch.Tensor,
beta: float = 1.0,
threshold: float = 20.0,
) -> tuple[torch.Tensor, torch.Tensor]:
batch, num_heads = a.shape
seq_len = 1
num_cores = get_vectorcore_num()
BLK_HEADS = 8
COL_ITER = triton.cdiv(num_heads, BLK_HEADS)
elem_size = a.element_size()
max_ub_batches = int((UNIFIED_BUFFER_SIZE * 0.95) / (BLK_HEADS * elem_size))
if batch <= num_cores:
progs = batch
BLK_BATCHES = 1
ROW_ITER = 1
else:
progs = num_cores
FACTOR = 8 * num_heads
calc_blk_batches = (
triton.next_power_of_2(triton.cdiv(int(UNIFIED_BUFFER_SIZE * 0.95), FACTOR * BLK_HEADS * elem_size)) // 2
)
BLK_BATCHES = max(1, min(calc_blk_batches, max_ub_batches, 64))
row_per_core = triton.cdiv(batch, progs)
ROW_ITER = triton.cdiv(row_per_core, BLK_BATCHES)
g = torch.empty(1, batch, num_heads, dtype=torch.float32, device=a.device)
beta_output = torch.empty(1, batch, num_heads, dtype=b.dtype, device=b.device)
grid = (progs, seq_len)
fused_gdn_gating_kernel[grid](
g,
beta_output,
A_log,
a,
b,
dt_bias,
seq_len,
num_heads,
batch,
beta,
threshold,
BLK_HEADS=BLK_HEADS,
COL_ITER=COL_ITER,
BLK_BATCHES=BLK_BATCHES,
ROW_ITER=ROW_ITER,
)
return g, beta_output