Fix data parallel + tensor parallel (#4499)

This commit is contained in:
Lianmin Zheng
2025-03-17 05:13:16 -07:00
committed by GitHub
parent f2ab37e500
commit 5493c3343e
6 changed files with 53 additions and 16 deletions

View File

@@ -38,7 +38,12 @@ def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_si
return attn_tp_rank, attn_tp_size, dp_rank
def initialize_dp_attention(enable_dp_attention, tp_rank, tp_size, dp_size):
def initialize_dp_attention(
enable_dp_attention: bool,
tp_rank: int,
tp_size: int,
dp_size: int,
):
global _ATTN_TP_GROUP, _ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK, _DP_SIZE
from sglang.srt.layers.sampler import SYNC_TOKEN_IDS_ACROSS_TP
@@ -46,7 +51,13 @@ def initialize_dp_attention(enable_dp_attention, tp_rank, tp_size, dp_size):
_ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK = compute_dp_attention_world_info(
enable_dp_attention, tp_rank, tp_size, dp_size
)
_DP_SIZE = dp_size
if enable_dp_attention:
local_rank = tp_rank % (tp_size // dp_size)
_DP_SIZE = dp_size
else:
local_rank = tp_rank
_DP_SIZE = 1
tp_group = get_tp_group()
_ATTN_TP_GROUP = GroupCoordinator(
@@ -54,7 +65,7 @@ def initialize_dp_attention(enable_dp_attention, tp_rank, tp_size, dp_size):
list(range(head, head + _ATTN_TP_SIZE))
for head in range(0, tp_size, _ATTN_TP_SIZE)
],
tp_rank,
local_rank,
torch.distributed.get_backend(tp_group.device_group),
SYNC_TOKEN_IDS_ACROSS_TP,
False,

View File

@@ -82,10 +82,12 @@ class DataParallelController:
self.scheduler_procs = []
self.workers = [None] * server_args.dp_size
if not server_args.enable_dp_attention:
dp_port_args = self.launch_dp_schedulers(server_args, port_args)
else:
if server_args.enable_dp_attention:
dp_port_args = self.launch_dp_attention_schedulers(server_args, port_args)
self.control_message_step = server_args.tp_size
else:
dp_port_args = self.launch_dp_schedulers(server_args, port_args)
self.control_message_step = 1
# Only node rank 0 runs the real data parallel controller that dispatches the requests.
if server_args.node_rank == 0:
@@ -105,6 +107,7 @@ class DataParallelController:
threads = []
sockets = []
dp_port_args = []
ready_events = []
for dp_rank in range(server_args.dp_size):
tmp_port_args = PortArgs.init_new(server_args)
tmp_port_args.tokenizer_ipc_name = port_args.tokenizer_ipc_name
@@ -115,10 +118,13 @@ class DataParallelController:
# We hold it first so that the next dp worker gets a different port
sockets.append(bind_port(tmp_port_args.nccl_port))
ready_event = threading.Event()
ready_events.append(ready_event)
# Create a thread for each worker
thread = threading.Thread(
target=self.launch_tensor_parallel_group,
args=(server_args, tmp_port_args, base_gpu_id, dp_rank),
target=self.launch_tensor_parallel_group_thread,
args=(server_args, tmp_port_args, base_gpu_id, dp_rank, ready_event),
)
threads.append(thread)
base_gpu_id += server_args.tp_size * server_args.gpu_id_step
@@ -130,11 +136,27 @@ class DataParallelController:
# Start all threads
for thread in threads:
thread.start()
for thread in threads:
thread.join()
for event in ready_events:
event.wait()
return dp_port_args
def launch_tensor_parallel_group_thread(
self,
server_args: ServerArgs,
port_args: PortArgs,
base_gpu_id: int,
dp_rank: int,
ready_event: threading.Event,
):
self.launch_tensor_parallel_group(server_args, port_args, base_gpu_id, dp_rank)
ready_event.set()
# This thread cannot be closed because otherwise the `kill_itself_when_parent_died`
# function in scheduler.py will kill the scheduler.
while True:
pass
def launch_dp_attention_schedulers(self, server_args, port_args):
self.launch_tensor_parallel_group(server_args, port_args, 0, None)
dp_port_args = []
@@ -223,7 +245,7 @@ class DataParallelController:
self.dispatching(recv_req)
else:
# Send other control messages to first worker of tp group
for worker in self.workers[:: self.server_args.tp_size]:
for worker in self.workers[:: self.control_message_step]:
worker.send_pyobj(recv_req)

View File

@@ -1786,7 +1786,7 @@ def run_scheduler_process(
prefix = f" DP{dp_rank} TP{tp_rank}"
# Config the process
# kill_itself_when_parent_died() # This is disabled because it does not work for `--dp 2`
kill_itself_when_parent_died()
setproctitle.setproctitle(f"sglang::scheduler{prefix.replace(' ', '_')}")
faulthandler.enable()
parent_process = psutil.Process().parent()

View File

@@ -16,6 +16,7 @@
from __future__ import annotations
import bisect
import os
from contextlib import contextmanager
from typing import TYPE_CHECKING, Callable
@@ -81,7 +82,9 @@ def patch_model(
# tp_group.ca_comm = None
yield torch.compile(
torch.no_grad()(model.forward),
mode="max-autotune-no-cudagraphs",
mode=os.environ.get(
"SGLANG_TORCH_COMPILE_MODE", "max-autotune-no-cudagraphs"
),
dynamic=False,
)
else: