From 2db33868a41efbf9dbdc3b31dbe666ab8b5b0c90 Mon Sep 17 00:00:00 2001 From: cvSoldier <43463523+cvSoldier@users.noreply.github.com> Date: Thu, 26 Mar 2026 16:31:34 +0800 Subject: [PATCH] [kernel] Recompilation optimization triggered by triton function parameter optimization (#7645) ### What this PR does / why we need it? - Please clarify why the changes are needed. For instance, the use case and bug description. Some parameters of Triton operators are unnecessarily modified with the "constexpr" modifier. When these parameters change, recompilation is triggered, which significantly affects the model performance. Therefore, these parameters need to be rectified. main branch:https://github.com/vllm-project/vllm-ascend/pull/7483 ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? --------- Signed-off-by: cvSoldier <610496306@qq.com> --- .../ops/triton/fla/chunk_scaled_dot_kkt.py | 2 +- vllm_ascend/ops/triton/fla/sigmoid_gating.py | 8 ++--- vllm_ascend/ops/triton/fla/solve_tril.py | 12 +++---- .../linearnorm/split_qkv_rmsnorm_mrope.py | 6 ++-- vllm_ascend/ops/triton/mamba/causal_conv1d.py | 33 +++++++++++++------ 5 files changed, 37 insertions(+), 24 deletions(-) diff --git a/vllm_ascend/ops/triton/fla/chunk_scaled_dot_kkt.py b/vllm_ascend/ops/triton/fla/chunk_scaled_dot_kkt.py index 7ceafe3c..864a29e6 100644 --- a/vllm_ascend/ops/triton/fla/chunk_scaled_dot_kkt.py +++ b/vllm_ascend/ops/triton/fla/chunk_scaled_dot_kkt.py @@ -21,7 +21,7 @@ from .utils import prepare_chunk_indices, safe_exp "USE_G": lambda args: args["g_cumsum"] is not None, } ) -@triton.jit(do_not_specialize=["T"]) +@triton.jit(do_not_specialize=["T", "B"]) def chunk_scaled_dot_kkt_fwd_kernel( k, beta, # [H, B, T] diff --git a/vllm_ascend/ops/triton/fla/sigmoid_gating.py b/vllm_ascend/ops/triton/fla/sigmoid_gating.py index e5512c6d..bd683cb6 100644 --- a/vllm_ascend/ops/triton/fla/sigmoid_gating.py +++ b/vllm_ascend/ops/triton/fla/sigmoid_gating.py @@ -39,7 +39,7 @@ else: "IS_SPEC_DECODING": lambda args: args["num_accepted_tokens"] is not None, } ) -@triton.jit(do_not_specialize=["N", "T"]) +@triton.jit(do_not_specialize=["scale", "N", "T", "B"]) def fused_recurrent_gated_delta_rule_fwd_kernel( q, k, @@ -53,9 +53,9 @@ def fused_recurrent_gated_delta_rule_fwd_kernel( ssm_state_indices, num_accepted_tokens, scale, - N: tl.constexpr, # num of sequences - T: tl.constexpr, # num of tokens - B: tl.constexpr, + N, # num of sequences + T, # num of tokens + B, H: tl.constexpr, HV: tl.constexpr, K: tl.constexpr, diff --git a/vllm_ascend/ops/triton/fla/solve_tril.py b/vllm_ascend/ops/triton/fla/solve_tril.py index 1783035d..fb46ccfa 100644 --- a/vllm_ascend/ops/triton/fla/solve_tril.py +++ b/vllm_ascend/ops/triton/fla/solve_tril.py @@ -18,14 +18,14 @@ from .utils import prepare_chunk_indices @triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) -@triton.jit(do_not_specialize=["T"]) +@triton.jit(do_not_specialize=["T", "H"]) def solve_tril_16x16_kernel( A, Ad, cu_seqlens, chunk_indices, T, - H: tl.constexpr, + H, BT: tl.constexpr, IS_VARLEN: tl.constexpr, LARGE_BLOCK_T: tl.constexpr, @@ -134,7 +134,7 @@ def solve_tril_16x16_kernel( @triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) -@triton.jit(do_not_specialize=["T"]) +@triton.jit(do_not_specialize=["T", "H"]) def merge_16x16_to_32x32_inverse_kernel( A, Ad, @@ -142,7 +142,7 @@ def merge_16x16_to_32x32_inverse_kernel( cu_seqlens, chunk_indices, T, - H: tl.constexpr, + H, BT: tl.constexpr, IS_VARLEN: tl.constexpr, ): @@ -198,7 +198,7 @@ def merge_16x16_to_32x32_inverse_kernel( @triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) -@triton.jit(do_not_specialize=["T"]) +@triton.jit(do_not_specialize=["T", "H"]) def merge_16x16_to_64x64_inverse_kernel( A, Ad, @@ -206,7 +206,7 @@ def merge_16x16_to_64x64_inverse_kernel( cu_seqlens, chunk_indices, T, - H: tl.constexpr, + H, BT: tl.constexpr, IS_VARLEN: tl.constexpr, ): diff --git a/vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_mrope.py b/vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_mrope.py index ba920c25..25aa727b 100644 --- a/vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_mrope.py +++ b/vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_mrope.py @@ -47,9 +47,9 @@ def split_qkv_rmsnorm_mrope_kernel( q_size: tl.constexpr, kv_size: tl.constexpr, eps: tl.constexpr, - mrope_section_t: tl.constexpr, - mrope_section_h: tl.constexpr, - mrope_section_w: tl.constexpr, + mrope_section_t, + mrope_section_h, + mrope_section_w, has_bias: tl.constexpr, is_interleaved: tl.constexpr, rope_dim: tl.constexpr, diff --git a/vllm_ascend/ops/triton/mamba/causal_conv1d.py b/vllm_ascend/ops/triton/mamba/causal_conv1d.py index a06320e0..da7f4183 100644 --- a/vllm_ascend/ops/triton/mamba/causal_conv1d.py +++ b/vllm_ascend/ops/triton/mamba/causal_conv1d.py @@ -156,7 +156,20 @@ def extract_last_width(x, start_loc, width): return x[:, indices].permute(1, 0, 2) -@triton.jit +@triton.jit( + do_not_specialize=[ + "batch", + "state_len", + "num_cache_lines", + "stride_x_seq", + "stride_x_token", + "stride_conv_state_seq", + "stride_conv_state_tok", + "stride_state_indices", + "stride_o_seq", + "stride_o_token", + ] +) def _causal_conv1d_update_kernel_npu_tiled( # Pointers x_ptr, # (batch, dim, seqlen) OR (num_tokens, dim) for varlen @@ -172,21 +185,21 @@ def _causal_conv1d_update_kernel_npu_tiled( batch: tl.int32, dim: tl.constexpr, seqlen: tl.constexpr, # max seqlen for varlen, or exact seqlen - state_len: tl.constexpr, # effective state_len computed in wrapper - num_cache_lines: tl.constexpr, + state_len, # effective state_len computed in wrapper + num_cache_lines, # Strides - stride_x_seq: tl.constexpr, + stride_x_seq, stride_x_dim: tl.constexpr, - stride_x_token: tl.constexpr, + stride_x_token, stride_w_dim: tl.constexpr, stride_w_width: tl.constexpr, - stride_conv_state_seq: tl.constexpr, + stride_conv_state_seq, stride_conv_state_dim: tl.constexpr, - stride_conv_state_tok: tl.constexpr, - stride_state_indices: tl.constexpr, - stride_o_seq: tl.constexpr, + stride_conv_state_tok, + stride_state_indices, + stride_o_seq, stride_o_dim: tl.constexpr, - stride_o_token: tl.constexpr, + stride_o_token, # others pad_slot_id: tl.constexpr, # Meta