Fix the --allow-auto-truncate argument in tokenizer manager. (#9391)

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Liangsheng Yin
2025-08-20 16:56:47 +08:00
committed by GitHub
parent 42c8704560
commit 08ebdf79d0

View File

@@ -565,10 +565,20 @@ class TokenizerManager:
self, obj: Union[GenerateReqInput, EmbeddingReqInput], input_ids: List[int] self, obj: Union[GenerateReqInput, EmbeddingReqInput], input_ids: List[int]
) -> None: ) -> None:
"""Validates that the input token count and the requested token count doesn't exceed the model's context length.""" """Validates that the input token count and the requested token count doesn't exceed the model's context length."""
# FIXME: unify the length validation logic with the one in the scheduler.
_max_req_len = self.context_len - 1
input_token_num = len(input_ids) if input_ids is not None else 0 input_token_num = len(input_ids) if input_ids is not None else 0
# Check if input alone exceeds context length
if input_token_num >= self.context_len: if input_token_num >= self.context_len:
if self.server_args.allow_auto_truncate:
logger.warning(
f"The input ({input_token_num} tokens) is longer than the "
f"model's context length ({self.context_len} tokens). "
"Truncating the input."
)
input_ids = input_ids[:_max_req_len]
input_token_num = len(input_ids)
else:
raise ValueError( raise ValueError(
f"The input ({input_token_num} tokens) is longer than the " f"The input ({input_token_num} tokens) is longer than the "
f"model's context length ({self.context_len} tokens)." f"model's context length ({self.context_len} tokens)."
@@ -584,8 +594,18 @@ class TokenizerManager:
max_new_tokens = obj.sampling_params.get("max_new_tokens") max_new_tokens = obj.sampling_params.get("max_new_tokens")
if ( if (
max_new_tokens is not None max_new_tokens is not None
and (max_new_tokens + input_token_num) >= self.context_len and (max_new_tokens + input_token_num) >= _max_req_len
): ):
if self.server_args.allow_auto_truncate:
logger.warning(
f"Requested token count ({input_token_num} input + {max_new_tokens} new) "
f"exceeds the model's context length ({self.context_len} tokens). "
"Truncating max_new_tokens."
)
obj.sampling_params["max_new_tokens"] = max(
0, _max_req_len - input_token_num
)
else:
total_tokens = max_new_tokens + input_token_num total_tokens = max_new_tokens + input_token_num
error_msg = ( error_msg = (
f"Requested token count exceeds the model's maximum context length " f"Requested token count exceeds the model's maximum context length "