This commit is contained in:
2026-01-09 13:34:11 +08:00
parent dfa6476b58
commit b2ef04d792
538 changed files with 105693 additions and 2 deletions

View File

@@ -0,0 +1,274 @@
from contextlib import contextmanager
from typing import Any, List, Optional
import torch
import torch.distributed as dist
import vllm.envs as envs
from vllm.logger import init_logger
try:
import pynvml
from vllm_C import custom_ar
except ImportError:
# For AMD GPUs
custom_ar = None
pynvml = None
logger = init_logger(__name__)
_CA_HANDLE: Optional["CustomAllreduce"] = None
_IS_CAPTURING = False
_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
def init_custom_ar() -> None:
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
global _CA_HANDLE
if _CA_HANDLE is not None:
return
rank = get_tensor_model_parallel_rank()
world_size = get_tensor_model_parallel_world_size()
if world_size == 1:
# No need to initialize custom allreduce for single GPU case.
return
if world_size not in _SUPPORTED_WORLD_SIZES:
logger.warning(
"Custom allreduce is disabled due to an unsupported world size: "
"%d. Supported world sizes: %s. To silence this warning, specify"
" disable_custom_all_reduce=True explicitly.", world_size,
str(_SUPPORTED_WORLD_SIZES))
return
num_dev = torch.musa.device_count()
# note: num dev can be larger than world_size if we're only using
# first few GPUs
if num_dev < world_size:
logger.warning(
"Cannot test GPU P2P because not all GPUs are visible to the "
"current process. This might be the case if 'CUDA_VISIBLE_DEVICES'"
" is set.")
return
# test nvlink first, this will filter out most of the cases
# where custom allreduce is not supported
cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
if cuda_visible_devices:
device_ids = list(map(int, cuda_visible_devices.split(",")))
else:
device_ids = list(range(num_dev))
# this checks hardware and driver support for NVLink
full_nvlink = _is_full_nvlink(device_ids)
if world_size > 2 and not full_nvlink:
logger.warning(
"Custom allreduce is disabled because it's not supported on more"
" than two PCIe-only GPUs. To silence this warning, specify"
" disable_custom_all_reduce=True explicitly.")
return
# test P2P capability, this checks software/cudaruntime support
# this is expensive to compute at the first time
# then we cache the result
if not _can_p2p(rank, world_size):
logger.warning(
"Custom allreduce is disabled because your platform lacks GPU P2P"
" capability or P2P test failed. To silence this warning, specify"
" disable_custom_all_reduce=True explicitly.")
return
_CA_HANDLE = CustomAllreduce(rank, world_size, full_nvlink)
def begin_capture() -> None:
global _IS_CAPTURING
_IS_CAPTURING = True
def end_capture() -> None:
global _IS_CAPTURING
_IS_CAPTURING = False
def is_capturing() -> bool:
return _IS_CAPTURING and _CA_HANDLE is not None
def get_handle() -> Optional["CustomAllreduce"]:
return _CA_HANDLE
def is_initialized() -> bool:
return _CA_HANDLE is not None
@contextmanager
def capture():
try:
begin_capture()
yield
finally:
end_capture()
handle = get_handle()
if handle is not None:
handle.register_graph_buffers()
def custom_all_reduce(input: torch.Tensor) -> Optional[torch.Tensor]:
ca_handle = get_handle()
# when custom allreduce is disabled, this will be None
if ca_handle is None:
return None
if is_capturing():
if torch.cuda.is_current_stream_capturing():
if ca_handle.should_custom_ar(input):
return ca_handle.all_reduce_reg(input)
else:
if ca_handle.should_custom_ar(input):
# if warm up, mimic the allocation pattern
# since custom allreduce is out-of-place
return torch.empty_like(input)
else:
# note: outside of cuda graph context,
# custom allreduce incurs a cost of cudaMemcpy, which should
# be small(<=1% of overall latency) compared to the performance
# gains of using custom kernels
if ca_handle.should_custom_ar(input):
return ca_handle.all_reduce_unreg(input)
return None
@contextmanager
def _nvml():
try:
pynvml.nvmlInit()
yield
finally:
pynvml.nvmlShutdown()
@_nvml()
def _is_full_nvlink(device_ids: List[int]) -> bool:
"""
query if the set of gpus are fully connected by nvlink (1 hop)
Note that `pynvml` is not affected by `CUDA_VISIBLE_DEVICES`,
so it works on real physical device ids.
"""
handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in device_ids]
for i, handle in enumerate(handles):
for j, peer_handle in enumerate(handles):
if i < j:
try:
p2p_status = pynvml.nvmlDeviceGetP2PStatus(
handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK)
if p2p_status != pynvml.NVML_P2P_STATUS_OK:
return False
except pynvml.NVMLError as error:
logger.error(
"NVLink detection failed. This is normal if your"
" machine has no NVLink equipped.",
exc_info=error)
return False
return True
def _can_p2p(rank: int, world_size: int) -> bool:
from vllm.distributed.utils import gpu_p2p_access_check
for i in range(world_size):
if i == rank:
continue
if not gpu_p2p_access_check(rank, i):
return False
return True
class CustomAllreduce:
# max_size: max supported allreduce size
def __init__(self,
rank,
world_size,
full_nvlink,
max_size=8192 * 1024) -> None:
# buffers memory are owned by this Python class and passed to C++
# meta data composes of two parts: meta data for synchronization
# (256 bytes) and a temporary buffer for storing intermediate
# allreduce results.
self.meta = torch.zeros(custom_ar.meta_size() + max_size,
dtype=torch.uint8,
device="musa")
# This is a pre-registered IPC buffer. In eager mode, input tensors
# are first copied into this buffer before allreduce is performed
self.buffer = torch.empty(max_size, dtype=torch.uint8, device="musa")
# This is a buffer for storing the tuples of pointers pointing to
# IPC buffers from all ranks. Each registered tuple has size of
# 8*world_size bytes where world_size is at most 8. Allocating 8MB
# is enough for 131072 such tuples. The largest model I've seen only
# needs less than 10000 of registered tuples.
self.rank_data = torch.empty(8 * 1024 * 1024,
dtype=torch.uint8,
device="musa")
self.max_size = max_size
self.world_size = world_size
handles, offsets = self._get_ipc_meta(self.meta)
self.full_nvlink = full_nvlink
self._ptr = custom_ar.init_custom_ar(self.meta, self.rank_data,
handles, offsets, rank,
self.full_nvlink)
self.register_buffer(self.buffer)
def _get_ipc_meta(self, inp: torch.Tensor):
data = inp.untyped_storage()._share_cuda_()
shard_data = (
data[1], # ipc handle to base ptr
data[3], # offset of base ptr
)
return self._gather_ipc_meta(shard_data)
def _gather_ipc_meta(self, shard_data):
all_data: List[Optional[Any]] = [None] * self.world_size
dist.all_gather_object(all_data, shard_data)
handles = []
offsets = []
for i in range(len(all_data)):
handles.append(all_data[i][0]) # type: ignore
offsets.append(all_data[i][1]) # type: ignore
return handles, offsets
def register_buffer(self, inp: torch.Tensor):
handles, offsets = self._get_ipc_meta(inp)
custom_ar.register_buffer(self._ptr, inp, handles, offsets)
def register_graph_buffers(self):
handle, offset = custom_ar.get_graph_buffer_ipc_meta(self._ptr)
handles, offsets = self._gather_ipc_meta((bytes(handle), offset))
logger.info("Registering %d cuda graph addresses", len(offset))
custom_ar.register_graph_buffers(self._ptr, handles, offsets)
def should_custom_ar(self, inp: torch.Tensor):
return custom_ar.should_custom_ar(inp, self.max_size, self.world_size,
self.full_nvlink)
# all reduce, assuming inp tensor is IPC registered with register_buffer,
# or, in the context of cuda graphs, register_graph_buffers
def all_reduce_reg(self, inp: torch.Tensor, out: torch.Tensor = None):
if out is None:
out = torch.empty_like(inp)
custom_ar.all_reduce_reg(self._ptr, inp, out)
return out
# all reduce, assuming inp tensor is NOT IPC registered
def all_reduce_unreg(self, inp: torch.Tensor, out: torch.Tensor = None):
if out is None:
out = torch.empty_like(inp)
custom_ar.all_reduce_unreg(self._ptr, inp, self.buffer, out)
return out
def close(self):
if self._ptr:
custom_ar.dispose(self._ptr)
self._ptr = 0
def __del__(self):
self.close()

