From 55a6e644b08a31272a02e5c4f1a7b736376a399e Mon Sep 17 00:00:00 2001 From: Liangsheng Yin Date: Sun, 14 Sep 2025 10:18:23 +0800 Subject: [PATCH] [Hack] Add pd-disaggregation decode polling interval (#10411) --- python/sglang/srt/disaggregation/decode.py | 21 +++++++++++++++------ python/sglang/srt/server_args.py | 9 +++++++++ 2 files changed, 24 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py index b79c8ca87..0bddf3dcc 100644 --- a/python/sglang/srt/disaggregation/decode.py +++ b/python/sglang/srt/disaggregation/decode.py @@ -886,9 +886,18 @@ class SchedulerDisaggregationDecodeMixin: # if there are still retracted requests, we do not allocate new requests return - req_conns = self.disagg_decode_prealloc_queue.pop_preallocated() - self.disagg_decode_transfer_queue.extend(req_conns) - alloc_reqs = ( - self.disagg_decode_transfer_queue.pop_transferred() - ) # the requests which kv has arrived - self.waiting_queue.extend(alloc_reqs) + if not hasattr(self, "polling_count"): + self.polling_count = 0 + self.polling_interval = ( + self.server_args.disaggregation_decode_polling_interval + ) + + self.polling_count = (self.polling_count + 1) % self.polling_interval + + if self.polling_count % self.polling_interval == 0: + req_conns = self.disagg_decode_prealloc_queue.pop_preallocated() + self.disagg_decode_transfer_queue.extend(req_conns) + alloc_reqs = ( + self.disagg_decode_transfer_queue.pop_transferred() + ) # the requests which kv has arrived + self.waiting_queue.extend(alloc_reqs) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 2e2fc458b..ce67d1f7b 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -394,6 +394,9 @@ class ServerArgs: disaggregation_ib_device: Optional[str] = None num_reserved_decode_tokens: int = 512 # used for decode kv cache offload in PD + # FIXME: hack to reduce ITL when decode bs is small + disaggregation_decode_polling_interval: int = 1 + # For model weight update custom_weight_loader: Optional[List[str]] = None weight_loader_disable_mmap: bool = False @@ -2245,6 +2248,12 @@ class ServerArgs: default=ServerArgs.num_reserved_decode_tokens, help="Number of decode tokens that will have memory reserved when adding new request to the running batch.", ) + parser.add_argument( + "--disaggregation-decode-polling-interval", + type=int, + default=ServerArgs.disaggregation_decode_polling_interval, + help="The interval to poll requests in decode server. Can be set to >1 to reduce the overhead of this.", + ) # Custom weight loader parser.add_argument(