From c55550cbf09f9e51eda7a2a5c3be8b118d4d05f2 Mon Sep 17 00:00:00 2001 From: Liangsheng Yin Date: Fri, 25 Apr 2025 17:25:45 +0800 Subject: [PATCH] [PD] Better logs (#5715) --- python/sglang/srt/disaggregation/decode.py | 17 ++++--- python/sglang/srt/disaggregation/prefill.py | 17 ++++--- python/sglang/srt/managers/scheduler.py | 50 ++++++++++++--------- 3 files changed, 50 insertions(+), 34 deletions(-) diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py index a45af7e37..103b6ad05 100644 --- a/python/sglang/srt/disaggregation/decode.py +++ b/python/sglang/srt/disaggregation/decode.py @@ -307,7 +307,7 @@ class DecodeTransferQueue: def extend(self, req_conns) -> None: self.queue.extend(req_conns) - def pop_transferred(self) -> List[Req]: + def pop_transferred(self) -> List[DecodeRequest]: if not self.queue: return [] @@ -330,7 +330,7 @@ class DecodeTransferQueue: assert len(decode_req.req.output_ids) == 0 assert decode_req.req.transferred_output_id is None decode_req.req.transferred_output_id = output_id - transferred_reqs.append(decode_req.req) + transferred_reqs.append(decode_req) indices_to_remove.add(i) elif poll in [ KVPoll.Bootstrapping, @@ -454,7 +454,7 @@ class SchedulerDisaggregationDecodeMixin: return batch, result @torch.no_grad() - def event_loop_normal_disagg_decode(self): + def event_loop_normal_disagg_decode(self: Scheduler): """A normal scheduler loop for decode worker in disaggregation mode.""" while True: @@ -497,7 +497,7 @@ class SchedulerDisaggregationDecodeMixin: self.last_batch = batch @torch.no_grad() - def event_loop_overlap_disagg_decode(self): + def event_loop_overlap_disagg_decode(self: Scheduler): result_queue = deque() self.last_batch: Optional[ScheduleBatch] = None self.last_batch_in_queue = False # last batch is modifed in-place, so we need another variable to track if it's extend @@ -641,8 +641,15 @@ class SchedulerDisaggregationDecodeMixin: def process_decode_queue(self: Scheduler): req_conns = self.disagg_decode_prealloc_queue.pop_preallocated() + + def _num_pre_alloc(req): + return len(req.req.origin_input_ids) + max(len(req.req.output_ids) - 1, 0) + + self.num_tokens_pre_allocated += sum(_num_pre_alloc(req) for req in req_conns) self.disagg_decode_transfer_queue.extend(req_conns) alloc_reqs = ( self.disagg_decode_transfer_queue.pop_transferred() ) # the requests which kv has arrived - self.waiting_queue.extend(alloc_reqs) + self.num_tokens_pre_allocated -= sum(_num_pre_alloc(req) for req in alloc_reqs) + + self.waiting_queue.extend([req.req for req in alloc_reqs]) diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py index 7c10da219..1af7a9b19 100644 --- a/python/sglang/srt/disaggregation/prefill.py +++ b/python/sglang/srt/disaggregation/prefill.py @@ -176,14 +176,14 @@ class SchedulerDisaggregationPrefillMixin: """ @torch.no_grad() - def event_loop_normal_disagg_prefill(self): + def event_loop_normal_disagg_prefill(self: Scheduler): """A normal scheduler loop for prefill worker in disaggregation mode.""" while True: recv_reqs = self.recv_requests() self.process_input_requests(recv_reqs) self.waiting_queue.extend( - self.disagg_prefill_pending_queue.pop_bootstrapped() + self.disagg_prefill_bootstrap_queue.pop_bootstrapped() ) self.process_prefill_chunk() batch = self.get_new_batch_prefill() @@ -214,14 +214,14 @@ class SchedulerDisaggregationPrefillMixin: self.running_batch.batch_is_full = False @torch.no_grad() - def event_loop_overlap_disagg_prefill(self): + def event_loop_overlap_disagg_prefill(self: Scheduler): self.result_queue = deque() while True: recv_reqs = self.recv_requests() self.process_input_requests(recv_reqs) self.waiting_queue.extend( - self.disagg_prefill_pending_queue.pop_bootstrapped() + self.disagg_prefill_bootstrap_queue.pop_bootstrapped() ) self.process_prefill_chunk() batch = self.get_new_batch_prefill() @@ -326,7 +326,7 @@ class SchedulerDisaggregationPrefillMixin: raise Exception("Transferring failed") for req in done_reqs: - self.disagg_prefill_pending_queue.req_to_metadata_buffer_idx_allocator.free( + self.disagg_prefill_bootstrap_queue.req_to_metadata_buffer_idx_allocator.free( req.metadata_buffer_index ) @@ -342,9 +342,8 @@ class SchedulerDisaggregationPrefillMixin: # only finished requests to running_batch. self.last_batch.filter_batch(chunked_req_to_exclude=self.chunked_req) self.tree_cache.cache_unfinished_req(self.chunked_req) - if ( - self.enable_overlap - ): # Delay KV transfer to process_batch_result_disagg_prefill when overlap is enabled to ensure results are resolved + if self.enable_overlap: + # Delay KV transfer to process_batch_result_disagg_prefill when overlap is enabled to ensure results are resolved self.chunked_req.tmp_end_idx = min( len(self.chunked_req.fill_ids), len(self.chunked_req.origin_input_ids), @@ -390,7 +389,7 @@ class SchedulerDisaggregationPrefillMixin: .numpy() ) if last_chunk is True: - self.disagg_prefill_pending_queue.store_prefill_results( + self.disagg_prefill_bootstrap_queue.store_prefill_results( req.metadata_buffer_index, token_id ) page_indices = kv_to_page_indices(kv_indices, page_size) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 303c22059..732f86453 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -578,6 +578,10 @@ class Scheduler( bootstrap_port=self.server_args.disaggregation_bootstrap_port, transfer_backend=self.transfer_backend, ) + + # Metric for pre-allocation + self.num_tokens_pre_allocated = 0 + elif self.disaggregation_mode == DisaggregationMode.PREFILL: # *2 for the headroom. buffer_size = self.max_running_requests * 2 @@ -593,7 +597,7 @@ class Scheduler( ) metadata_buffers = [output_id_buffer] - self.disagg_prefill_pending_queue = PrefillBootstrapQueue( + self.disagg_prefill_bootstrap_queue = PrefillBootstrapQueue( token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(), req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator, metadata_buffers=metadata_buffers, @@ -901,7 +905,7 @@ class Scheduler( def _add_request_to_queue(self, req: Req): req.queue_time_start = time.time() if self.disaggregation_mode == DisaggregationMode.PREFILL: - self.disagg_prefill_pending_queue.add(req) + self.disagg_prefill_bootstrap_queue.add(req) elif self.disaggregation_mode == DisaggregationMode.DECODE: self.disagg_decode_prealloc_queue.add(req) else: @@ -991,8 +995,15 @@ class Scheduler( f"#cached-token: {adder.log_hit_tokens}, " f"token usage: {num_used / self.max_total_num_tokens:.2f}, " f"#running-req: {running_bs}, " - f"#queue-req: {len(self.waiting_queue)}, " ) + + if self.disaggregation_mode == DisaggregationMode.PREFILL: + f += f"#unbootstrapped-req: {len(self.disagg_prefill_bootstrap_queue.queue)}, " + f += f"#queue-req: {len(self.waiting_queue)}, " + f += f"#transferring-req: {len(self.disagg_prefill_inflight_queue)} " + else: + f += f"#queue-req: {len(self.waiting_queue)}" + logger.info(f) if self.enable_metrics: @@ -1028,15 +1039,14 @@ class Scheduler( gap_latency / self.server_args.decode_log_interval ) + msg = ( + f"Decode batch. " + f"#running-req: {num_running_reqs}, " + f"#token: {num_used}, " + f"token usage: {num_used / self.max_total_num_tokens:.2f}, " + ) + if self.spec_algorithm.is_none(): - msg = ( - f"Decode batch. " - f"#running-req: {num_running_reqs}, " - f"#token: {num_used}, " - f"token usage: {num_used / self.max_total_num_tokens:.2f}, " - f"gen throughput (token/s): {self.last_gen_throughput:.2f}, " - f"#queue-req: {len(self.waiting_queue)}, " - ) spec_accept_length = 0 else: spec_accept_length = ( @@ -1045,15 +1055,15 @@ class Scheduler( self.cum_spec_accept_length += self.spec_num_total_accepted_tokens self.cum_spec_accept_count += self.spec_num_total_forward_ct self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0 - msg = ( - f"Decode batch. " - f"#running-req: {num_running_reqs}, " - f"#token: {num_used}, " - f"token usage: {num_used / self.max_total_num_tokens:.2f}, " - f"accept len: {spec_accept_length:.2f}, " - f"gen throughput (token/s): {self.last_gen_throughput:.2f}, " - f"#queue-req: {len(self.waiting_queue)}, " - ) + msg += f"accept len: {spec_accept_length:.2f}, " + + if self.disaggregation_mode == DisaggregationMode.DECODE: + msg += f"pre-allocated usage: {self.num_tokens_pre_allocated / self.max_total_num_tokens:.2f}, " + + msg += ( + f"gen throughput (token/s): {self.last_gen_throughput:.2f}, " + f"#queue-req: {len(self.waiting_queue)}" + ) logger.info(msg) if self.enable_metrics: