From 140fcaffc3f5bb2b1e92fef830653f542a0b16c8 Mon Sep 17 00:00:00 2001 From: Zetong Li <48438720+slippersss@users.noreply.github.com> Date: Wed, 11 Feb 2026 21:31:40 +0800 Subject: [PATCH] [Bugfix] Update target probs to target logits in rejection sample (#6685) ### What this PR does / why we need it? This PR aims to update `target_probs` to `target_logits` in `rejection_sample`, for catching up with https://github.com/vllm-project/vllm/pull/32852. Otherwise, sampling with temperature will incur accuracy problem where tokens can be accepted or rejected unreasonably. ### Does this PR introduce _any_ user-facing change? N/A ### How was this patch tested? by ci - vLLM version: v0.15.0 - vLLM main: https://github.com/vllm-project/vllm/commit/13397841ab469cecf1ed425c3f52a9ffc38139b5 Signed-off-by: Zetong Li --- vllm_ascend/sample/rejection_sampler.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/vllm_ascend/sample/rejection_sampler.py b/vllm_ascend/sample/rejection_sampler.py index f048dcce..695ae649 100644 --- a/vllm_ascend/sample/rejection_sampler.py +++ b/vllm_ascend/sample/rejection_sampler.py @@ -90,7 +90,7 @@ def rejection_sample( # [num_tokens, vocab_size] draft_probs: torch.Tensor | None, # [num_tokens, vocab_size] - target_probs: torch.Tensor, + target_logits: torch.Tensor, # [batch_size, 1] bonus_token_ids: torch.Tensor, sampling_metadata: SamplingMetadata, @@ -98,17 +98,17 @@ def rejection_sample( assert draft_token_ids.ndim == 1 assert draft_probs is None or draft_probs.ndim == 2 assert cu_num_draft_tokens.ndim == 1 - assert target_probs.ndim == 2 + assert target_logits.ndim == 2 batch_size = len(num_draft_tokens) num_tokens = draft_token_ids.shape[0] - vocab_size = target_probs.shape[-1] - device = target_probs.device + vocab_size = target_logits.shape[-1] + device = target_logits.device assert draft_token_ids.is_contiguous() assert draft_probs is None or draft_probs.is_contiguous() - assert target_probs.is_contiguous() + assert target_logits.is_contiguous() assert bonus_token_ids.is_contiguous() - assert target_probs.shape == (num_tokens, vocab_size) + assert target_logits.shape == (num_tokens, vocab_size) # When num_speculative_tokens>=3, using block verify. using_block_verify = max_spec_len >= 3 @@ -129,7 +129,7 @@ def rejection_sample( grid, block_size = cal_grid_and_block_size(batch_size) if not sampling_metadata.all_random: # Rejection sampling for greedy sampling requests. - target_argmax = target_probs.argmax(dim=-1) + target_argmax = target_logits.argmax(dim=-1) if HAS_TRITON: rejection_greedy_sample_with_triton( output_token_ids, @@ -165,6 +165,10 @@ def rejection_sample( if sampling_metadata.all_greedy: return output_token_ids + # Compute probability distribution from target logits. + target_probs = target_logits.softmax(dim=-1, dtype=torch.float32) + assert target_probs.is_contiguous() + # Generate uniform probabilities for rejection sampling. # [num_tokens] uniform_probs = generate_uniform_probs(