Files
sglang/python/sglang/srt/distributed/device_communicators/pymscclpp.py

316 lines
11 KiB
Python

import bisect
import logging
import math
import os
from contextlib import contextmanager
from enum import IntEnum
from typing import Any, Callable, List, Optional, TypeVar, Union
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup, ReduceOp
from sglang.srt import _custom_ops as ops
from sglang.srt.utils import is_cuda, is_hip
logger = logging.getLogger(__name__)
_is_cuda = is_cuda()
_is_hip = is_hip()
mscclpp_is_available = False
if _is_hip:
# TODO(zyksir): mscclpp is untested on AMD and therefore disabled.
mscclpp_is_available = False
if _is_cuda:
try:
import sgl_kernel
mscclpp_is_available = True
except:
mscclpp_is_available = False
class MscclContextSelection(IntEnum):
MSCCL1SHOT1NODELL = 1
MSCCL1SHOT2NODELL = 2
def mscclpp_is_weak_contiguous(inp: torch.Tensor):
return inp.is_contiguous() or (
inp.storage().nbytes() - inp.storage_offset() * inp.element_size()
== inp.numel() * inp.element_size()
)
def mscclpp_convert_to_bytes(size_str):
"""
Converts a human-readable size string (e.g., "1MB", "2.5kb", "3 GB")
into the equivalent number of bytes using binary units.
Args:
size_str (str): A string representing size with unit (KB, MB, GB).
Returns:
int: Number of bytes.
"""
size_str = size_str.strip().lower()
if not size_str:
raise ValueError("Empty input string")
# Extract numeric part and unit
for i in range(len(size_str)):
if not size_str[i].isdigit() and size_str[i] != ".":
break
num_str = size_str[:i]
unit = size_str[i:].strip()
try:
num = float(num_str)
except ValueError:
raise ValueError(f"Invalid numeric value in '{size_str}'")
# Conversion factors
if unit == "b":
return int(num)
elif unit == "kb":
return int(num * 1024)
elif unit == "mb":
return int(num * 1024 * 1024)
elif unit == "gb":
return int(num * 1024 * 1024 * 1024)
else:
raise ValueError(f"Unsupported unit: {unit}, support B, KB, MB, GB only")
def mscclpp_bench_time(func, test_niter: int = 10, warmup_niter: int = 2):
# warmup
for _ in range(warmup_niter):
func()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
torch.cuda.synchronize()
dist.barrier()
start_event.record()
for _ in range(test_niter):
func()
end_event.record()
end_event.synchronize()
func_cost_us = start_event.elapsed_time(end_event) / test_niter * 1000
return func_cost_us
class PyMscclppCommunicator:
_SUPPORTED_WORLD_SIZES = [8, 16]
_MAX_BYTES = mscclpp_convert_to_bytes(os.getenv("SGLANG_MSCCLPP_MAX_BYTES", "1MB"))
_SUPPORTED_DTYPE = [torch.float, torch.float16, torch.bfloat16]
# max_bytes: max supported mscclpp allreduce size
# in A100 mscclpp is faster than nccl only under condition of msg size smaller than1MB
def __init__(
self,
group: ProcessGroup,
device: Union[int, str, torch.device],
max_bytes=_MAX_BYTES,
) -> None:
"""
Args:
group: the process group to work on. If None, it will use the
default process group.
device: the device to bind the CustomAllreduce to. If None,
it will be bind to f"cuda:{local_rank}".
It is the caller's responsibility to make sure each communicator
is bind to a unique device, and all communicators in this group
are in the same node.
"""
self._IS_CAPTURING = False
self.disabled = True
if not mscclpp_is_available:
# disable because of missing mscclpp library
# e.g. in a non-cuda environment
return
self.group = group
assert (
dist.get_backend(group) != dist.Backend.NCCL
), "CustomAllreduce should be attached to a non-NCCL group."
rank = dist.get_rank(group=self.group)
world_size = dist.get_world_size(group=self.group)
if world_size == 1:
# No need to initialize mscclpp for single GPU case.
return
if world_size not in PyMscclppCommunicator._SUPPORTED_WORLD_SIZES:
logger.warning(
"PyMscclpp is disabled due to an unsupported world"
" size: %d. Supported world sizes: %s. To silence this "
"warning, specify disable_mscclpp=True explicitly.",
world_size,
str(PyMscclppCommunicator._SUPPORTED_WORLD_SIZES),
)
return
self.ranks = torch.distributed.get_process_group_ranks(group)
self.nranks_per_node = torch.cuda.device_count()
# for now mscclpp with stride in the communicator is not tested
if not (abs(self.ranks[-1] - self.ranks[0]) == world_size - 1):
logger.warning(
"PyMscclpp is disabled due to an unsupported group %s."
"Please ensure all ranks in the group are consecutive."
"To silence this warning, specify disable_mscclpp=True explicitly.",
str(self.ranks),
)
return
if isinstance(device, int):
device = torch.device(f"cuda:{device}")
elif isinstance(device, str):
device = torch.device(device)
# now `device` is a `torch.device` object
assert isinstance(device, torch.device)
self.device = device
self.max_bytes = max_bytes
self.rank = rank
self.world_size = world_size
if dist.get_rank(group) == 0:
unique_id = [ops.mscclpp_generate_unique_id()]
else:
unique_id = [None]
dist.broadcast_object_list(unique_id, src=self.ranks[0], group=self.group)
self.unique_id = unique_id[0]
self.rank_to_node, self.rank_to_ib = list(range(world_size)), list(
range(world_size)
)
for r in range(world_size):
self.rank_to_node[r] = r // 8
self.rank_to_ib[r] = self.rank % 8
self._context = None
self.context_selection = None
self.msg_size_for_finetune = [
2**i for i in range(10, math.floor(math.log2(self.max_bytes)) + 1)
]
self.msg_size2best_config = {}
if world_size == 8:
self.context_selection = MscclContextSelection.MSCCL1SHOT1NODELL
elif world_size == 16:
self.context_selection = MscclContextSelection.MSCCL1SHOT2NODELL
if not _is_hip:
self.scratch = torch.empty(
self.max_bytes * 8,
dtype=torch.uint8,
device=self.device,
)
self.put_buffer = torch.empty(
self.max_bytes * 8 // self.nranks_per_node,
dtype=torch.uint8,
device=self.device,
)
self._context = ops.mscclpp_init_context(
self.unique_id,
self.rank,
self.world_size,
self.scratch,
self.put_buffer,
self.nranks_per_node,
self.rank_to_node,
self.rank_to_ib,
int(self.context_selection),
)
else:
raise NotImplementedError("HIP Mscclpp is not supported yet.")
self.msg_size2best_config = {}
self.pre_tune_config()
if dist.get_rank(group) == 0:
msg_size2best_config = [self.msg_size2best_config]
else:
msg_size2best_config = [None]
dist.broadcast_object_list(
msg_size2best_config, src=self.ranks[0], group=self.group
)
self.msg_size2best_config = msg_size2best_config[0]
# PyMscclpp is enabled only in cuda graph
self.disabled = True
def pre_tune_config(self, dtype=torch.bfloat16) -> bool:
logger.debug(f"start to pre-tune configs for rank {self.rank}")
nthreads_to_try = [256, 512, 1024]
nblocks_to_try = [21, 42, 84]
inp_randn = torch.ones(
self.msg_size_for_finetune[-1] // dtype.itemsize, dtype=dtype, device="cuda"
)
oup_randn = torch.empty_like(inp_randn)
for msg_size in self.msg_size_for_finetune:
mock_inp, mock_outp = (
inp_randn[: msg_size // dtype.itemsize],
oup_randn[: msg_size // dtype.itemsize],
)
best_config, best_time = None, None
for nthreads in nthreads_to_try:
for nblocks in nblocks_to_try:
cur_cost = mscclpp_bench_time(
lambda: ops.mscclpp_allreduce(
self._context, mock_inp, mock_outp, nthreads, nblocks
)
)
if best_time is None or cur_cost < best_time:
best_config = (nthreads, nblocks)
best_time = cur_cost
self.msg_size2best_config[msg_size] = best_config
if self.rank == 0:
logger.debug(
f"for msg_size {msg_size}, best_config: {best_config}, best_time: {best_time}us"
)
def should_mscclpp_allreduce(
self, inp: torch.Tensor, op: ReduceOp = ReduceOp.SUM
) -> bool:
if self.disabled or self._context is None:
return False
if inp.dtype not in PyMscclppCommunicator._SUPPORTED_DTYPE:
return False
if not mscclpp_is_weak_contiguous(inp):
return False
# only support sum op
if op != ReduceOp.SUM:
return False
if inp.numel() * inp.element_size() > self.max_bytes:
return False
return True
def all_reduce(self, tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM):
if self._IS_CAPTURING:
if torch.cuda.is_current_stream_capturing():
self.graph_input_set.add((tensor.dtype, tensor.numel()))
msg_size = tensor.numel() * tensor.itemsize
index = bisect.bisect_left(self.msg_size_for_finetune, msg_size)
msg_size_finetune = self.msg_size_for_finetune[index]
nthreads, nblocks = self.msg_size2best_config[msg_size_finetune]
result = torch.empty_like(tensor)
ops.mscclpp_allreduce(self._context, tensor, result, nthreads, nblocks)
return result
@contextmanager
def change_state(
self,
enable: Optional[bool] = None,
):
if enable is None:
# guess a default value when not specified
enable = self.available
old_disable = self.disabled
self.disabled = not enable
yield
self.disabled = old_disable