From f4871c6ab98af9ab2766779b2167de04805f38e2 Mon Sep 17 00:00:00 2001 From: MidnightSun Date: Mon, 1 Dec 2025 17:41:58 +0800 Subject: [PATCH] [Kernel] add triton kernels for sampling (#4550) ### What this PR does / why we need it? Replace pyorch implement of sampling with triton kernels ### Does this PR introduce _any_ user-facing change? No - vLLM version: v0.11.2 --------- Signed-off-by: Lord_of_Ironhill Signed-off-by: whx-sjtu <2952154980@qq.com> Co-authored-by: Lord_of_Ironhill Co-authored-by: whx-sjtu <2952154980@qq.com> --- vllm_ascend/sample/rejection_sampler.py | 329 ++++++++++++++++++++---- 1 file changed, 284 insertions(+), 45 deletions(-) diff --git a/vllm_ascend/sample/rejection_sampler.py b/vllm_ascend/sample/rejection_sampler.py index a17f5340..9bd941fc 100644 --- a/vllm_ascend/sample/rejection_sampler.py +++ b/vllm_ascend/sample/rejection_sampler.py @@ -4,6 +4,7 @@ from typing import Optional import torch import torch.nn as nn import vllm.v1.sample.rejection_sampler as rs +from vllm.triton_utils import HAS_TRITON, tl, triton from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.rejection_sampler import (RejectionSampler, apply_sampling_constraints, @@ -149,25 +150,36 @@ def rejection_sample( if not sampling_metadata.all_random: # Rejection sampling for greedy sampling requests. target_argmax = target_probs.argmax(dim=-1) - if min(num_draft_tokens) == 1 and max( - num_draft_tokens) == 1 and sampling_metadata.all_greedy: - rejection_greedy_sample_spec_len_1_pytorch( - output_token_ids, - draft_token_ids, - target_argmax, - bonus_token_ids, - ) - else: - rejection_greedy_sample_pytorch( + if HAS_TRITON: + rejection_greedy_sample_kernel[(batch_size, )]( output_token_ids, cu_num_draft_tokens, draft_token_ids, target_argmax, bonus_token_ids, - num_draft_tokens, - max_spec_len, is_greedy, + max_spec_len, ) + else: + if min(num_draft_tokens) == 1 and max( + num_draft_tokens) == 1 and sampling_metadata.all_greedy: + rejection_greedy_sample_spec_len_1_pytorch( + output_token_ids, + draft_token_ids, + target_argmax, + bonus_token_ids, + ) + else: + rejection_greedy_sample_pytorch( + output_token_ids, + cu_num_draft_tokens, + draft_token_ids, + target_argmax, + bonus_token_ids, + num_draft_tokens, + max_spec_len, + is_greedy, + ) if sampling_metadata.all_greedy: return output_token_ids @@ -194,21 +206,37 @@ def rejection_sample( ) # Rejection sampling for random sampling requests. - rejection_random_sample_pytorch( - output_token_ids, - cu_num_draft_tokens, - draft_token_ids, - draft_probs, - target_probs, - bonus_token_ids, - recovered_token_ids, - uniform_probs, - is_greedy, - max_spec_len, - vocab_size, - IS_NGRAM=draft_probs is None, - # num_warps=1, - ) + if HAS_TRITON: + rejection_random_sample_kernel[(batch_size, )]( + output_token_ids, + cu_num_draft_tokens, + draft_token_ids, + draft_probs, + target_probs, + bonus_token_ids, + recovered_token_ids, + uniform_probs, + is_greedy, + max_spec_len, + vocab_size, + NO_DRAFT_PROBS=draft_probs is None, + ) + else: + rejection_random_sample_pytorch( + output_token_ids, + cu_num_draft_tokens, + draft_token_ids, + draft_probs, + target_probs, + bonus_token_ids, + recovered_token_ids, + uniform_probs, + is_greedy, + max_spec_len, + vocab_size, + IS_NGRAM=draft_probs is None, + # num_warps=1, + ) return output_token_ids @@ -241,14 +269,24 @@ def expand_batch_to_tokens( batch_size = x.shape[0] assert cu_num_tokens.shape[0] == batch_size expanded_x = x.new_empty(num_tokens) - expand_pytorch( - expanded_x, - x, - cu_num_tokens, - replace_from, - replace_to, - MAX_NUM_TOKENS=MAX_SPEC_LEN, # To avoid recompilation. - ) + if HAS_TRITON: + expand_kernel[(batch_size, )]( + expanded_x, + x, + cu_num_tokens, + replace_from, + replace_to, + MAX_NUM_TOKENS=MAX_SPEC_LEN, # To avoid recompilation. + ) + else: + expand_pytorch( + expanded_x, + x, + cu_num_tokens, + replace_from, + replace_to, + MAX_NUM_TOKENS=MAX_SPEC_LEN, # To avoid recompilation. + ) return expanded_x @@ -282,16 +320,29 @@ def sample_recovered_tokens( q[i].exponential_(generator=generator) recovered_token_ids = torch.empty_like(draft_token_ids) - sample_recovered_tokens_pytorch( - recovered_token_ids, - cu_num_draft_tokens, - draft_token_ids, - draft_probs, - target_probs, - q, - vocab_size, - IS_NGRAM=draft_probs is None, - ) + if HAS_TRITON: + sample_recovered_tokens_kernel[(batch_size, max_spec_len)]( + recovered_token_ids, + cu_num_draft_tokens, + draft_token_ids, + draft_probs, + target_probs, + q, + vocab_size, + triton.next_power_of_2(vocab_size), + NO_DRAFT_PROBS=draft_probs is None, + ) + else: + sample_recovered_tokens_pytorch( + recovered_token_ids, + cu_num_draft_tokens, + draft_token_ids, + draft_probs, + target_probs, + q, + vocab_size, + IS_NGRAM=draft_probs is None, + ) return recovered_token_ids @@ -504,4 +555,192 @@ def sample_recovered_tokens_pytorch( target_probs[token_idx, draft_token_id] = orig_prob +@triton.jit(do_not_specialize=["max_spec_len"]) +def rejection_greedy_sample_kernel( + output_token_ids_ptr, # [batch_size, max_spec_len + 1] + cu_num_draft_tokens_ptr, # [batch_size] + draft_token_ids_ptr, # [num_tokens] + target_argmax_ptr, # [num_tokens] + bonus_token_ids_ptr, # [batch_size] + is_greedy_ptr, # [batch_size] or None + max_spec_len, +): + req_idx = tl.program_id(0) + # Because is_greedy_ptr is not Nonr at profiling run, + # re-comilation may happen during runtime when is_greedy_ptr is None. + is_greedy = True if is_greedy_ptr is None else tl.load(is_greedy_ptr + + req_idx) + if not is_greedy: + # Early exit for non-greedy sampling requests + return + + start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr + + req_idx - 1) + end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx) + num_draft_tokens = end_idx - start_idx + + rejected = False + for pos in range(num_draft_tokens): + if not rejected: + draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) + target_argmax_id = tl.load(target_argmax_ptr + start_idx + pos) + tl.store( + output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos, + target_argmax_id, + ) + if draft_token_id != target_argmax_id: + # Reject + rejected = True + + if not rejected: + # If all tokens are accepted, append the bonus token + bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx) + tl.store( + output_token_ids_ptr + req_idx * (max_spec_len + 1) + + num_draft_tokens, + bonus_token_id, + ) + + +@triton.jit(do_not_specialize=["max_spec_len"]) +def rejection_random_sample_kernel( + output_token_ids_ptr, # [batch_size, max_spec_len + 1] + cu_num_draft_tokens_ptr, # [batch_size] + draft_token_ids_ptr, # [num_tokens] + draft_probs_ptr, # [num_tokens, vocab_size] or None + target_probs_ptr, # [num_tokens, vocab_size] + bonus_token_ids_ptr, # [batch_size] + recovered_token_ids_ptr, # [num_tokens] + uniform_probs_ptr, # [num_tokens] + is_greedy_ptr, # [batch_size] + max_spec_len, + vocab_size, + NO_DRAFT_PROBS: tl.constexpr, +): + req_idx = tl.program_id(0) + is_greedy = tl.load(is_greedy_ptr + req_idx) + if is_greedy: + # Early exost for greedy sampling requests + return + + start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr + + req_idx - 1) + end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx) + num_draft_tokens = end_idx - start_idx + + rejected = False + for pos in range(num_draft_tokens): + if not rejected: + draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) + if NO_DRAFT_PROBS: + draft_prob = 1 + else: + draft_prob = tl.load(draft_probs_ptr + + (start_idx + pos) * vocab_size + + draft_token_id) + target_prob = tl.load(target_probs_ptr + + (start_idx + pos) * vocab_size + + draft_token_id) + uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos) + if draft_prob > 0 and target_prob / draft_prob >= uniform_prob: + # Accept + token_id = draft_token_id + else: + # Reject. Use recovered token + rejected = True + token_id = tl.load(recovered_token_ids_ptr + start_idx + pos) + tl.store(output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos, + token_id) + + if not rejected: + # If all tokens are accepted, append the bonus token + bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx) + tl.store( + output_token_ids_ptr + req_idx * (max_spec_len + 1) + + num_draft_tokens, + bonus_token_id, + ) + + +@triton.jit(do_not_specialize=["replace_from", "replace_to"]) +def expand_kernel( + output_ptr, # [num_tokens] + input_ptr, # [batch_size] + cu_num_tokens_ptr, # [batch_size] + replace_from, + replace_to, + MAX_NUM_TOKENS: tl.constexpr, +): + req_idx = tl.program_id(0) + if req_idx == 0: + start_idx = 0 + else: + start_idx = tl.load(cu_num_tokens_ptr + req_idx - 1) + end_idx = tl.load(cu_num_tokens_ptr + req_idx) + num_tokens = end_idx - start_idx + + src_val = tl.load(input_ptr + req_idx) + src_val = tl.where(src_val == replace_from, replace_to, src_val) + offset = tl.arange(0, MAX_NUM_TOKENS) + tl.store(output_ptr + start_idx + offset, + src_val, + mask=offset < num_tokens) + + +@triton.jit +def sample_recovered_tokens_kernel( + output_token_ids_ptr, # [num_tokens] + cu_num_draft_tokens_ptr, # [batch_size] + draft_token_ids_ptr, # [num_tokens] + draft_probs_ptr, # [num_tokens, vocab_size] or None + target_probs_ptr, # [num_tokens, vocab_size] + q_ptr, # [batch_size, vocab_size] + vocab_size, + PADDED_VOCAB_SIZE: tl.constexpr, + NO_DRAFT_PROBS: tl.constexpr, +): + req_idx = tl.program_id(0) + start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr + + req_idx - 1) + end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx) + num_draft_tokens = end_idx - start_idx + + # Early exit for out-of-range positions + pos = tl.program_id(1) + if pos >= num_draft_tokens: + return + + vocab_offset = tl.arange(0, PADDED_VOCAB_SIZE) + if NO_DRAFT_PROBS: + draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) + prob = tl.load( + target_probs_ptr + (start_idx + pos) * vocab_size + vocab_offset, + mask=((vocab_offset < vocab_size) & + (vocab_offset != draft_token_id)), + other=0, + ) + else: + draft_prob = tl.load( + draft_probs_ptr + (start_idx + pos) * vocab_size + vocab_offset, + mask=vocab_offset < vocab_size, + other=0, + ) + target_prob = tl.load( + target_probs_ptr + (start_idx + pos) * vocab_size + vocab_offset, + mask=vocab_offset < vocab_size, + other=0, + ) + prob = tl.maximum(target_prob - draft_prob, 0) + # We don't need `prob = prob / tl.sum(prob)` here because + # `tl.argmax` will select the maximum value. + + q = tl.load( + q_ptr + req_idx * vocab_size + vocab_offset, + mask=vocab_offset < vocab_size, + other=float("-inf"), + ) + recovered_id = tl.argmax(prob / q, axis=-1) + tl.store(output_token_ids_ptr + start_idx + pos, recovered_id) + + rs.expand_batch_to_tokens = expand_batch_to_tokens