bugfix: Update sampling_params.py (#4413)
This commit is contained in:
@@ -77,7 +77,7 @@ class SamplingParams:
|
||||
self.custom_params = custom_params
|
||||
|
||||
# Process some special cases
|
||||
if self.temperature < _SAMPLING_EPS:
|
||||
if 0 <= self.temperature < _SAMPLING_EPS:
|
||||
# top_k = 1 means greedy sampling
|
||||
self.temperature = 1.0
|
||||
self.top_k = 1
|
||||
@@ -93,9 +93,9 @@ class SamplingParams:
|
||||
raise ValueError(f"top_p must be in (0, 1], got {self.top_p}.")
|
||||
if not 0.0 <= self.min_p <= 1.0:
|
||||
raise ValueError(f"min_p must be in [0, 1], got {self.min_p}.")
|
||||
if self.top_k < -1 or self.top_k == 0:
|
||||
if self.top_k < 1 or self.top_k == -1:
|
||||
raise ValueError(
|
||||
f"top_k must be -1 (disable), or at least 1, " f"got {self.top_k}."
|
||||
f"top_k must be -1 (disable) or at least 1, got {self.top_k}."
|
||||
)
|
||||
if not -2.0 <= self.frequency_penalty <= 2.0:
|
||||
raise ValueError(
|
||||
@@ -108,12 +108,12 @@ class SamplingParams:
|
||||
)
|
||||
if not 0.0 <= self.repetition_penalty <= 2.0:
|
||||
raise ValueError(
|
||||
"repetition_penalty must be in (0, 2], got "
|
||||
"repetition_penalty must be in [0, 2], got "
|
||||
f"{self.repetition_penalty}."
|
||||
)
|
||||
if not 0 <= self.min_new_tokens:
|
||||
raise ValueError(
|
||||
f"min_new_tokens must be in (0, max_new_tokens], got "
|
||||
f"min_new_tokens must be in [0, max_new_tokens], got "
|
||||
f"{self.min_new_tokens}."
|
||||
)
|
||||
if self.max_new_tokens is not None:
|
||||
@@ -123,7 +123,7 @@ class SamplingParams:
|
||||
)
|
||||
if not self.min_new_tokens <= self.max_new_tokens:
|
||||
raise ValueError(
|
||||
f"min_new_tokens must be in (0, max_new_tokens({self.max_new_tokens})], got "
|
||||
f"min_new_tokens must be in [0, max_new_tokens({self.max_new_tokens})], got "
|
||||
f"{self.min_new_tokens}."
|
||||
)
|
||||
grammars = [
|
||||
|
||||
Reference in New Issue
Block a user