[v0.18.0][Triton][Qwen3.5] delete expr for kernels args (#7646)

### What this PR does / why we need it?
Some parameters of Triton operators are unnecessarily modified with the
"constexpr" modifier. When these parameters change, recompilation is
triggered, which significantly affects the model performance. Therefore,
these parameters need to be rectified.
backport: https://github.com/vllm-project/vllm-ascend/pull/7482


Signed-off-by: w30012745 <wangxiaoshuai2@h-partners.com>
Co-authored-by: w30012745 <wangxiaoshuai2@h-partners.com>
This commit is contained in:
Mr.WXS
2026-03-25 23:31:27 +08:00
committed by GitHub
parent dd55736ee4
commit dba34d4915
4 changed files with 13 additions and 13 deletions

View File

@@ -22,7 +22,7 @@ from .utils import prepare_chunk_offsets, safe_exp
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
}
)
@triton.jit(do_not_specialize=["T"])
@triton.jit(do_not_specialize=["chunk_offsets", "scale", "T", "H", "Hg", "K", "V"])
def chunk_fwd_kernel_o(
q,
k,
@@ -34,10 +34,10 @@ def chunk_fwd_kernel_o(
chunk_offsets,
scale,
T,
H: tl.constexpr,
Hg: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
H,
Hg,
K,
V,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,

View File

@@ -14,8 +14,8 @@ from vllm.triton_utils import tl, triton
from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num
@triton.jit
def l2norm_fwd_kernel2_loop(X, Y, eps, M, N: tl.constexpr, MBLOCK: tl.constexpr, NUM_CHUNKS: tl.constexpr):
@triton.jit(do_not_specialize=["eps", "M", "NUM_CHUNKS"])
def l2norm_fwd_kernel2_loop(X, Y, eps, M, N: tl.constexpr, MBLOCK: tl.constexpr, NUM_CHUNKS):
base_row = tl.program_id(0) * (NUM_CHUNKS * MBLOCK)
rindex = tl.arange(0, N)[None, :]

View File

@@ -17,7 +17,7 @@ from .utils import prepare_chunk_indices
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
@triton.jit(do_not_specialize=["T"])
@triton.jit(do_not_specialize=["T", "H", "Hg", "K", "V"])
def recompute_w_u_fwd_kernel(
k,
v,
@@ -29,10 +29,10 @@ def recompute_w_u_fwd_kernel(
cu_seqlens,
chunk_indices,
T,
H: tl.constexpr,
Hg: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
H,
Hg,
K,
V,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,

View File

@@ -12,7 +12,7 @@ from vllm.triton_utils import tl, triton
@triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
@triton.heuristics({"HAS_Z": lambda args: args["Z"] is not None})
@triton.jit
@triton.jit(do_not_specialize=["stride_x_row", "stride_y_row", "stride_z_row", "M", "N", "eps"])
def _layer_norm_fwd_1pass_kernel_npu(
X, # pointer to the input
Y, # pointer to the output