[Bugfix] Update target probs to target logits in rejection sample (#6685)
### What this PR does / why we need it?
This PR aims to update `target_probs` to `target_logits` in
`rejection_sample`, for catching up with
https://github.com/vllm-project/vllm/pull/32852. Otherwise, sampling
with temperature will incur accuracy problem where tokens can be
accepted or rejected unreasonably.
### Does this PR introduce _any_ user-facing change?
N/A
### How was this patch tested?
by ci
- vLLM version: v0.15.0
- vLLM main:
13397841ab
Signed-off-by: Zetong Li <slippersss@126.com>
This commit is contained in:
@@ -90,7 +90,7 @@ def rejection_sample(
|
|||||||
# [num_tokens, vocab_size]
|
# [num_tokens, vocab_size]
|
||||||
draft_probs: torch.Tensor | None,
|
draft_probs: torch.Tensor | None,
|
||||||
# [num_tokens, vocab_size]
|
# [num_tokens, vocab_size]
|
||||||
target_probs: torch.Tensor,
|
target_logits: torch.Tensor,
|
||||||
# [batch_size, 1]
|
# [batch_size, 1]
|
||||||
bonus_token_ids: torch.Tensor,
|
bonus_token_ids: torch.Tensor,
|
||||||
sampling_metadata: SamplingMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
@@ -98,17 +98,17 @@ def rejection_sample(
|
|||||||
assert draft_token_ids.ndim == 1
|
assert draft_token_ids.ndim == 1
|
||||||
assert draft_probs is None or draft_probs.ndim == 2
|
assert draft_probs is None or draft_probs.ndim == 2
|
||||||
assert cu_num_draft_tokens.ndim == 1
|
assert cu_num_draft_tokens.ndim == 1
|
||||||
assert target_probs.ndim == 2
|
assert target_logits.ndim == 2
|
||||||
|
|
||||||
batch_size = len(num_draft_tokens)
|
batch_size = len(num_draft_tokens)
|
||||||
num_tokens = draft_token_ids.shape[0]
|
num_tokens = draft_token_ids.shape[0]
|
||||||
vocab_size = target_probs.shape[-1]
|
vocab_size = target_logits.shape[-1]
|
||||||
device = target_probs.device
|
device = target_logits.device
|
||||||
assert draft_token_ids.is_contiguous()
|
assert draft_token_ids.is_contiguous()
|
||||||
assert draft_probs is None or draft_probs.is_contiguous()
|
assert draft_probs is None or draft_probs.is_contiguous()
|
||||||
assert target_probs.is_contiguous()
|
assert target_logits.is_contiguous()
|
||||||
assert bonus_token_ids.is_contiguous()
|
assert bonus_token_ids.is_contiguous()
|
||||||
assert target_probs.shape == (num_tokens, vocab_size)
|
assert target_logits.shape == (num_tokens, vocab_size)
|
||||||
|
|
||||||
# When num_speculative_tokens>=3, using block verify.
|
# When num_speculative_tokens>=3, using block verify.
|
||||||
using_block_verify = max_spec_len >= 3
|
using_block_verify = max_spec_len >= 3
|
||||||
@@ -129,7 +129,7 @@ def rejection_sample(
|
|||||||
grid, block_size = cal_grid_and_block_size(batch_size)
|
grid, block_size = cal_grid_and_block_size(batch_size)
|
||||||
if not sampling_metadata.all_random:
|
if not sampling_metadata.all_random:
|
||||||
# Rejection sampling for greedy sampling requests.
|
# Rejection sampling for greedy sampling requests.
|
||||||
target_argmax = target_probs.argmax(dim=-1)
|
target_argmax = target_logits.argmax(dim=-1)
|
||||||
if HAS_TRITON:
|
if HAS_TRITON:
|
||||||
rejection_greedy_sample_with_triton(
|
rejection_greedy_sample_with_triton(
|
||||||
output_token_ids,
|
output_token_ids,
|
||||||
@@ -165,6 +165,10 @@ def rejection_sample(
|
|||||||
if sampling_metadata.all_greedy:
|
if sampling_metadata.all_greedy:
|
||||||
return output_token_ids
|
return output_token_ids
|
||||||
|
|
||||||
|
# Compute probability distribution from target logits.
|
||||||
|
target_probs = target_logits.softmax(dim=-1, dtype=torch.float32)
|
||||||
|
assert target_probs.is_contiguous()
|
||||||
|
|
||||||
# Generate uniform probabilities for rejection sampling.
|
# Generate uniform probabilities for rejection sampling.
|
||||||
# [num_tokens]
|
# [num_tokens]
|
||||||
uniform_probs = generate_uniform_probs(
|
uniform_probs = generate_uniform_probs(
|
||||||
|
|||||||
Reference in New Issue
Block a user