diff --git a/python/sglang/srt/managers/schedule_policy.py b/python/sglang/srt/managers/schedule_policy.py index 9fb7a27aa..22c18f2e4 100644 --- a/python/sglang/srt/managers/schedule_policy.py +++ b/python/sglang/srt/managers/schedule_policy.py @@ -19,6 +19,7 @@ import os import random from collections import defaultdict from contextlib import contextmanager +from enum import Enum, auto from typing import Dict, List, Optional from sglang.srt.managers.schedule_batch import Req, ScheduleBatch @@ -104,6 +105,12 @@ class SchedulePolicy: q.extend(last_node_to_reqs[cur_node]) +class AddReqResult(Enum): + CONTINUE = auto() # Continue to add requests + NO_TOKEN = auto() # No token left + OTHER = auto() # Other reasons to stop adding requests + + class PrefillAdder: def __init__( self, @@ -145,17 +152,16 @@ class PrefillAdder: ] ) - def no_remaining_tokens(self): - return ( - self.rem_total_tokens <= 0 - or self.rem_input_tokens <= 0 - or ( - self.rem_chunk_tokens <= 0 - if self.rem_chunk_tokens is not None - else False - ) - or self.cur_rem_tokens <= 0 - ) + def budget_state(self): + if self.rem_total_tokens <= 0 or self.cur_rem_tokens <= 0: + return AddReqResult.NO_TOKEN + + if self.rem_input_tokens <= 0 or ( + self.rem_chunk_tokens is not None and self.rem_chunk_tokens <= 0 + ): + return AddReqResult.OTHER + + return AddReqResult.CONTINUE def _prefill_one_req( self, prefix_len: int, extend_input_len: int, max_new_tokens: int @@ -239,7 +245,7 @@ class PrefillAdder: ) bs = len(self.req_states) - i if cur_rem_tokens + tokens_freed - decode_steps * bs <= 0: - return False + return AddReqResult.NO_TOKEN tokens_freed += tokens_occupied if req.extend_input_len <= self.rem_chunk_tokens: @@ -258,7 +264,7 @@ class PrefillAdder: self.new_inflight_req = req self._prefill_one_req(0, trunc_len, 0) - return True + return self.budget_state() def add_one_req(self, req: Req): if req.sampling_params.ignore_eos and self.tree_cache.disable: @@ -271,14 +277,14 @@ class PrefillAdder: prefix_len = len(req.prefix_indices) if total_tokens >= self.rem_total_tokens: - return False + return AddReqResult.NO_TOKEN if input_tokens > self.rem_input_tokens and len(self.can_run_list) != 0: - return False + return AddReqResult.OTHER with self._lock_node(req.last_node): if total_tokens > self.rem_total_tokens: - return False + return AddReqResult.NO_TOKEN if ( self.rem_chunk_tokens is None @@ -297,7 +303,7 @@ class PrefillAdder: # Chunked prefill trunc_len = self.rem_chunk_tokens if trunc_len == 0: - return False + return AddReqResult.OTHER req.extend_input_len = trunc_len req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len] @@ -306,4 +312,4 @@ class PrefillAdder: self.tree_cache.inc_lock_ref(req.last_node) self._prefill_one_req(prefix_len, trunc_len, 0) - return True and not self.no_remaining_tokens() + return self.budget_state() diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 9b3fc5cef..f4dcbb650 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -50,7 +50,11 @@ from sglang.srt.managers.schedule_batch import ( Req, ScheduleBatch, ) -from sglang.srt.managers.schedule_policy import PrefillAdder, SchedulePolicy +from sglang.srt.managers.schedule_policy import ( + AddReqResult, + PrefillAdder, + SchedulePolicy, +) from sglang.srt.managers.tp_worker import TpModelWorker from sglang.srt.mem_cache.chunk_cache import ChunkCache from sglang.srt.mem_cache.radix_cache import RadixCache @@ -493,16 +497,15 @@ class Scheduler: self.batch_is_full = True break - if adder.no_remaining_tokens(): + if running_bs + len(adder.can_run_list) >= self.max_running_requests: self.batch_is_full = True break + req.init_next_round_input(None if prefix_computed else self.tree_cache) res = adder.add_one_req(req) - if ( - not res - or running_bs + len(adder.can_run_list) >= self.max_running_requests - ): - self.batch_is_full = True + if res != AddReqResult.CONTINUE: + if res == AddReqResult.NO_TOKEN: + self.batch_is_full = True break can_run_list = adder.can_run_list