feat: implement high-performance Triton kernels for rejection sampling (#4830)
### What this PR does / why we need it?
This PR introduces optimized Triton implementations for the
rejection_greedy_sample_kernel and expand_kernel, delivering superior
performance compared to the existing Triton implementations. The new
Triton kernels maintain full functional accuracy while delivering
significant performance improvements across various batch sizes and MTP
configurations.
### Does this PR introduce _any_ user-facing change?
Yes, this PR modifies rejection_sampler.py to use optimized Triton
kernels:
- rejection_greedy_sample_kernel is enhanced with
rejection_greedy_sample_spec_len_1_triton and
rejection_greedy_sample_triton implementations
- expand_kernel receives a performance-optimized Triton version
These changes provide substantial performance improvements while
maintaining backward compatibility
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: yuxingcyx <yuxingchen.math@gmail.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -16,6 +16,16 @@ GREEDY_TEMPERATURE = -1
|
||||
# step. This value is chosen to be large enough to handle typical use cases.
|
||||
MAX_SPEC_LEN = 32
|
||||
|
||||
vectorcore_num = None
|
||||
device_properties = None
|
||||
|
||||
if HAS_TRITON:
|
||||
from triton.runtime import driver # type: ignore
|
||||
device_properties = driver.active.utils.get_device_properties(
|
||||
torch.npu.current_device())
|
||||
vectorcore_num = device_properties['num_vectorcore']
|
||||
#get vector core number in order for later tiling
|
||||
|
||||
|
||||
def apply_sampling_constraints(
|
||||
logits: torch.Tensor, # [num_tokens, vocab_size]
|
||||
@@ -128,15 +138,36 @@ def rejection_sample(
|
||||
# Rejection sampling for greedy sampling requests.
|
||||
target_argmax = target_probs.argmax(dim=-1)
|
||||
if HAS_TRITON:
|
||||
rejection_greedy_sample_kernel[(batch_size, )](
|
||||
output_token_ids,
|
||||
cu_num_draft_tokens,
|
||||
draft_token_ids,
|
||||
target_argmax,
|
||||
bonus_token_ids,
|
||||
is_greedy,
|
||||
max_spec_len,
|
||||
)
|
||||
vec_len = batch_size
|
||||
n = cu_num_draft_tokens.numel()
|
||||
BLOCK_SIZE = 2
|
||||
grid = triton.cdiv(n, BLOCK_SIZE)
|
||||
if n >= vectorcore_num:
|
||||
grid = vectorcore_num # Empirically tuned value
|
||||
BLOCK_SIZE = triton.next_power_of_2(triton.cdiv(n, grid))
|
||||
|
||||
if min(num_draft_tokens) == 1 and max(
|
||||
num_draft_tokens) == 1 and sampling_metadata.all_greedy:
|
||||
rejection_greedy_sample_spec_len_1_triton[(grid, )](
|
||||
output_token_ids,
|
||||
draft_token_ids,
|
||||
target_argmax,
|
||||
bonus_token_ids,
|
||||
vec_len,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
)
|
||||
else:
|
||||
rejection_greedy_sample_triton[(grid, )](
|
||||
output_token_ids,
|
||||
cu_num_draft_tokens,
|
||||
draft_token_ids,
|
||||
target_argmax,
|
||||
bonus_token_ids,
|
||||
is_greedy,
|
||||
vec_len,
|
||||
max_spec_len,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
)
|
||||
else:
|
||||
if min(num_draft_tokens) == 1 and max(
|
||||
num_draft_tokens) == 1 and sampling_metadata.all_greedy:
|
||||
@@ -247,13 +278,23 @@ def expand_batch_to_tokens(
|
||||
assert cu_num_tokens.shape[0] == batch_size
|
||||
expanded_x = x.new_empty(num_tokens)
|
||||
if HAS_TRITON:
|
||||
expand_kernel[(batch_size, )](
|
||||
vec_len = batch_size
|
||||
n = cu_num_tokens.numel()
|
||||
BLOCK_SIZE = 2
|
||||
grid = triton.cdiv(n, BLOCK_SIZE)
|
||||
if n >= vectorcore_num:
|
||||
grid = vectorcore_num
|
||||
BLOCK_SIZE = triton.next_power_of_2(triton.cdiv(n, grid))
|
||||
|
||||
expand_kernel[(grid, )](
|
||||
expanded_x,
|
||||
x,
|
||||
cu_num_tokens,
|
||||
replace_from,
|
||||
replace_to,
|
||||
vec_len,
|
||||
MAX_NUM_TOKENS=MAX_SPEC_LEN, # To avoid recompilation.
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
)
|
||||
else:
|
||||
expand_pytorch(
|
||||
@@ -536,50 +577,112 @@ def sample_recovered_tokens_pytorch(
|
||||
|
||||
|
||||
@triton.jit(do_not_specialize=["max_spec_len"])
|
||||
def rejection_greedy_sample_kernel(
|
||||
def bonus_renew_1(
|
||||
bonus_token_ids_ptr,
|
||||
position,
|
||||
output_token_ids_ptr,
|
||||
):
|
||||
bonus_token_id = tl.load(bonus_token_ids_ptr + position)
|
||||
tl.store(output_token_ids_ptr + position * 2 + 1, bonus_token_id)
|
||||
|
||||
|
||||
@triton.jit(do_not_specialize=["max_spec_len"])
|
||||
def rejection_greedy_sample_spec_len_1_triton(
|
||||
output_token_ids_ptr, # [batch_size, 2]
|
||||
draft_token_ids_ptr, # [num_tokens]
|
||||
target_argmax_ptr, # [num_tokens]
|
||||
bonus_token_ids_ptr,
|
||||
vec_len,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
block_idx = tl.program_id(0)
|
||||
offset = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offset < vec_len
|
||||
|
||||
draft_token_id = tl.load(draft_token_ids_ptr + offset, mask)
|
||||
target_argmax_id = tl.load(target_argmax_ptr + offset, mask)
|
||||
tl.store(output_token_ids_ptr + offset * 2, target_argmax_id, mask)
|
||||
|
||||
for pos in tl.range(0, BLOCK_SIZE):
|
||||
draft_token_id1 = tl.get_element(draft_token_id, (pos, ))
|
||||
target_argmax1 = tl.get_element(target_argmax_id, (pos, ))
|
||||
position = block_idx * BLOCK_SIZE + pos
|
||||
if draft_token_id1 == target_argmax1:
|
||||
bonus_renew_1(
|
||||
bonus_token_ids_ptr,
|
||||
position,
|
||||
output_token_ids_ptr,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit(do_not_specialize=["max_spec_len"])
|
||||
def bonus_renew(
|
||||
bonus_token_ids_ptr,
|
||||
position,
|
||||
output_token_ids_ptr,
|
||||
max_spec_len,
|
||||
num_tokens1,
|
||||
):
|
||||
bonus_token_id = tl.load(bonus_token_ids_ptr + position)
|
||||
tl.store(
|
||||
output_token_ids_ptr + position * (max_spec_len + 1) + num_tokens1,
|
||||
bonus_token_id)
|
||||
|
||||
|
||||
@triton.jit(do_not_specialize=["max_spec_len"])
|
||||
def rejection_greedy_sample_triton(
|
||||
output_token_ids_ptr, # [batch_size, max_spec_len + 1]
|
||||
cu_num_draft_tokens_ptr, # [batch_size]
|
||||
draft_token_ids_ptr, # [num_tokens]
|
||||
target_argmax_ptr, # [num_tokens]
|
||||
bonus_token_ids_ptr, # [batch_size]
|
||||
is_greedy_ptr, # [batch_size] or None
|
||||
vec_len,
|
||||
max_spec_len,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
# Because is_greedy_ptr is not Nonr at profiling run,
|
||||
# re-comilation may happen during runtime when is_greedy_ptr is None.
|
||||
is_greedy = True if is_greedy_ptr is None else tl.load(is_greedy_ptr +
|
||||
req_idx)
|
||||
if not is_greedy:
|
||||
# Early exit for non-greedy sampling requests
|
||||
return
|
||||
block_idx = tl.program_id(0)
|
||||
offset = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offset < vec_len
|
||||
|
||||
start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr +
|
||||
req_idx - 1)
|
||||
end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx)
|
||||
if is_greedy_ptr is None:
|
||||
is_greedy_mask = mask
|
||||
else:
|
||||
is_greedy = tl.load(is_greedy_ptr + offset, mask=mask, other=0)
|
||||
is_greedy_mask = mask & (is_greedy != 0)
|
||||
|
||||
start_idx = tl.where(
|
||||
offset == 0, 0,
|
||||
tl.load(cu_num_draft_tokens_ptr + offset - 1, is_greedy_mask))
|
||||
end_idx = tl.load(cu_num_draft_tokens_ptr + offset, is_greedy_mask)
|
||||
num_draft_tokens = end_idx - start_idx
|
||||
|
||||
rejected = False
|
||||
for pos in range(num_draft_tokens):
|
||||
if not rejected:
|
||||
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
|
||||
target_argmax_id = tl.load(target_argmax_ptr + start_idx + pos)
|
||||
tl.store(
|
||||
output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos,
|
||||
target_argmax_id,
|
||||
)
|
||||
if draft_token_id != target_argmax_id:
|
||||
# Reject
|
||||
rejected = True
|
||||
for pos in tl.range(0, BLOCK_SIZE):
|
||||
num_tokens1 = tl.get_element(num_draft_tokens, (pos, ))
|
||||
rejected = False
|
||||
start_idx1 = tl.get_element(start_idx, (pos, ))
|
||||
is_greedy_mask1 = tl.get_element(is_greedy_mask, (pos, ))
|
||||
position = block_idx * BLOCK_SIZE + pos
|
||||
for i in range(num_tokens1):
|
||||
if not rejected:
|
||||
draft_token_id = tl.load(draft_token_ids_ptr + start_idx1 + i)
|
||||
target_argmax_id = tl.load(target_argmax_ptr + start_idx1 + i)
|
||||
tl.store(
|
||||
output_token_ids_ptr + position * (max_spec_len + 1) + i,
|
||||
target_argmax_id,
|
||||
)
|
||||
if draft_token_id != target_argmax_id:
|
||||
# Reject.
|
||||
rejected = True
|
||||
|
||||
if not rejected:
|
||||
# If all tokens are accepted, append the bonus token
|
||||
bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx)
|
||||
tl.store(
|
||||
output_token_ids_ptr + req_idx * (max_spec_len + 1) +
|
||||
num_draft_tokens,
|
||||
bonus_token_id,
|
||||
)
|
||||
if not rejected and is_greedy_mask1:
|
||||
bonus_renew(
|
||||
bonus_token_ids_ptr,
|
||||
position,
|
||||
output_token_ids_ptr,
|
||||
max_spec_len,
|
||||
num_tokens1,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit(do_not_specialize=["max_spec_len"])
|
||||
@@ -649,22 +752,30 @@ def expand_kernel(
|
||||
cu_num_tokens_ptr, # [batch_size]
|
||||
replace_from,
|
||||
replace_to,
|
||||
vec_len,
|
||||
MAX_NUM_TOKENS: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
if req_idx == 0:
|
||||
start_idx = 0
|
||||
else:
|
||||
start_idx = tl.load(cu_num_tokens_ptr + req_idx - 1)
|
||||
end_idx = tl.load(cu_num_tokens_ptr + req_idx)
|
||||
offset = req_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
len_mask = offset < vec_len
|
||||
|
||||
start_idx = tl.where(offset == 0, 0,
|
||||
tl.load(cu_num_tokens_ptr + offset - 1, len_mask))
|
||||
end_idx = tl.load(cu_num_tokens_ptr + offset, len_mask)
|
||||
num_tokens = end_idx - start_idx
|
||||
|
||||
src_val = tl.load(input_ptr + req_idx)
|
||||
src_val = tl.load(input_ptr + offset, len_mask)
|
||||
src_val = tl.where(src_val == replace_from, replace_to, src_val)
|
||||
offset = tl.arange(0, MAX_NUM_TOKENS)
|
||||
tl.store(output_ptr + start_idx + offset,
|
||||
src_val,
|
||||
mask=offset < num_tokens)
|
||||
|
||||
for i in tl.range(0, BLOCK_SIZE):
|
||||
num_tokens1 = tl.get_element(num_tokens, (i, ))
|
||||
start_idx1 = tl.get_element(start_idx, (i, ))
|
||||
src_val1 = tl.get_element(src_val, (i, ))
|
||||
offset1 = tl.arange(0, MAX_NUM_TOKENS)
|
||||
tl.store(output_ptr + start_idx1 + offset1,
|
||||
src_val1,
|
||||
mask=offset1 < num_tokens1)
|
||||
|
||||
|
||||
@triton.jit
|
||||
|
||||
Reference in New Issue
Block a user