From d4d0c7c367947e95d03c35e48c348c18426073d8 Mon Sep 17 00:00:00 2001 From: ykcombat <99869808+ykcombat@users.noreply.github.com> Date: Tue, 15 Jul 2025 02:35:46 +0800 Subject: [PATCH] [Feature]TP Group Switching for PD-Multiplexing (#7653) --- .../sglang/srt/distributed/parallel_state.py | 33 +++++++++++++++++++ .../sglang/srt/model_executor/model_runner.py | 1 + python/sglang/srt/server_args.py | 15 +++++++++ 3 files changed, 49 insertions(+) diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py index 191e5b0ba..509c71531 100644 --- a/python/sglang/srt/distributed/parallel_state.py +++ b/python/sglang/srt/distributed/parallel_state.py @@ -1065,8 +1065,23 @@ def init_model_parallel_group( _TP: Optional[GroupCoordinator] = None +# duplicate GroupCoordinator for prefill in PD-Multiplexing +_PDMUX_PREFILL_TP_GROUP: Optional[GroupCoordinator] = None + +_ENABLE_PDMUX_P_TP: bool = False + + +def set_pdmux_status(enable_prefill_multiplexing: bool): + global _ENABLE_PDMUX_P_TP + _ENABLE_PDMUX_P_TP = enable_prefill_multiplexing + def get_tp_group() -> GroupCoordinator: + if _ENABLE_PDMUX_P_TP: + assert ( + _PDMUX_PREFILL_TP_GROUP is not None + ), "tensor model parallel group for PD-Multiplexing Prefill is not initialized" + return _PDMUX_PREFILL_TP_GROUP assert _TP is not None, "tensor model parallel group is not initialized" return _TP @@ -1182,6 +1197,7 @@ def initialize_model_parallel( tensor_model_parallel_size: int = 1, pipeline_model_parallel_size: int = 1, backend: Optional[str] = None, + duplicate_tp_group: bool = False, ) -> None: """ Initialize model parallel groups. @@ -1239,6 +1255,23 @@ def initialize_model_parallel( group_name="tp", ) + if duplicate_tp_group: + global _PDMUX_PREFILL_TP_GROUP + assert ( + _PDMUX_PREFILL_TP_GROUP is None + ), "tensor model parallel group for PD-Multiplexing Prefill is already initialized" + _PDMUX_PREFILL_TP_GROUP = init_model_parallel_group( + group_ranks, + get_world_group().local_rank, + backend, + use_message_queue_broadcaster=get_bool_env_var( + "SGLANG_USE_MESSAGE_QUEUE_BROADCASTER", "true" + ), + group_name="pdmux_prefill_tp", + ) + _TP.pynccl_comm.disabled = False + _PDMUX_PREFILL_TP_GROUP.pynccl_comm.disabled = False + # Build the pipeline model-parallel groups. num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size global _PP diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index f70eccd0c..a7885a5e3 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -539,6 +539,7 @@ class ModelRunner: initialize_model_parallel( tensor_model_parallel_size=self.tp_size, pipeline_model_parallel_size=self.pp_size, + duplicate_tp_group=self.server_args.enable_pdmux, ) initialize_dp_attention( enable_dp_attention=self.server_args.enable_dp_attention, diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 16ac09b16..95ba9bee6 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -251,6 +251,10 @@ class ServerArgs: custom_weight_loader: Optional[List[str]] = None weight_loader_disable_mmap: bool = False + # For PD-Multiplexing + enable_pdmux: bool = False + sm_group_num: int = 3 + def __post_init__(self): # Expert parallelism if self.enable_ep_moe: @@ -1721,6 +1725,17 @@ class ServerArgs: default=None, help="The custom dataloader which used to update the model. Should be set with a valid import path, such as my_package.weight_load_func", ) + parser.add_argument( + "--enable-pdmux", + action="store_true", + help="Enable PD-Multiplexing, PD running on greenctx stream.", + ) + parser.add_argument( + "--sm-group-num", + type=int, + default=ServerArgs.sm_group_num, + help="Number of sm partition groups.", + ) parser.add_argument( "--weight-loader-disable-mmap", action="store_true",