From 2892b9bb970c0abc95454f9a766fdf4433089ffe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wang=20Ran=20=28=E6=B1=AA=E7=84=B6=29?= Date: Sun, 16 Mar 2025 07:39:19 +0800 Subject: [PATCH] bugfix: Update sampling_params.py (#4413) --- python/sglang/srt/sampling/sampling_params.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/sampling/sampling_params.py b/python/sglang/srt/sampling/sampling_params.py index ffa2875e9..7c77a204f 100644 --- a/python/sglang/srt/sampling/sampling_params.py +++ b/python/sglang/srt/sampling/sampling_params.py @@ -77,7 +77,7 @@ class SamplingParams: self.custom_params = custom_params # Process some special cases - if self.temperature < _SAMPLING_EPS: + if 0 <= self.temperature < _SAMPLING_EPS: # top_k = 1 means greedy sampling self.temperature = 1.0 self.top_k = 1 @@ -93,9 +93,9 @@ class SamplingParams: raise ValueError(f"top_p must be in (0, 1], got {self.top_p}.") if not 0.0 <= self.min_p <= 1.0: raise ValueError(f"min_p must be in [0, 1], got {self.min_p}.") - if self.top_k < -1 or self.top_k == 0: + if self.top_k < 1 or self.top_k == -1: raise ValueError( - f"top_k must be -1 (disable), or at least 1, " f"got {self.top_k}." + f"top_k must be -1 (disable) or at least 1, got {self.top_k}." ) if not -2.0 <= self.frequency_penalty <= 2.0: raise ValueError( @@ -108,12 +108,12 @@ class SamplingParams: ) if not 0.0 <= self.repetition_penalty <= 2.0: raise ValueError( - "repetition_penalty must be in (0, 2], got " + "repetition_penalty must be in [0, 2], got " f"{self.repetition_penalty}." ) if not 0 <= self.min_new_tokens: raise ValueError( - f"min_new_tokens must be in (0, max_new_tokens], got " + f"min_new_tokens must be in [0, max_new_tokens], got " f"{self.min_new_tokens}." ) if self.max_new_tokens is not None: @@ -123,7 +123,7 @@ class SamplingParams: ) if not self.min_new_tokens <= self.max_new_tokens: raise ValueError( - f"min_new_tokens must be in (0, max_new_tokens({self.max_new_tokens})], got " + f"min_new_tokens must be in [0, max_new_tokens({self.max_new_tokens})], got " f"{self.min_new_tokens}." ) grammars = [