[Feature] Use current greenctx stream to communicate in PD-Multiplexing. (#11594)

This commit is contained in:
ykcombat
2025-10-20 10:58:20 +08:00
committed by GitHub
parent c726d44cc7
commit c4e81e64fb
2 changed files with 48 additions and 20 deletions

View File

@@ -30,6 +30,7 @@ class PyNcclCommunicator:
group: Union[ProcessGroup, StatelessProcessGroup], group: Union[ProcessGroup, StatelessProcessGroup],
device: Union[int, str, torch.device], device: Union[int, str, torch.device],
library_path: Optional[str] = None, library_path: Optional[str] = None,
use_current_stream: bool = False,
): ):
""" """
Args: Args:
@@ -74,6 +75,7 @@ class PyNcclCommunicator:
self.available = True self.available = True
self.disabled = False self.disabled = False
self.use_current_stream = use_current_stream
self.nccl_version = self.nccl.ncclGetRawVersion() self.nccl_version = self.nccl.ncclGetRawVersion()
if self.rank == 0: if self.rank == 0:
@@ -123,6 +125,21 @@ class PyNcclCommunicator:
# when we are using CUDA graph. # when we are using CUDA graph.
self.disabled = True self.disabled = True
def _resolve_stream(self, stream: Optional[torch.cuda.Stream]):
"""Return the stream to use for NCCL calls.
Behavior mirrors the previous inline logic:
- if an explicit stream is provided, return it
- if stream is None and self.use_current_stream is True, return
torch.cuda.current_stream()
- otherwise return the communicator's default stream (self.stream)
"""
if stream is not None:
return stream
if self.use_current_stream:
return torch.cuda.current_stream()
return self.stream
def all_reduce( def all_reduce(
self, tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM, stream=None self, tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM, stream=None
): ):
@@ -135,8 +152,7 @@ class PyNcclCommunicator:
f"this nccl communicator is created to work on {self.device}, " f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}" f"but the input tensor is on {tensor.device}"
) )
if stream is None: stream = self._resolve_stream(stream)
stream = self.stream
self.nccl.ncclAllReduce( self.nccl.ncclAllReduce(
buffer_type(tensor.data_ptr()), buffer_type(tensor.data_ptr()),
buffer_type(tensor.data_ptr()), buffer_type(tensor.data_ptr()),
@@ -163,8 +179,7 @@ class PyNcclCommunicator:
f"this nccl communicator is created to work on {self.device}, " f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {input_tensor.device}" f"but the input tensor is on {input_tensor.device}"
) )
if stream is None: stream = self._resolve_stream(stream)
stream = self.stream
if sizes is not None: if sizes is not None:
split_offset = 0 split_offset = 0
@@ -210,8 +225,7 @@ class PyNcclCommunicator:
f"this nccl communicator is created to work on {self.device}, " f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {input_tensor.device}" f"but the input tensor is on {input_tensor.device}"
) )
if stream is None: stream = self._resolve_stream(stream)
stream = self.stream
if sizes is not None: if sizes is not None:
split_offset = 0 split_offset = 0
@@ -249,8 +263,7 @@ class PyNcclCommunicator:
f"this nccl communicator is created to work on {self.device}, " f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}" f"but the input tensor is on {tensor.device}"
) )
if stream is None: stream = self._resolve_stream(stream)
stream = self.stream
self.nccl.ncclSend( self.nccl.ncclSend(
buffer_type(tensor.data_ptr()), buffer_type(tensor.data_ptr()),
tensor.numel(), tensor.numel(),
@@ -267,8 +280,7 @@ class PyNcclCommunicator:
f"this nccl communicator is created to work on {self.device}, " f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}" f"but the input tensor is on {tensor.device}"
) )
if stream is None: stream = self._resolve_stream(stream)
stream = self.stream
self.nccl.ncclRecv( self.nccl.ncclRecv(
buffer_type(tensor.data_ptr()), buffer_type(tensor.data_ptr()),
tensor.numel(), tensor.numel(),
@@ -285,8 +297,8 @@ class PyNcclCommunicator:
f"this nccl communicator is created to work on {self.device}, " f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}" f"but the input tensor is on {tensor.device}"
) )
if stream is None: stream = self._resolve_stream(stream)
stream = self.stream
if src == self.rank: if src == self.rank:
sendbuff = buffer_type(tensor.data_ptr()) sendbuff = buffer_type(tensor.data_ptr())
# NCCL requires the sender also to have a receive buffer # NCCL requires the sender also to have a receive buffer

View File

