From f7b2853ff8b4d2728e4781b45b8c4d7394cd9be9 Mon Sep 17 00:00:00 2001 From: Guanhua Wang <59467949+WANG-GH@users.noreply.github.com> Date: Sun, 3 Aug 2025 15:46:47 +0800 Subject: [PATCH] [feat] support minimum token load balance in dp attention (#7379) --- docs/backend/server_arguments.md | 2 +- python/sglang/srt/entrypoints/engine.py | 1 + .../srt/managers/data_parallel_controller.py | 54 ++++++++- python/sglang/srt/managers/io_struct.py | 5 + python/sglang/srt/managers/scheduler.py | 113 +++++++++++++++++- python/sglang/srt/managers/utils.py | 46 ++++++- python/sglang/srt/server_args.py | 1 + test/srt/test_dp_attention.py | 55 +++++++++ 8 files changed, 271 insertions(+), 6 deletions(-) diff --git a/docs/backend/server_arguments.md b/docs/backend/server_arguments.md index bff9dbcdc..a79911bc9 100644 --- a/docs/backend/server_arguments.md +++ b/docs/backend/server_arguments.md @@ -155,7 +155,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s | Arguments | Description | Defaults | |-----------|-------------|----------| | `--dp-size` | The data parallelism size. | 1 | -| `--load-balance-method` | The load balancing strategy for data parallelism. | round_robin | +| `--load-balance-method` | The load balancing strategy for data parallelism. Options include: 'round_robin', 'minimum_tokens'. The Minimum Token algorithm can only be used when DP attention is applied. This algorithm performs load balancing based on the real-time token load of the DP workers. | round_robin | ## Multi-node distributed serving diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 0e764081a..c2885fa78 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -732,6 +732,7 @@ def _launch_subprocesses( pp_rank, None, writer, + None, ), ) diff --git a/python/sglang/srt/managers/data_parallel_controller.py b/python/sglang/srt/managers/data_parallel_controller.py index 98173f7a6..76b9e1a01 100644 --- a/python/sglang/srt/managers/data_parallel_controller.py +++ b/python/sglang/srt/managers/data_parallel_controller.py @@ -16,9 +16,13 @@ import logging import multiprocessing as mp import signal +import struct +import sys import threading import time from enum import Enum, auto +from multiprocessing import shared_memory +from typing import Dict, List import psutil import setproctitle @@ -32,6 +36,7 @@ from sglang.srt.managers.io_struct import ( ) from sglang.srt.managers.schedule_batch import Req from sglang.srt.managers.scheduler import run_scheduler_process +from sglang.srt.managers.utils import DPBalanceMeta from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter from sglang.srt.utils import bind_port, configure_logger, get_zmq_socket @@ -45,6 +50,7 @@ class LoadBalanceMethod(Enum): ROUND_ROBIN = auto() SHORTEST_QUEUE = auto() + MINIMUM_TOKENS = auto() @classmethod def from_str(cls, method: str): @@ -58,7 +64,16 @@ class LoadBalanceMethod(Enum): class DataParallelController: """A controller that dispatches requests to multiple data parallel workers.""" - def __init__(self, server_args: ServerArgs, port_args: PortArgs) -> None: + def __init__( + self, + server_args: ServerArgs, + port_args: PortArgs, + dp_balance_meta: DPBalanceMeta, + ) -> None: + # for dp balance + self.global_balance_id = 0 + self.balance_meta = dp_balance_meta + # Parse args self.max_total_num_tokens = None self.server_args = server_args @@ -79,6 +94,7 @@ class DataParallelController: dispatch_lookup = { LoadBalanceMethod.ROUND_ROBIN: self.round_robin_scheduler, LoadBalanceMethod.SHORTEST_QUEUE: self.shortest_queue_scheduler, + LoadBalanceMethod.MINIMUM_TOKENS: self.minimum_tokens_scheduler, } self.dispatching = dispatch_lookup[self.load_balance_method] @@ -234,6 +250,7 @@ class DataParallelController: pp_rank, dp_rank, writer, + self.balance_meta, ), ) with memory_saver_adapter.configure_subprocess(): @@ -269,6 +286,33 @@ class DataParallelController: def shortest_queue_scheduler(self, input_requests): raise NotImplementedError() + def minimum_tokens_scheduler(self, req): + # This variable corresponds to the balance_id in TokenizedGenerateReqInput. + # We use it to to control the number of onfly tokens (requests dispatched to workers but not yet received). + def get_next_global_balance_id() -> int: + INT32_MAX = 2147483647 + current_id = self.global_balance_id + self.global_balance_id = (self.global_balance_id + 1) % INT32_MAX + return current_id + + req.dp_balance_id = get_next_global_balance_id() + with self.balance_meta.mutex: + # 1. local_tokens represents the tokens currently inferring on the worker, + # while onfly refers to the requests dispatched by the dispatcher but not yet received by the scheduler. + onfly_info = self.balance_meta.get_shared_onfly() + local_tokens = self.balance_meta.get_shared_local_tokens() + total_tokens = [ + local_token + sum(onfly_dict.values()) + for local_token, onfly_dict in zip(local_tokens, onfly_info) + ] + target_worker = total_tokens.index(min(total_tokens)) + onfly_info[target_worker][req.dp_balance_id] = len(req.input_ids) + # 2. write the new onfly info to the shm + self.balance_meta.set_shared_onfly_info(onfly_info) + + # logger.info(f"dp workers {local_tokens=}, {onfly_info=}, {target_worker=}") + self.workers[target_worker].send_pyobj(req) + def event_loop(self): while True: while True: @@ -302,9 +346,12 @@ def run_data_parallel_controller_process( setproctitle.setproctitle("sglang::data_parallel_controller") configure_logger(server_args) parent_process = psutil.Process().parent() + balance_meta = DPBalanceMeta(server_args.dp_size) try: - controller = DataParallelController(server_args, port_args) + controller = DataParallelController( + server_args, port_args, dp_balance_meta=balance_meta + ) pipe_writer.send( { "status": "ready", @@ -323,3 +370,6 @@ def run_data_parallel_controller_process( traceback = get_exception_traceback() logger.error(f"DataParallelController hit an exception: {traceback}") parent_process.send_signal(signal.SIGQUIT) + finally: + # we need to destruct mp.Manager() in balance_meta + balance_meta.destructor() diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 2b5f19c71..7935b4228 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -523,6 +523,9 @@ class TokenizedGenerateReqInput: # For data parallel rank routing data_parallel_rank: Optional[int] = None + # For dp balance + dp_balance_id: int = -1 + @dataclass class EmbeddingReqInput: @@ -648,6 +651,8 @@ class TokenizedEmbeddingReqInput: token_type_ids: List[int] # Dummy sampling params for compatibility sampling_params: SamplingParams + # For dp balance + dp_balance_id: int = -1 @dataclass diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 2a0b139f6..0249acd8d 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -126,7 +126,7 @@ from sglang.srt.managers.scheduler_update_weights_mixin import ( from sglang.srt.managers.session_controller import Session from sglang.srt.managers.tp_worker import TpModelWorker from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient -from sglang.srt.managers.utils import validate_input_length +from sglang.srt.managers.utils import DPBalanceMeta, validate_input_length from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache from sglang.srt.mem_cache.hiradix_cache import HiRadixCache from sglang.srt.mem_cache.radix_cache import RadixCache @@ -203,6 +203,7 @@ class Scheduler( moe_ep_rank: int, pp_rank: int, dp_rank: Optional[int], + dp_balance_meta: Optional[DPBalanceMeta] = None, ): # Parse args self.server_args = server_args @@ -522,6 +523,15 @@ class Scheduler( ] ) + self.balance_meta = dp_balance_meta + if ( + server_args.enable_dp_attention + and server_args.load_balance_method == "minimum_tokens" + ): + assert dp_balance_meta is not None + + self.recv_dp_balance_id_this_term = [] + def init_tokenizer(self): server_args = self.server_args @@ -1049,6 +1059,12 @@ class Scheduler( self, recv_req: TokenizedGenerateReqInput, ): + if ( + self.server_args.enable_dp_attention + and self.server_args.load_balance_method == "minimum_tokens" + ): + self.recv_dp_balance_id_this_term.append(recv_req.dp_balance_id) + # Create a new request if ( recv_req.session_params is None @@ -1459,6 +1475,11 @@ class Scheduler( # Handle DP attention if need_dp_attn_preparation: + if ( + self.server_args.load_balance_method == "minimum_tokens" + and self.forward_ct % 40 == 0 + ): + self.handle_dp_balance_data(ret) ret = self.prepare_mlp_sync_batch(ret) return ret @@ -1786,6 +1807,86 @@ class Scheduler( disable_overlap_schedule=self.server_args.disable_overlap_schedule, ) + def handle_dp_balance_data(self, local_batch: ScheduleBatch): + def gather_dp_balance_info(holding_tokens_list) -> Union[None, List[List[int]]]: + """gather recv_dp_balance_id_this_term and holding tokens per worker for dp balance""" + recv_list = self.recv_dp_balance_id_this_term + assert len(recv_list) <= 511, ( + "The number of requests received this round is too large. " + "Please increase gather_tensor_size and onfly_info_size." + ) + # The maximum size of the tensor used for gathering data from all workers. + gather_tensor_size = 512 + + # recv_tensor: | holding_tokens | len(recv_dp_balance_id) | recv_dp_balance_ids + recv_tensor = torch.zeros(gather_tensor_size, dtype=torch.int32) + recv_tensor[0] = holding_tokens_list + recv_tensor[1] = len( + recv_list + ) # The first element is the length of the list. + recv_tensor[2 : len(recv_list) + 2] = torch.tensor( + recv_list, dtype=torch.int32 + ) + + if self.tp_rank == 0: + gathered_list = [ + torch.zeros(gather_tensor_size, dtype=torch.int32) + for _ in range(self.balance_meta.num_workers) + ] + else: + gathered_list = None + + torch.distributed.gather( + recv_tensor, gathered_list, group=self.tp_cpu_group + ) + + gathered_id_list_per_worker = None + if self.tp_rank == 0: + gathered_id_list_per_worker = [] + holding_tokens_list = [] + for tensor in gathered_list: + holding_tokens_list.append(tensor[0].item()) + list_length = tensor[1].item() + gathered_id_list_per_worker.append( + tensor[2 : list_length + 2].tolist() + ) + + return gathered_id_list_per_worker, holding_tokens_list + + def write_shared_dp_balance_info(new_recv_rid_lists, local_tokens): + meta = self.balance_meta + + with meta.mutex: + onfly_list: List[Dict[int, int]] = meta.get_shared_onfly() + assert len(new_recv_rid_lists) == len( + onfly_list + ), "num_worker not equal" + # 1.Check if the rid received by each worker this round is present in onfly. + # If it is, remove the corresponding onfly item. + worker_id = 0 + for new_recv_rids, on_fly_reqs in zip(new_recv_rid_lists, onfly_list): + for new_recv_rid in new_recv_rids: + assert ( + new_recv_rid in on_fly_reqs + ), f"{new_recv_rid=} not in {worker_id=} {on_fly_reqs=}, data consistency is wrong" + del on_fly_reqs[new_recv_rid] + worker_id += 1 + # 2. Atomically write local_tokens and onfly into shm under the mutex + meta.set_shared_onfly_info(onfly_list) + meta.set_shared_local_tokens(local_tokens) + + holding_tokens = self.get_load() + + new_recv_dp_balance_id_list, holding_token_list = gather_dp_balance_info( + holding_tokens + ) + + self.recv_dp_balance_id_this_term.clear() + if self.tp_rank == 0: # only first worker write info + write_shared_dp_balance_info( + new_recv_dp_balance_id_list, holding_token_list + ) + @staticmethod def prepare_mlp_sync_batch_raw( local_batch: ScheduleBatch, @@ -2394,6 +2495,7 @@ def run_scheduler_process( pp_rank: int, dp_rank: Optional[int], pipe_writer, + balance_meta: Optional[DPBalanceMeta] = None, ): # Generate the prefix prefix = "" @@ -2427,7 +2529,14 @@ def run_scheduler_process( # Create a scheduler and run the event loop try: scheduler = Scheduler( - server_args, port_args, gpu_id, tp_rank, moe_ep_rank, pp_rank, dp_rank + server_args, + port_args, + gpu_id, + tp_rank, + moe_ep_rank, + pp_rank, + dp_rank, + dp_balance_meta=balance_meta, ) pipe_writer.send( { diff --git a/python/sglang/srt/managers/utils.py b/python/sglang/srt/managers/utils.py index 2909e7597..2ab32f242 100644 --- a/python/sglang/srt/managers/utils.py +++ b/python/sglang/srt/managers/utils.py @@ -1,6 +1,7 @@ import logging +import multiprocessing as mp from http import HTTPStatus -from typing import Optional +from typing import Dict, List, Optional from sglang.srt.managers.schedule_batch import FINISH_ABORT, Req @@ -38,3 +39,46 @@ def validate_input_length( return error_msg return None + + +class DPBalanceMeta: + """ + This class will be use in scheduler and dp controller + """ + + def __init__(self, num_workers: int): + self.num_workers = num_workers + self._manager = mp.Manager() + self.mutex = self._manager.Lock() + + init_local_tokens = [0] * self.num_workers + init_onfly_info = [self._manager.dict() for _ in range(self.num_workers)] + + self.shared_state = self._manager.Namespace() + self.shared_state.local_tokens = self._manager.list(init_local_tokens) + self.shared_state.onfly_info = self._manager.list(init_onfly_info) + + def destructor(self): + # we must destructor this class manually + self._manager.shutdown() + + def get_shared_onfly(self) -> List[Dict[int, int]]: + return [dict(d) for d in self.shared_state.onfly_info] + + def set_shared_onfly_info(self, data: List[Dict[int, int]]): + self.shared_state.onfly_info = data + + def get_shared_local_tokens(self) -> List[int]: + return list(self.shared_state.local_tokens) + + def set_shared_local_tokens(self, data: List[int]): + self.shared_state.local_tokens = data + + def __getstate__(self): + state = self.__dict__.copy() + del state["_manager"] + return state + + def __setstate__(self, state): + self.__dict__.update(state) + self._manager = None diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 7f3fd88b1..4691b3c7c 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -1171,6 +1171,7 @@ class ServerArgs: choices=[ "round_robin", "shortest_queue", + "minimum_tokens", ], ) diff --git a/test/srt/test_dp_attention.py b/test/srt/test_dp_attention.py index af50dc780..f997382f9 100644 --- a/test/srt/test_dp_attention.py +++ b/test/srt/test_dp_attention.py @@ -137,5 +137,60 @@ class TestDPAttentionDP2TP2DeepseekV3MTP(CustomTestCase): self.assertGreater(avg_spec_accept_length, 2.5) +class TestDPAttentionMinimumTokenLoadBalance(CustomTestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--trust-remote-code", + "--tp", + "2", + "--enable-dp-attention", + "--dp", + "2", + "--enable-torch-compile", + "--torch-compile-max-bs", + "2", + "--load-balance-method", + "minimum_tokens", + ], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_mmlu(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mmlu", + num_examples=64, + num_threads=32, + ) + + metrics = run_eval(args) + print(f"{metrics=}") + self.assertGreater(metrics["score"], 0.5) + + def test_mgsm_en(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mgsm_en", + num_examples=None, + num_threads=1024, + ) + + metrics = run_eval(args) + print(f"{metrics=}") + self.assertGreater(metrics["score"], 0.8) + + if __name__ == "__main__": unittest.main()