diff --git a/vllm_ascend/ops/triton/fla/chunk_delta_h.py b/vllm_ascend/ops/triton/fla/chunk_delta_h.py index e305878e..0189d313 100644 --- a/vllm_ascend/ops/triton/fla/chunk_delta_h.py +++ b/vllm_ascend/ops/triton/fla/chunk_delta_h.py @@ -26,7 +26,7 @@ _CONDITIONS = ("seq7168",) "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, } ) -@triton.jit(do_not_specialize=["T"]) +@triton.jit(do_not_specialize=["T", "H", "Hg", "K", "V"]) def chunk_gated_delta_rule_fwd_kernel_h_blockdim64( k, v, @@ -40,10 +40,10 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64( chunk_offsets, h_update, T, - H: tl.constexpr, - Hg: tl.constexpr, - K: tl.constexpr, - V: tl.constexpr, + H, + Hg, + K, + V, BT: tl.constexpr, USE_G: tl.constexpr, USE_INITIAL_STATE: tl.constexpr, diff --git a/vllm_ascend/ops/triton/fla/cumsum.py b/vllm_ascend/ops/triton/fla/cumsum.py index ea2fa143..95965617 100644 --- a/vllm_ascend/ops/triton/fla/cumsum.py +++ b/vllm_ascend/ops/triton/fla/cumsum.py @@ -26,7 +26,6 @@ def chunk_local_cumsum_scalar_kernel( cu_seqlens, chunk_indices, T, - B: tl.constexpr, H: tl.constexpr, BLOCK_T: tl.constexpr, REVERSE: tl.constexpr, @@ -103,7 +102,6 @@ def chunk_local_cumsum_scalar( cu_seqlens=cu_seqlens, chunk_indices=block_indices, T=T, - B=B, H=H, BLOCK_T=OPTIM_BLOCK_SIZE, CHUNK_SIZE=chunk_size, diff --git a/vllm_ascend/ops/triton/fused_gdn_gating.py b/vllm_ascend/ops/triton/fused_gdn_gating.py index 9c3ea9d1..69bbf699 100644 --- a/vllm_ascend/ops/triton/fused_gdn_gating.py +++ b/vllm_ascend/ops/triton/fused_gdn_gating.py @@ -10,7 +10,7 @@ from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num UNIFIED_BUFFER_SIZE = 1572864 -@triton.jit +@triton.jit(do_not_specialize=["seq_len", "NUM_HEADS", "NUM_BATCHES", "beta", "threshold", "ROW_ITER"]) def fused_gdn_gating_kernel( g, beta_output, @@ -19,16 +19,17 @@ def fused_gdn_gating_kernel( b, dt_bias, seq_len, - NUM_HEADS: tl.constexpr, - NUM_BATCHES: tl.constexpr, - beta: tl.constexpr, - threshold: tl.constexpr, + NUM_HEADS, + NUM_BATCHES, + beta, + threshold, BLK_HEADS: tl.constexpr, - COL_ITER: tl.constexpr, BLK_BATCHES: tl.constexpr, - ROW_ITER: tl.constexpr, + ROW_ITER, ): i_b, i_s = tl.program_id(0), tl.program_id(1) + COL_ITER = tl.cdiv(NUM_HEADS, BLK_HEADS) + for row_idx in range(0, ROW_ITER): batch_off = i_b * ROW_ITER * BLK_BATCHES + row_idx * BLK_BATCHES + tl.arange(0, BLK_BATCHES) @@ -69,23 +70,11 @@ def fused_gdn_gating_patch( num_cores = get_vectorcore_num() BLK_HEADS = 8 - COL_ITER = triton.cdiv(num_heads, BLK_HEADS) - elem_size = a.element_size() - max_ub_batches = int((UNIFIED_BUFFER_SIZE * 0.95) / (BLK_HEADS * elem_size)) - if batch <= num_cores: - progs = batch - BLK_BATCHES = 1 - ROW_ITER = 1 - else: - progs = num_cores - FACTOR = 8 * num_heads - calc_blk_batches = ( - triton.next_power_of_2(triton.cdiv(int(UNIFIED_BUFFER_SIZE * 0.95), FACTOR * BLK_HEADS * elem_size)) // 2 - ) - BLK_BATCHES = max(1, min(calc_blk_batches, max_ub_batches, 64)) - row_per_core = triton.cdiv(batch, progs) - ROW_ITER = triton.cdiv(row_per_core, BLK_BATCHES) + progs = num_cores + row_per_core = triton.cdiv(batch, progs) + BLK_BATCHES = 64 + ROW_ITER = triton.cdiv(row_per_core, BLK_BATCHES) g = torch.empty(1, batch, num_heads, dtype=torch.float32, device=a.device) beta_output = torch.empty(1, batch, num_heads, dtype=b.dtype, device=b.device) @@ -104,7 +93,6 @@ def fused_gdn_gating_patch( beta, threshold, BLK_HEADS=BLK_HEADS, - COL_ITER=COL_ITER, BLK_BATCHES=BLK_BATCHES, ROW_ITER=ROW_ITER, ) diff --git a/vllm_ascend/ops/triton/reject_sample.py b/vllm_ascend/ops/triton/reject_sample.py index 58a0fce7..4c85a430 100644 --- a/vllm_ascend/ops/triton/reject_sample.py +++ b/vllm_ascend/ops/triton/reject_sample.py @@ -82,7 +82,7 @@ def bonus_renew( tl.store(output_token_ids_ptr + position * (max_spec_len + 1) + num_tokens1, bonus_token_id) -@triton.jit(do_not_specialize=["max_spec_len"]) +@triton.jit(do_not_specialize=["vec_len", "max_spec_len"]) def rejection_greedy_sample_triton( output_token_ids_ptr, # [batch_size, max_spec_len + 1] cu_num_draft_tokens_ptr, # [batch_size] @@ -196,7 +196,7 @@ def rejection_random_sample_kernel( ) -@triton.jit(do_not_specialize=["replace_from", "replace_to"]) +@triton.jit(do_not_specialize=["replace_from", "replace_to", "vec_len"]) def expand_kernel( output_ptr, # [num_tokens] input_ptr, # [batch_size] diff --git a/vllm_ascend/ops/triton/spec_decode/utils.py b/vllm_ascend/ops/triton/spec_decode/utils.py index 15d42fa0..3c7aa450 100644 --- a/vllm_ascend/ops/triton/spec_decode/utils.py +++ b/vllm_ascend/ops/triton/spec_decode/utils.py @@ -18,7 +18,7 @@ from vllm.triton_utils import tl, triton -@triton.jit +@triton.jit(do_not_specialize=["num_reqs"]) def prepare_inputs_padded_kernel( cu_num_draft_tokens_ptr, # [num_reqs] valid_sampled_tokens_count_ptr, # [num_reqs]