From a8ccacc8b8118ded5c993fe42e27fa3c6533e6d8 Mon Sep 17 00:00:00 2001 From: Chang Su Date: Thu, 16 Jan 2025 14:51:19 -0800 Subject: [PATCH] [Frontend] Fix request length check and add option to disallow auto truncation in scheduler (#2876) --- python/sglang/srt/managers/scheduler.py | 32 +++++---- .../sglang/srt/managers/tokenizer_manager.py | 20 +++++- python/sglang/srt/managers/utils.py | 41 +++++++++++ python/sglang/srt/server_args.py | 6 ++ test/srt/run_suite.py | 1 + test/srt/test_request_length_validation.py | 71 +++++++++++++++++++ 6 files changed, 154 insertions(+), 17 deletions(-) create mode 100644 python/sglang/srt/managers/utils.py create mode 100644 test/srt/test_request_length_validation.py diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index e00bd980f..0619b2e98 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -78,6 +78,7 @@ from sglang.srt.managers.schedule_policy import ( from sglang.srt.managers.session_controller import Session from sglang.srt.managers.tp_worker import TpModelWorker from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient +from sglang.srt.managers.utils import validate_input_length from sglang.srt.mem_cache.chunk_cache import ChunkCache from sglang.srt.mem_cache.radix_cache import RadixCache from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats @@ -690,14 +691,16 @@ class Scheduler: # By default, only return the logprobs for output tokens req.logprob_start_len = len(req.origin_input_ids) - 1 - # Truncate prompts that are too long - if len(req.origin_input_ids) > self.max_req_input_len: - logger.warning( - "Request length is longer than the KV cache pool size or " - "the max context length. Truncated. " - f"{len(req.origin_input_ids)=}, {self.max_req_input_len=}." - ) - req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len] + # Validate prompts length + error_msg = validate_input_length( + req, + self.max_req_input_len, + self.server_args.allow_auto_truncate, + ) + + if error_msg: + self.waiting_queue.append(req) + return req.sampling_params.max_new_tokens = min( ( @@ -745,13 +748,12 @@ class Scheduler: ) req.tokenizer = self.tokenizer - # Truncate prompts that are too long - if len(req.origin_input_ids) >= self.max_req_input_len: - logger.warning( - "Request length is longer than the KV cache pool size or " - "the max context length. Truncated!!!" - ) - req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len] + # Validate prompts length + validate_input_length( + req, + self.max_req_input_len, + self.server_args.allow_auto_truncate, + ) self.waiting_queue.append(req) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 230d4f8d0..18ac7503c 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -292,12 +292,28 @@ class TokenizerManager: SessionParams(**obj.session_params) if obj.session_params else None ) - if obj.input_ids is not None and len(input_ids) >= self.context_len: + input_token_num = len(input_ids) if input_ids is not None else 0 + if input_token_num >= self.context_len: raise ValueError( - f"The input ({len(input_ids)} tokens) is longer than the " + f"The input ({input_token_num} tokens) is longer than the " f"model's context length ({self.context_len} tokens)." ) + if ( + obj.sampling_params.get("max_new_tokens") is not None + and obj.sampling_params.get("max_new_tokens") + input_token_num + >= self.context_len + ): + raise ValueError( + f"Requested token count exceeds the model's maximum context length " + f"of {self.context_len} tokens. You requested a total of " + f"{obj.sampling_params.get('max_new_tokens') + input_token_num} " + f"tokens: {input_token_num} tokens from the input messages and " + f"{obj.sampling_params.get('max_new_tokens')} tokens for the " + f"completion. Please reduce the number of tokens in the input " + f"messages or the completion to fit within the limit." + ) + # Parse sampling parameters sampling_params = SamplingParams(**obj.sampling_params) sampling_params.normalize(self.tokenizer) diff --git a/python/sglang/srt/managers/utils.py b/python/sglang/srt/managers/utils.py new file mode 100644 index 000000000..0ab5a0909 --- /dev/null +++ b/python/sglang/srt/managers/utils.py @@ -0,0 +1,41 @@ +import logging +from typing import Optional + +from sglang.srt.managers.schedule_batch import FINISH_ABORT, Req + +logger = logging.getLogger(__name__) + + +def validate_input_length( + req: Req, max_req_input_len: int, allow_auto_truncate: bool +) -> Optional[str]: + """Validate and potentially truncate input length. + + Args: + req: The request containing input_ids to validate + max_req_input_len: Maximum allowed input length + allow_auto_truncate: Whether to truncate long inputs + + Returns: + Error message if validation fails, None if successful + """ + if len(req.origin_input_ids) >= max_req_input_len: + if allow_auto_truncate: + logger.warning( + "Request length is longer than the KV cache pool size or " + "the max context length. Truncated. " + f"{len(req.origin_input_ids)=}, {max_req_input_len=}." + ) + req.origin_input_ids = req.origin_input_ids[:max_req_input_len] + return None + else: + error_msg = ( + f"Input length ({len(req.origin_input_ids)} tokens) exceeds " + f"the maximum allowed length ({max_req_input_len} tokens). " + f"Use a shorter input or enable --allow-auto-truncate." + ) + logger.error(error_msg) + req.finished_reason = FINISH_ABORT(error_msg) + return error_msg + + return None diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 9d4ec90e9..5d490d3f8 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -157,6 +157,7 @@ class ServerArgs: num_continuous_decode_steps: int = 1 delete_ckpt_after_loading: bool = False enable_memory_saver: bool = False + allow_auto_truncate: bool = False def __post_init__(self): # Set missing default values @@ -859,6 +860,11 @@ class ServerArgs: action="store_true", help="Allow saving memory using release_memory_occupation and resume_memory_occupation", ) + parser.add_argument( + "--allow-auto-truncate", + action="store_true", + help="Allow automatically truncating requests that exceed the maximum input length instead of returning an error.", + ) @classmethod def from_cli_args(cls, args: argparse.Namespace): diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index b00c866a9..83e24e3a8 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -31,6 +31,7 @@ suites = { "test_pytorch_sampling_backend.py", "test_radix_attention.py", "test_release_memory_occupation.py", + "test_request_length_validation.py", "test_retract_decode.py", "test_server_args.py", "test_session_control.py", diff --git a/test/srt/test_request_length_validation.py b/test/srt/test_request_length_validation.py new file mode 100644 index 000000000..713e3e21e --- /dev/null +++ b/test/srt/test_request_length_validation.py @@ -0,0 +1,71 @@ +import unittest + +import openai + +from sglang.srt.utils import kill_process_tree +from sglang.test.test_utils import ( + DEFAULT_SMALL_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + popen_launch_server, +) + + +class TestRequestLengthValidation(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.base_url = DEFAULT_URL_FOR_TEST + cls.api_key = "sk-123456" + + # Start server with auto truncate disabled + cls.process = popen_launch_server( + DEFAULT_SMALL_MODEL_NAME_FOR_TEST, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + api_key=cls.api_key, + other_args=("--max-total-tokens", "1000", "--context-length", "100"), + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_input_length_validation(self): + client = openai.Client(api_key=self.api_key, base_url=f"{self.base_url}/v1") + + long_text = "hello " * 100 # Will tokenize to more than context length + + with self.assertRaises(openai.BadRequestError) as cm: + client.chat.completions.create( + model=DEFAULT_SMALL_MODEL_NAME_FOR_TEST, + messages=[ + {"role": "user", "content": long_text}, + ], + temperature=0, + ) + + self.assertIn("is longer than the model's context length", str(cm.exception)) + + def test_max_tokens_validation(self): + client = openai.Client(api_key=self.api_key, base_url=f"{self.base_url}/v1") + + long_text = "hello " + + with self.assertRaises(openai.BadRequestError) as cm: + client.chat.completions.create( + model=DEFAULT_SMALL_MODEL_NAME_FOR_TEST, + messages=[ + {"role": "user", "content": long_text}, + ], + temperature=0, + max_tokens=500, + ) + + self.assertIn( + "Requested token count exceeds the model's maximum context", + str(cm.exception), + ) + + +if __name__ == "__main__": + unittest.main()