From 40d9b8acce88ae8abd4e48a6f5c09f409e7b41c8 Mon Sep 17 00:00:00 2001 From: Liangsheng Yin Date: Mon, 28 Apr 2025 11:19:16 +0800 Subject: [PATCH] Improve overlap scheduling (#5788) --- python/sglang/srt/disaggregation/prefill.py | 8 +++-- python/sglang/srt/managers/schedule_batch.py | 8 +++++ python/sglang/srt/managers/scheduler.py | 15 +++++--- .../scheduler_output_processor_mixin.py | 34 ++++++++++++++----- python/sglang/srt/managers/tp_worker.py | 6 ++-- .../srt/managers/tp_worker_overlap_thread.py | 13 ++++--- 6 files changed, 61 insertions(+), 23 deletions(-) diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py index 1af7a9b19..d6a8fa398 100644 --- a/python/sglang/srt/disaggregation/prefill.py +++ b/python/sglang/srt/disaggregation/prefill.py @@ -20,6 +20,7 @@ Life cycle of a request in the prefill server from __future__ import annotations import logging +import threading from collections import deque from typing import TYPE_CHECKING, List, Optional @@ -256,7 +257,10 @@ class SchedulerDisaggregationPrefillMixin: self.running_batch.batch_is_full = False def process_batch_result_disagg_prefill( - self: Scheduler, batch: ScheduleBatch, result: GenerationBatchResult + self: Scheduler, + batch: ScheduleBatch, + result: GenerationBatchResult, + launch_done: Optional[threading.Event] = None, ) -> None: """ Transfer kv for prefill completed requests and add it into disagg_prefill_inflight_queue @@ -280,7 +284,7 @@ class SchedulerDisaggregationPrefillMixin: # Transfer kv for prefill completed requests and add it into disagg_prefill_infight_queue if self.enable_overlap: # wait - _, next_token_ids = self.tp_worker.resolve_batch_result(bid) + _, next_token_ids = self.tp_worker.resolve_last_batch_result(launch_done) else: next_token_ids = result.next_token_ids.tolist() diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 4b51a09d7..960b6e70b 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -35,6 +35,7 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch import copy import dataclasses import logging +import threading from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union import numpy as np @@ -724,6 +725,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): # This is an optimization to reduce the overhead of the prefill check. batch_is_full: bool = False + # Events + launch_done: Optional[threading.Event] = None + # Sampling info sampling_info: SamplingBatchInfo = None next_batch_sampling_info: SamplingBatchInfo = None @@ -1565,6 +1569,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ) ), extend_input_logprob_token_ids=self.extend_input_logprob_token_ids, + launch_done=self.launch_done, ) def copy(self): @@ -1647,6 +1652,9 @@ class ModelWorkerBatch: # If set, the output of the batch contains the hidden states of the run. capture_hidden_mode: CaptureHiddenMode = None + # Overlap event + launch_done: Optional[threading.Event] = None + @triton.jit def write_req_to_token_pool_triton( diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 5b996f8b2..e29f46974 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -645,6 +645,7 @@ class Scheduler( self.cur_batch = batch if batch: + batch.launch_done = threading.Event() result = self.run_batch(batch) self.result_queue.append((batch.copy(), result)) @@ -656,7 +657,7 @@ class Scheduler( forward_mode=ForwardMode.DUMMY_FIRST, next_batch_sampling_info=self.tp_worker.cur_sampling_info, ) - self.process_batch_result(tmp_batch, None) + self.process_batch_result(tmp_batch, None, batch.launch_done) if self.last_batch: # Process the results of the last batch @@ -664,7 +665,10 @@ class Scheduler( tmp_batch.next_batch_sampling_info = ( self.tp_worker.cur_sampling_info if batch else None ) - self.process_batch_result(tmp_batch, tmp_result) + # NOTE: we should use current launched batch's launch_done event Instead of the last batch's + self.process_batch_result( + tmp_batch, tmp_result, batch.launch_done if batch else None + ) elif batch is None: # When the server is idle, do self-check and re-init some states self.check_memory() @@ -1417,14 +1421,15 @@ class Scheduler( self, batch: ScheduleBatch, result: Union[GenerationBatchResult, EmbeddingBatchResult], + launch_done: Optional[threading.Event] = None, ): if batch.forward_mode.is_decode(): - self.process_batch_result_decode(batch, result) + self.process_batch_result_decode(batch, result, launch_done) elif batch.forward_mode.is_extend(): - self.process_batch_result_prefill(batch, result) + self.process_batch_result_prefill(batch, result, launch_done) elif batch.forward_mode.is_idle(): if self.enable_overlap: - self.tp_worker.resolve_batch_result(result.bid) + self.tp_worker.resolve_last_batch_result(launch_done) if batch.next_batch_sampling_info: batch.next_batch_sampling_info.update_regex_vocab_mask() self.current_stream.synchronize() diff --git a/python/sglang/srt/managers/scheduler_output_processor_mixin.py b/python/sglang/srt/managers/scheduler_output_processor_mixin.py index 13158d937..ce570b75a 100644 --- a/python/sglang/srt/managers/scheduler_output_processor_mixin.py +++ b/python/sglang/srt/managers/scheduler_output_processor_mixin.py @@ -1,5 +1,6 @@ from __future__ import annotations +import threading from typing import TYPE_CHECKING, List, Optional, Tuple, Union from sglang.srt.layers.logits_processor import LogitsProcessorOutput @@ -11,6 +12,7 @@ if TYPE_CHECKING: EmbeddingBatchResult, GenerationBatchResult, ScheduleBatch, + Scheduler, ) @@ -21,9 +23,10 @@ class SchedulerOutputProcessorMixin: """ def process_batch_result_prefill( - self, + self: Scheduler, batch: ScheduleBatch, result: Union[GenerationBatchResult, EmbeddingBatchResult], + launch_done: Optional[threading.Event] = None, ): skip_stream_req = None @@ -43,7 +46,11 @@ class SchedulerOutputProcessorMixin: ) if self.enable_overlap: - logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid) + logits_output, next_token_ids = ( + self.tp_worker.resolve_last_batch_result( + launch_done, + ) + ) else: # Move next_token_ids and logprobs to cpu next_token_ids = next_token_ids.tolist() @@ -175,9 +182,10 @@ class SchedulerOutputProcessorMixin: self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req) def process_batch_result_decode( - self, + self: Scheduler, batch: ScheduleBatch, result: GenerationBatchResult, + launch_done: Optional[threading.Event] = None, ): logits_output, next_token_ids, bid = ( result.logits_output, @@ -187,7 +195,9 @@ class SchedulerOutputProcessorMixin: self.num_generated_tokens += len(batch.reqs) if self.enable_overlap: - logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid) + logits_output, next_token_ids = self.tp_worker.resolve_last_batch_result( + launch_done + ) next_token_logprobs = logits_output.next_token_logprobs elif batch.spec_algorithm.is_none(): # spec decoding handles output logprobs inside verify process. @@ -271,7 +281,7 @@ class SchedulerOutputProcessorMixin: self.log_decode_stats() def add_input_logprob_return_values( - self, + self: Scheduler, i: int, req: Req, output: LogitsProcessorOutput, @@ -405,7 +415,7 @@ class SchedulerOutputProcessorMixin: assert len(req.input_token_ids_logprobs_idx) == relevant_tokens_len def add_logprob_return_values( - self, + self: Scheduler, i: int, req: Req, pt: int, @@ -436,7 +446,10 @@ class SchedulerOutputProcessorMixin: return num_input_logprobs def stream_output( - self, reqs: List[Req], return_logprob: bool, skip_req: Optional[Req] = None + self: Scheduler, + reqs: List[Req], + return_logprob: bool, + skip_req: Optional[Req] = None, ): """Stream the output to detokenizer.""" if self.is_generation: @@ -445,7 +458,10 @@ class SchedulerOutputProcessorMixin: self.stream_output_embedding(reqs) def stream_output_generation( - self, reqs: List[Req], return_logprob: bool, skip_req: Optional[Req] = None + self: Scheduler, + reqs: List[Req], + return_logprob: bool, + skip_req: Optional[Req] = None, ): rids = [] finished_reasons: List[BaseFinishReason] = [] @@ -593,7 +609,7 @@ class SchedulerOutputProcessorMixin: ) ) - def stream_output_embedding(self, reqs: List[Req]): + def stream_output_embedding(self: Scheduler, reqs: List[Req]): rids = [] finished_reasons: List[BaseFinishReason] = [] diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index fb3d1c6b4..a07dbfb07 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -170,13 +170,13 @@ class TpModelWorker: def forward_batch_generation( self, model_worker_batch: ModelWorkerBatch, - launch_done: Optional[threading.Event] = None, skip_sample: bool = False, ) -> Tuple[LogitsProcessorOutput, Optional[torch.Tensor]]: forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) logits_output = self.model_runner.forward(forward_batch) - if launch_done: - launch_done.set() + + if model_worker_batch.launch_done is not None: + model_worker_batch.launch_done.set() if skip_sample: next_token_ids = None diff --git a/python/sglang/srt/managers/tp_worker_overlap_thread.py b/python/sglang/srt/managers/tp_worker_overlap_thread.py index fb4fdc6d5..8aa7f3346 100644 --- a/python/sglang/srt/managers/tp_worker_overlap_thread.py +++ b/python/sglang/srt/managers/tp_worker_overlap_thread.py @@ -132,7 +132,6 @@ class TpModelWorkerClient: batch_pt += 1 # Create event - self.launch_done = threading.Event() copy_done = torch.get_device_module(self.device).Event() # Resolve future tokens in the input @@ -141,7 +140,7 @@ class TpModelWorkerClient: # Run forward logits_output, next_token_ids = self.worker.forward_batch_generation( - model_worker_batch, self.launch_done + model_worker_batch ) # Update the future token ids map @@ -168,10 +167,16 @@ class TpModelWorkerClient: self.output_queue.put((copy_done, logits_output, next_token_ids)) - def resolve_batch_result(self, bid: int): + def resolve_last_batch_result(self, launch_done: Optional[threading.Event] = None): + """ + This function is called to resolve the last batch result and + wait for the current batch to be launched. Used in overlap mode. + """ copy_done, logits_output, next_token_ids = self.output_queue.get() + + if launch_done is not None: + launch_done.wait() copy_done.synchronize() - self.launch_done.wait() if logits_output.next_token_logprobs is not None: logits_output.next_token_logprobs = (