init
This commit is contained in:
0
vllm/model_executor/parallel_utils/__init__.py
Normal file
0
vllm/model_executor/parallel_utils/__init__.py
Normal file
213
vllm/model_executor/parallel_utils/communication_op.py
Normal file
213
vllm/model_executor/parallel_utils/communication_op.py
Normal file
@@ -0,0 +1,213 @@
|
||||
from collections import namedtuple
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from vllm.model_executor.parallel_utils import cupy_utils
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
get_tensor_model_parallel_group,
|
||||
is_cupy_nccl_enabled_for_all_reduce,
|
||||
)
|
||||
from vllm.model_executor.parallel_utils.custom_all_reduce import custom_all_reduce
|
||||
from ixformer.contrib.torch.extension.ixformer_torch.distributed import (
|
||||
create_ixformer_group_from_pg,
|
||||
)
|
||||
from ixformer.distributed import all_reduce
|
||||
_IXFORMER_TENSOR_MODEL_PARALLEL_GROUP = None
|
||||
|
||||
|
||||
def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
|
||||
"""All-reduce the input tensor across model parallel group.
|
||||
|
||||
NOTE: This operation will be applied in-place on the input tensor if
|
||||
disable_custom_all_reduce is set to True. Otherwise, this operation may or
|
||||
may not be applied in place depending on whether custom all reduce is
|
||||
invoked for a particular tensor, which further depends on the tensor size
|
||||
and GPU topology.
|
||||
|
||||
TLDR: always assume this function modifies its input, but use the return
|
||||
value as the output.
|
||||
"""
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if get_tensor_model_parallel_world_size() == 1:
|
||||
return input_
|
||||
global _IXFORMER_TENSOR_MODEL_PARALLEL_GROUP
|
||||
if _IXFORMER_TENSOR_MODEL_PARALLEL_GROUP is None:
|
||||
_IXFORMER_TENSOR_MODEL_PARALLEL_GROUP = create_ixformer_group_from_pg(get_tensor_model_parallel_group())
|
||||
out = custom_all_reduce(input_)
|
||||
if out is not None:
|
||||
return out
|
||||
if is_cupy_nccl_enabled_for_all_reduce():
|
||||
# TODO: support multiple parallel groups.
|
||||
cupy_utils.all_reduce(input_)
|
||||
else:
|
||||
all_reduce(input_,group=_IXFORMER_TENSOR_MODEL_PARALLEL_GROUP,async_op=True)
|
||||
# TODO use our all reduce..
|
||||
# torch.distributed.all_reduce(input_,
|
||||
# group=get_tensor_model_parallel_group())
|
||||
return input_
|
||||
|
||||
|
||||
def tensor_model_parallel_all_gather(input_: torch.Tensor,
|
||||
dim: int = -1) -> torch.Tensor:
|
||||
"""All-gather the input tensor across model parallel group."""
|
||||
world_size = get_tensor_model_parallel_world_size()
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if world_size == 1:
|
||||
return input_
|
||||
assert -input_.dim() <= dim < input_.dim(), (
|
||||
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
|
||||
if dim < 0:
|
||||
# Convert negative dim to positive.
|
||||
dim += input_.dim()
|
||||
input_size = input_.size()
|
||||
# Allocate output tensor.
|
||||
output_tensor = torch.empty((world_size, ) + input_size,
|
||||
dtype=input_.dtype,
|
||||
device=input_.device)
|
||||
# All-gather.
|
||||
torch.distributed.all_gather_into_tensor(
|
||||
output_tensor, input_, group=get_tensor_model_parallel_group())
|
||||
# Reshape
|
||||
output_tensor = output_tensor.movedim(0, dim)
|
||||
output_tensor = output_tensor.reshape(input_size[:dim] +
|
||||
(world_size * input_size[dim], ) +
|
||||
input_size[dim + 1:])
|
||||
return output_tensor
|
||||
|
||||
|
||||
def tensor_model_parallel_gather(input_: torch.Tensor,
|
||||
dst: int = 0,
|
||||
dim: int = -1) -> torch.Tensor:
|
||||
"""Gather the input tensor across model parallel group.
|
||||
|
||||
NOTE: We assume that the input tensor is on the same device across
|
||||
all the ranks.
|
||||
"""
|
||||
world_size = get_tensor_model_parallel_world_size()
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if world_size == 1:
|
||||
return input_
|
||||
assert -input_.dim() <= dim < input_.dim(), (
|
||||
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
|
||||
if dim < 0:
|
||||
# Convert negative dim to positive.
|
||||
dim += input_.dim()
|
||||
# Allocate output tensor.
|
||||
if get_tensor_model_parallel_rank() == dst:
|
||||
gather_list = [torch.empty_like(input_) for _ in range(world_size)]
|
||||
else:
|
||||
gather_list = None
|
||||
# Gather.
|
||||
torch.distributed.gather(input_,
|
||||
gather_list,
|
||||
dst=dst,
|
||||
group=get_tensor_model_parallel_group())
|
||||
if get_tensor_model_parallel_rank() == dst:
|
||||
output_tensor = torch.cat(gather_list, dim=dim)
|
||||
else:
|
||||
output_tensor = None
|
||||
return output_tensor
|
||||
|
||||
|
||||
def broadcast(input_: torch.Tensor,
|
||||
src: int = 0,
|
||||
group: Optional[ProcessGroup] = None):
|
||||
"""Broadcast the input tensor."""
|
||||
group = group or torch.distributed.group.WORLD
|
||||
ranks = torch.distributed.get_process_group_ranks(group)
|
||||
assert src in ranks, f"Invalid src rank ({src})"
|
||||
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
world_size = torch.distributed.get_world_size(group=group)
|
||||
if world_size == 1:
|
||||
return input_
|
||||
# Broadcast.
|
||||
torch.distributed.broadcast(input_, src=src, group=group)
|
||||
return input_
|
||||
|
||||
|
||||
def broadcast_object_list(obj_list: List[Any],
|
||||
src: int = 0,
|
||||
group: Optional[ProcessGroup] = None):
|
||||
"""Broadcast the input object list."""
|
||||
group = group or torch.distributed.group.WORLD
|
||||
ranks = torch.distributed.get_process_group_ranks(group)
|
||||
assert src in ranks, f"Invalid src rank ({src})"
|
||||
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
world_size = torch.distributed.get_world_size(group=group)
|
||||
if world_size == 1:
|
||||
return obj_list
|
||||
# Broadcast.
|
||||
torch.distributed.broadcast_object_list(obj_list, src=src, group=group)
|
||||
return obj_list
|
||||
|
||||
|
||||
TensorMetadata = namedtuple("TensorMetadata", ["dtype", "size"])
|
||||
|
||||
|
||||
def broadcast_tensor_dict(
|
||||
tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None,
|
||||
src: int = 0,
|
||||
group: Optional[ProcessGroup] = None,
|
||||
) -> Dict[Any, Union[torch.Tensor, Any]]:
|
||||
"""Broadcast the input tensor dictionary."""
|
||||
group = group or torch.distributed.group.WORLD
|
||||
ranks = torch.distributed.get_process_group_ranks(group)
|
||||
assert src in ranks, f"Invalid src rank ({src})"
|
||||
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
world_size = torch.distributed.get_world_size(group=group)
|
||||
if world_size == 1:
|
||||
return tensor_dict
|
||||
|
||||
rank = torch.distributed.get_rank()
|
||||
if rank == src:
|
||||
assert isinstance(
|
||||
tensor_dict,
|
||||
dict), (f"Expecting a dictionary, got {type(tensor_dict)}")
|
||||
metadata_list = []
|
||||
for key, value in tensor_dict.items():
|
||||
if isinstance(value, torch.Tensor):
|
||||
assert value.is_cuda, (
|
||||
f"Tensor {key}: {value} is not on cuda. Currently we only "
|
||||
f"support broadcasting tensors on cuda.")
|
||||
metadata_list.append(
|
||||
(key, TensorMetadata(value.dtype, value.size())))
|
||||
else:
|
||||
metadata_list.append((key, value))
|
||||
torch.distributed.broadcast_object_list([metadata_list],
|
||||
src=src,
|
||||
group=group)
|
||||
for key, value in metadata_list:
|
||||
if isinstance(value, TensorMetadata):
|
||||
tensor = tensor_dict[key]
|
||||
torch.distributed.broadcast(tensor, src=src)
|
||||
else:
|
||||
recv_metadata_list = [None]
|
||||
torch.distributed.broadcast_object_list(recv_metadata_list,
|
||||
src=src,
|
||||
group=group)
|
||||
metadata_list = recv_metadata_list[0]
|
||||
tensor_dict = {}
|
||||
async_handles = []
|
||||
for key, value in metadata_list:
|
||||
if isinstance(value, TensorMetadata):
|
||||
tensor = torch.empty(value.size,
|
||||
dtype=value.dtype,
|
||||
device="cuda")
|
||||
async_handle = torch.distributed.broadcast(tensor,
|
||||
src=src,
|
||||
async_op=True,
|
||||
group=group)
|
||||
async_handles.append(async_handle)
|
||||
tensor_dict[key] = tensor
|
||||
else:
|
||||
tensor_dict[key] = value
|
||||
for async_handle in async_handles:
|
||||
async_handle.wait()
|
||||
return tensor_dict
|
||||
130
vllm/model_executor/parallel_utils/cupy_utils.py
Normal file
130
vllm/model_executor/parallel_utils/cupy_utils.py
Normal file
@@ -0,0 +1,130 @@
|
||||
"""CuPy utilities for all-reduce.
|
||||
|
||||
We use CuPy all-reduce instead of torch.distributed.all_reduce when capturing
|
||||
CUDA graphs, because torch.distributed.all_reduce causes errors when capturing
|
||||
CUDA graphs.
|
||||
|
||||
NOTE: We use CuPy 12.3 since CuPy 13.0 does not support Python 3.8.
|
||||
TODO: Remove this file when torch.distributed.all_reduce is fixed.
|
||||
"""
|
||||
import contextlib
|
||||
|
||||
import torch
|
||||
from torch.distributed import ReduceOp
|
||||
|
||||
try:
|
||||
import cupy
|
||||
from cupy.cuda import nccl
|
||||
from cupyx.distributed import NCCLBackend
|
||||
except ImportError as e:
|
||||
cupy = e
|
||||
nccl = None
|
||||
|
||||
class NCCLBackend:
|
||||
...
|
||||
|
||||
|
||||
_OP_MAPPING = {
|
||||
ReduceOp.SUM: "sum",
|
||||
ReduceOp.PRODUCT: "prod",
|
||||
ReduceOp.MIN: "min",
|
||||
ReduceOp.MAX: "max",
|
||||
}
|
||||
|
||||
|
||||
class NCCLBackendWithBFloat16(NCCLBackend):
|
||||
# This is enough to add bfloat16 support for most operations,
|
||||
# but broadcast will fail (will require changes in compiled
|
||||
# cupy code).
|
||||
def _get_nccl_dtype_and_count(self, array, count=None):
|
||||
nccl_dtype, count = super()._get_nccl_dtype_and_count(array, count)
|
||||
torch_dtype = getattr(array, "_torch_dtype", None)
|
||||
if torch_dtype is torch.bfloat16:
|
||||
nccl_dtype = nccl.NCCL_BFLOAT16
|
||||
return nccl_dtype, count
|
||||
|
||||
def barrier(self) -> None:
|
||||
raise RuntimeError(
|
||||
"Currently, CuPy NCCL barrier is not supported since the TCP "
|
||||
"store is immediately stopped after the initialization.")
|
||||
|
||||
|
||||
_NCCL_BACKEND = None
|
||||
_WORLD_SIZE = 0
|
||||
|
||||
|
||||
def is_initialized() -> bool:
|
||||
"""Returns whether the NCCL backend is initialized."""
|
||||
return _NCCL_BACKEND is not None
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def set_cupy_stream(stream: torch.cuda.Stream):
|
||||
"""Set the cuda stream for communication"""
|
||||
cupy_stream = cupy.cuda.ExternalStream(stream.cuda_stream,
|
||||
stream.device_index)
|
||||
with cupy_stream:
|
||||
yield
|
||||
|
||||
|
||||
def init_process_group(world_size: int, rank: int, host: str,
|
||||
port: int) -> None:
|
||||
"""Initializes the CuPy NCCL backend.
|
||||
|
||||
# TODO: handle NCCL timeouts.
|
||||
"""
|
||||
assert not is_initialized()
|
||||
|
||||
if isinstance(cupy, Exception):
|
||||
raise ImportError(
|
||||
"NCCLBackend is not available. Please install cupy.") from cupy
|
||||
|
||||
# TODO(woosuk): Create TP and PP process groups for CuPy.
|
||||
global _NCCL_BACKEND
|
||||
global _WORLD_SIZE
|
||||
assert world_size > 0, f"{world_size=} should be a positive integer"
|
||||
assert 0 <= rank < world_size, (
|
||||
f"{rank=} should be a integer between [0, {world_size})")
|
||||
|
||||
cupy.cuda.runtime.setDevice(torch.cuda.current_device())
|
||||
_NCCL_BACKEND = NCCLBackendWithBFloat16(world_size, rank, host, port)
|
||||
_WORLD_SIZE = world_size
|
||||
|
||||
# Stop the TCP store to prevent the deadlock issues at termination time.
|
||||
# FIXME(woosuk): This is hacky. Find a more robust solution.
|
||||
if rank == 0 and hasattr(_NCCL_BACKEND, "_store"):
|
||||
_NCCL_BACKEND._store.stop()
|
||||
|
||||
|
||||
def all_reduce(input_: torch.Tensor, op=ReduceOp.SUM) -> None:
|
||||
"""All-reduces the input tensor across the process group."""
|
||||
assert input_.is_cuda, f"{input_} should be a cuda tensor"
|
||||
# Hack to support bfloat16
|
||||
torch_dtype = input_.dtype
|
||||
if torch_dtype is torch.bfloat16:
|
||||
# We need to view as float16, otherwise
|
||||
# cupy will fail. This will not change
|
||||
# the underlying data.
|
||||
input_ = input_.view(torch.float16)
|
||||
cupy_input = cupy.asarray(input_)
|
||||
cupy_input._torch_dtype = torch_dtype # pylint: disable=protected-access
|
||||
_NCCL_BACKEND.all_reduce(in_array=cupy_input,
|
||||
out_array=cupy_input,
|
||||
op=_OP_MAPPING[op])
|
||||
|
||||
|
||||
def destroy_process_group() -> None:
|
||||
"""Destroys the NCCL backend."""
|
||||
global _NCCL_BACKEND
|
||||
global _WORLD_SIZE
|
||||
_NCCL_BACKEND = None
|
||||
_WORLD_SIZE = 0
|
||||
|
||||
|
||||
def get_world_size() -> int:
|
||||
"""Returns the world size."""
|
||||
return _WORLD_SIZE
|
||||
|
||||
|
||||
def get_nccl_backend():
|
||||
return _NCCL_BACKEND
|
||||
247
vllm/model_executor/parallel_utils/custom_all_reduce.py
Normal file
247
vllm/model_executor/parallel_utils/custom_all_reduce.py
Normal file
@@ -0,0 +1,247 @@
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_world_size, get_tensor_model_parallel_rank)
|
||||
|
||||
try:
|
||||
from vllm._C import custom_ar
|
||||
# import pynvml avoid import error
|
||||
except ImportError:
|
||||
# For AMD GPUs
|
||||
custom_ar = None
|
||||
pynvml = None
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_CA_HANDLE = None
|
||||
_IS_CAPTURING = False
|
||||
_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
|
||||
|
||||
|
||||
def init_custom_ar() -> None:
|
||||
global _CA_HANDLE
|
||||
if _CA_HANDLE is not None:
|
||||
return
|
||||
rank = get_tensor_model_parallel_rank()
|
||||
world_size = get_tensor_model_parallel_world_size()
|
||||
if world_size == 1:
|
||||
# No need to initialize custom allreduce for single GPU case.
|
||||
return
|
||||
|
||||
if world_size not in _SUPPORTED_WORLD_SIZES:
|
||||
logger.warn(
|
||||
"Custom allreduce is disabled due to an unsupported world size: "
|
||||
"%d. Supported world sizes: %s. To silence this warning, specify"
|
||||
"disable_custom_all_reduce=True explicitly.", world_size,
|
||||
str(_SUPPORTED_WORLD_SIZES))
|
||||
return
|
||||
if not _can_p2p(rank, world_size):
|
||||
logger.warn(
|
||||
"Custom allreduce is disabled because your platform lacks GPU P2P"
|
||||
" capability. To silence this warning, specify"
|
||||
"disable_custom_all_reduce=True explicitly.")
|
||||
return
|
||||
_CA_HANDLE = CustomAllreduce(rank, world_size)
|
||||
|
||||
|
||||
def begin_capture() -> None:
|
||||
global _IS_CAPTURING
|
||||
_IS_CAPTURING = True
|
||||
|
||||
|
||||
def end_capture() -> None:
|
||||
global _IS_CAPTURING
|
||||
_IS_CAPTURING = False
|
||||
|
||||
|
||||
def is_capturing() -> bool:
|
||||
return _IS_CAPTURING and _CA_HANDLE is not None
|
||||
|
||||
|
||||
def get_handle() -> Optional["CustomAllreduce"]:
|
||||
return _CA_HANDLE
|
||||
|
||||
|
||||
def is_initialized() -> bool:
|
||||
return _CA_HANDLE is not None
|
||||
|
||||
|
||||
@contextmanager
|
||||
def capture():
|
||||
try:
|
||||
begin_capture()
|
||||
yield
|
||||
finally:
|
||||
end_capture()
|
||||
handle = get_handle()
|
||||
if handle is not None:
|
||||
handle.register_graph_buffers()
|
||||
|
||||
|
||||
def custom_all_reduce(input: torch.Tensor) -> Optional[torch.Tensor]:
|
||||
ca_handle = get_handle()
|
||||
# when custom allreduce is disabled, this will be None
|
||||
if ca_handle is None:
|
||||
return
|
||||
if is_capturing():
|
||||
if torch.cuda.is_current_stream_capturing():
|
||||
if ca_handle.should_custom_ar(input):
|
||||
return ca_handle.all_reduce_reg(input)
|
||||
else:
|
||||
if ca_handle.should_custom_ar(input):
|
||||
# if warm up, mimic the allocation pattern
|
||||
# since custom allreduce is out-of-place
|
||||
return torch.empty_like(input)
|
||||
else:
|
||||
# note: outside of cuda graph context,
|
||||
# custom allreduce incurs a cost of cudaMemcpy, which should
|
||||
# be small(<=1% of overall latency) compared to the performance
|
||||
# gains of using custom kernels
|
||||
if ca_handle.should_custom_ar(input):
|
||||
return ca_handle.all_reduce_unreg(input)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _nvml():
|
||||
try:
|
||||
pynvml.nvmlInit()
|
||||
yield
|
||||
finally:
|
||||
pynvml.nvmlShutdown()
|
||||
|
||||
|
||||
# query if the set of gpus are fully connected by nvlink (1 hop)
|
||||
@_nvml()
|
||||
def _is_full_nvlink(rank, world_size):
|
||||
handle = pynvml.nvmlDeviceGetHandleByIndex(rank)
|
||||
for i in range(world_size):
|
||||
if i != rank:
|
||||
try:
|
||||
link_state = pynvml.nvmlDeviceGetNvLinkState(handle, i)
|
||||
if not link_state:
|
||||
return False
|
||||
except pynvml.NVMLError as error:
|
||||
logger.info(
|
||||
f"NVLink detection failed with message \"{str(error)}\". "
|
||||
"This is normal if your machine has no NVLink equipped")
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _can_p2p(rank: int, world_size: int) -> bool:
|
||||
for i in range(world_size):
|
||||
if i == rank:
|
||||
continue
|
||||
if not torch.cuda.can_device_access_peer(rank, i):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
class CustomAllreduce:
|
||||
|
||||
# max_size: max supported allreduce size
|
||||
def __init__(self, rank, world_size, max_size=8192 * 1024) -> None:
|
||||
self.max_size = max_size
|
||||
self.world_size = world_size
|
||||
self.full_nvlink = False
|
||||
self._ptr = None
|
||||
self.buffer = None
|
||||
if not custom_ar.is_init():
|
||||
custom_ar.init_cumtom_ar()
|
||||
# TODO aling
|
||||
"""
|
||||
# buffers memory are owned by this Python class and passed to C++
|
||||
# meta data composes of two parts: meta data for synchronization
|
||||
# (256 bytes) and a temporary buffer for storing intermediate
|
||||
# allreduce results.
|
||||
self.meta = torch.zeros(custom_ar.meta_size() + max_size,
|
||||
dtype=torch.uint8,
|
||||
device="cuda")
|
||||
# This is a pre-registered IPC buffer. In eager mode, input tensors
|
||||
# are first copied into this buffer before allreduce is performed
|
||||
self.buffer = torch.empty(max_size, dtype=torch.uint8, device="cuda")
|
||||
# This is a buffer for storing the tuples of pointers pointing to
|
||||
# IPC buffers from all ranks. Each registered tuple has size of
|
||||
# 8*world_size bytes where world_size is at most 8. Allocating 8MB
|
||||
# is enough for 131072 such tuples. The largest model I've seen only
|
||||
# needs less than 10000 of registered tuples.
|
||||
self.rank_data = torch.empty(8 * 1024 * 1024,
|
||||
dtype=torch.uint8,
|
||||
device="cuda")
|
||||
self.max_size = max_size
|
||||
self.world_size = world_size
|
||||
handles, offsets = self._get_ipc_meta(self.meta)
|
||||
self.full_nvlink = _is_full_nvlink(rank, world_size)
|
||||
self._ptr = custom_ar.init_custom_ar(self.meta, self.rank_data,
|
||||
handles, offsets, rank,
|
||||
self.full_nvlink)
|
||||
self.fast_cond = self.full_nvlink or world_size <= 2
|
||||
self.register_buffer(self.buffer)
|
||||
"""
|
||||
#TODO align
|
||||
"""
|
||||
def _get_ipc_meta(self, inp: torch.Tensor):
|
||||
data = inp.untyped_storage()._share_cuda_()
|
||||
shard_data = (
|
||||
data[1], # ipc handle to base ptr
|
||||
data[3], # offset of base ptr
|
||||
)
|
||||
return self._gather_ipc_meta(shard_data)
|
||||
|
||||
def _gather_ipc_meta(self, shard_data):
|
||||
all_data = [None] * self.world_size
|
||||
dist.all_gather_object(all_data, shard_data)
|
||||
|
||||
handles = []
|
||||
offsets = []
|
||||
for i in range(len(all_data)):
|
||||
handles.append(all_data[i][0])
|
||||
offsets.append(all_data[i][1])
|
||||
return handles, offsets
|
||||
|
||||
def register_buffer(self, inp: torch.Tensor):
|
||||
handles, offsets = self._get_ipc_meta(inp)
|
||||
custom_ar.register_buffer(self._ptr, inp, handles, offsets)
|
||||
|
||||
def register_graph_buffers(self):
|
||||
handle, offset = custom_ar.get_graph_buffer_ipc_meta(self._ptr)
|
||||
handles, offsets = self._gather_ipc_meta((bytes(handle), offset))
|
||||
logger.info("Registering %d cuda graph addresses", len(offset))
|
||||
custom_ar.register_graph_buffers(self._ptr, handles, offsets)
|
||||
"""
|
||||
|
||||
def should_custom_ar(self, inp: torch.Tensor):
|
||||
return custom_ar.should_custom_ar(inp, self.max_size, self.world_size,
|
||||
self.full_nvlink)
|
||||
|
||||
# all reduce, assuming inp tensor is IPC registered with register_buffer,
|
||||
# or, in the context of cuda graphs, register_graph_buffers
|
||||
def all_reduce_reg(self, inp: torch.Tensor, out: torch.Tensor = None):
|
||||
if out is None:
|
||||
out = torch.empty_like(inp)
|
||||
custom_ar.all_reduce_reg(self._ptr, inp, out)
|
||||
return out
|
||||
|
||||
# all reduce, assuming inp tensor is NOT IPC registered
|
||||
def all_reduce_unreg(self, inp: torch.Tensor, out: torch.Tensor = None):
|
||||
if out is None:
|
||||
out = torch.empty_like(inp)
|
||||
custom_ar.all_reduce_unreg(self._ptr, inp, self.buffer, out)
|
||||
return out
|
||||
|
||||
def close(self):
|
||||
custom_ar.dispose(self._ptr)
|
||||
# TODO align
|
||||
"""
|
||||
if self._ptr:
|
||||
custom_ar.dispose(self._ptr)
|
||||
self._ptr = 0
|
||||
"""
|
||||
|
||||
def __del__(self):
|
||||
self.close()
|
||||
245
vllm/model_executor/parallel_utils/parallel_state.py
Normal file
245
vllm/model_executor/parallel_utils/parallel_state.py
Normal file
@@ -0,0 +1,245 @@
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Adapted from
|
||||
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
|
||||
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
||||
"""Tensor and pipeline parallel groups."""
|
||||
import contextlib
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.parallel_utils import cupy_utils
|
||||
|
||||
# Tensor model parallel group that the current rank belongs to.
|
||||
_TENSOR_MODEL_PARALLEL_GROUP = None
|
||||
# Pipeline model parallel group that the current rank belongs to.
|
||||
_PIPELINE_MODEL_PARALLEL_GROUP = None
|
||||
|
||||
# A list of global ranks for each pipeline group to ease calculation of the
|
||||
# source rank when broadcasting from the first or last pipeline stage.
|
||||
_PIPELINE_GLOBAL_RANKS = None
|
||||
|
||||
|
||||
def initialize_model_parallel(
|
||||
tensor_model_parallel_size: int = 1,
|
||||
pipeline_model_parallel_size: int = 1,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize model parallel groups.
|
||||
|
||||
Arguments:
|
||||
tensor_model_parallel_size: number of GPUs used for tensor model
|
||||
parallelism.
|
||||
pipeline_model_parallel_size: number of GPUs used for pipeline model
|
||||
parallelism.
|
||||
|
||||
Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we
|
||||
use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
|
||||
the model pipeline. The present function will
|
||||
create 4 tensor model-parallel groups and 2 pipeline model-parallel groups:
|
||||
4 tensor model-parallel groups:
|
||||
[g0, g1], [g2, g3], [g4, g5], [g6, g7]
|
||||
2 pipeline model-parallel groups:
|
||||
[g0, g2, g4, g6], [g1, g3, g5, g7]
|
||||
Note that for efficiency, the caller should make sure adjacent ranks
|
||||
are on the same DGX box. For example if we are using 2 DGX-1 boxes
|
||||
with a total of 16 GPUs, rank 0 to 7 belong to the first box and
|
||||
ranks 8 to 15 belong to the second box.
|
||||
"""
|
||||
# Get world size and rank. Ensure some consistencies.
|
||||
assert torch.distributed.is_initialized()
|
||||
world_size: int = torch.distributed.get_world_size()
|
||||
|
||||
if (world_size !=
|
||||
tensor_model_parallel_size * pipeline_model_parallel_size):
|
||||
raise RuntimeError(
|
||||
f"world_size ({world_size}) is not equal to "
|
||||
f"tensor_model_parallel_size ({tensor_model_parallel_size}) x "
|
||||
f"pipeline_model_parallel_size ({pipeline_model_parallel_size})")
|
||||
|
||||
num_tensor_model_parallel_groups: int = (world_size //
|
||||
tensor_model_parallel_size)
|
||||
num_pipeline_model_parallel_groups: int = (world_size //
|
||||
pipeline_model_parallel_size)
|
||||
rank = torch.distributed.get_rank()
|
||||
|
||||
# Build the tensor model-parallel groups.
|
||||
global _TENSOR_MODEL_PARALLEL_GROUP
|
||||
assert _TENSOR_MODEL_PARALLEL_GROUP is None, (
|
||||
"tensor model parallel group is already initialized")
|
||||
for i in range(num_tensor_model_parallel_groups):
|
||||
ranks = range(i * tensor_model_parallel_size,
|
||||
(i + 1) * tensor_model_parallel_size)
|
||||
group = torch.distributed.new_group(ranks)
|
||||
if rank in ranks:
|
||||
_TENSOR_MODEL_PARALLEL_GROUP = group
|
||||
|
||||
# Build the pipeline model-parallel groups.
|
||||
global _PIPELINE_MODEL_PARALLEL_GROUP
|
||||
global _PIPELINE_GLOBAL_RANKS
|
||||
assert _PIPELINE_MODEL_PARALLEL_GROUP is None, (
|
||||
"pipeline model parallel group is already initialized")
|
||||
for i in range(num_pipeline_model_parallel_groups):
|
||||
ranks = range(i, world_size, num_pipeline_model_parallel_groups)
|
||||
group = torch.distributed.new_group(ranks)
|
||||
if rank in ranks:
|
||||
_PIPELINE_MODEL_PARALLEL_GROUP = group
|
||||
_PIPELINE_GLOBAL_RANKS = ranks
|
||||
|
||||
|
||||
def ensure_model_parallel_initialized(
|
||||
tensor_model_parallel_size: int,
|
||||
pipeline_model_parallel_size: int,
|
||||
) -> None:
|
||||
"""Helper to initialize model parallel groups if they are not initialized,
|
||||
or ensure tensor-parallel and pipeline-parallel sizes are equal to expected
|
||||
values if the model parallel groups are initialized.
|
||||
"""
|
||||
if not model_parallel_is_initialized():
|
||||
initialize_model_parallel(tensor_model_parallel_size,
|
||||
pipeline_model_parallel_size)
|
||||
return
|
||||
|
||||
assert (
|
||||
get_tensor_model_parallel_world_size() == tensor_model_parallel_size
|
||||
), ("tensor parallel group already initialized, but of unexpected size: "
|
||||
f"{get_tensor_model_parallel_world_size()=} vs. "
|
||||
f"{tensor_model_parallel_size=}")
|
||||
assert (get_pipeline_model_parallel_world_size(
|
||||
) == pipeline_model_parallel_size), (
|
||||
"pipeline parallel group already initialized, but of unexpected size: "
|
||||
f"{get_pipeline_model_parallel_world_size()=} vs. "
|
||||
f"{pipeline_model_parallel_size=}")
|
||||
|
||||
|
||||
def model_parallel_is_initialized():
|
||||
"""Check if tensor and pipeline parallel groups are initialized."""
|
||||
return (_TENSOR_MODEL_PARALLEL_GROUP is not None
|
||||
and _PIPELINE_MODEL_PARALLEL_GROUP is not None)
|
||||
|
||||
|
||||
def get_tensor_model_parallel_group():
|
||||
"""Get the tensor model parallel group the caller rank belongs to."""
|
||||
assert _TENSOR_MODEL_PARALLEL_GROUP is not None, (
|
||||
"tensor model parallel group is not initialized")
|
||||
return _TENSOR_MODEL_PARALLEL_GROUP
|
||||
|
||||
|
||||
def get_pipeline_model_parallel_group():
|
||||
"""Get the pipeline model parallel group the caller rank belongs to."""
|
||||
assert _PIPELINE_MODEL_PARALLEL_GROUP is not None, (
|
||||
"pipeline model parallel group is not initialized")
|
||||
return _PIPELINE_MODEL_PARALLEL_GROUP
|
||||
|
||||
|
||||
def get_tensor_model_parallel_world_size():
|
||||
"""Return world size for the tensor model parallel group."""
|
||||
return torch.distributed.get_world_size(
|
||||
group=get_tensor_model_parallel_group())
|
||||
|
||||
|
||||
def get_pipeline_model_parallel_world_size():
|
||||
"""Return world size for the pipeline model parallel group."""
|
||||
return torch.distributed.get_world_size(
|
||||
group=get_pipeline_model_parallel_group())
|
||||
|
||||
|
||||
def get_tensor_model_parallel_rank():
|
||||
"""Return my rank for the tensor model parallel group."""
|
||||
return torch.distributed.get_rank(group=get_tensor_model_parallel_group())
|
||||
|
||||
|
||||
def get_pipeline_model_parallel_rank():
|
||||
"""Return my rank for the pipeline model parallel group."""
|
||||
return torch.distributed.get_rank(
|
||||
group=get_pipeline_model_parallel_group())
|
||||
|
||||
|
||||
def get_tensor_model_parallel_src_rank():
|
||||
"""Calculate the global rank corresponding to the first local rank
|
||||
in the tensor model parallel group."""
|
||||
global_rank = torch.distributed.get_rank()
|
||||
local_world_size = get_tensor_model_parallel_world_size()
|
||||
return (global_rank // local_world_size) * local_world_size
|
||||
|
||||
|
||||
def get_pipeline_model_parallel_first_rank():
|
||||
"""Return the global rank of the first process in the pipeline for the
|
||||
current tensor parallel group"""
|
||||
assert _PIPELINE_GLOBAL_RANKS is not None, (
|
||||
"Pipeline parallel group is not initialized")
|
||||
return _PIPELINE_GLOBAL_RANKS[0]
|
||||
|
||||
|
||||
def get_pipeline_model_parallel_last_rank():
|
||||
"""Return the global rank of the last process in the pipeline for the
|
||||
current tensor parallel group"""
|
||||
assert _PIPELINE_GLOBAL_RANKS is not None, (
|
||||
"Pipeline parallel group is not initialized")
|
||||
last_rank_local = get_pipeline_model_parallel_world_size() - 1
|
||||
return _PIPELINE_GLOBAL_RANKS[last_rank_local]
|
||||
|
||||
|
||||
def get_pipeline_model_parallel_next_rank():
|
||||
"""Return the global rank that follows the caller in the pipeline"""
|
||||
assert _PIPELINE_GLOBAL_RANKS is not None, (
|
||||
"Pipeline parallel group is not initialized")
|
||||
rank_in_pipeline = get_pipeline_model_parallel_rank()
|
||||
world_size = get_pipeline_model_parallel_world_size()
|
||||
return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size]
|
||||
|
||||
|
||||
def get_pipeline_model_parallel_prev_rank():
|
||||
"""Return the global rank that precedes the caller in the pipeline"""
|
||||
assert _PIPELINE_GLOBAL_RANKS is not None, (
|
||||
"Pipeline parallel group is not initialized")
|
||||
rank_in_pipeline = get_pipeline_model_parallel_rank()
|
||||
world_size = get_pipeline_model_parallel_world_size()
|
||||
return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size]
|
||||
|
||||
|
||||
def destroy_model_parallel():
|
||||
"""Set the groups to none and destroy them."""
|
||||
global _TENSOR_MODEL_PARALLEL_GROUP
|
||||
if _TENSOR_MODEL_PARALLEL_GROUP:
|
||||
torch.distributed.destroy_process_group(_TENSOR_MODEL_PARALLEL_GROUP)
|
||||
_TENSOR_MODEL_PARALLEL_GROUP = None
|
||||
global _PIPELINE_MODEL_PARALLEL_GROUP
|
||||
if _PIPELINE_MODEL_PARALLEL_GROUP:
|
||||
torch.distributed.destroy_process_group(_PIPELINE_MODEL_PARALLEL_GROUP)
|
||||
_PIPELINE_MODEL_PARALLEL_GROUP = None
|
||||
global _PIPELINE_GLOBAL_RANKS
|
||||
_PIPELINE_GLOBAL_RANKS = None
|
||||
|
||||
# Destroy the cupy states if any.
|
||||
cupy_utils.destroy_process_group()
|
||||
|
||||
|
||||
# Whether to use cupy for nccl all reduce.
|
||||
# We use cupy for all reduce when using CUDA graph, because torch.distributed
|
||||
# is not well supported by CUDA graph.
|
||||
_ENABLE_CUPY_FOR_ALL_REDUCE = False
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def with_cupy_nccl_for_all_reduce():
|
||||
"""use CuPy nccl instead of torch.distributed for all reduce"""
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
if tp_size == 1:
|
||||
# No-op.
|
||||
# NOTE(woosuk): We don't initialize CuPy when tp_size is 1.
|
||||
yield
|
||||
else:
|
||||
global _ENABLE_CUPY_FOR_ALL_REDUCE
|
||||
old = _ENABLE_CUPY_FOR_ALL_REDUCE
|
||||
_ENABLE_CUPY_FOR_ALL_REDUCE = True
|
||||
|
||||
stream = torch.cuda.current_stream()
|
||||
with cupy_utils.set_cupy_stream(stream):
|
||||
yield
|
||||
_ENABLE_CUPY_FOR_ALL_REDUCE = old
|
||||
|
||||
|
||||
def is_cupy_nccl_enabled_for_all_reduce():
|
||||
"""check if CuPy nccl is enabled for all reduce"""
|
||||
global _ENABLE_CUPY_FOR_ALL_REDUCE
|
||||
return _ENABLE_CUPY_FOR_ALL_REDUCE
|
||||
48
vllm/model_executor/parallel_utils/utils.py
Normal file
48
vllm/model_executor/parallel_utils/utils.py
Normal file
@@ -0,0 +1,48 @@
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Adapted from
|
||||
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
|
||||
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
||||
from typing import Sequence
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def ensure_divisibility(numerator, denominator):
|
||||
"""Ensure that numerator is divisible by the denominator."""
|
||||
assert numerator % denominator == 0, "{} is not divisible by {}".format(
|
||||
numerator, denominator)
|
||||
|
||||
|
||||
def divide(numerator, denominator):
|
||||
"""Ensure that numerator is divisible by the denominator and return
|
||||
the division value."""
|
||||
ensure_divisibility(numerator, denominator)
|
||||
return numerator // denominator
|
||||
|
||||
|
||||
def split_tensor_along_last_dim(
|
||||
tensor: torch.Tensor,
|
||||
num_partitions: int,
|
||||
contiguous_split_chunks: bool = False,
|
||||
) -> Sequence[torch.Tensor]:
|
||||
""" Split a tensor along its last dimension.
|
||||
|
||||
Arguments:
|
||||
tensor: input tensor.
|
||||
num_partitions: number of partitions to split the tensor
|
||||
contiguous_split_chunks: If True, make each chunk contiguous
|
||||
in memory.
|
||||
|
||||
Returns:
|
||||
A list of Tensors
|
||||
"""
|
||||
# Get the size and dimension.
|
||||
last_dim = tensor.dim() - 1
|
||||
last_dim_size = divide(tensor.size()[last_dim], num_partitions)
|
||||
# Split.
|
||||
tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
|
||||
# NOTE: torch.split does not create contiguous tensors by default.
|
||||
if contiguous_split_chunks:
|
||||
return tuple(chunk.contiguous() for chunk in tensor_list)
|
||||
|
||||
return tensor_list
|
||||
Reference in New Issue
Block a user