diff --git a/python/sglang/srt/managers/schedule_policy.py b/python/sglang/srt/managers/schedule_policy.py index 7cab55c74..a3a099b83 100644 --- a/python/sglang/srt/managers/schedule_policy.py +++ b/python/sglang/srt/managers/schedule_policy.py @@ -24,6 +24,7 @@ import torch from sglang.srt.managers.schedule_batch import Req, ScheduleBatch from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache +from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode # Clip the estimation of max_new_tokens for the request whose max_new_tokens is very large. @@ -250,23 +251,24 @@ class PrefillAdder: def __init__( self, tree_cache: BasePrefixCache, + token_to_kv_pool: BaseTokenToKVPool, running_batch: ScheduleBatch, new_token_ratio: float, - rem_total_tokens: int, rem_input_tokens: int, rem_chunk_tokens: Optional[int], mixed_with_decode_tokens: int = 0, ): self.tree_cache = tree_cache + self.token_to_kv_pool = token_to_kv_pool self.running_batch = running_batch self.new_token_ratio = new_token_ratio - self.rem_total_tokens = rem_total_tokens - mixed_with_decode_tokens self.rem_input_tokens = rem_input_tokens - mixed_with_decode_tokens self.rem_chunk_tokens = rem_chunk_tokens if self.rem_chunk_tokens is not None: self.rem_chunk_tokens -= mixed_with_decode_tokens - self.cur_rem_tokens = rem_total_tokens - mixed_with_decode_tokens + self.rem_total_token_offset = mixed_with_decode_tokens + self.cur_rem_token_offset = mixed_with_decode_tokens self.req_states = None self.can_run_list = [] @@ -275,8 +277,7 @@ class PrefillAdder: self.log_input_tokens = 0 if running_batch is not None: - # Pre-remove the tokens which will be occupied by the running requests - self.rem_total_tokens -= sum( + self.rem_total_token_offset += sum( [ min( (r.sampling_params.max_new_tokens - len(r.output_ids)), @@ -287,6 +288,22 @@ class PrefillAdder: ] ) + @property + def rem_total_tokens(self): + return ( + self.token_to_kv_pool.available_size() + + self.tree_cache.evictable_size() + - self.rem_total_token_offset + ) + + @property + def cur_rem_tokens(self): + return ( + self.token_to_kv_pool.available_size() + + self.tree_cache.evictable_size() + - self.cur_rem_token_offset + ) + def budget_state(self): if self.rem_total_tokens <= 0 or self.cur_rem_tokens <= 0: return AddReqResult.NO_TOKEN @@ -301,8 +318,8 @@ class PrefillAdder: def _prefill_one_req( self, prefix_len: int, extend_input_len: int, max_new_tokens: int ): - self.rem_total_tokens -= extend_input_len + max_new_tokens - self.cur_rem_tokens -= extend_input_len + self.rem_total_token_offset += extend_input_len + max_new_tokens + self.cur_rem_token_offset += extend_input_len self.rem_input_tokens -= extend_input_len if self.rem_chunk_tokens is not None: self.rem_chunk_tokens -= extend_input_len @@ -332,12 +349,10 @@ class PrefillAdder: @contextmanager def _lock_node(self, last_node: TreeNode): try: - delta = self.tree_cache.inc_lock_ref(last_node) - self.rem_total_tokens += delta + self.tree_cache.inc_lock_ref(last_node) yield None finally: - delta = self.tree_cache.dec_lock_ref(last_node) - self.rem_total_tokens += delta + self.tree_cache.dec_lock_ref(last_node) def add_one_req_ignore_eos(self, req: Req): def add_req_state(r, insert_sort=False): diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index f8bb7d334..bc963e008 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -891,9 +891,9 @@ class Scheduler: # Prefill policy adder = PrefillAdder( self.tree_cache, + self.token_to_kv_pool, self.running_batch, self.new_token_ratio, - self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size(), self.max_prefill_tokens, self.chunked_prefill_size, running_bs if self.is_mixed_chunk else 0,