Fix the --allow-auto-truncate argument in tokenizer manager. (#9391)

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Liangsheng Yin
2025-08-20 16:56:47 +08:00
committed by GitHub
parent 42c8704560
commit 08ebdf79d0

View File

@@ -565,14 +565,24 @@ class TokenizerManager:
self, obj: Union[GenerateReqInput, EmbeddingReqInput], input_ids: List[int] self, obj: Union[GenerateReqInput, EmbeddingReqInput], input_ids: List[int]
) -> None: ) -> None:
"""Validates that the input token count and the requested token count doesn't exceed the model's context length.""" """Validates that the input token count and the requested token count doesn't exceed the model's context length."""
# FIXME: unify the length validation logic with the one in the scheduler.
_max_req_len = self.context_len - 1
input_token_num = len(input_ids) if input_ids is not None else 0 input_token_num = len(input_ids) if input_ids is not None else 0
# Check if input alone exceeds context length
if input_token_num >= self.context_len: if input_token_num >= self.context_len:
raise ValueError( if self.server_args.allow_auto_truncate:
f"The input ({input_token_num} tokens) is longer than the " logger.warning(
f"model's context length ({self.context_len} tokens)." f"The input ({input_token_num} tokens) is longer than the "
) f"model's context length ({self.context_len} tokens). "
"Truncating the input."
)
input_ids = input_ids[:_max_req_len]
input_token_num = len(input_ids)
else:
raise ValueError(
f"The input ({input_token_num} tokens) is longer than the "
f"model's context length ({self.context_len} tokens)."
)
if isinstance(obj, EmbeddingReqInput) and self.is_generation: if isinstance(obj, EmbeddingReqInput) and self.is_generation:
raise ValueError( raise ValueError(
@@ -584,17 +594,27 @@ class TokenizerManager:
max_new_tokens = obj.sampling_params.get("max_new_tokens") max_new_tokens = obj.sampling_params.get("max_new_tokens")
if ( if (
max_new_tokens is not None max_new_tokens is not None
and (max_new_tokens + input_token_num) >= self.context_len and (max_new_tokens + input_token_num) >= _max_req_len
): ):
total_tokens = max_new_tokens + input_token_num if self.server_args.allow_auto_truncate:
error_msg = ( logger.warning(
f"Requested token count exceeds the model's maximum context length " f"Requested token count ({input_token_num} input + {max_new_tokens} new) "
f"of {self.context_len} tokens. You requested a total of {total_tokens} " f"exceeds the model's context length ({self.context_len} tokens). "
f"tokens: {input_token_num} tokens from the input messages and " "Truncating max_new_tokens."
f"{max_new_tokens} tokens for the completion. Please reduce the number " )
f"of tokens in the input messages or the completion to fit within the limit." obj.sampling_params["max_new_tokens"] = max(
) 0, _max_req_len - input_token_num
raise ValueError(error_msg) )
else:
total_tokens = max_new_tokens + input_token_num
error_msg = (
f"Requested token count exceeds the model's maximum context length "
f"of {self.context_len} tokens. You requested a total of {total_tokens} "
f"tokens: {input_token_num} tokens from the input messages and "
f"{max_new_tokens} tokens for the completion. Please reduce the number "
f"of tokens in the input messages or the completion to fit within the limit."
)
raise ValueError(error_msg)
if isinstance(obj, GenerateReqInput): if isinstance(obj, GenerateReqInput):
if ( if (