[gpt-oss] Add gpt-oss bf16 support
This commit is contained in:
0
vllm/distributed/device_communicators/__init__.py
Normal file
0
vllm/distributed/device_communicators/__init__.py
Normal file
264
vllm/distributed/device_communicators/all2all.py
Normal file
264
vllm/distributed/device_communicators/all2all.py
Normal file
@@ -0,0 +1,264 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import importlib.util
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .base_device_communicator import All2AllManagerBase, Cache
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
||||
else:
|
||||
FusedMoE = None
|
||||
|
||||
|
||||
class NaiveAll2AllManager(All2AllManagerBase):
|
||||
"""
|
||||
A naive implementation of all2all communication.
|
||||
It uses all-reduce under the hood, which is not
|
||||
efficient at all. The main purpose is for testing and
|
||||
debugging.
|
||||
"""
|
||||
|
||||
def __init__(self, cpu_group):
|
||||
super().__init__(cpu_group)
|
||||
|
||||
def naive_multicast(self, x: torch.Tensor,
|
||||
cu_tokens_across_dp_cpu: torch.Tensor):
|
||||
assert (len(x.shape) == 2)
|
||||
buffer = torch.empty((cu_tokens_across_dp_cpu[-1], x.size(1)),
|
||||
device=x.device,
|
||||
dtype=x.dtype)
|
||||
|
||||
start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[
|
||||
self.dp_rank - 1]
|
||||
end = cu_tokens_across_dp_cpu[self.dp_rank]
|
||||
buffer[start:end, :].copy_(x)
|
||||
for idx in range(self.dp_world_size):
|
||||
start = 0 if idx == 0 else cu_tokens_across_dp_cpu[idx - 1]
|
||||
end = cu_tokens_across_dp_cpu[idx]
|
||||
self.dp_group.broadcast(buffer[start:end, :], idx)
|
||||
|
||||
return buffer
|
||||
|
||||
def dispatch(self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor):
|
||||
cu_tokens_across_dp_cpu = get_forward_context(
|
||||
).dp_metadata.cu_tokens_across_dp_cpu
|
||||
|
||||
hidden_states = self.naive_multicast(hidden_states,
|
||||
cu_tokens_across_dp_cpu)
|
||||
router_logits = self.naive_multicast(router_logits,
|
||||
cu_tokens_across_dp_cpu)
|
||||
return hidden_states, router_logits
|
||||
|
||||
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
cu_tokens_across_dp_cpu = get_forward_context(
|
||||
).dp_metadata.cu_tokens_across_dp_cpu
|
||||
start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[
|
||||
self.dp_rank - 1]
|
||||
end = cu_tokens_across_dp_cpu[self.dp_rank]
|
||||
|
||||
all_hidden_states = self.dp_group.all_reduce(hidden_states)
|
||||
hidden_states = all_hidden_states[start:end, :]
|
||||
return hidden_states
|
||||
|
||||
def destroy(self):
|
||||
pass
|
||||
|
||||
|
||||
class PPLXAll2AllManager(All2AllManagerBase):
|
||||
"""
|
||||
All2All communication based on PPLX kernels.
|
||||
"""
|
||||
|
||||
def __init__(self, cpu_group):
|
||||
has_pplx = importlib.util.find_spec("pplx_kernels") is not None
|
||||
assert has_pplx, "pplx_kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install pplx_kernels." # noqa
|
||||
super().__init__(cpu_group)
|
||||
|
||||
if self.internode:
|
||||
# inter-node communication needs nvshmem,
|
||||
# intra-node communication uses p2p mapping directly
|
||||
from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id,
|
||||
nvshmem_get_unique_id,
|
||||
nvshmem_init)
|
||||
logger.debug(
|
||||
"Initialize NVSHMEM for pplx_kernels: "
|
||||
"rank=%d, world size=%d", self.rank, self.world_size)
|
||||
uid = nvshmem_get_unique_id(
|
||||
) if self.rank == 0 else nvshmem_alloc_empty_unique_id()
|
||||
dist.broadcast(uid,
|
||||
src=dist.get_process_group_ranks(self.cpu_group)[0],
|
||||
group=self.cpu_group)
|
||||
logger.debug("PPLX NVSHMEM UID = %s", uid)
|
||||
nvshmem_init(uid, self.rank, self.world_size)
|
||||
|
||||
self.handle_cache = Cache()
|
||||
|
||||
def get_handle(self, kwargs):
|
||||
import pplx_kernels as pplx
|
||||
return self.handle_cache.get_or_create(
|
||||
kwargs, pplx.AllToAll.internode
|
||||
if self.internode else pplx.AllToAll.intranode)
|
||||
|
||||
def dispatch(self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor):
|
||||
raise NotImplementedError
|
||||
|
||||
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def destroy(self):
|
||||
with self.handle_cache._lock:
|
||||
for _, handle in self.handle_cache._cache.items():
|
||||
handle.destroy()
|
||||
|
||||
if self.internode:
|
||||
from pplx_kernels.nvshmem import nvshmem_finalize
|
||||
logger.debug("PPLX NVSHMEM finalize")
|
||||
nvshmem_finalize()
|
||||
|
||||
|
||||
class DeepEPAll2AllManagerBase(All2AllManagerBase):
|
||||
"""
|
||||
All2All communication based on DeepEP High-Throughput kernels.
|
||||
"""
|
||||
|
||||
def __init__(self, cpu_group):
|
||||
has_deepep = importlib.util.find_spec("deep_ep") is not None
|
||||
assert has_deepep, "DeepEP kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install DeepEP kernels." # noqa
|
||||
super().__init__(cpu_group)
|
||||
self.handle_cache = Cache()
|
||||
|
||||
# This is the DeepEP default. Stick to it till we can establish
|
||||
# reasonable defaults based on profiling.
|
||||
self.num_sms = 20
|
||||
|
||||
def get_handle(self, kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def dispatch(self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor):
|
||||
raise NotImplementedError
|
||||
|
||||
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def destroy(self):
|
||||
pass
|
||||
|
||||
|
||||
class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
|
||||
"""
|
||||
All2All communication based on DeepEP High-Throughput kernels.
|
||||
"""
|
||||
|
||||
def __init__(self, cpu_group):
|
||||
super().__init__(cpu_group)
|
||||
|
||||
def _make_all2all_kwargs(self) -> dict[Any, Any]:
|
||||
# Defaults for internode and intranode are taken from DeepEP tests.
|
||||
num_nvl_bytes = 1024 * 1024 * 1024
|
||||
num_rdma_bytes = None
|
||||
num_qps_per_rank = None
|
||||
|
||||
if self.internode:
|
||||
num_rdma_bytes = 1024 * 1024 * 1024
|
||||
num_qps_per_rank = self.num_sms // 2
|
||||
else:
|
||||
num_rdma_bytes = 0
|
||||
num_qps_per_rank = 1
|
||||
|
||||
assert num_rdma_bytes is not None
|
||||
assert num_qps_per_rank is not None
|
||||
return dict(group=self.cpu_group,
|
||||
num_nvl_bytes=num_nvl_bytes,
|
||||
num_rdma_bytes=num_rdma_bytes,
|
||||
low_latency_mode=False,
|
||||
num_qps_per_rank=num_qps_per_rank)
|
||||
|
||||
def get_handle(self, kwargs):
|
||||
|
||||
assert len(kwargs) == 0, (
|
||||
"DeepEPHTAll2AllManager expects no arguments. All the required "
|
||||
"args are computed in the Manager itself.")
|
||||
|
||||
import deep_ep
|
||||
buffer_kwargs = self._make_all2all_kwargs()
|
||||
logger.debug("DeepEP all2all args %s", buffer_kwargs)
|
||||
handle: deep_ep.Buffer = self.handle_cache.get_or_create(
|
||||
buffer_kwargs, deep_ep.Buffer)
|
||||
# It is dangerous to set num sms outside this function. num_sms is not
|
||||
# a part of the hash-key that identifies this object. If we are in a
|
||||
# situation where we make objects with different num_sms, the hash key
|
||||
# in get_or_create must be updated.
|
||||
handle.set_num_sms(self.num_sms)
|
||||
return handle
|
||||
|
||||
|
||||
class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
|
||||
"""
|
||||
All2All communication based on DeepEP Low-Latency kernels.
|
||||
"""
|
||||
|
||||
def __init__(self, cpu_group):
|
||||
super().__init__(cpu_group)
|
||||
|
||||
def _make_all2all_kwargs(
|
||||
self,
|
||||
max_num_tokens_per_dp_rank: int,
|
||||
token_hidden_size: int,
|
||||
num_ep_ranks: int,
|
||||
num_global_experts: int,
|
||||
num_local_experts: int,
|
||||
) -> dict[Any, Any]:
|
||||
"""
|
||||
max_num_tokens_per_dp_rank : the maximum number of tokens a DP rank
|
||||
can dispatch all the ranks must hold the same value.
|
||||
token_hidden_size: the hidden dimension of each token.
|
||||
num_ep_ranks: the number of EP group ranks.
|
||||
num_global_experts: Number of experts in the model.
|
||||
num_local_experts: Number of experts in an EP rank.
|
||||
"""
|
||||
import deep_ep
|
||||
|
||||
# Defaults for internode and intranode are taken from DeepEP tests.
|
||||
num_nvl_bytes = 1024 * 1024 * 1024
|
||||
num_qps_per_rank = num_local_experts
|
||||
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(
|
||||
num_max_dispatch_tokens_per_rank=max_num_tokens_per_dp_rank,
|
||||
hidden=token_hidden_size,
|
||||
num_ranks=num_ep_ranks,
|
||||
num_experts=num_global_experts)
|
||||
|
||||
assert num_rdma_bytes is not None
|
||||
return dict(group=self.cpu_group,
|
||||
num_nvl_bytes=num_nvl_bytes,
|
||||
num_rdma_bytes=num_rdma_bytes,
|
||||
low_latency_mode=True,
|
||||
num_qps_per_rank=num_qps_per_rank)
|
||||
|
||||
def get_handle(self, kwargs):
|
||||
"""
|
||||
The kwargs for DeepEPLLAll2AllManager is dictated by
|
||||
_make_all2all_kwargs.
|
||||
"""
|
||||
import deep_ep
|
||||
buffer_kwargs = self._make_all2all_kwargs(**kwargs)
|
||||
logger.debug("DeepEP all2all args %s", buffer_kwargs)
|
||||
handle: deep_ep.Buffer = self.handle_cache.get_or_create(
|
||||
buffer_kwargs, deep_ep.Buffer)
|
||||
# It is dangerous to set num sms outside this function. num_sms is not
|
||||
# a part of the hash-key that identifies this object. If we are in a
|
||||
# situation where we make objects with different num_sms, the hash key
|
||||
# in get_or_create must be updated.
|
||||
handle.set_num_sms(self.num_sms)
|
||||
return handle
|
||||
@@ -0,0 +1,260 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import threading
|
||||
from typing import Optional
|
||||
from weakref import WeakValueDictionary
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
|
||||
class Cache:
|
||||
|
||||
def __init__(self):
|
||||
self._cache: WeakValueDictionary = WeakValueDictionary()
|
||||
self._lock = threading.RLock() # Reentrant lock for thread safety
|
||||
|
||||
def get_or_create(self, kwargs, func):
|
||||
# Create a hashable key from the kwargs
|
||||
key = tuple(sorted((k, v) for k, v in kwargs.items()))
|
||||
|
||||
with self._lock:
|
||||
instance = self._cache.get(key)
|
||||
if instance is None:
|
||||
instance = func(**kwargs)
|
||||
self._cache[key] = instance
|
||||
return instance
|
||||
|
||||
|
||||
class All2AllManagerBase:
|
||||
|
||||
def __init__(self, cpu_group):
|
||||
self.cpu_group = cpu_group
|
||||
|
||||
# compute some common properties
|
||||
from vllm.distributed.parallel_state import (get_dp_group,
|
||||
get_tp_group,
|
||||
in_the_same_node_as)
|
||||
|
||||
# all2all lives in ep group, which is merged from dp and tp group
|
||||
self.dp_group = get_dp_group()
|
||||
self.tp_group = get_tp_group()
|
||||
# no self.ep_group since self.ep_group is still in construction
|
||||
# when we create this object
|
||||
self.dp_rank = self.dp_group.rank_in_group
|
||||
self.dp_world_size = self.dp_group.world_size
|
||||
self.rank = dist.get_rank(cpu_group)
|
||||
self.world_size = dist.get_world_size(cpu_group)
|
||||
|
||||
# all2all communication often has separate implementations for
|
||||
# intra-node and inter-node communication
|
||||
self.internode = not all(in_the_same_node_as(cpu_group, source_rank=0))
|
||||
|
||||
def get_handle(self, kwargs):
|
||||
# get a handle for the all2all communication,
|
||||
# based on the kwargs.
|
||||
# different layers can have different configs,
|
||||
# e.g. one layer has hidden size 1024, another has 2048.
|
||||
# usually the underlying implementation caches the handle
|
||||
# and reuse it for the same config.
|
||||
raise NotImplementedError
|
||||
|
||||
def dispatch(self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor):
|
||||
raise NotImplementedError
|
||||
|
||||
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def destroy(self):
|
||||
pass
|
||||
|
||||
|
||||
class DeviceCommunicatorBase:
|
||||
"""
|
||||
Base class for device-specific communicator.
|
||||
It can use the `cpu_group` to initialize the communicator.
|
||||
If the device has PyTorch integration (PyTorch can recognize its
|
||||
communication backend), the `device_group` will also be given.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
cpu_group: ProcessGroup,
|
||||
device: Optional[torch.device] = None,
|
||||
device_group: Optional[ProcessGroup] = None,
|
||||
unique_name: str = ""):
|
||||
self.device = device or torch.device("cpu")
|
||||
self.cpu_group = cpu_group
|
||||
self.device_group = device_group
|
||||
self.unique_name = unique_name
|
||||
self.rank = dist.get_rank(cpu_group)
|
||||
self.world_size = dist.get_world_size(cpu_group)
|
||||
self.ranks = dist.get_process_group_ranks(cpu_group)
|
||||
self.global_rank = dist.get_rank()
|
||||
self.global_world_size = dist.get_world_size()
|
||||
self.rank_in_group = dist.get_group_rank(self.cpu_group,
|
||||
self.global_rank)
|
||||
|
||||
use_ep = False
|
||||
from vllm.config import get_current_vllm_config
|
||||
config = get_current_vllm_config()
|
||||
if config is not None:
|
||||
# as long as we use data parallel (coupled data parallel
|
||||
# where all data parallel ranks execute forward together),
|
||||
# we initialize the all2all manager used in expert parallel.
|
||||
use_ep = config.parallel_config.data_parallel_size > 1
|
||||
|
||||
self.use_all2all = "ep" in unique_name and use_ep
|
||||
self.all2all_manager: Optional[All2AllManagerBase] = None
|
||||
|
||||
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
|
||||
dist.all_reduce(input_, group=self.device_group)
|
||||
return input_
|
||||
|
||||
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
||||
if dim < 0:
|
||||
# Convert negative dim to positive.
|
||||
dim += input_.dim()
|
||||
input_size = input_.size()
|
||||
# NOTE: we have to use concat-style all-gather here,
|
||||
# stack-style all-gather has compatibility issues with
|
||||
# torch.compile . see https://github.com/pytorch/pytorch/issues/138795
|
||||
output_size = (input_size[0] * self.world_size, ) + input_size[1:]
|
||||
# Allocate output tensor.
|
||||
output_tensor = torch.empty(output_size,
|
||||
dtype=input_.dtype,
|
||||
device=input_.device)
|
||||
# All-gather.
|
||||
dist.all_gather_into_tensor(output_tensor,
|
||||
input_,
|
||||
group=self.device_group)
|
||||
# Reshape
|
||||
output_tensor = output_tensor.reshape((self.world_size, ) + input_size)
|
||||
output_tensor = output_tensor.movedim(0, dim)
|
||||
output_tensor = output_tensor.reshape(input_size[:dim] +
|
||||
(self.world_size *
|
||||
input_size[dim], ) +
|
||||
input_size[dim + 1:])
|
||||
return output_tensor
|
||||
|
||||
def reduce_scatter(self,
|
||||
input_: torch.Tensor,
|
||||
dim: int = -1) -> torch.Tensor:
|
||||
world_size = self.world_size
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if world_size == 1:
|
||||
return input_
|
||||
assert -input_.dim() <= dim < input_.dim(), (
|
||||
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
|
||||
|
||||
if dim < 0:
|
||||
# Convert negative dim to positive.
|
||||
dim += input_.dim()
|
||||
|
||||
# Note: This will produce an incorrect answer if we don't make
|
||||
# the input_tensor contiguous. Possible bug in reduce_scatter_tensor?
|
||||
input_tensor = input_.movedim(0, dim).contiguous()
|
||||
|
||||
assert input_tensor.shape[0] % world_size == 0
|
||||
chunk_size = input_tensor.shape[0] // world_size
|
||||
output_shape = (chunk_size, ) + input_tensor.shape[1:]
|
||||
|
||||
output_tensor = torch.empty(output_shape,
|
||||
dtype=input_tensor.dtype,
|
||||
device=input_tensor.device)
|
||||
|
||||
# Perform reduce-scatter operation
|
||||
torch.distributed.reduce_scatter_tensor(output_tensor,
|
||||
input_tensor,
|
||||
group=self.device_group)
|
||||
|
||||
# Reshape before returning
|
||||
return output_tensor.movedim(0, dim).contiguous()
|
||||
|
||||
def gather(self,
|
||||
input_: torch.Tensor,
|
||||
dst: int = 0,
|
||||
dim: int = -1) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
NOTE: We assume that the input tensor is on the same device across
|
||||
all the ranks.
|
||||
NOTE: `dst` is the local rank of the destination rank.
|
||||
"""
|
||||
world_size = self.world_size
|
||||
assert -input_.dim() <= dim < input_.dim(), (
|
||||
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
|
||||
if dim < 0:
|
||||
# Convert negative dim to positive.
|
||||
dim += input_.dim()
|
||||
|
||||
# Allocate output tensor.
|
||||
if self.rank_in_group == dst:
|
||||
gather_list = [torch.empty_like(input_) for _ in range(world_size)]
|
||||
else:
|
||||
gather_list = None
|
||||
# Gather.
|
||||
torch.distributed.gather(input_,
|
||||
gather_list,
|
||||
dst=self.ranks[dst],
|
||||
group=self.device_group)
|
||||
if self.rank_in_group == dst:
|
||||
output_tensor = torch.cat(gather_list, dim=dim)
|
||||
else:
|
||||
output_tensor = None
|
||||
return output_tensor
|
||||
|
||||
def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:
|
||||
"""Sends a tensor to the destination rank in a non-blocking way"""
|
||||
"""NOTE: `dst` is the local rank of the destination rank."""
|
||||
if dst is None:
|
||||
dst = (self.rank_in_group + 1) % self.world_size
|
||||
torch.distributed.send(tensor, self.ranks[dst], self.device_group)
|
||||
|
||||
def recv(self,
|
||||
size: torch.Size,
|
||||
dtype: torch.dtype,
|
||||
src: Optional[int] = None) -> torch.Tensor:
|
||||
"""Receives a tensor from the source rank."""
|
||||
"""NOTE: `src` is the local rank of the source rank."""
|
||||
if src is None:
|
||||
src = (self.rank_in_group - 1) % self.world_size
|
||||
|
||||
tensor = torch.empty(size, dtype=dtype, device=self.device)
|
||||
torch.distributed.recv(tensor, self.ranks[src], self.device_group)
|
||||
return tensor
|
||||
|
||||
def destroy(self):
|
||||
pass
|
||||
|
||||
def prepare_communication_buffer_for_model(self,
|
||||
model: torch.nn.Module) -> None:
|
||||
"""
|
||||
Prepare the communication buffer for the model.
|
||||
"""
|
||||
if not self.use_all2all:
|
||||
return
|
||||
|
||||
moe_modules = [
|
||||
module for module in model.modules()
|
||||
if module.__class__.__name__ == "FusedMoE"
|
||||
]
|
||||
for module in moe_modules:
|
||||
module.quant_method.init_prepare_finalize(module.moe_config,
|
||||
module.quant_config)
|
||||
|
||||
def dispatch(
|
||||
self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Dispatch the hidden states and router logits to the appropriate device.
|
||||
This is a no-op in the base class.
|
||||
"""
|
||||
return hidden_states, router_logits
|
||||
|
||||
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Combine the hidden states and router logits from the appropriate device.
|
||||
This is a no-op in the base class.
|
||||
"""
|
||||
return hidden_states
|
||||
145
vllm/distributed/device_communicators/cpu_communicator.py
Normal file
145
vllm/distributed/device_communicators/cpu_communicator.py
Normal file
@@ -0,0 +1,145 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.interface import CpuArchEnum
|
||||
|
||||
from .base_device_communicator import DeviceCommunicatorBase
|
||||
|
||||
|
||||
class CpuCommunicator(DeviceCommunicatorBase):
|
||||
|
||||
def __init__(self,
|
||||
cpu_group: ProcessGroup,
|
||||
device: Optional[torch.device] = None,
|
||||
device_group: Optional[ProcessGroup] = None,
|
||||
unique_name: str = ""):
|
||||
super().__init__(cpu_group, device, device_group, unique_name)
|
||||
self.dist_module = torch.distributed
|
||||
|
||||
if (current_platform.get_cpu_architecture()
|
||||
== CpuArchEnum.X86) and hasattr(
|
||||
torch.ops._C,
|
||||
"init_shm_manager") and unique_name.startswith("tp"):
|
||||
self.dist_module = _CPUSHMDistributed(self)
|
||||
|
||||
def all_reduce(self, input_):
|
||||
self.dist_module.all_reduce(input_, group=self.device_group)
|
||||
return input_
|
||||
|
||||
def gather(self,
|
||||
input_: torch.Tensor,
|
||||
dst: int = 0,
|
||||
dim: int = -1) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
NOTE: We assume that the input tensor is on the same device across
|
||||
all the ranks.
|
||||
NOTE: `dst` is the local rank of the destination rank.
|
||||
"""
|
||||
world_size = self.world_size
|
||||
assert -input_.dim() <= dim < input_.dim(), (
|
||||
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
|
||||
if dim < 0:
|
||||
# Convert negative dim to positive.
|
||||
dim += input_.dim()
|
||||
|
||||
# Allocate output tensor.
|
||||
if self.rank_in_group == dst:
|
||||
gather_list = [torch.empty_like(input_) for _ in range(world_size)]
|
||||
else:
|
||||
gather_list = None
|
||||
|
||||
# Gather.
|
||||
self.dist_module.gather(input_,
|
||||
gather_list,
|
||||
dst=self.ranks[dst],
|
||||
group=self.device_group)
|
||||
|
||||
if self.rank_in_group == dst:
|
||||
output_tensor = torch.cat(gather_list, dim=dim)
|
||||
else:
|
||||
output_tensor = None
|
||||
return output_tensor
|
||||
|
||||
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
||||
if dim < 0:
|
||||
# Convert negative dim to positive.
|
||||
dim += input_.dim()
|
||||
input_size = input_.size()
|
||||
# NOTE: we have to use concat-style all-gather here,
|
||||
# stack-style all-gather has compatibility issues with
|
||||
# torch.compile . see https://github.com/pytorch/pytorch/issues/138795
|
||||
output_size = (input_size[0] * self.world_size, ) + input_size[1:]
|
||||
# Allocate output tensor.
|
||||
output_tensor = torch.empty(output_size,
|
||||
dtype=input_.dtype,
|
||||
device=input_.device)
|
||||
# All-gather.
|
||||
self.dist_module.all_gather_into_tensor(output_tensor,
|
||||
input_,
|
||||
group=self.device_group)
|
||||
|
||||
# Reshape
|
||||
output_tensor = output_tensor.reshape((self.world_size, ) + input_size)
|
||||
output_tensor = output_tensor.movedim(0, dim)
|
||||
output_tensor = output_tensor.reshape(input_size[:dim] +
|
||||
(self.world_size *
|
||||
input_size[dim], ) +
|
||||
input_size[dim + 1:])
|
||||
return output_tensor
|
||||
|
||||
|
||||
class _CPUSHMDistributed:
|
||||
|
||||
def __init__(self, communicator: CpuCommunicator):
|
||||
instance_identifier = os.environ["VLLM_DIST_IDENT"]
|
||||
unique_name = communicator.unique_name
|
||||
instance_identifier = f"{instance_identifier}-{unique_name}"
|
||||
self.communicator = communicator
|
||||
|
||||
group_ranks = [str(rank) for rank in self.communicator.ranks]
|
||||
shm_group_identifier = f"[{'-'.join(group_ranks)}]"
|
||||
self.group_name = f"{instance_identifier}-{shm_group_identifier}-cpushm"
|
||||
|
||||
self.handle = self._init_cpu_shm()
|
||||
|
||||
def _init_cpu_shm(self) -> int:
|
||||
handle = torch.ops._C.init_shm_manager(
|
||||
self.group_name,
|
||||
self.communicator.world_size,
|
||||
self.communicator.rank,
|
||||
)
|
||||
torch.distributed.barrier(self.communicator.device_group)
|
||||
torch.ops._C.join_shm_manager(
|
||||
handle,
|
||||
self.group_name,
|
||||
)
|
||||
torch.distributed.barrier(self.communicator.device_group)
|
||||
|
||||
return handle
|
||||
|
||||
def all_reduce(self,
|
||||
input: torch.Tensor,
|
||||
group: Optional[ProcessGroup] = None) -> None:
|
||||
torch.ops._C.shm_allreduce(self.handle, input)
|
||||
|
||||
def gather(self,
|
||||
input: torch.Tensor,
|
||||
gather_list: Optional[list[torch.Tensor]],
|
||||
dst: int = -1,
|
||||
group: Optional[ProcessGroup] = None) -> None:
|
||||
# Note: different from the torch gather, here we use local dst rank.
|
||||
torch.ops._C.shm_gather(self.handle, input, gather_list,
|
||||
torch.distributed.get_group_rank(group, dst))
|
||||
|
||||
def all_gather_into_tensor(self,
|
||||
output: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
group: Optional[ProcessGroup] = None) -> None:
|
||||
torch.ops._C.shm_all_gather(self.handle, input, output)
|
||||
176
vllm/distributed/device_communicators/cuda_communicator.py
Normal file
176
vllm/distributed/device_communicators/cuda_communicator.py
Normal file
@@ -0,0 +1,176 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .base_device_communicator import DeviceCommunicatorBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class CudaCommunicator(DeviceCommunicatorBase):
|
||||
|
||||
def __init__(self,
|
||||
cpu_group: ProcessGroup,
|
||||
device: Optional[torch.device] = None,
|
||||
device_group: Optional[ProcessGroup] = None,
|
||||
unique_name: str = ""):
|
||||
super().__init__(cpu_group, device, device_group, unique_name)
|
||||
if "tp" not in unique_name:
|
||||
# only tp uses custom allreduce
|
||||
use_custom_allreduce = False
|
||||
else:
|
||||
from vllm.distributed.parallel_state import (
|
||||
_ENABLE_CUSTOM_ALL_REDUCE)
|
||||
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
|
||||
|
||||
# ep does not use pynccl
|
||||
use_pynccl = "ep" not in unique_name
|
||||
|
||||
self.use_pynccl = use_pynccl
|
||||
self.use_custom_allreduce = use_custom_allreduce
|
||||
|
||||
# lazy import to avoid documentation build error
|
||||
from vllm.distributed.device_communicators.custom_all_reduce import (
|
||||
CustomAllreduce)
|
||||
from vllm.distributed.device_communicators.pynccl import (
|
||||
PyNcclCommunicator)
|
||||
|
||||
self.pynccl_comm: Optional[PyNcclCommunicator] = None
|
||||
if use_pynccl and self.world_size > 1:
|
||||
self.pynccl_comm = PyNcclCommunicator(
|
||||
group=self.cpu_group,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
self.ca_comm: Optional[CustomAllreduce] = None
|
||||
if use_custom_allreduce and self.world_size > 1:
|
||||
# Initialize a custom fast all-reduce implementation.
|
||||
self.ca_comm = CustomAllreduce(
|
||||
group=self.cpu_group,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
if self.use_all2all:
|
||||
all2all_backend = envs.VLLM_ALL2ALL_BACKEND
|
||||
if all2all_backend == "naive":
|
||||
from .all2all import NaiveAll2AllManager
|
||||
self.all2all_manager = NaiveAll2AllManager(self.cpu_group)
|
||||
logger.info("Using naive all2all manager.")
|
||||
elif all2all_backend == "pplx":
|
||||
from .all2all import PPLXAll2AllManager
|
||||
self.all2all_manager = PPLXAll2AllManager(self.cpu_group)
|
||||
logger.info("Using PPLX all2all manager.")
|
||||
elif all2all_backend == "deepep_high_throughput":
|
||||
from .all2all import DeepEPHTAll2AllManager
|
||||
self.all2all_manager = DeepEPHTAll2AllManager(self.cpu_group)
|
||||
logger.info("Using DeepEP High-Throughput all2all manager.")
|
||||
elif all2all_backend == "deepep_low_latency":
|
||||
from .all2all import DeepEPLLAll2AllManager
|
||||
self.all2all_manager = DeepEPLLAll2AllManager(self.cpu_group)
|
||||
logger.info("Using DeepEP Low-Latency all2all manager.")
|
||||
else:
|
||||
raise ValueError(f"Unknown all2all backend: {all2all_backend}")
|
||||
|
||||
def all_reduce(self, input_):
|
||||
# always try custom allreduce first,
|
||||
# and then pynccl.
|
||||
ca_comm = self.ca_comm
|
||||
if ca_comm is not None and not ca_comm.disabled and \
|
||||
ca_comm.should_custom_ar(input_):
|
||||
out = ca_comm.custom_all_reduce(input_)
|
||||
assert out is not None
|
||||
return out
|
||||
pynccl_comm = self.pynccl_comm
|
||||
assert pynccl_comm is not None
|
||||
out = pynccl_comm.all_reduce(input_)
|
||||
if out is None:
|
||||
# fall back to the default all-reduce using PyTorch.
|
||||
# this usually happens during testing.
|
||||
# when we run the model, allreduce only happens for the TP
|
||||
# group, where we always have either custom allreduce or pynccl.
|
||||
out = input_.clone()
|
||||
torch.distributed.all_reduce(out, group=self.device_group)
|
||||
return out
|
||||
|
||||
def reduce_scatter(self, input_: torch.Tensor, dim: int = -1):
|
||||
world_size = self.world_size
|
||||
pynccl_comm = self.pynccl_comm
|
||||
assert pynccl_comm is not None
|
||||
if dim < 0:
|
||||
# Convert negative dim to positive.
|
||||
dim += input_.dim()
|
||||
|
||||
# Note: This will produce an incorrect answer if we don't make
|
||||
# the input_tensor contiguous. Possible bug in reduce_scatter_tensor?
|
||||
input_tensor = input_.movedim(0, dim).contiguous()
|
||||
|
||||
assert input_tensor.shape[0] % world_size == 0
|
||||
chunk_size = input_tensor.shape[0] // world_size
|
||||
output_shape = (chunk_size, ) + input_tensor.shape[1:]
|
||||
|
||||
output = torch.empty(output_shape,
|
||||
dtype=input_tensor.dtype,
|
||||
device=input_tensor.device)
|
||||
|
||||
pynccl_comm.reduce_scatter(output, input_)
|
||||
|
||||
# Reshape before returning
|
||||
return output.movedim(0, dim).contiguous()
|
||||
|
||||
def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:
|
||||
"""Sends a tensor to the destination rank in a non-blocking way"""
|
||||
"""NOTE: `dst` is the local rank of the destination rank."""
|
||||
if dst is None:
|
||||
dst = (self.rank_in_group + 1) % self.world_size
|
||||
|
||||
pynccl_comm = self.pynccl_comm
|
||||
if pynccl_comm is not None and not pynccl_comm.disabled:
|
||||
pynccl_comm.send(tensor, dst)
|
||||
else:
|
||||
torch.distributed.send(tensor, self.ranks[dst], self.device_group)
|
||||
|
||||
def recv(self,
|
||||
size: torch.Size,
|
||||
dtype: torch.dtype,
|
||||
src: Optional[int] = None) -> torch.Tensor:
|
||||
"""Receives a tensor from the source rank."""
|
||||
"""NOTE: `src` is the local rank of the source rank."""
|
||||
if src is None:
|
||||
src = (self.rank_in_group - 1) % self.world_size
|
||||
|
||||
tensor = torch.empty(size, dtype=dtype, device=self.device)
|
||||
pynccl_comm = self.pynccl_comm
|
||||
if pynccl_comm is not None and not pynccl_comm.disabled:
|
||||
pynccl_comm.recv(tensor, src)
|
||||
else:
|
||||
torch.distributed.recv(tensor, self.ranks[src], self.device_group)
|
||||
return tensor
|
||||
|
||||
def destroy(self):
|
||||
if self.pynccl_comm is not None:
|
||||
self.pynccl_comm = None
|
||||
if self.ca_comm is not None:
|
||||
self.ca_comm = None
|
||||
if self.all2all_manager is not None:
|
||||
self.all2all_manager.destroy()
|
||||
self.all2all_manager = None
|
||||
|
||||
def dispatch(
|
||||
self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self.all2all_manager is not None
|
||||
hidden_states, router_logits = self.all2all_manager.dispatch(
|
||||
hidden_states, router_logits)
|
||||
return hidden_states, router_logits
|
||||
|
||||
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
assert self.all2all_manager is not None
|
||||
hidden_states = self.all2all_manager.combine(hidden_states)
|
||||
return hidden_states
|
||||
180
vllm/distributed/device_communicators/cuda_wrapper.py
Normal file
180
vllm/distributed/device_communicators/cuda_wrapper.py
Normal file
@@ -0,0 +1,180 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""This file is a pure Python wrapper for the cudart library.
|
||||
It avoids the need to compile a separate shared library, and is
|
||||
convenient for use when we just need to call a few functions.
|
||||
"""
|
||||
|
||||
import ctypes
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
|
||||
# this line makes it possible to directly load `libcudart.so` using `ctypes`
|
||||
import torch # noqa
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# === export types and functions from cudart to Python ===
|
||||
# for the original cudart definition, please check
|
||||
# https://docs.nvidia.com/cuda/cuda-runtime-api/index.html
|
||||
|
||||
cudaError_t = ctypes.c_int
|
||||
cudaMemcpyKind = ctypes.c_int
|
||||
|
||||
|
||||
class cudaIpcMemHandle_t(ctypes.Structure):
|
||||
_fields_ = [("internal", ctypes.c_byte * 128)]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Function:
|
||||
name: str
|
||||
restype: Any
|
||||
argtypes: list[Any]
|
||||
|
||||
|
||||
def find_loaded_library(lib_name) -> Optional[str]:
|
||||
"""
|
||||
According to according to https://man7.org/linux/man-pages/man5/proc_pid_maps.5.html,
|
||||
the file `/proc/self/maps` contains the memory maps of the process, which includes the
|
||||
shared libraries loaded by the process. We can use this file to find the path of the
|
||||
a loaded library.
|
||||
""" # noqa
|
||||
found = False
|
||||
with open("/proc/self/maps") as f:
|
||||
for line in f:
|
||||
if lib_name in line:
|
||||
found = True
|
||||
break
|
||||
if not found:
|
||||
# the library is not loaded in the current process
|
||||
return None
|
||||
# if lib_name is libcudart, we need to match a line with:
|
||||
# address /path/to/libcudart-hash.so.11.0
|
||||
start = line.index("/")
|
||||
path = line[start:].strip()
|
||||
filename = path.split("/")[-1]
|
||||
assert filename.rpartition(".so")[0].startswith(lib_name), \
|
||||
f"Unexpected filename: {filename} for library {lib_name}"
|
||||
return path
|
||||
|
||||
|
||||
class CudaRTLibrary:
|
||||
exported_functions = [
|
||||
# cudaError_t cudaSetDevice ( int device )
|
||||
Function("mcSetDevice", cudaError_t, [ctypes.c_int]),
|
||||
# cudaError_t cudaDeviceSynchronize ( void )
|
||||
Function("mcDeviceSynchronize", cudaError_t, []),
|
||||
# cudaError_t cudaDeviceReset ( void )
|
||||
Function("mcDeviceReset", cudaError_t, []),
|
||||
|
||||
# const char* cudaGetErrorString ( cudaError_t error )
|
||||
Function("mcGetErrorString", ctypes.c_char_p, [cudaError_t]),
|
||||
|
||||
# cudaError_t cudaMalloc ( void** devPtr, size_t size )
|
||||
Function("mcMalloc", cudaError_t,
|
||||
[ctypes.POINTER(ctypes.c_void_p), ctypes.c_size_t]),
|
||||
# cudaError_t cudaFree ( void* devPtr )
|
||||
Function("mcFree", cudaError_t, [ctypes.c_void_p]),
|
||||
# cudaError_t cudaMemset ( void* devPtr, int value, size_t count )
|
||||
Function("mcMemset", cudaError_t,
|
||||
[ctypes.c_void_p, ctypes.c_int, ctypes.c_size_t]),
|
||||
# cudaError_t cudaMemcpy ( void* dst, const void* src, size_t count, cudaMemcpyKind kind ) # noqa
|
||||
Function("mcMemcpy", cudaError_t, [
|
||||
ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, cudaMemcpyKind
|
||||
]),
|
||||
|
||||
# cudaError_t cudaIpcGetMemHandle ( cudaIpcMemHandle_t* handle, void* devPtr ) # noqa
|
||||
Function("mcIpcGetMemHandle", cudaError_t,
|
||||
[ctypes.POINTER(cudaIpcMemHandle_t), ctypes.c_void_p]),
|
||||
# cudaError_t cudaIpcOpenMemHandle ( void** devPtr, cudaIpcMemHandle_t handle, unsigned int flags ) # noqa
|
||||
Function("mcIpcOpenMemHandle", cudaError_t, [
|
||||
ctypes.POINTER(ctypes.c_void_p), cudaIpcMemHandle_t, ctypes.c_uint
|
||||
]),
|
||||
]
|
||||
|
||||
# class attribute to store the mapping from the path to the library
|
||||
# to avoid loading the same library multiple times
|
||||
path_to_library_cache: dict[str, Any] = {}
|
||||
|
||||
# class attribute to store the mapping from library path
|
||||
# to the corresponding dictionary
|
||||
path_to_dict_mapping: dict[str, dict[str, Any]] = {}
|
||||
|
||||
def __init__(self, so_file: Optional[str] = None):
|
||||
if so_file is None:
|
||||
so_file = find_loaded_library("libmcruntime")
|
||||
if so_file is None:
|
||||
so_file = envs.VLLM_CUDART_SO_PATH # fallback to env var
|
||||
assert so_file is not None, \
|
||||
(
|
||||
"libcudart is not loaded in the current process, "
|
||||
"try setting VLLM_CUDART_SO_PATH"
|
||||
)
|
||||
if so_file not in CudaRTLibrary.path_to_library_cache:
|
||||
lib = ctypes.CDLL(so_file)
|
||||
CudaRTLibrary.path_to_library_cache[so_file] = lib
|
||||
self.lib = CudaRTLibrary.path_to_library_cache[so_file]
|
||||
|
||||
if so_file not in CudaRTLibrary.path_to_dict_mapping:
|
||||
_funcs = {}
|
||||
for func in CudaRTLibrary.exported_functions:
|
||||
f = getattr(self.lib, func.name)
|
||||
f.restype = func.restype
|
||||
f.argtypes = func.argtypes
|
||||
_funcs[func.name] = f
|
||||
CudaRTLibrary.path_to_dict_mapping[so_file] = _funcs
|
||||
self.funcs = CudaRTLibrary.path_to_dict_mapping[so_file]
|
||||
|
||||
def CUDART_CHECK(self, result: cudaError_t) -> None:
|
||||
if result != 0:
|
||||
error_str = self.cudaGetErrorString(result)
|
||||
raise RuntimeError(f"CUDART error: {error_str}")
|
||||
|
||||
def cudaGetErrorString(self, error: cudaError_t) -> str:
|
||||
return self.funcs["mcGetErrorString"](error).decode("utf-8")
|
||||
|
||||
def cudaSetDevice(self, device: int) -> None:
|
||||
self.CUDART_CHECK(self.funcs["mcSetDevice"](device))
|
||||
|
||||
def cudaDeviceSynchronize(self) -> None:
|
||||
self.CUDART_CHECK(self.funcs["mcDeviceSynchronize"]())
|
||||
|
||||
def cudaDeviceReset(self) -> None:
|
||||
self.CUDART_CHECK(self.funcs["mcDeviceReset"]())
|
||||
|
||||
def cudaMalloc(self, size: int) -> ctypes.c_void_p:
|
||||
devPtr = ctypes.c_void_p()
|
||||
self.CUDART_CHECK(self.funcs["mcMalloc"](ctypes.byref(devPtr), size))
|
||||
return devPtr
|
||||
|
||||
def cudaFree(self, devPtr: ctypes.c_void_p) -> None:
|
||||
self.CUDART_CHECK(self.funcs["mcFree"](devPtr))
|
||||
|
||||
def cudaMemset(self, devPtr: ctypes.c_void_p, value: int,
|
||||
count: int) -> None:
|
||||
self.CUDART_CHECK(self.funcs["mcMemset"](devPtr, value, count))
|
||||
|
||||
def cudaMemcpy(self, dst: ctypes.c_void_p, src: ctypes.c_void_p,
|
||||
count: int) -> None:
|
||||
cudaMemcpyDefault = 4
|
||||
kind = cudaMemcpyDefault
|
||||
self.CUDART_CHECK(self.funcs["mcMemcpy"](dst, src, count, kind))
|
||||
|
||||
def cudaIpcGetMemHandle(self,
|
||||
devPtr: ctypes.c_void_p) -> cudaIpcMemHandle_t:
|
||||
handle = cudaIpcMemHandle_t()
|
||||
self.CUDART_CHECK(self.funcs["mcIpcGetMemHandle"](
|
||||
ctypes.byref(handle), devPtr))
|
||||
return handle
|
||||
|
||||
def cudaIpcOpenMemHandle(self,
|
||||
handle: cudaIpcMemHandle_t) -> ctypes.c_void_p:
|
||||
cudaIpcMemLazyEnablePeerAccess = 1
|
||||
devPtr = ctypes.c_void_p()
|
||||
self.CUDART_CHECK(self.funcs["mcIpcOpenMemHandle"](
|
||||
ctypes.byref(devPtr), handle, cudaIpcMemLazyEnablePeerAccess))
|
||||
return devPtr
|
||||
304
vllm/distributed/device_communicators/custom_all_reduce.py
Normal file
304
vllm/distributed/device_communicators/custom_all_reduce.py
Normal file
@@ -0,0 +1,304 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.distributed.device_communicators.custom_all_reduce_utils import (
|
||||
gpu_p2p_access_check)
|
||||
from vllm.distributed.parallel_state import in_the_same_node_as
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import cuda_device_count_stateless
|
||||
|
||||
try:
|
||||
ops.meta_size()
|
||||
custom_ar = True
|
||||
except Exception:
|
||||
# For CPUs
|
||||
custom_ar = False
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def _can_p2p(rank: int, world_size: int) -> bool:
|
||||
for i in range(world_size):
|
||||
if i == rank:
|
||||
continue
|
||||
if envs.VLLM_SKIP_P2P_CHECK:
|
||||
logger.info(
|
||||
"Skipping P2P check and trusting the driver's P2P report.")
|
||||
return torch.cuda.can_device_access_peer(rank, i)
|
||||
if not gpu_p2p_access_check(rank, i):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def is_weak_contiguous(inp: torch.Tensor):
|
||||
return inp.is_contiguous() or (inp.storage().nbytes() -
|
||||
inp.storage_offset() * inp.element_size()
|
||||
== inp.numel() * inp.element_size())
|
||||
|
||||
|
||||
class CustomAllreduce:
|
||||
|
||||
_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
|
||||
|
||||
# max_size: max supported allreduce size
|
||||
def __init__(self,
|
||||
group: ProcessGroup,
|
||||
device: Union[int, str, torch.device],
|
||||
max_size=8192 * 1024) -> None:
|
||||
"""
|
||||
Args:
|
||||
group: the process group to work on. If None, it will use the
|
||||
default process group.
|
||||
device: the device to bind the CustomAllreduce to. If None,
|
||||
it will be bind to f"cuda:{local_rank}".
|
||||
It is the caller's responsibility to make sure each communicator
|
||||
is bind to a unique device, and all communicators in this group
|
||||
are in the same node.
|
||||
"""
|
||||
self._IS_CAPTURING = False
|
||||
self.disabled = True
|
||||
|
||||
if not custom_ar:
|
||||
# disable because of missing custom allreduce library
|
||||
# e.g. in a non-GPU environment
|
||||
logger.info("Custom allreduce is disabled because "
|
||||
"of missing custom allreduce library")
|
||||
return
|
||||
|
||||
self.group = group
|
||||
|
||||
assert dist.get_backend(group) != dist.Backend.NCCL, (
|
||||
"CustomAllreduce should be attached to a non-NCCL group.")
|
||||
|
||||
if not all(in_the_same_node_as(group, source_rank=0)):
|
||||
# No need to initialize custom allreduce for multi-node case.
|
||||
logger.warning(
|
||||
"Custom allreduce is disabled because this process group"
|
||||
" spans across nodes.")
|
||||
return
|
||||
|
||||
rank = dist.get_rank(group=self.group)
|
||||
self.rank = rank
|
||||
world_size = dist.get_world_size(group=self.group)
|
||||
if world_size == 1:
|
||||
# No need to initialize custom allreduce for single GPU case.
|
||||
return
|
||||
|
||||
if world_size not in CustomAllreduce._SUPPORTED_WORLD_SIZES:
|
||||
logger.warning(
|
||||
"Custom allreduce is disabled due to an unsupported world"
|
||||
" size: %d. Supported world sizes: %s. To silence this "
|
||||
"warning, specify disable_custom_all_reduce=True explicitly.",
|
||||
world_size, str(CustomAllreduce._SUPPORTED_WORLD_SIZES))
|
||||
return
|
||||
|
||||
if isinstance(device, int):
|
||||
device = torch.device(f"cuda:{device}")
|
||||
elif isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
# now `device` is a `torch.device` object
|
||||
assert isinstance(device, torch.device)
|
||||
self.device = device
|
||||
|
||||
cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
|
||||
if cuda_visible_devices:
|
||||
device_ids = list(map(int, cuda_visible_devices.split(",")))
|
||||
else:
|
||||
device_ids = list(range(cuda_device_count_stateless()))
|
||||
|
||||
physical_device_id = device_ids[device.index]
|
||||
tensor = torch.tensor([physical_device_id],
|
||||
dtype=torch.int,
|
||||
device="cpu")
|
||||
gather_list = [
|
||||
torch.tensor([0], dtype=torch.int, device="cpu")
|
||||
for _ in range(world_size)
|
||||
]
|
||||
dist.all_gather(gather_list, tensor, group=self.group)
|
||||
physical_device_ids = [t.item() for t in gather_list]
|
||||
|
||||
# test nvlink first, this will filter out most of the cases
|
||||
# where custom allreduce is not supported
|
||||
# this checks hardware and driver support for NVLink
|
||||
assert current_platform.is_cuda_alike()
|
||||
fully_connected = current_platform.is_fully_connected(
|
||||
physical_device_ids)
|
||||
if world_size > 2 and not fully_connected:
|
||||
logger.warning(
|
||||
"Custom allreduce is disabled because it's not supported on"
|
||||
" more than two PCIe-only GPUs. To silence this warning, "
|
||||
"specify disable_custom_all_reduce=True explicitly.")
|
||||
return
|
||||
# test P2P capability, this checks software/cudaruntime support
|
||||
# this is expensive to compute at the first time
|
||||
# then we cache the result
|
||||
# On AMD GPU, p2p is always enabled between XGMI connected GPUs
|
||||
if not current_platform.is_rocm() and not _can_p2p(rank, world_size):
|
||||
logger.warning(
|
||||
"Custom allreduce is disabled because your platform lacks "
|
||||
"GPU P2P capability or P2P test failed. To silence this "
|
||||
"warning, specify disable_custom_all_reduce=True explicitly.")
|
||||
return
|
||||
|
||||
self.disabled = False
|
||||
# Buffers memory are owned by this Python class and passed to C++.
|
||||
# Meta data composes of two parts: meta data for synchronization and a
|
||||
# temporary buffer for storing intermediate allreduce results.
|
||||
self.meta_ptrs = self.create_shared_buffer(ops.meta_size() + max_size,
|
||||
group=group,
|
||||
uncached=True)
|
||||
# This is a pre-registered IPC buffer. In eager mode, input tensors
|
||||
# are first copied into this buffer before allreduce is performed
|
||||
self.buffer_ptrs = self.create_shared_buffer(max_size, group=group)
|
||||
# This is a buffer for storing the tuples of pointers pointing to
|
||||
# IPC buffers from all ranks. Each registered tuple has size of
|
||||
# 8*world_size bytes where world_size is at most 8. Allocating 8MB
|
||||
# is enough for 131072 such tuples. The largest model I've seen only
|
||||
# needs less than 10000 of registered tuples.
|
||||
self.rank_data = torch.empty(8 * 1024 * 1024,
|
||||
dtype=torch.uint8,
|
||||
device=self.device)
|
||||
self.max_size = max_size
|
||||
self.rank = rank
|
||||
self.world_size = world_size
|
||||
self.fully_connected = fully_connected
|
||||
self._ptr = ops.init_custom_ar(self.meta_ptrs, self.rank_data, rank,
|
||||
self.fully_connected)
|
||||
ops.register_buffer(self._ptr, self.buffer_ptrs)
|
||||
|
||||
@contextmanager
|
||||
def capture(self):
|
||||
"""
|
||||
The main responsibility of this context manager is the
|
||||
`register_graph_buffers` call at the end of the context.
|
||||
It records all the buffer addresses used in the CUDA graph.
|
||||
"""
|
||||
try:
|
||||
self._IS_CAPTURING = True
|
||||
yield
|
||||
finally:
|
||||
self._IS_CAPTURING = False
|
||||
if not self.disabled:
|
||||
self.register_graph_buffers()
|
||||
|
||||
def register_graph_buffers(self):
|
||||
handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr)
|
||||
logger.info("Registering %d cuda graph addresses", len(offset))
|
||||
# We cannot directly use `dist.all_gather_object` here
|
||||
# because it is incompatible with `gloo` backend under inference mode.
|
||||
# see https://github.com/pytorch/pytorch/issues/126032 for details.
|
||||
all_data = [[None, None]
|
||||
for _ in range(dist.get_world_size(group=self.group))]
|
||||
all_data[self.rank] = [handle, offset]
|
||||
ranks = sorted(dist.get_process_group_ranks(group=self.group))
|
||||
for i, rank in enumerate(ranks):
|
||||
dist.broadcast_object_list(all_data[i],
|
||||
src=rank,
|
||||
group=self.group,
|
||||
device="cpu")
|
||||
# Unpack list of tuples to tuple of lists.
|
||||
handles = [d[0] for d in all_data] # type: ignore
|
||||
offsets = [d[1] for d in all_data] # type: ignore
|
||||
ops.register_graph_buffers(self._ptr, handles, offsets)
|
||||
|
||||
def should_custom_ar(self, inp: torch.Tensor):
|
||||
if self.disabled:
|
||||
return False
|
||||
inp_size = inp.numel() * inp.element_size()
|
||||
# custom allreduce requires input byte size to be multiples of 16
|
||||
if inp_size % 16 != 0:
|
||||
return False
|
||||
if not is_weak_contiguous(inp):
|
||||
return False
|
||||
# for 4 or more non NVLink-capable GPUs, custom allreduce provides
|
||||
# little performance improvement over NCCL.
|
||||
if self.world_size == 2 or self.fully_connected:
|
||||
return inp_size < self.max_size
|
||||
return False
|
||||
|
||||
def all_reduce(self,
|
||||
inp: torch.Tensor,
|
||||
*,
|
||||
out: torch.Tensor = None,
|
||||
registered: bool = False):
|
||||
"""Performs an out-of-place all reduce.
|
||||
|
||||
If registered is True, this assumes inp's pointer is already
|
||||
IPC-registered. Otherwise, inp is first copied into a pre-registered
|
||||
buffer.
|
||||
"""
|
||||
if out is None:
|
||||
out = torch.empty_like(inp)
|
||||
if registered:
|
||||
ops.all_reduce(self._ptr, inp, out, 0, 0)
|
||||
else:
|
||||
ops.all_reduce(self._ptr, inp, out, self.buffer_ptrs[self.rank],
|
||||
self.max_size)
|
||||
return out
|
||||
|
||||
def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]:
|
||||
"""The main allreduce API that provides support for cuda graph."""
|
||||
# When custom allreduce is disabled, this will be None.
|
||||
if self.disabled or not self.should_custom_ar(input):
|
||||
return None
|
||||
if self._IS_CAPTURING:
|
||||
if torch.cuda.is_current_stream_capturing():
|
||||
return self.all_reduce(input, registered=True)
|
||||
else:
|
||||
# If warm up, mimic the allocation pattern since custom
|
||||
# allreduce is out-of-place.
|
||||
return torch.empty_like(input)
|
||||
else:
|
||||
# Note: outside of cuda graph context, custom allreduce incurs a
|
||||
# cost of cudaMemcpy, which should be small (<=1% of overall
|
||||
# latency) compared to the performance gain of using custom kernels
|
||||
return self.all_reduce(input, registered=False)
|
||||
|
||||
def close(self):
|
||||
if not self.disabled and self._ptr:
|
||||
if ops is not None:
|
||||
ops.dispose(self._ptr)
|
||||
self._ptr = 0
|
||||
self.free_shared_buffer(self.meta_ptrs, rank=self.rank)
|
||||
self.free_shared_buffer(self.buffer_ptrs, rank=self.rank)
|
||||
|
||||
def __del__(self):
|
||||
self.close()
|
||||
|
||||
@staticmethod
|
||||
def create_shared_buffer(size_in_bytes: int,
|
||||
group: Optional[ProcessGroup] = None,
|
||||
uncached: Optional[bool] = False) -> list[int]:
|
||||
pointer, handle = ops.allocate_shared_buffer_and_handle(size_in_bytes)
|
||||
|
||||
world_size = dist.get_world_size(group=group)
|
||||
rank = dist.get_rank(group=group)
|
||||
handles = [None] * world_size
|
||||
dist.all_gather_object(handles, handle, group=group)
|
||||
|
||||
pointers: list[int] = []
|
||||
for i, h in enumerate(handles):
|
||||
if i == rank:
|
||||
pointers.append(pointer) # type: ignore
|
||||
else:
|
||||
pointers.append(ops.open_mem_handle(h))
|
||||
return pointers
|
||||
|
||||
@staticmethod
|
||||
def free_shared_buffer(pointers: list[int],
|
||||
group: Optional[ProcessGroup] = None,
|
||||
rank: Optional[int] = 0) -> None:
|
||||
if rank is None:
|
||||
rank = dist.get_rank(group=group)
|
||||
if ops is not None:
|
||||
ops.free_shared_buffer(pointers[rank])
|
||||
259
vllm/distributed/device_communicators/custom_all_reduce_utils.py
Normal file
259
vllm/distributed/device_communicators/custom_all_reduce_utils.py
Normal file
@@ -0,0 +1,259 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import ctypes
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
from collections.abc import Sequence
|
||||
from itertools import product
|
||||
from typing import Optional
|
||||
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import (cuda_device_count_stateless,
|
||||
update_environment_variables)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def producer(batch_src: Sequence[int],
|
||||
producer_queue,
|
||||
consumer_queue,
|
||||
result_queue,
|
||||
cuda_visible_devices: Optional[str] = None):
|
||||
if cuda_visible_devices is not None:
|
||||
update_environment_variables(
|
||||
{"CUDA_VISIBLE_DEVICES": cuda_visible_devices})
|
||||
|
||||
lib = CudaRTLibrary()
|
||||
for i in batch_src:
|
||||
lib.cudaSetDevice(i)
|
||||
pointer = lib.cudaMalloc(1024)
|
||||
lib.cudaMemset(pointer, 1, 1024)
|
||||
lib.cudaDeviceSynchronize()
|
||||
handle = lib.cudaIpcGetMemHandle(pointer)
|
||||
producer_queue.put(handle)
|
||||
open_success = consumer_queue.get()
|
||||
if open_success:
|
||||
# use two queues to simulate barrier
|
||||
producer_queue.put(0)
|
||||
consumer_queue.get()
|
||||
# check if the memory is modified
|
||||
host_data = (ctypes.c_char * 1024)()
|
||||
lib.cudaMemcpy(host_data, pointer, 1024) # type: ignore
|
||||
for i in range(1024):
|
||||
if ord(host_data[i]) != 2:
|
||||
open_success = False
|
||||
break
|
||||
result_queue.put(open_success)
|
||||
lib.cudaDeviceReset()
|
||||
|
||||
|
||||
def consumer(batch_tgt: Sequence[int],
|
||||
producer_queue,
|
||||
consumer_queue,
|
||||
result_queue,
|
||||
cuda_visible_devices: Optional[str] = None):
|
||||
if cuda_visible_devices is not None:
|
||||
update_environment_variables(
|
||||
{"CUDA_VISIBLE_DEVICES": cuda_visible_devices})
|
||||
|
||||
lib = CudaRTLibrary()
|
||||
for j in batch_tgt:
|
||||
lib.cudaSetDevice(j)
|
||||
handle = producer_queue.get()
|
||||
open_success = False
|
||||
try:
|
||||
pointer = lib.cudaIpcOpenMemHandle(handle) # type: ignore
|
||||
open_success = True
|
||||
except RuntimeError:
|
||||
# cannot error out here, because the producer process
|
||||
# is still waiting for the response.
|
||||
pass
|
||||
consumer_queue.put(open_success)
|
||||
if open_success:
|
||||
# modify the memory
|
||||
lib.cudaMemset(pointer, 2, 1024)
|
||||
lib.cudaDeviceSynchronize()
|
||||
# use two queues to simulate barrier
|
||||
producer_queue.get()
|
||||
consumer_queue.put(0)
|
||||
# check if the memory is modified
|
||||
host_data = (ctypes.c_char * 1024)()
|
||||
lib.cudaMemcpy(host_data, pointer, 1024) # type: ignore
|
||||
for i in range(1024):
|
||||
if ord(host_data[i]) != 2:
|
||||
open_success = False
|
||||
break
|
||||
result_queue.put(open_success)
|
||||
lib.cudaDeviceReset()
|
||||
|
||||
|
||||
def can_actually_p2p(
|
||||
batch_src: Sequence[int],
|
||||
batch_tgt: Sequence[int],
|
||||
) -> Sequence[bool]:
|
||||
"""
|
||||
Usually, checking if P2P access is enabled can be done by
|
||||
`torch.cuda.can_device_access_peer(src, tgt)`. However, sometimes
|
||||
the driver might be broken, and `torch.cuda.can_device_access_peer(src, tgt)`
|
||||
returns `True` even if P2P access is not actually possible.
|
||||
See https://github.com/vllm-project/vllm/issues/2728 and
|
||||
https://forums.developer.nvidia.com/t/direct-gpu-gpu-communication-does-not-seem-to-work-properly/283264/10
|
||||
Therefore, we have to perform a real P2P access to check if it is actually
|
||||
possible.
|
||||
|
||||
Note on p2p and cuda IPC:
|
||||
Usually, one process uses one GPU:
|
||||
GPU src --> cuda context src --> tensor src --> process src
|
||||
|
||||
We need to combine p2p and cuda IPC, so that:
|
||||
GPU src --> cuda context src --> tensor src --> process src
|
||||
|shared|
|
||||
GPU tgt --> cuda context tgt --> tensor tgt --> process tgt
|
||||
That is to say, process src creates a tensor in GPU src, passes IPC handle to
|
||||
process tgt, and process tgt accesses the tensor in GPU tgt. Any operation on the
|
||||
tensor in process tgt will be reflected in the tensor in process src, because
|
||||
they are the same memory segment.
|
||||
It is important to note that process tgt accesses the tensor in GPU tgt, not
|
||||
GPU src. That's why we need p2p access.
|
||||
|
||||
The most time-consuming part is the process creation. To avoid creating
|
||||
processes for every pair of GPUs, we use batched testing. We create two
|
||||
processes for testing all pairs of GPUs in batch. The trick is to reset
|
||||
the device after each test (which is not available in PyTorch).
|
||||
""" # noqa
|
||||
cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
|
||||
# pass the CUDA_VISIBLE_DEVICES to the child process
|
||||
# to make sure they see the same set of GPUs
|
||||
|
||||
# make sure the processes are spawned
|
||||
smp = mp.get_context("spawn")
|
||||
producer_queue = smp.Queue()
|
||||
consumer_queue = smp.Queue()
|
||||
result_queue = smp.Queue()
|
||||
p_src = smp.Process(target=producer,
|
||||
args=(batch_src, producer_queue, consumer_queue,
|
||||
result_queue, cuda_visible_devices))
|
||||
p_tgt = smp.Process(target=consumer,
|
||||
args=(batch_tgt, producer_queue, consumer_queue,
|
||||
result_queue, cuda_visible_devices))
|
||||
p_src.start()
|
||||
p_tgt.start()
|
||||
p_src.join()
|
||||
p_tgt.join()
|
||||
assert p_src.exitcode == 0 and p_tgt.exitcode == 0
|
||||
result: list[bool] = []
|
||||
for src, tgt in zip(batch_src, batch_tgt):
|
||||
a = result_queue.get()
|
||||
b = result_queue.get()
|
||||
if a != b:
|
||||
logger.warning(
|
||||
"Two processes do not agree on the P2P access"
|
||||
" status on %d -> %d, treat as disabled.", src, tgt)
|
||||
result.append(False)
|
||||
else:
|
||||
result.append(a)
|
||||
return result
|
||||
|
||||
|
||||
# why do we need this cache?
|
||||
# we are testing peer-to-peer (p2p) access between GPUs,across processes.
|
||||
# if we test it every time, it will be very slow, because we need to create
|
||||
# N * N * 2 processes, where N is the world size. This is very slow.
|
||||
# to reduce the time, we use a cache file to store the p2p access status.
|
||||
# the cache file is generated by the master process if it does not exist.
|
||||
# then all the processes can read the cache file to check the p2p access status.
|
||||
# Note that the cache file is suffixed by the CUDA_VISIBLE_DEVICES, so that we
|
||||
# can have different cache files for different CUDA_VISIBLE_DEVICES settings,
|
||||
# e.g. used by different vllm engines. The device id in the cache file is a
|
||||
# **local** device id, i.e. from 0 to num_dev-1, where num_dev is the number
|
||||
# of visible devices in the vllm engine.
|
||||
_gpu_p2p_access_cache: Optional[dict[str, bool]] = None
|
||||
|
||||
|
||||
def gpu_p2p_access_check(src: int, tgt: int) -> bool:
|
||||
"""Check if GPU src can access GPU tgt."""
|
||||
|
||||
# if the cache variable is already calculated,
|
||||
# read from the cache instead of checking it again
|
||||
global _gpu_p2p_access_cache
|
||||
if _gpu_p2p_access_cache is not None:
|
||||
return _gpu_p2p_access_cache[f"{src}->{tgt}"]
|
||||
|
||||
is_distributed = dist.is_initialized()
|
||||
|
||||
num_dev = cuda_device_count_stateless()
|
||||
cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
|
||||
if cuda_visible_devices is None:
|
||||
cuda_visible_devices = ",".join(str(i) for i in range(num_dev))
|
||||
|
||||
path = os.path.join(
|
||||
envs.VLLM_CACHE_ROOT,
|
||||
f"gpu_p2p_access_cache_for_{cuda_visible_devices}.json")
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||
from vllm.distributed.parallel_state import get_world_group
|
||||
if ((not is_distributed or get_world_group().local_rank == 0)
|
||||
and (not os.path.exists(path))):
|
||||
# only the local master process (with local_rank == 0) can
|
||||
# enter this block to calculate the cache
|
||||
logger.info("generating GPU P2P access cache in %s", path)
|
||||
cache: dict[str, bool] = {}
|
||||
ids = list(range(num_dev))
|
||||
# batch of all pairs of GPUs
|
||||
batch_src, batch_tgt = zip(*list(product(ids, ids)))
|
||||
# NOTE: we use `subprocess` rather than `multiprocessing` here
|
||||
# because the caller might not have `if __name__ == "__main__":`,
|
||||
# in that case we cannot use spawn method in multiprocessing.
|
||||
# However, `can_actually_p2p` requires spawn method.
|
||||
# The fix is, we use `subprocess` to call the function,
|
||||
# where we have `if __name__ == "__main__":` in this file.
|
||||
|
||||
# use a temporary file to store the result
|
||||
# we don't use the output of the subprocess directly,
|
||||
# because the subprocess might produce logging output
|
||||
with tempfile.NamedTemporaryFile() as output_file:
|
||||
input_bytes = pickle.dumps(
|
||||
(batch_src, batch_tgt, output_file.name))
|
||||
returned = subprocess.run([sys.executable, __file__],
|
||||
input=input_bytes,
|
||||
capture_output=True)
|
||||
# check if the subprocess is successful
|
||||
try:
|
||||
returned.check_returncode()
|
||||
except Exception as e:
|
||||
# wrap raised exception to provide more information
|
||||
raise RuntimeError(
|
||||
f"Error happened when batch testing "
|
||||
f"peer-to-peer access from {batch_src} to {batch_tgt}:\n"
|
||||
f"{returned.stderr.decode()}") from e
|
||||
with open(output_file.name, "rb") as f:
|
||||
result = pickle.load(f)
|
||||
for _i, _j, r in zip(batch_src, batch_tgt, result):
|
||||
cache[f"{_i}->{_j}"] = r
|
||||
with open(path, "w") as f:
|
||||
json.dump(cache, f, indent=4)
|
||||
if is_distributed:
|
||||
get_world_group().barrier()
|
||||
logger.info("reading GPU P2P access cache from %s", path)
|
||||
with open(path) as f:
|
||||
cache = json.load(f)
|
||||
_gpu_p2p_access_cache = cache
|
||||
return _gpu_p2p_access_cache[f"{src}->{tgt}"]
|
||||
|
||||
|
||||
__all__ = ["gpu_p2p_access_check"]
|
||||
|
||||
if __name__ == "__main__":
|
||||
batch_src, batch_tgt, output_file = pickle.loads(sys.stdin.buffer.read())
|
||||
result = can_actually_p2p(batch_src, batch_tgt)
|
||||
with open(output_file, "wb") as f:
|
||||
f.write(pickle.dumps(result))
|
||||
46
vllm/distributed/device_communicators/hpu_communicator.py
Normal file
46
vllm/distributed/device_communicators/hpu_communicator.py
Normal file
@@ -0,0 +1,46 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .base_device_communicator import DeviceCommunicatorBase
|
||||
|
||||
if current_platform.is_hpu():
|
||||
import habana_frameworks.torch as htorch # noqa: F401
|
||||
|
||||
|
||||
class HpuCommunicator(DeviceCommunicatorBase):
|
||||
|
||||
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
|
||||
# FIXME(kzawora): this is a workaround for a bug in Habana PT bridge
|
||||
# occurring when PT_HPU_ENABLE_LAZY_COLLECTIVES=true env var is used
|
||||
# (which is required for tensor parallel HPUGraph inference)
|
||||
htorch.core.mark_step()
|
||||
dist.all_reduce(input_, group=self.device_group)
|
||||
return input_
|
||||
|
||||
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
||||
world_size = self.world_size
|
||||
if dim < 0:
|
||||
# Convert negative dim to positive.
|
||||
dim += input_.dim()
|
||||
input_size = input_.size()
|
||||
# Allocate output tensor.
|
||||
output_tensor = torch.empty((world_size, ) + input_size,
|
||||
dtype=input_.dtype,
|
||||
device=input_.device)
|
||||
# All-gather.
|
||||
htorch.core.mark_step()
|
||||
dist.all_gather_into_tensor(output_tensor,
|
||||
input_,
|
||||
group=self.device_group)
|
||||
# Reshape
|
||||
output_tensor = output_tensor.movedim(0, dim)
|
||||
output_tensor = output_tensor.reshape(input_size[:dim] +
|
||||
(world_size *
|
||||
input_size[dim], ) +
|
||||
input_size[dim + 1:])
|
||||
return output_tensor
|
||||
20
vllm/distributed/device_communicators/neuron_communicator.py
Normal file
20
vllm/distributed/device_communicators/neuron_communicator.py
Normal file
@@ -0,0 +1,20 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch
|
||||
|
||||
from vllm.distributed.device_communicators.base_device_communicator import (
|
||||
DeviceCommunicatorBase)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if current_platform.is_neuron():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
|
||||
class NeuronCommunicator(DeviceCommunicatorBase):
|
||||
|
||||
def all_reduce(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return xm.all_reduce(xm.REDUCE_SUM, x)
|
||||
|
||||
def all_gather(self, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
||||
assert dim == -1, "Neuron only supports dim=-1 for all-gather."
|
||||
return xm.all_gather(x, dim=dim)
|
||||
218
vllm/distributed/device_communicators/pynccl.py
Normal file
218
vllm/distributed/device_communicators/pynccl.py
Normal file
@@ -0,0 +1,218 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
# ===================== import region =====================
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup, ReduceOp
|
||||
|
||||
from vllm.distributed.device_communicators.pynccl_wrapper import (
|
||||
NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum,
|
||||
ncclRedOpTypeEnum, ncclUniqueId)
|
||||
from vllm.distributed.utils import StatelessProcessGroup
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import current_stream
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class PyNcclCommunicator:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
group: Union[ProcessGroup, StatelessProcessGroup],
|
||||
device: Union[int, str, torch.device],
|
||||
library_path: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
group: the process group to work on. If None, it will use the
|
||||
default process group.
|
||||
device: the device to bind the PyNcclCommunicator to. If None,
|
||||
it will be bind to f"cuda:{local_rank}".
|
||||
library_path: the path to the NCCL library. If None, it will
|
||||
use the default library path.
|
||||
It is the caller's responsibility to make sure each communicator
|
||||
is bind to a unique device.
|
||||
"""
|
||||
if not isinstance(group, StatelessProcessGroup):
|
||||
assert dist.is_initialized()
|
||||
assert dist.get_backend(group) != dist.Backend.NCCL, (
|
||||
"PyNcclCommunicator should be attached to a non-NCCL group.")
|
||||
# note: this rank is the rank in the group
|
||||
self.rank = dist.get_rank(group)
|
||||
self.world_size = dist.get_world_size(group)
|
||||
else:
|
||||
self.rank = group.rank
|
||||
self.world_size = group.world_size
|
||||
|
||||
self.group = group
|
||||
|
||||
# if world_size == 1, no need to create communicator
|
||||
if self.world_size == 1:
|
||||
self.available = False
|
||||
self.disabled = True
|
||||
return
|
||||
try:
|
||||
self.nccl = NCCLLibrary(library_path)
|
||||
except Exception:
|
||||
# disable because of missing NCCL library
|
||||
# e.g. in a non-GPU environment
|
||||
self.available = False
|
||||
self.disabled = True
|
||||
return
|
||||
|
||||
self.available = True
|
||||
self.disabled = False
|
||||
|
||||
logger.info("vLLM is using nccl==%s", self.nccl.ncclGetVersion())
|
||||
|
||||
if self.rank == 0:
|
||||
# get the unique id from NCCL
|
||||
self.unique_id = self.nccl.ncclGetUniqueId()
|
||||
else:
|
||||
# construct an empty unique id
|
||||
self.unique_id = ncclUniqueId()
|
||||
|
||||
if not isinstance(group, StatelessProcessGroup):
|
||||
tensor = torch.ByteTensor(list(self.unique_id.internal))
|
||||
ranks = dist.get_process_group_ranks(group)
|
||||
# arg `src` in `broadcast` is the global rank
|
||||
dist.broadcast(tensor, src=ranks[0], group=group)
|
||||
byte_list = tensor.tolist()
|
||||
for i, byte in enumerate(byte_list):
|
||||
self.unique_id.internal[i] = byte
|
||||
else:
|
||||
self.unique_id = group.broadcast_obj(self.unique_id, src=0)
|
||||
if isinstance(device, int):
|
||||
device = torch.device(f"cuda:{device}")
|
||||
elif isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
# now `device` is a `torch.device` object
|
||||
assert isinstance(device, torch.device)
|
||||
self.device = device
|
||||
# nccl communicator and stream will use this device
|
||||
# `torch.cuda.device` is a context manager that changes the
|
||||
# current cuda device to the specified one
|
||||
with torch.cuda.device(device):
|
||||
self.comm: ncclComm_t = self.nccl.ncclCommInitRank(
|
||||
self.world_size, self.unique_id, self.rank)
|
||||
|
||||
stream = current_stream()
|
||||
# A small all_reduce for warmup.
|
||||
data = torch.zeros(1, device=device)
|
||||
self.all_reduce(data)
|
||||
stream.synchronize()
|
||||
del data
|
||||
|
||||
def all_reduce(self,
|
||||
in_tensor: torch.Tensor,
|
||||
op: ReduceOp = ReduceOp.SUM,
|
||||
stream=None) -> torch.Tensor:
|
||||
if self.disabled:
|
||||
return None
|
||||
# nccl communicator created on a specific device
|
||||
# will only work on tensors on the same device
|
||||
# otherwise it will cause "illegal memory access"
|
||||
assert in_tensor.device == self.device, (
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {in_tensor.device}")
|
||||
|
||||
out_tensor = torch.empty_like(in_tensor)
|
||||
|
||||
if stream is None:
|
||||
stream = current_stream()
|
||||
self.nccl.ncclAllReduce(buffer_type(in_tensor.data_ptr()),
|
||||
buffer_type(out_tensor.data_ptr()),
|
||||
in_tensor.numel(),
|
||||
ncclDataTypeEnum.from_torch(in_tensor.dtype),
|
||||
ncclRedOpTypeEnum.from_torch(op), self.comm,
|
||||
cudaStream_t(stream.cuda_stream))
|
||||
return out_tensor
|
||||
|
||||
def all_gather(self,
|
||||
output_tensor: torch.Tensor,
|
||||
input_tensor: torch.Tensor,
|
||||
stream=None):
|
||||
if self.disabled:
|
||||
return
|
||||
# nccl communicator created on a specific device
|
||||
# will only work on tensors on the same device
|
||||
# otherwise it will cause "illegal memory access"
|
||||
assert input_tensor.device == self.device, (
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {input_tensor.device}")
|
||||
if stream is None:
|
||||
stream = current_stream()
|
||||
self.nccl.ncclAllGather(
|
||||
buffer_type(input_tensor.data_ptr()),
|
||||
buffer_type(output_tensor.data_ptr()), input_tensor.numel(),
|
||||
ncclDataTypeEnum.from_torch(input_tensor.dtype), self.comm,
|
||||
cudaStream_t(stream.cuda_stream))
|
||||
|
||||
def reduce_scatter(self,
|
||||
output_tensor: torch.Tensor,
|
||||
input_tensor: torch.Tensor,
|
||||
op: ReduceOp = ReduceOp.SUM,
|
||||
stream=None):
|
||||
if self.disabled:
|
||||
return
|
||||
# nccl communicator created on a specific device
|
||||
# will only work on tensors on the same device
|
||||
# otherwise it will cause "illegal memory access"
|
||||
assert input_tensor.device == self.device, (
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {input_tensor.device}")
|
||||
if stream is None:
|
||||
stream = current_stream()
|
||||
self.nccl.ncclReduceScatter(
|
||||
buffer_type(input_tensor.data_ptr()),
|
||||
buffer_type(output_tensor.data_ptr()), output_tensor.numel(),
|
||||
ncclDataTypeEnum.from_torch(input_tensor.dtype),
|
||||
ncclRedOpTypeEnum.from_torch(op), self.comm,
|
||||
cudaStream_t(stream.cuda_stream))
|
||||
|
||||
def send(self, tensor: torch.Tensor, dst: int, stream=None):
|
||||
if self.disabled:
|
||||
return
|
||||
assert tensor.device == self.device, (
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {tensor.device}")
|
||||
if stream is None:
|
||||
stream = current_stream()
|
||||
self.nccl.ncclSend(buffer_type(tensor.data_ptr()), tensor.numel(),
|
||||
ncclDataTypeEnum.from_torch(tensor.dtype), dst,
|
||||
self.comm, cudaStream_t(stream.cuda_stream))
|
||||
|
||||
def recv(self, tensor: torch.Tensor, src: int, stream=None):
|
||||
if self.disabled:
|
||||
return
|
||||
assert tensor.device == self.device, (
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {tensor.device}")
|
||||
if stream is None:
|
||||
stream = current_stream()
|
||||
self.nccl.ncclRecv(buffer_type(tensor.data_ptr()), tensor.numel(),
|
||||
ncclDataTypeEnum.from_torch(tensor.dtype), src,
|
||||
self.comm, cudaStream_t(stream.cuda_stream))
|
||||
|
||||
def broadcast(self, tensor: torch.Tensor, src: int, stream=None):
|
||||
if self.disabled:
|
||||
return
|
||||
assert tensor.device == self.device, (
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {tensor.device}")
|
||||
if stream is None:
|
||||
stream = current_stream()
|
||||
if src == self.rank:
|
||||
sendbuff = buffer_type(tensor.data_ptr())
|
||||
# NCCL requires the sender also to have a receive buffer
|
||||
recvbuff = buffer_type(tensor.data_ptr())
|
||||
else:
|
||||
sendbuff = buffer_type()
|
||||
recvbuff = buffer_type(tensor.data_ptr())
|
||||
self.nccl.ncclBroadcast(sendbuff, recvbuff, tensor.numel(),
|
||||
ncclDataTypeEnum.from_torch(tensor.dtype), src,
|
||||
self.comm, cudaStream_t(stream.cuda_stream))
|
||||
341
vllm/distributed/device_communicators/pynccl_wrapper.py
Normal file
341
vllm/distributed/device_communicators/pynccl_wrapper.py
Normal file
@@ -0,0 +1,341 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# This file is a pure Python wrapper for the NCCL library.
|
||||
# The main purpose is to use NCCL combined with CUDA graph.
|
||||
# Before writing this script, we tried the following approach:
|
||||
# 1. We tried to use `cupy`, it calls NCCL correctly, but `cupy` itself
|
||||
# often gets stuck when initializing the NCCL communicator.
|
||||
# 2. We tried to use `torch.distributed`, but `torch.distributed.all_reduce`
|
||||
# contains many other potential cuda APIs, that are not allowed during
|
||||
# capturing the CUDA graph. For further details, please check
|
||||
# https://discuss.pytorch.org/t/pytorch-cudagraph-with-nccl-operation-failed/ .
|
||||
#
|
||||
# Another rejected idea is to write a C/C++ binding for NCCL. It is usually
|
||||
# doable, but we often encounter issues related with nccl versions, and need
|
||||
# to switch between different versions of NCCL. See
|
||||
# https://github.com/NVIDIA/nccl/issues/1234 for more details.
|
||||
# A C/C++ binding is not flexible enough to handle this. It requires
|
||||
# recompilation of the code every time we want to switch between different
|
||||
# versions. This current implementation, with a **pure** Python wrapper, is
|
||||
# more flexible. We can easily switch between different versions of NCCL by
|
||||
# changing the environment variable `VLLM_NCCL_SO_PATH`, or the `so_file`
|
||||
# variable in the code.
|
||||
|
||||
import ctypes
|
||||
import platform
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from torch.distributed import ReduceOp
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import find_nccl_library
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# === export types and functions from nccl to Python ===
|
||||
# for the original nccl definition, please check
|
||||
# https://github.com/NVIDIA/nccl/blob/master/src/nccl.h.in
|
||||
|
||||
ncclResult_t = ctypes.c_int
|
||||
ncclComm_t = ctypes.c_void_p
|
||||
|
||||
|
||||
class ncclUniqueId(ctypes.Structure):
|
||||
_fields_ = [("internal", ctypes.c_byte * 128)]
|
||||
|
||||
|
||||
cudaStream_t = ctypes.c_void_p
|
||||
buffer_type = ctypes.c_void_p
|
||||
|
||||
ncclDataType_t = ctypes.c_int
|
||||
|
||||
|
||||
class ncclDataTypeEnum:
|
||||
ncclInt8 = 0
|
||||
ncclChar = 0
|
||||
ncclUint8 = 1
|
||||
ncclInt32 = 2
|
||||
ncclInt = 2
|
||||
ncclUint32 = 3
|
||||
ncclInt64 = 4
|
||||
ncclUint64 = 5
|
||||
ncclFloat16 = 6
|
||||
ncclHalf = 6
|
||||
ncclFloat32 = 7
|
||||
ncclFloat = 7
|
||||
ncclFloat64 = 8
|
||||
ncclDouble = 8
|
||||
ncclBfloat16 = 9
|
||||
ncclNumTypes = 10
|
||||
|
||||
@classmethod
|
||||
def from_torch(cls, dtype: torch.dtype) -> int:
|
||||
if dtype == torch.int8:
|
||||
return cls.ncclInt8
|
||||
if dtype == torch.uint8:
|
||||
return cls.ncclUint8
|
||||
if dtype == torch.int32:
|
||||
return cls.ncclInt32
|
||||
if dtype == torch.int64:
|
||||
return cls.ncclInt64
|
||||
if dtype == torch.float16:
|
||||
return cls.ncclFloat16
|
||||
if dtype == torch.float32:
|
||||
return cls.ncclFloat32
|
||||
if dtype == torch.float64:
|
||||
return cls.ncclFloat64
|
||||
if dtype == torch.bfloat16:
|
||||
return cls.ncclBfloat16
|
||||
raise ValueError(f"Unsupported dtype: {dtype}")
|
||||
|
||||
|
||||
ncclRedOp_t = ctypes.c_int
|
||||
|
||||
|
||||
class ncclRedOpTypeEnum:
|
||||
ncclSum = 0
|
||||
ncclProd = 1
|
||||
ncclMax = 2
|
||||
ncclMin = 3
|
||||
ncclAvg = 4
|
||||
ncclNumOps = 5
|
||||
|
||||
@classmethod
|
||||
def from_torch(cls, op: ReduceOp) -> int:
|
||||
if op == ReduceOp.SUM:
|
||||
return cls.ncclSum
|
||||
if op == ReduceOp.PRODUCT:
|
||||
return cls.ncclProd
|
||||
if op == ReduceOp.MAX:
|
||||
return cls.ncclMax
|
||||
if op == ReduceOp.MIN:
|
||||
return cls.ncclMin
|
||||
if op == ReduceOp.AVG:
|
||||
return cls.ncclAvg
|
||||
raise ValueError(f"Unsupported op: {op}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class Function:
|
||||
name: str
|
||||
restype: Any
|
||||
argtypes: list[Any]
|
||||
|
||||
|
||||
class NCCLLibrary:
|
||||
exported_functions = [
|
||||
# const char* ncclGetErrorString(ncclResult_t result)
|
||||
Function("mcclGetErrorString", ctypes.c_char_p, [ncclResult_t]),
|
||||
# ncclResult_t ncclGetVersion(int *version);
|
||||
Function("mcclGetVersion", ncclResult_t,
|
||||
[ctypes.POINTER(ctypes.c_int)]),
|
||||
# ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId);
|
||||
Function("mcclGetUniqueId", ncclResult_t,
|
||||
[ctypes.POINTER(ncclUniqueId)]),
|
||||
# ncclResult_t ncclCommInitRank(
|
||||
# ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank);
|
||||
# note that ncclComm_t is a pointer type, so the first argument
|
||||
# is a pointer to a pointer
|
||||
Function("mcclCommInitRank", ncclResult_t, [
|
||||
ctypes.POINTER(ncclComm_t), ctypes.c_int, ncclUniqueId,
|
||||
ctypes.c_int
|
||||
]),
|
||||
# ncclResult_t ncclAllReduce(
|
||||
# const void* sendbuff, void* recvbuff, size_t count,
|
||||
# ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,
|
||||
# cudaStream_t stream);
|
||||
# note that cudaStream_t is a pointer type, so the last argument
|
||||
# is a pointer
|
||||
Function("mcclAllReduce", ncclResult_t, [
|
||||
buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t,
|
||||
ncclRedOp_t, ncclComm_t, cudaStream_t
|
||||
]),
|
||||
|
||||
# ncclResult_t ncclAllGather(
|
||||
# const void* sendbuff, void* recvbuff, size_t count,
|
||||
# ncclDataType_t datatype, ncclComm_t comm,
|
||||
# cudaStream_t stream);
|
||||
# note that cudaStream_t is a pointer type, so the last argument
|
||||
# is a pointer
|
||||
Function("mcclAllGather", ncclResult_t, [
|
||||
buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t,
|
||||
ncclComm_t, cudaStream_t
|
||||
]),
|
||||
|
||||
# ncclResult_t ncclReduceScatter(
|
||||
# const void* sendbuff, void* recvbuff, size_t count,
|
||||
# ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,
|
||||
# cudaStream_t stream);
|
||||
# note that cudaStream_t is a pointer type, so the last argument
|
||||
# is a pointer
|
||||
Function("mcclReduceScatter", ncclResult_t, [
|
||||
buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t,
|
||||
ncclRedOp_t, ncclComm_t, cudaStream_t
|
||||
]),
|
||||
|
||||
# ncclResult_t ncclSend(
|
||||
# const void* sendbuff, size_t count, ncclDataType_t datatype,
|
||||
# int dest, ncclComm_t comm, cudaStream_t stream);
|
||||
Function("mcclSend", ncclResult_t, [
|
||||
buffer_type, ctypes.c_size_t, ncclDataType_t, ctypes.c_int,
|
||||
ncclComm_t, cudaStream_t
|
||||
]),
|
||||
|
||||
# ncclResult_t ncclRecv(
|
||||
# void* recvbuff, size_t count, ncclDataType_t datatype,
|
||||
# int src, ncclComm_t comm, cudaStream_t stream);
|
||||
Function("mcclRecv", ncclResult_t, [
|
||||
buffer_type, ctypes.c_size_t, ncclDataType_t, ctypes.c_int,
|
||||
ncclComm_t, cudaStream_t
|
||||
]),
|
||||
|
||||
# ncclResult_t ncclBroadcast(
|
||||
# const void* sendbuff, void* recvbuff, size_t count,
|
||||
# ncclDataType_t datatype, int root, ncclComm_t comm,
|
||||
# cudaStream_t stream);
|
||||
Function("mcclBroadcast", ncclResult_t, [
|
||||
buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t,
|
||||
ctypes.c_int, ncclComm_t, cudaStream_t
|
||||
]),
|
||||
|
||||
# be cautious! this is a collective call, it will block until all
|
||||
# processes in the communicator have called this function.
|
||||
# because Python object destruction can happen in random order,
|
||||
# it is better not to call it at all.
|
||||
# ncclResult_t ncclCommDestroy(ncclComm_t comm);
|
||||
Function("mcclCommDestroy", ncclResult_t, [ncclComm_t]),
|
||||
]
|
||||
|
||||
# class attribute to store the mapping from the path to the library
|
||||
# to avoid loading the same library multiple times
|
||||
path_to_library_cache: dict[str, Any] = {}
|
||||
|
||||
# class attribute to store the mapping from library path
|
||||
# to the corresponding dictionary
|
||||
path_to_dict_mapping: dict[str, dict[str, Any]] = {}
|
||||
|
||||
def __init__(self, so_file: Optional[str] = None):
|
||||
|
||||
so_file = so_file or find_nccl_library()
|
||||
|
||||
try:
|
||||
if so_file not in NCCLLibrary.path_to_dict_mapping:
|
||||
lib = ctypes.CDLL(so_file)
|
||||
NCCLLibrary.path_to_library_cache[so_file] = lib
|
||||
self.lib = NCCLLibrary.path_to_library_cache[so_file]
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to load NCCL library from %s. "
|
||||
"It is expected if you are not running on NVIDIA/AMD GPUs."
|
||||
"Otherwise, the nccl library might not exist, be corrupted "
|
||||
"or it does not support the current platform %s. "
|
||||
"If you already have the library, please set the "
|
||||
"environment variable VLLM_NCCL_SO_PATH"
|
||||
" to point to the correct nccl library path.", so_file,
|
||||
platform.platform())
|
||||
raise e
|
||||
|
||||
if so_file not in NCCLLibrary.path_to_dict_mapping:
|
||||
_funcs: dict[str, Any] = {}
|
||||
for func in NCCLLibrary.exported_functions:
|
||||
f = getattr(self.lib, func.name)
|
||||
f.restype = func.restype
|
||||
f.argtypes = func.argtypes
|
||||
_funcs[func.name] = f
|
||||
NCCLLibrary.path_to_dict_mapping[so_file] = _funcs
|
||||
self._funcs = NCCLLibrary.path_to_dict_mapping[so_file]
|
||||
|
||||
def ncclGetErrorString(self, result: ncclResult_t) -> str:
|
||||
return self._funcs["mcclGetErrorString"](result).decode("utf-8")
|
||||
|
||||
def NCCL_CHECK(self, result: ncclResult_t) -> None:
|
||||
if result != 0:
|
||||
error_str = self.ncclGetErrorString(result)
|
||||
raise RuntimeError(f"MCCL error: {error_str}")
|
||||
|
||||
def ncclGetVersion(self) -> str:
|
||||
version = ctypes.c_int()
|
||||
self.NCCL_CHECK(self._funcs["mcclGetVersion"](ctypes.byref(version)))
|
||||
version_str = str(version.value)
|
||||
# something like 21903 --> "2.19.3"
|
||||
major = version_str[0].lstrip("0")
|
||||
minor = version_str[1:3].lstrip("0")
|
||||
patch = version_str[3:].lstrip("0")
|
||||
return f"{major}.{minor}.{patch}"
|
||||
|
||||
def ncclGetUniqueId(self) -> ncclUniqueId:
|
||||
unique_id = ncclUniqueId()
|
||||
self.NCCL_CHECK(self._funcs["mcclGetUniqueId"](
|
||||
ctypes.byref(unique_id)))
|
||||
return unique_id
|
||||
|
||||
def ncclCommInitRank(self, world_size: int, unique_id: ncclUniqueId,
|
||||
rank: int) -> ncclComm_t:
|
||||
comm = ncclComm_t()
|
||||
self.NCCL_CHECK(self._funcs["mcclCommInitRank"](ctypes.byref(comm),
|
||||
world_size, unique_id,
|
||||
rank))
|
||||
return comm
|
||||
|
||||
def ncclAllReduce(self, sendbuff: buffer_type, recvbuff: buffer_type,
|
||||
count: int, datatype: int, op: int, comm: ncclComm_t,
|
||||
stream: cudaStream_t) -> None:
|
||||
# `datatype` actually should be `ncclDataType_t`
|
||||
# and `op` should be `ncclRedOp_t`
|
||||
# both are aliases of `ctypes.c_int`
|
||||
# when we pass int to a function, it will be converted to `ctypes.c_int`
|
||||
# by ctypes automatically
|
||||
self.NCCL_CHECK(self._funcs["mcclAllReduce"](sendbuff, recvbuff, count,
|
||||
datatype, op, comm,
|
||||
stream))
|
||||
|
||||
def ncclReduceScatter(self, sendbuff: buffer_type, recvbuff: buffer_type,
|
||||
count: int, datatype: int, op: int, comm: ncclComm_t,
|
||||
stream: cudaStream_t) -> None:
|
||||
# `datatype` actually should be `ncclDataType_t`
|
||||
# and `op` should be `ncclRedOp_t`
|
||||
# both are aliases of `ctypes.c_int`
|
||||
# when we pass int to a function, it will be converted to `ctypes.c_int`
|
||||
# by ctypes automatically
|
||||
self.NCCL_CHECK(self._funcs["mcclReduceScatter"](sendbuff, recvbuff,
|
||||
count, datatype, op,
|
||||
comm, stream))
|
||||
|
||||
def ncclAllGather(self, sendbuff: buffer_type, recvbuff: buffer_type,
|
||||
count: int, datatype: int, comm: ncclComm_t,
|
||||
stream: cudaStream_t) -> None:
|
||||
# `datatype` actually should be `ncclDataType_t`
|
||||
# which is an aliases of `ctypes.c_int`
|
||||
# when we pass int to a function, it will be converted to `ctypes.c_int`
|
||||
# by ctypes automatically
|
||||
self.NCCL_CHECK(self._funcs["mcclAllGather"](sendbuff, recvbuff, count,
|
||||
datatype, comm, stream))
|
||||
|
||||
def ncclSend(self, sendbuff: buffer_type, count: int, datatype: int,
|
||||
dest: int, comm: ncclComm_t, stream: cudaStream_t) -> None:
|
||||
self.NCCL_CHECK(self._funcs["mcclSend"](sendbuff, count, datatype,
|
||||
dest, comm, stream))
|
||||
|
||||
def ncclRecv(self, recvbuff: buffer_type, count: int, datatype: int,
|
||||
src: int, comm: ncclComm_t, stream: cudaStream_t) -> None:
|
||||
self.NCCL_CHECK(self._funcs["mcclRecv"](recvbuff, count, datatype, src,
|
||||
comm, stream))
|
||||
|
||||
def ncclBroadcast(self, sendbuff: buffer_type, recvbuff: buffer_type,
|
||||
count: int, datatype: int, root: int, comm: ncclComm_t,
|
||||
stream: cudaStream_t) -> None:
|
||||
self.NCCL_CHECK(self._funcs["mcclBroadcast"](sendbuff, recvbuff, count,
|
||||
datatype, root, comm,
|
||||
stream))
|
||||
|
||||
def ncclCommDestroy(self, comm: ncclComm_t) -> None:
|
||||
self.NCCL_CHECK(self._funcs["mcclCommDestroy"](comm))
|
||||
|
||||
|
||||
__all__ = [
|
||||
"NCCLLibrary", "ncclDataTypeEnum", "ncclRedOpTypeEnum", "ncclUniqueId",
|
||||
"ncclComm_t", "cudaStream_t", "buffer_type"
|
||||
]
|
||||
585
vllm/distributed/device_communicators/shm_broadcast.py
Normal file
585
vllm/distributed/device_communicators/shm_broadcast.py
Normal file
@@ -0,0 +1,585 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pickle
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from multiprocessing import shared_memory
|
||||
from threading import Event
|
||||
from typing import Any, Optional, Union
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import zmq
|
||||
from torch.distributed import ProcessGroup
|
||||
from zmq import IPV6 # type: ignore
|
||||
from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.distributed.utils import StatelessProcessGroup, sched_yield
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import (get_ip, get_open_port, get_open_zmq_ipc_path,
|
||||
is_valid_ipv6_address)
|
||||
|
||||
VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class SpinTimer:
|
||||
|
||||
def record_activity(self):
|
||||
pass
|
||||
|
||||
def spin(self):
|
||||
sched_yield()
|
||||
|
||||
|
||||
class SpinSleepTimer(SpinTimer):
|
||||
"""
|
||||
In setups which have long inactivity periods it is desirable to reduce
|
||||
system power consumption when vllm does nothing. This would lead to more
|
||||
CPU thermal headroom when a request eventually comes, especially when
|
||||
multiple GPUs are connected as each GPU would otherwise pin one thread at
|
||||
100% CPU usage.
|
||||
|
||||
The simplest solution is to reduce polling frequency when there is no
|
||||
activity for a certain period of time.
|
||||
"""
|
||||
|
||||
def __init__(self, busy_loop_s: float = 3.0, wait_sleep_s: float = 0.1):
|
||||
self.last_activity = time.monotonic()
|
||||
self.busy_loop_s = busy_loop_s
|
||||
self.wait_sleep_s = wait_sleep_s
|
||||
|
||||
def record_activity(self):
|
||||
self.last_activity = time.monotonic()
|
||||
|
||||
def spin(self):
|
||||
curr_time = time.monotonic()
|
||||
if curr_time >= self.last_activity + self.busy_loop_s:
|
||||
time.sleep(self.wait_sleep_s)
|
||||
else:
|
||||
sched_yield()
|
||||
|
||||
|
||||
class ShmRingBuffer:
|
||||
|
||||
def __init__(self,
|
||||
n_reader: int,
|
||||
max_chunk_bytes: int,
|
||||
max_chunks: int,
|
||||
name: Optional[str] = None):
|
||||
"""
|
||||
A shared memory ring buffer implementation for broadcast communication.
|
||||
Essentially, it is a queue where only one will `enqueue` and multiple
|
||||
will `dequeue`. The max size of each item, together with the max number
|
||||
of items that can be stored in the buffer are known in advance.
|
||||
In this case, we don't need to synchronize the access to
|
||||
the buffer.
|
||||
|
||||
Buffer memory layout:
|
||||
data metadata
|
||||
| |
|
||||
| (current_idx) | (current_idx)
|
||||
v v
|
||||
+-------------------------------+----------------------------------------+
|
||||
| chunk0 | chunk1 | ... | chunk | metadata0 | metadata1 | ... | metadata |
|
||||
+-------------------------------+----------------------------------------+
|
||||
| max_chunks x max_chunk_bytes | max_chunks x (1 + n_reader) bytes |
|
||||
|
||||
metadata memory layout: each byte is a flag, the first byte is the written
|
||||
flag, and the rest are reader flags. The flags are set to 0 by default.
|
||||
+--------------+--------------+--------------+-----+--------------+
|
||||
| written_flag | reader0_flag | reader1_flag | ... | readerN_flag |
|
||||
+--------------+--------------+--------------+-----+--------------+
|
||||
|
||||
The state of metadata is as follows:
|
||||
|
||||
(case 1) 0???...???: the block is not written yet, cannot read, can write
|
||||
(case 2) 1000...000: the block is just written, can read, cannot write
|
||||
(case 3) 1???...???: the block is written and read by some readers, can read if not read, cannot write
|
||||
(case 4) 1111...111: the block is written and read by all readers, cannot read, can write
|
||||
|
||||
State transition for readers:
|
||||
|
||||
When a reader finds a block that it can read (case 2 or 3), it can yield the block for caller to read.
|
||||
Only after the caller finishes reading the block, the reader can mark the block as read.
|
||||
Readers only mark the block as read (from 0 to 1), the writer marks the block as ready to read (from 1 to 0).
|
||||
|
||||
State transition for writer:
|
||||
|
||||
When the writer writes to a block (case 1 or 4), it first resets the written flag to 0, converting either case
|
||||
to case 1. Then it can yield the block for caller to write. After the caller finishes writing the block, the writer
|
||||
can reset the reader flags to 0, and mark the block as written (from 0 to 1).
|
||||
NOTE: the order is important here, first reset the reader flags (so that we are still in case 1), then mark the block as written. The state transition is atomic. If we do it in the reverse order, it will go through case 3 and then back to case 2, and readers might read the intermediate case 3, which is not correct.
|
||||
|
||||
During creation, `name` is None and the buffer is created. We can pass the
|
||||
created object to other processes by pickling it. The other processes will
|
||||
get the name of the shared memory and open it, so that they can access the
|
||||
same shared memory buffer.
|
||||
"""# noqa
|
||||
self.n_reader = n_reader
|
||||
self.metadata_size = 1 + n_reader
|
||||
self.max_chunk_bytes = max_chunk_bytes
|
||||
self.max_chunks = max_chunks
|
||||
self.total_bytes_of_buffer = (self.max_chunk_bytes +
|
||||
self.metadata_size) * self.max_chunks
|
||||
self.data_offset = 0
|
||||
self.metadata_offset = self.max_chunk_bytes * self.max_chunks
|
||||
|
||||
if name is None:
|
||||
# we are creating a buffer
|
||||
self.is_creator = True
|
||||
self.shared_memory = shared_memory.SharedMemory(
|
||||
create=True, size=self.total_bytes_of_buffer)
|
||||
# initialize the metadata section to 0
|
||||
with memoryview(self.shared_memory.buf[self.metadata_offset:]
|
||||
) as metadata_buffer:
|
||||
torch.frombuffer(metadata_buffer, dtype=torch.uint8).fill_(0)
|
||||
else:
|
||||
# we are opening an existing buffer
|
||||
self.is_creator = False
|
||||
# fix to https://stackoverflow.com/q/62748654/9191338
|
||||
# Python incorrectly tracks shared memory even if it is not
|
||||
# created by the process. The following patch is a workaround.
|
||||
with patch("multiprocessing.resource_tracker.register",
|
||||
lambda *args, **kwargs: None):
|
||||
try:
|
||||
self.shared_memory = shared_memory.SharedMemory(name=name)
|
||||
# See https://docs.python.org/3/library/multiprocessing.shared_memory.html # noqa
|
||||
# Some platforms allocate memory based on page size,
|
||||
# so the shared memory block size may be larger or equal
|
||||
# to the requested size. The size parameter is ignored
|
||||
# when attaching to an existing block.
|
||||
assert (self.shared_memory.size
|
||||
>= self.total_bytes_of_buffer)
|
||||
except FileNotFoundError:
|
||||
# we might deserialize the object in a different node
|
||||
# in this case, this object is not used,
|
||||
# and we should suppress the error
|
||||
pass
|
||||
|
||||
def handle(self):
|
||||
return (self.n_reader, self.max_chunk_bytes, self.max_chunks,
|
||||
self.shared_memory.name)
|
||||
|
||||
def __reduce__(self):
|
||||
return (
|
||||
self.__class__,
|
||||
self.handle(),
|
||||
)
|
||||
|
||||
def __del__(self):
|
||||
if hasattr(self, "shared_memory"):
|
||||
self.shared_memory.close()
|
||||
if self.is_creator:
|
||||
self.shared_memory.unlink()
|
||||
|
||||
@contextmanager
|
||||
def get_data(self, current_idx: int):
|
||||
start = self.data_offset + current_idx * self.max_chunk_bytes
|
||||
end = start + self.max_chunk_bytes
|
||||
with memoryview(self.shared_memory.buf[start:end]) as buf:
|
||||
yield buf
|
||||
|
||||
@contextmanager
|
||||
def get_metadata(self, current_idx: int):
|
||||
start = self.metadata_offset + current_idx * self.metadata_size
|
||||
end = start + self.metadata_size
|
||||
with memoryview(self.shared_memory.buf[start:end]) as buf:
|
||||
yield buf
|
||||
|
||||
|
||||
@dataclass
|
||||
class Handle:
|
||||
local_reader_ranks: list[int] = field(default_factory=list)
|
||||
|
||||
buffer_handle: Optional[tuple[int, int, int, str]] = None
|
||||
local_subscribe_addr: Optional[str] = None
|
||||
remote_subscribe_addr: Optional[str] = None
|
||||
remote_addr_ipv6: bool = False
|
||||
|
||||
|
||||
class MessageQueue:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_reader, # number of all readers
|
||||
n_local_reader, # number of local readers through shared memory
|
||||
local_reader_ranks: Optional[list[int]] = None,
|
||||
max_chunk_bytes: int = 1024 * 1024 * 10,
|
||||
max_chunks: int = 10,
|
||||
connect_ip: Optional[str] = None,
|
||||
):
|
||||
if local_reader_ranks is None:
|
||||
local_reader_ranks = list(range(n_local_reader))
|
||||
else:
|
||||
assert len(local_reader_ranks) == n_local_reader
|
||||
self.n_local_reader = n_local_reader
|
||||
n_remote_reader = n_reader - n_local_reader
|
||||
self.n_remote_reader = n_remote_reader
|
||||
|
||||
context = Context()
|
||||
|
||||
if n_local_reader > 0:
|
||||
# for local readers, we will:
|
||||
# 1. create a shared memory ring buffer to communicate small data
|
||||
# 2. create a publish-subscribe socket to communicate large data
|
||||
self.buffer = ShmRingBuffer(n_local_reader, max_chunk_bytes,
|
||||
max_chunks)
|
||||
|
||||
# XPUB is very similar to PUB,
|
||||
# except that it can receive subscription messages
|
||||
# to confirm the number of subscribers
|
||||
self.local_socket = context.socket(XPUB)
|
||||
# set the verbose option so that we can receive every subscription
|
||||
# message. otherwise, we will only receive the first subscription
|
||||
# see http://api.zeromq.org/3-3:zmq-setsockopt for more details
|
||||
self.local_socket.setsockopt(XPUB_VERBOSE, True)
|
||||
local_subscribe_addr = get_open_zmq_ipc_path()
|
||||
logger.debug("Binding to %s", local_subscribe_addr)
|
||||
self.local_socket.bind(local_subscribe_addr)
|
||||
|
||||
self.current_idx = 0
|
||||
else:
|
||||
self.buffer = None # type: ignore
|
||||
local_subscribe_addr = None
|
||||
self.local_socket = None
|
||||
self.current_idx = -1
|
||||
|
||||
remote_addr_ipv6 = False
|
||||
if n_remote_reader > 0:
|
||||
# for remote readers, we will:
|
||||
# create a publish-subscribe socket to communicate large data
|
||||
if not connect_ip:
|
||||
connect_ip = get_ip()
|
||||
self.remote_socket = context.socket(XPUB)
|
||||
self.remote_socket.setsockopt(XPUB_VERBOSE, True)
|
||||
remote_subscribe_port = get_open_port()
|
||||
if is_valid_ipv6_address(connect_ip):
|
||||
self.remote_socket.setsockopt(IPV6, 1)
|
||||
remote_addr_ipv6 = True
|
||||
connect_ip = f"[{connect_ip}]"
|
||||
socket_addr = f"tcp://{connect_ip}:{remote_subscribe_port}"
|
||||
self.remote_socket.bind(socket_addr)
|
||||
remote_subscribe_addr = f"tcp://{connect_ip}:{remote_subscribe_port}"
|
||||
else:
|
||||
remote_subscribe_addr = None
|
||||
self.remote_socket = None
|
||||
|
||||
self._is_writer = True
|
||||
self._is_local_reader = False
|
||||
self.local_reader_rank = -1
|
||||
# rank does not matter for remote readers
|
||||
self._is_remote_reader = False
|
||||
self._read_spin_timer = SpinTimer()
|
||||
|
||||
self.handle = Handle(
|
||||
local_reader_ranks=local_reader_ranks,
|
||||
buffer_handle=self.buffer.handle()
|
||||
if self.buffer is not None else None,
|
||||
local_subscribe_addr=local_subscribe_addr,
|
||||
remote_subscribe_addr=remote_subscribe_addr,
|
||||
remote_addr_ipv6=remote_addr_ipv6,
|
||||
)
|
||||
|
||||
logger.info("vLLM message queue communication handle: %s", self.handle)
|
||||
|
||||
def export_handle(self) -> Handle:
|
||||
return self.handle
|
||||
|
||||
@staticmethod
|
||||
def create_from_handle(handle: Handle, rank) -> "MessageQueue":
|
||||
self = MessageQueue.__new__(MessageQueue)
|
||||
self.handle = handle
|
||||
self._is_writer = False
|
||||
|
||||
context = Context()
|
||||
|
||||
if rank in handle.local_reader_ranks:
|
||||
assert handle.buffer_handle is not None
|
||||
self.buffer = ShmRingBuffer(*handle.buffer_handle)
|
||||
self.current_idx = 0
|
||||
self.local_reader_rank = handle.local_reader_ranks.index(rank)
|
||||
self._is_local_reader = True
|
||||
self._is_remote_reader = False
|
||||
|
||||
self.local_socket = context.socket(SUB)
|
||||
self.local_socket.setsockopt_string(SUBSCRIBE, "")
|
||||
socket_addr = handle.local_subscribe_addr
|
||||
logger.debug("Connecting to %s", socket_addr)
|
||||
self.local_socket.connect(socket_addr)
|
||||
|
||||
self.remote_socket = None
|
||||
|
||||
self._read_spin_timer = SpinSleepTimer(
|
||||
) if envs.VLLM_SLEEP_WHEN_IDLE else SpinTimer()
|
||||
else:
|
||||
self.buffer = None # type: ignore
|
||||
self.current_idx = -1
|
||||
self.local_reader_rank = -1
|
||||
self._is_local_reader = False
|
||||
self._is_remote_reader = True
|
||||
|
||||
self.local_socket = None
|
||||
|
||||
self.remote_socket = context.socket(SUB)
|
||||
self.remote_socket.setsockopt_string(SUBSCRIBE, "")
|
||||
if handle.remote_addr_ipv6:
|
||||
self.remote_socket.setsockopt(IPV6, 1)
|
||||
socket_addr = handle.remote_subscribe_addr
|
||||
logger.debug("Connecting to %s", socket_addr)
|
||||
self.remote_socket.connect(socket_addr)
|
||||
|
||||
return self
|
||||
|
||||
def wait_until_ready(self):
|
||||
"""This is a collective operation. All processes (including the
|
||||
readers and the writer) should call this function.
|
||||
"""
|
||||
if self._is_writer:
|
||||
# wait for all readers to connect
|
||||
|
||||
# local readers
|
||||
for i in range(self.n_local_reader):
|
||||
# wait for subscription messages from all local readers
|
||||
self.local_socket.recv()
|
||||
if self.n_local_reader > 0:
|
||||
# send a message to all local readers
|
||||
# to make sure the publish channel is working
|
||||
self.local_socket.send(b"READY")
|
||||
|
||||
# remote readers
|
||||
for i in range(self.n_remote_reader):
|
||||
# wait for subscription messages from all remote readers
|
||||
self.remote_socket.recv()
|
||||
if self.n_remote_reader > 0:
|
||||
# send a message to all remote readers
|
||||
# to make sure the publish channel is working
|
||||
self.remote_socket.send(b"READY")
|
||||
elif self._is_local_reader:
|
||||
# wait for the writer to send a message
|
||||
recv = self.local_socket.recv()
|
||||
assert recv == b"READY"
|
||||
elif self._is_remote_reader:
|
||||
# wait for the writer to send a message
|
||||
recv = self.remote_socket.recv()
|
||||
assert recv == b"READY"
|
||||
|
||||
@contextmanager
|
||||
def acquire_write(self, timeout: Optional[float] = None):
|
||||
assert self._is_writer, "Only writers can acquire write"
|
||||
start_time = time.monotonic()
|
||||
n_warning = 1
|
||||
while True:
|
||||
with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
|
||||
read_count = sum(metadata_buffer[1:])
|
||||
written_flag = metadata_buffer[0]
|
||||
if written_flag and read_count != self.buffer.n_reader:
|
||||
# this block is written and not read by all readers
|
||||
# for writers, `self.current_idx` is the next block to write
|
||||
# if this block is not ready to write,
|
||||
# we need to wait until it is read by all readers
|
||||
|
||||
# Release the processor to other threads
|
||||
sched_yield()
|
||||
|
||||
# if we wait for a long time, log a message
|
||||
if (time.monotonic() - start_time
|
||||
> VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning):
|
||||
logger.debug(
|
||||
("No available shared memory broadcast block found"
|
||||
" in %s second."),
|
||||
VLLM_RINGBUFFER_WARNING_INTERVAL,
|
||||
)
|
||||
n_warning += 1
|
||||
|
||||
# if we time out, raise an exception
|
||||
if (timeout is not None
|
||||
and time.monotonic() - start_time > timeout):
|
||||
raise TimeoutError
|
||||
|
||||
continue
|
||||
# found a block that is either
|
||||
# (1) not written
|
||||
# (2) read by all readers
|
||||
|
||||
# mark the block as not written
|
||||
metadata_buffer[0] = 0
|
||||
# let caller write to the buffer
|
||||
with self.buffer.get_data(self.current_idx) as buf:
|
||||
yield buf
|
||||
|
||||
# caller has written to the buffer
|
||||
# NOTE: order is important here
|
||||
# first set the read flags to 0
|
||||
# then set the written flag to 1
|
||||
# otherwise, the readers may think they already read the block
|
||||
for i in range(1, self.buffer.n_reader + 1):
|
||||
# set read flag to 0, meaning it is not read yet
|
||||
metadata_buffer[i] = 0
|
||||
# mark the block as written
|
||||
metadata_buffer[0] = 1
|
||||
self.current_idx = (self.current_idx +
|
||||
1) % self.buffer.max_chunks
|
||||
break
|
||||
|
||||
@contextmanager
|
||||
def acquire_read(self,
|
||||
timeout: Optional[float] = None,
|
||||
cancel: Optional[Event] = None):
|
||||
assert self._is_local_reader, "Only readers can acquire read"
|
||||
start_time = time.monotonic()
|
||||
n_warning = 1
|
||||
while True:
|
||||
with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
|
||||
read_flag = metadata_buffer[self.local_reader_rank + 1]
|
||||
written_flag = metadata_buffer[0]
|
||||
if not written_flag or read_flag:
|
||||
# this block is either
|
||||
# (1) not written
|
||||
# (2) already read by this reader
|
||||
|
||||
# for readers, `self.current_idx` is the next block to read
|
||||
# if this block is not ready,
|
||||
# we need to wait until it is written
|
||||
|
||||
# Release the processor to other threads
|
||||
self._read_spin_timer.spin()
|
||||
|
||||
# if we wait for a long time, log a message
|
||||
if (time.monotonic() - start_time
|
||||
> VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning):
|
||||
logger.debug(
|
||||
("No available shared memory broadcast block found"
|
||||
" in %s second."),
|
||||
VLLM_RINGBUFFER_WARNING_INTERVAL,
|
||||
)
|
||||
n_warning += 1
|
||||
|
||||
if cancel is not None and cancel.is_set():
|
||||
raise RuntimeError("cancelled")
|
||||
|
||||
# if we time out, raise an exception
|
||||
if (timeout is not None
|
||||
and time.monotonic() - start_time > timeout):
|
||||
raise TimeoutError
|
||||
|
||||
continue
|
||||
# found a block that is not read by this reader
|
||||
# let caller read from the buffer
|
||||
with self.buffer.get_data(self.current_idx) as buf:
|
||||
yield buf
|
||||
|
||||
# caller has read from the buffer
|
||||
# set the read flag
|
||||
metadata_buffer[self.local_reader_rank + 1] = 1
|
||||
self.current_idx = (self.current_idx +
|
||||
1) % self.buffer.max_chunks
|
||||
|
||||
self._read_spin_timer.record_activity()
|
||||
break
|
||||
|
||||
def enqueue(self, obj, timeout: Optional[float] = None):
|
||||
""" Write to message queue with optional timeout (in seconds) """
|
||||
assert self._is_writer, "Only writers can enqueue"
|
||||
serialized_obj = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
if self.n_local_reader > 0:
|
||||
if len(serialized_obj) >= self.buffer.max_chunk_bytes:
|
||||
with self.acquire_write(timeout) as buf:
|
||||
buf[0] = 1 # overflow
|
||||
self.local_socket.send(serialized_obj)
|
||||
else:
|
||||
with self.acquire_write(timeout) as buf:
|
||||
buf[0] = 0 # not overflow
|
||||
buf[1:len(serialized_obj) + 1] = serialized_obj
|
||||
if self.n_remote_reader > 0:
|
||||
self.remote_socket.send(serialized_obj)
|
||||
|
||||
def dequeue(self,
|
||||
timeout: Optional[float] = None,
|
||||
cancel: Optional[Event] = None):
|
||||
""" Read from message queue with optional timeout (in seconds) """
|
||||
if self._is_local_reader:
|
||||
with self.acquire_read(timeout, cancel) as buf:
|
||||
overflow = buf[0] == 1
|
||||
if not overflow:
|
||||
# no need to know the size of serialized object
|
||||
# pickle format contains the size information internally
|
||||
# see https://docs.python.org/3/library/pickle.html
|
||||
obj = pickle.loads(buf[1:])
|
||||
if overflow:
|
||||
obj = MessageQueue.recv(self.local_socket, timeout)
|
||||
elif self._is_remote_reader:
|
||||
obj = MessageQueue.recv(self.remote_socket, timeout)
|
||||
else:
|
||||
raise RuntimeError("Only readers can dequeue")
|
||||
return obj
|
||||
|
||||
@staticmethod
|
||||
def recv(socket: zmq.Socket, timeout: Optional[float]) -> Any:
|
||||
timeout_ms = None if timeout is None else int(timeout * 1000)
|
||||
if not socket.poll(timeout=timeout_ms):
|
||||
raise TimeoutError
|
||||
recv = socket.recv(copy=False)
|
||||
return pickle.loads(recv.buffer)
|
||||
|
||||
def broadcast_object(self, obj=None):
|
||||
if self._is_writer:
|
||||
self.enqueue(obj)
|
||||
return obj
|
||||
else:
|
||||
return self.dequeue()
|
||||
|
||||
@staticmethod
|
||||
def create_from_process_group(pg: Union[ProcessGroup,
|
||||
StatelessProcessGroup],
|
||||
max_chunk_bytes,
|
||||
max_chunks,
|
||||
writer_rank=0) -> "MessageQueue":
|
||||
if isinstance(pg, ProcessGroup):
|
||||
group_rank = dist.get_rank(pg)
|
||||
group_world_size = dist.get_world_size(pg)
|
||||
global_ranks = dist.get_process_group_ranks(pg)
|
||||
else:
|
||||
group_rank = pg.rank
|
||||
group_world_size = pg.world_size
|
||||
global_ranks = list(range(pg.world_size))
|
||||
|
||||
from vllm.distributed.parallel_state import in_the_same_node_as
|
||||
status = in_the_same_node_as(pg, source_rank=writer_rank)
|
||||
same_node_ranks = [i for i, s in enumerate(status) if s]
|
||||
n_reader = group_world_size - 1
|
||||
n_local_reader = len(same_node_ranks) - 1
|
||||
local_reader_ranks = [i for i in same_node_ranks if i != writer_rank]
|
||||
buffer_io: MessageQueue
|
||||
if group_rank == writer_rank:
|
||||
buffer_io = MessageQueue(
|
||||
n_reader=n_reader,
|
||||
n_local_reader=n_local_reader,
|
||||
local_reader_ranks=local_reader_ranks,
|
||||
max_chunk_bytes=max_chunk_bytes,
|
||||
max_chunks=max_chunks,
|
||||
)
|
||||
handle = buffer_io.export_handle()
|
||||
if isinstance(pg, ProcessGroup):
|
||||
dist.broadcast_object_list([handle],
|
||||
src=global_ranks[writer_rank],
|
||||
group=pg)
|
||||
else:
|
||||
pg.broadcast_obj(handle, writer_rank)
|
||||
else:
|
||||
if isinstance(pg, ProcessGroup):
|
||||
recv = [None]
|
||||
dist.broadcast_object_list(recv,
|
||||
src=global_ranks[writer_rank],
|
||||
group=pg)
|
||||
handle = recv[0] # type: ignore
|
||||
else:
|
||||
handle = pg.broadcast_obj(None, writer_rank)
|
||||
buffer_io = MessageQueue.create_from_handle(handle, group_rank)
|
||||
buffer_io.wait_until_ready()
|
||||
return buffer_io
|
||||
103
vllm/distributed/device_communicators/tpu_communicator.py
Normal file
103
vllm/distributed/device_communicators/tpu_communicator.py
Normal file
@@ -0,0 +1,103 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .base_device_communicator import DeviceCommunicatorBase
|
||||
|
||||
USE_RAY = parallel_config = get_current_vllm_config(
|
||||
).parallel_config.distributed_executor_backend == "ray"
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
if current_platform.is_tpu():
|
||||
import torch_xla
|
||||
import torch_xla.core.xla_model as xm
|
||||
import torch_xla.runtime as xr
|
||||
from torch_xla._internal import pjrt
|
||||
from torch_xla.distributed.xla_multiprocessing import (
|
||||
create_optimized_replica_groups)
|
||||
|
||||
if USE_RAY:
|
||||
from vllm.executor import ray_utils
|
||||
|
||||
|
||||
class TpuCommunicator(DeviceCommunicatorBase):
|
||||
|
||||
def __init__(self,
|
||||
cpu_group: ProcessGroup,
|
||||
device: Optional[torch.device] = None,
|
||||
device_group: Optional[ProcessGroup] = None,
|
||||
unique_name: str = ""):
|
||||
super().__init__(cpu_group, device, device_group, unique_name)
|
||||
|
||||
# NOTE(woosuk): When using TP > 1 on TPUs, every TPU on the same node
|
||||
# must be used together. Therefore, the local rank and world size can
|
||||
# be simply calculated as follows.
|
||||
global_rank = self.global_rank
|
||||
global_world_size = self.global_world_size
|
||||
|
||||
if USE_RAY:
|
||||
logger.info("TpuCommunicator initialized with RAY")
|
||||
# Calculate how many TPU nodes are in the current deployment. This
|
||||
# is the Ray placement group if it is deployed with Ray. Default
|
||||
# to the number of TPU nodes in the Ray cluster. The number of TPU
|
||||
# nodes is computed by the total number of TPUs divided by the
|
||||
# number of TPU accelerators per node, to account for clusters
|
||||
# with both CPUs and TPUs.
|
||||
num_nodes = ray_utils.get_num_tpu_nodes()
|
||||
num_nodes_in_pg = ray_utils.get_num_nodes_in_placement_group()
|
||||
if num_nodes_in_pg > 0:
|
||||
num_nodes = num_nodes_in_pg
|
||||
|
||||
local_world_size = global_world_size // num_nodes
|
||||
local_rank = global_rank % local_world_size
|
||||
else:
|
||||
logger.info("TpuCommunicator initialized with MP")
|
||||
# Sanity: Verify we run on a single host
|
||||
num_hosts = torch_xla.tpu.num_tpu_workers()
|
||||
assert num_hosts == 1
|
||||
|
||||
# Get the current number of TPUs (we have locally)
|
||||
local_world_size = torch_xla.tpu.num_available_chips()
|
||||
|
||||
# Get current rank
|
||||
local_rank = global_rank % local_world_size
|
||||
|
||||
# Ensure environment variables are set for multihost deployments.
|
||||
# On GKE, this is needed for libtpu and TPU driver to know which TPU
|
||||
# chip is actually visible. Otherwise the TPU driver will fail to
|
||||
# initialize because the number of devices would be different from
|
||||
# the number of visible worker addresses.
|
||||
os.environ["CLOUD_TPU_TASK_ID"] = str(global_rank)
|
||||
os.environ["TPU_VISIBLE_CHIPS"] = str(local_rank)
|
||||
|
||||
pjrt.initialize_multiprocess(local_rank, local_world_size)
|
||||
xr._init_world_size_ordinal()
|
||||
self.groups = create_optimized_replica_groups()
|
||||
|
||||
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
|
||||
# TODO: Remove the groups specification after XLA compiler can support
|
||||
# auto-reordering the ring order for all-reduce.
|
||||
return xm.all_reduce(xm.REDUCE_SUM, input_, groups=self.groups)
|
||||
|
||||
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
||||
assert dim == -1, "TPUs only support dim=-1 for all-gather."
|
||||
return xm.all_gather(input_, dim=dim)
|
||||
|
||||
|
||||
try:
|
||||
from tpu_commons.distributed.device_communicators import (
|
||||
TpuCommunicator as TpuCommonsCommunicator)
|
||||
TpuCommunicator = TpuCommonsCommunicator # type: ignore
|
||||
except ImportError:
|
||||
logger.info("tpu_commons not found, using vLLM's TpuCommunicator")
|
||||
pass
|
||||
55
vllm/distributed/device_communicators/xpu_communicator.py
Normal file
55
vllm/distributed/device_communicators/xpu_communicator.py
Normal file
@@ -0,0 +1,55 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from .base_device_communicator import DeviceCommunicatorBase
|
||||
|
||||
|
||||
class XpuCommunicator(DeviceCommunicatorBase):
|
||||
|
||||
def __init__(self,
|
||||
cpu_group: ProcessGroup,
|
||||
device: Optional[torch.device] = None,
|
||||
device_group: Optional[ProcessGroup] = None,
|
||||
unique_name: str = ""):
|
||||
super().__init__(cpu_group, device, device_group, unique_name)
|
||||
|
||||
def all_reduce(self, input_) -> torch.Tensor:
|
||||
dist.all_reduce(input_, group=self.device_group)
|
||||
return input_
|
||||
|
||||
def gather(self,
|
||||
input_: torch.Tensor,
|
||||
dst: int = 0,
|
||||
dim: int = -1) -> Optional[torch.Tensor]:
|
||||
assert -input_.dim() <= dim < input_.dim(), (
|
||||
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
|
||||
if dim < 0:
|
||||
# Convert negative dim to positive.
|
||||
dim += input_.dim()
|
||||
# For xpu path, gather doesn't work properly together with ray
|
||||
# cluster so we use all_gather instead for now.
|
||||
input_size = input_.size()
|
||||
# Allocate output tensor.
|
||||
output_tensor = torch.empty((self.world_size, ) + input_size,
|
||||
dtype=input_.dtype,
|
||||
device=input_.device)
|
||||
# All-gather.
|
||||
dist.all_gather_into_tensor(output_tensor,
|
||||
input_,
|
||||
group=self.device_group)
|
||||
if self.rank_in_group == dst:
|
||||
# Reshape
|
||||
output_tensor = output_tensor.movedim(0, dim)
|
||||
output_tensor = output_tensor.reshape(input_size[:dim] +
|
||||
(self.world_size *
|
||||
input_size[dim], ) +
|
||||
input_size[dim + 1:])
|
||||
else:
|
||||
output_tensor = None
|
||||
return output_tensor
|
||||
Reference in New Issue
Block a user