Beta spec-overlap for EAGLE (#11398)
Co-authored-by: Lianmin Zheng <15100009+merrymercy@users.noreply.github.com> Co-authored-by: Hanming Lu <69857889+hanming-lu@users.noreply.github.com>
This commit is contained in:
@@ -148,13 +148,10 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
|
||||
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
|
||||
from sglang.srt.mem_cache.radix_cache import RadixCache
|
||||
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
|
||||
from sglang.srt.model_executor.forward_batch_info import (
|
||||
ForwardBatch,
|
||||
ForwardMode,
|
||||
PPProxyTensors,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
||||
from sglang.srt.parser.reasoning_parser import ReasoningParser
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.speculative.eagle_info import EagleDraftInput
|
||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||
from sglang.srt.tracing.trace import (
|
||||
process_tracing_init,
|
||||
@@ -219,6 +216,14 @@ class GenerationBatchResult:
|
||||
forward_batch: Optional[ForwardBatch] = None
|
||||
future_indices: Optional[FutureIndices] = None
|
||||
|
||||
# FIXME(lsyin): maybe move to <BetterPlace> ?
|
||||
# sync path: forward stream -> output processor
|
||||
accept_lens: Optional[torch.Tensor] = None
|
||||
last_batch_allocate_lens: Optional[torch.Tensor] = None
|
||||
|
||||
# relay path: forward stream -> next step forward
|
||||
next_draft_input: Optional[EagleDraftInput] = None
|
||||
|
||||
def copy_to_cpu(self, return_logprob: bool = False):
|
||||
"""Copy tensors to CPU in overlap scheduling.
|
||||
Only the tensors which are needed for processing results are copied,
|
||||
@@ -238,6 +243,15 @@ class GenerationBatchResult:
|
||||
"cpu", non_blocking=True
|
||||
)
|
||||
self.next_token_ids = self.next_token_ids.to("cpu", non_blocking=True)
|
||||
|
||||
if self.accept_lens is not None:
|
||||
self.accept_lens = self.accept_lens.to("cpu", non_blocking=True)
|
||||
|
||||
if self.last_batch_allocate_lens is not None:
|
||||
self.last_batch_allocate_lens = self.last_batch_allocate_lens.to(
|
||||
"cpu", non_blocking=True
|
||||
)
|
||||
|
||||
self.copy_done.record()
|
||||
|
||||
@classmethod
|
||||
@@ -273,48 +287,6 @@ class Scheduler(
|
||||
):
|
||||
"""A scheduler that manages a tensor parallel GPU worker."""
|
||||
|
||||
def launch_draft_worker(
|
||||
self, gpu_id, tp_rank, moe_ep_rank, server_args, port_args, dp_rank
|
||||
):
|
||||
if self.spec_algorithm.is_eagle():
|
||||
from sglang.srt.speculative.eagle_worker import EAGLEWorker
|
||||
|
||||
self.draft_worker = EAGLEWorker(
|
||||
gpu_id=gpu_id,
|
||||
tp_rank=tp_rank,
|
||||
moe_ep_rank=moe_ep_rank,
|
||||
server_args=server_args,
|
||||
nccl_port=port_args.nccl_port,
|
||||
target_worker=self.tp_worker,
|
||||
dp_rank=dp_rank,
|
||||
)
|
||||
elif self.spec_algorithm.is_standalone():
|
||||
from sglang.srt.speculative.standalone_worker import StandaloneWorker
|
||||
|
||||
self.draft_worker = StandaloneWorker(
|
||||
gpu_id=gpu_id,
|
||||
tp_rank=tp_rank,
|
||||
moe_ep_rank=moe_ep_rank,
|
||||
server_args=server_args,
|
||||
nccl_port=port_args.nccl_port,
|
||||
target_worker=self.tp_worker,
|
||||
dp_rank=dp_rank,
|
||||
)
|
||||
elif self.spec_algorithm.is_ngram():
|
||||
from sglang.srt.speculative.ngram_worker import NGRAMWorker
|
||||
|
||||
self.draft_worker = NGRAMWorker(
|
||||
gpu_id=gpu_id,
|
||||
tp_rank=tp_rank,
|
||||
moe_ep_rank=moe_ep_rank,
|
||||
server_args=server_args,
|
||||
nccl_port=port_args.nccl_port,
|
||||
target_worker=self.tp_worker,
|
||||
dp_rank=dp_rank,
|
||||
)
|
||||
else:
|
||||
self.draft_worker = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server_args: ServerArgs,
|
||||
@@ -454,6 +426,7 @@ class Scheduler(
|
||||
)
|
||||
|
||||
# Launch a draft worker for speculative decoding
|
||||
|
||||
self.launch_draft_worker(
|
||||
gpu_id, tp_rank, moe_ep_rank, server_args, port_args, dp_rank
|
||||
)
|
||||
@@ -683,6 +656,51 @@ class Scheduler(
|
||||
]
|
||||
)
|
||||
|
||||
def launch_draft_worker(
|
||||
self, gpu_id, tp_rank, moe_ep_rank, server_args, port_args, dp_rank
|
||||
):
|
||||
if self.spec_algorithm.is_eagle():
|
||||
from sglang.srt.speculative.eagle_worker import EAGLEWorker
|
||||
from sglang.srt.speculative.eagle_worker_v2 import EAGLEWorkerV2
|
||||
|
||||
WorkerClass = EAGLEWorkerV2 if self.enable_overlap else EAGLEWorker
|
||||
|
||||
self.draft_worker = WorkerClass(
|
||||
gpu_id=gpu_id,
|
||||
tp_rank=tp_rank,
|
||||
moe_ep_rank=moe_ep_rank,
|
||||
server_args=server_args,
|
||||
nccl_port=port_args.nccl_port,
|
||||
target_worker=self.tp_worker,
|
||||
dp_rank=dp_rank,
|
||||
)
|
||||
elif self.spec_algorithm.is_standalone():
|
||||
from sglang.srt.speculative.standalone_worker import StandaloneWorker
|
||||
|
||||
self.draft_worker = StandaloneWorker(
|
||||
gpu_id=gpu_id,
|
||||
tp_rank=tp_rank,
|
||||
moe_ep_rank=moe_ep_rank,
|
||||
server_args=server_args,
|
||||
nccl_port=port_args.nccl_port,
|
||||
target_worker=self.tp_worker,
|
||||
dp_rank=dp_rank,
|
||||
)
|
||||
elif self.spec_algorithm.is_ngram():
|
||||
from sglang.srt.speculative.ngram_worker import NGRAMWorker
|
||||
|
||||
self.draft_worker = NGRAMWorker(
|
||||
gpu_id=gpu_id,
|
||||
tp_rank=tp_rank,
|
||||
moe_ep_rank=moe_ep_rank,
|
||||
server_args=server_args,
|
||||
nccl_port=port_args.nccl_port,
|
||||
target_worker=self.tp_worker,
|
||||
dp_rank=dp_rank,
|
||||
)
|
||||
else:
|
||||
self.draft_worker = None
|
||||
|
||||
def init_deterministic_inference_config(self):
|
||||
"""Initialize deterministic inference configuration for different attention backends."""
|
||||
if not self.server_args.enable_deterministic_inference:
|
||||
@@ -965,7 +983,9 @@ class Scheduler(
|
||||
self.device
|
||||
).stream(self.copy_stream)
|
||||
|
||||
self.future_map = FutureMap(self.max_running_requests, self.device)
|
||||
self.future_map = FutureMap(
|
||||
self.max_running_requests, self.device, self.spec_algorithm
|
||||
)
|
||||
self.batch_record_buf = [None] * 2
|
||||
self.batch_record_ct = 0
|
||||
|
||||
@@ -2096,7 +2116,7 @@ class Scheduler(
|
||||
|
||||
batch_or_worker_batch = batch
|
||||
|
||||
if self.spec_algorithm.is_none():
|
||||
if self.enable_overlap or self.spec_algorithm.is_none():
|
||||
# FIXME(lsyin): remove this if and finally unify the abstraction
|
||||
batch_or_worker_batch = batch.get_model_worker_batch()
|
||||
|
||||
@@ -2120,39 +2140,49 @@ class Scheduler(
|
||||
if batch.sampling_info.grammars is not None:
|
||||
model_worker_batch.delay_sample_launch = True
|
||||
batch_result = self.model_worker.forward_batch_generation(
|
||||
batch_or_worker_batch
|
||||
model_worker_batch
|
||||
)
|
||||
# FIXME(lsyin): maybe move this to forward_batch_generation
|
||||
batch_result.copy_done = torch.get_device_module(
|
||||
self.device
|
||||
).Event()
|
||||
if not model_worker_batch.delay_sample_launch:
|
||||
self.future_map.store_to_map(
|
||||
future_indices, batch_result.next_token_ids
|
||||
)
|
||||
self.future_map.store_to_map(future_indices, batch_result)
|
||||
batch_result.copy_to_cpu()
|
||||
else:
|
||||
batch_result.future_indices = future_indices
|
||||
|
||||
# FIXME(lsyin): move this assignment elsewhere
|
||||
maybe_future_next_token_ids = -future_indices.indices
|
||||
future_indices_or_next_token_ids = -future_indices.indices
|
||||
|
||||
if batch.is_v2_eagle:
|
||||
# FIXME(lsyin): tmp code for eagle v2
|
||||
# We only keep future indices for next draft input
|
||||
|
||||
batch.spec_info = batch_result.next_draft_input
|
||||
batch.spec_info.future_indices = future_indices
|
||||
|
||||
# batch.spec_info = EagleDraftInput(
|
||||
# future_indices=future_indices,
|
||||
# verify_done=batch_result.next_draft_input.verify_done,
|
||||
# # FIXME(lsyin): remove the allocate_lens in EagleDraftInput
|
||||
# allocate_lens=batch_result.next_draft_input.allocate_lens,
|
||||
# )
|
||||
|
||||
# The future value, usually for next batch preparation
|
||||
# Current implementation strictly synchronizes the seq_lens
|
||||
batch.seq_lens = batch_result.next_draft_input.new_seq_lens
|
||||
else:
|
||||
batch_result = self.model_worker.forward_batch_generation(
|
||||
batch_or_worker_batch
|
||||
)
|
||||
maybe_future_next_token_ids = batch_result.next_token_ids
|
||||
future_indices_or_next_token_ids = batch_result.next_token_ids
|
||||
|
||||
if not self.spec_algorithm.is_none():
|
||||
# TODO(lsyin): unify this metric-updating logic with non-spec, and move it to decode processing
|
||||
self.update_spec_metrics(
|
||||
batch.batch_size(), batch_result.num_accepted_tokens
|
||||
)
|
||||
|
||||
# NOTE: maybe_future_next_token_ids is used in ScheduleBatch,
|
||||
# NOTE: future_indices_or_next_token_ids is used in ScheduleBatch,
|
||||
# which can probably be replaced by future_indices later [TODO(lsyin)].
|
||||
# we shall still keep the original outputs, e.g. next_token_ids
|
||||
# in the GenerationBatchOutput for processing after copy_done.
|
||||
batch.output_ids = maybe_future_next_token_ids
|
||||
batch.output_ids = future_indices_or_next_token_ids
|
||||
|
||||
# These 2 values are needed for processing the output, but the values can be
|
||||
# modified by overlap schedule. So we have to copy them here so that
|
||||
@@ -2200,7 +2230,7 @@ class Scheduler(
|
||||
tmp_result.forward_batch,
|
||||
)
|
||||
future_indices = tmp_result.future_indices
|
||||
self.future_map.store_to_map(future_indices, tmp_result.next_token_ids)
|
||||
self.future_map.store_to_map(future_indices, tmp_result)
|
||||
tmp_result.copy_to_cpu()
|
||||
self.result_queue.appendleft((tmp_batch, tmp_result))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user