Fix broadcast use cuda device lead to memory capacity unbalanced (#5416)
This commit is contained in:
@@ -31,6 +31,7 @@ from sglang.srt.utils import is_port_available
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def find_available_ports(base_port: int, count: int) -> List[int]:
|
||||
"""Find consecutive available ports starting from base_port."""
|
||||
available_ports = []
|
||||
@@ -43,6 +44,7 @@ def find_available_ports(base_port: int, count: int) -> List[int]:
|
||||
|
||||
return available_ports
|
||||
|
||||
|
||||
def group_concurrent_contiguous(
|
||||
src_indices: npt.NDArray[np.int64], dst_indices: npt.NDArray[np.int64]
|
||||
) -> Tuple[List[npt.NDArray[np.int64]], List[npt.NDArray[np.int64]]]:
|
||||
@@ -265,7 +267,9 @@ class MooncakeKVManager(BaseKVManager):
|
||||
)
|
||||
if ret != 0:
|
||||
self.request_status[kv_chunk.room] = KVPoll.Failed
|
||||
self.sync_status_to_decode_endpoint(req.endpoint, req.dst_port, req.room)
|
||||
self.sync_status_to_decode_endpoint(
|
||||
req.endpoint, req.dst_port, req.room
|
||||
)
|
||||
continue
|
||||
|
||||
if kv_chunk.is_last:
|
||||
@@ -279,7 +283,9 @@ class MooncakeKVManager(BaseKVManager):
|
||||
self.request_status[req.room] = (
|
||||
KVPoll.Success if ret == 0 else KVPoll.Failed
|
||||
)
|
||||
self.sync_status_to_decode_endpoint(req.endpoint, req.dst_port, req.room)
|
||||
self.sync_status_to_decode_endpoint(
|
||||
req.endpoint, req.dst_port, req.room
|
||||
)
|
||||
self.transfer_infos.pop(req.room)
|
||||
|
||||
except queue.Empty:
|
||||
@@ -443,13 +449,14 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
||||
prefill_info = response.json()
|
||||
return prefill_info
|
||||
else:
|
||||
logger.error(f"Failed to get prefill server info: {response.status_code}, {response.text}")
|
||||
logger.error(
|
||||
f"Failed to get prefill server info: {response.status_code}, {response.text}"
|
||||
)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching prefill info from bootstrap: {e}")
|
||||
return None
|
||||
|
||||
|
||||
@cache
|
||||
def _connect(self, endpoint: str):
|
||||
socket = zmq.Context().socket(zmq.PUSH)
|
||||
@@ -466,17 +473,25 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
||||
)
|
||||
if prefill_info is None:
|
||||
logger.error(
|
||||
logger.error(f"Could not fetch prefill server info for tp_rank {self.kv_mgr.kv_args.engine_rank}")
|
||||
logger.error(
|
||||
f"Could not fetch prefill server info for tp_rank {self.kv_mgr.kv_args.engine_rank}"
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.kv_mgr.connection_pool[self.kv_mgr.kv_args.engine_rank] = prefill_info
|
||||
self.kv_mgr.connection_pool[self.kv_mgr.kv_args.engine_rank] = (
|
||||
prefill_info
|
||||
)
|
||||
else:
|
||||
prefill_info = self.kv_mgr.connection_pool[self.kv_mgr.kv_args.engine_rank]
|
||||
|
||||
if prefill_info:
|
||||
self.prefill_server_url = f"{prefill_info['serve_ip']}:{prefill_info['serve_port']}"
|
||||
self.prefill_server_url = (
|
||||
f"{prefill_info['serve_ip']}:{prefill_info['serve_port']}"
|
||||
)
|
||||
|
||||
logger.info(f"Fetched prefill server info: {prefill_info} for tp_rank {self.kv_mgr.kv_args.engine_rank}")
|
||||
logger.info(
|
||||
f"Fetched prefill server info: {prefill_info} for tp_rank {self.kv_mgr.kv_args.engine_rank}"
|
||||
)
|
||||
self.handshake_prefill_server(kv_indices, aux_index)
|
||||
|
||||
def handshake_prefill_server(
|
||||
@@ -598,8 +613,13 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
|
||||
# Add lock to make sure thread-safe
|
||||
if role == "Prefill":
|
||||
async with self.lock:
|
||||
self.prefill_port_table[tp_rank] = {"serve_ip": serve_ip, "serve_port": serve_port}
|
||||
logger.info(f"Registered Prefill tp_rank: {tp_rank} with serve_ip: {serve_ip} and serve_port: {serve_port}")
|
||||
self.prefill_port_table[tp_rank] = {
|
||||
"serve_ip": serve_ip,
|
||||
"serve_port": serve_port,
|
||||
}
|
||||
logger.info(
|
||||
f"Registered Prefill tp_rank: {tp_rank} with serve_ip: {serve_ip} and serve_port: {serve_port}"
|
||||
)
|
||||
|
||||
return web.Response(text="OK", status=200)
|
||||
|
||||
|
||||
@@ -118,6 +118,7 @@ class VerlEngine:
|
||||
rank=self._tp_rank,
|
||||
dist_group=self._device_mesh_cpu.get_group(),
|
||||
src=self._device_mesh_cpu.mesh[0].item(),
|
||||
force_cpu_device=False,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
@@ -846,9 +846,12 @@ def broadcast_pyobj(
|
||||
rank: int,
|
||||
dist_group: Optional[torch.distributed.ProcessGroup] = None,
|
||||
src: int = 0,
|
||||
force_cpu_device: bool = True,
|
||||
):
|
||||
"""Broadcast inputs from rank=0 to all other ranks with torch.dist backend."""
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
device = torch.device(
|
||||
"cuda" if torch.cuda.is_available() and not force_cpu_device else "cpu"
|
||||
)
|
||||
|
||||
if rank == 0:
|
||||
if len(data) == 0:
|
||||
|
||||
Reference in New Issue
Block a user