Eagle speculative decoding part 3: small modifications to the general scheduler (#2709)

Co-authored-by: kavioyu <kavioyu@tencent.com>
This commit is contained in:
Lianmin Zheng
2025-01-02 02:09:08 -08:00
committed by GitHub
parent 9183c23eca
commit ad20b7957e
13 changed files with 224 additions and 69 deletions

View File

@@ -76,6 +76,7 @@ from sglang.srt.mem_cache.radix_cache import RadixCache
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.utils import (
broadcast_pyobj,
configure_logger,
@@ -116,6 +117,14 @@ class Scheduler:
self.enable_overlap = not server_args.disable_overlap_schedule
self.skip_tokenizer_init = server_args.skip_tokenizer_init
self.enable_metrics = server_args.enable_metrics
self.spec_algorithm = SpeculativeAlgorithm.from_string(
server_args.speculative_algorithm
)
self.decode_mem_cache_buf_multiplier = (
self.server_args.speculative_num_draft_tokens
if not self.spec_algorithm.is_none()
else 1
)
# Init inter-process communication
context = zmq.Context(2)
@@ -199,6 +208,21 @@ class Scheduler:
nccl_port=port_args.nccl_port,
)
# Launch worker for speculative decoding if need
if self.spec_algorithm.is_eagle():
from sglang.srt.speculative.eagle_worker import EAGLEWorker
self.draft_worker = EAGLEWorker(
gpu_id=gpu_id,
tp_rank=tp_rank,
server_args=server_args,
nccl_port=port_args.nccl_port,
target_worker=self.tp_worker,
dp_rank=dp_rank,
)
else:
self.draft_worker = None
# Get token and memory info from the model worker
(
self.max_total_num_tokens,
@@ -855,6 +879,7 @@ class Scheduler:
self.tree_cache,
self.model_config,
self.enable_overlap,
self.spec_algorithm,
)
new_batch.prepare_for_extend()
@@ -888,11 +913,15 @@ class Scheduler:
return None
# Check if decode out of memory
if not batch.check_decode_mem() or (test_retract and batch.batch_size() > 10):
if not batch.check_decode_mem(self.decode_mem_cache_buf_multiplier) or (
test_retract and batch.batch_size() > 10
):
old_ratio = self.new_token_ratio
retracted_reqs, new_token_ratio = batch.retract_decode()
self.new_token_ratio = new_token_ratio
if self.draft_worker:
self.draft_worker.finish_request(retracted_reqs)
logger.info(
"Decode out of memory happened. "
@@ -926,11 +955,17 @@ class Scheduler:
self.forward_ct += 1
if self.is_generation:
model_worker_batch = batch.get_model_worker_batch()
if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0:
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
model_worker_batch
)
if self.spec_algorithm.is_none():
model_worker_batch = batch.get_model_worker_batch()
logits_output, next_token_ids = (
self.tp_worker.forward_batch_generation(model_worker_batch)
)
else:
logits_output, next_token_ids, model_worker_batch, spec_info = (
self.draft_worker.forward_batch_speculative_generation(batch)
)
batch.spec_info = spec_info
elif batch.forward_mode.is_idle():
model_worker_batch = batch.get_model_worker_batch()
self.tp_worker.forward_batch_idle(model_worker_batch)
@@ -1077,7 +1112,10 @@ class Scheduler:
self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1])
continue
req.output_ids.append(next_token_id)
if batch.spec_algorithm.is_none():
# speculative worker will solve the output_ids in speculative decoding
req.output_ids.append(next_token_id)
req.check_finished()
if req.finished():
@@ -1252,6 +1290,9 @@ class Scheduler:
# If not stream, we still want to output some tokens to get the benefit of incremental decoding.
or (not req.stream and len(req.output_ids) % 50 == 0)
):
if self.draft_worker and req.finished():
self.draft_worker.finish_request(req)
rids.append(req.rid)
finished_reasons.append(
req.finished_reason.to_json() if req.finished_reason else None
@@ -1383,6 +1424,7 @@ class Scheduler:
self.tree_cache,
self.model_config,
self.enable_overlap,
self.spec_algorithm,
)
idle_batch.prepare_for_idle()
return idle_batch