[Feature] Use current greenctx stream to communicate in PD-Multiplexing. (#11594)
This commit is contained in:
@@ -30,6 +30,7 @@ class PyNcclCommunicator:
|
||||
group: Union[ProcessGroup, StatelessProcessGroup],
|
||||
device: Union[int, str, torch.device],
|
||||
library_path: Optional[str] = None,
|
||||
use_current_stream: bool = False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@@ -74,6 +75,7 @@ class PyNcclCommunicator:
|
||||
|
||||
self.available = True
|
||||
self.disabled = False
|
||||
self.use_current_stream = use_current_stream
|
||||
|
||||
self.nccl_version = self.nccl.ncclGetRawVersion()
|
||||
if self.rank == 0:
|
||||
@@ -123,6 +125,21 @@ class PyNcclCommunicator:
|
||||
# when we are using CUDA graph.
|
||||
self.disabled = True
|
||||
|
||||
def _resolve_stream(self, stream: Optional[torch.cuda.Stream]):
|
||||
"""Return the stream to use for NCCL calls.
|
||||
|
||||
Behavior mirrors the previous inline logic:
|
||||
- if an explicit stream is provided, return it
|
||||
- if stream is None and self.use_current_stream is True, return
|
||||
torch.cuda.current_stream()
|
||||
- otherwise return the communicator's default stream (self.stream)
|
||||
"""
|
||||
if stream is not None:
|
||||
return stream
|
||||
if self.use_current_stream:
|
||||
return torch.cuda.current_stream()
|
||||
return self.stream
|
||||
|
||||
def all_reduce(
|
||||
self, tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM, stream=None
|
||||
):
|
||||
@@ -135,8 +152,7 @@ class PyNcclCommunicator:
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {tensor.device}"
|
||||
)
|
||||
if stream is None:
|
||||
stream = self.stream
|
||||
stream = self._resolve_stream(stream)
|
||||
self.nccl.ncclAllReduce(
|
||||
buffer_type(tensor.data_ptr()),
|
||||
buffer_type(tensor.data_ptr()),
|
||||
@@ -163,8 +179,7 @@ class PyNcclCommunicator:
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {input_tensor.device}"
|
||||
)
|
||||
if stream is None:
|
||||
stream = self.stream
|
||||
stream = self._resolve_stream(stream)
|
||||
|
||||
if sizes is not None:
|
||||
split_offset = 0
|
||||
@@ -210,8 +225,7 @@ class PyNcclCommunicator:
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {input_tensor.device}"
|
||||
)
|
||||
if stream is None:
|
||||
stream = self.stream
|
||||
stream = self._resolve_stream(stream)
|
||||
|
||||
if sizes is not None:
|
||||
split_offset = 0
|
||||
@@ -249,8 +263,7 @@ class PyNcclCommunicator:
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {tensor.device}"
|
||||
)
|
||||
if stream is None:
|
||||
stream = self.stream
|
||||
stream = self._resolve_stream(stream)
|
||||
self.nccl.ncclSend(
|
||||
buffer_type(tensor.data_ptr()),
|
||||
tensor.numel(),
|
||||
@@ -267,8 +280,7 @@ class PyNcclCommunicator:
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {tensor.device}"
|
||||
)
|
||||
if stream is None:
|
||||
stream = self.stream
|
||||
stream = self._resolve_stream(stream)
|
||||
self.nccl.ncclRecv(
|
||||
buffer_type(tensor.data_ptr()),
|
||||
tensor.numel(),
|
||||
@@ -285,8 +297,8 @@ class PyNcclCommunicator:
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {tensor.device}"
|
||||
)
|
||||
if stream is None:
|
||||
stream = self.stream
|
||||
stream = self._resolve_stream(stream)
|
||||
|
||||
if src == self.rank:
|
||||
sendbuff = buffer_type(tensor.data_ptr())
|
||||
# NCCL requires the sender also to have a receive buffer
|
||||
|
||||
@@ -239,6 +239,7 @@ class GroupCoordinator:
|
||||
use_npu_communicator: bool,
|
||||
use_message_queue_broadcaster: bool = False,
|
||||
group_name: Optional[str] = None,
|
||||
pynccl_use_current_stream: bool = False,
|
||||
torch_compile: Optional[bool] = None,
|
||||
gloo_timeout: timedelta = timedelta(seconds=120 * 60),
|
||||
):
|
||||
@@ -289,6 +290,7 @@ class GroupCoordinator:
|
||||
|
||||
# Import communicators
|
||||
self.use_pynccl = use_pynccl
|
||||
self.pynccl_use_current_stream = pynccl_use_current_stream
|
||||
self.use_pymscclpp = use_pymscclpp
|
||||
self.use_custom_allreduce = use_custom_allreduce
|
||||
self.use_torch_symm_mem = use_torch_symm_mem
|
||||
@@ -322,6 +324,7 @@ class GroupCoordinator:
|
||||
self.pynccl_comm = PyNcclCommunicator(
|
||||
group=self.cpu_group,
|
||||
device=self.device,
|
||||
use_current_stream=pynccl_use_current_stream,
|
||||
)
|
||||
|
||||
self.pymscclpp_comm: Optional[PyMscclppCommunicator] = None
|
||||
@@ -449,10 +452,13 @@ class GroupCoordinator:
|
||||
|
||||
@contextmanager
|
||||
def graph_capture(
|
||||
self, graph_capture_context: Optional[GraphCaptureContext] = None
|
||||
self,
|
||||
graph_capture_context: Optional[GraphCaptureContext] = None,
|
||||
stream: Optional[torch.cuda.Stream] = None,
|
||||
):
|
||||
if graph_capture_context is None:
|
||||
stream = self.device_module.Stream()
|
||||
if stream is None:
|
||||
stream = self.device_module.Stream()
|
||||
graph_capture_context = GraphCaptureContext(stream)
|
||||
else:
|
||||
stream = graph_capture_context.stream
|
||||
@@ -1278,6 +1284,7 @@ def init_model_parallel_group(
|
||||
use_message_queue_broadcaster: bool = False,
|
||||
group_name: Optional[str] = None,
|
||||
use_mscclpp_allreduce: Optional[bool] = None,
|
||||
pynccl_use_current_stream: bool = True,
|
||||
use_symm_mem_allreduce: Optional[bool] = None,
|
||||
torch_compile: Optional[bool] = None,
|
||||
) -> GroupCoordinator:
|
||||
@@ -1300,6 +1307,7 @@ def init_model_parallel_group(
|
||||
use_npu_communicator=True,
|
||||
use_message_queue_broadcaster=use_message_queue_broadcaster,
|
||||
group_name=group_name,
|
||||
pynccl_use_current_stream=pynccl_use_current_stream,
|
||||
torch_compile=torch_compile,
|
||||
)
|
||||
|
||||
@@ -1357,7 +1365,7 @@ get_pipeline_model_parallel_group = get_pp_group
|
||||
|
||||
|
||||
@contextmanager
|
||||
def graph_capture():
|
||||
def graph_capture(stream: Optional[torch.cuda.Stream] = None):
|
||||
"""
|
||||
`graph_capture` is a context manager which should surround the code that
|
||||
is capturing the CUDA graph. Its main purpose is to ensure that the
|
||||
@@ -1371,9 +1379,9 @@ def graph_capture():
|
||||
in order to explicitly distinguish the kernels to capture
|
||||
from other kernels possibly launched on background in the default stream.
|
||||
"""
|
||||
with get_tp_group().graph_capture() as context, get_pp_group().graph_capture(
|
||||
context
|
||||
):
|
||||
with get_tp_group().graph_capture(
|
||||
stream=stream
|
||||
) as context, get_pp_group().graph_capture(context):
|
||||
yield context
|
||||
|
||||
|
||||
@@ -1527,6 +1535,7 @@ def initialize_model_parallel(
|
||||
"SGLANG_USE_MESSAGE_QUEUE_BROADCASTER", "true"
|
||||
),
|
||||
group_name="tp",
|
||||
pynccl_use_current_stream=duplicate_tp_group,
|
||||
torch_compile=torch_compile,
|
||||
)
|
||||
|
||||
@@ -1543,10 +1552,12 @@ def initialize_model_parallel(
|
||||
"SGLANG_USE_MESSAGE_QUEUE_BROADCASTER", "true"
|
||||
),
|
||||
group_name="pdmux_prefill_tp",
|
||||
pynccl_use_current_stream=True,
|
||||
torch_compile=torch_compile,
|
||||
)
|
||||
_TP.pynccl_comm.disabled = False
|
||||
_PDMUX_PREFILL_TP_GROUP.pynccl_comm.disabled = False
|
||||
if _TP.pynccl_comm:
|
||||
_TP.pynccl_comm.disabled = False
|
||||
_PDMUX_PREFILL_TP_GROUP.pynccl_comm.disabled = False
|
||||
|
||||
moe_ep_size = expert_model_parallel_size
|
||||
moe_tp_size = tensor_model_parallel_size // moe_ep_size
|
||||
@@ -1737,6 +1748,11 @@ def destroy_model_parallel():
|
||||
_PP.destroy()
|
||||
_PP = None
|
||||
|
||||
global _PDMUX_PREFILL_TP_GROUP
|
||||
if _PDMUX_PREFILL_TP_GROUP: # type: ignore[union-attr]
|
||||
_PDMUX_PREFILL_TP_GROUP.destroy()
|
||||
_PDMUX_PREFILL_TP_GROUP = None
|
||||
|
||||
|
||||
def destroy_distributed_environment():
|
||||
global _WORLD
|
||||
|
||||
Reference in New Issue
Block a user