Launch dp ranks in parallel (#2053)
Co-authored-by: Haotian Liu <6631389+haotian-liu@users.noreply.github.com>
This commit is contained in:
@@ -3,7 +3,6 @@ from __future__ import annotations
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from sglang.srt.layers.attention import AttentionBackend
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
|
||||
@@ -17,6 +17,7 @@ limitations under the License.
|
||||
|
||||
import logging
|
||||
import multiprocessing as mp
|
||||
import threading
|
||||
from enum import Enum, auto
|
||||
|
||||
import zmq
|
||||
@@ -28,6 +29,7 @@ from sglang.srt.managers.io_struct import (
|
||||
from sglang.srt.managers.scheduler import run_scheduler_process
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.utils import (
|
||||
bind_port,
|
||||
configure_logger,
|
||||
get_zmq_socket,
|
||||
kill_parent_process,
|
||||
@@ -80,35 +82,62 @@ class DataParallelController:
|
||||
|
||||
# Start data parallel workers
|
||||
base_gpu_id = 0
|
||||
self.workers = []
|
||||
scheduler_pipe_readers = []
|
||||
self.workers = [None] * server_args.dp_size
|
||||
|
||||
threads = []
|
||||
sockets = []
|
||||
for dp_rank in range(server_args.dp_size):
|
||||
tmp_port_args = PortArgs.init_new(server_args)
|
||||
tmp_port_args.tokenizer_ipc_name = port_args.tokenizer_ipc_name
|
||||
tmp_port_args.detokenizer_ipc_name = port_args.detokenizer_ipc_name
|
||||
|
||||
if server_args.enable_dp_attention:
|
||||
# Share workers for DP and TP
|
||||
send_to, reader = self.launch_tensor_parallel_process(
|
||||
server_args,
|
||||
tmp_port_args,
|
||||
base_gpu_id,
|
||||
dp_rank,
|
||||
)
|
||||
base_gpu_id += 1
|
||||
scheduler_pipe_readers.append(reader)
|
||||
# Data parallelism resues the tensor parallelism group,
|
||||
# so all dp ranks should use the same nccl port.
|
||||
tmp_port_args.nccl_port = port_args.nccl_port
|
||||
else:
|
||||
send_to = self.launch_tensor_parallel_group(
|
||||
server_args,
|
||||
tmp_port_args,
|
||||
base_gpu_id,
|
||||
dp_rank,
|
||||
)
|
||||
base_gpu_id += server_args.tp_size
|
||||
self.workers.append(send_to)
|
||||
# This port is checked free in PortArgs.init_new.
|
||||
# We hold it first so that the next dp worker gets a different port
|
||||
sockets.append(bind_port(tmp_port_args.nccl_port))
|
||||
|
||||
for reader in scheduler_pipe_readers:
|
||||
reader.recv()
|
||||
# Create a thread for each worker
|
||||
thread = threading.Thread(
|
||||
target=self.launch_worker_func,
|
||||
args=(server_args, tmp_port_args, base_gpu_id, dp_rank),
|
||||
)
|
||||
threads.append(thread)
|
||||
base_gpu_id += 1 if server_args.enable_dp_attention else server_args.tp_size
|
||||
|
||||
# Free all sockets before starting the threads to launch TP workers
|
||||
for sock in sockets:
|
||||
sock.close()
|
||||
|
||||
# Start all threads
|
||||
for thread in threads:
|
||||
thread.start()
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
def launch_worker_func(
|
||||
self,
|
||||
server_args: ServerArgs,
|
||||
port_args: PortArgs,
|
||||
base_gpu_id: int,
|
||||
dp_rank: int,
|
||||
):
|
||||
logger.info(f"Launch DP{dp_rank} starting at GPU #{base_gpu_id}.")
|
||||
|
||||
launch_func_ = (
|
||||
self.launch_tensor_parallel_process
|
||||
if server_args.enable_dp_attention
|
||||
else self.launch_tensor_parallel_group
|
||||
)
|
||||
self.workers[dp_rank] = launch_func_(
|
||||
server_args,
|
||||
port_args,
|
||||
base_gpu_id,
|
||||
dp_rank,
|
||||
)
|
||||
|
||||
def launch_tensor_parallel_group(
|
||||
self,
|
||||
@@ -164,8 +193,8 @@ class DataParallelController:
|
||||
send_to = get_zmq_socket(
|
||||
self.context, zmq.PUSH, port_args.scheduler_input_ipc_name
|
||||
)
|
||||
|
||||
return send_to, reader
|
||||
reader.recv()
|
||||
return send_to
|
||||
|
||||
def round_robin_scheduler(self, req):
|
||||
self.workers[self.round_robin_counter].send_pyobj(req)
|
||||
|
||||
@@ -159,7 +159,7 @@ class ServerArgs:
|
||||
if self.tp_size >= 16:
|
||||
self.mem_fraction_static = 0.79
|
||||
elif self.tp_size >= 8:
|
||||
self.mem_fraction_static = 0.83
|
||||
self.mem_fraction_static = 0.82
|
||||
elif self.tp_size >= 4:
|
||||
self.mem_fraction_static = 0.85
|
||||
elif self.tp_size >= 2:
|
||||
@@ -211,7 +211,7 @@ class ServerArgs:
|
||||
self.enable_overlap_schedule = False
|
||||
logger.warning(
|
||||
f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE workload issue. "
|
||||
"The CUDA graph is disabled."
|
||||
"The CUDA graph is disabled. Data parallel size is adjust to be the same as tensor parallel size."
|
||||
)
|
||||
|
||||
if self.enable_overlap_schedule:
|
||||
|
||||
@@ -794,6 +794,15 @@ def add_prometheus_middleware(app):
|
||||
app.routes.append(metrics_route)
|
||||
|
||||
|
||||
def bind_port(port):
|
||||
"""Bind to a specific port, assuming it's available."""
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) # Allows address reuse
|
||||
sock.bind(("", port))
|
||||
sock.listen(1)
|
||||
return sock
|
||||
|
||||
|
||||
def get_amdgpu_memory_capacity():
|
||||
try:
|
||||
# Run rocm-smi and capture the output
|
||||
|
||||
@@ -24,8 +24,6 @@ class TestDPAttention(unittest.TestCase):
|
||||
"--trust-remote-code",
|
||||
"--tp",
|
||||
"2",
|
||||
"--dp",
|
||||
"2",
|
||||
"--enable-dp-attention",
|
||||
],
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user