first commit

This commit is contained in:
2026-03-10 13:31:25 +08:00
parent ba974cecfa
commit b62b889355
2604 changed files with 438977 additions and 0 deletions

View File

@@ -0,0 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from .communication_op import *
from .parallel_state import *
from .utils import *

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -0,0 +1,41 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Optional, Union
import torch
import torch.distributed
from .parallel_state import get_tp_group
def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
"""All-reduce the input tensor across model parallel group."""
return get_tp_group().all_reduce(input_)
def tensor_model_parallel_all_gather(input_: torch.Tensor,
dim: int = -1) -> torch.Tensor:
"""All-gather the input tensor across model parallel group."""
return get_tp_group().all_gather(input_, dim)
def tensor_model_parallel_reduce_scatter(input_: torch.Tensor,
dim: int = -1) -> torch.Tensor:
"""Reduce-Scatter the input tensor across model parallel group."""
return get_tp_group().reduce_scatter(input_, dim)
def tensor_model_parallel_gather(input_: torch.Tensor,
dst: int = 0,
dim: int = -1) -> Optional[torch.Tensor]:
"""Gather the input tensor across model parallel group."""
return get_tp_group().gather(input_, dst, dim)
def broadcast_tensor_dict(tensor_dict: Optional[dict[Any, Union[torch.Tensor,
Any]]] = None,
src: int = 0):
if not torch.distributed.is_initialized():
return tensor_dict
return get_tp_group().broadcast_tensor_dict(tensor_dict, src)

View File

@@ -0,0 +1,440 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Optional
import torch
import torch.distributed as dist
import vllm.envs as envs
from vllm.distributed import get_dp_group, get_ep_group
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
from vllm.utils import has_deep_ep, has_pplx
from vllm.utils.flashinfer import has_flashinfer_all2all
from .base_device_communicator import All2AllManagerBase, Cache
if has_flashinfer_all2all():
from flashinfer.comm import Mapping
from flashinfer.comm.mnnvl import MnnvlConfig
from flashinfer.comm.trtllm_alltoall import MnnvlMoe
logger = init_logger(__name__)
class NaiveAll2AllManager(All2AllManagerBase):
"""
A naive implementation of all2all communication.
It uses all-reduce under the hood, which is not
efficient at all. The main purpose is for testing and
debugging.
"""
def __init__(self, cpu_group):
super().__init__(cpu_group)
def naive_multicast(self, x: torch.Tensor,
cu_tokens_across_sp_cpu: torch.Tensor,
is_sequence_parallel: bool) -> torch.Tensor:
assert (len(x.shape) == 2)
buffer = torch.empty((cu_tokens_across_sp_cpu[-1], x.size(1)),
device=x.device,
dtype=x.dtype)
rank = self.rank if is_sequence_parallel else self.dp_rank
world_size = (self.world_size
if is_sequence_parallel else self.dp_world_size)
start = 0 if rank == 0 else cu_tokens_across_sp_cpu[rank - 1]
end = cu_tokens_across_sp_cpu[rank]
buffer[start:end, :].copy_(x)
for idx in range(world_size):
start = 0 if idx == 0 else cu_tokens_across_sp_cpu[idx - 1]
end = cu_tokens_across_sp_cpu[idx]
get_ep_group().broadcast(buffer[start:end, :], idx)
return buffer
def dispatch(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False
) -> tuple[torch.Tensor, torch.Tensor]:
sp_size = self.tp_group.world_size if is_sequence_parallel else 1
dp_metadata = get_forward_context().dp_metadata
cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size)
hidden_states = self.naive_multicast(hidden_states,
cu_tokens_across_sp_cpu,
is_sequence_parallel)
router_logits = self.naive_multicast(router_logits,
cu_tokens_across_sp_cpu,
is_sequence_parallel)
return hidden_states, router_logits
def combine(self,
hidden_states: torch.Tensor,
is_sequence_parallel: bool = False) -> torch.Tensor:
ep_rank = self.rank if is_sequence_parallel else self.dp_rank
dp_metadata = get_forward_context().dp_metadata
sp_size = self.tp_group.world_size if is_sequence_parallel else 1
cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size)
start = 0 if ep_rank == 0 else cu_tokens_across_sp_cpu[ep_rank - 1]
end = cu_tokens_across_sp_cpu[ep_rank]
all_hidden_states = get_ep_group().all_reduce(hidden_states)
hidden_states = all_hidden_states[start:end, :]
return hidden_states
def destroy(self):
pass
class AgRsAll2AllManager(All2AllManagerBase):
"""
An implementation of all2all communication based on
all-gather (dispatch) and reduce-scatter (combine).
"""
def __init__(self, cpu_group):
super().__init__(cpu_group)
def dispatch(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Gather hidden_states and router_logits from all dp ranks.
"""
sizes = get_forward_context(
).dp_metadata.get_chunk_sizes_across_dp_rank()
dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
assert sizes[dist_group.rank_in_group] == hidden_states.shape[0]
hidden_states, router_logits = dist_group.all_gatherv(
[hidden_states, router_logits],
dim=0,
sizes=sizes,
)
return hidden_states, router_logits
def combine(self,
hidden_states: torch.Tensor,
is_sequence_parallel: bool = False) -> torch.Tensor:
"""
Reduce-scatter hidden_states across all dp ranks.
"""
sizes = get_forward_context(
).dp_metadata.get_chunk_sizes_across_dp_rank()
dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
hidden_states = dist_group.reduce_scatterv(hidden_states,
dim=0,
sizes=sizes)
return hidden_states
def destroy(self):
pass
class PPLXAll2AllManager(All2AllManagerBase):
"""
All2All communication based on PPLX kernels.
"""
def __init__(self, cpu_group):
assert has_pplx(
), "pplx_kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install pplx_kernels." # noqa
super().__init__(cpu_group)
if self.internode:
# inter-node communication needs nvshmem,
# intra-node communication uses p2p mapping directly
from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id,
nvshmem_get_unique_id,
nvshmem_init)
logger.debug(
"Initialize NVSHMEM for pplx_kernels: "
"rank=%d, world size=%d", self.rank, self.world_size)
uid = nvshmem_get_unique_id(
) if self.rank == 0 else nvshmem_alloc_empty_unique_id()
dist.broadcast(uid,
src=dist.get_process_group_ranks(self.cpu_group)[0],
group=self.cpu_group)
logger.debug("PPLX NVSHMEM UID = %s", uid)
nvshmem_init(uid, self.rank, self.world_size)
self.handle_cache = Cache()
def get_handle(self, kwargs):
import pplx_kernels as pplx
return self.handle_cache.get_or_create(
kwargs, pplx.AllToAll.internode
if self.internode else pplx.AllToAll.intranode)
def dispatch(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False
) -> tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError
def combine(self,
hidden_states: torch.Tensor,
is_sequence_parallel: bool = False) -> torch.Tensor:
raise NotImplementedError
def destroy(self):
with self.handle_cache._lock:
for _, handle in self.handle_cache._cache.items():
handle.destroy()
if self.internode:
from pplx_kernels.nvshmem import nvshmem_finalize
logger.debug("PPLX NVSHMEM finalize")
nvshmem_finalize()
class DeepEPAll2AllManagerBase(All2AllManagerBase):
"""
All2All communication based on DeepEP High-Throughput kernels.
"""
def __init__(self, cpu_group):
assert has_deep_ep(
), "DeepEP kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install DeepEP kernels." # noqa
super().__init__(cpu_group)
self.handle_cache = Cache()
# This is the DeepEP default. Stick to it till we can establish
# reasonable defaults based on profiling.
self.num_sms = 20
def get_handle(self, kwargs):
raise NotImplementedError
def dispatch(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False
) -> tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError
def combine(self,
hidden_states: torch.Tensor,
is_sequence_parallel: bool = False) -> torch.Tensor:
raise NotImplementedError
def destroy(self):
pass
class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
"""
All2All communication based on DeepEP High-Throughput kernels.
"""
def __init__(self, cpu_group):
super().__init__(cpu_group)
def _make_all2all_kwargs(self) -> dict[Any, Any]:
# Defaults for internode and intranode are taken from DeepEP tests.
num_nvl_bytes = envs.VLLM_DEEPEP_BUFFER_SIZE_MB * 1024 * 1024
num_rdma_bytes = None
num_qps_per_rank = None
if self.internode:
num_rdma_bytes = envs.VLLM_DEEPEP_BUFFER_SIZE_MB * 1024 * 1024
num_qps_per_rank = self.num_sms // 2
else:
num_rdma_bytes = 0
num_qps_per_rank = 1
assert num_rdma_bytes is not None
assert num_qps_per_rank is not None
return dict(group=self.cpu_group,
num_nvl_bytes=num_nvl_bytes,
num_rdma_bytes=num_rdma_bytes,
low_latency_mode=False,
num_qps_per_rank=num_qps_per_rank)
def get_handle(self, kwargs):
assert len(kwargs) == 0, (
"DeepEPHTAll2AllManager expects no arguments. All the required "
"args are computed in the Manager itself.")
import deep_ep
buffer_kwargs = self._make_all2all_kwargs()
logger.debug("DeepEP all2all args %s", buffer_kwargs)
handle: deep_ep.Buffer = self.handle_cache.get_or_create(
buffer_kwargs, deep_ep.Buffer)
return handle
def set_num_sms(self, num_sms: int):
import deep_ep
# Right now the buffers are sized for only what the kernels were
# created with. So we can only reduce the number of SMS used
# but not increase it.
if num_sms > self.num_sms:
num_sms = self.num_sms
deep_ep.Buffer.set_num_sms(num_sms)
class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
"""
All2All communication based on DeepEP Low-Latency kernels.
"""
def __init__(self, cpu_group):
super().__init__(cpu_group)
def _make_all2all_kwargs(
self,
max_num_tokens_per_dp_rank: int,
token_hidden_size: int,
num_ep_ranks: int,
num_global_experts: int,
num_local_experts: int,
) -> dict[Any, Any]:
"""
max_num_tokens_per_dp_rank : the maximum number of tokens a DP rank
can dispatch all the ranks must hold the same value.
token_hidden_size: the hidden dimension of each token.
num_ep_ranks: the number of EP group ranks.
num_global_experts: Number of experts in the model.
num_local_experts: Number of experts in an EP rank.
"""
import deep_ep
# Defaults for internode and intranode are taken from DeepEP tests.
num_nvl_bytes = envs.VLLM_DEEPEP_BUFFER_SIZE_MB * 1024 * 1024
num_qps_per_rank = num_local_experts
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(
num_max_dispatch_tokens_per_rank=max_num_tokens_per_dp_rank,
hidden=token_hidden_size,
num_ranks=num_ep_ranks,
num_experts=num_global_experts)
assert num_rdma_bytes is not None
return dict(group=self.cpu_group,
num_nvl_bytes=num_nvl_bytes,
num_rdma_bytes=num_rdma_bytes,
low_latency_mode=True,
num_qps_per_rank=num_qps_per_rank)
def get_handle(self, kwargs):
"""
The kwargs for DeepEPLLAll2AllManager is dictated by
_make_all2all_kwargs.
"""
import deep_ep
buffer_kwargs = self._make_all2all_kwargs(**kwargs)
logger.debug("DeepEP all2all args %s", buffer_kwargs)
handle: deep_ep.Buffer = self.handle_cache.get_or_create(
buffer_kwargs, deep_ep.Buffer)
return handle
# DeepEP LL uses RDMA so no SMs are used for communication
def max_sms_used(self) -> Optional[int]:
return 0
class FlashInferAllToAllManager(All2AllManagerBase):
"""
All2All communication based on flashinfer kernels.
"""
def __init__(self, cpu_group):
assert has_flashinfer_all2all(
), "flashinfer all2all module not found. Please install/check flashinfer" # noqa
super().__init__(cpu_group)
logger.debug(
"Initialize for flashinfer All2All "
"rank=%d, world size=%d", self.rank, self.world_size)
self.initialized = False
self.alltoall_info = None
def initialize(
self,
world_size: int,
rank: int,
gpus_per_node: int,
):
"""Initialize workspace"""
if self.initialized:
return
self.cleanup()
logger.debug("making map: "
"rank=%d, world size=%d", rank, world_size)
self.mapping = Mapping(
world_size,
rank,
gpus_per_node,
tp_size=world_size,
)
from vllm.distributed.device_communicators.mnnvl_compat import (
CustomCommunicator)
dp_config = MnnvlConfig(
comm_backend=CustomCommunicator(get_dp_group().cpu_group),
fabric_page_size=1 << 29, # 512MB
allocation_granularity=0 # Auto-detect
)
self.workspace_tensor = MnnvlMoe.get_moe_workspaces(
self.mapping, dp_config)
self.prepare_workspace_tensor = MnnvlMoe.get_moe_prepare_workspace(
self.mapping, dp_config)
self.world_size = world_size
self.rank = rank
self.gpus_per_node = gpus_per_node
self.initialized = True
logger.info("FlashInfer All2All initialized for rank %s, size %s",
rank, world_size)
def ensure_alltoall_workspace_initialized(self):
"""Ensure workspace is initialized"""
if not has_flashinfer_all2all():
return False
if self.world_size <= 1:
return False
if not self.initialized:
self.initialize(
world_size=self.world_size,
rank=self.rank,
gpus_per_node=torch.cuda.device_count,
)
return self.initialized
def get_handle(self, kwargs):
return self
def cleanup(self):
"""Clean up workspace"""
if self.initialized and self.workspace_tensor is not None \
and self.prepare_workspace_tensor is not None:
try:
del self.workspace_tensor
del self.prepare_workspace_tensor
except Exception as e:
logger.warning("Failed to cleanup FlashInfer workspace: %s", e)
finally:
self.workspace_tensor = None
self.prepare_workspace_tensor = None
self.mapping = None
self.initialized = False

View File

@@ -0,0 +1,317 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import ctypes
import json
import os
import pickle
import subprocess
import sys
import tempfile
from collections.abc import Sequence
from itertools import product
from typing import Any, Optional
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import vllm.envs as envs
from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
from vllm.logger import init_logger
from vllm.utils import (cuda_device_count_stateless,
update_environment_variables)
logger = init_logger(__name__)
MiB = 1024 * 1024
# Max size for each world size in case symmetric memory is available
# For different SM architectures
CUSTOM_ALL_REDUCE_MAX_SIZES = {
"9.0": {
2: 64 * MiB, # 64 MB
4: 32 * MiB, # 32 MB
6: MiB // 2, # 512 KB
8: MiB // 4, # 256 KB
},
"10.0": {
2: 2 * MiB, # 2 MB
4: 2 * MiB, # 2 MB
6: 1 * MiB, # 1 MB
8: 1 * MiB, # 1 MB
}
}
SYMM_MEM_ALL_REDUCE_MAX_SIZES = {
"9.0": {
2: 64 * MiB, # 64 MB
4: 32 * MiB, # 32 MB
6: 64 * MiB, # 64 MB
8: 64 * MiB, # 64 MB
},
"10.0": {
2: 8 * MiB, # 8 MB
4: 32 * MiB, # 32 MB
6: 128 * MiB, # 128 MB
8: 128 * MiB, # 128 MB
}
}
NCCL_SYMM_MEM_ALL_REDUCE_CONFIG: dict[str, Any] = {
"min_world_size": 4,
"thresholds": {
4: 2 * MiB, # 2 MB
8: 1 * MiB, # 1 MB
},
"always_use_above_world_size": 8 # Always use symm mem for world_size > 8
}
def should_nccl_symm_mem_allreduce(world_size: int,
input_tensor: torch.Tensor) -> bool:
from vllm.distributed.device_communicators.pynccl_allocator import (
is_symmetric_memory_enabled)
if not is_symmetric_memory_enabled():
return False
if world_size < NCCL_SYMM_MEM_ALL_REDUCE_CONFIG["min_world_size"]:
return False
threshold = NCCL_SYMM_MEM_ALL_REDUCE_CONFIG["thresholds"].get(world_size)
if threshold is not None and input_tensor.nbytes >= threshold:
return True
return (world_size
> NCCL_SYMM_MEM_ALL_REDUCE_CONFIG["always_use_above_world_size"])
def producer(batch_src: Sequence[int],
producer_queue,
consumer_queue,
result_queue,
cuda_visible_devices: Optional[str] = None):
if cuda_visible_devices is not None:
update_environment_variables(
{"CUDA_VISIBLE_DEVICES": cuda_visible_devices})
lib = CudaRTLibrary()
for i in batch_src:
lib.cudaSetDevice(i)
pointer = lib.cudaMalloc(1024)
lib.cudaMemset(pointer, 1, 1024)
lib.cudaDeviceSynchronize()
handle = lib.cudaIpcGetMemHandle(pointer)
producer_queue.put(handle)
open_success = consumer_queue.get()
if open_success:
# use two queues to simulate barrier
producer_queue.put(0)
consumer_queue.get()
# check if the memory is modified
host_data = (ctypes.c_char * 1024)()
lib.cudaMemcpy(host_data, pointer, 1024) # type: ignore
for i in range(1024):
if ord(host_data[i]) != 2:
open_success = False
break
result_queue.put(open_success)
lib.cudaDeviceReset()
def consumer(batch_tgt: Sequence[int],
producer_queue,
consumer_queue,
result_queue,
cuda_visible_devices: Optional[str] = None):
if cuda_visible_devices is not None:
update_environment_variables(
{"CUDA_VISIBLE_DEVICES": cuda_visible_devices})
lib = CudaRTLibrary()
for j in batch_tgt:
lib.cudaSetDevice(j)
handle = producer_queue.get()
open_success = False
try:
pointer = lib.cudaIpcOpenMemHandle(handle) # type: ignore
open_success = True
except RuntimeError:
# cannot error out here, because the producer process
# is still waiting for the response.
pass
consumer_queue.put(open_success)
if open_success:
# modify the memory
lib.cudaMemset(pointer, 2, 1024)
lib.cudaDeviceSynchronize()
# use two queues to simulate barrier
producer_queue.get()
consumer_queue.put(0)
# check if the memory is modified
host_data = (ctypes.c_char * 1024)()
lib.cudaMemcpy(host_data, pointer, 1024) # type: ignore
for i in range(1024):
if ord(host_data[i]) != 2:
open_success = False
break
result_queue.put(open_success)
lib.cudaDeviceReset()
def can_actually_p2p(
batch_src: Sequence[int],
batch_tgt: Sequence[int],
) -> Sequence[bool]:
"""
Usually, checking if P2P access is enabled can be done by
`torch.cuda.can_device_access_peer(src, tgt)`. However, sometimes
the driver might be broken, and `torch.cuda.can_device_access_peer(src, tgt)`
returns `True` even if P2P access is not actually possible.
See https://github.com/vllm-project/vllm/issues/2728 and
https://forums.developer.nvidia.com/t/direct-gpu-gpu-communication-does-not-seem-to-work-properly/283264/10
Therefore, we have to perform a real P2P access to check if it is actually
possible.
Note on p2p and cuda IPC:
Usually, one process uses one GPU:
GPU src --> cuda context src --> tensor src --> process src
We need to combine p2p and cuda IPC, so that:
GPU src --> cuda context src --> tensor src --> process src
|shared|
GPU tgt --> cuda context tgt --> tensor tgt --> process tgt
That is to say, process src creates a tensor in GPU src, passes IPC handle to
process tgt, and process tgt accesses the tensor in GPU tgt. Any operation on the
tensor in process tgt will be reflected in the tensor in process src, because
they are the same memory segment.
It is important to note that process tgt accesses the tensor in GPU tgt, not
GPU src. That's why we need p2p access.
The most time-consuming part is the process creation. To avoid creating
processes for every pair of GPUs, we use batched testing. We create two
processes for testing all pairs of GPUs in batch. The trick is to reset
the device after each test (which is not available in PyTorch).
""" # noqa
cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
# pass the CUDA_VISIBLE_DEVICES to the child process
# to make sure they see the same set of GPUs
# make sure the processes are spawned
smp = mp.get_context("spawn")
producer_queue = smp.Queue()
consumer_queue = smp.Queue()
result_queue = smp.Queue()
p_src = smp.Process(target=producer,
args=(batch_src, producer_queue, consumer_queue,
result_queue, cuda_visible_devices))
p_tgt = smp.Process(target=consumer,
args=(batch_tgt, producer_queue, consumer_queue,
result_queue, cuda_visible_devices))
p_src.start()
p_tgt.start()
p_src.join()
p_tgt.join()
assert p_src.exitcode == 0 and p_tgt.exitcode == 0
result: list[bool] = []
for src, tgt in zip(batch_src, batch_tgt):
a = result_queue.get()
b = result_queue.get()
if a != b:
logger.warning(
"Two processes do not agree on the P2P access"
" status on %d -> %d, treat as disabled.", src, tgt)
result.append(False)
else:
result.append(a)
return result
# why do we need this cache?
# we are testing peer-to-peer (p2p) access between GPUs,across processes.
# if we test it every time, it will be very slow, because we need to create
# N * N * 2 processes, where N is the world size. This is very slow.
# to reduce the time, we use a cache file to store the p2p access status.
# the cache file is generated by the master process if it does not exist.
# then all the processes can read the cache file to check the p2p access status.
# Note that the cache file is suffixed by the CUDA_VISIBLE_DEVICES, so that we
# can have different cache files for different CUDA_VISIBLE_DEVICES settings,
# e.g. used by different vllm engines. The device id in the cache file is a
# **local** device id, i.e. from 0 to num_dev-1, where num_dev is the number
# of visible devices in the vllm engine.
_gpu_p2p_access_cache: Optional[dict[str, bool]] = None
def gpu_p2p_access_check(src: int, tgt: int) -> bool:
"""Check if GPU src can access GPU tgt."""
# if the cache variable is already calculated,
# read from the cache instead of checking it again
global _gpu_p2p_access_cache
if _gpu_p2p_access_cache is not None:
return _gpu_p2p_access_cache[f"{src}->{tgt}"]
is_distributed = dist.is_initialized()
num_dev = cuda_device_count_stateless()
cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
if cuda_visible_devices is None:
cuda_visible_devices = ",".join(str(i) for i in range(num_dev))
path = os.path.join(
envs.VLLM_CACHE_ROOT,
f"gpu_p2p_access_cache_for_{cuda_visible_devices}.json")
os.makedirs(os.path.dirname(path), exist_ok=True)
from vllm.distributed.parallel_state import get_world_group
if ((not is_distributed or get_world_group().local_rank == 0)
and (not os.path.exists(path))):
# only the local master process (with local_rank == 0) can
# enter this block to calculate the cache
logger.info("generating GPU P2P access cache in %s", path)
cache: dict[str, bool] = {}
ids = list(range(num_dev))
# batch of all pairs of GPUs
batch_src, batch_tgt = zip(*list(product(ids, ids)))
# NOTE: we use `subprocess` rather than `multiprocessing` here
# because the caller might not have `if __name__ == "__main__":`,
# in that case we cannot use spawn method in multiprocessing.
# However, `can_actually_p2p` requires spawn method.
# The fix is, we use `subprocess` to call the function,
# where we have `if __name__ == "__main__":` in this file.
# use a temporary file to store the result
# we don't use the output of the subprocess directly,
# because the subprocess might produce logging output
with tempfile.NamedTemporaryFile() as output_file:
input_bytes = pickle.dumps(
(batch_src, batch_tgt, output_file.name))
returned = subprocess.run([sys.executable, __file__],
input=input_bytes,
capture_output=True)
# check if the subprocess is successful
try:
returned.check_returncode()
except Exception as e:
# wrap raised exception to provide more information
raise RuntimeError(
f"Error happened when batch testing "
f"peer-to-peer access from {batch_src} to {batch_tgt}:\n"
f"{returned.stderr.decode()}") from e
with open(output_file.name, "rb") as f:
result = pickle.load(f)
for _i, _j, r in zip(batch_src, batch_tgt, result):
cache[f"{_i}->{_j}"] = r
with open(path, "w") as f:
json.dump(cache, f, indent=4)
if is_distributed:
get_world_group().barrier()
logger.info("reading GPU P2P access cache from %s", path)
with open(path) as f:
cache = json.load(f)
_gpu_p2p_access_cache = cache
return _gpu_p2p_access_cache[f"{src}->{tgt}"]
__all__ = ["gpu_p2p_access_check"]
if __name__ == "__main__":
batch_src, batch_tgt, output_file = pickle.loads(sys.stdin.buffer.read())
result = can_actually_p2p(batch_src, batch_tgt)
with open(output_file, "wb") as f:
f.write(pickle.dumps(result))

View File

@@ -0,0 +1,295 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import threading
from typing import Optional, Union
from weakref import WeakValueDictionary
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
class Cache:
def __init__(self):
self._cache: WeakValueDictionary = WeakValueDictionary()
self._lock = threading.RLock() # Reentrant lock for thread safety
def get_or_create(self, kwargs, func):
# Create a hashable key from the kwargs
key = tuple(sorted((k, v) for k, v in kwargs.items()))
with self._lock:
instance = self._cache.get(key)
if instance is None:
instance = func(**kwargs)
self._cache[key] = instance
return instance
class All2AllManagerBase:
rank: int
world_size: int
def __init__(self, cpu_group):
self.cpu_group = cpu_group
# compute some common properties
from vllm.distributed.parallel_state import (get_dp_group,
get_tp_group,
in_the_same_node_as)
# all2all lives in ep group, which is merged from dp and tp group
self.dp_group = get_dp_group()
self.tp_group = get_tp_group()
# no self.ep_group since self.ep_group is still in construction
# when we create this object
self.dp_rank = self.dp_group.rank_in_group
self.dp_world_size = self.dp_group.world_size
self.rank = dist.get_rank(cpu_group)
self.world_size = dist.get_world_size(cpu_group)
# all2all communication often has separate implementations for
# intra-node and inter-node communication
self.internode = not all(in_the_same_node_as(cpu_group, source_rank=0))
def get_handle(self, kwargs):
# get a handle for the all2all communication,
# based on the kwargs.
# different layers can have different configs,
# e.g. one layer has hidden size 1024, another has 2048.
# usually the underlying implementation caches the handle
# and reuse it for the same config.
raise NotImplementedError
def dispatch(self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False):
raise NotImplementedError
def set_num_sms(self, num_sms: int):
pass
def max_sms_used(self) -> Optional[int]:
return None # None means it could use the whole GPU
def combine(self,
hidden_states: torch.Tensor,
is_sequence_parallel: bool = False):
raise NotImplementedError
def destroy(self):
pass
class DeviceCommunicatorBase:
"""
Base class for device-specific communicator.
It can use the `cpu_group` to initialize the communicator.
If the device has PyTorch integration (PyTorch can recognize its
communication backend), the `device_group` will also be given.
"""
def __init__(self,
cpu_group: ProcessGroup,
device: Optional[torch.device] = None,
device_group: Optional[ProcessGroup] = None,
unique_name: str = ""):
self.device = device or torch.device("cpu")
self.cpu_group = cpu_group
self.device_group = device_group
self.unique_name = unique_name
self.rank = dist.get_rank(cpu_group)
self.world_size = dist.get_world_size(cpu_group)
self.ranks = dist.get_process_group_ranks(cpu_group)
self.global_rank = dist.get_rank()
self.global_world_size = dist.get_world_size()
self.rank_in_group = dist.get_group_rank(self.cpu_group,
self.global_rank)
use_ep = False
from vllm.config import get_current_vllm_config
config = get_current_vllm_config()
if config is not None:
# as long as we use data parallel (coupled data parallel
# where all data parallel ranks execute forward together),
# we initialize the all2all manager used in expert parallel.
use_ep = config.parallel_config.data_parallel_size > 1
self.is_ep_communicator = "ep" in unique_name
self.use_all2all = self.is_ep_communicator and use_ep
self.all2all_manager: Optional[All2AllManagerBase] = None
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
dist.all_reduce(input_, group=self.device_group)
return input_
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
input_size = input_.size()
# NOTE: we have to use concat-style all-gather here,
# stack-style all-gather has compatibility issues with
# torch.compile . see https://github.com/pytorch/pytorch/issues/138795
output_size = (input_size[0] * self.world_size, ) + input_size[1:]
# Allocate output tensor.
output_tensor = torch.empty(output_size,
dtype=input_.dtype,
device=input_.device)
# All-gather.
dist.all_gather_into_tensor(output_tensor,
input_,
group=self.device_group)
# Reshape
output_tensor = output_tensor.reshape((self.world_size, ) + input_size)
output_tensor = output_tensor.movedim(0, dim)
output_tensor = output_tensor.reshape(input_size[:dim] +
(self.world_size *
input_size[dim], ) +
input_size[dim + 1:])
return output_tensor
def all_gatherv(
self,
input_: Union[torch.Tensor, list[torch.Tensor]],
dim: int = 0,
sizes: Optional[list[int]] = None
) -> Union[torch.Tensor, list[torch.Tensor]]:
raise NotImplementedError
def reduce_scatter(self,
input_: torch.Tensor,
dim: int = -1) -> torch.Tensor:
world_size = self.world_size
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
assert -input_.dim() <= dim < input_.dim(), (
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
# Note: This will produce an incorrect answer if we don't make
# the input_tensor contiguous. Possible bug in reduce_scatter_tensor?
input_tensor = input_.movedim(0, dim).contiguous()
assert input_tensor.shape[0] % world_size == 0
chunk_size = input_tensor.shape[0] // world_size
output_shape = (chunk_size, ) + input_tensor.shape[1:]
output_tensor = torch.empty(output_shape,
dtype=input_tensor.dtype,
device=input_tensor.device)
# Perform reduce-scatter operation
torch.distributed.reduce_scatter_tensor(output_tensor,
input_tensor,
group=self.device_group)
# Reshape before returning
return output_tensor.movedim(0, dim).contiguous()
def reduce_scatterv(self,
input_: torch.Tensor,
dim: int = -1,
sizes: Optional[list[int]] = None) -> torch.Tensor:
raise NotImplementedError
def gather(self,
input_: torch.Tensor,
dst: int = 0,
dim: int = -1) -> Optional[torch.Tensor]:
"""
NOTE: We assume that the input tensor is on the same device across
all the ranks.
NOTE: `dst` is the local rank of the destination rank.
"""
world_size = self.world_size
assert -input_.dim() <= dim < input_.dim(), (
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
# Allocate output tensor.
if self.rank_in_group == dst:
gather_list = [torch.empty_like(input_) for _ in range(world_size)]
else:
gather_list = None
# Gather.
torch.distributed.gather(input_,
gather_list,
dst=self.ranks[dst],
group=self.device_group)
if self.rank_in_group == dst:
output_tensor = torch.cat(gather_list, dim=dim)
else:
output_tensor = None
return output_tensor
def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:
"""Sends a tensor to the destination rank in a blocking way"""
"""NOTE: `dst` is the local rank of the destination rank."""
if dst is None:
dst = (self.rank_in_group + 1) % self.world_size
torch.distributed.send(tensor, self.ranks[dst], self.device_group)
def recv(self,
size: torch.Size,
dtype: torch.dtype,
src: Optional[int] = None) -> torch.Tensor:
"""Receives a tensor from the source rank."""
"""NOTE: `src` is the local rank of the source rank."""
if src is None:
src = (self.rank_in_group - 1) % self.world_size
tensor = torch.empty(size, dtype=dtype, device=self.device)
torch.distributed.recv(tensor, self.ranks[src], self.device_group)
return tensor
def destroy(self):
pass
def prepare_communication_buffer_for_model(self,
model: torch.nn.Module) -> None:
"""
Prepare the communication buffer for the model.
"""
if not self.is_ep_communicator:
return
moe_modules = [
module for module in model.modules()
# TODO(bnell): Should use isinstance but can't. Maybe search for
# presence of quant_method.init_prepare_finalize?
if (module.__class__.__name__ == "FusedMoE"
or module.__class__.__name__ == "SharedFusedMoE")
]
for module in moe_modules:
module.quant_method.init_prepare_finalize(module)
def dispatch(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Dispatch the hidden states and router logits to the appropriate device.
This is a no-op in the base class.
"""
return hidden_states, router_logits
def combine(self,
hidden_states: torch.Tensor,
is_sequence_parallel: bool = False) -> torch.Tensor:
"""
Combine the hidden states and router logits from the appropriate device.
This is a no-op in the base class.
"""
return hidden_states

View File

@@ -0,0 +1,201 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from typing import Any, Optional, Union
import torch
from torch.distributed import ProcessGroup
from vllm.distributed.utils import pickle
from vllm.platforms import current_platform
from vllm.platforms.interface import CpuArchEnum
from .base_device_communicator import DeviceCommunicatorBase
class CpuCommunicator(DeviceCommunicatorBase):
def __init__(self,
cpu_group: ProcessGroup,
device: Optional[torch.device] = None,
device_group: Optional[ProcessGroup] = None,
unique_name: str = ""):
super().__init__(cpu_group, device, device_group, unique_name)
self.dist_module = torch.distributed
if (current_platform.get_cpu_architecture()
== CpuArchEnum.X86) and hasattr(
torch.ops._C,
"init_shm_manager") and (unique_name.startswith("tp")
or unique_name.startswith("pp")):
self.dist_module = _CPUSHMDistributed(self)
def all_reduce(self, input_):
self.dist_module.all_reduce(input_, group=self.device_group)
return input_
def gather(self,
input_: torch.Tensor,
dst: int = 0,
dim: int = -1) -> Optional[torch.Tensor]:
"""
NOTE: We assume that the input tensor is on the same device across
all the ranks.
NOTE: `dst` is the local rank of the destination rank.
"""
world_size = self.world_size
assert -input_.dim() <= dim < input_.dim(), (
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
# Allocate output tensor.
if self.rank_in_group == dst:
gather_list = [torch.empty_like(input_) for _ in range(world_size)]
else:
gather_list = None
# Gather.
self.dist_module.gather(input_,
gather_list,
dst=self.ranks[dst],
group=self.device_group)
if self.rank_in_group == dst:
output_tensor = torch.cat(gather_list, dim=dim)
else:
output_tensor = None
return output_tensor
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
input_size = input_.size()
# NOTE: we have to use concat-style all-gather here,
# stack-style all-gather has compatibility issues with
# torch.compile . see https://github.com/pytorch/pytorch/issues/138795
output_size = (input_size[0] * self.world_size, ) + input_size[1:]
# Allocate output tensor.
output_tensor = torch.empty(output_size,
dtype=input_.dtype,
device=input_.device)
# All-gather.
self.dist_module.all_gather_into_tensor(output_tensor,
input_,
group=self.device_group)
# Reshape
output_tensor = output_tensor.reshape((self.world_size, ) + input_size)
output_tensor = output_tensor.movedim(0, dim)
output_tensor = output_tensor.reshape(input_size[:dim] +
(self.world_size *
input_size[dim], ) +
input_size[dim + 1:])
return output_tensor
def send_tensor_dict(
self,
tensor_dict: dict[str, Union[torch.Tensor, Any]],
dst: int,
) -> None:
return self.dist_module.send_tensor_dict(tensor_dict, dst)
def recv_tensor_dict(
self,
src: int,
) -> dict[str, Union[torch.Tensor, Any]]:
return self.dist_module.recv_tensor_dict(src)
class _CPUSHMDistributed:
def __init__(self, communicator: CpuCommunicator):
instance_identifier = os.environ["VLLM_DIST_IDENT"]
unique_name = communicator.unique_name
instance_identifier = f"{instance_identifier}-{unique_name}"
self.communicator = communicator
group_ranks = [str(rank) for rank in self.communicator.ranks]
shm_group_identifier = f"[{'-'.join(group_ranks)}]"
self.group_name = f"{instance_identifier}-{shm_group_identifier}-cpushm"
self.handle = self._init_cpu_shm()
def _init_cpu_shm(self) -> int:
handle = torch.ops._C.init_shm_manager(
self.group_name,
self.communicator.world_size,
self.communicator.rank,
)
torch.distributed.barrier(self.communicator.device_group)
torch.ops._C.join_shm_manager(
handle,
self.group_name,
)
torch.distributed.barrier(self.communicator.device_group)
return handle
def all_reduce(self,
input: torch.Tensor,
group: Optional[ProcessGroup] = None) -> None:
torch.ops._C.shm_allreduce(self.handle, input)
def gather(self,
input: torch.Tensor,
gather_list: Optional[list[torch.Tensor]],
dst: int = -1,
group: Optional[ProcessGroup] = None) -> None:
# Note: different from the torch gather, here we use local dst rank.
torch.ops._C.shm_gather(self.handle, input, gather_list,
torch.distributed.get_group_rank(group, dst))
def all_gather_into_tensor(self,
output: torch.Tensor,
input: torch.Tensor,
group: Optional[ProcessGroup] = None) -> None:
torch.ops._C.shm_all_gather(self.handle, input, output)
def send_tensor_dict(
self,
tensor_dict: dict[str, Union[torch.Tensor, Any]],
dst: int,
) -> None:
key_list = list(tensor_dict.keys())
value_list = list(tensor_dict.values())
size_list = []
for v in value_list:
if not isinstance(v, torch.Tensor):
raise RuntimeError(
"CpuCommunicator only supports sending tensors.")
size_list.append(v.size())
key_size_tensor = torch.frombuffer(pickle.dumps([key_list, size_list]),
dtype=torch.uint8)
value_list.append(key_size_tensor)
torch.ops._C.shm_send_tensor_list(self.handle, value_list, dst)
return None
def recv_tensor_dict(
self,
src: int,
) -> dict[str, Union[torch.Tensor, Any]]:
tensor_list = torch.ops._C.shm_recv_tensor_list(self.handle, src)
value_list: list[torch.Tensor] = tensor_list[:-1]
key_size_tensor = tensor_list[-1]
key_size = pickle.loads(key_size_tensor.numpy().tobytes())
key_list = key_size[0]
size_list = key_size[1]
assert len(key_list) == len(size_list)
assert len(key_list) == len(value_list)
tensor_dict: dict[str, torch.Tensor] = {}
for key, size, t in zip(key_list, size_list, value_list):
tensor_dict[key] = t.view(size)
return tensor_dict

View File

@@ -0,0 +1,323 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional, Union
import torch
from torch.distributed import ProcessGroup
import vllm.envs as envs
from vllm.distributed.device_communicators.all_reduce_utils import (
should_nccl_symm_mem_allreduce)
from vllm.distributed.device_communicators.pynccl import (
register_nccl_symmetric_ops)
from vllm.distributed.device_communicators.pynccl_allocator import (
is_symmetric_memory_enabled)
from vllm.logger import init_logger
from vllm.platforms import current_platform
from .base_device_communicator import DeviceCommunicatorBase
logger = init_logger(__name__)
class CudaCommunicator(DeviceCommunicatorBase):
def __init__(self,
cpu_group: ProcessGroup,
device: Optional[torch.device] = None,
device_group: Optional[ProcessGroup] = None,
unique_name: str = ""):
super().__init__(cpu_group, device, device_group, unique_name)
if "tp" not in unique_name:
# custom allreduce or torch symm mem can be used only by tp
use_custom_allreduce = False
use_torch_symm_mem = False
else:
from vllm.distributed.parallel_state import (
_ENABLE_CUSTOM_ALL_REDUCE)
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
use_torch_symm_mem = envs.VLLM_ALLREDUCE_USE_SYMM_MEM
self.use_custom_allreduce = use_custom_allreduce
self.use_torch_symm_mem = use_torch_symm_mem
# lazy import to avoid documentation build error
from vllm.distributed.device_communicators.custom_all_reduce import (
CustomAllreduce)
from vllm.distributed.device_communicators.pynccl import (
PyNcclCommunicator)
from vllm.distributed.device_communicators.quick_all_reduce import (
QuickAllReduce)
from vllm.distributed.device_communicators.symm_mem import (
SymmMemCommunicator)
self.pynccl_comm: Optional[PyNcclCommunicator] = None
if self.world_size > 1:
self.pynccl_comm = PyNcclCommunicator(
group=self.cpu_group,
device=self.device,
)
if is_symmetric_memory_enabled():
register_nccl_symmetric_ops(self.pynccl_comm)
self.ca_comm: Optional[CustomAllreduce] = None
self.qr_comm: Optional[QuickAllReduce] = None
self.symm_mem_comm: Optional[SymmMemCommunicator] = None
if use_torch_symm_mem and current_platform.is_cuda():
self.symm_mem_comm = SymmMemCommunicator(
group=self.cpu_group,
device=self.device,
)
if use_custom_allreduce and self.world_size > 1:
# Initialize a custom fast all-reduce implementation.
self.ca_comm = CustomAllreduce(
group=self.cpu_group,
device=self.device,
symm_mem_enabled=(self.symm_mem_comm is not None
and not self.symm_mem_comm.disabled),
)
if current_platform.is_rocm():
# Initialize a custom quick all-reduce implementation for AMD.
# Quick reduce is designed as a complement to custom allreduce.
# Based on quickreduce (https://github.com/mk1-project/quickreduce).
# If it's a rocm, 'use_custom_allreduce==True' means it must
# currently be an MI300 series.
self.qr_comm = QuickAllReduce(group=self.cpu_group,
device=self.device)
if self.use_all2all:
all2all_backend = envs.VLLM_ALL2ALL_BACKEND
if all2all_backend == "naive":
from .all2all import NaiveAll2AllManager
self.all2all_manager = NaiveAll2AllManager(self.cpu_group)
logger.info("Using naive all2all manager.")
elif all2all_backend == "allgather_reducescatter":
from .all2all import AgRsAll2AllManager
self.all2all_manager = AgRsAll2AllManager(self.cpu_group)
logger.info("Using AllGather-ReduceScatter all2all manager.")
elif all2all_backend == "pplx":
from .all2all import PPLXAll2AllManager
self.all2all_manager = PPLXAll2AllManager(self.cpu_group)
logger.info("Using PPLX all2all manager.")
elif all2all_backend == "deepep_high_throughput":
from .all2all import DeepEPHTAll2AllManager
self.all2all_manager = DeepEPHTAll2AllManager(self.cpu_group)
logger.info("Using DeepEP High-Throughput all2all manager.")
elif all2all_backend == "deepep_low_latency":
from .all2all import DeepEPLLAll2AllManager
self.all2all_manager = DeepEPLLAll2AllManager(self.cpu_group)
logger.info("Using DeepEP Low-Latency all2all manager.")
elif all2all_backend == "flashinfer_all2allv":
from .all2all import FlashInferAllToAllManager
self.all2all_manager = FlashInferAllToAllManager(
self.cpu_group)
logger.info("Using Flashinfer all2allv manager.")
else:
raise ValueError(f"Unknown all2all backend: {all2all_backend}")
def all_reduce(self, input_):
# since currently we perform copy input -> symm_input -> out-of-place AR
# return symm_output, we don't need to check if input is symmetric
if self.pynccl_comm is not None and \
should_nccl_symm_mem_allreduce(self.pynccl_comm.world_size,input_):
out = torch.ops.vllm.all_reduce_symmetric_with_copy(input_)
if out is not None:
return out
# always try quick reduce first, then custom allreduce,
# and then pynccl. (quick reduce just for ROCM MI3*)
qr_comm = self.qr_comm
if qr_comm is not None and not qr_comm.disabled and \
qr_comm.should_quick_allreduce(input_):
out = qr_comm.quick_all_reduce(input_)
assert out is not None
return out
ca_comm = self.ca_comm
if ca_comm is not None and not ca_comm.disabled and \
ca_comm.should_custom_ar(input_):
out = ca_comm.custom_all_reduce(input_)
assert out is not None
return out
symm_mem_comm = self.symm_mem_comm
if symm_mem_comm is not None and \
symm_mem_comm.should_use_symm_mem(input_):
out = symm_mem_comm.all_reduce(input_)
assert out is not None
return out
pynccl_comm = self.pynccl_comm
assert pynccl_comm is not None
out = pynccl_comm.all_reduce(input_)
if out is None:
# fall back to the default all-reduce using PyTorch.
# this usually happens during testing.
# when we run the model, allreduce only happens for the TP
# group, where we always have either custom allreduce or pynccl.
out = input_.clone()
torch.distributed.all_reduce(out, group=self.device_group)
return out
def reduce_scatter(self, input_: torch.Tensor, dim: int = -1):
world_size = self.world_size
pynccl_comm = self.pynccl_comm
assert pynccl_comm is not None
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
# Note: This will produce an incorrect answer if we don't make
# the input_tensor contiguous. Possible bug in reduce_scatter_tensor?
input_tensor = input_.movedim(0, dim).contiguous()
assert input_tensor.shape[0] % world_size == 0
chunk_size = input_tensor.shape[0] // world_size
output_shape = (chunk_size, ) + input_tensor.shape[1:]
output = torch.empty(output_shape,
dtype=input_tensor.dtype,
device=input_tensor.device)
pynccl_comm.reduce_scatter(output, input_tensor)
# Reshape before returning
return output.movedim(0, dim).contiguous()
def reduce_scatterv(self,
input_: torch.Tensor,
dim: int = -1,
sizes: Optional[list[int]] = None):
world_size = self.world_size
pynccl_comm = self.pynccl_comm
assert pynccl_comm is not None
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
# Note: This will produce an incorrect answer if we don't make
# the input_tensor contiguous. Possible bug in reduce_scatter_tensor?
input_tensor = input_.movedim(0, dim).contiguous()
if sizes is not None:
assert len(sizes) == world_size
assert input_tensor.shape[0] == sum(sizes)
chunk_size = sizes[self.rank_in_group]
else:
assert input_tensor.shape[0] % world_size == 0
chunk_size = input_tensor.shape[0] // world_size
output_shape = (chunk_size, ) + input_tensor.shape[1:]
output = torch.empty(output_shape,
dtype=input_tensor.dtype,
device=input_tensor.device)
if sizes is not None:
pynccl_comm.reduce_scatterv(output, input_tensor, sizes=sizes)
else:
pynccl_comm.reduce_scatter(output, input_tensor)
# Reshape before returning
return output.movedim(0, dim).contiguous()
def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:
"""Sends a tensor to the destination rank in a blocking way"""
"""NOTE: `dst` is the local rank of the destination rank."""
if dst is None:
dst = (self.rank_in_group + 1) % self.world_size
pynccl_comm = self.pynccl_comm
if pynccl_comm is not None and not pynccl_comm.disabled:
pynccl_comm.send(tensor, dst)
else:
torch.distributed.send(tensor, self.ranks[dst], self.device_group)
def recv(self,
size: torch.Size,
dtype: torch.dtype,
src: Optional[int] = None) -> torch.Tensor:
"""Receives a tensor from the source rank."""
"""NOTE: `src` is the local rank of the source rank."""
if src is None:
src = (self.rank_in_group - 1) % self.world_size
tensor = torch.empty(size, dtype=dtype, device=self.device)
pynccl_comm = self.pynccl_comm
if pynccl_comm is not None and not pynccl_comm.disabled:
pynccl_comm.recv(tensor, src)
else:
torch.distributed.recv(tensor, self.ranks[src], self.device_group)
return tensor
def destroy(self):
if self.pynccl_comm is not None:
self.pynccl_comm = None
if self.ca_comm is not None:
self.ca_comm = None
if self.all2all_manager is not None:
self.all2all_manager.destroy()
self.all2all_manager = None
def all_gatherv(self,
input_: Union[torch.Tensor, list[torch.Tensor]],
dim: int = 0,
sizes: Optional[list[int]] = None):
if dim != 0:
raise NotImplementedError("only dim 0 all-gatherv is supported")
world_size = self.world_size
pynccl_comm = self.pynccl_comm
assert pynccl_comm is not None and not pynccl_comm.disabled
# 'sizes' is not needed if all inputs in the same group have the same
# shape
if sizes is not None and all(s == sizes[0] for s in sizes):
sizes = None
def _all_gather_single(input_: torch.Tensor,
sizes: Optional[list[int]] = None):
input_size = input_.size()
if sizes is not None:
assert len(sizes) == world_size
assert input_.shape[dim] == sizes[self.rank_in_group], (
f"{input_.shape[dim]} != {sizes[self.rank_in_group]}")
output_size = (sum(sizes), ) + input_size[1:]
else:
output_size = (input_size[0] * world_size, ) + input_size[1:]
# Allocate output tensor.
output_tensor = torch.empty(output_size,
dtype=input_.dtype,
device=input_.device)
if sizes is not None:
pynccl_comm.all_gatherv(output_tensor, input_, sizes=sizes)
else:
pynccl_comm.all_gather(output_tensor, input_)
return output_tensor
if isinstance(input_, torch.Tensor):
return _all_gather_single(input_, sizes)
output_list = []
pynccl_comm.group_start()
for inp in input_:
output_list.append(_all_gather_single(inp, sizes=sizes))
pynccl_comm.group_end()
return output_list
def dispatch(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False
) -> tuple[torch.Tensor, torch.Tensor]:
assert self.all2all_manager is not None
hidden_states, router_logits = self.all2all_manager.dispatch(
hidden_states, router_logits, is_sequence_parallel)
return hidden_states, router_logits
def combine(self,
hidden_states: torch.Tensor,
is_sequence_parallel: bool = False) -> torch.Tensor:
assert self.all2all_manager is not None
hidden_states = self.all2all_manager.combine(hidden_states,
is_sequence_parallel)
return hidden_states

View File

@@ -0,0 +1,180 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""This file is a pure Python wrapper for the cudart library.
It avoids the need to compile a separate shared library, and is
convenient for use when we just need to call a few functions.
"""
import ctypes
from dataclasses import dataclass
from typing import Any, Optional
# this line makes it possible to directly load `libcudart.so` using `ctypes`
import torch # noqa
import vllm.envs as envs
from vllm.logger import init_logger
logger = init_logger(__name__)
# === export types and functions from cudart to Python ===
# for the original cudart definition, please check
# https://docs.nvidia.com/cuda/cuda-runtime-api/index.html
cudaError_t = ctypes.c_int
cudaMemcpyKind = ctypes.c_int
class cudaIpcMemHandle_t(ctypes.Structure):
_fields_ = [("internal", ctypes.c_byte * 128)]
@dataclass
class Function:
name: str
restype: Any
argtypes: list[Any]
def find_loaded_library(lib_name) -> Optional[str]:
"""
According to according to https://man7.org/linux/man-pages/man5/proc_pid_maps.5.html,
the file `/proc/self/maps` contains the memory maps of the process, which includes the
shared libraries loaded by the process. We can use this file to find the path of the
a loaded library.
""" # noqa
found = False
with open("/proc/self/maps") as f:
for line in f:
if lib_name in line:
found = True
break
if not found:
# the library is not loaded in the current process
return None
# if lib_name is libcudart, we need to match a line with:
# address /path/to/libcudart-hash.so.11.0
start = line.index("/")
path = line[start:].strip()
filename = path.split("/")[-1]
assert filename.rpartition(".so")[0].startswith(lib_name), \
f"Unexpected filename: {filename} for library {lib_name}"
return path
class CudaRTLibrary:
exported_functions = [
# cudaError_t cudaSetDevice ( int device )
Function("cudaSetDevice", cudaError_t, [ctypes.c_int]),
# cudaError_t cudaDeviceSynchronize ( void )
Function("cudaDeviceSynchronize", cudaError_t, []),
# cudaError_t cudaDeviceReset ( void )
Function("cudaDeviceReset", cudaError_t, []),
# const char* cudaGetErrorString ( cudaError_t error )
Function("cudaGetErrorString", ctypes.c_char_p, [cudaError_t]),
# cudaError_t cudaMalloc ( void** devPtr, size_t size )
Function("cudaMalloc", cudaError_t,
[ctypes.POINTER(ctypes.c_void_p), ctypes.c_size_t]),
# cudaError_t cudaFree ( void* devPtr )
Function("cudaFree", cudaError_t, [ctypes.c_void_p]),
# cudaError_t cudaMemset ( void* devPtr, int value, size_t count )
Function("cudaMemset", cudaError_t,
[ctypes.c_void_p, ctypes.c_int, ctypes.c_size_t]),
# cudaError_t cudaMemcpy ( void* dst, const void* src, size_t count, cudaMemcpyKind kind ) # noqa
Function("cudaMemcpy", cudaError_t, [
ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, cudaMemcpyKind
]),
# cudaError_t cudaIpcGetMemHandle ( cudaIpcMemHandle_t* handle, void* devPtr ) # noqa
Function("cudaIpcGetMemHandle", cudaError_t,
[ctypes.POINTER(cudaIpcMemHandle_t), ctypes.c_void_p]),
# cudaError_t cudaIpcOpenMemHandle ( void** devPtr, cudaIpcMemHandle_t handle, unsigned int flags ) # noqa
Function("cudaIpcOpenMemHandle", cudaError_t, [
ctypes.POINTER(ctypes.c_void_p), cudaIpcMemHandle_t, ctypes.c_uint
]),
]
# class attribute to store the mapping from the path to the library
# to avoid loading the same library multiple times
path_to_library_cache: dict[str, Any] = {}
# class attribute to store the mapping from library path
# to the corresponding dictionary
path_to_dict_mapping: dict[str, dict[str, Any]] = {}
def __init__(self, so_file: Optional[str] = None):
if so_file is None:
so_file = find_loaded_library("libcudart")
if so_file is None:
so_file = envs.VLLM_CUDART_SO_PATH # fallback to env var
assert so_file is not None, \
(
"libcudart is not loaded in the current process, "
"try setting VLLM_CUDART_SO_PATH"
)
if so_file not in CudaRTLibrary.path_to_library_cache:
lib = ctypes.CDLL(so_file)
CudaRTLibrary.path_to_library_cache[so_file] = lib
self.lib = CudaRTLibrary.path_to_library_cache[so_file]
if so_file not in CudaRTLibrary.path_to_dict_mapping:
_funcs = {}
for func in CudaRTLibrary.exported_functions:
f = getattr(self.lib, func.name)
f.restype = func.restype
f.argtypes = func.argtypes
_funcs[func.name] = f
CudaRTLibrary.path_to_dict_mapping[so_file] = _funcs
self.funcs = CudaRTLibrary.path_to_dict_mapping[so_file]
def CUDART_CHECK(self, result: cudaError_t) -> None:
if result != 0:
error_str = self.cudaGetErrorString(result)
raise RuntimeError(f"CUDART error: {error_str}")
def cudaGetErrorString(self, error: cudaError_t) -> str:
return self.funcs["cudaGetErrorString"](error).decode("utf-8")
def cudaSetDevice(self, device: int) -> None:
self.CUDART_CHECK(self.funcs["cudaSetDevice"](device))
def cudaDeviceSynchronize(self) -> None:
self.CUDART_CHECK(self.funcs["cudaDeviceSynchronize"]())
def cudaDeviceReset(self) -> None:
self.CUDART_CHECK(self.funcs["cudaDeviceReset"]())
def cudaMalloc(self, size: int) -> ctypes.c_void_p:
devPtr = ctypes.c_void_p()
self.CUDART_CHECK(self.funcs["cudaMalloc"](ctypes.byref(devPtr), size))
return devPtr
def cudaFree(self, devPtr: ctypes.c_void_p) -> None:
self.CUDART_CHECK(self.funcs["cudaFree"](devPtr))
def cudaMemset(self, devPtr: ctypes.c_void_p, value: int,
count: int) -> None:
self.CUDART_CHECK(self.funcs["cudaMemset"](devPtr, value, count))
def cudaMemcpy(self, dst: ctypes.c_void_p, src: ctypes.c_void_p,
count: int) -> None:
cudaMemcpyDefault = 4
kind = cudaMemcpyDefault
self.CUDART_CHECK(self.funcs["cudaMemcpy"](dst, src, count, kind))
def cudaIpcGetMemHandle(self,
devPtr: ctypes.c_void_p) -> cudaIpcMemHandle_t:
handle = cudaIpcMemHandle_t()
self.CUDART_CHECK(self.funcs["cudaIpcGetMemHandle"](
ctypes.byref(handle), devPtr))
return handle
def cudaIpcOpenMemHandle(self,
handle: cudaIpcMemHandle_t) -> ctypes.c_void_p:
cudaIpcMemLazyEnablePeerAccess = 1
devPtr = ctypes.c_void_p()
self.CUDART_CHECK(self.funcs["cudaIpcOpenMemHandle"](
ctypes.byref(devPtr), handle, cudaIpcMemLazyEnablePeerAccess))
return devPtr

View File

@@ -0,0 +1,311 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from contextlib import contextmanager
from typing import Optional, Union
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.distributed.device_communicators.all_reduce_utils import (
CUSTOM_ALL_REDUCE_MAX_SIZES, gpu_p2p_access_check)
from vllm.distributed.parallel_state import in_the_same_node_as
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import cuda_device_count_stateless
try:
ops.meta_size()
custom_ar = True
except Exception:
# For CPUs
custom_ar = False
logger = init_logger(__name__)
def _can_p2p(rank: int, world_size: int) -> bool:
for i in range(world_size):
if i == rank:
continue
if envs.VLLM_SKIP_P2P_CHECK:
logger.info(
"Skipping P2P check and trusting the driver's P2P report.")
return torch.cuda.can_device_access_peer(rank, i)
if not gpu_p2p_access_check(rank, i):
return False
return True
def is_weak_contiguous(inp: torch.Tensor):
return inp.is_contiguous() or (inp.storage().nbytes() -
inp.storage_offset() * inp.element_size()
== inp.numel() * inp.element_size())
class CustomAllreduce:
_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
# max_size: max supported allreduce size
def __init__(self,
group: ProcessGroup,
device: Union[int, str, torch.device],
max_size=8192 * 1024,
symm_mem_enabled=False) -> None:
"""
Args:
group: the process group to work on. If None, it will use the
default process group.
device: the device to bind the CustomAllreduce to. If None,
it will be bound to f"cuda:{local_rank}".
It is the caller's responsibility to make sure each communicator
is bind to a unique device, and all communicators in this group
are in the same node.
"""
self._IS_CAPTURING = False
self.disabled = True
if not custom_ar:
# disable because of missing custom allreduce library
# e.g. in a non-GPU environment
logger.info("Custom allreduce is disabled because "
"of missing custom allreduce library")
return
self.group = group
assert dist.get_backend(group) != dist.Backend.NCCL, (
"CustomAllreduce should be attached to a non-NCCL group.")
if not all(in_the_same_node_as(group, source_rank=0)):
# No need to initialize custom allreduce for multi-node case.
logger.warning(
"Custom allreduce is disabled because this process group"
" spans across nodes.")
return
rank = dist.get_rank(group=self.group)
self.rank = rank
world_size = dist.get_world_size(group=self.group)
if world_size == 1:
# No need to initialize custom allreduce for single GPU case.
return
if world_size not in CustomAllreduce._SUPPORTED_WORLD_SIZES:
logger.warning(
"Custom allreduce is disabled due to an unsupported world"
" size: %d. Supported world sizes: %s. To silence this "
"warning, specify disable_custom_all_reduce=True explicitly.",
world_size, str(CustomAllreduce._SUPPORTED_WORLD_SIZES))
return
if isinstance(device, int):
device = torch.device(f"cuda:{device}")
elif isinstance(device, str):
device = torch.device(device)
# now `device` is a `torch.device` object
assert isinstance(device, torch.device)
self.device = device
device_capability = current_platform.get_device_capability(
).as_version_str()
if (current_platform.is_cuda() and symm_mem_enabled
and device_capability in CUSTOM_ALL_REDUCE_MAX_SIZES):
max_size = min(
CUSTOM_ALL_REDUCE_MAX_SIZES[device_capability][world_size],
max_size)
cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
if cuda_visible_devices:
device_ids = list(map(int, cuda_visible_devices.split(",")))
else:
device_ids = list(range(cuda_device_count_stateless()))
physical_device_id = device_ids[device.index]
tensor = torch.tensor([physical_device_id],
dtype=torch.int,
device="cpu")
gather_list = [
torch.tensor([0], dtype=torch.int, device="cpu")
for _ in range(world_size)
]
dist.all_gather(gather_list, tensor, group=self.group)
physical_device_ids = [t.item() for t in gather_list]
# test nvlink first, this will filter out most of the cases
# where custom allreduce is not supported
# this checks hardware and driver support for NVLink
assert current_platform.is_cuda_alike()
fully_connected = current_platform.is_fully_connected(
physical_device_ids)
if world_size > 2 and not fully_connected:
logger.warning(
"Custom allreduce is disabled because it's not supported on"
" more than two PCIe-only GPUs. To silence this warning, "
"specify disable_custom_all_reduce=True explicitly.")
return
# test P2P capability, this checks software/cudaruntime support
# this is expensive to compute at the first time
# then we cache the result
# On AMD GPU, p2p is always enabled between XGMI connected GPUs
if not current_platform.is_rocm() and not _can_p2p(rank, world_size):
logger.warning(
"Custom allreduce is disabled because your platform lacks "
"GPU P2P capability or P2P test failed. To silence this "
"warning, specify disable_custom_all_reduce=True explicitly.")
return
self.disabled = False
# Buffers memory are owned by this Python class and passed to C++.
# Metadata composes of two parts: metadata for synchronization and a
# temporary buffer for storing intermediate allreduce results.
self.meta_ptrs = self.create_shared_buffer(ops.meta_size() + max_size,
group=group,
uncached=True)
# This is a pre-registered IPC buffer. In eager mode, input tensors
# are first copied into this buffer before allreduce is performed
self.buffer_ptrs = self.create_shared_buffer(max_size, group=group)
# This is a buffer for storing the tuples of pointers pointing to
# IPC buffers from all ranks. Each registered tuple has size of
# 8*world_size bytes where world_size is at most 8. Allocating 8MB
# is enough for 131072 such tuples. The largest model I've seen only
# needs less than 10000 of registered tuples.
self.rank_data = torch.empty(8 * 1024 * 1024,
dtype=torch.uint8,
device=self.device)
self.max_size = max_size
self.rank = rank
self.world_size = world_size
self.fully_connected = fully_connected
self._ptr = ops.init_custom_ar(self.meta_ptrs, self.rank_data, rank,
self.fully_connected)
ops.register_buffer(self._ptr, self.buffer_ptrs)
@contextmanager
def capture(self):
"""
The main responsibility of this context manager is the
`register_graph_buffers` call at the end of the context.
It records all the buffer addresses used in the CUDA graph.
"""
try:
self._IS_CAPTURING = True
yield
finally:
self._IS_CAPTURING = False
if not self.disabled:
self.register_graph_buffers()
def register_graph_buffers(self):
handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr)
logger.info("Registering %d cuda graph addresses", len(offset))
# We cannot directly use `dist.all_gather_object` here
# because it is incompatible with `gloo` backend under inference mode.
# see https://github.com/pytorch/pytorch/issues/126032 for details.
all_data = [[None, None]
for _ in range(dist.get_world_size(group=self.group))]
all_data[self.rank] = [handle, offset]
ranks = sorted(dist.get_process_group_ranks(group=self.group))
for i, rank in enumerate(ranks):
dist.broadcast_object_list(all_data[i],
src=rank,
group=self.group,
device="cpu")
# Unpack list of tuples to tuple of lists.
handles = [d[0] for d in all_data] # type: ignore
offsets = [d[1] for d in all_data] # type: ignore
ops.register_graph_buffers(self._ptr, handles, offsets)
def should_custom_ar(self, inp: torch.Tensor):
if self.disabled:
return False
inp_size = inp.numel() * inp.element_size()
# custom allreduce requires input byte size to be multiples of 16
if inp_size % 16 != 0:
return False
if not is_weak_contiguous(inp):
return False
# for 4 or more non NVLink-capable GPUs, custom allreduce provides
# little performance improvement over NCCL.
if self.world_size == 2 or self.fully_connected:
return inp_size < self.max_size
return False
def all_reduce(self,
inp: torch.Tensor,
*,
out: torch.Tensor = None,
registered: bool = False):
"""Performs an out-of-place all reduce.
If registered is True, this assumes inp's pointer is already
IPC-registered. Otherwise, inp is first copied into a pre-registered
buffer.
"""
if out is None:
out = torch.empty_like(inp)
if registered:
ops.all_reduce(self._ptr, inp, out, 0, 0)
else:
ops.all_reduce(self._ptr, inp, out, self.buffer_ptrs[self.rank],
self.max_size)
return out
def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]:
"""The main allreduce API that provides support for cuda graph."""
# When custom allreduce is disabled, this will be None.
if self.disabled or not self.should_custom_ar(input):
return None
if self._IS_CAPTURING:
if torch.cuda.is_current_stream_capturing():
return self.all_reduce(input, registered=True)
else:
# If warm up, mimic the allocation pattern since custom
# allreduce is out-of-place.
return torch.empty_like(input)
else:
# Note: outside of cuda graph context, custom allreduce incurs a
# cost of cudaMemcpy, which should be small (<=1% of overall
# latency) compared to the performance gain of using custom kernels
return self.all_reduce(input, registered=False)
def close(self):
if not self.disabled and self._ptr:
if ops is not None:
ops.dispose(self._ptr)
self._ptr = 0
self.free_shared_buffer(self.meta_ptrs, rank=self.rank)
self.free_shared_buffer(self.buffer_ptrs, rank=self.rank)
def __del__(self):
self.close()
@staticmethod
def create_shared_buffer(size_in_bytes: int,
group: Optional[ProcessGroup] = None,
uncached: Optional[bool] = False) -> list[int]:
pointer, handle = ops.allocate_shared_buffer_and_handle(size_in_bytes)
world_size = dist.get_world_size(group=group)
rank = dist.get_rank(group=group)
handles = [None] * world_size
dist.all_gather_object(handles, handle, group=group)
pointers: list[int] = []
for i, h in enumerate(handles):
if i == rank:
pointers.append(pointer) # type: ignore
else:
pointers.append(ops.open_mem_handle(h))
return pointers
@staticmethod
def free_shared_buffer(pointers: list[int],
group: Optional[ProcessGroup] = None,
rank: Optional[int] = None) -> None:
if rank is None:
rank = dist.get_rank(group=group)
if ops is not None:
ops.free_shared_buffer(pointers[rank])

View File

@@ -0,0 +1,28 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch.distributed as dist
from flashinfer.comm.mnnvl import CommBackend as CommBackend
from vllm.utils.flashinfer import has_flashinfer_all2all
assert has_flashinfer_all2all(), "Flashinfer alltoallv module cannot be found"
class CustomCommunicator(CommBackend):
def __init__(self, group):
self._group = group
def Get_rank(self) -> int:
return self._group.rank()
def Get_size(self) -> int:
return self._group.size()
def allgather(self, data: int):
gathered = [None] * self.Get_size()
dist.all_gather_object(gathered, data, group=self._group)
return gathered
def Split(self, color: int, key: int) -> 'CustomCommunicator':
return self

View File

@@ -0,0 +1,340 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional, Union
# ===================== import region =====================
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup, ReduceOp
from vllm.distributed.device_communicators.pynccl_wrapper import (
NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum,
ncclRedOpTypeEnum, ncclUniqueId)
from vllm.distributed.utils import StatelessProcessGroup
from vllm.logger import init_logger
from vllm.utils import current_stream
logger = init_logger(__name__)
_NCCL_SYMM_OPS_REGISTERED = False
def register_nccl_symmetric_ops(pynccl_comm):
from vllm.distributed.device_communicators.pynccl_allocator import (
nccl_symm_mem_context)
from vllm.utils import direct_register_custom_op
global _NCCL_SYMM_OPS_REGISTERED
if _NCCL_SYMM_OPS_REGISTERED:
return
_NCCL_SYMM_OPS_REGISTERED = True
def all_reduce_symmetric_with_copy_impl(
input_tensor: torch.Tensor) -> torch.Tensor:
with nccl_symm_mem_context(pynccl_comm):
symm_input = torch.empty_like(input_tensor)
symm_output = torch.empty_like(input_tensor)
symm_input.copy_(input_tensor)
symm_output = pynccl_comm.all_reduce(symm_input, symm_output)
return symm_output
def all_reduce_symmetric_with_copy_fake(
input_tensor: torch.Tensor) -> torch.Tensor:
return torch.empty_like(input_tensor)
direct_register_custom_op(
op_name="all_reduce_symmetric_with_copy",
op_func=all_reduce_symmetric_with_copy_impl,
fake_impl=all_reduce_symmetric_with_copy_fake,
)
class PyNcclCommunicator:
def __init__(
self,
group: Union[ProcessGroup, StatelessProcessGroup],
device: Union[int, str, torch.device],
library_path: Optional[str] = None,
):
"""
Args:
group: the process group to work on. If None, it will use the
default process group.
device: the device to bind the PyNcclCommunicator to. If None,
it will be bound to f"cuda:{local_rank}".
library_path: the path to the NCCL library. If None, it will
use the default library path.
It is the caller's responsibility to make sure each communicator
is bind to a unique device.
"""
if not isinstance(group, StatelessProcessGroup):
assert dist.is_initialized()
assert dist.get_backend(group) != dist.Backend.NCCL, (
"PyNcclCommunicator should be attached to a non-NCCL group.")
# note: this rank is the rank in the group
self.rank = dist.get_rank(group)
self.world_size = dist.get_world_size(group)
else:
self.rank = group.rank
self.world_size = group.world_size
self.group = group
# if world_size == 1, no need to create communicator
if self.world_size == 1:
self.available = False
self.disabled = True
return
try:
self.nccl = NCCLLibrary(library_path)
except Exception:
# disable because of missing NCCL library
# e.g. in a non-GPU environment
self.available = False
self.disabled = True
return
self.available = True
self.disabled = False
self.nccl_version = self.nccl.ncclGetRawVersion()
logger.info("vLLM is using nccl==%s", self.nccl.ncclGetVersion())
if self.rank == 0:
# get the unique id from NCCL
self.unique_id = self.nccl.ncclGetUniqueId()
else:
# construct an empty unique id
self.unique_id = ncclUniqueId()
if not isinstance(group, StatelessProcessGroup):
tensor = torch.ByteTensor(list(self.unique_id.internal))
ranks = dist.get_process_group_ranks(group)
# arg `src` in `broadcast` is the global rank
dist.broadcast(tensor, src=ranks[0], group=group)
byte_list = tensor.tolist()
for i, byte in enumerate(byte_list):
self.unique_id.internal[i] = byte
else:
self.unique_id = group.broadcast_obj(self.unique_id, src=0)
if isinstance(device, int):
device = torch.device(f"cuda:{device}")
elif isinstance(device, str):
device = torch.device(device)
# now `device` is a `torch.device` object
assert isinstance(device, torch.device)
self.device = device
# nccl communicator and stream will use this device
# `torch.cuda.device` is a context manager that changes the
# current cuda device to the specified one
with torch.cuda.device(device):
self.comm: ncclComm_t = self.nccl.ncclCommInitRank(
self.world_size, self.unique_id, self.rank)
stream = current_stream()
# A small all_reduce for warmup.
data = torch.zeros(1, device=device)
self.all_reduce(data)
stream.synchronize()
del data
def all_reduce(self,
in_tensor: torch.Tensor,
out_tensor: torch.Tensor = None,
op: ReduceOp = ReduceOp.SUM,
stream=None) -> torch.Tensor:
if self.disabled:
return None
# nccl communicator created on a specific device
# will only work on tensors on the same device
# otherwise it will cause "illegal memory access"
assert in_tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {in_tensor.device}")
if out_tensor is None:
out_tensor = torch.empty_like(in_tensor)
if stream is None:
stream = current_stream()
self.nccl.ncclAllReduce(buffer_type(in_tensor.data_ptr()),
buffer_type(out_tensor.data_ptr()),
in_tensor.numel(),
ncclDataTypeEnum.from_torch(in_tensor.dtype),
ncclRedOpTypeEnum.from_torch(op), self.comm,
cudaStream_t(stream.cuda_stream))
return out_tensor
def all_gather(self,
output_tensor: torch.Tensor,
input_tensor: torch.Tensor,
stream=None):
if self.disabled:
return
# nccl communicator created on a specific device
# will only work on tensors on the same device
# otherwise it will cause "illegal memory access"
assert input_tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {input_tensor.device}")
if stream is None:
stream = current_stream()
self.nccl.ncclAllGather(
buffer_type(input_tensor.data_ptr()),
buffer_type(output_tensor.data_ptr()), input_tensor.numel(),
ncclDataTypeEnum.from_torch(input_tensor.dtype), self.comm,
cudaStream_t(stream.cuda_stream))
def all_gatherv(
self,
output_tensor: torch.Tensor,
input_tensor: torch.Tensor,
sizes: list[int],
stream=None,
):
if self.disabled:
return
# nccl communicator created on a specific device
# will only work on tensors on the same device
# otherwise it will cause "illegal memory access"
assert input_tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {input_tensor.device}")
if stream is None:
stream = current_stream()
assert output_tensor.shape[0] == sum(sizes)
split_offset = 0
self.nccl.ncclGroupStart()
for root, split_size in enumerate(sizes):
dst_slice = output_tensor[split_offset:split_offset + split_size]
self.nccl.ncclBroadcast(
buffer_type(input_tensor.data_ptr()),
buffer_type(dst_slice.data_ptr()),
dst_slice.numel(),
ncclDataTypeEnum.from_torch(input_tensor.dtype),
root,
self.comm,
cudaStream_t(stream.cuda_stream),
)
split_offset += split_size
self.nccl.ncclGroupEnd()
def reduce_scatter(self,
output_tensor: torch.Tensor,
input_tensor: torch.Tensor,
op: ReduceOp = ReduceOp.SUM,
stream=None):
if self.disabled:
return
# nccl communicator created on a specific device
# will only work on tensors on the same device
# otherwise it will cause "illegal memory access"
assert input_tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {input_tensor.device}")
if stream is None:
stream = current_stream()
self.nccl.ncclReduceScatter(
buffer_type(input_tensor.data_ptr()),
buffer_type(output_tensor.data_ptr()), output_tensor.numel(),
ncclDataTypeEnum.from_torch(input_tensor.dtype),
ncclRedOpTypeEnum.from_torch(op), self.comm,
cudaStream_t(stream.cuda_stream))
def reduce_scatterv(
self,
output_tensor: torch.Tensor,
input_tensor: torch.Tensor,
sizes: list[int],
op: ReduceOp = ReduceOp.SUM,
stream=None,
):
if self.disabled:
return
# nccl communicator created on a specific device
# will only work on tensors on the same device
# otherwise it will cause "illegal memory access"
assert input_tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {input_tensor.device}")
if stream is None:
stream = current_stream()
split_offset = 0
self.nccl.ncclGroupStart()
for root, split_size in enumerate(sizes):
chunk = input_tensor[split_offset:split_offset + split_size, ...]
self.nccl.ncclReduce(
buffer_type(chunk.data_ptr()),
buffer_type(output_tensor.data_ptr()), chunk.numel(),
ncclDataTypeEnum.from_torch(input_tensor.dtype),
ncclRedOpTypeEnum.from_torch(op), root, self.comm,
cudaStream_t(stream.cuda_stream))
split_offset += split_size
self.nccl.ncclGroupEnd()
def send(self, tensor: torch.Tensor, dst: int, stream=None):
if self.disabled:
return
assert tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}")
if stream is None:
stream = current_stream()
self.nccl.ncclSend(buffer_type(tensor.data_ptr()), tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype), dst,
self.comm, cudaStream_t(stream.cuda_stream))
def recv(self, tensor: torch.Tensor, src: int, stream=None):
if self.disabled:
return
assert tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}")
if stream is None:
stream = current_stream()
self.nccl.ncclRecv(buffer_type(tensor.data_ptr()), tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype), src,
self.comm, cudaStream_t(stream.cuda_stream))
def broadcast(self, tensor: torch.Tensor, src: int, stream=None):
if self.disabled:
return
assert tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}")
if stream is None:
stream = current_stream()
if src == self.rank:
sendbuff = buffer_type(tensor.data_ptr())
# NCCL requires the sender also to have a receive buffer
recvbuff = buffer_type(tensor.data_ptr())
else:
sendbuff = buffer_type()
recvbuff = buffer_type(tensor.data_ptr())
self.nccl.ncclBroadcast(sendbuff, recvbuff, tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype), src,
self.comm, cudaStream_t(stream.cuda_stream))
def group_start(self):
self.nccl.ncclGroupStart()
def group_end(self):
self.nccl.ncclGroupEnd()
def register_comm_window(self, tensor: torch.Tensor):
return self.nccl.ncclCommWindowRegister(
self.comm,
buffer_type(tensor.data_ptr()),
tensor.numel() * tensor.element_size(),
1,
)
def register_comm_window_raw(self, ptr: int, size: int):
return self.nccl.ncclCommWindowRegister(self.comm, buffer_type(ptr),
size, 1)
def deregister_comm_window(self, window):
return self.nccl.ncclCommWindowDeregister(self.comm, window)

View File

@@ -0,0 +1,186 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import atexit
import contextlib
import tempfile
from typing import Any, Optional
import torch
from packaging import version
from torch.cuda.memory import CUDAPluggableAllocator
from torch.utils.cpp_extension import load_inline
from vllm import envs
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import find_nccl_include_paths
logger = init_logger(__name__)
nccl_allocator_source = """
#include <nccl.h>
extern "C" {
void* nccl_alloc_plug(size_t size, int device, void* stream) {
void* ptr;
ncclResult_t err = ncclMemAlloc(&ptr, size);
return ptr;
}
void nccl_free_plug(void* ptr, size_t size, int device, void* stream) {
ncclResult_t err = ncclMemFree(ptr);
}
}
"""
_allocator = None
_allocator_wrapper = None
_mem_pool = None
_registered_base_addrs = set()
_graph_pool_id = None
_nccl_allocator_failed_to_compile = False
_cached_pool_snapshot = None
def is_symmetric_memory_enabled():
global _nccl_allocator_failed_to_compile
return envs.VLLM_USE_NCCL_SYMM_MEM and not _nccl_allocator_failed_to_compile
def is_symmetric_memory_tensor(tensor: torch.Tensor):
if not is_symmetric_memory_enabled() or _cached_pool_snapshot is None:
return False
for segment in _cached_pool_snapshot:
for block in segment["blocks"]:
if block["address"] == tensor.untyped_storage().data_ptr():
return True
return False
def set_graph_pool_id(graph_pool_id):
global _graph_pool_id
_graph_pool_id = graph_pool_id
def compile_nccl_allocator():
global _allocator, _allocator_wrapper, _nccl_allocator_failed_to_compile
if not current_platform.is_cuda():
_nccl_allocator_failed_to_compile = True
return
try:
out_dir = tempfile.gettempdir()
nccl_allocator_libname = "nccl_allocator"
nccl_include_paths = find_nccl_include_paths()
load_inline(
name=nccl_allocator_libname,
cpp_sources=nccl_allocator_source,
with_cuda=True,
extra_ldflags=["-lnccl"],
verbose=envs.VLLM_LOGGING_LEVEL == "DEBUG",
is_python_module=False,
build_directory=out_dir,
extra_include_paths=nccl_include_paths,
)
_allocator_wrapper = CUDAPluggableAllocator(
f"{out_dir}/{nccl_allocator_libname}.so",
"nccl_alloc_plug",
"nccl_free_plug",
)
_allocator = _allocator_wrapper.allocator()
except Exception as e:
_nccl_allocator_failed_to_compile = True
logger.warning(
"Failed to compile NCCL memory allocator. "
"Symmetric memory will be disabled. "
"This is expected if NCCL headers are not available. "
"optionally set VLLM_NCCL_INCLUDE_PATH to point to a directory "
"containing the NCCL header. "
"Error: %s", str(e))
def get_nccl_mem_pool():
global _mem_pool, _nccl_allocator_failed_to_compile
if _mem_pool is None and not _nccl_allocator_failed_to_compile:
compile_nccl_allocator()
if _allocator is not None:
_mem_pool = torch.cuda.MemPool(_allocator)
return _mem_pool
def _cleanup_nccl_mem_pool():
global _mem_pool
_mem_pool = None
def _cleanup_nccl_allocator_wrapper():
global _allocator_wrapper
_allocator_wrapper = None
atexit.register(_cleanup_nccl_mem_pool)
atexit.register(_cleanup_nccl_allocator_wrapper)
class nccl_symm_mem_context:
def __init__(
self,
pynccl_comm: PyNcclCommunicator,
disabled: bool = False,
):
self.disabled = (disabled or not is_symmetric_memory_enabled()
or pynccl_comm.world_size == 1
or not current_platform.is_cuda()
or get_nccl_mem_pool() is None or version.parse(
torch.__version__) < version.parse("2.8.0.a0"))
if self.disabled:
self.pynccl_comm: Optional[PyNcclCommunicator] = None
self._mem_pool_ctx: contextlib.AbstractContextManager[
Any] = contextlib.nullcontext()
self.is_graph_capture = None
self.device = None
else:
self.pynccl_comm = pynccl_comm
self._mem_pool_ctx = torch.cuda.use_mem_pool(get_nccl_mem_pool())
self.is_graph_capture = torch.cuda.is_current_stream_capturing()
self.device = torch.cuda.current_device()
def __enter__(self):
if self.disabled:
return self
assert (
self.pynccl_comm
is not None), "Symmetric memory requires pynccl to be initalized"
assert (
self.pynccl_comm.nccl_version >= 22703
), "NCCL version 2.27.3 or higher is required for NCCL symmetric memory"
if self.is_graph_capture:
assert (
_graph_pool_id
is not None), "graph_pool_id is not set under graph capture"
# Pause graph memory pool to use symmetric memory with cuda graph
torch._C._cuda_endAllocateToPool(self.device, _graph_pool_id)
self._mem_pool_ctx.__enter__()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if self.disabled:
return
global _cached_pool_snapshot
global _registered_base_addrs
self._mem_pool_ctx.__exit__(exc_type, exc_val, exc_tb)
_pool = get_nccl_mem_pool()
assert _pool is not None
_cached_pool_snapshot = _pool.snapshot()
assert self.pynccl_comm is not None
for segment in _cached_pool_snapshot:
if segment["address"] not in _registered_base_addrs:
self.pynccl_comm.register_comm_window_raw(
segment["address"], segment["total_size"])
_registered_base_addrs.add(segment["address"])
if self.is_graph_capture:
torch._C._cuda_beginAllocateCurrentThreadToPool(
self.device, _graph_pool_id)

View File

@@ -0,0 +1,416 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# This file is a pure Python wrapper for the NCCL library.
# The main purpose is to use NCCL combined with CUDA graph.
# Before writing this script, we tried the following approach:
# 1. We tried to use `cupy`, it calls NCCL correctly, but `cupy` itself
# often gets stuck when initializing the NCCL communicator.
# 2. We tried to use `torch.distributed`, but `torch.distributed.all_reduce`
# contains many other potential cuda APIs, that are not allowed during
# capturing the CUDA graph. For further details, please check
# https://discuss.pytorch.org/t/pytorch-cudagraph-with-nccl-operation-failed/ .
#
# Another rejected idea is to write a C/C++ binding for NCCL. It is usually
# doable, but we often encounter issues related with nccl versions, and need
# to switch between different versions of NCCL. See
# https://github.com/NVIDIA/nccl/issues/1234 for more details.
# A C/C++ binding is not flexible enough to handle this. It requires
# recompilation of the code every time we want to switch between different
# versions. This current implementation, with a **pure** Python wrapper, is
# more flexible. We can easily switch between different versions of NCCL by
# changing the environment variable `VLLM_NCCL_SO_PATH`, or the `so_file`
# variable in the code.
import ctypes
import platform
from dataclasses import dataclass
from typing import Any, Optional
import torch
from torch.distributed import ReduceOp
from vllm.logger import init_logger
from vllm.utils import find_nccl_library
logger = init_logger(__name__)
# === export types and functions from nccl to Python ===
# for the original nccl definition, please check
# https://github.com/NVIDIA/nccl/blob/master/src/nccl.h.in
ncclResult_t = ctypes.c_int
ncclComm_t = ctypes.c_void_p
ncclWindow_t = ctypes.c_void_p
class ncclUniqueId(ctypes.Structure):
_fields_ = [("internal", ctypes.c_byte * 128)]
cudaStream_t = ctypes.c_void_p
buffer_type = ctypes.c_void_p
ncclDataType_t = ctypes.c_int
class ncclDataTypeEnum:
ncclInt8 = 0
ncclChar = 0
ncclUint8 = 1
ncclInt32 = 2
ncclInt = 2
ncclUint32 = 3
ncclInt64 = 4
ncclUint64 = 5
ncclFloat16 = 6
ncclHalf = 6
ncclFloat32 = 7
ncclFloat = 7
ncclFloat64 = 8
ncclDouble = 8
ncclBfloat16 = 9
ncclNumTypes = 10
@classmethod
def from_torch(cls, dtype: torch.dtype) -> int:
if dtype == torch.int8:
return cls.ncclInt8
if dtype == torch.uint8:
return cls.ncclUint8
if dtype == torch.int32:
return cls.ncclInt32
if dtype == torch.int64:
return cls.ncclInt64
if dtype == torch.float16:
return cls.ncclFloat16
if dtype == torch.float32:
return cls.ncclFloat32
if dtype == torch.float64:
return cls.ncclFloat64
if dtype == torch.bfloat16:
return cls.ncclBfloat16
raise ValueError(f"Unsupported dtype: {dtype}")
ncclRedOp_t = ctypes.c_int
class ncclRedOpTypeEnum:
ncclSum = 0
ncclProd = 1
ncclMax = 2
ncclMin = 3
ncclAvg = 4
ncclNumOps = 5
@classmethod
def from_torch(cls, op: ReduceOp) -> int:
if op == ReduceOp.SUM:
return cls.ncclSum
if op == ReduceOp.PRODUCT:
return cls.ncclProd
if op == ReduceOp.MAX:
return cls.ncclMax
if op == ReduceOp.MIN:
return cls.ncclMin
if op == ReduceOp.AVG:
return cls.ncclAvg
raise ValueError(f"Unsupported op: {op}")
@dataclass
class Function:
name: str
restype: Any
argtypes: list[Any]
class NCCLLibrary:
exported_functions = [
# const char* ncclGetErrorString(ncclResult_t result)
Function("ncclGetErrorString", ctypes.c_char_p, [ncclResult_t]),
# ncclResult_t ncclGetVersion(int *version);
Function("ncclGetVersion", ncclResult_t,
[ctypes.POINTER(ctypes.c_int)]),
# ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId);
Function("ncclGetUniqueId", ncclResult_t,
[ctypes.POINTER(ncclUniqueId)]),
# ncclResult_t ncclCommInitRank(
# ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank);
# note that ncclComm_t is a pointer type, so the first argument
# is a pointer to a pointer
Function("ncclCommInitRank", ncclResult_t, [
ctypes.POINTER(ncclComm_t), ctypes.c_int, ncclUniqueId,
ctypes.c_int
]),
# ncclResult_t ncclAllReduce(
# const void* sendbuff, void* recvbuff, size_t count,
# ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,
# cudaStream_t stream);
# note that cudaStream_t is a pointer type, so the last argument
# is a pointer
Function("ncclAllReduce", ncclResult_t, [
buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t,
ncclRedOp_t, ncclComm_t, cudaStream_t
]),
# ncclResult_t ncclReduce(
# const void* sendbuff, void* recvbuff, size_t count,
# ncclDataType_t datatype, ncclRedOp_t op, int root,
# ncclComm_t comm, cudaStream_t stream);
# note that cudaStream_t is a pointer type, so the last argument
# is a pointer
Function("ncclReduce", ncclResult_t, [
buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t,
ncclRedOp_t, ctypes.c_int, ncclComm_t, cudaStream_t
]),
# ncclResult_t ncclAllGather(
# const void* sendbuff, void* recvbuff, size_t count,
# ncclDataType_t datatype, ncclComm_t comm,
# cudaStream_t stream);
# note that cudaStream_t is a pointer type, so the last argument
# is a pointer
Function("ncclAllGather", ncclResult_t, [
buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t,
ncclComm_t, cudaStream_t
]),
# ncclResult_t ncclReduceScatter(
# const void* sendbuff, void* recvbuff, size_t count,
# ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,
# cudaStream_t stream);
# note that cudaStream_t is a pointer type, so the last argument
# is a pointer
Function("ncclReduceScatter", ncclResult_t, [
buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t,
ncclRedOp_t, ncclComm_t, cudaStream_t
]),
# ncclResult_t ncclSend(
# const void* sendbuff, size_t count, ncclDataType_t datatype,
# int dest, ncclComm_t comm, cudaStream_t stream);
Function("ncclSend", ncclResult_t, [
buffer_type, ctypes.c_size_t, ncclDataType_t, ctypes.c_int,
ncclComm_t, cudaStream_t
]),
# ncclResult_t ncclRecv(
# void* recvbuff, size_t count, ncclDataType_t datatype,
# int src, ncclComm_t comm, cudaStream_t stream);
Function("ncclRecv", ncclResult_t, [
buffer_type, ctypes.c_size_t, ncclDataType_t, ctypes.c_int,
ncclComm_t, cudaStream_t
]),
# ncclResult_t ncclBroadcast(
# const void* sendbuff, void* recvbuff, size_t count,
# ncclDataType_t datatype, int root, ncclComm_t comm,
# cudaStream_t stream);
Function("ncclBroadcast", ncclResult_t, [
buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t,
ctypes.c_int, ncclComm_t, cudaStream_t
]),
# be cautious! this is a collective call, it will block until all
# processes in the communicator have called this function.
# because Python object destruction can happen in random order,
# it is better not to call it at all.
# ncclResult_t ncclCommDestroy(ncclComm_t comm);
Function("ncclCommDestroy", ncclResult_t, [ncclComm_t]),
# ncclResult_t ncclGroupStart();
Function("ncclGroupStart", ncclResult_t, []),
# ncclResult_t ncclGroupEnd();
Function("ncclGroupEnd", ncclResult_t, []),
# ncclResult_t ncclCommWindowRegister(
# ncclComm_t comm, void* buff, size_t size,
# ncclWindow_t* win, int winFlags);
Function(
"ncclCommWindowRegister",
ncclResult_t,
[
ncclComm_t,
buffer_type,
ctypes.c_size_t,
ctypes.POINTER(ncclWindow_t),
ctypes.c_int,
],
),
# ncclResult_t ncclCommWindowDeregister(
# ncclComm_t comm, ncclWindow_t win);
Function("ncclCommWindowDeregister", ncclResult_t,
[ncclComm_t, ncclWindow_t]),
]
# class attribute to store the mapping from the path to the library
# to avoid loading the same library multiple times
path_to_library_cache: dict[str, Any] = {}
# class attribute to store the mapping from library path
# to the corresponding dictionary
path_to_dict_mapping: dict[str, dict[str, Any]] = {}
def __init__(self, so_file: Optional[str] = None):
so_file = so_file or find_nccl_library()
try:
if so_file not in NCCLLibrary.path_to_dict_mapping:
lib = ctypes.CDLL(so_file)
NCCLLibrary.path_to_library_cache[so_file] = lib
self.lib = NCCLLibrary.path_to_library_cache[so_file]
except Exception as e:
logger.error(
"Failed to load NCCL library from %s. "
"It is expected if you are not running on NVIDIA/AMD GPUs."
"Otherwise, the nccl library might not exist, be corrupted "
"or it does not support the current platform %s. "
"If you already have the library, please set the "
"environment variable VLLM_NCCL_SO_PATH"
" to point to the correct nccl library path.", so_file,
platform.platform())
raise e
if so_file not in NCCLLibrary.path_to_dict_mapping:
_funcs: dict[str, Any] = {}
for func in NCCLLibrary.exported_functions:
f = getattr(self.lib, func.name)
f.restype = func.restype
f.argtypes = func.argtypes
_funcs[func.name] = f
NCCLLibrary.path_to_dict_mapping[so_file] = _funcs
self._funcs = NCCLLibrary.path_to_dict_mapping[so_file]
def ncclGetErrorString(self, result: ncclResult_t) -> str:
return self._funcs["ncclGetErrorString"](result).decode("utf-8")
def NCCL_CHECK(self, result: ncclResult_t) -> None:
if result != 0:
error_str = self.ncclGetErrorString(result)
raise RuntimeError(f"NCCL error: {error_str}")
def ncclGetRawVersion(self) -> int:
version = ctypes.c_int()
self.NCCL_CHECK(self._funcs["ncclGetVersion"](ctypes.byref(version)))
# something like 21903
return version.value
def ncclGetVersion(self) -> str:
version_str = str(self.ncclGetRawVersion())
# something like 21903 --> "2.19.3"
major = version_str[0].lstrip("0")
minor = version_str[1:3].lstrip("0")
patch = version_str[3:].lstrip("0")
return f"{major}.{minor}.{patch}"
def ncclGetUniqueId(self) -> ncclUniqueId:
unique_id = ncclUniqueId()
self.NCCL_CHECK(self._funcs["ncclGetUniqueId"](
ctypes.byref(unique_id)))
return unique_id
def unique_id_from_bytes(self, data: bytes) -> ncclUniqueId:
if len(data) != 128:
raise ValueError(
f"Expected 128 bytes for ncclUniqueId, got {len(data)} bytes")
unique_id = ncclUniqueId()
ctypes.memmove(ctypes.addressof(unique_id.internal), data, 128)
return unique_id
def ncclCommInitRank(self, world_size: int, unique_id: ncclUniqueId,
rank: int) -> ncclComm_t:
comm = ncclComm_t()
self.NCCL_CHECK(self._funcs["ncclCommInitRank"](ctypes.byref(comm),
world_size, unique_id,
rank))
return comm
def ncclAllReduce(self, sendbuff: buffer_type, recvbuff: buffer_type,
count: int, datatype: int, op: int, comm: ncclComm_t,
stream: cudaStream_t) -> None:
# `datatype` actually should be `ncclDataType_t`
# and `op` should be `ncclRedOp_t`
# both are aliases of `ctypes.c_int`
# when we pass int to a function, it will be converted to `ctypes.c_int`
# by ctypes automatically
self.NCCL_CHECK(self._funcs["ncclAllReduce"](sendbuff, recvbuff, count,
datatype, op, comm,
stream))
def ncclReduce(self, sendbuff: buffer_type, recvbuff: buffer_type,
count: int, datatype: int, op: int, root: int,
comm: ncclComm_t, stream: cudaStream_t) -> None:
# `datatype` actually should be `ncclDataType_t`
# and `op` should be `ncclRedOp_t`
# both are aliases of `ctypes.c_int`
# when we pass int to a function, it will be converted to `ctypes.c_int`
# by ctypes automatically
self.NCCL_CHECK(self._funcs["ncclReduce"](sendbuff, recvbuff, count,
datatype, op, root, comm,
stream))
def ncclReduceScatter(self, sendbuff: buffer_type, recvbuff: buffer_type,
count: int, datatype: int, op: int, comm: ncclComm_t,
stream: cudaStream_t) -> None:
# `datatype` actually should be `ncclDataType_t`
# and `op` should be `ncclRedOp_t`
# both are aliases of `ctypes.c_int`
# when we pass int to a function, it will be converted to `ctypes.c_int`
# by ctypes automatically
self.NCCL_CHECK(self._funcs["ncclReduceScatter"](sendbuff, recvbuff,
count, datatype, op,
comm, stream))
def ncclAllGather(self, sendbuff: buffer_type, recvbuff: buffer_type,
count: int, datatype: int, comm: ncclComm_t,
stream: cudaStream_t) -> None:
# `datatype` actually should be `ncclDataType_t`
# which is an aliases of `ctypes.c_int`
# when we pass int to a function, it will be converted to `ctypes.c_int`
# by ctypes automatically
self.NCCL_CHECK(self._funcs["ncclAllGather"](sendbuff, recvbuff, count,
datatype, comm, stream))
def ncclSend(self, sendbuff: buffer_type, count: int, datatype: int,
dest: int, comm: ncclComm_t, stream: cudaStream_t) -> None:
self.NCCL_CHECK(self._funcs["ncclSend"](sendbuff, count, datatype,
dest, comm, stream))
def ncclRecv(self, recvbuff: buffer_type, count: int, datatype: int,
src: int, comm: ncclComm_t, stream: cudaStream_t) -> None:
self.NCCL_CHECK(self._funcs["ncclRecv"](recvbuff, count, datatype, src,
comm, stream))
def ncclBroadcast(self, sendbuff: buffer_type, recvbuff: buffer_type,
count: int, datatype: int, root: int, comm: ncclComm_t,
stream: cudaStream_t) -> None:
self.NCCL_CHECK(self._funcs["ncclBroadcast"](sendbuff, recvbuff, count,
datatype, root, comm,
stream))
def ncclCommDestroy(self, comm: ncclComm_t) -> None:
self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm))
def ncclGroupStart(self) -> None:
self.NCCL_CHECK(self._funcs["ncclGroupStart"]())
def ncclGroupEnd(self) -> None:
self.NCCL_CHECK(self._funcs["ncclGroupEnd"]())
def ncclCommWindowRegister(self, comm: ncclComm_t, buff: buffer_type,
size: int, win_flags: int) -> ncclWindow_t:
window = ncclWindow_t()
self.NCCL_CHECK(self._funcs["ncclCommWindowRegister"](
comm, buff, size, ctypes.byref(window), win_flags))
return window
def ncclCommWindowDeregister(self, comm: ncclComm_t,
window: ncclWindow_t) -> None:
self.NCCL_CHECK(self._funcs["ncclCommWindowDeregister"](comm, window))
__all__ = [
"NCCLLibrary", "ncclDataTypeEnum", "ncclRedOpTypeEnum", "ncclUniqueId",
"ncclComm_t", "cudaStream_t", "buffer_type"
]

View File

@@ -0,0 +1,278 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from enum import Enum
from typing import Union
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.config import get_current_vllm_config
from vllm.distributed.parallel_state import in_the_same_node_as
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import cuda_device_count_stateless
logger = init_logger(__name__)
try:
ops.qr_max_size()
quick_ar = True
except Exception:
# For CPUs and CUDA
quick_ar = False
def is_weak_contiguous(inp: torch.Tensor):
return inp.is_contiguous() or (inp.storage().nbytes() -
inp.storage_offset() * inp.element_size()
== inp.numel() * inp.element_size())
class QuickReduceRegime(Enum):
FP = 0
INT8 = 1
INT6 = 2
INT4 = 3
NONE = 4
MB = 1024 * 1024
class QuickAllReduce:
_SUPPORTED_WORLD_SIZES = [2, 4, 8]
_SUPPORTED_DTYPES = [torch.float16, torch.bfloat16]
# The following data is based on kernel tests.
# In this order [FP, INT8, INT6, INT4].
_QR_MIN_SIZE = {
(torch.float16, 2): [1 * MB, 2 * MB, 2 * MB, 1 * MB],
(torch.float16, 4): [1 * MB, 16 * MB, 4 * MB, 2 * MB],
(torch.float16, 8): [16 * MB, 4 * MB, 4 * MB, 2 * MB],
(torch.bfloat16, 2): [2 * MB, 8 * MB, 8 * MB, 8 * MB],
(torch.bfloat16, 4): [8 * MB, 64 * MB, 64 * MB, 16 * MB],
(torch.bfloat16, 8): [16 * MB, 2048 * MB, 2048 * MB, 2048 * MB],
}
def __init__(self, group: ProcessGroup,
device: Union[int, str, torch.device]) -> None:
"""
Custom allreduce provides non-destructive acceleration and is
available for CUDA and ROCm MI300 series.
Custom quick allreduce leverages quantization for further
acceleration on ROCm. It currently supports Q8, Q6, and Q4
quantization formats and FP(float16, bfloat16).
Quick allreduce is designed as a complement to custom allreduce.
Its initialization requires even stricter conditions.
Only the ROCm MI300 series is supported for quick allreduce at
this time.
Args:
group: the process group to work on. If None, it will use the
default process group.
device: the device to bind the CustomAllreduce to. If None,
it will be bound to f"cuda:{local_rank}".
It is the caller's responsibility to make sure each communicator
is bind to a unique device, and all communicators in this group
are in the same node.
"""
self.disabled = True
if not self._rocm_arch_available():
logger.debug(
"Custom quick allreduce is only supported on ROCm MI300 series."
)
return
if not quick_ar:
# disable because of missing quick reduce library
# e.g. in a cuda environment
logger.info("Custom quick allreduce is disabled because "
"of missing custom quick allreduce library")
return
self.group = group
assert dist.get_backend(group) != dist.Backend.NCCL, (
"Custom quick allreduce should be attached to a non-NCCL group.")
if not all(in_the_same_node_as(group, source_rank=0)):
# No need to initialize custom quick allreduce for
# multi-node case.
logger.warning("Custom quick allreduce is disabled because this "
"process group spans across nodes.")
return
rank = dist.get_rank(group=self.group)
world_size = dist.get_world_size(group=self.group)
self.rank = rank
self.world_size = world_size
if world_size == 1:
# No need to initialize QuickReduce for single GPU case.
return
if world_size not in QuickAllReduce._SUPPORTED_WORLD_SIZES:
logger.warning(
"Custom quick allreduce is disabled due to an "
"unsupported world size: %d. Supported world sizes: %s.",
world_size, str(QuickAllReduce._SUPPORTED_WORLD_SIZES))
return
if isinstance(device, int):
device = torch.device(f"cuda:{device}")
elif isinstance(device, str):
device = torch.device(device)
assert isinstance(device, torch.device)
self.device = device
cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
if cuda_visible_devices:
device_ids = list(map(int, cuda_visible_devices.split(",")))
else:
device_ids = list(range(cuda_device_count_stateless()))
physical_device_id = device_ids[device.index]
tensor = torch.tensor([physical_device_id],
dtype=torch.int,
device="cpu")
gather_list = [
torch.tensor([0], dtype=torch.int, device="cpu")
for _ in range(self.world_size)
]
dist.all_gather(gather_list, tensor, group=self.group)
physical_device_ids = [t.item() for t in gather_list]
# test nvlink first, this will filter out most of the cases
# where custom quick allreduce is not supported
# this checks hardware and driver support for NVLink
assert current_platform.is_cuda_alike()
self.fully_connected = current_platform.is_fully_connected(
physical_device_ids)
if self.world_size > 2 and not self.fully_connected:
logger.debug(
"Custom quick allreduce is disabled because it's not supported "
"on more than two PCIe-only GPUs. ")
return
self.init_quick_all_reduce()
def init_quick_all_reduce(self):
# On RocM, bfloat16 kernels are slower than fp16
# due to slower match operations
# If environment variable is set to 1, we convert input to fp16
self.use_fp16_kernels = envs.VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16
regime_str = envs.VLLM_ROCM_QUICK_REDUCE_QUANTIZATION
if regime_str not in QuickReduceRegime.__members__:
logger.warning(
"Custom quick allreduce:",
f"Invalid quantization level: {regime_str}. "
"Supported levels: "
f"{list(QuickReduceRegime.__members__.keys())}")
return
if regime_str == "NONE":
logger.debug("Custom quick allreduce is disabled based "
"on env variable "
"VLLM_ROCM_QUICK_REDUCE_QUANTIZATION='NONE'")
return
self.qr_quant_level = QuickReduceRegime[regime_str]
vllm_config = get_current_vllm_config()
if vllm_config is not None and \
hasattr(vllm_config, "model_config") and \
hasattr(vllm_config.model_config, "dtype"):
dtype = vllm_config.model_config.dtype
if dtype not in [torch.float16, torch.bfloat16]:
logger.debug(
"Custom quick allreduce disabled: only supports "
"float16 and float16, but get %s.", dtype)
return
if dtype == torch.bfloat16 and self.use_fp16_kernels:
logger.info(
"Custom quick allreduce: BF16 inputs will be converted "
"to FP16 to improve performance. set "
"envs.VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16=0 "
"to turn off.")
# VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB is specified in MB
qr_max_size = envs.VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB
if qr_max_size is not None:
if qr_max_size < 1:
logger.info(
"You should not set a max_size smaller than 1MB, which can "
"lead to error or degradation to custom allreduce or rccl."
)
qr_max_size = qr_max_size * MB
self._ptr = ops.init_custom_qr(self.rank, self.world_size, qr_max_size)
self.qr_max_size = qr_max_size if qr_max_size is not None \
else ops.qr_max_size()
self.create_shared_buffer()
self.disabled = False
def _rocm_arch_available(self):
if not current_platform.is_rocm():
return False
try:
props = torch.cuda.get_device_properties(0)
gcn_arch = getattr(props, "gcnArchName", "")
supported_archs = ['gfx94', 'gfx95']
return any(gfx in gcn_arch for gfx in supported_archs)
except Exception as e:
logger.warning("Failed to determine ROCm for quick allreduce: %s",
e)
return False
def create_shared_buffer(self):
"""
Creates a shared buffer for quickreduce.
Has to be called after init_custom_qr
"""
handle = ops.qr_get_handle(self._ptr)
world_size = dist.get_world_size(group=self.group)
handles = [None] * world_size
dist.all_gather_object(handles, handle, group=self.group)
ops.qr_open_handles(self._ptr, handles)
def should_quick_allreduce(self, inp: torch.Tensor):
"""
Check if quickreduce is available
"""
if self.disabled:
return False
if inp.dtype not in self._SUPPORTED_DTYPES:
return False
inp_size = inp.numel() * inp.element_size()
# custom quick allreduce requires input byte size to be
# multiples of 16
if inp_size % 16 != 0:
return False
if not is_weak_contiguous(inp):
return False
dtype = inp.dtype
if self.use_fp16_kernels:
dtype = torch.float16
return inp_size <= self.qr_max_size and \
inp_size >= self._QR_MIN_SIZE[(dtype, self.world_size)]\
[self.qr_quant_level.value]
def quick_all_reduce(self, inp: torch.Tensor, *, out: torch.Tensor = None):
"""Performs an out-of-place custom quick all reduce."""
# quick allreduce doesn't require a separate graph mode,
# as QR uses static IPC buffer.
if out is None:
out = torch.empty_like(inp)
ops.qr_all_reduce(self._ptr, inp, out, self.qr_quant_level.value,
self.use_fp16_kernels)
return out
def close(self):
if not self.disabled and getattr(self, "_ptr", None):
if ops is not None:
ops.qr_destroy(self._ptr)
self._ptr = 0
self.disabled = True
def __del__(self):
self.close()

View File

@@ -0,0 +1,258 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import uuid
from typing import Any, Optional
import ray
import torch
from ray.exceptions import RayChannelError
from ray.experimental.channel.communicator import (Communicator,
TorchTensorAllocator)
from torch.distributed import ReduceOp
from vllm.distributed.device_communicators.base_device_communicator import (
DeviceCommunicatorBase)
from vllm.distributed.parallel_state import get_pp_group
from vllm.logger import init_logger
from vllm.utils import current_stream
logger = init_logger(__name__)
class RayPPCommunicator(Communicator):
"""
Communicator to be used for pipeline parallelism in Ray Compiled Graph.
This is wraps around the vLLM _PP GroupCoordinator.
This class is not thread-safe.
"""
_comm: Optional[DeviceCommunicatorBase]
def __init__(
self,
world_size: int,
comm_id: Any,
rank: Optional[int],
actor_handles: list["ray.actor.ActorHandle"],
cuda_stream: Optional[torch.cuda.Stream],
use_communication_streams: bool = False,
):
"""
Initialize a RayPPCommunicator that can be used to communicate with
other Ray Compiled Graph actors for pipeline parallelism.
Args:
world_size: The number of participating actors.
comm_id: A unique communicator ID. This is just to conform with
the Ray Communicator API and is not used.
rank: The rank of this actor. If None, then the caller is not a
participant of the RayPPCommunicator group (e.g., the Ray
driver).
actor_handles: A list of actor handles.
cuda_stream: A CUDA stream to dispatch communication ops to. This
is not supported.
use_communication_streams: Whether to use communication streams.
This is not supported.
"""
self._world_size = world_size
self._rank: Optional[int] = None
self._actor_handles = actor_handles
if use_communication_streams:
raise NotImplementedError(
"use_communication_streams is not supported")
if cuda_stream is not None and cuda_stream != current_stream():
raise ValueError(
"cuda_stream other than the current stream is not supported")
if rank is not None:
# Rank is not None, this is Ray worker
assert ray.get_gpu_ids(), "RayPPCommunicator has no GPUs assigned"
self._comm = get_pp_group().device_communicator
assert self._comm is not None
# Since we wrap around the vLLM _PP communicator, we use
# the rank from the vLLM communicator, and ignore the rank
# passed in from Ray.
# TODO(rui): refactor the Ray Communicator API so that
# it also supports no rank passed in.
self._rank = self._comm.rank_in_group
self._build_actor_rank_mapping()
else:
# Rank is None, this is Ray driver
self._comm = None
self._closed = False
def _build_actor_rank_mapping(self):
"""
Use collective communication to build a mapping from actor IDs to ranks.
This should be called once during initialization.
"""
if self._comm is None:
return {}
current_actor = ray.get_runtime_context().current_actor
actor_id_str = current_actor._actor_id.hex()
# Ray actor IDs are 32-character hex strings (128 bits)
ACTOR_ID_LEN = 32
actor_id_bytes = actor_id_str.encode('utf-8')
assert len(
actor_id_bytes
) == ACTOR_ID_LEN, f"Unexpected actor ID length: {len(actor_id_bytes)}"
actor_id_tensor = torch.frombuffer(
actor_id_bytes, dtype=torch.uint8).to(self._comm.device)
# All-gather full actor IDs from all actors
gathered_ids = self._comm.all_gather(actor_id_tensor, dim=0)
# Build mapping: actor_id -> device_comm_rank
self._actor_id_to_rank = {}
for rank in range(self._world_size):
start_idx = rank * ACTOR_ID_LEN
end_idx = (rank + 1) * ACTOR_ID_LEN
actor_bytes = gathered_ids[start_idx:end_idx].cpu().numpy(
).tobytes()
actor_id = actor_bytes.decode('utf-8')
self._actor_id_to_rank[actor_id] = rank
def initialize(self, rank: int) -> None:
# No additional initialization is needed.
pass
def get_actor_handles(self) -> list["ray.actor.ActorHandle"]:
return self._actor_handles
def get_rank(self, actor: ray.actor.ActorHandle) -> int:
"""
Return the given actor's rank using device communicator collective ops.
"""
assert hasattr(self, '_actor_id_to_rank'), (
"Actor rank mapping not built. "
"This should have been done during initialization.")
actor_id_str = actor._actor_id.hex()
if actor_id_str in self._actor_id_to_rank:
return self._actor_id_to_rank[actor_id_str] # type: ignore
else:
raise ValueError(f"Actor {actor} not found in communicator group")
def get_self_rank(self) -> Optional[int]:
"""
Return this actor's rank.
"""
return self._rank
def get_world_size(self) -> int:
"""
Return the number of ranks in the RayPPCommunicator group.
"""
return self._world_size
def send(self, buf: "torch.Tensor", peer_rank: int) -> None:
"""
Send a torch.Tensor to a peer.
This returns when the send kernel has been queued, but the kernel may
not have completed. Therefore, the caller should ensure that there are
no concurrent writes to the sent `buf` until the send has finished.
That is, either all writes should be submitted on the current stream
(self._cuda_stream) or, if on a different stream, that stream should
synchronize with the current stream.
Args:
buf: The torch.Tensor to send. It should already be on this
actor's default device.
peer_rank: The rank of the actor to send to.
"""
if self._closed:
raise RayChannelError("RayPPCommunicator has been destroyed.")
assert self._comm is not None
self._comm.send(buf, peer_rank)
def recv(
self,
shape: tuple[int, ...],
dtype: "torch.dtype",
peer_rank: int,
allocator: TorchTensorAllocator,
) -> "torch.Tensor":
"""
Receive a torch.Tensor from a peer and synchronize the current stream.
After this call returns, the receive buffer is safe to read from
any stream. An RayChannelError will be raised if an error occurred
(e.g., remote actor died), and the buffer is not safe to read.
Args:
shape: The shape of the tensor to receive.
dtype: The dtype of the tensor to receive.
peer_rank: The rank of the actor to receive from.
allocator: The allocator to use to create the received tensor.
This is ignored for this implementation.
"""
if self._closed:
raise RayChannelError("RayPPCommunicator has been destroyed.")
assert self._comm is not None
size = torch.Size(shape)
buf = self._comm.recv(size, dtype, src=peer_rank)
# Buffer values are undefined if NCCL ops are aborted. Therefore, we
# need to synchronize here and check that the channel is still
# open to ensure that the receive buffer is valid.
# TODO(swang): Avoid CUDA synchronization.
current_stream().synchronize()
if self._closed:
raise RayChannelError("RayPPCommunicator has been destroyed.")
return buf
def allgather(
self,
send_buf: "torch.Tensor",
recv_buf: "torch.Tensor",
):
raise NotImplementedError("allgather is not supported")
def allreduce(
self,
send_buf: "torch.Tensor",
recv_buf: "torch.Tensor",
op: ReduceOp = ReduceOp.SUM,
):
raise NotImplementedError("allreduce is not supported")
def reducescatter(
self,
send_buf: "torch.Tensor",
recv_buf: "torch.Tensor",
op: ReduceOp = ReduceOp.SUM,
):
raise NotImplementedError("reducescatter is not supported")
@property
def recv_stream(self):
return torch.cuda.StreamContext(current_stream())
@property
def send_stream(self):
return torch.cuda.StreamContext(current_stream())
def destroy(self) -> None:
# Just sets a flag, vLLM manages the lifecycle of the underlying
# _PP GroupCoordinator.
self._closed = True
def get_transport_name(self) -> str:
return "nccl"
@classmethod
def generate_communicator_id(cls) -> Any:
return uuid.uuid4()

View File

@@ -0,0 +1,589 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pickle
import time
from contextlib import contextmanager
from dataclasses import dataclass, field
from multiprocessing import shared_memory
from threading import Event
from typing import Any, Optional, Union
from unittest.mock import patch
import torch
import torch.distributed as dist
import zmq
from torch.distributed import ProcessGroup
from zmq import IPV6 # type: ignore
from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore
import vllm.envs as envs
from vllm.distributed.utils import StatelessProcessGroup, sched_yield
from vllm.logger import init_logger
from vllm.utils import (get_ip, get_open_port, get_open_zmq_ipc_path,
is_valid_ipv6_address)
VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL
logger = init_logger(__name__)
class SpinTimer:
def record_activity(self):
pass
def spin(self):
sched_yield()
class SpinSleepTimer(SpinTimer):
"""
In setups which have long inactivity periods it is desirable to reduce
system power consumption when vllm does nothing. This would lead to more
CPU thermal headroom when a request eventually comes, especially when
multiple GPUs are connected as each GPU would otherwise pin one thread at
100% CPU usage.
The simplest solution is to reduce polling frequency when there is no
activity for a certain period of time.
"""
def __init__(self, busy_loop_s: float = 3.0, wait_sleep_s: float = 0.1):
self.last_activity = time.monotonic()
self.busy_loop_s = busy_loop_s
self.wait_sleep_s = wait_sleep_s
def record_activity(self):
self.last_activity = time.monotonic()
def spin(self):
curr_time = time.monotonic()
if curr_time >= self.last_activity + self.busy_loop_s:
time.sleep(self.wait_sleep_s)
else:
sched_yield()
class ShmRingBuffer:
def __init__(self,
n_reader: int,
max_chunk_bytes: int,
max_chunks: int,
name: Optional[str] = None):
"""
A shared memory ring buffer implementation for broadcast communication.
Essentially, it is a queue where only one will `enqueue` and multiple
will `dequeue`. The max size of each item, together with the max number
of items that can be stored in the buffer are known in advance.
In this case, we don't need to synchronize the access to
the buffer.
Buffer memory layout:
data metadata
| |
| (current_idx) | (current_idx)
v v
+-------------------------------+----------------------------------------+
| chunk0 | chunk1 | ... | chunk | metadata0 | metadata1 | ... | metadata |
+-------------------------------+----------------------------------------+
| max_chunks x max_chunk_bytes | max_chunks x (1 + n_reader) bytes |
metadata memory layout: each byte is a flag, the first byte is the written
flag, and the rest are reader flags. The flags are set to 0 by default.
+--------------+--------------+--------------+-----+--------------+
| written_flag | reader0_flag | reader1_flag | ... | readerN_flag |
+--------------+--------------+--------------+-----+--------------+
The state of metadata is as follows:
(case 1) 0???...???: the block is not written yet, cannot read, can write
(case 2) 1000...000: the block is just written, can read, cannot write
(case 3) 1???...???: the block is written and read by some readers, can read if not read, cannot write
(case 4) 1111...111: the block is written and read by all readers, cannot read, can write
State transition for readers:
When a reader finds a block that it can read (case 2 or 3), it can yield the block for caller to read.
Only after the caller finishes reading the block, the reader can mark the block as read.
Readers only mark the block as read (from 0 to 1), the writer marks the block as ready to read (from 1 to 0).
State transition for writer:
When the writer writes to a block (case 1 or 4), it first resets the written flag to 0, converting either case
to case 1. Then it can yield the block for caller to write. After the caller finishes writing the block, the writer
can reset the reader flags to 0, and mark the block as written (from 0 to 1).
NOTE: the order is important here, first reset the reader flags (so that we are still in case 1), then mark the block as written. The state transition is atomic. If we do it in the reverse order, it will go through case 3 and then back to case 2, and readers might read the intermediate case 3, which is not correct.
During creation, `name` is None and the buffer is created. We can pass the
created object to other processes by pickling it. The other processes will
get the name of the shared memory and open it, so that they can access the
same shared memory buffer.
"""# noqa
self.n_reader = n_reader
self.metadata_size = 1 + n_reader
self.max_chunk_bytes = max_chunk_bytes
self.max_chunks = max_chunks
self.total_bytes_of_buffer = (self.max_chunk_bytes +
self.metadata_size) * self.max_chunks
self.data_offset = 0
self.metadata_offset = self.max_chunk_bytes * self.max_chunks
if name is None:
# we are creating a buffer
self.is_creator = True
self.shared_memory = shared_memory.SharedMemory(
create=True, size=self.total_bytes_of_buffer)
# initialize the metadata section to 0
with self.shared_memory.buf[self.
metadata_offset:] as metadata_buffer:
torch.frombuffer(metadata_buffer, dtype=torch.uint8).fill_(0)
else:
# we are opening an existing buffer
self.is_creator = False
# fix to https://stackoverflow.com/q/62748654/9191338
# Python incorrectly tracks shared memory even if it is not
# created by the process. The following patch is a workaround.
with patch("multiprocessing.resource_tracker.register",
lambda *args, **kwargs: None):
try:
self.shared_memory = shared_memory.SharedMemory(name=name)
# See https://docs.python.org/3/library/multiprocessing.shared_memory.html # noqa
# Some platforms allocate memory based on page size,
# so the shared memory block size may be larger or equal
# to the requested size. The size parameter is ignored
# when attaching to an existing block.
assert (self.shared_memory.size
>= self.total_bytes_of_buffer)
except FileNotFoundError:
# we might deserialize the object in a different node
# in this case, this object is not used,
# and we should suppress the error
pass
def handle(self):
return (self.n_reader, self.max_chunk_bytes, self.max_chunks,
self.shared_memory.name)
def __reduce__(self):
return (
self.__class__,
self.handle(),
)
def __del__(self):
if hasattr(self, "shared_memory"):
self.shared_memory.close()
if self.is_creator:
self.shared_memory.unlink()
@contextmanager
def get_data(self, current_idx: int):
start = self.data_offset + current_idx * self.max_chunk_bytes
end = start + self.max_chunk_bytes
with self.shared_memory.buf[start:end] as buf:
yield buf
@contextmanager
def get_metadata(self, current_idx: int):
start = self.metadata_offset + current_idx * self.metadata_size
end = start + self.metadata_size
with self.shared_memory.buf[start:end] as buf:
yield buf
@dataclass
class Handle:
local_reader_ranks: list[int] = field(default_factory=list)
buffer_handle: Optional[tuple[int, int, int, str]] = None
local_subscribe_addr: Optional[str] = None
remote_subscribe_addr: Optional[str] = None
remote_addr_ipv6: bool = False
class MessageQueue:
def __init__(
self,
n_reader, # number of all readers
n_local_reader, # number of local readers through shared memory
local_reader_ranks: Optional[list[int]] = None,
max_chunk_bytes: int = 1024 * 1024 * 10,
max_chunks: int = 10,
connect_ip: Optional[str] = None,
):
if local_reader_ranks is None:
local_reader_ranks = list(range(n_local_reader))
else:
assert len(local_reader_ranks) == n_local_reader
self.n_local_reader = n_local_reader
n_remote_reader = n_reader - n_local_reader
self.n_remote_reader = n_remote_reader
context = Context()
if n_local_reader > 0:
# for local readers, we will:
# 1. create a shared memory ring buffer to communicate small data
# 2. create a publish-subscribe socket to communicate large data
self.buffer = ShmRingBuffer(n_local_reader, max_chunk_bytes,
max_chunks)
# XPUB is very similar to PUB,
# except that it can receive subscription messages
# to confirm the number of subscribers
self.local_socket = context.socket(XPUB)
# set the verbose option so that we can receive every subscription
# message. otherwise, we will only receive the first subscription
# see http://api.zeromq.org/3-3:zmq-setsockopt for more details
self.local_socket.setsockopt(XPUB_VERBOSE, True)
local_subscribe_addr = get_open_zmq_ipc_path()
logger.debug("Binding to %s", local_subscribe_addr)
self.local_socket.bind(local_subscribe_addr)
self.current_idx = 0
else:
self.buffer = None # type: ignore
local_subscribe_addr = None
self.local_socket = None
self.current_idx = -1
remote_addr_ipv6 = False
if n_remote_reader > 0:
# for remote readers, we will:
# create a publish-subscribe socket to communicate large data
if not connect_ip:
connect_ip = get_ip()
self.remote_socket = context.socket(XPUB)
self.remote_socket.setsockopt(XPUB_VERBOSE, True)
remote_subscribe_port = get_open_port()
if is_valid_ipv6_address(connect_ip):
self.remote_socket.setsockopt(IPV6, 1)
remote_addr_ipv6 = True
connect_ip = f"[{connect_ip}]"
socket_addr = f"tcp://{connect_ip}:{remote_subscribe_port}"
self.remote_socket.bind(socket_addr)
remote_subscribe_addr = f"tcp://{connect_ip}:{remote_subscribe_port}"
else:
remote_subscribe_addr = None
self.remote_socket = None
self._is_writer = True
self._is_local_reader = False
self.local_reader_rank = -1
# rank does not matter for remote readers
self._is_remote_reader = False
self._read_spin_timer = SpinTimer()
self.handle = Handle(
local_reader_ranks=local_reader_ranks,
buffer_handle=self.buffer.handle()
if self.buffer is not None else None,
local_subscribe_addr=local_subscribe_addr,
remote_subscribe_addr=remote_subscribe_addr,
remote_addr_ipv6=remote_addr_ipv6,
)
logger.info("vLLM message queue communication handle: %s", self.handle)
def export_handle(self) -> Handle:
return self.handle
@staticmethod
def create_from_handle(handle: Handle, rank) -> "MessageQueue":
self = MessageQueue.__new__(MessageQueue)
self.handle = handle
self._is_writer = False
context = Context()
if rank in handle.local_reader_ranks:
assert handle.buffer_handle is not None
self.buffer = ShmRingBuffer(*handle.buffer_handle)
self.current_idx = 0
self.local_reader_rank = handle.local_reader_ranks.index(rank)
self._is_local_reader = True
self._is_remote_reader = False
self.local_socket = context.socket(SUB)
self.local_socket.setsockopt_string(SUBSCRIBE, "")
socket_addr = handle.local_subscribe_addr
logger.debug("Connecting to %s", socket_addr)
self.local_socket.connect(socket_addr)
self.remote_socket = None
self._read_spin_timer = SpinSleepTimer(
) if envs.VLLM_SLEEP_WHEN_IDLE else SpinTimer()
else:
self.buffer = None # type: ignore
self.current_idx = -1
self.local_reader_rank = -1
self._is_local_reader = False
self._is_remote_reader = True
self.local_socket = None
self.remote_socket = context.socket(SUB)
self.remote_socket.setsockopt_string(SUBSCRIBE, "")
if handle.remote_addr_ipv6:
self.remote_socket.setsockopt(IPV6, 1)
socket_addr = handle.remote_subscribe_addr
logger.debug("Connecting to %s", socket_addr)
self.remote_socket.connect(socket_addr)
return self
def wait_until_ready(self):
"""This is a collective operation. All processes (including the
readers and the writer) should call this function.
"""
if self._is_writer:
# wait for all readers to connect
# local readers
for i in range(self.n_local_reader):
# wait for subscription messages from all local readers
self.local_socket.recv()
if self.n_local_reader > 0:
# send a message to all local readers
# to make sure the publish channel is working
self.local_socket.send(b"READY")
# remote readers
for i in range(self.n_remote_reader):
# wait for subscription messages from all remote readers
self.remote_socket.recv()
if self.n_remote_reader > 0:
# send a message to all remote readers
# to make sure the publish channel is working
self.remote_socket.send(b"READY")
elif self._is_local_reader:
# wait for the writer to send a message
recv = self.local_socket.recv()
assert recv == b"READY"
elif self._is_remote_reader:
# wait for the writer to send a message
recv = self.remote_socket.recv()
assert recv == b"READY"
@contextmanager
def acquire_write(self, timeout: Optional[float] = None):
assert self._is_writer, "Only writers can acquire write"
start_time = time.monotonic()
n_warning = 1
while True:
with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
read_count = sum(metadata_buffer[1:])
written_flag = metadata_buffer[0]
if written_flag and read_count != self.buffer.n_reader:
# this block is written and not read by all readers
# for writers, `self.current_idx` is the next block to write
# if this block is not ready to write,
# we need to wait until it is read by all readers
# Release the processor to other threads
sched_yield()
# if we time out, raise an exception
elapsed = time.monotonic() - start_time
if timeout is not None and elapsed > timeout:
raise TimeoutError
# if we wait for a long time, log a message
if elapsed > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning:
logger.info(
"No available shared memory broadcast block found"
" in %s seconds. This typically happens when some"
" processes are hanging or doing some"
" time-consuming work (e.g. compilation)",
VLLM_RINGBUFFER_WARNING_INTERVAL)
n_warning += 1
continue
# found a block that is either
# (1) not written
# (2) read by all readers
# mark the block as not written
metadata_buffer[0] = 0
# let caller write to the buffer
with self.buffer.get_data(self.current_idx) as buf:
yield buf
# caller has written to the buffer
# NOTE: order is important here
# first set the read flags to 0
# then set the written flag to 1
# otherwise, the readers may think they already read the block
for i in range(1, self.buffer.n_reader + 1):
# set read flag to 0, meaning it is not read yet
metadata_buffer[i] = 0
# mark the block as written
metadata_buffer[0] = 1
self.current_idx = (self.current_idx +
1) % self.buffer.max_chunks
break
@contextmanager
def acquire_read(self,
timeout: Optional[float] = None,
cancel: Optional[Event] = None,
indefinite: bool = False):
assert self._is_local_reader, "Only readers can acquire read"
start_time = time.monotonic()
n_warning = 1
while True:
with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
read_flag = metadata_buffer[self.local_reader_rank + 1]
written_flag = metadata_buffer[0]
if not written_flag or read_flag:
# this block is either
# (1) not written
# (2) already read by this reader
# for readers, `self.current_idx` is the next block to read
# if this block is not ready,
# we need to wait until it is written
# Release the processor to other threads
self._read_spin_timer.spin()
if cancel is not None and cancel.is_set():
raise RuntimeError("cancelled")
# if we time out, raise an exception
elapsed = time.monotonic() - start_time
if timeout is not None and elapsed > timeout:
raise TimeoutError
# if we wait for a long time, log a message
if not indefinite and (elapsed
> VLLM_RINGBUFFER_WARNING_INTERVAL *
n_warning):
logger.info(
"No available shared memory broadcast block found"
" in %s seconds. This typically happens when some"
" processes are hanging or doing some"
" time-consuming work (e.g. compilation).",
VLLM_RINGBUFFER_WARNING_INTERVAL)
n_warning += 1
continue
# found a block that is not read by this reader
# let caller read from the buffer
with self.buffer.get_data(self.current_idx) as buf:
yield buf
# caller has read from the buffer
# set the read flag
metadata_buffer[self.local_reader_rank + 1] = 1
self.current_idx = (self.current_idx +
1) % self.buffer.max_chunks
self._read_spin_timer.record_activity()
break
def enqueue(self, obj, timeout: Optional[float] = None):
""" Write to message queue with optional timeout (in seconds) """
assert self._is_writer, "Only writers can enqueue"
serialized_obj = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
if self.n_local_reader > 0:
if len(serialized_obj) >= self.buffer.max_chunk_bytes:
with self.acquire_write(timeout) as buf:
buf[0] = 1 # overflow
self.local_socket.send(serialized_obj)
else:
with self.acquire_write(timeout) as buf:
buf[0] = 0 # not overflow
buf[1:len(serialized_obj) + 1] = serialized_obj
if self.n_remote_reader > 0:
self.remote_socket.send(serialized_obj)
def dequeue(self,
timeout: Optional[float] = None,
cancel: Optional[Event] = None,
indefinite: bool = False):
""" Read from message queue with optional timeout (in seconds) """
if self._is_local_reader:
with self.acquire_read(timeout, cancel, indefinite) as buf:
overflow = buf[0] == 1
if not overflow:
# no need to know the size of serialized object
# pickle format contains the size information internally
# see https://docs.python.org/3/library/pickle.html
obj = pickle.loads(buf[1:])
if overflow:
obj = MessageQueue.recv(self.local_socket, timeout)
elif self._is_remote_reader:
obj = MessageQueue.recv(self.remote_socket, timeout)
else:
raise RuntimeError("Only readers can dequeue")
return obj
@staticmethod
def recv(socket: zmq.Socket, timeout: Optional[float]) -> Any:
timeout_ms = None if timeout is None else int(timeout * 1000)
if not socket.poll(timeout=timeout_ms):
raise TimeoutError
recv = socket.recv(copy=False)
return pickle.loads(recv.buffer)
def broadcast_object(self, obj=None):
if self._is_writer:
self.enqueue(obj)
return obj
else:
return self.dequeue()
@staticmethod
def create_from_process_group(pg: Union[ProcessGroup,
StatelessProcessGroup],
max_chunk_bytes,
max_chunks,
writer_rank=0) -> "MessageQueue":
if isinstance(pg, ProcessGroup):
group_rank = dist.get_rank(pg)
group_world_size = dist.get_world_size(pg)
global_ranks = dist.get_process_group_ranks(pg)
else:
group_rank = pg.rank
group_world_size = pg.world_size
global_ranks = list(range(pg.world_size))
from vllm.distributed.parallel_state import in_the_same_node_as
status = in_the_same_node_as(pg, source_rank=writer_rank)
same_node_ranks = [i for i, s in enumerate(status) if s]
n_reader = group_world_size - 1
n_local_reader = len(same_node_ranks) - 1
local_reader_ranks = [i for i in same_node_ranks if i != writer_rank]
buffer_io: MessageQueue
if group_rank == writer_rank:
buffer_io = MessageQueue(
n_reader=n_reader,
n_local_reader=n_local_reader,
local_reader_ranks=local_reader_ranks,
max_chunk_bytes=max_chunk_bytes,
max_chunks=max_chunks,
)
handle = buffer_io.export_handle()
if isinstance(pg, ProcessGroup):
dist.broadcast_object_list([handle],
src=global_ranks[writer_rank],
group=pg)
else:
pg.broadcast_obj(handle, writer_rank)
else:
if isinstance(pg, ProcessGroup):
recv = [None]
dist.broadcast_object_list(recv,
src=global_ranks[writer_rank],
group=pg)
handle = recv[0] # type: ignore
else:
handle = pg.broadcast_obj(None, writer_rank)
buffer_io = MessageQueue.create_from_handle(handle, group_rank)
buffer_io.wait_until_ready()
return buffer_io

View File

@@ -0,0 +1,635 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pickle
from abc import ABC, abstractmethod
from collections.abc import Iterable
from contextlib import contextmanager
from dataclasses import dataclass
from itertools import chain
from multiprocessing import shared_memory
from multiprocessing.synchronize import Lock as LockType
from typing import Any, Callable, Optional, Union
from unittest.mock import patch
import torch
from vllm.logger import init_logger
logger = init_logger(__name__)
class SingleWriterShmRingBuffer:
"""
A single-writer, multiple-reader ring buffer implementation using shared
memory. This class provides a thread-safe ring buffer where one process
can write data while multiple processes/threads can read from it.
Architecture:
- Uses shared memory for cross-process communication
- Maintains metadata for each allocated buffer chunk in the writer process
- Supports custom "is_free_fn" functions to determine when buffers can be
reused
- Each buffer chunk contains: `[4-byte id][4-byte size][actual_data]`
Key Concepts:
- monotonic_id_start/end: Track the range of active buffer IDs
- data_buffer_start/end: Track the physical memory range in use
- Automatic wraparound when reaching buffer end
- Lazy garbage collection based on is_free_fn checks
Example Usage Scenarios:
Scenario 1: Simple Linear Allocation
```
Buffer size: 100 bytes
Initial state: [................................................. ]
^start=end(0)
After allocating 20 bytes (id=0):
[id:0|size:20|data........][...................................]
^start(0) ^end(28)
After allocating 30 bytes (id=1):
[id:0|size:20|data........][id:1|size:30|data..............][..]
^start(0) ^end(66)
```
Scenario 2: Memory Reclamation
```
Before freeing (both buffers still in use):
[id:0|size:20|data........][id:1|size:30|data..............][..]
^start(0) ^end(66)
After id:0 is marked free by readers:
[FREED.................... ][id:1|size:30|data..............][..]
^start(28) ^end(66)
After both are freed:
[FREED..............................................][..]
^start=end(66)
```
Scenario 3: Wraparound Allocation (continuing from Scenario 2)
```
Starting from after memory reclamation in Scenario 2:
[FREED..............................................][..]
^start=end(66)
Allocate 40 bytes (id=2) - only 34 bytes available at end, so wraparound:
[id:2|size:40|data........................][FREED.............][..]
^end(148) ^start(66)
```
Scenario 4: Error Handling - Out of Space
```
Starting from after wraparound allocation in Scenario 3:
[id:2|size:40|data........................][FREED.............][..]
^end(148) ^start(66)
Trying to allocate 20 more bytes:
occupied_size_new = end + size - start = 148 + 28 - 66 > buffer_size(100)
-> Raises MemoryError: "Not enough space in the data buffer"
```
Thread Safety:
- Single writer: Only one process/thread should write (allocate_buf)
- Multiple readers: Multiple processes/threads can read (access_buf)
- Reader synchronization handled by is_free_fn callback
- Writer handles garbage collection (free_buf) based on reader feedback
Memory Layout per Buffer Chunk:
`[4-byte monotonic_id][4-byte chunk_size][actual_data...]`
^metadata_start ^data_start
The monotonic_id ensures data integrity - readers can verify they're
accessing the correct data even after buffer wraparound or reuse.
"""
def __init__(
self,
data_buffer_size: int,
name: Optional[str] = None,
create: bool = False,
):
self.data_buffer_size = data_buffer_size
self.is_writer = create
self.ID_NBYTES = 4
self.ID_MAX = 2**31 # exclusive, so 2**31 - 1 is the max value
self.SIZE_NBYTES = 4
# 4 bytes for id, 4 bytes for buffer size
self.MD_SIZE = self.ID_NBYTES + self.SIZE_NBYTES
self.monotonic_id_end = 0
self.monotonic_id_start = 0
self.data_buffer_start = 0
self.data_buffer_end = 0
if create:
# we are creating a buffer
self.metadata = {
self.monotonic_id_end: self.data_buffer_end
} # monotonic_id -> start address
self.shared_memory = shared_memory.SharedMemory(
create=True, size=self.data_buffer_size, name=name)
else:
# we are opening an existing buffer
# fix to https://stackoverflow.com/q/62748654/9191338
# Python incorrectly tracks shared memory even if it is not
# created by the process. The following patch is a workaround.
with patch(
"multiprocessing.resource_tracker.register",
lambda *args, **kwargs: None,
):
self.shared_memory = shared_memory.SharedMemory(name=name)
# See https://docs.python.org/3/library/multiprocessing.shared_memory.html # noqa
# Some platforms allocate memory based on page size,
# so the shared memory block size may be larger or equal
# to the requested size. The size parameter is ignored
# when attaching to an existing block.
assert self.shared_memory.size >= self.data_buffer_size
logger.debug("Shared memory created/opened with name: %s, size: %d",
self.shared_memory.name, self.data_buffer_size)
def handle(self):
return (
self.data_buffer_size,
self.shared_memory.name,
)
def clear(self) -> None:
"""Clear the ring buffer."""
assert self.is_writer, "Only the writer can clear the buffer."
self.metadata.clear()
self.monotonic_id_end = 0
self.monotonic_id_start = 0
self.data_buffer_start = 0
self.data_buffer_end = 0
def __del__(self):
if hasattr(self, "shared_memory"):
self.shared_memory.close()
if self.is_writer:
self.shared_memory.unlink()
def int2byte(self, integer: int) -> bytes:
"""Convert an integer to bytes."""
return integer.to_bytes(self.ID_NBYTES, "little", signed=True)
def byte2int(self, byte_data: bytes) -> int:
"""Convert bytes back to an integer."""
return int.from_bytes(byte_data, "little", signed=True)
def allocate_buf(self, size: int) -> tuple[int, int]:
'''
Allocate a buffer `MD_SIZE` + `size` bytes in the shared memory.
Memory layout:
`[4-byte monotonic_id][4-byte size][buffer data...]`
'''
assert self.is_writer, "Only the writer can allocate buffers."
assert size > 0, "Size must be greater than 0"
size += self.MD_SIZE # add metadata size to the buffer size
# reset to beginning if the buffer does have enough contiguous space
buffer_end_reset = self.data_buffer_end % self.data_buffer_size
if buffer_end_reset + size > self.data_buffer_size:
buffer_end_reset = (self.data_buffer_end // self.data_buffer_size +
1) * self.data_buffer_size
else: # no reset needed
buffer_end_reset = self.data_buffer_end
# check if we have enough space in the data buffer
# i.e. if the new end (self.data_buffer_end + size)
# exceeds the start of the data buffer
occupied_size_new = buffer_end_reset + size - self.data_buffer_start
if occupied_size_new > self.data_buffer_size:
raise MemoryError("Not enough space in the data buffer, "
"try calling free_buf() to free up space")
self.data_buffer_end = buffer_end_reset
# first 4 bytes as the monotonic id
buf_idx = self.data_buffer_end % self.data_buffer_size
self.shared_memory.buf[buf_idx:buf_idx + self.ID_NBYTES] = \
self.int2byte(self.monotonic_id_end)
# next 4 bytes as the size of the data buffer
self.shared_memory.buf[buf_idx + self.ID_NBYTES: \
buf_idx + self.MD_SIZE] = self.int2byte(size)
# record metadata
self.metadata[self.monotonic_id_end %
self.ID_MAX] = self.data_buffer_end
# update buffer and monotonic id indices
current_buffer_end = self.data_buffer_end
current_id_end = self.monotonic_id_end
self.data_buffer_end += size
self.monotonic_id_end = (self.monotonic_id_end + 1) % self.ID_MAX
return current_buffer_end, current_id_end
@contextmanager
def access_buf(self, address: int):
buf_idx = address % self.data_buffer_size
# read metadata
metadata_buff = self.shared_memory.buf[buf_idx:buf_idx + self.MD_SIZE]
id = self.byte2int(metadata_buff[:self.ID_NBYTES])
size = self.byte2int(metadata_buff[self.ID_NBYTES:self.MD_SIZE])
# yield the data buffer and metadata
data_buff = self.shared_memory.buf[buf_idx + self.MD_SIZE:buf_idx +
size]
with (memoryview(data_buff) as data_view, ):
yield data_view, (id, size)
def free_buf(self,
is_free_fn: Callable[[int, memoryview], bool],
nbytes: Optional[int] = None) -> Iterable[int]:
'''
Free a buffer of the given size. This is a no-op in shared memory,
but we need to keep track of the metadata.
If freed memory spreads across the end and start of the ring buffer,
the actual freed memory will be in two segments. In this case there
still might not be a contiguous space of `nbytes` available.
Args:
nbytes (int, optional): The size of the buffer to free. If None,
frees the maximum size of the ring buffer.
'''
assert self.is_writer, "Only the writer can free buffers."
logger.debug(
"Freeing up space in the ring buffer, "
"monotonic_id_start: %d, monotonic_id_end: %d",
self.monotonic_id_start, self.monotonic_id_end)
monotonic_id_before = self.monotonic_id_start
# if nbytes is None, free up the maximum size of the ring buffer
if nbytes is None:
nbytes = self.data_buffer_size
freed_bytes = 0
while self.monotonic_id_start in self.metadata and freed_bytes < nbytes:
address = self.metadata[self.monotonic_id_start]
with self.access_buf(address) as (data_buff, metadata):
if is_free_fn(self.monotonic_id_start, data_buff):
# check passed, we can free the buffer
del self.metadata[self.monotonic_id_start]
self.monotonic_id_start = ((self.monotonic_id_start + 1) %
self.ID_MAX)
self.data_buffer_start = address
freed_bytes += metadata[1]
else:
# there are still readers, we cannot free the buffer
break
logger.debug(
"Freed %d bytes from the ring buffer, "
"monotonic_id_start: %d, monotonic_id_end: %d", freed_bytes,
self.monotonic_id_start, self.monotonic_id_end)
# buffer wrap around
if self.data_buffer_start >= self.data_buffer_size:
self.data_buffer_start -= self.data_buffer_size
self.data_buffer_end -= self.data_buffer_size
monotonic_id_after = self.monotonic_id_start
# id wrap around
if monotonic_id_after >= monotonic_id_before:
return range(monotonic_id_before, monotonic_id_after)
else:
return chain(range(monotonic_id_before, self.ID_MAX),
range(0, monotonic_id_after))
class ObjectSerde(ABC):
@abstractmethod
def serialize(self, value: Any) -> tuple[Any, int, bytes, int]:
"""Serialize an object to bytes."""
raise NotImplementedError
@abstractmethod
def deserialize(self, data: memoryview) -> Any:
"""Deserialize bytes back to an object."""
raise NotImplementedError
class MsgpackSerde(ObjectSerde):
def __init__(self):
# Delayed import to avoid circular dependency
from vllm.multimodal.inputs import MultiModalKwargsItem
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
self.encoder = MsgpackEncoder()
self.tensor_decoder = MsgpackDecoder(torch.Tensor)
self.mm_decoder = MsgpackDecoder(MultiModalKwargsItem)
self._mm_kwargs_item_cls = MultiModalKwargsItem
def serialize(
self,
value: Any) -> tuple[Union[bytes, list[bytes]], int, bytes, int]:
len_arr = None
if isinstance(value, (torch.Tensor, self._mm_kwargs_item_cls)):
type_name = type(value).__name__
value = self.encoder.encode(value)
len_arr = [len(s) for s in value]
nbytes = sum(len_arr)
else:
value = pickle.dumps(value, protocol=pickle.HIGHEST_PROTOCOL)
type_name = type(value).__name__
nbytes = len(value)
object_metadata = (type_name, nbytes, len_arr)
serialized_metadata = pickle.dumps(object_metadata,
protocol=pickle.HIGHEST_PROTOCOL)
return value, nbytes, serialized_metadata, len(serialized_metadata)
def deserialize(self, data_view: memoryview) -> Any:
# pickle.loads do not read past the end of a pickled object
# within a large buffer, so we can skip storing the metadata size
type_name, nbytes, len_arr = pickle.loads(data_view)
serialized_data = bytearray(data_view[-nbytes:])
if type_name == torch.Tensor.__name__:
obj = []
start_idx = 0
for length in len_arr:
item_bytes = serialized_data[start_idx:start_idx + length]
obj.append(item_bytes)
start_idx += length
obj = self.tensor_decoder.decode(obj)
elif type_name == self._mm_kwargs_item_cls.__name__:
obj = []
start_idx = 0
for length in len_arr:
item_bytes = serialized_data[start_idx:start_idx + length]
obj.append(item_bytes)
start_idx += length
obj = self.mm_decoder.decode(obj)
elif type_name == bytes.__name__:
obj = pickle.loads(serialized_data)
else:
raise ValueError(
f"Unsupported object type '{type_name}' in metadata")
return obj
@dataclass
class ShmObjectStorageHandle:
max_object_size: int
n_readers: int
ring_buffer_handle: tuple[int, str]
serde_class: type[ObjectSerde]
reader_lock: Optional[LockType]
class SingleWriterShmObjectStorage:
"""
A single-writer, multiple-reader object storage system built on top of a
shared memory ring buffer. Provides key-value storage with automatic memory
management and cross-process serialization support.
This storage system follows a FIFO (First-In-First-Out) eviction policy
where the oldest objects are automatically freed when memory runs low.
Memory is reclaimed based on reader reference counting - objects are only
freed when all readers have finished accessing them.
Architecture:
- Single writer process can put(key, value) objects
- Multiple reader processes can get(address, monotonic_id) objects
- Built on SingleWriterShmRingBuffer for efficient shared memory management
- Thread-safe operations with reader synchronization via locks
Key Features:
- FIFO Eviction: Oldest objects are evicted first when memory is full
- Reference Counting: Objects are only freed when no readers are
accessing them
- Duplicate Key Handling: Existing keys are not overwritten, just
re-referenced
- Customized Serialization: By default uses Msgpack for efficient
serialization of Python objects, but can be extended for custom types
- Cross-Process Safety: Uses shared memory with proper synchronization
- Automatic Cleanup: Garbage collection happens transparently during
allocation
Memory Layout per Object:
`[4-byte reference_count][metadata_size][serialized_object_data]`
Thread Safety:
- Writer operations (put, clear) are single-threaded by design
- Reader operations (get) are thread-safe with lock-based reference
counting
- Memory reclamation is handled exclusively by the writer process
"""
def __init__(
self,
max_object_size: int,
n_readers: int,
ring_buffer: SingleWriterShmRingBuffer,
serde_class: type[ObjectSerde] = MsgpackSerde,
reader_lock: Optional[LockType] = None,
):
"""
Initialize the object storage.
Args:
max_object_size: Maximum size for a single object in bytes.
n_readers: Number of reader processes that can access the storage.
ring_buffer: The shared memory ring buffer for storing objects.
serde_class: Serializer/deserializer for objects.
reader_lock: Optional lock for synchronizing reader access.
Raises:
ValueError: If reader_lock is None for readers.
"""
self.max_object_size = max_object_size
self.n_readers = n_readers
self.serde_class = serde_class
self.ser_de = serde_class()
self.ring_buffer = ring_buffer
self.is_writer = self.ring_buffer.is_writer
self.flag_bytes = 4 # for in-use flag
if self.is_writer:
# Key-value mapping: key -> (address, monotonic_id)
self.key_index: dict[str, tuple[int, int]] = {}
# Reverse mapping: monotonic_id -> key
self.id_index: dict[int, str] = {}
# Writer flag to track in-use status: monotonic_id -> count
self.writer_flag: dict[int, int] = {}
else:
if reader_lock is None:
raise ValueError("Lock must be provided for readers.")
self._reader_lock = reader_lock
def clear(self) -> None:
"""Clear the object storage."""
if self.is_writer:
self.ring_buffer.clear()
self.key_index.clear()
self.id_index.clear()
self.writer_flag.clear()
logger.debug("Object storage cleared and reinitialized.")
def copy_to_buffer(
self,
data: Union[bytes, list[bytes]],
data_bytes: int,
metadata: bytes,
md_bytes: int,
data_view: memoryview,
) -> None:
data_view[self.flag_bytes:self.flag_bytes + md_bytes] = metadata
if isinstance(data, bytes):
data_view[-data_bytes:] = data
elif isinstance(data, list):
start_idx = self.flag_bytes + md_bytes
for item_bytes in data:
item_size = len(item_bytes)
data_view[start_idx:start_idx + item_size] = item_bytes
start_idx += item_size
else:
raise ValueError(
f"Unsupported data type for serialization: {type(data)}")
def increment_writer_flag(self, id: int) -> None:
"""Set the in-use flag for the writer."""
self.writer_flag[id] = self.writer_flag.get(id, 0) + 1
def increment_reader_flag(self, data_view: memoryview) -> None:
"""Set the in-use flag for the reader."""
# >0 for in-use flag
reader_count = self.ring_buffer.byte2int(data_view)
data_view[:] = self.ring_buffer.int2byte(reader_count + 1)
def free_unused(self) -> None:
"""Free unused buffers in the ring buffer."""
# try to free up 2*max_object_size bytes of space in the ring buffer,
# since the buffer might be fragmented
freed_ids = self.ring_buffer.free_buf(self.default_is_free_check,
2 * self.max_object_size)
# update the metadata after freeing up space
for freed_id in freed_ids:
key_to_free = self.id_index[freed_id]
del self.key_index[key_to_free]
del self.id_index[freed_id]
del self.writer_flag[freed_id]
def is_cached(self, key: str) -> bool:
"""
Check if the object with the given key is cached.
"""
return key in self.key_index
def get_cached(self, key: str) -> tuple[int, int]:
"""
Get the cached object by key if it exists.
"""
address, monotonic_id = self.key_index[key]
self.increment_writer_flag(monotonic_id)
return address, monotonic_id
def put(self, key: str, value: Any) -> tuple[int, int]:
"""
Store a key-value pair in the object storage.
Attempts to free max_object_size bytes using FIFO order
when the ring buffer runs out of space during a put() operation.
Args:
key: String key to identify the object
value: Any serializable Python object
Raises:
MemoryError: If there's not enough space in the buffer
ValueError: If the serialized object is too large
ValueError: If the key already exists in the storage
"""
if key in self.key_index:
raise ValueError(f"Key '{key}' already exists in the storage.")
object_data, data_bytes, object_metadata, md_bytes = \
self.ser_de.serialize(value)
buffer_size = self.flag_bytes + data_bytes + md_bytes
# Sanity checks
if buffer_size > self.max_object_size:
raise ValueError(
f"Serialized object size ({buffer_size} bytes) exceeds "
f"max object size ({self.max_object_size} bytes)")
# Allocate new buffer
try:
address, monotonic_id = self.ring_buffer.allocate_buf(buffer_size)
except MemoryError:
self.free_unused()
# try again after freeing up space
address, monotonic_id = self.ring_buffer.allocate_buf(buffer_size)
# Write data to buffer
with self.ring_buffer.access_buf(address) as (data_view, metadata):
data_view[:self.flag_bytes] = self.ring_buffer.int2byte(0)
self.copy_to_buffer(object_data, data_bytes, object_metadata,
md_bytes, data_view)
self.increment_writer_flag(monotonic_id)
# Update key index
self.key_index[key] = (address, monotonic_id)
self.id_index[monotonic_id] = key
return address, monotonic_id
def get(self, address: int, monotonic_id: int) -> Any:
# Read data from buffer
with self.ring_buffer.access_buf(address) as (data_view, buf_metadata):
# check id from metadata
if buf_metadata[0] != monotonic_id:
raise ValueError(
f"Data for address:id '{address}:{monotonic_id}'"
" has been modified or is invalid.")
obj = self.ser_de.deserialize(data_view[self.flag_bytes:])
# decrease the in-use flag for reader reads
if self._reader_lock is not None:
with self._reader_lock:
self.increment_reader_flag(data_view[:self.flag_bytes])
else:
# if self._reader_lock is None, it means we are the writer
# in this case, we do not need to decrease the reader count
assert self.is_writer
return obj
def handle(self):
"""Get handle for sharing across processes."""
return ShmObjectStorageHandle(
max_object_size=self.max_object_size,
n_readers=self.n_readers,
ring_buffer_handle=self.ring_buffer.handle(),
serde_class=self.serde_class,
reader_lock=self._reader_lock,
)
@staticmethod
def create_from_handle(
handle: ShmObjectStorageHandle) -> "SingleWriterShmObjectStorage":
logger.debug("Creating storage from handle: %s", handle)
ring_buffer = SingleWriterShmRingBuffer(*handle.ring_buffer_handle)
return SingleWriterShmObjectStorage(
max_object_size=handle.max_object_size,
n_readers=handle.n_readers,
ring_buffer=ring_buffer,
serde_class=handle.serde_class,
reader_lock=handle.reader_lock,
)
def default_is_free_check(self, id: int, buf: memoryview) -> bool:
"""
Default is_free function that checks if the first 4 bytes are zero.
This indicates that the buffer is free.
"""
reader_count = int.from_bytes(buf[0:4], "little", signed=True)
writer_count = self.writer_flag[id]
return reader_count >= writer_count * self.n_readers

View File

@@ -0,0 +1,136 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional, Union
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from vllm.distributed.device_communicators.all_reduce_utils import (
SYMM_MEM_ALL_REDUCE_MAX_SIZES)
from vllm.logger import init_logger
from vllm.platforms import current_platform
try:
import torch.distributed._symmetric_memory as torch_symm_mem
symm_mem_available = True
except ImportError:
symm_mem_available = False
logger = init_logger(__name__)
class SymmMemCommunicator:
_WORLD_SIZES_MULTIMEM = {
"9.0": [4, 6, 8],
"10.0": [6, 8],
}
def __init__(
self,
group: ProcessGroup,
device: Union[int, str, torch.device],
# add options for testing
force_multimem: Optional[bool] = None,
max_size_override: Optional[int] = None):
self.disabled = True
if not symm_mem_available:
return
if not current_platform.is_cuda():
logger.warning("SymmMemCommunicator: symmetric "
"memory is not available.")
return
if isinstance(device, int):
device = torch.device(f"cuda:{device}")
elif isinstance(device, str):
device = torch.device(device)
torch.cuda.set_device(device)
self.dtype = torch.bfloat16
self.device = device
self.group = group
self.world_size = dist.get_world_size(self.group)
self.device_capability = current_platform.get_device_capability(
).as_version_str()
if self.device_capability not in SYMM_MEM_ALL_REDUCE_MAX_SIZES:
logger.warning(
"SymmMemCommunicator: Device capability %s not supported, "
"communicator is not available.",
self.device_capability,
)
return
if self.world_size not in SYMM_MEM_ALL_REDUCE_MAX_SIZES[
self.device_capability]:
logger.warning(
"SymmMemCommunicator: World size %d not supported, "
"communicator is not available.",
self.world_size,
)
return
# Use override max_size if provided, otherwise use default
if max_size_override is not None:
self.max_size = max_size_override
logger.info(
"SymmMemCommunicator: Using override max_size: %s bytes",
self.max_size,
)
else:
self.max_size = SYMM_MEM_ALL_REDUCE_MAX_SIZES[
self.device_capability][self.world_size]
self.buffer = torch_symm_mem.empty(
self.max_size // self.dtype.itemsize,
device=self.device,
dtype=self.dtype,
)
handle = torch_symm_mem.rendezvous(self.buffer, self.group.group_name)
if handle.multicast_ptr == 0:
logger.warning("SymmMemCommunicator: symmetric memory "
"multicast operations are not supported.")
return
self.force_multimem = force_multimem
self.disabled = False
def should_use_symm_mem(self, inp: torch.Tensor):
if self.disabled:
return False
if inp.dtype != self.dtype:
return False
inp_size = inp.numel() * inp.element_size()
if inp_size % 4 != 0:
return False
return inp_size < self.max_size
def all_reduce(
self,
inp: torch.Tensor,
*,
out: Optional[torch.Tensor] = None) -> Optional[torch.Tensor]:
if not self.should_use_symm_mem(inp):
return None
if out is None:
out = torch.empty_like(inp)
self.buffer[:inp.numel()].copy_(inp.view(-1))
# Determine which algorithm to use
use_multimem = False
if self.force_multimem is not None:
# Test override: use forced setting
use_multimem = self.force_multimem
else:
# Normal logic: use multimem for supported world sizes
use_multimem = self.world_size in self._WORLD_SIZES_MULTIMEM[
self.device_capability]
if use_multimem:
torch.ops.symm_mem.multimem_all_reduce_(self.buffer[:inp.numel()],
"sum",
self.group.group_name)
else:
torch.ops.symm_mem.two_shot_all_reduce_(self.buffer[:inp.numel()],
"sum",
self.group.group_name)
out.copy_(self.buffer[:inp.numel()].view(out.shape))
return out

View File

@@ -0,0 +1,102 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from typing import Optional
import torch
from torch.distributed import ProcessGroup
from vllm.config import get_current_vllm_config
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.platforms.tpu import USE_TPU_COMMONS
from .base_device_communicator import DeviceCommunicatorBase
USE_RAY = parallel_config = get_current_vllm_config(
).parallel_config.distributed_executor_backend == "ray"
logger = init_logger(__name__)
if not USE_TPU_COMMONS:
logger.info("tpu_commons not found, using vLLM's TpuCommunicator")
if current_platform.is_tpu():
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
from torch_xla._internal import pjrt
from torch_xla.distributed.xla_multiprocessing import (
create_optimized_replica_groups)
if USE_RAY:
from vllm.executor import ray_utils
class TpuCommunicator(DeviceCommunicatorBase):
def __init__(self,
cpu_group: ProcessGroup,
device: Optional[torch.device] = None,
device_group: Optional[ProcessGroup] = None,
unique_name: str = ""):
super().__init__(cpu_group, device, device_group, unique_name)
# NOTE(woosuk): When using TP > 1 on TPUs, every TPU on the same node
# must be used together. Therefore, the local rank and world size can
# be simply calculated as follows.
global_rank = self.global_rank
global_world_size = self.global_world_size
if USE_RAY:
logger.info("TpuCommunicator initialized with RAY")
# Calculate how many TPU nodes are in the current deployment. This
# is the Ray placement group if it is deployed with Ray. Default
# to the number of TPU nodes in the Ray cluster. The number of TPU
# nodes is computed by the total number of TPUs divided by the
# number of TPU accelerators per node, to account for clusters
# with both CPUs and TPUs.
num_nodes = ray_utils.get_num_tpu_nodes()
num_nodes_in_pg = ray_utils.get_num_nodes_in_placement_group()
if num_nodes_in_pg > 0:
num_nodes = num_nodes_in_pg
local_world_size = global_world_size // num_nodes
local_rank = global_rank % local_world_size
else:
logger.info("TpuCommunicator initialized with MP")
# Sanity: Verify we run on a single host
num_hosts = torch_xla.tpu.num_tpu_workers()
assert num_hosts == 1
# Get the current number of TPUs (we have locally)
local_world_size = torch_xla.tpu.num_available_chips()
# Get current rank
local_rank = global_rank % local_world_size
# Ensure environment variables are set for multihost deployments.
# On GKE, this is needed for libtpu and TPU driver to know which TPU
# chip is actually visible. Otherwise the TPU driver will fail to
# initialize because the number of devices would be different from
# the number of visible worker addresses.
os.environ["CLOUD_TPU_TASK_ID"] = str(global_rank)
os.environ["TPU_VISIBLE_CHIPS"] = str(local_rank)
pjrt.initialize_multiprocess(local_rank, local_world_size)
xr._init_world_size_ordinal()
self.groups = create_optimized_replica_groups()
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
# TODO: Remove the groups specification after XLA compiler can support
# auto-reordering the ring order for all-reduce.
return xm.all_reduce(xm.REDUCE_SUM, input_, groups=self.groups)
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
assert dim == -1, "TPUs only support dim=-1 for all-gather."
return xm.all_gather(input_, dim=dim)
if USE_TPU_COMMONS:
from tpu_commons.distributed.device_communicators import (
TpuCommunicator as TpuCommonsCommunicator)
TpuCommunicator = TpuCommonsCommunicator # type: ignore

View File

@@ -0,0 +1,94 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
import vllm.envs as envs
from vllm.logger import init_logger
from .base_device_communicator import DeviceCommunicatorBase
logger = init_logger(__name__)
class XpuCommunicator(DeviceCommunicatorBase):
def __init__(self,
cpu_group: ProcessGroup,
device: Optional[torch.device] = None,
device_group: Optional[ProcessGroup] = None,
unique_name: str = ""):
super().__init__(cpu_group, device, device_group, unique_name)
if self.use_all2all:
all2all_backend = envs.VLLM_ALL2ALL_BACKEND
if all2all_backend != "naive":
logger.warning(
"`%s` all2all manager is not supported on XPU."
"Falling back to `naive` all2all manager for XPU.",
all2all_backend)
all2all_backend = "naive"
if all2all_backend == "naive":
from .all2all import NaiveAll2AllManager
self.all2all_manager = NaiveAll2AllManager(self.cpu_group)
logger.info("Using naive all2all manager.")
def all_reduce(self, input_) -> torch.Tensor:
dist.all_reduce(input_, group=self.device_group)
return input_
def gather(self,
input_: torch.Tensor,
dst: int = 0,
dim: int = -1) -> Optional[torch.Tensor]:
assert -input_.dim() <= dim < input_.dim(), (
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
# For xpu path, gather doesn't work properly together with ray
# cluster so we use all_gather instead for now.
input_size = input_.size()
# Allocate output tensor.
output_tensor = torch.empty((self.world_size, ) + input_size,
dtype=input_.dtype,
device=input_.device)
# All-gather.
dist.all_gather_into_tensor(output_tensor,
input_,
group=self.device_group)
if self.rank_in_group == dst:
# Reshape
output_tensor = output_tensor.movedim(0, dim)
output_tensor = output_tensor.reshape(input_size[:dim] +
(self.world_size *
input_size[dim], ) +
input_size[dim + 1:])
else:
output_tensor = None
return output_tensor
def broadcast(self, input_: torch.Tensor, src: int = 0) -> None:
dist.broadcast(input_, src=src, group=self.device_group)
def dispatch(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False
) -> tuple[torch.Tensor, torch.Tensor]:
assert self.all2all_manager is not None
hidden_states, router_logits = self.all2all_manager.dispatch(
hidden_states, router_logits, is_sequence_parallel)
return hidden_states, router_logits
def combine(self,
hidden_states: torch.Tensor,
is_sequence_parallel: bool = False) -> torch.Tensor:
assert self.all2all_manager is not None
hidden_states = self.all2all_manager.combine(hidden_states,
is_sequence_parallel)
return hidden_states

View File

@@ -0,0 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
'''
Expert parallelism load balancer (EPLB).
'''
from .eplb_state import *
from .rebalance_algo import *

View File

@@ -0,0 +1,620 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Expert parallelism load balancer (EPLB) metrics and states.
# Glossary
- **Logical Expert**: An expert that is part of the model's logical structure.
It holds a set of weights and is replicated across multiple physical
experts.
- **Redundant Expert**: To achieve load balancing, for some popular logical
experts, we create additional copies of the expert weights. During inference,
each of these copies can be routed to by the same set of tokens.
- **Physical Expert**: An expert that is instantiated on a specific device.
It is a replica of a logical expert and can be rearranged across devices.
I.e., one logical expert may have multiple sets of weights initialized on
different devices, and each of these sets is a physical expert.
- **Local Physical Expert**: A physical expert that is instantiated on the
current device.
For example: DeepSeek-R1 has 256 logical experts, so each MoE layer
has 256 sets of linear layer weights in the model parameters. If we add 32
redundant experts, DeepSeek-R1 will have 256 + 32 = 288 physical experts in
total. And when deploying, we'll have 288 sets of linear layer weights for each
MoE layer. If we have 32 EP ranks, then each GPU will hold 288 / 32 = 9 local
physical experts.
"""
import time
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Optional, Union
import torch
from torch.distributed import ProcessGroup, all_reduce
from vllm.config import ParallelConfig
from vllm.distributed.parallel_state import (get_ep_group, get_node_count,
in_the_same_node_as)
from vllm.distributed.utils import StatelessProcessGroup
from vllm.logger import init_logger
from vllm.model_executor.models.interfaces import MixtureOfExperts
from .rebalance_algo import rebalance_experts
from .rebalance_execute import rearrange_expert_weights_inplace
logger = init_logger(__name__)
@dataclass
class EplbState:
"""EPLB metrics."""
physical_to_logical_map: torch.Tensor
"""
Mapping from physical experts to logical experts.
Shape: (num_moe_layers, num_physical_experts)
# Example
For a 2-layer MoE model with 6 physical experts and 4 logical experts on 3
EP ranks, the mapping could look like this:
```
[[0, 1, 2, 3, 0, 1],
[0, 2, 0, 1, 0, 3]]
```
"""
logical_to_physical_map: torch.Tensor
"""
Mapping from logical experts to physical experts.
This is a sparse matrix, where -1 indicates no mapping.
Shape: (num_moe_layers, num_logical_experts, num_redundant_experts + 1)
# Example
For a 2-layer MoE model with 6 physical experts and 4 logical experts on 3
EP ranks, the mapping could look like this:
```
[[[0, 4, -1],
[1, 5, -1],
[2, -1, -1],
[3, -1, -1]],
[[0, 2, 4],
[3, -1, -1],
[1, -1, -1],
[5, -1, -1]]]
```
"""
logical_replica_count: torch.Tensor
"""
Number of replicas for each logical expert.
This is exactly the non-`-1` count in the `logical_to_physical_map`.
Shape: (num_moe_layers, num_logical_experts)
# Example
For a 2-layer MoE model with 6 physical experts and 4 logical experts on 3
EP ranks, the count could look like this:
```
[[2, 2, 1, 1],
[3, 1, 1, 1]]
"""
expert_load_pass: torch.Tensor
"""
Expert load during this forward pass.
We use the token count each expert processes as the load.
Shape: (num_moe_layers, num_physical_experts)
"""
expert_load_window: torch.Tensor
"""
A sliding window of expert load.
Shape: (window_size, num_moe_layers, num_physical_experts)
NOTE: The expert_load_view now records load for all physical experts
rather than just local experts. This ensures consistent load statistics
across different dispatch methods (naive all-to-all, DeepEP, pplx-kernels).
The recorded load will be multiplied by dp_size when using naive all-to-all
due to each DP rank contributing the same token set to the calculation.
See:
https://github.com/vllm-project/vllm/pull/22167#pullrequestreview-3086143856
"""
expert_load_window_step: int = 0
"""
Current step in the sliding window.
Different from `expert_rearrangement_step`, each EP rank may have its own
`expert_load_window_step`.
"""
expert_load_window_size: int = 0
"""
Size of the expert load sliding window.
This is a constant and is taken from the config.
"""
expert_rearrangement_step: int = 0
"""
Steps after last rearrangement.
Will trigger a rearrangement if it exceeds the threshold.
NOTE: Keep in mind that all EP ranks need to have the same
`expert_rearrangement_step` value to ensure synchronization.
Otherwise, the rearrangement will hang at collective
communication calls.
"""
expert_rearrangement_step_interval: int = 0
"""
Interval for expert rearrangement steps.
This is a constant and is taken from the config.
"""
@staticmethod
def build_initial_global_physical_to_logical_map(
num_routed_experts: int,
num_redundant_experts: int,
) -> Sequence[int]:
"""
Build an initial expert arrangement using the following structure:
[original routed experts, redundant experts]
Returns:
physical_to_logical_map (Sequence[int]): A list of integers,
where each integer is the index of the logical expert
that the corresponding physical expert maps to.
"""
global_physical_to_logical_map = list(range(num_routed_experts))
global_physical_to_logical_map += [
i % num_routed_experts for i in range(num_redundant_experts)
]
return global_physical_to_logical_map
@classmethod
def build(
cls,
model: MixtureOfExperts,
device: torch.device,
parallel_config: ParallelConfig,
global_expert_load: Optional[torch.Tensor] = None,
old_global_expert_indices: Optional[torch.Tensor] = None,
rank_mapping: Optional[dict[int, int]] = None,
) -> "EplbState":
"""
Build the initial EPLB state.
"""
physical_to_logical_map_list = (
cls.build_initial_global_physical_to_logical_map(
model.num_routed_experts,
model.num_redundant_experts,
))
physical_to_logical_map = torch.tensor(
physical_to_logical_map_list,
device=device,
)
# Assuming 8 GPUs per node, this supports up to
# (1023 + 1) / 8 = 128 nodes for now.
# TODO(rui): make this configurable
MAX_EXPERT_REDUNDANCY = 1023
assert model.num_redundant_experts <= MAX_EXPERT_REDUNDANCY, (
f"num_redundant_experts {model.num_redundant_experts} "
f"must be less than or equal to {MAX_EXPERT_REDUNDANCY}")
max_slots_per_logical_expert = MAX_EXPERT_REDUNDANCY + 1
logical_to_physical_map = torch.full(
(model.num_logical_experts, max_slots_per_logical_expert),
-1,
device=device,
)
logical_replica_count = torch.zeros(
(model.num_logical_experts, ),
device=device,
dtype=torch.long,
)
for i in range(model.num_physical_experts):
logical_idx = physical_to_logical_map[i]
logical_to_physical_map[logical_idx,
logical_replica_count[logical_idx]] = i
logical_replica_count[logical_idx] += 1
# Duplicate initial mapping for all layers
physical_to_logical_map = physical_to_logical_map.unsqueeze(0).expand(
model.num_moe_layers,
-1,
).contiguous()
logical_to_physical_map = logical_to_physical_map.unsqueeze(0).expand(
model.num_moe_layers,
-1,
-1,
).contiguous()
logical_replica_count = logical_replica_count.unsqueeze(0).expand(
model.num_moe_layers,
-1,
).contiguous()
expert_load_pass = torch.zeros(
(model.num_moe_layers, model.num_physical_experts),
dtype=torch.int32,
device=device,
)
expert_load_window_size = parallel_config.eplb_config.window_size
expert_load_window = torch.zeros(
(expert_load_window_size, model.num_moe_layers,
model.num_physical_experts),
dtype=torch.int32,
device=device,
)
# Set the initial progress of rearrangement to 3/4
eplb_step_interval = parallel_config.eplb_config.step_interval
expert_rearrangement_step = max(
0, eplb_step_interval - eplb_step_interval // 4)
if global_expert_load is not None:
ep_group = get_ep_group().device_group
assert global_expert_load.shape == (model.num_moe_layers,
model.num_logical_experts)
assert global_expert_load.dtype == torch.int64
num_replicas = model.num_physical_experts
num_groups = model.num_expert_groups
num_nodes = get_node_count()
num_gpus = ep_group.size()
if num_gpus % num_nodes != 0:
num_nodes = 1
logger.warning_once(
f"num_gpus % num_nodes != 0, "
"not using hierarchical rearrangement algorithm.\n"
f"{num_gpus=}, {num_nodes=}")
# Get new expert mappings
(
new_physical_to_logical_map,
new_logical_to_physical_map,
new_logical_replica_count,
) = (rebalance_experts(
global_expert_load,
num_replicas,
num_groups,
num_nodes,
num_gpus,
))
max_physical_slots = new_logical_to_physical_map.shape[-1]
assert max_physical_slots <= logical_to_physical_map.shape[-1]
new_logical_to_physical_map = torch.nn.functional.pad(
new_logical_to_physical_map,
(0, logical_to_physical_map.shape[-1] - max_physical_slots),
value=-1,
)
physical_to_logical_map = new_physical_to_logical_map.to(device)
logical_to_physical_map.copy_(new_logical_to_physical_map)
logical_replica_count.copy_(new_logical_replica_count)
model.set_eplb_state(
expert_load_pass,
logical_to_physical_map,
logical_replica_count,
)
if global_expert_load is not None:
rearrange_expert_weights_inplace(
old_global_expert_indices,
new_physical_to_logical_map,
model.expert_weights,
ep_group,
False,
rank_mapping,
)
expert_rearrangement_step = 0
return cls(
physical_to_logical_map,
logical_to_physical_map,
logical_replica_count,
expert_load_pass,
expert_load_window,
expert_load_window_size=expert_load_window_size,
expert_rearrangement_step=expert_rearrangement_step,
expert_rearrangement_step_interval=eplb_step_interval,
)
def step(self,
model: MixtureOfExperts,
is_dummy: bool = False,
is_profile: bool = False,
log_stats: bool = False) -> None:
"""
Step the EPLB state.
Args:
model (MixtureOfExperts): The MoE model.
is_dummy (bool): If `True`, this is a dummy step and the load
metrics recorded in this forward pass will not count.
Defaults to `False`.
is_profile (bool): If `True`, perform a dummy rearrangement
with maximum communication cost. This is used in
`profile_run` to reserve enough memory
for the communication buffer.
log_stats (bool): If `True`, log the expert load metrics.
# Stats
The metrics are all summed up across layers.
- `avg_tokens`: The average load across ranks.
- `max_tokens`: The maximum load across ranks.
- `balancedness`: The ratio of average load to maximum load.
"""
if is_profile:
self.rearrange(model, is_profile=True)
return
if is_dummy:
# Do not record load metrics for dummy steps
self.expert_load_pass.zero_()
if log_stats:
# total_expert_load_pass: (num_moe_layers, num_physical_experts)
total_expert_load_pass = self.expert_load_pass.clone()
# Collect load metrics from all ranks
ep_group = get_ep_group().device_group
all_reduce(total_expert_load_pass, group=ep_group)
# num_tokens_per_rank: (num_moe_layers, num_ranks)
num_tokens_per_rank = total_expert_load_pass.reshape(
total_expert_load_pass.shape[0], ep_group.size(),
-1).sum(dim=-1).float()
# Compute balancedness ratio:
# for each layer:
# (mean load across ranks) / (max load across ranks)
avg_tokens_tensor = num_tokens_per_rank.mean(dim=0).sum(dim=0)
max_tokens_tensor = num_tokens_per_rank.max(dim=0).values.sum(
dim=0)
# Just to make type checker happy
tokens_tensors: list[float] = torch.stack(
[avg_tokens_tensor, max_tokens_tensor]).tolist()
avg_tokens, max_tokens = tokens_tensors
balancedness = avg_tokens / max_tokens if max_tokens > 0 else 0.0
if ep_group.rank() == 0:
logger.info(
"EPLB step: avg_tokens=%.2f, max_tokens=%d, "
"balancedness=%.4f", avg_tokens, max_tokens, balancedness)
# Update the expert load sliding window
if not is_dummy:
self.expert_load_window[self.expert_load_window_step] = (
self.expert_load_pass.clone())
self.expert_load_window_step += 1
if self.expert_load_window_step >= self.expert_load_window_size:
self.expert_load_window_step = 0
self.expert_load_pass.zero_()
# Step the expert rearrangement step
# Note that even if this is a dummy step, we still increment the
# rearrangement step and perform rearrangement to ensure all ranks are
# performing collective communication.
self.expert_rearrangement_step += 1
if (self.expert_rearrangement_step
>= self.expert_rearrangement_step_interval):
self.expert_rearrangement_step = 0
self.rearrange(model)
def rearrange(
self,
model: MixtureOfExperts,
is_profile: bool = False,
execute_shuffle: bool = True,
global_expert_load: Optional[torch.Tensor] = None,
rank_mapping: Optional[dict[int,
int]] = None) -> Optional[torch.Tensor]:
"""
Rearrange the experts according to the current load.
"""
ep_group = get_ep_group().device_group
ep_rank = ep_group.rank()
time_start = None
is_main_rank = ep_rank == 0
if is_main_rank:
torch.cuda.synchronize()
time_start = time.perf_counter()
logger.info("Rearranging experts %s...",
"(profile)" if is_profile else "")
if global_expert_load is None:
# Map the physical expert load to global logical experts
logical_expert_load_window = torch.zeros(
self.expert_load_window_size,
model.num_moe_layers,
model.num_logical_experts,
dtype=self.expert_load_window.dtype,
device=self.expert_load_window.device,
)
logical_expert_load_window.scatter_add_(
dim=-1,
index=self.physical_to_logical_map.unsqueeze(0).expand_as(
self.expert_load_window).long(),
src=self.expert_load_window,
)
if not execute_shuffle:
metadata = torch.tensor(
[
model.num_moe_layers, model.num_logical_experts,
self.physical_to_logical_map.shape[1]
],
dtype=torch.int32,
device="cpu",
)
torch.distributed.broadcast(metadata,
group=get_ep_group().cpu_group,
group_src=0)
# Perform all-reduce to get the expert load across all ranks
global_expert_load_window = logical_expert_load_window.sum(dim=0)
all_reduce(global_expert_load_window, group=ep_group)
if not execute_shuffle:
# (num_moe_layers, old_num_physical_experts)
old_global_expert_indices = self.physical_to_logical_map
torch.distributed.broadcast(old_global_expert_indices,
group=ep_group,
group_src=0)
return global_expert_load_window
else:
assert execute_shuffle
global_expert_load_window = global_expert_load
# TODO(bowen): Treat differently for prefill and decode nodes
num_replicas = model.num_physical_experts
num_groups = model.num_expert_groups
if rank_mapping is not None and len(rank_mapping) == ep_group.size():
# NOTE(yongji): scale down, we need to rebalance the experts on
# remaining GPUs, transfer the experts while we haven't shutdown
# the GPUs to be released.
cpu_group = get_ep_group().cpu_group
num_nodes = _node_count_with_rank_mapping(cpu_group, rank_mapping)
num_gpus = sum(new_rank != -1
for new_rank in rank_mapping.values())
num_replicas = num_replicas // ep_group.size(
) * num_gpus # handle num replicas change
else:
num_nodes = get_node_count()
num_gpus = ep_group.size()
if num_gpus % num_nodes != 0:
self.num_nodes = 1
logger.warning_once(
f"num_gpus % num_nodes != 0, "
"not using hierarchical rearrangement algorithm.\n"
f"{num_gpus=}, {num_nodes=}")
# Get new expert mappings
(
new_physical_to_logical_map,
new_logical_to_physical_map,
new_logical_replica_count,
) = (rebalance_experts(
global_expert_load_window,
num_replicas,
num_groups,
num_nodes,
num_gpus,
))
# Update expert weights
rearrange_expert_weights_inplace(
self.physical_to_logical_map,
new_physical_to_logical_map,
model.expert_weights,
ep_group,
is_profile,
rank_mapping,
)
if not is_profile:
if self.physical_to_logical_map.shape[
1] != new_physical_to_logical_map.shape[1]:
self.physical_to_logical_map = new_physical_to_logical_map.to(
self.physical_to_logical_map.device)
else:
self.physical_to_logical_map.copy_(new_physical_to_logical_map)
max_physical_slots = new_logical_to_physical_map.shape[-1]
assert max_physical_slots <= self.logical_to_physical_map.shape[-1]
new_logical_to_physical_map = torch.nn.functional.pad(
new_logical_to_physical_map,
(0,
self.logical_to_physical_map.shape[-1] - max_physical_slots),
value=-1,
)
self.logical_to_physical_map.copy_(new_logical_to_physical_map)
self.logical_replica_count.copy_(new_logical_replica_count)
if is_main_rank:
assert time_start is not None
torch.cuda.synchronize()
time_end = time.perf_counter()
logger.info(
"Rearranged experts%sin %.2f seconds.",
" (profile) " if is_profile else " ",
time_end - time_start,
)
return None
@staticmethod
def recv_state() -> tuple[torch.Tensor, torch.Tensor]:
"""
Receive the expert load and old placement from the master rank.
"""
ep_group = get_ep_group()
metadata = torch.empty(3, dtype=torch.int32, device="cpu")
torch.distributed.broadcast(metadata,
group=ep_group.cpu_group,
group_src=0)
num_moe_layers, num_logical_experts, num_old_physical_experts = (
metadata.tolist())
global_expert_load = torch.zeros(
(num_moe_layers, num_logical_experts),
dtype=torch.int64,
device=ep_group.device,
)
all_reduce(global_expert_load, group=ep_group.device_group)
old_global_expert_indices = torch.empty(
(num_moe_layers, num_old_physical_experts),
dtype=torch.int64,
device=ep_group.device,
)
torch.distributed.broadcast(old_global_expert_indices,
group=ep_group.device_group,
group_src=0)
return global_expert_load, old_global_expert_indices
def _node_count_with_rank_mapping(
pg: Union[ProcessGroup, StatelessProcessGroup],
rank_mapping: dict[int, int],
) -> int:
if isinstance(pg, ProcessGroup):
world_size = torch.distributed.get_world_size(group=pg)
else:
world_size = pg.world_size
if world_size == 1:
return 1
# Build node assignment map
node_assignment = [0] * world_size # rank -> node_id
next_node_id = 0
for current_rank in range(world_size):
if node_assignment[current_rank] != 0:
continue # Already assigned to a node
assert current_rank in rank_mapping
if rank_mapping[current_rank] == -1:
continue # Pending shutdown
# Assign current rank to a new node
next_node_id += 1
node_assignment[current_rank] = next_node_id
# Find all ranks on the same node as current_rank
same_node_flags = in_the_same_node_as(pg, current_rank)
for other_rank, is_same_node in enumerate(same_node_flags):
if is_same_node and node_assignment[other_rank] == 0:
node_assignment[other_rank] = next_node_id
return next_node_id

View File

@@ -0,0 +1,239 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Expert parallelism load balancer (EPLB) for vLLM.
This module implements the core rearrangement algorithm.
The rearrangement algorithm is adapted from
[DeepSeek EPLB](https://github.com/deepseek-ai/eplb).
Please find at [#12](https://github.com/deepseek-ai/EPLB/issues/12) an example
on how the EPLB algorithm works.
"""
import torch
def balanced_packing(weight: torch.Tensor,
num_packs: int) -> tuple[torch.Tensor, torch.Tensor]:
"""
Pack n weighted objects to m packs, such that each bin contains exactly
n/m objects and the weights of all packs are as balanced as possible.
Parameters:
weight: [X, n], the weight of each item
num_packs: number of packs
Returns:
pack_index: [X, n], the pack index of each item
rank_in_pack: [X, n], the rank of the item in the pack
"""
num_layers, num_groups = weight.shape
assert num_groups % num_packs == 0
groups_per_pack = num_groups // num_packs
if groups_per_pack == 1:
pack_index = torch.arange(weight.size(-1),
dtype=torch.int64,
device=weight.device).expand(weight.shape)
rank_in_pack = torch.zeros_like(weight, dtype=torch.int64)
return pack_index, rank_in_pack
indices = weight.float().sort(-1, descending=True).indices.cpu()
pack_index = torch.full_like(weight,
fill_value=-1,
dtype=torch.int64,
device="cpu")
rank_in_pack = torch.full_like(pack_index, fill_value=-1)
for i in range(num_layers):
pack_weights = [0] * num_packs
pack_items = [0] * num_packs
for group in indices[i]:
pack = min(
(i
for i in range(num_packs) if pack_items[i] < groups_per_pack),
key=pack_weights.__getitem__,
)
assert pack_items[pack] < groups_per_pack
pack_index[i, group] = pack
rank_in_pack[i, group] = pack_items[pack]
pack_weights[pack] += weight[i, group]
pack_items[pack] += 1
return pack_index, rank_in_pack
def replicate_experts(
weight: torch.Tensor,
num_phy: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Replicate `num_log` experts to `num_phy` replicas, such that the maximum
load of all replicas is minimized.
Parameters:
weight: [X, num_log]
num_phy: total number of experts after replication
Returns:
phy2log: [X, num_phy], logical expert id of each physical expert
rank: [X, num_phy], the replica rank
logcnt: [X, num_log], number of replicas for each logical expert
"""
n, num_log = weight.shape
num_redundant = num_phy - num_log
assert num_redundant >= 0
device = weight.device
phy2log = torch.arange(num_phy, dtype=torch.int64,
device=device).repeat(n, 1)
rank = torch.zeros(n, num_phy, dtype=torch.int64, device=device)
logcnt = torch.ones(n, num_log, dtype=torch.int64, device=device)
arangen = torch.arange(n, dtype=torch.int64, device=device)
for i in range(num_log, num_phy):
redundant_indices = (weight / logcnt).max(dim=-1).indices
phy2log[:, i] = redundant_indices
rank[:, i] = logcnt[arangen, redundant_indices]
logcnt[arangen, redundant_indices] += 1
return phy2log, rank, logcnt
def rebalance_experts_hierarchical(
weight: torch.Tensor,
num_physical_experts: int,
num_groups: int,
num_nodes: int,
num_gpus: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Parameters:
weight: [num_moe_layers, num_logical_experts]
num_physical_experts: number of physical experts after replication
num_groups: number of expert groups
num_nodes: number of server nodes, where the intra-node network
(e.g., NVLink) is faster
num_gpus: number of GPUs, must be a multiple of `num_nodes`
Returns:
physical_to_logical_map (torch.Tensor):
[num_moe_layers, num_physical_experts]
logical_to_physical_map (torch.Tensor):
[num_moe_layers, num_logical_experts, X]
logical_count (torch.Tensor):
[num_moe_layers, num_logical_experts]
"""
num_layers, num_logical_experts = weight.shape
assert num_logical_experts % num_groups == 0
group_size = num_logical_experts // num_groups
assert num_groups % num_nodes == 0
groups_per_node = num_groups // num_nodes
assert num_gpus % num_nodes == 0
assert num_physical_experts % num_gpus == 0
phy_experts_per_gpu = num_physical_experts // num_gpus
def inverse(perm: torch.Tensor) -> torch.Tensor:
inv = torch.empty_like(perm)
inv.scatter_(
1,
perm,
torch.arange(perm.size(1), dtype=torch.int64,
device=perm.device).expand(perm.shape),
)
return inv
# Step 1: pack groups to nodes
tokens_per_group = weight.unflatten(-1, (num_groups, group_size)).sum(-1)
group_pack_index, group_rank_in_pack = balanced_packing(
tokens_per_group, num_nodes)
log2mlog = (((group_pack_index * groups_per_node + group_rank_in_pack) *
group_size).unsqueeze(-1) +
torch.arange(group_size,
dtype=torch.int64,
device=group_pack_index.device)).flatten(-2)
mlog2log = inverse(log2mlog)
# Step 2: construct redundant experts within nodes
# [num_layers * num_nodes, num_logical_experts // num_nodes]
tokens_per_mlog = weight.gather(-1, mlog2log).view(
-1, num_logical_experts // num_nodes)
phy2mlog, phyrank, mlogcnt = replicate_experts(
tokens_per_mlog, num_physical_experts // num_nodes)
# Step 3: pack physical_experts to GPUs
# [num_layers * num_nodes, num_physical_experts // num_nodes]
tokens_per_phy = (tokens_per_mlog / mlogcnt).gather(-1, phy2mlog)
pack_index, rank_in_pack = balanced_packing(tokens_per_phy,
num_gpus // num_nodes)
phy2pphy = pack_index * phy_experts_per_gpu + rank_in_pack
pphy2phy = inverse(phy2pphy)
pphy2mlog = phy2mlog.gather(
-1, pphy2phy) # [num_layers * num_nodes, num_log_per_nodes]
pphy2mlog = (pphy2mlog.view(num_layers, num_nodes, -1) + torch.arange(
0,
num_logical_experts,
num_logical_experts // num_nodes,
device=group_pack_index.device,
).view(1, -1, 1)).flatten(-2)
pphy2log = mlog2log.gather(-1, pphy2mlog)
pphyrank = phyrank.gather(-1, pphy2phy).view(num_layers, -1)
logcnt = mlogcnt.view(num_layers, -1).gather(-1, log2mlog)
return pphy2log, pphyrank, logcnt
def rebalance_experts(
weight: torch.Tensor,
num_replicas: int,
num_groups: int,
num_nodes: int,
num_gpus: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Entry point for expert-parallelism load balancer.
Parameters:
weight: [layers, num_logical_experts], the load statistics for all
logical experts
num_replicas: number of physical experts, must be a multiple of
`num_gpus`
num_groups: number of expert groups
num_nodes: number of server nodes, where the intra-node network
(e.g, NVLink) is faster
num_gpus: number of GPUs, must be a multiple of `num_nodes`
Returns:
physical_to_logical_map:
[layers, num_replicas], the expert index of each replica
logical_to_physical_map:
[layers, num_logical_experts, X], the replica indices for each
expert
expert_count:
[layers, num_logical_experts], number of physical
replicas for each logical expert
"""
num_layers, num_logical_experts = weight.shape
weight = weight.float().cpu()
if num_groups % num_nodes == 0:
# use hierarchical load-balance policy
phy2log, phyrank, logcnt = rebalance_experts_hierarchical(
weight, num_replicas, num_groups, num_nodes, num_gpus)
else:
# use global load-balance policy
phy2log, phyrank, logcnt = rebalance_experts_hierarchical(
weight, num_replicas, 1, 1, num_gpus)
num_redundant_experts = num_replicas - num_logical_experts
maxlogcnt = num_redundant_experts + 1
log2phy: torch.Tensor = torch.full(
(num_layers, num_logical_experts, maxlogcnt),
-1,
dtype=torch.int64,
device=logcnt.device,
)
log2phy.view(num_layers, -1).scatter_(
-1,
phy2log * maxlogcnt + phyrank,
torch.arange(num_replicas, dtype=torch.int64,
device=log2phy.device).expand(num_layers, -1),
)
return phy2log, log2phy, logcnt
__all__ = ["rebalance_experts"]

View File

@@ -0,0 +1,424 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
The actual execution of the rearrangement.
This involves the exchange of expert weights between GPUs.
"""
from collections.abc import Iterable, MutableSequence, Sequence
from functools import partial
from typing import Optional
import torch
from torch.distributed import (P2POp, ProcessGroup, all_gather,
batch_isend_irecv, get_global_rank)
def idx_local_to_global(
local_idx: int,
local_cnt: int,
ep_rank: int,
) -> int:
"""
Convert a local expert index to a global expert index.
"""
return ep_rank * local_cnt + local_idx
def idx_global_to_local(
global_idx: int,
local_cnt: int,
ep_rank: int,
) -> int:
"""
Convert a global expert index to a local expert index.
"""
return global_idx - ep_rank * local_cnt
def global_idx_to_rank(
global_idx: int,
local_cnt: int,
) -> int:
"""
Convert a global expert index to a rank index.
"""
return global_idx // local_cnt
def get_ep_ranks_with_expert(
idx: int,
num_local_experts: int,
old_indices: Sequence[int],
new_indices: Sequence[int],
) -> tuple[MutableSequence[int], MutableSequence[int]]:
"""
Get the ranks of the experts that need to be exchanged.
Args:
idx: The index of the expert.
num_local_experts: The number of local experts.
old_indices: The old indices of the experts.
new_indices: The new indices of the experts.
Returns:
A tuple of two lists:
- The ranks of the experts that need to be sent.
- The ranks of the experts that need to be received.
"""
global2rank = partial(
global_idx_to_rank,
local_cnt=num_local_experts,
)
ranks_to_send: list[int] = []
ranks_to_recv: list[int] = []
for i, e in enumerate(old_indices):
if e == idx:
rank = global2rank(i)
if not ranks_to_send or ranks_to_send[-1] != rank:
ranks_to_send.append(rank)
for i, e in enumerate(new_indices):
if e == idx:
rank = global2rank(i)
if not ranks_to_recv or ranks_to_recv[-1] != rank:
ranks_to_recv.append(rank)
# Remove those ranks that can get this expert locally.
ranks_to_send_set = set(ranks_to_send)
ranks_to_recv_actual = [
rank for rank in ranks_to_recv if rank not in ranks_to_send_set
]
return ranks_to_send, ranks_to_recv_actual
def shuffle_layer(
num_local_experts: int,
ep_rank: int,
old_indices: Sequence[int],
new_indices: Sequence[int],
expert_weights: Iterable[torch.Tensor],
expert_weights_buffer: Sequence[torch.Tensor],
ep_group: ProcessGroup,
) -> None:
"""
Perform expert weights rearrangement of one layer.
"""
local2global = partial(
idx_local_to_global,
local_cnt=num_local_experts,
ep_rank=ep_rank,
)
# 0. Do nothing for experts that did not change.
is_unchanged = [
old_indices[local2global(i)] == new_indices[local2global(i)]
for i in range(num_local_experts)
]
# 1. Perform weight copy inside the local rank.
is_received_locally = is_unchanged[:]
for src in range(num_local_experts):
src_global = local2global(src)
for dst in range(num_local_experts):
dst_global = local2global(dst)
if is_received_locally[dst]:
continue
if old_indices[src_global] == -1 or new_indices[dst_global] == -1:
continue
if old_indices[src_global] == new_indices[dst_global]:
is_received_locally[dst] = True
for weight, buffer in zip(expert_weights,
expert_weights_buffer):
buffer[dst].copy_(weight[src])
p2p_ops: list[P2POp] = []
# 2. Initiate sending of weights.
experts_send_loc: dict[int, int] = {}
for src in range(num_local_experts):
expert = old_indices[local2global(src)]
if expert == -1:
continue
if expert in experts_send_loc:
continue
experts_send_loc[expert] = src
# We need to sort here to match send/recv
for expert, src in sorted(experts_send_loc.items()):
ranks_to_send, ranks_to_recv = get_ep_ranks_with_expert(
expert,
num_local_experts,
old_indices,
new_indices,
)
# Calculate the ranks to send by this rank
num_dst_per_sender = len(ranks_to_recv) // len(ranks_to_send)
sender_pos = ranks_to_send.index(ep_rank)
recv_begin = sender_pos * num_dst_per_sender
recv_end = recv_begin + num_dst_per_sender
recv_ranks = ranks_to_recv[recv_begin:recv_end]
# Tackle remainders
remainder_start = len(ranks_to_send) * num_dst_per_sender
recver_pos = remainder_start + sender_pos
if recver_pos < len(ranks_to_recv):
recv_ranks.append(ranks_to_recv[recver_pos])
for dst in recv_ranks:
dst_global = get_global_rank(ep_group, dst)
p2p_ops += [
P2POp(
torch.distributed.isend,
weight[src],
dst_global,
) for weight in expert_weights
]
# 3. Initiate receiving of weights.
experts_recv_loc: dict[int, int] = {}
for dst in range(num_local_experts):
if is_received_locally[dst]:
continue
expert = new_indices[local2global(dst)]
if expert == -1:
continue
if expert in experts_recv_loc:
continue
experts_recv_loc[expert] = dst
# We need to sort here to match send/recv
for expert, dst in sorted(experts_recv_loc.items()):
ranks_to_send, ranks_to_recv = get_ep_ranks_with_expert(
expert,
num_local_experts,
old_indices,
new_indices,
)
# Calculate the rank to recv by this rank
num_dst_per_sender = len(ranks_to_recv) // len(ranks_to_send)
recver_pos = ranks_to_recv.index(ep_rank)
remainder_start = len(ranks_to_send) * num_dst_per_sender
if recver_pos < remainder_start:
src = ranks_to_send[recver_pos // num_dst_per_sender]
else:
src = ranks_to_send[recver_pos - remainder_start]
src_global = get_global_rank(ep_group, src)
p2p_ops += [
P2POp(
torch.distributed.irecv,
weight[dst],
src_global,
) for weight in expert_weights_buffer
]
# 4. Execute the P2P operations. The real communication happens here.
if p2p_ops:
reqs = batch_isend_irecv(p2p_ops)
for req in reqs:
req.wait()
# 5. Copy the weights from the buffer back to the original weights.
for dst in range(num_local_experts):
if is_unchanged[dst]:
continue
if is_received_locally[dst]:
for weight, buffer in zip(expert_weights, expert_weights_buffer):
weight[dst].copy_(buffer[dst])
else:
expert = new_indices[local2global(dst)]
if expert == -1:
continue
src = experts_recv_loc[expert]
for weight, buffer in zip(expert_weights, expert_weights_buffer):
weight[dst].copy_(buffer[src])
def rearrange_expert_weights_inplace(
old_global_expert_indices: torch.Tensor,
new_global_expert_indices: torch.Tensor,
expert_weights: Sequence[Iterable[torch.Tensor]],
ep_group: ProcessGroup,
is_profile: bool = False,
rank_mapping: Optional[dict[int, int]] = None,
) -> None:
"""
Rearranges the expert weights in place according to the new expert indices.
The value of the indices arguments are logical indices of the experts,
while keys are physical.
Args:
old_global_expert_indices: Shape (num_moe_layers, num_physical_experts).
new_global_expert_indices: Shape (num_moe_layers, num_physical_experts).
expert_weights: A sequence of shape (num_moe_layers)(weight_count)
of tensors of shape (num_local_physical_experts, hidden_size_i).
For example, a linear layer may have up and down projection,
so weight_count = 2. Each weight's hidden size can be different.
ep_group: The device process group for expert parallelism.
is_profile (bool): If `True`, do not perform any actual weight copy.
This is used during profile run, where we only perform dummy
communications to reserve enough memory for the buffers.
rank_mapping: A dictionary mapping old rank to new rank.
"""
if rank_mapping is not None:
if len(rank_mapping) == ep_group.size():
# scale down
new_global_expert_indices = \
_map_new_expert_indices_with_rank_mapping(
new_global_expert_indices,
rank_mapping,
)
else:
# scale up
old_global_expert_indices = \
_map_old_expert_indices_with_rank_mapping(
old_global_expert_indices,
rank_mapping,
ep_group.size(),
)
assert old_global_expert_indices.shape[
1] == new_global_expert_indices.shape[1]
num_moe_layers, num_physical_experts = old_global_expert_indices.shape
assert len(expert_weights) == num_moe_layers
num_local_physical_experts = next(iter(expert_weights[0])).shape[0]
assert new_global_expert_indices.shape == (num_moe_layers,
num_physical_experts)
ep_rank = ep_group.rank()
ep_size = ep_group.size()
assert num_physical_experts == ep_size * num_local_physical_experts
# A buffer to hold the expert weights in one layer during the exchange.
# NOTE: Currently we assume the same weights across different layers
# have the same shape.
expert_weights_buffer = [torch.empty_like(w) for w in expert_weights[0]]
if is_profile:
# Maximum send size is to send all local experts to all ranks,
# So we use a dummy `all_gather` to reserve enough communication buffer
for weight, buffer in zip(expert_weights[0], expert_weights_buffer):
# A `/dev/null`-like buffer to avoid real memory allocation
dummy_recv_buffer = [buffer for _ in range(ep_size)]
# NOTE(bowen): Needed this barrier to avoid OOM during actual
# execution. I'm not very sure why this is needed
torch.distributed.barrier()
all_gather(
dummy_recv_buffer,
weight,
group=ep_group,
)
return
for layer in range(num_moe_layers):
# NOTE(bowen): We need this synchronize to run, but I don't know why.
# If you figure out the reason, please let me know -- thank you!
torch.cuda.synchronize()
shuffle_layer(
num_local_physical_experts,
ep_rank,
old_global_expert_indices[layer].tolist(),
new_global_expert_indices[layer].tolist(),
expert_weights[layer],
expert_weights_buffer,
ep_group,
)
def _map_old_expert_indices_with_rank_mapping(
old_global_expert_indices: torch.Tensor,
rank_mapping: dict[int, int],
new_ep_size: int,
) -> torch.Tensor:
"""
Map the old global expert indices to the new global expert indices.
Args:
old_global_expert_indices:
Shape (num_layers, old_ep_size * num_local_physical_experts).
rank_mapping: Mapping from old rank to new rank.
new_ep_size: New expert parallelism size.
Returns:
Mapped expert indices with shape
(num_layers, new_ep_size * num_local_physical_experts).
"""
num_layers, old_num_physical_experts = old_global_expert_indices.shape
assert rank_mapping, "Rank mapping is required"
# Get sizes from parameters and rank_mapping
old_ep_size = len(rank_mapping)
num_local_physical_experts = old_num_physical_experts // old_ep_size
new_num_physical_experts = new_ep_size * num_local_physical_experts
# Create mapped tensor with new shape, initialized to -1
mapped_expert_indices = torch.full(
(num_layers, new_num_physical_experts),
fill_value=-1,
dtype=old_global_expert_indices.dtype,
device=old_global_expert_indices.device,
)
# Handle rank mapping (scale up/down with rank changes)
for old_rank in range(old_ep_size):
new_rank = rank_mapping.get(old_rank)
if new_rank is not None and new_rank >= 0 and new_rank < new_ep_size:
# This old rank exists in the new configuration
old_start_idx = old_rank * num_local_physical_experts
old_end_idx = (old_rank + 1) * num_local_physical_experts
new_start_idx = new_rank * num_local_physical_experts
new_end_idx = (new_rank + 1) * num_local_physical_experts
mapped_expert_indices[:, new_start_idx:new_end_idx] = \
old_global_expert_indices[:, old_start_idx:old_end_idx]
# If new_rank is None or >= new_ep_size, the experts remain -1
# (scale down case)
return mapped_expert_indices
def _map_new_expert_indices_with_rank_mapping(
new_global_expert_indices: torch.Tensor,
rank_mapping: dict[int, int],
) -> torch.Tensor:
num_layers, new_num_physical_experts = new_global_expert_indices.shape
assert rank_mapping, "Rank mapping is required"
# Get sizes from parameters and rank_mapping
old_ep_size = len(rank_mapping)
new_ep_size = sum(new_rank != -1 for new_rank in rank_mapping.values())
num_local_physical_experts = new_num_physical_experts // new_ep_size
old_num_physical_experts = old_ep_size * num_local_physical_experts
mapped_expert_indices = torch.full(
(num_layers, old_num_physical_experts),
fill_value=-1,
dtype=new_global_expert_indices.dtype,
device=new_global_expert_indices.device,
)
for old_rank in range(old_ep_size):
new_rank = rank_mapping[old_rank]
if new_rank >= 0 and new_rank < new_ep_size:
old_start_idx = old_rank * num_local_physical_experts
old_end_idx = (old_rank + 1) * num_local_physical_experts
new_start_idx = new_rank * num_local_physical_experts
new_end_idx = (new_rank + 1) * num_local_physical_experts
mapped_expert_indices[:, old_start_idx:old_end_idx] = \
new_global_expert_indices[:, new_start_idx:new_end_idx]
return mapped_expert_indices
__all__ = ["rearrange_expert_weights_inplace"]

View File

@@ -0,0 +1,362 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import queue
import threading
import time
from abc import ABC, abstractmethod
from collections import deque
from dataclasses import asdict
from itertools import count
from queue import Queue
from typing import Any, Callable, Optional, Union
import msgspec
import zmq
from vllm.config.kv_events import KVEventsConfig
from vllm.logger import init_logger
from vllm.v1.core.kv_cache_utils import ExternalBlockHash
logger = init_logger(__name__)
class EventBatch(
msgspec.Struct,
array_like=True, # type: ignore[call-arg]
omit_defaults=True, # type: ignore[call-arg]
gc=False, # type: ignore[call-arg]
):
ts: float
events: list[Any]
data_parallel_rank: Optional[int] = None
class KVCacheEvent(
msgspec.Struct,
array_like=True, # type: ignore[call-arg]
omit_defaults=True, # type: ignore[call-arg]
gc=False, # type: ignore[call-arg]
tag=True):
"""Base class for all KV cache-related events"""
MEDIUM_GPU = "GPU"
class BlockStored(KVCacheEvent):
block_hashes: list[ExternalBlockHash]
parent_block_hash: Optional[ExternalBlockHash]
token_ids: list[int]
block_size: int
lora_id: Optional[int]
medium: Optional[str]
class BlockRemoved(KVCacheEvent):
block_hashes: list[ExternalBlockHash]
medium: Optional[str]
class AllBlocksCleared(KVCacheEvent):
pass
class KVEventBatch(EventBatch):
events: list[Union[BlockStored, BlockRemoved, AllBlocksCleared]]
class EventPublisher(ABC):
"""Lightweight publisher for EventBatch batches with data parallelism
support.
In data parallel setups, each DP rank runs its own EventPublisher instance
to avoid duplicate events and ensure proper event attribution:
- Each DP rank creates a separate publisher
- Publishers automatically annotate events with their data_parallel_rank
- This allows consumers to distinguish events from different DP ranks
The publisher is responsible for adding DP metadata since the scheduler
operates independently of DP topology and shouldn't need DP awareness.
"""
def __init__(self, data_parallel_rank: int = 0) -> None:
self._data_parallel_rank = data_parallel_rank
@abstractmethod
def publish(self, events: EventBatch) -> None:
"""Emit events in order.
Implementations should guarantee at-least-once delivery and
monotonic ordering (e.g., via sequence numbers).
"""
@abstractmethod
def shutdown(self) -> None:
"""Shutdown the publisher."""
class NullEventPublisher(EventPublisher):
"""No-op implementation (default when disabled)."""
def publish(self, events) -> None:
return
def shutdown(self) -> None:
return
class ZmqEventPublisher(EventPublisher):
"""Reliable PUB/ROUTER publisher with an in-memory replay buffer.
Spawns a separate thread to handle publishing from a queue.
Parameters
----------
endpoint:
PUB address. Use ``tcp://*:5557`` to bind or ``tcp://host:5557`` to
connect.
replay_endpoint:
Optional ROUTER address for replay requests. When given, subscribers can
request missed batches by sending the starting sequence number as an
8-byte big-endian integer.
buffer_steps:
Number of past batches to keep for replay.
hwm:
ZeroMQ high-water-mark for PUB socket.
max_queue_size:
Maximum number of events to buffer in memory.
topic:
Topic to publish events to.
"""
SHUTDOWN_TIMEOUT: float = 1.0
END_SEQ = (-1).to_bytes(8, "big", signed=True)
def __init__(
self,
data_parallel_rank: int,
endpoint: str = "tcp://*:5557",
replay_endpoint: Optional[str] = None,
buffer_steps: int = 10_000,
hwm: int = 100_000,
max_queue_size: int = 100_000,
topic: str = "",
) -> None:
# Storage
super().__init__(data_parallel_rank)
self._event_queue = Queue[Optional[EventBatch]](maxsize=max_queue_size)
self._buffer = deque[tuple[int, bytes]](maxlen=buffer_steps)
# ZMQ sockets
self._ctx = zmq.Context.instance()
self._pub: Optional[zmq.Socket] = None
self._replay: Optional[zmq.Socket] = None
self._dp_rank = data_parallel_rank
self._endpoint = self.offset_endpoint_port(endpoint, self._dp_rank)
self._replay_endpoint = self.offset_endpoint_port(
replay_endpoint, self._dp_rank)
self._hwm = hwm
self._socket_setup()
# Payload
self._seq_gen = count()
self._topic_bytes = topic.encode('utf-8')
# Thread
self._running = True
logger.info("Starting ZMQ publisher thread")
self._thread = threading.Thread(target=self._publisher_thread,
daemon=True,
name="zmq-publisher")
self._thread.start()
def publish(self, events: EventBatch) -> None:
if not self._running:
raise RuntimeError("Publisher is closed")
if events.data_parallel_rank is None:
events.data_parallel_rank = self._data_parallel_rank
self._event_queue.put(events)
def shutdown(self) -> None:
"""Stop the publisher thread and clean up resources."""
self._running = False
self._event_queue.put_nowait(None)
start = time.time()
pending_items = True
while pending_items and (time.time() - start < self.SHUTDOWN_TIMEOUT):
pending_items = not self._event_queue.empty()
if pending_items:
time.sleep(0.1)
if pending_items:
logger.warning(
"Warning: Queue still has %s items after %s seconds timeout",
self._event_queue.qsize(),
self.SHUTDOWN_TIMEOUT,
)
if self._thread.is_alive():
self._thread.join(timeout=self.SHUTDOWN_TIMEOUT)
# Clean up ZMQ resources
try:
if self._pub is not None:
self._pub.close(linger=0)
if self._replay is not None:
self._replay.close(linger=0)
finally:
pass # Do not terminate context; other sockets may use it
def _socket_setup(self) -> None:
"""Initialize sockets
https://pyzmq.readthedocs.io/en/v19.0.0/morethanbindings.html#thread-safety
"""
if self._pub is None:
self._pub = self._ctx.socket(zmq.PUB)
self._pub.set_hwm(self._hwm)
# Heuristic: bind if wildcard / * present, else connect.
# bind stable, connect volatile convention
if (self._endpoint is not None
and ("*" in self._endpoint or "::" in self._endpoint
or self._endpoint.startswith("ipc://")
or self._endpoint.startswith("inproc://"))):
self._pub.bind(self._endpoint)
elif self._endpoint is not None:
self._pub.connect(self._endpoint)
# Set up replay socket: use ROUTER
# 1) handles multiple REQ clients (identities)
# 2) lets us send back one request → many replies (streamed events)
# 3) works in our nonblocking poll loop alongside PUB
if self._replay_endpoint is not None:
self._replay = self._ctx.socket(zmq.ROUTER)
self._replay.bind(self._replay_endpoint)
def _publisher_thread(self) -> None:
"""Background thread that processes the event queue."""
self._pack = msgspec.msgpack.Encoder()
assert self._pub is not None # narrows type for mypy
while self._running or self._event_queue.qsize() > 0:
# --- replay (non-critical) ---------------------------------
if self._replay is not None and self._replay.poll(0):
try:
self._service_replay()
except Exception as e:
logger.exception("Error in replay: %s", e)
# --- main queue (critical) ---------------------------------
try:
event = self._event_queue.get(timeout=0.1)
if event is None:
break # Sentinel received, exit thread
except queue.Empty:
continue
try:
seq = next(self._seq_gen)
payload = self._pack.encode(event)
seq_bytes = seq.to_bytes(8, "big")
self._pub.send_multipart(
(self._topic_bytes, seq_bytes, payload))
self._buffer.append((seq, payload))
self._event_queue.task_done()
except Exception as e:
# Publishing failed; back-off a bit to avoid a tight error loop
logger.exception("Error in publisher thread: %s", e)
time.sleep(0.1)
def _service_replay(self) -> None:
"""If a replay request is waiting, send buffered batches."""
assert self._replay is not None # narrows type for mypy
frame = self._replay.recv_multipart()
if len(frame) != 3:
logger.warning("Invalid replay request: %s", frame)
return
client_id, _, start_seq_bytes = frame
start_seq = int.from_bytes(start_seq_bytes, "big")
for seq, buf in self._buffer:
if seq >= start_seq:
# [identity, empty_delim, seq_bytes, payload]
# (identity, empty_delim) are stripped off by the router
# receiving payload is (seq_bytes, payload)
self._replay.send_multipart(
(client_id, b"", seq.to_bytes(8, "big"), buf))
# Send end of sequence marker
# receiving payload is (-1, b""")
self._replay.send_multipart((client_id, b"", self.END_SEQ, b""))
@staticmethod
def offset_endpoint_port(endpoint: Optional[str],
data_parallel_rank: int) -> Optional[str]:
"""Helper function to offset the port in an endpoint by
the data parallel rank.
Args:
endpoint: The endpoint string
(e.g., "tcp://*:5557" or "inproc://cache")
data_parallel_rank: The data parallel rank to offset by
Returns:
The endpoint with the port offset by data_parallel_rank
or suffix appended
"""
# Do nothing if input is None or data_parallel_rank is 0
if not endpoint or data_parallel_rank == 0:
return endpoint
if "inproc" in endpoint:
return f"{endpoint}_dp{data_parallel_rank}"
if "tcp" in endpoint:
if endpoint and ":" in endpoint:
# Get everything after the last colon (the port)
last_colon_idx = endpoint.rfind(":")
base_addr = endpoint[:last_colon_idx]
base_port = int(endpoint[last_colon_idx + 1:])
new_port = base_port + data_parallel_rank
return f"{base_addr}:{new_port}"
return endpoint
raise ValueError("Invalid endpoint: must contain 'inproc' or 'tcp'")
class EventPublisherFactory:
_registry: dict[str, Callable[..., EventPublisher]] = {
"null": NullEventPublisher,
"zmq": ZmqEventPublisher,
}
@classmethod
def register_publisher(cls, name: str,
ctor: Callable[..., EventPublisher]) -> None:
if name in cls._registry:
raise KeyError(f"publisher '{name}' already registered")
cls._registry[name] = ctor
@classmethod
def create(cls,
config: Optional[KVEventsConfig],
data_parallel_rank: int = 0) -> EventPublisher:
"""Create publisher from a config mapping."""
if not config:
return NullEventPublisher()
config_dict = asdict(config)
kind = config_dict.pop("publisher", "null")
config_dict.pop("enable_kv_cache_events")
try:
constructor = cls._registry[kind]
except KeyError as exc:
raise ValueError(f"Unknown event publisher '{kind}'") from exc
return constructor(data_parallel_rank=data_parallel_rank,
**config_dict)

View File

@@ -0,0 +1,29 @@
# Distributed KV cache transfer
This folder implements distributed KV cache transfer across vLLM instances.
Currently the main use case is for disaggregated prefilling.
## Abstractions
The KV cache transfer contains three layer of abstractions:
- KV pipe: a FIFO pipe for torch.tensor transmission. Key APIs: `send_tensor` and `recv_tensor`.
- KV lookup buffer: a lookup buffer for KV caches. Key: the tokens, value: the KV caches (and/or hidden states). Key APIs: `insert` and `drop_select` (similar to SQL semantics).
- KV connector: a connector that connects the KV pipe and KV lookup buffer to vLLM. Key APIs: `send_kv_caches_and_hidden_states` and `recv_kv_caches_and_hidden_states`.
Why we need KV lookup buffer: FIFO pipe itself is not enough as prefill vLLM worker may process requests in a different order compared to decode vLLM worker. Say the QPS is really high, prefill worker may handle requests in order A -> B -> C, but the decode worker may process request C first. This is not the case that can be naturally handled by FIFO pipe, so we provide KV lookup buffer to help translate a FIFO pipe to a lookup buffer.
NOTE: KV pipe layer is bypassable: you can skip this layer if your distributed
communication service already supports key-value-based lookup (like redis or
RDMA database).
NOTE: If you want to not only transfer KV caches, but adjust the model execution flow of vLLM as well (for example, allow vLLM to receive KV caches on some tokens and do prefill on the remaining tokens), you can bypass both KV pipe layer and KV lookup buffer layer, and directly implement on KV connector layer. Bear in mind that as vLLM's model input is constantly changing, this implementation will likely be broken when vLLM has new updates.
## Disaggregated prefilling
The example usage is in [this file](../../../examples/online_serving/disaggregated_prefill.sh).
Here is the diagram of how we run disaggregated prefilling.
![Disaggregated prefill workflow](./disagg_prefill_workflow.jpg)

View File

@@ -0,0 +1,13 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.distributed.kv_transfer.kv_transfer_state import (
KVConnectorBaseType, ensure_kv_transfer_initialized,
ensure_kv_transfer_shutdown, get_kv_transfer_group, has_kv_transfer_group,
is_v1_kv_transfer_group)
__all__ = [
"get_kv_transfer_group", "has_kv_transfer_group",
"is_v1_kv_transfer_group", "ensure_kv_transfer_initialized",
"ensure_kv_transfer_shutdown", "KVConnectorBaseType"
]

Binary file not shown.

After

Width:  |  Height:  |  Size: 139 KiB

View File

@@ -0,0 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Defines the base type for KV cache connectors."""
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
KVConnectorBase = KVConnectorBase_V1
KVConnectorBaseType = KVConnectorBase_V1
__all__ = ["KVConnectorBase", "KVConnectorBaseType"]

View File

@@ -0,0 +1,113 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import importlib
from typing import TYPE_CHECKING, Callable
# yapf: disable
import vllm.envs as envs
from vllm.distributed.kv_transfer.kv_connector.base import (
KVConnectorBase, KVConnectorBaseType)
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole
from vllm.logger import init_logger
# yapf: enable
if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.config.kv_transfer import KVTransferConfig
logger = init_logger(__name__)
class KVConnectorFactory:
_registry: dict[str, Callable[[], type[KVConnectorBase]]] = {}
@classmethod
def register_connector(cls, name: str, module_path: str,
class_name: str) -> None:
"""Register a connector with a lazy-loading module and class name."""
if name in cls._registry:
raise ValueError(f"Connector '{name}' is already registered.")
def loader() -> type[KVConnectorBase]:
module = importlib.import_module(module_path)
return getattr(module, class_name)
cls._registry[name] = loader
@classmethod
def create_connector(
cls,
config: "VllmConfig",
role: KVConnectorRole,
) -> KVConnectorBase:
if not envs.VLLM_USE_V1:
raise ValueError("Attempting to initialize a V1 Connector, "
f"but found {envs.VLLM_USE_V1=}")
kv_transfer_config = config.kv_transfer_config
connector_cls = cls.get_connector_class(kv_transfer_config)
logger.info("Creating v1 connector with name: %s and engine_id: %s",
connector_cls.__name__, kv_transfer_config.engine_id)
# NOTE(Kuntai): v1 connector is explicitly separated into two roles.
# Scheduler connector:
# - Co-locate with scheduler process
# - Should only be used inside the Scheduler class
# Worker connector:
# - Co-locate with worker process
# - Should only be used inside the forward context & attention layer
# We build separately to enforce strict separation
return connector_cls(config, role)
@classmethod
def get_connector_class(
cls, kv_transfer_config: "KVTransferConfig"
) -> type[KVConnectorBaseType]:
"""Get the connector class by name."""
connector_name = kv_transfer_config.kv_connector
if connector_name in cls._registry:
connector_cls = cls._registry[connector_name]()
else:
connector_module_path = kv_transfer_config.kv_connector_module_path
if connector_module_path is None:
raise ValueError(
f"Unsupported connector type: {connector_name}")
connector_module = importlib.import_module(connector_module_path)
connector_cls = getattr(connector_module, connector_name)
return connector_cls
# Register various connectors here.
# The registration should not be done in each individual file, as we want to
# only load the files corresponding to the current connector.
KVConnectorFactory.register_connector(
"SharedStorageConnector",
"vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector",
"SharedStorageConnector")
KVConnectorFactory.register_connector(
"P2pNcclConnector",
"vllm.distributed.kv_transfer.kv_connector.v1.p2p.p2p_nccl_connector",
"P2pNcclConnector")
KVConnectorFactory.register_connector(
"LMCacheConnectorV1",
"vllm.distributed.kv_transfer.kv_connector.v1.lmcache_connector",
"LMCacheConnectorV1")
KVConnectorFactory.register_connector(
"NixlConnector",
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector",
"NixlConnector")
KVConnectorFactory.register_connector(
"MultiConnector",
"vllm.distributed.kv_transfer.kv_connector.v1.multi_connector",
"MultiConnector")
KVConnectorFactory.register_connector(
"OffloadingConnector",
"vllm.distributed.kv_transfer.kv_connector.v1.offloading_connector",
"OffloadingConnector")

View File

@@ -0,0 +1,261 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
KV cache helper for store.
"""
from collections import defaultdict
from collections.abc import Sequence
from concurrent.futures import CancelledError, Future
from typing import Literal, Optional, Union, cast
import torch
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.distributed.kv_transfer.kv_connector.factory import (
KVConnectorFactory)
from vllm.logger import init_logger
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
logger = init_logger(__name__)
class model_aware_kv_ops_helper:
def __init__(self, config: VllmConfig):
self.is_deepseek_mla = config.model_config.is_deepseek_mla
self.use_mla_opt = not envs.VLLM_MLA_DISABLE
self.tp_size = config.parallel_config.tensor_parallel_size
def get_model_args(self, model_executable: torch.nn.Module):
model_config = model_executable.model.config
self.model_executable = model_executable
num_heads = int(model_config.num_key_value_heads / self.tp_size)
hidden_size = model_config.hidden_size
num_attention_heads = model_config.num_attention_heads
# Deepseek's MLA (Multi-head Latent Attention) uses two different
# kv_cache shapes based on whether VLLM_MLA_DISABLE is set to 0.
# When VLLM_MLA_DISABLE=0 (default), forward absorb is applied,
# resulting in a kv_cache shape of [num_blks, blk_size, 1,
# kv_lora_rank + qk_rope_head_dim].
# When VLLM_MLA_DISABLE=1, standard FA is used instead, leading
# to a kv_cache shape of [2, num_blks, blk_size,
# num_key_value_heads / tp, qk_nope_head_dim + qk_rope_head_dim].
# For more details, see vllm/v1/attention/backends/mla/common.py.
if self.is_deepseek_mla and self.use_mla_opt:
head_size = model_config.kv_lora_rank + \
model_config.qk_rope_head_dim
num_heads = 1
elif self.is_deepseek_mla and not self.use_mla_opt:
head_size = model_config.qk_nope_head_dim + \
model_config.qk_rope_head_dim
else:
head_size = getattr(model_config, "head_dim", None)
if head_size is None:
head_size = int(hidden_size // num_attention_heads)
return num_heads, head_size
def get_kv_from_cache(self, kv_cache, num_heads, head_size):
if self.is_deepseek_mla and self.use_mla_opt:
key_cache = kv_cache.reshape(-1, num_heads, head_size)
value_cache = kv_cache.reshape(-1, num_heads, head_size)
else:
key_cache = kv_cache[0].reshape(-1, num_heads, head_size)
value_cache = kv_cache[1].reshape(-1, num_heads, head_size)
return key_cache, value_cache
def put_kv_to_cache(self, model_executable: torch.nn.Module, keys, values,
layer, kv_cache, slot_mapping, start_pos, end_pos):
model_config = model_executable.model.config
if self.is_deepseek_mla and self.use_mla_opt:
layer.self_attn.attn = layer.self_attn.mla_attn
k_c_normed_k_pe = keys.squeeze(1)
k_c_normed = k_c_normed_k_pe[:, :model_config.kv_lora_rank]
k_pe = k_c_normed_k_pe[:, model_config.kv_lora_rank:]
ops.concat_and_cache_mla(
k_c_normed.to(kv_cache.device),
k_pe.to(kv_cache.device),
kv_cache,
slot_mapping[start_pos:end_pos],
layer.self_attn.attn.kv_cache_dtype,
layer.self_attn.attn._k_scale,
)
else:
key_cache, value_cache = kv_cache[0], kv_cache[1]
ops.reshape_and_cache_flash(
keys.to(key_cache.device),
values.to(value_cache.device),
key_cache,
value_cache,
slot_mapping[start_pos:end_pos],
layer.self_attn.attn.kv_cache_dtype,
layer.self_attn.attn._k_scale,
layer.self_attn.attn._v_scale,
)
def get_kv_connector_cache_layout():
# NOTE (NickLucche) When running disaggregated PD with NIXL, HND layout is
# used for faster transfer.
vllm_config = get_current_vllm_config()
kv_config = vllm_config.kv_transfer_config
if kv_config is not None:
connector_cls = KVConnectorFactory.get_connector_class(kv_config)
required_kvcache_layout = connector_cls.get_required_kvcache_layout(
vllm_config)
if required_kvcache_layout is not None:
return required_kvcache_layout
logger.info_once("Connectors do not specify a " \
"kv cache layout, defaulting to NHD.")
return "NHD"
class KVOutputAggregator:
"""Utility class to aggregate the output of all workers into a single
output corresponding to Rank 0 for scheduler."""
def __init__(self, world_size: int):
# Complete transfer tracker. Used to track finished requests
# [req_id -> n_remaining_workers]
self._recv_remaining_count = defaultdict[str, int](lambda: world_size)
self._send_remaining_count = defaultdict[str, int](lambda: world_size)
def aggregate(self,
outputs: list[ModelRunnerOutput],
output_rank: int = 0) -> ModelRunnerOutput:
# Aggregate kv_connector_output from all workers
def update_finished_set(req_ids: Optional[set[str]],
remaining_count_dict: dict[str, int],
finished_set: set[str]) -> None:
for req_id in req_ids or ():
remaining_count_dict[req_id] -= 1
if remaining_count_dict[req_id] == 0:
finished_set.add(req_id)
del remaining_count_dict[req_id]
finished_sending = set[str]()
finished_recving = set[str]()
aggregated_kv_connector_stats = None
for model_runner_output in outputs:
output = model_runner_output.kv_connector_output
if not output:
continue
update_finished_set(output.finished_sending,
self._send_remaining_count, finished_sending)
update_finished_set(output.finished_recving,
self._recv_remaining_count, finished_recving)
# Aggregate kv_connector_stats from all workers.
if aggregated_kv_connector_stats is None:
# Use the first worker's kv_connector_stats as accumulator.
aggregated_kv_connector_stats = output.kv_connector_stats
elif kv_connector_stats := output.kv_connector_stats:
if aggregated_kv_connector_stats is None:
aggregated_kv_connector_stats = kv_connector_stats
else:
assert isinstance(aggregated_kv_connector_stats,
type(kv_connector_stats))
aggregated_kv_connector_stats = \
aggregated_kv_connector_stats.aggregate(kv_connector_stats)
# select output of the worker specified by output_rank
output = outputs[output_rank]
output.kv_connector_output = KVConnectorOutput(
finished_sending=finished_sending or None,
finished_recving=finished_recving or None,
kv_connector_stats=aggregated_kv_connector_stats or None,
)
return output
def async_aggregate(self,
output_futures: Sequence[Future[ModelRunnerOutput]],
output_rank: int = 0) -> Future[ModelRunnerOutput]:
"""Takes a list of futures and returns a single future which resolves
to the respective list of outputs."""
result_future: Future[ModelRunnerOutput] = Future()
outputs: list[Optional[ModelRunnerOutput]] = [None
] * len(output_futures)
def make_callback(idx):
def callback(fut):
if result_future.done():
return
try:
outputs[idx] = fut.result()
except CancelledError:
result_future.cancel()
except Exception as e:
result_future.set_exception(e)
# this check assumes io_thread_pool uses a single thread
if all(outputs):
result_future.set_result(
self.aggregate(cast(list[ModelRunnerOutput], outputs),
output_rank))
return callback
for i, output_future in enumerate(output_futures):
output_future.add_done_callback(make_callback(i))
return result_future
def _make_src_and_dst_indices(
src_block_ids: list[int],
dst_block_ids: list[int],
src_device: Union[torch.device, str],
dst_device: Union[torch.device, str],
) -> tuple[torch.Tensor, torch.Tensor]:
src_indices = torch.tensor(src_block_ids,
device=src_device,
dtype=torch.int64)
dst_indices = torch.tensor(dst_block_ids,
device=dst_device,
dtype=torch.int64)
return src_indices, dst_indices
def copy_kv_blocks(
src_kv_caches: dict[str, torch.Tensor],
dst_kv_caches: dict[str, torch.Tensor],
src_block_ids: list[int],
dst_block_ids: list[int],
direction: Literal["h2d", "d2h"],
) -> None:
"""Copy kv blocks between different buffers."""
if not src_kv_caches or not dst_kv_caches or \
not src_block_ids or not dst_block_ids or \
len(src_block_ids) != len(dst_block_ids):
return
src_device = next(iter(src_kv_caches.values())).device
dst_device = next(iter(dst_kv_caches.values())).device
src_indices, dst_indices = _make_src_and_dst_indices(
src_block_ids=src_block_ids,
dst_block_ids=dst_block_ids,
src_device=src_device,
dst_device=dst_device)
from vllm.platforms import current_platform
if direction == "h2d":
copy_fn = current_platform.insert_blocks_to_device
else:
copy_fn = current_platform.swap_out_blocks_to_host
for layer_name in src_kv_caches:
src_tensor = src_kv_caches[layer_name]
dst_tensor = dst_kv_caches[layer_name]
copy_fn(src_tensor, dst_tensor, src_indices, dst_indices)

View File

@@ -0,0 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorRole)
__all__ = ["KVConnectorRole", "KVConnectorBase_V1"]

View File

@@ -0,0 +1,388 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
KVConnectorBase_V1 Class for Distributed KV Cache & Hidden State
communication in vLLM v1
The class provides the following primitives:
Scheduler-side: runs in the scheduler, binds metadata, which
is used by the worker-side to load/save KV cache.
get_num_new_matched_tokens() - get number of new tokens
that exist in the remote KV cache. Might be called multiple
times for a given request and should be side-effect free.
update_state_after_alloc() - update KVConnector state after
temporary buffer alloc by the CacheManager.
update_connector_output() - update KVConnector state after
output is received from worker-side connectors.
request_finished() - called when a request is finished, with
the computed kv cache blocks for the request.
Returns whether KV cache should be freed now or will be
freed asynchronously and optionally returns KV transfer
params.
take_events() - returns new KV events that were collected
by the connector since the last call.
Worker-side: runs in each worker, loads/saves KV cache to/from
the Connector based on the metadata.
start_load_kv() - starts loading all KVs (maybe async)
wait_for_layer_load() - blocks until layer i load is done
save_kv_layer() - starts saving KV for layer i (maybe async)
wait_for_save() - blocks until all saves are done
get_finished() - called with ids of finished requests, returns
ids of requests that have completed async sending/recving.
"""
import enum
from abc import ABC, abstractmethod
from collections.abc import Iterable
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional
import torch
from vllm.logger import init_logger
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.outputs import KVConnectorOutput
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig
from vllm.distributed.kv_events import KVCacheEvent
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
KVConnectorStats)
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.request import Request
# s_tensor_list, d_tensor_list, s_indices, d_indices, direction
CopyBlocksOp = Callable[[
dict[str, torch.Tensor], dict[
str, torch.Tensor], list[int], list[int], Literal["h2d", "d2h"]
], None]
logger = init_logger(__name__)
class KVConnectorRole(enum.Enum):
# Connector running in the scheduler process
SCHEDULER = 0
# Connector running in the worker process
WORKER = 1
class KVConnectorMetadata(ABC): # noqa: B024
"""
Abstract Metadata used to communicate between the
Scheduler KVConnector and Worker KVConnector.
"""
pass
class KVConnectorBase_V1(ABC):
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
logger.warning(
"Initializing KVConnectorBase_V1. This API is experimental and "
"subject to change in the future as we iterate the design.")
self._connector_metadata: Optional[KVConnectorMetadata] = None
self._vllm_config = vllm_config
self._role = role
@property
def role(self) -> KVConnectorRole:
return self._role
# ==============================
# Worker-side methods
# ==============================
def bind_connector_metadata(
self, connector_metadata: KVConnectorMetadata) -> None:
"""Set the connector metadata from the scheduler.
This function should be called by the model runner every time
before the model execution. The metadata will be used for runtime
KV cache loading and saving.
Args:
connector_metadata (dict): the connector metadata.
"""
self._connector_metadata = connector_metadata
def clear_connector_metadata(self) -> None:
"""Clear the connector metadata.
This function should be called by the model runner every time
after the model execution.
"""
self._connector_metadata = None
def _get_connector_metadata(self) -> KVConnectorMetadata:
"""Get the connector metadata.
This function should only be called inside the connector.
Returns:
ConnectorMetadata: the connector metadata.
"""
# Should only be called while set to valid metadata.
assert self._connector_metadata is not None
return self._connector_metadata
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
"""
Initialize with the KV caches. Useful for pre-registering the
KV Caches in the KVConnector (e.g. for NIXL).
Args:
kv_caches: dictionary of layer names, kv cache
"""
return
def set_host_xfer_buffer_ops(self, copy_operation: CopyBlocksOp):
"""
Set the xPU-specific ops for copying KV between host and device.
Needed when host buffer is used for kv transfer (e.g., in NixlConnector)
"""
return
@abstractmethod
def start_load_kv(self, forward_context: "ForwardContext",
**kwargs: Any) -> None:
"""
Start loading the KV cache from the connector to vLLM's paged
KV buffer. This is called from the forward context before the
forward pass to enable async loading during model execution.
Args:
forward_context (ForwardContext): the forward context.
**kwargs: additional arguments for the load operation
Note:
The number of elements in kv_caches and layer_names should be
the same.
"""
pass
@abstractmethod
def wait_for_layer_load(self, layer_name: str) -> None:
"""
Block until the KV for a specific layer is loaded into vLLM's
paged buffer. This is called from within attention layer to ensure
async copying from start_load_kv is complete.
This interface will be useful for layer-by-layer pipelining.
Args:
layer_name: the name of that layer
"""
pass
@abstractmethod
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata",
**kwargs: Any) -> None:
"""
Start saving a layer of KV cache from vLLM's paged buffer
to the connector. This is called from within attention layer to
enable async copying during execution.
Args:
layer_name (str): the name of the layer.
kv_layer (torch.Tensor): the paged KV buffer of the current
layer in vLLM.
attn_metadata (AttentionMetadata): the attention metadata.
**kwargs: additional arguments for the save operation.
"""
pass
@abstractmethod
def wait_for_save(self):
"""
Block until all the save operations is done. This is called
as the forward context exits to ensure that the async saving
from save_kv_layer is complete before finishing the forward.
This prevents overwrites of paged KV buffer before saving done.
"""
pass
def get_finished(
self, finished_req_ids: set[str]
) -> tuple[Optional[set[str]], Optional[set[str]]]:
"""
Notifies worker-side connector ids of requests that have
finished generating tokens on the worker.
The scheduler process (via the Executors) will use this output
to track which workers are done.
Returns:
ids of requests that have finished asynchronous transfer
(requests that previously returned True from request_finished()),
tuple of (sending/saving ids, recving/loading ids).
The finished saves/sends req ids must belong to a set provided in a
call to this method (this call or a prior one).
"""
return None, None
def shutdown(self):
"""
Shutdown the connector. This is called when the worker process
is shutting down to ensure that all the async operations are
completed and the connector is cleaned up properly.
"""
return None
def get_kv_connector_stats(self) -> Optional["KVConnectorStats"]:
"""
Get the KV connector stats collected during the last interval.
"""
return None
# ==============================
# Scheduler-side methods
# ==============================
@abstractmethod
def get_num_new_matched_tokens(
self,
request: "Request",
num_computed_tokens: int,
) -> tuple[Optional[int], bool]:
"""
Get number of new tokens that can be loaded from the
external KV cache beyond the num_computed_tokens.
Args:
request (Request): the request object.
num_computed_tokens (int): the number of locally
computed tokens for this request
Returns:
A tuple with the following elements:
- An optional number of tokens that can be loaded from the
external KV cache beyond what is already computed.
If None, it means that the connector needs more time to
determine the number of matched tokens, and the scheduler
should query for this request again later.
- `True` if external KV cache tokens will be loaded
asynchronously (between scheduler steps). Must be
'False' if the first element is 0.
"""
pass
@abstractmethod
def update_state_after_alloc(self, request: "Request",
blocks: "KVCacheBlocks",
num_external_tokens: int):
"""
Update KVConnector state after block allocation.
If get_num_new_matched_tokens previously returned True for a
request, this function may be called twice for that same request -
first when blocks are allocated for the connector tokens to be
asynchronously loaded into, and second when any additional blocks
are allocated, after the load/transfer is complete.
Args:
request (Request): the request object.
blocks (KVCacheBlocks): the blocks allocated for the request.
num_external_tokens (int): the number of tokens that will be
loaded from the external KV cache.
"""
pass
@abstractmethod
def build_connector_meta(
self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
"""
Build the connector metadata for this step.
This function should NOT modify fields in the scheduler_output.
Also, calling this function will reset the state of the connector.
Args:
scheduler_output (SchedulerOutput): the scheduler output object.
"""
pass
def update_connector_output(self, connector_output: KVConnectorOutput):
"""
Update KVConnector state from worker-side connectors output.
Args:
connector_output (KVConnectorOutput): the worker-side
connectors output.
"""
return
def request_finished(
self,
request: "Request",
block_ids: list[int],
) -> tuple[bool, Optional[dict[str, Any]]]:
"""
Called when a request has finished, before its blocks are freed.
Returns:
True if the request is being saved/sent asynchronously and blocks
should not be freed until the request_id is returned from
get_finished().
Optional KVTransferParams to be included in the request outputs
returned by the engine.
"""
return False, None
def take_events(self) -> Iterable["KVCacheEvent"]:
"""
Take the KV cache events from the connector.
Yields:
New KV cache events since the last call.
"""
return ()
@classmethod
def get_required_kvcache_layout(
cls, vllm_config: "VllmConfig") -> Optional[str]:
"""
Get the required KV cache layout for this connector.
Args:
vllm_config (VllmConfig): the vllm config.
Returns:
str: the required KV cache layout. e.g. HND, or NHD.
None if the connector does not require a specific layout.
"""
if cls is KVConnectorBase_V1:
raise TypeError("get_required_kvcache_layout should not be called "
"on the abstract base class")
return None
def get_finished_count(self) -> Optional[int]:
"""
Get the count of requests expected to complete send/receive operations
via this connector.
Returns:
int: expected sending or receiving completion count.
"""
return None
@classmethod
def build_kv_connector_stats(
cls,
data: Optional[dict[str,
Any]] = None) -> Optional["KVConnectorStats"]:
"""
KVConnectorStats resolution method. This method allows dynamically
registered connectors to return their own KVConnectorStats object,
which can implement custom aggregation logic on the data dict.
"""
return None

View File

@@ -0,0 +1,168 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING, Any, Optional
import torch
from lmcache.integration.vllm.vllm_v1_adapter import LMCacheConnectorV1Impl
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
from vllm.logger import init_logger
from vllm.v1.core.sched.output import SchedulerOutput
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.request import Request
logger = init_logger(__name__)
class LMCacheConnectorV1(KVConnectorBase_V1):
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
super().__init__(vllm_config=vllm_config, role=role)
self._lmcache_engine = LMCacheConnectorV1Impl(vllm_config, role, self)
# ==============================
# Worker-side methods
# ==============================
def start_load_kv(self, forward_context: "ForwardContext",
**kwargs: Any) -> None:
"""
Start loading the KV cache from the connector to vLLM's paged
KV buffer. This is called from the forward context before the
forward pass to enable async loading during model execution.
Args:
forward_context (ForwardContext): the forward context.
**kwargs: additional arguments for the load operation
Note:
The number of elements in kv_caches and layer_names should be
the same.
"""
self._lmcache_engine.start_load_kv(forward_context, **kwargs)
def wait_for_layer_load(self, layer_name: str) -> None:
"""
Block until the KV for a specific layer is loaded into vLLM's
paged buffer. This is called from within attention layer to ensure
async copying from start_load_kv is complete.
This interface will be useful for layer-by-layer pipelining.
Args:
layer_name: the name of that layer
"""
self._lmcache_engine.wait_for_layer_load(layer_name)
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata",
**kwargs: Any) -> None:
"""
Start saving the a layer of KV cache from vLLM's paged buffer
to the connector. This is called from within attention layer to
enable async copying during execution.
Args:
layer_name (str): the name of the layer.
kv_layer (torch.Tensor): the paged KV buffer of the current
layer in vLLM.
attn_metadata (AttentionMetadata): the attention metadata.
**kwargs: additional arguments for the save operation.
"""
self._lmcache_engine.save_kv_layer(layer_name, kv_layer, attn_metadata,
**kwargs)
def wait_for_save(self):
"""
Block until all the save operations is done. This is called
as the forward context exits to ensure that the async saving
from save_kv_layer is complete before finishing the forward.
This prevents overwrites of paged KV buffer before saving done.
"""
self._lmcache_engine.wait_for_save()
def get_finished(
self, finished_req_ids: set[str]
) -> tuple[Optional[set[str]], Optional[set[str]]]:
"""
Notifies worker-side connector ids of requests that have
finished generating tokens.
Returns:
ids of requests that have finished asynchronous transfer
(requests that previously returned True from request_finished()),
tuple of (sending/saving ids, recving/loading ids).
The finished saves/sends req ids must belong to a set provided in a
call to this method (this call or a prior one).
"""
return self._lmcache_engine.get_finished(finished_req_ids)
# ==============================
# Scheduler-side methods
# ==============================
def get_num_new_matched_tokens(
self,
request: "Request",
num_computed_tokens: int,
) -> tuple[Optional[int], bool]:
"""
Get number of new tokens that can be loaded from the
external KV cache beyond the num_computed_tokens.
Args:
request (Request): the request object.
num_computed_tokens (int): the number of locally
computed tokens for this request
Returns:
the number of tokens that can be loaded from the
external KV cache beyond what is already computed.
"""
return self._lmcache_engine.get_num_new_matched_tokens(
request, num_computed_tokens), False
def update_state_after_alloc(self, request: "Request",
blocks: "KVCacheBlocks",
num_external_tokens: int):
"""
Update KVConnector state after block allocation.
"""
self._lmcache_engine.update_state_after_alloc(request,
num_external_tokens)
def build_connector_meta(
self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
"""
Build the connector metadata for this step.
This function should NOT modify fields in the scheduler_output.
Also, calling this function will reset the state of the connector.
Args:
scheduler_output (SchedulerOutput): the scheduler output object.
"""
return self._lmcache_engine.build_connector_meta(scheduler_output)
def request_finished(
self,
request: "Request",
block_ids: list[int],
) -> tuple[bool, Optional[dict[str, Any]]]:
"""
Called when a request has finished, before its blocks are freed.
Returns:
True if the request is being saved/sent asynchronously and blocks
should not be freed until the request_id is returned from
get_finished().
Optional KVTransferParams to be included in the request outputs
returned by the engine.
"""
return self._lmcache_engine.request_finished(request, block_ids)

View File

@@ -0,0 +1,100 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass, field
from typing import Any, Optional, Union
from vllm.config.kv_transfer import KVTransferConfig
from vllm.distributed.kv_transfer.kv_connector.factory import (
KVConnectorFactory)
from vllm.distributed.kv_transfer.kv_transfer_state import (
has_kv_transfer_group)
from vllm.logger import init_logger
logger = init_logger(__name__)
@dataclass
class KVConnectorStats:
"""
Base class for KV Connector Stats, a container for transfer performance
metrics or otherwise important telemetry from the connector.
All sub-classes need to be serializable as stats are sent from worker to
logger process.
"""
data: dict[str, Any] = field(default_factory=dict)
def reset(self):
"""Reset the stats, clear the state."""
raise NotImplementedError
def aggregate(self, other: "KVConnectorStats") -> "KVConnectorStats":
"""
Aggregate stats with another `KVConnectorStats` object.
"""
raise NotImplementedError
def reduce(self) -> dict[str, Union[int, float]]:
"""
Reduce the observations collected during a time interval to one or
more representative values (eg avg/median/sum of the series).
This is meant to be called by the logger to produce a summary of the
stats for the last time interval.
"""
raise NotImplementedError
def is_empty(self) -> bool:
"""Return True if the stats are empty."""
raise NotImplementedError
class KVConnectorLogging:
def __init__(self, kv_tranfer_config: KVTransferConfig):
# This should be called on frontend process.
assert not has_kv_transfer_group()
# Instantiate the connector's stats class.
if kv_tranfer_config and kv_tranfer_config.kv_connector:
self.connector_cls = KVConnectorFactory.get_connector_class(
kv_tranfer_config)
self.reset()
def reset(self):
self.transfer_stats_accumulator: Optional[KVConnectorStats] = None
def observe(self, transfer_stats_data: dict[str, Any]):
# Should not be called when a KVConnector is not configured.
assert self.connector_cls is not None
# Called periodically when connector syncs with the scheduler.
# Note that this is not the same as the logging interval.
# We expect transfer_stats_data to be aggregated across all workers and
# consist of observations from a single connector or a MultiConnector.
transfer_stats = self.connector_cls.build_kv_connector_stats(
transfer_stats_data)
if transfer_stats is None:
logger.warning_once(
"The connector %s is collecting stats but "
"does not implement the "
"`build_kv_connector_stats` method. "
"Stats will not be logged.", self.connector_cls)
return
if self.transfer_stats_accumulator is None:
self.transfer_stats_accumulator = transfer_stats
else:
# Accumulate last interval stats.
self.transfer_stats_accumulator = \
self.transfer_stats_accumulator.aggregate(transfer_stats)
def log(self, log_fn=logger.info):
"""Log transfer metrics periodically, similar to throughput logging"""
if (self.transfer_stats_accumulator
and not self.transfer_stats_accumulator.is_empty()):
# Produce a single cumulative stats object for the last time
# interval from the recorded observations.
xfer_metrics = self.transfer_stats_accumulator.reduce()
xfer_metrics_str = ", ".join(f"{k}={v}"
for k, v in xfer_metrics.items())
log_fn("KV Transfer metrics: %s", xfer_metrics_str)
# Reset metrics for next interval
self.reset()

View File

@@ -0,0 +1,328 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
from collections.abc import Iterable
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional
import torch
from vllm.config import VllmConfig
from vllm.config.kv_transfer import KVTransferConfig
from vllm.distributed.kv_transfer.kv_connector.factory import (
KVConnectorFactory)
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
KVConnectorStats)
from vllm.logger import init_logger
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.outputs import KVConnectorOutput
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.distributed.kv_events import KVCacheEvent
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.request import Request
logger = init_logger(__name__)
@dataclass
class MultiKVConnectorMetadata(KVConnectorMetadata):
metadata: tuple[KVConnectorMetadata, ...]
extra_async_saves: Optional[dict[str, int]] = None
@dataclass
class MultiKVConnectorStats(KVConnectorStats):
"""
Maintain a dict of KVConnectorStats objects, one for each connector.
This is used to aggregate the stats from all connectors separately.
"""
def aggregate(self, other: KVConnectorStats) -> KVConnectorStats:
for connector_id, stats in other.data.items():
if connector_id not in self.data:
self[connector_id] = stats
else:
assert isinstance(stats, type(self.data[connector_id]))
self[connector_id] = self[connector_id].aggregate(stats)
return self
def reset(self):
for stats in self.data.values():
stats.reset()
def reduce(self) -> dict[str, Any]:
# TODO (NickLucche) Adjust for logging on separate lines
return {
connector_id: stats.reduce()
for connector_id, stats in self.data.items()
}
def is_empty(self) -> bool:
return all(stats.is_empty() for stats in self.data.values())
def __getitem__(self, connector_id: str) -> KVConnectorStats:
return self.data[connector_id]
def __setitem__(self, connector_id: str, stats: KVConnectorStats):
self.data[connector_id] = stats
class MultiConnector(KVConnectorBase_V1):
"""
A wrapper for using multiple KVConnectors at the same time.
The current logic is:
- Load KV from the first connector that advertises available tokens from
get_num_new_matched_tokens(), based on the order in the config.
- Save to all connectors.
"""
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
super().__init__(vllm_config=vllm_config, role=role)
self._connectors: list[KVConnectorBase_V1] = []
self._ktc_kv_transfer_config = []
ktcs = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"connectors")
assert ktcs is not None
for ktc in ktcs:
temp_config = copy.copy(vllm_config)
engine_id = ktc.get("engine_id",
vllm_config.kv_transfer_config.engine_id)
temp_config.kv_transfer_config = KVTransferConfig(
**ktc, engine_id=engine_id)
self._connectors.append(
KVConnectorFactory.create_connector(temp_config, role))
self._ktc_kv_transfer_config.append(temp_config.kv_transfer_config)
# A mapping from request id to the index of the connector chosen to
# load the request from (if any).
self._requests_to_connector: dict[str, int] = {}
# Keeps track of *additional* remaining async saves (beyond 1) to be
# finished per request. Not needed for async loads since we only allow
# a single connector to load.
# Propagated from scheduler to worker side via the connector metadata.
self._extra_async_saves: dict[str, int] = {}
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
for c in self._connectors:
c.register_kv_caches(kv_caches)
# We must override the base class method here because we need to bind
# the metadata to each connector in the order of the connectors in the
# MultiKVConnectorMetadata.
def bind_connector_metadata(
self, connector_metadata: KVConnectorMetadata) -> None:
assert isinstance(connector_metadata, MultiKVConnectorMetadata)
if connector_metadata.extra_async_saves:
self._extra_async_saves.update(
connector_metadata.extra_async_saves)
for c, cm in zip(self._connectors, connector_metadata.metadata):
c.bind_connector_metadata(cm)
def clear_connector_metadata(self) -> None:
for c in self._connectors:
c.clear_connector_metadata()
def shutdown(self):
exception: Optional[Exception] = None
for c in self._connectors:
try:
c.shutdown()
except Exception as e:
logger.exception("Exception during connector %s shutdown.",
c.__class__.__name__)
exception = e
if exception:
raise exception
# ==============================
# Worker-side methods
# ==============================
def start_load_kv(self, forward_context: "ForwardContext",
**kwargs) -> None:
for c in self._connectors:
c.start_load_kv(forward_context, **kwargs)
def wait_for_layer_load(self, layer_name: str) -> None:
for c in self._connectors:
c.wait_for_layer_load(layer_name)
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata", **kwargs) -> None:
for c in self._connectors:
c.save_kv_layer(layer_name, kv_layer, attn_metadata, **kwargs)
def wait_for_save(self):
for c in self._connectors:
c.wait_for_save()
def get_finished(
self, finished_req_ids: set[str]
) -> tuple[Optional[set[str]], Optional[set[str]]]:
finished_sending: set[str] = set()
finished_recving: set[str] = set()
for c in self._connectors:
sending, recving = c.get_finished(finished_req_ids)
if not recving and not sending:
continue
# Aggregate finished recving request ids.
finished_recving.update(recving or ())
# Aggregate finished sending request ids - only include
# once we've drained the "extra" count (for cases where
# more than one connector is async-saving the same request).
for req_id in sending or ():
extra_pending = self._extra_async_saves.get(req_id)
if extra_pending is None:
finished_sending.add(req_id)
continue
assert extra_pending > 0
if extra_pending == 1:
del self._extra_async_saves[req_id]
else:
self._extra_async_saves[req_id] = extra_pending - 1
return finished_sending or None, finished_recving or None
# ==============================
# Scheduler-side methods
# ==============================
def get_num_new_matched_tokens(
self,
request: "Request",
num_computed_tokens: int,
) -> tuple[Optional[int], bool]:
to_return = (0, False)
for i, c in enumerate(self._connectors):
toks, load_async = c.get_num_new_matched_tokens(
request, num_computed_tokens)
# If there is a connector still looking up the matches,
# we return None to indicate that we are not done yet.
if toks is None:
return (None, False)
# The first connector that has new matched tokens will be assigned
# to this request.
if to_return[0] == 0 and toks > 0:
self._requests_to_connector[request.request_id] = i
to_return = (toks, load_async)
return to_return
def update_state_after_alloc(self, request: "Request",
blocks: "KVCacheBlocks",
num_external_tokens: int):
chosen_connector = self._requests_to_connector.get(
request.request_id, -1)
empty_blocks = blocks.new_empty()
for i, c in enumerate(self._connectors):
if i == chosen_connector:
# Forward call to the chosen connector (if any).
c.update_state_after_alloc(request, blocks,
num_external_tokens)
else:
# Call with empty blocks for other connectors.
c.update_state_after_alloc(request, empty_blocks, 0)
def build_connector_meta(
self,
scheduler_output: SchedulerOutput) -> MultiKVConnectorMetadata:
metadata = MultiKVConnectorMetadata(metadata=tuple(
c.build_connector_meta(scheduler_output)
for c in self._connectors))
if self._extra_async_saves:
metadata.extra_async_saves = self._extra_async_saves
self._extra_async_saves = {}
return metadata
def update_connector_output(self, connector_output: KVConnectorOutput):
for c in self._connectors:
c.update_connector_output(connector_output)
def request_finished(
self,
request: "Request",
blocks: list[int],
) -> tuple[bool, Optional[dict[str, Any]]]:
async_saves = 0
kv_txfer_params = None
for c in self._connectors:
async_save, txfer_params = c.request_finished(request, blocks)
if async_save:
async_saves += 1
if txfer_params is not None:
if kv_txfer_params is not None:
# TODO we can probably change this to merge the dicts here,
# checking for key clashes.
raise RuntimeError(
"Only one connector can produce KV transfer params")
kv_txfer_params = txfer_params
if async_saves > 1:
self._extra_async_saves[request.request_id] = async_saves - 1
# Clean up other state for this request.
self._requests_to_connector.pop(request.request_id, None)
return async_saves > 0, kv_txfer_params
def take_events(self) -> Iterable["KVCacheEvent"]:
for c in self._connectors:
yield from c.take_events()
@classmethod
def get_required_kvcache_layout(
cls, vllm_config: "VllmConfig") -> Optional[str]:
"""
Get the required KV cache layout for this connector.
Args:
vllm_config (VllmConfig): the vllm config.
Returns:
str: the required KV cache layout. e.g. HND, or NHD.
None if the connector does not require a specific layout.
"""
ktcs = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"connectors")
assert ktcs is not None
layouts: set[str] = set()
temp_vllm_config = copy.copy(vllm_config)
for ktc in ktcs:
kv_transfer_config = KVTransferConfig(**ktc)
temp_vllm_config.kv_transfer_config = kv_transfer_config
connector_cls = KVConnectorFactory.get_connector_class(
kv_transfer_config)
required_kvcache_layout = (
connector_cls.get_required_kvcache_layout(temp_vllm_config))
if required_kvcache_layout is not None:
layouts.add(required_kvcache_layout)
if len(layouts) > 1:
raise ValueError(f"KV cache layout mismatch: "
f"found {len(layouts)} different layouts "
f"({', '.join(layouts) })."
f"All connectors must use the same layout.")
return next(iter(layouts), None)
@classmethod
def build_kv_connector_stats(
cls,
data: Optional[dict[str,
Any]] = None) -> Optional[KVConnectorStats]:
return MultiKVConnectorStats(data=data) if data is not None \
else MultiKVConnectorStats()
def get_kv_connector_stats(self) -> Optional[MultiKVConnectorStats]:
# Group connector stats by connector type.
stats_by_connector: Optional[MultiKVConnectorStats] = None
for c in self._connectors:
stats = c.get_kv_connector_stats()
if stats is None:
continue
if stats_by_connector is None:
# Lazy init to allow optional return value.
stats_by_connector = MultiKVConnectorStats()
stats_by_connector[c.__class__.__name__] = stats
return stats_by_connector

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,485 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections import defaultdict
from collections.abc import Iterable, Iterator
from dataclasses import dataclass
from itertools import islice
from typing import Any, Optional
import torch
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.distributed.kv_events import BlockRemoved, BlockStored, KVCacheEvent
from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1,
KVConnectorRole)
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorMetadata)
from vllm.forward_context import ForwardContext
from vllm.logger import init_logger
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.core.kv_cache_utils import BlockHash
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_offload.abstract import OffloadingManager
from vllm.v1.kv_offload.factory import OffloadingSpecFactory
from vllm.v1.kv_offload.mediums import GPULoadStoreSpec
from vllm.v1.kv_offload.spec import OffloadingSpec
from vllm.v1.kv_offload.worker.worker import OffloadingWorker, TransferSpec
from vllm.v1.outputs import KVConnectorOutput
from vllm.v1.request import Request
ReqId = str
logger = init_logger(__name__)
@dataclass
class OffloadingConnectorMetadata(KVConnectorMetadata):
reqs_to_load: dict[ReqId, TransferSpec]
reqs_to_store: dict[ReqId, TransferSpec]
class OffloadingConnector(KVConnectorBase_V1):
def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole):
super().__init__(vllm_config, role)
spec = OffloadingSpecFactory.create_spec(vllm_config)
self.connector_scheduler: Optional[OffloadingConnectorScheduler] = None
self.connector_worker: Optional[OffloadingConnectorWorker] = None
if role == KVConnectorRole.SCHEDULER:
self.connector_scheduler = OffloadingConnectorScheduler(spec)
elif role == KVConnectorRole.WORKER:
self.connector_worker = OffloadingConnectorWorker(spec)
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
assert self.connector_worker is not None
self.connector_worker.register_kv_caches(kv_caches)
def start_load_kv(self, forward_context: "ForwardContext",
**kwargs) -> None:
assert self.connector_worker is not None
assert isinstance(self._connector_metadata,
OffloadingConnectorMetadata)
self.connector_worker.start_load_kv(self._connector_metadata)
def wait_for_layer_load(self, layer_name: str) -> None:
pass
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata", **kwargs) -> None:
pass
def wait_for_save(self):
assert self.connector_worker is not None
assert isinstance(self._connector_metadata,
OffloadingConnectorMetadata)
self.connector_worker.start_store_kv(self._connector_metadata)
def get_finished(self,
finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
assert self.connector_worker is not None
return self.connector_worker.get_finished(finished_req_ids)
def get_num_new_matched_tokens(
self, request: "Request",
num_computed_tokens: int) -> tuple[int, bool]:
assert self.connector_scheduler is not None
return self.connector_scheduler.get_num_new_matched_tokens(
request, num_computed_tokens)
def update_state_after_alloc(self, request: "Request",
blocks: "KVCacheBlocks",
num_external_tokens: int):
assert self.connector_scheduler is not None
return self.connector_scheduler.update_state_after_alloc(
request, blocks, num_external_tokens)
def build_connector_meta(
self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
assert self.connector_scheduler is not None
return self.connector_scheduler.build_connector_meta(scheduler_output)
def update_connector_output(self, connector_output: KVConnectorOutput):
assert self.connector_scheduler is not None
self.connector_scheduler.update_connector_output(connector_output)
def request_finished(
self,
request: "Request",
block_ids: list[int],
) -> tuple[bool, Optional[dict[str, Any]]]:
assert self.connector_scheduler is not None
return self.connector_scheduler.request_finished(request, block_ids)
def take_events(self) -> Iterable[KVCacheEvent]:
assert self.connector_scheduler is not None
return self.connector_scheduler.take_events()
class OffloadingConnectorScheduler:
"""Implementation of Scheduler side methods"""
def __init__(self, spec: OffloadingSpec):
self.gpu_block_size = spec.gpu_block_size
self.offloaded_block_size = spec.offloaded_block_size
self.block_size_factor = (self.offloaded_block_size //
self.gpu_block_size)
self.manager: OffloadingManager = spec.get_manager()
self._requests: dict[ReqId, Request] = {}
# list of GPU block IDs per request
self._request_block_ids: dict[ReqId, list[int]] = {}
# requests to load for the current scheduler step
self._reqs_to_load: dict[ReqId, TransferSpec] = {}
# request blocks are stored in order
# index of next block (of size offloaded_block_size) to offload
self._next_stored_block_idx: dict[ReqId, int] = {}
# request ID -> set(block hashes being stored/load)
self._reqs_being_stored = defaultdict[ReqId, set[BlockHash]](set)
self._reqs_being_loaded = defaultdict[ReqId, set[BlockHash]](set)
def _get_block_hashes(
self,
req: Request,
start_idx: int = 0,
end_idx: Optional[int] = None,
) -> Iterable[BlockHash]:
return islice(
req.block_hashes,
self.block_size_factor * start_idx + self.block_size_factor - 1,
self.block_size_factor * end_idx if end_idx else None,
self.block_size_factor)
def get_num_new_matched_tokens(
self, request: Request,
num_computed_tokens: int) -> tuple[int, bool]:
"""
Get number of new tokens that can be loaded beyond the
num_computed_tokens.
Args:
request (Request): the request object.
num_computed_tokens (int): the number of locally
computed tokens for this request
Returns:
A tuple with the following elements:
- The number of tokens that can be loaded beyond what is
already computed.
- `True` if tokens will be loaded asynchronously
(between scheduler steps).
"""
num_blocks = request.num_tokens // self.offloaded_block_size
assert (len(request.block_hashes) //
self.block_size_factor == num_blocks)
block_hashes = self._get_block_hashes(request)
self.manager.touch(block_hashes)
full_block_tokens = self.offloaded_block_size * num_blocks
if full_block_tokens - num_computed_tokens < self.offloaded_block_size:
# we can load less than a block, skip
return 0, False
start_block_idx = num_computed_tokens // self.offloaded_block_size
hits = self.manager.lookup(
self._get_block_hashes(request, start_idx=start_block_idx))
if hits == 0:
return 0, False
num_hit_tokens = (self.offloaded_block_size *
(start_block_idx + hits) - num_computed_tokens)
logger.debug(
"Request %s hit %s offloaded tokens after %s GPU hit tokens",
request.request_id,
num_hit_tokens,
num_computed_tokens,
)
if num_hit_tokens < self.offloaded_block_size:
return 0, False
return num_hit_tokens, True
def update_state_after_alloc(self, request: Request, blocks: KVCacheBlocks,
num_external_tokens: int):
self._requests[request.request_id] = request
# the block ids are updated in _get_reqs_to_store
self._request_block_ids[request.request_id] = []
if num_external_tokens == 0:
return
block_groups = blocks.get_block_ids()
block_ids = block_groups[0]
num_computed_gpu_blocks = sum(block.block_hash is not None
for block in blocks.blocks[0])
num_computed_tokens = num_computed_gpu_blocks * self.gpu_block_size
full_block_tokens = num_computed_tokens + num_external_tokens
assert full_block_tokens % self.offloaded_block_size == 0
num_pending_gpu_blocks = len(block_ids) - num_computed_gpu_blocks
assert (num_external_tokens == num_pending_gpu_blocks *
self.gpu_block_size)
start_block_idx = num_computed_tokens // self.offloaded_block_size
num_blocks = full_block_tokens // self.offloaded_block_size
assert (len(request.block_hashes) // self.block_size_factor
>= num_blocks)
block_hashes = self._get_block_hashes(request,
start_idx=start_block_idx,
end_idx=num_blocks)
src_spec = self.manager.prepare_load(block_hashes)
dst_spec = GPULoadStoreSpec(block_ids[num_computed_gpu_blocks:])
block_hashes = self._get_block_hashes(request,
start_idx=start_block_idx,
end_idx=num_blocks)
self._reqs_to_load[request.request_id] = (src_spec, dst_spec)
self._reqs_being_loaded[request.request_id].update(block_hashes)
self._next_stored_block_idx[request.request_id] = num_blocks
def _get_reqs_to_store(self, scheduler_output: SchedulerOutput):
reqs_to_store: dict[ReqId, TransferSpec] = {}
# iterate over both new and cached requests
for req_id, new_block_id_groups, preempted in yield_req_data(
scheduler_output):
if preempted:
self._request_block_ids[req_id] = []
if new_block_id_groups:
new_block_ids = new_block_id_groups[0]
self._request_block_ids[req_id] += new_block_ids
block_ids = self._request_block_ids[req_id]
req = self._requests[req_id]
new_tokens = scheduler_output.num_scheduled_tokens[req_id]
total_tokens = req.num_computed_tokens + new_tokens
num_blocks = total_tokens // self.offloaded_block_size
start_block_idx = self._next_stored_block_idx.get(req_id, 0)
num_new_blocks = num_blocks - start_block_idx
if num_new_blocks <= 0:
continue
num_gpu_blocks = num_blocks * self.block_size_factor
assert len(req.block_hashes) >= num_gpu_blocks
new_block_hashes = self._get_block_hashes(
req, start_idx=start_block_idx, end_idx=num_blocks)
store_output = self.manager.prepare_store(new_block_hashes)
if store_output is None:
logger.warning("Cannot store %s blocks", num_new_blocks)
break
self._next_stored_block_idx[req_id] = num_blocks
if not store_output.block_hashes_to_store:
continue
block_hashes_to_store = set(store_output.block_hashes_to_store)
block_hashes = self._get_block_hashes(req, end_idx=num_blocks)
self.manager.touch(block_hashes)
new_block_hashes = self._get_block_hashes(
req, start_idx=start_block_idx, end_idx=num_blocks)
dst_spec = store_output.store_spec
src_block_ids: list[int] = []
for idx, blk_hash in enumerate(new_block_hashes):
if blk_hash not in block_hashes_to_store:
continue
offloaded_block_idx = start_block_idx + idx
gpu_block_idx = offloaded_block_idx * self.block_size_factor
for i in range(self.block_size_factor):
src_block_ids.append(block_ids[gpu_block_idx + i])
src_spec = GPULoadStoreSpec(src_block_ids)
reqs_to_store[req_id] = (src_spec, dst_spec)
self._reqs_being_stored[req_id] |= block_hashes_to_store
logger.debug(
"Request %s offloading %s blocks starting from block #%d",
req_id,
len(block_hashes_to_store),
start_block_idx,
)
return reqs_to_store
def build_connector_meta(
self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
meta = OffloadingConnectorMetadata(
reqs_to_load=self._reqs_to_load,
reqs_to_store=self._get_reqs_to_store(scheduler_output))
self._reqs_to_load = {}
return meta
def update_connector_output(self, connector_output: KVConnectorOutput):
"""
Update KVConnector state from worker-side connectors output.
Args:
connector_output (KVConnectorOutput): the worker-side
connectors output.
"""
for req_id in connector_output.finished_sending or []:
block_hashes = self._reqs_being_stored.pop(req_id, None)
if block_hashes:
self.manager.complete_store(block_hashes)
for req_id in connector_output.finished_recving or []:
block_hashes = self._reqs_being_loaded.pop(req_id, None)
if block_hashes:
self.manager.complete_load(block_hashes)
def request_finished(
self,
request: Request,
block_ids: list[int],
) -> tuple[bool, Optional[dict[str, Any]]]:
"""
Called when a request has finished, before its blocks are freed.
Returns:
True if the request is being saved/sent asynchronously and blocks
should not be freed until the request_id is returned from
get_finished().
Optional KVTransferParams to be included in the request outputs
returned by the engine.
"""
req_id = request.request_id
self._requests.pop(req_id, None)
self._request_block_ids.pop(req_id, None)
self._next_stored_block_idx.pop(req_id, None)
request_being_stored = req_id in self._reqs_being_stored
return request_being_stored, None
def take_events(self) -> Iterable[KVCacheEvent]:
"""Take the KV cache events from the connector.
Returns:
A list of KV cache events.
"""
for event in self.manager.take_events():
if event.removed:
yield BlockRemoved(block_hashes=event.block_hashes,
medium=event.medium)
else:
yield BlockStored(block_hashes=event.block_hashes,
parent_block_hash=None,
token_ids=[],
lora_id=None,
block_size=event.block_size,
medium=event.medium)
class OffloadingConnectorWorker:
"""Implementation of Worker side methods"""
def __init__(self, spec: OffloadingSpec):
self.spec = spec
self.worker = OffloadingWorker()
self._job_counter = 0
# req_id -> (job_id, store)
self._jobs: dict[int, tuple[ReqId, bool]] = {}
# req_id -> active job IDs
self._load_job: dict[ReqId, int] = {}
# req_id -> set(active job IDs)
self._store_jobs = defaultdict[ReqId, set[int]](set)
self._finished_reqs_waiting_for_store: set[ReqId] = set()
def _generate_job_id(self) -> int:
job_id = self._job_counter
self._job_counter = job_id + 1
return job_id
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
for src_cls, dst_cls, handler in (self.spec.get_handlers(kv_caches)):
self.worker.register_handler(src_cls, dst_cls, handler)
def start_load_kv(self, metadata: OffloadingConnectorMetadata):
for req_id, transfer_spec in metadata.reqs_to_load.items():
job_id = self._generate_job_id()
self._jobs[job_id] = (req_id, False)
assert req_id not in self._load_job
self._load_job[req_id] = job_id
assert self.worker.transfer_async(job_id, transfer_spec)
def start_store_kv(self, metadata: OffloadingConnectorMetadata):
for req_id, transfer_spec in metadata.reqs_to_store.items():
job_id = self._generate_job_id()
self._jobs[job_id] = (req_id, True)
self._store_jobs[req_id].add(job_id)
assert self.worker.transfer_async(job_id, transfer_spec)
def get_finished(self,
finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
"""
Notifies worker-side connector ids of requests that have
finished generating tokens.
Returns a list of request IDs that finished loading or storing.
Returns:
ids of requests that have finished asynchronous transfer
tuple of (sending/saving ids, recving/loading ids).
"""
finished_sending = set()
finished_recving = set()
for job_id, success in self.worker.get_finished():
# we currently do not support job failures
assert success
req_id, store = self._jobs.pop(job_id)
if store:
req_jobs = self._store_jobs[req_id]
req_jobs.remove(job_id)
if req_jobs:
continue
if req_id in self._finished_reqs_waiting_for_store:
self._finished_reqs_waiting_for_store.remove(req_id)
finished_sending.add(req_id)
del self._store_jobs[req_id]
else:
req_job = self._load_job[req_id]
assert job_id == req_job
del self._load_job[req_id]
finished_recving.add(req_id)
for req_id in finished_req_ids:
pending_req_jobs = self._store_jobs.get(req_id)
if pending_req_jobs:
self._finished_reqs_waiting_for_store.add(req_id)
elif pending_req_jobs is not None:
finished_sending.add(req_id)
del self._store_jobs[req_id]
return finished_sending, finished_recving
def yield_req_data(
scheduler_output) -> Iterator[tuple[str, tuple[list[int], ...], bool]]:
"""
Yields:
(req_id, new_block_id_groups, preempted)
"""
# new requests
for req_data in scheduler_output.scheduled_new_reqs:
yield req_data.req_id, req_data.block_ids, False
# cached requests
cached_reqs = scheduler_output.scheduled_cached_reqs
yield from zip(cached_reqs.req_ids, cached_reqs.new_block_ids,
cached_reqs.resumed_from_preemption)

View File

@@ -0,0 +1,488 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional
import regex as re
import torch
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
from vllm.distributed.kv_transfer.kv_connector.v1.p2p.p2p_nccl_engine import (
P2pNcclEngine)
from vllm.distributed.parallel_state import get_world_group
from vllm.logger import init_logger
from vllm.v1.attention.backends.mla.common import MLACommonMetadata
from vllm.v1.core.sched.output import SchedulerOutput
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.request import Request
logger = init_logger(__name__)
@dataclass
class ReqMeta:
# Request Id
request_id: str
# Request block ids
block_ids: torch.Tensor
# Request num tokens
num_tokens: int
@staticmethod
def make_meta(request_id: str, token_ids: list[int], block_ids: list[int],
block_size: int) -> "ReqMeta":
block_ids_tensor = torch.tensor(block_ids)
return ReqMeta(
request_id=request_id,
block_ids=block_ids_tensor,
num_tokens=len(token_ids),
)
@dataclass
class P2pNcclConnectorMetadata(KVConnectorMetadata):
requests: list[ReqMeta]
def __init__(self):
self.requests = []
def add_request(
self,
request_id: str,
token_ids: list[int],
block_ids: list[int],
block_size: int,
) -> None:
self.requests.append(
ReqMeta.make_meta(request_id, token_ids, block_ids, block_size))
class P2pNcclConnector(KVConnectorBase_V1):
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
super().__init__(vllm_config=vllm_config, role=role)
self._block_size = vllm_config.cache_config.block_size
self._requests_need_load: dict[str, Any] = {}
self.config = vllm_config.kv_transfer_config
self.is_producer = self.config.is_kv_producer
self.chunked_prefill: dict[str, Any] = {}
self._rank = get_world_group().rank \
if role == KVConnectorRole.WORKER else 0
self._local_rank = get_world_group().local_rank \
if role == KVConnectorRole.WORKER else 0
self.p2p_nccl_engine = P2pNcclEngine(
local_rank=self._local_rank,
config=self.config,
hostname="",
port_offset=self._rank,
) if role == KVConnectorRole.WORKER else None
# ==============================
# Worker-side methods
# ==============================
def start_load_kv(self, forward_context: "ForwardContext",
**kwargs: Any) -> None:
"""Start loading the KV cache from the connector buffer to vLLM's
paged KV buffer.
Args:
forward_context (ForwardContext): the forward context.
**kwargs: additional arguments for the load operation
Note:
The number of elements in kv_caches and layer_names should be
the same.
"""
# Only consumer/decode loads KV Cache
if self.is_producer:
return
assert self.p2p_nccl_engine is not None
attn_metadata = forward_context.attn_metadata
if attn_metadata is None:
return
def inject_kv_into_layer(
layer: torch.Tensor,
kv_cache: torch.Tensor,
block_ids: torch.Tensor,
request_id: str,
) -> None:
"""
Inject KV cache data into a given attention layer tensor.
This function updates `layer` in-place with values from `kv_cache`,
handling different backend layouts:
- MLA (Multi-Linear Attention) or FlashInfer: KV tensors are
indexed along the first dimension.
- FlashAttention: KV tensors are indexed along the second
dimension.
If the number of provided block IDs does not match the number of KV
blocks, only the overlapping portion is updated, and a warning is
logged.
Args:
layer (torch.Tensor): The attention layer KV tensor to update.
kv_cache (torch.Tensor): The KV cache tensor to inject.
block_ids (torch.Tensor): Indices of the blocks to update.
request_id (str): Request identifier used for logging.
Returns:
None. The function modifies `layer` in-place.
"""
if (isinstance(attn_metadata, MLACommonMetadata)
or layer.shape[1] == 2): # MLA or FlashInfer
num_block = kv_cache.shape[0]
self.check_tensors_except_dim(layer, kv_cache, 0)
if len(block_ids) == num_block:
layer[block_ids, ...] = kv_cache
else:
layer[block_ids[:num_block], ...] = kv_cache
logger.warning(
"🚧kv_cache does not match, block_ids:%d, "
"num_block:%d, request_id:%s", len(block_ids),
num_block, request_id)
elif layer.shape[0] == 2: # FlashAttention
num_block = kv_cache.shape[1]
self.check_tensors_except_dim(layer, kv_cache, 1)
if len(block_ids) == num_block:
layer[:, block_ids, ...] = kv_cache
else:
layer[:, block_ids[:num_block], ...] = kv_cache
logger.warning(
"🚧kv_cache does not match, block_ids:%d, "
"num_block:%d, request_id:%s", len(block_ids),
num_block, request_id)
# Get the metadata
metadata: KVConnectorMetadata = \
self._get_connector_metadata()
assert isinstance(metadata, P2pNcclConnectorMetadata)
if metadata is None:
return
# Load the KV for each request each layer
for request in metadata.requests:
request_id = request.request_id
ip, port = self.parse_request_id(request_id, False)
remote_address = ip + ":" + str(port + self._rank)
for layer_name in forward_context.no_compile_layers:
layer = forward_context.no_compile_layers[layer_name]
# Only process layers that have kv_cache
# attribute (attention layers) Skip non-attention
# layers like FusedMoE
kv_cache = getattr(layer, 'kv_cache', None)
if kv_cache is None:
continue
layer = kv_cache[forward_context.virtual_engine]
kv_cache = self.p2p_nccl_engine.recv_tensor(
request.request_id + "#" + layer_name, remote_address)
if kv_cache is None:
logger.warning("🚧kv_cache is None, %s", request.request_id)
continue
inject_kv_into_layer(layer, kv_cache, request.block_ids,
request.request_id)
def wait_for_layer_load(self, layer_name: str) -> None:
"""Blocking until the KV for a specific layer is loaded into vLLM's
paged buffer.
This interface will be useful for layer-by-layer pipelining.
Args:
layer_name: the name of that layer
"""
return
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata",
**kwargs: Any) -> None:
"""Start saving the KV cache of the layer from vLLM's paged buffer
to the connector.
Args:
layer_name (str): the name of the layer.
kv_layer (torch.Tensor): the paged KV buffer of the current
layer in vLLM.
attn_metadata (AttentionMetadata): the attention metadata.
**kwargs: additional arguments for the save operation.
"""
# Only producer/prefill saves KV Cache
if not self.is_producer:
return
assert self.p2p_nccl_engine is not None
def extract_kv_from_layer(
layer: torch.Tensor,
block_ids: torch.Tensor,
) -> torch.Tensor:
"""
Extract KV cache slices from a given attention layer tensor.
This function handles multiple backend layouts:
- MLA (Multi-Linear Attention) or FlashInfer: KV tensors are
indexed along the first dimension.
- FlashAttention: KV tensors are indexed along the second
dimension.
Args:
layer (torch.Tensor): The KV cache from the attention layer.
block_ids (torch.Tensor): Indices of blocks to extract.
Returns:
torch.Tensor: A tensor containing the extracted KV slices.
Returns None if the layout is unsupported.
"""
if (isinstance(attn_metadata, MLACommonMetadata)
or layer.shape[1] == 2): # MLA or FlashInfer
return layer[block_ids, ...]
if layer.shape[0] == 2: # FlashAttention
return layer[:, block_ids, ...]
return None
connector_metadata = self._get_connector_metadata()
assert isinstance(connector_metadata, P2pNcclConnectorMetadata)
for request in connector_metadata.requests:
request_id = request.request_id
ip, port = self.parse_request_id(request_id, True)
remote_address = ip + ":" + str(port + self._rank)
kv_cache = extract_kv_from_layer(kv_layer, request.block_ids)
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, remote_address)
def wait_for_save(self):
if self.is_producer:
assert self.p2p_nccl_engine is not None
self.p2p_nccl_engine.wait_for_sent()
def get_finished(
self, finished_req_ids: set[str],
**kwargs: Any) -> tuple[Optional[set[str]], Optional[set[str]]]:
"""
Notifies worker-side connector ids of requests that have
finished generating tokens.
Returns:
ids of requests that have finished asynchronous transfer,
tuple of (sending/saving ids, recving/loading ids).
The finished saves/sends req ids must belong to a set provided in a
call to this method (this call or a prior one).
"""
assert self.p2p_nccl_engine is not None
no_compile_layers = (
self._vllm_config.compilation_config.static_forward_context)
return self.p2p_nccl_engine.get_finished(finished_req_ids,
no_compile_layers)
# ==============================
# Scheduler-side methods
# ==============================
def get_num_new_matched_tokens(
self,
request: "Request",
num_computed_tokens: int,
) -> tuple[int, bool]:
"""
Get number of new tokens that can be loaded from the
external KV cache beyond the num_computed_tokens.
Args:
request (Request): the request object.
num_computed_tokens (int): the number of locally
computed tokens for this request
Returns:
the number of tokens that can be loaded from the
external KV cache beyond what is already computed.
"""
if self.is_producer:
return 0, False
num_external_tokens = (len(request.prompt_token_ids) - 1 -
num_computed_tokens)
if num_external_tokens < 0:
num_external_tokens = 0
return num_external_tokens, False
def update_state_after_alloc(self, request: "Request",
blocks: "KVCacheBlocks",
num_external_tokens: int):
"""
Update KVConnector state after block allocation.
"""
if not self.is_producer and num_external_tokens > 0:
self._requests_need_load[request.request_id] = (
request, blocks.get_block_ids()[0])
def build_connector_meta(
self,
scheduler_output: SchedulerOutput,
) -> KVConnectorMetadata:
"""Build the connector metadata for this step.
This function should NOT modify any fields in the scheduler_output.
Also, calling this function will reset the state of the connector.
Args:
scheduler_output (SchedulerOutput): the scheduler output object.
"""
meta = P2pNcclConnectorMetadata()
for new_req in scheduler_output.scheduled_new_reqs:
if self.is_producer:
num_scheduled_tokens = (
scheduler_output.num_scheduled_tokens)[new_req.req_id]
num_tokens = num_scheduled_tokens + new_req.num_computed_tokens
# the request's prompt is chunked prefill
if num_tokens < len(new_req.prompt_token_ids):
# 'CachedRequestData' has no attribute 'prompt_token_ids'
self.chunked_prefill[new_req.req_id] = (
new_req.block_ids[0], new_req.prompt_token_ids)
continue
# the request's prompt is not chunked prefill
meta.add_request(request_id=new_req.req_id,
token_ids=new_req.prompt_token_ids,
block_ids=new_req.block_ids[0],
block_size=self._block_size)
continue
if new_req.req_id in self._requests_need_load:
meta.add_request(request_id=new_req.req_id,
token_ids=new_req.prompt_token_ids,
block_ids=new_req.block_ids[0],
block_size=self._block_size)
self._requests_need_load.pop(new_req.req_id)
cached_reqs = scheduler_output.scheduled_cached_reqs
for i, req_id in enumerate(cached_reqs.req_ids):
num_computed_tokens = cached_reqs.num_computed_tokens[i]
new_block_ids = cached_reqs.new_block_ids[i]
resumed_from_preemption = cached_reqs.resumed_from_preemption[i]
if self.is_producer:
num_scheduled_tokens = (
scheduler_output.num_scheduled_tokens)[req_id]
num_tokens = (num_scheduled_tokens + num_computed_tokens)
assert req_id in self.chunked_prefill
block_ids = new_block_ids[0]
if not resumed_from_preemption:
block_ids = (self.chunked_prefill[req_id][0] + block_ids)
prompt_token_ids = self.chunked_prefill[req_id][1]
# the request's prompt is chunked prefill again
if num_tokens < len(prompt_token_ids):
self.chunked_prefill[req_id] = (block_ids,
prompt_token_ids)
continue
# the request's prompt is all prefilled finally
meta.add_request(request_id=req_id,
token_ids=prompt_token_ids,
block_ids=block_ids,
block_size=self._block_size)
self.chunked_prefill.pop(req_id, None)
continue
# NOTE(rob): here we rely on the resumed requests being
# the first N requests in the list scheduled_cache_reqs.
if not resumed_from_preemption:
break
if req_id in self._requests_need_load:
request, _ = self._requests_need_load.pop(req_id)
total_tokens = num_computed_tokens + 1
token_ids = request.all_token_ids[:total_tokens]
# NOTE(rob): For resumed req, new_block_ids is all
# of the block_ids for the request.
block_ids = new_block_ids[0]
meta.add_request(request_id=req_id,
token_ids=token_ids,
block_ids=block_ids,
block_size=self._block_size)
self._requests_need_load.clear()
return meta
def request_finished(
self,
request: "Request",
block_ids: list[int],
) -> tuple[bool, Optional[dict[str, Any]]]:
"""
Called when a request has finished, before its blocks are freed.
Returns:
True if the request is being saved/sent asynchronously and blocks
should not be freed until the request_id is returned from
get_finished().
Optional KVTransferParams to be included in the request outputs
returned by the engine.
"""
self.chunked_prefill.pop(request.request_id, None)
return False, None
# ==============================
# Static methods
# ==============================
@staticmethod
def parse_request_id(request_id: str, is_prefill=True) -> tuple[str, int]:
# Regular expression to match the string hostname and integer port
if is_prefill:
pattern = r"___decode_addr_(.*):(\d+)"
else:
pattern = r"___prefill_addr_(.*):(\d+)___"
# Use re.search to find the pattern in the request_id
match = re.search(pattern, request_id)
if match:
# Extract the ranks
ip = match.group(1)
port = int(match.group(2))
return ip, port
raise ValueError(
f"Request id {request_id} does not contain hostname and port")
@staticmethod
def check_tensors_except_dim(tensor1, tensor2, dim):
shape1 = tensor1.size()
shape2 = tensor2.size()
if len(shape1) != len(shape2) or not all(
s1 == s2
for i, (s1, s2) in enumerate(zip(shape1, shape2)) if i != dim):
raise NotImplementedError(
"Currently, only symmetric TP is supported. Asymmetric TP, PP,"
"and others will be supported in future PRs.")

View File

@@ -0,0 +1,550 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import logging
import os
import threading
import time
import typing
from collections import deque
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any, Optional
import msgpack
import torch
import zmq
from vllm.config.kv_transfer import KVTransferConfig
from vllm.distributed.device_communicators.pynccl_wrapper import (
NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum)
from vllm.distributed.kv_transfer.kv_connector.v1.p2p.tensor_memory_pool import ( # noqa: E501
TensorMemoryPool)
from vllm.utils import current_stream, get_ip
logger = logging.getLogger(__name__)
DEFAULT_MEM_POOL_SIZE_GB = 32
@contextmanager
def set_p2p_nccl_context(num_channels: str):
original_values: dict[str, Any] = {}
env_vars = [
'NCCL_MAX_NCHANNELS',
'NCCL_MIN_NCHANNELS',
'NCCL_CUMEM_ENABLE',
'NCCL_BUFFSIZE',
'NCCL_PROTO', # LL,LL128,SIMPLE
'NCCL_ALGO', # RING,TREE
]
for var in env_vars:
original_values[var] = os.environ.get(var)
logger.info("set_p2p_nccl_context, original_values: %s", original_values)
try:
os.environ['NCCL_MAX_NCHANNELS'] = num_channels
os.environ['NCCL_MIN_NCHANNELS'] = num_channels
os.environ['NCCL_CUMEM_ENABLE'] = '1'
yield
finally:
for var in env_vars:
if original_values[var] is not None:
os.environ[var] = original_values[var]
else:
os.environ.pop(var, None)
@dataclass
class SendQueueItem:
tensor_id: str
remote_address: str
tensor: torch.Tensor
class P2pNcclEngine:
def __init__(self,
local_rank: int,
config: KVTransferConfig,
hostname: str = "",
port_offset: int = 0,
library_path: Optional[str] = None) -> None:
self.config = config
self.rank = port_offset
self.local_rank = local_rank
self.device = torch.device(f"cuda:{self.local_rank}")
self.nccl = NCCLLibrary(library_path)
if not hostname:
hostname = get_ip()
port = int(self.config.kv_port) + port_offset
if port == 0:
raise ValueError("Port cannot be 0")
self._hostname = hostname
self._port = port
# Each card corresponds to a ZMQ address.
self.zmq_address = f"{self._hostname}:{self._port}"
# The `http_port` must be consistent with the port of OpenAI.
self.http_address = (
f"{self._hostname}:"
f"{self.config.kv_connector_extra_config['http_port']}")
# If `proxy_ip` or `proxy_port` is `""`,
# then the ping thread will not be enabled.
proxy_ip = self.config.get_from_extra_config("proxy_ip", "")
proxy_port = self.config.get_from_extra_config("proxy_port", "")
if proxy_ip == "" or proxy_port == "":
self.proxy_address = ""
else:
self.proxy_address = proxy_ip + ":" + proxy_port
self.context = zmq.Context()
self.router_socket = self.context.socket(zmq.ROUTER)
self.router_socket.bind(f"tcp://{self.zmq_address}")
self.poller = zmq.Poller()
self.poller.register(self.router_socket, zmq.POLLIN)
self.send_store_cv = threading.Condition()
self.send_queue_cv = threading.Condition()
self.recv_store_cv = threading.Condition()
self.send_stream = torch.cuda.Stream()
self.recv_stream = torch.cuda.Stream()
mem_pool_size_gb = float(
self.config.get_from_extra_config("mem_pool_size_gb",
DEFAULT_MEM_POOL_SIZE_GB))
self.pool = TensorMemoryPool(max_block_size=int(mem_pool_size_gb *
1024**3)) # GB
# The sending type includes tree mutually exclusive options:
# PUT, GET, PUT_ASYNC.
self.send_type = self.config.get_from_extra_config(
"send_type", "PUT_ASYNC")
if self.send_type == "GET":
# tensor_id: torch.Tensor
self.send_store: dict[str, torch.Tensor] = {}
else:
# PUT or PUT_ASYNC
# tensor_id: torch.Tensor
self.send_queue: deque[SendQueueItem] = deque()
if self.send_type == "PUT_ASYNC":
self._send_thread = threading.Thread(target=self.send_async,
daemon=True)
self._send_thread.start()
# tensor_id: torch.Tensor/(addr, dtype, shape)
self.recv_store: dict[str, Any] = {}
self.recv_request_id_to_tensor_ids: dict[str, set[str]] = {}
self.send_request_id_to_tensor_ids: dict[str, set[str]] = {}
self.socks: dict[str, Any] = {} # remote_address: client socket
self.comms: dict[str, Any] = {} # remote_address: (ncclComm_t, rank)
self.buffer_size = 0
self.buffer_size_threshold = float(self.config.kv_buffer_size)
self.nccl_num_channels = self.config.get_from_extra_config(
"nccl_num_channels", "8")
self._listener_thread = threading.Thread(
target=self.listen_for_requests, daemon=True)
self._listener_thread.start()
self._ping_thread = None
if port_offset == 0 and self.proxy_address != "":
self._ping_thread = threading.Thread(target=self.ping, daemon=True)
self._ping_thread.start()
logger.info(
"💯P2pNcclEngine init, rank:%d, local_rank:%d, http_address:%s, "
"zmq_address:%s, proxy_address:%s, send_type:%s, buffer_size_"
"threshold:%.2f, nccl_num_channels:%s", self.rank, self.local_rank,
self.http_address, self.zmq_address, self.proxy_address,
self.send_type, self.buffer_size_threshold, self.nccl_num_channels)
def create_connect(self, remote_address: typing.Optional[str] = None):
assert remote_address is not None
if remote_address not in self.socks:
sock = self.context.socket(zmq.DEALER)
sock.setsockopt_string(zmq.IDENTITY, self.zmq_address)
sock.connect(f"tcp://{remote_address}")
self.socks[remote_address] = sock
if remote_address in self.comms:
logger.info("👋comm exists, remote_address:%s, comms:%s",
remote_address, self.comms)
return sock, self.comms[remote_address]
unique_id = self.nccl.ncclGetUniqueId()
data = {"cmd": "NEW", "unique_id": bytes(unique_id.internal)}
sock.send(msgpack.dumps(data))
with torch.cuda.device(self.device):
rank = 0
with set_p2p_nccl_context(self.nccl_num_channels):
comm: ncclComm_t = self.nccl.ncclCommInitRank(
2, unique_id, rank)
self.comms[remote_address] = (comm, rank)
logger.info("🤝ncclCommInitRank Success, %s👉%s, MyRank:%s",
self.zmq_address, remote_address, rank)
return self.socks[remote_address], self.comms[remote_address]
def send_tensor(
self,
tensor_id: str,
tensor: torch.Tensor,
remote_address: typing.Optional[str] = None,
) -> bool:
if remote_address is None:
with self.recv_store_cv:
self.recv_store[tensor_id] = tensor
self.recv_store_cv.notify()
return True
item = SendQueueItem(tensor_id=tensor_id,
remote_address=remote_address,
tensor=tensor)
if self.send_type == "PUT":
return self.send_sync(item)
if self.send_type == "PUT_ASYNC":
with self.send_queue_cv:
self.send_queue.append(item)
self.send_queue_cv.notify()
return True
# GET
with self.send_store_cv:
tensor_size = tensor.element_size() * tensor.numel()
if tensor_size > self.buffer_size_threshold:
logger.warning(
"❗[GET]tensor_id:%s, tensor_size:%d, is greater than"
"buffer size threshold :%d, skip send to %s, rank:%d",
tensor_id, tensor_size, self.buffer_size_threshold,
remote_address, self.rank)
return False
while (self.buffer_size + tensor_size
> self.buffer_size_threshold):
assert len(self.send_store) > 0
oldest_tensor_id = next(iter(self.send_store))
oldest_tensor = self.send_store.pop(oldest_tensor_id)
oldest_tensor_size = oldest_tensor.element_size(
) * oldest_tensor.numel()
self.buffer_size -= oldest_tensor_size
logger.debug(
"⛔[GET]Send to %s, tensor_id:%s, tensor_size:%d,"
" buffer_size:%d, oldest_tensor_size:%d, rank:%d",
remote_address, tensor_id, tensor_size, self.buffer_size,
oldest_tensor_size, self.rank)
self.send_store[tensor_id] = tensor
self.buffer_size += tensor_size
logger.debug(
"🔵[GET]Send to %s, tensor_id:%s, tensor_size:%d, "
"shape:%s, rank:%d, buffer_size:%d(%.2f%%)", remote_address,
tensor_id, tensor_size, tensor.shape, self.rank,
self.buffer_size,
self.buffer_size / self.buffer_size_threshold * 100)
return True
def recv_tensor(
self,
tensor_id: str,
remote_address: typing.Optional[str] = None,
) -> torch.Tensor:
if self.send_type == "PUT" or self.send_type == "PUT_ASYNC":
start_time = time.time()
with self.recv_store_cv:
while tensor_id not in self.recv_store:
self.recv_store_cv.wait()
tensor = self.recv_store[tensor_id]
if tensor is not None:
if isinstance(tensor, tuple):
addr, dtype, shape = tensor
tensor = self.pool.load_tensor(addr, dtype, shape,
self.device)
else:
self.buffer_size -= (tensor.element_size() *
tensor.numel())
else:
duration = time.time() - start_time
logger.warning(
"🔴[PUT]Recv From %s, tensor_id:%s, duration:%.3fms, "
"rank:%d", remote_address, tensor_id, duration * 1000,
self.rank)
return tensor
# GET
if remote_address is None:
return None
if remote_address not in self.socks:
self.create_connect(remote_address)
sock = self.socks[remote_address]
comm, rank = self.comms[remote_address]
data = {"cmd": "GET", "tensor_id": tensor_id}
sock.send(msgpack.dumps(data))
message = sock.recv()
data = msgpack.loads(message)
if data["ret"] != 0:
logger.warning("🔴[GET]Recv From %s, tensor_id: %s, ret: %d",
remote_address, tensor_id, data["ret"])
return None
with torch.cuda.stream(self.recv_stream):
tensor = torch.empty(data["shape"],
dtype=getattr(torch, data["dtype"]),
device=self.device)
self.recv(comm, tensor, rank ^ 1, self.recv_stream)
return tensor
def listen_for_requests(self):
while True:
socks = dict(self.poller.poll())
if self.router_socket not in socks:
continue
remote_address, message = self.router_socket.recv_multipart()
data = msgpack.loads(message)
if data["cmd"] == "NEW":
unique_id = self.nccl.unique_id_from_bytes(
bytes(data["unique_id"]))
with torch.cuda.device(self.device):
rank = 1
with set_p2p_nccl_context(self.nccl_num_channels):
comm: ncclComm_t = self.nccl.ncclCommInitRank(
2, unique_id, rank)
self.comms[remote_address.decode()] = (comm, rank)
logger.info("🤝ncclCommInitRank Success, %s👈%s, MyRank:%s",
self.zmq_address, remote_address.decode(),
rank)
elif data["cmd"] == "PUT":
tensor_id = data["tensor_id"]
try:
with torch.cuda.stream(self.recv_stream):
tensor = torch.empty(data["shape"],
dtype=getattr(
torch, data["dtype"]),
device=self.device)
self.router_socket.send_multipart([remote_address, b"0"])
comm, rank = self.comms[remote_address.decode()]
self.recv(comm, tensor, rank ^ 1, self.recv_stream)
tensor_size = tensor.element_size() * tensor.numel()
if (self.buffer_size + tensor_size
> self.buffer_size_threshold):
# Store Tensor in memory pool
addr = self.pool.store_tensor(tensor)
tensor = (addr, tensor.dtype, tensor.shape)
logger.warning(
"🔴[PUT]Recv Tensor, Out Of Threshold, "
"%s👈%s, data:%s, addr:%d", self.zmq_address,
remote_address.decode(), data, addr)
else:
self.buffer_size += tensor_size
except torch.cuda.OutOfMemoryError:
self.router_socket.send_multipart([remote_address, b"1"])
tensor = None
logger.warning(
"🔴[PUT]Recv Tensor, Out Of Memory, %s👈%s, "
"data:%s", self.zmq_address, remote_address.decode(),
data)
with self.recv_store_cv:
self.recv_store[tensor_id] = tensor
self.have_received_tensor_id(tensor_id)
self.recv_store_cv.notify()
elif data["cmd"] == "GET":
tensor_id = data["tensor_id"]
with self.send_store_cv:
tensor = self.send_store.pop(tensor_id, None)
if tensor is not None:
data = {
"ret": 0,
"shape": tensor.shape,
"dtype": str(tensor.dtype).replace("torch.", "")
}
# LRU
self.send_store[tensor_id] = tensor
self.have_sent_tensor_id(tensor_id)
else:
data = {"ret": 1}
self.router_socket.send_multipart(
[remote_address, msgpack.dumps(data)])
if data["ret"] == 0:
comm, rank = self.comms[remote_address.decode()]
self.send(comm, tensor.to(self.device), rank ^ 1,
self.send_stream)
else:
logger.warning(
"🚧Unexpected, Received message from %s, data:%s",
remote_address, data)
def have_sent_tensor_id(self, tensor_id: str):
request_id = tensor_id.split('#')[0]
if request_id not in self.send_request_id_to_tensor_ids:
self.send_request_id_to_tensor_ids[request_id] = set()
self.send_request_id_to_tensor_ids[request_id].add(tensor_id)
def have_received_tensor_id(self, tensor_id: str):
request_id = tensor_id.split('#')[0]
if request_id not in self.recv_request_id_to_tensor_ids:
self.recv_request_id_to_tensor_ids[request_id] = set()
self.recv_request_id_to_tensor_ids[request_id].add(tensor_id)
def send_async(self):
while True:
with self.send_queue_cv:
while not self.send_queue:
self.send_queue_cv.wait()
item = self.send_queue.popleft()
if not self.send_queue:
self.send_queue_cv.notify()
self.send_sync(item)
def wait_for_sent(self):
if self.send_type == "PUT_ASYNC":
start_time = time.time()
with self.send_queue_cv:
while self.send_queue:
self.send_queue_cv.wait()
duration = time.time() - start_time
logger.debug(
"🚧[PUT_ASYNC]It took %.3fms to wait for the send_queue"
" to be empty, rank:%d", duration * 1000, self.rank)
def send_sync(self, item: SendQueueItem) -> bool:
if item.remote_address is None:
return False
if item.remote_address not in self.socks:
self.create_connect(item.remote_address)
tensor = item.tensor
sock = self.socks[item.remote_address]
comm, rank = self.comms[item.remote_address]
data = {
"cmd": "PUT",
"tensor_id": item.tensor_id,
"shape": tensor.shape,
"dtype": str(tensor.dtype).replace("torch.", "")
}
sock.send(msgpack.dumps(data))
response = sock.recv()
if response != b"0":
logger.error(
"🔴Send Tensor, Peer Out Of Memory/Threshold, %s 👉 %s, "
"MyRank:%s, data:%s, tensor:%s, size:%fGB, response:%s",
self.zmq_address, item.remote_address, rank, data,
tensor.shape,
tensor.element_size() * tensor.numel() / 1024**3,
response.decode())
return False
self.send(comm, tensor.to(self.device), rank ^ 1, self.send_stream)
if self.send_type == "PUT_ASYNC":
self.have_sent_tensor_id(item.tensor_id)
return True
def get_finished(
self, finished_req_ids: set[str], no_compile_layers
) -> tuple[Optional[set[str]], Optional[set[str]]]:
"""
Notifies worker-side connector ids of requests that have
finished generating tokens.
Returns:
ids of requests that have finished asynchronous transfer,
tuple of (sending/saving ids, recving/loading ids).
The finished saves/sends req ids must belong to a set provided in a
call to this method (this call or a prior one).
"""
# Clear the buffer upon request completion.
for request_id in finished_req_ids:
for layer_name in no_compile_layers:
tensor_id = request_id + "#" + layer_name
if tensor_id in self.recv_store:
with self.recv_store_cv:
tensor = self.recv_store.pop(tensor_id, None)
self.send_request_id_to_tensor_ids.pop(
request_id, None)
self.recv_request_id_to_tensor_ids.pop(
request_id, None)
if isinstance(tensor, tuple):
addr, _, _ = tensor
self.pool.free(addr)
# TODO:Retrieve requests that have already sent the KV cache.
finished_sending: set[str] = set()
# TODO:Retrieve requests that have already received the KV cache.
finished_recving: set[str] = set()
return finished_sending or None, finished_recving or None
def ping(self):
sock = self.context.socket(zmq.DEALER)
sock.setsockopt_string(zmq.IDENTITY, self.zmq_address)
logger.debug("ping start, zmq_address:%s", self.zmq_address)
sock.connect(f"tcp://{self.proxy_address}")
data = {
"type": "P" if self.config.is_kv_producer else "D",
"http_address": self.http_address,
"zmq_address": self.zmq_address
}
while True:
sock.send(msgpack.dumps(data))
time.sleep(3)
def send(self, comm, tensor: torch.Tensor, dst: int, stream=None):
assert tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}")
if stream is None:
stream = current_stream()
with torch.cuda.stream(stream):
self.nccl.ncclSend(buffer_type(tensor.data_ptr()), tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype), dst,
comm, cudaStream_t(stream.cuda_stream))
stream.synchronize()
def recv(self, comm, tensor: torch.Tensor, src: int, stream=None):
assert tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}")
if stream is None:
stream = current_stream()
with torch.cuda.stream(stream):
self.nccl.ncclRecv(buffer_type(tensor.data_ptr()), tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype), src,
comm, cudaStream_t(stream.cuda_stream))
stream.synchronize()
def close(self) -> None:
self._listener_thread.join()
if self.send_type == "PUT_ASYNC":
self._send_thread.join()
if self._ping_thread is not None:
self._ping_thread.join()

View File

@@ -0,0 +1,267 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import atexit
import ctypes
import math
from dataclasses import dataclass
import torch
from vllm.logger import init_logger
logger = init_logger(__name__)
@dataclass
class MemoryBlock:
size: int
addr: int
"""A memory pool for managing pinned host memory allocations for tensors.
This class implements a buddy allocation system to efficiently manage pinned
host memory for tensor storage. It supports allocation, deallocation, and
tensor storage/retrieval operations.
Key Features:
- Uses power-of-two block sizes for efficient buddy allocation
- Supports splitting and merging of memory blocks
- Provides methods to store CUDA tensors in pinned host memory
- Allows loading tensors from pinned memory back to device
- Automatically cleans up memory on destruction
Attributes:
max_block_size (int): Maximum block size (rounded to nearest power of two)
min_block_size (int): Minimum block size (rounded to nearest power of two)
free_lists (dict): Dictionary of free memory blocks by size
allocated_blocks (dict): Dictionary of currently allocated blocks
base_tensor (torch.Tensor): Base pinned memory tensor
base_address (int): Base memory address of the pinned memory region
Example:
>>> pool = TensorMemoryPool(max_block_size=1024*1024)
>>> tensor = torch.randn(100, device='cuda')
>>> addr = pool.store_tensor(tensor)
>>> loaded_tensor = pool.load_tensor(addr, tensor.dtype,
... tensor.shape, 'cuda')
>>> pool.free(addr)
"""
class TensorMemoryPool:
"""Initializes the memory pool with given size constraints.
Args:
max_block_size (int): Maximum size of memory blocks to manage
min_block_size (int, optional): Minimum size of memory blocks
to manage. Defaults to 512.
Raises:
ValueError: If block sizes are invalid or max_block_size is less
than min_block_size
"""
def __init__(self, max_block_size: int, min_block_size: int = 512):
if max_block_size <= 0 or min_block_size <= 0:
raise ValueError("Block sizes must be positive")
if max_block_size < min_block_size:
raise ValueError(
"Max block size must be greater than min block size")
self.max_block_size = self._round_to_power_of_two(max_block_size)
self.min_block_size = self._round_to_power_of_two(min_block_size)
self.free_lists: dict[int, dict[int, MemoryBlock]] = {}
self.allocated_blocks: dict[int, MemoryBlock] = {}
self._initialize_free_lists()
self._allocate_pinned_memory()
atexit.register(self.cleanup)
def _round_to_power_of_two(self, size: int) -> int:
return 1 << (size - 1).bit_length()
def _initialize_free_lists(self):
size = self.max_block_size
while size >= self.min_block_size:
self.free_lists[size] = {}
size //= 2
def _allocate_pinned_memory(self):
self.base_tensor = torch.empty(self.max_block_size // 4,
dtype=torch.float32,
pin_memory=True)
self.base_address = self.base_tensor.data_ptr()
initial_block = MemoryBlock(size=self.max_block_size,
addr=self.base_address)
self.free_lists[self.max_block_size][
initial_block.addr] = initial_block
logger.debug("TensorMemoryPool, base_address:%d, max_block_size:%d",
self.base_address, self.max_block_size)
def allocate(self, size: int) -> int:
"""Allocates a memory block of at least the requested size.
Args:
size (int): Minimum size of memory to allocate
Returns:
int: Address of the allocated memory block
Raises:
ValueError: If size is invalid or insufficient memory is available
"""
if size <= 0:
raise ValueError("Allocation size must be positive")
required_size = self._round_to_power_of_two(
max(size, self.min_block_size))
if required_size > self.max_block_size:
raise ValueError("Requested size exceeds maximum block size")
current_size = required_size
while current_size <= self.max_block_size:
if self.free_lists[current_size]:
_, block = self.free_lists[current_size].popitem()
self._split_block(block, required_size)
self.allocated_blocks[block.addr] = block
return block.addr
current_size *= 2
raise ValueError("Insufficient memory")
def _split_block(self, block: MemoryBlock, required_size: int):
while (block.size > required_size
and block.size // 2 >= self.min_block_size):
buddy_size = block.size // 2
buddy_addr = block.addr + buddy_size
buddy = MemoryBlock(size=buddy_size, addr=buddy_addr)
block.size = buddy_size
self.free_lists[buddy_size][buddy.addr] = buddy
def free(self, addr: int):
"""Frees an allocated memory block.
Args:
addr (int): Address of the block to free
Raises:
ValueError: If address is invalid or not allocated
"""
if addr not in self.allocated_blocks:
raise ValueError("Invalid address to free")
block = self.allocated_blocks.pop(addr)
self._merge_buddies(block)
def _merge_buddies(self, block: MemoryBlock):
MAX_MERGE_DEPTH = 30
depth = 0
while depth < MAX_MERGE_DEPTH:
buddy_offset = block.size if (block.addr - self.base_address) % (
2 * block.size) == 0 else -block.size
buddy_addr = block.addr + buddy_offset
buddy = self.free_lists[block.size].get(buddy_addr)
if buddy:
del self.free_lists[buddy.size][buddy.addr]
merged_addr = min(block.addr, buddy.addr)
merged_size = block.size * 2
block = MemoryBlock(size=merged_size, addr=merged_addr)
depth += 1
else:
break
self.free_lists[block.size][block.addr] = block
def store_tensor(self, tensor: torch.Tensor) -> int:
"""Stores a CUDA tensor in pinned host memory.
Args:
tensor (torch.Tensor): CUDA tensor to store
Returns:
int: Address where the tensor is stored
Raises:
ValueError: If tensor is not on CUDA or allocation fails
"""
if not tensor.is_cuda:
raise ValueError("Only CUDA tensors can be stored")
size = tensor.element_size() * tensor.numel()
addr = self.allocate(size)
block = self.allocated_blocks[addr]
if block.size < size:
self.free(addr)
raise ValueError(
f"Allocated block size {block.size} is smaller than "
f"required size {size}")
try:
buffer = (ctypes.c_byte * block.size).from_address(block.addr)
cpu_tensor = torch.frombuffer(buffer,
dtype=tensor.dtype,
count=tensor.numel()).reshape(
tensor.shape)
except ValueError as err:
self.free(addr)
raise ValueError(f"Failed to create tensor view: {err}") from err
cpu_tensor.copy_(tensor)
return addr
def load_tensor(self, addr: int, dtype: torch.dtype, shape: tuple[int,
...],
device: torch.device) -> torch.Tensor:
"""Loads a tensor from pinned host memory to the specified device.
Args:
addr (int): Address where tensor is stored
dtype (torch.dtype): Data type of the tensor
shape (tuple[int, ...]): Shape of the tensor
device: Target device for the loaded tensor
Returns:
torch.Tensor: The loaded tensor on the specified device
Raises:
ValueError: If address is invalid or sizes don't match
"""
if addr not in self.allocated_blocks:
raise ValueError("Invalid address to load")
block = self.allocated_blocks[addr]
num_elements = math.prod(shape)
dtype_size = torch.tensor([], dtype=dtype).element_size()
required_size = num_elements * dtype_size
if required_size > block.size:
raise ValueError("Requested tensor size exceeds block size")
buffer = (ctypes.c_byte * block.size).from_address(block.addr)
cpu_tensor = torch.frombuffer(buffer, dtype=dtype,
count=num_elements).reshape(shape)
cuda_tensor = torch.empty(shape, dtype=dtype, device=device)
cuda_tensor.copy_(cpu_tensor)
return cuda_tensor
def cleanup(self):
"""Cleans up all memory resources and resets the pool state."""
self.free_lists.clear()
self.allocated_blocks.clear()
if hasattr(self, 'base_tensor'):
del self.base_tensor
def __del__(self):
self.cleanup()

View File

@@ -0,0 +1,418 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import hashlib
import os
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional
import safetensors
import torch
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
from vllm.logger import init_logger
from vllm.v1.attention.backends.mla.common import MLACommonMetadata
from vllm.v1.core.sched.output import SchedulerOutput
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.request import Request
logger = init_logger(__name__)
@dataclass
class ReqMeta:
# Request tokens
token_ids: torch.Tensor
# Slot mappings, should have the same length as token_ids
slot_mapping: torch.Tensor
# Is store or load
is_store: bool
mm_hashes: list[str]
@staticmethod
def make_meta(token_ids: list[int], block_ids: list[int], block_size: int,
is_store: bool, mm_hashes: list[str]) -> "ReqMeta":
valid_num_tokens = align_to_block_size(len(token_ids), block_size)
token_ids_tensor = torch.tensor(token_ids)[:valid_num_tokens]
block_ids_tensor = torch.tensor(block_ids)
num_blocks = block_ids_tensor.shape[0]
block_offsets = torch.arange(0, block_size)
slot_mapping = block_offsets.reshape((1, block_size)) + \
block_ids_tensor.reshape((num_blocks, 1)) * block_size
slot_mapping = slot_mapping.flatten()[:valid_num_tokens]
return ReqMeta(
token_ids=token_ids_tensor,
slot_mapping=slot_mapping,
is_store=is_store,
mm_hashes=mm_hashes,
)
@dataclass
class SharedStorageConnectorMetadata(KVConnectorMetadata):
requests: list[ReqMeta]
def __init__(self):
self.requests = []
def add_request(
self,
token_ids: list[int],
block_ids: list[int],
block_size: int,
is_store: bool,
mm_hashes: list[str],
) -> None:
self.requests.append(
ReqMeta.make_meta(token_ids, block_ids, block_size, is_store,
mm_hashes))
class SharedStorageConnector(KVConnectorBase_V1):
# NOTE: This is Simple debug implementation of the KV connector.
# It save / load the KV cache to / from the disk.
# It does extra work which will overwrite the existing prefix-cache in GPU
# - to remove the overhead, need to add some "mask" in the ReqMeta class
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
super().__init__(vllm_config=vllm_config, role=role)
self._block_size = vllm_config.cache_config.block_size
self._requests_need_load: dict[str, Request] = {}
transfer_config = vllm_config.kv_transfer_config
self._storage_path = transfer_config.get_from_extra_config(
"shared_storage_path", "/tmp")
logger.info(vllm_config.kv_transfer_config)
logger.info("Shared storage path is %s", self._storage_path)
def start_load_kv(self, forward_context: "ForwardContext",
**kwargs: Any) -> None:
"""Start loading the KV cache from the connector buffer to vLLM's
paged KV buffer.
Args:
forward_context (ForwardContext): the forward context.
**kwargs: additional arguments for the load operation
Note:
The number of elements in kv_caches and layer_names should be
the same.
"""
attn_metadata = forward_context.attn_metadata
def inject_kv_into_layer(
dst_kv_cache_layer: torch.Tensor,
src_kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
) -> None:
"""Inject the KV cache into the layer.
Args:
dst_kv_cache_layer (torch.Tensor): the destination KV cache
layer. In shape [2, num_pages, page_size, xxx] if not
using MLA, [num_pages, page_size, xxx] otherwise.
src_kv_cache (torch.Tensor): the source KV cache. In shape
[2, num_tokens, xxx] if not using MLA, [num_tokens, xxx]
otherwise.
slot_mapping (torch.Tensor): the slot mapping. In shape
[num_tokens].
"""
dst_kv_cache_layer_shape = dst_kv_cache_layer.shape
if isinstance(attn_metadata, MLACommonMetadata):
num_pages = dst_kv_cache_layer_shape[0]
page_size = dst_kv_cache_layer_shape[1]
dst_kv_cache_layer = dst_kv_cache_layer.reshape(
num_pages * page_size, -1)
dst_kv_cache_layer[slot_mapping, ...] = src_kv_cache
dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape)
else:
num_pages = dst_kv_cache_layer_shape[1]
page_size = dst_kv_cache_layer_shape[2]
dst_kv_cache_layer = dst_kv_cache_layer.reshape(
2, num_pages * page_size, -1)
dst_kv_cache_layer[:, slot_mapping, ...] = src_kv_cache
dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape)
# Get the metadata
metadata: KVConnectorMetadata = self._get_connector_metadata()
assert isinstance(metadata, SharedStorageConnectorMetadata)
if metadata is None:
logger.warning(
"In connector.start_load_kv, but the connector metadata is None"
)
return
attn_metadata = forward_context.attn_metadata
if attn_metadata is None:
logger.warning(
"In connector.start_load_kv, but the attn_metadata is None")
return
# Load the KV for each request each layer
for request in metadata.requests:
if request.is_store:
continue
logger.info("Inject KV cache of %d tokens to the paged memory",
len(request.slot_mapping))
for layer_name in forward_context.no_compile_layers:
layer = forward_context.no_compile_layers[layer_name]
# Only process layers that have kv_cache
# attribute (attention layers) Skip non-attention
# layers like FusedMoE/MLP etc.
kv_cache_attr = getattr(layer, 'kv_cache', None)
if kv_cache_attr is None:
continue
kv_cache_layer = kv_cache_attr[ \
forward_context.virtual_engine]
filename = self._generate_filename_debug(
layer_name, request.token_ids, request.mm_hashes)
kv_cache = safetensors.torch.load_file(
filename)["kv_cache"].cuda()
inject_kv_into_layer(kv_cache_layer, kv_cache,
request.slot_mapping)
def wait_for_layer_load(self, layer_name: str) -> None:
"""Blocking until the KV for a specific layer is loaded into vLLM's
paged buffer.
This interface will be useful for layer-by-layer pipelining.
Args:
layer_name: the name of that layer
"""
return
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata",
**kwargs: Any) -> None:
"""Start saving the KV cache of the layer from vLLM's paged buffer
to the connector.
Args:
layer_name (str): the name of the layer.
kv_layer (torch.Tensor): the paged KV buffer of the current
layer in vLLM.
attn_metadata (AttentionMetadata): the attention metadata.
**kwargs: additional arguments for the save operation.
"""
def extract_kv_from_layer(
layer: torch.Tensor,
slot_mapping: torch.Tensor,
) -> torch.Tensor:
"""Extract the KV cache from the layer.
Assume the shape of the layer is (2, num_pages, page_size, xxx)
if MLA is not used, and (num_pages, page_size, xxx) otherwise.
"""
if isinstance(attn_metadata, MLACommonMetadata):
num_pages, page_size = layer.shape[0], layer.shape[1]
return layer.reshape(num_pages * page_size, -1)[slot_mapping,
...]
num_pages, page_size = layer.shape[1], layer.shape[2]
return layer.reshape(2, num_pages * page_size, -1)[:, slot_mapping,
...]
connector_metadata = self._get_connector_metadata()
assert isinstance(connector_metadata, SharedStorageConnectorMetadata)
for request in connector_metadata.requests:
if request.is_store:
filename = self._generate_filename_debug(
layer_name, request.token_ids, request.mm_hashes)
kv_cache = extract_kv_from_layer(kv_layer,
request.slot_mapping)
tensors = {"kv_cache": kv_cache.detach().cpu()}
safetensors.torch.save_file(tensors, filename)
def wait_for_save(self):
return
def get_num_new_matched_tokens(
self,
request: "Request",
num_computed_tokens: int,
) -> tuple[Optional[int], bool]:
"""
Get number of new tokens that can be loaded from the
external KV cache beyond the num_computed_tokens.
Args:
request (Request): the request object.
num_computed_tokens (int): the number of locally
computed tokens for this request
Returns:
the number of tokens that can be loaded from the
external KV cache beyond what is already computed.
"""
# NOTE: in this debug implementation, we assume that the prompt is
# cached_prompt + newly_generated_single_token
# Therefore, we use prompt_token_ids[:-1] to determine the folder name
# NOTE: in current v1 scheduler, the num_computed_tokens is aligned
# with the block granularity. And it expects the returned blocks and
# num_computed_tokens to also be aligned with the block granularity.
if not self._found_match_for_request(request):
return 0, False
logger.info("External Cache Hit!")
# Now, first num_tokens_to_check tokens are hit, we need to prepare
# the metadata for the worker connector to correctly load the KV
num_tokens_to_check = align_to_block_size(
len(request.prompt_token_ids) - 1, self._block_size)
return num_tokens_to_check - num_computed_tokens, False
def update_state_after_alloc(self, request: "Request",
blocks: "KVCacheBlocks",
num_external_tokens: int):
"""
Update KVConnector state after block allocation.
If blocks were allocated, add to _requests_need_load,
such that we load the KVs in the next forward pass.
"""
if num_external_tokens > 0:
self._requests_need_load[request.request_id] = request
def build_connector_meta(
self,
scheduler_output: SchedulerOutput,
) -> KVConnectorMetadata:
"""Build the connector metadata for this step.
This function should NOT modify any fields in the scheduler_output.
Also, calling this function will reset the state of the connector.
Args:
scheduler_output (SchedulerOutput): the scheduler output object.
"""
meta = SharedStorageConnectorMetadata()
total_need_load = 0
for new_req in scheduler_output.scheduled_new_reqs:
if new_req.req_id in self._requests_need_load:
meta.add_request(
token_ids=new_req.prompt_token_ids,
block_ids=new_req.block_ids[0],
block_size=self._block_size,
is_store=False,
mm_hashes=[f.identifier for f in new_req.mm_features])
total_need_load += 1
else:
# NOTE: here, we set the store and load being exclusive,
# but a single request can have both store and load.
# NOTE(rob): for this debug implementation, we only cache
# the original prompt tokens.
if not self._found_match_for_request(new_req):
meta.add_request(
token_ids=new_req.prompt_token_ids,
block_ids=new_req.block_ids[0],
block_size=self._block_size,
is_store=True,
mm_hashes=[f.identifier for f in new_req.mm_features])
cached_reqs = scheduler_output.scheduled_cached_reqs
for i, req_id in enumerate(cached_reqs.req_ids):
num_computed_tokens = cached_reqs.num_computed_tokens[i]
num_new_tokens = scheduler_output.num_scheduled_tokens[req_id]
new_block_ids = cached_reqs.new_block_ids[i]
resumed_from_preemption = cached_reqs.resumed_from_preemption[i]
# NOTE(rob): here we rely on the resumed requests being
# the first N requests in the list scheduled_cache_reqs.
if not resumed_from_preemption:
break
if req_id in self._requests_need_load:
# NOTE(rob): cached_req_data does not have the full
# list of token ids (only new tokens). So we look it
# up in the actual request object.
request = self._requests_need_load[req_id]
total_tokens = num_computed_tokens + num_new_tokens
token_ids = request.all_token_ids[:total_tokens]
# NOTE(rob): For resumed req, new_block_ids is all
# of the block_ids for the request.
block_ids = new_block_ids[0]
meta.add_request(
token_ids=token_ids,
block_ids=block_ids,
block_size=self._block_size,
is_store=False,
mm_hashes=[f.identifier for f in request.mm_features])
total_need_load += 1
assert total_need_load == len(self._requests_need_load)
self._requests_need_load.clear()
return meta
# ==============================
# Helper functions
# ==============================
def _found_match_for_request(
self,
request: "Request",
) -> bool:
"""Check if the cache is hit for the request.
"""
num_tokens_to_check = align_to_block_size(
len(request.prompt_token_ids) - 1, self._block_size)
foldername = self._generate_foldername_debug(
torch.tensor(request.prompt_token_ids)[:num_tokens_to_check],
[f.identifier for f in request.mm_features],
create_folder=False)
return os.path.exists(foldername)
def _generate_foldername_debug(
self,
token_ids: torch.Tensor,
mm_hashes: list[str],
create_folder=False,
) -> str:
"""Generate a folder name based on the hash of the bytes of the input
ids.
"""
token_bytes = token_ids.numpy().tobytes()
# Add mm_hashes to the bytes being hashed to avoid path traversal and
# to create a canonical key.
if mm_hashes:
mm_str = "-".join(mm_hashes)
token_bytes += mm_str.encode('utf-8')
input_ids_hash = hashlib.md5(token_bytes,
usedforsecurity=False).hexdigest()
foldername = os.path.join(self._storage_path, input_ids_hash)
if create_folder:
os.makedirs(foldername, exist_ok=True)
return foldername
def _generate_filename_debug(
self,
layer_name: str,
token_ids: torch.Tensor,
mm_hashes: list[str],
) -> str:
"""Generate a file name based on the layer name and the hash
of the bytes of the input ids.
"""
foldername = self._generate_foldername_debug(token_ids,
mm_hashes=mm_hashes,
create_folder=True)
return os.path.join(foldername, f"{layer_name}.safetensors")
def align_to_block_size(num_tokens: int, block_size) -> int:
"""Align the number of tokens to the block size.
"""
return (num_tokens - 1) // block_size * block_size

View File

@@ -0,0 +1,175 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This file contains a new class `KVLookupBufferBase` that allows developers to
think of KV cache operations as inserting new KV cache entries (`insert`)
into the lookup buffer and querying existing KV caches (`drop_select`)
from the lookup buffer.
This file also contains a new class `KVStoreBufferBase` that allows developers
to manage the KVCache buffer as a simple key-value storage buffer with basic
put/get operations.
These classes above are abstracted behind class `KVCacheBufferBase`.
"""
from abc import ABC, abstractmethod
from typing import Optional
import torch
class KVCacheBufferBase(ABC):
"""
Abstract base class for a KVCache buffer.
"""
@abstractmethod
def close(self) -> None:
"""Close the buffer and release resources.
This method is responsible for cleaning up resources related to the
KVCache buffer when it is no longer needed.
Raises:
NotImplementedError: This method must be implemented in subclasses.
"""
raise NotImplementedError
class KVLookupBufferBase(KVCacheBufferBase):
"""
Abstract base class for a KVCache lookup buffer.
This class provides an abstraction for a key-value (KV) cache lookup buffer.
The key of the lookup buffer:
- input_tokens: token IDs of the request
- roi: a binary mask on top of input_tokens.
- Purpose of roi: Since KV cache may only be available for a subset of
tokens in the input (for example, when vLLM is connected to an external
KV cache service), roi specifies the subset of tokens that the KV cache
is associated with.
- NOTE: roi can be further extended to describe which part of KV the
current process is holding (each process may only hold a part of KV
due to TP and PP). This is not implemented for now.
The value of the lookup buffer:
- key: the key tensor in the KV cache
- value: the value tensor in the KV cache
- hidden: the final hidden state generated by model forwarding. This allows
vLLM to bypass further model forwarding by transmitting the hidden state.
"""
@abstractmethod
def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor,
key: torch.Tensor, value: torch.Tensor,
hidden: torch.Tensor) -> None:
"""Insert into the lookup buffer.
The functionality is similar to the following python statement
```
buffer[input_tokens, roi] = [key, value, hidden]
```
FIXME: in the future, we should only have two arguments, key and value,
where key is a tensor dict and value is a tensor dict.
FIXME: we should transmit both sampler outputs and the hidden states.
Args:
input_tokens (torch.Tensor): token IDs.
roi (torch.Tensor): A binary mask on top of the input tokens
key (torch.Tensor): The key tensor in the KV cache.
value (torch.Tensor): The value tensor in the KV cache.
hidden (torch.Tensor): The final hidden state tensor generated
during model forwarding to bypass model
forwarding.
Raises:
NotImplementedError: This method must be implemented in subclasses.
"""
raise NotImplementedError
@abstractmethod
def drop_select(
self, input_tokens: Optional[torch.Tensor],
roi: Optional[torch.Tensor]) -> list[Optional[torch.Tensor]]:
"""Select and *drop* KV cache entries from the lookup buffer.
The functionality is similar to the following python statements
```
ret = buffer.pop(input_tokens, roi)
return ret
```
If `input_tokens` and `roi` is `None`, it means selecting any of the
KV caches in the buffer, return, and remove it from the buffer, useful
when offloading KV cache to KV cache storage service.
Args:
input_tokens (torch.Tensor): token IDs.
roi (torch.Tensor): A binary mask on top of the input tokens
Returns:
list[Optional[torch.Tensor]]: A list of tensors. Can be None.
Raises:
NotImplementedError: This method must be implemented in subclasses.
"""
raise NotImplementedError
class KVStoreBufferBase(KVCacheBufferBase):
"""
Abstract base class for a KVCache storage buffer with key-value semantics.
This class provides a simple key-value storage buffer abstract with basic
put/get operations, which enables flexible KVCache transfer granular
control.
The functionality is similar to a distributed key-value store, where:
- Key: A unique string identifier for the cached entry
- Value:
- Tensor to be stored and retrieved
- None (indicating deletion or empty value)
"""
@abstractmethod
def put(
self,
key: str,
value: Optional[torch.Tensor],
) -> None:
"""Store a key-value pair in the buffer.
Args:
key (str): Unique identifier for a tensor, this tensor could be the
key cache tensor, value cache tensor, or hidden state tensor
generated during model forwarding.
value (Optional[torch.Tensor]): Tensor to be stored.
Raises:
NotImplementedError: This method must be implemented in subclasses.
"""
raise NotImplementedError
@abstractmethod
def get(
self,
key: str,
) -> Optional[torch.Tensor]:
"""Retrieve a value from the buffer by key.
Args:
key (str): Unique identifier for a tensor, this tensor could be the
key cache tensor, value cache tensor, or hidden state tensor
generated during model forwarding.
Returns:
Optional[torch.Tensor]: Stored tensor if exists, None otherwise.
Raises:
NotImplementedError: This method must be implemented in subclasses.
"""
raise NotImplementedError

View File

@@ -0,0 +1,161 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This file contains a new class `MooncakeStore` that allows developers to
think of KV cache transfer operations as putting new KV cache entries
into a remote KVStore-based lookup buffer and getting existing KV caches
from this remote lookup buffer.
"""
import json
import os
from dataclasses import dataclass
from typing import Optional
import torch
from safetensors.torch import load as safetensors_load
from safetensors.torch import save as safetensors_save
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_lookup_buffer.base import (
KVStoreBufferBase)
from vllm.logger import init_logger
DEFAULT_GLOBAL_SEGMENT_SIZE = 3355443200 # 3.125 GiB
DEFAULT_LOCAL_BUFFER_SIZE = 1073741824 # 1.0 GiB
logger = init_logger(__name__)
@dataclass
class MooncakeStoreConfig:
local_hostname: str
metadata_server: str
global_segment_size: int
local_buffer_size: int
protocol: str
device_name: str
master_server_address: str
@staticmethod
def from_file(file_path: str) -> 'MooncakeStoreConfig':
"""Load the config from a JSON file."""
with open(file_path) as fin:
config = json.load(fin)
return MooncakeStoreConfig(
local_hostname=config.get("local_hostname"),
metadata_server=config.get("metadata_server"),
global_segment_size=config.get("global_segment_size",
DEFAULT_GLOBAL_SEGMENT_SIZE),
local_buffer_size=config.get("local_buffer_size",
DEFAULT_LOCAL_BUFFER_SIZE),
protocol=config.get("protocol", "tcp"),
device_name=config.get("device_name", ""),
master_server_address=config.get("master_server_address"),
)
@staticmethod
def load_from_env() -> 'MooncakeStoreConfig':
"""Load config from a file specified in the environment variable."""
config_file_path = os.getenv('MOONCAKE_CONFIG_PATH')
if config_file_path is None:
raise ValueError(
"The environment variable 'MOONCAKE_CONFIG_PATH' is not set.")
return MooncakeStoreConfig.from_file(config_file_path)
class MooncakeStore(KVStoreBufferBase):
def __init__(
self,
config: VllmConfig,
):
try:
from mooncake.store import MooncakeDistributedStore
except ImportError as e:
raise ImportError(
"Please install mooncake by following the instructions at "
"https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501
"to run vLLM with MooncakeConnector.") from e
try:
self.store = MooncakeDistributedStore()
self.config = MooncakeStoreConfig.load_from_env()
logger.info("Mooncake Configuration loaded successfully.")
self.store.setup(self.config.local_hostname,
self.config.metadata_server,
self.config.global_segment_size,
self.config.local_buffer_size,
self.config.protocol, self.config.device_name,
self.config.master_server_address)
except ValueError as e:
logger.error("Configuration loading failed: %s", e)
raise
except Exception as exc:
logger.error(
"An error occurred while loading the configuration: %s", exc)
raise
def close(self):
# MooncakeDistributedStore will automatically call the destructor, so
# it is unnecessary to close it manually.
pass
def put(
self,
key: str,
value: Optional[torch.Tensor],
) -> None:
# A message queue needs to be introduced before making it asynchronous.
if value is not None:
self._put_impl(key, value)
def get(
self,
key: str,
) -> Optional[torch.Tensor]:
# A message queue needs to be introduced before making it asynchronous.
value = self._get_impl(key)
return value
def _put_impl(
self,
key: str,
value: torch.Tensor,
) -> None:
"""Put KVCache to Mooncake Store"""
device_id = value.device.index if value.device.type == 'cuda' else -1
device_tensor = torch.tensor(device_id, dtype=torch.int32)
value_bytes = safetensors_save({
"tensor": value,
"device_id": device_tensor
})
try:
self.store.put(key, value_bytes)
except TypeError as err:
logger.error("Failed to put value into Mooncake Store: %s", err)
raise TypeError("Mooncake Store Put Type Error.") from err
def _get_impl(
self,
key: str,
) -> Optional[torch.Tensor]:
"""Get KVCache from Mooncake Store"""
try:
data = self.store.get(key)
except TypeError as err:
logger.error("Failed to get value from Mooncake Store: %s", err)
raise TypeError("Mooncake Store Get Type Error.") from err
if data:
loaded_tensors = safetensors_load(data)
tensor = loaded_tensors["tensor"]
device_id_tensor = loaded_tensors["device_id"]
device_id = int(device_id_tensor.item())
device = torch.device(
'cuda', device_id) if device_id >= 0 else torch.device('cpu')
return tensor.to(device)
return None

View File

@@ -0,0 +1,237 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Implements a distributed key-value (KV) cache transfer mechanism.
Key Features:
- Distributed KV cache transmission using PyNccl pipes.
- Non-blocking `insert`, blocking `drop_select`.
- Use CPU signal pipe to avoid racing condition
- Handles buffer size constraints and provide backpressure mechanism to
stop the prefill instance when the decode instance is slow.
"""
import threading
from collections import deque
from typing import Optional, Union
import torch
from vllm.distributed.kv_transfer.kv_lookup_buffer.base import (
KVLookupBufferBase)
from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase
from vllm.logger import init_logger
logger = init_logger(__name__)
class SimpleBuffer(KVLookupBufferBase):
def __init__(self, signal_pipe: KVPipeBase, data_pipe: KVPipeBase,
buffer_size_thresh: float):
"""
signal_pipe: on CPU
NOTE: on-device recv will block all threads in the process, making the
KV cache producer unable to listen to new request while transmitting
KV cache. Luckily CPU recv only blocks the current thread so we use
CPU recv to listen to new request.
data_pipe: on device (e.g. GPU)
"""
self.buffer: deque[list[torch.Tensor]] = deque()
self.buffer_size = 0
self.buffer_size_threshold = buffer_size_thresh
self.buffer_cv = threading.Condition()
self.signal_pipe = signal_pipe
self.data_pipe = data_pipe
self.request_handling_thread: Optional[threading.Thread] = None
self.normal_signal = torch.tensor([0], device="cpu")
self.end_signal = None
def _matches(self, tokens_roi_sender: list[torch.Tensor],
tokens_roi_recver: list[torch.Tensor]):
# tokens_roi_sender: tokens and roi of the producer (in the buffer)
# tokens_roi_recver: tokens and roi of the consumer (query)
tokens_sender = tokens_roi_sender[0]
tokens_recver = tokens_roi_recver[0]
roi_sender = tokens_roi_sender[1]
roi_recver = tokens_roi_recver[1]
if tokens_recver is None:
# consumer sends an empty request
# semantics: DROP SELECT * LIMIT 1
# so any of the data in the buffer can be drop-selected
return True
# Assuming that roi is a binary mask on tokens
tokens_sender = tokens_sender[roi_sender]
tokens_recver = tokens_recver[roi_recver]
# simple common prefix matching
min_length = min(len(tokens_sender), len(tokens_recver))
if torch.allclose(tokens_sender[:min_length],
tokens_recver[:min_length]):
return min_length
return 0
def _send_tensor_and_dec_size(self,
tensor: Optional[torch.Tensor]) -> None:
assert tensor is not None, "Use self.data_pipe.send(None) instead"
self.buffer_size -= tensor.element_size() * tensor.numel()
if tensor.dtype == torch.bool:
tensor = tensor.float()
self.data_pipe.send_tensor(tensor)
def _get_element_size(self, data: Optional[Union[list, torch.Tensor]]):
if isinstance(data, torch.Tensor):
return data.element_size() * data.numel()
if not data:
# cannot perform `not data` on a tensor
# so this check needs to go after the check above
return 0
raise AssertionError(f"Unknown data type {type(data)}")
def _add_to_buffer(self, input_tokens: torch.Tensor, roi: torch.Tensor,
key: torch.Tensor, value: torch.Tensor,
hidden: torch.Tensor):
if isinstance(input_tokens, torch.Tensor):
input_tokens = input_tokens.clone()
if isinstance(roi, torch.Tensor):
roi = roi.clone()
if isinstance(key, torch.Tensor):
key = key.clone()
if isinstance(value, torch.Tensor):
value = value.clone()
if isinstance(hidden, torch.Tensor):
hidden = hidden.clone()
buffer_item = [input_tokens, roi, key, value, hidden]
data_size = sum([self._get_element_size(data) for data in buffer_item])
with self.buffer_cv:
if self.buffer_size + data_size > self.buffer_size_threshold:
# log outside the while loop to avoid this message being logged
# repeatedly.
logger.debug("KV transfer buffer is full. Handling...")
while self.buffer_size + data_size > self.buffer_size_threshold:
self.buffer_cv.wait()
self.buffer_size += data_size
self.buffer.append(buffer_item)
self.buffer_cv.notify()
def _is_end_signal(self, signal):
return signal is None
def drop_select_handler(self):
try:
while True:
signal = self.signal_pipe.recv_tensor()
if self._is_end_signal(signal):
logger.info("Received end signal!")
break
input_tokens = self.data_pipe.recv_tensor()
roi = self.data_pipe.recv_tensor()
assert roi is not None, "Please provide the roi when sending "\
"drop-select request"
roi = (roi > 0.5)
tokens_roi_recver = [input_tokens, roi]
def is_buffer_available(
tokens_roi_recver: list[torch.Tensor], ) -> bool:
# perform input tokens and roi matching
# FIXME: this matching is O(n), ideally it should be O(1)
# but this buffer size won't (and shouldn't) be too large so
# the fix is not urgent.
for _ in range(len(self.buffer)):
if self._matches(self.buffer[0],
tokens_roi_recver) > 0:
return True
# rotate the element we just accessed to the end
self.buffer.rotate(-1)
return False
with self.buffer_cv:
while not is_buffer_available(tokens_roi_recver):
logger.debug(
"KV transfer buffer is not available. Waiting...")
self.buffer_cv.wait()
# need to clone the tensor
# in case the tensor is freed before sending finishes
matched_item = self.buffer.popleft()
for tensor in matched_item:
self._send_tensor_and_dec_size(tensor)
self.buffer_cv.notify()
except RuntimeError as e:
if 'Connection closed by peer' not in str(e):
raise e
logger.debug("Closing drop_select_handler")
def drop_select(
self, input_tokens: Optional[torch.Tensor],
roi: Optional[torch.Tensor]) -> list[Optional[torch.Tensor]]:
assert self.request_handling_thread is None, \
"drop_select should be called by the KV cache consumer "\
"(e.g. the decode vLLM instance)"
if isinstance(input_tokens, torch.Tensor):
input_tokens = input_tokens.clone()
if isinstance(roi, torch.Tensor):
roi = roi.clone().float()
self.signal_pipe.send_tensor(self.normal_signal)
self.data_pipe.send_tensor(input_tokens)
self.data_pipe.send_tensor(roi)
input_tokens = self.data_pipe.recv_tensor()
roi = self.data_pipe.recv_tensor()
if roi is not None:
# convert from float tensor to bool tensor
# as PyNccl does not support sending bool tensor
roi = (roi > 0.5)
key = self.data_pipe.recv_tensor()
value = self.data_pipe.recv_tensor()
hidden = self.data_pipe.recv_tensor()
return [input_tokens, roi, key, value, hidden]
def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor,
key: torch.Tensor, value: torch.Tensor,
hidden: torch.Tensor) -> None:
self._add_to_buffer(input_tokens, roi, key, value, hidden)
# when calling the insert, the current process is a sender
# need to launch the request handler and start listening to request.
if self.request_handling_thread is None:
self.request_handling_thread = threading.Thread(
target=self.drop_select_handler)
self.request_handling_thread.start()
def close(self):
if hasattr(self, "request_handling_thread"
) and self.request_handling_thread is not None:
self.request_handling_thread.join()
else:
# TODO: have a explicit close signal and have a explicit way to
# check if it's requester
self.signal_pipe.send_tensor(self.end_signal)

Some files were not shown because too many files have changed in this diff Show More