From a7f91079b8576a846f671c9e6923805e74e35c87 Mon Sep 17 00:00:00 2001 From: whx <56632993+whx-sjtu@users.noreply.github.com> Date: Fri, 5 Dec 2025 15:16:19 +0800 Subject: [PATCH] [BugFix][Triton] Fix ub overflow bug of sample_recover_tokens_kernel (#4673) ### What this PR does / why we need it? Original `sample_recover_tokens_kernel` of reject sampler didn't tile the vocab size dim, whitch will cause ub overflow problem for models with big vocab size like deepseek. This PR adds tiling to the vocab size dim to avoid this problem. Note that currently we just use a emperical `SUB_BLOCK_SIZE` of `4*1024` for functionality. If in the future this kernel becomes performance bottle neck, we can use triton autotune to optimize this. What's more, we have to disable multibuffer of this kernel due to some accuracy issues. - vLLM version: v0.12.0 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.12.0 Signed-off-by: whx-sjtu <2952154980@qq.com> Co-authored-by: weijinqian0 <1184188277@qq.com> --- vllm_ascend/sample/rejection_sampler.py | 94 +++++++++++++++++-------- 1 file changed, 65 insertions(+), 29 deletions(-) diff --git a/vllm_ascend/sample/rejection_sampler.py b/vllm_ascend/sample/rejection_sampler.py index 9bd941fc..b7905373 100644 --- a/vllm_ascend/sample/rejection_sampler.py +++ b/vllm_ascend/sample/rejection_sampler.py @@ -215,7 +215,7 @@ def rejection_sample( target_probs, bonus_token_ids, recovered_token_ids, - uniform_probs, + uniform_probs.to(torch.float32), is_greedy, max_spec_len, vocab_size, @@ -331,6 +331,9 @@ def sample_recovered_tokens( vocab_size, triton.next_power_of_2(vocab_size), NO_DRAFT_PROBS=draft_probs is None, + SUB_BLOCK=4 * 1024, + # TODO: enable multibuffer when accuracy problem is solved. + multibuffer=False, ) else: sample_recovered_tokens_pytorch( @@ -698,6 +701,7 @@ def sample_recovered_tokens_kernel( vocab_size, PADDED_VOCAB_SIZE: tl.constexpr, NO_DRAFT_PROBS: tl.constexpr, + SUB_BLOCK: tl.constexpr, ): req_idx = tl.program_id(0) start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr + @@ -705,42 +709,74 @@ def sample_recovered_tokens_kernel( end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx) num_draft_tokens = end_idx - start_idx - # Early exit for out-of-range positions + # Early exit for out-of-range positions. pos = tl.program_id(1) if pos >= num_draft_tokens: return - vocab_offset = tl.arange(0, PADDED_VOCAB_SIZE) + loop = (vocab_size + SUB_BLOCK - 1) // SUB_BLOCK + global_recovered_id = -1 + global_max_p = -1.0 if NO_DRAFT_PROBS: draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) - prob = tl.load( - target_probs_ptr + (start_idx + pos) * vocab_size + vocab_offset, - mask=((vocab_offset < vocab_size) & - (vocab_offset != draft_token_id)), - other=0, - ) + orig_prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size + + draft_token_id) + # Temporarily zero out the probability of the draft token. + # This is essentially the same as target_prob - draft_prob, except that + # n-gram does not have draft_prob. We regard it as 1. + tl.store( + target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id, + 0) + for loop_i in range(loop): + vocab_start = loop_i * SUB_BLOCK + vocab_offset = vocab_start + tl.arange(0, SUB_BLOCK) + prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size + + vocab_offset, + mask=vocab_offset < vocab_size, + other=0) + q = tl.load(q_ptr + req_idx * vocab_size + vocab_offset, + mask=vocab_offset < vocab_size, + other=float("-inf")) + new_p = prob / q + recovered_id = tl.argmax(new_p, axis=-1) + max_p = tl.get_element(new_p, (recovered_id, )) + if max_p > global_max_p: + global_max_p = max_p + global_recovered_id = vocab_start + recovered_id else: - draft_prob = tl.load( - draft_probs_ptr + (start_idx + pos) * vocab_size + vocab_offset, - mask=vocab_offset < vocab_size, - other=0, - ) - target_prob = tl.load( - target_probs_ptr + (start_idx + pos) * vocab_size + vocab_offset, - mask=vocab_offset < vocab_size, - other=0, - ) - prob = tl.maximum(target_prob - draft_prob, 0) - # We don't need `prob = prob / tl.sum(prob)` here because - # `tl.argmax` will select the maximum value. + for loop_i in range(loop): + vocab_start = loop_i * SUB_BLOCK + vocab_offset = vocab_start + tl.arange(0, SUB_BLOCK) + draft_prob = tl.load(draft_probs_ptr + + (start_idx + pos) * vocab_size + vocab_offset, + mask=vocab_offset < vocab_size, + other=0) + target_prob = tl.load(target_probs_ptr + + (start_idx + pos) * vocab_size + + vocab_offset, + mask=vocab_offset < vocab_size, + other=0) + prob = tl.maximum(target_prob - draft_prob, 0) + # NOTE(woosuk): We don't need `prob = prob / tl.sum(prob)` here because + # `tl.argmax` will select the maximum value. - q = tl.load( - q_ptr + req_idx * vocab_size + vocab_offset, - mask=vocab_offset < vocab_size, - other=float("-inf"), - ) - recovered_id = tl.argmax(prob / q, axis=-1) - tl.store(output_token_ids_ptr + start_idx + pos, recovered_id) + q = tl.load(q_ptr + req_idx * vocab_size + vocab_offset, + mask=vocab_offset < vocab_size, + other=float("-inf")) + new_p = prob / q + recovered_id = tl.argmax(new_p, axis=-1) + max_p = tl.get_element(new_p, (recovered_id, )) + if max_p > global_max_p: + global_max_p = max_p + global_recovered_id = vocab_start + recovered_id + + tl.store(output_token_ids_ptr + start_idx + pos, global_recovered_id) + + if NO_DRAFT_PROBS: + # Restore the original probability. + tl.store( + target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id, + orig_prob) rs.expand_batch_to_tokens = expand_batch_to_tokens