[Feature] Use current greenctx stream to communicate in PD-Multiplexing. (#11594)
This commit is contained in:
@@ -30,6 +30,7 @@ class PyNcclCommunicator:
|
|||||||
group: Union[ProcessGroup, StatelessProcessGroup],
|
group: Union[ProcessGroup, StatelessProcessGroup],
|
||||||
device: Union[int, str, torch.device],
|
device: Union[int, str, torch.device],
|
||||||
library_path: Optional[str] = None,
|
library_path: Optional[str] = None,
|
||||||
|
use_current_stream: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@@ -74,6 +75,7 @@ class PyNcclCommunicator:
|
|||||||
|
|
||||||
self.available = True
|
self.available = True
|
||||||
self.disabled = False
|
self.disabled = False
|
||||||
|
self.use_current_stream = use_current_stream
|
||||||
|
|
||||||
self.nccl_version = self.nccl.ncclGetRawVersion()
|
self.nccl_version = self.nccl.ncclGetRawVersion()
|
||||||
if self.rank == 0:
|
if self.rank == 0:
|
||||||
@@ -123,6 +125,21 @@ class PyNcclCommunicator:
|
|||||||
# when we are using CUDA graph.
|
# when we are using CUDA graph.
|
||||||
self.disabled = True
|
self.disabled = True
|
||||||
|
|
||||||
|
def _resolve_stream(self, stream: Optional[torch.cuda.Stream]):
|
||||||
|
"""Return the stream to use for NCCL calls.
|
||||||
|
|
||||||
|
Behavior mirrors the previous inline logic:
|
||||||
|
- if an explicit stream is provided, return it
|
||||||
|
- if stream is None and self.use_current_stream is True, return
|
||||||
|
torch.cuda.current_stream()
|
||||||
|
- otherwise return the communicator's default stream (self.stream)
|
||||||
|
"""
|
||||||
|
if stream is not None:
|
||||||
|
return stream
|
||||||
|
if self.use_current_stream:
|
||||||
|
return torch.cuda.current_stream()
|
||||||
|
return self.stream
|
||||||
|
|
||||||
def all_reduce(
|
def all_reduce(
|
||||||
self, tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM, stream=None
|
self, tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM, stream=None
|
||||||
):
|
):
|
||||||
@@ -135,8 +152,7 @@ class PyNcclCommunicator:
|
|||||||
f"this nccl communicator is created to work on {self.device}, "
|
f"this nccl communicator is created to work on {self.device}, "
|
||||||
f"but the input tensor is on {tensor.device}"
|
f"but the input tensor is on {tensor.device}"
|
||||||
)
|
)
|
||||||
if stream is None:
|
stream = self._resolve_stream(stream)
|
||||||
stream = self.stream
|
|
||||||
self.nccl.ncclAllReduce(
|
self.nccl.ncclAllReduce(
|
||||||
buffer_type(tensor.data_ptr()),
|
buffer_type(tensor.data_ptr()),
|
||||||
buffer_type(tensor.data_ptr()),
|
buffer_type(tensor.data_ptr()),
|
||||||
@@ -163,8 +179,7 @@ class PyNcclCommunicator:
|
|||||||
f"this nccl communicator is created to work on {self.device}, "
|
f"this nccl communicator is created to work on {self.device}, "
|
||||||
f"but the input tensor is on {input_tensor.device}"
|
f"but the input tensor is on {input_tensor.device}"
|
||||||
)
|
)
|
||||||
if stream is None:
|
stream = self._resolve_stream(stream)
|
||||||
stream = self.stream
|
|
||||||
|
|
||||||
if sizes is not None:
|
if sizes is not None:
|
||||||
split_offset = 0
|
split_offset = 0
|
||||||
@@ -210,8 +225,7 @@ class PyNcclCommunicator:
|
|||||||
f"this nccl communicator is created to work on {self.device}, "
|
f"this nccl communicator is created to work on {self.device}, "
|
||||||
f"but the input tensor is on {input_tensor.device}"
|
f"but the input tensor is on {input_tensor.device}"
|
||||||
)
|
)
|
||||||
if stream is None:
|
stream = self._resolve_stream(stream)
|
||||||
stream = self.stream
|
|
||||||
|
|
||||||
if sizes is not None:
|
if sizes is not None:
|
||||||
split_offset = 0
|
split_offset = 0
|
||||||
@@ -249,8 +263,7 @@ class PyNcclCommunicator:
|
|||||||
f"this nccl communicator is created to work on {self.device}, "
|
f"this nccl communicator is created to work on {self.device}, "
|
||||||
f"but the input tensor is on {tensor.device}"
|
f"but the input tensor is on {tensor.device}"
|
||||||
)
|
)
|
||||||
if stream is None:
|
stream = self._resolve_stream(stream)
|
||||||
stream = self.stream
|
|
||||||
self.nccl.ncclSend(
|
self.nccl.ncclSend(
|
||||||
buffer_type(tensor.data_ptr()),
|
buffer_type(tensor.data_ptr()),
|
||||||
tensor.numel(),
|
tensor.numel(),
|
||||||
@@ -267,8 +280,7 @@ class PyNcclCommunicator:
|
|||||||
f"this nccl communicator is created to work on {self.device}, "
|
f"this nccl communicator is created to work on {self.device}, "
|
||||||
f"but the input tensor is on {tensor.device}"
|
f"but the input tensor is on {tensor.device}"
|
||||||
)
|
)
|
||||||
if stream is None:
|
stream = self._resolve_stream(stream)
|
||||||
stream = self.stream
|
|
||||||
self.nccl.ncclRecv(
|
self.nccl.ncclRecv(
|
||||||
buffer_type(tensor.data_ptr()),
|
buffer_type(tensor.data_ptr()),
|
||||||
tensor.numel(),
|
tensor.numel(),
|
||||||
@@ -285,8 +297,8 @@ class PyNcclCommunicator:
|
|||||||
f"this nccl communicator is created to work on {self.device}, "
|
f"this nccl communicator is created to work on {self.device}, "
|
||||||
f"but the input tensor is on {tensor.device}"
|
f"but the input tensor is on {tensor.device}"
|
||||||
)
|
)
|
||||||
if stream is None:
|
stream = self._resolve_stream(stream)
|
||||||
stream = self.stream
|
|
||||||
if src == self.rank:
|
if src == self.rank:
|
||||||
sendbuff = buffer_type(tensor.data_ptr())
|
sendbuff = buffer_type(tensor.data_ptr())
|
||||||
# NCCL requires the sender also to have a receive buffer
|
# NCCL requires the sender also to have a receive buffer
|
||||||
|
|||||||
@@ -239,6 +239,7 @@ class GroupCoordinator:
|
|||||||
use_npu_communicator: bool,
|
use_npu_communicator: bool,
|
||||||
use_message_queue_broadcaster: bool = False,
|
use_message_queue_broadcaster: bool = False,
|
||||||
group_name: Optional[str] = None,
|
group_name: Optional[str] = None,
|
||||||
|
pynccl_use_current_stream: bool = False,
|
||||||
torch_compile: Optional[bool] = None,
|
torch_compile: Optional[bool] = None,
|
||||||
gloo_timeout: timedelta = timedelta(seconds=120 * 60),
|
gloo_timeout: timedelta = timedelta(seconds=120 * 60),
|
||||||
):
|
):
|
||||||
@@ -289,6 +290,7 @@ class GroupCoordinator:
|
|||||||
|
|
||||||
# Import communicators
|
# Import communicators
|
||||||
self.use_pynccl = use_pynccl
|
self.use_pynccl = use_pynccl
|
||||||
|
self.pynccl_use_current_stream = pynccl_use_current_stream
|
||||||
self.use_pymscclpp = use_pymscclpp
|
self.use_pymscclpp = use_pymscclpp
|
||||||
self.use_custom_allreduce = use_custom_allreduce
|
self.use_custom_allreduce = use_custom_allreduce
|
||||||
self.use_torch_symm_mem = use_torch_symm_mem
|
self.use_torch_symm_mem = use_torch_symm_mem
|
||||||
@@ -322,6 +324,7 @@ class GroupCoordinator:
|
|||||||
self.pynccl_comm = PyNcclCommunicator(
|
self.pynccl_comm = PyNcclCommunicator(
|
||||||
group=self.cpu_group,
|
group=self.cpu_group,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
|
use_current_stream=pynccl_use_current_stream,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.pymscclpp_comm: Optional[PyMscclppCommunicator] = None
|
self.pymscclpp_comm: Optional[PyMscclppCommunicator] = None
|
||||||
@@ -449,10 +452,13 @@ class GroupCoordinator:
|
|||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def graph_capture(
|
def graph_capture(
|
||||||
self, graph_capture_context: Optional[GraphCaptureContext] = None
|
self,
|
||||||
|
graph_capture_context: Optional[GraphCaptureContext] = None,
|
||||||
|
stream: Optional[torch.cuda.Stream] = None,
|
||||||
):
|
):
|
||||||
if graph_capture_context is None:
|
if graph_capture_context is None:
|
||||||
stream = self.device_module.Stream()
|
if stream is None:
|
||||||
|
stream = self.device_module.Stream()
|
||||||
graph_capture_context = GraphCaptureContext(stream)
|
graph_capture_context = GraphCaptureContext(stream)
|
||||||
else:
|
else:
|
||||||
stream = graph_capture_context.stream
|
stream = graph_capture_context.stream
|
||||||
@@ -1278,6 +1284,7 @@ def init_model_parallel_group(
|
|||||||
use_message_queue_broadcaster: bool = False,
|
use_message_queue_broadcaster: bool = False,
|
||||||
group_name: Optional[str] = None,
|
group_name: Optional[str] = None,
|
||||||
use_mscclpp_allreduce: Optional[bool] = None,
|
use_mscclpp_allreduce: Optional[bool] = None,
|
||||||
|
pynccl_use_current_stream: bool = True,
|
||||||
use_symm_mem_allreduce: Optional[bool] = None,
|
use_symm_mem_allreduce: Optional[bool] = None,
|
||||||
torch_compile: Optional[bool] = None,
|
torch_compile: Optional[bool] = None,
|
||||||
) -> GroupCoordinator:
|
) -> GroupCoordinator:
|
||||||
@@ -1300,6 +1307,7 @@ def init_model_parallel_group(
|
|||||||
use_npu_communicator=True,
|
use_npu_communicator=True,
|
||||||
use_message_queue_broadcaster=use_message_queue_broadcaster,
|
use_message_queue_broadcaster=use_message_queue_broadcaster,
|
||||||
group_name=group_name,
|
group_name=group_name,
|
||||||
|
pynccl_use_current_stream=pynccl_use_current_stream,
|
||||||
torch_compile=torch_compile,
|
torch_compile=torch_compile,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1357,7 +1365,7 @@ get_pipeline_model_parallel_group = get_pp_group
|
|||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def graph_capture():
|
def graph_capture(stream: Optional[torch.cuda.Stream] = None):
|
||||||
"""
|
"""
|
||||||
`graph_capture` is a context manager which should surround the code that
|
`graph_capture` is a context manager which should surround the code that
|
||||||
is capturing the CUDA graph. Its main purpose is to ensure that the
|
is capturing the CUDA graph. Its main purpose is to ensure that the
|
||||||
@@ -1371,9 +1379,9 @@ def graph_capture():
|
|||||||
in order to explicitly distinguish the kernels to capture
|
in order to explicitly distinguish the kernels to capture
|
||||||
from other kernels possibly launched on background in the default stream.
|
from other kernels possibly launched on background in the default stream.
|
||||||
"""
|
"""
|
||||||
with get_tp_group().graph_capture() as context, get_pp_group().graph_capture(
|
with get_tp_group().graph_capture(
|
||||||
context
|
stream=stream
|
||||||
):
|
) as context, get_pp_group().graph_capture(context):
|
||||||
yield context
|
yield context
|
||||||
|
|
||||||
|
|
||||||
@@ -1527,6 +1535,7 @@ def initialize_model_parallel(
|
|||||||
"SGLANG_USE_MESSAGE_QUEUE_BROADCASTER", "true"
|
"SGLANG_USE_MESSAGE_QUEUE_BROADCASTER", "true"
|
||||||
),
|
),
|
||||||
group_name="tp",
|
group_name="tp",
|
||||||
|
pynccl_use_current_stream=duplicate_tp_group,
|
||||||
torch_compile=torch_compile,
|
torch_compile=torch_compile,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1543,10 +1552,12 @@ def initialize_model_parallel(
|
|||||||
"SGLANG_USE_MESSAGE_QUEUE_BROADCASTER", "true"
|
"SGLANG_USE_MESSAGE_QUEUE_BROADCASTER", "true"
|
||||||
),
|
),
|
||||||
group_name="pdmux_prefill_tp",
|
group_name="pdmux_prefill_tp",
|
||||||
|
pynccl_use_current_stream=True,
|
||||||
torch_compile=torch_compile,
|
torch_compile=torch_compile,
|
||||||
)
|
)
|
||||||
_TP.pynccl_comm.disabled = False
|
if _TP.pynccl_comm:
|
||||||
_PDMUX_PREFILL_TP_GROUP.pynccl_comm.disabled = False
|
_TP.pynccl_comm.disabled = False
|
||||||
|
_PDMUX_PREFILL_TP_GROUP.pynccl_comm.disabled = False
|
||||||
|
|
||||||
moe_ep_size = expert_model_parallel_size
|
moe_ep_size = expert_model_parallel_size
|
||||||
moe_tp_size = tensor_model_parallel_size // moe_ep_size
|
moe_tp_size = tensor_model_parallel_size // moe_ep_size
|
||||||
@@ -1737,6 +1748,11 @@ def destroy_model_parallel():
|
|||||||
_PP.destroy()
|
_PP.destroy()
|
||||||
_PP = None
|
_PP = None
|
||||||
|
|
||||||
|
global _PDMUX_PREFILL_TP_GROUP
|
||||||
|
if _PDMUX_PREFILL_TP_GROUP: # type: ignore[union-attr]
|
||||||
|
_PDMUX_PREFILL_TP_GROUP.destroy()
|
||||||
|
_PDMUX_PREFILL_TP_GROUP = None
|
||||||
|
|
||||||
|
|
||||||
def destroy_distributed_environment():
|
def destroy_distributed_environment():
|
||||||
global _WORLD
|
global _WORLD
|
||||||
|
|||||||
Reference in New Issue
Block a user