v1.0
This commit is contained in:
386
distributed/device_communicators/pynccl.py
Normal file
386
distributed/device_communicators/pynccl.py
Normal file
@@ -0,0 +1,386 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
# ===================== import region =====================
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup, ReduceOp
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.distributed.device_communicators.pynccl_wrapper import (
|
||||
NCCLLibrary,
|
||||
buffer_type,
|
||||
cudaStream_t,
|
||||
ncclComm_t,
|
||||
ncclDataTypeEnum,
|
||||
ncclRedOpTypeEnum,
|
||||
ncclUniqueId,
|
||||
)
|
||||
from vllm.distributed.utils import StatelessProcessGroup
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.torch_utils import current_stream
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_NCCL_SYMM_OPS_REGISTERED = False
|
||||
|
||||
|
||||
def register_nccl_symmetric_ops(pynccl_comm):
|
||||
from vllm.distributed.device_communicators.pynccl_allocator import (
|
||||
nccl_symm_mem_context,
|
||||
)
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
global _NCCL_SYMM_OPS_REGISTERED
|
||||
if _NCCL_SYMM_OPS_REGISTERED:
|
||||
return
|
||||
_NCCL_SYMM_OPS_REGISTERED = True
|
||||
|
||||
def all_reduce_symmetric_with_copy_impl(input_tensor: torch.Tensor) -> torch.Tensor:
|
||||
with nccl_symm_mem_context(pynccl_comm):
|
||||
symm_input = torch.empty_like(input_tensor)
|
||||
symm_output = torch.empty_like(input_tensor)
|
||||
symm_input.copy_(input_tensor)
|
||||
symm_output = pynccl_comm.all_reduce(symm_input, symm_output)
|
||||
return symm_output
|
||||
|
||||
def all_reduce_symmetric_with_copy_fake(input_tensor: torch.Tensor) -> torch.Tensor:
|
||||
return torch.empty_like(input_tensor)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="all_reduce_symmetric_with_copy",
|
||||
op_func=all_reduce_symmetric_with_copy_impl,
|
||||
fake_impl=all_reduce_symmetric_with_copy_fake,
|
||||
)
|
||||
|
||||
|
||||
class PyNcclCommunicator:
|
||||
def __init__(
|
||||
self,
|
||||
group: ProcessGroup | StatelessProcessGroup,
|
||||
device: int | str | torch.device,
|
||||
library_path: str | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
group: the process group to work on. If None, it will use the
|
||||
default process group.
|
||||
device: the device to bind the PyNcclCommunicator to. If None,
|
||||
it will be bound to f"cuda:{local_rank}".
|
||||
library_path: the path to the NCCL library. If None, it will
|
||||
use the default library path.
|
||||
It is the caller's responsibility to make sure each communicator
|
||||
is bind to a unique device.
|
||||
"""
|
||||
if not isinstance(group, StatelessProcessGroup):
|
||||
assert dist.is_initialized()
|
||||
assert dist.get_backend(group) != dist.Backend.NCCL, (
|
||||
"PyNcclCommunicator should be attached to a non-NCCL group."
|
||||
)
|
||||
# note: this rank is the rank in the group
|
||||
self.rank = dist.get_rank(group)
|
||||
self.world_size = dist.get_world_size(group)
|
||||
else:
|
||||
self.rank = group.rank
|
||||
self.world_size = group.world_size
|
||||
|
||||
self.group = group
|
||||
|
||||
# if world_size == 1, no need to create communicator
|
||||
if self.world_size == 1 or envs.VLLM_DISABLE_PYNCCL:
|
||||
self.available = False
|
||||
self.disabled = True
|
||||
return
|
||||
try:
|
||||
self.nccl = NCCLLibrary(library_path)
|
||||
except Exception:
|
||||
# disable because of missing NCCL library
|
||||
# e.g. in a non-GPU environment
|
||||
self.available = False
|
||||
self.disabled = True
|
||||
return
|
||||
|
||||
self.available = True
|
||||
self.disabled = False
|
||||
|
||||
self.nccl_version = self.nccl.ncclGetRawVersion()
|
||||
if self.rank == 0:
|
||||
# get the unique id from NCCL
|
||||
self.unique_id = self.nccl.ncclGetUniqueId()
|
||||
logger.info_once(
|
||||
"vLLM is using nccl==%s", self.nccl.ncclGetVersion(), scope="local"
|
||||
)
|
||||
else:
|
||||
# construct an empty unique id
|
||||
self.unique_id = ncclUniqueId()
|
||||
|
||||
if not isinstance(group, StatelessProcessGroup):
|
||||
tensor = torch.ByteTensor(list(self.unique_id.internal))
|
||||
ranks = dist.get_process_group_ranks(group)
|
||||
# arg `src` in `broadcast` is the global rank
|
||||
dist.broadcast(tensor, src=ranks[0], group=group)
|
||||
byte_list = tensor.tolist()
|
||||
for i, byte in enumerate(byte_list):
|
||||
self.unique_id.internal[i] = byte
|
||||
else:
|
||||
self.unique_id = group.broadcast_obj(self.unique_id, src=0)
|
||||
if isinstance(device, int):
|
||||
device = torch.device(f"cuda:{device}")
|
||||
elif isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
# now `device` is a `torch.device` object
|
||||
assert isinstance(device, torch.device)
|
||||
self.device = device
|
||||
# nccl communicator and stream will use this device
|
||||
# `torch.cuda.device` is a context manager that changes the
|
||||
# current cuda device to the specified one
|
||||
with torch.cuda.device(device):
|
||||
self.comm: ncclComm_t = self.nccl.ncclCommInitRank(
|
||||
self.world_size, self.unique_id, self.rank
|
||||
)
|
||||
|
||||
stream = current_stream()
|
||||
# A small all_reduce for warmup.
|
||||
data = torch.zeros(1, device=device)
|
||||
self.all_reduce(data)
|
||||
stream.synchronize()
|
||||
del data
|
||||
|
||||
def all_reduce(
|
||||
self,
|
||||
in_tensor: torch.Tensor,
|
||||
out_tensor: torch.Tensor = None,
|
||||
op: ReduceOp = ReduceOp.SUM,
|
||||
stream=None,
|
||||
) -> torch.Tensor:
|
||||
if self.disabled:
|
||||
return None
|
||||
# nccl communicator created on a specific device
|
||||
# will only work on tensors on the same device
|
||||
# otherwise it will cause "illegal memory access"
|
||||
assert in_tensor.device == self.device, (
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {in_tensor.device}"
|
||||
)
|
||||
|
||||
if out_tensor is None:
|
||||
out_tensor = torch.empty_like(in_tensor)
|
||||
|
||||
if stream is None:
|
||||
stream = current_stream()
|
||||
self.nccl.ncclAllReduce(
|
||||
buffer_type(in_tensor.data_ptr()),
|
||||
buffer_type(out_tensor.data_ptr()),
|
||||
in_tensor.numel(),
|
||||
ncclDataTypeEnum.from_torch(in_tensor.dtype),
|
||||
ncclRedOpTypeEnum.from_torch(op),
|
||||
self.comm,
|
||||
cudaStream_t(stream.cuda_stream),
|
||||
)
|
||||
return out_tensor
|
||||
|
||||
def all_gather(
|
||||
self, output_tensor: torch.Tensor, input_tensor: torch.Tensor, stream=None
|
||||
):
|
||||
if self.disabled:
|
||||
return
|
||||
# nccl communicator created on a specific device
|
||||
# will only work on tensors on the same device
|
||||
# otherwise it will cause "illegal memory access"
|
||||
assert input_tensor.device == self.device, (
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {input_tensor.device}"
|
||||
)
|
||||
if stream is None:
|
||||
stream = current_stream()
|
||||
self.nccl.ncclAllGather(
|
||||
buffer_type(input_tensor.data_ptr()),
|
||||
buffer_type(output_tensor.data_ptr()),
|
||||
input_tensor.numel(),
|
||||
ncclDataTypeEnum.from_torch(input_tensor.dtype),
|
||||
self.comm,
|
||||
cudaStream_t(stream.cuda_stream),
|
||||
)
|
||||
|
||||
def all_gatherv(
|
||||
self,
|
||||
output_tensor: torch.Tensor,
|
||||
input_tensor: torch.Tensor,
|
||||
sizes: list[int],
|
||||
stream=None,
|
||||
):
|
||||
if self.disabled:
|
||||
return
|
||||
# nccl communicator created on a specific device
|
||||
# will only work on tensors on the same device
|
||||
# otherwise it will cause "illegal memory access"
|
||||
assert input_tensor.device == self.device, (
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {input_tensor.device}"
|
||||
)
|
||||
if stream is None:
|
||||
stream = current_stream()
|
||||
assert output_tensor.shape[0] == sum(sizes)
|
||||
split_offset = 0
|
||||
self.nccl.ncclGroupStart()
|
||||
for root, split_size in enumerate(sizes):
|
||||
dst_slice = output_tensor[split_offset : split_offset + split_size]
|
||||
self.nccl.ncclBroadcast(
|
||||
buffer_type(input_tensor.data_ptr()),
|
||||
buffer_type(dst_slice.data_ptr()),
|
||||
dst_slice.numel(),
|
||||
ncclDataTypeEnum.from_torch(input_tensor.dtype),
|
||||
root,
|
||||
self.comm,
|
||||
cudaStream_t(stream.cuda_stream),
|
||||
)
|
||||
split_offset += split_size
|
||||
self.nccl.ncclGroupEnd()
|
||||
|
||||
def reduce_scatter(
|
||||
self,
|
||||
output_tensor: torch.Tensor,
|
||||
input_tensor: torch.Tensor,
|
||||
op: ReduceOp = ReduceOp.SUM,
|
||||
stream=None,
|
||||
):
|
||||
if self.disabled:
|
||||
return
|
||||
# nccl communicator created on a specific device
|
||||
# will only work on tensors on the same device
|
||||
# otherwise it will cause "illegal memory access"
|
||||
assert input_tensor.device == self.device, (
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {input_tensor.device}"
|
||||
)
|
||||
if stream is None:
|
||||
stream = current_stream()
|
||||
self.nccl.ncclReduceScatter(
|
||||
buffer_type(input_tensor.data_ptr()),
|
||||
buffer_type(output_tensor.data_ptr()),
|
||||
output_tensor.numel(),
|
||||
ncclDataTypeEnum.from_torch(input_tensor.dtype),
|
||||
ncclRedOpTypeEnum.from_torch(op),
|
||||
self.comm,
|
||||
cudaStream_t(stream.cuda_stream),
|
||||
)
|
||||
|
||||
def reduce_scatterv(
|
||||
self,
|
||||
output_tensor: torch.Tensor,
|
||||
input_tensor: torch.Tensor,
|
||||
sizes: list[int],
|
||||
op: ReduceOp = ReduceOp.SUM,
|
||||
stream=None,
|
||||
):
|
||||
if self.disabled:
|
||||
return
|
||||
# nccl communicator created on a specific device
|
||||
# will only work on tensors on the same device
|
||||
# otherwise it will cause "illegal memory access"
|
||||
assert input_tensor.device == self.device, (
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {input_tensor.device}"
|
||||
)
|
||||
if stream is None:
|
||||
stream = current_stream()
|
||||
|
||||
split_offset = 0
|
||||
self.nccl.ncclGroupStart()
|
||||
for root, split_size in enumerate(sizes):
|
||||
chunk = input_tensor[split_offset : split_offset + split_size, ...]
|
||||
self.nccl.ncclReduce(
|
||||
buffer_type(chunk.data_ptr()),
|
||||
buffer_type(output_tensor.data_ptr()),
|
||||
chunk.numel(),
|
||||
ncclDataTypeEnum.from_torch(input_tensor.dtype),
|
||||
ncclRedOpTypeEnum.from_torch(op),
|
||||
root,
|
||||
self.comm,
|
||||
cudaStream_t(stream.cuda_stream),
|
||||
)
|
||||
split_offset += split_size
|
||||
self.nccl.ncclGroupEnd()
|
||||
|
||||
def send(self, tensor: torch.Tensor, dst: int, stream=None):
|
||||
if self.disabled:
|
||||
return
|
||||
assert tensor.device == self.device, (
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {tensor.device}"
|
||||
)
|
||||
if stream is None:
|
||||
stream = current_stream()
|
||||
self.nccl.ncclSend(
|
||||
buffer_type(tensor.data_ptr()),
|
||||
tensor.numel(),
|
||||
ncclDataTypeEnum.from_torch(tensor.dtype),
|
||||
dst,
|
||||
self.comm,
|
||||
cudaStream_t(stream.cuda_stream),
|
||||
)
|
||||
|
||||
def recv(self, tensor: torch.Tensor, src: int, stream=None):
|
||||
if self.disabled:
|
||||
return
|
||||
assert tensor.device == self.device, (
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {tensor.device}"
|
||||
)
|
||||
if stream is None:
|
||||
stream = current_stream()
|
||||
self.nccl.ncclRecv(
|
||||
buffer_type(tensor.data_ptr()),
|
||||
tensor.numel(),
|
||||
ncclDataTypeEnum.from_torch(tensor.dtype),
|
||||
src,
|
||||
self.comm,
|
||||
cudaStream_t(stream.cuda_stream),
|
||||
)
|
||||
|
||||
def broadcast(self, tensor: torch.Tensor, src: int, stream=None):
|
||||
if self.disabled:
|
||||
return
|
||||
assert tensor.device == self.device, (
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {tensor.device}"
|
||||
)
|
||||
if stream is None:
|
||||
stream = current_stream()
|
||||
if src == self.rank:
|
||||
sendbuff = buffer_type(tensor.data_ptr())
|
||||
# NCCL requires the sender also to have a receive buffer
|
||||
recvbuff = buffer_type(tensor.data_ptr())
|
||||
else:
|
||||
sendbuff = buffer_type()
|
||||
recvbuff = buffer_type(tensor.data_ptr())
|
||||
self.nccl.ncclBroadcast(
|
||||
sendbuff,
|
||||
recvbuff,
|
||||
tensor.numel(),
|
||||
ncclDataTypeEnum.from_torch(tensor.dtype),
|
||||
src,
|
||||
self.comm,
|
||||
cudaStream_t(stream.cuda_stream),
|
||||
)
|
||||
|
||||
def group_start(self):
|
||||
self.nccl.ncclGroupStart()
|
||||
|
||||
def group_end(self):
|
||||
self.nccl.ncclGroupEnd()
|
||||
|
||||
def register_comm_window(self, tensor: torch.Tensor):
|
||||
return self.nccl.ncclCommWindowRegister(
|
||||
self.comm,
|
||||
buffer_type(tensor.data_ptr()),
|
||||
tensor.numel() * tensor.element_size(),
|
||||
1,
|
||||
)
|
||||
|
||||
def register_comm_window_raw(self, ptr: int, size: int):
|
||||
return self.nccl.ncclCommWindowRegister(self.comm, buffer_type(ptr), size, 1)
|
||||
|
||||
def deregister_comm_window(self, window):
|
||||
return self.nccl.ncclCommWindowDeregister(self.comm, window)
|
||||
Reference in New Issue
Block a user