diff --git a/vllm_ascend/sample/rejection_sampler.py b/vllm_ascend/sample/rejection_sampler.py index 695ae649..04552b97 100644 --- a/vllm_ascend/sample/rejection_sampler.py +++ b/vllm_ascend/sample/rejection_sampler.py @@ -19,6 +19,7 @@ from vllm_ascend.ops.triton.reject_sample import ( sample_recovered_tokens_kernel, ) from vllm_ascend.sample.sampler import apply_top_k_top_p +from vllm_ascend.utils import vllm_version_is def apply_sampling_constraints( @@ -166,7 +167,10 @@ def rejection_sample( return output_token_ids # Compute probability distribution from target logits. - target_probs = target_logits.softmax(dim=-1, dtype=torch.float32) + if vllm_version_is("0.15.0"): + target_probs = target_logits + else: + target_probs = target_logits.softmax(dim=-1, dtype=torch.float32) assert target_probs.is_contiguous() # Generate uniform probabilities for rejection sampling.