Fix the --allow-auto-truncate argument in tokenizer manager. (#9391)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -565,14 +565,24 @@ class TokenizerManager:
|
||||
self, obj: Union[GenerateReqInput, EmbeddingReqInput], input_ids: List[int]
|
||||
) -> None:
|
||||
"""Validates that the input token count and the requested token count doesn't exceed the model's context length."""
|
||||
# FIXME: unify the length validation logic with the one in the scheduler.
|
||||
_max_req_len = self.context_len - 1
|
||||
|
||||
input_token_num = len(input_ids) if input_ids is not None else 0
|
||||
# Check if input alone exceeds context length
|
||||
if input_token_num >= self.context_len:
|
||||
raise ValueError(
|
||||
f"The input ({input_token_num} tokens) is longer than the "
|
||||
f"model's context length ({self.context_len} tokens)."
|
||||
)
|
||||
if self.server_args.allow_auto_truncate:
|
||||
logger.warning(
|
||||
f"The input ({input_token_num} tokens) is longer than the "
|
||||
f"model's context length ({self.context_len} tokens). "
|
||||
"Truncating the input."
|
||||
)
|
||||
input_ids = input_ids[:_max_req_len]
|
||||
input_token_num = len(input_ids)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"The input ({input_token_num} tokens) is longer than the "
|
||||
f"model's context length ({self.context_len} tokens)."
|
||||
)
|
||||
|
||||
if isinstance(obj, EmbeddingReqInput) and self.is_generation:
|
||||
raise ValueError(
|
||||
@@ -584,17 +594,27 @@ class TokenizerManager:
|
||||
max_new_tokens = obj.sampling_params.get("max_new_tokens")
|
||||
if (
|
||||
max_new_tokens is not None
|
||||
and (max_new_tokens + input_token_num) >= self.context_len
|
||||
and (max_new_tokens + input_token_num) >= _max_req_len
|
||||
):
|
||||
total_tokens = max_new_tokens + input_token_num
|
||||
error_msg = (
|
||||
f"Requested token count exceeds the model's maximum context length "
|
||||
f"of {self.context_len} tokens. You requested a total of {total_tokens} "
|
||||
f"tokens: {input_token_num} tokens from the input messages and "
|
||||
f"{max_new_tokens} tokens for the completion. Please reduce the number "
|
||||
f"of tokens in the input messages or the completion to fit within the limit."
|
||||
)
|
||||
raise ValueError(error_msg)
|
||||
if self.server_args.allow_auto_truncate:
|
||||
logger.warning(
|
||||
f"Requested token count ({input_token_num} input + {max_new_tokens} new) "
|
||||
f"exceeds the model's context length ({self.context_len} tokens). "
|
||||
"Truncating max_new_tokens."
|
||||
)
|
||||
obj.sampling_params["max_new_tokens"] = max(
|
||||
0, _max_req_len - input_token_num
|
||||
)
|
||||
else:
|
||||
total_tokens = max_new_tokens + input_token_num
|
||||
error_msg = (
|
||||
f"Requested token count exceeds the model's maximum context length "
|
||||
f"of {self.context_len} tokens. You requested a total of {total_tokens} "
|
||||
f"tokens: {input_token_num} tokens from the input messages and "
|
||||
f"{max_new_tokens} tokens for the completion. Please reduce the number "
|
||||
f"of tokens in the input messages or the completion to fit within the limit."
|
||||
)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
if isinstance(obj, GenerateReqInput):
|
||||
if (
|
||||
|
||||
Reference in New Issue
Block a user