From f719d9aebc1820bad70be738b8473fbf2f1dd370 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sat, 16 Nov 2024 17:13:36 -0800 Subject: [PATCH] Launch dp ranks in parallel (#2053) Co-authored-by: Haotian Liu <6631389+haotian-liu@users.noreply.github.com> --- .../srt/layers/attention/triton_backend.py | 1 - .../srt/managers/data_parallel_controller.py | 75 +++++++++++++------ python/sglang/srt/server_args.py | 4 +- python/sglang/srt/utils.py | 9 +++ test/srt/test_dp_attention.py | 2 - 5 files changed, 63 insertions(+), 28 deletions(-) diff --git a/python/sglang/srt/layers/attention/triton_backend.py b/python/sglang/srt/layers/attention/triton_backend.py index 0c99d1ec4..69b96fdd0 100644 --- a/python/sglang/srt/layers/attention/triton_backend.py +++ b/python/sglang/srt/layers/attention/triton_backend.py @@ -3,7 +3,6 @@ from __future__ import annotations from typing import TYPE_CHECKING import torch -import torch.nn as nn from sglang.srt.layers.attention import AttentionBackend from sglang.srt.managers.schedule_batch import global_server_args_dict diff --git a/python/sglang/srt/managers/data_parallel_controller.py b/python/sglang/srt/managers/data_parallel_controller.py index 472d9174e..10e8cd801 100644 --- a/python/sglang/srt/managers/data_parallel_controller.py +++ b/python/sglang/srt/managers/data_parallel_controller.py @@ -17,6 +17,7 @@ limitations under the License. import logging import multiprocessing as mp +import threading from enum import Enum, auto import zmq @@ -28,6 +29,7 @@ from sglang.srt.managers.io_struct import ( from sglang.srt.managers.scheduler import run_scheduler_process from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.utils import ( + bind_port, configure_logger, get_zmq_socket, kill_parent_process, @@ -80,35 +82,62 @@ class DataParallelController: # Start data parallel workers base_gpu_id = 0 - self.workers = [] - scheduler_pipe_readers = [] + self.workers = [None] * server_args.dp_size + + threads = [] + sockets = [] for dp_rank in range(server_args.dp_size): tmp_port_args = PortArgs.init_new(server_args) tmp_port_args.tokenizer_ipc_name = port_args.tokenizer_ipc_name tmp_port_args.detokenizer_ipc_name = port_args.detokenizer_ipc_name if server_args.enable_dp_attention: - # Share workers for DP and TP - send_to, reader = self.launch_tensor_parallel_process( - server_args, - tmp_port_args, - base_gpu_id, - dp_rank, - ) - base_gpu_id += 1 - scheduler_pipe_readers.append(reader) + # Data parallelism resues the tensor parallelism group, + # so all dp ranks should use the same nccl port. + tmp_port_args.nccl_port = port_args.nccl_port else: - send_to = self.launch_tensor_parallel_group( - server_args, - tmp_port_args, - base_gpu_id, - dp_rank, - ) - base_gpu_id += server_args.tp_size - self.workers.append(send_to) + # This port is checked free in PortArgs.init_new. + # We hold it first so that the next dp worker gets a different port + sockets.append(bind_port(tmp_port_args.nccl_port)) - for reader in scheduler_pipe_readers: - reader.recv() + # Create a thread for each worker + thread = threading.Thread( + target=self.launch_worker_func, + args=(server_args, tmp_port_args, base_gpu_id, dp_rank), + ) + threads.append(thread) + base_gpu_id += 1 if server_args.enable_dp_attention else server_args.tp_size + + # Free all sockets before starting the threads to launch TP workers + for sock in sockets: + sock.close() + + # Start all threads + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + def launch_worker_func( + self, + server_args: ServerArgs, + port_args: PortArgs, + base_gpu_id: int, + dp_rank: int, + ): + logger.info(f"Launch DP{dp_rank} starting at GPU #{base_gpu_id}.") + + launch_func_ = ( + self.launch_tensor_parallel_process + if server_args.enable_dp_attention + else self.launch_tensor_parallel_group + ) + self.workers[dp_rank] = launch_func_( + server_args, + port_args, + base_gpu_id, + dp_rank, + ) def launch_tensor_parallel_group( self, @@ -164,8 +193,8 @@ class DataParallelController: send_to = get_zmq_socket( self.context, zmq.PUSH, port_args.scheduler_input_ipc_name ) - - return send_to, reader + reader.recv() + return send_to def round_robin_scheduler(self, req): self.workers[self.round_robin_counter].send_pyobj(req) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index cb4d19192..26c339e64 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -159,7 +159,7 @@ class ServerArgs: if self.tp_size >= 16: self.mem_fraction_static = 0.79 elif self.tp_size >= 8: - self.mem_fraction_static = 0.83 + self.mem_fraction_static = 0.82 elif self.tp_size >= 4: self.mem_fraction_static = 0.85 elif self.tp_size >= 2: @@ -211,7 +211,7 @@ class ServerArgs: self.enable_overlap_schedule = False logger.warning( f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE workload issue. " - "The CUDA graph is disabled." + "The CUDA graph is disabled. Data parallel size is adjust to be the same as tensor parallel size." ) if self.enable_overlap_schedule: diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index e04ec7ddf..32317ec2e 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -794,6 +794,15 @@ def add_prometheus_middleware(app): app.routes.append(metrics_route) +def bind_port(port): + """Bind to a specific port, assuming it's available.""" + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) # Allows address reuse + sock.bind(("", port)) + sock.listen(1) + return sock + + def get_amdgpu_memory_capacity(): try: # Run rocm-smi and capture the output diff --git a/test/srt/test_dp_attention.py b/test/srt/test_dp_attention.py index 4cfdac228..32fe75a59 100644 --- a/test/srt/test_dp_attention.py +++ b/test/srt/test_dp_attention.py @@ -24,8 +24,6 @@ class TestDPAttention(unittest.TestCase): "--trust-remote-code", "--tp", "2", - "--dp", - "2", "--enable-dp-attention", ], )