From 711efe781426dad242e88fe71d6eefe866fe3866 Mon Sep 17 00:00:00 2001 From: Cheng Wan <54331508+ch-wan@users.noreply.github.com> Date: Wed, 23 Apr 2025 01:46:01 -0700 Subject: [PATCH] Integrating PD disaggregation with DP attention and DeepEP (#5435) Co-authored-by: Byron Hsu --- python/sglang/srt/disaggregation/decode.py | 51 +++++++++++++++++-- python/sglang/srt/disaggregation/prefill.py | 16 ++++++ .../srt/managers/data_parallel_controller.py | 13 +++-- 3 files changed, 72 insertions(+), 8 deletions(-) diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py index 105142e69..a45af7e37 100644 --- a/python/sglang/srt/disaggregation/decode.py +++ b/python/sglang/srt/disaggregation/decode.py @@ -444,6 +444,15 @@ class ScheduleBatchDisaggregationDecodeMixin: class SchedulerDisaggregationDecodeMixin: + def _prepare_idle_batch_and_run(self, batch, delay_process=False): + batch, _ = self.prepare_dp_attn_batch(batch) + result = None + if batch: + result = self.run_batch(batch) + if not delay_process: + self.process_batch_result(batch, result) + return batch, result + @torch.no_grad() def event_loop_normal_disagg_decode(self): """A normal scheduler loop for decode worker in disaggregation mode.""" @@ -456,14 +465,25 @@ class SchedulerDisaggregationDecodeMixin: batch = self.get_next_disagg_decode_batch_to_run() self.cur_batch = batch + prepare_dp_attn_flag = ( + self.server_args.enable_dp_attention + or self.server_args.enable_sp_layernorm + ) + if batch: # Generate fake extend output. if batch.forward_mode.is_extend(): # Note: Logprobs should be handled on the prefill engine. self.stream_output(batch.reqs, False) + if prepare_dp_attn_flag: + self._prepare_idle_batch_and_run(None) else: + if prepare_dp_attn_flag: + self.prepare_dp_attn_batch(batch) result = self.run_batch(batch) self.process_batch_result(batch, result) + elif prepare_dp_attn_flag: + batch, _ = self._prepare_idle_batch_and_run(None) if batch is None and ( len(self.disagg_decode_transfer_queue.queue) @@ -480,7 +500,7 @@ class SchedulerDisaggregationDecodeMixin: def event_loop_overlap_disagg_decode(self): result_queue = deque() self.last_batch: Optional[ScheduleBatch] = None - self.last_batch_is_extend = False # last batch is modifed in-place, so we need another variable to track if it's extend + self.last_batch_in_queue = False # last batch is modifed in-place, so we need another variable to track if it's extend while True: recv_reqs = self.recv_requests() @@ -489,20 +509,41 @@ class SchedulerDisaggregationDecodeMixin: self.process_decode_queue() batch = self.get_next_disagg_decode_batch_to_run() self.cur_batch = batch - last_batch_is_extend = False + last_batch_in_queue = False + + prepare_dp_attn_flag = ( + self.server_args.enable_dp_attention + or self.server_args.enable_sp_layernorm + ) if batch: # Generate fake extend output. if batch.forward_mode.is_extend(): # Note: Logprobs should be handled on the prefill engine. self.stream_output(batch.reqs, False) - last_batch_is_extend = True + if prepare_dp_attn_flag: + batch_, result = self._prepare_idle_batch_and_run( + None, delay_process=True + ) + if batch_: + result_queue.append((batch_.copy(), result)) + last_batch_in_queue = True else: + if prepare_dp_attn_flag: + self.prepare_dp_attn_batch(batch) result = self.run_batch(batch) result_queue.append((batch.copy(), result)) + last_batch_in_queue = True + elif prepare_dp_attn_flag: + batch, result = self._prepare_idle_batch_and_run( + None, delay_process=True + ) + if batch: + result_queue.append((batch.copy(), result)) + last_batch_in_queue = True # Process the results of the previous batch but skip if the last batch is extend - if self.last_batch and not self.last_batch_is_extend: + if self.last_batch and self.last_batch_in_queue: tmp_batch, tmp_result = result_queue.popleft() self.process_batch_result(tmp_batch, tmp_result) @@ -516,7 +557,7 @@ class SchedulerDisaggregationDecodeMixin: self.new_token_ratio = self.init_new_token_ratio self.last_batch = batch - self.last_batch_is_extend = last_batch_is_extend + self.last_batch_in_queue = last_batch_in_queue def get_next_disagg_decode_batch_to_run( self: Scheduler, diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py index 48743ef1f..7c10da219 100644 --- a/python/sglang/srt/disaggregation/prefill.py +++ b/python/sglang/srt/disaggregation/prefill.py @@ -187,6 +187,14 @@ class SchedulerDisaggregationPrefillMixin: ) self.process_prefill_chunk() batch = self.get_new_batch_prefill() + + # Handle DP attention + if ( + self.server_args.enable_dp_attention + or self.server_args.enable_sp_layernorm + ): + batch, _ = self.prepare_dp_attn_batch(batch) + self.cur_batch = batch if batch: @@ -217,6 +225,14 @@ class SchedulerDisaggregationPrefillMixin: ) self.process_prefill_chunk() batch = self.get_new_batch_prefill() + + # Handle DP attention + if ( + self.server_args.enable_dp_attention + or self.server_args.enable_sp_layernorm + ): + batch, _ = self.prepare_dp_attn_batch(batch) + self.cur_batch = batch if batch: diff --git a/python/sglang/srt/managers/data_parallel_controller.py b/python/sglang/srt/managers/data_parallel_controller.py index 279f9e27b..ce921988b 100644 --- a/python/sglang/srt/managers/data_parallel_controller.py +++ b/python/sglang/srt/managers/data_parallel_controller.py @@ -23,11 +23,13 @@ import psutil import setproctitle import zmq +from sglang.srt.disaggregation.utils import DisaggregationMode from sglang.srt.layers.dp_attention import compute_dp_attention_world_info from sglang.srt.managers.io_struct import ( TokenizedEmbeddingReqInput, TokenizedGenerateReqInput, ) +from sglang.srt.managers.schedule_batch import Req from sglang.srt.managers.scheduler import run_scheduler_process from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter @@ -226,9 +228,14 @@ class DataParallelController: self.max_total_num_tokens = scheduler_info[0]["max_total_num_tokens"] self.max_req_input_len = scheduler_info[0]["max_req_input_len"] - def round_robin_scheduler(self, req): - self.workers[self.round_robin_counter].send_pyobj(req) - self.round_robin_counter = (self.round_robin_counter + 1) % len(self.workers) + def round_robin_scheduler(self, req: Req): + if self.server_args.disaggregation_mode == "null": + self.workers[self.round_robin_counter].send_pyobj(req) + self.round_robin_counter = (self.round_robin_counter + 1) % len( + self.workers + ) + else: + self.workers[req.bootstrap_room % len(self.workers)].send_pyobj(req) def shortest_queue_scheduler(self, input_requests): raise NotImplementedError()