support 1 shot allreduce in 1-node and 2-node using mscclpp (#6277)

This commit is contained in:
zyksir
2025-06-05 13:11:24 +08:00
committed by GitHub
parent 4474eaf552
commit 8e3797be1c
20 changed files with 2177 additions and 12 deletions

View File

@@ -190,6 +190,7 @@ class GroupCoordinator:
cpu_group: ProcessGroup # group for CPU communication
device_group: ProcessGroup # group for device communication
use_pynccl: bool # a hint of whether to use PyNccl
use_pymscclpp: bool # a hint of whether to use PyMsccl
use_custom_allreduce: bool # a hint of whether to use CustomAllreduce
use_message_queue_broadcaster: (
bool # a hint of whether to use message queue broadcaster
@@ -205,6 +206,7 @@ class GroupCoordinator:
local_rank: int,
torch_distributed_backend: Union[str, Backend],
use_pynccl: bool,
use_pymscclpp: bool,
use_custom_allreduce: bool,
use_hpu_communicator: bool,
use_xpu_communicator: bool,
@@ -244,6 +246,7 @@ class GroupCoordinator:
self.device = torch.device("cpu")
self.use_pynccl = use_pynccl
self.use_pymscclpp = use_pymscclpp
self.use_custom_allreduce = use_custom_allreduce
self.use_hpu_communicator = use_hpu_communicator
self.use_xpu_communicator = use_xpu_communicator
@@ -265,6 +268,17 @@ class GroupCoordinator:
device=self.device,
)
from sglang.srt.distributed.device_communicators.pymscclpp import (
PyMscclppCommunicator,
)
self.pymscclpp_comm: Optional[PyMscclppCommunicator] = None
if use_pymscclpp and self.world_size > 1:
self.pymscclpp_comm = PyMscclppCommunicator(
group=self.cpu_group,
device=self.device,
)
self.ca_comm: Optional[CustomAllreduce] = None
if use_custom_allreduce and self.world_size > 1:
# Initialize a custom fast all-reduce implementation.
@@ -373,11 +387,15 @@ class GroupCoordinator:
# --------------------------------------------
# custom allreduce | enabled | enabled |
# PyNccl | disabled| enabled |
# PyMscclpp | disabled| enabled |
# torch.distributed | enabled | disabled|
#
# Note that custom allreduce will have a runtime check, if the
# tensor size is too large, it will fallback to the next
# available option.
# Note that the PyMsccl needs to register the tensor in ahead,
# which will introduce large overhead in the eager case,
# therefore it is only supported in the graph case.
# In summary: When using CUDA graph, we use
# either custom all-reduce kernel or pynccl. When not using
# CUDA graph, we use either custom all-reduce kernel or
@@ -392,7 +410,14 @@ class GroupCoordinator:
maybe_pynccl_context = pynccl_comm.change_state(
enable=True, stream=torch.cuda.current_stream()
)
with maybe_pynccl_context:
pymscclpp_comm = self.pymscclpp_comm
maybe_pymscclpp_context: Any
if not pymscclpp_comm:
maybe_pymscclpp_context = nullcontext()
else:
maybe_pymscclpp_context = pymscclpp_comm.change_state(enable=True)
with maybe_pynccl_context, maybe_pymscclpp_context:
yield graph_capture_context
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
@@ -437,6 +462,10 @@ class GroupCoordinator:
self.ca_comm is not None
and not self.ca_comm.disabled
and self.ca_comm.should_custom_ar(input_)
) or (
self.pymscclpp_comm is not None
and not self.pymscclpp_comm.disabled
and self.pymscclpp_comm.should_mscclpp_allreduce(input_)
):
return torch.ops.sglang.outplace_all_reduce(
input_, group_name=self.unique_name
@@ -447,9 +476,13 @@ class GroupCoordinator:
def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor:
ca_comm = self.ca_comm
assert ca_comm is not None
assert not ca_comm.disabled
out = ca_comm.custom_all_reduce(input_)
pymscclpp_comm = self.pymscclpp_comm
assert ca_comm is not None or pymscclpp_comm is not None
if ca_comm is not None and not ca_comm.disabled:
out = ca_comm.custom_all_reduce(input_)
else:
assert not pymscclpp_comm.disabled
out = pymscclpp_comm.all_reduce(input_)
assert out is not None
return out
@@ -958,6 +991,7 @@ def init_world_group(
local_rank=local_rank,
torch_distributed_backend=backend,
use_pynccl=False,
use_pymscclpp=False,
use_custom_allreduce=False,
use_hpu_communicator=False,
use_xpu_communicator=False,
@@ -973,14 +1007,18 @@ def init_model_parallel_group(
use_custom_allreduce: Optional[bool] = None,
use_message_queue_broadcaster: bool = False,
group_name: Optional[str] = None,
use_mscclpp_allreduce: Optional[bool] = None,
) -> GroupCoordinator:
if use_custom_allreduce is None:
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
if use_mscclpp_allreduce is None:
use_mscclpp_allreduce = _ENABLE_MSCCLPP_ALL_REDUCE
return GroupCoordinator(
group_ranks=group_ranks,
local_rank=local_rank,
torch_distributed_backend=backend,
use_pynccl=not is_npu(),
use_pymscclpp=use_mscclpp_allreduce,
use_custom_allreduce=use_custom_allreduce,
use_hpu_communicator=True,
use_xpu_communicator=True,
@@ -1037,6 +1075,7 @@ def graph_capture():
logger = logging.getLogger(__name__)
_ENABLE_CUSTOM_ALL_REDUCE = True
_ENABLE_MSCCLPP_ALL_REDUCE = False
def set_custom_all_reduce(enable: bool):
@@ -1044,6 +1083,11 @@ def set_custom_all_reduce(enable: bool):
_ENABLE_CUSTOM_ALL_REDUCE = enable
def set_mscclpp_all_reduce(enable: bool):
global _ENABLE_MSCCLPP_ALL_REDUCE
_ENABLE_MSCCLPP_ALL_REDUCE = enable
def init_distributed_environment(
world_size: int = -1,
rank: int = -1,