Improve dp attention port assignment scheme (#5889)
Co-authored-by: Cheng Wan <cwan@x.ai>
This commit is contained in:
@@ -21,7 +21,7 @@ import threading
|
||||
import time
|
||||
from collections import deque
|
||||
from enum import Enum, auto
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
import psutil
|
||||
import setproctitle
|
||||
@@ -36,7 +36,11 @@ from sglang.srt.managers.io_struct import (
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import Req
|
||||
from sglang.srt.managers.scheduler import run_scheduler_process
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.server_args import (
|
||||
DP_ATTENTION_HANDSHAKE_PORT_DELTA,
|
||||
PortArgs,
|
||||
ServerArgs,
|
||||
)
|
||||
from sglang.srt.utils import (
|
||||
bind_port,
|
||||
configure_logger,
|
||||
@@ -140,22 +144,12 @@ class DataParallelController:
|
||||
self.workers: List[zmq.Socket] = [None] * server_args.dp_size
|
||||
|
||||
if server_args.enable_dp_attention:
|
||||
dp_port_args = self.launch_dp_attention_schedulers(server_args, port_args)
|
||||
self.launch_dp_attention_schedulers(server_args, port_args)
|
||||
self.control_message_step = server_args.tp_size
|
||||
else:
|
||||
dp_port_args = self.launch_dp_schedulers(server_args, port_args)
|
||||
self.launch_dp_schedulers(server_args, port_args)
|
||||
self.control_message_step = 1
|
||||
|
||||
# Only node rank 0 runs the real data parallel controller that dispatches the requests.
|
||||
if server_args.node_rank == 0:
|
||||
for dp_rank in range(server_args.dp_size):
|
||||
self.workers[dp_rank] = get_zmq_socket(
|
||||
self.context,
|
||||
zmq.PUSH,
|
||||
dp_port_args[dp_rank].scheduler_input_ipc_name,
|
||||
True,
|
||||
)
|
||||
|
||||
self.max_req_input_len = None
|
||||
|
||||
self.init_dispatcher()
|
||||
@@ -188,13 +182,11 @@ class DataParallelController:
|
||||
|
||||
threads = []
|
||||
sockets = []
|
||||
dp_port_args = []
|
||||
ready_events = []
|
||||
for dp_rank in range(server_args.dp_size):
|
||||
tmp_port_args = PortArgs.init_new(server_args)
|
||||
tmp_port_args.tokenizer_ipc_name = port_args.tokenizer_ipc_name
|
||||
tmp_port_args.detokenizer_ipc_name = port_args.detokenizer_ipc_name
|
||||
dp_port_args.append(tmp_port_args)
|
||||
|
||||
# This port is checked free in PortArgs.init_new.
|
||||
# We hold it first so that the next dp worker gets a different port
|
||||
@@ -213,6 +205,14 @@ class DataParallelController:
|
||||
server_args.tp_size * server_args.pp_size * server_args.gpu_id_step
|
||||
)
|
||||
|
||||
if server_args.node_rank == 0:
|
||||
self.workers[dp_rank] = get_zmq_socket(
|
||||
self.context,
|
||||
zmq.PUSH,
|
||||
tmp_port_args.scheduler_input_ipc_name,
|
||||
True,
|
||||
)
|
||||
|
||||
# Free all sockets before starting the threads to launch TP workers
|
||||
for sock in sockets:
|
||||
sock.close()
|
||||
@@ -223,8 +223,6 @@ class DataParallelController:
|
||||
for event in ready_events:
|
||||
event.wait()
|
||||
|
||||
return dp_port_args
|
||||
|
||||
def launch_tensor_parallel_group_thread(
|
||||
self,
|
||||
server_args: ServerArgs,
|
||||
@@ -241,19 +239,115 @@ class DataParallelController:
|
||||
while True:
|
||||
time.sleep(30 * 24 * 3600)
|
||||
|
||||
def launch_dp_attention_schedulers(self, server_args, port_args):
|
||||
self.launch_tensor_parallel_group(server_args, port_args, 0, None)
|
||||
dp_port_args = []
|
||||
for dp_rank in range(server_args.dp_size):
|
||||
dp_port_args.append(PortArgs.init_new(server_args, dp_rank))
|
||||
return dp_port_args
|
||||
def _broadcast_worker_ports(
|
||||
self, server_args: ServerArgs, worker_ports: Optional[List[int]] = None
|
||||
) -> List[int]:
|
||||
"""Broadcast worker ports from node 0 to all other nodes.
|
||||
|
||||
Node 0 acts as the server, waiting for all other nodes to connect and
|
||||
sending them the pre-allocated worker ports. Other nodes act as clients,
|
||||
connecting to node 0 to receive their copy of the worker ports.
|
||||
|
||||
Args:
|
||||
server_args: Server arguments containing node configuration.
|
||||
worker_ports: Pre-allocated worker ports to broadcast.
|
||||
|
||||
Returns:
|
||||
List of worker ports (same on all nodes after broadcast).
|
||||
"""
|
||||
# Determine the endpoint for inter-node communication
|
||||
if server_args.dist_init_addr is None:
|
||||
endpoint = f"tcp://127.0.0.1:{server_args.port + DP_ATTENTION_HANDSHAKE_PORT_DELTA}"
|
||||
else:
|
||||
endpoint = f"tcp://{server_args.dist_init_addr}"
|
||||
|
||||
if server_args.node_rank == 0:
|
||||
# Node 0: Broadcast worker ports to all other nodes
|
||||
return self._broadcast_ports_as_server(
|
||||
endpoint, server_args.nnodes - 1, worker_ports
|
||||
)
|
||||
else:
|
||||
# Other nodes: Receive worker ports from node 0
|
||||
return self._receive_ports_as_client(endpoint, server_args.node_rank)
|
||||
|
||||
def _broadcast_ports_as_server(
|
||||
self, endpoint: str, expected_clients: int, worker_ports: List[int]
|
||||
) -> List[int]:
|
||||
"""Broadcast worker ports to all client nodes."""
|
||||
logger.debug(f"Broadcasting worker ports to {expected_clients} client nodes")
|
||||
logger.debug(f"Worker ports: {worker_ports}")
|
||||
|
||||
rep_socket = get_zmq_socket(self.context, zmq.REP, endpoint, True)
|
||||
|
||||
try:
|
||||
connected_clients = 0
|
||||
while connected_clients < expected_clients:
|
||||
# Wait for client handshake
|
||||
client_rank = rep_socket.recv().decode()
|
||||
logger.debug(f"Received handshake from node {client_rank}")
|
||||
|
||||
# Send worker ports to client
|
||||
rep_socket.send_pyobj(worker_ports)
|
||||
connected_clients += 1
|
||||
logger.debug(
|
||||
f"Sent worker ports to {connected_clients}/{expected_clients} nodes"
|
||||
)
|
||||
|
||||
logger.debug("Worker port broadcast completed")
|
||||
return worker_ports
|
||||
finally:
|
||||
rep_socket.close()
|
||||
|
||||
def _receive_ports_as_client(self, endpoint: str, node_rank: int) -> List[int]:
|
||||
"""Receive worker ports from the server node."""
|
||||
logger.debug(f"Connecting to node 0 to receive worker ports")
|
||||
|
||||
req_socket = get_zmq_socket(self.context, zmq.REQ, endpoint, False)
|
||||
req_socket.setsockopt(zmq.RCVTIMEO, 60 * 1000) # 1 minute timeout
|
||||
req_socket.setsockopt(zmq.SNDTIMEO, 60 * 1000)
|
||||
|
||||
try:
|
||||
# Send handshake with our node rank
|
||||
req_socket.send(str(node_rank).encode())
|
||||
|
||||
# Receive worker ports
|
||||
worker_ports = req_socket.recv_pyobj()
|
||||
logger.debug(f"Received {len(worker_ports)} worker ports from node 0")
|
||||
return worker_ports
|
||||
except zmq.Again:
|
||||
logger.error("Timeout waiting for worker ports from node 0")
|
||||
raise RuntimeError(
|
||||
"Failed to receive worker ports from node 0 within timeout"
|
||||
)
|
||||
finally:
|
||||
req_socket.close()
|
||||
|
||||
def launch_dp_attention_schedulers(
|
||||
self, server_args: ServerArgs, port_args: PortArgs
|
||||
):
|
||||
# Pre-allocate worker ports on node 0 to avoid conflicts
|
||||
worker_ports = []
|
||||
if server_args.node_rank == 0:
|
||||
for dp_rank in range(server_args.dp_size):
|
||||
port_and_socket = get_zmq_socket(self.context, zmq.PUSH)
|
||||
worker_ports.append(port_and_socket[0])
|
||||
self.workers[dp_rank] = port_and_socket[1]
|
||||
logger.debug(f"Assigned port {port_and_socket[0]} to worker {dp_rank}")
|
||||
|
||||
broadcasted_ports = self._broadcast_worker_ports(
|
||||
server_args, worker_ports if worker_ports else None
|
||||
)
|
||||
self.launch_tensor_parallel_group(
|
||||
server_args, port_args, 0, None, broadcasted_ports
|
||||
)
|
||||
|
||||
def launch_tensor_parallel_group(
|
||||
self,
|
||||
server_args: ServerArgs,
|
||||
port_args: PortArgs,
|
||||
base_gpu_id: int,
|
||||
dp_rank: int,
|
||||
dp_rank: Optional[int],
|
||||
worker_ports: Optional[List[int]] = None,
|
||||
):
|
||||
if not server_args.enable_dp_attention:
|
||||
logger.info(f"Launch DP{dp_rank} starting at GPU #{base_gpu_id}.")
|
||||
@@ -290,7 +384,9 @@ class DataParallelController:
|
||||
server_args.dp_size,
|
||||
)
|
||||
# compute zmq ports for this dp rank
|
||||
rank_port_args = PortArgs.init_new(server_args, dp_rank)
|
||||
rank_port_args = PortArgs.init_new(
|
||||
server_args, dp_rank, worker_ports
|
||||
)
|
||||
# Data parallelism reuses the tensor parallelism group,
|
||||
# so all dp ranks should use the same nccl port.
|
||||
rank_port_args.nccl_port = port_args.nccl_port
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
# ==============================================================================
|
||||
"""The arguments of the server."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import dataclasses
|
||||
import json
|
||||
@@ -3362,6 +3364,7 @@ def prepare_server_args(argv: List[str]) -> ServerArgs:
|
||||
|
||||
|
||||
ZMQ_TCP_PORT_DELTA = 233
|
||||
DP_ATTENTION_HANDSHAKE_PORT_DELTA = 5
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@@ -3386,7 +3389,11 @@ class PortArgs:
|
||||
tokenizer_worker_ipc_name: Optional[str]
|
||||
|
||||
@staticmethod
|
||||
def init_new(server_args, dp_rank: Optional[int] = None) -> "PortArgs":
|
||||
def init_new(
|
||||
server_args: ServerArgs,
|
||||
dp_rank: Optional[int] = None,
|
||||
worker_ports: Optional[List[int]] = None,
|
||||
) -> PortArgs:
|
||||
if server_args.nccl_port is None:
|
||||
nccl_port = server_args.port + random.randint(100, 1000)
|
||||
while True:
|
||||
@@ -3433,8 +3440,8 @@ class PortArgs:
|
||||
# TokenizerManager to DataParallelController
|
||||
scheduler_input_port = port_base + 4
|
||||
else:
|
||||
scheduler_input_port = port_base + 4 + 1 + dp_rank
|
||||
|
||||
assert worker_ports is not None
|
||||
scheduler_input_port = worker_ports[dp_rank]
|
||||
return PortArgs(
|
||||
tokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base}",
|
||||
scheduler_input_ipc_name=f"tcp://{dist_init_host}:{scheduler_input_port}",
|
||||
|
||||
@@ -1291,8 +1291,46 @@ def pytorch_profile(name, func, *args, data_size=-1):
|
||||
|
||||
|
||||
def get_zmq_socket(
|
||||
context: zmq.Context, socket_type: zmq.SocketType, endpoint: str, bind: bool
|
||||
) -> zmq.Socket:
|
||||
context: zmq.Context,
|
||||
socket_type: zmq.SocketType,
|
||||
endpoint: Optional[str] = None,
|
||||
bind: bool = True,
|
||||
) -> Union[zmq.Socket, Tuple[int, zmq.Socket]]:
|
||||
"""Create and configure a ZeroMQ socket.
|
||||
|
||||
Args:
|
||||
context: ZeroMQ context to create the socket from.
|
||||
socket_type: Type of ZeroMQ socket to create.
|
||||
endpoint: Optional endpoint to bind/connect to. If None, binds to a random TCP port.
|
||||
bind: Whether to bind (True) or connect (False) to the endpoint. Ignored if endpoint is None.
|
||||
|
||||
Returns:
|
||||
If endpoint is None: Tuple of (port, socket) where port is the randomly assigned TCP port.
|
||||
If endpoint is provided: The configured ZeroMQ socket.
|
||||
"""
|
||||
socket = context.socket(socket_type)
|
||||
|
||||
if endpoint is None:
|
||||
# Bind to random TCP port
|
||||
config_socket(socket, socket_type)
|
||||
port = socket.bind_to_random_port("tcp://*")
|
||||
return port, socket
|
||||
else:
|
||||
# Handle IPv6 if endpoint contains brackets
|
||||
if endpoint.find("[") != -1:
|
||||
socket.setsockopt(zmq.IPV6, 1)
|
||||
|
||||
config_socket(socket, socket_type)
|
||||
|
||||
if bind:
|
||||
socket.bind(endpoint)
|
||||
else:
|
||||
socket.connect(endpoint)
|
||||
|
||||
return socket
|
||||
|
||||
|
||||
def config_socket(socket, socket_type: zmq.SocketType):
|
||||
mem = psutil.virtual_memory()
|
||||
total_mem = mem.total / 1024**3
|
||||
available_mem = mem.available / 1024**3
|
||||
@@ -1301,10 +1339,6 @@ def get_zmq_socket(
|
||||
else:
|
||||
buf_size = -1
|
||||
|
||||
socket = context.socket(socket_type)
|
||||
if endpoint.find("[") != -1:
|
||||
socket.setsockopt(zmq.IPV6, 1)
|
||||
|
||||
def set_send_opt():
|
||||
socket.setsockopt(zmq.SNDHWM, 0)
|
||||
socket.setsockopt(zmq.SNDBUF, buf_size)
|
||||
@@ -1317,19 +1351,12 @@ def get_zmq_socket(
|
||||
set_send_opt()
|
||||
elif socket_type == zmq.PULL:
|
||||
set_recv_opt()
|
||||
elif socket_type == zmq.DEALER:
|
||||
elif socket_type in [zmq.DEALER, zmq.REQ, zmq.REP]:
|
||||
set_send_opt()
|
||||
set_recv_opt()
|
||||
else:
|
||||
raise ValueError(f"Unsupported socket type: {socket_type}")
|
||||
|
||||
if bind:
|
||||
socket.bind(endpoint)
|
||||
else:
|
||||
socket.connect(endpoint)
|
||||
|
||||
return socket
|
||||
|
||||
|
||||
def dump_to_file(dirpath, name, value):
|
||||
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
||||
|
||||
Reference in New Issue
Block a user