diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py index fb0fef464..d106e42d4 100644 --- a/python/sglang/srt/disaggregation/decode.py +++ b/python/sglang/srt/disaggregation/decode.py @@ -419,6 +419,38 @@ class ScheduleBatchDisaggregationDecodeMixin: class SchedulerDisaggregationDecodeMixin: + @torch.no_grad() + def event_loop_normal_disagg_decode(self): + """A normal scheduler loop for decode worker in disaggregation mode.""" + + while True: + recv_reqs = self.recv_requests() + self.process_input_requests(recv_reqs) + # polling and allocating kv cache + self.process_decode_queue() + batch = self.get_next_disagg_decode_batch_to_run() + self.cur_batch = batch + + if batch: + # Generate fake extend output. + if batch.forward_mode.is_extend(): + # Note: Logprobs should be handled on the prefill engine. + self.stream_output(batch.reqs, False) + else: + result = self.run_batch(batch) + self.process_batch_result(batch, result) + + if batch is None and ( + len(self.disagg_decode_transfer_queue.queue) + + len(self.disagg_decode_prealloc_queue.queue) + == 0 + ): + # When the server is idle, do self-check and re-init some states + self.check_memory() + self.new_token_ratio = self.init_new_token_ratio + + self.last_batch = batch + def get_next_disagg_decode_batch_to_run( self: Scheduler, ) -> Optional[Tuple[ScheduleBatch, bool]]: diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py index 7ad548ccc..d513b13dd 100644 --- a/python/sglang/srt/disaggregation/prefill.py +++ b/python/sglang/srt/disaggregation/prefill.py @@ -171,6 +171,36 @@ class SchedulerDisaggregationPrefillMixin: Mixin for Scheduler to handle disaggregation prefill """ + @torch.no_grad() + def event_loop_normal_disagg_prefill(self): + """A normal scheduler loop for prefill worker in disaggregation mode.""" + + while True: + recv_reqs = self.recv_requests() + self.process_input_requests(recv_reqs) + self.waiting_queue.extend( + self.disagg_prefill_pending_queue.pop_bootstrapped() + ) + self.process_prefill_chunk() + batch = self.get_new_batch_prefill() + self.cur_batch = batch + + if batch: + result = self.run_batch(batch) + self.process_batch_result_disagg_prefill(batch, result) + + if len(self.disagg_prefill_inflight_queue) > 0: + self.process_disagg_prefill_inflight_queue() + + if batch is None and len(self.disagg_prefill_inflight_queue) == 0: + self.check_memory() + self.new_token_ratio = self.init_new_token_ratio + + self.last_batch = batch + # HACK (byronhsu): reset the batch_is_full flag because we never enter update_running_batch which resets it + # Otherwise, it hangs under high concurrency + self.running_batch.batch_is_full = False + def process_batch_result_disagg_prefill( self: Scheduler, batch: ScheduleBatch, result: GenerationBatchResult ) -> None: @@ -210,7 +240,7 @@ class SchedulerDisaggregationPrefillMixin: polls = poll_and_all_reduce( [req.disagg_kv_sender for req in self.disagg_prefill_inflight_queue], - self.tp_worker.get_tp_cpu_group(), + self.attn_tp_cpu_group, ) undone_reqs: List[Req] = [] diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 7fa979c0e..5fb0a749a 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -484,7 +484,7 @@ class Scheduler( self.tree_cache = HiRadixCache( req_to_token_pool=self.req_to_token_pool, token_to_kv_pool_allocator=self.token_to_kv_pool_allocator, - tp_cache_group=self.tp_worker.get_tp_cpu_group(), + tp_cache_group=self.tp_cpu_group, page_size=self.page_size, hicache_ratio=server_args.hicache_ratio, ) @@ -553,7 +553,7 @@ class Scheduler( # The decode requests polling kv cache self.disagg_decode_transfer_queue = DecodeTransferQueue( - gloo_group=self.tp_worker.get_attention_tp_cpu_group(), + gloo_group=self.attn_tp_cpu_group, req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator, metadata_buffers=metadata_buffers, ) @@ -568,7 +568,7 @@ class Scheduler( scheduler=self, transfer_queue=self.disagg_decode_transfer_queue, tree_cache=self.tree_cache, - gloo_group=self.tp_worker.get_attention_tp_cpu_group(), + gloo_group=self.attn_tp_cpu_group, tp_rank=self.tp_rank, tp_size=self.tp_size, bootstrap_port=self.server_args.disaggregation_bootstrap_port, @@ -597,7 +597,7 @@ class Scheduler( tp_rank=self.tp_rank, tp_size=self.tp_size, bootstrap_port=self.server_args.disaggregation_bootstrap_port, - gloo_group=self.tp_worker.get_attention_tp_cpu_group(), + gloo_group=self.attn_tp_cpu_group, transfer_backend=self.transfer_backend, scheduler=self, ) @@ -664,70 +664,6 @@ class Scheduler( self.last_batch = batch - @torch.no_grad() - def event_loop_normal_disagg_prefill(self): - """A normal scheduler loop for prefill worker in disaggregation mode.""" - - while True: - recv_reqs = self.recv_requests() - self.process_input_requests(recv_reqs) - self.waiting_queue.extend( - self.disagg_prefill_pending_queue.pop_bootstrapped() - ) - self.process_prefill_chunk() - batch = self.get_new_batch_prefill() - self.cur_batch = batch - - if batch: - result = self.run_batch(batch) - self.process_batch_result_disagg_prefill(batch, result) - - if len(self.disagg_prefill_inflight_queue) > 0: - self.process_disagg_prefill_inflight_queue() - - if batch is None and len(self.disagg_prefill_inflight_queue) == 0: - self.check_memory() - self.new_token_ratio = self.init_new_token_ratio - - self.last_batch = batch - # HACK (byronhsu): reset the batch_is_full flag because we never enter update_running_batch which resets it - # Otherwise, it hangs under high concurrency - self.running_batch.batch_is_full = False - - @torch.no_grad() - def event_loop_normal_disagg_decode(self): - """A normal scheduler loop for decode worker in disaggregation mode.""" - - while True: - recv_reqs = self.recv_requests() - self.process_input_requests(recv_reqs) - # polling and allocating kv cache - self.process_decode_queue() - batch = self.get_next_disagg_decode_batch_to_run() - self.cur_batch = batch - - if batch: - # Generate fake extend output. - if batch.forward_mode.is_extend(): - # Note: Logprobs should be handled on the prefill engine. - self.stream_output( - batch.reqs, [False for _ in range(len(batch.reqs))] - ) - else: - result = self.run_batch(batch) - self.process_batch_result(batch, result) - - if batch is None and ( - len(self.disagg_decode_transfer_queue.queue) - + len(self.disagg_decode_prealloc_queue.queue) - == 0 - ): - # When the server is idle, do self-check and re-init some states - self.check_memory() - self.new_token_ratio = self.init_new_token_ratio - - self.last_batch = batch - def recv_requests(self) -> List[Req]: """Receive results at tp_rank = 0 and broadcast it to all other TP ranks.""" if self.attn_tp_rank == 0: