Beta spec-overlap for EAGLE (#11398)
Co-authored-by: Lianmin Zheng <15100009+merrymercy@users.noreply.github.com> Co-authored-by: Hanming Lu <69857889+hanming-lu@users.noreply.github.com>
This commit is contained in:
@@ -61,8 +61,12 @@ from sglang.srt.mem_cache.allocator import (
|
||||
)
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
|
||||
from sglang.srt.mem_cache.common import alloc_for_decode, alloc_for_extend
|
||||
from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool
|
||||
from sglang.srt.mem_cache.common import (
|
||||
alloc_for_decode,
|
||||
alloc_for_extend,
|
||||
alloc_token_slots,
|
||||
)
|
||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
||||
from sglang.srt.mem_cache.radix_cache import RadixKey
|
||||
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
|
||||
from sglang.srt.metrics.collector import SchedulerMetricsCollector, TimeStats
|
||||
@@ -71,6 +75,7 @@ from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import flatten_nested_list
|
||||
from sglang.srt.utils.common import next_power_of_2
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
@@ -1067,6 +1072,38 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
def is_empty(self):
|
||||
return len(self.reqs) == 0
|
||||
|
||||
def allocate_for_eagle_v2(self):
|
||||
from sglang.srt.speculative.eagle_info import EagleDraftInput
|
||||
from sglang.srt.speculative.spec_utils import assign_req_to_token_pool
|
||||
|
||||
bs = self.batch_size()
|
||||
|
||||
assert self.spec_info.is_draft_input()
|
||||
draft_input: EagleDraftInput = self.spec_info
|
||||
|
||||
# FIXME(lsyin): now implementation does not enable over-allocation
|
||||
# Now seq_lens and allocate_lens are correct
|
||||
self.maybe_wait_verify_done()
|
||||
|
||||
new_allocate_lens = self.seq_lens + EagleDraftInput.ALLOC_LEN_PER_DECODE
|
||||
num_needed_tokens = (new_allocate_lens - draft_input.allocate_lens).sum().item()
|
||||
out_cache_loc = alloc_token_slots(self.tree_cache, num_needed_tokens)
|
||||
|
||||
assign_req_to_token_pool[(bs,)](
|
||||
self.req_pool_indices,
|
||||
self.req_to_token_pool.req_to_token,
|
||||
draft_input.allocate_lens,
|
||||
new_allocate_lens,
|
||||
out_cache_loc,
|
||||
self.req_to_token_pool.req_to_token.shape[1],
|
||||
next_power_of_2(bs),
|
||||
)
|
||||
draft_input.allocate_lens = new_allocate_lens
|
||||
|
||||
# FIXME(lsyin): remove seq_lens_sum calculation
|
||||
self.seq_lens_cpu = self.seq_lens.cpu()
|
||||
self.seq_lens_sum = self.seq_lens_cpu.sum().item()
|
||||
|
||||
def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]):
|
||||
self.encoder_lens_cpu = []
|
||||
self.encoder_cached = []
|
||||
@@ -1507,15 +1544,20 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
self.model_config.vocab_size,
|
||||
)
|
||||
|
||||
@property
|
||||
def is_v2_eagle(self):
|
||||
# FIXME: finally deprecate is_v2_eagle
|
||||
return self.enable_overlap and self.spec_algorithm.is_eagle()
|
||||
|
||||
def prepare_for_decode(self):
|
||||
self.forward_mode = ForwardMode.DECODE
|
||||
bs = len(self.reqs)
|
||||
|
||||
if (
|
||||
self.spec_algorithm.is_eagle()
|
||||
or self.spec_algorithm.is_standalone()
|
||||
or self.spec_algorithm.is_ngram()
|
||||
):
|
||||
if self.is_v2_eagle:
|
||||
# FIXME(lsyin): make this sync optional
|
||||
self.allocate_for_eagle_v2()
|
||||
|
||||
if not self.spec_algorithm.is_none():
|
||||
# if spec decoding is used, the decode batch is prepared inside
|
||||
# `forward_batch_speculative_generation` after running draft models.
|
||||
return
|
||||
@@ -1566,11 +1608,23 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
self.orig_seq_lens.add_(1)
|
||||
self.seq_lens_sum += bs
|
||||
|
||||
def maybe_wait_verify_done(self):
|
||||
if self.is_v2_eagle:
|
||||
from sglang.srt.speculative.eagle_info import EagleDraftInput
|
||||
|
||||
draft_input: EagleDraftInput = self.spec_info
|
||||
if draft_input.verify_done is not None:
|
||||
draft_input.verify_done.synchronize()
|
||||
|
||||
def filter_batch(
|
||||
self,
|
||||
chunked_req_to_exclude: Optional[Union[Req, List[Req]]] = None,
|
||||
keep_indices: Optional[List[int]] = None,
|
||||
):
|
||||
# FIXME(lsyin): used here to get the correct seq_lens
|
||||
# The batch has been launched but we need it verified to get correct next batch info
|
||||
self.maybe_wait_verify_done()
|
||||
|
||||
if keep_indices is None:
|
||||
if isinstance(chunked_req_to_exclude, Req):
|
||||
chunked_req_to_exclude = [chunked_req_to_exclude]
|
||||
@@ -1633,6 +1687,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
)
|
||||
|
||||
def merge_batch(self, other: "ScheduleBatch"):
|
||||
# NOTE: in v2 eagle mode, we do not need wait verify here because
|
||||
# 1) current batch is always prefill, whose seq_lens and allocate_lens are not a future
|
||||
# 2) other batch is always decode, which is finished in previous step
|
||||
|
||||
# Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
|
||||
# orchestrator.merge() depends on Batch.reqs during preparation of each penalizers, so it
|
||||
# needs to be called with pre-merged Batch.reqs.
|
||||
@@ -1757,6 +1815,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
|
||||
is_extend_in_batch=self.is_extend_in_batch,
|
||||
is_prefill_only=self.is_prefill_only,
|
||||
seq_lens_cpu=self.seq_lens_cpu,
|
||||
enable_overlap=self.enable_overlap,
|
||||
)
|
||||
|
||||
def _evict_tree_cache_if_needed(self, num_tokens: int):
|
||||
|
||||
Reference in New Issue
Block a user