Files
xc-llm-ascend/vllm_ascend/ops/triton/fused_gdn_gating.py
HarpsealCC d6661c09b6 [v0.18.0][kernel] Recompilation optimization triggered by triton function parameter optimization (#7647)
### What this PR does / why we need it?
Some parameters of Triton operators are unnecessarily modified with the
"constexpr" modifier. When these parameters change, recompilation is
triggered, which significantly affects the model performance. Therefore,
these parameters need to be rectified.

- vLLM version: v0.17.0
- vLLM main:
8b6325758c

Signed-off-by: HarpSealCC [844291270@qq.com](mailto:844291270@qq.com)
Signed-off-by: l30072083 <liuchengzhuo1@h-partners.com>
Co-authored-by: l30072083 <liuchengzhuo1@h-partners.com>
2026-03-26 19:10:45 +08:00

100 lines
2.9 KiB
Python

# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/qwen3_next.py
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.triton_utils import tl, triton
from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num
UNIFIED_BUFFER_SIZE = 1572864
@triton.jit(do_not_specialize=["seq_len", "NUM_HEADS", "NUM_BATCHES", "beta", "threshold", "ROW_ITER"])
def fused_gdn_gating_kernel(
g,
beta_output,
A_log,
a,
b,
dt_bias,
seq_len,
NUM_HEADS,
NUM_BATCHES,
beta,
threshold,
BLK_HEADS: tl.constexpr,
BLK_BATCHES: tl.constexpr,
ROW_ITER,
):
i_b, i_s = tl.program_id(0), tl.program_id(1)
COL_ITER = tl.cdiv(NUM_HEADS, BLK_HEADS)
for row_idx in range(0, ROW_ITER):
batch_off = i_b * ROW_ITER * BLK_BATCHES + row_idx * BLK_BATCHES + tl.arange(0, BLK_BATCHES)
for col_idx in range(0, COL_ITER):
head_off = col_idx * BLK_HEADS + tl.arange(0, BLK_HEADS)
off = batch_off[:, None] * seq_len * NUM_HEADS + i_s * NUM_HEADS + head_off[None, :]
head_mask = head_off < NUM_HEADS
mask = head_mask[None, :] & (batch_off[:, None] < NUM_BATCHES)
blk_A_log = tl.load(A_log + head_off, mask=head_mask)
blk_a = tl.load(a + off, mask=mask)
blk_b = tl.load(b + off, mask=mask)
blk_bias = tl.load(dt_bias + head_off, mask=head_mask)
x = blk_a.to(tl.float32) + blk_bias.to(tl.float32)[None, :]
softplus_x = tl.where(beta * x <= threshold, (1 / beta) * tl.log(1 + tl.exp(beta * x)), x)
blk_g = -tl.exp(blk_A_log.to(tl.float32)) * softplus_x
tl.store(g + off, blk_g.to(g.dtype.element_ty), mask=mask)
# compute beta_output = sigmoid(b)
blk_beta_output = tl.sigmoid(blk_b.to(tl.float32))
tl.store(beta_output + off, blk_beta_output.to(beta_output.dtype.element_ty), mask=mask)
def fused_gdn_gating_patch(
A_log: torch.Tensor,
a: torch.Tensor,
b: torch.Tensor,
dt_bias: torch.Tensor,
beta: float = 1.0,
threshold: float = 20.0,
) -> tuple[torch.Tensor, torch.Tensor]:
batch, num_heads = a.shape
seq_len = 1
num_cores = get_vectorcore_num()
BLK_HEADS = 8
progs = num_cores
row_per_core = triton.cdiv(batch, progs)
BLK_BATCHES = 64
ROW_ITER = triton.cdiv(row_per_core, BLK_BATCHES)
g = torch.empty(1, batch, num_heads, dtype=torch.float32, device=a.device)
beta_output = torch.empty(1, batch, num_heads, dtype=b.dtype, device=b.device)
grid = (progs, seq_len)
fused_gdn_gating_kernel[grid](
g,
beta_output,
A_log,
a,
b,
dt_bias,
seq_len,
num_heads,
batch,
beta,
threshold,
BLK_HEADS=BLK_HEADS,
BLK_BATCHES=BLK_BATCHES,
ROW_ITER=ROW_ITER,
)
return g, beta_output