From 471650dee0435d42658a9588956cea7c15104070 Mon Sep 17 00:00:00 2001 From: lambert0312 Date: Tue, 15 Apr 2025 17:47:26 +0800 Subject: [PATCH] Fix broadcast use cuda device lead to memory capacity unbalanced (#5416) --- .../srt/disaggregation/mooncake/conn.py | 40 ++++++++++++++----- python/sglang/srt/entrypoints/verl_engine.py | 1 + python/sglang/srt/utils.py | 5 ++- 3 files changed, 35 insertions(+), 11 deletions(-) diff --git a/python/sglang/srt/disaggregation/mooncake/conn.py b/python/sglang/srt/disaggregation/mooncake/conn.py index 91903cf85..957fdc559 100644 --- a/python/sglang/srt/disaggregation/mooncake/conn.py +++ b/python/sglang/srt/disaggregation/mooncake/conn.py @@ -31,6 +31,7 @@ from sglang.srt.utils import is_port_available logger = logging.getLogger(__name__) + def find_available_ports(base_port: int, count: int) -> List[int]: """Find consecutive available ports starting from base_port.""" available_ports = [] @@ -43,6 +44,7 @@ def find_available_ports(base_port: int, count: int) -> List[int]: return available_ports + def group_concurrent_contiguous( src_indices: npt.NDArray[np.int64], dst_indices: npt.NDArray[np.int64] ) -> Tuple[List[npt.NDArray[np.int64]], List[npt.NDArray[np.int64]]]: @@ -265,7 +267,9 @@ class MooncakeKVManager(BaseKVManager): ) if ret != 0: self.request_status[kv_chunk.room] = KVPoll.Failed - self.sync_status_to_decode_endpoint(req.endpoint, req.dst_port, req.room) + self.sync_status_to_decode_endpoint( + req.endpoint, req.dst_port, req.room + ) continue if kv_chunk.is_last: @@ -279,7 +283,9 @@ class MooncakeKVManager(BaseKVManager): self.request_status[req.room] = ( KVPoll.Success if ret == 0 else KVPoll.Failed ) - self.sync_status_to_decode_endpoint(req.endpoint, req.dst_port, req.room) + self.sync_status_to_decode_endpoint( + req.endpoint, req.dst_port, req.room + ) self.transfer_infos.pop(req.room) except queue.Empty: @@ -443,13 +449,14 @@ class MooncakeKVReceiver(BaseKVReceiver): prefill_info = response.json() return prefill_info else: - logger.error(f"Failed to get prefill server info: {response.status_code}, {response.text}") + logger.error( + f"Failed to get prefill server info: {response.status_code}, {response.text}" + ) return None except Exception as e: logger.error(f"Error fetching prefill info from bootstrap: {e}") return None - @cache def _connect(self, endpoint: str): socket = zmq.Context().socket(zmq.PUSH) @@ -466,17 +473,25 @@ class MooncakeKVReceiver(BaseKVReceiver): ) if prefill_info is None: logger.error( - logger.error(f"Could not fetch prefill server info for tp_rank {self.kv_mgr.kv_args.engine_rank}") + logger.error( + f"Could not fetch prefill server info for tp_rank {self.kv_mgr.kv_args.engine_rank}" + ) ) else: - self.kv_mgr.connection_pool[self.kv_mgr.kv_args.engine_rank] = prefill_info + self.kv_mgr.connection_pool[self.kv_mgr.kv_args.engine_rank] = ( + prefill_info + ) else: prefill_info = self.kv_mgr.connection_pool[self.kv_mgr.kv_args.engine_rank] if prefill_info: - self.prefill_server_url = f"{prefill_info['serve_ip']}:{prefill_info['serve_port']}" + self.prefill_server_url = ( + f"{prefill_info['serve_ip']}:{prefill_info['serve_port']}" + ) - logger.info(f"Fetched prefill server info: {prefill_info} for tp_rank {self.kv_mgr.kv_args.engine_rank}") + logger.info( + f"Fetched prefill server info: {prefill_info} for tp_rank {self.kv_mgr.kv_args.engine_rank}" + ) self.handshake_prefill_server(kv_indices, aux_index) def handshake_prefill_server( @@ -598,8 +613,13 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer): # Add lock to make sure thread-safe if role == "Prefill": async with self.lock: - self.prefill_port_table[tp_rank] = {"serve_ip": serve_ip, "serve_port": serve_port} - logger.info(f"Registered Prefill tp_rank: {tp_rank} with serve_ip: {serve_ip} and serve_port: {serve_port}") + self.prefill_port_table[tp_rank] = { + "serve_ip": serve_ip, + "serve_port": serve_port, + } + logger.info( + f"Registered Prefill tp_rank: {tp_rank} with serve_ip: {serve_ip} and serve_port: {serve_port}" + ) return web.Response(text="OK", status=200) diff --git a/python/sglang/srt/entrypoints/verl_engine.py b/python/sglang/srt/entrypoints/verl_engine.py index d49392f4c..ef139af27 100644 --- a/python/sglang/srt/entrypoints/verl_engine.py +++ b/python/sglang/srt/entrypoints/verl_engine.py @@ -118,6 +118,7 @@ class VerlEngine: rank=self._tp_rank, dist_group=self._device_mesh_cpu.get_group(), src=self._device_mesh_cpu.mesh[0].item(), + force_cpu_device=False, ) return output diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 2c0159a1f..b2dca816a 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -846,9 +846,12 @@ def broadcast_pyobj( rank: int, dist_group: Optional[torch.distributed.ProcessGroup] = None, src: int = 0, + force_cpu_device: bool = True, ): """Broadcast inputs from rank=0 to all other ranks with torch.dist backend.""" - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + device = torch.device( + "cuda" if torch.cuda.is_available() and not force_cpu_device else "cpu" + ) if rank == 0: if len(data) == 0: