[Frontend] Fix request length check and add option to disallow auto truncation in scheduler (#2876)
This commit is contained in:
@@ -78,6 +78,7 @@ from sglang.srt.managers.schedule_policy import (
|
||||
from sglang.srt.managers.session_controller import Session
|
||||
from sglang.srt.managers.tp_worker import TpModelWorker
|
||||
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
|
||||
from sglang.srt.managers.utils import validate_input_length
|
||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
||||
from sglang.srt.mem_cache.radix_cache import RadixCache
|
||||
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
|
||||
@@ -690,14 +691,16 @@ class Scheduler:
|
||||
# By default, only return the logprobs for output tokens
|
||||
req.logprob_start_len = len(req.origin_input_ids) - 1
|
||||
|
||||
# Truncate prompts that are too long
|
||||
if len(req.origin_input_ids) > self.max_req_input_len:
|
||||
logger.warning(
|
||||
"Request length is longer than the KV cache pool size or "
|
||||
"the max context length. Truncated. "
|
||||
f"{len(req.origin_input_ids)=}, {self.max_req_input_len=}."
|
||||
)
|
||||
req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]
|
||||
# Validate prompts length
|
||||
error_msg = validate_input_length(
|
||||
req,
|
||||
self.max_req_input_len,
|
||||
self.server_args.allow_auto_truncate,
|
||||
)
|
||||
|
||||
if error_msg:
|
||||
self.waiting_queue.append(req)
|
||||
return
|
||||
|
||||
req.sampling_params.max_new_tokens = min(
|
||||
(
|
||||
@@ -745,13 +748,12 @@ class Scheduler:
|
||||
)
|
||||
req.tokenizer = self.tokenizer
|
||||
|
||||
# Truncate prompts that are too long
|
||||
if len(req.origin_input_ids) >= self.max_req_input_len:
|
||||
logger.warning(
|
||||
"Request length is longer than the KV cache pool size or "
|
||||
"the max context length. Truncated!!!"
|
||||
)
|
||||
req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]
|
||||
# Validate prompts length
|
||||
validate_input_length(
|
||||
req,
|
||||
self.max_req_input_len,
|
||||
self.server_args.allow_auto_truncate,
|
||||
)
|
||||
|
||||
self.waiting_queue.append(req)
|
||||
|
||||
|
||||
@@ -292,12 +292,28 @@ class TokenizerManager:
|
||||
SessionParams(**obj.session_params) if obj.session_params else None
|
||||
)
|
||||
|
||||
if obj.input_ids is not None and len(input_ids) >= self.context_len:
|
||||
input_token_num = len(input_ids) if input_ids is not None else 0
|
||||
if input_token_num >= self.context_len:
|
||||
raise ValueError(
|
||||
f"The input ({len(input_ids)} tokens) is longer than the "
|
||||
f"The input ({input_token_num} tokens) is longer than the "
|
||||
f"model's context length ({self.context_len} tokens)."
|
||||
)
|
||||
|
||||
if (
|
||||
obj.sampling_params.get("max_new_tokens") is not None
|
||||
and obj.sampling_params.get("max_new_tokens") + input_token_num
|
||||
>= self.context_len
|
||||
):
|
||||
raise ValueError(
|
||||
f"Requested token count exceeds the model's maximum context length "
|
||||
f"of {self.context_len} tokens. You requested a total of "
|
||||
f"{obj.sampling_params.get('max_new_tokens') + input_token_num} "
|
||||
f"tokens: {input_token_num} tokens from the input messages and "
|
||||
f"{obj.sampling_params.get('max_new_tokens')} tokens for the "
|
||||
f"completion. Please reduce the number of tokens in the input "
|
||||
f"messages or the completion to fit within the limit."
|
||||
)
|
||||
|
||||
# Parse sampling parameters
|
||||
sampling_params = SamplingParams(**obj.sampling_params)
|
||||
sampling_params.normalize(self.tokenizer)
|
||||
|
||||
41
python/sglang/srt/managers/utils.py
Normal file
41
python/sglang/srt/managers/utils.py
Normal file
@@ -0,0 +1,41 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from sglang.srt.managers.schedule_batch import FINISH_ABORT, Req
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def validate_input_length(
|
||||
req: Req, max_req_input_len: int, allow_auto_truncate: bool
|
||||
) -> Optional[str]:
|
||||
"""Validate and potentially truncate input length.
|
||||
|
||||
Args:
|
||||
req: The request containing input_ids to validate
|
||||
max_req_input_len: Maximum allowed input length
|
||||
allow_auto_truncate: Whether to truncate long inputs
|
||||
|
||||
Returns:
|
||||
Error message if validation fails, None if successful
|
||||
"""
|
||||
if len(req.origin_input_ids) >= max_req_input_len:
|
||||
if allow_auto_truncate:
|
||||
logger.warning(
|
||||
"Request length is longer than the KV cache pool size or "
|
||||
"the max context length. Truncated. "
|
||||
f"{len(req.origin_input_ids)=}, {max_req_input_len=}."
|
||||
)
|
||||
req.origin_input_ids = req.origin_input_ids[:max_req_input_len]
|
||||
return None
|
||||
else:
|
||||
error_msg = (
|
||||
f"Input length ({len(req.origin_input_ids)} tokens) exceeds "
|
||||
f"the maximum allowed length ({max_req_input_len} tokens). "
|
||||
f"Use a shorter input or enable --allow-auto-truncate."
|
||||
)
|
||||
logger.error(error_msg)
|
||||
req.finished_reason = FINISH_ABORT(error_msg)
|
||||
return error_msg
|
||||
|
||||
return None
|
||||
@@ -157,6 +157,7 @@ class ServerArgs:
|
||||
num_continuous_decode_steps: int = 1
|
||||
delete_ckpt_after_loading: bool = False
|
||||
enable_memory_saver: bool = False
|
||||
allow_auto_truncate: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
# Set missing default values
|
||||
@@ -859,6 +860,11 @@ class ServerArgs:
|
||||
action="store_true",
|
||||
help="Allow saving memory using release_memory_occupation and resume_memory_occupation",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--allow-auto-truncate",
|
||||
action="store_true",
|
||||
help="Allow automatically truncating requests that exceed the maximum input length instead of returning an error.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
|
||||
Reference in New Issue
Block a user