From c4e81e64fb4233fbd76b5bb24e5fe6dd2a87832e Mon Sep 17 00:00:00 2001 From: ykcombat <99869808+ykcombat@users.noreply.github.com> Date: Mon, 20 Oct 2025 10:58:20 +0800 Subject: [PATCH] [Feature] Use current greenctx stream to communicate in PD-Multiplexing. (#11594) --- .../device_communicators/pynccl.py | 36 ++++++++++++------- .../sglang/srt/distributed/parallel_state.py | 32 ++++++++++++----- 2 files changed, 48 insertions(+), 20 deletions(-) diff --git a/python/sglang/srt/distributed/device_communicators/pynccl.py b/python/sglang/srt/distributed/device_communicators/pynccl.py index fbb59c477..6180e64d6 100644 --- a/python/sglang/srt/distributed/device_communicators/pynccl.py +++ b/python/sglang/srt/distributed/device_communicators/pynccl.py @@ -30,6 +30,7 @@ class PyNcclCommunicator: group: Union[ProcessGroup, StatelessProcessGroup], device: Union[int, str, torch.device], library_path: Optional[str] = None, + use_current_stream: bool = False, ): """ Args: @@ -74,6 +75,7 @@ class PyNcclCommunicator: self.available = True self.disabled = False + self.use_current_stream = use_current_stream self.nccl_version = self.nccl.ncclGetRawVersion() if self.rank == 0: @@ -123,6 +125,21 @@ class PyNcclCommunicator: # when we are using CUDA graph. self.disabled = True + def _resolve_stream(self, stream: Optional[torch.cuda.Stream]): + """Return the stream to use for NCCL calls. + + Behavior mirrors the previous inline logic: + - if an explicit stream is provided, return it + - if stream is None and self.use_current_stream is True, return + torch.cuda.current_stream() + - otherwise return the communicator's default stream (self.stream) + """ + if stream is not None: + return stream + if self.use_current_stream: + return torch.cuda.current_stream() + return self.stream + def all_reduce( self, tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM, stream=None ): @@ -135,8 +152,7 @@ class PyNcclCommunicator: f"this nccl communicator is created to work on {self.device}, " f"but the input tensor is on {tensor.device}" ) - if stream is None: - stream = self.stream + stream = self._resolve_stream(stream) self.nccl.ncclAllReduce( buffer_type(tensor.data_ptr()), buffer_type(tensor.data_ptr()), @@ -163,8 +179,7 @@ class PyNcclCommunicator: f"this nccl communicator is created to work on {self.device}, " f"but the input tensor is on {input_tensor.device}" ) - if stream is None: - stream = self.stream + stream = self._resolve_stream(stream) if sizes is not None: split_offset = 0 @@ -210,8 +225,7 @@ class PyNcclCommunicator: f"this nccl communicator is created to work on {self.device}, " f"but the input tensor is on {input_tensor.device}" ) - if stream is None: - stream = self.stream + stream = self._resolve_stream(stream) if sizes is not None: split_offset = 0 @@ -249,8 +263,7 @@ class PyNcclCommunicator: f"this nccl communicator is created to work on {self.device}, " f"but the input tensor is on {tensor.device}" ) - if stream is None: - stream = self.stream + stream = self._resolve_stream(stream) self.nccl.ncclSend( buffer_type(tensor.data_ptr()), tensor.numel(), @@ -267,8 +280,7 @@ class PyNcclCommunicator: f"this nccl communicator is created to work on {self.device}, " f"but the input tensor is on {tensor.device}" ) - if stream is None: - stream = self.stream + stream = self._resolve_stream(stream) self.nccl.ncclRecv( buffer_type(tensor.data_ptr()), tensor.numel(), @@ -285,8 +297,8 @@ class PyNcclCommunicator: f"this nccl communicator is created to work on {self.device}, " f"but the input tensor is on {tensor.device}" ) - if stream is None: - stream = self.stream + stream = self._resolve_stream(stream) + if src == self.rank: sendbuff = buffer_type(tensor.data_ptr()) # NCCL requires the sender also to have a receive buffer diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py index 013b3f785..15ab65589 100644 --- a/python/sglang/srt/distributed/parallel_state.py +++ b/python/sglang/srt/distributed/parallel_state.py @@ -239,6 +239,7 @@ class GroupCoordinator: use_npu_communicator: bool, use_message_queue_broadcaster: bool = False, group_name: Optional[str] = None, + pynccl_use_current_stream: bool = False, torch_compile: Optional[bool] = None, gloo_timeout: timedelta = timedelta(seconds=120 * 60), ): @@ -289,6 +290,7 @@ class GroupCoordinator: # Import communicators self.use_pynccl = use_pynccl + self.pynccl_use_current_stream = pynccl_use_current_stream self.use_pymscclpp = use_pymscclpp self.use_custom_allreduce = use_custom_allreduce self.use_torch_symm_mem = use_torch_symm_mem @@ -322,6 +324,7 @@ class GroupCoordinator: self.pynccl_comm = PyNcclCommunicator( group=self.cpu_group, device=self.device, + use_current_stream=pynccl_use_current_stream, ) self.pymscclpp_comm: Optional[PyMscclppCommunicator] = None @@ -449,10 +452,13 @@ class GroupCoordinator: @contextmanager def graph_capture( - self, graph_capture_context: Optional[GraphCaptureContext] = None + self, + graph_capture_context: Optional[GraphCaptureContext] = None, + stream: Optional[torch.cuda.Stream] = None, ): if graph_capture_context is None: - stream = self.device_module.Stream() + if stream is None: + stream = self.device_module.Stream() graph_capture_context = GraphCaptureContext(stream) else: stream = graph_capture_context.stream @@ -1278,6 +1284,7 @@ def init_model_parallel_group( use_message_queue_broadcaster: bool = False, group_name: Optional[str] = None, use_mscclpp_allreduce: Optional[bool] = None, + pynccl_use_current_stream: bool = True, use_symm_mem_allreduce: Optional[bool] = None, torch_compile: Optional[bool] = None, ) -> GroupCoordinator: @@ -1300,6 +1307,7 @@ def init_model_parallel_group( use_npu_communicator=True, use_message_queue_broadcaster=use_message_queue_broadcaster, group_name=group_name, + pynccl_use_current_stream=pynccl_use_current_stream, torch_compile=torch_compile, ) @@ -1357,7 +1365,7 @@ get_pipeline_model_parallel_group = get_pp_group @contextmanager -def graph_capture(): +def graph_capture(stream: Optional[torch.cuda.Stream] = None): """ `graph_capture` is a context manager which should surround the code that is capturing the CUDA graph. Its main purpose is to ensure that the @@ -1371,9 +1379,9 @@ def graph_capture(): in order to explicitly distinguish the kernels to capture from other kernels possibly launched on background in the default stream. """ - with get_tp_group().graph_capture() as context, get_pp_group().graph_capture( - context - ): + with get_tp_group().graph_capture( + stream=stream + ) as context, get_pp_group().graph_capture(context): yield context @@ -1527,6 +1535,7 @@ def initialize_model_parallel( "SGLANG_USE_MESSAGE_QUEUE_BROADCASTER", "true" ), group_name="tp", + pynccl_use_current_stream=duplicate_tp_group, torch_compile=torch_compile, ) @@ -1543,10 +1552,12 @@ def initialize_model_parallel( "SGLANG_USE_MESSAGE_QUEUE_BROADCASTER", "true" ), group_name="pdmux_prefill_tp", + pynccl_use_current_stream=True, torch_compile=torch_compile, ) - _TP.pynccl_comm.disabled = False - _PDMUX_PREFILL_TP_GROUP.pynccl_comm.disabled = False + if _TP.pynccl_comm: + _TP.pynccl_comm.disabled = False + _PDMUX_PREFILL_TP_GROUP.pynccl_comm.disabled = False moe_ep_size = expert_model_parallel_size moe_tp_size = tensor_model_parallel_size // moe_ep_size @@ -1737,6 +1748,11 @@ def destroy_model_parallel(): _PP.destroy() _PP = None + global _PDMUX_PREFILL_TP_GROUP + if _PDMUX_PREFILL_TP_GROUP: # type: ignore[union-attr] + _PDMUX_PREFILL_TP_GROUP.destroy() + _PDMUX_PREFILL_TP_GROUP = None + def destroy_distributed_environment(): global _WORLD