View File

@@ -0,0 +1,284 @@
# This file is a pure Python wrapper for the MCCL library.
# The main purpose is to use MCCL combined with MUSA graph.
# Before writing this script, we tried the following approach:
# 1. We tried to use `cupy`, it calls MCCL correctly, but `cupy` itself
# often gets stuck when initializing the MCCL communicator.
# 2. We tried to use `torch.distributed`, but `torch.distributed.all_reduce`
# contains many other potential musa APIs, that are not allowed during
# capturing the MUSA graph. For further details, please check
# https://discuss.pytorch.org/t/pytorch-musagraph-with-mccl-operation-failed/ .
#
# Another rejected idea is to write a C/C++ binding for MCCL. It is usually
# doable, but we often encounter issues related with mccl versions, and need
# to switch between different versions of MCCL. See
# https://github.com/NVIDIA/mccl/issues/1234 for more details.
# A C/C++ binding is not flexible enough to handle this. It requires
# recompilation of the code every time we want to switch between different
# versions. This current implementation, with a **pure** Python wrapper, is
# more flexible. We can easily switch between different versions of MCCL by
# changing the environment variable `VLLM_MCCL_SO_PATH`, or the `so_file`
# variable in the code.
import ctypes
import platform
from typing import Optional, Union
# ===================== import region =====================
import torch
import torch_musa
import torch.distributed as dist
from torch.distributed import ProcessGroup, ReduceOp
from vllm.distributed.parallel_state import get_cpu_world_group, get_local_rank
from vllm.logger import init_logger
from vllm.utils import find_mccl_library, mccl_integrity_check
logger = init_logger(__name__)
so_file = find_mccl_library()
try:
# load the library in another process.
# if it core dumps, it will not crash the current process
mccl_integrity_check(so_file)
mccl = ctypes.CDLL(so_file)
except Exception as e:
logger.error(
"Failed to load MCCL library from %s ."
"It is expected if you are not running on NVIDIA/AMD GPUs."
"Otherwise, the mccl library might not exist, be corrupted "
"or it does not support the current platform %s."
"One solution is to download libmccl2 version 2.18 from "
"https://developer.download.nvidia.com/compute/musa/repos/ "
"and extract the libmccl.so.2 file. If you already have the "
"library, please set the environment variable VLLM_MCCL_SO_PATH"
" to point to the correct mccl library path.", so_file,
platform.platform())
raise e
# === export types and functions from mccl to Python ===
# for the original mccl definition, please check
# https://github.com/NVIDIA/mccl/blob/master/src/mccl.h.in
mcclResult_t = ctypes.c_int
_c_mcclGetErrorString = mccl.mcclGetErrorString
_c_mcclGetErrorString.restype = ctypes.c_char_p
_c_mcclGetErrorString.argtypes = [mcclResult_t]
def MCCL_CHECK(result: mcclResult_t) -> None:
if result != 0:
error_str = _c_mcclGetErrorString(result)
error_str = error_str.decode("utf-8")
raise RuntimeError(f"MCCL error: {error_str}")
# equivalent to c declaration:
# mcclResult_t mcclGetVersion(int *version);
_c_mcclGetVersion = mccl.mcclGetVersion
_c_mcclGetVersion.restype = ctypes.c_int
_c_mcclGetVersion.argtypes = [ctypes.POINTER(ctypes.c_int)]
def mcclGetVersion() -> str:
version = ctypes.c_int()
MCCL_CHECK(_c_mcclGetVersion(ctypes.byref(version)))
version_str = str(version.value)
return version_str
class McclUniqueId(ctypes.Structure):
_fields_ = [("internal", ctypes.c_byte * 128)]
# equivalent to c declaration:
# mcclResult_t mcclGetUniqueId(mcclUniqueId* uniqueId);
_c_mcclGetUniqueId = mccl.mcclGetUniqueId
_c_mcclGetUniqueId.restype = ctypes.c_int
_c_mcclGetUniqueId.argtypes = [ctypes.POINTER(McclUniqueId)]
def mcclGetUniqueId() -> McclUniqueId:
unique_id = McclUniqueId()
MCCL_CHECK(_c_mcclGetUniqueId(ctypes.byref(unique_id)))
return unique_id
# equivalent to c declaration:
# mcclResult_t mcclCommInitRank(
# mcclComm_t* comm, int nranks, mcclUniqueId commId, int rank);
# note that mcclComm_t is a pointer type, so the first argument
# is a pointer to a pointer
_c_mcclCommInitRank = mccl.mcclCommInitRank
_c_mcclCommInitRank.restype = ctypes.c_int
_c_mcclCommInitRank.argtypes = [
ctypes.POINTER(ctypes.c_void_p), ctypes.c_int, McclUniqueId, ctypes.c_int
]
mcclDataType_t = ctypes.c_int
class mcclDataTypeEnum:
mcclInt8 = 0
mcclChar = 0
mcclUint8 = 1
mcclInt32 = 2
mcclInt = 2
mcclUint32 = 3
mcclInt64 = 4
mcclUint64 = 5
mcclFloat16 = 6
mcclHalf = 6
mcclFloat32 = 7
mcclFloat = 7
mcclFloat64 = 8
mcclDouble = 8
mcclBfloat16 = 9
mcclNumTypes = 10
@classmethod
def from_torch(cls, dtype: torch.dtype) -> int:
if dtype == torch.int8:
return cls.mcclInt8
if dtype == torch.uint8:
return cls.mcclUint8
if dtype == torch.int32:
return cls.mcclInt32
if dtype == torch.int64:
return cls.mcclInt64
if dtype == torch.float16:
return cls.mcclFloat16
if dtype == torch.float32:
return cls.mcclFloat32
if dtype == torch.float64:
return cls.mcclFloat64
if dtype == torch.bfloat16:
return cls.mcclBfloat16
raise ValueError(f"Unsupported dtype: {dtype}")
mcclRedOp_t = ctypes.c_int
class mcclRedOpTypeEnum:
mcclSum = 0
mcclProd = 1
mcclMax = 2
mcclMin = 3
mcclAvg = 4
mcclNumOps = 5
@classmethod
def from_torch(cls, op: ReduceOp) -> int:
if op == ReduceOp.SUM:
return cls.mcclSum
if op == ReduceOp.PRODUCT:
return cls.mcclProd
if op == ReduceOp.MAX:
return cls.mcclMax
if op == ReduceOp.MIN:
return cls.mcclMin
if op == ReduceOp.AVG:
return cls.mcclAvg
raise ValueError(f"Unsupported op: {op}")
# equivalent to c declaration:
# mcclResult_t mcclAllReduce(
# const void* sendbuff, void* recvbuff, size_t count,
# mcclDataType_t datatype, mcclRedOp_t op, mcclComm_t comm,
# udaStream_t stream);
# note that musaStream_t is a pointer type, so the last argument is a pointer
_c_mcclAllReduce = mccl.mcclAllReduce
_c_mcclAllReduce.restype = ctypes.c_int
_c_mcclAllReduce.argtypes = [
ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, mcclRedOp_t,
mcclDataType_t, ctypes.c_void_p, ctypes.c_void_p
]
# be cautious! this is a collective call, it will block until all
# processes in the communicator have called this function.
# because Python object destruction can happen in random order,
# it is better not to call it at all.
# equivalent to c declaration:
# mcclResult_t mcclCommDestroy(mcclComm_t comm);
_c_mcclCommDestroy = mccl.mcclCommDestroy
_c_mcclCommDestroy.restype = ctypes.c_int
_c_mcclCommDestroy.argtypes = [ctypes.c_void_p]
class MCCLCommunicator:
def __init__(
self,
group: Optional[ProcessGroup] = None,
device: Optional[Union[int, str, torch.device]] = None,
):
"""
Args:
group: the process group to work on. If None, it will use the
default process group.
device: the device to bind the MCCLCommunicator to. If None,
it will be bind to f"musa:{local_rank}".
It is the caller's responsibility to make sure each communicator
is bind to a unique device.
"""
assert dist.is_initialized()
group = get_cpu_world_group() if group is None else group
assert dist.get_backend(group) != dist.Backend.MCCL, (
"MCCLCommunicator should be attached to a non-MCCL group.")
self.group = group
# note: this rank is the rank in the group
self.rank = dist.get_rank(group)
self.world_size = dist.get_world_size(group)
if self.rank == 0:
self.unique_id = mcclGetUniqueId()
else:
self.unique_id = McclUniqueId()
tensor = torch.ByteTensor(list(self.unique_id.internal))
ranks = dist.get_process_group_ranks(group)
# arg `src` in `broadcast` is the global rank
dist.broadcast(tensor, src=ranks[0], group=group)
byte_list = tensor.tolist()
for i, byte in enumerate(byte_list):
self.unique_id.internal[i] = byte
self.comm = ctypes.c_void_p()
if device is None:
local_rank = get_local_rank()
device = torch.device(f"musa:{local_rank}")
elif isinstance(device, int):
device = torch.device(f"musa:{device}")
elif isinstance(device, str):
device = torch.device(device)
# now `device` is a `torch.device` object
assert isinstance(device, torch.device)
self.device = device
# mccl communicator and stream will use this device
# `torch.musa.device` is a context manager that changes the
# current musa device to the specified one
with torch.musa.device(device):
MCCL_CHECK(
_c_mcclCommInitRank(ctypes.byref(self.comm), self.world_size,
self.unique_id, self.rank))
self.stream = torch.musa.Stream()
def all_reduce(self,
tensor: torch.Tensor,
op: ReduceOp = ReduceOp.SUM,
stream=None):
# mccl communicator created on a specific device
# will only work on tensors on the same device
# otherwise it will cause "illegal memory access"
assert tensor.device == self.device, (
f"this mccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}")
if stream is None:
stream = self.stream
MCCL_CHECK(
_c_mcclAllReduce(ctypes.c_void_p(tensor.data_ptr()),
ctypes.c_void_p(tensor.data_ptr()),
tensor.numel(),
mcclDataTypeEnum.from_torch(tensor.dtype),
mcclRedOpTypeEnum.from_torch(op), self.comm,
ctypes.c_void_p(stream.musa_stream)))

