Eagle speculative decoding part 3: small modifications to the general scheduler (#2709)
Co-authored-by: kavioyu <kavioyu@tencent.com>
This commit is contained in:
@@ -76,6 +76,7 @@ from sglang.srt.mem_cache.radix_cache import RadixCache
|
||||
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||
from sglang.srt.utils import (
|
||||
broadcast_pyobj,
|
||||
configure_logger,
|
||||
@@ -116,6 +117,14 @@ class Scheduler:
|
||||
self.enable_overlap = not server_args.disable_overlap_schedule
|
||||
self.skip_tokenizer_init = server_args.skip_tokenizer_init
|
||||
self.enable_metrics = server_args.enable_metrics
|
||||
self.spec_algorithm = SpeculativeAlgorithm.from_string(
|
||||
server_args.speculative_algorithm
|
||||
)
|
||||
self.decode_mem_cache_buf_multiplier = (
|
||||
self.server_args.speculative_num_draft_tokens
|
||||
if not self.spec_algorithm.is_none()
|
||||
else 1
|
||||
)
|
||||
|
||||
# Init inter-process communication
|
||||
context = zmq.Context(2)
|
||||
@@ -199,6 +208,21 @@ class Scheduler:
|
||||
nccl_port=port_args.nccl_port,
|
||||
)
|
||||
|
||||
# Launch worker for speculative decoding if need
|
||||
if self.spec_algorithm.is_eagle():
|
||||
from sglang.srt.speculative.eagle_worker import EAGLEWorker
|
||||
|
||||
self.draft_worker = EAGLEWorker(
|
||||
gpu_id=gpu_id,
|
||||
tp_rank=tp_rank,
|
||||
server_args=server_args,
|
||||
nccl_port=port_args.nccl_port,
|
||||
target_worker=self.tp_worker,
|
||||
dp_rank=dp_rank,
|
||||
)
|
||||
else:
|
||||
self.draft_worker = None
|
||||
|
||||
# Get token and memory info from the model worker
|
||||
(
|
||||
self.max_total_num_tokens,
|
||||
@@ -855,6 +879,7 @@ class Scheduler:
|
||||
self.tree_cache,
|
||||
self.model_config,
|
||||
self.enable_overlap,
|
||||
self.spec_algorithm,
|
||||
)
|
||||
new_batch.prepare_for_extend()
|
||||
|
||||
@@ -888,11 +913,15 @@ class Scheduler:
|
||||
return None
|
||||
|
||||
# Check if decode out of memory
|
||||
if not batch.check_decode_mem() or (test_retract and batch.batch_size() > 10):
|
||||
if not batch.check_decode_mem(self.decode_mem_cache_buf_multiplier) or (
|
||||
test_retract and batch.batch_size() > 10
|
||||
):
|
||||
old_ratio = self.new_token_ratio
|
||||
|
||||
retracted_reqs, new_token_ratio = batch.retract_decode()
|
||||
self.new_token_ratio = new_token_ratio
|
||||
if self.draft_worker:
|
||||
self.draft_worker.finish_request(retracted_reqs)
|
||||
|
||||
logger.info(
|
||||
"Decode out of memory happened. "
|
||||
@@ -926,11 +955,17 @@ class Scheduler:
|
||||
self.forward_ct += 1
|
||||
|
||||
if self.is_generation:
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0:
|
||||
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
|
||||
model_worker_batch
|
||||
)
|
||||
if self.spec_algorithm.is_none():
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
logits_output, next_token_ids = (
|
||||
self.tp_worker.forward_batch_generation(model_worker_batch)
|
||||
)
|
||||
else:
|
||||
logits_output, next_token_ids, model_worker_batch, spec_info = (
|
||||
self.draft_worker.forward_batch_speculative_generation(batch)
|
||||
)
|
||||
batch.spec_info = spec_info
|
||||
elif batch.forward_mode.is_idle():
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
self.tp_worker.forward_batch_idle(model_worker_batch)
|
||||
@@ -1077,7 +1112,10 @@ class Scheduler:
|
||||
self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1])
|
||||
continue
|
||||
|
||||
req.output_ids.append(next_token_id)
|
||||
if batch.spec_algorithm.is_none():
|
||||
# speculative worker will solve the output_ids in speculative decoding
|
||||
req.output_ids.append(next_token_id)
|
||||
|
||||
req.check_finished()
|
||||
|
||||
if req.finished():
|
||||
@@ -1252,6 +1290,9 @@ class Scheduler:
|
||||
# If not stream, we still want to output some tokens to get the benefit of incremental decoding.
|
||||
or (not req.stream and len(req.output_ids) % 50 == 0)
|
||||
):
|
||||
if self.draft_worker and req.finished():
|
||||
self.draft_worker.finish_request(req)
|
||||
|
||||
rids.append(req.rid)
|
||||
finished_reasons.append(
|
||||
req.finished_reason.to_json() if req.finished_reason else None
|
||||
@@ -1383,6 +1424,7 @@ class Scheduler:
|
||||
self.tree_cache,
|
||||
self.model_config,
|
||||
self.enable_overlap,
|
||||
self.spec_algorithm,
|
||||
)
|
||||
idle_batch.prepare_for_idle()
|
||||
return idle_batch
|
||||
|
||||
Reference in New Issue
Block a user