131 lines
3.9 KiB
Python
131 lines
3.9 KiB
Python
"""CuPy utilities for all-reduce.
|
|
|
|
We use CuPy all-reduce instead of torch.distributed.all_reduce when capturing
|
|
CUDA graphs, because torch.distributed.all_reduce causes errors when capturing
|
|
CUDA graphs.
|
|
|
|
NOTE: We use CuPy 12.3 since CuPy 13.0 does not support Python 3.8.
|
|
TODO: Remove this file when torch.distributed.all_reduce is fixed.
|
|
"""
|
|
import contextlib
|
|
|
|
import torch
|
|
from torch.distributed import ReduceOp
|
|
|
|
try:
|
|
import cupy
|
|
from cupy.cuda import nccl
|
|
from cupyx.distributed import NCCLBackend
|
|
except ImportError as e:
|
|
cupy = e
|
|
nccl = None
|
|
|
|
class NCCLBackend:
|
|
...
|
|
|
|
|
|
_OP_MAPPING = {
|
|
ReduceOp.SUM: "sum",
|
|
ReduceOp.PRODUCT: "prod",
|
|
ReduceOp.MIN: "min",
|
|
ReduceOp.MAX: "max",
|
|
}
|
|
|
|
|
|
class NCCLBackendWithBFloat16(NCCLBackend):
|
|
# This is enough to add bfloat16 support for most operations,
|
|
# but broadcast will fail (will require changes in compiled
|
|
# cupy code).
|
|
def _get_nccl_dtype_and_count(self, array, count=None):
|
|
nccl_dtype, count = super()._get_nccl_dtype_and_count(array, count)
|
|
torch_dtype = getattr(array, "_torch_dtype", None)
|
|
if torch_dtype is torch.bfloat16:
|
|
nccl_dtype = nccl.NCCL_BFLOAT16
|
|
return nccl_dtype, count
|
|
|
|
def barrier(self) -> None:
|
|
raise RuntimeError(
|
|
"Currently, CuPy NCCL barrier is not supported since the TCP "
|
|
"store is immediately stopped after the initialization.")
|
|
|
|
|
|
_NCCL_BACKEND = None
|
|
_WORLD_SIZE = 0
|
|
|
|
|
|
def is_initialized() -> bool:
|
|
"""Returns whether the NCCL backend is initialized."""
|
|
return _NCCL_BACKEND is not None
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def set_cupy_stream(stream: torch.cuda.Stream):
|
|
"""Set the cuda stream for communication"""
|
|
cupy_stream = cupy.cuda.ExternalStream(stream.cuda_stream,
|
|
stream.device_index)
|
|
with cupy_stream:
|
|
yield
|
|
|
|
|
|
def init_process_group(world_size: int, rank: int, host: str,
|
|
port: int) -> None:
|
|
"""Initializes the CuPy NCCL backend.
|
|
|
|
# TODO: handle NCCL timeouts.
|
|
"""
|
|
assert not is_initialized()
|
|
|
|
if isinstance(cupy, Exception):
|
|
raise ImportError(
|
|
"NCCLBackend is not available. Please install cupy.") from cupy
|
|
|
|
# TODO(woosuk): Create TP and PP process groups for CuPy.
|
|
global _NCCL_BACKEND
|
|
global _WORLD_SIZE
|
|
assert world_size > 0, f"{world_size=} should be a positive integer"
|
|
assert 0 <= rank < world_size, (
|
|
f"{rank=} should be a integer between [0, {world_size})")
|
|
|
|
cupy.cuda.runtime.setDevice(torch.cuda.current_device())
|
|
_NCCL_BACKEND = NCCLBackendWithBFloat16(world_size, rank, host, port)
|
|
_WORLD_SIZE = world_size
|
|
|
|
# Stop the TCP store to prevent the deadlock issues at termination time.
|
|
# FIXME(woosuk): This is hacky. Find a more robust solution.
|
|
if rank == 0 and hasattr(_NCCL_BACKEND, "_store"):
|
|
_NCCL_BACKEND._store.stop()
|
|
|
|
|
|
def all_reduce(input_: torch.Tensor, op=ReduceOp.SUM) -> None:
|
|
"""All-reduces the input tensor across the process group."""
|
|
assert input_.is_cuda, f"{input_} should be a cuda tensor"
|
|
# Hack to support bfloat16
|
|
torch_dtype = input_.dtype
|
|
if torch_dtype is torch.bfloat16:
|
|
# We need to view as float16, otherwise
|
|
# cupy will fail. This will not change
|
|
# the underlying data.
|
|
input_ = input_.view(torch.float16)
|
|
cupy_input = cupy.asarray(input_)
|
|
cupy_input._torch_dtype = torch_dtype # pylint: disable=protected-access
|
|
_NCCL_BACKEND.all_reduce(in_array=cupy_input,
|
|
out_array=cupy_input,
|
|
op=_OP_MAPPING[op])
|
|
|
|
|
|
def destroy_process_group() -> None:
|
|
"""Destroys the NCCL backend."""
|
|
global _NCCL_BACKEND
|
|
global _WORLD_SIZE
|
|
_NCCL_BACKEND = None
|
|
_WORLD_SIZE = 0
|
|
|
|
|
|
def get_world_size() -> int:
|
|
"""Returns the world size."""
|
|
return _WORLD_SIZE
|
|
|
|
|
|
def get_nccl_backend():
|
|
return _NCCL_BACKEND
|