[BugFix][Triton] Fix ub overflow bug of sample_recover_tokens_kernel (#4673)

### What this PR does / why we need it?
Original `sample_recover_tokens_kernel` of reject sampler didn't tile
the vocab size dim, whitch will cause ub overflow problem for models
with big vocab size like deepseek. This PR adds tiling to the vocab size
dim to avoid this problem.

Note that currently we just use a emperical `SUB_BLOCK_SIZE` of `4*1024`
for functionality. If in the future this kernel becomes performance
bottle neck, we can use triton autotune to optimize this. What's more,
we have to disable multibuffer of this kernel due to some accuracy
issues.

- vLLM version: v0.12.0
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.12.0

Signed-off-by: whx-sjtu <2952154980@qq.com>
Co-authored-by: weijinqian0 <1184188277@qq.com>
This commit is contained in:
whx
2025-12-05 15:16:19 +08:00
committed by GitHub
parent 7f33838e6e
commit a7f91079b8

View File

@@ -215,7 +215,7 @@ def rejection_sample(
target_probs,
bonus_token_ids,
recovered_token_ids,
uniform_probs,
uniform_probs.to(torch.float32),
is_greedy,
max_spec_len,
vocab_size,
@@ -331,6 +331,9 @@ def sample_recovered_tokens(
vocab_size,
triton.next_power_of_2(vocab_size),
NO_DRAFT_PROBS=draft_probs is None,
SUB_BLOCK=4 * 1024,
# TODO: enable multibuffer when accuracy problem is solved.
multibuffer=False,
)
else:
sample_recovered_tokens_pytorch(
@@ -698,6 +701,7 @@ def sample_recovered_tokens_kernel(
vocab_size,
PADDED_VOCAB_SIZE: tl.constexpr,
NO_DRAFT_PROBS: tl.constexpr,
SUB_BLOCK: tl.constexpr,
):
req_idx = tl.program_id(0)
start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr +
@@ -705,42 +709,74 @@ def sample_recovered_tokens_kernel(
end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx)
num_draft_tokens = end_idx - start_idx
# Early exit for out-of-range positions
# Early exit for out-of-range positions.
pos = tl.program_id(1)
if pos >= num_draft_tokens:
return
vocab_offset = tl.arange(0, PADDED_VOCAB_SIZE)
loop = (vocab_size + SUB_BLOCK - 1) // SUB_BLOCK
global_recovered_id = -1
global_max_p = -1.0
if NO_DRAFT_PROBS:
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
prob = tl.load(
target_probs_ptr + (start_idx + pos) * vocab_size + vocab_offset,
mask=((vocab_offset < vocab_size) &
(vocab_offset != draft_token_id)),
other=0,
)
orig_prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size +
draft_token_id)
# Temporarily zero out the probability of the draft token.
# This is essentially the same as target_prob - draft_prob, except that
# n-gram does not have draft_prob. We regard it as 1.
tl.store(
target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id,
0)
for loop_i in range(loop):
vocab_start = loop_i * SUB_BLOCK
vocab_offset = vocab_start + tl.arange(0, SUB_BLOCK)
prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size +
vocab_offset,
mask=vocab_offset < vocab_size,
other=0)
q = tl.load(q_ptr + req_idx * vocab_size + vocab_offset,
mask=vocab_offset < vocab_size,
other=float("-inf"))
new_p = prob / q
recovered_id = tl.argmax(new_p, axis=-1)
max_p = tl.get_element(new_p, (recovered_id, ))
if max_p > global_max_p:
global_max_p = max_p
global_recovered_id = vocab_start + recovered_id
else:
draft_prob = tl.load(
draft_probs_ptr + (start_idx + pos) * vocab_size + vocab_offset,
mask=vocab_offset < vocab_size,
other=0,
)
target_prob = tl.load(
target_probs_ptr + (start_idx + pos) * vocab_size + vocab_offset,
mask=vocab_offset < vocab_size,
other=0,
)
prob = tl.maximum(target_prob - draft_prob, 0)
# We don't need `prob = prob / tl.sum(prob)` here because
# `tl.argmax` will select the maximum value.
for loop_i in range(loop):
vocab_start = loop_i * SUB_BLOCK
vocab_offset = vocab_start + tl.arange(0, SUB_BLOCK)
draft_prob = tl.load(draft_probs_ptr +
(start_idx + pos) * vocab_size + vocab_offset,
mask=vocab_offset < vocab_size,
other=0)
target_prob = tl.load(target_probs_ptr +
(start_idx + pos) * vocab_size +
vocab_offset,
mask=vocab_offset < vocab_size,
other=0)
prob = tl.maximum(target_prob - draft_prob, 0)
# NOTE(woosuk): We don't need `prob = prob / tl.sum(prob)` here because
# `tl.argmax` will select the maximum value.
q = tl.load(
q_ptr + req_idx * vocab_size + vocab_offset,
mask=vocab_offset < vocab_size,
other=float("-inf"),
)
recovered_id = tl.argmax(prob / q, axis=-1)
tl.store(output_token_ids_ptr + start_idx + pos, recovered_id)
q = tl.load(q_ptr + req_idx * vocab_size + vocab_offset,
mask=vocab_offset < vocab_size,
other=float("-inf"))
new_p = prob / q
recovered_id = tl.argmax(new_p, axis=-1)
max_p = tl.get_element(new_p, (recovered_id, ))
if max_p > global_max_p:
global_max_p = max_p
global_recovered_id = vocab_start + recovered_id
tl.store(output_token_ids_ptr + start_idx + pos, global_recovered_id)
if NO_DRAFT_PROBS:
# Restore the original probability.
tl.store(
target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id,
orig_prob)
rs.expand_batch_to_tokens = expand_batch_to_tokens