diff --git a/vllm_ascend/ops/triton/fla/chunk_o.py b/vllm_ascend/ops/triton/fla/chunk_o.py index 258b40eb..66093fe4 100644 --- a/vllm_ascend/ops/triton/fla/chunk_o.py +++ b/vllm_ascend/ops/triton/fla/chunk_o.py @@ -22,7 +22,7 @@ from .utils import prepare_chunk_offsets, safe_exp "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, } ) -@triton.jit(do_not_specialize=["T"]) +@triton.jit(do_not_specialize=["chunk_offsets", "scale", "T", "H", "Hg", "K", "V"]) def chunk_fwd_kernel_o( q, k, @@ -34,10 +34,10 @@ def chunk_fwd_kernel_o( chunk_offsets, scale, T, - H: tl.constexpr, - Hg: tl.constexpr, - K: tl.constexpr, - V: tl.constexpr, + H, + Hg, + K, + V, BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, diff --git a/vllm_ascend/ops/triton/fla/l2norm.py b/vllm_ascend/ops/triton/fla/l2norm.py index 9ba89faa..28abeccd 100644 --- a/vllm_ascend/ops/triton/fla/l2norm.py +++ b/vllm_ascend/ops/triton/fla/l2norm.py @@ -14,8 +14,8 @@ from vllm.triton_utils import tl, triton from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num -@triton.jit -def l2norm_fwd_kernel2_loop(X, Y, eps, M, N: tl.constexpr, MBLOCK: tl.constexpr, NUM_CHUNKS: tl.constexpr): +@triton.jit(do_not_specialize=["eps", "M", "NUM_CHUNKS"]) +def l2norm_fwd_kernel2_loop(X, Y, eps, M, N: tl.constexpr, MBLOCK: tl.constexpr, NUM_CHUNKS): base_row = tl.program_id(0) * (NUM_CHUNKS * MBLOCK) rindex = tl.arange(0, N)[None, :] diff --git a/vllm_ascend/ops/triton/fla/wy_fast.py b/vllm_ascend/ops/triton/fla/wy_fast.py index 12cdfcad..82f2ab18 100644 --- a/vllm_ascend/ops/triton/fla/wy_fast.py +++ b/vllm_ascend/ops/triton/fla/wy_fast.py @@ -17,7 +17,7 @@ from .utils import prepare_chunk_indices @triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) -@triton.jit(do_not_specialize=["T"]) +@triton.jit(do_not_specialize=["T", "H", "Hg", "K", "V"]) def recompute_w_u_fwd_kernel( k, v, @@ -29,10 +29,10 @@ def recompute_w_u_fwd_kernel( cu_seqlens, chunk_indices, T, - H: tl.constexpr, - Hg: tl.constexpr, - K: tl.constexpr, - V: tl.constexpr, + H, + Hg, + K, + V, BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, diff --git a/vllm_ascend/ops/triton/layernorm_gated.py b/vllm_ascend/ops/triton/layernorm_gated.py index 76c418f2..0da63052 100644 --- a/vllm_ascend/ops/triton/layernorm_gated.py +++ b/vllm_ascend/ops/triton/layernorm_gated.py @@ -12,7 +12,7 @@ from vllm.triton_utils import tl, triton @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) @triton.heuristics({"HAS_Z": lambda args: args["Z"] is not None}) -@triton.jit +@triton.jit(do_not_specialize=["stride_x_row", "stride_y_row", "stride_z_row", "M", "N", "eps"]) def _layer_norm_fwd_1pass_kernel_npu( X, # pointer to the input Y, # pointer to the output