diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index ab9674425..a7aa55cc9 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -290,6 +290,7 @@ jobs: python3 test_moe_eval_accuracy_large.py finish: + if: always() needs: [ unit-test-frontend, unit-test-backend-1-gpu, unit-test-backend-2-gpu, performance-test-1-gpu-part-1, performance-test-1-gpu-part-2, performance-test-2-gpu, diff --git a/python/sglang/srt/layers/dp_attention.py b/python/sglang/srt/layers/dp_attention.py index 42d4a1457..0d593e048 100644 --- a/python/sglang/srt/layers/dp_attention.py +++ b/python/sglang/srt/layers/dp_attention.py @@ -38,7 +38,12 @@ def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_si return attn_tp_rank, attn_tp_size, dp_rank -def initialize_dp_attention(enable_dp_attention, tp_rank, tp_size, dp_size): +def initialize_dp_attention( + enable_dp_attention: bool, + tp_rank: int, + tp_size: int, + dp_size: int, +): global _ATTN_TP_GROUP, _ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK, _DP_SIZE from sglang.srt.layers.sampler import SYNC_TOKEN_IDS_ACROSS_TP @@ -46,7 +51,13 @@ def initialize_dp_attention(enable_dp_attention, tp_rank, tp_size, dp_size): _ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK = compute_dp_attention_world_info( enable_dp_attention, tp_rank, tp_size, dp_size ) - _DP_SIZE = dp_size + + if enable_dp_attention: + local_rank = tp_rank % (tp_size // dp_size) + _DP_SIZE = dp_size + else: + local_rank = tp_rank + _DP_SIZE = 1 tp_group = get_tp_group() _ATTN_TP_GROUP = GroupCoordinator( @@ -54,7 +65,7 @@ def initialize_dp_attention(enable_dp_attention, tp_rank, tp_size, dp_size): list(range(head, head + _ATTN_TP_SIZE)) for head in range(0, tp_size, _ATTN_TP_SIZE) ], - tp_rank, + local_rank, torch.distributed.get_backend(tp_group.device_group), SYNC_TOKEN_IDS_ACROSS_TP, False, diff --git a/python/sglang/srt/managers/data_parallel_controller.py b/python/sglang/srt/managers/data_parallel_controller.py index 627d72c7b..fb0264a6e 100644 --- a/python/sglang/srt/managers/data_parallel_controller.py +++ b/python/sglang/srt/managers/data_parallel_controller.py @@ -82,10 +82,12 @@ class DataParallelController: self.scheduler_procs = [] self.workers = [None] * server_args.dp_size - if not server_args.enable_dp_attention: - dp_port_args = self.launch_dp_schedulers(server_args, port_args) - else: + if server_args.enable_dp_attention: dp_port_args = self.launch_dp_attention_schedulers(server_args, port_args) + self.control_message_step = server_args.tp_size + else: + dp_port_args = self.launch_dp_schedulers(server_args, port_args) + self.control_message_step = 1 # Only node rank 0 runs the real data parallel controller that dispatches the requests. if server_args.node_rank == 0: @@ -105,6 +107,7 @@ class DataParallelController: threads = [] sockets = [] dp_port_args = [] + ready_events = [] for dp_rank in range(server_args.dp_size): tmp_port_args = PortArgs.init_new(server_args) tmp_port_args.tokenizer_ipc_name = port_args.tokenizer_ipc_name @@ -115,10 +118,13 @@ class DataParallelController: # We hold it first so that the next dp worker gets a different port sockets.append(bind_port(tmp_port_args.nccl_port)) + ready_event = threading.Event() + ready_events.append(ready_event) + # Create a thread for each worker thread = threading.Thread( - target=self.launch_tensor_parallel_group, - args=(server_args, tmp_port_args, base_gpu_id, dp_rank), + target=self.launch_tensor_parallel_group_thread, + args=(server_args, tmp_port_args, base_gpu_id, dp_rank, ready_event), ) threads.append(thread) base_gpu_id += server_args.tp_size * server_args.gpu_id_step @@ -130,11 +136,27 @@ class DataParallelController: # Start all threads for thread in threads: thread.start() - for thread in threads: - thread.join() + for event in ready_events: + event.wait() return dp_port_args + def launch_tensor_parallel_group_thread( + self, + server_args: ServerArgs, + port_args: PortArgs, + base_gpu_id: int, + dp_rank: int, + ready_event: threading.Event, + ): + self.launch_tensor_parallel_group(server_args, port_args, base_gpu_id, dp_rank) + ready_event.set() + + # This thread cannot be closed because otherwise the `kill_itself_when_parent_died` + # function in scheduler.py will kill the scheduler. + while True: + pass + def launch_dp_attention_schedulers(self, server_args, port_args): self.launch_tensor_parallel_group(server_args, port_args, 0, None) dp_port_args = [] @@ -223,7 +245,7 @@ class DataParallelController: self.dispatching(recv_req) else: # Send other control messages to first worker of tp group - for worker in self.workers[:: self.server_args.tp_size]: + for worker in self.workers[:: self.control_message_step]: worker.send_pyobj(recv_req) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index fa4a49ce8..a743af97a 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -1786,7 +1786,7 @@ def run_scheduler_process( prefix = f" DP{dp_rank} TP{tp_rank}" # Config the process - # kill_itself_when_parent_died() # This is disabled because it does not work for `--dp 2` + kill_itself_when_parent_died() setproctitle.setproctitle(f"sglang::scheduler{prefix.replace(' ', '_')}") faulthandler.enable() parent_process = psutil.Process().parent() diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 445476f07..9add51eef 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -16,6 +16,7 @@ from __future__ import annotations import bisect +import os from contextlib import contextmanager from typing import TYPE_CHECKING, Callable @@ -81,7 +82,9 @@ def patch_model( # tp_group.ca_comm = None yield torch.compile( torch.no_grad()(model.forward), - mode="max-autotune-no-cudagraphs", + mode=os.environ.get( + "SGLANG_TORCH_COMPILE_MODE", "max-autotune-no-cudagraphs" + ), dynamic=False, ) else: diff --git a/test/srt/test_data_parallelism.py b/test/srt/test_data_parallelism.py index 1998fee2f..1c674c327 100644 --- a/test/srt/test_data_parallelism.py +++ b/test/srt/test_data_parallelism.py @@ -23,7 +23,7 @@ class TestDataParallelism(unittest.TestCase): cls.model, cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=["--dp", "2"], + other_args=["--dp", 2], ) @classmethod @@ -52,7 +52,7 @@ class TestDataParallelism(unittest.TestCase): assert response.status_code == 200 # pause a few seconds then send again - time.sleep(5) + time.sleep(1) response = requests.post( self.base_url + "/update_weights_from_disk", @@ -67,7 +67,7 @@ class TestDataParallelism(unittest.TestCase): response = requests.get(self.base_url + "/get_server_info") assert response.status_code == 200 - time.sleep(5) + time.sleep(1) response = requests.get(self.base_url + "/get_server_info") assert response.status_code == 200