View File

@@ -0,0 +1,66 @@
import contextlib
from typing import Optional
import torch
from torch.distributed import ProcessGroup, ReduceOp
from vllm.logger import init_logger
logger = init_logger(__name__)
try:
from vllm.distributed.device_communicators.pymccl import (MCCLCommunicator,
mcclGetVersion)
except Exception as e:
# in non-MTHREADS environments, we can't import the mccl module
# e.g. when running on machines with AMD GPUs
logger.info("Failed to import MCCL library: %s", e)
logger.info("It is expected if you are not running on Mthreads GPUs.")
pass
comm: Optional["MCCLCommunicator"] = None
def is_initialized() -> bool:
"""Returns whether the NCCL backend is initialized."""
return comm is not None
@contextlib.contextmanager
def set_pymccl_stream(stream: torch.cuda.Stream):
"""Set the cuda stream for communication"""
try:
assert comm is not None
comm.stream = stream
yield
finally:
pass
def init_process_group(group: Optional[ProcessGroup] = None) -> None:
assert not is_initialized()
global comm
logger.info("vLLM is using nccl==%s", mcclGetVersion())
comm = MCCLCommunicator(group=group)
def all_reduce(input_: torch.Tensor, op=ReduceOp.SUM) -> None:
"""All-reduces the input tensor across the process group."""
assert input_.is_musa, f"{input_} should be a musa tensor"
assert comm is not None
comm.all_reduce(input_, op)
def destroy_process_group() -> None:
global comm
comm = None
def get_world_size() -> int:
"""Returns the world size."""
assert comm is not None
return comm.world_size
def get_nccl_backend() -> Optional["MCCLCommunicator"]:
return comm

