From f8e460930ae1b8044c2e05614dd02606f76235b2 Mon Sep 17 00:00:00 2001 From: Zhiqiang Xie Date: Mon, 5 May 2025 16:02:55 -0700 Subject: [PATCH] Fix prefill OOM error in the case of large page size (#5081) --- python/sglang/srt/managers/schedule_policy.py | 10 ++++++++-- python/sglang/srt/managers/scheduler.py | 1 + python/sglang/srt/mem_cache/chunk_cache.py | 2 ++ 3 files changed, 11 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/schedule_policy.py b/python/sglang/srt/managers/schedule_policy.py index 8922e6e9d..4b36281ac 100644 --- a/python/sglang/srt/managers/schedule_policy.py +++ b/python/sglang/srt/managers/schedule_policy.py @@ -455,7 +455,10 @@ class PrefillAdder: total_tokens = req.extend_input_len + min( req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION ) - input_tokens = req.extend_input_len + input_tokens = ( + -(-req.extend_input_len // self.tree_cache.page_size) + * self.tree_cache.page_size + ) prefix_len = len(req.prefix_indices) if total_tokens >= self.rem_total_tokens: @@ -477,7 +480,10 @@ class PrefillAdder: req.last_node_global, req.prefix_indices ) req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices) - input_tokens = req.extend_input_len + input_tokens = ( + -(-req.extend_input_len // self.tree_cache.page_size) + * self.tree_cache.page_size + ) prefix_len = len(req.prefix_indices) if self.rem_chunk_tokens is None or input_tokens <= self.rem_chunk_tokens: diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 76196d95a..fb9566093 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -502,6 +502,7 @@ class Scheduler( self.tree_cache = ChunkCache( req_to_token_pool=self.req_to_token_pool, token_to_kv_pool_allocator=self.token_to_kv_pool_allocator, + page_size=self.page_size, ) else: if self.enable_hierarchical_cache: diff --git a/python/sglang/srt/mem_cache/chunk_cache.py b/python/sglang/srt/mem_cache/chunk_cache.py index 3cb540fc6..2c51cb855 100644 --- a/python/sglang/srt/mem_cache/chunk_cache.py +++ b/python/sglang/srt/mem_cache/chunk_cache.py @@ -24,9 +24,11 @@ class ChunkCache(BasePrefixCache): self, req_to_token_pool: ReqToTokenPool, token_to_kv_pool_allocator: TokenToKVPoolAllocator, + page_size: int, ): self.req_to_token_pool = req_to_token_pool self.token_to_kv_pool_allocator = token_to_kv_pool_allocator + self.page_size = page_size def reset(self): pass