diff --git a/vllm_ascend/sample/rejection_sampler.py b/vllm_ascend/sample/rejection_sampler.py index f048dcce..695ae649 100644 --- a/vllm_ascend/sample/rejection_sampler.py +++ b/vllm_ascend/sample/rejection_sampler.py @@ -90,7 +90,7 @@ def rejection_sample( # [num_tokens, vocab_size] draft_probs: torch.Tensor | None, # [num_tokens, vocab_size] - target_probs: torch.Tensor, + target_logits: torch.Tensor, # [batch_size, 1] bonus_token_ids: torch.Tensor, sampling_metadata: SamplingMetadata, @@ -98,17 +98,17 @@ def rejection_sample( assert draft_token_ids.ndim == 1 assert draft_probs is None or draft_probs.ndim == 2 assert cu_num_draft_tokens.ndim == 1 - assert target_probs.ndim == 2 + assert target_logits.ndim == 2 batch_size = len(num_draft_tokens) num_tokens = draft_token_ids.shape[0] - vocab_size = target_probs.shape[-1] - device = target_probs.device + vocab_size = target_logits.shape[-1] + device = target_logits.device assert draft_token_ids.is_contiguous() assert draft_probs is None or draft_probs.is_contiguous() - assert target_probs.is_contiguous() + assert target_logits.is_contiguous() assert bonus_token_ids.is_contiguous() - assert target_probs.shape == (num_tokens, vocab_size) + assert target_logits.shape == (num_tokens, vocab_size) # When num_speculative_tokens>=3, using block verify. using_block_verify = max_spec_len >= 3 @@ -129,7 +129,7 @@ def rejection_sample( grid, block_size = cal_grid_and_block_size(batch_size) if not sampling_metadata.all_random: # Rejection sampling for greedy sampling requests. - target_argmax = target_probs.argmax(dim=-1) + target_argmax = target_logits.argmax(dim=-1) if HAS_TRITON: rejection_greedy_sample_with_triton( output_token_ids, @@ -165,6 +165,10 @@ def rejection_sample( if sampling_metadata.all_greedy: return output_token_ids + # Compute probability distribution from target logits. + target_probs = target_logits.softmax(dim=-1, dtype=torch.float32) + assert target_probs.is_contiguous() + # Generate uniform probabilities for rejection sampling. # [num_tokens] uniform_probs = generate_uniform_probs(