[kernel] Recompilation optimization triggered by triton function parameter optimization (#7645)
<!-- Thanks for sending a pull request! BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing/overview.html --> ### What this PR does / why we need it? - Please clarify why the changes are needed. For instance, the use case and bug description. Some parameters of Triton operators are unnecessarily modified with the "constexpr" modifier. When these parameters change, recompilation is triggered, which significantly affects the model performance. Therefore, these parameters need to be rectified. main branch:https://github.com/vllm-project/vllm-ascend/pull/7483 ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? <!-- CI passed with new added/existing test. If it was tested in a way different from regular unit tests, please clarify how you tested step by step, ideally copy and paste-able, so that other reviewers can test and check, and descendants can verify in the future. If tests were not added, please describe why they were not added and/or why it was difficult to add. --> --------- Signed-off-by: cvSoldier <610496306@qq.com>
This commit is contained in:
@@ -21,7 +21,7 @@ from .utils import prepare_chunk_indices, safe_exp
|
|||||||
"USE_G": lambda args: args["g_cumsum"] is not None,
|
"USE_G": lambda args: args["g_cumsum"] is not None,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@triton.jit(do_not_specialize=["T"])
|
@triton.jit(do_not_specialize=["T", "B"])
|
||||||
def chunk_scaled_dot_kkt_fwd_kernel(
|
def chunk_scaled_dot_kkt_fwd_kernel(
|
||||||
k,
|
k,
|
||||||
beta, # [H, B, T]
|
beta, # [H, B, T]
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ else:
|
|||||||
"IS_SPEC_DECODING": lambda args: args["num_accepted_tokens"] is not None,
|
"IS_SPEC_DECODING": lambda args: args["num_accepted_tokens"] is not None,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@triton.jit(do_not_specialize=["N", "T"])
|
@triton.jit(do_not_specialize=["scale", "N", "T", "B"])
|
||||||
def fused_recurrent_gated_delta_rule_fwd_kernel(
|
def fused_recurrent_gated_delta_rule_fwd_kernel(
|
||||||
q,
|
q,
|
||||||
k,
|
k,
|
||||||
@@ -53,9 +53,9 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
|
|||||||
ssm_state_indices,
|
ssm_state_indices,
|
||||||
num_accepted_tokens,
|
num_accepted_tokens,
|
||||||
scale,
|
scale,
|
||||||
N: tl.constexpr, # num of sequences
|
N, # num of sequences
|
||||||
T: tl.constexpr, # num of tokens
|
T, # num of tokens
|
||||||
B: tl.constexpr,
|
B,
|
||||||
H: tl.constexpr,
|
H: tl.constexpr,
|
||||||
HV: tl.constexpr,
|
HV: tl.constexpr,
|
||||||
K: tl.constexpr,
|
K: tl.constexpr,
|
||||||
|
|||||||
@@ -18,14 +18,14 @@ from .utils import prepare_chunk_indices
|
|||||||
|
|
||||||
|
|
||||||
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
|
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
|
||||||
@triton.jit(do_not_specialize=["T"])
|
@triton.jit(do_not_specialize=["T", "H"])
|
||||||
def solve_tril_16x16_kernel(
|
def solve_tril_16x16_kernel(
|
||||||
A,
|
A,
|
||||||
Ad,
|
Ad,
|
||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
chunk_indices,
|
chunk_indices,
|
||||||
T,
|
T,
|
||||||
H: tl.constexpr,
|
H,
|
||||||
BT: tl.constexpr,
|
BT: tl.constexpr,
|
||||||
IS_VARLEN: tl.constexpr,
|
IS_VARLEN: tl.constexpr,
|
||||||
LARGE_BLOCK_T: tl.constexpr,
|
LARGE_BLOCK_T: tl.constexpr,
|
||||||
@@ -134,7 +134,7 @@ def solve_tril_16x16_kernel(
|
|||||||
|
|
||||||
|
|
||||||
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
|
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
|
||||||
@triton.jit(do_not_specialize=["T"])
|
@triton.jit(do_not_specialize=["T", "H"])
|
||||||
def merge_16x16_to_32x32_inverse_kernel(
|
def merge_16x16_to_32x32_inverse_kernel(
|
||||||
A,
|
A,
|
||||||
Ad,
|
Ad,
|
||||||
@@ -142,7 +142,7 @@ def merge_16x16_to_32x32_inverse_kernel(
|
|||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
chunk_indices,
|
chunk_indices,
|
||||||
T,
|
T,
|
||||||
H: tl.constexpr,
|
H,
|
||||||
BT: tl.constexpr,
|
BT: tl.constexpr,
|
||||||
IS_VARLEN: tl.constexpr,
|
IS_VARLEN: tl.constexpr,
|
||||||
):
|
):
|
||||||
@@ -198,7 +198,7 @@ def merge_16x16_to_32x32_inverse_kernel(
|
|||||||
|
|
||||||
|
|
||||||
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
|
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
|
||||||
@triton.jit(do_not_specialize=["T"])
|
@triton.jit(do_not_specialize=["T", "H"])
|
||||||
def merge_16x16_to_64x64_inverse_kernel(
|
def merge_16x16_to_64x64_inverse_kernel(
|
||||||
A,
|
A,
|
||||||
Ad,
|
Ad,
|
||||||
@@ -206,7 +206,7 @@ def merge_16x16_to_64x64_inverse_kernel(
|
|||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
chunk_indices,
|
chunk_indices,
|
||||||
T,
|
T,
|
||||||
H: tl.constexpr,
|
H,
|
||||||
BT: tl.constexpr,
|
BT: tl.constexpr,
|
||||||
IS_VARLEN: tl.constexpr,
|
IS_VARLEN: tl.constexpr,
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -47,9 +47,9 @@ def split_qkv_rmsnorm_mrope_kernel(
|
|||||||
q_size: tl.constexpr,
|
q_size: tl.constexpr,
|
||||||
kv_size: tl.constexpr,
|
kv_size: tl.constexpr,
|
||||||
eps: tl.constexpr,
|
eps: tl.constexpr,
|
||||||
mrope_section_t: tl.constexpr,
|
mrope_section_t,
|
||||||
mrope_section_h: tl.constexpr,
|
mrope_section_h,
|
||||||
mrope_section_w: tl.constexpr,
|
mrope_section_w,
|
||||||
has_bias: tl.constexpr,
|
has_bias: tl.constexpr,
|
||||||
is_interleaved: tl.constexpr,
|
is_interleaved: tl.constexpr,
|
||||||
rope_dim: tl.constexpr,
|
rope_dim: tl.constexpr,
|
||||||
|
|||||||
@@ -156,7 +156,20 @@ def extract_last_width(x, start_loc, width):
|
|||||||
return x[:, indices].permute(1, 0, 2)
|
return x[:, indices].permute(1, 0, 2)
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit(
|
||||||
|
do_not_specialize=[
|
||||||
|
"batch",
|
||||||
|
"state_len",
|
||||||
|
"num_cache_lines",
|
||||||
|
"stride_x_seq",
|
||||||
|
"stride_x_token",
|
||||||
|
"stride_conv_state_seq",
|
||||||
|
"stride_conv_state_tok",
|
||||||
|
"stride_state_indices",
|
||||||
|
"stride_o_seq",
|
||||||
|
"stride_o_token",
|
||||||
|
]
|
||||||
|
)
|
||||||
def _causal_conv1d_update_kernel_npu_tiled(
|
def _causal_conv1d_update_kernel_npu_tiled(
|
||||||
# Pointers
|
# Pointers
|
||||||
x_ptr, # (batch, dim, seqlen) OR (num_tokens, dim) for varlen
|
x_ptr, # (batch, dim, seqlen) OR (num_tokens, dim) for varlen
|
||||||
@@ -172,21 +185,21 @@ def _causal_conv1d_update_kernel_npu_tiled(
|
|||||||
batch: tl.int32,
|
batch: tl.int32,
|
||||||
dim: tl.constexpr,
|
dim: tl.constexpr,
|
||||||
seqlen: tl.constexpr, # max seqlen for varlen, or exact seqlen
|
seqlen: tl.constexpr, # max seqlen for varlen, or exact seqlen
|
||||||
state_len: tl.constexpr, # effective state_len computed in wrapper
|
state_len, # effective state_len computed in wrapper
|
||||||
num_cache_lines: tl.constexpr,
|
num_cache_lines,
|
||||||
# Strides
|
# Strides
|
||||||
stride_x_seq: tl.constexpr,
|
stride_x_seq,
|
||||||
stride_x_dim: tl.constexpr,
|
stride_x_dim: tl.constexpr,
|
||||||
stride_x_token: tl.constexpr,
|
stride_x_token,
|
||||||
stride_w_dim: tl.constexpr,
|
stride_w_dim: tl.constexpr,
|
||||||
stride_w_width: tl.constexpr,
|
stride_w_width: tl.constexpr,
|
||||||
stride_conv_state_seq: tl.constexpr,
|
stride_conv_state_seq,
|
||||||
stride_conv_state_dim: tl.constexpr,
|
stride_conv_state_dim: tl.constexpr,
|
||||||
stride_conv_state_tok: tl.constexpr,
|
stride_conv_state_tok,
|
||||||
stride_state_indices: tl.constexpr,
|
stride_state_indices,
|
||||||
stride_o_seq: tl.constexpr,
|
stride_o_seq,
|
||||||
stride_o_dim: tl.constexpr,
|
stride_o_dim: tl.constexpr,
|
||||||
stride_o_token: tl.constexpr,
|
stride_o_token,
|
||||||
# others
|
# others
|
||||||
pad_slot_id: tl.constexpr,
|
pad_slot_id: tl.constexpr,
|
||||||
# Meta
|
# Meta
|
||||||
|
|||||||
Reference in New Issue
Block a user