[v0.18.0][kernel] Recompilation optimization triggered by triton function parameter optimization (#7647)
### What this PR does / why we need it?
Some parameters of Triton operators are unnecessarily modified with the
"constexpr" modifier. When these parameters change, recompilation is
triggered, which significantly affects the model performance. Therefore,
these parameters need to be rectified.
- vLLM version: v0.17.0
- vLLM main:
8b6325758c
Signed-off-by: HarpSealCC [844291270@qq.com](mailto:844291270@qq.com)
Signed-off-by: l30072083 <liuchengzhuo1@h-partners.com>
Co-authored-by: l30072083 <liuchengzhuo1@h-partners.com>
This commit is contained in:
@@ -26,7 +26,7 @@ _CONDITIONS = ("seq7168",)
|
||||
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
|
||||
}
|
||||
)
|
||||
@triton.jit(do_not_specialize=["T"])
|
||||
@triton.jit(do_not_specialize=["T", "H", "Hg", "K", "V"])
|
||||
def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
|
||||
k,
|
||||
v,
|
||||
@@ -40,10 +40,10 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
|
||||
chunk_offsets,
|
||||
h_update,
|
||||
T,
|
||||
H: tl.constexpr,
|
||||
Hg: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
V: tl.constexpr,
|
||||
H,
|
||||
Hg,
|
||||
K,
|
||||
V,
|
||||
BT: tl.constexpr,
|
||||
USE_G: tl.constexpr,
|
||||
USE_INITIAL_STATE: tl.constexpr,
|
||||
|
||||
@@ -26,7 +26,6 @@ def chunk_local_cumsum_scalar_kernel(
|
||||
cu_seqlens,
|
||||
chunk_indices,
|
||||
T,
|
||||
B: tl.constexpr,
|
||||
H: tl.constexpr,
|
||||
BLOCK_T: tl.constexpr,
|
||||
REVERSE: tl.constexpr,
|
||||
@@ -103,7 +102,6 @@ def chunk_local_cumsum_scalar(
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_indices=block_indices,
|
||||
T=T,
|
||||
B=B,
|
||||
H=H,
|
||||
BLOCK_T=OPTIM_BLOCK_SIZE,
|
||||
CHUNK_SIZE=chunk_size,
|
||||
|
||||
@@ -10,7 +10,7 @@ from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num
|
||||
UNIFIED_BUFFER_SIZE = 1572864
|
||||
|
||||
|
||||
@triton.jit
|
||||
@triton.jit(do_not_specialize=["seq_len", "NUM_HEADS", "NUM_BATCHES", "beta", "threshold", "ROW_ITER"])
|
||||
def fused_gdn_gating_kernel(
|
||||
g,
|
||||
beta_output,
|
||||
@@ -19,16 +19,17 @@ def fused_gdn_gating_kernel(
|
||||
b,
|
||||
dt_bias,
|
||||
seq_len,
|
||||
NUM_HEADS: tl.constexpr,
|
||||
NUM_BATCHES: tl.constexpr,
|
||||
beta: tl.constexpr,
|
||||
threshold: tl.constexpr,
|
||||
NUM_HEADS,
|
||||
NUM_BATCHES,
|
||||
beta,
|
||||
threshold,
|
||||
BLK_HEADS: tl.constexpr,
|
||||
COL_ITER: tl.constexpr,
|
||||
BLK_BATCHES: tl.constexpr,
|
||||
ROW_ITER: tl.constexpr,
|
||||
ROW_ITER,
|
||||
):
|
||||
i_b, i_s = tl.program_id(0), tl.program_id(1)
|
||||
COL_ITER = tl.cdiv(NUM_HEADS, BLK_HEADS)
|
||||
|
||||
for row_idx in range(0, ROW_ITER):
|
||||
batch_off = i_b * ROW_ITER * BLK_BATCHES + row_idx * BLK_BATCHES + tl.arange(0, BLK_BATCHES)
|
||||
|
||||
@@ -69,22 +70,10 @@ def fused_gdn_gating_patch(
|
||||
num_cores = get_vectorcore_num()
|
||||
|
||||
BLK_HEADS = 8
|
||||
COL_ITER = triton.cdiv(num_heads, BLK_HEADS)
|
||||
|
||||
elem_size = a.element_size()
|
||||
max_ub_batches = int((UNIFIED_BUFFER_SIZE * 0.95) / (BLK_HEADS * elem_size))
|
||||
if batch <= num_cores:
|
||||
progs = batch
|
||||
BLK_BATCHES = 1
|
||||
ROW_ITER = 1
|
||||
else:
|
||||
progs = num_cores
|
||||
FACTOR = 8 * num_heads
|
||||
calc_blk_batches = (
|
||||
triton.next_power_of_2(triton.cdiv(int(UNIFIED_BUFFER_SIZE * 0.95), FACTOR * BLK_HEADS * elem_size)) // 2
|
||||
)
|
||||
BLK_BATCHES = max(1, min(calc_blk_batches, max_ub_batches, 64))
|
||||
row_per_core = triton.cdiv(batch, progs)
|
||||
BLK_BATCHES = 64
|
||||
ROW_ITER = triton.cdiv(row_per_core, BLK_BATCHES)
|
||||
|
||||
g = torch.empty(1, batch, num_heads, dtype=torch.float32, device=a.device)
|
||||
@@ -104,7 +93,6 @@ def fused_gdn_gating_patch(
|
||||
beta,
|
||||
threshold,
|
||||
BLK_HEADS=BLK_HEADS,
|
||||
COL_ITER=COL_ITER,
|
||||
BLK_BATCHES=BLK_BATCHES,
|
||||
ROW_ITER=ROW_ITER,
|
||||
)
|
||||
|
||||
@@ -82,7 +82,7 @@ def bonus_renew(
|
||||
tl.store(output_token_ids_ptr + position * (max_spec_len + 1) + num_tokens1, bonus_token_id)
|
||||
|
||||
|
||||
@triton.jit(do_not_specialize=["max_spec_len"])
|
||||
@triton.jit(do_not_specialize=["vec_len", "max_spec_len"])
|
||||
def rejection_greedy_sample_triton(
|
||||
output_token_ids_ptr, # [batch_size, max_spec_len + 1]
|
||||
cu_num_draft_tokens_ptr, # [batch_size]
|
||||
@@ -196,7 +196,7 @@ def rejection_random_sample_kernel(
|
||||
)
|
||||
|
||||
|
||||
@triton.jit(do_not_specialize=["replace_from", "replace_to"])
|
||||
@triton.jit(do_not_specialize=["replace_from", "replace_to", "vec_len"])
|
||||
def expand_kernel(
|
||||
output_ptr, # [num_tokens]
|
||||
input_ptr, # [batch_size]
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
|
||||
@triton.jit
|
||||
@triton.jit(do_not_specialize=["num_reqs"])
|
||||
def prepare_inputs_padded_kernel(
|
||||
cu_num_draft_tokens_ptr, # [num_reqs]
|
||||
valid_sampled_tokens_count_ptr, # [num_reqs]
|
||||
|
||||
Reference in New Issue
Block a user