[v0.18.0][kernel] Recompilation optimization triggered by triton function parameter optimization (#7647)

### What this PR does / why we need it?
Some parameters of Triton operators are unnecessarily modified with the
"constexpr" modifier. When these parameters change, recompilation is
triggered, which significantly affects the model performance. Therefore,
these parameters need to be rectified.

- vLLM version: v0.17.0
- vLLM main:
8b6325758c

Signed-off-by: HarpSealCC [844291270@qq.com](mailto:844291270@qq.com)
Signed-off-by: l30072083 <liuchengzhuo1@h-partners.com>
Co-authored-by: l30072083 <liuchengzhuo1@h-partners.com>
This commit is contained in:
HarpsealCC
2026-03-26 19:10:45 +08:00
committed by GitHub
parent d781902ce9
commit d6661c09b6
5 changed files with 20 additions and 34 deletions

View File

@@ -10,7 +10,7 @@ from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num
UNIFIED_BUFFER_SIZE = 1572864
@triton.jit
@triton.jit(do_not_specialize=["seq_len", "NUM_HEADS", "NUM_BATCHES", "beta", "threshold", "ROW_ITER"])
def fused_gdn_gating_kernel(
g,
beta_output,
@@ -19,16 +19,17 @@ def fused_gdn_gating_kernel(
b,
dt_bias,
seq_len,
NUM_HEADS: tl.constexpr,
NUM_BATCHES: tl.constexpr,
beta: tl.constexpr,
threshold: tl.constexpr,
NUM_HEADS,
NUM_BATCHES,
beta,
threshold,
BLK_HEADS: tl.constexpr,
COL_ITER: tl.constexpr,
BLK_BATCHES: tl.constexpr,
ROW_ITER: tl.constexpr,
ROW_ITER,
):
i_b, i_s = tl.program_id(0), tl.program_id(1)
COL_ITER = tl.cdiv(NUM_HEADS, BLK_HEADS)
for row_idx in range(0, ROW_ITER):
batch_off = i_b * ROW_ITER * BLK_BATCHES + row_idx * BLK_BATCHES + tl.arange(0, BLK_BATCHES)
@@ -69,23 +70,11 @@ def fused_gdn_gating_patch(
num_cores = get_vectorcore_num()
BLK_HEADS = 8
COL_ITER = triton.cdiv(num_heads, BLK_HEADS)
elem_size = a.element_size()
max_ub_batches = int((UNIFIED_BUFFER_SIZE * 0.95) / (BLK_HEADS * elem_size))
if batch <= num_cores:
progs = batch
BLK_BATCHES = 1
ROW_ITER = 1
else:
progs = num_cores
FACTOR = 8 * num_heads
calc_blk_batches = (
triton.next_power_of_2(triton.cdiv(int(UNIFIED_BUFFER_SIZE * 0.95), FACTOR * BLK_HEADS * elem_size)) // 2
)
BLK_BATCHES = max(1, min(calc_blk_batches, max_ub_batches, 64))
row_per_core = triton.cdiv(batch, progs)
ROW_ITER = triton.cdiv(row_per_core, BLK_BATCHES)
progs = num_cores
row_per_core = triton.cdiv(batch, progs)
BLK_BATCHES = 64
ROW_ITER = triton.cdiv(row_per_core, BLK_BATCHES)
g = torch.empty(1, batch, num_heads, dtype=torch.float32, device=a.device)
beta_output = torch.empty(1, batch, num_heads, dtype=b.dtype, device=b.device)
@@ -104,7 +93,6 @@ def fused_gdn_gating_patch(
beta,
threshold,
BLK_HEADS=BLK_HEADS,
COL_ITER=COL_ITER,
BLK_BATCHES=BLK_BATCHES,
ROW_ITER=ROW_ITER,
)