Fix data parallel + tensor parallel (#4499)
This commit is contained in:
1
.github/workflows/pr-test.yml
vendored
1
.github/workflows/pr-test.yml
vendored
@@ -290,6 +290,7 @@ jobs:
|
|||||||
python3 test_moe_eval_accuracy_large.py
|
python3 test_moe_eval_accuracy_large.py
|
||||||
|
|
||||||
finish:
|
finish:
|
||||||
|
if: always()
|
||||||
needs: [
|
needs: [
|
||||||
unit-test-frontend, unit-test-backend-1-gpu, unit-test-backend-2-gpu,
|
unit-test-frontend, unit-test-backend-1-gpu, unit-test-backend-2-gpu,
|
||||||
performance-test-1-gpu-part-1, performance-test-1-gpu-part-2, performance-test-2-gpu,
|
performance-test-1-gpu-part-1, performance-test-1-gpu-part-2, performance-test-2-gpu,
|
||||||
|
|||||||
@@ -38,7 +38,12 @@ def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_si
|
|||||||
return attn_tp_rank, attn_tp_size, dp_rank
|
return attn_tp_rank, attn_tp_size, dp_rank
|
||||||
|
|
||||||
|
|
||||||
def initialize_dp_attention(enable_dp_attention, tp_rank, tp_size, dp_size):
|
def initialize_dp_attention(
|
||||||
|
enable_dp_attention: bool,
|
||||||
|
tp_rank: int,
|
||||||
|
tp_size: int,
|
||||||
|
dp_size: int,
|
||||||
|
):
|
||||||
global _ATTN_TP_GROUP, _ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK, _DP_SIZE
|
global _ATTN_TP_GROUP, _ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK, _DP_SIZE
|
||||||
|
|
||||||
from sglang.srt.layers.sampler import SYNC_TOKEN_IDS_ACROSS_TP
|
from sglang.srt.layers.sampler import SYNC_TOKEN_IDS_ACROSS_TP
|
||||||
@@ -46,7 +51,13 @@ def initialize_dp_attention(enable_dp_attention, tp_rank, tp_size, dp_size):
|
|||||||
_ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK = compute_dp_attention_world_info(
|
_ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK = compute_dp_attention_world_info(
|
||||||
enable_dp_attention, tp_rank, tp_size, dp_size
|
enable_dp_attention, tp_rank, tp_size, dp_size
|
||||||
)
|
)
|
||||||
_DP_SIZE = dp_size
|
|
||||||
|
if enable_dp_attention:
|
||||||
|
local_rank = tp_rank % (tp_size // dp_size)
|
||||||
|
_DP_SIZE = dp_size
|
||||||
|
else:
|
||||||
|
local_rank = tp_rank
|
||||||
|
_DP_SIZE = 1
|
||||||
|
|
||||||
tp_group = get_tp_group()
|
tp_group = get_tp_group()
|
||||||
_ATTN_TP_GROUP = GroupCoordinator(
|
_ATTN_TP_GROUP = GroupCoordinator(
|
||||||
@@ -54,7 +65,7 @@ def initialize_dp_attention(enable_dp_attention, tp_rank, tp_size, dp_size):
|
|||||||
list(range(head, head + _ATTN_TP_SIZE))
|
list(range(head, head + _ATTN_TP_SIZE))
|
||||||
for head in range(0, tp_size, _ATTN_TP_SIZE)
|
for head in range(0, tp_size, _ATTN_TP_SIZE)
|
||||||
],
|
],
|
||||||
tp_rank,
|
local_rank,
|
||||||
torch.distributed.get_backend(tp_group.device_group),
|
torch.distributed.get_backend(tp_group.device_group),
|
||||||
SYNC_TOKEN_IDS_ACROSS_TP,
|
SYNC_TOKEN_IDS_ACROSS_TP,
|
||||||
False,
|
False,
|
||||||
|
|||||||
@@ -82,10 +82,12 @@ class DataParallelController:
|
|||||||
self.scheduler_procs = []
|
self.scheduler_procs = []
|
||||||
self.workers = [None] * server_args.dp_size
|
self.workers = [None] * server_args.dp_size
|
||||||
|
|
||||||
if not server_args.enable_dp_attention:
|
if server_args.enable_dp_attention:
|
||||||
dp_port_args = self.launch_dp_schedulers(server_args, port_args)
|
|
||||||
else:
|
|
||||||
dp_port_args = self.launch_dp_attention_schedulers(server_args, port_args)
|
dp_port_args = self.launch_dp_attention_schedulers(server_args, port_args)
|
||||||
|
self.control_message_step = server_args.tp_size
|
||||||
|
else:
|
||||||
|
dp_port_args = self.launch_dp_schedulers(server_args, port_args)
|
||||||
|
self.control_message_step = 1
|
||||||
|
|
||||||
# Only node rank 0 runs the real data parallel controller that dispatches the requests.
|
# Only node rank 0 runs the real data parallel controller that dispatches the requests.
|
||||||
if server_args.node_rank == 0:
|
if server_args.node_rank == 0:
|
||||||
@@ -105,6 +107,7 @@ class DataParallelController:
|
|||||||
threads = []
|
threads = []
|
||||||
sockets = []
|
sockets = []
|
||||||
dp_port_args = []
|
dp_port_args = []
|
||||||
|
ready_events = []
|
||||||
for dp_rank in range(server_args.dp_size):
|
for dp_rank in range(server_args.dp_size):
|
||||||
tmp_port_args = PortArgs.init_new(server_args)
|
tmp_port_args = PortArgs.init_new(server_args)
|
||||||
tmp_port_args.tokenizer_ipc_name = port_args.tokenizer_ipc_name
|
tmp_port_args.tokenizer_ipc_name = port_args.tokenizer_ipc_name
|
||||||
@@ -115,10 +118,13 @@ class DataParallelController:
|
|||||||
# We hold it first so that the next dp worker gets a different port
|
# We hold it first so that the next dp worker gets a different port
|
||||||
sockets.append(bind_port(tmp_port_args.nccl_port))
|
sockets.append(bind_port(tmp_port_args.nccl_port))
|
||||||
|
|
||||||
|
ready_event = threading.Event()
|
||||||
|
ready_events.append(ready_event)
|
||||||
|
|
||||||
# Create a thread for each worker
|
# Create a thread for each worker
|
||||||
thread = threading.Thread(
|
thread = threading.Thread(
|
||||||
target=self.launch_tensor_parallel_group,
|
target=self.launch_tensor_parallel_group_thread,
|
||||||
args=(server_args, tmp_port_args, base_gpu_id, dp_rank),
|
args=(server_args, tmp_port_args, base_gpu_id, dp_rank, ready_event),
|
||||||
)
|
)
|
||||||
threads.append(thread)
|
threads.append(thread)
|
||||||
base_gpu_id += server_args.tp_size * server_args.gpu_id_step
|
base_gpu_id += server_args.tp_size * server_args.gpu_id_step
|
||||||
@@ -130,11 +136,27 @@ class DataParallelController:
|
|||||||
# Start all threads
|
# Start all threads
|
||||||
for thread in threads:
|
for thread in threads:
|
||||||
thread.start()
|
thread.start()
|
||||||
for thread in threads:
|
for event in ready_events:
|
||||||
thread.join()
|
event.wait()
|
||||||
|
|
||||||
return dp_port_args
|
return dp_port_args
|
||||||
|
|
||||||
|
def launch_tensor_parallel_group_thread(
|
||||||
|
self,
|
||||||
|
server_args: ServerArgs,
|
||||||
|
port_args: PortArgs,
|
||||||
|
base_gpu_id: int,
|
||||||
|
dp_rank: int,
|
||||||
|
ready_event: threading.Event,
|
||||||
|
):
|
||||||
|
self.launch_tensor_parallel_group(server_args, port_args, base_gpu_id, dp_rank)
|
||||||
|
ready_event.set()
|
||||||
|
|
||||||
|
# This thread cannot be closed because otherwise the `kill_itself_when_parent_died`
|
||||||
|
# function in scheduler.py will kill the scheduler.
|
||||||
|
while True:
|
||||||
|
pass
|
||||||
|
|
||||||
def launch_dp_attention_schedulers(self, server_args, port_args):
|
def launch_dp_attention_schedulers(self, server_args, port_args):
|
||||||
self.launch_tensor_parallel_group(server_args, port_args, 0, None)
|
self.launch_tensor_parallel_group(server_args, port_args, 0, None)
|
||||||
dp_port_args = []
|
dp_port_args = []
|
||||||
@@ -223,7 +245,7 @@ class DataParallelController:
|
|||||||
self.dispatching(recv_req)
|
self.dispatching(recv_req)
|
||||||
else:
|
else:
|
||||||
# Send other control messages to first worker of tp group
|
# Send other control messages to first worker of tp group
|
||||||
for worker in self.workers[:: self.server_args.tp_size]:
|
for worker in self.workers[:: self.control_message_step]:
|
||||||
worker.send_pyobj(recv_req)
|
worker.send_pyobj(recv_req)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1786,7 +1786,7 @@ def run_scheduler_process(
|
|||||||
prefix = f" DP{dp_rank} TP{tp_rank}"
|
prefix = f" DP{dp_rank} TP{tp_rank}"
|
||||||
|
|
||||||
# Config the process
|
# Config the process
|
||||||
# kill_itself_when_parent_died() # This is disabled because it does not work for `--dp 2`
|
kill_itself_when_parent_died()
|
||||||
setproctitle.setproctitle(f"sglang::scheduler{prefix.replace(' ', '_')}")
|
setproctitle.setproctitle(f"sglang::scheduler{prefix.replace(' ', '_')}")
|
||||||
faulthandler.enable()
|
faulthandler.enable()
|
||||||
parent_process = psutil.Process().parent()
|
parent_process = psutil.Process().parent()
|
||||||
|
|||||||
@@ -16,6 +16,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import bisect
|
import bisect
|
||||||
|
import os
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import TYPE_CHECKING, Callable
|
from typing import TYPE_CHECKING, Callable
|
||||||
|
|
||||||
@@ -81,7 +82,9 @@ def patch_model(
|
|||||||
# tp_group.ca_comm = None
|
# tp_group.ca_comm = None
|
||||||
yield torch.compile(
|
yield torch.compile(
|
||||||
torch.no_grad()(model.forward),
|
torch.no_grad()(model.forward),
|
||||||
mode="max-autotune-no-cudagraphs",
|
mode=os.environ.get(
|
||||||
|
"SGLANG_TORCH_COMPILE_MODE", "max-autotune-no-cudagraphs"
|
||||||
|
),
|
||||||
dynamic=False,
|
dynamic=False,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ class TestDataParallelism(unittest.TestCase):
|
|||||||
cls.model,
|
cls.model,
|
||||||
cls.base_url,
|
cls.base_url,
|
||||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
other_args=["--dp", "2"],
|
other_args=["--dp", 2],
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -52,7 +52,7 @@ class TestDataParallelism(unittest.TestCase):
|
|||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
|
|
||||||
# pause a few seconds then send again
|
# pause a few seconds then send again
|
||||||
time.sleep(5)
|
time.sleep(1)
|
||||||
|
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
self.base_url + "/update_weights_from_disk",
|
self.base_url + "/update_weights_from_disk",
|
||||||
@@ -67,7 +67,7 @@ class TestDataParallelism(unittest.TestCase):
|
|||||||
response = requests.get(self.base_url + "/get_server_info")
|
response = requests.get(self.base_url + "/get_server_info")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
|
|
||||||
time.sleep(5)
|
time.sleep(1)
|
||||||
|
|
||||||
response = requests.get(self.base_url + "/get_server_info")
|
response = requests.get(self.base_url + "/get_server_info")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
|
|||||||
Reference in New Issue
Block a user