[Bugfix] Update target probs to target logits in rejection sample (#6685)

### What this PR does / why we need it?
This PR aims to update `target_probs` to `target_logits` in
`rejection_sample`, for catching up with
https://github.com/vllm-project/vllm/pull/32852. Otherwise, sampling
with temperature will incur accuracy problem where tokens can be
accepted or rejected unreasonably.

### Does this PR introduce _any_ user-facing change?
N/A

### How was this patch tested?
by ci

- vLLM version: v0.15.0
- vLLM main:
13397841ab

Signed-off-by: Zetong Li <slippersss@126.com>
This commit is contained in:
Zetong Li
2026-02-11 21:31:40 +08:00
committed by GitHub
parent c0c2eb614e
commit 140fcaffc3

View File

@@ -90,7 +90,7 @@ def rejection_sample(
# [num_tokens, vocab_size]
draft_probs: torch.Tensor | None,
# [num_tokens, vocab_size]
target_probs: torch.Tensor,
target_logits: torch.Tensor,
# [batch_size, 1]
bonus_token_ids: torch.Tensor,
sampling_metadata: SamplingMetadata,
@@ -98,17 +98,17 @@ def rejection_sample(
assert draft_token_ids.ndim == 1
assert draft_probs is None or draft_probs.ndim == 2
assert cu_num_draft_tokens.ndim == 1
assert target_probs.ndim == 2
assert target_logits.ndim == 2
batch_size = len(num_draft_tokens)
num_tokens = draft_token_ids.shape[0]
vocab_size = target_probs.shape[-1]
device = target_probs.device
vocab_size = target_logits.shape[-1]
device = target_logits.device
assert draft_token_ids.is_contiguous()
assert draft_probs is None or draft_probs.is_contiguous()
assert target_probs.is_contiguous()
assert target_logits.is_contiguous()
assert bonus_token_ids.is_contiguous()
assert target_probs.shape == (num_tokens, vocab_size)
assert target_logits.shape == (num_tokens, vocab_size)
# When num_speculative_tokens>=3, using block verify.
using_block_verify = max_spec_len >= 3
@@ -129,7 +129,7 @@ def rejection_sample(
grid, block_size = cal_grid_and_block_size(batch_size)
if not sampling_metadata.all_random:
# Rejection sampling for greedy sampling requests.
target_argmax = target_probs.argmax(dim=-1)
target_argmax = target_logits.argmax(dim=-1)
if HAS_TRITON:
rejection_greedy_sample_with_triton(
output_token_ids,
@@ -165,6 +165,10 @@ def rejection_sample(
if sampling_metadata.all_greedy:
return output_token_ids
# Compute probability distribution from target logits.
target_probs = target_logits.softmax(dim=-1, dtype=torch.float32)
assert target_probs.is_contiguous()
# Generate uniform probabilities for rejection sampling.
# [num_tokens]
uniform_probs = generate_uniform_probs(