diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py index 7fb2365ca..9305be298 100644 --- a/python/sglang/srt/disaggregation/decode.py +++ b/python/sglang/srt/disaggregation/decode.py @@ -752,7 +752,6 @@ class SchedulerDisaggregationDecodeMixin: self.last_batch_in_queue = False # last batch is modified in-place, so we need another variable to track if it's extend while True: - self.launch_last_batch_sample_if_needed() recv_reqs = self.recv_requests() self.process_input_requests(recv_reqs) @@ -764,6 +763,7 @@ class SchedulerDisaggregationDecodeMixin: prepare_mlp_sync_flag = require_mlp_sync(self.server_args) + batch_result = None if batch: # Generate fake extend output. if batch.forward_mode.is_extend(): @@ -772,25 +772,25 @@ class SchedulerDisaggregationDecodeMixin: batch.reqs, any(req.return_logprob for req in batch.reqs) ) if prepare_mlp_sync_flag: - batch_, result = self._prepare_idle_batch_and_run( + batch_, batch_result = self._prepare_idle_batch_and_run( None, delay_process=True ) if batch_: - self.result_queue.append((batch_.copy(), result)) + self.result_queue.append((batch_.copy(), batch_result)) last_batch_in_queue = True else: if prepare_mlp_sync_flag: self.prepare_mlp_sync_batch(batch) - result = self.run_batch(batch) - self.result_queue.append((batch.copy(), result)) + batch_result = self.run_batch(batch) + self.result_queue.append((batch.copy(), batch_result)) last_batch_in_queue = True elif prepare_mlp_sync_flag: - batch, result = self._prepare_idle_batch_and_run( + batch, batch_result = self._prepare_idle_batch_and_run( None, delay_process=True ) if batch: - self.result_queue.append((batch.copy(), result)) + self.result_queue.append((batch.copy(), batch_result)) last_batch_in_queue = True # Process the results of the previous batch but skip if the last batch is extend @@ -798,6 +798,8 @@ class SchedulerDisaggregationDecodeMixin: tmp_batch, tmp_result = self.result_queue.popleft() self.process_batch_result(tmp_batch, tmp_result) + self.launch_batch_sample_if_needed(batch_result) + queue_size = ( len(self.waiting_queue) + len(self.disagg_decode_transfer_queue.queue) diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py index 020d3f5aa..b9884414c 100644 --- a/python/sglang/srt/disaggregation/prefill.py +++ b/python/sglang/srt/disaggregation/prefill.py @@ -321,8 +321,6 @@ class SchedulerDisaggregationPrefillMixin: self.result_queue = deque() while True: - self.launch_last_batch_sample_if_needed() - recv_reqs = self.recv_requests() self.process_input_requests(recv_reqs) self.waiting_queue.extend( @@ -334,9 +332,11 @@ class SchedulerDisaggregationPrefillMixin: if require_mlp_sync(self.server_args): batch = self.prepare_mlp_sync_batch(batch) self.cur_batch = batch + + batch_result = None if batch: - result = self.run_batch(batch) - self.result_queue.append((batch.copy(), result)) + batch_result = self.run_batch(batch) + self.result_queue.append((batch.copy(), batch_result)) if self.last_batch: tmp_batch, tmp_result = self.result_queue.popleft() @@ -345,6 +345,8 @@ class SchedulerDisaggregationPrefillMixin: if len(self.disagg_prefill_inflight_queue) > 0: self.process_disagg_prefill_inflight_queue() + self.launch_batch_sample_if_needed(batch_result) + if batch is None and len(self.disagg_prefill_inflight_queue) == 0: self.self_check_during_idle() diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index ee427cce2..5ebf3f61a 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -1907,8 +1907,5 @@ class ModelWorkerBatch: capture_hidden_mode: CaptureHiddenMode = None hicache_consumer_index: int = -1 - # Overlap scheduler related - delay_sample_launch: bool = False - # Whether this batch is prefill-only (no token generation needed) is_prefill_only: bool = False diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index e19f83f24..10110751c 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -148,7 +148,7 @@ from sglang.srt.mem_cache.hiradix_cache import HiRadixCache from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache from sglang.srt.mem_cache.radix_cache import RadixCache from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache -from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors +from sglang.srt.model_executor.forward_batch_info import PPProxyTensors from sglang.srt.parser.reasoning_parser import ReasoningParser from sglang.srt.server_args import PortArgs, ServerArgs, get_global_server_args from sglang.srt.speculative.eagle_info import EagleDraftInput @@ -212,8 +212,7 @@ class GenerationBatchResult: # For overlap scheduling copy_done: Optional[torch.cuda.Event] = None - delay_sample_launch: bool = False - forward_batch: Optional[ForwardBatch] = None + delay_sample_func: Optional[callable] = None future_indices: Optional[FutureIndices] = None # FIXME(lsyin): maybe move to ? @@ -1036,17 +1035,16 @@ class Scheduler( self.result_queue: Deque[Tuple[ScheduleBatch, GenerationBatchResult]] = deque() while True: - self.launch_last_batch_sample_if_needed() - recv_reqs = self.recv_requests() self.process_input_requests(recv_reqs) batch = self.get_next_batch_to_run() self.cur_batch = batch + batch_result = None if batch: - result = self.run_batch(batch) - self.result_queue.append((batch.copy(), result)) + batch_result = self.run_batch(batch) + self.result_queue.append((batch.copy(), batch_result)) if self.last_batch: # Process the results of the last batch @@ -1056,6 +1054,7 @@ class Scheduler( # When the server is idle, do self-check and re-init some states self.self_check_during_idle() + self.launch_batch_sample_if_needed(batch_result) self.last_batch = batch @DynamicGradMode() @@ -2207,8 +2206,6 @@ class Scheduler( with self.forward_stream_ctx: self.forward_stream.wait_stream(self.default_stream) self.future_map.resolve_future(model_worker_batch) - if batch.sampling_info.grammars is not None: - model_worker_batch.delay_sample_launch = True batch_result = self.model_worker.forward_batch_generation( model_worker_batch ) @@ -2216,7 +2213,7 @@ class Scheduler( batch_result.copy_done = torch.get_device_module( self.device ).Event() - if not model_worker_batch.delay_sample_launch: + if batch_result.delay_sample_func is None: self.future_map.store_to_map(future_indices, batch_result) batch_result.copy_to_cpu() else: @@ -2280,29 +2277,20 @@ class Scheduler( ret = EmbeddingBatchResult(embeddings=embeddings) return ret - def launch_last_batch_sample_if_needed( - self, + def launch_batch_sample_if_needed( + self, batch_result: GenerationBatchResult ) -> Union[GenerationBatchResult, EmbeddingBatchResult]: - if len(self.result_queue) == 0: - return - - tmp_batch, tmp_result = self.result_queue.popleft() - - tmp_result: GenerationBatchResult - if not tmp_result.delay_sample_launch: - self.result_queue.appendleft((tmp_batch, tmp_result)) + # TODO(lsyin): make the delayed sample a default behavior after + # unifying the forward_batch_generation interface (related to spec V2). + if batch_result is None or batch_result.delay_sample_func is None: return with self.forward_stream_ctx: self.forward_stream.wait_stream(self.default_stream) - tmp_result.next_token_ids = self.model_worker.model_runner.sample( - tmp_result.logits_output, - tmp_result.forward_batch, - ) - future_indices = tmp_result.future_indices - self.future_map.store_to_map(future_indices, tmp_result) - tmp_result.copy_to_cpu() - self.result_queue.appendleft((tmp_batch, tmp_result)) + _batch_result = batch_result.delay_sample_func() + assert _batch_result is batch_result + self.future_map.store_to_map(batch_result.future_indices, batch_result) + batch_result.copy_to_cpu() def process_batch_result( self, diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 3485a0357..6546781de 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -168,6 +168,7 @@ class TpModelWorker: )[0] set_random_seed(self.random_seed) + self.enable_overlap = not server_args.disable_overlap_schedule self.hicache_layer_transfer_counter = None def register_hicache_layer_transfer_counter(self, counter: LayerDoneCounter): @@ -266,9 +267,18 @@ class TpModelWorker: # Skip sampling and return logits for target forward return batch_result - if model_worker_batch.delay_sample_launch: - batch_result.delay_sample_launch = True - batch_result.forward_batch = forward_batch + if ( + self.enable_overlap + and model_worker_batch.sampling_info.grammars is not None + ): + + def sample_batch_func(): + batch_result.next_token_ids = self.model_runner.sample( + logits_output, forward_batch + ) + return batch_result + + batch_result.delay_sample_func = sample_batch_func return batch_result if model_worker_batch.is_prefill_only: