support 1 shot allreduce in 1-node and 2-node using mscclpp (#6277)
This commit is contained in:
@@ -190,6 +190,7 @@ class GroupCoordinator:
|
||||
cpu_group: ProcessGroup # group for CPU communication
|
||||
device_group: ProcessGroup # group for device communication
|
||||
use_pynccl: bool # a hint of whether to use PyNccl
|
||||
use_pymscclpp: bool # a hint of whether to use PyMsccl
|
||||
use_custom_allreduce: bool # a hint of whether to use CustomAllreduce
|
||||
use_message_queue_broadcaster: (
|
||||
bool # a hint of whether to use message queue broadcaster
|
||||
@@ -205,6 +206,7 @@ class GroupCoordinator:
|
||||
local_rank: int,
|
||||
torch_distributed_backend: Union[str, Backend],
|
||||
use_pynccl: bool,
|
||||
use_pymscclpp: bool,
|
||||
use_custom_allreduce: bool,
|
||||
use_hpu_communicator: bool,
|
||||
use_xpu_communicator: bool,
|
||||
@@ -244,6 +246,7 @@ class GroupCoordinator:
|
||||
self.device = torch.device("cpu")
|
||||
|
||||
self.use_pynccl = use_pynccl
|
||||
self.use_pymscclpp = use_pymscclpp
|
||||
self.use_custom_allreduce = use_custom_allreduce
|
||||
self.use_hpu_communicator = use_hpu_communicator
|
||||
self.use_xpu_communicator = use_xpu_communicator
|
||||
@@ -265,6 +268,17 @@ class GroupCoordinator:
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
from sglang.srt.distributed.device_communicators.pymscclpp import (
|
||||
PyMscclppCommunicator,
|
||||
)
|
||||
|
||||
self.pymscclpp_comm: Optional[PyMscclppCommunicator] = None
|
||||
if use_pymscclpp and self.world_size > 1:
|
||||
self.pymscclpp_comm = PyMscclppCommunicator(
|
||||
group=self.cpu_group,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
self.ca_comm: Optional[CustomAllreduce] = None
|
||||
if use_custom_allreduce and self.world_size > 1:
|
||||
# Initialize a custom fast all-reduce implementation.
|
||||
@@ -373,11 +387,15 @@ class GroupCoordinator:
|
||||
# --------------------------------------------
|
||||
# custom allreduce | enabled | enabled |
|
||||
# PyNccl | disabled| enabled |
|
||||
# PyMscclpp | disabled| enabled |
|
||||
# torch.distributed | enabled | disabled|
|
||||
#
|
||||
# Note that custom allreduce will have a runtime check, if the
|
||||
# tensor size is too large, it will fallback to the next
|
||||
# available option.
|
||||
# Note that the PyMsccl needs to register the tensor in ahead,
|
||||
# which will introduce large overhead in the eager case,
|
||||
# therefore it is only supported in the graph case.
|
||||
# In summary: When using CUDA graph, we use
|
||||
# either custom all-reduce kernel or pynccl. When not using
|
||||
# CUDA graph, we use either custom all-reduce kernel or
|
||||
@@ -392,7 +410,14 @@ class GroupCoordinator:
|
||||
maybe_pynccl_context = pynccl_comm.change_state(
|
||||
enable=True, stream=torch.cuda.current_stream()
|
||||
)
|
||||
with maybe_pynccl_context:
|
||||
|
||||
pymscclpp_comm = self.pymscclpp_comm
|
||||
maybe_pymscclpp_context: Any
|
||||
if not pymscclpp_comm:
|
||||
maybe_pymscclpp_context = nullcontext()
|
||||
else:
|
||||
maybe_pymscclpp_context = pymscclpp_comm.change_state(enable=True)
|
||||
with maybe_pynccl_context, maybe_pymscclpp_context:
|
||||
yield graph_capture_context
|
||||
|
||||
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
|
||||
@@ -437,6 +462,10 @@ class GroupCoordinator:
|
||||
self.ca_comm is not None
|
||||
and not self.ca_comm.disabled
|
||||
and self.ca_comm.should_custom_ar(input_)
|
||||
) or (
|
||||
self.pymscclpp_comm is not None
|
||||
and not self.pymscclpp_comm.disabled
|
||||
and self.pymscclpp_comm.should_mscclpp_allreduce(input_)
|
||||
):
|
||||
return torch.ops.sglang.outplace_all_reduce(
|
||||
input_, group_name=self.unique_name
|
||||
@@ -447,9 +476,13 @@ class GroupCoordinator:
|
||||
|
||||
def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor:
|
||||
ca_comm = self.ca_comm
|
||||
assert ca_comm is not None
|
||||
assert not ca_comm.disabled
|
||||
out = ca_comm.custom_all_reduce(input_)
|
||||
pymscclpp_comm = self.pymscclpp_comm
|
||||
assert ca_comm is not None or pymscclpp_comm is not None
|
||||
if ca_comm is not None and not ca_comm.disabled:
|
||||
out = ca_comm.custom_all_reduce(input_)
|
||||
else:
|
||||
assert not pymscclpp_comm.disabled
|
||||
out = pymscclpp_comm.all_reduce(input_)
|
||||
assert out is not None
|
||||
return out
|
||||
|
||||
@@ -958,6 +991,7 @@ def init_world_group(
|
||||
local_rank=local_rank,
|
||||
torch_distributed_backend=backend,
|
||||
use_pynccl=False,
|
||||
use_pymscclpp=False,
|
||||
use_custom_allreduce=False,
|
||||
use_hpu_communicator=False,
|
||||
use_xpu_communicator=False,
|
||||
@@ -973,14 +1007,18 @@ def init_model_parallel_group(
|
||||
use_custom_allreduce: Optional[bool] = None,
|
||||
use_message_queue_broadcaster: bool = False,
|
||||
group_name: Optional[str] = None,
|
||||
use_mscclpp_allreduce: Optional[bool] = None,
|
||||
) -> GroupCoordinator:
|
||||
if use_custom_allreduce is None:
|
||||
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
|
||||
if use_mscclpp_allreduce is None:
|
||||
use_mscclpp_allreduce = _ENABLE_MSCCLPP_ALL_REDUCE
|
||||
return GroupCoordinator(
|
||||
group_ranks=group_ranks,
|
||||
local_rank=local_rank,
|
||||
torch_distributed_backend=backend,
|
||||
use_pynccl=not is_npu(),
|
||||
use_pymscclpp=use_mscclpp_allreduce,
|
||||
use_custom_allreduce=use_custom_allreduce,
|
||||
use_hpu_communicator=True,
|
||||
use_xpu_communicator=True,
|
||||
@@ -1037,6 +1075,7 @@ def graph_capture():
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_ENABLE_CUSTOM_ALL_REDUCE = True
|
||||
_ENABLE_MSCCLPP_ALL_REDUCE = False
|
||||
|
||||
|
||||
def set_custom_all_reduce(enable: bool):
|
||||
@@ -1044,6 +1083,11 @@ def set_custom_all_reduce(enable: bool):
|
||||
_ENABLE_CUSTOM_ALL_REDUCE = enable
|
||||
|
||||
|
||||
def set_mscclpp_all_reduce(enable: bool):
|
||||
global _ENABLE_MSCCLPP_ALL_REDUCE
|
||||
_ENABLE_MSCCLPP_ALL_REDUCE = enable
|
||||
|
||||
|
||||
def init_distributed_environment(
|
||||
world_size: int = -1,
|
||||
rank: int = -1,
|
||||
|
||||
Reference in New Issue
Block a user