diff --git a/tests/e2e/singlecard/spec_decode_v1/test_v1_spec_decode.py b/tests/e2e/singlecard/spec_decode_v1/test_v1_spec_decode.py index 150fbeec..8b2a145a 100644 --- a/tests/e2e/singlecard/spec_decode_v1/test_v1_spec_decode.py +++ b/tests/e2e/singlecard/spec_decode_v1/test_v1_spec_decode.py @@ -145,7 +145,9 @@ def test_eagle_correctness( sampling_params = SamplingParams( max_tokens=300, - temperature=0.0, + temperature=0.8, + top_p=0.7, + top_k=4, ignore_eos=False, ) diff --git a/vllm_ascend/sample/rejection_sampler.py b/vllm_ascend/sample/rejection_sampler.py index e6e9e791..44bf7264 100644 --- a/vllm_ascend/sample/rejection_sampler.py +++ b/vllm_ascend/sample/rejection_sampler.py @@ -83,8 +83,7 @@ def apply_sampling_constraints( if get_ascend_device_type( ) != AscendDeviceType._310P and top_p is not None and top_k is not None and 1 <= int( top_k.max()) <= 1024: - return torch_npu.npu_top_k_top_p(logits, top_p.to(torch.bfloat16), - top_k) + return torch_npu.npu_top_k_top_p(logits, top_p.to(logits.dtype), top_k) else: # NOTE(woosuk): `apply_top_k_top_p` uses sorting to calculate the mask, # which is slow for large vocab sizes. This may cause performance issues.