From a93f10a7222bfe6d51a7130eb30e8ba0cf58e0e3 Mon Sep 17 00:00:00 2001 From: Liangsheng Yin Date: Sun, 19 Oct 2025 02:09:13 +0800 Subject: [PATCH] [overlap-spec] support page size > 1 (#11772) --- .../scheduler_output_processor_mixin.py | 31 ++++++++++------- .../sglang/srt/speculative/eagle_info_v2.py | 34 ++++++++++++++++--- 2 files changed, 48 insertions(+), 17 deletions(-) diff --git a/python/sglang/srt/managers/scheduler_output_processor_mixin.py b/python/sglang/srt/managers/scheduler_output_processor_mixin.py index 64d34bf03..55ce5ebd5 100644 --- a/python/sglang/srt/managers/scheduler_output_processor_mixin.py +++ b/python/sglang/srt/managers/scheduler_output_processor_mixin.py @@ -14,6 +14,7 @@ from sglang.srt.managers.io_struct import ( BatchTokenIDOutput, ) from sglang.srt.managers.schedule_batch import BaseFinishReason, Req, ScheduleBatch +from sglang.srt.utils.common import ceil_div if TYPE_CHECKING: from sglang.srt.managers.scheduler import ( @@ -258,22 +259,22 @@ class SchedulerOutputProcessorMixin: if self.enable_overlap and req.finished(): indices_to_free = None - if self.page_size == 1: - if batch.spec_algorithm.is_eagle(): - from sglang.srt.speculative.eagle_info import EagleDraftInput + if batch.spec_algorithm.is_eagle(): + from sglang.srt.speculative.eagle_info import EagleDraftInput - end_p = allocate_lens_list[i] - start_p = end_p - EagleDraftInput.ALLOC_LEN_PER_DECODE - indices_to_free = self.req_to_token_pool.req_to_token[ - req.req_pool_idx - ][start_p:end_p] - else: + end_p = allocate_lens_list[i] + start_p = end_p - EagleDraftInput.ALLOC_LEN_PER_DECODE + if self.page_size > 1: + start_p = ceil_div(start_p, self.page_size) * self.page_size + + indices_to_free = self.req_to_token_pool.req_to_token[ + req.req_pool_idx + ][start_p:end_p] + + else: + if self.page_size == 1: # Free the one extra delayed token indices_to_free = batch.out_cache_loc[i : i + 1] - else: - if batch.spec_algorithm.is_eagle(): - # TODO(spec-v2): support eagle with page_size > 1 - raise NotImplementedError() else: if ( len(req.origin_input_ids) + len(req.output_ids) - 1 @@ -299,6 +300,10 @@ class SchedulerOutputProcessorMixin: # 2) overlap eagle and the current batch is prefill. This seq will not run extra iteration. start_p = batch.seq_lens_cpu[i] + accept_lens_list[i] end_p = allocate_lens_list[i] + + if self.page_size > 1: + start_p = ceil_div(start_p, self.page_size) * self.page_size + indices_to_free = self.req_to_token_pool.req_to_token[ req.req_pool_idx ][start_p:end_p] diff --git a/python/sglang/srt/speculative/eagle_info_v2.py b/python/sglang/srt/speculative/eagle_info_v2.py index 6ba42f326..c5fddca45 100644 --- a/python/sglang/srt/speculative/eagle_info_v2.py +++ b/python/sglang/srt/speculative/eagle_info_v2.py @@ -10,7 +10,11 @@ import triton.language as tl from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.managers.schedule_batch import ModelWorkerBatch, ScheduleBatch -from sglang.srt.mem_cache.common import alloc_token_slots +from sglang.srt.mem_cache.common import ( + alloc_paged_token_slots_extend, + alloc_token_slots, + get_last_loc, +) from sglang.srt.mem_cache.memory_pool import ReqToTokenPool from sglang.srt.model_executor.forward_batch_info import ( CaptureHiddenMode, @@ -82,9 +86,31 @@ class EagleDraftInputV2Mixin: # Now seq_lens and allocate_lens are correct batch.maybe_wait_verify_done() - new_allocate_lens = batch.seq_lens + self.ALLOC_LEN_PER_DECODE - num_needed_tokens = (new_allocate_lens - self.allocate_lens).sum().item() - out_cache_loc = alloc_token_slots(batch.tree_cache, num_needed_tokens) + page_size = batch.token_to_kv_pool_allocator.page_size + + if page_size == 1: + new_allocate_lens = batch.seq_lens + self.ALLOC_LEN_PER_DECODE + num_needed_tokens = (new_allocate_lens - self.allocate_lens).sum().item() + out_cache_loc = alloc_token_slots(batch.tree_cache, num_needed_tokens) + else: + last_loc = get_last_loc( + batch.req_to_token_pool.req_to_token, + batch.req_pool_indices, + self.allocate_lens, + ) + new_allocate_lens = batch.seq_lens + self.ALLOC_LEN_PER_DECODE + new_allocate_lens_cpu = new_allocate_lens.cpu() + allocate_lens_cpu = self.allocate_lens.cpu() + extend_num_tokens = sum(new_allocate_lens_cpu - allocate_lens_cpu).item() + out_cache_loc = alloc_paged_token_slots_extend( + batch.tree_cache, + self.allocate_lens, + allocate_lens_cpu, + new_allocate_lens, + new_allocate_lens_cpu, + last_loc, + extend_num_tokens, + ) assign_req_to_token_pool[(bs,)]( batch.req_pool_indices,