diff --git a/python/sglang/srt/mem_cache/allocator.py b/python/sglang/srt/mem_cache/allocator.py index 64a6116c2..8be1be85a 100644 --- a/python/sglang/srt/mem_cache/allocator.py +++ b/python/sglang/srt/mem_cache/allocator.py @@ -434,15 +434,12 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator): device: str, kvcache: KVCache, need_sort: bool, - max_num_extend_tokens: int, ): super().__init__(size, page_size, dtype, device, kvcache, need_sort) self.num_pages = size // page_size - self.max_num_extend_tokens_next_power_of_2 = next_power_of_2( - max_num_extend_tokens - ) self.debug_mode = get_bool_env_var("SGLANG_DEBUG_MEMORY_POOL") self.ret_values = torch.empty((), dtype=torch.int64, device=self.device) + self.seen_max_num_extend_tokens_next_power_of_2 = 1 self.clear() def alloc(self, need_size: int): @@ -480,17 +477,17 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator): (last_loc + 1) % self.page_size == prefix_lens % self.page_size ) + self.seen_max_num_extend_tokens_next_power_of_2 = max( + self.seen_max_num_extend_tokens_next_power_of_2, + next_power_of_2(extend_num_tokens), + ) + bs = len(prefix_lens) if self.need_sort and extend_num_tokens // self.page_size + bs + 1 > len( self.free_pages ): self.merge_and_sort_free() - assert self.max_num_extend_tokens_next_power_of_2 >= extend_num_tokens, ( - f"{self.max_num_extend_tokens_next_power_of_2=} >= {extend_num_tokens=} does not hold. " - f"If this happens in PD, consider letting chunked_prefill_size in D be as large as in P" - ) - out_indices = torch.empty( (extend_num_tokens,), dtype=torch.int64, device=self.device ) @@ -503,7 +500,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator): self.ret_values, next_power_of_2(bs), self.page_size, - self.max_num_extend_tokens_next_power_of_2, + self.seen_max_num_extend_tokens_next_power_of_2, ) if self.debug_mode: diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 41b9ce93f..b05973c81 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -1353,11 +1353,6 @@ class ModelRunner: # Initialize token_to_kv_pool_allocator need_sort = self.server_args.disaggregation_mode in ("decode", "prefill") - max_num_extend_tokens = ( - self.server_args.chunked_prefill_size - if self.server_args.chunked_prefill_size > 0 - else self.server_args.max_prefill_tokens - ) if self.token_to_kv_pool_allocator is None: if self.server_args.attention_backend == "ascend": self.token_to_kv_pool_allocator = AscendPagedTokenToKVPoolAllocator( @@ -1396,7 +1391,6 @@ class ModelRunner: device=self.device, kvcache=self.token_to_kv_pool, need_sort=need_sort, - max_num_extend_tokens=max_num_extend_tokens, ) else: assert self.is_draft_worker