[Frontend] Fix request length check and add option to disallow auto truncation in scheduler (#2876)

This commit is contained in:
Chang Su
2025-01-16 14:51:19 -08:00
committed by GitHub
parent 0427416b59
commit a8ccacc8b8
6 changed files with 154 additions and 17 deletions

View File

@@ -78,6 +78,7 @@ from sglang.srt.managers.schedule_policy import (
from sglang.srt.managers.session_controller import Session
from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
from sglang.srt.managers.utils import validate_input_length
from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.radix_cache import RadixCache
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
@@ -690,14 +691,16 @@ class Scheduler:
# By default, only return the logprobs for output tokens
req.logprob_start_len = len(req.origin_input_ids) - 1
# Truncate prompts that are too long
if len(req.origin_input_ids) > self.max_req_input_len:
logger.warning(
"Request length is longer than the KV cache pool size or "
"the max context length. Truncated. "
f"{len(req.origin_input_ids)=}, {self.max_req_input_len=}."
)
req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]
# Validate prompts length
error_msg = validate_input_length(
req,
self.max_req_input_len,
self.server_args.allow_auto_truncate,
)
if error_msg:
self.waiting_queue.append(req)
return
req.sampling_params.max_new_tokens = min(
(
@@ -745,13 +748,12 @@ class Scheduler:
)
req.tokenizer = self.tokenizer
# Truncate prompts that are too long
if len(req.origin_input_ids) >= self.max_req_input_len:
logger.warning(
"Request length is longer than the KV cache pool size or "
"the max context length. Truncated!!!"
)
req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]
# Validate prompts length
validate_input_length(
req,
self.max_req_input_len,
self.server_args.allow_auto_truncate,
)
self.waiting_queue.append(req)

View File

@@ -292,12 +292,28 @@ class TokenizerManager:
SessionParams(**obj.session_params) if obj.session_params else None
)
if obj.input_ids is not None and len(input_ids) >= self.context_len:
input_token_num = len(input_ids) if input_ids is not None else 0
if input_token_num >= self.context_len:
raise ValueError(
f"The input ({len(input_ids)} tokens) is longer than the "
f"The input ({input_token_num} tokens) is longer than the "
f"model's context length ({self.context_len} tokens)."
)
if (
obj.sampling_params.get("max_new_tokens") is not None
and obj.sampling_params.get("max_new_tokens") + input_token_num
>= self.context_len
):
raise ValueError(
f"Requested token count exceeds the model's maximum context length "
f"of {self.context_len} tokens. You requested a total of "
f"{obj.sampling_params.get('max_new_tokens') + input_token_num} "
f"tokens: {input_token_num} tokens from the input messages and "
f"{obj.sampling_params.get('max_new_tokens')} tokens for the "
f"completion. Please reduce the number of tokens in the input "
f"messages or the completion to fit within the limit."
)
# Parse sampling parameters
sampling_params = SamplingParams(**obj.sampling_params)
sampling_params.normalize(self.tokenizer)

View File

@@ -0,0 +1,41 @@
import logging
from typing import Optional
from sglang.srt.managers.schedule_batch import FINISH_ABORT, Req
logger = logging.getLogger(__name__)
def validate_input_length(
req: Req, max_req_input_len: int, allow_auto_truncate: bool
) -> Optional[str]:
"""Validate and potentially truncate input length.
Args:
req: The request containing input_ids to validate
max_req_input_len: Maximum allowed input length
allow_auto_truncate: Whether to truncate long inputs
Returns:
Error message if validation fails, None if successful
"""
if len(req.origin_input_ids) >= max_req_input_len:
if allow_auto_truncate:
logger.warning(
"Request length is longer than the KV cache pool size or "
"the max context length. Truncated. "
f"{len(req.origin_input_ids)=}, {max_req_input_len=}."
)
req.origin_input_ids = req.origin_input_ids[:max_req_input_len]
return None
else:
error_msg = (
f"Input length ({len(req.origin_input_ids)} tokens) exceeds "
f"the maximum allowed length ({max_req_input_len} tokens). "
f"Use a shorter input or enable --allow-auto-truncate."
)
logger.error(error_msg)
req.finished_reason = FINISH_ABORT(error_msg)
return error_msg
return None

View File

@@ -157,6 +157,7 @@ class ServerArgs:
num_continuous_decode_steps: int = 1
delete_ckpt_after_loading: bool = False
enable_memory_saver: bool = False
allow_auto_truncate: bool = False
def __post_init__(self):
# Set missing default values
@@ -859,6 +860,11 @@ class ServerArgs:
action="store_true",
help="Allow saving memory using release_memory_occupation and resume_memory_occupation",
)
parser.add_argument(
"--allow-auto-truncate",
action="store_true",
help="Allow automatically truncating requests that exceed the maximum input length instead of returning an error.",
)
@classmethod
def from_cli_args(cls, args: argparse.Namespace):