[Bugfix] Update target probs to target logits in rejection sample (#6685)
### What this PR does / why we need it?
This PR aims to update `target_probs` to `target_logits` in
`rejection_sample`, for catching up with
https://github.com/vllm-project/vllm/pull/32852. Otherwise, sampling
with temperature will incur accuracy problem where tokens can be
accepted or rejected unreasonably.
### Does this PR introduce _any_ user-facing change?
N/A
### How was this patch tested?
by ci
- vLLM version: v0.15.0
- vLLM main:
13397841ab
Signed-off-by: Zetong Li <slippersss@126.com>
This commit is contained in:
@@ -90,7 +90,7 @@ def rejection_sample(
|
||||
# [num_tokens, vocab_size]
|
||||
draft_probs: torch.Tensor | None,
|
||||
# [num_tokens, vocab_size]
|
||||
target_probs: torch.Tensor,
|
||||
target_logits: torch.Tensor,
|
||||
# [batch_size, 1]
|
||||
bonus_token_ids: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
@@ -98,17 +98,17 @@ def rejection_sample(
|
||||
assert draft_token_ids.ndim == 1
|
||||
assert draft_probs is None or draft_probs.ndim == 2
|
||||
assert cu_num_draft_tokens.ndim == 1
|
||||
assert target_probs.ndim == 2
|
||||
assert target_logits.ndim == 2
|
||||
|
||||
batch_size = len(num_draft_tokens)
|
||||
num_tokens = draft_token_ids.shape[0]
|
||||
vocab_size = target_probs.shape[-1]
|
||||
device = target_probs.device
|
||||
vocab_size = target_logits.shape[-1]
|
||||
device = target_logits.device
|
||||
assert draft_token_ids.is_contiguous()
|
||||
assert draft_probs is None or draft_probs.is_contiguous()
|
||||
assert target_probs.is_contiguous()
|
||||
assert target_logits.is_contiguous()
|
||||
assert bonus_token_ids.is_contiguous()
|
||||
assert target_probs.shape == (num_tokens, vocab_size)
|
||||
assert target_logits.shape == (num_tokens, vocab_size)
|
||||
|
||||
# When num_speculative_tokens>=3, using block verify.
|
||||
using_block_verify = max_spec_len >= 3
|
||||
@@ -129,7 +129,7 @@ def rejection_sample(
|
||||
grid, block_size = cal_grid_and_block_size(batch_size)
|
||||
if not sampling_metadata.all_random:
|
||||
# Rejection sampling for greedy sampling requests.
|
||||
target_argmax = target_probs.argmax(dim=-1)
|
||||
target_argmax = target_logits.argmax(dim=-1)
|
||||
if HAS_TRITON:
|
||||
rejection_greedy_sample_with_triton(
|
||||
output_token_ids,
|
||||
@@ -165,6 +165,10 @@ def rejection_sample(
|
||||
if sampling_metadata.all_greedy:
|
||||
return output_token_ids
|
||||
|
||||
# Compute probability distribution from target logits.
|
||||
target_probs = target_logits.softmax(dim=-1, dtype=torch.float32)
|
||||
assert target_probs.is_contiguous()
|
||||
|
||||
# Generate uniform probabilities for rejection sampling.
|
||||
# [num_tokens]
|
||||
uniform_probs = generate_uniform_probs(
|
||||
|
||||
Reference in New Issue
Block a user