From e65b9f21e3092e53c94e827bc5df8c37a6801302 Mon Sep 17 00:00:00 2001 From: Byron Hsu Date: Mon, 21 Apr 2025 12:06:16 -0700 Subject: [PATCH] [PD] Support decode overlap schedule (#5608) --- python/sglang/srt/disaggregation/decode.py | 43 ++++++++++++++++++++++ python/sglang/srt/managers/scheduler.py | 5 ++- python/sglang/srt/server_args.py | 6 +-- 3 files changed, 49 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py index db2ed2ae9..105142e69 100644 --- a/python/sglang/srt/disaggregation/decode.py +++ b/python/sglang/srt/disaggregation/decode.py @@ -21,6 +21,7 @@ Life cycle of a request in the decode server from __future__ import annotations import logging +from collections import deque from dataclasses import dataclass from typing import TYPE_CHECKING, List, Optional, Tuple @@ -475,6 +476,48 @@ class SchedulerDisaggregationDecodeMixin: self.last_batch = batch + @torch.no_grad() + def event_loop_overlap_disagg_decode(self): + result_queue = deque() + self.last_batch: Optional[ScheduleBatch] = None + self.last_batch_is_extend = False # last batch is modifed in-place, so we need another variable to track if it's extend + + while True: + recv_reqs = self.recv_requests() + self.process_input_requests(recv_reqs) + # polling and allocating kv cache + self.process_decode_queue() + batch = self.get_next_disagg_decode_batch_to_run() + self.cur_batch = batch + last_batch_is_extend = False + + if batch: + # Generate fake extend output. + if batch.forward_mode.is_extend(): + # Note: Logprobs should be handled on the prefill engine. + self.stream_output(batch.reqs, False) + last_batch_is_extend = True + else: + result = self.run_batch(batch) + result_queue.append((batch.copy(), result)) + + # Process the results of the previous batch but skip if the last batch is extend + if self.last_batch and not self.last_batch_is_extend: + tmp_batch, tmp_result = result_queue.popleft() + self.process_batch_result(tmp_batch, tmp_result) + + if batch is None and ( + len(self.disagg_decode_transfer_queue.queue) + + len(self.disagg_decode_prealloc_queue.queue) + == 0 + ): + # When the server is idle, do self-check and re-init some states + self.check_memory() + self.new_token_ratio = self.init_new_token_ratio + + self.last_batch = batch + self.last_batch_is_extend = last_batch_is_extend + def get_next_disagg_decode_batch_to_run( self: Scheduler, ) -> Optional[Tuple[ScheduleBatch, bool]]: diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 3f96f106c..53c5ea4f9 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -2016,7 +2016,10 @@ def run_scheduler_process( elif disaggregation_mode == DisaggregationMode.PREFILL: scheduler.event_loop_normal_disagg_prefill() elif disaggregation_mode == DisaggregationMode.DECODE: - scheduler.event_loop_normal_disagg_decode() + if scheduler.enable_overlap: + scheduler.event_loop_overlap_disagg_decode() + else: + scheduler.event_loop_normal_disagg_decode() except Exception: traceback = get_exception_traceback() diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 4fea1c824..dee87dbb2 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -387,14 +387,12 @@ class ServerArgs: # PD disaggregation if self.disaggregation_mode == "prefill": self.disable_cuda_graph = True - logger.warning("KV cache is forced as chunk cache for decode server") + logger.warning("Cuda graph is disabled for prefill server") self.disable_overlap_schedule = True logger.warning("Overlap scheduler is disabled for prefill server") elif self.disaggregation_mode == "decode": self.disable_radix_cache = True - logger.warning("Cuda graph is disabled for prefill server") - self.disable_overlap_schedule = True - logger.warning("Overlap scheduler is disabled for decode server") + logger.warning("KV cache is forced as chunk cache for decode server") os.environ["SGLANG_ENABLE_TORCH_COMPILE"] = ( "1" if self.enable_torch_compile else "0"