Improve dp attention port assignment scheme (#5889)
Co-authored-by: Cheng Wan <cwan@x.ai>
This commit is contained in:
@@ -21,7 +21,7 @@ import threading
|
|||||||
import time
|
import time
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from enum import Enum, auto
|
from enum import Enum, auto
|
||||||
from typing import List
|
from typing import List, Optional
|
||||||
|
|
||||||
import psutil
|
import psutil
|
||||||
import setproctitle
|
import setproctitle
|
||||||
@@ -36,7 +36,11 @@ from sglang.srt.managers.io_struct import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.managers.schedule_batch import Req
|
from sglang.srt.managers.schedule_batch import Req
|
||||||
from sglang.srt.managers.scheduler import run_scheduler_process
|
from sglang.srt.managers.scheduler import run_scheduler_process
|
||||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
from sglang.srt.server_args import (
|
||||||
|
DP_ATTENTION_HANDSHAKE_PORT_DELTA,
|
||||||
|
PortArgs,
|
||||||
|
ServerArgs,
|
||||||
|
)
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
bind_port,
|
bind_port,
|
||||||
configure_logger,
|
configure_logger,
|
||||||
@@ -140,22 +144,12 @@ class DataParallelController:
|
|||||||
self.workers: List[zmq.Socket] = [None] * server_args.dp_size
|
self.workers: List[zmq.Socket] = [None] * server_args.dp_size
|
||||||
|
|
||||||
if server_args.enable_dp_attention:
|
if server_args.enable_dp_attention:
|
||||||
dp_port_args = self.launch_dp_attention_schedulers(server_args, port_args)
|
self.launch_dp_attention_schedulers(server_args, port_args)
|
||||||
self.control_message_step = server_args.tp_size
|
self.control_message_step = server_args.tp_size
|
||||||
else:
|
else:
|
||||||
dp_port_args = self.launch_dp_schedulers(server_args, port_args)
|
self.launch_dp_schedulers(server_args, port_args)
|
||||||
self.control_message_step = 1
|
self.control_message_step = 1
|
||||||
|
|
||||||
# Only node rank 0 runs the real data parallel controller that dispatches the requests.
|
|
||||||
if server_args.node_rank == 0:
|
|
||||||
for dp_rank in range(server_args.dp_size):
|
|
||||||
self.workers[dp_rank] = get_zmq_socket(
|
|
||||||
self.context,
|
|
||||||
zmq.PUSH,
|
|
||||||
dp_port_args[dp_rank].scheduler_input_ipc_name,
|
|
||||||
True,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.max_req_input_len = None
|
self.max_req_input_len = None
|
||||||
|
|
||||||
self.init_dispatcher()
|
self.init_dispatcher()
|
||||||
@@ -188,13 +182,11 @@ class DataParallelController:
|
|||||||
|
|
||||||
threads = []
|
threads = []
|
||||||
sockets = []
|
sockets = []
|
||||||
dp_port_args = []
|
|
||||||
ready_events = []
|
ready_events = []
|
||||||
for dp_rank in range(server_args.dp_size):
|
for dp_rank in range(server_args.dp_size):
|
||||||
tmp_port_args = PortArgs.init_new(server_args)
|
tmp_port_args = PortArgs.init_new(server_args)
|
||||||
tmp_port_args.tokenizer_ipc_name = port_args.tokenizer_ipc_name
|
tmp_port_args.tokenizer_ipc_name = port_args.tokenizer_ipc_name
|
||||||
tmp_port_args.detokenizer_ipc_name = port_args.detokenizer_ipc_name
|
tmp_port_args.detokenizer_ipc_name = port_args.detokenizer_ipc_name
|
||||||
dp_port_args.append(tmp_port_args)
|
|
||||||
|
|
||||||
# This port is checked free in PortArgs.init_new.
|
# This port is checked free in PortArgs.init_new.
|
||||||
# We hold it first so that the next dp worker gets a different port
|
# We hold it first so that the next dp worker gets a different port
|
||||||
@@ -213,6 +205,14 @@ class DataParallelController:
|
|||||||
server_args.tp_size * server_args.pp_size * server_args.gpu_id_step
|
server_args.tp_size * server_args.pp_size * server_args.gpu_id_step
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if server_args.node_rank == 0:
|
||||||
|
self.workers[dp_rank] = get_zmq_socket(
|
||||||
|
self.context,
|
||||||
|
zmq.PUSH,
|
||||||
|
tmp_port_args.scheduler_input_ipc_name,
|
||||||
|
True,
|
||||||
|
)
|
||||||
|
|
||||||
# Free all sockets before starting the threads to launch TP workers
|
# Free all sockets before starting the threads to launch TP workers
|
||||||
for sock in sockets:
|
for sock in sockets:
|
||||||
sock.close()
|
sock.close()
|
||||||
@@ -223,8 +223,6 @@ class DataParallelController:
|
|||||||
for event in ready_events:
|
for event in ready_events:
|
||||||
event.wait()
|
event.wait()
|
||||||
|
|
||||||
return dp_port_args
|
|
||||||
|
|
||||||
def launch_tensor_parallel_group_thread(
|
def launch_tensor_parallel_group_thread(
|
||||||
self,
|
self,
|
||||||
server_args: ServerArgs,
|
server_args: ServerArgs,
|
||||||
@@ -241,19 +239,115 @@ class DataParallelController:
|
|||||||
while True:
|
while True:
|
||||||
time.sleep(30 * 24 * 3600)
|
time.sleep(30 * 24 * 3600)
|
||||||
|
|
||||||
def launch_dp_attention_schedulers(self, server_args, port_args):
|
def _broadcast_worker_ports(
|
||||||
self.launch_tensor_parallel_group(server_args, port_args, 0, None)
|
self, server_args: ServerArgs, worker_ports: Optional[List[int]] = None
|
||||||
dp_port_args = []
|
) -> List[int]:
|
||||||
for dp_rank in range(server_args.dp_size):
|
"""Broadcast worker ports from node 0 to all other nodes.
|
||||||
dp_port_args.append(PortArgs.init_new(server_args, dp_rank))
|
|
||||||
return dp_port_args
|
Node 0 acts as the server, waiting for all other nodes to connect and
|
||||||
|
sending them the pre-allocated worker ports. Other nodes act as clients,
|
||||||
|
connecting to node 0 to receive their copy of the worker ports.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
server_args: Server arguments containing node configuration.
|
||||||
|
worker_ports: Pre-allocated worker ports to broadcast.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of worker ports (same on all nodes after broadcast).
|
||||||
|
"""
|
||||||
|
# Determine the endpoint for inter-node communication
|
||||||
|
if server_args.dist_init_addr is None:
|
||||||
|
endpoint = f"tcp://127.0.0.1:{server_args.port + DP_ATTENTION_HANDSHAKE_PORT_DELTA}"
|
||||||
|
else:
|
||||||
|
endpoint = f"tcp://{server_args.dist_init_addr}"
|
||||||
|
|
||||||
|
if server_args.node_rank == 0:
|
||||||
|
# Node 0: Broadcast worker ports to all other nodes
|
||||||
|
return self._broadcast_ports_as_server(
|
||||||
|
endpoint, server_args.nnodes - 1, worker_ports
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Other nodes: Receive worker ports from node 0
|
||||||
|
return self._receive_ports_as_client(endpoint, server_args.node_rank)
|
||||||
|
|
||||||
|
def _broadcast_ports_as_server(
|
||||||
|
self, endpoint: str, expected_clients: int, worker_ports: List[int]
|
||||||
|
) -> List[int]:
|
||||||
|
"""Broadcast worker ports to all client nodes."""
|
||||||
|
logger.debug(f"Broadcasting worker ports to {expected_clients} client nodes")
|
||||||
|
logger.debug(f"Worker ports: {worker_ports}")
|
||||||
|
|
||||||
|
rep_socket = get_zmq_socket(self.context, zmq.REP, endpoint, True)
|
||||||
|
|
||||||
|
try:
|
||||||
|
connected_clients = 0
|
||||||
|
while connected_clients < expected_clients:
|
||||||
|
# Wait for client handshake
|
||||||
|
client_rank = rep_socket.recv().decode()
|
||||||
|
logger.debug(f"Received handshake from node {client_rank}")
|
||||||
|
|
||||||
|
# Send worker ports to client
|
||||||
|
rep_socket.send_pyobj(worker_ports)
|
||||||
|
connected_clients += 1
|
||||||
|
logger.debug(
|
||||||
|
f"Sent worker ports to {connected_clients}/{expected_clients} nodes"
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug("Worker port broadcast completed")
|
||||||
|
return worker_ports
|
||||||
|
finally:
|
||||||
|
rep_socket.close()
|
||||||
|
|
||||||
|
def _receive_ports_as_client(self, endpoint: str, node_rank: int) -> List[int]:
|
||||||
|
"""Receive worker ports from the server node."""
|
||||||
|
logger.debug(f"Connecting to node 0 to receive worker ports")
|
||||||
|
|
||||||
|
req_socket = get_zmq_socket(self.context, zmq.REQ, endpoint, False)
|
||||||
|
req_socket.setsockopt(zmq.RCVTIMEO, 60 * 1000) # 1 minute timeout
|
||||||
|
req_socket.setsockopt(zmq.SNDTIMEO, 60 * 1000)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Send handshake with our node rank
|
||||||
|
req_socket.send(str(node_rank).encode())
|
||||||
|
|
||||||
|
# Receive worker ports
|
||||||
|
worker_ports = req_socket.recv_pyobj()
|
||||||
|
logger.debug(f"Received {len(worker_ports)} worker ports from node 0")
|
||||||
|
return worker_ports
|
||||||
|
except zmq.Again:
|
||||||
|
logger.error("Timeout waiting for worker ports from node 0")
|
||||||
|
raise RuntimeError(
|
||||||
|
"Failed to receive worker ports from node 0 within timeout"
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
req_socket.close()
|
||||||
|
|
||||||
|
def launch_dp_attention_schedulers(
|
||||||
|
self, server_args: ServerArgs, port_args: PortArgs
|
||||||
|
):
|
||||||
|
# Pre-allocate worker ports on node 0 to avoid conflicts
|
||||||
|
worker_ports = []
|
||||||
|
if server_args.node_rank == 0:
|
||||||
|
for dp_rank in range(server_args.dp_size):
|
||||||
|
port_and_socket = get_zmq_socket(self.context, zmq.PUSH)
|
||||||
|
worker_ports.append(port_and_socket[0])
|
||||||
|
self.workers[dp_rank] = port_and_socket[1]
|
||||||
|
logger.debug(f"Assigned port {port_and_socket[0]} to worker {dp_rank}")
|
||||||
|
|
||||||
|
broadcasted_ports = self._broadcast_worker_ports(
|
||||||
|
server_args, worker_ports if worker_ports else None
|
||||||
|
)
|
||||||
|
self.launch_tensor_parallel_group(
|
||||||
|
server_args, port_args, 0, None, broadcasted_ports
|
||||||
|
)
|
||||||
|
|
||||||
def launch_tensor_parallel_group(
|
def launch_tensor_parallel_group(
|
||||||
self,
|
self,
|
||||||
server_args: ServerArgs,
|
server_args: ServerArgs,
|
||||||
port_args: PortArgs,
|
port_args: PortArgs,
|
||||||
base_gpu_id: int,
|
base_gpu_id: int,
|
||||||
dp_rank: int,
|
dp_rank: Optional[int],
|
||||||
|
worker_ports: Optional[List[int]] = None,
|
||||||
):
|
):
|
||||||
if not server_args.enable_dp_attention:
|
if not server_args.enable_dp_attention:
|
||||||
logger.info(f"Launch DP{dp_rank} starting at GPU #{base_gpu_id}.")
|
logger.info(f"Launch DP{dp_rank} starting at GPU #{base_gpu_id}.")
|
||||||
@@ -290,7 +384,9 @@ class DataParallelController:
|
|||||||
server_args.dp_size,
|
server_args.dp_size,
|
||||||
)
|
)
|
||||||
# compute zmq ports for this dp rank
|
# compute zmq ports for this dp rank
|
||||||
rank_port_args = PortArgs.init_new(server_args, dp_rank)
|
rank_port_args = PortArgs.init_new(
|
||||||
|
server_args, dp_rank, worker_ports
|
||||||
|
)
|
||||||
# Data parallelism reuses the tensor parallelism group,
|
# Data parallelism reuses the tensor parallelism group,
|
||||||
# so all dp ranks should use the same nccl port.
|
# so all dp ranks should use the same nccl port.
|
||||||
rank_port_args.nccl_port = port_args.nccl_port
|
rank_port_args.nccl_port = port_args.nccl_port
|
||||||
|
|||||||
@@ -13,6 +13,8 @@
|
|||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""The arguments of the server."""
|
"""The arguments of the server."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import json
|
import json
|
||||||
@@ -3362,6 +3364,7 @@ def prepare_server_args(argv: List[str]) -> ServerArgs:
|
|||||||
|
|
||||||
|
|
||||||
ZMQ_TCP_PORT_DELTA = 233
|
ZMQ_TCP_PORT_DELTA = 233
|
||||||
|
DP_ATTENTION_HANDSHAKE_PORT_DELTA = 5
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
@@ -3386,7 +3389,11 @@ class PortArgs:
|
|||||||
tokenizer_worker_ipc_name: Optional[str]
|
tokenizer_worker_ipc_name: Optional[str]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def init_new(server_args, dp_rank: Optional[int] = None) -> "PortArgs":
|
def init_new(
|
||||||
|
server_args: ServerArgs,
|
||||||
|
dp_rank: Optional[int] = None,
|
||||||
|
worker_ports: Optional[List[int]] = None,
|
||||||
|
) -> PortArgs:
|
||||||
if server_args.nccl_port is None:
|
if server_args.nccl_port is None:
|
||||||
nccl_port = server_args.port + random.randint(100, 1000)
|
nccl_port = server_args.port + random.randint(100, 1000)
|
||||||
while True:
|
while True:
|
||||||
@@ -3433,8 +3440,8 @@ class PortArgs:
|
|||||||
# TokenizerManager to DataParallelController
|
# TokenizerManager to DataParallelController
|
||||||
scheduler_input_port = port_base + 4
|
scheduler_input_port = port_base + 4
|
||||||
else:
|
else:
|
||||||
scheduler_input_port = port_base + 4 + 1 + dp_rank
|
assert worker_ports is not None
|
||||||
|
scheduler_input_port = worker_ports[dp_rank]
|
||||||
return PortArgs(
|
return PortArgs(
|
||||||
tokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base}",
|
tokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base}",
|
||||||
scheduler_input_ipc_name=f"tcp://{dist_init_host}:{scheduler_input_port}",
|
scheduler_input_ipc_name=f"tcp://{dist_init_host}:{scheduler_input_port}",
|
||||||
|
|||||||
@@ -1291,8 +1291,46 @@ def pytorch_profile(name, func, *args, data_size=-1):
|
|||||||
|
|
||||||
|
|
||||||
def get_zmq_socket(
|
def get_zmq_socket(
|
||||||
context: zmq.Context, socket_type: zmq.SocketType, endpoint: str, bind: bool
|
context: zmq.Context,
|
||||||
) -> zmq.Socket:
|
socket_type: zmq.SocketType,
|
||||||
|
endpoint: Optional[str] = None,
|
||||||
|
bind: bool = True,
|
||||||
|
) -> Union[zmq.Socket, Tuple[int, zmq.Socket]]:
|
||||||
|
"""Create and configure a ZeroMQ socket.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context: ZeroMQ context to create the socket from.
|
||||||
|
socket_type: Type of ZeroMQ socket to create.
|
||||||
|
endpoint: Optional endpoint to bind/connect to. If None, binds to a random TCP port.
|
||||||
|
bind: Whether to bind (True) or connect (False) to the endpoint. Ignored if endpoint is None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
If endpoint is None: Tuple of (port, socket) where port is the randomly assigned TCP port.
|
||||||
|
If endpoint is provided: The configured ZeroMQ socket.
|
||||||
|
"""
|
||||||
|
socket = context.socket(socket_type)
|
||||||
|
|
||||||
|
if endpoint is None:
|
||||||
|
# Bind to random TCP port
|
||||||
|
config_socket(socket, socket_type)
|
||||||
|
port = socket.bind_to_random_port("tcp://*")
|
||||||
|
return port, socket
|
||||||
|
else:
|
||||||
|
# Handle IPv6 if endpoint contains brackets
|
||||||
|
if endpoint.find("[") != -1:
|
||||||
|
socket.setsockopt(zmq.IPV6, 1)
|
||||||
|
|
||||||
|
config_socket(socket, socket_type)
|
||||||
|
|
||||||
|
if bind:
|
||||||
|
socket.bind(endpoint)
|
||||||
|
else:
|
||||||
|
socket.connect(endpoint)
|
||||||
|
|
||||||
|
return socket
|
||||||
|
|
||||||
|
|
||||||
|
def config_socket(socket, socket_type: zmq.SocketType):
|
||||||
mem = psutil.virtual_memory()
|
mem = psutil.virtual_memory()
|
||||||
total_mem = mem.total / 1024**3
|
total_mem = mem.total / 1024**3
|
||||||
available_mem = mem.available / 1024**3
|
available_mem = mem.available / 1024**3
|
||||||
@@ -1301,10 +1339,6 @@ def get_zmq_socket(
|
|||||||
else:
|
else:
|
||||||
buf_size = -1
|
buf_size = -1
|
||||||
|
|
||||||
socket = context.socket(socket_type)
|
|
||||||
if endpoint.find("[") != -1:
|
|
||||||
socket.setsockopt(zmq.IPV6, 1)
|
|
||||||
|
|
||||||
def set_send_opt():
|
def set_send_opt():
|
||||||
socket.setsockopt(zmq.SNDHWM, 0)
|
socket.setsockopt(zmq.SNDHWM, 0)
|
||||||
socket.setsockopt(zmq.SNDBUF, buf_size)
|
socket.setsockopt(zmq.SNDBUF, buf_size)
|
||||||
@@ -1317,19 +1351,12 @@ def get_zmq_socket(
|
|||||||
set_send_opt()
|
set_send_opt()
|
||||||
elif socket_type == zmq.PULL:
|
elif socket_type == zmq.PULL:
|
||||||
set_recv_opt()
|
set_recv_opt()
|
||||||
elif socket_type == zmq.DEALER:
|
elif socket_type in [zmq.DEALER, zmq.REQ, zmq.REP]:
|
||||||
set_send_opt()
|
set_send_opt()
|
||||||
set_recv_opt()
|
set_recv_opt()
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported socket type: {socket_type}")
|
raise ValueError(f"Unsupported socket type: {socket_type}")
|
||||||
|
|
||||||
if bind:
|
|
||||||
socket.bind(endpoint)
|
|
||||||
else:
|
|
||||||
socket.connect(endpoint)
|
|
||||||
|
|
||||||
return socket
|
|
||||||
|
|
||||||
|
|
||||||
def dump_to_file(dirpath, name, value):
|
def dump_to_file(dirpath, name, value):
|
||||||
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
||||||
|
|||||||
@@ -75,7 +75,8 @@ class TestPortArgs(unittest.TestCase):
|
|||||||
server_args.nnodes = 1
|
server_args.nnodes = 1
|
||||||
server_args.dist_init_addr = "192.168.1.1:25000"
|
server_args.dist_init_addr = "192.168.1.1:25000"
|
||||||
|
|
||||||
port_args = PortArgs.init_new(server_args, dp_rank=2)
|
worker_ports = [25006, 25007, 25008, 25009]
|
||||||
|
port_args = PortArgs.init_new(server_args, dp_rank=2, worker_ports=worker_ports)
|
||||||
|
|
||||||
self.assertTrue(port_args.scheduler_input_ipc_name.endswith(":25008"))
|
self.assertTrue(port_args.scheduler_input_ipc_name.endswith(":25008"))
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user