[Frontend] Fix request length check and add option to disallow auto truncation in scheduler (#2876)
This commit is contained in:
@@ -78,6 +78,7 @@ from sglang.srt.managers.schedule_policy import (
|
||||
from sglang.srt.managers.session_controller import Session
|
||||
from sglang.srt.managers.tp_worker import TpModelWorker
|
||||
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
|
||||
from sglang.srt.managers.utils import validate_input_length
|
||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
||||
from sglang.srt.mem_cache.radix_cache import RadixCache
|
||||
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
|
||||
@@ -690,14 +691,16 @@ class Scheduler:
|
||||
# By default, only return the logprobs for output tokens
|
||||
req.logprob_start_len = len(req.origin_input_ids) - 1
|
||||
|
||||
# Truncate prompts that are too long
|
||||
if len(req.origin_input_ids) > self.max_req_input_len:
|
||||
logger.warning(
|
||||
"Request length is longer than the KV cache pool size or "
|
||||
"the max context length. Truncated. "
|
||||
f"{len(req.origin_input_ids)=}, {self.max_req_input_len=}."
|
||||
)
|
||||
req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]
|
||||
# Validate prompts length
|
||||
error_msg = validate_input_length(
|
||||
req,
|
||||
self.max_req_input_len,
|
||||
self.server_args.allow_auto_truncate,
|
||||
)
|
||||
|
||||
if error_msg:
|
||||
self.waiting_queue.append(req)
|
||||
return
|
||||
|
||||
req.sampling_params.max_new_tokens = min(
|
||||
(
|
||||
@@ -745,13 +748,12 @@ class Scheduler:
|
||||
)
|
||||
req.tokenizer = self.tokenizer
|
||||
|
||||
# Truncate prompts that are too long
|
||||
if len(req.origin_input_ids) >= self.max_req_input_len:
|
||||
logger.warning(
|
||||
"Request length is longer than the KV cache pool size or "
|
||||
"the max context length. Truncated!!!"
|
||||
)
|
||||
req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]
|
||||
# Validate prompts length
|
||||
validate_input_length(
|
||||
req,
|
||||
self.max_req_input_len,
|
||||
self.server_args.allow_auto_truncate,
|
||||
)
|
||||
|
||||
self.waiting_queue.append(req)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user