From dba34d4915f83c1c77825f8f2c8fda030f389d8d Mon Sep 17 00:00:00 2001 From: "Mr.WXS" <46145645+894743926@users.noreply.github.com> Date: Wed, 25 Mar 2026 23:31:27 +0800 Subject: [PATCH] [v0.18.0][Triton][Qwen3.5] delete expr for kernels args (#7646) ### What this PR does / why we need it? Some parameters of Triton operators are unnecessarily modified with the "constexpr" modifier. When these parameters change, recompilation is triggered, which significantly affects the model performance. Therefore, these parameters need to be rectified. backport: https://github.com/vllm-project/vllm-ascend/pull/7482 Signed-off-by: w30012745 Co-authored-by: w30012745 --- vllm_ascend/ops/triton/fla/chunk_o.py | 10 +++++----- vllm_ascend/ops/triton/fla/l2norm.py | 4 ++-- vllm_ascend/ops/triton/fla/wy_fast.py | 10 +++++----- vllm_ascend/ops/triton/layernorm_gated.py | 2 +- 4 files changed, 13 insertions(+), 13 deletions(-) diff --git a/vllm_ascend/ops/triton/fla/chunk_o.py b/vllm_ascend/ops/triton/fla/chunk_o.py index 258b40eb..66093fe4 100644 --- a/vllm_ascend/ops/triton/fla/chunk_o.py +++ b/vllm_ascend/ops/triton/fla/chunk_o.py @@ -22,7 +22,7 @@ from .utils import prepare_chunk_offsets, safe_exp "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, } ) -@triton.jit(do_not_specialize=["T"]) +@triton.jit(do_not_specialize=["chunk_offsets", "scale", "T", "H", "Hg", "K", "V"]) def chunk_fwd_kernel_o( q, k, @@ -34,10 +34,10 @@ def chunk_fwd_kernel_o( chunk_offsets, scale, T, - H: tl.constexpr, - Hg: tl.constexpr, - K: tl.constexpr, - V: tl.constexpr, + H, + Hg, + K, + V, BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, diff --git a/vllm_ascend/ops/triton/fla/l2norm.py b/vllm_ascend/ops/triton/fla/l2norm.py index 9ba89faa..28abeccd 100644 --- a/vllm_ascend/ops/triton/fla/l2norm.py +++ b/vllm_ascend/ops/triton/fla/l2norm.py @@ -14,8 +14,8 @@ from vllm.triton_utils import tl, triton from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num -@triton.jit -def l2norm_fwd_kernel2_loop(X, Y, eps, M, N: tl.constexpr, MBLOCK: tl.constexpr, NUM_CHUNKS: tl.constexpr): +@triton.jit(do_not_specialize=["eps", "M", "NUM_CHUNKS"]) +def l2norm_fwd_kernel2_loop(X, Y, eps, M, N: tl.constexpr, MBLOCK: tl.constexpr, NUM_CHUNKS): base_row = tl.program_id(0) * (NUM_CHUNKS * MBLOCK) rindex = tl.arange(0, N)[None, :] diff --git a/vllm_ascend/ops/triton/fla/wy_fast.py b/vllm_ascend/ops/triton/fla/wy_fast.py index 12cdfcad..82f2ab18 100644 --- a/vllm_ascend/ops/triton/fla/wy_fast.py +++ b/vllm_ascend/ops/triton/fla/wy_fast.py @@ -17,7 +17,7 @@ from .utils import prepare_chunk_indices @triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) -@triton.jit(do_not_specialize=["T"]) +@triton.jit(do_not_specialize=["T", "H", "Hg", "K", "V"]) def recompute_w_u_fwd_kernel( k, v, @@ -29,10 +29,10 @@ def recompute_w_u_fwd_kernel( cu_seqlens, chunk_indices, T, - H: tl.constexpr, - Hg: tl.constexpr, - K: tl.constexpr, - V: tl.constexpr, + H, + Hg, + K, + V, BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, diff --git a/vllm_ascend/ops/triton/layernorm_gated.py b/vllm_ascend/ops/triton/layernorm_gated.py index 76c418f2..0da63052 100644 --- a/vllm_ascend/ops/triton/layernorm_gated.py +++ b/vllm_ascend/ops/triton/layernorm_gated.py @@ -12,7 +12,7 @@ from vllm.triton_utils import tl, triton @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) @triton.heuristics({"HAS_Z": lambda args: args["Z"] is not None}) -@triton.jit +@triton.jit(do_not_specialize=["stride_x_row", "stride_y_row", "stride_z_row", "M", "N", "eps"]) def _layer_norm_fwd_1pass_kernel_npu( X, # pointer to the input Y, # pointer to the output