@@ -239,6 +239,7 @@ class GroupCoordinator:
use_npu_communicator: bool, use_npu_communicator: bool,
use_message_queue_broadcaster: bool = False, use_message_queue_broadcaster: bool = False,
group_name: Optional[str] = None, group_name: Optional[str] = None,
pynccl_use_current_stream: bool = False,
torch_compile: Optional[bool] = None, torch_compile: Optional[bool] = None,
gloo_timeout: timedelta = timedelta(seconds=120 * 60), gloo_timeout: timedelta = timedelta(seconds=120 * 60),
): ):
@@ -289,6 +290,7 @@ class GroupCoordinator:
# Import communicators # Import communicators
self.use_pynccl = use_pynccl self.use_pynccl = use_pynccl
self.pynccl_use_current_stream = pynccl_use_current_stream
self.use_pymscclpp = use_pymscclpp self.use_pymscclpp = use_pymscclpp
self.use_custom_allreduce = use_custom_allreduce self.use_custom_allreduce = use_custom_allreduce
self.use_torch_symm_mem = use_torch_symm_mem self.use_torch_symm_mem = use_torch_symm_mem
@@ -322,6 +324,7 @@ class GroupCoordinator:
self.pynccl_comm = PyNcclCommunicator( self.pynccl_comm = PyNcclCommunicator(
group=self.cpu_group, group=self.cpu_group,
device=self.device, device=self.device,
use_current_stream=pynccl_use_current_stream,
) )
self.pymscclpp_comm: Optional[PyMscclppCommunicator] = None self.pymscclpp_comm: Optional[PyMscclppCommunicator] = None
@@ -449,10 +452,13 @@ class GroupCoordinator:
@contextmanager @contextmanager
def graph_capture( def graph_capture(
self, graph_capture_context: Optional[GraphCaptureContext] = None self,
graph_capture_context: Optional[GraphCaptureContext] = None,
stream: Optional[torch.cuda.Stream] = None,
): ):
if graph_capture_context is None: if graph_capture_context is None:
stream = self.device_module.Stream() if stream is None:
stream = self.device_module.Stream()
graph_capture_context = GraphCaptureContext(stream) graph_capture_context = GraphCaptureContext(stream)
else: else:
stream = graph_capture_context.stream stream = graph_capture_context.stream
@@ -1278,6 +1284,7 @@ def init_model_parallel_group(
use_message_queue_broadcaster: bool = False, use_message_queue_broadcaster: bool = False,
group_name: Optional[str] = None, group_name: Optional[str] = None,
use_mscclpp_allreduce: Optional[bool] = None, use_mscclpp_allreduce: Optional[bool] = None,
pynccl_use_current_stream: bool = True,
use_symm_mem_allreduce: Optional[bool] = None, use_symm_mem_allreduce: Optional[bool] = None,
torch_compile: Optional[bool] = None, torch_compile: Optional[bool] = None,
) -> GroupCoordinator: ) -> GroupCoordinator:
@@ -1300,6 +1307,7 @@ def init_model_parallel_group(
use_npu_communicator=True, use_npu_communicator=True,
use_message_queue_broadcaster=use_message_queue_broadcaster, use_message_queue_broadcaster=use_message_queue_broadcaster,
group_name=group_name, group_name=group_name,
pynccl_use_current_stream=pynccl_use_current_stream,
torch_compile=torch_compile, torch_compile=torch_compile,
) )
@@ -1357,7 +1365,7 @@ get_pipeline_model_parallel_group = get_pp_group
@contextmanager @contextmanager
def graph_capture(): def graph_capture(stream: Optional[torch.cuda.Stream] = None):
""" """
`graph_capture` is a context manager which should surround the code that `graph_capture` is a context manager which should surround the code that
is capturing the CUDA graph. Its main purpose is to ensure that the is capturing the CUDA graph. Its main purpose is to ensure that the
@@ -1371,9 +1379,9 @@ def graph_capture():
in order to explicitly distinguish the kernels to capture in order to explicitly distinguish the kernels to capture
from other kernels possibly launched on background in the default stream. from other kernels possibly launched on background in the default stream.
""" """
with get_tp_group().graph_capture() as context, get_pp_group().graph_capture( with get_tp_group().graph_capture(
context stream=stream
): ) as context, get_pp_group().graph_capture(context):
yield context yield context
@@ -1527,6 +1535,7 @@ def initialize_model_parallel(
"SGLANG_USE_MESSAGE_QUEUE_BROADCASTER", "true" "SGLANG_USE_MESSAGE_QUEUE_BROADCASTER", "true"
), ),
group_name="tp", group_name="tp",
pynccl_use_current_stream=duplicate_tp_group,
torch_compile=torch_compile, torch_compile=torch_compile,
) )
@@ -1543,10 +1552,12 @@ def initialize_model_parallel(
"SGLANG_USE_MESSAGE_QUEUE_BROADCASTER", "true" "SGLANG_USE_MESSAGE_QUEUE_BROADCASTER", "true"
), ),
group_name="pdmux_prefill_tp", group_name="pdmux_prefill_tp",
pynccl_use_current_stream=True,
torch_compile=torch_compile, torch_compile=torch_compile,
) )
_TP.pynccl_comm.disabled = False if _TP.pynccl_comm:
_PDMUX_PREFILL_TP_GROUP.pynccl_comm.disabled = False _TP.pynccl_comm.disabled = False
_PDMUX_PREFILL_TP_GROUP.pynccl_comm.disabled = False
moe_ep_size = expert_model_parallel_size moe_ep_size = expert_model_parallel_size
moe_tp_size = tensor_model_parallel_size // moe_ep_size moe_tp_size = tensor_model_parallel_size // moe_ep_size
@@ -1737,6 +1748,11 @@ def destroy_model_parallel():
_PP.destroy() _PP.destroy()
_PP = None _PP = None
global _PDMUX_PREFILL_TP_GROUP
if _PDMUX_PREFILL_TP_GROUP: # type: ignore[union-attr]
_PDMUX_PREFILL_TP_GROUP.destroy()
_PDMUX_PREFILL_TP_GROUP = None
def destroy_distributed_environment(): def destroy_distributed_environment():
global _WORLD global _WORLD