diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 3a81a3636..b0416a065 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -565,14 +565,24 @@ class TokenizerManager: self, obj: Union[GenerateReqInput, EmbeddingReqInput], input_ids: List[int] ) -> None: """Validates that the input token count and the requested token count doesn't exceed the model's context length.""" + # FIXME: unify the length validation logic with the one in the scheduler. + _max_req_len = self.context_len - 1 input_token_num = len(input_ids) if input_ids is not None else 0 - # Check if input alone exceeds context length if input_token_num >= self.context_len: - raise ValueError( - f"The input ({input_token_num} tokens) is longer than the " - f"model's context length ({self.context_len} tokens)." - ) + if self.server_args.allow_auto_truncate: + logger.warning( + f"The input ({input_token_num} tokens) is longer than the " + f"model's context length ({self.context_len} tokens). " + "Truncating the input." + ) + input_ids = input_ids[:_max_req_len] + input_token_num = len(input_ids) + else: + raise ValueError( + f"The input ({input_token_num} tokens) is longer than the " + f"model's context length ({self.context_len} tokens)." + ) if isinstance(obj, EmbeddingReqInput) and self.is_generation: raise ValueError( @@ -584,17 +594,27 @@ class TokenizerManager: max_new_tokens = obj.sampling_params.get("max_new_tokens") if ( max_new_tokens is not None - and (max_new_tokens + input_token_num) >= self.context_len + and (max_new_tokens + input_token_num) >= _max_req_len ): - total_tokens = max_new_tokens + input_token_num - error_msg = ( - f"Requested token count exceeds the model's maximum context length " - f"of {self.context_len} tokens. You requested a total of {total_tokens} " - f"tokens: {input_token_num} tokens from the input messages and " - f"{max_new_tokens} tokens for the completion. Please reduce the number " - f"of tokens in the input messages or the completion to fit within the limit." - ) - raise ValueError(error_msg) + if self.server_args.allow_auto_truncate: + logger.warning( + f"Requested token count ({input_token_num} input + {max_new_tokens} new) " + f"exceeds the model's context length ({self.context_len} tokens). " + "Truncating max_new_tokens." + ) + obj.sampling_params["max_new_tokens"] = max( + 0, _max_req_len - input_token_num + ) + else: + total_tokens = max_new_tokens + input_token_num + error_msg = ( + f"Requested token count exceeds the model's maximum context length " + f"of {self.context_len} tokens. You requested a total of {total_tokens} " + f"tokens: {input_token_num} tokens from the input messages and " + f"{max_new_tokens} tokens for the completion. Please reduce the number " + f"of tokens in the input messages or the completion to fit within the limit." + ) + raise ValueError(error_msg) if isinstance(obj, GenerateReqInput): if (