Fix the --allow-auto-truncate argument in tokenizer manager. (#9391)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -565,14 +565,24 @@ class TokenizerManager:
|
|||||||
self, obj: Union[GenerateReqInput, EmbeddingReqInput], input_ids: List[int]
|
self, obj: Union[GenerateReqInput, EmbeddingReqInput], input_ids: List[int]
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Validates that the input token count and the requested token count doesn't exceed the model's context length."""
|
"""Validates that the input token count and the requested token count doesn't exceed the model's context length."""
|
||||||
|
# FIXME: unify the length validation logic with the one in the scheduler.
|
||||||
|
_max_req_len = self.context_len - 1
|
||||||
|
|
||||||
input_token_num = len(input_ids) if input_ids is not None else 0
|
input_token_num = len(input_ids) if input_ids is not None else 0
|
||||||
# Check if input alone exceeds context length
|
|
||||||
if input_token_num >= self.context_len:
|
if input_token_num >= self.context_len:
|
||||||
raise ValueError(
|
if self.server_args.allow_auto_truncate:
|
||||||
f"The input ({input_token_num} tokens) is longer than the "
|
logger.warning(
|
||||||
f"model's context length ({self.context_len} tokens)."
|
f"The input ({input_token_num} tokens) is longer than the "
|
||||||
)
|
f"model's context length ({self.context_len} tokens). "
|
||||||
|
"Truncating the input."
|
||||||
|
)
|
||||||
|
input_ids = input_ids[:_max_req_len]
|
||||||
|
input_token_num = len(input_ids)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"The input ({input_token_num} tokens) is longer than the "
|
||||||
|
f"model's context length ({self.context_len} tokens)."
|
||||||
|
)
|
||||||
|
|
||||||
if isinstance(obj, EmbeddingReqInput) and self.is_generation:
|
if isinstance(obj, EmbeddingReqInput) and self.is_generation:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -584,17 +594,27 @@ class TokenizerManager:
|
|||||||
max_new_tokens = obj.sampling_params.get("max_new_tokens")
|
max_new_tokens = obj.sampling_params.get("max_new_tokens")
|
||||||
if (
|
if (
|
||||||
max_new_tokens is not None
|
max_new_tokens is not None
|
||||||
and (max_new_tokens + input_token_num) >= self.context_len
|
and (max_new_tokens + input_token_num) >= _max_req_len
|
||||||
):
|
):
|
||||||
total_tokens = max_new_tokens + input_token_num
|
if self.server_args.allow_auto_truncate:
|
||||||
error_msg = (
|
logger.warning(
|
||||||
f"Requested token count exceeds the model's maximum context length "
|
f"Requested token count ({input_token_num} input + {max_new_tokens} new) "
|
||||||
f"of {self.context_len} tokens. You requested a total of {total_tokens} "
|
f"exceeds the model's context length ({self.context_len} tokens). "
|
||||||
f"tokens: {input_token_num} tokens from the input messages and "
|
"Truncating max_new_tokens."
|
||||||
f"{max_new_tokens} tokens for the completion. Please reduce the number "
|
)
|
||||||
f"of tokens in the input messages or the completion to fit within the limit."
|
obj.sampling_params["max_new_tokens"] = max(
|
||||||
)
|
0, _max_req_len - input_token_num
|
||||||
raise ValueError(error_msg)
|
)
|
||||||
|
else:
|
||||||
|
total_tokens = max_new_tokens + input_token_num
|
||||||
|
error_msg = (
|
||||||
|
f"Requested token count exceeds the model's maximum context length "
|
||||||
|
f"of {self.context_len} tokens. You requested a total of {total_tokens} "
|
||||||
|
f"tokens: {input_token_num} tokens from the input messages and "
|
||||||
|
f"{max_new_tokens} tokens for the completion. Please reduce the number "
|
||||||
|
f"of tokens in the input messages or the completion to fit within the limit."
|
||||||
|
)
|
||||||
|
raise ValueError(error_msg)
|
||||||
|
|
||||||
if isinstance(obj, GenerateReqInput):
|
if isinstance(obj, GenerateReqInput):
|
||||||
if (
|
if (
|
||||||
|
|||||||
Reference in New Issue
Block a user