From 25549433e8ff9965cbbbb8efe9d450753ca23356 Mon Sep 17 00:00:00 2001 From: Liangsheng Yin Date: Tue, 24 Jun 2025 02:12:29 +0800 Subject: [PATCH] Fix prefill OOM due to wrong token calculation when page > 1 (#7397) --- python/sglang/srt/managers/schedule_policy.py | 62 ++++++++++++------- 1 file changed, 39 insertions(+), 23 deletions(-) diff --git a/python/sglang/srt/managers/schedule_policy.py b/python/sglang/srt/managers/schedule_policy.py index 87a6a145b..ba3dd8d4e 100644 --- a/python/sglang/srt/managers/schedule_policy.py +++ b/python/sglang/srt/managers/schedule_policy.py @@ -55,6 +55,9 @@ IN_BATCH_PREFIX_CACHING_DEPRIORITIZE_THRESHOLD = int( ) +IGNORE_EOS_RESERVE_TOKENS = 1 + + class CacheAwarePolicy(Enum): """Scheduling policies that are aware of the tree cache.""" @@ -293,6 +296,7 @@ class PrefillAdder: self.can_run_list = [] self.new_chunked_req = None self.log_hit_tokens = 0 + # TODO(lsyin): report the real input tokens excluding page alignment self.log_input_tokens = 0 if running_batch is not None: @@ -323,6 +327,9 @@ class PrefillAdder: - self.cur_rem_token_offset ) + def ceil_paged_tokens(self, tokens: int) -> int: + return -(-tokens // self.page_size) * self.page_size + def budget_state(self): if self.rem_total_tokens <= 0 or self.cur_rem_tokens <= 0: return AddReqResult.NO_TOKEN @@ -334,9 +341,12 @@ class PrefillAdder: return AddReqResult.CONTINUE - def _prefill_one_req( + def _update_prefill_budget( self, prefix_len: int, extend_input_len: int, max_new_tokens: int ): + # TODO(lsyin): check this workaround logic, which only ensures the prefill will not out of memory, and may be too conservative + extend_input_len = self.ceil_paged_tokens(extend_input_len) + self.rem_total_token_offset += extend_input_len + max_new_tokens self.cur_rem_token_offset += extend_input_len self.rem_input_tokens -= extend_input_len @@ -351,7 +361,7 @@ class PrefillAdder: req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens) req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len] self.can_run_list.append(req) - self._prefill_one_req( + self._update_prefill_budget( 0, req.extend_input_len, ( @@ -373,6 +383,12 @@ class PrefillAdder: self.tree_cache.dec_lock_ref(last_node) def add_one_req_ignore_eos(self, req: Req, has_chunked_req: bool): + # Early exit if no enough tokens for the input tokens + if self.ceil_paged_tokens(req.extend_input_len) > min( + self.cur_rem_tokens, self.rem_total_tokens + ): + return AddReqResult.NO_TOKEN + def add_req_state(r, insert_sort=False): new_token_ratio = ( 1.0 if r.sampling_params.ignore_eos else self.new_token_ratio @@ -382,15 +398,17 @@ class PrefillAdder: ) tokens_occupied = len(r.origin_input_ids) + len(r.output_ids) - if tokens_left > 0: - if not insert_sort: - self.req_states.append((tokens_left, tokens_occupied)) - else: - i = 0 - for i in range(len(self.req_states)): - if tokens_left <= self.req_states[i][0]: - break - self.req_states.insert(i, (tokens_left, tokens_occupied)) + if tokens_left <= 0: + return + + if not insert_sort: + self.req_states.append((tokens_left, tokens_occupied)) + else: + i = 0 + for i in range(len(self.req_states)): + if tokens_left <= self.req_states[i][0]: + break + self.req_states.insert(i, (tokens_left, tokens_occupied)) if self.req_states is None: self.req_states = [] @@ -407,13 +425,11 @@ class PrefillAdder: cur_rem_tokens = self.cur_rem_tokens - len(req.origin_input_ids) tokens_freed = 0 for i, (tokens_left, tokens_occupied) in enumerate(self.req_states): - decode_steps = ( - self.req_states[i + 1][0] - if i + 1 < len(self.req_states) - else tokens_left - ) + # tokens_left gives a reservative calculation as the last token is not stored bs = len(self.req_states) - i - if cur_rem_tokens + tokens_freed - decode_steps * bs <= 0: + min_free_tokens = cur_rem_tokens + tokens_freed - tokens_left * bs + # reserve tokens for corner cases + if min_free_tokens <= IGNORE_EOS_RESERVE_TOKENS * bs: return AddReqResult.NO_TOKEN tokens_freed += tokens_occupied @@ -423,7 +439,7 @@ class PrefillAdder: ): # Non-chunked prefill self.can_run_list.append(req) - self._prefill_one_req( + self._update_prefill_budget( 0, req.extend_input_len, min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION), @@ -439,7 +455,7 @@ class PrefillAdder: req.fill_ids = req.fill_ids[:trunc_len] self.can_run_list.append(req) self.new_chunked_req = req - self._prefill_one_req(0, trunc_len, 0) + self._update_prefill_budget(0, trunc_len, 0) return self.budget_state() @@ -453,7 +469,7 @@ class PrefillAdder: # adjusting the input_tokens based on host_hit_length and page_size real_input_tokens = req.extend_input_len - req.host_hit_length - real_input_tokens = -(-real_input_tokens // self.page_size) * self.page_size + real_input_tokens = self.ceil_paged_tokens(real_input_tokens) prefix_len = len(req.prefix_indices) if total_tokens >= self.rem_total_tokens: @@ -475,7 +491,7 @@ class PrefillAdder: req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices) prefix_len = len(req.prefix_indices) - input_tokens = -(-req.extend_input_len // self.page_size) * self.page_size + input_tokens = self.ceil_paged_tokens(req.extend_input_len) if input_tokens >= self.rem_input_tokens and len(self.can_run_list) != 0: return AddReqResult.OTHER @@ -484,7 +500,7 @@ class PrefillAdder: # Non-chunked prefill self.can_run_list.append(req) self.tree_cache.inc_lock_ref(req.last_node) - self._prefill_one_req( + self._update_prefill_budget( prefix_len, input_tokens, min( @@ -505,6 +521,6 @@ class PrefillAdder: self.can_run_list.append(req) self.new_chunked_req = req self.tree_cache.inc_lock_ref(req.last_node) - self._prefill_one_req(prefix_len, trunc_len, 0) + self._update_prefill_budget(prefix_len, trunc_len, 0) return self.budget_state()