View File

@@ -0,0 +1,287 @@
# This file is a pure Python wrapper for the NCCL library.
# The main purpose is to use NCCL combined with CUDA graph.
# Before writing this script, we tried the following approach:
# 1. We tried to use `cupy`, it calls NCCL correctly, but `cupy` itself
# often gets stuck when initializing the NCCL communicator.
# 2. We tried to use `torch.distributed`, but `torch.distributed.all_reduce`
# contains many other potential cuda APIs, that are not allowed during
# capturing the CUDA graph. For further details, please check
# https://discuss.pytorch.org/t/pytorch-cudagraph-with-nccl-operation-failed/ .
#
# Another rejected idea is to write a C/C++ binding for NCCL. It is usually
# doable, but we often encounter issues related with nccl versions, and need
# to switch between different versions of NCCL. See
# https://github.com/NVIDIA/nccl/issues/1234 for more details.
# A C/C++ binding is not flexible enough to handle this. It requires
# recompilation of the code every time we want to switch between different
# versions. This current implementation, with a **pure** Python wrapper, is
# more flexible. We can easily switch between different versions of NCCL by
# changing the environment variable `VLLM_NCCL_SO_PATH`, or the `so_file`
# variable in the code.
import ctypes
import platform
from typing import Optional, Union
# ===================== import region =====================
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup, ReduceOp
from vllm.distributed.parallel_state import get_cpu_world_group, get_local_rank
from vllm.logger import init_logger
from vllm.utils import find_nccl_library, nccl_integrity_check
logger = init_logger(__name__)
so_file = find_nccl_library()
try:
# load the library in another process.
# if it core dumps, it will not crash the current process
nccl_integrity_check(so_file)
nccl = ctypes.CDLL(so_file)
except Exception as e:
logger.error(
"Failed to load NCCL library from %s ."
"It is expected if you are not running on NVIDIA/AMD GPUs."
"Otherwise, the nccl library might not exist, be corrupted "
"or it does not support the current platform %s."
"One solution is to download libnccl2 version 2.18 from "
"https://developer.download.nvidia.com/compute/cuda/repos/ "
"and extract the libnccl.so.2 file. If you already have the "
"library, please set the environment variable VLLM_NCCL_SO_PATH"
" to point to the correct nccl library path.", so_file,
platform.platform())
raise e
# === export types and functions from nccl to Python ===
# for the original nccl definition, please check
# https://github.com/NVIDIA/nccl/blob/master/src/nccl.h.in
ncclResult_t = ctypes.c_int
_c_ncclGetErrorString = nccl.ncclGetErrorString
_c_ncclGetErrorString.restype = ctypes.c_char_p
_c_ncclGetErrorString.argtypes = [ncclResult_t]
def NCCL_CHECK(result: ncclResult_t) -> None:
if result != 0:
error_str = _c_ncclGetErrorString(result)
error_str = error_str.decode("utf-8")
raise RuntimeError(f"NCCL error: {error_str}")
# equivalent to c declaration:
# ncclResult_t ncclGetVersion(int *version);
_c_ncclGetVersion = nccl.ncclGetVersion
_c_ncclGetVersion.restype = ctypes.c_int
_c_ncclGetVersion.argtypes = [ctypes.POINTER(ctypes.c_int)]
def ncclGetVersion() -> str:
version = ctypes.c_int()
NCCL_CHECK(_c_ncclGetVersion(ctypes.byref(version)))
# something like 21903 --> "2.19.3"
version_str = str(version.value)
major = version_str[0].lstrip("0")
minor = version_str[1:3].lstrip("0")
patch = version_str[3:].lstrip("0")
return f"{major}.{minor}.{patch}"
class NcclUniqueId(ctypes.Structure):
_fields_ = [("internal", ctypes.c_byte * 128)]
# equivalent to c declaration:
# ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId);
_c_ncclGetUniqueId = nccl.ncclGetUniqueId
_c_ncclGetUniqueId.restype = ctypes.c_int
_c_ncclGetUniqueId.argtypes = [ctypes.POINTER(NcclUniqueId)]
def ncclGetUniqueId() -> NcclUniqueId:
unique_id = NcclUniqueId()
NCCL_CHECK(_c_ncclGetUniqueId(ctypes.byref(unique_id)))
return unique_id
# equivalent to c declaration:
# ncclResult_t ncclCommInitRank(
# ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank);
# note that ncclComm_t is a pointer type, so the first argument
# is a pointer to a pointer
_c_ncclCommInitRank = nccl.ncclCommInitRank
_c_ncclCommInitRank.restype = ctypes.c_int
_c_ncclCommInitRank.argtypes = [
ctypes.POINTER(ctypes.c_void_p), ctypes.c_int, NcclUniqueId, ctypes.c_int
]
ncclDataType_t = ctypes.c_int
class ncclDataTypeEnum:
ncclInt8 = 0
ncclChar = 0
ncclUint8 = 1
ncclInt32 = 2
ncclInt = 2
ncclUint32 = 3
ncclInt64 = 4
ncclUint64 = 5
ncclFloat16 = 6
ncclHalf = 6
ncclFloat32 = 7
ncclFloat = 7
ncclFloat64 = 8
ncclDouble = 8
ncclBfloat16 = 9
ncclNumTypes = 10
@classmethod
def from_torch(cls, dtype: torch.dtype) -> int:
if dtype == torch.int8:
return cls.ncclInt8
if dtype == torch.uint8:
return cls.ncclUint8
if dtype == torch.int32:
return cls.ncclInt32
if dtype == torch.int64:
return cls.ncclInt64
if dtype == torch.float16:
return cls.ncclFloat16
if dtype == torch.float32:
return cls.ncclFloat32
if dtype == torch.float64:
return cls.ncclFloat64
if dtype == torch.bfloat16:
return cls.ncclBfloat16
raise ValueError(f"Unsupported dtype: {dtype}")
ncclRedOp_t = ctypes.c_int
class ncclRedOpTypeEnum:
ncclSum = 0
ncclProd = 1
ncclMax = 2
ncclMin = 3
ncclAvg = 4
ncclNumOps = 5
@classmethod
def from_torch(cls, op: ReduceOp) -> int:
if op == ReduceOp.SUM:
return cls.ncclSum
if op == ReduceOp.PRODUCT:
return cls.ncclProd
if op == ReduceOp.MAX:
return cls.ncclMax
if op == ReduceOp.MIN:
return cls.ncclMin
if op == ReduceOp.AVG:
return cls.ncclAvg
raise ValueError(f"Unsupported op: {op}")
# equivalent to c declaration:
# ncclResult_t ncclAllReduce(
# const void* sendbuff, void* recvbuff, size_t count,
# ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,
# udaStream_t stream);
# note that cudaStream_t is a pointer type, so the last argument is a pointer
_c_ncclAllReduce = nccl.ncclAllReduce
_c_ncclAllReduce.restype = ctypes.c_int
_c_ncclAllReduce.argtypes = [
ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, ncclRedOp_t,
ncclDataType_t, ctypes.c_void_p, ctypes.c_void_p
]
# be cautious! this is a collective call, it will block until all
# processes in the communicator have called this function.
# because Python object destruction can happen in random order,
# it is better not to call it at all.
# equivalent to c declaration:
# ncclResult_t ncclCommDestroy(ncclComm_t comm);
_c_ncclCommDestroy = nccl.ncclCommDestroy
_c_ncclCommDestroy.restype = ctypes.c_int
_c_ncclCommDestroy.argtypes = [ctypes.c_void_p]
class NCCLCommunicator:
def __init__(
self,
group: Optional[ProcessGroup] = None,
device: Optional[Union[int, str, torch.device]] = None,
):
"""
Args:
group: the process group to work on. If None, it will use the
default process group.
device: the device to bind the NCCLCommunicator to. If None,
it will be bind to f"cuda:{local_rank}".
It is the caller's responsibility to make sure each communicator
is bind to a unique device.
"""
assert dist.is_initialized()
group = get_cpu_world_group() if group is None else group
assert dist.get_backend(group) != dist.Backend.NCCL, (
"NCCLCommunicator should be attached to a non-NCCL group.")
self.group = group
# note: this rank is the rank in the group
self.rank = dist.get_rank(group)
self.world_size = dist.get_world_size(group)
if self.rank == 0:
self.unique_id = ncclGetUniqueId()
else:
self.unique_id = NcclUniqueId()
tensor = torch.ByteTensor(list(self.unique_id.internal))
ranks = dist.get_process_group_ranks(group)
# arg `src` in `broadcast` is the global rank
dist.broadcast(tensor, src=ranks[0], group=group)
byte_list = tensor.tolist()
for i, byte in enumerate(byte_list):
self.unique_id.internal[i] = byte
self.comm = ctypes.c_void_p()
if device is None:
local_rank = get_local_rank()
device = torch.device(f"cuda:{local_rank}")
elif isinstance(device, int):
device = torch.device(f"cuda:{device}")
elif isinstance(device, str):
device = torch.device(device)
# now `device` is a `torch.device` object
assert isinstance(device, torch.device)
self.device = device
# nccl communicator and stream will use this device
# `torch.cuda.device` is a context manager that changes the
# current cuda device to the specified one
with torch.cuda.device(device):
NCCL_CHECK(
_c_ncclCommInitRank(ctypes.byref(self.comm), self.world_size,
self.unique_id, self.rank))
self.stream = torch.cuda.Stream()
def all_reduce(self,
tensor: torch.Tensor,
op: ReduceOp = ReduceOp.SUM,
stream=None):
# nccl communicator created on a specific device
# will only work on tensors on the same device
# otherwise it will cause "illegal memory access"
assert tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}")
if stream is None:
stream = self.stream
NCCL_CHECK(
_c_ncclAllReduce(ctypes.c_void_p(tensor.data_ptr()),
ctypes.c_void_p(tensor.data_ptr()),
tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype),
ncclRedOpTypeEnum.from_torch(op), self.comm,
ctypes.c_void_p(stream.cuda_stream)))