diff --git a/python/sglang/srt/managers/data_parallel_controller.py b/python/sglang/srt/managers/data_parallel_controller.py index ee6b0f0d2..56a87516d 100644 --- a/python/sglang/srt/managers/data_parallel_controller.py +++ b/python/sglang/srt/managers/data_parallel_controller.py @@ -21,7 +21,7 @@ import threading import time from collections import deque from enum import Enum, auto -from typing import List +from typing import List, Optional import psutil import setproctitle @@ -36,7 +36,11 @@ from sglang.srt.managers.io_struct import ( ) from sglang.srt.managers.schedule_batch import Req from sglang.srt.managers.scheduler import run_scheduler_process -from sglang.srt.server_args import PortArgs, ServerArgs +from sglang.srt.server_args import ( + DP_ATTENTION_HANDSHAKE_PORT_DELTA, + PortArgs, + ServerArgs, +) from sglang.srt.utils import ( bind_port, configure_logger, @@ -140,22 +144,12 @@ class DataParallelController: self.workers: List[zmq.Socket] = [None] * server_args.dp_size if server_args.enable_dp_attention: - dp_port_args = self.launch_dp_attention_schedulers(server_args, port_args) + self.launch_dp_attention_schedulers(server_args, port_args) self.control_message_step = server_args.tp_size else: - dp_port_args = self.launch_dp_schedulers(server_args, port_args) + self.launch_dp_schedulers(server_args, port_args) self.control_message_step = 1 - # Only node rank 0 runs the real data parallel controller that dispatches the requests. - if server_args.node_rank == 0: - for dp_rank in range(server_args.dp_size): - self.workers[dp_rank] = get_zmq_socket( - self.context, - zmq.PUSH, - dp_port_args[dp_rank].scheduler_input_ipc_name, - True, - ) - self.max_req_input_len = None self.init_dispatcher() @@ -188,13 +182,11 @@ class DataParallelController: threads = [] sockets = [] - dp_port_args = [] ready_events = [] for dp_rank in range(server_args.dp_size): tmp_port_args = PortArgs.init_new(server_args) tmp_port_args.tokenizer_ipc_name = port_args.tokenizer_ipc_name tmp_port_args.detokenizer_ipc_name = port_args.detokenizer_ipc_name - dp_port_args.append(tmp_port_args) # This port is checked free in PortArgs.init_new. # We hold it first so that the next dp worker gets a different port @@ -213,6 +205,14 @@ class DataParallelController: server_args.tp_size * server_args.pp_size * server_args.gpu_id_step ) + if server_args.node_rank == 0: + self.workers[dp_rank] = get_zmq_socket( + self.context, + zmq.PUSH, + tmp_port_args.scheduler_input_ipc_name, + True, + ) + # Free all sockets before starting the threads to launch TP workers for sock in sockets: sock.close() @@ -223,8 +223,6 @@ class DataParallelController: for event in ready_events: event.wait() - return dp_port_args - def launch_tensor_parallel_group_thread( self, server_args: ServerArgs, @@ -241,19 +239,115 @@ class DataParallelController: while True: time.sleep(30 * 24 * 3600) - def launch_dp_attention_schedulers(self, server_args, port_args): - self.launch_tensor_parallel_group(server_args, port_args, 0, None) - dp_port_args = [] - for dp_rank in range(server_args.dp_size): - dp_port_args.append(PortArgs.init_new(server_args, dp_rank)) - return dp_port_args + def _broadcast_worker_ports( + self, server_args: ServerArgs, worker_ports: Optional[List[int]] = None + ) -> List[int]: + """Broadcast worker ports from node 0 to all other nodes. + + Node 0 acts as the server, waiting for all other nodes to connect and + sending them the pre-allocated worker ports. Other nodes act as clients, + connecting to node 0 to receive their copy of the worker ports. + + Args: + server_args: Server arguments containing node configuration. + worker_ports: Pre-allocated worker ports to broadcast. + + Returns: + List of worker ports (same on all nodes after broadcast). + """ + # Determine the endpoint for inter-node communication + if server_args.dist_init_addr is None: + endpoint = f"tcp://127.0.0.1:{server_args.port + DP_ATTENTION_HANDSHAKE_PORT_DELTA}" + else: + endpoint = f"tcp://{server_args.dist_init_addr}" + + if server_args.node_rank == 0: + # Node 0: Broadcast worker ports to all other nodes + return self._broadcast_ports_as_server( + endpoint, server_args.nnodes - 1, worker_ports + ) + else: + # Other nodes: Receive worker ports from node 0 + return self._receive_ports_as_client(endpoint, server_args.node_rank) + + def _broadcast_ports_as_server( + self, endpoint: str, expected_clients: int, worker_ports: List[int] + ) -> List[int]: + """Broadcast worker ports to all client nodes.""" + logger.debug(f"Broadcasting worker ports to {expected_clients} client nodes") + logger.debug(f"Worker ports: {worker_ports}") + + rep_socket = get_zmq_socket(self.context, zmq.REP, endpoint, True) + + try: + connected_clients = 0 + while connected_clients < expected_clients: + # Wait for client handshake + client_rank = rep_socket.recv().decode() + logger.debug(f"Received handshake from node {client_rank}") + + # Send worker ports to client + rep_socket.send_pyobj(worker_ports) + connected_clients += 1 + logger.debug( + f"Sent worker ports to {connected_clients}/{expected_clients} nodes" + ) + + logger.debug("Worker port broadcast completed") + return worker_ports + finally: + rep_socket.close() + + def _receive_ports_as_client(self, endpoint: str, node_rank: int) -> List[int]: + """Receive worker ports from the server node.""" + logger.debug(f"Connecting to node 0 to receive worker ports") + + req_socket = get_zmq_socket(self.context, zmq.REQ, endpoint, False) + req_socket.setsockopt(zmq.RCVTIMEO, 60 * 1000) # 1 minute timeout + req_socket.setsockopt(zmq.SNDTIMEO, 60 * 1000) + + try: + # Send handshake with our node rank + req_socket.send(str(node_rank).encode()) + + # Receive worker ports + worker_ports = req_socket.recv_pyobj() + logger.debug(f"Received {len(worker_ports)} worker ports from node 0") + return worker_ports + except zmq.Again: + logger.error("Timeout waiting for worker ports from node 0") + raise RuntimeError( + "Failed to receive worker ports from node 0 within timeout" + ) + finally: + req_socket.close() + + def launch_dp_attention_schedulers( + self, server_args: ServerArgs, port_args: PortArgs + ): + # Pre-allocate worker ports on node 0 to avoid conflicts + worker_ports = [] + if server_args.node_rank == 0: + for dp_rank in range(server_args.dp_size): + port_and_socket = get_zmq_socket(self.context, zmq.PUSH) + worker_ports.append(port_and_socket[0]) + self.workers[dp_rank] = port_and_socket[1] + logger.debug(f"Assigned port {port_and_socket[0]} to worker {dp_rank}") + + broadcasted_ports = self._broadcast_worker_ports( + server_args, worker_ports if worker_ports else None + ) + self.launch_tensor_parallel_group( + server_args, port_args, 0, None, broadcasted_ports + ) def launch_tensor_parallel_group( self, server_args: ServerArgs, port_args: PortArgs, base_gpu_id: int, - dp_rank: int, + dp_rank: Optional[int], + worker_ports: Optional[List[int]] = None, ): if not server_args.enable_dp_attention: logger.info(f"Launch DP{dp_rank} starting at GPU #{base_gpu_id}.") @@ -290,7 +384,9 @@ class DataParallelController: server_args.dp_size, ) # compute zmq ports for this dp rank - rank_port_args = PortArgs.init_new(server_args, dp_rank) + rank_port_args = PortArgs.init_new( + server_args, dp_rank, worker_ports + ) # Data parallelism reuses the tensor parallelism group, # so all dp ranks should use the same nccl port. rank_port_args.nccl_port = port_args.nccl_port diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 8053be39d..b19b7bb32 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -13,6 +13,8 @@ # ============================================================================== """The arguments of the server.""" +from __future__ import annotations + import argparse import dataclasses import json @@ -3362,6 +3364,7 @@ def prepare_server_args(argv: List[str]) -> ServerArgs: ZMQ_TCP_PORT_DELTA = 233 +DP_ATTENTION_HANDSHAKE_PORT_DELTA = 5 @dataclasses.dataclass @@ -3386,7 +3389,11 @@ class PortArgs: tokenizer_worker_ipc_name: Optional[str] @staticmethod - def init_new(server_args, dp_rank: Optional[int] = None) -> "PortArgs": + def init_new( + server_args: ServerArgs, + dp_rank: Optional[int] = None, + worker_ports: Optional[List[int]] = None, + ) -> PortArgs: if server_args.nccl_port is None: nccl_port = server_args.port + random.randint(100, 1000) while True: @@ -3433,8 +3440,8 @@ class PortArgs: # TokenizerManager to DataParallelController scheduler_input_port = port_base + 4 else: - scheduler_input_port = port_base + 4 + 1 + dp_rank - + assert worker_ports is not None + scheduler_input_port = worker_ports[dp_rank] return PortArgs( tokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base}", scheduler_input_ipc_name=f"tcp://{dist_init_host}:{scheduler_input_port}", diff --git a/python/sglang/srt/utils/common.py b/python/sglang/srt/utils/common.py index b65c311f9..084065b61 100644 --- a/python/sglang/srt/utils/common.py +++ b/python/sglang/srt/utils/common.py @@ -1291,8 +1291,46 @@ def pytorch_profile(name, func, *args, data_size=-1): def get_zmq_socket( - context: zmq.Context, socket_type: zmq.SocketType, endpoint: str, bind: bool -) -> zmq.Socket: + context: zmq.Context, + socket_type: zmq.SocketType, + endpoint: Optional[str] = None, + bind: bool = True, +) -> Union[zmq.Socket, Tuple[int, zmq.Socket]]: + """Create and configure a ZeroMQ socket. + + Args: + context: ZeroMQ context to create the socket from. + socket_type: Type of ZeroMQ socket to create. + endpoint: Optional endpoint to bind/connect to. If None, binds to a random TCP port. + bind: Whether to bind (True) or connect (False) to the endpoint. Ignored if endpoint is None. + + Returns: + If endpoint is None: Tuple of (port, socket) where port is the randomly assigned TCP port. + If endpoint is provided: The configured ZeroMQ socket. + """ + socket = context.socket(socket_type) + + if endpoint is None: + # Bind to random TCP port + config_socket(socket, socket_type) + port = socket.bind_to_random_port("tcp://*") + return port, socket + else: + # Handle IPv6 if endpoint contains brackets + if endpoint.find("[") != -1: + socket.setsockopt(zmq.IPV6, 1) + + config_socket(socket, socket_type) + + if bind: + socket.bind(endpoint) + else: + socket.connect(endpoint) + + return socket + + +def config_socket(socket, socket_type: zmq.SocketType): mem = psutil.virtual_memory() total_mem = mem.total / 1024**3 available_mem = mem.available / 1024**3 @@ -1301,10 +1339,6 @@ def get_zmq_socket( else: buf_size = -1 - socket = context.socket(socket_type) - if endpoint.find("[") != -1: - socket.setsockopt(zmq.IPV6, 1) - def set_send_opt(): socket.setsockopt(zmq.SNDHWM, 0) socket.setsockopt(zmq.SNDBUF, buf_size) @@ -1317,19 +1351,12 @@ def get_zmq_socket( set_send_opt() elif socket_type == zmq.PULL: set_recv_opt() - elif socket_type == zmq.DEALER: + elif socket_type in [zmq.DEALER, zmq.REQ, zmq.REP]: set_send_opt() set_recv_opt() else: raise ValueError(f"Unsupported socket type: {socket_type}") - if bind: - socket.bind(endpoint) - else: - socket.connect(endpoint) - - return socket - def dump_to_file(dirpath, name, value): from sglang.srt.distributed import get_tensor_model_parallel_rank diff --git a/test/srt/test_server_args.py b/test/srt/test_server_args.py index 6096bc13b..4a0cee42b 100644 --- a/test/srt/test_server_args.py +++ b/test/srt/test_server_args.py @@ -75,7 +75,8 @@ class TestPortArgs(unittest.TestCase): server_args.nnodes = 1 server_args.dist_init_addr = "192.168.1.1:25000" - port_args = PortArgs.init_new(server_args, dp_rank=2) + worker_ports = [25006, 25007, 25008, 25009] + port_args = PortArgs.init_new(server_args, dp_rank=2, worker_ports=worker_ports) self.assertTrue(port_args.scheduler_input_ipc_name.endswith(":25008"))