support 1 shot allreduce in 1-node and 2-node using mscclpp (#6277)

This commit is contained in:
zyksir
2025-06-05 13:11:24 +08:00
committed by GitHub
parent 4474eaf552
commit 8e3797be1c
20 changed files with 2177 additions and 12 deletions

View File

@@ -113,3 +113,37 @@ else:
def get_meta_buffer_ipc_handle(inp: torch.Tensor) -> torch.Tensor:
return sgl_kernel.allreduce.get_meta_buffer_ipc_handle(inp)
def mscclpp_generate_unique_id() -> bytes:
return sgl_kernel.allreduce.mscclpp_generate_unique_id()
def mscclpp_init_context(
unique_id: bytes,
rank: int,
world_size: int,
scratch: torch.Tensor,
put_buffer: torch.Tensor,
nranks_per_node: int,
rank_to_node: List[int],
rank_to_ib: List[int],
context_selection: int,
) -> int:
return sgl_kernel.allreduce.mscclpp_init_context(
unique_id,
rank,
world_size,
scratch,
put_buffer,
nranks_per_node,
rank_to_node,
rank_to_ib,
context_selection,
)
def mscclpp_allreduce(
context: int, inp: torch.Tensor, out: torch.Tensor, nthreads: int, nblocks: int
) -> None:
return sgl_kernel.allreduce.mscclpp_allreduce(context, inp, out, nthreads, nblocks)

View File

@@ -0,0 +1,315 @@
import bisect
import logging
import math
import os
from contextlib import contextmanager
from enum import IntEnum
from typing import Any, Callable, List, Optional, TypeVar, Union
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup, ReduceOp
from sglang.srt import _custom_ops as ops
from sglang.srt.utils import is_cuda, is_hip
logger = logging.getLogger(__name__)
_is_cuda = is_cuda()
_is_hip = is_hip()
mscclpp_is_available = False
if _is_hip:
# TODO(zyksir): mscclpp is untested on AMD and therefore disabled.
mscclpp_is_available = False
if _is_cuda:
try:
import sgl_kernel
mscclpp_is_available = True
except:
mscclpp_is_available = False
class MscclContextSelection(IntEnum):
MSCCL1SHOT1NODELL = 1
MSCCL1SHOT2NODELL = 2
def mscclpp_is_weak_contiguous(inp: torch.Tensor):
return inp.is_contiguous() or (
inp.storage().nbytes() - inp.storage_offset() * inp.element_size()
== inp.numel() * inp.element_size()
)
def mscclpp_convert_to_bytes(size_str):
"""
Converts a human-readable size string (e.g., "1MB", "2.5kb", "3 GB")
into the equivalent number of bytes using binary units.
Args:
size_str (str): A string representing size with unit (KB, MB, GB).
Returns:
int: Number of bytes.
"""
size_str = size_str.strip().lower()
if not size_str:
raise ValueError("Empty input string")
# Extract numeric part and unit
for i in range(len(size_str)):
if not size_str[i].isdigit() and size_str[i] != ".":
break
num_str = size_str[:i]
unit = size_str[i:].strip()
try:
num = float(num_str)
except ValueError:
raise ValueError(f"Invalid numeric value in '{size_str}'")
# Conversion factors
if unit == "b":
return int(num)
elif unit == "kb":
return int(num * 1024)
elif unit == "mb":
return int(num * 1024 * 1024)
elif unit == "gb":
return int(num * 1024 * 1024 * 1024)
else:
raise ValueError(f"Unsupported unit: {unit}, support B, KB, MB, GB only")
def mscclpp_bench_time(func, test_niter: int = 10, warmup_niter: int = 2):
# warmup
for _ in range(warmup_niter):
func()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
torch.cuda.synchronize()
dist.barrier()
start_event.record()
for _ in range(test_niter):
func()
end_event.record()
end_event.synchronize()
func_cost_us = start_event.elapsed_time(end_event) / test_niter * 1000
return func_cost_us
class PyMscclppCommunicator:
_SUPPORTED_WORLD_SIZES = [8, 16]
_MAX_BYTES = mscclpp_convert_to_bytes(os.getenv("SGLANG_MSCCLPP_MAX_BYTES", "1MB"))
_SUPPORTED_DTYPE = [torch.float, torch.float16, torch.bfloat16]
# max_bytes: max supported mscclpp allreduce size
# in A100 mscclpp is faster than nccl only under condition of msg size smaller than1MB
def __init__(
self,
group: ProcessGroup,
device: Union[int, str, torch.device],
max_bytes=_MAX_BYTES,
) -> None:
"""
Args:
group: the process group to work on. If None, it will use the
default process group.
device: the device to bind the CustomAllreduce to. If None,
it will be bind to f"cuda:{local_rank}".
It is the caller's responsibility to make sure each communicator
is bind to a unique device, and all communicators in this group
are in the same node.
"""
self._IS_CAPTURING = False
self.disabled = True
if not mscclpp_is_available:
# disable because of missing mscclpp library
# e.g. in a non-cuda environment
return
self.group = group
assert (
dist.get_backend(group) != dist.Backend.NCCL
), "CustomAllreduce should be attached to a non-NCCL group."
rank = dist.get_rank(group=self.group)
world_size = dist.get_world_size(group=self.group)
if world_size == 1:
# No need to initialize mscclpp for single GPU case.
return
if world_size not in PyMscclppCommunicator._SUPPORTED_WORLD_SIZES:
logger.warning(
"PyMscclpp is disabled due to an unsupported world"
" size: %d. Supported world sizes: %s. To silence this "
"warning, specify disable_mscclpp=True explicitly.",
world_size,
str(PyMscclppCommunicator._SUPPORTED_WORLD_SIZES),
)
return
self.ranks = torch.distributed.get_process_group_ranks(group)
self.nranks_per_node = torch.cuda.device_count()
# for now mscclpp with stride in the communicator is not tested
if not (abs(self.ranks[-1] - self.ranks[0]) == world_size - 1):
logger.warning(
"PyMscclpp is disabled due to an unsupported group %s."
"Please ensure all ranks in the group are consecutive."
"To silence this warning, specify disable_mscclpp=True explicitly.",
str(self.ranks),
)
return
if isinstance(device, int):
device = torch.device(f"cuda:{device}")
elif isinstance(device, str):
device = torch.device(device)
# now `device` is a `torch.device` object
assert isinstance(device, torch.device)
self.device = device
self.max_bytes = max_bytes
self.rank = rank
self.world_size = world_size
if dist.get_rank(group) == 0:
unique_id = [ops.mscclpp_generate_unique_id()]
else:
unique_id = [None]
dist.broadcast_object_list(unique_id, src=self.ranks[0], group=self.group)
self.unique_id = unique_id[0]
self.rank_to_node, self.rank_to_ib = list(range(world_size)), list(
range(world_size)
)
for r in range(world_size):
self.rank_to_node[r] = r // 8
self.rank_to_ib[r] = self.rank % 8
self._context = None
self.context_selection = None
self.msg_size_for_finetune = [
2**i for i in range(10, math.floor(math.log2(self.max_bytes)) + 1)
]
self.msg_size2best_config = {}
if world_size == 8:
self.context_selection = MscclContextSelection.MSCCL1SHOT1NODELL
elif world_size == 16:
self.context_selection = MscclContextSelection.MSCCL1SHOT2NODELL
if not _is_hip:
self.scratch = torch.empty(
self.max_bytes * 8,
dtype=torch.uint8,
device=self.device,
)
self.put_buffer = torch.empty(
self.max_bytes * 8 // self.nranks_per_node,
dtype=torch.uint8,
device=self.device,
)
self._context = ops.mscclpp_init_context(
self.unique_id,
self.rank,
self.world_size,
self.scratch,
self.put_buffer,
self.nranks_per_node,
self.rank_to_node,
self.rank_to_ib,
int(self.context_selection),
)
else:
raise NotImplementedError("HIP Mscclpp is not supported yet.")
self.msg_size2best_config = {}
self.pre_tune_config()
if dist.get_rank(group) == 0:
msg_size2best_config = [self.msg_size2best_config]
else:
msg_size2best_config = [None]
dist.broadcast_object_list(
msg_size2best_config, src=self.ranks[0], group=self.group
)
self.msg_size2best_config = msg_size2best_config[0]
# PyMscclpp is enabled only in cuda graph
self.disabled = True
def pre_tune_config(self, dtype=torch.bfloat16) -> bool:
logger.debug(f"start to pre-tune configs for rank {self.rank}")
nthreads_to_try = [256, 512, 1024]
nblocks_to_try = [21, 42, 84]
inp_randn = torch.ones(
self.msg_size_for_finetune[-1] // dtype.itemsize, dtype=dtype, device="cuda"
)
oup_randn = torch.empty_like(inp_randn)
for msg_size in self.msg_size_for_finetune:
mock_inp, mock_outp = (
inp_randn[: msg_size // dtype.itemsize],
oup_randn[: msg_size // dtype.itemsize],
)
best_config, best_time = None, None
for nthreads in nthreads_to_try:
for nblocks in nblocks_to_try:
cur_cost = mscclpp_bench_time(
lambda: ops.mscclpp_allreduce(
self._context, mock_inp, mock_outp, nthreads, nblocks
)
)
if best_time is None or cur_cost < best_time:
best_config = (nthreads, nblocks)
best_time = cur_cost
self.msg_size2best_config[msg_size] = best_config
if self.rank == 0:
logger.debug(
f"for msg_size {msg_size}, best_config: {best_config}, best_time: {best_time}us"
)
def should_mscclpp_allreduce(
self, inp: torch.Tensor, op: ReduceOp = ReduceOp.SUM
) -> bool:
if self.disabled or self._context is None:
return False
if inp.dtype not in PyMscclppCommunicator._SUPPORTED_DTYPE:
return False
if not mscclpp_is_weak_contiguous(inp):
return False
# only support sum op
if op != ReduceOp.SUM:
return False
if inp.numel() * inp.element_size() > self.max_bytes:
return False
return True
def all_reduce(self, tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM):
if self._IS_CAPTURING:
if torch.cuda.is_current_stream_capturing():
self.graph_input_set.add((tensor.dtype, tensor.numel()))
msg_size = tensor.numel() * tensor.itemsize
index = bisect.bisect_left(self.msg_size_for_finetune, msg_size)
msg_size_finetune = self.msg_size_for_finetune[index]
nthreads, nblocks = self.msg_size2best_config[msg_size_finetune]
result = torch.empty_like(tensor)
ops.mscclpp_allreduce(self._context, tensor, result, nthreads, nblocks)
return result
@contextmanager
def change_state(
self,
enable: Optional[bool] = None,
):
if enable is None:
# guess a default value when not specified
enable = self.available
old_disable = self.disabled
self.disabled = not enable
yield
self.disabled = old_disable

View File

@@ -190,6 +190,7 @@ class GroupCoordinator:
cpu_group: ProcessGroup # group for CPU communication
device_group: ProcessGroup # group for device communication
use_pynccl: bool # a hint of whether to use PyNccl
use_pymscclpp: bool # a hint of whether to use PyMsccl
use_custom_allreduce: bool # a hint of whether to use CustomAllreduce
use_message_queue_broadcaster: (
bool # a hint of whether to use message queue broadcaster
@@ -205,6 +206,7 @@ class GroupCoordinator:
local_rank: int,
torch_distributed_backend: Union[str, Backend],
use_pynccl: bool,
use_pymscclpp: bool,
use_custom_allreduce: bool,
use_hpu_communicator: bool,
use_xpu_communicator: bool,
@@ -244,6 +246,7 @@ class GroupCoordinator:
self.device = torch.device("cpu")
self.use_pynccl = use_pynccl
self.use_pymscclpp = use_pymscclpp
self.use_custom_allreduce = use_custom_allreduce
self.use_hpu_communicator = use_hpu_communicator
self.use_xpu_communicator = use_xpu_communicator
@@ -265,6 +268,17 @@ class GroupCoordinator:
device=self.device,
)
from sglang.srt.distributed.device_communicators.pymscclpp import (
PyMscclppCommunicator,
)
self.pymscclpp_comm: Optional[PyMscclppCommunicator] = None
if use_pymscclpp and self.world_size > 1:
self.pymscclpp_comm = PyMscclppCommunicator(
group=self.cpu_group,
device=self.device,
)
self.ca_comm: Optional[CustomAllreduce] = None
if use_custom_allreduce and self.world_size > 1:
# Initialize a custom fast all-reduce implementation.
@@ -373,11 +387,15 @@ class GroupCoordinator:
# --------------------------------------------
# custom allreduce | enabled | enabled |
# PyNccl | disabled| enabled |
# PyMscclpp | disabled| enabled |
# torch.distributed | enabled | disabled|
#
# Note that custom allreduce will have a runtime check, if the
# tensor size is too large, it will fallback to the next
# available option.
# Note that the PyMsccl needs to register the tensor in ahead,
# which will introduce large overhead in the eager case,
# therefore it is only supported in the graph case.
# In summary: When using CUDA graph, we use
# either custom all-reduce kernel or pynccl. When not using
# CUDA graph, we use either custom all-reduce kernel or
@@ -392,7 +410,14 @@ class GroupCoordinator:
maybe_pynccl_context = pynccl_comm.change_state(
enable=True, stream=torch.cuda.current_stream()
)
with maybe_pynccl_context:
pymscclpp_comm = self.pymscclpp_comm
maybe_pymscclpp_context: Any
if not pymscclpp_comm:
maybe_pymscclpp_context = nullcontext()
else:
maybe_pymscclpp_context = pymscclpp_comm.change_state(enable=True)
with maybe_pynccl_context, maybe_pymscclpp_context:
yield graph_capture_context
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
@@ -437,6 +462,10 @@ class GroupCoordinator:
self.ca_comm is not None
and not self.ca_comm.disabled
and self.ca_comm.should_custom_ar(input_)
) or (
self.pymscclpp_comm is not None
and not self.pymscclpp_comm.disabled
and self.pymscclpp_comm.should_mscclpp_allreduce(input_)
):
return torch.ops.sglang.outplace_all_reduce(
input_, group_name=self.unique_name
@@ -447,9 +476,13 @@ class GroupCoordinator:
def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor:
ca_comm = self.ca_comm
assert ca_comm is not None
assert not ca_comm.disabled
out = ca_comm.custom_all_reduce(input_)
pymscclpp_comm = self.pymscclpp_comm
assert ca_comm is not None or pymscclpp_comm is not None
if ca_comm is not None and not ca_comm.disabled:
out = ca_comm.custom_all_reduce(input_)
else:
assert not pymscclpp_comm.disabled
out = pymscclpp_comm.all_reduce(input_)
assert out is not None
return out
@@ -958,6 +991,7 @@ def init_world_group(
local_rank=local_rank,
torch_distributed_backend=backend,
use_pynccl=False,
use_pymscclpp=False,
use_custom_allreduce=False,
use_hpu_communicator=False,
use_xpu_communicator=False,
@@ -973,14 +1007,18 @@ def init_model_parallel_group(
use_custom_allreduce: Optional[bool] = None,
use_message_queue_broadcaster: bool = False,
group_name: Optional[str] = None,
use_mscclpp_allreduce: Optional[bool] = None,
) -> GroupCoordinator:
if use_custom_allreduce is None:
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
if use_mscclpp_allreduce is None:
use_mscclpp_allreduce = _ENABLE_MSCCLPP_ALL_REDUCE
return GroupCoordinator(
group_ranks=group_ranks,
local_rank=local_rank,
torch_distributed_backend=backend,
use_pynccl=not is_npu(),
use_pymscclpp=use_mscclpp_allreduce,
use_custom_allreduce=use_custom_allreduce,
use_hpu_communicator=True,
use_xpu_communicator=True,
@@ -1037,6 +1075,7 @@ def graph_capture():
logger = logging.getLogger(__name__)
_ENABLE_CUSTOM_ALL_REDUCE = True
_ENABLE_MSCCLPP_ALL_REDUCE = False
def set_custom_all_reduce(enable: bool):
@@ -1044,6 +1083,11 @@ def set_custom_all_reduce(enable: bool):
_ENABLE_CUSTOM_ALL_REDUCE = enable
def set_mscclpp_all_reduce(enable: bool):
global _ENABLE_MSCCLPP_ALL_REDUCE
_ENABLE_MSCCLPP_ALL_REDUCE = enable
def init_distributed_environment(
world_size: int = -1,
rank: int = -1,

View File

@@ -98,11 +98,12 @@ def initialize_dp_attention(
],
local_rank,
torch.distributed.get_backend(tp_group.device_group),
SYNC_TOKEN_IDS_ACROSS_TP,
False,
False,
False,
False,
use_pynccl=SYNC_TOKEN_IDS_ACROSS_TP,
use_pymscclpp=False,
use_custom_allreduce=False,
use_hpu_communicator=False,
use_xpu_communicator=False,
use_npu_communicator=False,
group_name="attention_tp",
)

View File

@@ -35,6 +35,7 @@ from sglang.srt.distributed import (
init_distributed_environment,
initialize_model_parallel,
set_custom_all_reduce,
set_mscclpp_all_reduce,
)
from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
from sglang.srt.layers.attention.tbo_backend import TboAttnBackend
@@ -460,6 +461,7 @@ class ModelRunner:
else:
dist_init_method = f"tcp://127.0.0.1:{self.dist_port}"
set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
set_mscclpp_all_reduce(self.server_args.enable_mscclpp)
if not self.is_draft_worker:
# Only initialize the distributed environment on the target model worker.

View File

@@ -165,6 +165,7 @@ class ServerArgs:
enable_tokenizer_batch_encode: bool = False
disable_outlines_disk_cache: bool = False
disable_custom_all_reduce: bool = False
enable_mscclpp: bool = False
disable_overlap_schedule: bool = False
enable_mixed_chunk: bool = False
enable_dp_attention: bool = False
@@ -1168,6 +1169,11 @@ class ServerArgs:
action="store_true",
help="Disable the custom all-reduce kernel and fall back to NCCL.",
)
parser.add_argument(
"--enable-mscclpp",
action="store_true",
help="Enable using mscclpp for small messages for all-reduce kernel and fall back to NCCL.",
)
parser.add_argument(
"--disable-overlap-schedule",
action="store_true",