From d6661c09b6f106ee37f9a69dce4bd8b82b179cce Mon Sep 17 00:00:00 2001 From: HarpsealCC <58933088+HarpsealCC@users.noreply.github.com> Date: Thu, 26 Mar 2026 19:10:45 +0800 Subject: [PATCH] [v0.18.0][kernel] Recompilation optimization triggered by triton function parameter optimization (#7647) ### What this PR does / why we need it? Some parameters of Triton operators are unnecessarily modified with the "constexpr" modifier. When these parameters change, recompilation is triggered, which significantly affects the model performance. Therefore, these parameters need to be rectified. - vLLM version: v0.17.0 - vLLM main: https://github.com/vllm-project/vllm/commit/8b6325758cce5f9c36d38f2462edbd368b97a07c Signed-off-by: HarpSealCC [844291270@qq.com](mailto:844291270@qq.com) Signed-off-by: l30072083 Co-authored-by: l30072083 --- vllm_ascend/ops/triton/fla/chunk_delta_h.py | 10 +++--- vllm_ascend/ops/triton/fla/cumsum.py | 2 -- vllm_ascend/ops/triton/fused_gdn_gating.py | 36 +++++++-------------- vllm_ascend/ops/triton/reject_sample.py | 4 +-- vllm_ascend/ops/triton/spec_decode/utils.py | 2 +- 5 files changed, 20 insertions(+), 34 deletions(-) diff --git a/vllm_ascend/ops/triton/fla/chunk_delta_h.py b/vllm_ascend/ops/triton/fla/chunk_delta_h.py index e305878e..0189d313 100644 --- a/vllm_ascend/ops/triton/fla/chunk_delta_h.py +++ b/vllm_ascend/ops/triton/fla/chunk_delta_h.py @@ -26,7 +26,7 @@ _CONDITIONS = ("seq7168",) "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, } ) -@triton.jit(do_not_specialize=["T"]) +@triton.jit(do_not_specialize=["T", "H", "Hg", "K", "V"]) def chunk_gated_delta_rule_fwd_kernel_h_blockdim64( k, v, @@ -40,10 +40,10 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64( chunk_offsets, h_update, T, - H: tl.constexpr, - Hg: tl.constexpr, - K: tl.constexpr, - V: tl.constexpr, + H, + Hg, + K, + V, BT: tl.constexpr, USE_G: tl.constexpr, USE_INITIAL_STATE: tl.constexpr, diff --git a/vllm_ascend/ops/triton/fla/cumsum.py b/vllm_ascend/ops/triton/fla/cumsum.py index ea2fa143..95965617 100644 --- a/vllm_ascend/ops/triton/fla/cumsum.py +++ b/vllm_ascend/ops/triton/fla/cumsum.py @@ -26,7 +26,6 @@ def chunk_local_cumsum_scalar_kernel( cu_seqlens, chunk_indices, T, - B: tl.constexpr, H: tl.constexpr, BLOCK_T: tl.constexpr, REVERSE: tl.constexpr, @@ -103,7 +102,6 @@ def chunk_local_cumsum_scalar( cu_seqlens=cu_seqlens, chunk_indices=block_indices, T=T, - B=B, H=H, BLOCK_T=OPTIM_BLOCK_SIZE, CHUNK_SIZE=chunk_size, diff --git a/vllm_ascend/ops/triton/fused_gdn_gating.py b/vllm_ascend/ops/triton/fused_gdn_gating.py index 9c3ea9d1..69bbf699 100644 --- a/vllm_ascend/ops/triton/fused_gdn_gating.py +++ b/vllm_ascend/ops/triton/fused_gdn_gating.py @@ -10,7 +10,7 @@ from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num UNIFIED_BUFFER_SIZE = 1572864 -@triton.jit +@triton.jit(do_not_specialize=["seq_len", "NUM_HEADS", "NUM_BATCHES", "beta", "threshold", "ROW_ITER"]) def fused_gdn_gating_kernel( g, beta_output, @@ -19,16 +19,17 @@ def fused_gdn_gating_kernel( b, dt_bias, seq_len, - NUM_HEADS: tl.constexpr, - NUM_BATCHES: tl.constexpr, - beta: tl.constexpr, - threshold: tl.constexpr, + NUM_HEADS, + NUM_BATCHES, + beta, + threshold, BLK_HEADS: tl.constexpr, - COL_ITER: tl.constexpr, BLK_BATCHES: tl.constexpr, - ROW_ITER: tl.constexpr, + ROW_ITER, ): i_b, i_s = tl.program_id(0), tl.program_id(1) + COL_ITER = tl.cdiv(NUM_HEADS, BLK_HEADS) + for row_idx in range(0, ROW_ITER): batch_off = i_b * ROW_ITER * BLK_BATCHES + row_idx * BLK_BATCHES + tl.arange(0, BLK_BATCHES) @@ -69,23 +70,11 @@ def fused_gdn_gating_patch( num_cores = get_vectorcore_num() BLK_HEADS = 8 - COL_ITER = triton.cdiv(num_heads, BLK_HEADS) - elem_size = a.element_size() - max_ub_batches = int((UNIFIED_BUFFER_SIZE * 0.95) / (BLK_HEADS * elem_size)) - if batch <= num_cores: - progs = batch - BLK_BATCHES = 1 - ROW_ITER = 1 - else: - progs = num_cores - FACTOR = 8 * num_heads - calc_blk_batches = ( - triton.next_power_of_2(triton.cdiv(int(UNIFIED_BUFFER_SIZE * 0.95), FACTOR * BLK_HEADS * elem_size)) // 2 - ) - BLK_BATCHES = max(1, min(calc_blk_batches, max_ub_batches, 64)) - row_per_core = triton.cdiv(batch, progs) - ROW_ITER = triton.cdiv(row_per_core, BLK_BATCHES) + progs = num_cores + row_per_core = triton.cdiv(batch, progs) + BLK_BATCHES = 64 + ROW_ITER = triton.cdiv(row_per_core, BLK_BATCHES) g = torch.empty(1, batch, num_heads, dtype=torch.float32, device=a.device) beta_output = torch.empty(1, batch, num_heads, dtype=b.dtype, device=b.device) @@ -104,7 +93,6 @@ def fused_gdn_gating_patch( beta, threshold, BLK_HEADS=BLK_HEADS, - COL_ITER=COL_ITER, BLK_BATCHES=BLK_BATCHES, ROW_ITER=ROW_ITER, ) diff --git a/vllm_ascend/ops/triton/reject_sample.py b/vllm_ascend/ops/triton/reject_sample.py index 58a0fce7..4c85a430 100644 --- a/vllm_ascend/ops/triton/reject_sample.py +++ b/vllm_ascend/ops/triton/reject_sample.py @@ -82,7 +82,7 @@ def bonus_renew( tl.store(output_token_ids_ptr + position * (max_spec_len + 1) + num_tokens1, bonus_token_id) -@triton.jit(do_not_specialize=["max_spec_len"]) +@triton.jit(do_not_specialize=["vec_len", "max_spec_len"]) def rejection_greedy_sample_triton( output_token_ids_ptr, # [batch_size, max_spec_len + 1] cu_num_draft_tokens_ptr, # [batch_size] @@ -196,7 +196,7 @@ def rejection_random_sample_kernel( ) -@triton.jit(do_not_specialize=["replace_from", "replace_to"]) +@triton.jit(do_not_specialize=["replace_from", "replace_to", "vec_len"]) def expand_kernel( output_ptr, # [num_tokens] input_ptr, # [batch_size] diff --git a/vllm_ascend/ops/triton/spec_decode/utils.py b/vllm_ascend/ops/triton/spec_decode/utils.py index 15d42fa0..3c7aa450 100644 --- a/vllm_ascend/ops/triton/spec_decode/utils.py +++ b/vllm_ascend/ops/triton/spec_decode/utils.py @@ -18,7 +18,7 @@ from vllm.triton_utils import tl, triton -@triton.jit +@triton.jit(do_not_specialize=["num_reqs"]) def prepare_inputs_padded_kernel( cu_num_draft_tokens_ptr, # [num_reqs] valid_sampled_tokens_count_ptr, # [num_reqs]