From a2ba0bc3df90ce4527997505eb2703686c4033d7 Mon Sep 17 00:00:00 2001 From: Shangming Cai Date: Mon, 20 Oct 2025 11:52:42 +0800 Subject: [PATCH] Tiny clean up for PD module and doc (#11747) Signed-off-by: Shangming Cai --- docs/advanced_features/pd_disaggregation.md | 3 +++ python/sglang/srt/disaggregation/common/conn.py | 1 + .../sglang/srt/disaggregation/mooncake/conn.py | 17 ++++++++++++----- python/sglang/srt/disaggregation/nixl/conn.py | 7 +++++++ 4 files changed, 23 insertions(+), 5 deletions(-) diff --git a/docs/advanced_features/pd_disaggregation.md b/docs/advanced_features/pd_disaggregation.md index 460503608..2c74b77d8 100644 --- a/docs/advanced_features/pd_disaggregation.md +++ b/docs/advanced_features/pd_disaggregation.md @@ -41,6 +41,7 @@ uv pip install mooncake-transfer-engine python -m sglang.launch_server \ --model-path meta-llama/Llama-3.1-8B-Instruct \ --disaggregation-mode prefill \ + --port 30000 \ --disaggregation-ib-device mlx5_roce0 python -m sglang.launch_server \ --model-path meta-llama/Llama-3.1-8B-Instruct \ @@ -179,6 +180,7 @@ pip install . --config-settings=setup-args="-Ducx_path=/path/to/ucx" python -m sglang.launch_server \ --model-path meta-llama/Llama-3.1-8B-Instruct \ --disaggregation-mode prefill \ + --port 30000 \ --disaggregation-transfer-backend nixl python -m sglang.launch_server \ --model-path meta-llama/Llama-3.1-8B-Instruct \ @@ -282,6 +284,7 @@ export ENABLE_ASCEND_TRANSFER_WITH_MOONCAKE=true python -m sglang.launch_server \ --model-path meta-llama/Llama-3.1-8B-Instruct \ --disaggregation-mode prefill \ + --port 30000 \ --disaggregation-transfer-backend ascend python -m sglang.launch_server \ --model-path meta-llama/Llama-3.1-8B-Instruct \ diff --git a/python/sglang/srt/disaggregation/common/conn.py b/python/sglang/srt/disaggregation/common/conn.py index e34778a38..5d0fd19c1 100644 --- a/python/sglang/srt/disaggregation/common/conn.py +++ b/python/sglang/srt/disaggregation/common/conn.py @@ -246,6 +246,7 @@ class CommonKVReceiver(BaseKVReceiver): f"Could not fetch prefill parallel info from bootstrap_addr: {self.bootstrap_addr}", ) self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Failed) + self.bootstrap_infos = None return else: logger.debug( diff --git a/python/sglang/srt/disaggregation/mooncake/conn.py b/python/sglang/srt/disaggregation/mooncake/conn.py index 8013f0f0b..af2c75d83 100644 --- a/python/sglang/srt/disaggregation/mooncake/conn.py +++ b/python/sglang/srt/disaggregation/mooncake/conn.py @@ -174,7 +174,7 @@ class MooncakeKVManager(CommonKVManager): cpu_count = os.cpu_count() transfer_thread_pool_size = get_int_env_var( "SGLANG_DISAGGREGATION_THREAD_POOL_SIZE", - min(max(4, int(0.75 * cpu_count) // 8), 12), + min(max(4, int(0.5 * cpu_count) // 8), 12), ) transfer_queue_size = get_int_env_var("SGLANG_DISAGGREGATION_QUEUE_SIZE", 4) self.transfer_queues: List[FastQueue] = [ @@ -190,9 +190,6 @@ class MooncakeKVManager(CommonKVManager): ) for _ in range(transfer_queue_size) ] - self.state_executors = concurrent.futures.ThreadPoolExecutor( - transfer_thread_pool_size // transfer_queue_size - ) for queue, executor in zip(self.transfer_queues, self.executors): threading.Thread( target=self.transfer_worker, args=(queue, executor), daemon=True @@ -641,6 +638,7 @@ class MooncakeKVManager(CommonKVManager): req: TransferInfo, prefill_state_indices: list[int], dst_state_data_ptrs: list[int], + executor: concurrent.futures.ThreadPoolExecutor, ): """Send state or extra pool data with type-specific handling.""" state_type = getattr(self.kv_args, "state_type", "none") @@ -662,7 +660,7 @@ class MooncakeKVManager(CommonKVManager): item_lens=self.kv_args.state_item_lens, prefill_data_indices=prefill_state_indices, dst_data_indices=dst_state_indices, - executor=self.state_executors, + executor=executor, ) else: return 0 @@ -810,6 +808,7 @@ class MooncakeKVManager(CommonKVManager): req, kv_chunk.state_indices, target_rank_registration_info.dst_state_data_ptrs, + executor, ) if self.pp_group.is_last_rank: @@ -1257,6 +1256,14 @@ class MooncakeKVReceiver(CommonKVReceiver): aux_index: Optional[int] = None, state_indices: Optional[List[int]] = None, ): + if self.bootstrap_infos is None: + self.kv_mgr.record_failure( + self.bootstrap_room, + f"Could not fetch prefill parallel info from bootstrap_addr: {self.bootstrap_addr}", + ) + self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Failed) + return + for bootstrap_info in self.bootstrap_infos: sock, lock = self._connect_to_bootstrap_server(bootstrap_info) is_dummy = bootstrap_info["is_dummy"] diff --git a/python/sglang/srt/disaggregation/nixl/conn.py b/python/sglang/srt/disaggregation/nixl/conn.py index 8d9bdffc6..2d56a5b72 100644 --- a/python/sglang/srt/disaggregation/nixl/conn.py +++ b/python/sglang/srt/disaggregation/nixl/conn.py @@ -762,6 +762,13 @@ class NixlKVReceiver(CommonKVReceiver): aux_index: Optional[int] = None, state_indices: Optional[List[int]] = None, ): + if self.bootstrap_infos is None: + logger.error( + f"Could not fetch prefill parallel info from bootstrap_addr: {self.bootstrap_addr}", + ) + self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Failed) + return + for bootstrap_info in self.bootstrap_infos: logger.debug( f"Fetched bootstrap info: {bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}"