[BugFix] add verify logit_bias to avoid crash because of IndexError (#7749)
This commit is contained in:
@@ -604,7 +604,7 @@ class TokenizerManager:
|
||||
sampling_kwargs = obj.sampling_params
|
||||
sampling_params = SamplingParams(**sampling_kwargs)
|
||||
sampling_params.normalize(self.tokenizer)
|
||||
sampling_params.verify()
|
||||
sampling_params.verify(self.model_config.vocab_size)
|
||||
|
||||
# Build return object
|
||||
if isinstance(obj, GenerateReqInput):
|
||||
|
||||
@@ -89,7 +89,7 @@ class SamplingParams:
|
||||
if self.top_k == -1:
|
||||
self.top_k = TOP_K_ALL # whole vocabulary
|
||||
|
||||
def verify(self):
|
||||
def verify(self, vocab_size):
|
||||
if self.temperature < 0.0:
|
||||
raise ValueError(
|
||||
f"temperature must be non-negative, got {self.temperature}."
|
||||
@@ -131,6 +131,13 @@ class SamplingParams:
|
||||
f"min_new_tokens must be in [0, max_new_tokens({self.max_new_tokens})], got "
|
||||
f"{self.min_new_tokens}."
|
||||
)
|
||||
if self.logit_bias is not None:
|
||||
for token_id in self.logit_bias:
|
||||
if not 0 <= int(token_id) < vocab_size:
|
||||
raise ValueError(
|
||||
f"logit_bias must has keys in [0, {vocab_size - 1}], got "
|
||||
f"{token_id}."
|
||||
)
|
||||
grammars = [
|
||||
self.json_schema,
|
||||
self.regex,
|
||||
|
||||
Reference in New Issue
Block a user