Fix the --allow-auto-truncate argument in tokenizer manager. (#9391)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -565,10 +565,20 @@ class TokenizerManager:
|
|||||||
self, obj: Union[GenerateReqInput, EmbeddingReqInput], input_ids: List[int]
|
self, obj: Union[GenerateReqInput, EmbeddingReqInput], input_ids: List[int]
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Validates that the input token count and the requested token count doesn't exceed the model's context length."""
|
"""Validates that the input token count and the requested token count doesn't exceed the model's context length."""
|
||||||
|
# FIXME: unify the length validation logic with the one in the scheduler.
|
||||||
|
_max_req_len = self.context_len - 1
|
||||||
|
|
||||||
input_token_num = len(input_ids) if input_ids is not None else 0
|
input_token_num = len(input_ids) if input_ids is not None else 0
|
||||||
# Check if input alone exceeds context length
|
|
||||||
if input_token_num >= self.context_len:
|
if input_token_num >= self.context_len:
|
||||||
|
if self.server_args.allow_auto_truncate:
|
||||||
|
logger.warning(
|
||||||
|
f"The input ({input_token_num} tokens) is longer than the "
|
||||||
|
f"model's context length ({self.context_len} tokens). "
|
||||||
|
"Truncating the input."
|
||||||
|
)
|
||||||
|
input_ids = input_ids[:_max_req_len]
|
||||||
|
input_token_num = len(input_ids)
|
||||||
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"The input ({input_token_num} tokens) is longer than the "
|
f"The input ({input_token_num} tokens) is longer than the "
|
||||||
f"model's context length ({self.context_len} tokens)."
|
f"model's context length ({self.context_len} tokens)."
|
||||||
@@ -584,8 +594,18 @@ class TokenizerManager:
|
|||||||
max_new_tokens = obj.sampling_params.get("max_new_tokens")
|
max_new_tokens = obj.sampling_params.get("max_new_tokens")
|
||||||
if (
|
if (
|
||||||
max_new_tokens is not None
|
max_new_tokens is not None
|
||||||
and (max_new_tokens + input_token_num) >= self.context_len
|
and (max_new_tokens + input_token_num) >= _max_req_len
|
||||||
):
|
):
|
||||||
|
if self.server_args.allow_auto_truncate:
|
||||||
|
logger.warning(
|
||||||
|
f"Requested token count ({input_token_num} input + {max_new_tokens} new) "
|
||||||
|
f"exceeds the model's context length ({self.context_len} tokens). "
|
||||||
|
"Truncating max_new_tokens."
|
||||||
|
)
|
||||||
|
obj.sampling_params["max_new_tokens"] = max(
|
||||||
|
0, _max_req_len - input_token_num
|
||||||
|
)
|
||||||
|
else:
|
||||||
total_tokens = max_new_tokens + input_token_num
|
total_tokens = max_new_tokens + input_token_num
|
||||||
error_msg = (
|
error_msg = (
|
||||||
f"Requested token count exceeds the model's maximum context length "
|
f"Requested token count exceeds the model's maximum context length "
|
||||||
|
|||||||
Reference in New Issue
Block a user