feat: implement high-performance Triton kernels for rejection sampling (#4830)

### What this PR does / why we need it?
This PR introduces optimized Triton implementations for the
rejection_greedy_sample_kernel and expand_kernel, delivering superior
performance compared to the existing Triton implementations. The new
Triton kernels maintain full functional accuracy while delivering
significant performance improvements across various batch sizes and MTP
configurations.

### Does this PR introduce _any_ user-facing change?
Yes, this PR modifies rejection_sampler.py to use optimized Triton
kernels:

- rejection_greedy_sample_kernel is enhanced with
rejection_greedy_sample_spec_len_1_triton and
rejection_greedy_sample_triton implementations

- expand_kernel receives a performance-optimized Triton version

These changes provide substantial performance improvements while
maintaining backward compatibility


- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: yuxingcyx <yuxingchen.math@gmail.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
yuxingcyx
2025-12-18 19:42:10 +08:00
committed by GitHub
parent 0f571c347b
commit 5a88e3333b

View File

@@ -16,6 +16,16 @@ GREEDY_TEMPERATURE = -1
# step. This value is chosen to be large enough to handle typical use cases. # step. This value is chosen to be large enough to handle typical use cases.
MAX_SPEC_LEN = 32 MAX_SPEC_LEN = 32
vectorcore_num = None
device_properties = None
if HAS_TRITON:
from triton.runtime import driver # type: ignore
device_properties = driver.active.utils.get_device_properties(
torch.npu.current_device())
vectorcore_num = device_properties['num_vectorcore']
#get vector core number in order for later tiling
def apply_sampling_constraints( def apply_sampling_constraints(
logits: torch.Tensor, # [num_tokens, vocab_size] logits: torch.Tensor, # [num_tokens, vocab_size]
@@ -128,15 +138,36 @@ def rejection_sample(
# Rejection sampling for greedy sampling requests. # Rejection sampling for greedy sampling requests.
target_argmax = target_probs.argmax(dim=-1) target_argmax = target_probs.argmax(dim=-1)
if HAS_TRITON: if HAS_TRITON:
rejection_greedy_sample_kernel[(batch_size, )]( vec_len = batch_size
output_token_ids, n = cu_num_draft_tokens.numel()
cu_num_draft_tokens, BLOCK_SIZE = 2
draft_token_ids, grid = triton.cdiv(n, BLOCK_SIZE)
target_argmax, if n >= vectorcore_num:
bonus_token_ids, grid = vectorcore_num # Empirically tuned value
is_greedy, BLOCK_SIZE = triton.next_power_of_2(triton.cdiv(n, grid))
max_spec_len,
) if min(num_draft_tokens) == 1 and max(
num_draft_tokens) == 1 and sampling_metadata.all_greedy:
rejection_greedy_sample_spec_len_1_triton[(grid, )](
output_token_ids,
draft_token_ids,
target_argmax,
bonus_token_ids,
vec_len,
BLOCK_SIZE=BLOCK_SIZE,
)
else:
rejection_greedy_sample_triton[(grid, )](
output_token_ids,
cu_num_draft_tokens,
draft_token_ids,
target_argmax,
bonus_token_ids,
is_greedy,
vec_len,
max_spec_len,
BLOCK_SIZE=BLOCK_SIZE,
)
else: else:
if min(num_draft_tokens) == 1 and max( if min(num_draft_tokens) == 1 and max(
num_draft_tokens) == 1 and sampling_metadata.all_greedy: num_draft_tokens) == 1 and sampling_metadata.all_greedy:
@@ -247,13 +278,23 @@ def expand_batch_to_tokens(
assert cu_num_tokens.shape[0] == batch_size assert cu_num_tokens.shape[0] == batch_size
expanded_x = x.new_empty(num_tokens) expanded_x = x.new_empty(num_tokens)
if HAS_TRITON: if HAS_TRITON:
expand_kernel[(batch_size, )]( vec_len = batch_size
n = cu_num_tokens.numel()
BLOCK_SIZE = 2
grid = triton.cdiv(n, BLOCK_SIZE)
if n >= vectorcore_num:
grid = vectorcore_num
BLOCK_SIZE = triton.next_power_of_2(triton.cdiv(n, grid))
expand_kernel[(grid, )](
expanded_x, expanded_x,
x, x,
cu_num_tokens, cu_num_tokens,
replace_from, replace_from,
replace_to, replace_to,
vec_len,
MAX_NUM_TOKENS=MAX_SPEC_LEN, # To avoid recompilation. MAX_NUM_TOKENS=MAX_SPEC_LEN, # To avoid recompilation.
BLOCK_SIZE=BLOCK_SIZE,
) )
else: else:
expand_pytorch( expand_pytorch(
@@ -536,50 +577,112 @@ def sample_recovered_tokens_pytorch(
@triton.jit(do_not_specialize=["max_spec_len"]) @triton.jit(do_not_specialize=["max_spec_len"])
def rejection_greedy_sample_kernel( def bonus_renew_1(
bonus_token_ids_ptr,
position,
output_token_ids_ptr,
):
bonus_token_id = tl.load(bonus_token_ids_ptr + position)
tl.store(output_token_ids_ptr + position * 2 + 1, bonus_token_id)
@triton.jit(do_not_specialize=["max_spec_len"])
def rejection_greedy_sample_spec_len_1_triton(
output_token_ids_ptr, # [batch_size, 2]
draft_token_ids_ptr, # [num_tokens]
target_argmax_ptr, # [num_tokens]
bonus_token_ids_ptr,
vec_len,
BLOCK_SIZE: tl.constexpr,
):
block_idx = tl.program_id(0)
offset = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offset < vec_len
draft_token_id = tl.load(draft_token_ids_ptr + offset, mask)
target_argmax_id = tl.load(target_argmax_ptr + offset, mask)
tl.store(output_token_ids_ptr + offset * 2, target_argmax_id, mask)
for pos in tl.range(0, BLOCK_SIZE):
draft_token_id1 = tl.get_element(draft_token_id, (pos, ))
target_argmax1 = tl.get_element(target_argmax_id, (pos, ))
position = block_idx * BLOCK_SIZE + pos
if draft_token_id1 == target_argmax1:
bonus_renew_1(
bonus_token_ids_ptr,
position,
output_token_ids_ptr,
)
@triton.jit(do_not_specialize=["max_spec_len"])
def bonus_renew(
bonus_token_ids_ptr,
position,
output_token_ids_ptr,
max_spec_len,
num_tokens1,
):
bonus_token_id = tl.load(bonus_token_ids_ptr + position)
tl.store(
output_token_ids_ptr + position * (max_spec_len + 1) + num_tokens1,
bonus_token_id)
@triton.jit(do_not_specialize=["max_spec_len"])
def rejection_greedy_sample_triton(
output_token_ids_ptr, # [batch_size, max_spec_len + 1] output_token_ids_ptr, # [batch_size, max_spec_len + 1]
cu_num_draft_tokens_ptr, # [batch_size] cu_num_draft_tokens_ptr, # [batch_size]
draft_token_ids_ptr, # [num_tokens] draft_token_ids_ptr, # [num_tokens]
target_argmax_ptr, # [num_tokens] target_argmax_ptr, # [num_tokens]
bonus_token_ids_ptr, # [batch_size] bonus_token_ids_ptr, # [batch_size]
is_greedy_ptr, # [batch_size] or None is_greedy_ptr, # [batch_size] or None
vec_len,
max_spec_len, max_spec_len,
BLOCK_SIZE: tl.constexpr,
): ):
req_idx = tl.program_id(0) block_idx = tl.program_id(0)
# Because is_greedy_ptr is not Nonr at profiling run, offset = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
# re-comilation may happen during runtime when is_greedy_ptr is None. mask = offset < vec_len
is_greedy = True if is_greedy_ptr is None else tl.load(is_greedy_ptr +
req_idx)
if not is_greedy:
# Early exit for non-greedy sampling requests
return
start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr + if is_greedy_ptr is None:
req_idx - 1) is_greedy_mask = mask
end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx) else:
is_greedy = tl.load(is_greedy_ptr + offset, mask=mask, other=0)
is_greedy_mask = mask & (is_greedy != 0)
start_idx = tl.where(
offset == 0, 0,
tl.load(cu_num_draft_tokens_ptr + offset - 1, is_greedy_mask))
end_idx = tl.load(cu_num_draft_tokens_ptr + offset, is_greedy_mask)
num_draft_tokens = end_idx - start_idx num_draft_tokens = end_idx - start_idx
rejected = False for pos in tl.range(0, BLOCK_SIZE):
for pos in range(num_draft_tokens): num_tokens1 = tl.get_element(num_draft_tokens, (pos, ))
if not rejected: rejected = False
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) start_idx1 = tl.get_element(start_idx, (pos, ))
target_argmax_id = tl.load(target_argmax_ptr + start_idx + pos) is_greedy_mask1 = tl.get_element(is_greedy_mask, (pos, ))
tl.store( position = block_idx * BLOCK_SIZE + pos
output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos, for i in range(num_tokens1):
target_argmax_id, if not rejected:
) draft_token_id = tl.load(draft_token_ids_ptr + start_idx1 + i)
if draft_token_id != target_argmax_id: target_argmax_id = tl.load(target_argmax_ptr + start_idx1 + i)
# Reject tl.store(
rejected = True output_token_ids_ptr + position * (max_spec_len + 1) + i,
target_argmax_id,
)
if draft_token_id != target_argmax_id:
# Reject.
rejected = True
if not rejected: if not rejected and is_greedy_mask1:
# If all tokens are accepted, append the bonus token bonus_renew(
bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx) bonus_token_ids_ptr,
tl.store( position,
output_token_ids_ptr + req_idx * (max_spec_len + 1) + output_token_ids_ptr,
num_draft_tokens, max_spec_len,
bonus_token_id, num_tokens1,
) )
@triton.jit(do_not_specialize=["max_spec_len"]) @triton.jit(do_not_specialize=["max_spec_len"])
@@ -649,22 +752,30 @@ def expand_kernel(
cu_num_tokens_ptr, # [batch_size] cu_num_tokens_ptr, # [batch_size]
replace_from, replace_from,
replace_to, replace_to,
vec_len,
MAX_NUM_TOKENS: tl.constexpr, MAX_NUM_TOKENS: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
): ):
req_idx = tl.program_id(0) req_idx = tl.program_id(0)
if req_idx == 0: offset = req_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
start_idx = 0 len_mask = offset < vec_len
else:
start_idx = tl.load(cu_num_tokens_ptr + req_idx - 1) start_idx = tl.where(offset == 0, 0,
end_idx = tl.load(cu_num_tokens_ptr + req_idx) tl.load(cu_num_tokens_ptr + offset - 1, len_mask))
end_idx = tl.load(cu_num_tokens_ptr + offset, len_mask)
num_tokens = end_idx - start_idx num_tokens = end_idx - start_idx
src_val = tl.load(input_ptr + req_idx) src_val = tl.load(input_ptr + offset, len_mask)
src_val = tl.where(src_val == replace_from, replace_to, src_val) src_val = tl.where(src_val == replace_from, replace_to, src_val)
offset = tl.arange(0, MAX_NUM_TOKENS)
tl.store(output_ptr + start_idx + offset, for i in tl.range(0, BLOCK_SIZE):
src_val, num_tokens1 = tl.get_element(num_tokens, (i, ))
mask=offset < num_tokens) start_idx1 = tl.get_element(start_idx, (i, ))
src_val1 = tl.get_element(src_val, (i, ))
offset1 = tl.arange(0, MAX_NUM_TOKENS)
tl.store(output_ptr + start_idx1 + offset1,
src_val1,
mask=offset1 < num_tokens1)
@triton.jit @triton.jit