[gpt-oss] Add gpt-oss bf16 support
This commit is contained in:
6
vllm/distributed/__init__.py
Normal file
6
vllm/distributed/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from .communication_op import *
|
||||
from .parallel_state import *
|
||||
from .utils import *
|
||||
41
vllm/distributed/communication_op.py
Normal file
41
vllm/distributed/communication_op.py
Normal file
@@ -0,0 +1,41 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from .parallel_state import get_tp_group
|
||||
|
||||
|
||||
def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
|
||||
"""All-reduce the input tensor across model parallel group."""
|
||||
return get_tp_group().all_reduce(input_)
|
||||
|
||||
|
||||
def tensor_model_parallel_all_gather(input_: torch.Tensor,
|
||||
dim: int = -1) -> torch.Tensor:
|
||||
"""All-gather the input tensor across model parallel group."""
|
||||
return get_tp_group().all_gather(input_, dim)
|
||||
|
||||
|
||||
def tensor_model_parallel_reduce_scatter(input_: torch.Tensor,
|
||||
dim: int = -1) -> torch.Tensor:
|
||||
"""Reduce-Scatter the input tensor across model parallel group."""
|
||||
return get_tp_group().reduce_scatter(input_, dim)
|
||||
|
||||
|
||||
def tensor_model_parallel_gather(input_: torch.Tensor,
|
||||
dst: int = 0,
|
||||
dim: int = -1) -> Optional[torch.Tensor]:
|
||||
"""Gather the input tensor across model parallel group."""
|
||||
return get_tp_group().gather(input_, dst, dim)
|
||||
|
||||
|
||||
def broadcast_tensor_dict(tensor_dict: Optional[dict[Any, Union[torch.Tensor,
|
||||
Any]]] = None,
|
||||
src: int = 0):
|
||||
if not torch.distributed.is_initialized():
|
||||
return tensor_dict
|
||||
return get_tp_group().broadcast_tensor_dict(tensor_dict, src)
|
||||
0
vllm/distributed/device_communicators/__init__.py
Normal file
0
vllm/distributed/device_communicators/__init__.py
Normal file
264
vllm/distributed/device_communicators/all2all.py
Normal file
264
vllm/distributed/device_communicators/all2all.py
Normal file
@@ -0,0 +1,264 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import importlib.util
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .base_device_communicator import All2AllManagerBase, Cache
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
||||
else:
|
||||
FusedMoE = None
|
||||
|
||||
|
||||
class NaiveAll2AllManager(All2AllManagerBase):
|
||||
"""
|
||||
A naive implementation of all2all communication.
|
||||
It uses all-reduce under the hood, which is not
|
||||
efficient at all. The main purpose is for testing and
|
||||
debugging.
|
||||
"""
|
||||
|
||||
def __init__(self, cpu_group):
|
||||
super().__init__(cpu_group)
|
||||
|
||||
def naive_multicast(self, x: torch.Tensor,
|
||||
cu_tokens_across_dp_cpu: torch.Tensor):
|
||||
assert (len(x.shape) == 2)
|
||||
buffer = torch.empty((cu_tokens_across_dp_cpu[-1], x.size(1)),
|
||||
device=x.device,
|
||||
dtype=x.dtype)
|
||||
|
||||
start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[
|
||||
self.dp_rank - 1]
|
||||
end = cu_tokens_across_dp_cpu[self.dp_rank]
|
||||
buffer[start:end, :].copy_(x)
|
||||
for idx in range(self.dp_world_size):
|
||||
start = 0 if idx == 0 else cu_tokens_across_dp_cpu[idx - 1]
|
||||
end = cu_tokens_across_dp_cpu[idx]
|
||||
self.dp_group.broadcast(buffer[start:end, :], idx)
|
||||
|
||||
return buffer
|
||||
|
||||
def dispatch(self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor):
|
||||
cu_tokens_across_dp_cpu = get_forward_context(
|
||||
).dp_metadata.cu_tokens_across_dp_cpu
|
||||
|
||||
hidden_states = self.naive_multicast(hidden_states,
|
||||
cu_tokens_across_dp_cpu)
|
||||
router_logits = self.naive_multicast(router_logits,
|
||||
cu_tokens_across_dp_cpu)
|
||||
return hidden_states, router_logits
|
||||
|
||||
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
cu_tokens_across_dp_cpu = get_forward_context(
|
||||
).dp_metadata.cu_tokens_across_dp_cpu
|
||||
start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[
|
||||
self.dp_rank - 1]
|
||||
end = cu_tokens_across_dp_cpu[self.dp_rank]
|
||||
|
||||
all_hidden_states = self.dp_group.all_reduce(hidden_states)
|
||||
hidden_states = all_hidden_states[start:end, :]
|
||||
return hidden_states
|
||||
|
||||
def destroy(self):
|
||||
pass
|
||||
|
||||
|
||||
class PPLXAll2AllManager(All2AllManagerBase):
|
||||
"""
|
||||
All2All communication based on PPLX kernels.
|
||||
"""
|
||||
|
||||
def __init__(self, cpu_group):
|
||||
has_pplx = importlib.util.find_spec("pplx_kernels") is not None
|
||||
assert has_pplx, "pplx_kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install pplx_kernels." # noqa
|
||||
super().__init__(cpu_group)
|
||||
|
||||
if self.internode:
|
||||
# inter-node communication needs nvshmem,
|
||||
# intra-node communication uses p2p mapping directly
|
||||
from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id,
|
||||
nvshmem_get_unique_id,
|
||||
nvshmem_init)
|
||||
logger.debug(
|
||||
"Initialize NVSHMEM for pplx_kernels: "
|
||||
"rank=%d, world size=%d", self.rank, self.world_size)
|
||||
uid = nvshmem_get_unique_id(
|
||||
) if self.rank == 0 else nvshmem_alloc_empty_unique_id()
|
||||
dist.broadcast(uid,
|
||||
src=dist.get_process_group_ranks(self.cpu_group)[0],
|
||||
group=self.cpu_group)
|
||||
logger.debug("PPLX NVSHMEM UID = %s", uid)
|
||||
nvshmem_init(uid, self.rank, self.world_size)
|
||||
|
||||
self.handle_cache = Cache()
|
||||
|
||||
def get_handle(self, kwargs):
|
||||
import pplx_kernels as pplx
|
||||
return self.handle_cache.get_or_create(
|
||||
kwargs, pplx.AllToAll.internode
|
||||
if self.internode else pplx.AllToAll.intranode)
|
||||
|
||||
def dispatch(self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor):
|
||||
raise NotImplementedError
|
||||
|
||||
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def destroy(self):
|
||||
with self.handle_cache._lock:
|
||||
for _, handle in self.handle_cache._cache.items():
|
||||
handle.destroy()
|
||||
|
||||
if self.internode:
|
||||
from pplx_kernels.nvshmem import nvshmem_finalize
|
||||
logger.debug("PPLX NVSHMEM finalize")
|
||||
nvshmem_finalize()
|
||||
|
||||
|
||||
class DeepEPAll2AllManagerBase(All2AllManagerBase):
|
||||
"""
|
||||
All2All communication based on DeepEP High-Throughput kernels.
|
||||
"""
|
||||
|
||||
def __init__(self, cpu_group):
|
||||
has_deepep = importlib.util.find_spec("deep_ep") is not None
|
||||
assert has_deepep, "DeepEP kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install DeepEP kernels." # noqa
|
||||
super().__init__(cpu_group)
|
||||
self.handle_cache = Cache()
|
||||
|
||||
# This is the DeepEP default. Stick to it till we can establish
|
||||
# reasonable defaults based on profiling.
|
||||
self.num_sms = 20
|
||||
|
||||
def get_handle(self, kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def dispatch(self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor):
|
||||
raise NotImplementedError
|
||||
|
||||
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def destroy(self):
|
||||
pass
|
||||
|
||||
|
||||
class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
|
||||
"""
|
||||
All2All communication based on DeepEP High-Throughput kernels.
|
||||
"""
|
||||
|
||||
def __init__(self, cpu_group):
|
||||
super().__init__(cpu_group)
|
||||
|
||||
def _make_all2all_kwargs(self) -> dict[Any, Any]:
|
||||
# Defaults for internode and intranode are taken from DeepEP tests.
|
||||
num_nvl_bytes = 1024 * 1024 * 1024
|
||||
num_rdma_bytes = None
|
||||
num_qps_per_rank = None
|
||||
|
||||
if self.internode:
|
||||
num_rdma_bytes = 1024 * 1024 * 1024
|
||||
num_qps_per_rank = self.num_sms // 2
|
||||
else:
|
||||
num_rdma_bytes = 0
|
||||
num_qps_per_rank = 1
|
||||
|
||||
assert num_rdma_bytes is not None
|
||||
assert num_qps_per_rank is not None
|
||||
return dict(group=self.cpu_group,
|
||||
num_nvl_bytes=num_nvl_bytes,
|
||||
num_rdma_bytes=num_rdma_bytes,
|
||||
low_latency_mode=False,
|
||||
num_qps_per_rank=num_qps_per_rank)
|
||||
|
||||
def get_handle(self, kwargs):
|
||||
|
||||
assert len(kwargs) == 0, (
|
||||
"DeepEPHTAll2AllManager expects no arguments. All the required "
|
||||
"args are computed in the Manager itself.")
|
||||
|
||||
import deep_ep
|
||||
buffer_kwargs = self._make_all2all_kwargs()
|
||||
logger.debug("DeepEP all2all args %s", buffer_kwargs)
|
||||
handle: deep_ep.Buffer = self.handle_cache.get_or_create(
|
||||
buffer_kwargs, deep_ep.Buffer)
|
||||
# It is dangerous to set num sms outside this function. num_sms is not
|
||||
# a part of the hash-key that identifies this object. If we are in a
|
||||
# situation where we make objects with different num_sms, the hash key
|
||||
# in get_or_create must be updated.
|
||||
handle.set_num_sms(self.num_sms)
|
||||
return handle
|
||||
|
||||
|
||||
class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
|
||||
"""
|
||||
All2All communication based on DeepEP Low-Latency kernels.
|
||||
"""
|
||||
|
||||
def __init__(self, cpu_group):
|
||||
super().__init__(cpu_group)
|
||||
|
||||
def _make_all2all_kwargs(
|
||||
self,
|
||||
max_num_tokens_per_dp_rank: int,
|
||||
token_hidden_size: int,
|
||||
num_ep_ranks: int,
|
||||
num_global_experts: int,
|
||||
num_local_experts: int,
|
||||
) -> dict[Any, Any]:
|
||||
"""
|
||||
max_num_tokens_per_dp_rank : the maximum number of tokens a DP rank
|
||||
can dispatch all the ranks must hold the same value.
|
||||
token_hidden_size: the hidden dimension of each token.
|
||||
num_ep_ranks: the number of EP group ranks.
|
||||
num_global_experts: Number of experts in the model.
|
||||
num_local_experts: Number of experts in an EP rank.
|
||||
"""
|
||||
import deep_ep
|
||||
|
||||
# Defaults for internode and intranode are taken from DeepEP tests.
|
||||
num_nvl_bytes = 1024 * 1024 * 1024
|
||||
num_qps_per_rank = num_local_experts
|
||||
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(
|
||||
num_max_dispatch_tokens_per_rank=max_num_tokens_per_dp_rank,
|
||||
hidden=token_hidden_size,
|
||||
num_ranks=num_ep_ranks,
|
||||
num_experts=num_global_experts)
|
||||
|
||||
assert num_rdma_bytes is not None
|
||||
return dict(group=self.cpu_group,
|
||||
num_nvl_bytes=num_nvl_bytes,
|
||||
num_rdma_bytes=num_rdma_bytes,
|
||||
low_latency_mode=True,
|
||||
num_qps_per_rank=num_qps_per_rank)
|
||||
|
||||
def get_handle(self, kwargs):
|
||||
"""
|
||||
The kwargs for DeepEPLLAll2AllManager is dictated by
|
||||
_make_all2all_kwargs.
|
||||
"""
|
||||
import deep_ep
|
||||
buffer_kwargs = self._make_all2all_kwargs(**kwargs)
|
||||
logger.debug("DeepEP all2all args %s", buffer_kwargs)
|
||||
handle: deep_ep.Buffer = self.handle_cache.get_or_create(
|
||||
buffer_kwargs, deep_ep.Buffer)
|
||||
# It is dangerous to set num sms outside this function. num_sms is not
|
||||
# a part of the hash-key that identifies this object. If we are in a
|
||||
# situation where we make objects with different num_sms, the hash key
|
||||
# in get_or_create must be updated.
|
||||
handle.set_num_sms(self.num_sms)
|
||||
return handle
|
||||
@@ -0,0 +1,260 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import threading
|
||||
from typing import Optional
|
||||
from weakref import WeakValueDictionary
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
|
||||
class Cache:
|
||||
|
||||
def __init__(self):
|
||||
self._cache: WeakValueDictionary = WeakValueDictionary()
|
||||
self._lock = threading.RLock() # Reentrant lock for thread safety
|
||||
|
||||
def get_or_create(self, kwargs, func):
|
||||
# Create a hashable key from the kwargs
|
||||
key = tuple(sorted((k, v) for k, v in kwargs.items()))
|
||||
|
||||
with self._lock:
|
||||
instance = self._cache.get(key)
|
||||
if instance is None:
|
||||
instance = func(**kwargs)
|
||||
self._cache[key] = instance
|
||||
return instance
|
||||
|
||||
|
||||
class All2AllManagerBase:
|
||||
|
||||
def __init__(self, cpu_group):
|
||||
self.cpu_group = cpu_group
|
||||
|
||||
# compute some common properties
|
||||
from vllm.distributed.parallel_state import (get_dp_group,
|
||||
get_tp_group,
|
||||
in_the_same_node_as)
|
||||
|
||||
# all2all lives in ep group, which is merged from dp and tp group
|
||||
self.dp_group = get_dp_group()
|
||||
self.tp_group = get_tp_group()
|
||||
# no self.ep_group since self.ep_group is still in construction
|
||||
# when we create this object
|
||||
self.dp_rank = self.dp_group.rank_in_group
|
||||
self.dp_world_size = self.dp_group.world_size
|
||||
self.rank = dist.get_rank(cpu_group)
|
||||
self.world_size = dist.get_world_size(cpu_group)
|
||||
|
||||
# all2all communication often has separate implementations for
|
||||
# intra-node and inter-node communication
|
||||
self.internode = not all(in_the_same_node_as(cpu_group, source_rank=0))
|
||||
|
||||
def get_handle(self, kwargs):
|
||||
# get a handle for the all2all communication,
|
||||
# based on the kwargs.
|
||||
# different layers can have different configs,
|
||||
# e.g. one layer has hidden size 1024, another has 2048.
|
||||
# usually the underlying implementation caches the handle
|
||||
# and reuse it for the same config.
|
||||
raise NotImplementedError
|
||||
|
||||
def dispatch(self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor):
|
||||
raise NotImplementedError
|
||||
|
||||
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def destroy(self):
|
||||
pass
|
||||
|
||||
|
||||
class DeviceCommunicatorBase:
|
||||
"""
|
||||
Base class for device-specific communicator.
|
||||
It can use the `cpu_group` to initialize the communicator.
|
||||
If the device has PyTorch integration (PyTorch can recognize its
|
||||
communication backend), the `device_group` will also be given.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
cpu_group: ProcessGroup,
|
||||
device: Optional[torch.device] = None,
|
||||
device_group: Optional[ProcessGroup] = None,
|
||||
unique_name: str = ""):
|
||||
self.device = device or torch.device("cpu")
|
||||
self.cpu_group = cpu_group
|
||||
self.device_group = device_group
|
||||
self.unique_name = unique_name
|
||||
self.rank = dist.get_rank(cpu_group)
|
||||
self.world_size = dist.get_world_size(cpu_group)
|
||||
self.ranks = dist.get_process_group_ranks(cpu_group)
|
||||
self.global_rank = dist.get_rank()
|
||||
self.global_world_size = dist.get_world_size()
|
||||
self.rank_in_group = dist.get_group_rank(self.cpu_group,
|
||||
self.global_rank)
|
||||
|
||||
use_ep = False
|
||||
from vllm.config import get_current_vllm_config
|
||||
config = get_current_vllm_config()
|
||||
if config is not None:
|
||||
# as long as we use data parallel (coupled data parallel
|
||||
# where all data parallel ranks execute forward together),
|
||||
# we initialize the all2all manager used in expert parallel.
|
||||
use_ep = config.parallel_config.data_parallel_size > 1
|
||||
|
||||
self.use_all2all = "ep" in unique_name and use_ep
|
||||
self.all2all_manager: Optional[All2AllManagerBase] = None
|
||||
|
||||
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
|
||||
dist.all_reduce(input_, group=self.device_group)
|
||||
return input_
|
||||
|
||||
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
||||
if dim < 0:
|
||||
# Convert negative dim to positive.
|
||||
dim += input_.dim()
|
||||
input_size = input_.size()
|
||||
# NOTE: we have to use concat-style all-gather here,
|
||||
# stack-style all-gather has compatibility issues with
|
||||
# torch.compile . see https://github.com/pytorch/pytorch/issues/138795
|
||||
output_size = (input_size[0] * self.world_size, ) + input_size[1:]
|
||||
# Allocate output tensor.
|
||||
output_tensor = torch.empty(output_size,
|
||||
dtype=input_.dtype,
|
||||
device=input_.device)
|
||||
# All-gather.
|
||||
dist.all_gather_into_tensor(output_tensor,
|
||||
input_,
|
||||
group=self.device_group)
|
||||
# Reshape
|
||||
output_tensor = output_tensor.reshape((self.world_size, ) + input_size)
|
||||
output_tensor = output_tensor.movedim(0, dim)
|
||||
output_tensor = output_tensor.reshape(input_size[:dim] +
|
||||
(self.world_size *
|
||||
input_size[dim], ) +
|
||||
input_size[dim + 1:])
|
||||
return output_tensor
|
||||
|
||||
def reduce_scatter(self,
|
||||
input_: torch.Tensor,
|
||||
dim: int = -1) -> torch.Tensor:
|
||||
world_size = self.world_size
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if world_size == 1:
|
||||
return input_
|
||||
assert -input_.dim() <= dim < input_.dim(), (
|
||||
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
|
||||
|
||||
if dim < 0:
|
||||
# Convert negative dim to positive.
|
||||
dim += input_.dim()
|
||||
|
||||
# Note: This will produce an incorrect answer if we don't make
|
||||
# the input_tensor contiguous. Possible bug in reduce_scatter_tensor?
|
||||
input_tensor = input_.movedim(0, dim).contiguous()
|
||||
|
||||
assert input_tensor.shape[0] % world_size == 0
|
||||
chunk_size = input_tensor.shape[0] // world_size
|
||||
output_shape = (chunk_size, ) + input_tensor.shape[1:]
|
||||
|
||||
output_tensor = torch.empty(output_shape,
|
||||
dtype=input_tensor.dtype,
|
||||
device=input_tensor.device)
|
||||
|
||||
# Perform reduce-scatter operation
|
||||
torch.distributed.reduce_scatter_tensor(output_tensor,
|
||||
input_tensor,
|
||||
group=self.device_group)
|
||||
|
||||
# Reshape before returning
|
||||
return output_tensor.movedim(0, dim).contiguous()
|
||||
|
||||
def gather(self,
|
||||
input_: torch.Tensor,
|
||||
dst: int = 0,
|
||||
dim: int = -1) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
NOTE: We assume that the input tensor is on the same device across
|
||||
all the ranks.
|
||||
NOTE: `dst` is the local rank of the destination rank.
|
||||
"""
|
||||
world_size = self.world_size
|
||||
assert -input_.dim() <= dim < input_.dim(), (
|
||||
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
|
||||
if dim < 0:
|
||||
# Convert negative dim to positive.
|
||||
dim += input_.dim()
|
||||
|
||||
# Allocate output tensor.
|
||||
if self.rank_in_group == dst:
|
||||
gather_list = [torch.empty_like(input_) for _ in range(world_size)]
|
||||
else:
|
||||
gather_list = None
|
||||
# Gather.
|
||||
torch.distributed.gather(input_,
|
||||
gather_list,
|
||||
dst=self.ranks[dst],
|
||||
group=self.device_group)
|
||||
if self.rank_in_group == dst:
|
||||
output_tensor = torch.cat(gather_list, dim=dim)
|
||||
else:
|
||||
output_tensor = None
|
||||
return output_tensor
|
||||
|
||||
def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:
|
||||
"""Sends a tensor to the destination rank in a non-blocking way"""
|
||||
"""NOTE: `dst` is the local rank of the destination rank."""
|
||||
if dst is None:
|
||||
dst = (self.rank_in_group + 1) % self.world_size
|
||||
torch.distributed.send(tensor, self.ranks[dst], self.device_group)
|
||||
|
||||
def recv(self,
|
||||
size: torch.Size,
|
||||
dtype: torch.dtype,
|
||||
src: Optional[int] = None) -> torch.Tensor:
|
||||
"""Receives a tensor from the source rank."""
|
||||
"""NOTE: `src` is the local rank of the source rank."""
|
||||
if src is None:
|
||||
src = (self.rank_in_group - 1) % self.world_size
|
||||
|
||||
tensor = torch.empty(size, dtype=dtype, device=self.device)
|
||||
torch.distributed.recv(tensor, self.ranks[src], self.device_group)
|
||||
return tensor
|
||||
|
||||
def destroy(self):
|
||||
pass
|
||||
|
||||
def prepare_communication_buffer_for_model(self,
|
||||
model: torch.nn.Module) -> None:
|
||||
"""
|
||||
Prepare the communication buffer for the model.
|
||||
"""
|
||||
if not self.use_all2all:
|
||||
return
|
||||
|
||||
moe_modules = [
|
||||
module for module in model.modules()
|
||||
if module.__class__.__name__ == "FusedMoE"
|
||||
]
|
||||
for module in moe_modules:
|
||||
module.quant_method.init_prepare_finalize(module.moe_config,
|
||||
module.quant_config)
|
||||
|
||||
def dispatch(
|
||||
self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Dispatch the hidden states and router logits to the appropriate device.
|
||||
This is a no-op in the base class.
|
||||
"""
|
||||
return hidden_states, router_logits
|
||||
|
||||
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Combine the hidden states and router logits from the appropriate device.
|
||||
This is a no-op in the base class.
|
||||
"""
|
||||
return hidden_states
|
||||
145
vllm/distributed/device_communicators/cpu_communicator.py
Normal file
145
vllm/distributed/device_communicators/cpu_communicator.py
Normal file
@@ -0,0 +1,145 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.interface import CpuArchEnum
|
||||
|
||||
from .base_device_communicator import DeviceCommunicatorBase
|
||||
|
||||
|
||||
class CpuCommunicator(DeviceCommunicatorBase):
|
||||
|
||||
def __init__(self,
|
||||
cpu_group: ProcessGroup,
|
||||
device: Optional[torch.device] = None,
|
||||
device_group: Optional[ProcessGroup] = None,
|
||||
unique_name: str = ""):
|
||||
super().__init__(cpu_group, device, device_group, unique_name)
|
||||
self.dist_module = torch.distributed
|
||||
|
||||
if (current_platform.get_cpu_architecture()
|
||||
== CpuArchEnum.X86) and hasattr(
|
||||
torch.ops._C,
|
||||
"init_shm_manager") and unique_name.startswith("tp"):
|
||||
self.dist_module = _CPUSHMDistributed(self)
|
||||
|
||||
def all_reduce(self, input_):
|
||||
self.dist_module.all_reduce(input_, group=self.device_group)
|
||||
return input_
|
||||
|
||||
def gather(self,
|
||||
input_: torch.Tensor,
|
||||
dst: int = 0,
|
||||
dim: int = -1) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
NOTE: We assume that the input tensor is on the same device across
|
||||
all the ranks.
|
||||
NOTE: `dst` is the local rank of the destination rank.
|
||||
"""
|
||||
world_size = self.world_size
|
||||
assert -input_.dim() <= dim < input_.dim(), (
|
||||
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
|
||||
if dim < 0:
|
||||
# Convert negative dim to positive.
|
||||
dim += input_.dim()
|
||||
|
||||
# Allocate output tensor.
|
||||
if self.rank_in_group == dst:
|
||||
gather_list = [torch.empty_like(input_) for _ in range(world_size)]
|
||||
else:
|
||||
gather_list = None
|
||||
|
||||
# Gather.
|
||||
self.dist_module.gather(input_,
|
||||
gather_list,
|
||||
dst=self.ranks[dst],
|
||||
group=self.device_group)
|
||||
|
||||
if self.rank_in_group == dst:
|
||||
output_tensor = torch.cat(gather_list, dim=dim)
|
||||
else:
|
||||
output_tensor = None
|
||||
return output_tensor
|
||||
|
||||
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
||||
if dim < 0:
|
||||
# Convert negative dim to positive.
|
||||
dim += input_.dim()
|
||||
input_size = input_.size()
|
||||
# NOTE: we have to use concat-style all-gather here,
|
||||
# stack-style all-gather has compatibility issues with
|
||||
# torch.compile . see https://github.com/pytorch/pytorch/issues/138795
|
||||
output_size = (input_size[0] * self.world_size, ) + input_size[1:]
|
||||
# Allocate output tensor.
|
||||
output_tensor = torch.empty(output_size,
|
||||
dtype=input_.dtype,
|
||||
device=input_.device)
|
||||
# All-gather.
|
||||
self.dist_module.all_gather_into_tensor(output_tensor,
|
||||
input_,
|
||||
group=self.device_group)
|
||||
|
||||
# Reshape
|
||||
output_tensor = output_tensor.reshape((self.world_size, ) + input_size)
|
||||
output_tensor = output_tensor.movedim(0, dim)
|
||||
output_tensor = output_tensor.reshape(input_size[:dim] +
|
||||
(self.world_size *
|
||||
input_size[dim], ) +
|
||||
input_size[dim + 1:])
|
||||
return output_tensor
|
||||
|
||||
|
||||
class _CPUSHMDistributed:
|
||||
|
||||
def __init__(self, communicator: CpuCommunicator):
|
||||
instance_identifier = os.environ["VLLM_DIST_IDENT"]
|
||||
unique_name = communicator.unique_name
|
||||
instance_identifier = f"{instance_identifier}-{unique_name}"
|
||||
self.communicator = communicator
|
||||
|
||||
group_ranks = [str(rank) for rank in self.communicator.ranks]
|
||||
shm_group_identifier = f"[{'-'.join(group_ranks)}]"
|
||||
self.group_name = f"{instance_identifier}-{shm_group_identifier}-cpushm"
|
||||
|
||||
self.handle = self._init_cpu_shm()
|
||||
|
||||
def _init_cpu_shm(self) -> int:
|
||||
handle = torch.ops._C.init_shm_manager(
|
||||
self.group_name,
|
||||
self.communicator.world_size,
|
||||
self.communicator.rank,
|
||||
)
|
||||
torch.distributed.barrier(self.communicator.device_group)
|
||||
torch.ops._C.join_shm_manager(
|
||||
handle,
|
||||
self.group_name,
|
||||
)
|
||||
torch.distributed.barrier(self.communicator.device_group)
|
||||
|
||||
return handle
|
||||
|
||||
def all_reduce(self,
|
||||
input: torch.Tensor,
|
||||
group: Optional[ProcessGroup] = None) -> None:
|
||||
torch.ops._C.shm_allreduce(self.handle, input)
|
||||
|
||||
def gather(self,
|
||||
input: torch.Tensor,
|
||||
gather_list: Optional[list[torch.Tensor]],
|
||||
dst: int = -1,
|
||||
group: Optional[ProcessGroup] = None) -> None:
|
||||
# Note: different from the torch gather, here we use local dst rank.
|
||||
torch.ops._C.shm_gather(self.handle, input, gather_list,
|
||||
torch.distributed.get_group_rank(group, dst))
|
||||
|
||||
def all_gather_into_tensor(self,
|
||||
output: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
group: Optional[ProcessGroup] = None) -> None:
|
||||
torch.ops._C.shm_all_gather(self.handle, input, output)
|
||||
176
vllm/distributed/device_communicators/cuda_communicator.py
Normal file
176
vllm/distributed/device_communicators/cuda_communicator.py
Normal file
@@ -0,0 +1,176 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .base_device_communicator import DeviceCommunicatorBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class CudaCommunicator(DeviceCommunicatorBase):
|
||||
|
||||
def __init__(self,
|
||||
cpu_group: ProcessGroup,
|
||||
device: Optional[torch.device] = None,
|
||||
device_group: Optional[ProcessGroup] = None,
|
||||
unique_name: str = ""):
|
||||
super().__init__(cpu_group, device, device_group, unique_name)
|
||||
if "tp" not in unique_name:
|
||||
# only tp uses custom allreduce
|
||||
use_custom_allreduce = False
|
||||
else:
|
||||
from vllm.distributed.parallel_state import (
|
||||
_ENABLE_CUSTOM_ALL_REDUCE)
|
||||
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
|
||||
|
||||
# ep does not use pynccl
|
||||
use_pynccl = "ep" not in unique_name
|
||||
|
||||
self.use_pynccl = use_pynccl
|
||||
self.use_custom_allreduce = use_custom_allreduce
|
||||
|
||||
# lazy import to avoid documentation build error
|
||||
from vllm.distributed.device_communicators.custom_all_reduce import (
|
||||
CustomAllreduce)
|
||||
from vllm.distributed.device_communicators.pynccl import (
|
||||
PyNcclCommunicator)
|
||||
|
||||
self.pynccl_comm: Optional[PyNcclCommunicator] = None
|
||||
if use_pynccl and self.world_size > 1:
|
||||
self.pynccl_comm = PyNcclCommunicator(
|
||||
group=self.cpu_group,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
self.ca_comm: Optional[CustomAllreduce] = None
|
||||
if use_custom_allreduce and self.world_size > 1:
|
||||
# Initialize a custom fast all-reduce implementation.
|
||||
self.ca_comm = CustomAllreduce(
|
||||
group=self.cpu_group,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
if self.use_all2all:
|
||||
all2all_backend = envs.VLLM_ALL2ALL_BACKEND
|
||||
if all2all_backend == "naive":
|
||||
from .all2all import NaiveAll2AllManager
|
||||
self.all2all_manager = NaiveAll2AllManager(self.cpu_group)
|
||||
logger.info("Using naive all2all manager.")
|
||||
elif all2all_backend == "pplx":
|
||||
from .all2all import PPLXAll2AllManager
|
||||
self.all2all_manager = PPLXAll2AllManager(self.cpu_group)
|
||||
logger.info("Using PPLX all2all manager.")
|
||||
elif all2all_backend == "deepep_high_throughput":
|
||||
from .all2all import DeepEPHTAll2AllManager
|
||||
self.all2all_manager = DeepEPHTAll2AllManager(self.cpu_group)
|
||||
logger.info("Using DeepEP High-Throughput all2all manager.")
|
||||
elif all2all_backend == "deepep_low_latency":
|
||||
from .all2all import DeepEPLLAll2AllManager
|
||||
self.all2all_manager = DeepEPLLAll2AllManager(self.cpu_group)
|
||||
logger.info("Using DeepEP Low-Latency all2all manager.")
|
||||
else:
|
||||
raise ValueError(f"Unknown all2all backend: {all2all_backend}")
|
||||
|
||||
def all_reduce(self, input_):
|
||||
# always try custom allreduce first,
|
||||
# and then pynccl.
|
||||
ca_comm = self.ca_comm
|
||||
if ca_comm is not None and not ca_comm.disabled and \
|
||||
ca_comm.should_custom_ar(input_):
|
||||
out = ca_comm.custom_all_reduce(input_)
|
||||
assert out is not None
|
||||
return out
|
||||
pynccl_comm = self.pynccl_comm
|
||||
assert pynccl_comm is not None
|
||||
out = pynccl_comm.all_reduce(input_)
|
||||
if out is None:
|
||||
# fall back to the default all-reduce using PyTorch.
|
||||
# this usually happens during testing.
|
||||
# when we run the model, allreduce only happens for the TP
|
||||
# group, where we always have either custom allreduce or pynccl.
|
||||
out = input_.clone()
|
||||
torch.distributed.all_reduce(out, group=self.device_group)
|
||||
return out
|
||||
|
||||
def reduce_scatter(self, input_: torch.Tensor, dim: int = -1):
|
||||
world_size = self.world_size
|
||||
pynccl_comm = self.pynccl_comm
|
||||
assert pynccl_comm is not None
|
||||
if dim < 0:
|
||||
# Convert negative dim to positive.
|
||||
dim += input_.dim()
|
||||
|
||||
# Note: This will produce an incorrect answer if we don't make
|
||||
# the input_tensor contiguous. Possible bug in reduce_scatter_tensor?
|
||||
input_tensor = input_.movedim(0, dim).contiguous()
|
||||
|
||||
assert input_tensor.shape[0] % world_size == 0
|
||||
chunk_size = input_tensor.shape[0] // world_size
|
||||
output_shape = (chunk_size, ) + input_tensor.shape[1:]
|
||||
|
||||
output = torch.empty(output_shape,
|
||||
dtype=input_tensor.dtype,
|
||||
device=input_tensor.device)
|
||||
|
||||
pynccl_comm.reduce_scatter(output, input_)
|
||||
|
||||
# Reshape before returning
|
||||
return output.movedim(0, dim).contiguous()
|
||||
|
||||
def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:
|
||||
"""Sends a tensor to the destination rank in a non-blocking way"""
|
||||
"""NOTE: `dst` is the local rank of the destination rank."""
|
||||
if dst is None:
|
||||
dst = (self.rank_in_group + 1) % self.world_size
|
||||
|
||||
pynccl_comm = self.pynccl_comm
|
||||
if pynccl_comm is not None and not pynccl_comm.disabled:
|
||||
pynccl_comm.send(tensor, dst)
|
||||
else:
|
||||
torch.distributed.send(tensor, self.ranks[dst], self.device_group)
|
||||
|
||||
def recv(self,
|
||||
size: torch.Size,
|
||||
dtype: torch.dtype,
|
||||
src: Optional[int] = None) -> torch.Tensor:
|
||||
"""Receives a tensor from the source rank."""
|
||||
"""NOTE: `src` is the local rank of the source rank."""
|
||||
if src is None:
|
||||
src = (self.rank_in_group - 1) % self.world_size
|
||||
|
||||
tensor = torch.empty(size, dtype=dtype, device=self.device)
|
||||
pynccl_comm = self.pynccl_comm
|
||||
if pynccl_comm is not None and not pynccl_comm.disabled:
|
||||
pynccl_comm.recv(tensor, src)
|
||||
else:
|
||||
torch.distributed.recv(tensor, self.ranks[src], self.device_group)
|
||||
return tensor
|
||||
|
||||
def destroy(self):
|
||||
if self.pynccl_comm is not None:
|
||||
self.pynccl_comm = None
|
||||
if self.ca_comm is not None:
|
||||
self.ca_comm = None
|
||||
if self.all2all_manager is not None:
|
||||
self.all2all_manager.destroy()
|
||||
self.all2all_manager = None
|
||||
|
||||
def dispatch(
|
||||
self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self.all2all_manager is not None
|
||||
hidden_states, router_logits = self.all2all_manager.dispatch(
|
||||
hidden_states, router_logits)
|
||||
return hidden_states, router_logits
|
||||
|
||||
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
assert self.all2all_manager is not None
|
||||
hidden_states = self.all2all_manager.combine(hidden_states)
|
||||
return hidden_states
|
||||
180
vllm/distributed/device_communicators/cuda_wrapper.py
Normal file
180
vllm/distributed/device_communicators/cuda_wrapper.py
Normal file
@@ -0,0 +1,180 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""This file is a pure Python wrapper for the cudart library.
|
||||
It avoids the need to compile a separate shared library, and is
|
||||
convenient for use when we just need to call a few functions.
|
||||
"""
|
||||
|
||||
import ctypes
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
|
||||
# this line makes it possible to directly load `libcudart.so` using `ctypes`
|
||||
import torch # noqa
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# === export types and functions from cudart to Python ===
|
||||
# for the original cudart definition, please check
|
||||
# https://docs.nvidia.com/cuda/cuda-runtime-api/index.html
|
||||
|
||||
cudaError_t = ctypes.c_int
|
||||
cudaMemcpyKind = ctypes.c_int
|
||||
|
||||
|
||||
class cudaIpcMemHandle_t(ctypes.Structure):
|
||||
_fields_ = [("internal", ctypes.c_byte * 128)]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Function:
|
||||
name: str
|
||||
restype: Any
|
||||
argtypes: list[Any]
|
||||
|
||||
|
||||
def find_loaded_library(lib_name) -> Optional[str]:
|
||||
"""
|
||||
According to according to https://man7.org/linux/man-pages/man5/proc_pid_maps.5.html,
|
||||
the file `/proc/self/maps` contains the memory maps of the process, which includes the
|
||||
shared libraries loaded by the process. We can use this file to find the path of the
|
||||
a loaded library.
|
||||
""" # noqa
|
||||
found = False
|
||||
with open("/proc/self/maps") as f:
|
||||
for line in f:
|
||||
if lib_name in line:
|
||||
found = True
|
||||
break
|
||||
if not found:
|
||||
# the library is not loaded in the current process
|
||||
return None
|
||||
# if lib_name is libcudart, we need to match a line with:
|
||||
# address /path/to/libcudart-hash.so.11.0
|
||||
start = line.index("/")
|
||||
path = line[start:].strip()
|
||||
filename = path.split("/")[-1]
|
||||
assert filename.rpartition(".so")[0].startswith(lib_name), \
|
||||
f"Unexpected filename: {filename} for library {lib_name}"
|
||||
return path
|
||||
|
||||
|
||||
class CudaRTLibrary:
|
||||
exported_functions = [
|
||||
# cudaError_t cudaSetDevice ( int device )
|
||||
Function("mcSetDevice", cudaError_t, [ctypes.c_int]),
|
||||
# cudaError_t cudaDeviceSynchronize ( void )
|
||||
Function("mcDeviceSynchronize", cudaError_t, []),
|
||||
# cudaError_t cudaDeviceReset ( void )
|
||||
Function("mcDeviceReset", cudaError_t, []),
|
||||
|
||||
# const char* cudaGetErrorString ( cudaError_t error )
|
||||
Function("mcGetErrorString", ctypes.c_char_p, [cudaError_t]),
|
||||
|
||||
# cudaError_t cudaMalloc ( void** devPtr, size_t size )
|
||||
Function("mcMalloc", cudaError_t,
|
||||
[ctypes.POINTER(ctypes.c_void_p), ctypes.c_size_t]),
|
||||
# cudaError_t cudaFree ( void* devPtr )
|
||||
Function("mcFree", cudaError_t, [ctypes.c_void_p]),
|
||||
# cudaError_t cudaMemset ( void* devPtr, int value, size_t count )
|
||||
Function("mcMemset", cudaError_t,
|
||||
[ctypes.c_void_p, ctypes.c_int, ctypes.c_size_t]),
|
||||
# cudaError_t cudaMemcpy ( void* dst, const void* src, size_t count, cudaMemcpyKind kind ) # noqa
|
||||
Function("mcMemcpy", cudaError_t, [
|
||||
ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, cudaMemcpyKind
|
||||
]),
|
||||
|
||||
# cudaError_t cudaIpcGetMemHandle ( cudaIpcMemHandle_t* handle, void* devPtr ) # noqa
|
||||
Function("mcIpcGetMemHandle", cudaError_t,
|
||||
[ctypes.POINTER(cudaIpcMemHandle_t), ctypes.c_void_p]),
|
||||
# cudaError_t cudaIpcOpenMemHandle ( void** devPtr, cudaIpcMemHandle_t handle, unsigned int flags ) # noqa
|
||||
Function("mcIpcOpenMemHandle", cudaError_t, [
|
||||
ctypes.POINTER(ctypes.c_void_p), cudaIpcMemHandle_t, ctypes.c_uint
|
||||
]),
|
||||
]
|
||||
|
||||
# class attribute to store the mapping from the path to the library
|
||||
# to avoid loading the same library multiple times
|
||||
path_to_library_cache: dict[str, Any] = {}
|
||||
|
||||
# class attribute to store the mapping from library path
|
||||
# to the corresponding dictionary
|
||||
path_to_dict_mapping: dict[str, dict[str, Any]] = {}
|
||||
|
||||
def __init__(self, so_file: Optional[str] = None):
|
||||
if so_file is None:
|
||||
so_file = find_loaded_library("libmcruntime")
|
||||
if so_file is None:
|
||||
so_file = envs.VLLM_CUDART_SO_PATH # fallback to env var
|
||||
assert so_file is not None, \
|
||||
(
|
||||
"libcudart is not loaded in the current process, "
|
||||
"try setting VLLM_CUDART_SO_PATH"
|
||||
)
|
||||
if so_file not in CudaRTLibrary.path_to_library_cache:
|
||||
lib = ctypes.CDLL(so_file)
|
||||
CudaRTLibrary.path_to_library_cache[so_file] = lib
|
||||
self.lib = CudaRTLibrary.path_to_library_cache[so_file]
|
||||
|
||||
if so_file not in CudaRTLibrary.path_to_dict_mapping:
|
||||
_funcs = {}
|
||||
for func in CudaRTLibrary.exported_functions:
|
||||
f = getattr(self.lib, func.name)
|
||||
f.restype = func.restype
|
||||
f.argtypes = func.argtypes
|
||||
_funcs[func.name] = f
|
||||
CudaRTLibrary.path_to_dict_mapping[so_file] = _funcs
|
||||
self.funcs = CudaRTLibrary.path_to_dict_mapping[so_file]
|
||||
|
||||
def CUDART_CHECK(self, result: cudaError_t) -> None:
|
||||
if result != 0:
|
||||
error_str = self.cudaGetErrorString(result)
|
||||
raise RuntimeError(f"CUDART error: {error_str}")
|
||||
|
||||
def cudaGetErrorString(self, error: cudaError_t) -> str:
|
||||
return self.funcs["mcGetErrorString"](error).decode("utf-8")
|
||||
|
||||
def cudaSetDevice(self, device: int) -> None:
|
||||
self.CUDART_CHECK(self.funcs["mcSetDevice"](device))
|
||||
|
||||
def cudaDeviceSynchronize(self) -> None:
|
||||
self.CUDART_CHECK(self.funcs["mcDeviceSynchronize"]())
|
||||
|
||||
def cudaDeviceReset(self) -> None:
|
||||
self.CUDART_CHECK(self.funcs["mcDeviceReset"]())
|
||||
|
||||
def cudaMalloc(self, size: int) -> ctypes.c_void_p:
|
||||
devPtr = ctypes.c_void_p()
|
||||
self.CUDART_CHECK(self.funcs["mcMalloc"](ctypes.byref(devPtr), size))
|
||||
return devPtr
|
||||
|
||||
def cudaFree(self, devPtr: ctypes.c_void_p) -> None:
|
||||
self.CUDART_CHECK(self.funcs["mcFree"](devPtr))
|
||||
|
||||
def cudaMemset(self, devPtr: ctypes.c_void_p, value: int,
|
||||
count: int) -> None:
|
||||
self.CUDART_CHECK(self.funcs["mcMemset"](devPtr, value, count))
|
||||
|
||||
def cudaMemcpy(self, dst: ctypes.c_void_p, src: ctypes.c_void_p,
|
||||
count: int) -> None:
|
||||
cudaMemcpyDefault = 4
|
||||
kind = cudaMemcpyDefault
|
||||
self.CUDART_CHECK(self.funcs["mcMemcpy"](dst, src, count, kind))
|
||||
|
||||
def cudaIpcGetMemHandle(self,
|
||||
devPtr: ctypes.c_void_p) -> cudaIpcMemHandle_t:
|
||||
handle = cudaIpcMemHandle_t()
|
||||
self.CUDART_CHECK(self.funcs["mcIpcGetMemHandle"](
|
||||
ctypes.byref(handle), devPtr))
|
||||
return handle
|
||||
|
||||
def cudaIpcOpenMemHandle(self,
|
||||
handle: cudaIpcMemHandle_t) -> ctypes.c_void_p:
|
||||
cudaIpcMemLazyEnablePeerAccess = 1
|
||||
devPtr = ctypes.c_void_p()
|
||||
self.CUDART_CHECK(self.funcs["mcIpcOpenMemHandle"](
|
||||
ctypes.byref(devPtr), handle, cudaIpcMemLazyEnablePeerAccess))
|
||||
return devPtr
|
||||
304
vllm/distributed/device_communicators/custom_all_reduce.py
Normal file
304
vllm/distributed/device_communicators/custom_all_reduce.py
Normal file
@@ -0,0 +1,304 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.distributed.device_communicators.custom_all_reduce_utils import (
|
||||
gpu_p2p_access_check)
|
||||
from vllm.distributed.parallel_state import in_the_same_node_as
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import cuda_device_count_stateless
|
||||
|
||||
try:
|
||||
ops.meta_size()
|
||||
custom_ar = True
|
||||
except Exception:
|
||||
# For CPUs
|
||||
custom_ar = False
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def _can_p2p(rank: int, world_size: int) -> bool:
|
||||
for i in range(world_size):
|
||||
if i == rank:
|
||||
continue
|
||||
if envs.VLLM_SKIP_P2P_CHECK:
|
||||
logger.info(
|
||||
"Skipping P2P check and trusting the driver's P2P report.")
|
||||
return torch.cuda.can_device_access_peer(rank, i)
|
||||
if not gpu_p2p_access_check(rank, i):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def is_weak_contiguous(inp: torch.Tensor):
|
||||
return inp.is_contiguous() or (inp.storage().nbytes() -
|
||||
inp.storage_offset() * inp.element_size()
|
||||
== inp.numel() * inp.element_size())
|
||||
|
||||
|
||||
class CustomAllreduce:
|
||||
|
||||
_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
|
||||
|
||||
# max_size: max supported allreduce size
|
||||
def __init__(self,
|
||||
group: ProcessGroup,
|
||||
device: Union[int, str, torch.device],
|
||||
max_size=8192 * 1024) -> None:
|
||||
"""
|
||||
Args:
|
||||
group: the process group to work on. If None, it will use the
|
||||
default process group.
|
||||
device: the device to bind the CustomAllreduce to. If None,
|
||||
it will be bind to f"cuda:{local_rank}".
|
||||
It is the caller's responsibility to make sure each communicator
|
||||
is bind to a unique device, and all communicators in this group
|
||||
are in the same node.
|
||||
"""
|
||||
self._IS_CAPTURING = False
|
||||
self.disabled = True
|
||||
|
||||
if not custom_ar:
|
||||
# disable because of missing custom allreduce library
|
||||
# e.g. in a non-GPU environment
|
||||
logger.info("Custom allreduce is disabled because "
|
||||
"of missing custom allreduce library")
|
||||
return
|
||||
|
||||
self.group = group
|
||||
|
||||
assert dist.get_backend(group) != dist.Backend.NCCL, (
|
||||
"CustomAllreduce should be attached to a non-NCCL group.")
|
||||
|
||||
if not all(in_the_same_node_as(group, source_rank=0)):
|
||||
# No need to initialize custom allreduce for multi-node case.
|
||||
logger.warning(
|
||||
"Custom allreduce is disabled because this process group"
|
||||
" spans across nodes.")
|
||||
return
|
||||
|
||||
rank = dist.get_rank(group=self.group)
|
||||
self.rank = rank
|
||||
world_size = dist.get_world_size(group=self.group)
|
||||
if world_size == 1:
|
||||
# No need to initialize custom allreduce for single GPU case.
|
||||
return
|
||||
|
||||
if world_size not in CustomAllreduce._SUPPORTED_WORLD_SIZES:
|
||||
logger.warning(
|
||||
"Custom allreduce is disabled due to an unsupported world"
|
||||
" size: %d. Supported world sizes: %s. To silence this "
|
||||
"warning, specify disable_custom_all_reduce=True explicitly.",
|
||||
world_size, str(CustomAllreduce._SUPPORTED_WORLD_SIZES))
|
||||
return
|
||||
|
||||
if isinstance(device, int):
|
||||
device = torch.device(f"cuda:{device}")
|
||||
elif isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
# now `device` is a `torch.device` object
|
||||
assert isinstance(device, torch.device)
|
||||
self.device = device
|
||||
|
||||
cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
|
||||
if cuda_visible_devices:
|
||||
device_ids = list(map(int, cuda_visible_devices.split(",")))
|
||||
else:
|
||||
device_ids = list(range(cuda_device_count_stateless()))
|
||||
|
||||
physical_device_id = device_ids[device.index]
|
||||
tensor = torch.tensor([physical_device_id],
|
||||
dtype=torch.int,
|
||||
device="cpu")
|
||||
gather_list = [
|
||||
torch.tensor([0], dtype=torch.int, device="cpu")
|
||||
for _ in range(world_size)
|
||||
]
|
||||
dist.all_gather(gather_list, tensor, group=self.group)
|
||||
physical_device_ids = [t.item() for t in gather_list]
|
||||
|
||||
# test nvlink first, this will filter out most of the cases
|
||||
# where custom allreduce is not supported
|
||||
# this checks hardware and driver support for NVLink
|
||||
assert current_platform.is_cuda_alike()
|
||||
fully_connected = current_platform.is_fully_connected(
|
||||
physical_device_ids)
|
||||
if world_size > 2 and not fully_connected:
|
||||
logger.warning(
|
||||
"Custom allreduce is disabled because it's not supported on"
|
||||
" more than two PCIe-only GPUs. To silence this warning, "
|
||||
"specify disable_custom_all_reduce=True explicitly.")
|
||||
return
|
||||
# test P2P capability, this checks software/cudaruntime support
|
||||
# this is expensive to compute at the first time
|
||||
# then we cache the result
|
||||
# On AMD GPU, p2p is always enabled between XGMI connected GPUs
|
||||
if not current_platform.is_rocm() and not _can_p2p(rank, world_size):
|
||||
logger.warning(
|
||||
"Custom allreduce is disabled because your platform lacks "
|
||||
"GPU P2P capability or P2P test failed. To silence this "
|
||||
"warning, specify disable_custom_all_reduce=True explicitly.")
|
||||
return
|
||||
|
||||
self.disabled = False
|
||||
# Buffers memory are owned by this Python class and passed to C++.
|
||||
# Meta data composes of two parts: meta data for synchronization and a
|
||||
# temporary buffer for storing intermediate allreduce results.
|
||||
self.meta_ptrs = self.create_shared_buffer(ops.meta_size() + max_size,
|
||||
group=group,
|
||||
uncached=True)
|
||||
# This is a pre-registered IPC buffer. In eager mode, input tensors
|
||||
# are first copied into this buffer before allreduce is performed
|
||||
self.buffer_ptrs = self.create_shared_buffer(max_size, group=group)
|
||||
# This is a buffer for storing the tuples of pointers pointing to
|
||||
# IPC buffers from all ranks. Each registered tuple has size of
|
||||
# 8*world_size bytes where world_size is at most 8. Allocating 8MB
|
||||
# is enough for 131072 such tuples. The largest model I've seen only
|
||||
# needs less than 10000 of registered tuples.
|
||||
self.rank_data = torch.empty(8 * 1024 * 1024,
|
||||
dtype=torch.uint8,
|
||||
device=self.device)
|
||||
self.max_size = max_size
|
||||
self.rank = rank
|
||||
self.world_size = world_size
|
||||
self.fully_connected = fully_connected
|
||||
self._ptr = ops.init_custom_ar(self.meta_ptrs, self.rank_data, rank,
|
||||
self.fully_connected)
|
||||
ops.register_buffer(self._ptr, self.buffer_ptrs)
|
||||
|
||||
@contextmanager
|
||||
def capture(self):
|
||||
"""
|
||||
The main responsibility of this context manager is the
|
||||
`register_graph_buffers` call at the end of the context.
|
||||
It records all the buffer addresses used in the CUDA graph.
|
||||
"""
|
||||
try:
|
||||
self._IS_CAPTURING = True
|
||||
yield
|
||||
finally:
|
||||
self._IS_CAPTURING = False
|
||||
if not self.disabled:
|
||||
self.register_graph_buffers()
|
||||
|
||||
def register_graph_buffers(self):
|
||||
handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr)
|
||||
logger.info("Registering %d cuda graph addresses", len(offset))
|
||||
# We cannot directly use `dist.all_gather_object` here
|
||||
# because it is incompatible with `gloo` backend under inference mode.
|
||||
# see https://github.com/pytorch/pytorch/issues/126032 for details.
|
||||
all_data = [[None, None]
|
||||
for _ in range(dist.get_world_size(group=self.group))]
|
||||
all_data[self.rank] = [handle, offset]
|
||||
ranks = sorted(dist.get_process_group_ranks(group=self.group))
|
||||
for i, rank in enumerate(ranks):
|
||||
dist.broadcast_object_list(all_data[i],
|
||||
src=rank,
|
||||
group=self.group,
|
||||
device="cpu")
|
||||
# Unpack list of tuples to tuple of lists.
|
||||
handles = [d[0] for d in all_data] # type: ignore
|
||||
offsets = [d[1] for d in all_data] # type: ignore
|
||||
ops.register_graph_buffers(self._ptr, handles, offsets)
|
||||
|
||||
def should_custom_ar(self, inp: torch.Tensor):
|
||||
if self.disabled:
|
||||
return False
|
||||
inp_size = inp.numel() * inp.element_size()
|
||||
# custom allreduce requires input byte size to be multiples of 16
|
||||
if inp_size % 16 != 0:
|
||||
return False
|
||||
if not is_weak_contiguous(inp):
|
||||
return False
|
||||
# for 4 or more non NVLink-capable GPUs, custom allreduce provides
|
||||
# little performance improvement over NCCL.
|
||||
if self.world_size == 2 or self.fully_connected:
|
||||
return inp_size < self.max_size
|
||||
return False
|
||||
|
||||
def all_reduce(self,
|
||||
inp: torch.Tensor,
|
||||
*,
|
||||
out: torch.Tensor = None,
|
||||
registered: bool = False):
|
||||
"""Performs an out-of-place all reduce.
|
||||
|
||||
If registered is True, this assumes inp's pointer is already
|
||||
IPC-registered. Otherwise, inp is first copied into a pre-registered
|
||||
buffer.
|
||||
"""
|
||||
if out is None:
|
||||
out = torch.empty_like(inp)
|
||||
if registered:
|
||||
ops.all_reduce(self._ptr, inp, out, 0, 0)
|
||||
else:
|
||||
ops.all_reduce(self._ptr, inp, out, self.buffer_ptrs[self.rank],
|
||||
self.max_size)
|
||||
return out
|
||||
|
||||
def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]:
|
||||
"""The main allreduce API that provides support for cuda graph."""
|
||||
# When custom allreduce is disabled, this will be None.
|
||||
if self.disabled or not self.should_custom_ar(input):
|
||||
return None
|
||||
if self._IS_CAPTURING:
|
||||
if torch.cuda.is_current_stream_capturing():
|
||||
return self.all_reduce(input, registered=True)
|
||||
else:
|
||||
# If warm up, mimic the allocation pattern since custom
|
||||
# allreduce is out-of-place.
|
||||
return torch.empty_like(input)
|
||||
else:
|
||||
# Note: outside of cuda graph context, custom allreduce incurs a
|
||||
# cost of cudaMemcpy, which should be small (<=1% of overall
|
||||
# latency) compared to the performance gain of using custom kernels
|
||||
return self.all_reduce(input, registered=False)
|
||||
|
||||
def close(self):
|
||||
if not self.disabled and self._ptr:
|
||||
if ops is not None:
|
||||
ops.dispose(self._ptr)
|
||||
self._ptr = 0
|
||||
self.free_shared_buffer(self.meta_ptrs, rank=self.rank)
|
||||
self.free_shared_buffer(self.buffer_ptrs, rank=self.rank)
|
||||
|
||||
def __del__(self):
|
||||
self.close()
|
||||
|
||||
@staticmethod
|
||||
def create_shared_buffer(size_in_bytes: int,
|
||||
group: Optional[ProcessGroup] = None,
|
||||
uncached: Optional[bool] = False) -> list[int]:
|
||||
pointer, handle = ops.allocate_shared_buffer_and_handle(size_in_bytes)
|
||||
|
||||
world_size = dist.get_world_size(group=group)
|
||||
rank = dist.get_rank(group=group)
|
||||
handles = [None] * world_size
|
||||
dist.all_gather_object(handles, handle, group=group)
|
||||
|
||||
pointers: list[int] = []
|
||||
for i, h in enumerate(handles):
|
||||
if i == rank:
|
||||
pointers.append(pointer) # type: ignore
|
||||
else:
|
||||
pointers.append(ops.open_mem_handle(h))
|
||||
return pointers
|
||||
|
||||
@staticmethod
|
||||
def free_shared_buffer(pointers: list[int],
|
||||
group: Optional[ProcessGroup] = None,
|
||||
rank: Optional[int] = 0) -> None:
|
||||
if rank is None:
|
||||
rank = dist.get_rank(group=group)
|
||||
if ops is not None:
|
||||
ops.free_shared_buffer(pointers[rank])
|
||||
259
vllm/distributed/device_communicators/custom_all_reduce_utils.py
Normal file
259
vllm/distributed/device_communicators/custom_all_reduce_utils.py
Normal file
@@ -0,0 +1,259 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import ctypes
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
from collections.abc import Sequence
|
||||
from itertools import product
|
||||
from typing import Optional
|
||||
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import (cuda_device_count_stateless,
|
||||
update_environment_variables)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def producer(batch_src: Sequence[int],
|
||||
producer_queue,
|
||||
consumer_queue,
|
||||
result_queue,
|
||||
cuda_visible_devices: Optional[str] = None):
|
||||
if cuda_visible_devices is not None:
|
||||
update_environment_variables(
|
||||
{"CUDA_VISIBLE_DEVICES": cuda_visible_devices})
|
||||
|
||||
lib = CudaRTLibrary()
|
||||
for i in batch_src:
|
||||
lib.cudaSetDevice(i)
|
||||
pointer = lib.cudaMalloc(1024)
|
||||
lib.cudaMemset(pointer, 1, 1024)
|
||||
lib.cudaDeviceSynchronize()
|
||||
handle = lib.cudaIpcGetMemHandle(pointer)
|
||||
producer_queue.put(handle)
|
||||
open_success = consumer_queue.get()
|
||||
if open_success:
|
||||
# use two queues to simulate barrier
|
||||
producer_queue.put(0)
|
||||
consumer_queue.get()
|
||||
# check if the memory is modified
|
||||
host_data = (ctypes.c_char * 1024)()
|
||||
lib.cudaMemcpy(host_data, pointer, 1024) # type: ignore
|
||||
for i in range(1024):
|
||||
if ord(host_data[i]) != 2:
|
||||
open_success = False
|
||||
break
|
||||
result_queue.put(open_success)
|
||||
lib.cudaDeviceReset()
|
||||
|
||||
|
||||
def consumer(batch_tgt: Sequence[int],
|
||||
producer_queue,
|
||||
consumer_queue,
|
||||
result_queue,
|
||||
cuda_visible_devices: Optional[str] = None):
|
||||
if cuda_visible_devices is not None:
|
||||
update_environment_variables(
|
||||
{"CUDA_VISIBLE_DEVICES": cuda_visible_devices})
|
||||
|
||||
lib = CudaRTLibrary()
|
||||
for j in batch_tgt:
|
||||
lib.cudaSetDevice(j)
|
||||
handle = producer_queue.get()
|
||||
open_success = False
|
||||
try:
|
||||
pointer = lib.cudaIpcOpenMemHandle(handle) # type: ignore
|
||||
open_success = True
|
||||
except RuntimeError:
|
||||
# cannot error out here, because the producer process
|
||||
# is still waiting for the response.
|
||||
pass
|
||||
consumer_queue.put(open_success)
|
||||
if open_success:
|
||||
# modify the memory
|
||||
lib.cudaMemset(pointer, 2, 1024)
|
||||
lib.cudaDeviceSynchronize()
|
||||
# use two queues to simulate barrier
|
||||
producer_queue.get()
|
||||
consumer_queue.put(0)
|
||||
# check if the memory is modified
|
||||
host_data = (ctypes.c_char * 1024)()
|
||||
lib.cudaMemcpy(host_data, pointer, 1024) # type: ignore
|
||||
for i in range(1024):
|
||||
if ord(host_data[i]) != 2:
|
||||
open_success = False
|
||||
break
|
||||
result_queue.put(open_success)
|
||||
lib.cudaDeviceReset()
|
||||
|
||||
|
||||
def can_actually_p2p(
|
||||
batch_src: Sequence[int],
|
||||
batch_tgt: Sequence[int],
|
||||
) -> Sequence[bool]:
|
||||
"""
|
||||
Usually, checking if P2P access is enabled can be done by
|
||||
`torch.cuda.can_device_access_peer(src, tgt)`. However, sometimes
|
||||
the driver might be broken, and `torch.cuda.can_device_access_peer(src, tgt)`
|
||||
returns `True` even if P2P access is not actually possible.
|
||||
See https://github.com/vllm-project/vllm/issues/2728 and
|
||||
https://forums.developer.nvidia.com/t/direct-gpu-gpu-communication-does-not-seem-to-work-properly/283264/10
|
||||
Therefore, we have to perform a real P2P access to check if it is actually
|
||||
possible.
|
||||
|
||||
Note on p2p and cuda IPC:
|
||||
Usually, one process uses one GPU:
|
||||
GPU src --> cuda context src --> tensor src --> process src
|
||||
|
||||
We need to combine p2p and cuda IPC, so that:
|
||||
GPU src --> cuda context src --> tensor src --> process src
|
||||
|shared|
|
||||
GPU tgt --> cuda context tgt --> tensor tgt --> process tgt
|
||||
That is to say, process src creates a tensor in GPU src, passes IPC handle to
|
||||
process tgt, and process tgt accesses the tensor in GPU tgt. Any operation on the
|
||||
tensor in process tgt will be reflected in the tensor in process src, because
|
||||
they are the same memory segment.
|
||||
It is important to note that process tgt accesses the tensor in GPU tgt, not
|
||||
GPU src. That's why we need p2p access.
|
||||
|
||||
The most time-consuming part is the process creation. To avoid creating
|
||||
processes for every pair of GPUs, we use batched testing. We create two
|
||||
processes for testing all pairs of GPUs in batch. The trick is to reset
|
||||
the device after each test (which is not available in PyTorch).
|
||||
""" # noqa
|
||||
cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
|
||||
# pass the CUDA_VISIBLE_DEVICES to the child process
|
||||
# to make sure they see the same set of GPUs
|
||||
|
||||
# make sure the processes are spawned
|
||||
smp = mp.get_context("spawn")
|
||||
producer_queue = smp.Queue()
|
||||
consumer_queue = smp.Queue()
|
||||
result_queue = smp.Queue()
|
||||
p_src = smp.Process(target=producer,
|
||||
args=(batch_src, producer_queue, consumer_queue,
|
||||
result_queue, cuda_visible_devices))
|
||||
p_tgt = smp.Process(target=consumer,
|
||||
args=(batch_tgt, producer_queue, consumer_queue,
|
||||
result_queue, cuda_visible_devices))
|
||||
p_src.start()
|
||||
p_tgt.start()
|
||||
p_src.join()
|
||||
p_tgt.join()
|
||||
assert p_src.exitcode == 0 and p_tgt.exitcode == 0
|
||||
result: list[bool] = []
|
||||
for src, tgt in zip(batch_src, batch_tgt):
|
||||
a = result_queue.get()
|
||||
b = result_queue.get()
|
||||
if a != b:
|
||||
logger.warning(
|
||||
"Two processes do not agree on the P2P access"
|
||||
" status on %d -> %d, treat as disabled.", src, tgt)
|
||||
result.append(False)
|
||||
else:
|
||||
result.append(a)
|
||||
return result
|
||||
|
||||
|
||||
# why do we need this cache?
|
||||
# we are testing peer-to-peer (p2p) access between GPUs,across processes.
|
||||
# if we test it every time, it will be very slow, because we need to create
|
||||
# N * N * 2 processes, where N is the world size. This is very slow.
|
||||
# to reduce the time, we use a cache file to store the p2p access status.
|
||||
# the cache file is generated by the master process if it does not exist.
|
||||
# then all the processes can read the cache file to check the p2p access status.
|
||||
# Note that the cache file is suffixed by the CUDA_VISIBLE_DEVICES, so that we
|
||||
# can have different cache files for different CUDA_VISIBLE_DEVICES settings,
|
||||
# e.g. used by different vllm engines. The device id in the cache file is a
|
||||
# **local** device id, i.e. from 0 to num_dev-1, where num_dev is the number
|
||||
# of visible devices in the vllm engine.
|
||||
_gpu_p2p_access_cache: Optional[dict[str, bool]] = None
|
||||
|
||||
|
||||
def gpu_p2p_access_check(src: int, tgt: int) -> bool:
|
||||
"""Check if GPU src can access GPU tgt."""
|
||||
|
||||
# if the cache variable is already calculated,
|
||||
# read from the cache instead of checking it again
|
||||
global _gpu_p2p_access_cache
|
||||
if _gpu_p2p_access_cache is not None:
|
||||
return _gpu_p2p_access_cache[f"{src}->{tgt}"]
|
||||
|
||||
is_distributed = dist.is_initialized()
|
||||
|
||||
num_dev = cuda_device_count_stateless()
|
||||
cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
|
||||
if cuda_visible_devices is None:
|
||||
cuda_visible_devices = ",".join(str(i) for i in range(num_dev))
|
||||
|
||||
path = os.path.join(
|
||||
envs.VLLM_CACHE_ROOT,
|
||||
f"gpu_p2p_access_cache_for_{cuda_visible_devices}.json")
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||
from vllm.distributed.parallel_state import get_world_group
|
||||
if ((not is_distributed or get_world_group().local_rank == 0)
|
||||
and (not os.path.exists(path))):
|
||||
# only the local master process (with local_rank == 0) can
|
||||
# enter this block to calculate the cache
|
||||
logger.info("generating GPU P2P access cache in %s", path)
|
||||
cache: dict[str, bool] = {}
|
||||
ids = list(range(num_dev))
|
||||
# batch of all pairs of GPUs
|
||||
batch_src, batch_tgt = zip(*list(product(ids, ids)))
|
||||
# NOTE: we use `subprocess` rather than `multiprocessing` here
|
||||
# because the caller might not have `if __name__ == "__main__":`,
|
||||
# in that case we cannot use spawn method in multiprocessing.
|
||||
# However, `can_actually_p2p` requires spawn method.
|
||||
# The fix is, we use `subprocess` to call the function,
|
||||
# where we have `if __name__ == "__main__":` in this file.
|
||||
|
||||
# use a temporary file to store the result
|
||||
# we don't use the output of the subprocess directly,
|
||||
# because the subprocess might produce logging output
|
||||
with tempfile.NamedTemporaryFile() as output_file:
|
||||
input_bytes = pickle.dumps(
|
||||
(batch_src, batch_tgt, output_file.name))
|
||||
returned = subprocess.run([sys.executable, __file__],
|
||||
input=input_bytes,
|
||||
capture_output=True)
|
||||
# check if the subprocess is successful
|
||||
try:
|
||||
returned.check_returncode()
|
||||
except Exception as e:
|
||||
# wrap raised exception to provide more information
|
||||
raise RuntimeError(
|
||||
f"Error happened when batch testing "
|
||||
f"peer-to-peer access from {batch_src} to {batch_tgt}:\n"
|
||||
f"{returned.stderr.decode()}") from e
|
||||
with open(output_file.name, "rb") as f:
|
||||
result = pickle.load(f)
|
||||
for _i, _j, r in zip(batch_src, batch_tgt, result):
|
||||
cache[f"{_i}->{_j}"] = r
|
||||
with open(path, "w") as f:
|
||||
json.dump(cache, f, indent=4)
|
||||
if is_distributed:
|
||||
get_world_group().barrier()
|
||||
logger.info("reading GPU P2P access cache from %s", path)
|
||||
with open(path) as f:
|
||||
cache = json.load(f)
|
||||
_gpu_p2p_access_cache = cache
|
||||
return _gpu_p2p_access_cache[f"{src}->{tgt}"]
|
||||
|
||||
|
||||
__all__ = ["gpu_p2p_access_check"]
|
||||
|
||||
if __name__ == "__main__":
|
||||
batch_src, batch_tgt, output_file = pickle.loads(sys.stdin.buffer.read())
|
||||
result = can_actually_p2p(batch_src, batch_tgt)
|
||||
with open(output_file, "wb") as f:
|
||||
f.write(pickle.dumps(result))
|
||||
46
vllm/distributed/device_communicators/hpu_communicator.py
Normal file
46
vllm/distributed/device_communicators/hpu_communicator.py
Normal file
@@ -0,0 +1,46 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .base_device_communicator import DeviceCommunicatorBase
|
||||
|
||||
if current_platform.is_hpu():
|
||||
import habana_frameworks.torch as htorch # noqa: F401
|
||||
|
||||
|
||||
class HpuCommunicator(DeviceCommunicatorBase):
|
||||
|
||||
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
|
||||
# FIXME(kzawora): this is a workaround for a bug in Habana PT bridge
|
||||
# occurring when PT_HPU_ENABLE_LAZY_COLLECTIVES=true env var is used
|
||||
# (which is required for tensor parallel HPUGraph inference)
|
||||
htorch.core.mark_step()
|
||||
dist.all_reduce(input_, group=self.device_group)
|
||||
return input_
|
||||
|
||||
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
||||
world_size = self.world_size
|
||||
if dim < 0:
|
||||
# Convert negative dim to positive.
|
||||
dim += input_.dim()
|
||||
input_size = input_.size()
|
||||
# Allocate output tensor.
|
||||
output_tensor = torch.empty((world_size, ) + input_size,
|
||||
dtype=input_.dtype,
|
||||
device=input_.device)
|
||||
# All-gather.
|
||||
htorch.core.mark_step()
|
||||
dist.all_gather_into_tensor(output_tensor,
|
||||
input_,
|
||||
group=self.device_group)
|
||||
# Reshape
|
||||
output_tensor = output_tensor.movedim(0, dim)
|
||||
output_tensor = output_tensor.reshape(input_size[:dim] +
|
||||
(world_size *
|
||||
input_size[dim], ) +
|
||||
input_size[dim + 1:])
|
||||
return output_tensor
|
||||
20
vllm/distributed/device_communicators/neuron_communicator.py
Normal file
20
vllm/distributed/device_communicators/neuron_communicator.py
Normal file
@@ -0,0 +1,20 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch
|
||||
|
||||
from vllm.distributed.device_communicators.base_device_communicator import (
|
||||
DeviceCommunicatorBase)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if current_platform.is_neuron():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
|
||||
class NeuronCommunicator(DeviceCommunicatorBase):
|
||||
|
||||
def all_reduce(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return xm.all_reduce(xm.REDUCE_SUM, x)
|
||||
|
||||
def all_gather(self, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
||||
assert dim == -1, "Neuron only supports dim=-1 for all-gather."
|
||||
return xm.all_gather(x, dim=dim)
|
||||
218
vllm/distributed/device_communicators/pynccl.py
Normal file
218
vllm/distributed/device_communicators/pynccl.py
Normal file
@@ -0,0 +1,218 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
# ===================== import region =====================
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup, ReduceOp
|
||||
|
||||
from vllm.distributed.device_communicators.pynccl_wrapper import (
|
||||
NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum,
|
||||
ncclRedOpTypeEnum, ncclUniqueId)
|
||||
from vllm.distributed.utils import StatelessProcessGroup
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import current_stream
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class PyNcclCommunicator:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
group: Union[ProcessGroup, StatelessProcessGroup],
|
||||
device: Union[int, str, torch.device],
|
||||
library_path: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
group: the process group to work on. If None, it will use the
|
||||
default process group.
|
||||
device: the device to bind the PyNcclCommunicator to. If None,
|
||||
it will be bind to f"cuda:{local_rank}".
|
||||
library_path: the path to the NCCL library. If None, it will
|
||||
use the default library path.
|
||||
It is the caller's responsibility to make sure each communicator
|
||||
is bind to a unique device.
|
||||
"""
|
||||
if not isinstance(group, StatelessProcessGroup):
|
||||
assert dist.is_initialized()
|
||||
assert dist.get_backend(group) != dist.Backend.NCCL, (
|
||||
"PyNcclCommunicator should be attached to a non-NCCL group.")
|
||||
# note: this rank is the rank in the group
|
||||
self.rank = dist.get_rank(group)
|
||||
self.world_size = dist.get_world_size(group)
|
||||
else:
|
||||
self.rank = group.rank
|
||||
self.world_size = group.world_size
|
||||
|
||||
self.group = group
|
||||
|
||||
# if world_size == 1, no need to create communicator
|
||||
if self.world_size == 1:
|
||||
self.available = False
|
||||
self.disabled = True
|
||||
return
|
||||
try:
|
||||
self.nccl = NCCLLibrary(library_path)
|
||||
except Exception:
|
||||
# disable because of missing NCCL library
|
||||
# e.g. in a non-GPU environment
|
||||
self.available = False
|
||||
self.disabled = True
|
||||
return
|
||||
|
||||
self.available = True
|
||||
self.disabled = False
|
||||
|
||||
logger.info("vLLM is using nccl==%s", self.nccl.ncclGetVersion())
|
||||
|
||||
if self.rank == 0:
|
||||
# get the unique id from NCCL
|
||||
self.unique_id = self.nccl.ncclGetUniqueId()
|
||||
else:
|
||||
# construct an empty unique id
|
||||
self.unique_id = ncclUniqueId()
|
||||
|
||||
if not isinstance(group, StatelessProcessGroup):
|
||||
tensor = torch.ByteTensor(list(self.unique_id.internal))
|
||||
ranks = dist.get_process_group_ranks(group)
|
||||
# arg `src` in `broadcast` is the global rank
|
||||
dist.broadcast(tensor, src=ranks[0], group=group)
|
||||
byte_list = tensor.tolist()
|
||||
for i, byte in enumerate(byte_list):
|
||||
self.unique_id.internal[i] = byte
|
||||
else:
|
||||
self.unique_id = group.broadcast_obj(self.unique_id, src=0)
|
||||
if isinstance(device, int):
|
||||
device = torch.device(f"cuda:{device}")
|
||||
elif isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
# now `device` is a `torch.device` object
|
||||
assert isinstance(device, torch.device)
|
||||
self.device = device
|
||||
# nccl communicator and stream will use this device
|
||||
# `torch.cuda.device` is a context manager that changes the
|
||||
# current cuda device to the specified one
|
||||
with torch.cuda.device(device):
|
||||
self.comm: ncclComm_t = self.nccl.ncclCommInitRank(
|
||||
self.world_size, self.unique_id, self.rank)
|
||||
|
||||
stream = current_stream()
|
||||
# A small all_reduce for warmup.
|
||||
data = torch.zeros(1, device=device)
|
||||
self.all_reduce(data)
|
||||
stream.synchronize()
|
||||
del data
|
||||
|
||||
def all_reduce(self,
|
||||
in_tensor: torch.Tensor,
|
||||
op: ReduceOp = ReduceOp.SUM,
|
||||
stream=None) -> torch.Tensor:
|
||||
if self.disabled:
|
||||
return None
|
||||
# nccl communicator created on a specific device
|
||||
# will only work on tensors on the same device
|
||||
# otherwise it will cause "illegal memory access"
|
||||
assert in_tensor.device == self.device, (
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {in_tensor.device}")
|
||||
|
||||
out_tensor = torch.empty_like(in_tensor)
|
||||
|
||||
if stream is None:
|
||||
stream = current_stream()
|
||||
self.nccl.ncclAllReduce(buffer_type(in_tensor.data_ptr()),
|
||||
buffer_type(out_tensor.data_ptr()),
|
||||
in_tensor.numel(),
|
||||
ncclDataTypeEnum.from_torch(in_tensor.dtype),
|
||||
ncclRedOpTypeEnum.from_torch(op), self.comm,
|
||||
cudaStream_t(stream.cuda_stream))
|
||||
return out_tensor
|
||||
|
||||
def all_gather(self,
|
||||
output_tensor: torch.Tensor,
|
||||
input_tensor: torch.Tensor,
|
||||
stream=None):
|
||||
if self.disabled:
|
||||
return
|
||||
# nccl communicator created on a specific device
|
||||
# will only work on tensors on the same device
|
||||
# otherwise it will cause "illegal memory access"
|
||||
assert input_tensor.device == self.device, (
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {input_tensor.device}")
|
||||
if stream is None:
|
||||
stream = current_stream()
|
||||
self.nccl.ncclAllGather(
|
||||
buffer_type(input_tensor.data_ptr()),
|
||||
buffer_type(output_tensor.data_ptr()), input_tensor.numel(),
|
||||
ncclDataTypeEnum.from_torch(input_tensor.dtype), self.comm,
|
||||
cudaStream_t(stream.cuda_stream))
|
||||
|
||||
def reduce_scatter(self,
|
||||
output_tensor: torch.Tensor,
|
||||
input_tensor: torch.Tensor,
|
||||
op: ReduceOp = ReduceOp.SUM,
|
||||
stream=None):
|
||||
if self.disabled:
|
||||
return
|
||||
# nccl communicator created on a specific device
|
||||
# will only work on tensors on the same device
|
||||
# otherwise it will cause "illegal memory access"
|
||||
assert input_tensor.device == self.device, (
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {input_tensor.device}")
|
||||
if stream is None:
|
||||
stream = current_stream()
|
||||
self.nccl.ncclReduceScatter(
|
||||
buffer_type(input_tensor.data_ptr()),
|
||||
buffer_type(output_tensor.data_ptr()), output_tensor.numel(),
|
||||
ncclDataTypeEnum.from_torch(input_tensor.dtype),
|
||||
ncclRedOpTypeEnum.from_torch(op), self.comm,
|
||||
cudaStream_t(stream.cuda_stream))
|
||||
|
||||
def send(self, tensor: torch.Tensor, dst: int, stream=None):
|
||||
if self.disabled:
|
||||
return
|
||||
assert tensor.device == self.device, (
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {tensor.device}")
|
||||
if stream is None:
|
||||
stream = current_stream()
|
||||
self.nccl.ncclSend(buffer_type(tensor.data_ptr()), tensor.numel(),
|
||||
ncclDataTypeEnum.from_torch(tensor.dtype), dst,
|
||||
self.comm, cudaStream_t(stream.cuda_stream))
|
||||
|
||||
def recv(self, tensor: torch.Tensor, src: int, stream=None):
|
||||
if self.disabled:
|
||||
return
|
||||
assert tensor.device == self.device, (
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {tensor.device}")
|
||||
if stream is None:
|
||||
stream = current_stream()
|
||||
self.nccl.ncclRecv(buffer_type(tensor.data_ptr()), tensor.numel(),
|
||||
ncclDataTypeEnum.from_torch(tensor.dtype), src,
|
||||
self.comm, cudaStream_t(stream.cuda_stream))
|
||||
|
||||
def broadcast(self, tensor: torch.Tensor, src: int, stream=None):
|
||||
if self.disabled:
|
||||
return
|
||||
assert tensor.device == self.device, (
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {tensor.device}")
|
||||
if stream is None:
|
||||
stream = current_stream()
|
||||
if src == self.rank:
|
||||
sendbuff = buffer_type(tensor.data_ptr())
|
||||
# NCCL requires the sender also to have a receive buffer
|
||||
recvbuff = buffer_type(tensor.data_ptr())
|
||||
else:
|
||||
sendbuff = buffer_type()
|
||||
recvbuff = buffer_type(tensor.data_ptr())
|
||||
self.nccl.ncclBroadcast(sendbuff, recvbuff, tensor.numel(),
|
||||
ncclDataTypeEnum.from_torch(tensor.dtype), src,
|
||||
self.comm, cudaStream_t(stream.cuda_stream))
|
||||
341
vllm/distributed/device_communicators/pynccl_wrapper.py
Normal file
341
vllm/distributed/device_communicators/pynccl_wrapper.py
Normal file
@@ -0,0 +1,341 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# This file is a pure Python wrapper for the NCCL library.
|
||||
# The main purpose is to use NCCL combined with CUDA graph.
|
||||
# Before writing this script, we tried the following approach:
|
||||
# 1. We tried to use `cupy`, it calls NCCL correctly, but `cupy` itself
|
||||
# often gets stuck when initializing the NCCL communicator.
|
||||
# 2. We tried to use `torch.distributed`, but `torch.distributed.all_reduce`
|
||||
# contains many other potential cuda APIs, that are not allowed during
|
||||
# capturing the CUDA graph. For further details, please check
|
||||
# https://discuss.pytorch.org/t/pytorch-cudagraph-with-nccl-operation-failed/ .
|
||||
#
|
||||
# Another rejected idea is to write a C/C++ binding for NCCL. It is usually
|
||||
# doable, but we often encounter issues related with nccl versions, and need
|
||||
# to switch between different versions of NCCL. See
|
||||
# https://github.com/NVIDIA/nccl/issues/1234 for more details.
|
||||
# A C/C++ binding is not flexible enough to handle this. It requires
|
||||
# recompilation of the code every time we want to switch between different
|
||||
# versions. This current implementation, with a **pure** Python wrapper, is
|
||||
# more flexible. We can easily switch between different versions of NCCL by
|
||||
# changing the environment variable `VLLM_NCCL_SO_PATH`, or the `so_file`
|
||||
# variable in the code.
|
||||
|
||||
import ctypes
|
||||
import platform
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from torch.distributed import ReduceOp
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import find_nccl_library
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# === export types and functions from nccl to Python ===
|
||||
# for the original nccl definition, please check
|
||||
# https://github.com/NVIDIA/nccl/blob/master/src/nccl.h.in
|
||||
|
||||
ncclResult_t = ctypes.c_int
|
||||
ncclComm_t = ctypes.c_void_p
|
||||
|
||||
|
||||
class ncclUniqueId(ctypes.Structure):
|
||||
_fields_ = [("internal", ctypes.c_byte * 128)]
|
||||
|
||||
|
||||
cudaStream_t = ctypes.c_void_p
|
||||
buffer_type = ctypes.c_void_p
|
||||
|
||||
ncclDataType_t = ctypes.c_int
|
||||
|
||||
|
||||
class ncclDataTypeEnum:
|
||||
ncclInt8 = 0
|
||||
ncclChar = 0
|
||||
ncclUint8 = 1
|
||||
ncclInt32 = 2
|
||||
ncclInt = 2
|
||||
ncclUint32 = 3
|
||||
ncclInt64 = 4
|
||||
ncclUint64 = 5
|
||||
ncclFloat16 = 6
|
||||
ncclHalf = 6
|
||||
ncclFloat32 = 7
|
||||
ncclFloat = 7
|
||||
ncclFloat64 = 8
|
||||
ncclDouble = 8
|
||||
ncclBfloat16 = 9
|
||||
ncclNumTypes = 10
|
||||
|
||||
@classmethod
|
||||
def from_torch(cls, dtype: torch.dtype) -> int:
|
||||
if dtype == torch.int8:
|
||||
return cls.ncclInt8
|
||||
if dtype == torch.uint8:
|
||||
return cls.ncclUint8
|
||||
if dtype == torch.int32:
|
||||
return cls.ncclInt32
|
||||
if dtype == torch.int64:
|
||||
return cls.ncclInt64
|
||||
if dtype == torch.float16:
|
||||
return cls.ncclFloat16
|
||||
if dtype == torch.float32:
|
||||
return cls.ncclFloat32
|
||||
if dtype == torch.float64:
|
||||
return cls.ncclFloat64
|
||||
if dtype == torch.bfloat16:
|
||||
return cls.ncclBfloat16
|
||||
raise ValueError(f"Unsupported dtype: {dtype}")
|
||||
|
||||
|
||||
ncclRedOp_t = ctypes.c_int
|
||||
|
||||
|
||||
class ncclRedOpTypeEnum:
|
||||
ncclSum = 0
|
||||
ncclProd = 1
|
||||
ncclMax = 2
|
||||
ncclMin = 3
|
||||
ncclAvg = 4
|
||||
ncclNumOps = 5
|
||||
|
||||
@classmethod
|
||||
def from_torch(cls, op: ReduceOp) -> int:
|
||||
if op == ReduceOp.SUM:
|
||||
return cls.ncclSum
|
||||
if op == ReduceOp.PRODUCT:
|
||||
return cls.ncclProd
|
||||
if op == ReduceOp.MAX:
|
||||
return cls.ncclMax
|
||||
if op == ReduceOp.MIN:
|
||||
return cls.ncclMin
|
||||
if op == ReduceOp.AVG:
|
||||
return cls.ncclAvg
|
||||
raise ValueError(f"Unsupported op: {op}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class Function:
|
||||
name: str
|
||||
restype: Any
|
||||
argtypes: list[Any]
|
||||
|
||||
|
||||
class NCCLLibrary:
|
||||
exported_functions = [
|
||||
# const char* ncclGetErrorString(ncclResult_t result)
|
||||
Function("mcclGetErrorString", ctypes.c_char_p, [ncclResult_t]),
|
||||
# ncclResult_t ncclGetVersion(int *version);
|
||||
Function("mcclGetVersion", ncclResult_t,
|
||||
[ctypes.POINTER(ctypes.c_int)]),
|
||||
# ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId);
|
||||
Function("mcclGetUniqueId", ncclResult_t,
|
||||
[ctypes.POINTER(ncclUniqueId)]),
|
||||
# ncclResult_t ncclCommInitRank(
|
||||
# ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank);
|
||||
# note that ncclComm_t is a pointer type, so the first argument
|
||||
# is a pointer to a pointer
|
||||
Function("mcclCommInitRank", ncclResult_t, [
|
||||
ctypes.POINTER(ncclComm_t), ctypes.c_int, ncclUniqueId,
|
||||
ctypes.c_int
|
||||
]),
|
||||
# ncclResult_t ncclAllReduce(
|
||||
# const void* sendbuff, void* recvbuff, size_t count,
|
||||
# ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,
|
||||
# cudaStream_t stream);
|
||||
# note that cudaStream_t is a pointer type, so the last argument
|
||||
# is a pointer
|
||||
Function("mcclAllReduce", ncclResult_t, [
|
||||
buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t,
|
||||
ncclRedOp_t, ncclComm_t, cudaStream_t
|
||||
]),
|
||||
|
||||
# ncclResult_t ncclAllGather(
|
||||
# const void* sendbuff, void* recvbuff, size_t count,
|
||||
# ncclDataType_t datatype, ncclComm_t comm,
|
||||
# cudaStream_t stream);
|
||||
# note that cudaStream_t is a pointer type, so the last argument
|
||||
# is a pointer
|
||||
Function("mcclAllGather", ncclResult_t, [
|
||||
buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t,
|
||||
ncclComm_t, cudaStream_t
|
||||
]),
|
||||
|
||||
# ncclResult_t ncclReduceScatter(
|
||||
# const void* sendbuff, void* recvbuff, size_t count,
|
||||
# ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,
|
||||
# cudaStream_t stream);
|
||||
# note that cudaStream_t is a pointer type, so the last argument
|
||||
# is a pointer
|
||||
Function("mcclReduceScatter", ncclResult_t, [
|
||||
buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t,
|
||||
ncclRedOp_t, ncclComm_t, cudaStream_t
|
||||
]),
|
||||
|
||||
# ncclResult_t ncclSend(
|
||||
# const void* sendbuff, size_t count, ncclDataType_t datatype,
|
||||
# int dest, ncclComm_t comm, cudaStream_t stream);
|
||||
Function("mcclSend", ncclResult_t, [
|
||||
buffer_type, ctypes.c_size_t, ncclDataType_t, ctypes.c_int,
|
||||
ncclComm_t, cudaStream_t
|
||||
]),
|
||||
|
||||
# ncclResult_t ncclRecv(
|
||||
# void* recvbuff, size_t count, ncclDataType_t datatype,
|
||||
# int src, ncclComm_t comm, cudaStream_t stream);
|
||||
Function("mcclRecv", ncclResult_t, [
|
||||
buffer_type, ctypes.c_size_t, ncclDataType_t, ctypes.c_int,
|
||||
ncclComm_t, cudaStream_t
|
||||
]),
|
||||
|
||||
# ncclResult_t ncclBroadcast(
|
||||
# const void* sendbuff, void* recvbuff, size_t count,
|
||||
# ncclDataType_t datatype, int root, ncclComm_t comm,
|
||||
# cudaStream_t stream);
|
||||
Function("mcclBroadcast", ncclResult_t, [
|
||||
buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t,
|
||||
ctypes.c_int, ncclComm_t, cudaStream_t
|
||||
]),
|
||||
|
||||
# be cautious! this is a collective call, it will block until all
|
||||
# processes in the communicator have called this function.
|
||||
# because Python object destruction can happen in random order,
|
||||
# it is better not to call it at all.
|
||||
# ncclResult_t ncclCommDestroy(ncclComm_t comm);
|
||||
Function("mcclCommDestroy", ncclResult_t, [ncclComm_t]),
|
||||
]
|
||||
|
||||
# class attribute to store the mapping from the path to the library
|
||||
# to avoid loading the same library multiple times
|
||||
path_to_library_cache: dict[str, Any] = {}
|
||||
|
||||
# class attribute to store the mapping from library path
|
||||
# to the corresponding dictionary
|
||||
path_to_dict_mapping: dict[str, dict[str, Any]] = {}
|
||||
|
||||
def __init__(self, so_file: Optional[str] = None):
|
||||
|
||||
so_file = so_file or find_nccl_library()
|
||||
|
||||
try:
|
||||
if so_file not in NCCLLibrary.path_to_dict_mapping:
|
||||
lib = ctypes.CDLL(so_file)
|
||||
NCCLLibrary.path_to_library_cache[so_file] = lib
|
||||
self.lib = NCCLLibrary.path_to_library_cache[so_file]
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to load NCCL library from %s. "
|
||||
"It is expected if you are not running on NVIDIA/AMD GPUs."
|
||||
"Otherwise, the nccl library might not exist, be corrupted "
|
||||
"or it does not support the current platform %s. "
|
||||
"If you already have the library, please set the "
|
||||
"environment variable VLLM_NCCL_SO_PATH"
|
||||
" to point to the correct nccl library path.", so_file,
|
||||
platform.platform())
|
||||
raise e
|
||||
|
||||
if so_file not in NCCLLibrary.path_to_dict_mapping:
|
||||
_funcs: dict[str, Any] = {}
|
||||
for func in NCCLLibrary.exported_functions:
|
||||
f = getattr(self.lib, func.name)
|
||||
f.restype = func.restype
|
||||
f.argtypes = func.argtypes
|
||||
_funcs[func.name] = f
|
||||
NCCLLibrary.path_to_dict_mapping[so_file] = _funcs
|
||||
self._funcs = NCCLLibrary.path_to_dict_mapping[so_file]
|
||||
|
||||
def ncclGetErrorString(self, result: ncclResult_t) -> str:
|
||||
return self._funcs["mcclGetErrorString"](result).decode("utf-8")
|
||||
|
||||
def NCCL_CHECK(self, result: ncclResult_t) -> None:
|
||||
if result != 0:
|
||||
error_str = self.ncclGetErrorString(result)
|
||||
raise RuntimeError(f"MCCL error: {error_str}")
|
||||
|
||||
def ncclGetVersion(self) -> str:
|
||||
version = ctypes.c_int()
|
||||
self.NCCL_CHECK(self._funcs["mcclGetVersion"](ctypes.byref(version)))
|
||||
version_str = str(version.value)
|
||||
# something like 21903 --> "2.19.3"
|
||||
major = version_str[0].lstrip("0")
|
||||
minor = version_str[1:3].lstrip("0")
|
||||
patch = version_str[3:].lstrip("0")
|
||||
return f"{major}.{minor}.{patch}"
|
||||
|
||||
def ncclGetUniqueId(self) -> ncclUniqueId:
|
||||
unique_id = ncclUniqueId()
|
||||
self.NCCL_CHECK(self._funcs["mcclGetUniqueId"](
|
||||
ctypes.byref(unique_id)))
|
||||
return unique_id
|
||||
|
||||
def ncclCommInitRank(self, world_size: int, unique_id: ncclUniqueId,
|
||||
rank: int) -> ncclComm_t:
|
||||
comm = ncclComm_t()
|
||||
self.NCCL_CHECK(self._funcs["mcclCommInitRank"](ctypes.byref(comm),
|
||||
world_size, unique_id,
|
||||
rank))
|
||||
return comm
|
||||
|
||||
def ncclAllReduce(self, sendbuff: buffer_type, recvbuff: buffer_type,
|
||||
count: int, datatype: int, op: int, comm: ncclComm_t,
|
||||
stream: cudaStream_t) -> None:
|
||||
# `datatype` actually should be `ncclDataType_t`
|
||||
# and `op` should be `ncclRedOp_t`
|
||||
# both are aliases of `ctypes.c_int`
|
||||
# when we pass int to a function, it will be converted to `ctypes.c_int`
|
||||
# by ctypes automatically
|
||||
self.NCCL_CHECK(self._funcs["mcclAllReduce"](sendbuff, recvbuff, count,
|
||||
datatype, op, comm,
|
||||
stream))
|
||||
|
||||
def ncclReduceScatter(self, sendbuff: buffer_type, recvbuff: buffer_type,
|
||||
count: int, datatype: int, op: int, comm: ncclComm_t,
|
||||
stream: cudaStream_t) -> None:
|
||||
# `datatype` actually should be `ncclDataType_t`
|
||||
# and `op` should be `ncclRedOp_t`
|
||||
# both are aliases of `ctypes.c_int`
|
||||
# when we pass int to a function, it will be converted to `ctypes.c_int`
|
||||
# by ctypes automatically
|
||||
self.NCCL_CHECK(self._funcs["mcclReduceScatter"](sendbuff, recvbuff,
|
||||
count, datatype, op,
|
||||
comm, stream))
|
||||
|
||||
def ncclAllGather(self, sendbuff: buffer_type, recvbuff: buffer_type,
|
||||
count: int, datatype: int, comm: ncclComm_t,
|
||||
stream: cudaStream_t) -> None:
|
||||
# `datatype` actually should be `ncclDataType_t`
|
||||
# which is an aliases of `ctypes.c_int`
|
||||
# when we pass int to a function, it will be converted to `ctypes.c_int`
|
||||
# by ctypes automatically
|
||||
self.NCCL_CHECK(self._funcs["mcclAllGather"](sendbuff, recvbuff, count,
|
||||
datatype, comm, stream))
|
||||
|
||||
def ncclSend(self, sendbuff: buffer_type, count: int, datatype: int,
|
||||
dest: int, comm: ncclComm_t, stream: cudaStream_t) -> None:
|
||||
self.NCCL_CHECK(self._funcs["mcclSend"](sendbuff, count, datatype,
|
||||
dest, comm, stream))
|
||||
|
||||
def ncclRecv(self, recvbuff: buffer_type, count: int, datatype: int,
|
||||
src: int, comm: ncclComm_t, stream: cudaStream_t) -> None:
|
||||
self.NCCL_CHECK(self._funcs["mcclRecv"](recvbuff, count, datatype, src,
|
||||
comm, stream))
|
||||
|
||||
def ncclBroadcast(self, sendbuff: buffer_type, recvbuff: buffer_type,
|
||||
count: int, datatype: int, root: int, comm: ncclComm_t,
|
||||
stream: cudaStream_t) -> None:
|
||||
self.NCCL_CHECK(self._funcs["mcclBroadcast"](sendbuff, recvbuff, count,
|
||||
datatype, root, comm,
|
||||
stream))
|
||||
|
||||
def ncclCommDestroy(self, comm: ncclComm_t) -> None:
|
||||
self.NCCL_CHECK(self._funcs["mcclCommDestroy"](comm))
|
||||
|
||||
|
||||
__all__ = [
|
||||
"NCCLLibrary", "ncclDataTypeEnum", "ncclRedOpTypeEnum", "ncclUniqueId",
|
||||
"ncclComm_t", "cudaStream_t", "buffer_type"
|
||||
]
|
||||
585
vllm/distributed/device_communicators/shm_broadcast.py
Normal file
585
vllm/distributed/device_communicators/shm_broadcast.py
Normal file
@@ -0,0 +1,585 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pickle
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from multiprocessing import shared_memory
|
||||
from threading import Event
|
||||
from typing import Any, Optional, Union
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import zmq
|
||||
from torch.distributed import ProcessGroup
|
||||
from zmq import IPV6 # type: ignore
|
||||
from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.distributed.utils import StatelessProcessGroup, sched_yield
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import (get_ip, get_open_port, get_open_zmq_ipc_path,
|
||||
is_valid_ipv6_address)
|
||||
|
||||
VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class SpinTimer:
|
||||
|
||||
def record_activity(self):
|
||||
pass
|
||||
|
||||
def spin(self):
|
||||
sched_yield()
|
||||
|
||||
|
||||
class SpinSleepTimer(SpinTimer):
|
||||
"""
|
||||
In setups which have long inactivity periods it is desirable to reduce
|
||||
system power consumption when vllm does nothing. This would lead to more
|
||||
CPU thermal headroom when a request eventually comes, especially when
|
||||
multiple GPUs are connected as each GPU would otherwise pin one thread at
|
||||
100% CPU usage.
|
||||
|
||||
The simplest solution is to reduce polling frequency when there is no
|
||||
activity for a certain period of time.
|
||||
"""
|
||||
|
||||
def __init__(self, busy_loop_s: float = 3.0, wait_sleep_s: float = 0.1):
|
||||
self.last_activity = time.monotonic()
|
||||
self.busy_loop_s = busy_loop_s
|
||||
self.wait_sleep_s = wait_sleep_s
|
||||
|
||||
def record_activity(self):
|
||||
self.last_activity = time.monotonic()
|
||||
|
||||
def spin(self):
|
||||
curr_time = time.monotonic()
|
||||
if curr_time >= self.last_activity + self.busy_loop_s:
|
||||
time.sleep(self.wait_sleep_s)
|
||||
else:
|
||||
sched_yield()
|
||||
|
||||
|
||||
class ShmRingBuffer:
|
||||
|
||||
def __init__(self,
|
||||
n_reader: int,
|
||||
max_chunk_bytes: int,
|
||||
max_chunks: int,
|
||||
name: Optional[str] = None):
|
||||
"""
|
||||
A shared memory ring buffer implementation for broadcast communication.
|
||||
Essentially, it is a queue where only one will `enqueue` and multiple
|
||||
will `dequeue`. The max size of each item, together with the max number
|
||||
of items that can be stored in the buffer are known in advance.
|
||||
In this case, we don't need to synchronize the access to
|
||||
the buffer.
|
||||
|
||||
Buffer memory layout:
|
||||
data metadata
|
||||
| |
|
||||
| (current_idx) | (current_idx)
|
||||
v v
|
||||
+-------------------------------+----------------------------------------+
|
||||
| chunk0 | chunk1 | ... | chunk | metadata0 | metadata1 | ... | metadata |
|
||||
+-------------------------------+----------------------------------------+
|
||||
| max_chunks x max_chunk_bytes | max_chunks x (1 + n_reader) bytes |
|
||||
|
||||
metadata memory layout: each byte is a flag, the first byte is the written
|
||||
flag, and the rest are reader flags. The flags are set to 0 by default.
|
||||
+--------------+--------------+--------------+-----+--------------+
|
||||
| written_flag | reader0_flag | reader1_flag | ... | readerN_flag |
|
||||
+--------------+--------------+--------------+-----+--------------+
|
||||
|
||||
The state of metadata is as follows:
|
||||
|
||||
(case 1) 0???...???: the block is not written yet, cannot read, can write
|
||||
(case 2) 1000...000: the block is just written, can read, cannot write
|
||||
(case 3) 1???...???: the block is written and read by some readers, can read if not read, cannot write
|
||||
(case 4) 1111...111: the block is written and read by all readers, cannot read, can write
|
||||
|
||||
State transition for readers:
|
||||
|
||||
When a reader finds a block that it can read (case 2 or 3), it can yield the block for caller to read.
|
||||
Only after the caller finishes reading the block, the reader can mark the block as read.
|
||||
Readers only mark the block as read (from 0 to 1), the writer marks the block as ready to read (from 1 to 0).
|
||||
|
||||
State transition for writer:
|
||||
|
||||
When the writer writes to a block (case 1 or 4), it first resets the written flag to 0, converting either case
|
||||
to case 1. Then it can yield the block for caller to write. After the caller finishes writing the block, the writer
|
||||
can reset the reader flags to 0, and mark the block as written (from 0 to 1).
|
||||
NOTE: the order is important here, first reset the reader flags (so that we are still in case 1), then mark the block as written. The state transition is atomic. If we do it in the reverse order, it will go through case 3 and then back to case 2, and readers might read the intermediate case 3, which is not correct.
|
||||
|
||||
During creation, `name` is None and the buffer is created. We can pass the
|
||||
created object to other processes by pickling it. The other processes will
|
||||
get the name of the shared memory and open it, so that they can access the
|
||||
same shared memory buffer.
|
||||
"""# noqa
|
||||
self.n_reader = n_reader
|
||||
self.metadata_size = 1 + n_reader
|
||||
self.max_chunk_bytes = max_chunk_bytes
|
||||
self.max_chunks = max_chunks
|
||||
self.total_bytes_of_buffer = (self.max_chunk_bytes +
|
||||
self.metadata_size) * self.max_chunks
|
||||
self.data_offset = 0
|
||||
self.metadata_offset = self.max_chunk_bytes * self.max_chunks
|
||||
|
||||
if name is None:
|
||||
# we are creating a buffer
|
||||
self.is_creator = True
|
||||
self.shared_memory = shared_memory.SharedMemory(
|
||||
create=True, size=self.total_bytes_of_buffer)
|
||||
# initialize the metadata section to 0
|
||||
with memoryview(self.shared_memory.buf[self.metadata_offset:]
|
||||
) as metadata_buffer:
|
||||
torch.frombuffer(metadata_buffer, dtype=torch.uint8).fill_(0)
|
||||
else:
|
||||
# we are opening an existing buffer
|
||||
self.is_creator = False
|
||||
# fix to https://stackoverflow.com/q/62748654/9191338
|
||||
# Python incorrectly tracks shared memory even if it is not
|
||||
# created by the process. The following patch is a workaround.
|
||||
with patch("multiprocessing.resource_tracker.register",
|
||||
lambda *args, **kwargs: None):
|
||||
try:
|
||||
self.shared_memory = shared_memory.SharedMemory(name=name)
|
||||
# See https://docs.python.org/3/library/multiprocessing.shared_memory.html # noqa
|
||||
# Some platforms allocate memory based on page size,
|
||||
# so the shared memory block size may be larger or equal
|
||||
# to the requested size. The size parameter is ignored
|
||||
# when attaching to an existing block.
|
||||
assert (self.shared_memory.size
|
||||
>= self.total_bytes_of_buffer)
|
||||
except FileNotFoundError:
|
||||
# we might deserialize the object in a different node
|
||||
# in this case, this object is not used,
|
||||
# and we should suppress the error
|
||||
pass
|
||||
|
||||
def handle(self):
|
||||
return (self.n_reader, self.max_chunk_bytes, self.max_chunks,
|
||||
self.shared_memory.name)
|
||||
|
||||
def __reduce__(self):
|
||||
return (
|
||||
self.__class__,
|
||||
self.handle(),
|
||||
)
|
||||
|
||||
def __del__(self):
|
||||
if hasattr(self, "shared_memory"):
|
||||
self.shared_memory.close()
|
||||
if self.is_creator:
|
||||
self.shared_memory.unlink()
|
||||
|
||||
@contextmanager
|
||||
def get_data(self, current_idx: int):
|
||||
start = self.data_offset + current_idx * self.max_chunk_bytes
|
||||
end = start + self.max_chunk_bytes
|
||||
with memoryview(self.shared_memory.buf[start:end]) as buf:
|
||||
yield buf
|
||||
|
||||
@contextmanager
|
||||
def get_metadata(self, current_idx: int):
|
||||
start = self.metadata_offset + current_idx * self.metadata_size
|
||||
end = start + self.metadata_size
|
||||
with memoryview(self.shared_memory.buf[start:end]) as buf:
|
||||
yield buf
|
||||
|
||||
|
||||
@dataclass
|
||||
class Handle:
|
||||
local_reader_ranks: list[int] = field(default_factory=list)
|
||||
|
||||
buffer_handle: Optional[tuple[int, int, int, str]] = None
|
||||
local_subscribe_addr: Optional[str] = None
|
||||
remote_subscribe_addr: Optional[str] = None
|
||||
remote_addr_ipv6: bool = False
|
||||
|
||||
|
||||
class MessageQueue:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_reader, # number of all readers
|
||||
n_local_reader, # number of local readers through shared memory
|
||||
local_reader_ranks: Optional[list[int]] = None,
|
||||
max_chunk_bytes: int = 1024 * 1024 * 10,
|
||||
max_chunks: int = 10,
|
||||
connect_ip: Optional[str] = None,
|
||||
):
|
||||
if local_reader_ranks is None:
|
||||
local_reader_ranks = list(range(n_local_reader))
|
||||
else:
|
||||
assert len(local_reader_ranks) == n_local_reader
|
||||
self.n_local_reader = n_local_reader
|
||||
n_remote_reader = n_reader - n_local_reader
|
||||
self.n_remote_reader = n_remote_reader
|
||||
|
||||
context = Context()
|
||||
|
||||
if n_local_reader > 0:
|
||||
# for local readers, we will:
|
||||
# 1. create a shared memory ring buffer to communicate small data
|
||||
# 2. create a publish-subscribe socket to communicate large data
|
||||
self.buffer = ShmRingBuffer(n_local_reader, max_chunk_bytes,
|
||||
max_chunks)
|
||||
|
||||
# XPUB is very similar to PUB,
|
||||
# except that it can receive subscription messages
|
||||
# to confirm the number of subscribers
|
||||
self.local_socket = context.socket(XPUB)
|
||||
# set the verbose option so that we can receive every subscription
|
||||
# message. otherwise, we will only receive the first subscription
|
||||
# see http://api.zeromq.org/3-3:zmq-setsockopt for more details
|
||||
self.local_socket.setsockopt(XPUB_VERBOSE, True)
|
||||
local_subscribe_addr = get_open_zmq_ipc_path()
|
||||
logger.debug("Binding to %s", local_subscribe_addr)
|
||||
self.local_socket.bind(local_subscribe_addr)
|
||||
|
||||
self.current_idx = 0
|
||||
else:
|
||||
self.buffer = None # type: ignore
|
||||
local_subscribe_addr = None
|
||||
self.local_socket = None
|
||||
self.current_idx = -1
|
||||
|
||||
remote_addr_ipv6 = False
|
||||
if n_remote_reader > 0:
|
||||
# for remote readers, we will:
|
||||
# create a publish-subscribe socket to communicate large data
|
||||
if not connect_ip:
|
||||
connect_ip = get_ip()
|
||||
self.remote_socket = context.socket(XPUB)
|
||||
self.remote_socket.setsockopt(XPUB_VERBOSE, True)
|
||||
remote_subscribe_port = get_open_port()
|
||||
if is_valid_ipv6_address(connect_ip):
|
||||
self.remote_socket.setsockopt(IPV6, 1)
|
||||
remote_addr_ipv6 = True
|
||||
connect_ip = f"[{connect_ip}]"
|
||||
socket_addr = f"tcp://{connect_ip}:{remote_subscribe_port}"
|
||||
self.remote_socket.bind(socket_addr)
|
||||
remote_subscribe_addr = f"tcp://{connect_ip}:{remote_subscribe_port}"
|
||||
else:
|
||||
remote_subscribe_addr = None
|
||||
self.remote_socket = None
|
||||
|
||||
self._is_writer = True
|
||||
self._is_local_reader = False
|
||||
self.local_reader_rank = -1
|
||||
# rank does not matter for remote readers
|
||||
self._is_remote_reader = False
|
||||
self._read_spin_timer = SpinTimer()
|
||||
|
||||
self.handle = Handle(
|
||||
local_reader_ranks=local_reader_ranks,
|
||||
buffer_handle=self.buffer.handle()
|
||||
if self.buffer is not None else None,
|
||||
local_subscribe_addr=local_subscribe_addr,
|
||||
remote_subscribe_addr=remote_subscribe_addr,
|
||||
remote_addr_ipv6=remote_addr_ipv6,
|
||||
)
|
||||
|
||||
logger.info("vLLM message queue communication handle: %s", self.handle)
|
||||
|
||||
def export_handle(self) -> Handle:
|
||||
return self.handle
|
||||
|
||||
@staticmethod
|
||||
def create_from_handle(handle: Handle, rank) -> "MessageQueue":
|
||||
self = MessageQueue.__new__(MessageQueue)
|
||||
self.handle = handle
|
||||
self._is_writer = False
|
||||
|
||||
context = Context()
|
||||
|
||||
if rank in handle.local_reader_ranks:
|
||||
assert handle.buffer_handle is not None
|
||||
self.buffer = ShmRingBuffer(*handle.buffer_handle)
|
||||
self.current_idx = 0
|
||||
self.local_reader_rank = handle.local_reader_ranks.index(rank)
|
||||
self._is_local_reader = True
|
||||
self._is_remote_reader = False
|
||||
|
||||
self.local_socket = context.socket(SUB)
|
||||
self.local_socket.setsockopt_string(SUBSCRIBE, "")
|
||||
socket_addr = handle.local_subscribe_addr
|
||||
logger.debug("Connecting to %s", socket_addr)
|
||||
self.local_socket.connect(socket_addr)
|
||||
|
||||
self.remote_socket = None
|
||||
|
||||
self._read_spin_timer = SpinSleepTimer(
|
||||
) if envs.VLLM_SLEEP_WHEN_IDLE else SpinTimer()
|
||||
else:
|
||||
self.buffer = None # type: ignore
|
||||
self.current_idx = -1
|
||||
self.local_reader_rank = -1
|
||||
self._is_local_reader = False
|
||||
self._is_remote_reader = True
|
||||
|
||||
self.local_socket = None
|
||||
|
||||
self.remote_socket = context.socket(SUB)
|
||||
self.remote_socket.setsockopt_string(SUBSCRIBE, "")
|
||||
if handle.remote_addr_ipv6:
|
||||
self.remote_socket.setsockopt(IPV6, 1)
|
||||
socket_addr = handle.remote_subscribe_addr
|
||||
logger.debug("Connecting to %s", socket_addr)
|
||||
self.remote_socket.connect(socket_addr)
|
||||
|
||||
return self
|
||||
|
||||
def wait_until_ready(self):
|
||||
"""This is a collective operation. All processes (including the
|
||||
readers and the writer) should call this function.
|
||||
"""
|
||||
if self._is_writer:
|
||||
# wait for all readers to connect
|
||||
|
||||
# local readers
|
||||
for i in range(self.n_local_reader):
|
||||
# wait for subscription messages from all local readers
|
||||
self.local_socket.recv()
|
||||
if self.n_local_reader > 0:
|
||||
# send a message to all local readers
|
||||
# to make sure the publish channel is working
|
||||
self.local_socket.send(b"READY")
|
||||
|
||||
# remote readers
|
||||
for i in range(self.n_remote_reader):
|
||||
# wait for subscription messages from all remote readers
|
||||
self.remote_socket.recv()
|
||||
if self.n_remote_reader > 0:
|
||||
# send a message to all remote readers
|
||||
# to make sure the publish channel is working
|
||||
self.remote_socket.send(b"READY")
|
||||
elif self._is_local_reader:
|
||||
# wait for the writer to send a message
|
||||
recv = self.local_socket.recv()
|
||||
assert recv == b"READY"
|
||||
elif self._is_remote_reader:
|
||||
# wait for the writer to send a message
|
||||
recv = self.remote_socket.recv()
|
||||
assert recv == b"READY"
|
||||
|
||||
@contextmanager
|
||||
def acquire_write(self, timeout: Optional[float] = None):
|
||||
assert self._is_writer, "Only writers can acquire write"
|
||||
start_time = time.monotonic()
|
||||
n_warning = 1
|
||||
while True:
|
||||
with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
|
||||
read_count = sum(metadata_buffer[1:])
|
||||
written_flag = metadata_buffer[0]
|
||||
if written_flag and read_count != self.buffer.n_reader:
|
||||
# this block is written and not read by all readers
|
||||
# for writers, `self.current_idx` is the next block to write
|
||||
# if this block is not ready to write,
|
||||
# we need to wait until it is read by all readers
|
||||
|
||||
# Release the processor to other threads
|
||||
sched_yield()
|
||||
|
||||
# if we wait for a long time, log a message
|
||||
if (time.monotonic() - start_time
|
||||
> VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning):
|
||||
logger.debug(
|
||||
("No available shared memory broadcast block found"
|
||||
" in %s second."),
|
||||
VLLM_RINGBUFFER_WARNING_INTERVAL,
|
||||
)
|
||||
n_warning += 1
|
||||
|
||||
# if we time out, raise an exception
|
||||
if (timeout is not None
|
||||
and time.monotonic() - start_time > timeout):
|
||||
raise TimeoutError
|
||||
|
||||
continue
|
||||
# found a block that is either
|
||||
# (1) not written
|
||||
# (2) read by all readers
|
||||
|
||||
# mark the block as not written
|
||||
metadata_buffer[0] = 0
|
||||
# let caller write to the buffer
|
||||
with self.buffer.get_data(self.current_idx) as buf:
|
||||
yield buf
|
||||
|
||||
# caller has written to the buffer
|
||||
# NOTE: order is important here
|
||||
# first set the read flags to 0
|
||||
# then set the written flag to 1
|
||||
# otherwise, the readers may think they already read the block
|
||||
for i in range(1, self.buffer.n_reader + 1):
|
||||
# set read flag to 0, meaning it is not read yet
|
||||
metadata_buffer[i] = 0
|
||||
# mark the block as written
|
||||
metadata_buffer[0] = 1
|
||||
self.current_idx = (self.current_idx +
|
||||
1) % self.buffer.max_chunks
|
||||
break
|
||||
|
||||
@contextmanager
|
||||
def acquire_read(self,
|
||||
timeout: Optional[float] = None,
|
||||
cancel: Optional[Event] = None):
|
||||
assert self._is_local_reader, "Only readers can acquire read"
|
||||
start_time = time.monotonic()
|
||||
n_warning = 1
|
||||
while True:
|
||||
with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
|
||||
read_flag = metadata_buffer[self.local_reader_rank + 1]
|
||||
written_flag = metadata_buffer[0]
|
||||
if not written_flag or read_flag:
|
||||
# this block is either
|
||||
# (1) not written
|
||||
# (2) already read by this reader
|
||||
|
||||
# for readers, `self.current_idx` is the next block to read
|
||||
# if this block is not ready,
|
||||
# we need to wait until it is written
|
||||
|
||||
# Release the processor to other threads
|
||||
self._read_spin_timer.spin()
|
||||
|
||||
# if we wait for a long time, log a message
|
||||
if (time.monotonic() - start_time
|
||||
> VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning):
|
||||
logger.debug(
|
||||
("No available shared memory broadcast block found"
|
||||
" in %s second."),
|
||||
VLLM_RINGBUFFER_WARNING_INTERVAL,
|
||||
)
|
||||
n_warning += 1
|
||||
|
||||
if cancel is not None and cancel.is_set():
|
||||
raise RuntimeError("cancelled")
|
||||
|
||||
# if we time out, raise an exception
|
||||
if (timeout is not None
|
||||
and time.monotonic() - start_time > timeout):
|
||||
raise TimeoutError
|
||||
|
||||
continue
|
||||
# found a block that is not read by this reader
|
||||
# let caller read from the buffer
|
||||
with self.buffer.get_data(self.current_idx) as buf:
|
||||
yield buf
|
||||
|
||||
# caller has read from the buffer
|
||||
# set the read flag
|
||||
metadata_buffer[self.local_reader_rank + 1] = 1
|
||||
self.current_idx = (self.current_idx +
|
||||
1) % self.buffer.max_chunks
|
||||
|
||||
self._read_spin_timer.record_activity()
|
||||
break
|
||||
|
||||
def enqueue(self, obj, timeout: Optional[float] = None):
|
||||
""" Write to message queue with optional timeout (in seconds) """
|
||||
assert self._is_writer, "Only writers can enqueue"
|
||||
serialized_obj = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
if self.n_local_reader > 0:
|
||||
if len(serialized_obj) >= self.buffer.max_chunk_bytes:
|
||||
with self.acquire_write(timeout) as buf:
|
||||
buf[0] = 1 # overflow
|
||||
self.local_socket.send(serialized_obj)
|
||||
else:
|
||||
with self.acquire_write(timeout) as buf:
|
||||
buf[0] = 0 # not overflow
|
||||
buf[1:len(serialized_obj) + 1] = serialized_obj
|
||||
if self.n_remote_reader > 0:
|
||||
self.remote_socket.send(serialized_obj)
|
||||
|
||||
def dequeue(self,
|
||||
timeout: Optional[float] = None,
|
||||
cancel: Optional[Event] = None):
|
||||
""" Read from message queue with optional timeout (in seconds) """
|
||||
if self._is_local_reader:
|
||||
with self.acquire_read(timeout, cancel) as buf:
|
||||
overflow = buf[0] == 1
|
||||
if not overflow:
|
||||
# no need to know the size of serialized object
|
||||
# pickle format contains the size information internally
|
||||
# see https://docs.python.org/3/library/pickle.html
|
||||
obj = pickle.loads(buf[1:])
|
||||
if overflow:
|
||||
obj = MessageQueue.recv(self.local_socket, timeout)
|
||||
elif self._is_remote_reader:
|
||||
obj = MessageQueue.recv(self.remote_socket, timeout)
|
||||
else:
|
||||
raise RuntimeError("Only readers can dequeue")
|
||||
return obj
|
||||
|
||||
@staticmethod
|
||||
def recv(socket: zmq.Socket, timeout: Optional[float]) -> Any:
|
||||
timeout_ms = None if timeout is None else int(timeout * 1000)
|
||||
if not socket.poll(timeout=timeout_ms):
|
||||
raise TimeoutError
|
||||
recv = socket.recv(copy=False)
|
||||
return pickle.loads(recv.buffer)
|
||||
|
||||
def broadcast_object(self, obj=None):
|
||||
if self._is_writer:
|
||||
self.enqueue(obj)
|
||||
return obj
|
||||
else:
|
||||
return self.dequeue()
|
||||
|
||||
@staticmethod
|
||||
def create_from_process_group(pg: Union[ProcessGroup,
|
||||
StatelessProcessGroup],
|
||||
max_chunk_bytes,
|
||||
max_chunks,
|
||||
writer_rank=0) -> "MessageQueue":
|
||||
if isinstance(pg, ProcessGroup):
|
||||
group_rank = dist.get_rank(pg)
|
||||
group_world_size = dist.get_world_size(pg)
|
||||
global_ranks = dist.get_process_group_ranks(pg)
|
||||
else:
|
||||
group_rank = pg.rank
|
||||
group_world_size = pg.world_size
|
||||
global_ranks = list(range(pg.world_size))
|
||||
|
||||
from vllm.distributed.parallel_state import in_the_same_node_as
|
||||
status = in_the_same_node_as(pg, source_rank=writer_rank)
|
||||
same_node_ranks = [i for i, s in enumerate(status) if s]
|
||||
n_reader = group_world_size - 1
|
||||
n_local_reader = len(same_node_ranks) - 1
|
||||
local_reader_ranks = [i for i in same_node_ranks if i != writer_rank]
|
||||
buffer_io: MessageQueue
|
||||
if group_rank == writer_rank:
|
||||
buffer_io = MessageQueue(
|
||||
n_reader=n_reader,
|
||||
n_local_reader=n_local_reader,
|
||||
local_reader_ranks=local_reader_ranks,
|
||||
max_chunk_bytes=max_chunk_bytes,
|
||||
max_chunks=max_chunks,
|
||||
)
|
||||
handle = buffer_io.export_handle()
|
||||
if isinstance(pg, ProcessGroup):
|
||||
dist.broadcast_object_list([handle],
|
||||
src=global_ranks[writer_rank],
|
||||
group=pg)
|
||||
else:
|
||||
pg.broadcast_obj(handle, writer_rank)
|
||||
else:
|
||||
if isinstance(pg, ProcessGroup):
|
||||
recv = [None]
|
||||
dist.broadcast_object_list(recv,
|
||||
src=global_ranks[writer_rank],
|
||||
group=pg)
|
||||
handle = recv[0] # type: ignore
|
||||
else:
|
||||
handle = pg.broadcast_obj(None, writer_rank)
|
||||
buffer_io = MessageQueue.create_from_handle(handle, group_rank)
|
||||
buffer_io.wait_until_ready()
|
||||
return buffer_io
|
||||
103
vllm/distributed/device_communicators/tpu_communicator.py
Normal file
103
vllm/distributed/device_communicators/tpu_communicator.py
Normal file
@@ -0,0 +1,103 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .base_device_communicator import DeviceCommunicatorBase
|
||||
|
||||
USE_RAY = parallel_config = get_current_vllm_config(
|
||||
).parallel_config.distributed_executor_backend == "ray"
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
if current_platform.is_tpu():
|
||||
import torch_xla
|
||||
import torch_xla.core.xla_model as xm
|
||||
import torch_xla.runtime as xr
|
||||
from torch_xla._internal import pjrt
|
||||
from torch_xla.distributed.xla_multiprocessing import (
|
||||
create_optimized_replica_groups)
|
||||
|
||||
if USE_RAY:
|
||||
from vllm.executor import ray_utils
|
||||
|
||||
|
||||
class TpuCommunicator(DeviceCommunicatorBase):
|
||||
|
||||
def __init__(self,
|
||||
cpu_group: ProcessGroup,
|
||||
device: Optional[torch.device] = None,
|
||||
device_group: Optional[ProcessGroup] = None,
|
||||
unique_name: str = ""):
|
||||
super().__init__(cpu_group, device, device_group, unique_name)
|
||||
|
||||
# NOTE(woosuk): When using TP > 1 on TPUs, every TPU on the same node
|
||||
# must be used together. Therefore, the local rank and world size can
|
||||
# be simply calculated as follows.
|
||||
global_rank = self.global_rank
|
||||
global_world_size = self.global_world_size
|
||||
|
||||
if USE_RAY:
|
||||
logger.info("TpuCommunicator initialized with RAY")
|
||||
# Calculate how many TPU nodes are in the current deployment. This
|
||||
# is the Ray placement group if it is deployed with Ray. Default
|
||||
# to the number of TPU nodes in the Ray cluster. The number of TPU
|
||||
# nodes is computed by the total number of TPUs divided by the
|
||||
# number of TPU accelerators per node, to account for clusters
|
||||
# with both CPUs and TPUs.
|
||||
num_nodes = ray_utils.get_num_tpu_nodes()
|
||||
num_nodes_in_pg = ray_utils.get_num_nodes_in_placement_group()
|
||||
if num_nodes_in_pg > 0:
|
||||
num_nodes = num_nodes_in_pg
|
||||
|
||||
local_world_size = global_world_size // num_nodes
|
||||
local_rank = global_rank % local_world_size
|
||||
else:
|
||||
logger.info("TpuCommunicator initialized with MP")
|
||||
# Sanity: Verify we run on a single host
|
||||
num_hosts = torch_xla.tpu.num_tpu_workers()
|
||||
assert num_hosts == 1
|
||||
|
||||
# Get the current number of TPUs (we have locally)
|
||||
local_world_size = torch_xla.tpu.num_available_chips()
|
||||
|
||||
# Get current rank
|
||||
local_rank = global_rank % local_world_size
|
||||
|
||||
# Ensure environment variables are set for multihost deployments.
|
||||
# On GKE, this is needed for libtpu and TPU driver to know which TPU
|
||||
# chip is actually visible. Otherwise the TPU driver will fail to
|
||||
# initialize because the number of devices would be different from
|
||||
# the number of visible worker addresses.
|
||||
os.environ["CLOUD_TPU_TASK_ID"] = str(global_rank)
|
||||
os.environ["TPU_VISIBLE_CHIPS"] = str(local_rank)
|
||||
|
||||
pjrt.initialize_multiprocess(local_rank, local_world_size)
|
||||
xr._init_world_size_ordinal()
|
||||
self.groups = create_optimized_replica_groups()
|
||||
|
||||
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
|
||||
# TODO: Remove the groups specification after XLA compiler can support
|
||||
# auto-reordering the ring order for all-reduce.
|
||||
return xm.all_reduce(xm.REDUCE_SUM, input_, groups=self.groups)
|
||||
|
||||
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
||||
assert dim == -1, "TPUs only support dim=-1 for all-gather."
|
||||
return xm.all_gather(input_, dim=dim)
|
||||
|
||||
|
||||
try:
|
||||
from tpu_commons.distributed.device_communicators import (
|
||||
TpuCommunicator as TpuCommonsCommunicator)
|
||||
TpuCommunicator = TpuCommonsCommunicator # type: ignore
|
||||
except ImportError:
|
||||
logger.info("tpu_commons not found, using vLLM's TpuCommunicator")
|
||||
pass
|
||||
55
vllm/distributed/device_communicators/xpu_communicator.py
Normal file
55
vllm/distributed/device_communicators/xpu_communicator.py
Normal file
@@ -0,0 +1,55 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from .base_device_communicator import DeviceCommunicatorBase
|
||||
|
||||
|
||||
class XpuCommunicator(DeviceCommunicatorBase):
|
||||
|
||||
def __init__(self,
|
||||
cpu_group: ProcessGroup,
|
||||
device: Optional[torch.device] = None,
|
||||
device_group: Optional[ProcessGroup] = None,
|
||||
unique_name: str = ""):
|
||||
super().__init__(cpu_group, device, device_group, unique_name)
|
||||
|
||||
def all_reduce(self, input_) -> torch.Tensor:
|
||||
dist.all_reduce(input_, group=self.device_group)
|
||||
return input_
|
||||
|
||||
def gather(self,
|
||||
input_: torch.Tensor,
|
||||
dst: int = 0,
|
||||
dim: int = -1) -> Optional[torch.Tensor]:
|
||||
assert -input_.dim() <= dim < input_.dim(), (
|
||||
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
|
||||
if dim < 0:
|
||||
# Convert negative dim to positive.
|
||||
dim += input_.dim()
|
||||
# For xpu path, gather doesn't work properly together with ray
|
||||
# cluster so we use all_gather instead for now.
|
||||
input_size = input_.size()
|
||||
# Allocate output tensor.
|
||||
output_tensor = torch.empty((self.world_size, ) + input_size,
|
||||
dtype=input_.dtype,
|
||||
device=input_.device)
|
||||
# All-gather.
|
||||
dist.all_gather_into_tensor(output_tensor,
|
||||
input_,
|
||||
group=self.device_group)
|
||||
if self.rank_in_group == dst:
|
||||
# Reshape
|
||||
output_tensor = output_tensor.movedim(0, dim)
|
||||
output_tensor = output_tensor.reshape(input_size[:dim] +
|
||||
(self.world_size *
|
||||
input_size[dim], ) +
|
||||
input_size[dim + 1:])
|
||||
else:
|
||||
output_tensor = None
|
||||
return output_tensor
|
||||
356
vllm/distributed/kv_events.py
Normal file
356
vllm/distributed/kv_events.py
Normal file
@@ -0,0 +1,356 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import queue
|
||||
import threading
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import deque
|
||||
from dataclasses import asdict
|
||||
from itertools import count
|
||||
from queue import Queue
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import msgspec
|
||||
import zmq
|
||||
|
||||
from vllm.config import KVEventsConfig
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class EventBatch(
|
||||
msgspec.Struct,
|
||||
array_like=True, # type: ignore[call-arg]
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
gc=False, # type: ignore[call-arg]
|
||||
):
|
||||
ts: float
|
||||
events: list[Any]
|
||||
data_parallel_rank: Optional[int] = None
|
||||
|
||||
|
||||
class KVCacheEvent(
|
||||
msgspec.Struct,
|
||||
array_like=True, # type: ignore[call-arg]
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
gc=False, # type: ignore[call-arg]
|
||||
tag=True):
|
||||
"""Base class for all KV cache-related events"""
|
||||
|
||||
|
||||
class BlockStored(KVCacheEvent):
|
||||
block_hashes: list[int]
|
||||
parent_block_hash: Optional[int]
|
||||
token_ids: list[int]
|
||||
block_size: int
|
||||
lora_id: Optional[int]
|
||||
|
||||
|
||||
class BlockRemoved(KVCacheEvent):
|
||||
block_hashes: list[int]
|
||||
|
||||
|
||||
class AllBlocksCleared(KVCacheEvent):
|
||||
pass
|
||||
|
||||
|
||||
class KVEventBatch(EventBatch):
|
||||
events: list[Union[BlockStored, BlockRemoved, AllBlocksCleared]]
|
||||
|
||||
|
||||
class EventPublisher(ABC):
|
||||
"""Lightweight publisher for EventBatch batches with data parallelism
|
||||
support.
|
||||
|
||||
In data parallel setups, each DP rank runs its own EventPublisher instance
|
||||
to avoid duplicate events and ensure proper event attribution:
|
||||
|
||||
- Each DP rank creates a separate publisher
|
||||
- Publishers automatically annotate events with their data_parallel_rank
|
||||
- This allows consumers to distinguish events from different DP ranks
|
||||
|
||||
The publisher is responsible for adding DP metadata since the scheduler
|
||||
operates independently of DP topology and shouldn't need DP awareness.
|
||||
"""
|
||||
|
||||
def __init__(self, data_parallel_rank: int = 0) -> None:
|
||||
self._data_parallel_rank = data_parallel_rank
|
||||
|
||||
@abstractmethod
|
||||
def publish(self, events: EventBatch) -> None:
|
||||
"""Emit events in order.
|
||||
|
||||
Implementations should guarantee at-least-once delivery and
|
||||
monotonic ordering (e.g., via sequence numbers).
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def shutdown(self) -> None:
|
||||
"""Shutdown the publisher."""
|
||||
|
||||
|
||||
class NullEventPublisher(EventPublisher):
|
||||
"""No-op implementation (default when disabled)."""
|
||||
|
||||
def publish(self, events) -> None:
|
||||
return
|
||||
|
||||
def shutdown(self) -> None:
|
||||
return
|
||||
|
||||
|
||||
class ZmqEventPublisher(EventPublisher):
|
||||
"""Reliable PUB/ROUTER publisher with an in-memory replay buffer.
|
||||
|
||||
Spawns a separate thread to handle publishing from a queue.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
endpoint:
|
||||
PUB address. Use ``tcp://*:5557`` to bind or ``tcp://host:5557`` to
|
||||
connect.
|
||||
replay_endpoint:
|
||||
Optional ROUTER address for replay requests. When given, subscribers can
|
||||
request missed batches by sending the starting sequence number as an
|
||||
8-byte big-endian integer.
|
||||
buffer_steps:
|
||||
Number of past batches to keep for replay.
|
||||
hwm:
|
||||
ZeroMQ high-water-mark for PUB socket.
|
||||
max_queue_size:
|
||||
Maximum number of events to buffer in memory.
|
||||
topic:
|
||||
Topic to publish events to.
|
||||
"""
|
||||
SHUTDOWN_TIMEOUT: float = 1.0
|
||||
END_SEQ = (-1).to_bytes(8, "big", signed=True)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_parallel_rank: int,
|
||||
endpoint: str = "tcp://*:5557",
|
||||
replay_endpoint: Optional[str] = None,
|
||||
buffer_steps: int = 10_000,
|
||||
hwm: int = 100_000,
|
||||
max_queue_size: int = 100_000,
|
||||
topic: str = "",
|
||||
) -> None:
|
||||
# Storage
|
||||
super().__init__(data_parallel_rank)
|
||||
self._event_queue = Queue[Optional[EventBatch]](maxsize=max_queue_size)
|
||||
self._buffer = deque[tuple[int, bytes]](maxlen=buffer_steps)
|
||||
|
||||
# ZMQ sockets
|
||||
self._ctx = zmq.Context.instance()
|
||||
self._pub: Optional[zmq.Socket] = None
|
||||
self._replay: Optional[zmq.Socket] = None
|
||||
self._dp_rank = data_parallel_rank
|
||||
|
||||
self._endpoint = self.offset_endpoint_port(endpoint, self._dp_rank)
|
||||
self._replay_endpoint = self.offset_endpoint_port(
|
||||
replay_endpoint, self._dp_rank)
|
||||
self._hwm = hwm
|
||||
self._socket_setup()
|
||||
|
||||
# Payload
|
||||
self._seq_gen = count()
|
||||
self._topic_bytes = topic.encode('utf-8')
|
||||
|
||||
# Thread
|
||||
self._running = True
|
||||
logger.info("Starting ZMQ publisher thread")
|
||||
|
||||
self._thread = threading.Thread(target=self._publisher_thread,
|
||||
daemon=True,
|
||||
name="zmq-publisher")
|
||||
self._thread.start()
|
||||
|
||||
def publish(self, events: EventBatch) -> None:
|
||||
if not self._running:
|
||||
raise RuntimeError("Publisher is closed")
|
||||
if events.data_parallel_rank is None:
|
||||
events.data_parallel_rank = self._data_parallel_rank
|
||||
self._event_queue.put(events)
|
||||
|
||||
def shutdown(self) -> None:
|
||||
"""Stop the publisher thread and clean up resources."""
|
||||
self._running = False
|
||||
self._event_queue.put_nowait(None)
|
||||
|
||||
start = time.time()
|
||||
pending_items = True
|
||||
while pending_items and (time.time() - start < self.SHUTDOWN_TIMEOUT):
|
||||
pending_items = not self._event_queue.empty()
|
||||
if pending_items:
|
||||
time.sleep(0.1)
|
||||
|
||||
if pending_items:
|
||||
logger.warning(
|
||||
"Warning: Queue still has %s items after %s seconds timeout",
|
||||
self._event_queue.qsize(),
|
||||
self.SHUTDOWN_TIMEOUT,
|
||||
)
|
||||
|
||||
if self._thread.is_alive():
|
||||
self._thread.join(timeout=self.SHUTDOWN_TIMEOUT)
|
||||
|
||||
# Clean up ZMQ resources
|
||||
try:
|
||||
if self._pub is not None:
|
||||
self._pub.close(linger=0)
|
||||
if self._replay is not None:
|
||||
self._replay.close(linger=0)
|
||||
finally:
|
||||
pass # Do not terminate context; other sockets may use it
|
||||
|
||||
def _socket_setup(self) -> None:
|
||||
"""Initialize sockets
|
||||
https://pyzmq.readthedocs.io/en/v19.0.0/morethanbindings.html#thread-safety
|
||||
"""
|
||||
if self._pub is None:
|
||||
self._pub = self._ctx.socket(zmq.PUB)
|
||||
self._pub.set_hwm(self._hwm)
|
||||
# Heuristic: bind if wildcard / * present, else connect.
|
||||
# bind stable, connect volatile convention
|
||||
if (self._endpoint is not None
|
||||
and ("*" in self._endpoint or "::" in self._endpoint
|
||||
or self._endpoint.startswith("ipc://")
|
||||
or self._endpoint.startswith("inproc://"))):
|
||||
self._pub.bind(self._endpoint)
|
||||
elif self._endpoint is not None:
|
||||
self._pub.connect(self._endpoint)
|
||||
|
||||
# Set up replay socket: use ROUTER
|
||||
# 1) handles multiple REQ clients (identities)
|
||||
# 2) lets us send back one request → many replies (streamed events)
|
||||
# 3) works in our non‑blocking poll loop alongside PUB
|
||||
if self._replay_endpoint is not None:
|
||||
self._replay = self._ctx.socket(zmq.ROUTER)
|
||||
self._replay.bind(self._replay_endpoint)
|
||||
|
||||
def _publisher_thread(self) -> None:
|
||||
"""Background thread that processes the event queue."""
|
||||
self._pack = msgspec.msgpack.Encoder()
|
||||
|
||||
assert self._pub is not None # narrows type for mypy
|
||||
|
||||
while self._running or self._event_queue.qsize() > 0:
|
||||
# --- replay (non-critical) ---------------------------------
|
||||
if self._replay is not None and self._replay.poll(0):
|
||||
try:
|
||||
self._service_replay()
|
||||
except Exception as e:
|
||||
logger.exception("Error in replay: %s", e)
|
||||
|
||||
# --- main queue (critical) ---------------------------------
|
||||
try:
|
||||
event = self._event_queue.get(timeout=0.1)
|
||||
if event is None:
|
||||
break # Sentinel received, exit thread
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
try:
|
||||
seq = next(self._seq_gen)
|
||||
|
||||
payload = self._pack.encode(event)
|
||||
seq_bytes = seq.to_bytes(8, "big")
|
||||
self._pub.send_multipart(
|
||||
(self._topic_bytes, seq_bytes, payload))
|
||||
|
||||
self._buffer.append((seq, payload))
|
||||
self._event_queue.task_done()
|
||||
|
||||
except Exception as e:
|
||||
# Publishing failed; back-off a bit to avoid a tight error loop
|
||||
logger.exception("Error in publisher thread: %s", e)
|
||||
time.sleep(0.1)
|
||||
|
||||
def _service_replay(self) -> None:
|
||||
"""If a replay request is waiting, send buffered batches."""
|
||||
assert self._replay is not None # narrows type for mypy
|
||||
|
||||
frame = self._replay.recv_multipart()
|
||||
if len(frame) != 3:
|
||||
logger.warning("Invalid replay request: %s", frame)
|
||||
return
|
||||
client_id, _, start_seq_bytes = frame
|
||||
start_seq = int.from_bytes(start_seq_bytes, "big")
|
||||
|
||||
for seq, buf in self._buffer:
|
||||
if seq >= start_seq:
|
||||
# [identity, empty_delim, seq_bytes, payload]
|
||||
# (identity, empty_delim) are stripped off by the router
|
||||
# receiving payload is (seq_bytes, payload)
|
||||
self._replay.send_multipart(
|
||||
(client_id, b"", seq.to_bytes(8, "big"), buf))
|
||||
# Send end of sequence marker
|
||||
# receiving payload is (-1, b""")
|
||||
self._replay.send_multipart((client_id, b"", self.END_SEQ, b""))
|
||||
|
||||
@staticmethod
|
||||
def offset_endpoint_port(endpoint: Optional[str],
|
||||
data_parallel_rank: int) -> Optional[str]:
|
||||
"""Helper function to offset the port in an endpoint by
|
||||
the data parallel rank.
|
||||
|
||||
Args:
|
||||
endpoint: The endpoint string
|
||||
(e.g., "tcp://*:5557" or "inproc://cache")
|
||||
data_parallel_rank: The data parallel rank to offset by
|
||||
|
||||
Returns:
|
||||
The endpoint with the port offset by data_parallel_rank
|
||||
or suffix appended
|
||||
"""
|
||||
# Do nothing if input is None or data_parallel_rank is 0
|
||||
if not endpoint or data_parallel_rank == 0:
|
||||
return endpoint
|
||||
|
||||
if "inproc" in endpoint:
|
||||
return f"{endpoint}_dp{data_parallel_rank}"
|
||||
if "tcp" in endpoint:
|
||||
if endpoint and ":" in endpoint:
|
||||
# Get everything after the last colon (the port)
|
||||
last_colon_idx = endpoint.rfind(":")
|
||||
base_addr = endpoint[:last_colon_idx]
|
||||
base_port = int(endpoint[last_colon_idx + 1:])
|
||||
new_port = base_port + data_parallel_rank
|
||||
return f"{base_addr}:{new_port}"
|
||||
return endpoint
|
||||
raise ValueError("Invalid endpoint: must contain 'inproc' or 'tcp'")
|
||||
|
||||
|
||||
class EventPublisherFactory:
|
||||
_registry: dict[str, Callable[..., EventPublisher]] = {
|
||||
"null": NullEventPublisher,
|
||||
"zmq": ZmqEventPublisher,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def register_publisher(cls, name: str,
|
||||
ctor: Callable[..., EventPublisher]) -> None:
|
||||
if name in cls._registry:
|
||||
raise KeyError(f"publisher '{name}' already registered")
|
||||
cls._registry[name] = ctor
|
||||
|
||||
@classmethod
|
||||
def create(cls,
|
||||
config: Optional[KVEventsConfig],
|
||||
data_parallel_rank: int = 0) -> EventPublisher:
|
||||
"""Create publisher from a config mapping."""
|
||||
if not config:
|
||||
return NullEventPublisher()
|
||||
|
||||
config_dict = asdict(config)
|
||||
|
||||
kind = config_dict.pop("publisher", "null")
|
||||
config_dict.pop("enable_kv_cache_events")
|
||||
try:
|
||||
constructor = cls._registry[kind]
|
||||
except KeyError as exc:
|
||||
raise ValueError(f"Unknown event publisher '{kind}'") from exc
|
||||
return constructor(data_parallel_rank=data_parallel_rank,
|
||||
**config_dict)
|
||||
29
vllm/distributed/kv_transfer/README.md
Normal file
29
vllm/distributed/kv_transfer/README.md
Normal file
@@ -0,0 +1,29 @@
|
||||
|
||||
# Distributed KV cache transfer
|
||||
|
||||
This folder implements distributed KV cache transfer across vLLM instances.
|
||||
Currently the main usecase is for disaggregated prefilling.
|
||||
|
||||
## Abstractions
|
||||
|
||||
The KV cache transfer contains three layer of abstractions:
|
||||
|
||||
- KV pipe: a FIFO pipe for torch.tensor transmission. Key APIs: `send_tensor` and `recv_tensor`.
|
||||
- KV lookup buffer: a lookup buffer for KV caches. Key: the tokens, value: the KV caches (and/or hidden states). Key APIs: `insert` and `drop_select` (similar to SQL semantics).
|
||||
- KV connector: a connector that connects the KV pipe and KV lookup buffer to vLLM. Key APIs: `send_kv_caches_and_hidden_states` and `recv_kv_caches_and_hidden_states`.
|
||||
|
||||
Why we need KV lookup buffer: FIFO pipe itself is not enough as prefill vLLM worker may process requests in a different order compared to decode vLLM worker. Say the QPS is really high, prefill worker may handle requests in order A -> B -> C, but the decode worker may process request C first. This is not the case that can be naturally handled by FIFO pipe, so we provide KV lookup buffer to help translate a FIFO pipe to a lookup buffer.
|
||||
|
||||
NOTE: KV pipe layer is bypassible: you can skip this layer if your distributed
|
||||
communication service already supports key-value-based lookup (like redis or
|
||||
RDMA database).
|
||||
|
||||
NOTE: If you want to not only transfer KV caches, but adjust the model execution flow of vLLM as well (for example, allow vLLM to receive KV caches on some tokens and do prefill on the remaining tokens), you can bypass both KV pipe layer and KV lookup buffer layer, and directly implement on KV connector layer. Bear in mind that as vLLM's model input is constantly changing, this implementation will likely be broken when vLLM has new updates.
|
||||
|
||||
## Disaggregated prefilling
|
||||
|
||||
The example usage is in [this file](../../../examples/online_serving/disaggregated_prefill.sh).
|
||||
|
||||
Here is the diagram of how we run disaggregated prefilling.
|
||||
|
||||

|
||||
12
vllm/distributed/kv_transfer/__init__.py
Normal file
12
vllm/distributed/kv_transfer/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm.distributed.kv_transfer.kv_transfer_state import (
|
||||
KVConnectorBaseType, ensure_kv_transfer_initialized, get_kv_transfer_group,
|
||||
has_kv_transfer_group, is_v1_kv_transfer_group)
|
||||
|
||||
__all__ = [
|
||||
"get_kv_transfer_group", "has_kv_transfer_group",
|
||||
"is_v1_kv_transfer_group", "ensure_kv_transfer_initialized",
|
||||
"KVConnectorBaseType"
|
||||
]
|
||||
BIN
vllm/distributed/kv_transfer/disagg_prefill_workflow.jpg
Normal file
BIN
vllm/distributed/kv_transfer/disagg_prefill_workflow.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 139 KiB |
128
vllm/distributed/kv_transfer/kv_connector/base.py
Normal file
128
vllm/distributed/kv_transfer/kv_connector/base.py
Normal file
@@ -0,0 +1,128 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
KVConnectorBase Class for Distributed KV Cache & Hidden State communication
|
||||
|
||||
The class provides two primary abstract methods:
|
||||
1. send_kv_caches_and_hidden_states(): Send KV caches and hidden states
|
||||
2. recv_kv_caches_and_hidden_states(): Recv KV caches and hidden states
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
|
||||
|
||||
|
||||
class KVConnectorBase(ABC):
|
||||
"""
|
||||
Abstract base class for a KV connector.
|
||||
|
||||
The class provides two primary abstract methods:
|
||||
1. send_kv_caches_and_hidden_states(): Send KV caches and hidden states
|
||||
2. recv_kv_caches_and_hidden_states(): Recv KV caches and hidden states
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
self,
|
||||
rank: int,
|
||||
local_rank: int,
|
||||
config: "VllmConfig",
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def close(self) -> None:
|
||||
"""Close the buffer and release resources.
|
||||
|
||||
This method is responsible for cleaning up resources related to the
|
||||
connector when it is no longer needed.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: This method must be implemented in subclasses.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def send_kv_caches_and_hidden_states(
|
||||
self,
|
||||
model_executable: torch.nn.Module,
|
||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
||||
kv_caches: list[torch.Tensor],
|
||||
hidden_or_intermediate_states: Union[torch.Tensor,
|
||||
IntermediateTensors],
|
||||
) -> None:
|
||||
"""
|
||||
Send KV caches and hidden states to the connector.
|
||||
|
||||
This method processes the input tokens, KV caches, and
|
||||
hidden/intermediate states for a given model and sends the data to the
|
||||
decode instance.
|
||||
|
||||
Args:
|
||||
model_executable (torch.nn.Module): The model executable containing
|
||||
start and end layer information.
|
||||
model_input (ModelInputForGPUWithSamplingMetadata): The input
|
||||
metadata from vLLM.
|
||||
kv_caches (list[torch.Tensor]): List of KV caches (keys and values)
|
||||
for each layer.
|
||||
hidden_or_intermediate_states (Union[torch.Tensor,
|
||||
IntermediateTensors]):
|
||||
The hidden or intermediate states associated with the tokens.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def recv_kv_caches_and_hidden_states(
|
||||
self, model_executable: torch.nn.Module,
|
||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
||||
kv_caches: list[torch.Tensor]
|
||||
) -> tuple[Union[torch.Tensor, IntermediateTensors], bool,
|
||||
"ModelInputForGPUWithSamplingMetadata"]:
|
||||
"""
|
||||
Receive KV caches and hidden states from the connector.
|
||||
|
||||
This method attempts to retrieve KV caches and hidden states for input
|
||||
tokens. If all required KV caches and hidden states are received, it
|
||||
will bypass model input, else it will fall back to normal vLLM model
|
||||
forwarding.
|
||||
|
||||
Args:
|
||||
model_executable (torch.nn.Module):
|
||||
The model executable from vLLM modelrunner.
|
||||
model_input (ModelInputForGPUWithSamplingMetadata):
|
||||
The model input from vLLM modelrunner.
|
||||
kv_caches (list[torch.Tensor]):
|
||||
List of KV caches for each layer.
|
||||
|
||||
Returns:
|
||||
- hidden_or_intermediate_states (torch.Tensor or
|
||||
IntermediateTensors):
|
||||
Concatenated hidden states if all required data is retrieved,
|
||||
otherwise `None`.
|
||||
- bypass_model_exec (bool):
|
||||
Indicates whether the model execution can be skipped (True) or
|
||||
needs to be redone (False).
|
||||
- model_input (ModelInputForGPUWithSamplingMetadata):
|
||||
Optionally adjusted input metadata for re-execution when
|
||||
`bypass_model_exec=False`.
|
||||
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
KVConnectorBaseType = Union[KVConnectorBase, KVConnectorBase_V1]
|
||||
128
vllm/distributed/kv_transfer/kv_connector/factory.py
Normal file
128
vllm/distributed/kv_transfer/kv_connector/factory.py
Normal file
@@ -0,0 +1,128 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import importlib
|
||||
from typing import TYPE_CHECKING, Callable
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1,
|
||||
KVConnectorRole)
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .base import KVConnectorBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class KVConnectorFactory:
|
||||
_registry: dict[str, Callable[[], type[KVConnectorBaseType]]] = {}
|
||||
|
||||
@classmethod
|
||||
def register_connector(cls, name: str, module_path: str,
|
||||
class_name: str) -> None:
|
||||
"""Register a connector with a lazy-loading module and class name."""
|
||||
if name in cls._registry:
|
||||
raise ValueError(f"Connector '{name}' is already registered.")
|
||||
|
||||
def loader() -> type[KVConnectorBaseType]:
|
||||
module = importlib.import_module(module_path)
|
||||
return getattr(module, class_name)
|
||||
|
||||
cls._registry[name] = loader
|
||||
|
||||
@classmethod
|
||||
def create_connector_v0(cls, rank: int, local_rank: int,
|
||||
config: "VllmConfig") -> KVConnectorBase:
|
||||
if envs.VLLM_USE_V1:
|
||||
raise ValueError("Attempting to initialize a V0 Connector, "
|
||||
f"but found {envs.VLLM_USE_V1=}")
|
||||
|
||||
connector_name = config.kv_transfer_config.kv_connector
|
||||
if connector_name not in cls._registry:
|
||||
raise ValueError(f"Unsupported connector type: {connector_name}")
|
||||
|
||||
connector_cls = cls._registry[connector_name]()
|
||||
assert issubclass(connector_cls, KVConnectorBase)
|
||||
return connector_cls(rank, local_rank, config)
|
||||
|
||||
@classmethod
|
||||
def create_connector_v1(
|
||||
cls,
|
||||
config: "VllmConfig",
|
||||
role: KVConnectorRole,
|
||||
) -> KVConnectorBase_V1:
|
||||
if not envs.VLLM_USE_V1:
|
||||
raise ValueError("Attempting to initialize a V1 Connector, "
|
||||
f"but found {envs.VLLM_USE_V1=}")
|
||||
|
||||
kv_transfer_config = config.kv_transfer_config
|
||||
connector_name = kv_transfer_config.kv_connector
|
||||
if connector_name in cls._registry:
|
||||
connector_cls = cls._registry[connector_name]()
|
||||
else:
|
||||
connector_module_path = kv_transfer_config.kv_connector_module_path
|
||||
if connector_module_path is None:
|
||||
raise ValueError(
|
||||
f"Unsupported connector type: {connector_name}")
|
||||
connector_module = importlib.import_module(connector_module_path)
|
||||
connector_cls = getattr(connector_module, connector_name)
|
||||
assert issubclass(connector_cls, KVConnectorBase_V1)
|
||||
logger.info("Creating v1 connector with name: %s and engine_id: %s",
|
||||
connector_name, kv_transfer_config.engine_id)
|
||||
# NOTE(Kuntai): v1 connector is explicitly separated into two roles.
|
||||
# Scheduler connector:
|
||||
# - Co-locate with scheduler process
|
||||
# - Should only be used inside the Scheduler class
|
||||
# Worker connector:
|
||||
# - Co-locate with worker process
|
||||
# - Should only be used inside the forward context & attention layer
|
||||
# We build separately to enforce strict separation
|
||||
return connector_cls(config, role)
|
||||
|
||||
|
||||
# Register various connectors here.
|
||||
# The registration should not be done in each individual file, as we want to
|
||||
# only load the files corresponding to the current connector.
|
||||
KVConnectorFactory.register_connector(
|
||||
"PyNcclConnector",
|
||||
"vllm.distributed.kv_transfer.kv_connector.simple_connector",
|
||||
"SimpleConnector")
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"MooncakeConnector",
|
||||
"vllm.distributed.kv_transfer.kv_connector.simple_connector",
|
||||
"SimpleConnector")
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"LMCacheConnector",
|
||||
"vllm.distributed.kv_transfer.kv_connector.lmcache_connector",
|
||||
"LMCacheConnector")
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"MooncakeStoreConnector",
|
||||
"vllm.distributed.kv_transfer.kv_connector.mooncake_store_connector",
|
||||
"MooncakeStoreConnector")
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"SharedStorageConnector",
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector",
|
||||
"SharedStorageConnector")
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"LMCacheConnectorV1",
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.lmcache_connector",
|
||||
"LMCacheConnectorV1")
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"NixlConnector",
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector",
|
||||
"NixlConnector")
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"MultiConnector",
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.multi_connector",
|
||||
"MultiConnector")
|
||||
@@ -0,0 +1,99 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
LMCache KV Cache Connector for Distributed Machine Learning Inference
|
||||
|
||||
The LMCacheConnector can (1) transfer KV caches between prefill vLLM worker
|
||||
(KV cache producer) and decode vLLM worker (KV cache consumer) using LMCache;
|
||||
(2) offload and share KV caches.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class LMCacheConnector(KVConnectorBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
rank: int,
|
||||
local_rank: int,
|
||||
config: VllmConfig,
|
||||
):
|
||||
|
||||
self.transfer_config = config.kv_transfer_config
|
||||
self.vllm_config = config
|
||||
|
||||
from lmcache.experimental.cache_engine import LMCacheEngineBuilder
|
||||
from lmcache.integration.vllm.utils import ENGINE_NAME
|
||||
from lmcache.integration.vllm.vllm_adapter import (
|
||||
RetrieveStatus, StoreStatus, init_lmcache_engine,
|
||||
lmcache_retrieve_kv, lmcache_should_retrieve, lmcache_should_store,
|
||||
lmcache_store_kv)
|
||||
logger.info("Initializing LMCacheConfig under kv_transfer_config %s",
|
||||
self.transfer_config)
|
||||
|
||||
# TODO (Jiayi): Find model_config, parallel_config, and cache_config
|
||||
self.engine = init_lmcache_engine(config.model_config,
|
||||
config.parallel_config,
|
||||
config.cache_config)
|
||||
self.lmcache_engine_name = ENGINE_NAME
|
||||
self.lmcache_engine_builder = LMCacheEngineBuilder
|
||||
|
||||
self.model_config = config.model_config
|
||||
self.parallel_config = config.parallel_config
|
||||
self.cache_config = config.cache_config
|
||||
self.lmcache_retrieve_kv = lmcache_retrieve_kv
|
||||
self.lmcache_store_kv = lmcache_store_kv
|
||||
self.lmcache_should_retrieve = lmcache_should_retrieve
|
||||
self.lmcache_should_store = lmcache_should_store
|
||||
self.store_status = StoreStatus
|
||||
self.retrieve_status = RetrieveStatus
|
||||
|
||||
def recv_kv_caches_and_hidden_states(
|
||||
self, model_executable: torch.nn.Module,
|
||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
||||
kv_caches: list[torch.Tensor]
|
||||
) -> tuple[Union[torch.Tensor, IntermediateTensors], bool,
|
||||
"ModelInputForGPUWithSamplingMetadata"]:
|
||||
|
||||
retrieve_status = self.lmcache_should_retrieve(model_input)
|
||||
model_input, bypass_model_exec, hidden_or_intermediate_states =\
|
||||
self.lmcache_retrieve_kv(
|
||||
model_executable, model_input, self.cache_config, kv_caches,
|
||||
retrieve_status)
|
||||
return hidden_or_intermediate_states, bypass_model_exec, model_input
|
||||
|
||||
def send_kv_caches_and_hidden_states(
|
||||
self,
|
||||
model_executable: torch.nn.Module,
|
||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
||||
kv_caches: list[torch.Tensor],
|
||||
hidden_or_intermediate_states: Union[torch.Tensor,
|
||||
IntermediateTensors],
|
||||
) -> None:
|
||||
|
||||
store_status = self.lmcache_should_store(model_input)
|
||||
self.lmcache_store_kv(
|
||||
self.model_config,
|
||||
self.parallel_config,
|
||||
self.cache_config,
|
||||
model_executable,
|
||||
model_input,
|
||||
kv_caches,
|
||||
store_status,
|
||||
)
|
||||
|
||||
def close(self):
|
||||
self.lmcache_engine_builder.destroy(self.lmcache_engine_name)
|
||||
@@ -0,0 +1,203 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
MooncakeStore Connector for Distributed Machine Learning Inference
|
||||
The MooncakeStoreConnector transfers KV caches between prefill vLLM workers
|
||||
(KV cache producer) and decode vLLM workers (KV cache consumer) using a
|
||||
database-style KVStore.
|
||||
"""
|
||||
import hashlib
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
|
||||
from vllm.distributed.kv_transfer.kv_connector.utils import (
|
||||
model_aware_kv_ops_helper as kv_helper)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class MooncakeStoreConnector(KVConnectorBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
rank: int,
|
||||
local_rank: int,
|
||||
config: VllmConfig,
|
||||
):
|
||||
self.kv_transfer_config = config.kv_transfer_config
|
||||
self.kv_helper = kv_helper(config)
|
||||
self.local_tp_rank = local_rank
|
||||
|
||||
# Init kv_store
|
||||
if self.kv_transfer_config.kv_connector == "MooncakeStoreConnector":
|
||||
# Check if MOONCAKE_CONFIG_PATH is set
|
||||
import os
|
||||
use_mooncake_store = os.getenv('MOONCAKE_CONFIG_PATH') is not None
|
||||
|
||||
if not use_mooncake_store:
|
||||
raise ValueError(
|
||||
"To use MooncakeStoreConnector, you need to pass the ENV: "
|
||||
"'MOONCAKE_CONFIG_PATH=/path/to/mooncake_config.json'.")
|
||||
else:
|
||||
from vllm.distributed.kv_transfer.kv_lookup_buffer.mooncake_store import ( # noqa: E501
|
||||
MooncakeStore)
|
||||
logger.info(
|
||||
"Initializing KVStoreConnector under kv_transfer_config %s",
|
||||
self.kv_transfer_config)
|
||||
self.kv_store = MooncakeStore(config)
|
||||
else:
|
||||
logger.error("Can not find %s",
|
||||
self.kv_transfer_config.kv_connector)
|
||||
|
||||
assert self.kv_store is not None
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the buffer and release resources.
|
||||
This method is responsible for cleaning up resources related to the
|
||||
connector when it is no longer needed.
|
||||
Raises:
|
||||
NotImplementedError: This method must be implemented in subclasses.
|
||||
"""
|
||||
self.kv_store.close()
|
||||
|
||||
def send_kv_caches_and_hidden_states(
|
||||
self,
|
||||
model_executable: torch.nn.Module,
|
||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
||||
kv_caches: list[torch.Tensor],
|
||||
hidden_or_intermediate_states: Union[torch.Tensor,
|
||||
IntermediateTensors],
|
||||
) -> None:
|
||||
input_tokens_tensor = model_input.input_tokens
|
||||
seq_lens = model_input.attn_metadata.seq_lens
|
||||
slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten()
|
||||
start_layer = model_executable.model.start_layer
|
||||
end_layer = model_executable.model.end_layer
|
||||
num_heads, head_size = self.kv_helper.get_model_args(model_executable)
|
||||
|
||||
for idx, slen in enumerate(seq_lens):
|
||||
start_pos = sum(seq_lens[:idx])
|
||||
end_pos = start_pos + slen
|
||||
|
||||
current_tokens = input_tokens_tensor[start_pos:end_pos]
|
||||
store_key_prefix = self.tensor_hash(current_tokens)
|
||||
keys, values = [], []
|
||||
|
||||
for layer_id in range(start_layer, end_layer):
|
||||
kv_cache = kv_caches[layer_id - start_layer]
|
||||
key_cache, value_cache = self.kv_helper.get_kv_from_cache(
|
||||
kv_cache, num_heads, head_size)
|
||||
current_slot_mapping = slot_mapping_flat[start_pos:end_pos]
|
||||
|
||||
keys.append(key_cache[current_slot_mapping].unsqueeze(0))
|
||||
values.append(value_cache[current_slot_mapping].unsqueeze(0))
|
||||
|
||||
keys = torch.cat(keys, dim=0)
|
||||
values = torch.cat(values, dim=0)
|
||||
kvcache_to_sent = torch.stack((keys, values), dim=0)
|
||||
store_kvcache_key = f"{store_key_prefix}_{self.local_tp_rank}"
|
||||
self.kv_store.put(store_kvcache_key, kvcache_to_sent)
|
||||
|
||||
hidden_key = f"{store_key_prefix}_hidden_{self.local_tp_rank}"
|
||||
self.kv_store.put(hidden_key,
|
||||
hidden_or_intermediate_states[start_pos:end_pos])
|
||||
|
||||
logger.debug("[rank%d]: KV send DONE.", torch.distributed.get_rank())
|
||||
|
||||
def recv_kv_caches_and_hidden_states(
|
||||
self, model_executable: torch.nn.Module,
|
||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
||||
kv_caches: list[torch.Tensor]
|
||||
) -> tuple[Union[torch.Tensor, IntermediateTensors], bool,
|
||||
"ModelInputForGPUWithSamplingMetadata"]:
|
||||
bypass_model_exec = True
|
||||
input_tokens_tensor = model_input.input_tokens
|
||||
seq_lens = model_input.attn_metadata.seq_lens
|
||||
num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens
|
||||
slot_mapping = model_input.attn_metadata.slot_mapping.flatten()
|
||||
start_layer = model_executable.model.start_layer
|
||||
end_layer = model_executable.model.end_layer
|
||||
hidden_or_intermediate_states_for_one_req = []
|
||||
|
||||
for idx, slen in enumerate(seq_lens):
|
||||
start_pos = sum(seq_lens[:idx])
|
||||
end_pos = start_pos + slen
|
||||
|
||||
if start_pos >= num_prefill_tokens:
|
||||
# This can happen during inflight batching. See:
|
||||
# vllm/worker/model_runner.py::_prepare_model_input_tensors:
|
||||
# - input_tokens[:num_prefill_tokens] contains prefill tokens.
|
||||
# - input_tokens[num_prefill_tokens:] contains decode tokens.
|
||||
logger.warning("You should set --enable_chunked_prefill=False "
|
||||
"and --max_num_batched_tokens "
|
||||
"should be equal to max_seq_len_to_capture")
|
||||
bypass_model_exec = False
|
||||
assert start_pos == num_prefill_tokens
|
||||
break
|
||||
|
||||
current_tokens = input_tokens_tensor[start_pos:end_pos]
|
||||
|
||||
# get roi for current seq
|
||||
load_key_prefix = self.tensor_hash(current_tokens)
|
||||
load_kvcache_key = f"{load_key_prefix}_{self.local_tp_rank}"
|
||||
remote_kv = self.kv_store.get(load_kvcache_key)
|
||||
hidden_key = f"{load_key_prefix}_hidden_{self.local_tp_rank}"
|
||||
hidden = self.kv_store.get(hidden_key)
|
||||
|
||||
if remote_kv is None or hidden is None:
|
||||
# didn't find any match.
|
||||
bypass_model_exec = False
|
||||
continue
|
||||
|
||||
num_computed_tokens = current_tokens.shape[0]
|
||||
|
||||
# update the end position based on how many tokens are cached.
|
||||
end_pos = start_pos + num_computed_tokens
|
||||
|
||||
# call self.kv_store to get kv layer by layer
|
||||
for layer_id in range(start_layer, end_layer):
|
||||
layer = model_executable.model.layers[layer_id]
|
||||
# get kvcache object
|
||||
kv_cache = kv_caches[layer_id - start_layer]
|
||||
|
||||
# get remote kvcache
|
||||
remote_k, remote_v = remote_kv[0][layer_id], remote_kv[1][
|
||||
layer_id]
|
||||
|
||||
self.kv_helper.put_kv_to_cache(model_executable, remote_k,
|
||||
remote_v, layer, kv_cache,
|
||||
slot_mapping, start_pos,
|
||||
end_pos)
|
||||
|
||||
hidden_or_intermediate_states_for_one_req.append(hidden)
|
||||
|
||||
if not bypass_model_exec:
|
||||
logger.warning(
|
||||
"[rank%d]: Failed to receive all KVs and hidden "
|
||||
"states, redo model forwarding.", torch.distributed.get_rank())
|
||||
hidden_or_intermediate_states = None
|
||||
|
||||
else:
|
||||
logger.debug(
|
||||
"[rank%d]: Successfully received all KVs and hidden "
|
||||
"states, skip model forwarding.", torch.distributed.get_rank())
|
||||
hidden_or_intermediate_states = torch.cat(
|
||||
hidden_or_intermediate_states_for_one_req, dim=0)
|
||||
|
||||
return hidden_or_intermediate_states, bypass_model_exec, model_input
|
||||
|
||||
@staticmethod
|
||||
def tensor_hash(tensor: torch.Tensor) -> int:
|
||||
"""Calculate the hash value of the tensor."""
|
||||
tensor_bytes = tensor.clone().detach().cpu().numpy().tobytes()
|
||||
hash_object = hashlib.blake2b(tensor_bytes)
|
||||
hash_hex = hash_object.hexdigest()
|
||||
return int(hash_hex[:16], 16)
|
||||
329
vllm/distributed/kv_transfer/kv_connector/simple_connector.py
Normal file
329
vllm/distributed/kv_transfer/kv_connector/simple_connector.py
Normal file
@@ -0,0 +1,329 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Simple KV Cache Connector for Distributed Machine Learning Inference
|
||||
|
||||
The SimpleConnector transfers KV caches between prefill vLLM worker (KV cache
|
||||
producer) and decode vLLM worker (KV cache consumer) using PyNcclPipe or
|
||||
MooncakePipe.
|
||||
|
||||
But the logic can be extended to support other pipe and lookup buffer.
|
||||
"""
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
|
||||
from vllm.distributed.kv_transfer.kv_connector.utils import (
|
||||
model_aware_kv_ops_helper as kv_helper)
|
||||
from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer import (
|
||||
SimpleBuffer)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class SimpleConnector(KVConnectorBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
rank: int,
|
||||
local_rank: int,
|
||||
config: VllmConfig,
|
||||
):
|
||||
|
||||
self.config = config.kv_transfer_config
|
||||
self.kv_helper = kv_helper(config)
|
||||
|
||||
if self.config.kv_connector == "PyNcclConnector":
|
||||
from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import (
|
||||
PyNcclPipe)
|
||||
logger.info(
|
||||
"Initializing PyNcclConfig under kv_transfer_config %s",
|
||||
self.config)
|
||||
elif self.config.kv_connector == "MooncakeConnector":
|
||||
# Check if MOONCAKE_CONFIG_PATH is set
|
||||
import os
|
||||
use_mooncake_distributed_pipe = os.getenv(
|
||||
'MOONCAKE_CONFIG_PATH') is not None
|
||||
|
||||
if not use_mooncake_distributed_pipe:
|
||||
raise ValueError(
|
||||
"To use MooncakeConnector, you need to pass the ENV: "
|
||||
"'MOONCAKE_CONFIG_PATH=/path/to/mooncake_config.json'.")
|
||||
else:
|
||||
from vllm.distributed.kv_transfer.kv_pipe.mooncake_pipe import ( # noqa: E501
|
||||
MooncakePipe)
|
||||
logger.info(
|
||||
"Initializing MooncakeConfig under kv_transfer_config %s",
|
||||
self.config)
|
||||
|
||||
self.lookup_buffer_size = self.config.kv_buffer_size
|
||||
|
||||
self.producer_buffer: Optional[SimpleBuffer] = None
|
||||
self.consumer_buffer: Optional[SimpleBuffer] = None
|
||||
|
||||
self.producer_data_pipe: Union[PyNcclPipe, MooncakePipe]
|
||||
self.consumer_data_pipe: Union[PyNcclPipe, MooncakePipe]
|
||||
self.producer_signal_pipe: Union[PyNcclPipe, MooncakePipe]
|
||||
self.consumer_signal_pipe: Union[PyNcclPipe, MooncakePipe]
|
||||
|
||||
# 2 pipes for every rank in the world
|
||||
port_offset_base = 2 * rank
|
||||
|
||||
# In disaggregated prefill, the prefill vLLM only uses send pipe
|
||||
# and the decode vLLM only uses recv pipe
|
||||
if self.config.is_kv_producer:
|
||||
|
||||
if self.config.kv_connector == "PyNcclConnector":
|
||||
self.producer_data_pipe = PyNcclPipe(
|
||||
local_rank=local_rank,
|
||||
config=self.config,
|
||||
port_offset=port_offset_base,
|
||||
)
|
||||
self.producer_signal_pipe = PyNcclPipe(
|
||||
local_rank=local_rank,
|
||||
config=self.config,
|
||||
port_offset=port_offset_base + 1,
|
||||
device="cpu",
|
||||
)
|
||||
elif self.config.kv_connector == "MooncakeConnector":
|
||||
self.producer_data_pipe = MooncakePipe(
|
||||
local_rank=local_rank,
|
||||
config=self.config,
|
||||
)
|
||||
# We only need to initialize MooncakePipe once
|
||||
self.producer_signal_pipe = self.producer_data_pipe
|
||||
|
||||
self.producer_buffer = SimpleBuffer(self.producer_signal_pipe,
|
||||
self.producer_data_pipe,
|
||||
self.config.kv_buffer_size)
|
||||
|
||||
else:
|
||||
|
||||
# the current vLLM instance is KV consumer, so it needs to connect
|
||||
# its recv pipe to the send pipe of KV producer
|
||||
if self.config.kv_connector == "PyNcclConnector":
|
||||
self.consumer_data_pipe = PyNcclPipe(
|
||||
local_rank=local_rank,
|
||||
config=self.config,
|
||||
port_offset=port_offset_base,
|
||||
)
|
||||
self.consumer_signal_pipe = PyNcclPipe(
|
||||
local_rank=local_rank,
|
||||
config=self.config,
|
||||
port_offset=port_offset_base + 1,
|
||||
device="cpu",
|
||||
)
|
||||
elif self.config.kv_connector == "MooncakeConnector":
|
||||
self.consumer_data_pipe = MooncakePipe(
|
||||
local_rank=local_rank,
|
||||
config=self.config,
|
||||
)
|
||||
self.consumer_signal_pipe = self.consumer_data_pipe
|
||||
|
||||
self.consumer_buffer = SimpleBuffer(
|
||||
self.consumer_signal_pipe,
|
||||
self.consumer_data_pipe,
|
||||
self.config.kv_buffer_size,
|
||||
)
|
||||
|
||||
def select(self, input_tokens: Optional[torch.Tensor],
|
||||
roi: Optional[torch.Tensor]) -> list[Optional[torch.Tensor]]:
|
||||
|
||||
assert self.consumer_buffer is not None, "Please initialize the "\
|
||||
"consumer buffer before calling select."
|
||||
return self.consumer_buffer.drop_select(input_tokens, roi)
|
||||
|
||||
def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor,
|
||||
key: torch.Tensor, value: torch.Tensor,
|
||||
hidden: torch.Tensor) -> None:
|
||||
|
||||
assert self.producer_buffer is not None, "Please initialize the "\
|
||||
"producer buffer before calling insert."
|
||||
|
||||
self.producer_buffer.insert(input_tokens, roi, key, value, hidden)
|
||||
|
||||
def send_kv_caches_and_hidden_states(
|
||||
self,
|
||||
model_executable: torch.nn.Module,
|
||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
||||
kv_caches: list[torch.Tensor],
|
||||
hidden_or_intermediate_states: Union[torch.Tensor,
|
||||
IntermediateTensors],
|
||||
) -> None:
|
||||
|
||||
input_tokens_tensor = model_input.input_tokens
|
||||
seq_lens = model_input.attn_metadata.seq_lens
|
||||
slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten()
|
||||
num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens
|
||||
start_layer = model_executable.model.start_layer
|
||||
end_layer = model_executable.model.end_layer
|
||||
num_heads, head_size = self.kv_helper.get_model_args(model_executable)
|
||||
|
||||
# query_lens contains new KV caches that are added to vLLM.
|
||||
# so we will send them to decode instance
|
||||
# FIXME(Kuntai): This assume that all requests are prefill.
|
||||
for idx, slen in enumerate(seq_lens):
|
||||
start_pos = sum(seq_lens[:idx])
|
||||
end_pos = start_pos + slen
|
||||
|
||||
if start_pos >= num_prefill_tokens:
|
||||
# vllm/worker/model_runner.py::_prepare_model_input_tensors:
|
||||
# - input_tokens[:num_prefill_tokens] contains prefill tokens.
|
||||
# - input_tokens[num_prefill_tokens:] contains decode tokens.
|
||||
logger.warning("You have some decode requests while using "
|
||||
"SimpleConnector. Their KVCache won't be sent.")
|
||||
break
|
||||
|
||||
current_tokens = input_tokens_tensor[start_pos:end_pos]
|
||||
|
||||
keys, values = [], []
|
||||
|
||||
for layer_id in range(start_layer, end_layer):
|
||||
kv_cache = kv_caches[layer_id - start_layer]
|
||||
key_cache, value_cache = self.kv_helper.get_kv_from_cache(
|
||||
kv_cache, num_heads, head_size)
|
||||
|
||||
current_slot_mapping = slot_mapping_flat[start_pos:end_pos]
|
||||
|
||||
keys.append(key_cache[current_slot_mapping].unsqueeze(0))
|
||||
values.append(value_cache[current_slot_mapping].unsqueeze(0))
|
||||
|
||||
keys = torch.cat(keys, dim=0)
|
||||
values = torch.cat(values, dim=0)
|
||||
|
||||
self.insert(current_tokens,
|
||||
torch.ones_like(current_tokens,
|
||||
dtype=bool), keys, values,
|
||||
hidden_or_intermediate_states[start_pos:end_pos])
|
||||
|
||||
logger.debug("[rank%d]: KV send DONE.", torch.distributed.get_rank())
|
||||
|
||||
def recv_kv_caches_and_hidden_states(
|
||||
self, model_executable: torch.nn.Module,
|
||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
||||
kv_caches: list[torch.Tensor]
|
||||
) -> tuple[Union[torch.Tensor, IntermediateTensors], bool,
|
||||
"ModelInputForGPUWithSamplingMetadata"]:
|
||||
|
||||
# When bypass_model_exec is set to False, it means that at least for one
|
||||
# request its corresponding KV cache or hidden state is missing.
|
||||
# In this case we need to do prefilling to recompute missing KV cache
|
||||
# and hidden states.
|
||||
bypass_model_exec = True
|
||||
|
||||
input_tokens_tensor = model_input.input_tokens
|
||||
seq_lens = model_input.attn_metadata.seq_lens
|
||||
num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens
|
||||
slot_mapping = model_input.attn_metadata.slot_mapping.flatten()
|
||||
start_layer = model_executable.model.start_layer
|
||||
end_layer = model_executable.model.end_layer
|
||||
|
||||
hidden_or_intermediate_states_for_one_req = []
|
||||
|
||||
input_tokens_list = []
|
||||
num_computed_tokens_list = []
|
||||
start_pos_list = []
|
||||
|
||||
# enumerate different requests
|
||||
# FIXME(Kuntai): This impl assumes that all requests are prefill.
|
||||
for idx, slen in enumerate(seq_lens):
|
||||
start_pos = sum(seq_lens[:idx])
|
||||
end_pos = start_pos + slen
|
||||
|
||||
if start_pos >= num_prefill_tokens:
|
||||
# This can happen during inflight batching. See:
|
||||
# vllm/worker/model_runner.py::_prepare_model_input_tensors:
|
||||
# - input_tokens[:num_prefill_tokens] contains prefill tokens.
|
||||
# - input_tokens[num_prefill_tokens:] contains decode tokens.
|
||||
logger.warning("You should set --enable_chunked_prefill=False "
|
||||
"and --max_num_batched_tokens "
|
||||
"should be equal to --max_seq_len_to_capture")
|
||||
bypass_model_exec = False
|
||||
assert start_pos == num_prefill_tokens
|
||||
break
|
||||
|
||||
current_tokens = input_tokens_tensor[start_pos:end_pos]
|
||||
num_tokens = slen
|
||||
|
||||
# collecting data for rebuilding the input
|
||||
input_tokens_list.append(current_tokens)
|
||||
start_pos_list.append(start_pos)
|
||||
|
||||
ret = self.select(current_tokens,
|
||||
torch.ones_like(current_tokens, dtype=bool))
|
||||
if ret[0] is None:
|
||||
# didn't find any match.
|
||||
bypass_model_exec = False
|
||||
num_computed_tokens_list.append(0)
|
||||
continue
|
||||
|
||||
roi: torch.Tensor = ret[1]
|
||||
keys: torch.Tensor = ret[2]
|
||||
values: torch.Tensor = ret[3]
|
||||
hidden: torch.Tensor = ret[4]
|
||||
|
||||
num_computed_tokens = roi.shape[0]
|
||||
num_computed_tokens_list.append(num_computed_tokens)
|
||||
|
||||
# check if both KV cache and the hidden states are received
|
||||
# If not, need to redo the forwarding to compute missing states
|
||||
if not all([(num_computed_tokens == num_tokens), hidden is not None
|
||||
]):
|
||||
bypass_model_exec = False
|
||||
|
||||
# update the end position based on how many tokens are cached.
|
||||
end_pos = start_pos + num_computed_tokens
|
||||
|
||||
# put received KV caches into paged memory
|
||||
for cur_layer in range(start_layer, end_layer):
|
||||
|
||||
layer_id = cur_layer - start_layer
|
||||
kv_cache = kv_caches[layer_id]
|
||||
layer = model_executable.model.layers[cur_layer]
|
||||
|
||||
# get remote kvcache
|
||||
remote_k, remote_v = keys[layer_id], values[layer_id]
|
||||
|
||||
self.kv_helper.put_kv_to_cache(model_executable, remote_k,
|
||||
remote_v, layer, kv_cache,
|
||||
slot_mapping, start_pos,
|
||||
end_pos)
|
||||
|
||||
hidden_or_intermediate_states_for_one_req.append(hidden)
|
||||
|
||||
if not bypass_model_exec:
|
||||
# Some of the KV cache is not retrieved
|
||||
# Here we will fall back to normal model forwarding
|
||||
# But optionally you can adjust model_input so that you only do
|
||||
# prefilling on those tokens that are missing KV caches.
|
||||
logger.warning(
|
||||
"[rank%d]: Failed to receive all KVs and hidden "
|
||||
"states, redo model forwarding.", torch.distributed.get_rank())
|
||||
hidden_or_intermediate_states = None
|
||||
|
||||
else:
|
||||
logger.debug(
|
||||
"[rank%d]: Successfully received all KVs and hidden "
|
||||
"states, skip model forwarding.", torch.distributed.get_rank())
|
||||
hidden_or_intermediate_states = torch.cat(
|
||||
hidden_or_intermediate_states_for_one_req, dim=0)
|
||||
|
||||
return hidden_or_intermediate_states, bypass_model_exec, model_input
|
||||
|
||||
def close(self):
|
||||
self.producer_data_pipe.close()
|
||||
self.consumer_data_pipe.close()
|
||||
if self.config.kv_connector == "PyNcclConnector":
|
||||
self.producer_signal_pipe.close()
|
||||
self.consumer_signal_pipe.close()
|
||||
elif self.config.kv_connector == "MooncakeConnector":
|
||||
# MooncakePipe reuses data_pipe for signal_pipe, so we only have to
|
||||
# close the data_pipe.
|
||||
pass
|
||||
108
vllm/distributed/kv_transfer/kv_connector/utils.py
Normal file
108
vllm/distributed/kv_transfer/kv_connector/utils.py
Normal file
@@ -0,0 +1,108 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
KV cache helper for store.
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.config import VllmConfig, get_current_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class model_aware_kv_ops_helper:
|
||||
|
||||
def __init__(self, config: VllmConfig):
|
||||
self.is_deepseek_mla = config.model_config.is_deepseek_mla
|
||||
self.use_mla_opt = not envs.VLLM_MLA_DISABLE
|
||||
self.tp_size = config.parallel_config.tensor_parallel_size
|
||||
|
||||
def get_model_args(self, model_executable: torch.nn.Module):
|
||||
|
||||
model_config = model_executable.model.config
|
||||
self.model_executable = model_executable
|
||||
num_heads = int(model_config.num_key_value_heads / self.tp_size)
|
||||
hidden_size = model_config.hidden_size
|
||||
num_attention_heads = model_config.num_attention_heads
|
||||
|
||||
# Deepseek's MLA (Multi-head Latent Attention) uses two different
|
||||
# kv_cache shapes based on whether VLLM_MLA_DISABLE is set to 0.
|
||||
# When VLLM_MLA_DISABLE=0 (default), forward absorb is applied,
|
||||
# resulting in a kv_cache shape of [num_blks, blk_size, 1,
|
||||
# kv_lora_rank + qk_rope_head_dim].
|
||||
# When VLLM_MLA_DISABLE=1, standard FA is used instead, leading
|
||||
# to a kv_cache shape of [2, num_blks, blk_size,
|
||||
# num_key_value_heads / tp, qk_nope_head_dim + qk_rope_head_dim].
|
||||
# For more details, see vllm/attention/backends/mla/common.py.
|
||||
if self.is_deepseek_mla and self.use_mla_opt:
|
||||
head_size = model_config.kv_lora_rank + \
|
||||
model_config.qk_rope_head_dim
|
||||
num_heads = 1
|
||||
elif self.is_deepseek_mla and not self.use_mla_opt:
|
||||
head_size = model_config.qk_nope_head_dim + \
|
||||
model_config.qk_rope_head_dim
|
||||
else:
|
||||
head_size = getattr(model_config, "head_dim", None)
|
||||
if head_size is None:
|
||||
head_size = int(hidden_size // num_attention_heads)
|
||||
|
||||
return num_heads, head_size
|
||||
|
||||
def get_kv_from_cache(self, kv_cache, num_heads, head_size):
|
||||
if self.is_deepseek_mla and self.use_mla_opt:
|
||||
key_cache = kv_cache.reshape(-1, num_heads, head_size)
|
||||
value_cache = kv_cache.reshape(-1, num_heads, head_size)
|
||||
else:
|
||||
key_cache = kv_cache[0].reshape(-1, num_heads, head_size)
|
||||
value_cache = kv_cache[1].reshape(-1, num_heads, head_size)
|
||||
return key_cache, value_cache
|
||||
|
||||
def put_kv_to_cache(self, model_executable: torch.nn.Module, keys, values,
|
||||
layer, kv_cache, slot_mapping, start_pos, end_pos):
|
||||
|
||||
model_config = model_executable.model.config
|
||||
|
||||
if self.is_deepseek_mla and self.use_mla_opt:
|
||||
layer.self_attn.attn = layer.self_attn.mla_attn
|
||||
k_c_normed_k_pe = keys.squeeze(1)
|
||||
k_c_normed = k_c_normed_k_pe[:, :model_config.kv_lora_rank]
|
||||
k_pe = k_c_normed_k_pe[:, model_config.kv_lora_rank:]
|
||||
ops.concat_and_cache_mla(
|
||||
k_c_normed.to(kv_cache.device),
|
||||
k_pe.to(kv_cache.device),
|
||||
kv_cache,
|
||||
slot_mapping[start_pos:end_pos],
|
||||
layer.self_attn.attn.kv_cache_dtype,
|
||||
layer.self_attn.attn._k_scale,
|
||||
)
|
||||
else:
|
||||
key_cache, value_cache = kv_cache[0], kv_cache[1]
|
||||
ops.reshape_and_cache_flash(
|
||||
keys.to(key_cache.device),
|
||||
values.to(value_cache.device),
|
||||
key_cache,
|
||||
value_cache,
|
||||
slot_mapping[start_pos:end_pos],
|
||||
layer.self_attn.attn.kv_cache_dtype,
|
||||
layer.self_attn.attn._k_scale,
|
||||
layer.self_attn.attn._v_scale,
|
||||
)
|
||||
|
||||
|
||||
def get_kv_connector_cache_layout():
|
||||
vllm_config = get_current_vllm_config()
|
||||
kv_config = vllm_config.kv_transfer_config
|
||||
if vllm_config.model_config is None:
|
||||
logger.warning("Unable to detect current VLLM config. " \
|
||||
"Defaulting to NHD kv cache layout.")
|
||||
else:
|
||||
use_mla = vllm_config.model_config.use_mla
|
||||
if not use_mla and kv_config.kv_connector == "NixlConnector":
|
||||
logger.info("NixlConnector detected. Setting KV cache " \
|
||||
"layout to HND for better xfer performance.")
|
||||
return "HND"
|
||||
return "NHD"
|
||||
6
vllm/distributed/kv_transfer/kv_connector/v1/__init__.py
Normal file
6
vllm/distributed/kv_transfer/kv_connector/v1/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1, KVConnectorRole)
|
||||
|
||||
__all__ = ["KVConnectorRole", "KVConnectorBase_V1"]
|
||||
283
vllm/distributed/kv_transfer/kv_connector/v1/base.py
Normal file
283
vllm/distributed/kv_transfer/kv_connector/v1/base.py
Normal file
@@ -0,0 +1,283 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
KVConnectorBase_V1 Class for Distributed KV Cache & Hidden State
|
||||
communication in vLLM v1
|
||||
|
||||
The class provides the following primitives:
|
||||
Scheduler-side: runs in the scheduler, binds metadata, which
|
||||
is used by the worker-side to load/save KV cache.
|
||||
get_num_new_matched_tokens() - get number of new tokens
|
||||
that exist in the remote KV cache. Might be called multiple
|
||||
times for a given request and should be side-effect free.
|
||||
update_state_after_alloc() - update KVConnector state after
|
||||
temporary buffer alloc by the CacheManager.
|
||||
request_finished() - called when a request is finished, with
|
||||
the computed kv cache blocks for the request.
|
||||
Returns whether KV cache should be freed now or will be
|
||||
freed asynchronously and optionally returns KV transfer
|
||||
params.
|
||||
|
||||
Worker-side: runs in each worker, loads/saves KV cache to/from
|
||||
the Connector based on the metadata.
|
||||
start_load_kv() - starts loading all KVs (maybe async)
|
||||
wait_for_layer_load() - blocks until layer i load is done
|
||||
|
||||
save_kv_layer() - starts saving KV for layer i (maybe async)
|
||||
wait_for_save() - blocks until all saves are done
|
||||
|
||||
get_finished() - called with ids of finished requests, returns
|
||||
ids of requests that have completed async sending/recving.
|
||||
"""
|
||||
|
||||
import enum
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.request import Request
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class KVConnectorRole(enum.Enum):
|
||||
# Connector running in the scheduler process
|
||||
SCHEDULER = 0
|
||||
|
||||
# Connector running in the worker process
|
||||
WORKER = 1
|
||||
|
||||
|
||||
class KVConnectorMetadata:
|
||||
"""
|
||||
Abstract Metadata used to communicate between the
|
||||
Scheduler KVConnector and Worker KVConnector.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class KVConnectorBase_V1(ABC):
|
||||
|
||||
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
|
||||
logger.warning(
|
||||
"Initializing KVConnectorBase_V1. This API is experimental and "
|
||||
"subject to change in the future as we iterate the design.")
|
||||
self._connector_metadata = KVConnectorMetadata()
|
||||
self._vllm_config = vllm_config
|
||||
self._role = role
|
||||
|
||||
@property
|
||||
def role(self) -> KVConnectorRole:
|
||||
return self._role
|
||||
|
||||
# ==============================
|
||||
# Worker-side methods
|
||||
# ==============================
|
||||
|
||||
def bind_connector_metadata(
|
||||
self, connector_metadata: KVConnectorMetadata) -> None:
|
||||
"""Set the connector metadata from the scheduler.
|
||||
|
||||
This function should be called by the model runner every time
|
||||
before the model execution. The metadata will be used for runtime
|
||||
KV cache loading and saving.
|
||||
|
||||
Args:
|
||||
connector_metadata (dict): the connector metadata.
|
||||
"""
|
||||
self._connector_metadata = connector_metadata
|
||||
|
||||
def clear_connector_metadata(self) -> None:
|
||||
"""Clear the connector metadata.
|
||||
|
||||
This function should be called by the model runner every time
|
||||
after the model execution.
|
||||
"""
|
||||
self._connector_metadata = KVConnectorMetadata()
|
||||
|
||||
def _get_connector_metadata(self) -> KVConnectorMetadata:
|
||||
"""Get the connector metadata.
|
||||
|
||||
This function should only be called inside the connector.
|
||||
|
||||
Returns:
|
||||
ConnectorMetadata: the connector metadata.
|
||||
"""
|
||||
return self._connector_metadata
|
||||
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
"""
|
||||
Initialize with the KV caches. Useful for pre-registering the
|
||||
KV Caches in the KVConnector (e.g. for NIXL).
|
||||
|
||||
Args: kv_caches:
|
||||
dictionary of layer names, kv cache
|
||||
"""
|
||||
return
|
||||
|
||||
@abstractmethod
|
||||
def start_load_kv(self, forward_context: "ForwardContext",
|
||||
**kwargs) -> None:
|
||||
"""
|
||||
Start loading the KV cache from the connector to vLLM's paged
|
||||
KV buffer. This is called from the forward context before the
|
||||
forward pass to enable async loading during model execution.
|
||||
|
||||
Args:
|
||||
forward_context (ForwardContext): the forward context.
|
||||
**kwargs: additional arguments for the load operation
|
||||
|
||||
Note:
|
||||
The number of elements in kv_caches and layer_names should be
|
||||
the same.
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||
"""
|
||||
Block until the KV for a specific layer is loaded into vLLM's
|
||||
paged buffer. This is called from within attention layer to ensure
|
||||
async copying from start_load_kv is complete.
|
||||
|
||||
This interface will be useful for layer-by-layer pipelining.
|
||||
|
||||
Args:
|
||||
layer_name: the name of that layer
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
|
||||
attn_metadata: "AttentionMetadata", **kwargs) -> None:
|
||||
"""
|
||||
Start saving a layer of KV cache from vLLM's paged buffer
|
||||
to the connector. This is called from within attention layer to
|
||||
enable async copying during execution.
|
||||
|
||||
Args:
|
||||
layer_name (str): the name of the layer.
|
||||
kv_layer (torch.Tensor): the paged KV buffer of the current
|
||||
layer in vLLM.
|
||||
attn_metadata (AttentionMetadata): the attention metadata.
|
||||
**kwargs: additional arguments for the save operation.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def wait_for_save(self):
|
||||
"""
|
||||
Block until all the save operations is done. This is called
|
||||
as the forward context exits to ensure that the async saving
|
||||
from save_kv_layer is complete before finishing the forward.
|
||||
|
||||
This prevents overwrites of paged KV buffer before saving done.
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_finished(
|
||||
self, finished_req_ids: set[str]
|
||||
) -> tuple[Optional[set[str]], Optional[set[str]]]:
|
||||
"""
|
||||
Notifies worker-side connector ids of requests that have
|
||||
finished generating tokens.
|
||||
|
||||
Returns:
|
||||
ids of requests that have finished asynchronous transfer
|
||||
(requests that previously returned True from request_finished()),
|
||||
tuple of (sending/saving ids, recving/loading ids).
|
||||
The finished saves/sends req ids must belong to a set provided in a
|
||||
call to this method (this call or a prior one).
|
||||
"""
|
||||
return None, None
|
||||
|
||||
# ==============================
|
||||
# Scheduler-side methods
|
||||
# ==============================
|
||||
|
||||
@abstractmethod
|
||||
def get_num_new_matched_tokens(
|
||||
self,
|
||||
request: "Request",
|
||||
num_computed_tokens: int,
|
||||
) -> tuple[int, bool]:
|
||||
"""
|
||||
Get number of new tokens that can be loaded from the
|
||||
external KV cache beyond the num_computed_tokens.
|
||||
|
||||
Args:
|
||||
request (Request): the request object.
|
||||
num_computed_tokens (int): the number of locally
|
||||
computed tokens for this request
|
||||
|
||||
Returns:
|
||||
A tuple with the following elements:
|
||||
- The number of tokens that can be loaded from the
|
||||
external KV cache beyond what is already computed.
|
||||
- `True` if external KV cache tokens will be loaded
|
||||
asynchronously (between scheduler steps). Must be
|
||||
'False' if the first element is 0.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_state_after_alloc(self, request: "Request",
|
||||
blocks: "KVCacheBlocks",
|
||||
num_external_tokens: int):
|
||||
"""
|
||||
Update KVConnector state after block allocation.
|
||||
|
||||
If get_num_new_matched_tokens previously returned True for a
|
||||
request, this function may be called twice for that same request -
|
||||
first when blocks are allocated for the connector tokens to be
|
||||
asynchronously loaded into, and second when any additional blocks
|
||||
are allocated, after the load/transfer is complete.
|
||||
|
||||
Args:
|
||||
request (Request): the request object.
|
||||
blocks (KVCacheBlocks): the blocks allocated for the request.
|
||||
num_external_tokens (int): the number of tokens that will be
|
||||
loaded from the external KV cache.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def build_connector_meta(
|
||||
self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
|
||||
"""
|
||||
Build the connector metadata for this step.
|
||||
|
||||
This function should NOT modify fields in the scheduler_output.
|
||||
Also, calling this function will reset the state of the connector.
|
||||
|
||||
Args:
|
||||
scheduler_output (SchedulerOutput): the scheduler output object.
|
||||
"""
|
||||
pass
|
||||
|
||||
def request_finished(
|
||||
self,
|
||||
request: "Request",
|
||||
block_ids: list[int],
|
||||
) -> tuple[bool, Optional[dict[str, Any]]]:
|
||||
"""
|
||||
Called when a request has finished, before its blocks are freed.
|
||||
|
||||
Returns:
|
||||
True if the request is being saved/sent asynchronously and blocks
|
||||
should not be freed until the request_id is returned from
|
||||
get_finished().
|
||||
Optional KVTransferParams to be included in the request outputs
|
||||
returned by the engine.
|
||||
"""
|
||||
return False, None
|
||||
@@ -0,0 +1,134 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from lmcache.integration.vllm.vllm_v1_adapter import LMCacheConnectorV1Impl
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.request import Request
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class LMCacheConnectorV1(KVConnectorBase_V1):
|
||||
|
||||
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
|
||||
super().__init__(vllm_config=vllm_config, role=role)
|
||||
self._lmcache_engine = LMCacheConnectorV1Impl(vllm_config, role, self)
|
||||
|
||||
# ==============================
|
||||
# Worker-side methods
|
||||
# ==============================
|
||||
def start_load_kv(self, forward_context: "ForwardContext",
|
||||
**kwargs) -> None:
|
||||
"""
|
||||
Start loading the KV cache from the connector to vLLM's paged
|
||||
KV buffer. This is called from the forward context before the
|
||||
forward pass to enable async loading during model execution.
|
||||
|
||||
Args:
|
||||
forward_context (ForwardContext): the forward context.
|
||||
**kwargs: additional arguments for the load operation
|
||||
|
||||
Note:
|
||||
The number of elements in kv_caches and layer_names should be
|
||||
the same.
|
||||
|
||||
"""
|
||||
self._lmcache_engine.start_load_kv(forward_context, **kwargs)
|
||||
|
||||
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||
"""
|
||||
Block until the KV for a specific layer is loaded into vLLM's
|
||||
paged buffer. This is called from within attention layer to ensure
|
||||
async copying from start_load_kv is complete.
|
||||
|
||||
This interface will be useful for layer-by-layer pipelining.
|
||||
|
||||
Args:
|
||||
layer_name: the name of that layer
|
||||
"""
|
||||
self._lmcache_engine.wait_for_layer_load(layer_name)
|
||||
|
||||
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
|
||||
attn_metadata: "AttentionMetadata", **kwargs) -> None:
|
||||
"""
|
||||
Start saving the a layer of KV cache from vLLM's paged buffer
|
||||
to the connector. This is called from within attention layer to
|
||||
enable async copying during execution.
|
||||
|
||||
Args:
|
||||
layer_name (str): the name of the layer.
|
||||
kv_layer (torch.Tensor): the paged KV buffer of the current
|
||||
layer in vLLM.
|
||||
attn_metadata (AttentionMetadata): the attention metadata.
|
||||
**kwargs: additional arguments for the save operation.
|
||||
"""
|
||||
self._lmcache_engine.save_kv_layer(layer_name, kv_layer, attn_metadata,
|
||||
**kwargs)
|
||||
|
||||
def wait_for_save(self):
|
||||
"""
|
||||
Block until all the save operations is done. This is called
|
||||
as the forward context exits to ensure that the async saving
|
||||
from save_kv_layer is complete before finishing the forward.
|
||||
|
||||
This prevents overwrites of paged KV buffer before saving done.
|
||||
"""
|
||||
self._lmcache_engine.wait_for_save()
|
||||
|
||||
# ==============================
|
||||
# Scheduler-side methods
|
||||
# ==============================
|
||||
def get_num_new_matched_tokens(
|
||||
self,
|
||||
request: "Request",
|
||||
num_computed_tokens: int,
|
||||
) -> tuple[int, bool]:
|
||||
"""
|
||||
Get number of new tokens that can be loaded from the
|
||||
external KV cache beyond the num_computed_tokens.
|
||||
|
||||
Args:
|
||||
request (Request): the request object.
|
||||
num_computed_tokens (int): the number of locally
|
||||
computed tokens for this request
|
||||
|
||||
Returns:
|
||||
the number of tokens that can be loaded from the
|
||||
external KV cache beyond what is already computed.
|
||||
"""
|
||||
return self._lmcache_engine.get_num_new_matched_tokens(
|
||||
request, num_computed_tokens), False
|
||||
|
||||
def update_state_after_alloc(self, request: "Request",
|
||||
blocks: "KVCacheBlocks",
|
||||
num_external_tokens: int):
|
||||
"""
|
||||
Update KVConnector state after block allocation.
|
||||
"""
|
||||
self._lmcache_engine.update_state_after_alloc(request,
|
||||
num_external_tokens)
|
||||
|
||||
def build_connector_meta(
|
||||
self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
|
||||
"""
|
||||
Build the connector metadata for this step.
|
||||
|
||||
This function should NOT modify fields in the scheduler_output.
|
||||
Also, calling this function will reset the state of the connector.
|
||||
|
||||
Args:
|
||||
scheduler_output (SchedulerOutput): the scheduler output object.
|
||||
"""
|
||||
return self._lmcache_engine.build_connector_meta(scheduler_output)
|
||||
201
vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
Normal file
201
vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
Normal file
@@ -0,0 +1,201 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import copy
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import KVTransferConfig, VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.factory import (
|
||||
KVConnectorFactory)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.v1.request import Request
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MultiKVConnectorMetadata(KVConnectorMetadata):
|
||||
metadata: tuple[KVConnectorMetadata, ...]
|
||||
extra_async_saves: Optional[dict[str, int]] = None
|
||||
|
||||
|
||||
class MultiConnector(KVConnectorBase_V1):
|
||||
"""
|
||||
A wrapper for using multiple KVConnectors at the same time.
|
||||
|
||||
The current logic is:
|
||||
- Load KV from the first connector that advertises available tokens from
|
||||
get_num_new_matched_tokens(), based on the order in the config.
|
||||
- Save to all connectors.
|
||||
"""
|
||||
|
||||
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
|
||||
super().__init__(vllm_config=vllm_config, role=role)
|
||||
self._connectors: list[KVConnectorBase_V1] = []
|
||||
ktcs = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||
"connectors")
|
||||
assert ktcs is not None
|
||||
for ktc in ktcs:
|
||||
temp_config = copy.copy(vllm_config)
|
||||
temp_config.kv_transfer_config = KVTransferConfig(**ktc)
|
||||
self._connectors.append(
|
||||
KVConnectorFactory.create_connector_v1(temp_config, role))
|
||||
|
||||
# A mapping from request id to the index of the connector chosen to
|
||||
# load the request from (if any).
|
||||
self._requests_to_connector: dict[str, int] = {}
|
||||
|
||||
# Keeps track of *additional* remaining async saves (beyond 1) to be
|
||||
# finished per request. Not needed for async loads since we only allow
|
||||
# a single connector to load.
|
||||
# Propagated from scheduler to worker side via the connector metadata.
|
||||
self._extra_async_saves: dict[str, int] = {}
|
||||
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
for c in self._connectors:
|
||||
c.register_kv_caches(kv_caches)
|
||||
|
||||
# We must override the base class method here because we need to bind
|
||||
# the metadata to each connector in the order of the connectors in the
|
||||
# MultiKVConnectorMetadata.
|
||||
def bind_connector_metadata(
|
||||
self, connector_metadata: KVConnectorMetadata) -> None:
|
||||
assert isinstance(connector_metadata, MultiKVConnectorMetadata)
|
||||
if connector_metadata.extra_async_saves:
|
||||
self._extra_async_saves.update(
|
||||
connector_metadata.extra_async_saves)
|
||||
for c, cm in zip(self._connectors, connector_metadata.metadata):
|
||||
c.bind_connector_metadata(cm)
|
||||
|
||||
def clear_connector_metadata(self) -> None:
|
||||
for c in self._connectors:
|
||||
c.clear_connector_metadata()
|
||||
|
||||
# ==============================
|
||||
# Worker-side methods
|
||||
# ==============================
|
||||
def start_load_kv(self, forward_context: "ForwardContext",
|
||||
**kwargs) -> None:
|
||||
for c in self._connectors:
|
||||
c.start_load_kv(forward_context, **kwargs)
|
||||
|
||||
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||
for c in self._connectors:
|
||||
c.wait_for_layer_load(layer_name)
|
||||
|
||||
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
|
||||
attn_metadata: "AttentionMetadata", **kwargs) -> None:
|
||||
for c in self._connectors:
|
||||
c.save_kv_layer(layer_name, kv_layer, attn_metadata, **kwargs)
|
||||
|
||||
def wait_for_save(self):
|
||||
for c in self._connectors:
|
||||
c.wait_for_save()
|
||||
|
||||
def get_finished(
|
||||
self, finished_req_ids: set[str]
|
||||
) -> tuple[Optional[set[str]], Optional[set[str]]]:
|
||||
finished_sending: set[str] = set()
|
||||
finished_recving: set[str] = set()
|
||||
for c in self._connectors:
|
||||
sending, recving = c.get_finished(finished_req_ids)
|
||||
if not recving and not sending:
|
||||
continue
|
||||
# Aggregate finished recving request ids.
|
||||
finished_recving.update(recving or ())
|
||||
# Aggregate finished sending request ids - only include
|
||||
# once we've drained the "extra" count (for cases where
|
||||
# more than one connector is async-saving the same request).
|
||||
for req_id in sending or ():
|
||||
extra_pending = self._extra_async_saves.get(req_id)
|
||||
if extra_pending is None:
|
||||
finished_sending.add(req_id)
|
||||
continue
|
||||
assert extra_pending > 0
|
||||
if extra_pending == 1:
|
||||
del self._extra_async_saves[req_id]
|
||||
else:
|
||||
self._extra_async_saves[req_id] = extra_pending - 1
|
||||
|
||||
return finished_sending or None, finished_recving or None
|
||||
|
||||
# ==============================
|
||||
# Scheduler-side methods
|
||||
# ==============================
|
||||
def get_num_new_matched_tokens(
|
||||
self,
|
||||
request: "Request",
|
||||
num_computed_tokens: int,
|
||||
) -> tuple[int, bool]:
|
||||
to_return = (0, False)
|
||||
for i, c in enumerate(self._connectors):
|
||||
toks, load_async = c.get_num_new_matched_tokens(
|
||||
request, num_computed_tokens)
|
||||
# The first connector that has new matched tokens will be assigned
|
||||
# to this request.
|
||||
if to_return[0] == 0 and toks > 0:
|
||||
self._requests_to_connector[request.request_id] = i
|
||||
to_return = (toks, load_async)
|
||||
return to_return
|
||||
|
||||
def update_state_after_alloc(self, request: "Request",
|
||||
blocks: "KVCacheBlocks",
|
||||
num_external_tokens: int):
|
||||
chosen_connector = self._requests_to_connector.get(
|
||||
request.request_id, -1)
|
||||
empty_blocks = blocks.new_empty()
|
||||
for i, c in enumerate(self._connectors):
|
||||
if i == chosen_connector:
|
||||
# Forward call to the chosen connector (if any).
|
||||
c.update_state_after_alloc(request, blocks,
|
||||
num_external_tokens)
|
||||
else:
|
||||
# Call with empty blocks for other connectors.
|
||||
c.update_state_after_alloc(request, empty_blocks, 0)
|
||||
|
||||
def build_connector_meta(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput) -> MultiKVConnectorMetadata:
|
||||
metadata = MultiKVConnectorMetadata(metadata=tuple(
|
||||
c.build_connector_meta(scheduler_output)
|
||||
for c in self._connectors))
|
||||
if self._extra_async_saves:
|
||||
metadata.extra_async_saves = self._extra_async_saves
|
||||
self._extra_async_saves = {}
|
||||
return metadata
|
||||
|
||||
def request_finished(
|
||||
self,
|
||||
request: "Request",
|
||||
blocks: list[int],
|
||||
) -> tuple[bool, Optional[dict[str, Any]]]:
|
||||
async_saves = 0
|
||||
kv_txfer_params = None
|
||||
for c in self._connectors:
|
||||
async_save, txfer_params = c.request_finished(request, blocks)
|
||||
if async_save:
|
||||
async_saves += 1
|
||||
if txfer_params is not None:
|
||||
if kv_txfer_params is not None:
|
||||
#TODO we can probably change this to merge the dicts here,
|
||||
# checking for key clashes.
|
||||
raise RuntimeError(
|
||||
"Only one connector can produce KV transfer params")
|
||||
kv_txfer_params = txfer_params
|
||||
if async_saves > 1:
|
||||
self._extra_async_saves[request.request_id] = async_saves - 1
|
||||
|
||||
# Clean up other state for this request.
|
||||
self._requests_to_connector.pop(request.request_id, None)
|
||||
|
||||
return async_saves > 0, kv_txfer_params
|
||||
1030
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
Normal file
1030
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,384 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import hashlib
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import safetensors
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backends.mla.common import MLACommonMetadata
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.request import Request
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReqMeta:
|
||||
# Request tokens
|
||||
token_ids: torch.Tensor
|
||||
# Slot mappings, should have the same length as token_ids
|
||||
slot_mapping: torch.Tensor
|
||||
# Is store or load
|
||||
is_store: bool
|
||||
|
||||
@staticmethod
|
||||
def make_meta(token_ids: list[int], block_ids: list[int], block_size: int,
|
||||
is_store: bool) -> "ReqMeta":
|
||||
valid_num_tokens = align_to_block_size(len(token_ids), block_size)
|
||||
token_ids_tensor = torch.tensor(token_ids)[:valid_num_tokens]
|
||||
block_ids_tensor = torch.tensor(block_ids)
|
||||
num_blocks = block_ids_tensor.shape[0]
|
||||
block_offsets = torch.arange(0, block_size)
|
||||
slot_mapping = block_offsets.reshape((1, block_size)) + \
|
||||
block_ids_tensor.reshape((num_blocks, 1)) * block_size
|
||||
slot_mapping = slot_mapping.flatten()[:valid_num_tokens]
|
||||
return ReqMeta(
|
||||
token_ids=token_ids_tensor,
|
||||
slot_mapping=slot_mapping,
|
||||
is_store=is_store,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SharedStorageConnectorMetadata(KVConnectorMetadata):
|
||||
requests: list[ReqMeta]
|
||||
|
||||
def __init__(self):
|
||||
self.requests = []
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
token_ids: list[int],
|
||||
block_ids: list[int],
|
||||
block_size: int,
|
||||
is_store: bool,
|
||||
) -> None:
|
||||
self.requests.append(
|
||||
ReqMeta.make_meta(token_ids, block_ids, block_size, is_store))
|
||||
|
||||
|
||||
class SharedStorageConnector(KVConnectorBase_V1):
|
||||
# NOTE: This is Simple debug implementation of the KV connector.
|
||||
# It save / load the KV cache to / from the disk.
|
||||
# It does extra work which will overwrite the existing prefix-cache in GPU
|
||||
# - to remove the overhead, need to add some "mask" in the ReqMeta class
|
||||
|
||||
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
|
||||
super().__init__(vllm_config=vllm_config, role=role)
|
||||
self._block_size = vllm_config.cache_config.block_size
|
||||
self._requests_need_load: dict[str, Request] = {}
|
||||
transfer_config = vllm_config.kv_transfer_config
|
||||
self._storage_path = transfer_config.get_from_extra_config(
|
||||
"shared_storage_path", "/tmp")
|
||||
logger.info(vllm_config.kv_transfer_config)
|
||||
logger.info("Shared storage path is %s", self._storage_path)
|
||||
|
||||
def start_load_kv(self, forward_context: "ForwardContext",
|
||||
**kwargs) -> None:
|
||||
"""Start loading the KV cache from the connector buffer to vLLM's
|
||||
paged KV buffer.
|
||||
|
||||
Args:
|
||||
forward_context (ForwardContext): the forward context.
|
||||
**kwargs: additional arguments for the load operation
|
||||
|
||||
Note:
|
||||
The number of elements in kv_caches and layer_names should be
|
||||
the same.
|
||||
"""
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
|
||||
def inject_kv_into_layer(
|
||||
dst_kv_cache_layer: torch.Tensor,
|
||||
src_kv_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
) -> None:
|
||||
"""Inject the KV cache into the layer.
|
||||
|
||||
Args:
|
||||
dst_kv_cache_layer (torch.Tensor): the destination KV cache
|
||||
layer. In shape [2, num_pages, page_size, xxx] if not
|
||||
using MLA, [num_pages, page_size, xxx] otherwise.
|
||||
src_kv_cache (torch.Tensor): the source KV cache. In shape
|
||||
[2, num_tokens, xxx] if not using MLA, [num_tokens, xxx]
|
||||
otherwise.
|
||||
slot_mapping (torch.Tensor): the slot mapping. In shape
|
||||
[num_tokens].
|
||||
"""
|
||||
dst_kv_cache_layer_shape = dst_kv_cache_layer.shape
|
||||
if isinstance(attn_metadata, MLACommonMetadata):
|
||||
num_pages = dst_kv_cache_layer_shape[0]
|
||||
page_size = dst_kv_cache_layer_shape[1]
|
||||
dst_kv_cache_layer = dst_kv_cache_layer.reshape(
|
||||
num_pages * page_size, -1)
|
||||
dst_kv_cache_layer[slot_mapping, ...] = src_kv_cache
|
||||
dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape)
|
||||
else:
|
||||
num_pages = dst_kv_cache_layer_shape[1]
|
||||
page_size = dst_kv_cache_layer_shape[2]
|
||||
dst_kv_cache_layer = dst_kv_cache_layer.reshape(
|
||||
2, num_pages * page_size, -1)
|
||||
dst_kv_cache_layer[:, slot_mapping, ...] = src_kv_cache
|
||||
dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape)
|
||||
|
||||
# Get the metadata
|
||||
metadata: KVConnectorMetadata = self._get_connector_metadata()
|
||||
assert isinstance(metadata, SharedStorageConnectorMetadata)
|
||||
|
||||
if metadata is None:
|
||||
logger.warning(
|
||||
"In connector.start_load_kv, but the connector metadata is None"
|
||||
)
|
||||
return
|
||||
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
if attn_metadata is None:
|
||||
logger.warning(
|
||||
"In connector.start_load_kv, but the attn_metadata is None")
|
||||
return
|
||||
|
||||
# Load the KV for each request each layer
|
||||
for request in metadata.requests:
|
||||
if request.is_store:
|
||||
continue
|
||||
logger.info("Inject KV cache of %d tokens to the paged memory",
|
||||
len(request.slot_mapping))
|
||||
for layer_name in forward_context.no_compile_layers:
|
||||
attn_layer = forward_context.no_compile_layers[layer_name]
|
||||
kv_cache_layer = attn_layer.kv_cache[\
|
||||
forward_context.virtual_engine]
|
||||
|
||||
filename = self._generate_filename_debug(
|
||||
layer_name, request.token_ids)
|
||||
kv_cache = safetensors.torch.load_file(
|
||||
filename)["kv_cache"].cuda()
|
||||
inject_kv_into_layer(kv_cache_layer, kv_cache,
|
||||
request.slot_mapping)
|
||||
|
||||
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||
"""Blocking until the KV for a specific layer is loaded into vLLM's
|
||||
paged buffer.
|
||||
|
||||
This interface will be useful for layer-by-layer pipelining.
|
||||
|
||||
Args:
|
||||
layer_name: the name of that layer
|
||||
"""
|
||||
return
|
||||
|
||||
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
|
||||
attn_metadata: "AttentionMetadata", **kwargs) -> None:
|
||||
"""Start saving the KV cache of the layer from vLLM's paged buffer
|
||||
to the connector.
|
||||
|
||||
Args:
|
||||
layer_name (str): the name of the layer.
|
||||
kv_layer (torch.Tensor): the paged KV buffer of the current
|
||||
layer in vLLM.
|
||||
attn_metadata (AttentionMetadata): the attention metadata.
|
||||
**kwargs: additional arguments for the save operation.
|
||||
"""
|
||||
|
||||
def extract_kv_from_layer(
|
||||
layer: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Extract the KV cache from the layer.
|
||||
|
||||
Assume the shape of the layer is (2, num_pages, page_size, xxx)
|
||||
if MLA is not used, and (num_pages, page_size, xxx) otherwise.
|
||||
"""
|
||||
if isinstance(attn_metadata, MLACommonMetadata):
|
||||
num_pages, page_size = layer.shape[0], layer.shape[1]
|
||||
return layer.reshape(num_pages * page_size, -1)[slot_mapping,
|
||||
...]
|
||||
num_pages, page_size = layer.shape[1], layer.shape[2]
|
||||
return layer.reshape(2, num_pages * page_size, -1)[:, slot_mapping,
|
||||
...]
|
||||
|
||||
connector_metadata = self._get_connector_metadata()
|
||||
assert isinstance(connector_metadata, SharedStorageConnectorMetadata)
|
||||
for request in connector_metadata.requests:
|
||||
if request.is_store:
|
||||
filename = self._generate_filename_debug(
|
||||
layer_name, request.token_ids)
|
||||
kv_cache = extract_kv_from_layer(kv_layer,
|
||||
request.slot_mapping)
|
||||
tensors = {"kv_cache": kv_cache.detach().cpu()}
|
||||
safetensors.torch.save_file(tensors, filename)
|
||||
|
||||
def wait_for_save(self):
|
||||
return
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self,
|
||||
request: "Request",
|
||||
num_computed_tokens: int,
|
||||
) -> tuple[int, bool]:
|
||||
"""
|
||||
Get number of new tokens that can be loaded from the
|
||||
external KV cache beyond the num_computed_tokens.
|
||||
|
||||
Args:
|
||||
request (Request): the request object.
|
||||
num_computed_tokens (int): the number of locally
|
||||
computed tokens for this request
|
||||
|
||||
Returns:
|
||||
the number of tokens that can be loaded from the
|
||||
external KV cache beyond what is already computed.
|
||||
"""
|
||||
# NOTE: in this debug implementation, we assume that the prompt is
|
||||
# cached_prompt + newly_generated_single_token
|
||||
# Therefore, we use prompt_token_ids[:-1] to determine the folder name
|
||||
|
||||
# NOTE: in current v1 scheduler, the num_computed_tokens is aligned
|
||||
# with the block granularity. And it expects the returned blocks and
|
||||
# num_computed_tokens to also be aligned with the block granularity.
|
||||
if not self._found_match_for_request(request):
|
||||
return 0, False
|
||||
|
||||
logger.info("External Cache Hit!")
|
||||
|
||||
# Now, first num_tokens_to_check tokens are hit, we need to prepare
|
||||
# the metadata for the worker connector to correctly load the KV
|
||||
num_tokens_to_check = align_to_block_size(
|
||||
len(request.prompt_token_ids) - 1, self._block_size)
|
||||
|
||||
return num_tokens_to_check - num_computed_tokens, False
|
||||
|
||||
def update_state_after_alloc(self, request: "Request",
|
||||
blocks: "KVCacheBlocks",
|
||||
num_external_tokens: int):
|
||||
"""
|
||||
Update KVConnector state after block allocation.
|
||||
|
||||
If blocks were allocated, add to _requests_need_load,
|
||||
such that we load the KVs in the next forward pass.
|
||||
"""
|
||||
if num_external_tokens > 0:
|
||||
self._requests_need_load[request.request_id] = request
|
||||
|
||||
def build_connector_meta(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
) -> KVConnectorMetadata:
|
||||
"""Build the connector metadata for this step.
|
||||
|
||||
This function should NOT modify any fields in the scheduler_output.
|
||||
Also, calling this function will reset the state of the connector.
|
||||
|
||||
Args:
|
||||
scheduler_output (SchedulerOutput): the scheduler output object.
|
||||
"""
|
||||
meta = SharedStorageConnectorMetadata()
|
||||
|
||||
total_need_load = 0
|
||||
for new_req in scheduler_output.scheduled_new_reqs:
|
||||
if new_req.req_id in self._requests_need_load:
|
||||
meta.add_request(token_ids=new_req.prompt_token_ids,
|
||||
block_ids=new_req.block_ids[0],
|
||||
block_size=self._block_size,
|
||||
is_store=False)
|
||||
total_need_load += 1
|
||||
else:
|
||||
# NOTE: here, we set the store and load being exclusive,
|
||||
# but a single request can have both store and load.
|
||||
# NOTE(rob): for this debug implementation, we only cache
|
||||
# the original prompt tokens.
|
||||
if not self._found_match_for_request(new_req):
|
||||
meta.add_request(token_ids=new_req.prompt_token_ids,
|
||||
block_ids=new_req.block_ids[0],
|
||||
block_size=self._block_size,
|
||||
is_store=True)
|
||||
|
||||
for cached_req in scheduler_output.scheduled_cached_reqs:
|
||||
# NOTE(rob): here we rely on the resumed requests being
|
||||
# the first N requests in the list scheduled_cache_reqs.
|
||||
if not cached_req.resumed_from_preemption:
|
||||
break
|
||||
if cached_req.req_id in self._requests_need_load:
|
||||
# NOTE(rob): cached_req_data does not have the full
|
||||
# list of token ids (only new tokens). So we look it
|
||||
# up in the actual request object.
|
||||
request = self._requests_need_load[cached_req.req_id]
|
||||
total_tokens = (len(cached_req.new_token_ids) +
|
||||
cached_req.num_computed_tokens)
|
||||
token_ids = request.all_token_ids[:total_tokens]
|
||||
|
||||
# NOTE(rob): For resumed req, new_block_ids is all
|
||||
# of the block_ids for the request.
|
||||
block_ids = cached_req.new_block_ids[0]
|
||||
|
||||
meta.add_request(token_ids=token_ids,
|
||||
block_ids=block_ids,
|
||||
block_size=self._block_size,
|
||||
is_store=False)
|
||||
total_need_load += 1
|
||||
|
||||
assert total_need_load == len(self._requests_need_load)
|
||||
self._requests_need_load.clear()
|
||||
return meta
|
||||
|
||||
# ==============================
|
||||
# Helper functions
|
||||
# ==============================
|
||||
|
||||
def _found_match_for_request(
|
||||
self,
|
||||
request: "Request",
|
||||
) -> bool:
|
||||
"""Check if the cache is hit for the request.
|
||||
"""
|
||||
num_tokens_to_check = align_to_block_size(
|
||||
len(request.prompt_token_ids) - 1, self._block_size)
|
||||
foldername = self._generate_foldername_debug(torch.tensor(
|
||||
request.prompt_token_ids)[:num_tokens_to_check],
|
||||
create_folder=False)
|
||||
return os.path.exists(foldername)
|
||||
|
||||
def _generate_foldername_debug(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
create_folder=False,
|
||||
) -> str:
|
||||
"""Generate a folder name based on the hash of the bytes of the input
|
||||
ids.
|
||||
"""
|
||||
input_ids_bytes = input_ids.numpy().tobytes()
|
||||
input_ids_hash = hashlib.md5(input_ids_bytes,
|
||||
usedforsecurity=False).hexdigest()
|
||||
foldername = os.path.join(self._storage_path, input_ids_hash)
|
||||
if create_folder:
|
||||
os.makedirs(foldername, exist_ok=True)
|
||||
return foldername
|
||||
|
||||
def _generate_filename_debug(
|
||||
self,
|
||||
layer_name: str,
|
||||
input_ids: torch.Tensor,
|
||||
) -> str:
|
||||
"""Generate a file name based on the layer name and the hash
|
||||
of the bytes of the input ids.
|
||||
"""
|
||||
foldername = self._generate_foldername_debug(input_ids,
|
||||
create_folder=True)
|
||||
return os.path.join(foldername, f"{layer_name}.safetensors")
|
||||
|
||||
|
||||
def align_to_block_size(num_tokens: int, block_size) -> int:
|
||||
"""Align the number of tokens to the block size.
|
||||
"""
|
||||
return (num_tokens - 1) // block_size * block_size
|
||||
77
vllm/distributed/kv_transfer/kv_connector_agent.py
Normal file
77
vllm/distributed/kv_transfer/kv_connector_agent.py
Normal file
@@ -0,0 +1,77 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""A centralized entrypoint to perform distributed KV cache transfer.
|
||||
|
||||
This implementation is a shim wrapper on two APIs exposed by `kv_connector`:
|
||||
1. `send_kv_caches_and_hidden_states`
|
||||
2. `recv_kv_caches_and_hidden_states
|
||||
"""
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
|
||||
from vllm.config import VllmConfig
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.distributed.kv_transfer.kv_connector.factory import (
|
||||
KVConnectorFactory)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class KVTransferAgent:
|
||||
"""
|
||||
A class designated for distributed KV transfer
|
||||
|
||||
Target use cases:
|
||||
1. Disaggregated prefill
|
||||
2. Remote KV cache storage
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
rank: int,
|
||||
local_rank: int,
|
||||
config: "VllmConfig",
|
||||
):
|
||||
|
||||
self.config = config
|
||||
|
||||
if config.kv_transfer_config is None:
|
||||
raise ValueError("KVTransferConfig is not set in the VllmConfig,"
|
||||
" cannot initialize KVConnector.")
|
||||
|
||||
assert self.config.kv_transfer_config.is_kv_transfer_instance, "KV"\
|
||||
"TransferAgent should only be used when kv_connector is set."
|
||||
|
||||
self.connector = KVConnectorFactory.create_connector_v0(
|
||||
rank, local_rank, config)
|
||||
|
||||
def send_kv_caches_and_hidden_states(
|
||||
self,
|
||||
model_executable: torch.nn.Module,
|
||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
||||
kv_caches: list[torch.Tensor],
|
||||
hidden_or_intermediate_states: Union[torch.Tensor,
|
||||
IntermediateTensors],
|
||||
) -> None:
|
||||
|
||||
self.connector.send_kv_caches_and_hidden_states(
|
||||
model_executable, model_input, kv_caches,
|
||||
hidden_or_intermediate_states)
|
||||
|
||||
def close(self) -> None:
|
||||
self.connector.close()
|
||||
|
||||
def recv_kv_caches_and_hidden_states(
|
||||
self, model_executable: torch.nn.Module,
|
||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
||||
kv_caches: list[torch.Tensor]
|
||||
) -> tuple[Union[torch.Tensor, IntermediateTensors], bool,
|
||||
"ModelInputForGPUWithSamplingMetadata"]:
|
||||
|
||||
return self.connector.recv_kv_caches_and_hidden_states(
|
||||
model_executable, model_input, kv_caches)
|
||||
175
vllm/distributed/kv_transfer/kv_lookup_buffer/base.py
Normal file
175
vllm/distributed/kv_transfer/kv_lookup_buffer/base.py
Normal file
@@ -0,0 +1,175 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
This file contains a new class `KVLookupBufferBase` that allows developers to
|
||||
think of KV cache operations as inserting new KV cache entries (`insert`)
|
||||
into the lookup buffer and querying existing KV caches (`drop_select`)
|
||||
from the lookup buffer.
|
||||
|
||||
This file also contains a new class `KVStoreBufferBase` that allows developers
|
||||
to manage the KVCache buffer as a simple key-value storage buffer with basic
|
||||
put/get operations.
|
||||
|
||||
These classes above are abstracted behind class `KVCacheBufferBase`.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class KVCacheBufferBase(ABC):
|
||||
"""
|
||||
Abstract base class for a KVCache buffer.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def close(self) -> None:
|
||||
"""Close the buffer and release resources.
|
||||
|
||||
This method is responsible for cleaning up resources related to the
|
||||
KVCache buffer when it is no longer needed.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: This method must be implemented in subclasses.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class KVLookupBufferBase(KVCacheBufferBase):
|
||||
"""
|
||||
Abstract base class for a KVCache lookup buffer.
|
||||
|
||||
This class provides an abstraction for a key-value (KV) cache lookup buffer.
|
||||
|
||||
The key of the lookup buffer:
|
||||
- input_tokens: token IDs of the request
|
||||
- roi: a binary mask on top of input_tokens.
|
||||
- Purpose of roi: Since KV cache may only be available for a subset of
|
||||
tokens in the input (for example, when vLLM is connected to an external
|
||||
KV cache service), roi specifies the subset of tokens that the KV cache
|
||||
is associated with.
|
||||
- NOTE: roi can be further extended to describe which part of KV the
|
||||
current process is holding (each process may only hold a part of KV
|
||||
due to TP and PP). This is not implemented for now.
|
||||
|
||||
The value of the lookup buffer:
|
||||
- key: the key tensor in the KV cache
|
||||
- value: the value tensor in the KV cache
|
||||
- hidden: the final hidden state generated by model forwarding. This allows
|
||||
vLLM to bypass further model forwarding by transmitting the hidden state.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor,
|
||||
key: torch.Tensor, value: torch.Tensor,
|
||||
hidden: torch.Tensor) -> None:
|
||||
"""Insert into the lookup buffer.
|
||||
|
||||
The functionality is similar to the following python statement
|
||||
```
|
||||
buffer[input_tokens, roi] = [key, value, hidden]
|
||||
```
|
||||
|
||||
FIXME: in the future, we should only have two arguments, key and value,
|
||||
where key is a tensor dict and value is a tensor dict.
|
||||
|
||||
FIXME: we should transmit both sampler outputs and the hidden states.
|
||||
|
||||
Args:
|
||||
input_tokens (torch.Tensor): token IDs.
|
||||
roi (torch.Tensor): A binary mask on top of the input tokens
|
||||
key (torch.Tensor): The key tensor in the KV cache.
|
||||
value (torch.Tensor): The value tensor in the KV cache.
|
||||
hidden (torch.Tensor): The final hidden state tensor generated
|
||||
during model forwarding to bypass model
|
||||
forwarding.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: This method must be implemented in subclasses.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def drop_select(
|
||||
self, input_tokens: Optional[torch.Tensor],
|
||||
roi: Optional[torch.Tensor]) -> list[Optional[torch.Tensor]]:
|
||||
"""Select and *drop* KV cache entries from the lookup buffer.
|
||||
|
||||
The functionality is similar to the following python statements
|
||||
```
|
||||
ret = buffer.pop(input_tokens, roi)
|
||||
return ret
|
||||
```
|
||||
|
||||
If `input_tokens` and `roi` is `None`, it means selecting any of the
|
||||
KV caches in the buffer, return, and remove it from the buffer, useful
|
||||
when offloading KV cache to KV cache storage service.
|
||||
|
||||
Args:
|
||||
input_tokens (torch.Tensor): token IDs.
|
||||
roi (torch.Tensor): A binary mask on top of the input tokens
|
||||
|
||||
Returns:
|
||||
list[Optional[torch.Tensor]]: A list of tensors. Can be None.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: This method must be implemented in subclasses.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class KVStoreBufferBase(KVCacheBufferBase):
|
||||
"""
|
||||
Abstract base class for a KVCache storage buffer with key-value semantics.
|
||||
This class provides a simple key-value storage buffer abstract with basic
|
||||
put/get operations, which enables flexible KVCache transfer granular
|
||||
control.
|
||||
|
||||
The functionality is similar to a distributed key-value store, where:
|
||||
- Key: A unique string identifier for the cached entry
|
||||
- Value:
|
||||
- Tensor to be stored and retrieved
|
||||
- None (indicating deletion or empty value)
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def put(
|
||||
self,
|
||||
key: str,
|
||||
value: Optional[torch.Tensor],
|
||||
) -> None:
|
||||
"""Store a key-value pair in the buffer.
|
||||
|
||||
Args:
|
||||
key (str): Unique identifier for a tensor, this tensor could be the
|
||||
key cache tensor, value cache tensor, or hidden state tensor
|
||||
generated during model forwarding.
|
||||
|
||||
value (Optional[torch.Tensor]): Tensor to be stored.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: This method must be implemented in subclasses.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get(
|
||||
self,
|
||||
key: str,
|
||||
) -> Optional[torch.Tensor]:
|
||||
"""Retrieve a value from the buffer by key.
|
||||
|
||||
Args:
|
||||
key (str): Unique identifier for a tensor, this tensor could be the
|
||||
key cache tensor, value cache tensor, or hidden state tensor
|
||||
generated during model forwarding.
|
||||
|
||||
Returns:
|
||||
Optional[torch.Tensor]: Stored tensor if exists, None otherwise.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: This method must be implemented in subclasses.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
161
vllm/distributed/kv_transfer/kv_lookup_buffer/mooncake_store.py
Normal file
161
vllm/distributed/kv_transfer/kv_lookup_buffer/mooncake_store.py
Normal file
@@ -0,0 +1,161 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
This file contains a new class `MooncakeStore` that allows developers to
|
||||
think of KV cache transfer operations as putting new KV cache entries
|
||||
into a remote KVStore-based lookup buffer and getting existing KV caches
|
||||
from this remote lookup buffer.
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from safetensors.torch import load as safetensors_load
|
||||
from safetensors.torch import save as safetensors_save
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_lookup_buffer.base import (
|
||||
KVStoreBufferBase)
|
||||
from vllm.logger import init_logger
|
||||
|
||||
DEFAULT_GLOBAL_SEGMENT_SIZE = 3355443200 # 3.125 GiB
|
||||
DEFAULT_LOCAL_BUFFER_SIZE = 1073741824 # 1.0 GiB
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MooncakeStoreConfig:
|
||||
local_hostname: str
|
||||
metadata_server: str
|
||||
global_segment_size: int
|
||||
local_buffer_size: int
|
||||
protocol: str
|
||||
device_name: str
|
||||
master_server_address: str
|
||||
|
||||
@staticmethod
|
||||
def from_file(file_path: str) -> 'MooncakeStoreConfig':
|
||||
"""Load the config from a JSON file."""
|
||||
with open(file_path) as fin:
|
||||
config = json.load(fin)
|
||||
return MooncakeStoreConfig(
|
||||
local_hostname=config.get("local_hostname"),
|
||||
metadata_server=config.get("metadata_server"),
|
||||
global_segment_size=config.get("global_segment_size",
|
||||
DEFAULT_GLOBAL_SEGMENT_SIZE),
|
||||
local_buffer_size=config.get("local_buffer_size",
|
||||
DEFAULT_LOCAL_BUFFER_SIZE),
|
||||
protocol=config.get("protocol", "tcp"),
|
||||
device_name=config.get("device_name", ""),
|
||||
master_server_address=config.get("master_server_address"),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def load_from_env() -> 'MooncakeStoreConfig':
|
||||
"""Load config from a file specified in the environment variable."""
|
||||
config_file_path = os.getenv('MOONCAKE_CONFIG_PATH')
|
||||
if config_file_path is None:
|
||||
raise ValueError(
|
||||
"The environment variable 'MOONCAKE_CONFIG_PATH' is not set.")
|
||||
return MooncakeStoreConfig.from_file(config_file_path)
|
||||
|
||||
|
||||
class MooncakeStore(KVStoreBufferBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: VllmConfig,
|
||||
):
|
||||
|
||||
try:
|
||||
from mooncake.store import MooncakeDistributedStore
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Please install mooncake by following the instructions at "
|
||||
"https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501
|
||||
"to run vLLM with MooncakeConnector.") from e
|
||||
|
||||
try:
|
||||
self.store = MooncakeDistributedStore()
|
||||
self.config = MooncakeStoreConfig.load_from_env()
|
||||
logger.info("Mooncake Configuration loaded successfully.")
|
||||
|
||||
self.store.setup(self.config.local_hostname,
|
||||
self.config.metadata_server,
|
||||
self.config.global_segment_size,
|
||||
self.config.local_buffer_size,
|
||||
self.config.protocol, self.config.device_name,
|
||||
self.config.master_server_address)
|
||||
|
||||
except ValueError as e:
|
||||
logger.error("Configuration loading failed: %s", e)
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"An error occurred while loading the configuration: %s", exc)
|
||||
raise
|
||||
|
||||
def close(self):
|
||||
# MooncakeDistributedStore will automatically call the destructor, so
|
||||
# it is unnecessary to close it manually.
|
||||
pass
|
||||
|
||||
def put(
|
||||
self,
|
||||
key: str,
|
||||
value: Optional[torch.Tensor],
|
||||
) -> None:
|
||||
# A message queue needs to be introduced before making it asynchronous.
|
||||
if value is not None:
|
||||
self._put_impl(key, value)
|
||||
|
||||
def get(
|
||||
self,
|
||||
key: str,
|
||||
) -> Optional[torch.Tensor]:
|
||||
# A message queue needs to be introduced before making it asynchronous.
|
||||
value = self._get_impl(key)
|
||||
return value
|
||||
|
||||
def _put_impl(
|
||||
self,
|
||||
key: str,
|
||||
value: torch.Tensor,
|
||||
) -> None:
|
||||
"""Put KVCache to Mooncake Store"""
|
||||
device_id = value.device.index if value.device.type == 'cuda' else -1
|
||||
device_tensor = torch.tensor(device_id, dtype=torch.int32)
|
||||
value_bytes = safetensors_save({
|
||||
"tensor": value,
|
||||
"device_id": device_tensor
|
||||
})
|
||||
try:
|
||||
self.store.put(key, value_bytes)
|
||||
except TypeError as err:
|
||||
logger.error("Failed to put value into Mooncake Store: %s", err)
|
||||
raise TypeError("Mooncake Store Put Type Error.") from err
|
||||
|
||||
def _get_impl(
|
||||
self,
|
||||
key: str,
|
||||
) -> Optional[torch.Tensor]:
|
||||
"""Get KVCache from Mooncake Store"""
|
||||
try:
|
||||
data = self.store.get(key)
|
||||
except TypeError as err:
|
||||
logger.error("Failed to get value from Mooncake Store: %s", err)
|
||||
raise TypeError("Mooncake Store Get Type Error.") from err
|
||||
|
||||
if data:
|
||||
loaded_tensors = safetensors_load(data)
|
||||
tensor = loaded_tensors["tensor"]
|
||||
device_id_tensor = loaded_tensors["device_id"]
|
||||
device_id = int(device_id_tensor.item())
|
||||
device = torch.device(
|
||||
'cuda', device_id) if device_id >= 0 else torch.device('cpu')
|
||||
return tensor.to(device)
|
||||
|
||||
return None
|
||||
237
vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py
Normal file
237
vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py
Normal file
@@ -0,0 +1,237 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Implements a distributed key-value (KV) cache transfer mechanism.
|
||||
|
||||
Key Features:
|
||||
- Distributed KV cache transmission using PyNccl pipes.
|
||||
- Non-blocking `insert`, blocking `drop_select`.
|
||||
- Use CPU signal pipe to avoid racing condition
|
||||
- Handles buffer size constraints and provide backpressure mechanism to
|
||||
stop the prefill instance when the decode instance is slow.
|
||||
"""
|
||||
import threading
|
||||
from collections import deque
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.distributed.kv_transfer.kv_lookup_buffer.base import (
|
||||
KVLookupBufferBase)
|
||||
from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class SimpleBuffer(KVLookupBufferBase):
|
||||
|
||||
def __init__(self, signal_pipe: KVPipeBase, data_pipe: KVPipeBase,
|
||||
buffer_size_thresh: float):
|
||||
"""
|
||||
signal_pipe: on CPU
|
||||
|
||||
NOTE: on-device recv will block all threads in the process, making the
|
||||
KV cache producer unable to listen to new request while transmitting
|
||||
KV cache. Luckily CPU recv only blocks the current thread so we use
|
||||
CPU recv to listen to new request.
|
||||
|
||||
data_pipe: on device (e.g. GPU)
|
||||
"""
|
||||
|
||||
self.buffer: deque[list[torch.Tensor]] = deque()
|
||||
|
||||
self.buffer_size = 0
|
||||
self.buffer_size_threshold = buffer_size_thresh
|
||||
self.buffer_cv = threading.Condition()
|
||||
self.signal_pipe = signal_pipe
|
||||
self.data_pipe = data_pipe
|
||||
self.request_handling_thread: Optional[threading.Thread] = None
|
||||
|
||||
self.normal_signal = torch.tensor([0], device="cpu")
|
||||
self.end_signal = None
|
||||
|
||||
def _matches(self, tokens_roi_sender: list[torch.Tensor],
|
||||
tokens_roi_recver: list[torch.Tensor]):
|
||||
|
||||
# tokens_roi_sender: tokens and roi of the producer (in the buffer)
|
||||
# tokens_roi_recver: tokens and roi of the consumer (query)
|
||||
|
||||
tokens_sender = tokens_roi_sender[0]
|
||||
tokens_recver = tokens_roi_recver[0]
|
||||
roi_sender = tokens_roi_sender[1]
|
||||
roi_recver = tokens_roi_recver[1]
|
||||
|
||||
if tokens_recver is None:
|
||||
# consumer sends an empty request
|
||||
# semantics: DROP SELECT * LIMIT 1
|
||||
# so any of the data in the buffer can be drop-selected
|
||||
return True
|
||||
|
||||
# Assuming that roi is a binary mask on tokens
|
||||
tokens_sender = tokens_sender[roi_sender]
|
||||
tokens_recver = tokens_recver[roi_recver]
|
||||
|
||||
# simple common prefix matching
|
||||
min_length = min(len(tokens_sender), len(tokens_recver))
|
||||
if torch.allclose(tokens_sender[:min_length],
|
||||
tokens_recver[:min_length]):
|
||||
return min_length
|
||||
|
||||
return 0
|
||||
|
||||
def _send_tensor_and_dec_size(self,
|
||||
tensor: Optional[torch.Tensor]) -> None:
|
||||
|
||||
assert tensor is not None, "Use self.data_pipe.send(None) instead"
|
||||
self.buffer_size -= tensor.element_size() * tensor.numel()
|
||||
if tensor.dtype == torch.bool:
|
||||
tensor = tensor.float()
|
||||
self.data_pipe.send_tensor(tensor)
|
||||
|
||||
def _get_element_size(self, data: Optional[Union[list, torch.Tensor]]):
|
||||
|
||||
if isinstance(data, torch.Tensor):
|
||||
return data.element_size() * data.numel()
|
||||
if not data:
|
||||
# cannot perform `not data` on a tensor
|
||||
# so this check needs to go after the check above
|
||||
return 0
|
||||
|
||||
raise AssertionError(f"Unknown data type {type(data)}")
|
||||
|
||||
def _add_to_buffer(self, input_tokens: torch.Tensor, roi: torch.Tensor,
|
||||
key: torch.Tensor, value: torch.Tensor,
|
||||
hidden: torch.Tensor):
|
||||
|
||||
if isinstance(input_tokens, torch.Tensor):
|
||||
input_tokens = input_tokens.clone()
|
||||
if isinstance(roi, torch.Tensor):
|
||||
roi = roi.clone()
|
||||
if isinstance(key, torch.Tensor):
|
||||
key = key.clone()
|
||||
if isinstance(value, torch.Tensor):
|
||||
value = value.clone()
|
||||
if isinstance(hidden, torch.Tensor):
|
||||
hidden = hidden.clone()
|
||||
|
||||
buffer_item = [input_tokens, roi, key, value, hidden]
|
||||
data_size = sum([self._get_element_size(data) for data in buffer_item])
|
||||
|
||||
with self.buffer_cv:
|
||||
if self.buffer_size + data_size > self.buffer_size_threshold:
|
||||
# log outside the while loop to avoid this message being logged
|
||||
# repeatedly.
|
||||
logger.debug("KV transfer buffer is full. Handling...")
|
||||
while self.buffer_size + data_size > self.buffer_size_threshold:
|
||||
self.buffer_cv.wait()
|
||||
|
||||
self.buffer_size += data_size
|
||||
self.buffer.append(buffer_item)
|
||||
self.buffer_cv.notify()
|
||||
|
||||
def _is_end_signal(self, signal):
|
||||
return signal is None
|
||||
|
||||
def drop_select_handler(self):
|
||||
|
||||
try:
|
||||
|
||||
while True:
|
||||
signal = self.signal_pipe.recv_tensor()
|
||||
if self._is_end_signal(signal):
|
||||
logger.info("Received end signal!")
|
||||
break
|
||||
|
||||
input_tokens = self.data_pipe.recv_tensor()
|
||||
|
||||
roi = self.data_pipe.recv_tensor()
|
||||
assert roi is not None, "Please provide the roi when sending "\
|
||||
"drop-select request"
|
||||
roi = (roi > 0.5)
|
||||
tokens_roi_recver = [input_tokens, roi]
|
||||
|
||||
def is_buffer_available(
|
||||
tokens_roi_recver: list[torch.Tensor], ) -> bool:
|
||||
# perform input tokens and roi matching
|
||||
# FIXME: this matching is O(n), ideally it should be O(1)
|
||||
# but this buffer size won't (and shouldn't) be too large so
|
||||
# the fix is not urgent.
|
||||
for _ in range(len(self.buffer)):
|
||||
if self._matches(self.buffer[0],
|
||||
tokens_roi_recver) > 0:
|
||||
return True
|
||||
# rotate the element we just accessed to the end
|
||||
self.buffer.rotate(-1)
|
||||
return False
|
||||
|
||||
with self.buffer_cv:
|
||||
while not is_buffer_available(tokens_roi_recver):
|
||||
logger.debug(
|
||||
"KV transfer buffer is not available. Waiting...")
|
||||
self.buffer_cv.wait()
|
||||
# need to clone the tensor
|
||||
# in case the tensor is freed before sending finishes
|
||||
matched_item = self.buffer.popleft()
|
||||
for tensor in matched_item:
|
||||
self._send_tensor_and_dec_size(tensor)
|
||||
self.buffer_cv.notify()
|
||||
|
||||
except RuntimeError as e:
|
||||
if 'Connection closed by peer' not in str(e):
|
||||
raise e
|
||||
|
||||
logger.debug("Closing drop_select_handler")
|
||||
|
||||
def drop_select(
|
||||
self, input_tokens: Optional[torch.Tensor],
|
||||
roi: Optional[torch.Tensor]) -> list[Optional[torch.Tensor]]:
|
||||
|
||||
assert self.request_handling_thread is None, \
|
||||
"drop_select should be called by the KV cache consumer "\
|
||||
"(e.g. the decode vLLM instance)"
|
||||
|
||||
if isinstance(input_tokens, torch.Tensor):
|
||||
input_tokens = input_tokens.clone()
|
||||
if isinstance(roi, torch.Tensor):
|
||||
roi = roi.clone().float()
|
||||
|
||||
self.signal_pipe.send_tensor(self.normal_signal)
|
||||
self.data_pipe.send_tensor(input_tokens)
|
||||
self.data_pipe.send_tensor(roi)
|
||||
|
||||
input_tokens = self.data_pipe.recv_tensor()
|
||||
roi = self.data_pipe.recv_tensor()
|
||||
if roi is not None:
|
||||
# convert from float tensor to bool tensor
|
||||
# as PyNccl does not support sending bool tensor
|
||||
roi = (roi > 0.5)
|
||||
key = self.data_pipe.recv_tensor()
|
||||
value = self.data_pipe.recv_tensor()
|
||||
hidden = self.data_pipe.recv_tensor()
|
||||
|
||||
return [input_tokens, roi, key, value, hidden]
|
||||
|
||||
def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor,
|
||||
key: torch.Tensor, value: torch.Tensor,
|
||||
hidden: torch.Tensor) -> None:
|
||||
|
||||
self._add_to_buffer(input_tokens, roi, key, value, hidden)
|
||||
|
||||
# when calling the insert, the current process is a sender
|
||||
# need to launch the request handler and start listening to request.
|
||||
if self.request_handling_thread is None:
|
||||
self.request_handling_thread = threading.Thread(
|
||||
target=self.drop_select_handler)
|
||||
self.request_handling_thread.start()
|
||||
|
||||
def close(self):
|
||||
|
||||
if hasattr(self, "request_handling_thread"
|
||||
) and self.request_handling_thread is not None:
|
||||
self.request_handling_thread.join()
|
||||
|
||||
else:
|
||||
# TODO: have a explicit close signal and have a explicit way to
|
||||
# check if it's requester
|
||||
self.signal_pipe.send_tensor(self.end_signal)
|
||||
0
vllm/distributed/kv_transfer/kv_pipe/__init__.py
Normal file
0
vllm/distributed/kv_transfer/kv_pipe/__init__.py
Normal file
67
vllm/distributed/kv_transfer/kv_pipe/base.py
Normal file
67
vllm/distributed/kv_transfer/kv_pipe/base.py
Normal file
@@ -0,0 +1,67 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
This file defines an interface `KVPipeBase`
|
||||
that provides an abstraction for sending and receiving tensors, or None, via
|
||||
distributed communications.
|
||||
|
||||
All classes instantiated from this interface are assumed to be a FIFO pipe.
|
||||
|
||||
If your distributed communication platform already supports key-value lookup,
|
||||
you can bypass this interface and directly start from `kv_lookup_buffer`.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class KVPipeBase(ABC):
|
||||
"""
|
||||
This class provides an interface for sending and receiving tensors, or
|
||||
None, by distributed communications.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def send_tensor(self, tensor: Optional[torch.Tensor]) -> None:
|
||||
"""Send a tensor, or None, via the pipe.
|
||||
|
||||
Need to support sending None -- important for error handling.
|
||||
|
||||
TODO: add a `key` argument so that we can use traditional
|
||||
key-value database as the distributed communication mechanism behind
|
||||
the pipe.
|
||||
|
||||
Args:
|
||||
tensor (Optional[torch.Tensor]): The tensor to be sent. Can be None.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: This method must be implemented in subclasses.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def recv_tensor(self) -> Optional[torch.Tensor]:
|
||||
"""Receive a tensor (can be None) from the pipeline.
|
||||
|
||||
Returns:
|
||||
Optional[torch.Tensor]: The tensor received from the pipeline. Can
|
||||
be None.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: This method must be implemented in subclasses.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def close(self) -> None:
|
||||
"""Close the pipeline and release resources.
|
||||
|
||||
This method is responsible for closing the communication pipeline
|
||||
and releasing any resources associated with it.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: This method must be implemented in subclasses.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
280
vllm/distributed/kv_transfer/kv_pipe/mooncake_pipe.py
Normal file
280
vllm/distributed/kv_transfer/kv_pipe/mooncake_pipe.py
Normal file
@@ -0,0 +1,280 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
import os
|
||||
import struct
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import zmq
|
||||
from safetensors.torch import load as safetensors_load
|
||||
from safetensors.torch import save as safetensors_save
|
||||
|
||||
from vllm.config import KVTransferConfig
|
||||
from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
NONE_INT = -150886311
|
||||
|
||||
|
||||
@dataclass
|
||||
class MooncakeTransferEngineConfig:
|
||||
prefill_url: str
|
||||
decode_url: str
|
||||
metadata_backend: Union[str, None]
|
||||
metadata_server: str
|
||||
protocol: str
|
||||
device_name: str
|
||||
|
||||
@staticmethod
|
||||
def from_file(file_path: str) -> 'MooncakeTransferEngineConfig':
|
||||
"""Load the config from a JSON file."""
|
||||
with open(file_path) as fin:
|
||||
config = json.load(fin)
|
||||
return MooncakeTransferEngineConfig(
|
||||
prefill_url=config.get("prefill_url"),
|
||||
decode_url=config.get("decode_url"),
|
||||
metadata_backend=config.get("metadata_backend", None),
|
||||
metadata_server=config.get("metadata_server"),
|
||||
protocol=config.get("protocol", "tcp"),
|
||||
device_name=config.get("device_name", ""),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def load_from_env() -> 'MooncakeTransferEngineConfig':
|
||||
"""Load config from a file specified in the environment variable."""
|
||||
config_file_path = os.getenv('MOONCAKE_CONFIG_PATH')
|
||||
if config_file_path is None:
|
||||
raise ValueError(
|
||||
"The environment variable 'MOONCAKE_CONFIG_PATH' is not set.")
|
||||
return MooncakeTransferEngineConfig.from_file(config_file_path)
|
||||
|
||||
|
||||
class MooncakeTransferEngine:
|
||||
"""Handles the transfer of data using mooncake_vllm_adaptor and ZeroMQ."""
|
||||
|
||||
def __init__(self, kv_rank: int, local_rank: int):
|
||||
try:
|
||||
from mooncake.engine import TransferEngine
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Please install mooncake by following the instructions at "
|
||||
"https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501
|
||||
"to run vLLM with MooncakeConnector.") from e
|
||||
|
||||
self.engine = TransferEngine()
|
||||
self.local_rank = local_rank
|
||||
|
||||
try:
|
||||
self.config = MooncakeTransferEngineConfig.load_from_env()
|
||||
logger.info("Mooncake Configuration loaded successfully.")
|
||||
except ValueError as e:
|
||||
logger.error(e)
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"An error occurred while loading the configuration: %s", exc)
|
||||
raise
|
||||
prefill_host, base_prefill_port = self.config.prefill_url.split(':')
|
||||
decode_host, base_decode_port = self.config.decode_url.split(':')
|
||||
|
||||
# Avoid ports conflict when running prefill and decode on the same node
|
||||
if prefill_host == decode_host and \
|
||||
base_prefill_port == base_decode_port:
|
||||
base_decode_port = str(int(base_decode_port) + 100)
|
||||
|
||||
prefill_port = int(base_prefill_port) + self.local_rank
|
||||
decode_port = int(base_decode_port) + self.local_rank
|
||||
self.prefill_url = ':'.join([prefill_host, str(prefill_port)])
|
||||
self.decode_url = ':'.join([decode_host, str(decode_port)])
|
||||
|
||||
self.initialize(self.prefill_url if kv_rank == 0 else self.decode_url,
|
||||
self.config.metadata_server, self.config.protocol,
|
||||
self.config.device_name, self.config.metadata_backend)
|
||||
|
||||
self.remote_url = (self.decode_url
|
||||
if kv_rank == 0 else self.prefill_url)
|
||||
|
||||
# Initialize ZeroMQ context and sockets
|
||||
self.context = zmq.Context() # type: ignore[attr-defined]
|
||||
self.sender_socket = self.context.socket(zmq.constants.PUSH)
|
||||
self.receiver_socket = self.context.socket(zmq.constants.PULL)
|
||||
self.sender_ack = self.context.socket(zmq.constants.PULL)
|
||||
self.receiver_ack = self.context.socket(zmq.constants.PUSH)
|
||||
|
||||
self.buffer_cleaner = ThreadPoolExecutor(max_workers=1)
|
||||
self._setup_metadata_sockets(kv_rank, prefill_host, base_prefill_port,
|
||||
decode_host, base_decode_port)
|
||||
|
||||
def _setup_metadata_sockets(self, kv_rank: int, p_host: str, p_port: str,
|
||||
d_host: str, d_port: str) -> None:
|
||||
"""Set up ZeroMQ sockets for sending and receiving data."""
|
||||
# Offsets < 8 are left for initialization in case tp and pp are enabled
|
||||
p_rank_offset = int(p_port) + 8 + self.local_rank * 2
|
||||
d_rank_offset = int(d_port) + 8 + self.local_rank * 2
|
||||
if kv_rank == 0:
|
||||
self.sender_socket.bind(f"tcp://{p_host}:{p_rank_offset + 1}")
|
||||
self.receiver_socket.connect(f"tcp://{d_host}:{d_rank_offset + 1}")
|
||||
self.sender_ack.connect(f"tcp://{d_host}:{d_rank_offset + 2}")
|
||||
self.receiver_ack.bind(f"tcp://{p_host}:{p_rank_offset + 2}")
|
||||
else:
|
||||
self.receiver_socket.connect(f"tcp://{p_host}:{p_rank_offset + 1}")
|
||||
self.sender_socket.bind(f"tcp://{d_host}:{d_rank_offset + 1}")
|
||||
self.receiver_ack.bind(f"tcp://{d_host}:{d_rank_offset + 2}")
|
||||
self.sender_ack.connect(f"tcp://{p_host}:{p_rank_offset + 2}")
|
||||
|
||||
def initialize(self, local_hostname: str, metadata_server: str,
|
||||
protocol: str, device_name: str,
|
||||
metadata_backend: Union[str, None]) -> None:
|
||||
"""Initialize the mooncake instance."""
|
||||
if metadata_backend is None:
|
||||
self.engine.initialize(local_hostname, metadata_server, protocol,
|
||||
device_name)
|
||||
else:
|
||||
supported_backend = ["etcd", "redis"]
|
||||
metadata_backend = metadata_backend.lower()
|
||||
if metadata_backend not in supported_backend:
|
||||
raise ValueError(
|
||||
"Mooncake Configuration error. `metadata_backend`"
|
||||
f" should be one of {supported_backend}.")
|
||||
|
||||
self.engine.initialize_ext(local_hostname, metadata_server,
|
||||
protocol, device_name, metadata_backend)
|
||||
|
||||
def allocate_managed_buffer(self, length: int) -> int:
|
||||
"""Allocate a managed buffer of the specified length."""
|
||||
ret = self.engine.allocate_managed_buffer(length)
|
||||
if ret <= 0:
|
||||
logger.error("Allocation Return Error")
|
||||
raise Exception("Allocation Return Error")
|
||||
return ret
|
||||
|
||||
def free_managed_buffer(self, buffer: int, length: int) -> int:
|
||||
"""Free a previously allocated managed buffer."""
|
||||
return self.engine.free_managed_buffer(buffer, length)
|
||||
|
||||
def transfer_sync(self, buffer: int, peer_buffer_address: int,
|
||||
length: int) -> int:
|
||||
"""Synchronously transfer data to the specified address."""
|
||||
ret = self.engine.transfer_sync_read(self.remote_url, buffer,
|
||||
peer_buffer_address, length)
|
||||
if ret < 0:
|
||||
logger.error("Transfer Return Error")
|
||||
raise Exception("Transfer Return Error")
|
||||
return ret
|
||||
|
||||
def write_bytes_to_buffer(self, buffer: int, user_data: bytes,
|
||||
length: int) -> int:
|
||||
"""Write bytes to the allocated buffer."""
|
||||
return self.engine.write_bytes_to_buffer(buffer, user_data, length)
|
||||
|
||||
def read_bytes_from_buffer(self, buffer: int, length: int) -> bytes:
|
||||
"""Read bytes from the allocated buffer."""
|
||||
return self.engine.read_bytes_from_buffer(buffer, length)
|
||||
|
||||
def wait_for_ack(self, src_ptr: int, length: int) -> None:
|
||||
"""Asynchronously wait for ACK from the receiver."""
|
||||
ack = self.sender_ack.recv()
|
||||
if ack != b'ACK':
|
||||
logger.error("Failed to receive ACK from the receiver")
|
||||
|
||||
self.free_managed_buffer(src_ptr, length)
|
||||
|
||||
def send_bytes(self, user_data: bytes) -> None:
|
||||
"""Send bytes to the remote process."""
|
||||
length = len(user_data)
|
||||
src_ptr = self.allocate_managed_buffer(length)
|
||||
self.write_bytes_to_buffer(src_ptr, user_data, length)
|
||||
self.sender_socket.send_multipart(
|
||||
[struct.pack("!Q", src_ptr),
|
||||
struct.pack("!Q", length)])
|
||||
self.buffer_cleaner.submit(self.wait_for_ack, src_ptr, length)
|
||||
|
||||
def recv_bytes(self) -> bytes:
|
||||
"""Receive bytes from the remote process."""
|
||||
data = self.receiver_socket.recv_multipart()
|
||||
src_ptr = struct.unpack("!Q", data[0])[0]
|
||||
length = struct.unpack("!Q", data[1])[0]
|
||||
dst_ptr = self.allocate_managed_buffer(length)
|
||||
self.transfer_sync(dst_ptr, src_ptr, length)
|
||||
ret = self.read_bytes_from_buffer(dst_ptr, length)
|
||||
|
||||
# Buffer cleanup
|
||||
self.receiver_ack.send(b'ACK')
|
||||
self.free_managed_buffer(dst_ptr, length)
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
class MooncakePipe(KVPipeBase):
|
||||
"""MooncakeTransferEngine based Pipe implementation."""
|
||||
|
||||
def __init__(self,
|
||||
local_rank: int,
|
||||
config: KVTransferConfig,
|
||||
device: Optional[str] = None):
|
||||
"""Initialize the mooncake pipe and set related parameters."""
|
||||
self.config = config
|
||||
self.local_rank = local_rank
|
||||
self.kv_rank = self.config.kv_rank
|
||||
if device is None:
|
||||
self.device = self._select_device(self.config.kv_buffer_device)
|
||||
else:
|
||||
self.device = self._select_device(device)
|
||||
|
||||
self.transfer_engine = MooncakeTransferEngine(self.kv_rank,
|
||||
self.local_rank)
|
||||
self.transport_thread: Optional[ThreadPoolExecutor] = None
|
||||
self.none_tensor = torch.tensor([NONE_INT], device=self.device)
|
||||
|
||||
def _select_device(self, device: str) -> torch.device:
|
||||
"""Select available device (CUDA or CPU)."""
|
||||
logger.info("Selecting device: %s", device)
|
||||
if device == "cuda":
|
||||
return torch.device(f"cuda:{self.local_rank}")
|
||||
else:
|
||||
return torch.device("cpu")
|
||||
|
||||
def tensor_hash(self, tensor: torch.Tensor) -> int:
|
||||
"""Calculate the hash value of the tensor."""
|
||||
return hash(tensor.data_ptr())
|
||||
|
||||
def _send_impl(self, tensor: torch.Tensor) -> None:
|
||||
"""Implement the tensor sending logic using safetensors."""
|
||||
self.transfer_engine.send_bytes(safetensors_save({"tensor": tensor}))
|
||||
|
||||
def _recv_impl(self) -> torch.Tensor:
|
||||
"""Implement the tensor receiving logic using safetensors."""
|
||||
data = self.transfer_engine.recv_bytes()
|
||||
return safetensors_load(data)["tensor"].to(self.device)
|
||||
|
||||
def send_tensor(self, tensor: Optional[torch.Tensor]) -> None:
|
||||
"""Send tensor to the target process."""
|
||||
if self.transport_thread is None:
|
||||
self.transport_thread = ThreadPoolExecutor(max_workers=1)
|
||||
tensor = tensor if tensor is not None else self.none_tensor
|
||||
assert (len(tensor.shape) > 0)
|
||||
self.transport_thread.submit(self._send_impl, tensor)
|
||||
|
||||
def recv_tensor(self) -> Optional[torch.Tensor]:
|
||||
"""Receive tensor from other processes."""
|
||||
if self.transport_thread is None:
|
||||
self.transport_thread = ThreadPoolExecutor(max_workers=1)
|
||||
tensor = self.transport_thread.submit(self._recv_impl).result()
|
||||
if tensor.numel() == 1 and tensor.item() == NONE_INT:
|
||||
return None
|
||||
else:
|
||||
return tensor
|
||||
|
||||
def close(self) -> None:
|
||||
"""Cleanup logic when closing the pipe."""
|
||||
self.transfer_engine.sender_socket.close()
|
||||
self.transfer_engine.receiver_socket.close()
|
||||
self.transfer_engine.sender_ack.close()
|
||||
self.transfer_engine.receiver_ack.close()
|
||||
self.transfer_engine.context.term() # Terminate the ZMQ context
|
||||
logger.info("Closed the transfer engine and cleaned up resources.")
|
||||
280
vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py
Normal file
280
vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py
Normal file
@@ -0,0 +1,280 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
This module implements a PyNccl pipe for sending and receiving
|
||||
Optional[torch.Tensor] between distributed ranks with advanced
|
||||
communication features.
|
||||
|
||||
Key Features:
|
||||
- Supports sending and receiving tensors with metadata
|
||||
- Handles both CUDA and CPU device communications
|
||||
- Implements a non-blocking tensor transfer mechanism
|
||||
- Manages buffer size and provides backpressure control
|
||||
- Supports distributed process groups with configurable parameters
|
||||
"""
|
||||
|
||||
import threading
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import KVTransferConfig
|
||||
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
|
||||
from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase
|
||||
from vllm.distributed.utils import StatelessProcessGroup
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class BrokenPipeException(Exception):
|
||||
|
||||
def __init__(self, message):
|
||||
self.message = message
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
Metadata = dict[str, Optional[torch.Tensor]]
|
||||
|
||||
|
||||
class PyNcclPipe(KVPipeBase):
|
||||
|
||||
METADATA_LENGTH = 16
|
||||
MAX_TENSOR_DIMENSIONS = 14
|
||||
METADATA_DTYPE = torch.int64
|
||||
|
||||
def __init__(self,
|
||||
local_rank: int,
|
||||
config: KVTransferConfig,
|
||||
device: Optional[str] = None,
|
||||
port_offset: int = 0):
|
||||
self.config = config
|
||||
self.local_rank = local_rank
|
||||
self.kv_rank = self.config.kv_rank
|
||||
self.kv_parallel_size = self.config.kv_parallel_size
|
||||
if device is None:
|
||||
self.device = self._select_device(self.config.kv_buffer_device)
|
||||
else:
|
||||
self.device = self._select_device(device)
|
||||
|
||||
# build distributed connection and send/recv implementation
|
||||
store_timeout = self.config.get_from_extra_config("store_timeout", 300)
|
||||
self.group = StatelessProcessGroup.create(
|
||||
host=self.config.kv_ip,
|
||||
port=self.config.kv_port + port_offset,
|
||||
rank=self.kv_rank,
|
||||
world_size=self.kv_parallel_size,
|
||||
store_timeout=store_timeout,
|
||||
)
|
||||
# add a barrier to make sure the connection is initiated properly
|
||||
self.group.barrier()
|
||||
impl = self._get_device_send_recv_impl(self.group)
|
||||
self.device_send_func, self.device_recv_func = impl
|
||||
# set target rank
|
||||
self.target_rank_for_send = (self.kv_rank + 1) % self.kv_parallel_size
|
||||
self.target_rank_for_recv = (self.kv_rank - 1) % self.kv_parallel_size
|
||||
|
||||
# transportation-related variables
|
||||
self.transport_thread: Optional[ThreadPoolExecutor] = None
|
||||
self.buffer_size = 0
|
||||
self.buffer_size_lock = threading.Lock()
|
||||
self.buffer_size_thresh = self.config.kv_buffer_size
|
||||
|
||||
def _get_device_send_recv_impl(
|
||||
self, group: StatelessProcessGroup
|
||||
) -> tuple[Callable[[torch.Tensor, int], None], Callable[
|
||||
[torch.Tensor, int], None]]:
|
||||
|
||||
send: Callable[[torch.Tensor, int], None]
|
||||
recv: Callable[[torch.Tensor, int], None]
|
||||
if self.device.type == "cuda":
|
||||
# use PyNCCL for send / recv
|
||||
comm = PyNcclCommunicator(group, device=self.local_rank)
|
||||
comm.disabled = False
|
||||
send, recv = comm.send, comm.recv # type: ignore
|
||||
else:
|
||||
# This send / recv implementation here is NOT intended to transfer
|
||||
# KV caches (and should NOT be repurposed to transfer KV caches).
|
||||
# Currently it is only used to transmit control-plane messages
|
||||
# for PyNcclBuffer.
|
||||
send = group.send_obj
|
||||
|
||||
def my_recv(x, src):
|
||||
x[...] = group.recv_obj(src)
|
||||
|
||||
recv = my_recv
|
||||
|
||||
return send, recv
|
||||
|
||||
def _select_device(self, device: str):
|
||||
logger.info("Selecting device: %s", device)
|
||||
if device == "cuda":
|
||||
return torch.device(f"cuda:{self.local_rank}")
|
||||
else:
|
||||
return torch.device("cpu")
|
||||
|
||||
def _make_metadata(self, tensor: Optional[torch.Tensor]) -> Metadata:
|
||||
"""
|
||||
Create the metadata as a dictionary based on the input tensor.
|
||||
|
||||
Args:
|
||||
tensor: The input tensor or None if no tensor is provided.
|
||||
|
||||
Returns:
|
||||
metadata: A dictionary with the following keys:
|
||||
- "dtype": The data type of the tensor or None.
|
||||
- "shape": The shape of the tensor or None.
|
||||
"""
|
||||
if tensor is None:
|
||||
return {"dtype": None, "shape": None}
|
||||
else:
|
||||
return {"dtype": tensor.dtype, "shape": tensor.shape}
|
||||
|
||||
def _prepare_recv_buffer(self, metadata: Metadata) -> torch.Tensor:
|
||||
"""
|
||||
Create a buffer to receive the tensor based on the provided metadata.
|
||||
|
||||
Args:
|
||||
metadata: A dictionary with keys "dtype" and "shape",
|
||||
describing the tensor's data type and shape.
|
||||
|
||||
Returns:
|
||||
buffer: A tensor of the specified type and shape,
|
||||
allocated on `self.device`.
|
||||
"""
|
||||
return torch.empty(metadata["shape"],
|
||||
dtype=metadata["dtype"],
|
||||
device=self.device)
|
||||
|
||||
def _send_metadata(self, metadata: Metadata):
|
||||
"""
|
||||
Send the metadata dictionary to the target rank.
|
||||
|
||||
Args:
|
||||
metadata: A dictionary with keys "dtype" and "shape".
|
||||
"""
|
||||
self.group.send_obj(metadata, self.target_rank_for_send)
|
||||
|
||||
def _recv_metadata(self) -> Metadata:
|
||||
"""
|
||||
Receive the metadata dictionary from the target rank.
|
||||
|
||||
Returns:
|
||||
metadata: A dictionary with keys "dtype" and "shape"
|
||||
describing the tensor.
|
||||
"""
|
||||
return self.group.recv_obj(self.target_rank_for_recv)
|
||||
|
||||
def _send_impl(self, tensor: Optional[torch.Tensor]) -> None:
|
||||
"""
|
||||
The actual implementation of sending the tensor and its metadata to the
|
||||
target rank.
|
||||
|
||||
Args:
|
||||
tensor: The input tensor to be sent, or `None` if no tensor is
|
||||
being sent.
|
||||
"""
|
||||
metadata = self._make_metadata(tensor)
|
||||
self._send_metadata(metadata)
|
||||
if tensor is not None:
|
||||
self.device_send_func(tensor.to(self.device),
|
||||
self.target_rank_for_send)
|
||||
|
||||
def _recv_impl(self) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
The actual implementation of receiving a tensor and its metadata from
|
||||
the target rank.
|
||||
|
||||
Returns:
|
||||
buffer: The received tensor, or `None` if no tensor is received.
|
||||
"""
|
||||
metadata = self._recv_metadata()
|
||||
if metadata["dtype"] is None:
|
||||
return None
|
||||
buffer = self._prepare_recv_buffer(metadata)
|
||||
self.device_recv_func(buffer, self.target_rank_for_recv)
|
||||
|
||||
return buffer
|
||||
|
||||
def send_tensor_wrapper(self, tensor: Optional[torch.Tensor],
|
||||
tensor_size: int) -> None:
|
||||
"""
|
||||
Wrapper for _send_impl to handle exceptions and update buffer size.
|
||||
"""
|
||||
try:
|
||||
self._send_impl(tensor)
|
||||
|
||||
with self.buffer_size_lock:
|
||||
self.buffer_size -= tensor_size
|
||||
except Exception as e:
|
||||
logger.error("[rank%d]: Exception when trying to send %s, msg: %s",
|
||||
torch.distributed.get_rank(), str(tensor), str(e))
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
def block_if_full(self):
|
||||
"""
|
||||
Block the current thread if the buffer size is larger than the
|
||||
threshold.
|
||||
"""
|
||||
while self.buffer_size > self.buffer_size_thresh:
|
||||
logger.debug("KV cache transfer pipe is full. Waiting...")
|
||||
time.sleep(0.05)
|
||||
|
||||
def send_tensor(self, tensor: Optional[torch.Tensor]) -> None:
|
||||
"""
|
||||
Sends a tensor and its metadata to the destination rank in a
|
||||
non-blocking way.
|
||||
|
||||
Args:
|
||||
tensor: The tensor to send, or `None` if no tensor is being sent.
|
||||
"""
|
||||
if self.transport_thread is None:
|
||||
self.transport_thread = ThreadPoolExecutor(max_workers=1)
|
||||
|
||||
if tensor is not None:
|
||||
tensor_size = tensor.element_size() * tensor.numel()
|
||||
else:
|
||||
tensor_size = 0
|
||||
|
||||
self.block_if_full()
|
||||
|
||||
with self.buffer_size_lock:
|
||||
self.buffer_size += tensor_size
|
||||
|
||||
self.transport_thread.submit(self.send_tensor_wrapper, tensor,
|
||||
tensor_size)
|
||||
|
||||
def recv_tensor(self) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
Receives a tensor and its metadata from the source rank. Blocking call.
|
||||
|
||||
Args:
|
||||
tensor: The received tensor, or `None` if no tensor is received.
|
||||
"""
|
||||
if self.transport_thread is None:
|
||||
self.transport_thread = ThreadPoolExecutor(max_workers=1)
|
||||
|
||||
future = self.transport_thread.submit(self._recv_impl)
|
||||
|
||||
try:
|
||||
tensor = future.result()
|
||||
except Exception as e:
|
||||
logger.error("Encountering exception in KV receiving thread")
|
||||
logger.error("%s", e)
|
||||
logger.error("My device: %s", self.device)
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
raise e
|
||||
|
||||
return tensor
|
||||
|
||||
def close(self):
|
||||
"""
|
||||
Close the pipe and release associated resources.
|
||||
"""
|
||||
if hasattr(self,
|
||||
"transport_thread") and self.transport_thread is not None:
|
||||
self.transport_thread.shutdown()
|
||||
71
vllm/distributed/kv_transfer/kv_transfer_state.py
Normal file
71
vllm/distributed/kv_transfer/kv_transfer_state.py
Normal file
@@ -0,0 +1,71 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from vllm import envs
|
||||
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType
|
||||
from vllm.distributed.kv_transfer.kv_connector.factory import (
|
||||
KVConnectorFactory)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1,
|
||||
KVConnectorRole)
|
||||
from vllm.distributed.parallel_state import get_world_group
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
|
||||
_KV_CONNECTOR_AGENT: Optional[KVConnectorBaseType] = None
|
||||
|
||||
|
||||
def get_kv_transfer_group() -> KVConnectorBaseType:
|
||||
assert _KV_CONNECTOR_AGENT is not None, (
|
||||
"disaggregated KV cache transfer parallel group is not initialized")
|
||||
return _KV_CONNECTOR_AGENT
|
||||
|
||||
|
||||
def has_kv_transfer_group() -> bool:
|
||||
return _KV_CONNECTOR_AGENT is not None
|
||||
|
||||
|
||||
def is_v1_kv_transfer_group(
|
||||
connector: Optional[KVConnectorBaseType] = None) -> bool:
|
||||
"""Check if the KV connector is the v1 connector.
|
||||
If the argument is None, it will check the global KV connector
|
||||
|
||||
Args:
|
||||
connector: The KV connector to check. If None, it will check the
|
||||
global KV connector.
|
||||
|
||||
Note:
|
||||
This function will no-longer be needed after the v1 KV connector
|
||||
becomes the default.
|
||||
"""
|
||||
if connector is None:
|
||||
connector = _KV_CONNECTOR_AGENT
|
||||
|
||||
if connector is None:
|
||||
return False
|
||||
|
||||
return isinstance(connector, KVConnectorBase_V1)
|
||||
|
||||
|
||||
def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None:
|
||||
"""
|
||||
Initialize KV cache transfer parallel group.
|
||||
"""
|
||||
|
||||
global _KV_CONNECTOR_AGENT
|
||||
|
||||
if vllm_config.kv_transfer_config is None:
|
||||
return
|
||||
|
||||
if (vllm_config.kv_transfer_config.is_kv_transfer_instance
|
||||
and _KV_CONNECTOR_AGENT is None):
|
||||
if envs.VLLM_USE_V1:
|
||||
_KV_CONNECTOR_AGENT = KVConnectorFactory.create_connector_v1(
|
||||
config=vllm_config, role=KVConnectorRole.WORKER)
|
||||
else:
|
||||
_KV_CONNECTOR_AGENT = KVConnectorFactory.create_connector_v0(
|
||||
rank=get_world_group().rank,
|
||||
local_rank=get_world_group().local_rank,
|
||||
config=vllm_config,
|
||||
)
|
||||
1297
vllm/distributed/parallel_state.py
Normal file
1297
vllm/distributed/parallel_state.py
Normal file
File diff suppressed because it is too large
Load Diff
177
vllm/distributed/tpu_distributed_utils.py
Normal file
177
vllm/distributed/tpu_distributed_utils.py
Normal file
@@ -0,0 +1,177 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from collections import OrderedDict
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch_xla.distributed.spmd as xs
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class XlaQKVParallelLinear(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
qkv_linear: nn.Module,
|
||||
mesh: Optional["xs.Mesh"] = None):
|
||||
super().__init__()
|
||||
assert isinstance(qkv_linear, QKVParallelLinear)
|
||||
self.skip_bias_add = qkv_linear.skip_bias_add
|
||||
self.return_bias = qkv_linear.return_bias
|
||||
assert qkv_linear.tp_size == 1, "TP > 1 is only supported under SPMD."
|
||||
|
||||
self.q_weight: Parameter
|
||||
self.k_weight: Parameter
|
||||
self.v_weight: Parameter
|
||||
self.q_bias: Optional[Parameter]
|
||||
self.k_bias: Optional[Parameter]
|
||||
self.v_bias: Optional[Parameter]
|
||||
self._load_weights_from_qkv_linear(qkv_linear)
|
||||
if mesh is not None:
|
||||
self._shard_weight(mesh)
|
||||
|
||||
def _shard_weight(self, mesh: "xs.Mesh"):
|
||||
self.q_weight = Parameter(self.q_weight.to('xla'), requires_grad=False)
|
||||
self.k_weight = Parameter(self.k_weight.to('xla'), requires_grad=False)
|
||||
self.v_weight = Parameter(self.v_weight.to('xla'), requires_grad=False)
|
||||
xs.mark_sharding(self.q_weight, mesh, ('x', None))
|
||||
xs.mark_sharding(self.k_weight, mesh, ('x', None))
|
||||
xs.mark_sharding(self.v_weight, mesh, ('x', None))
|
||||
if self.q_bias is not None:
|
||||
assert self.k_bias is not None and self.v_bias is not None, \
|
||||
"QKVParallelLinear should have q, k, and v biases together."
|
||||
self.q_bias = Parameter(self.q_bias.to('xla'), requires_grad=False)
|
||||
xs.mark_sharding(self.q_bias, mesh, ('x', ))
|
||||
self.k_bias = Parameter(self.k_bias.to('xla'), requires_grad=False)
|
||||
xs.mark_sharding(self.k_bias, mesh, ('x', ))
|
||||
self.v_bias = Parameter(self.v_bias.to('xla'), requires_grad=False)
|
||||
xs.mark_sharding(self.v_bias, mesh, ('x', ))
|
||||
|
||||
def _load_weights_from_qkv_linear(self, qkv_linear: nn.Module):
|
||||
q_proj_size, k_proj_size, _ = qkv_linear.output_sizes
|
||||
# The weight of qkv linear is a concatenation of q, k, and v weights
|
||||
# along the output dimension.
|
||||
qkv_weight = qkv_linear.weight.data.cpu()
|
||||
q_weight = Parameter(qkv_weight[:q_proj_size], requires_grad=False)
|
||||
k_weight = Parameter(qkv_weight[q_proj_size:q_proj_size + k_proj_size],
|
||||
requires_grad=False)
|
||||
v_weight = Parameter(qkv_weight[q_proj_size + k_proj_size:],
|
||||
requires_grad=False)
|
||||
self.register_parameter("q_weight", q_weight)
|
||||
self.register_parameter("k_weight", k_weight)
|
||||
self.register_parameter("v_weight", v_weight)
|
||||
|
||||
if qkv_linear.bias is not None:
|
||||
q_bias = Parameter(qkv_linear.bias[:q_proj_size],
|
||||
requires_grad=False)
|
||||
k_bias = Parameter(qkv_linear.bias[q_proj_size:q_proj_size +
|
||||
k_proj_size],
|
||||
requires_grad=False)
|
||||
v_bias = Parameter(qkv_linear.bias[q_proj_size + k_proj_size:],
|
||||
requires_grad=False)
|
||||
self.register_parameter("q_bias", q_bias)
|
||||
self.register_parameter("k_bias", k_bias)
|
||||
self.register_parameter("v_bias", v_bias)
|
||||
else:
|
||||
self.register_parameter("q_bias", None)
|
||||
self.register_parameter("k_bias", None)
|
||||
self.register_parameter("v_bias", None)
|
||||
|
||||
def forward(self, input):
|
||||
# Same forward functionality as QKVParallelLinear, but doing qkv porj
|
||||
# separately.
|
||||
q_bias = self.q_bias if not self.skip_bias_add else None
|
||||
k_bias = self.k_bias if not self.skip_bias_add else None
|
||||
v_bias = self.v_bias if not self.skip_bias_add else None
|
||||
q_proj = F.linear(input, self.q_weight, q_bias)
|
||||
k_proj = F.linear(input, self.k_weight, k_bias)
|
||||
v_proj = F.linear(input, self.v_weight, v_bias)
|
||||
# The q/k/v projections will be split outside of the QKVParallelLinear.
|
||||
# Because we are replacing XlaQKVParallelLinear with the
|
||||
# QKVParallelLinear, we need to concatenate q, k, and v projections to
|
||||
# match the output shape of the QKVParallelLinear implementation even if
|
||||
# it seems to be redundant.
|
||||
# The concat and the following split will be noop, and should be
|
||||
# optimized away by the compiler.
|
||||
qkv_proj = torch.cat([q_proj, k_proj, v_proj], dim=-1)
|
||||
output_bias = torch.cat([q_bias, k_bias, v_bias], dim=-1) if \
|
||||
self.skip_bias_add else None
|
||||
if not self.return_bias:
|
||||
return qkv_proj
|
||||
return qkv_proj, output_bias
|
||||
|
||||
|
||||
def partition_column_parallel_linear(layer: torch.nn.Module,
|
||||
mesh: xs.Mesh) -> torch.nn.Module:
|
||||
assert isinstance(layer, ColumnParallelLinear)
|
||||
xs.mark_sharding(layer.weight, mesh, ('x', None))
|
||||
logger.debug("Applied column-parallel sharding to %s", layer)
|
||||
return layer
|
||||
|
||||
|
||||
def partition_row_parallel_linear(layer: torch.nn.Module,
|
||||
mesh: xs.Mesh) -> torch.nn.Module:
|
||||
assert isinstance(layer, RowParallelLinear)
|
||||
xs.mark_sharding(layer.weight, mesh, (None, 'x'))
|
||||
logger.debug("Applied row-parallel sharding to %s", layer)
|
||||
return layer
|
||||
|
||||
|
||||
def partition_qkv_parallel_linear(layer: torch.nn.Module,
|
||||
mesh: xs.Mesh) -> torch.nn.Module:
|
||||
assert isinstance(layer, QKVParallelLinear)
|
||||
xla_layer = XlaQKVParallelLinear(layer, mesh)
|
||||
logger.debug("Applied qkv parallel sharding to %s", layer)
|
||||
return xla_layer
|
||||
|
||||
|
||||
MODULE_TYPE_TO_WRAPPING_FUNC = OrderedDict([
|
||||
("QKVParallelLinear", partition_qkv_parallel_linear),
|
||||
("ColumnParallelLinear", partition_column_parallel_linear),
|
||||
("RowParallelLinear", partition_row_parallel_linear),
|
||||
])
|
||||
|
||||
|
||||
def get_fqn(module):
|
||||
# Get the fully qualified name of the module
|
||||
return module.__class__.__qualname__
|
||||
|
||||
|
||||
def shard_model(model: torch.nn.Module, mesh: "xs.Mesh") -> None:
|
||||
"""
|
||||
Recursively check a PyTorch model and apply appropriate sharding based on
|
||||
the MODULE_TYPE_TO_WRAPPING_FUNC mapping.
|
||||
|
||||
Args:
|
||||
model: torch.nn.Module to process
|
||||
mesh: An XLA SPMD mesh object used for sharding
|
||||
"""
|
||||
|
||||
def _process_module(module, name=None, parent=None):
|
||||
for module_type, wrapping_func in MODULE_TYPE_TO_WRAPPING_FUNC.items():
|
||||
if get_fqn(module) == module_type:
|
||||
wrapped_module = wrapping_func(module, mesh)
|
||||
|
||||
assert parent is not None and name is not None, (
|
||||
"Top Level module is not expected to be wrapped.")
|
||||
if wrapped_module is not module:
|
||||
# Wrapped module and module are different py object.
|
||||
# The original module should be replaced by the
|
||||
# wrapped_module.
|
||||
logger.debug("replace %s with %s", module, wrapped_module)
|
||||
setattr(parent, name, wrapped_module)
|
||||
|
||||
module = wrapped_module
|
||||
break
|
||||
|
||||
for child_name, child_module in list(module.named_children()):
|
||||
_process_module(child_module, child_name, module)
|
||||
|
||||
_process_module(model)
|
||||
536
vllm/distributed/utils.py
Normal file
536
vllm/distributed/utils.py
Normal file
@@ -0,0 +1,536 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Adapted from
|
||||
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
|
||||
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
||||
import dataclasses
|
||||
import os
|
||||
import pickle
|
||||
import socket
|
||||
import sys
|
||||
import time
|
||||
import uuid
|
||||
from collections import deque
|
||||
from collections.abc import Sequence
|
||||
from datetime import timedelta
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from torch.distributed import ProcessGroup, TCPStore
|
||||
from torch.distributed.distributed_c10d import (Backend, PrefixStore,
|
||||
_get_default_timeout,
|
||||
_unregister_process_group)
|
||||
from torch.distributed.rendezvous import rendezvous
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import get_tcp_uri, is_torch_equal_or_newer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# We prefer to use os.sched_yield as it results in tighter polling loops,
|
||||
# measured to be around 3e-7 seconds. However on earlier versions of Python
|
||||
# os.sched_yield() does not release the GIL, so we fall back to time.sleep(0)
|
||||
USE_SCHED_YIELD = ((sys.version_info[:3] >= (3, 11, 1))
|
||||
or (sys.version_info[:2] == (3, 10)
|
||||
and sys.version_info[2] >= 8))
|
||||
|
||||
|
||||
def sched_yield():
|
||||
if USE_SCHED_YIELD:
|
||||
os.sched_yield()
|
||||
else:
|
||||
time.sleep(0)
|
||||
|
||||
|
||||
def ensure_divisibility(numerator, denominator):
|
||||
"""Ensure that numerator is divisible by the denominator."""
|
||||
assert numerator % denominator == 0, "{} is not divisible by {}".format(
|
||||
numerator, denominator)
|
||||
|
||||
|
||||
def divide(numerator, denominator):
|
||||
"""Ensure that numerator is divisible by the denominator and return
|
||||
the division value."""
|
||||
ensure_divisibility(numerator, denominator)
|
||||
return numerator // denominator
|
||||
|
||||
|
||||
def split_tensor_along_last_dim(
|
||||
tensor: torch.Tensor,
|
||||
num_partitions: int,
|
||||
contiguous_split_chunks: bool = False,
|
||||
) -> Sequence[torch.Tensor]:
|
||||
""" Split a tensor along its last dimension.
|
||||
|
||||
Arguments:
|
||||
tensor: input tensor.
|
||||
num_partitions: number of partitions to split the tensor
|
||||
contiguous_split_chunks: If True, make each chunk contiguous
|
||||
in memory.
|
||||
|
||||
Returns:
|
||||
A list of Tensors
|
||||
"""
|
||||
# Get the size and dimension.
|
||||
last_dim = tensor.dim() - 1
|
||||
last_dim_size = divide(tensor.size()[last_dim], num_partitions)
|
||||
# Split.
|
||||
tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
|
||||
# NOTE: torch.split does not create contiguous tensors by default.
|
||||
if contiguous_split_chunks:
|
||||
return tuple(chunk.contiguous() for chunk in tensor_list)
|
||||
|
||||
return tensor_list
|
||||
|
||||
|
||||
def get_pp_indices(num_hidden_layers: int, pp_rank: int,
|
||||
pp_size: int) -> tuple[int, int]:
|
||||
"""Try to evenly distribute layers across partitions.
|
||||
|
||||
If the number of layers is not divisible by the number of partitions,
|
||||
the remaining layers are evenly distributed across all but the last
|
||||
partition. The last partition is excluded because it often contains an
|
||||
additional norm layer and we are attempting to balance compute.
|
||||
|
||||
If `pp_size > 2` and the number of remaining layers is
|
||||
`0 < x <= pp_size - 2` then the remaining layers are evenly distributed
|
||||
across the middle partitions. The first and last partitions are excluded
|
||||
because they contain the input and output embeddings respectively and we
|
||||
are attempting to reduce maximum memory consumption across partitions.
|
||||
"""
|
||||
partition_list_str = envs.VLLM_PP_LAYER_PARTITION
|
||||
if partition_list_str is not None:
|
||||
try:
|
||||
partitions = [
|
||||
int(layer) for layer in partition_list_str.split(",")
|
||||
]
|
||||
except ValueError as err:
|
||||
raise ValueError("Invalid partition string: {}".format(
|
||||
partition_list_str)) from err
|
||||
if len(partitions) != pp_size:
|
||||
raise ValueError(f"{len(partitions)=} does not match {pp_size=}.")
|
||||
if sum(partitions) != num_hidden_layers:
|
||||
raise ValueError(
|
||||
f"{sum(partitions)=} does not match {num_hidden_layers=}.")
|
||||
else:
|
||||
layers_per_partition = num_hidden_layers // pp_size
|
||||
partitions = [layers_per_partition for _ in range(pp_size)]
|
||||
|
||||
if remaining_layers := num_hidden_layers % pp_size:
|
||||
for i in range(2, remaining_layers + 2):
|
||||
partitions[-i] += 1
|
||||
logger.info(
|
||||
"Hidden layers were unevenly partitioned: [%s]. "
|
||||
"This can be manually overridden using the "
|
||||
"VLLM_PP_LAYER_PARTITION environment variable",
|
||||
",".join(str(p) for p in partitions))
|
||||
|
||||
start_layer = sum(partitions[:pp_rank])
|
||||
end_layer = start_layer + partitions[pp_rank]
|
||||
|
||||
return (start_layer, end_layer)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class StatelessProcessGroup:
|
||||
"""A dataclass to hold a metadata store, and the rank, world_size of the
|
||||
group. Only use it to communicate metadata between processes.
|
||||
For data-plane communication, create NCCL-related objects.
|
||||
"""
|
||||
rank: int
|
||||
world_size: int
|
||||
store: torch._C._distributed_c10d.Store
|
||||
|
||||
# stores a reference to the socket so that the file descriptor stays alive
|
||||
socket: Optional[socket.socket]
|
||||
|
||||
data_expiration_seconds: int = 3600 # 1 hour
|
||||
|
||||
# dst rank -> counter
|
||||
send_dst_counter: dict[int, int] = dataclasses.field(default_factory=dict)
|
||||
# src rank -> counter
|
||||
recv_src_counter: dict[int, int] = dataclasses.field(default_factory=dict)
|
||||
broadcast_send_counter: int = 0
|
||||
broadcast_recv_src_counter: dict[int, int] = dataclasses.field(
|
||||
default_factory=dict)
|
||||
|
||||
# A deque to store the data entries, with key and timestamp.
|
||||
entries: deque[tuple[str,
|
||||
float]] = dataclasses.field(default_factory=deque)
|
||||
|
||||
def __post_init__(self):
|
||||
assert self.rank < self.world_size
|
||||
self.send_dst_counter = {i: 0 for i in range(self.world_size)}
|
||||
self.recv_src_counter = {i: 0 for i in range(self.world_size)}
|
||||
self.broadcast_recv_src_counter = {
|
||||
i: 0
|
||||
for i in range(self.world_size)
|
||||
}
|
||||
|
||||
def send_obj(self, obj: Any, dst: int):
|
||||
"""Send an object to a destination rank."""
|
||||
self.expire_data()
|
||||
key = f"send_to/{dst}/{self.send_dst_counter[dst]}"
|
||||
self.store.set(key, pickle.dumps(obj))
|
||||
self.send_dst_counter[dst] += 1
|
||||
self.entries.append((key, time.time()))
|
||||
|
||||
def expire_data(self):
|
||||
"""Expire data that is older than `data_expiration_seconds` seconds."""
|
||||
while self.entries:
|
||||
# check the oldest entry
|
||||
key, timestamp = self.entries[0]
|
||||
if time.time() - timestamp > self.data_expiration_seconds:
|
||||
self.store.delete_key(key)
|
||||
self.entries.popleft()
|
||||
else:
|
||||
break
|
||||
|
||||
def recv_obj(self, src: int) -> Any:
|
||||
"""Receive an object from a source rank."""
|
||||
obj = pickle.loads(
|
||||
self.store.get(
|
||||
f"send_to/{self.rank}/{self.recv_src_counter[src]}"))
|
||||
self.recv_src_counter[src] += 1
|
||||
return obj
|
||||
|
||||
def broadcast_obj(self, obj: Optional[Any], src: int) -> Any:
|
||||
"""Broadcast an object from a source rank to all other ranks.
|
||||
It does not clean up after all ranks have received the object.
|
||||
Use it for limited times, e.g., for initialization.
|
||||
"""
|
||||
if self.rank == src:
|
||||
self.expire_data()
|
||||
key = (f"broadcast_from/{src}/"
|
||||
f"{self.broadcast_send_counter}")
|
||||
self.store.set(key, pickle.dumps(obj))
|
||||
self.broadcast_send_counter += 1
|
||||
self.entries.append((key, time.time()))
|
||||
return obj
|
||||
else:
|
||||
key = (f"broadcast_from/{src}/"
|
||||
f"{self.broadcast_recv_src_counter[src]}")
|
||||
recv_obj = pickle.loads(self.store.get(key))
|
||||
self.broadcast_recv_src_counter[src] += 1
|
||||
return recv_obj
|
||||
|
||||
def all_gather_obj(self, obj: Any) -> list[Any]:
|
||||
"""All gather an object from all ranks."""
|
||||
gathered_objs = []
|
||||
for i in range(self.world_size):
|
||||
if i == self.rank:
|
||||
gathered_objs.append(obj)
|
||||
self.broadcast_obj(obj, src=self.rank)
|
||||
else:
|
||||
recv_obj = self.broadcast_obj(None, src=i)
|
||||
gathered_objs.append(recv_obj)
|
||||
return gathered_objs
|
||||
|
||||
def barrier(self, timeout: float = 30.0):
|
||||
"""A robust barrier to synchronize all ranks.
|
||||
|
||||
|
||||
Uses a multi-phase approach to ensure all processes reach the barrier
|
||||
before proceeding:
|
||||
|
||||
1. Each process signals it has reached the barrier
|
||||
|
||||
2. Each process signals that it has confirmed the arrival of all other
|
||||
ranks.
|
||||
|
||||
3. Rank 0 waits for all other ranks to signal their departure to ensure
|
||||
that all ranks have departed the barrier first.
|
||||
|
||||
Args:
|
||||
timeout: Maximum time in seconds to wait for each phase (in seconds)
|
||||
|
||||
|
||||
Raises:
|
||||
RuntimeError: If coordination fails or times out
|
||||
"""
|
||||
# Generate a barrier ID that is globally unique
|
||||
try:
|
||||
if self.rank == 0:
|
||||
barrier_id = f"barrier_{uuid.uuid4()}"
|
||||
self.broadcast_obj(barrier_id, src=0)
|
||||
else:
|
||||
barrier_id = self.broadcast_obj(None, src=0)
|
||||
except Exception as e:
|
||||
raise RuntimeError("Failed to broadcast barrier_id") from e
|
||||
|
||||
# Phase 1: Signal arrival at barrier
|
||||
# Wait for all processes to arrive
|
||||
# We need all ranks to confirm the arrival of all other ranks.
|
||||
# This is the key synchronization point.
|
||||
arrival_key = f"arrival_{barrier_id}_{self.rank}"
|
||||
try:
|
||||
self.store.set(arrival_key, b"1")
|
||||
except Exception as e:
|
||||
raise RuntimeError("Failed to signal barrier arrival") from e
|
||||
|
||||
start_time = time.time()
|
||||
processes_arrived: set[int] = set()
|
||||
|
||||
while len(processes_arrived) < self.world_size:
|
||||
# Check for timeout
|
||||
cur_time = time.time()
|
||||
if cur_time - start_time > timeout:
|
||||
raise RuntimeError("Barrier timed out after %f seconds",
|
||||
timeout)
|
||||
|
||||
# Check for each process
|
||||
for i in range(self.world_size):
|
||||
if i in processes_arrived:
|
||||
continue
|
||||
|
||||
key = f"arrival_{barrier_id}_{i}"
|
||||
try:
|
||||
# Try to get the key - if it exists, we'll get a value
|
||||
# If it doesn't exist, it will throw an exception
|
||||
self.store.get(key)
|
||||
processes_arrived.add(i)
|
||||
except KeyError:
|
||||
# Key doesn't exist yet
|
||||
pass
|
||||
except Exception as check_e:
|
||||
logger.debug("Error checking key existence: %s", check_e)
|
||||
sched_yield()
|
||||
|
||||
# Short sleep to avoid tight polling
|
||||
if len(processes_arrived) < self.world_size:
|
||||
sched_yield()
|
||||
|
||||
# Phase 2: Signal departure from barrier
|
||||
# We only care to block at this stage in rank 0, which runs the
|
||||
# server side of the TCPStore. We want to make sure that all
|
||||
# clients have departed the barrier before rank 0 in case the
|
||||
# next thing after the barrier is a shutdown, including tearing
|
||||
# down the TCPStore. Other ranks can exit the barrier immediately
|
||||
# after signaling their departure.
|
||||
departure_key = f"departure_{barrier_id}_{self.rank}"
|
||||
try:
|
||||
self.store.set(departure_key, b"1")
|
||||
except Exception as e:
|
||||
raise RuntimeError("Failed to signal barrier departure") from e
|
||||
|
||||
if self.rank != 0:
|
||||
return
|
||||
|
||||
# Make rank 0 wait for all processes to signal departure
|
||||
start_time = time.time()
|
||||
processes_departed: set[int] = set()
|
||||
|
||||
while len(processes_departed) < self.world_size:
|
||||
# Check for timeout
|
||||
if time.time() - start_time > timeout:
|
||||
raise RuntimeError("Barrier departure timed out after %f s",
|
||||
timeout)
|
||||
|
||||
# Check for each process
|
||||
for i in range(self.world_size):
|
||||
if i in processes_departed:
|
||||
continue
|
||||
|
||||
key = f"departure_{barrier_id}_{i}"
|
||||
try:
|
||||
# Try to get the key - if it exists, we'll get a value
|
||||
# If it doesn't exist, it will throw an exception
|
||||
self.store.get(key)
|
||||
processes_departed.add(i)
|
||||
except KeyError:
|
||||
# Key doesn't exist yet
|
||||
pass
|
||||
except Exception as check_e:
|
||||
logger.debug("Error checking key existence: %s", check_e)
|
||||
sched_yield()
|
||||
|
||||
# Short sleep to avoid tight polling
|
||||
if len(processes_departed) < self.world_size:
|
||||
sched_yield()
|
||||
|
||||
# Clean up keys to avoid leaking memory in the store
|
||||
for i in range(self.world_size):
|
||||
try:
|
||||
self.store.delete_key(f"arrival_{barrier_id}_{i}")
|
||||
except Exception:
|
||||
logger.debug("Error deleting key: %s",
|
||||
f'arrival_{barrier_id}_{i}')
|
||||
|
||||
try:
|
||||
self.store.delete_key(f"departure_{barrier_id}_{i}")
|
||||
except Exception:
|
||||
logger.debug("Error deleting key: %s",
|
||||
f'departure_{barrier_id}_{i}')
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
host: str,
|
||||
port: int,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
data_expiration_seconds: int = 3600,
|
||||
store_timeout: int = 300,
|
||||
) -> "StatelessProcessGroup":
|
||||
"""A replacement for `torch.distributed.init_process_group` that does not
|
||||
pollute the global state.
|
||||
|
||||
If we have process A and process B called `torch.distributed.init_process_group`
|
||||
to form a group, and then we want to form another group with process A, B, C,
|
||||
D, it is not possible in PyTorch, because process A and process B have already
|
||||
formed a group, and process C and process D cannot join that group. This
|
||||
function is a workaround for this issue.
|
||||
|
||||
`torch.distributed.init_process_group` is a global call, while this function
|
||||
is a stateless call. It will return a `StatelessProcessGroup` object that can be
|
||||
used for exchanging metadata. With this function, process A and process B
|
||||
can call `StatelessProcessGroup.create` to form a group, and then process A, B,
|
||||
C, and D can call `StatelessProcessGroup.create` to form another group.
|
||||
""" # noqa
|
||||
launch_server = rank == 0
|
||||
if launch_server:
|
||||
# listen on the specified interface (instead of 0.0.0.0)
|
||||
listen_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
listen_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
listen_socket.bind((host, port))
|
||||
listen_socket.listen()
|
||||
listen_fd = listen_socket.fileno()
|
||||
else:
|
||||
listen_socket = None
|
||||
listen_fd = None
|
||||
|
||||
store = TCPStore(
|
||||
host_name=host,
|
||||
port=port,
|
||||
world_size=world_size,
|
||||
is_master=launch_server,
|
||||
timeout=timedelta(seconds=store_timeout),
|
||||
use_libuv=False, # for now: github.com/pytorch/pytorch/pull/150215
|
||||
master_listen_fd=listen_fd,
|
||||
)
|
||||
|
||||
return StatelessProcessGroup(
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
store=store,
|
||||
socket=listen_socket,
|
||||
data_expiration_seconds=data_expiration_seconds)
|
||||
|
||||
|
||||
def init_gloo_process_group(backend: Backend, prefix_store: PrefixStore,
|
||||
group_rank: int, group_size: int,
|
||||
timeout: timedelta) -> ProcessGroup:
|
||||
"""
|
||||
Stateless init ProcessGroup with gloo backend compatible with
|
||||
different torch versions.
|
||||
"""
|
||||
if is_torch_equal_or_newer("2.6"):
|
||||
pg = ProcessGroup(
|
||||
prefix_store,
|
||||
group_rank,
|
||||
group_size,
|
||||
)
|
||||
else:
|
||||
options = ProcessGroup.Options(backend=backend)
|
||||
pg = ProcessGroup(
|
||||
prefix_store,
|
||||
group_rank,
|
||||
group_size,
|
||||
options,
|
||||
)
|
||||
from torch.distributed.distributed_c10d import ProcessGroupGloo
|
||||
backend_class = ProcessGroupGloo(prefix_store,
|
||||
group_rank,
|
||||
group_size,
|
||||
timeout=timeout)
|
||||
backend_type = ProcessGroup.BackendType.GLOO
|
||||
device = torch.device("cpu")
|
||||
if is_torch_equal_or_newer("2.6"):
|
||||
# _set_default_backend is supported in torch >= 2.6
|
||||
pg._set_default_backend(backend_type)
|
||||
backend_class._set_sequence_number_for_group()
|
||||
|
||||
pg._register_backend(device, backend_type, backend_class)
|
||||
return pg
|
||||
|
||||
|
||||
def stateless_init_torch_distributed_process_group(
|
||||
host: str, port: int, rank: int, world_size: int,
|
||||
backend: str) -> ProcessGroup:
|
||||
"""
|
||||
A replacement for `torch.distributed.init_process_group` that does not
|
||||
pollute the global state. The created ProcessGroup object can be used for
|
||||
some operations such as `allreduce`, because it does not depend on the
|
||||
global rank. However, some operations such as `broadcast` cannot be used
|
||||
because it depends on the global rank.
|
||||
|
||||
# TODO: ask for help from PyTorch team if we need the `broadcast` operation.
|
||||
|
||||
This function is useful when we are not sure about the total number of
|
||||
processes in the process group. For example, we may have process
|
||||
1, 2, ..., 8 who want to communicate, and process 9 might be the same
|
||||
process as process 1, or it might be a different process; process 10
|
||||
might be the same process as process 5, or it might be a different process.
|
||||
In this case, how can we reliably form a communication channel within
|
||||
process 9 and 10, without affecting the communication channel within
|
||||
process 1, 2, ..., 8?
|
||||
|
||||
One possible solution is to figure out if process 9 and 10 are the same
|
||||
as process 1 and 5 beforehand, and then form a communication channel
|
||||
based on the information, adjusting the ranks and world_size etc. However,
|
||||
figuring out the information is not always easy, and it will interfere
|
||||
with the main communication channel.
|
||||
|
||||
Our solution is to always form a communication channel with process 1, 2,
|
||||
..., 8, and then use this function to form another communication channel
|
||||
with process 9 and 10. This way, regardless of whether process 9 and 10
|
||||
are the same as process 1 and 5, the main communication channel is
|
||||
always formed with process 1, 2, ..., 8, and the additional communication
|
||||
channel is formed with process 9 and 10.
|
||||
"""
|
||||
init_method = get_tcp_uri(host, port)
|
||||
backend = Backend(backend) # it is basically string
|
||||
timeout = _get_default_timeout(backend)
|
||||
|
||||
store, rank, world_size = next(
|
||||
rendezvous(init_method, rank, world_size, timeout=timeout))
|
||||
store.set_timeout(timeout)
|
||||
|
||||
group_rank = rank
|
||||
group_size = world_size
|
||||
|
||||
# Use a PrefixStore to avoid accidental overrides of keys used by
|
||||
# different systems (e.g. RPC) in case the store is multi-tenant.
|
||||
prefix_store = PrefixStore(init_method, store)
|
||||
|
||||
if backend == "gloo":
|
||||
return init_gloo_process_group(backend=backend,
|
||||
prefix_store=prefix_store,
|
||||
group_rank=group_rank,
|
||||
group_size=group_size,
|
||||
timeout=timeout)
|
||||
from vllm.platforms import current_platform
|
||||
return current_platform.stateless_init_device_torch_dist_pg(
|
||||
backend=backend,
|
||||
prefix_store=prefix_store,
|
||||
group_rank=group_rank,
|
||||
group_size=group_size,
|
||||
timeout=timeout)
|
||||
|
||||
|
||||
def stateless_destroy_torch_distributed_process_group(
|
||||
pg: ProcessGroup) -> None:
|
||||
"""
|
||||
Destroy ProcessGroup returned by
|
||||
stateless_init_torch_distributed_process_group().
|
||||
"""
|
||||
if is_torch_equal_or_newer("2.7"):
|
||||
pg.shutdown()
|
||||
else:
|
||||
# Lazy import for non-CUDA backends.
|
||||
from torch.distributed.distributed_c10d import _shutdown_backend
|
||||
_shutdown_backend(pg)
|
||||
|
||||
_unregister_process_group(pg.group_name)
|
||||
Reference in New Issue
Block a user