diff --git a/python/sglang/srt/managers/policy_scheduler.py b/python/sglang/srt/managers/policy_scheduler.py index 925b74b91..4fd0ea290 100644 --- a/python/sglang/srt/managers/policy_scheduler.py +++ b/python/sglang/srt/managers/policy_scheduler.py @@ -18,12 +18,16 @@ limitations under the License. import random from collections import defaultdict from contextlib import contextmanager -from typing import Dict, List +from typing import Dict, List, Optional from sglang.srt.managers.schedule_batch import Req, ScheduleBatch from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.radix_cache import TreeNode +# Clip the max new tokens for the request whose max_new_tokens is very large. +# This can prevent the server from being too conservative. +CLIP_MAX_NEW_TOKENS = 4096 + class PolicyScheduler: def __init__(self, policy: str, tree_cache: BasePrefixCache): @@ -98,7 +102,7 @@ class PrefillAdder: tree_cache: BasePrefixCache, rem_total_tokens: int, rem_input_tokens: int, - rem_chunk_tokens: int, + rem_chunk_tokens: Optional[int], ): self.tree_cache = tree_cache self.rem_total_tokens = rem_total_tokens @@ -126,7 +130,11 @@ class PrefillAdder: ): self.rem_total_tokens -= sum( [ - (r.sampling_params.max_new_tokens - len(r.output_ids)) * new_token_ratio + min( + (r.sampling_params.max_new_tokens - len(r.output_ids)), + CLIP_MAX_NEW_TOKENS, + ) + * new_token_ratio for r in running_batch.reqs ] ) @@ -151,7 +159,11 @@ class PrefillAdder: self._prefill_one_req( len(req.prefix_indices), req.extend_input_len, - req.sampling_params.max_new_tokens if not truncated else 0, + ( + min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS) + if not truncated + else 0 + ), ) # Return if chunked prefill not finished @@ -168,7 +180,9 @@ class PrefillAdder: self.rem_total_tokens += delta def add_one_req(self, req: Req): - total_tokens = req.extend_input_len + req.sampling_params.max_new_tokens + total_tokens = req.extend_input_len + min( + req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS + ) input_tokens = req.extend_input_len prefix_len = len(req.prefix_indices) @@ -191,7 +205,9 @@ class PrefillAdder: self.can_run_list.append(req) self.tree_cache.inc_lock_ref(req.last_node) self._prefill_one_req( - prefix_len, input_tokens, req.sampling_params.max_new_tokens + prefix_len, + input_tokens, + min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS), ) else: # Chunked prefill