Beta spec-overlap for EAGLE (#11398)

Co-authored-by: Lianmin Zheng <15100009+merrymercy@users.noreply.github.com>
Co-authored-by: Hanming Lu <69857889+hanming-lu@users.noreply.github.com>
This commit is contained in:
Liangsheng Yin
2025-10-12 11:02:22 +08:00
committed by GitHub
parent 47c606d3dc
commit 20a6c0a63d
21 changed files with 1567 additions and 108 deletions

View File

@@ -148,13 +148,10 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
from sglang.srt.mem_cache.radix_cache import RadixCache
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
from sglang.srt.model_executor.forward_batch_info import (
ForwardBatch,
ForwardMode,
PPProxyTensors,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.parser.reasoning_parser import ReasoningParser
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.speculative.eagle_info import EagleDraftInput
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.tracing.trace import (
process_tracing_init,
@@ -219,6 +216,14 @@ class GenerationBatchResult:
forward_batch: Optional[ForwardBatch] = None
future_indices: Optional[FutureIndices] = None
# FIXME(lsyin): maybe move to <BetterPlace> ?
# sync path: forward stream -> output processor
accept_lens: Optional[torch.Tensor] = None
last_batch_allocate_lens: Optional[torch.Tensor] = None
# relay path: forward stream -> next step forward
next_draft_input: Optional[EagleDraftInput] = None
def copy_to_cpu(self, return_logprob: bool = False):
"""Copy tensors to CPU in overlap scheduling.
Only the tensors which are needed for processing results are copied,
@@ -238,6 +243,15 @@ class GenerationBatchResult:
"cpu", non_blocking=True
)
self.next_token_ids = self.next_token_ids.to("cpu", non_blocking=True)
if self.accept_lens is not None:
self.accept_lens = self.accept_lens.to("cpu", non_blocking=True)
if self.last_batch_allocate_lens is not None:
self.last_batch_allocate_lens = self.last_batch_allocate_lens.to(
"cpu", non_blocking=True
)
self.copy_done.record()
@classmethod
@@ -273,48 +287,6 @@ class Scheduler(
):
"""A scheduler that manages a tensor parallel GPU worker."""
def launch_draft_worker(
self, gpu_id, tp_rank, moe_ep_rank, server_args, port_args, dp_rank
):
if self.spec_algorithm.is_eagle():
from sglang.srt.speculative.eagle_worker import EAGLEWorker
self.draft_worker = EAGLEWorker(
gpu_id=gpu_id,
tp_rank=tp_rank,
moe_ep_rank=moe_ep_rank,
server_args=server_args,
nccl_port=port_args.nccl_port,
target_worker=self.tp_worker,
dp_rank=dp_rank,
)
elif self.spec_algorithm.is_standalone():
from sglang.srt.speculative.standalone_worker import StandaloneWorker
self.draft_worker = StandaloneWorker(
gpu_id=gpu_id,
tp_rank=tp_rank,
moe_ep_rank=moe_ep_rank,
server_args=server_args,
nccl_port=port_args.nccl_port,
target_worker=self.tp_worker,
dp_rank=dp_rank,
)
elif self.spec_algorithm.is_ngram():
from sglang.srt.speculative.ngram_worker import NGRAMWorker
self.draft_worker = NGRAMWorker(
gpu_id=gpu_id,
tp_rank=tp_rank,
moe_ep_rank=moe_ep_rank,
server_args=server_args,
nccl_port=port_args.nccl_port,
target_worker=self.tp_worker,
dp_rank=dp_rank,
)
else:
self.draft_worker = None
def __init__(
self,
server_args: ServerArgs,
@@ -454,6 +426,7 @@ class Scheduler(
)
# Launch a draft worker for speculative decoding
self.launch_draft_worker(
gpu_id, tp_rank, moe_ep_rank, server_args, port_args, dp_rank
)
@@ -683,6 +656,51 @@ class Scheduler(
]
)
def launch_draft_worker(
self, gpu_id, tp_rank, moe_ep_rank, server_args, port_args, dp_rank
):
if self.spec_algorithm.is_eagle():
from sglang.srt.speculative.eagle_worker import EAGLEWorker
from sglang.srt.speculative.eagle_worker_v2 import EAGLEWorkerV2
WorkerClass = EAGLEWorkerV2 if self.enable_overlap else EAGLEWorker
self.draft_worker = WorkerClass(
gpu_id=gpu_id,
tp_rank=tp_rank,
moe_ep_rank=moe_ep_rank,
server_args=server_args,
nccl_port=port_args.nccl_port,
target_worker=self.tp_worker,
dp_rank=dp_rank,
)
elif self.spec_algorithm.is_standalone():
from sglang.srt.speculative.standalone_worker import StandaloneWorker
self.draft_worker = StandaloneWorker(
gpu_id=gpu_id,
tp_rank=tp_rank,
moe_ep_rank=moe_ep_rank,
server_args=server_args,
nccl_port=port_args.nccl_port,
target_worker=self.tp_worker,
dp_rank=dp_rank,
)
elif self.spec_algorithm.is_ngram():
from sglang.srt.speculative.ngram_worker import NGRAMWorker
self.draft_worker = NGRAMWorker(
gpu_id=gpu_id,
tp_rank=tp_rank,
moe_ep_rank=moe_ep_rank,
server_args=server_args,
nccl_port=port_args.nccl_port,
target_worker=self.tp_worker,
dp_rank=dp_rank,
)
else:
self.draft_worker = None
def init_deterministic_inference_config(self):
"""Initialize deterministic inference configuration for different attention backends."""
if not self.server_args.enable_deterministic_inference:
@@ -965,7 +983,9 @@ class Scheduler(
self.device
).stream(self.copy_stream)
self.future_map = FutureMap(self.max_running_requests, self.device)
self.future_map = FutureMap(
self.max_running_requests, self.device, self.spec_algorithm
)
self.batch_record_buf = [None] * 2
self.batch_record_ct = 0
@@ -2096,7 +2116,7 @@ class Scheduler(
batch_or_worker_batch = batch
if self.spec_algorithm.is_none():
if self.enable_overlap or self.spec_algorithm.is_none():
# FIXME(lsyin): remove this if and finally unify the abstraction
batch_or_worker_batch = batch.get_model_worker_batch()
@@ -2120,39 +2140,49 @@ class Scheduler(
if batch.sampling_info.grammars is not None:
model_worker_batch.delay_sample_launch = True
batch_result = self.model_worker.forward_batch_generation(
batch_or_worker_batch
model_worker_batch
)
# FIXME(lsyin): maybe move this to forward_batch_generation
batch_result.copy_done = torch.get_device_module(
self.device
).Event()
if not model_worker_batch.delay_sample_launch:
self.future_map.store_to_map(
future_indices, batch_result.next_token_ids
)
self.future_map.store_to_map(future_indices, batch_result)
batch_result.copy_to_cpu()
else:
batch_result.future_indices = future_indices
# FIXME(lsyin): move this assignment elsewhere
maybe_future_next_token_ids = -future_indices.indices
future_indices_or_next_token_ids = -future_indices.indices
if batch.is_v2_eagle:
# FIXME(lsyin): tmp code for eagle v2
# We only keep future indices for next draft input
batch.spec_info = batch_result.next_draft_input
batch.spec_info.future_indices = future_indices
# batch.spec_info = EagleDraftInput(
# future_indices=future_indices,
# verify_done=batch_result.next_draft_input.verify_done,
# # FIXME(lsyin): remove the allocate_lens in EagleDraftInput
# allocate_lens=batch_result.next_draft_input.allocate_lens,
# )
# The future value, usually for next batch preparation
# Current implementation strictly synchronizes the seq_lens
batch.seq_lens = batch_result.next_draft_input.new_seq_lens
else:
batch_result = self.model_worker.forward_batch_generation(
batch_or_worker_batch
)
maybe_future_next_token_ids = batch_result.next_token_ids
future_indices_or_next_token_ids = batch_result.next_token_ids
if not self.spec_algorithm.is_none():
# TODO(lsyin): unify this metric-updating logic with non-spec, and move it to decode processing
self.update_spec_metrics(
batch.batch_size(), batch_result.num_accepted_tokens
)
# NOTE: maybe_future_next_token_ids is used in ScheduleBatch,
# NOTE: future_indices_or_next_token_ids is used in ScheduleBatch,
# which can probably be replaced by future_indices later [TODO(lsyin)].
# we shall still keep the original outputs, e.g. next_token_ids
# in the GenerationBatchOutput for processing after copy_done.
batch.output_ids = maybe_future_next_token_ids
batch.output_ids = future_indices_or_next_token_ids
# These 2 values are needed for processing the output, but the values can be
# modified by overlap schedule. So we have to copy them here so that
@@ -2200,7 +2230,7 @@ class Scheduler(
tmp_result.forward_batch,
)
future_indices = tmp_result.future_indices
self.future_map.store_to_map(future_indices, tmp_result.next_token_ids)
self.future_map.store_to_map(future_indices, tmp_result)
tmp_result.copy_to_cpu()
self.result_queue.appendleft((tmp_batch, tmp_result))