diff --git a/vllm_ascend/ops/triton/fla/chunk_scaled_dot_kkt.py b/vllm_ascend/ops/triton/fla/chunk_scaled_dot_kkt.py index 7ceafe3c..864a29e6 100644 --- a/vllm_ascend/ops/triton/fla/chunk_scaled_dot_kkt.py +++ b/vllm_ascend/ops/triton/fla/chunk_scaled_dot_kkt.py @@ -21,7 +21,7 @@ from .utils import prepare_chunk_indices, safe_exp "USE_G": lambda args: args["g_cumsum"] is not None, } ) -@triton.jit(do_not_specialize=["T"]) +@triton.jit(do_not_specialize=["T", "B"]) def chunk_scaled_dot_kkt_fwd_kernel( k, beta, # [H, B, T] diff --git a/vllm_ascend/ops/triton/fla/sigmoid_gating.py b/vllm_ascend/ops/triton/fla/sigmoid_gating.py index e5512c6d..bd683cb6 100644 --- a/vllm_ascend/ops/triton/fla/sigmoid_gating.py +++ b/vllm_ascend/ops/triton/fla/sigmoid_gating.py @@ -39,7 +39,7 @@ else: "IS_SPEC_DECODING": lambda args: args["num_accepted_tokens"] is not None, } ) -@triton.jit(do_not_specialize=["N", "T"]) +@triton.jit(do_not_specialize=["scale", "N", "T", "B"]) def fused_recurrent_gated_delta_rule_fwd_kernel( q, k, @@ -53,9 +53,9 @@ def fused_recurrent_gated_delta_rule_fwd_kernel( ssm_state_indices, num_accepted_tokens, scale, - N: tl.constexpr, # num of sequences - T: tl.constexpr, # num of tokens - B: tl.constexpr, + N, # num of sequences + T, # num of tokens + B, H: tl.constexpr, HV: tl.constexpr, K: tl.constexpr, diff --git a/vllm_ascend/ops/triton/fla/solve_tril.py b/vllm_ascend/ops/triton/fla/solve_tril.py index 1783035d..fb46ccfa 100644 --- a/vllm_ascend/ops/triton/fla/solve_tril.py +++ b/vllm_ascend/ops/triton/fla/solve_tril.py @@ -18,14 +18,14 @@ from .utils import prepare_chunk_indices @triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) -@triton.jit(do_not_specialize=["T"]) +@triton.jit(do_not_specialize=["T", "H"]) def solve_tril_16x16_kernel( A, Ad, cu_seqlens, chunk_indices, T, - H: tl.constexpr, + H, BT: tl.constexpr, IS_VARLEN: tl.constexpr, LARGE_BLOCK_T: tl.constexpr, @@ -134,7 +134,7 @@ def solve_tril_16x16_kernel( @triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) -@triton.jit(do_not_specialize=["T"]) +@triton.jit(do_not_specialize=["T", "H"]) def merge_16x16_to_32x32_inverse_kernel( A, Ad, @@ -142,7 +142,7 @@ def merge_16x16_to_32x32_inverse_kernel( cu_seqlens, chunk_indices, T, - H: tl.constexpr, + H, BT: tl.constexpr, IS_VARLEN: tl.constexpr, ): @@ -198,7 +198,7 @@ def merge_16x16_to_32x32_inverse_kernel( @triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) -@triton.jit(do_not_specialize=["T"]) +@triton.jit(do_not_specialize=["T", "H"]) def merge_16x16_to_64x64_inverse_kernel( A, Ad, @@ -206,7 +206,7 @@ def merge_16x16_to_64x64_inverse_kernel( cu_seqlens, chunk_indices, T, - H: tl.constexpr, + H, BT: tl.constexpr, IS_VARLEN: tl.constexpr, ): diff --git a/vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_mrope.py b/vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_mrope.py index ba920c25..25aa727b 100644 --- a/vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_mrope.py +++ b/vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_mrope.py @@ -47,9 +47,9 @@ def split_qkv_rmsnorm_mrope_kernel( q_size: tl.constexpr, kv_size: tl.constexpr, eps: tl.constexpr, - mrope_section_t: tl.constexpr, - mrope_section_h: tl.constexpr, - mrope_section_w: tl.constexpr, + mrope_section_t, + mrope_section_h, + mrope_section_w, has_bias: tl.constexpr, is_interleaved: tl.constexpr, rope_dim: tl.constexpr, diff --git a/vllm_ascend/ops/triton/mamba/causal_conv1d.py b/vllm_ascend/ops/triton/mamba/causal_conv1d.py index a06320e0..da7f4183 100644 --- a/vllm_ascend/ops/triton/mamba/causal_conv1d.py +++ b/vllm_ascend/ops/triton/mamba/causal_conv1d.py @@ -156,7 +156,20 @@ def extract_last_width(x, start_loc, width): return x[:, indices].permute(1, 0, 2) -@triton.jit +@triton.jit( + do_not_specialize=[ + "batch", + "state_len", + "num_cache_lines", + "stride_x_seq", + "stride_x_token", + "stride_conv_state_seq", + "stride_conv_state_tok", + "stride_state_indices", + "stride_o_seq", + "stride_o_token", + ] +) def _causal_conv1d_update_kernel_npu_tiled( # Pointers x_ptr, # (batch, dim, seqlen) OR (num_tokens, dim) for varlen @@ -172,21 +185,21 @@ def _causal_conv1d_update_kernel_npu_tiled( batch: tl.int32, dim: tl.constexpr, seqlen: tl.constexpr, # max seqlen for varlen, or exact seqlen - state_len: tl.constexpr, # effective state_len computed in wrapper - num_cache_lines: tl.constexpr, + state_len, # effective state_len computed in wrapper + num_cache_lines, # Strides - stride_x_seq: tl.constexpr, + stride_x_seq, stride_x_dim: tl.constexpr, - stride_x_token: tl.constexpr, + stride_x_token, stride_w_dim: tl.constexpr, stride_w_width: tl.constexpr, - stride_conv_state_seq: tl.constexpr, + stride_conv_state_seq, stride_conv_state_dim: tl.constexpr, - stride_conv_state_tok: tl.constexpr, - stride_state_indices: tl.constexpr, - stride_o_seq: tl.constexpr, + stride_conv_state_tok, + stride_state_indices, + stride_o_seq, stride_o_dim: tl.constexpr, - stride_o_token: tl.constexpr, + stride_o_token, # others pad_slot_id: tl.constexpr, # Meta