Beta spec-overlap for EAGLE (#11398)

Co-authored-by: Lianmin Zheng <15100009+merrymercy@users.noreply.github.com>
Co-authored-by: Hanming Lu <69857889+hanming-lu@users.noreply.github.com>
This commit is contained in:
Liangsheng Yin
2025-10-12 11:02:22 +08:00
committed by GitHub
parent 47c606d3dc
commit 20a6c0a63d
21 changed files with 1567 additions and 108 deletions

View File

@@ -61,8 +61,12 @@ from sglang.srt.mem_cache.allocator import (
)
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
from sglang.srt.mem_cache.common import alloc_for_decode, alloc_for_extend
from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool
from sglang.srt.mem_cache.common import (
alloc_for_decode,
alloc_for_extend,
alloc_token_slots,
)
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
from sglang.srt.mem_cache.radix_cache import RadixKey
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
from sglang.srt.metrics.collector import SchedulerMetricsCollector, TimeStats
@@ -71,6 +75,7 @@ from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import flatten_nested_list
from sglang.srt.utils.common import next_power_of_2
if TYPE_CHECKING:
from sglang.srt.configs.model_config import ModelConfig
@@ -1067,6 +1072,38 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
def is_empty(self):
return len(self.reqs) == 0
def allocate_for_eagle_v2(self):
from sglang.srt.speculative.eagle_info import EagleDraftInput
from sglang.srt.speculative.spec_utils import assign_req_to_token_pool
bs = self.batch_size()
assert self.spec_info.is_draft_input()
draft_input: EagleDraftInput = self.spec_info
# FIXME(lsyin): now implementation does not enable over-allocation
# Now seq_lens and allocate_lens are correct
self.maybe_wait_verify_done()
new_allocate_lens = self.seq_lens + EagleDraftInput.ALLOC_LEN_PER_DECODE
num_needed_tokens = (new_allocate_lens - draft_input.allocate_lens).sum().item()
out_cache_loc = alloc_token_slots(self.tree_cache, num_needed_tokens)
assign_req_to_token_pool[(bs,)](
self.req_pool_indices,
self.req_to_token_pool.req_to_token,
draft_input.allocate_lens,
new_allocate_lens,
out_cache_loc,
self.req_to_token_pool.req_to_token.shape[1],
next_power_of_2(bs),
)
draft_input.allocate_lens = new_allocate_lens
# FIXME(lsyin): remove seq_lens_sum calculation
self.seq_lens_cpu = self.seq_lens.cpu()
self.seq_lens_sum = self.seq_lens_cpu.sum().item()
def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]):
self.encoder_lens_cpu = []
self.encoder_cached = []
@@ -1507,15 +1544,20 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.model_config.vocab_size,
)
@property
def is_v2_eagle(self):
# FIXME: finally deprecate is_v2_eagle
return self.enable_overlap and self.spec_algorithm.is_eagle()
def prepare_for_decode(self):
self.forward_mode = ForwardMode.DECODE
bs = len(self.reqs)
if (
self.spec_algorithm.is_eagle()
or self.spec_algorithm.is_standalone()
or self.spec_algorithm.is_ngram()
):
if self.is_v2_eagle:
# FIXME(lsyin): make this sync optional
self.allocate_for_eagle_v2()
if not self.spec_algorithm.is_none():
# if spec decoding is used, the decode batch is prepared inside
# `forward_batch_speculative_generation` after running draft models.
return
@@ -1566,11 +1608,23 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.orig_seq_lens.add_(1)
self.seq_lens_sum += bs
def maybe_wait_verify_done(self):
if self.is_v2_eagle:
from sglang.srt.speculative.eagle_info import EagleDraftInput
draft_input: EagleDraftInput = self.spec_info
if draft_input.verify_done is not None:
draft_input.verify_done.synchronize()
def filter_batch(
self,
chunked_req_to_exclude: Optional[Union[Req, List[Req]]] = None,
keep_indices: Optional[List[int]] = None,
):
# FIXME(lsyin): used here to get the correct seq_lens
# The batch has been launched but we need it verified to get correct next batch info
self.maybe_wait_verify_done()
if keep_indices is None:
if isinstance(chunked_req_to_exclude, Req):
chunked_req_to_exclude = [chunked_req_to_exclude]
@@ -1633,6 +1687,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
)
def merge_batch(self, other: "ScheduleBatch"):
# NOTE: in v2 eagle mode, we do not need wait verify here because
# 1) current batch is always prefill, whose seq_lens and allocate_lens are not a future
# 2) other batch is always decode, which is finished in previous step
# Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
# orchestrator.merge() depends on Batch.reqs during preparation of each penalizers, so it
# needs to be called with pre-merged Batch.reqs.
@@ -1757,6 +1815,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
is_extend_in_batch=self.is_extend_in_batch,
is_prefill_only=self.is_prefill_only,
seq_lens_cpu=self.seq_lens_cpu,
enable_overlap=self.enable_overlap,
)
def _evict_tree_cache_if_needed(self, num_tokens: int):