update
This commit is contained in:
6
vllm/distributed/__init__.py
Normal file
6
vllm/distributed/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from .communication_op import *
|
||||
from .parallel_state import *
|
||||
from .utils import *
|
||||
43
vllm/distributed/communication_op.py
Normal file
43
vllm/distributed/communication_op.py
Normal file
@@ -0,0 +1,43 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from .parallel_state import get_tp_group
|
||||
|
||||
|
||||
def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
|
||||
"""All-reduce the input tensor across model parallel group."""
|
||||
return get_tp_group().all_reduce(input_)
|
||||
|
||||
|
||||
def tensor_model_parallel_all_gather(
|
||||
input_: torch.Tensor, dim: int = -1
|
||||
) -> torch.Tensor:
|
||||
"""All-gather the input tensor across model parallel group."""
|
||||
return get_tp_group().all_gather(input_, dim)
|
||||
|
||||
|
||||
def tensor_model_parallel_reduce_scatter(
|
||||
input_: torch.Tensor, dim: int = -1
|
||||
) -> torch.Tensor:
|
||||
"""Reduce-Scatter the input tensor across model parallel group."""
|
||||
return get_tp_group().reduce_scatter(input_, dim)
|
||||
|
||||
|
||||
def tensor_model_parallel_gather(
|
||||
input_: torch.Tensor, dst: int = 0, dim: int = -1
|
||||
) -> torch.Tensor | None:
|
||||
"""Gather the input tensor across model parallel group."""
|
||||
return get_tp_group().gather(input_, dst, dim)
|
||||
|
||||
|
||||
def broadcast_tensor_dict(
|
||||
tensor_dict: dict[Any, torch.Tensor | Any] | None = None, src: int = 0
|
||||
):
|
||||
if not torch.distributed.is_initialized():
|
||||
return tensor_dict
|
||||
return get_tp_group().broadcast_tensor_dict(tensor_dict, src)
|
||||
0
vllm/distributed/device_communicators/__init__.py
Normal file
0
vllm/distributed/device_communicators/__init__.py
Normal file
696
vllm/distributed/device_communicators/all2all.py
Normal file
696
vllm/distributed/device_communicators/all2all.py
Normal file
@@ -0,0 +1,696 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.distributed import get_dp_group, get_ep_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.flashinfer import has_flashinfer_all2all
|
||||
from vllm.utils.import_utils import has_deep_ep, has_mori, has_pplx
|
||||
|
||||
from .base_device_communicator import All2AllManagerBase, Cache
|
||||
|
||||
if has_flashinfer_all2all():
|
||||
from flashinfer.comm import Mapping # type: ignore[import-not-found]
|
||||
from flashinfer.comm.mnnvl import MnnvlConfig # type: ignore[import-not-found]
|
||||
from flashinfer.comm.trtllm_alltoall import (
|
||||
MnnvlMoe, # type: ignore[import-not-found]
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class NaiveAll2AllManager(All2AllManagerBase):
|
||||
"""
|
||||
A naive implementation of all2all communication.
|
||||
It uses all-reduce under the hood, which is not
|
||||
efficient at all. The main purpose is for testing and
|
||||
debugging.
|
||||
"""
|
||||
|
||||
def __init__(self, cpu_group):
|
||||
super().__init__(cpu_group)
|
||||
|
||||
def naive_multicast(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
cu_tokens_across_sp_cpu: torch.Tensor,
|
||||
is_sequence_parallel: bool,
|
||||
) -> torch.Tensor:
|
||||
assert len(x.shape) == 2
|
||||
buffer = torch.empty(
|
||||
(cu_tokens_across_sp_cpu[-1], x.size(1)), device=x.device, dtype=x.dtype
|
||||
)
|
||||
|
||||
rank = self.rank if is_sequence_parallel else self.dp_rank
|
||||
world_size = self.world_size if is_sequence_parallel else self.dp_world_size
|
||||
|
||||
start = 0 if rank == 0 else cu_tokens_across_sp_cpu[rank - 1]
|
||||
end = cu_tokens_across_sp_cpu[rank]
|
||||
buffer[start:end, :].copy_(x)
|
||||
for idx in range(world_size):
|
||||
start = 0 if idx == 0 else cu_tokens_across_sp_cpu[idx - 1]
|
||||
end = cu_tokens_across_sp_cpu[idx]
|
||||
get_ep_group().broadcast(buffer[start:end, :], idx)
|
||||
|
||||
return buffer
|
||||
|
||||
def dispatch_router_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_sequence_parallel: bool = False,
|
||||
extra_tensors: list[torch.Tensor] | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if extra_tensors is not None:
|
||||
raise NotImplementedError(
|
||||
"extra_tensors is not supported for NaiveAll2AllManager"
|
||||
)
|
||||
sp_size = self.tp_group.world_size if is_sequence_parallel else 1
|
||||
dp_metadata = get_forward_context().dp_metadata
|
||||
assert dp_metadata is not None
|
||||
cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size)
|
||||
|
||||
hidden_states = self.naive_multicast(
|
||||
hidden_states, cu_tokens_across_sp_cpu, is_sequence_parallel
|
||||
)
|
||||
router_logits = self.naive_multicast(
|
||||
router_logits, cu_tokens_across_sp_cpu, is_sequence_parallel
|
||||
)
|
||||
|
||||
return hidden_states, router_logits
|
||||
|
||||
def dispatch(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
is_sequence_parallel: bool = False,
|
||||
extra_tensors: list[torch.Tensor] | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
if extra_tensors is not None:
|
||||
raise NotImplementedError(
|
||||
"extra_tensors is not supported for NaiveAll2AllManager"
|
||||
)
|
||||
sp_size = self.tp_group.world_size if is_sequence_parallel else 1
|
||||
dp_metadata = get_forward_context().dp_metadata
|
||||
assert dp_metadata is not None
|
||||
cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size)
|
||||
|
||||
hidden_states = self.naive_multicast(
|
||||
hidden_states, cu_tokens_across_sp_cpu, is_sequence_parallel
|
||||
)
|
||||
topk_weights = self.naive_multicast(
|
||||
topk_weights, cu_tokens_across_sp_cpu, is_sequence_parallel
|
||||
)
|
||||
topk_ids = self.naive_multicast(
|
||||
topk_ids, cu_tokens_across_sp_cpu, is_sequence_parallel
|
||||
)
|
||||
return hidden_states, topk_weights, topk_ids
|
||||
|
||||
def combine(
|
||||
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
|
||||
) -> torch.Tensor:
|
||||
ep_rank = self.rank if is_sequence_parallel else self.dp_rank
|
||||
|
||||
dp_metadata = get_forward_context().dp_metadata
|
||||
assert dp_metadata is not None
|
||||
sp_size = self.tp_group.world_size if is_sequence_parallel else 1
|
||||
cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size)
|
||||
|
||||
start = 0 if ep_rank == 0 else cu_tokens_across_sp_cpu[ep_rank - 1]
|
||||
end = cu_tokens_across_sp_cpu[ep_rank]
|
||||
|
||||
all_hidden_states = get_ep_group().all_reduce(hidden_states)
|
||||
hidden_states = all_hidden_states[start:end, :]
|
||||
return hidden_states
|
||||
|
||||
def destroy(self):
|
||||
pass
|
||||
|
||||
|
||||
class AgRsAll2AllManager(All2AllManagerBase):
|
||||
"""
|
||||
An implementation of all2all communication based on
|
||||
all-gather (dispatch) and reduce-scatter (combine).
|
||||
"""
|
||||
|
||||
def __init__(self, cpu_group):
|
||||
super().__init__(cpu_group)
|
||||
|
||||
def dispatch_router_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_sequence_parallel: bool = False,
|
||||
extra_tensors: list[torch.Tensor] | None = None,
|
||||
) -> (
|
||||
tuple[torch.Tensor, torch.Tensor]
|
||||
| tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
|
||||
):
|
||||
"""
|
||||
Gather hidden_states and router_logits from all dp ranks.
|
||||
"""
|
||||
dp_metadata = get_forward_context().dp_metadata
|
||||
assert dp_metadata is not None
|
||||
sizes = dp_metadata.get_chunk_sizes_across_dp_rank()
|
||||
assert sizes is not None
|
||||
dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
|
||||
assert sizes[dist_group.rank_in_group] == hidden_states.shape[0]
|
||||
|
||||
tensors_to_gather = [hidden_states, router_logits]
|
||||
if extra_tensors is not None:
|
||||
tensors_to_gather.extend(extra_tensors)
|
||||
|
||||
gathered_tensors = dist_group.all_gatherv(
|
||||
tensors_to_gather,
|
||||
dim=0,
|
||||
sizes=sizes,
|
||||
)
|
||||
|
||||
if extra_tensors is not None:
|
||||
return (gathered_tensors[0], gathered_tensors[1], gathered_tensors[2:])
|
||||
return gathered_tensors[0], gathered_tensors[1]
|
||||
|
||||
def dispatch(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
is_sequence_parallel: bool = False,
|
||||
extra_tensors: list[torch.Tensor] | None = None,
|
||||
) -> (
|
||||
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
|
||||
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
|
||||
):
|
||||
"""
|
||||
Gather hidden_states and router_logits from all dp ranks.
|
||||
"""
|
||||
dp_metadata = get_forward_context().dp_metadata
|
||||
assert dp_metadata is not None
|
||||
sizes = dp_metadata.get_chunk_sizes_across_dp_rank()
|
||||
assert sizes is not None
|
||||
dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
|
||||
assert sizes[dist_group.rank_in_group] == hidden_states.shape[0]
|
||||
|
||||
tensors_to_gather = [hidden_states, topk_weights, topk_ids]
|
||||
if extra_tensors is not None:
|
||||
tensors_to_gather.extend(extra_tensors)
|
||||
|
||||
gathered_tensors = dist_group.all_gatherv(
|
||||
tensors_to_gather,
|
||||
dim=0,
|
||||
sizes=sizes,
|
||||
)
|
||||
|
||||
hidden_states = gathered_tensors[0]
|
||||
topk_weights = gathered_tensors[1]
|
||||
topk_ids = gathered_tensors[2]
|
||||
|
||||
if extra_tensors is None:
|
||||
return hidden_states, topk_weights, topk_ids
|
||||
|
||||
return hidden_states, topk_weights, topk_ids, gathered_tensors[3:]
|
||||
|
||||
def combine(
|
||||
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Reduce-scatter hidden_states across all dp ranks.
|
||||
"""
|
||||
dp_metadata = get_forward_context().dp_metadata
|
||||
assert dp_metadata is not None
|
||||
sizes = dp_metadata.get_chunk_sizes_across_dp_rank()
|
||||
assert sizes is not None
|
||||
|
||||
dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
|
||||
hidden_states = dist_group.reduce_scatterv(hidden_states, dim=0, sizes=sizes)
|
||||
return hidden_states
|
||||
|
||||
def destroy(self):
|
||||
pass
|
||||
|
||||
|
||||
class PPLXAll2AllManager(All2AllManagerBase):
|
||||
"""
|
||||
All2All communication based on PPLX kernels.
|
||||
"""
|
||||
|
||||
def __init__(self, cpu_group):
|
||||
assert has_pplx(), (
|
||||
"pplx_kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md"
|
||||
" to install pplx_kernels."
|
||||
)
|
||||
super().__init__(cpu_group)
|
||||
|
||||
if self.internode:
|
||||
# inter-node communication needs nvshmem,
|
||||
# intra-node communication uses p2p mapping directly
|
||||
from pplx_kernels.nvshmem import ( # type: ignore[import-not-found]
|
||||
nvshmem_alloc_empty_unique_id,
|
||||
nvshmem_get_unique_id,
|
||||
nvshmem_init,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"Initialize NVSHMEM for pplx_kernels: rank=%d, world size=%d",
|
||||
self.rank,
|
||||
self.world_size,
|
||||
)
|
||||
uid = (
|
||||
nvshmem_get_unique_id()
|
||||
if self.rank == 0
|
||||
else nvshmem_alloc_empty_unique_id()
|
||||
)
|
||||
dist.broadcast(
|
||||
uid,
|
||||
src=dist.get_process_group_ranks(self.cpu_group)[0],
|
||||
group=self.cpu_group,
|
||||
)
|
||||
logger.debug("PPLX NVSHMEM UID = %s", uid)
|
||||
nvshmem_init(uid, self.rank, self.world_size)
|
||||
|
||||
self.handle_cache = Cache()
|
||||
|
||||
def get_handle(self, kwargs):
|
||||
import pplx_kernels as pplx # type: ignore[import-not-found]
|
||||
|
||||
return self.handle_cache.get_or_create(
|
||||
kwargs,
|
||||
pplx.AllToAll.internode if self.internode else pplx.AllToAll.intranode,
|
||||
)
|
||||
|
||||
def dispatch_router_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_sequence_parallel: bool = False,
|
||||
extra_tensors: list[torch.Tensor] | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
raise NotImplementedError
|
||||
|
||||
def dispatch(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
is_sequence_parallel: bool = False,
|
||||
extra_tensors: list[torch.Tensor] | None = None,
|
||||
) -> (
|
||||
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
|
||||
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
def combine(
|
||||
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def destroy(self):
|
||||
with self.handle_cache._lock:
|
||||
for _, handle in self.handle_cache._cache.items():
|
||||
handle.destroy()
|
||||
|
||||
if self.internode:
|
||||
from pplx_kernels.nvshmem import (
|
||||
nvshmem_finalize, # type: ignore[import-not-found]
|
||||
)
|
||||
|
||||
logger.debug("PPLX NVSHMEM finalize")
|
||||
nvshmem_finalize()
|
||||
|
||||
|
||||
class DeepEPAll2AllManagerBase(All2AllManagerBase):
|
||||
"""
|
||||
All2All communication based on DeepEP High-Throughput kernels.
|
||||
"""
|
||||
|
||||
def __init__(self, cpu_group):
|
||||
assert has_deep_ep(), (
|
||||
"DeepEP kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md"
|
||||
" to install DeepEP kernels."
|
||||
) # noqa
|
||||
super().__init__(cpu_group)
|
||||
self.handle_cache = Cache()
|
||||
|
||||
# This is the DeepEP default. Stick to it till we can establish
|
||||
# reasonable defaults based on profiling.
|
||||
self.num_sms = 20
|
||||
|
||||
def get_handle(self, kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def dispatch_router_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_sequence_parallel: bool = False,
|
||||
extra_tensors: list[torch.Tensor] | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
raise NotImplementedError
|
||||
|
||||
def dispatch(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
is_sequence_parallel: bool = False,
|
||||
extra_tensors: list[torch.Tensor] | None = None,
|
||||
) -> (
|
||||
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
|
||||
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
def combine(
|
||||
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def destroy(self):
|
||||
pass
|
||||
|
||||
|
||||
class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
|
||||
"""
|
||||
All2All communication based on DeepEP High-Throughput kernels.
|
||||
"""
|
||||
|
||||
def __init__(self, cpu_group):
|
||||
super().__init__(cpu_group)
|
||||
|
||||
def _make_all2all_kwargs(self) -> dict[Any, Any]:
|
||||
# Defaults for internode and intranode are taken from DeepEP tests.
|
||||
num_nvl_bytes = envs.VLLM_DEEPEP_BUFFER_SIZE_MB * 1024 * 1024
|
||||
num_rdma_bytes = None
|
||||
num_qps_per_rank = None
|
||||
|
||||
if self.internode and not envs.VLLM_DEEPEP_HIGH_THROUGHPUT_FORCE_INTRA_NODE:
|
||||
num_rdma_bytes = envs.VLLM_DEEPEP_BUFFER_SIZE_MB * 1024 * 1024
|
||||
num_qps_per_rank = self.num_sms // 2
|
||||
else:
|
||||
num_rdma_bytes = 0
|
||||
num_qps_per_rank = 1
|
||||
|
||||
assert num_rdma_bytes is not None
|
||||
assert num_qps_per_rank is not None
|
||||
return dict(
|
||||
group=self.cpu_group,
|
||||
num_nvl_bytes=num_nvl_bytes,
|
||||
num_rdma_bytes=num_rdma_bytes,
|
||||
low_latency_mode=False,
|
||||
num_qps_per_rank=num_qps_per_rank,
|
||||
)
|
||||
|
||||
def get_handle(self, kwargs):
|
||||
assert len(kwargs) == 0, (
|
||||
"DeepEPHTAll2AllManager expects no arguments. All the required "
|
||||
"args are computed in the Manager itself."
|
||||
)
|
||||
|
||||
import deep_ep # type: ignore[import-not-found]
|
||||
|
||||
buffer_kwargs = self._make_all2all_kwargs()
|
||||
logger.debug("DeepEP all2all args %s", buffer_kwargs)
|
||||
handle: deep_ep.Buffer = self.handle_cache.get_or_create(
|
||||
buffer_kwargs, deep_ep.Buffer
|
||||
)
|
||||
return handle
|
||||
|
||||
def set_num_sms(self, num_sms: int):
|
||||
import deep_ep # type: ignore[import-not-found]
|
||||
|
||||
# Right now the buffers are sized for only what the kernels were
|
||||
# created with. So we can only reduce the number of SMS used
|
||||
# but not increase it.
|
||||
if num_sms > self.num_sms:
|
||||
num_sms = self.num_sms
|
||||
deep_ep.Buffer.set_num_sms(num_sms)
|
||||
|
||||
|
||||
class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
|
||||
"""
|
||||
All2All communication based on DeepEP Low-Latency kernels.
|
||||
"""
|
||||
|
||||
def __init__(self, cpu_group):
|
||||
super().__init__(cpu_group)
|
||||
|
||||
def _make_all2all_kwargs(
|
||||
self,
|
||||
max_num_tokens_per_dp_rank: int,
|
||||
token_hidden_size: int,
|
||||
num_ep_ranks: int,
|
||||
num_global_experts: int,
|
||||
num_local_experts: int,
|
||||
) -> dict[Any, Any]:
|
||||
"""
|
||||
max_num_tokens_per_dp_rank : the maximum number of tokens a DP rank
|
||||
can dispatch all the ranks must hold the same value.
|
||||
token_hidden_size: the hidden dimension of each token.
|
||||
num_ep_ranks: the number of EP group ranks.
|
||||
num_global_experts: Number of experts in the model.
|
||||
num_local_experts: Number of experts in an EP rank.
|
||||
"""
|
||||
import deep_ep # type: ignore[import-not-found]
|
||||
|
||||
# Defaults for internode and intranode are taken from DeepEP tests.
|
||||
num_nvl_bytes = envs.VLLM_DEEPEP_BUFFER_SIZE_MB * 1024 * 1024
|
||||
num_qps_per_rank = num_local_experts
|
||||
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(
|
||||
num_max_dispatch_tokens_per_rank=max_num_tokens_per_dp_rank,
|
||||
hidden=token_hidden_size,
|
||||
num_ranks=num_ep_ranks,
|
||||
num_experts=num_global_experts,
|
||||
)
|
||||
|
||||
assert num_rdma_bytes is not None
|
||||
return dict(
|
||||
group=self.cpu_group,
|
||||
num_nvl_bytes=num_nvl_bytes,
|
||||
num_rdma_bytes=num_rdma_bytes,
|
||||
low_latency_mode=True,
|
||||
num_qps_per_rank=num_qps_per_rank,
|
||||
allow_nvlink_for_low_latency_mode=True,
|
||||
allow_mnnvl=envs.VLLM_DEEPEP_LOW_LATENCY_USE_MNNVL,
|
||||
)
|
||||
|
||||
def get_handle(self, kwargs):
|
||||
"""
|
||||
The kwargs for DeepEPLLAll2AllManager is dictated by
|
||||
_make_all2all_kwargs.
|
||||
"""
|
||||
import deep_ep # type: ignore[import-not-found]
|
||||
|
||||
buffer_kwargs = self._make_all2all_kwargs(**kwargs)
|
||||
logger.debug("DeepEP all2all args %s", buffer_kwargs)
|
||||
handle: deep_ep.Buffer = self.handle_cache.get_or_create(
|
||||
buffer_kwargs, deep_ep.Buffer
|
||||
)
|
||||
return handle
|
||||
|
||||
# DeepEP LL uses RDMA so no SMs are used for communication
|
||||
def max_sms_used(self) -> int | None:
|
||||
return 0
|
||||
|
||||
|
||||
class FlashInferAllToAllManager(All2AllManagerBase):
|
||||
"""
|
||||
All2All communication based on flashinfer kernels.
|
||||
"""
|
||||
|
||||
# This type lint could be removed after all of the work in
|
||||
# https://github.com/vllm-project/vllm/issues/26533 done.
|
||||
rank: int
|
||||
world_size: int
|
||||
|
||||
def __init__(self, cpu_group):
|
||||
assert has_flashinfer_all2all(), (
|
||||
"flashinfer all2all module not found. Please install/check flashinfer"
|
||||
) # noqa
|
||||
super().__init__(cpu_group)
|
||||
logger.debug(
|
||||
"Initialize for flashinfer All2All rank=%d, world size=%d",
|
||||
self.rank,
|
||||
self.world_size,
|
||||
)
|
||||
self.initialized = False
|
||||
self.alltoall_info = None
|
||||
|
||||
def initialize(
|
||||
self,
|
||||
world_size: int,
|
||||
rank: int,
|
||||
gpus_per_node: int,
|
||||
):
|
||||
"""Initialize workspace"""
|
||||
if self.initialized:
|
||||
return
|
||||
|
||||
self.cleanup()
|
||||
logger.debug("making map: rank=%d, world size=%d", rank, world_size)
|
||||
self.mapping = Mapping(
|
||||
world_size,
|
||||
rank,
|
||||
gpus_per_node,
|
||||
tp_size=world_size,
|
||||
)
|
||||
|
||||
from vllm.distributed.device_communicators.mnnvl_compat import (
|
||||
CustomCommunicator,
|
||||
)
|
||||
|
||||
dp_config = MnnvlConfig(
|
||||
comm_backend=CustomCommunicator(get_dp_group().cpu_group),
|
||||
fabric_page_size=1 << 29, # 512MB
|
||||
allocation_granularity=0, # Auto-detect
|
||||
)
|
||||
|
||||
self.workspace_tensor = MnnvlMoe.get_moe_workspaces(self.mapping, dp_config)
|
||||
self.prepare_workspace_tensor = MnnvlMoe.get_moe_prepare_workspace(
|
||||
self.mapping, dp_config
|
||||
)
|
||||
|
||||
self.world_size = world_size
|
||||
self.rank = rank
|
||||
self.gpus_per_node = gpus_per_node
|
||||
self.initialized = True
|
||||
|
||||
logger.info(
|
||||
"FlashInfer All2All initialized for rank %s, size %s", rank, world_size
|
||||
)
|
||||
|
||||
def ensure_alltoall_workspace_initialized(self):
|
||||
"""Ensure workspace is initialized"""
|
||||
if not has_flashinfer_all2all():
|
||||
return False
|
||||
|
||||
if self.world_size <= 1:
|
||||
return False
|
||||
|
||||
if not self.initialized:
|
||||
self.initialize(
|
||||
world_size=self.world_size,
|
||||
rank=self.rank,
|
||||
gpus_per_node=torch.cuda.device_count,
|
||||
)
|
||||
return self.initialized
|
||||
|
||||
def get_handle(self, kwargs):
|
||||
return self
|
||||
|
||||
def cleanup(self):
|
||||
"""Clean up workspace"""
|
||||
if (
|
||||
self.initialized
|
||||
and self.workspace_tensor is not None
|
||||
and self.prepare_workspace_tensor is not None
|
||||
):
|
||||
try:
|
||||
del self.workspace_tensor
|
||||
del self.prepare_workspace_tensor
|
||||
except Exception as e:
|
||||
logger.warning("Failed to cleanup FlashInfer workspace: %s", e)
|
||||
finally:
|
||||
self.workspace_tensor = None
|
||||
self.prepare_workspace_tensor = None
|
||||
self.mapping = None
|
||||
self.initialized = False
|
||||
|
||||
|
||||
class MoriAll2AllManager(All2AllManagerBase):
|
||||
def __init__(self, cpu_group):
|
||||
assert has_mori(), (
|
||||
"MoRI kernels not found. Please follow https://github.com/ROCm/mori/blob/main/README.md"
|
||||
" to install MoRI kernels."
|
||||
) # noqa
|
||||
import mori
|
||||
|
||||
super().__init__(cpu_group)
|
||||
self.handle_cache = Cache()
|
||||
|
||||
torch._C._distributed_c10d._register_process_group("mori", cpu_group)
|
||||
mori.shmem.shmem_torch_process_group_init("mori")
|
||||
|
||||
def _make_all2all_kwargs(
|
||||
self,
|
||||
rank: int,
|
||||
num_ep_ranks: int,
|
||||
input_dtype: torch.dtype,
|
||||
quant_dtype: torch.dtype,
|
||||
token_hidden_size: int,
|
||||
scale_dim: int,
|
||||
scale_type_size: int,
|
||||
max_num_tokens_per_dp_rank: int,
|
||||
num_local_experts: int,
|
||||
num_experts_per_token: int,
|
||||
):
|
||||
import mori # type: ignore[import-not-found]
|
||||
|
||||
from vllm.platforms.rocm import on_gfx942, on_gfx950
|
||||
|
||||
assert on_gfx942() or on_gfx950(), (
|
||||
"mori currently only support arch gfx942 and gfx950"
|
||||
)
|
||||
|
||||
if not self.internode:
|
||||
# single node
|
||||
kernel_type = mori.ops.EpDispatchCombineKernelType.IntraNode
|
||||
rdma_block_num = 0
|
||||
warp_num_per_block = 16
|
||||
block_num = 80
|
||||
else:
|
||||
# multi node
|
||||
kernel_type = mori.ops.EpDispatchCombineKernelType.InterNodeV1
|
||||
if on_gfx942():
|
||||
warp_num_per_block = 16
|
||||
block_num = 32
|
||||
rdma_block_num = 16
|
||||
elif on_gfx950():
|
||||
warp_num_per_block = 8
|
||||
block_num = 64
|
||||
rdma_block_num = 32
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"mori currently only support arch gfx942 and gfx950"
|
||||
)
|
||||
|
||||
return dict(
|
||||
rank=rank,
|
||||
world_size=num_ep_ranks,
|
||||
data_type=quant_dtype,
|
||||
hidden_dim=token_hidden_size,
|
||||
scale_dim=scale_dim,
|
||||
scale_type_size=scale_type_size,
|
||||
max_token_type_size=input_dtype.itemsize,
|
||||
max_num_inp_token_per_rank=max_num_tokens_per_dp_rank,
|
||||
num_experts_per_rank=num_local_experts,
|
||||
num_experts_per_token=num_experts_per_token,
|
||||
warp_num_per_block=warp_num_per_block,
|
||||
block_num=block_num,
|
||||
kernel_type=kernel_type,
|
||||
rdma_block_num=rdma_block_num,
|
||||
gpu_per_node=min(8, num_ep_ranks),
|
||||
)
|
||||
|
||||
def _make_handle(self, **kwargs):
|
||||
import mori # type: ignore[import-not-found]
|
||||
|
||||
mori_config = mori.ops.EpDispatchCombineConfig(**kwargs)
|
||||
handle = mori.ops.EpDispatchCombineOp(mori_config)
|
||||
return handle
|
||||
|
||||
def get_handle(self, kwargs):
|
||||
import mori # type: ignore[import-not-found]
|
||||
|
||||
mori_kwargs = self._make_all2all_kwargs(**kwargs)
|
||||
logger.debug("MoRI all2all args %s", mori_kwargs)
|
||||
handle: mori.ops.EpDispatchCombineOp = self.handle_cache.get_or_create(
|
||||
mori_kwargs, self._make_handle
|
||||
)
|
||||
return handle
|
||||
344
vllm/distributed/device_communicators/all_reduce_utils.py
Normal file
344
vllm/distributed/device_communicators/all_reduce_utils.py
Normal file
@@ -0,0 +1,344 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import ctypes
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
from collections.abc import Sequence
|
||||
from itertools import product
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_is_batch_invariant,
|
||||
)
|
||||
from vllm.utils.system_utils import update_environment_variables
|
||||
from vllm.utils.torch_utils import cuda_device_count_stateless
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
MiB = 1024 * 1024
|
||||
# Max size for each world size in case symmetric memory is available
|
||||
# For different SM architectures
|
||||
CUSTOM_ALL_REDUCE_MAX_SIZES = {
|
||||
"9.0": {
|
||||
2: 64 * MiB, # 64 MB
|
||||
4: 32 * MiB, # 32 MB
|
||||
6: MiB // 2, # 512 KB
|
||||
8: MiB // 4, # 256 KB
|
||||
},
|
||||
"10.0": {
|
||||
2: 2 * MiB, # 2 MB
|
||||
4: 2 * MiB, # 2 MB
|
||||
6: 1 * MiB, # 1 MB
|
||||
8: 1 * MiB, # 1 MB
|
||||
},
|
||||
}
|
||||
|
||||
SYMM_MEM_ALL_REDUCE_MAX_SIZES = {
|
||||
"9.0": {
|
||||
2: 64 * MiB, # 64 MB
|
||||
4: 32 * MiB, # 32 MB
|
||||
6: 64 * MiB, # 64 MB
|
||||
8: 64 * MiB, # 64 MB
|
||||
},
|
||||
"10.0": {
|
||||
2: 8 * MiB, # 8 MB
|
||||
4: 32 * MiB, # 32 MB
|
||||
6: 128 * MiB, # 128 MB
|
||||
8: 128 * MiB, # 128 MB
|
||||
},
|
||||
}
|
||||
|
||||
NCCL_SYMM_MEM_ALL_REDUCE_CONFIG: dict[str, Any] = {
|
||||
"min_world_size": 4,
|
||||
"thresholds": {
|
||||
4: 2 * MiB, # 2 MB
|
||||
8: 1 * MiB, # 1 MB
|
||||
},
|
||||
"always_use_above_world_size": 8, # Always use symm mem for world_size > 8
|
||||
}
|
||||
|
||||
|
||||
def should_nccl_symm_mem_allreduce(world_size: int, input_tensor: torch.Tensor) -> bool:
|
||||
from vllm.distributed.device_communicators.pynccl_allocator import (
|
||||
is_symmetric_memory_enabled,
|
||||
)
|
||||
|
||||
if vllm_is_batch_invariant():
|
||||
return False
|
||||
|
||||
if not is_symmetric_memory_enabled():
|
||||
return False
|
||||
if world_size < NCCL_SYMM_MEM_ALL_REDUCE_CONFIG["min_world_size"]:
|
||||
return False
|
||||
threshold = NCCL_SYMM_MEM_ALL_REDUCE_CONFIG["thresholds"].get(world_size)
|
||||
if threshold is not None and input_tensor.nbytes >= threshold:
|
||||
return True
|
||||
return world_size > NCCL_SYMM_MEM_ALL_REDUCE_CONFIG["always_use_above_world_size"]
|
||||
|
||||
|
||||
def producer(
|
||||
batch_src: Sequence[int],
|
||||
producer_queue,
|
||||
consumer_queue,
|
||||
result_queue,
|
||||
cuda_visible_devices: str | None = None,
|
||||
):
|
||||
if cuda_visible_devices is not None:
|
||||
update_environment_variables({"CUDA_VISIBLE_DEVICES": cuda_visible_devices})
|
||||
|
||||
lib = CudaRTLibrary()
|
||||
for i in batch_src:
|
||||
lib.cudaSetDevice(i)
|
||||
pointer = lib.cudaMalloc(1024)
|
||||
lib.cudaMemset(pointer, 1, 1024)
|
||||
lib.cudaDeviceSynchronize()
|
||||
handle = lib.cudaIpcGetMemHandle(pointer)
|
||||
producer_queue.put(handle)
|
||||
open_success = consumer_queue.get()
|
||||
if open_success:
|
||||
# use two queues to simulate barrier
|
||||
producer_queue.put(0)
|
||||
consumer_queue.get()
|
||||
# check if the memory is modified
|
||||
host_data = (ctypes.c_char * 1024)()
|
||||
lib.cudaMemcpy(host_data, pointer, 1024) # type: ignore
|
||||
for i in range(1024):
|
||||
if ord(host_data[i]) != 2:
|
||||
open_success = False
|
||||
break
|
||||
result_queue.put(open_success)
|
||||
lib.cudaDeviceReset()
|
||||
|
||||
|
||||
def consumer(
|
||||
batch_tgt: Sequence[int],
|
||||
producer_queue,
|
||||
consumer_queue,
|
||||
result_queue,
|
||||
cuda_visible_devices: str | None = None,
|
||||
):
|
||||
if cuda_visible_devices is not None:
|
||||
update_environment_variables({"CUDA_VISIBLE_DEVICES": cuda_visible_devices})
|
||||
|
||||
lib = CudaRTLibrary()
|
||||
for j in batch_tgt:
|
||||
lib.cudaSetDevice(j)
|
||||
handle = producer_queue.get()
|
||||
open_success = False
|
||||
try:
|
||||
pointer = lib.cudaIpcOpenMemHandle(handle) # type: ignore
|
||||
open_success = True
|
||||
except RuntimeError:
|
||||
# cannot error out here, because the producer process
|
||||
# is still waiting for the response.
|
||||
pass
|
||||
consumer_queue.put(open_success)
|
||||
if open_success:
|
||||
# modify the memory
|
||||
lib.cudaMemset(pointer, 2, 1024)
|
||||
lib.cudaDeviceSynchronize()
|
||||
# use two queues to simulate barrier
|
||||
producer_queue.get()
|
||||
consumer_queue.put(0)
|
||||
# check if the memory is modified
|
||||
host_data = (ctypes.c_char * 1024)()
|
||||
lib.cudaMemcpy(host_data, pointer, 1024) # type: ignore
|
||||
for i in range(1024):
|
||||
if ord(host_data[i]) != 2:
|
||||
open_success = False
|
||||
break
|
||||
result_queue.put(open_success)
|
||||
lib.cudaDeviceReset()
|
||||
|
||||
|
||||
def can_actually_p2p(
|
||||
batch_src: Sequence[int],
|
||||
batch_tgt: Sequence[int],
|
||||
) -> Sequence[bool]:
|
||||
"""
|
||||
Usually, checking if P2P access is enabled can be done by
|
||||
`torch.cuda.can_device_access_peer(src, tgt)`. However, sometimes
|
||||
the driver might be broken, and `torch.cuda.can_device_access_peer(src, tgt)`
|
||||
returns `True` even if P2P access is not actually possible.
|
||||
See https://github.com/vllm-project/vllm/issues/2728 and
|
||||
https://forums.developer.nvidia.com/t/direct-gpu-gpu-communication-does-not-seem-to-work-properly/283264/10
|
||||
Therefore, we have to perform a real P2P access to check if it is actually
|
||||
possible.
|
||||
|
||||
Note on p2p and cuda IPC:
|
||||
Usually, one process uses one GPU:
|
||||
GPU src --> cuda context src --> tensor src --> process src
|
||||
|
||||
We need to combine p2p and cuda IPC, so that:
|
||||
GPU src --> cuda context src --> tensor src --> process src
|
||||
|shared|
|
||||
GPU tgt --> cuda context tgt --> tensor tgt --> process tgt
|
||||
That is to say, process src creates a tensor in GPU src, passes IPC handle to
|
||||
process tgt, and process tgt accesses the tensor in GPU tgt. Any operation on the
|
||||
tensor in process tgt will be reflected in the tensor in process src, because
|
||||
they are the same memory segment.
|
||||
It is important to note that process tgt accesses the tensor in GPU tgt, not
|
||||
GPU src. That's why we need p2p access.
|
||||
|
||||
The most time-consuming part is the process creation. To avoid creating
|
||||
processes for every pair of GPUs, we use batched testing. We create two
|
||||
processes for testing all pairs of GPUs in batch. The trick is to reset
|
||||
the device after each test (which is not available in PyTorch).
|
||||
""" # noqa
|
||||
cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
|
||||
# pass the CUDA_VISIBLE_DEVICES to the child process
|
||||
# to make sure they see the same set of GPUs
|
||||
|
||||
# make sure the processes are spawned
|
||||
smp = mp.get_context("spawn")
|
||||
producer_queue = smp.Queue()
|
||||
consumer_queue = smp.Queue()
|
||||
result_queue = smp.Queue()
|
||||
p_src = smp.Process(
|
||||
target=producer,
|
||||
args=(
|
||||
batch_src,
|
||||
producer_queue,
|
||||
consumer_queue,
|
||||
result_queue,
|
||||
cuda_visible_devices,
|
||||
),
|
||||
)
|
||||
p_tgt = smp.Process(
|
||||
target=consumer,
|
||||
args=(
|
||||
batch_tgt,
|
||||
producer_queue,
|
||||
consumer_queue,
|
||||
result_queue,
|
||||
cuda_visible_devices,
|
||||
),
|
||||
)
|
||||
p_src.start()
|
||||
p_tgt.start()
|
||||
p_src.join()
|
||||
p_tgt.join()
|
||||
assert p_src.exitcode == 0 and p_tgt.exitcode == 0
|
||||
result: list[bool] = []
|
||||
for src, tgt in zip(batch_src, batch_tgt):
|
||||
a = result_queue.get()
|
||||
b = result_queue.get()
|
||||
if a != b:
|
||||
logger.warning(
|
||||
"Two processes do not agree on the P2P access"
|
||||
" status on %d -> %d, treat as disabled.",
|
||||
src,
|
||||
tgt,
|
||||
)
|
||||
result.append(False)
|
||||
else:
|
||||
result.append(a)
|
||||
return result
|
||||
|
||||
|
||||
# why do we need this cache?
|
||||
# we are testing peer-to-peer (p2p) access between GPUs,across processes.
|
||||
# if we test it every time, it will be very slow, because we need to create
|
||||
# N * N * 2 processes, where N is the world size. This is very slow.
|
||||
# to reduce the time, we use a cache file to store the p2p access status.
|
||||
# the cache file is generated by the master process if it does not exist.
|
||||
# then all the processes can read the cache file to check the p2p access status.
|
||||
# Note that the cache file is suffixed by the CUDA_VISIBLE_DEVICES, so that we
|
||||
# can have different cache files for different CUDA_VISIBLE_DEVICES settings,
|
||||
# e.g. used by different vllm engines. The device id in the cache file is a
|
||||
# **local** device id, i.e. from 0 to num_dev-1, where num_dev is the number
|
||||
# of visible devices in the vllm engine.
|
||||
_gpu_p2p_access_cache: dict[str, bool] | None = None
|
||||
|
||||
|
||||
def gpu_p2p_access_check(src: int, tgt: int) -> bool:
|
||||
"""Check if GPU src can access GPU tgt."""
|
||||
|
||||
# if the cache variable is already calculated,
|
||||
# read from the cache instead of checking it again
|
||||
global _gpu_p2p_access_cache
|
||||
if _gpu_p2p_access_cache is not None:
|
||||
return _gpu_p2p_access_cache[f"{src}->{tgt}"]
|
||||
|
||||
is_distributed = dist.is_initialized()
|
||||
|
||||
num_dev = cuda_device_count_stateless()
|
||||
cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
|
||||
if cuda_visible_devices is None:
|
||||
cuda_visible_devices = ",".join(str(i) for i in range(num_dev))
|
||||
|
||||
path = os.path.join(
|
||||
envs.VLLM_CACHE_ROOT, f"gpu_p2p_access_cache_for_{cuda_visible_devices}.json"
|
||||
)
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||
from vllm.distributed.parallel_state import get_world_group
|
||||
|
||||
if (not is_distributed or get_world_group().local_rank == 0) and (
|
||||
not os.path.exists(path)
|
||||
):
|
||||
# only the local master process (with local_rank == 0) can
|
||||
# enter this block to calculate the cache
|
||||
logger.info("generating GPU P2P access cache in %s", path)
|
||||
cache: dict[str, bool] = {}
|
||||
ids = list(range(num_dev))
|
||||
# batch of all pairs of GPUs
|
||||
batch_src, batch_tgt = zip(*list(product(ids, ids)))
|
||||
# NOTE: we use `subprocess` rather than `multiprocessing` here
|
||||
# because the caller might not have `if __name__ == "__main__":`,
|
||||
# in that case we cannot use spawn method in multiprocessing.
|
||||
# However, `can_actually_p2p` requires spawn method.
|
||||
# The fix is, we use `subprocess` to call the function,
|
||||
# where we have `if __name__ == "__main__":` in this file.
|
||||
|
||||
# use a temporary file to store the result
|
||||
# we don't use the output of the subprocess directly,
|
||||
# because the subprocess might produce logging output
|
||||
with tempfile.NamedTemporaryFile() as output_file:
|
||||
input_bytes = pickle.dumps((batch_src, batch_tgt, output_file.name))
|
||||
returned = subprocess.run(
|
||||
[sys.executable, __file__], input=input_bytes, capture_output=True
|
||||
)
|
||||
# check if the subprocess is successful
|
||||
try:
|
||||
returned.check_returncode()
|
||||
except Exception as e:
|
||||
# wrap raised exception to provide more information
|
||||
raise RuntimeError(
|
||||
f"Error happened when batch testing "
|
||||
f"peer-to-peer access from {batch_src} to {batch_tgt}:\n"
|
||||
f"{returned.stderr.decode()}"
|
||||
) from e
|
||||
with open(output_file.name, "rb") as f:
|
||||
result = pickle.load(f)
|
||||
for _i, _j, r in zip(batch_src, batch_tgt, result):
|
||||
cache[f"{_i}->{_j}"] = r
|
||||
with open(path, "w") as f:
|
||||
json.dump(cache, f, indent=4)
|
||||
if is_distributed:
|
||||
get_world_group().barrier()
|
||||
logger.info("reading GPU P2P access cache from %s", path)
|
||||
with open(path) as f:
|
||||
cache = json.load(f)
|
||||
_gpu_p2p_access_cache = cache
|
||||
return _gpu_p2p_access_cache[f"{src}->{tgt}"]
|
||||
|
||||
|
||||
__all__ = ["gpu_p2p_access_check"]
|
||||
|
||||
if __name__ == "__main__":
|
||||
batch_src, batch_tgt, output_file = pickle.loads(sys.stdin.buffer.read())
|
||||
result = can_actually_p2p(batch_src, batch_tgt)
|
||||
with open(output_file, "wb") as f:
|
||||
f.write(pickle.dumps(result))
|
||||
@@ -0,0 +1,362 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import threading
|
||||
from weakref import WeakValueDictionary
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
import ixformer.distributed as ixfd
|
||||
import os
|
||||
|
||||
class Cache:
|
||||
def __init__(self):
|
||||
self._cache: WeakValueDictionary = WeakValueDictionary()
|
||||
self._lock = threading.RLock() # Reentrant lock for thread safety
|
||||
|
||||
def get_or_create(self, kwargs, func):
|
||||
# Create a hashable key from the kwargs
|
||||
key = tuple(sorted((k, v) for k, v in kwargs.items()))
|
||||
|
||||
with self._lock:
|
||||
instance = self._cache.get(key)
|
||||
if instance is None:
|
||||
instance = func(**kwargs)
|
||||
self._cache[key] = instance
|
||||
return instance
|
||||
|
||||
|
||||
class All2AllManagerBase:
|
||||
rank: int
|
||||
world_size: int
|
||||
|
||||
def __init__(self, cpu_group):
|
||||
self.cpu_group = cpu_group
|
||||
|
||||
# compute some common properties
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_dp_group,
|
||||
get_tp_group,
|
||||
in_the_same_node_as,
|
||||
)
|
||||
|
||||
# all2all lives in ep group, which is merged from dp and tp group
|
||||
self.dp_group = get_dp_group()
|
||||
self.tp_group = get_tp_group()
|
||||
|
||||
# no self.ep_group since self.ep_group is still in construction
|
||||
# when we create this object
|
||||
self.dp_rank = self.dp_group.rank_in_group
|
||||
self.dp_world_size = self.dp_group.world_size
|
||||
self.rank = dist.get_rank(cpu_group)
|
||||
self.world_size = dist.get_world_size(cpu_group)
|
||||
|
||||
# all2all communication often has separate implementations for
|
||||
# intra-node and inter-node communication
|
||||
self.internode = not all(in_the_same_node_as(cpu_group, source_rank=0))
|
||||
|
||||
def get_handle(self, kwargs):
|
||||
# get a handle for the all2all communication,
|
||||
# based on the kwargs.
|
||||
# different layers can have different configs,
|
||||
# e.g. one layer has hidden size 1024, another has 2048.
|
||||
# usually the underlying implementation caches the handle
|
||||
# and reuse it for the same config.
|
||||
raise NotImplementedError
|
||||
|
||||
def dispatch_router_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_sequence_parallel: bool = False,
|
||||
extra_tensors: list[torch.Tensor] | None = None,
|
||||
) -> (
|
||||
tuple[torch.Tensor, torch.Tensor]
|
||||
| tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
|
||||
):
|
||||
# Subclasses should either:
|
||||
# - implement handling for extra_tensors, or
|
||||
# - raise a clear error if extra_tensors is not supported.
|
||||
raise NotImplementedError
|
||||
|
||||
def dispatch(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
is_sequence_parallel: bool = False,
|
||||
extra_tensors: list[torch.Tensor] | None = None,
|
||||
) -> (
|
||||
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
|
||||
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
|
||||
):
|
||||
# Subclasses should either:
|
||||
# - implement handling for extra_tensors, or
|
||||
# - raise a clear error if extra_tensors is not supported.
|
||||
raise NotImplementedError
|
||||
|
||||
def set_num_sms(self, num_sms: int):
|
||||
pass
|
||||
|
||||
def max_sms_used(self) -> int | None:
|
||||
return None # None means it could use the whole GPU
|
||||
|
||||
def combine(self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False):
|
||||
raise NotImplementedError
|
||||
|
||||
def destroy(self):
|
||||
pass
|
||||
|
||||
|
||||
class DeviceCommunicatorBase:
|
||||
"""
|
||||
Base class for device-specific communicator.
|
||||
It can use the `cpu_group` to initialize the communicator.
|
||||
If the device has PyTorch integration (PyTorch can recognize its
|
||||
communication backend), the `device_group` will also be given.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cpu_group: ProcessGroup,
|
||||
device: torch.device | None = None,
|
||||
device_group: ProcessGroup | None = None,
|
||||
unique_name: str = "",
|
||||
):
|
||||
self.device = device or torch.device("cpu")
|
||||
self.cpu_group = cpu_group
|
||||
self.device_group = device_group
|
||||
self.unique_name = unique_name
|
||||
self.rank = dist.get_rank(cpu_group)
|
||||
self.world_size = dist.get_world_size(cpu_group)
|
||||
self.ranks = dist.get_process_group_ranks(cpu_group)
|
||||
self.global_rank = dist.get_rank()
|
||||
self.global_world_size = dist.get_world_size()
|
||||
self.rank_in_group = dist.get_group_rank(self.cpu_group, self.global_rank)
|
||||
|
||||
use_ep = False
|
||||
all2all_backend = None
|
||||
from vllm.config import get_current_vllm_config_or_none
|
||||
|
||||
config = get_current_vllm_config_or_none()
|
||||
if config is not None:
|
||||
# as long as we use data parallel (coupled data parallel
|
||||
# where all data parallel ranks execute forward together),
|
||||
# we initialize the all2all manager used in expert parallel.
|
||||
use_ep = config.parallel_config.data_parallel_size > 1
|
||||
all2all_backend = config.parallel_config.all2all_backend
|
||||
|
||||
self.is_ep_communicator = "ep" in unique_name
|
||||
self.use_all2all = self.is_ep_communicator and use_ep
|
||||
self.all2all_backend = all2all_backend
|
||||
self.all2all_manager: All2AllManagerBase | None = None
|
||||
|
||||
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
|
||||
dist.all_reduce(input_, group=self.device_group)
|
||||
return input_
|
||||
|
||||
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
||||
if dim < 0:
|
||||
# Convert negative dim to positive.
|
||||
dim += input_.dim()
|
||||
input_size = input_.size()
|
||||
# NOTE: we have to use concat-style all-gather here,
|
||||
# stack-style all-gather has compatibility issues with
|
||||
# torch.compile . see https://github.com/pytorch/pytorch/issues/138795
|
||||
output_size = (input_size[0] * self.world_size,) + input_size[1:]
|
||||
# Allocate output tensor.
|
||||
output_tensor = torch.empty(
|
||||
output_size, dtype=input_.dtype, device=input_.device
|
||||
)
|
||||
# All-gather.
|
||||
if self.use_vllm_comm:
|
||||
ixfd.all_gather_into_tensor(output_tensor,
|
||||
input_,
|
||||
group=self.device_group,
|
||||
async_op=True)
|
||||
else:
|
||||
torch.distributed.all_gather_into_tensor(output_tensor,
|
||||
input_,
|
||||
group=self.device_group)
|
||||
# Reshape
|
||||
output_tensor = output_tensor.reshape((self.world_size,) + input_size)
|
||||
output_tensor = output_tensor.movedim(0, dim)
|
||||
output_tensor = output_tensor.reshape(
|
||||
input_size[:dim]
|
||||
+ (self.world_size * input_size[dim],)
|
||||
+ input_size[dim + 1 :]
|
||||
)
|
||||
return output_tensor
|
||||
|
||||
def all_gatherv(
|
||||
self,
|
||||
input_: torch.Tensor | list[torch.Tensor],
|
||||
dim: int = 0,
|
||||
sizes: list[int] | None = None,
|
||||
) -> torch.Tensor | list[torch.Tensor]:
|
||||
raise NotImplementedError
|
||||
|
||||
def reduce_scatter(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
||||
world_size = self.world_size
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if world_size == 1:
|
||||
return input_
|
||||
assert -input_.dim() <= dim < input_.dim(), (
|
||||
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}"
|
||||
)
|
||||
|
||||
if dim < 0:
|
||||
# Convert negative dim to positive.
|
||||
dim += input_.dim()
|
||||
|
||||
# Note: This will produce an incorrect answer if we don't make
|
||||
# the input_tensor contiguous. Possible bug in reduce_scatter_tensor?
|
||||
input_tensor = input_.movedim(0, dim).contiguous()
|
||||
|
||||
assert input_tensor.shape[0] % world_size == 0
|
||||
chunk_size = input_tensor.shape[0] // world_size
|
||||
output_shape = (chunk_size,) + input_tensor.shape[1:]
|
||||
|
||||
output_tensor = torch.empty(
|
||||
output_shape, dtype=input_tensor.dtype, device=input_tensor.device
|
||||
)
|
||||
|
||||
# Perform reduce-scatter operation
|
||||
torch.distributed.reduce_scatter_tensor(
|
||||
output_tensor, input_tensor, group=self.device_group
|
||||
)
|
||||
|
||||
# Reshape before returning
|
||||
return output_tensor.movedim(0, dim).contiguous()
|
||||
|
||||
def reduce_scatterv(
|
||||
self, input_: torch.Tensor, dim: int = -1, sizes: list[int] | None = None
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def gather(
|
||||
self, input_: torch.Tensor, dst: int = 0, dim: int = -1
|
||||
) -> torch.Tensor | None:
|
||||
"""
|
||||
NOTE: We assume that the input tensor is on the same device across
|
||||
all the ranks.
|
||||
NOTE: `dst` is the local rank of the destination rank.
|
||||
"""
|
||||
world_size = self.world_size
|
||||
assert -input_.dim() <= dim < input_.dim(), (
|
||||
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}"
|
||||
)
|
||||
if dim < 0:
|
||||
# Convert negative dim to positive.
|
||||
dim += input_.dim()
|
||||
|
||||
# Allocate output tensor.
|
||||
if self.rank_in_group == dst:
|
||||
gather_list = [torch.empty_like(input_) for _ in range(world_size)]
|
||||
else:
|
||||
gather_list = None
|
||||
# Gather.
|
||||
if self.use_vllm_comm:
|
||||
ixfd.gather(input_,
|
||||
gather_list,
|
||||
dst=self.ranks[dst],
|
||||
group=self.device_group,
|
||||
async_op=True)
|
||||
else:
|
||||
torch.distributed.gather(input_,
|
||||
gather_list,
|
||||
dst=self.ranks[dst],
|
||||
group=self.device_group)
|
||||
if self.rank_in_group == dst:
|
||||
output_tensor = torch.cat(gather_list, dim=dim)
|
||||
else:
|
||||
output_tensor = None
|
||||
return output_tensor
|
||||
|
||||
def send(self, tensor: torch.Tensor, dst: int | None = None) -> None:
|
||||
"""Sends a tensor to the destination rank in a blocking way"""
|
||||
"""NOTE: `dst` is the local rank of the destination rank."""
|
||||
if dst is None:
|
||||
dst = (self.rank_in_group + 1) % self.world_size
|
||||
torch.distributed.send(tensor, self.ranks[dst], self.device_group)
|
||||
|
||||
def recv(
|
||||
self, size: torch.Size, dtype: torch.dtype, src: int | None = None
|
||||
) -> torch.Tensor:
|
||||
"""Receives a tensor from the source rank."""
|
||||
"""NOTE: `src` is the local rank of the source rank."""
|
||||
if src is None:
|
||||
src = (self.rank_in_group - 1) % self.world_size
|
||||
|
||||
tensor = torch.empty(size, dtype=dtype, device=self.device)
|
||||
torch.distributed.recv(tensor, self.ranks[src], self.device_group)
|
||||
return tensor
|
||||
|
||||
def destroy(self):
|
||||
pass
|
||||
|
||||
def prepare_communication_buffer_for_model(self, model: torch.nn.Module) -> None:
|
||||
"""
|
||||
Prepare the communication buffer for the model.
|
||||
"""
|
||||
if not self.is_ep_communicator:
|
||||
return
|
||||
|
||||
moe_modules = [
|
||||
module
|
||||
for module in model.modules()
|
||||
# TODO(bnell): Should use isinstance but can't. Maybe search for
|
||||
# presence of quant_method.maybe_init_modular_kernel?
|
||||
if (
|
||||
module.__class__.__name__ == "FusedMoE"
|
||||
or module.__class__.__name__ == "SharedFusedMoE"
|
||||
)
|
||||
]
|
||||
for module in moe_modules:
|
||||
module.maybe_init_modular_kernel()
|
||||
|
||||
def dispatch_router_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_sequence_parallel: bool = False,
|
||||
extra_tensors: list[torch.Tensor] | None = None,
|
||||
) -> (
|
||||
tuple[torch.Tensor, torch.Tensor]
|
||||
| tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
|
||||
):
|
||||
"""
|
||||
Dispatch the hidden states and router logits to the appropriate device.
|
||||
This is a no-op in the base class.
|
||||
"""
|
||||
if extra_tensors is not None:
|
||||
return hidden_states, router_logits, extra_tensors
|
||||
return hidden_states, router_logits
|
||||
|
||||
def dispatch(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
is_sequence_parallel: bool = False,
|
||||
extra_tensors: list[torch.Tensor] | None = None,
|
||||
) -> (
|
||||
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
|
||||
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
|
||||
):
|
||||
"""
|
||||
Dispatch the hidden states and topk weights/ids to the appropriate device.
|
||||
This is a no-op in the base class.
|
||||
"""
|
||||
if extra_tensors is not None:
|
||||
return hidden_states, topk_weights, topk_ids, extra_tensors
|
||||
return hidden_states, topk_weights, topk_ids
|
||||
|
||||
def combine(
|
||||
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Combine the hidden states and router logits from the appropriate device.
|
||||
This is a no-op in the base class.
|
||||
"""
|
||||
return hidden_states
|
||||
301
vllm/distributed/device_communicators/cpu_communicator.py
Normal file
301
vllm/distributed/device_communicators/cpu_communicator.py
Normal file
@@ -0,0 +1,301 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from vllm.distributed.utils import pickle
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.interface import CpuArchEnum
|
||||
|
||||
from .base_device_communicator import DeviceCommunicatorBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class CpuCommunicator(DeviceCommunicatorBase):
|
||||
def __init__(
|
||||
self,
|
||||
cpu_group: ProcessGroup,
|
||||
device: torch.device | None = None,
|
||||
device_group: ProcessGroup | None = None,
|
||||
unique_name: str = "",
|
||||
):
|
||||
super().__init__(cpu_group, device, device_group, unique_name)
|
||||
self.dist_module = torch.distributed
|
||||
|
||||
if (
|
||||
(
|
||||
current_platform.get_cpu_architecture() == CpuArchEnum.X86
|
||||
or current_platform.get_cpu_architecture() == CpuArchEnum.ARM
|
||||
)
|
||||
and hasattr(torch.ops._C, "init_shm_manager")
|
||||
and (unique_name.startswith("tp") or unique_name.startswith("pp"))
|
||||
):
|
||||
self.dist_module = _CPUSHMDistributed(self)
|
||||
|
||||
if self.use_all2all:
|
||||
if self.all2all_backend != "naive": # type: ignore[has-type]
|
||||
logger.warning(
|
||||
"`%s` all2all manager is not supported on CPU. "
|
||||
"Falling back to `naive` all2all manager for CPU.",
|
||||
self.all2all_backend, # type: ignore[has-type]
|
||||
)
|
||||
self.all2all_backend = "naive"
|
||||
if self.all2all_backend == "naive":
|
||||
from .all2all import NaiveAll2AllManager
|
||||
|
||||
self.all2all_manager = NaiveAll2AllManager(self.cpu_group)
|
||||
logger.info("Using naive all2all manager.")
|
||||
|
||||
def all_reduce(self, input_):
|
||||
self.dist_module.all_reduce(input_, group=self.device_group)
|
||||
return input_
|
||||
|
||||
def gather(
|
||||
self, input_: torch.Tensor, dst: int = 0, dim: int = -1
|
||||
) -> torch.Tensor | None:
|
||||
"""
|
||||
NOTE: We assume that the input tensor is on the same device across
|
||||
all the ranks.
|
||||
NOTE: `dst` is the local rank of the destination rank.
|
||||
"""
|
||||
world_size = self.world_size
|
||||
assert -input_.dim() <= dim < input_.dim(), (
|
||||
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}"
|
||||
)
|
||||
if dim < 0:
|
||||
# Convert negative dim to positive.
|
||||
dim += input_.dim()
|
||||
|
||||
# Allocate output tensor.
|
||||
if self.rank_in_group == dst:
|
||||
gather_list = [torch.empty_like(input_) for _ in range(world_size)]
|
||||
else:
|
||||
gather_list = None
|
||||
|
||||
# Gather.
|
||||
self.dist_module.gather(
|
||||
input_, gather_list, dst=self.ranks[dst], group=self.device_group
|
||||
)
|
||||
|
||||
if self.rank_in_group == dst:
|
||||
output_tensor = torch.cat(gather_list, dim=dim)
|
||||
else:
|
||||
output_tensor = None
|
||||
return output_tensor
|
||||
|
||||
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
||||
if dim < 0:
|
||||
# Convert negative dim to positive.
|
||||
dim += input_.dim()
|
||||
input_size = input_.size()
|
||||
# NOTE: we have to use concat-style all-gather here,
|
||||
# stack-style all-gather has compatibility issues with
|
||||
# torch.compile . see https://github.com/pytorch/pytorch/issues/138795
|
||||
output_size = (input_size[0] * self.world_size,) + input_size[1:]
|
||||
# Allocate output tensor.
|
||||
output_tensor = torch.empty(
|
||||
output_size, dtype=input_.dtype, device=input_.device
|
||||
)
|
||||
# All-gather.
|
||||
self.dist_module.all_gather_into_tensor(
|
||||
output_tensor, input_, group=self.device_group
|
||||
)
|
||||
|
||||
# Reshape
|
||||
output_tensor = output_tensor.reshape((self.world_size,) + input_size)
|
||||
output_tensor = output_tensor.movedim(0, dim)
|
||||
output_tensor = output_tensor.reshape(
|
||||
input_size[:dim]
|
||||
+ (self.world_size * input_size[dim],)
|
||||
+ input_size[dim + 1 :]
|
||||
)
|
||||
return output_tensor
|
||||
|
||||
def send_tensor_dict(
|
||||
self,
|
||||
tensor_dict: dict[str, torch.Tensor | Any],
|
||||
dst: int,
|
||||
) -> None:
|
||||
return self.dist_module.send_tensor_dict(tensor_dict, dst)
|
||||
|
||||
def recv_tensor_dict(
|
||||
self,
|
||||
src: int,
|
||||
) -> dict[str, torch.Tensor | Any]:
|
||||
return self.dist_module.recv_tensor_dict(src)
|
||||
|
||||
def dispatch_router_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_sequence_parallel: bool = False,
|
||||
extra_tensors: list[torch.Tensor] | None = None,
|
||||
) -> (
|
||||
tuple[torch.Tensor, torch.Tensor]
|
||||
| tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
|
||||
):
|
||||
"""
|
||||
Dispatch the hidden states and router logits to the appropriate device.
|
||||
This is a no-op in the base class.
|
||||
"""
|
||||
|
||||
assert self.all2all_manager is not None
|
||||
return self.all2all_manager.dispatch_router_logits(
|
||||
hidden_states,
|
||||
router_logits,
|
||||
is_sequence_parallel,
|
||||
extra_tensors,
|
||||
)
|
||||
|
||||
def dispatch(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
is_sequence_parallel: bool = False,
|
||||
extra_tensors: list[torch.Tensor] | None = None,
|
||||
) -> (
|
||||
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
|
||||
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
|
||||
):
|
||||
"""
|
||||
Dispatch the hidden states and topk weights/ids to the appropriate device.
|
||||
This is a no-op in the base class.
|
||||
"""
|
||||
assert self.all2all_manager is not None
|
||||
return self.all2all_manager.dispatch(
|
||||
hidden_states,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
is_sequence_parallel,
|
||||
extra_tensors=extra_tensors,
|
||||
)
|
||||
|
||||
def combine(
|
||||
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Combine the hidden states and router logits from the appropriate device.
|
||||
This is a no-op in the base class.
|
||||
"""
|
||||
assert self.all2all_manager is not None
|
||||
return self.all2all_manager.combine(
|
||||
hidden_states,
|
||||
is_sequence_parallel,
|
||||
)
|
||||
|
||||
|
||||
class _CPUSHMDistributed:
|
||||
def __init__(self, communicator: CpuCommunicator):
|
||||
instance_identifier = os.environ["VLLM_DIST_IDENT"]
|
||||
unique_name = communicator.unique_name
|
||||
instance_identifier = f"{instance_identifier}-{unique_name}"
|
||||
self.communicator = communicator
|
||||
|
||||
group_ranks = [str(rank) for rank in self.communicator.ranks]
|
||||
shm_group_identifier = f"[{'-'.join(group_ranks)}]"
|
||||
self.group_name = f"{instance_identifier}-{shm_group_identifier}-cpushm"
|
||||
|
||||
self.handle = self._init_cpu_shm()
|
||||
|
||||
def _init_cpu_shm(self) -> int:
|
||||
thread_num_tensor = torch.tensor(
|
||||
[torch.get_num_threads()],
|
||||
dtype=torch.int64,
|
||||
)
|
||||
torch.distributed.all_reduce(
|
||||
thread_num_tensor,
|
||||
op=torch.distributed.ReduceOp.MIN,
|
||||
group=self.communicator.device_group,
|
||||
)
|
||||
thread_num = thread_num_tensor.item()
|
||||
|
||||
handle = torch.ops._C.init_shm_manager(
|
||||
self.group_name,
|
||||
self.communicator.world_size,
|
||||
self.communicator.rank,
|
||||
thread_num,
|
||||
)
|
||||
torch.distributed.barrier(self.communicator.device_group)
|
||||
torch.ops._C.join_shm_manager(
|
||||
handle,
|
||||
self.group_name,
|
||||
)
|
||||
torch.distributed.barrier(self.communicator.device_group)
|
||||
|
||||
return handle
|
||||
|
||||
def all_reduce(
|
||||
self, input: torch.Tensor, group: ProcessGroup | None = None
|
||||
) -> None:
|
||||
torch.ops._C.shm_allreduce(self.handle, input)
|
||||
|
||||
def gather(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
gather_list: list[torch.Tensor] | None,
|
||||
dst: int = -1,
|
||||
group: ProcessGroup | None = None,
|
||||
) -> None:
|
||||
# Note: different from the torch gather, here we use local dst rank.
|
||||
torch.ops._C.shm_gather(
|
||||
self.handle,
|
||||
input,
|
||||
gather_list,
|
||||
torch.distributed.get_group_rank(group, dst),
|
||||
)
|
||||
|
||||
def all_gather_into_tensor(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
group: ProcessGroup | None = None,
|
||||
) -> None:
|
||||
torch.ops._C.shm_all_gather(self.handle, input, output)
|
||||
|
||||
def send_tensor_dict(
|
||||
self,
|
||||
tensor_dict: dict[str, torch.Tensor | Any],
|
||||
dst: int,
|
||||
) -> None:
|
||||
key_list = list(tensor_dict.keys())
|
||||
value_list = list(tensor_dict.values())
|
||||
size_list = []
|
||||
for v in value_list:
|
||||
if not isinstance(v, torch.Tensor):
|
||||
raise RuntimeError("CpuCommunicator only supports sending tensors.")
|
||||
size_list.append(v.size())
|
||||
key_size_tensor = torch.frombuffer(
|
||||
pickle.dumps([key_list, size_list]), dtype=torch.uint8
|
||||
)
|
||||
value_list.append(key_size_tensor)
|
||||
|
||||
torch.ops._C.shm_send_tensor_list(self.handle, value_list, dst)
|
||||
|
||||
return None
|
||||
|
||||
def recv_tensor_dict(
|
||||
self,
|
||||
src: int,
|
||||
) -> dict[str, torch.Tensor | Any]:
|
||||
tensor_list = torch.ops._C.shm_recv_tensor_list(self.handle, src)
|
||||
|
||||
value_list: list[torch.Tensor] = tensor_list[:-1]
|
||||
key_size_tensor = tensor_list[-1]
|
||||
|
||||
key_size = pickle.loads(key_size_tensor.numpy().tobytes())
|
||||
key_list = key_size[0]
|
||||
size_list = key_size[1]
|
||||
assert len(key_list) == len(size_list)
|
||||
assert len(key_list) == len(value_list)
|
||||
|
||||
tensor_dict: dict[str, torch.Tensor] = {}
|
||||
for key, size, t in zip(key_list, size_list, value_list):
|
||||
tensor_dict[key] = t.view(size)
|
||||
return tensor_dict
|
||||
434
vllm/distributed/device_communicators/cuda_communicator.py
Normal file
434
vllm/distributed/device_communicators/cuda_communicator.py
Normal file
@@ -0,0 +1,434 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import torch
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.distributed.device_communicators.all_reduce_utils import (
|
||||
should_nccl_symm_mem_allreduce,
|
||||
)
|
||||
from vllm.distributed.device_communicators.pynccl import register_nccl_symmetric_ops
|
||||
from vllm.distributed.device_communicators.pynccl_allocator import (
|
||||
is_symmetric_memory_enabled,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .base_device_communicator import DeviceCommunicatorBase
|
||||
import ixformer.distributed as ixfd
|
||||
import os
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class CudaCommunicator(DeviceCommunicatorBase):
|
||||
def __init__(
|
||||
self,
|
||||
cpu_group: ProcessGroup,
|
||||
device: torch.device | None = None,
|
||||
device_group: ProcessGroup | None = None,
|
||||
unique_name: str = "",
|
||||
):
|
||||
super().__init__(cpu_group, device, device_group, unique_name)
|
||||
if "tp" not in unique_name:
|
||||
# custom allreduce or torch symm mem can be used only by tp
|
||||
use_custom_allreduce = False
|
||||
use_torch_symm_mem = False
|
||||
use_flashinfer_allreduce = False
|
||||
else:
|
||||
from vllm.distributed.parallel_state import _ENABLE_CUSTOM_ALL_REDUCE
|
||||
|
||||
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
|
||||
use_torch_symm_mem = envs.VLLM_ALLREDUCE_USE_SYMM_MEM
|
||||
use_flashinfer_allreduce = envs.VLLM_ALLREDUCE_USE_FLASHINFER
|
||||
|
||||
self.use_custom_allreduce = use_custom_allreduce
|
||||
self.use_torch_symm_mem = use_torch_symm_mem
|
||||
self.use_flashinfer_allreduce = use_flashinfer_allreduce
|
||||
|
||||
self.use_vllm_comm = os.environ.get("VLLM_FORCE_NCCL_COMM",None) not in ["1", "Y", "y"]
|
||||
# lazy import to avoid documentation build error
|
||||
from vllm.distributed.device_communicators.custom_all_reduce import (
|
||||
CustomAllreduce,
|
||||
)
|
||||
from vllm.distributed.device_communicators.flashinfer_all_reduce import (
|
||||
FlashInferAllReduce,
|
||||
)
|
||||
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
|
||||
from vllm.distributed.device_communicators.quick_all_reduce import (
|
||||
QuickAllReduce,
|
||||
)
|
||||
from vllm.distributed.device_communicators.symm_mem import SymmMemCommunicator
|
||||
|
||||
self.pynccl_comm: PyNcclCommunicator | None = None
|
||||
if self.world_size > 1:
|
||||
self.pynccl_comm = PyNcclCommunicator(
|
||||
group=self.cpu_group,
|
||||
device=self.device,
|
||||
)
|
||||
if is_symmetric_memory_enabled():
|
||||
register_nccl_symmetric_ops(self.pynccl_comm)
|
||||
|
||||
self.ca_comm: CustomAllreduce | None = None
|
||||
self.qr_comm: QuickAllReduce | None = None
|
||||
self.symm_mem_comm: SymmMemCommunicator | None = None
|
||||
self.fi_ar_comm: FlashInferAllReduce | None = None
|
||||
|
||||
if use_torch_symm_mem and current_platform.is_cuda():
|
||||
self.symm_mem_comm = SymmMemCommunicator(
|
||||
group=self.cpu_group,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
if self.use_flashinfer_allreduce and self.world_size > 1:
|
||||
self.fi_ar_comm = FlashInferAllReduce(
|
||||
group=self.cpu_group,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
if use_custom_allreduce and self.world_size > 1:
|
||||
# Initialize a custom fast all-reduce implementation.
|
||||
self.ca_comm = CustomAllreduce(
|
||||
group=self.cpu_group,
|
||||
device=self.device,
|
||||
symm_mem_enabled=(
|
||||
self.symm_mem_comm is not None and not self.symm_mem_comm.disabled
|
||||
),
|
||||
)
|
||||
|
||||
if current_platform.is_rocm():
|
||||
# Initialize a custom quick all-reduce implementation for AMD.
|
||||
# Quick reduce is designed as a complement to custom allreduce.
|
||||
# Based on quickreduce (https://github.com/mk1-project/quickreduce).
|
||||
# If it's a rocm, 'use_custom_allreduce==True' means it must
|
||||
# currently be an MI300 series.
|
||||
self.qr_comm = QuickAllReduce(group=self.cpu_group, device=self.device)
|
||||
|
||||
if self.use_all2all:
|
||||
if self.all2all_backend == "naive":
|
||||
from .all2all import NaiveAll2AllManager
|
||||
|
||||
self.all2all_manager = NaiveAll2AllManager(self.cpu_group)
|
||||
elif self.all2all_backend == "allgather_reducescatter":
|
||||
from .all2all import AgRsAll2AllManager
|
||||
|
||||
self.all2all_manager = AgRsAll2AllManager(self.cpu_group)
|
||||
elif self.all2all_backend == "pplx":
|
||||
from .all2all import PPLXAll2AllManager
|
||||
|
||||
self.all2all_manager = PPLXAll2AllManager(self.cpu_group)
|
||||
elif self.all2all_backend == "deepep_high_throughput":
|
||||
from .all2all import DeepEPHTAll2AllManager
|
||||
|
||||
self.all2all_manager = DeepEPHTAll2AllManager(self.cpu_group)
|
||||
elif self.all2all_backend == "deepep_low_latency":
|
||||
from .all2all import DeepEPLLAll2AllManager
|
||||
|
||||
self.all2all_manager = DeepEPLLAll2AllManager(self.cpu_group)
|
||||
elif self.all2all_backend == "mori":
|
||||
from .all2all import MoriAll2AllManager
|
||||
|
||||
self.all2all_manager = MoriAll2AllManager(self.cpu_group)
|
||||
elif self.all2all_backend == "flashinfer_all2allv":
|
||||
from .all2all import FlashInferAllToAllManager
|
||||
|
||||
self.all2all_manager = FlashInferAllToAllManager(self.cpu_group)
|
||||
else:
|
||||
raise ValueError(f"Unknown all2all backend: {self.all2all_backend}")
|
||||
|
||||
logger.info_once(
|
||||
"Using %s all2all manager.",
|
||||
self.all2all_manager.__class__.__name__,
|
||||
scope="global",
|
||||
)
|
||||
|
||||
def all_reduce(self, input_):
|
||||
# since currently we perform copy input -> symm_input -> out-of-place AR
|
||||
# return symm_output, we don't need to check if input is symmetric
|
||||
if self.pynccl_comm is not None and should_nccl_symm_mem_allreduce(
|
||||
self.pynccl_comm.world_size, input_
|
||||
):
|
||||
out = torch.ops.vllm.all_reduce_symmetric_with_copy(input_)
|
||||
if out is not None:
|
||||
return out
|
||||
# always try quick reduce first, then flashinfer, then custom allreduce,
|
||||
# and then pynccl. (quick reduce just for ROCM MI3*)
|
||||
qr_comm = self.qr_comm
|
||||
if (
|
||||
qr_comm is not None
|
||||
and not qr_comm.disabled
|
||||
and qr_comm.should_quick_allreduce(input_)
|
||||
):
|
||||
out = qr_comm.quick_all_reduce(input_)
|
||||
assert out is not None
|
||||
return out
|
||||
fi_ar_comm = self.fi_ar_comm
|
||||
if (
|
||||
fi_ar_comm is not None
|
||||
and not fi_ar_comm.disabled
|
||||
and fi_ar_comm.should_use_fi_ar(input_)
|
||||
):
|
||||
out = fi_ar_comm.all_reduce(input_)
|
||||
assert out is not None
|
||||
return out
|
||||
ca_comm = self.ca_comm
|
||||
if (
|
||||
ca_comm is not None
|
||||
and not ca_comm.disabled
|
||||
and ca_comm.should_custom_ar(input_)
|
||||
):
|
||||
out = ca_comm.custom_all_reduce(input_)
|
||||
assert out is not None
|
||||
return out
|
||||
symm_mem_comm = self.symm_mem_comm
|
||||
if symm_mem_comm is not None and symm_mem_comm.should_use_symm_mem(input_):
|
||||
out = symm_mem_comm.all_reduce(input_)
|
||||
assert out is not None
|
||||
return out
|
||||
if self.world_size == 1:
|
||||
return input_
|
||||
if self.use_vllm_comm:
|
||||
# torch.ops.ixf_ops.vllm_all_reduce(input_, async_op=True)
|
||||
ixfd.all_reduce(input_, group=self.device_group, async_op=True)
|
||||
else:
|
||||
torch.distributed.all_reduce(input_, group=self.device_group)
|
||||
return input_
|
||||
pynccl_comm = self.pynccl_comm
|
||||
if pynccl_comm is None or pynccl_comm.disabled:
|
||||
out = input_.clone()
|
||||
torch.distributed.all_reduce(out, group=self.device_group)
|
||||
return out
|
||||
assert pynccl_comm is not None
|
||||
out = pynccl_comm.all_reduce(input_)
|
||||
if out is None:
|
||||
# fall back to the default all-reduce using PyTorch.
|
||||
# this usually happens during testing.
|
||||
# when we run the model, allreduce only happens for the TP
|
||||
# group, where we always have either custom allreduce or pynccl.
|
||||
out = input_.clone()
|
||||
torch.distributed.all_reduce(out, group=self.device_group)
|
||||
return out
|
||||
|
||||
def reduce_scatter(self, input_: torch.Tensor, dim: int = -1):
|
||||
world_size = self.world_size
|
||||
pynccl_comm = self.pynccl_comm
|
||||
assert pynccl_comm is not None
|
||||
if dim < 0:
|
||||
# Convert negative dim to positive.
|
||||
dim += input_.dim()
|
||||
|
||||
# Note: This will produce an incorrect answer if we don't make
|
||||
# the input_tensor contiguous. Possible bug in reduce_scatter_tensor?
|
||||
input_tensor = input_.movedim(0, dim).contiguous()
|
||||
|
||||
assert input_tensor.shape[0] % world_size == 0
|
||||
chunk_size = input_tensor.shape[0] // world_size
|
||||
output_shape = (chunk_size,) + input_tensor.shape[1:]
|
||||
|
||||
output = torch.empty(
|
||||
output_shape, dtype=input_tensor.dtype, device=input_tensor.device
|
||||
)
|
||||
|
||||
# pynccl_comm.reduce_scatter(output, input_tensor)
|
||||
torch.distributed.reduce_scatter_tensor(output,
|
||||
input_tensor,
|
||||
group=self.device_group)
|
||||
|
||||
# Reshape before returning
|
||||
return output.movedim(0, dim).contiguous()
|
||||
|
||||
def reduce_scatterv(
|
||||
self, input_: torch.Tensor, dim: int = -1, sizes: list[int] | None = None
|
||||
):
|
||||
world_size = self.world_size
|
||||
pynccl_comm = self.pynccl_comm
|
||||
assert pynccl_comm is not None
|
||||
if dim < 0:
|
||||
# Convert negative dim to positive.
|
||||
dim += input_.dim()
|
||||
|
||||
# Note: This will produce an incorrect answer if we don't make
|
||||
# the input_tensor contiguous. Possible bug in reduce_scatter_tensor?
|
||||
input_tensor = input_.movedim(0, dim).contiguous()
|
||||
|
||||
if sizes is not None:
|
||||
assert len(sizes) == world_size
|
||||
assert input_tensor.shape[0] == sum(sizes)
|
||||
chunk_size = sizes[self.rank_in_group]
|
||||
else:
|
||||
assert input_tensor.shape[0] % world_size == 0
|
||||
chunk_size = input_tensor.shape[0] // world_size
|
||||
output_shape = (chunk_size,) + input_tensor.shape[1:]
|
||||
|
||||
output = torch.empty(
|
||||
output_shape, dtype=input_tensor.dtype, device=input_tensor.device
|
||||
)
|
||||
|
||||
if sizes is not None and sizes.count(sizes[0]) != len(sizes):
|
||||
pynccl_comm.reduce_scatterv(output, input_tensor, sizes=sizes)
|
||||
else:
|
||||
pynccl_comm.reduce_scatter(output, input_tensor)
|
||||
|
||||
# Reshape before returning
|
||||
return output.movedim(0, dim).contiguous()
|
||||
|
||||
def send(self, tensor: torch.Tensor, dst: int | None = None) -> None:
|
||||
"""Sends a tensor to the destination rank in a blocking way"""
|
||||
"""NOTE: `dst` is the local rank of the destination rank."""
|
||||
if dst is None:
|
||||
dst = (self.rank_in_group + 1) % self.world_size
|
||||
|
||||
pynccl_comm = self.pynccl_comm
|
||||
# if pynccl_comm is not None and not pynccl_comm.disabled:
|
||||
# pynccl_comm.send(tensor, dst)
|
||||
# else:
|
||||
# torch.distributed.send(tensor, self.ranks[dst], self.device_group)
|
||||
if self.use_vllm_comm:
|
||||
ixfd.send(tensor, self.ranks[dst], self.device_group)
|
||||
else:
|
||||
torch.distributed.send(tensor, self.ranks[dst], self.device_group)
|
||||
|
||||
def recv(
|
||||
self, size: torch.Size, dtype: torch.dtype, src: int | None = None
|
||||
) -> torch.Tensor:
|
||||
"""Receives a tensor from the source rank."""
|
||||
"""NOTE: `src` is the local rank of the source rank."""
|
||||
if src is None:
|
||||
src = (self.rank_in_group - 1) % self.world_size
|
||||
|
||||
tensor = torch.empty(size, dtype=dtype, device=self.device)
|
||||
# pynccl_comm = self.pynccl_comm
|
||||
# if pynccl_comm is not None and not pynccl_comm.disabled:
|
||||
# pynccl_comm.recv(tensor, src)
|
||||
# else:
|
||||
# torch.distributed.recv(tensor, self.ranks[src], self.device_group)
|
||||
if self.use_vllm_comm:
|
||||
ixfd.recv(tensor, self.ranks[src], self.device_group)
|
||||
else:
|
||||
torch.distributed.recv(tensor, self.ranks[src], self.device_group)
|
||||
return tensor
|
||||
|
||||
def destroy(self):
|
||||
if self.pynccl_comm is not None:
|
||||
self.pynccl_comm = None
|
||||
if self.ca_comm is not None:
|
||||
self.ca_comm = None
|
||||
if self.fi_ar_comm is not None:
|
||||
self.fi_ar_comm.destroy()
|
||||
self.fi_ar_comm = None
|
||||
if self.all2all_manager is not None:
|
||||
self.all2all_manager.destroy()
|
||||
self.all2all_manager = None
|
||||
|
||||
def all_gatherv(
|
||||
self,
|
||||
input_: torch.Tensor | list[torch.Tensor],
|
||||
dim: int = 0,
|
||||
sizes: list[int] | None = None,
|
||||
):
|
||||
if dim != 0:
|
||||
raise NotImplementedError("only dim 0 all-gatherv is supported")
|
||||
world_size = self.world_size
|
||||
pynccl_comm = self.pynccl_comm
|
||||
assert pynccl_comm is not None and not pynccl_comm.disabled
|
||||
|
||||
# 'sizes' is not needed if all inputs in the same group have the same
|
||||
# shape
|
||||
if sizes is not None and all(s == sizes[0] for s in sizes):
|
||||
sizes = None
|
||||
|
||||
def _all_gather_single(input_: torch.Tensor, sizes: list[int] | None = None):
|
||||
input_size = input_.size()
|
||||
if sizes is not None:
|
||||
assert len(sizes) == world_size
|
||||
assert input_.shape[dim] == sizes[self.rank_in_group], (
|
||||
f"{input_.shape[dim]} != {sizes[self.rank_in_group]}"
|
||||
)
|
||||
output_size = (sum(sizes),) + input_size[1:]
|
||||
else:
|
||||
output_size = (input_size[0] * world_size,) + input_size[1:]
|
||||
# Allocate output tensor.
|
||||
output_tensor = torch.empty(
|
||||
output_size, dtype=input_.dtype, device=input_.device
|
||||
)
|
||||
if sizes is not None:
|
||||
pynccl_comm.all_gatherv(output_tensor, input_, sizes=sizes)
|
||||
else:
|
||||
pynccl_comm.all_gather(output_tensor, input_)
|
||||
return output_tensor
|
||||
|
||||
if isinstance(input_, torch.Tensor):
|
||||
return _all_gather_single(input_, sizes)
|
||||
|
||||
output_list = []
|
||||
pynccl_comm.group_start()
|
||||
for inp in input_:
|
||||
output_list.append(_all_gather_single(inp, sizes=sizes))
|
||||
pynccl_comm.group_end()
|
||||
|
||||
return output_list
|
||||
|
||||
def dispatch_router_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
extra_residual:torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_sequence_parallel: bool = False,
|
||||
extra_tensors: list[torch.Tensor] | None = None,
|
||||
) -> (
|
||||
tuple[torch.Tensor, torch.Tensor]
|
||||
| tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
|
||||
):
|
||||
"""
|
||||
Dispatch the hidden states and router logits to the appropriate device.
|
||||
This is a no-op in the base class.
|
||||
"""
|
||||
|
||||
assert self.all2all_manager is not None
|
||||
return self.all2all_manager.dispatch_router_logits(
|
||||
hidden_states,
|
||||
router_logits,
|
||||
is_sequence_parallel,
|
||||
extra_tensors,
|
||||
)
|
||||
|
||||
def dispatch(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
is_sequence_parallel: bool = False,
|
||||
extra_tensors: list[torch.Tensor] | None = None,
|
||||
) -> (
|
||||
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
|
||||
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
|
||||
):
|
||||
"""
|
||||
Dispatch the hidden states and topk weights/ids to the appropriate device.
|
||||
This is a no-op in the base class.
|
||||
"""
|
||||
assert self.all2all_manager is not None
|
||||
# return self.all2all_manager.dispatch(
|
||||
# hidden_states,
|
||||
# topk_weights,
|
||||
# topk_ids,
|
||||
# is_sequence_parallel,
|
||||
# extra_tensors=extra_tensors,
|
||||
# )
|
||||
hidden_states, extra_residual, router_logits = self.all2all_manager.dispatch(
|
||||
hidden_states, extra_residual, router_logits)
|
||||
return hidden_states, extra_residual, router_logits
|
||||
|
||||
def combine(
|
||||
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Combine the hidden states and router logits from the appropriate device.
|
||||
This is a no-op in the base class.
|
||||
"""
|
||||
assert self.all2all_manager is not None
|
||||
return self.all2all_manager.combine(
|
||||
hidden_states,
|
||||
is_sequence_parallel,
|
||||
)
|
||||
190
vllm/distributed/device_communicators/cuda_wrapper.py
Normal file
190
vllm/distributed/device_communicators/cuda_wrapper.py
Normal file
@@ -0,0 +1,190 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""This file is a pure Python wrapper for the cudart library.
|
||||
It avoids the need to compile a separate shared library, and is
|
||||
convenient for use when we just need to call a few functions.
|
||||
"""
|
||||
|
||||
import ctypes
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
# this line makes it possible to directly load `libcudart.so` using `ctypes`
|
||||
import torch # noqa
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.system_utils import find_loaded_library
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# === export types and functions from cudart to Python ===
|
||||
# for the original cudart definition, please check
|
||||
# https://docs.nvidia.com/cuda/cuda-runtime-api/index.html
|
||||
|
||||
cudaError_t = ctypes.c_int
|
||||
cudaMemcpyKind = ctypes.c_int
|
||||
|
||||
|
||||
class cudaIpcMemHandle_t(ctypes.Structure):
|
||||
_fields_ = [("internal", ctypes.c_byte * 128)]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Function:
|
||||
name: str
|
||||
restype: Any
|
||||
argtypes: list[Any]
|
||||
|
||||
|
||||
class CudaRTLibrary:
|
||||
exported_functions = [
|
||||
# cudaError_t cudaSetDevice ( int device )
|
||||
Function("cudaSetDevice", cudaError_t, [ctypes.c_int]),
|
||||
# cudaError_t cudaDeviceSynchronize ( void )
|
||||
Function("cudaDeviceSynchronize", cudaError_t, []),
|
||||
# cudaError_t cudaDeviceReset ( void )
|
||||
Function("cudaDeviceReset", cudaError_t, []),
|
||||
# const char* cudaGetErrorString ( cudaError_t error )
|
||||
Function("cudaGetErrorString", ctypes.c_char_p, [cudaError_t]),
|
||||
# cudaError_t cudaMalloc ( void** devPtr, size_t size )
|
||||
Function(
|
||||
"cudaMalloc",
|
||||
cudaError_t,
|
||||
[ctypes.POINTER(ctypes.c_void_p), ctypes.c_size_t],
|
||||
),
|
||||
# cudaError_t cudaFree ( void* devPtr )
|
||||
Function("cudaFree", cudaError_t, [ctypes.c_void_p]),
|
||||
# cudaError_t cudaMemset ( void* devPtr, int value, size_t count )
|
||||
Function(
|
||||
"cudaMemset", cudaError_t, [ctypes.c_void_p, ctypes.c_int, ctypes.c_size_t]
|
||||
),
|
||||
# cudaError_t cudaMemcpy ( void* dst, const void* src, size_t count, cudaMemcpyKind kind ) # noqa
|
||||
Function(
|
||||
"cudaMemcpy",
|
||||
cudaError_t,
|
||||
[ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, cudaMemcpyKind],
|
||||
),
|
||||
# cudaError_t cudaIpcGetMemHandle ( cudaIpcMemHandle_t* handle, void* devPtr ) # noqa
|
||||
Function(
|
||||
"cudaIpcGetMemHandle",
|
||||
cudaError_t,
|
||||
[ctypes.POINTER(cudaIpcMemHandle_t), ctypes.c_void_p],
|
||||
),
|
||||
# cudaError_t cudaIpcOpenMemHandle ( void** devPtr, cudaIpcMemHandle_t handle, unsigned int flags ) # noqa
|
||||
Function(
|
||||
"cudaIpcOpenMemHandle",
|
||||
cudaError_t,
|
||||
[ctypes.POINTER(ctypes.c_void_p), cudaIpcMemHandle_t, ctypes.c_uint],
|
||||
),
|
||||
]
|
||||
|
||||
# https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Runtime_API_functions_supported_by_HIP.html # noqa
|
||||
cuda_to_hip_mapping = {
|
||||
"cudaSetDevice": "hipSetDevice",
|
||||
"cudaDeviceSynchronize": "hipDeviceSynchronize",
|
||||
"cudaDeviceReset": "hipDeviceReset",
|
||||
"cudaGetErrorString": "hipGetErrorString",
|
||||
"cudaMalloc": "hipMalloc",
|
||||
"cudaFree": "hipFree",
|
||||
"cudaMemset": "hipMemset",
|
||||
"cudaMemcpy": "hipMemcpy",
|
||||
"cudaIpcGetMemHandle": "hipIpcGetMemHandle",
|
||||
"cudaIpcOpenMemHandle": "hipIpcOpenMemHandle",
|
||||
}
|
||||
|
||||
# class attribute to store the mapping from the path to the library
|
||||
# to avoid loading the same library multiple times
|
||||
path_to_library_cache: dict[str, Any] = {}
|
||||
|
||||
# class attribute to store the mapping from library path
|
||||
# to the corresponding dictionary
|
||||
path_to_dict_mapping: dict[str, dict[str, Any]] = {}
|
||||
|
||||
def __init__(self, so_file: str | None = None):
|
||||
if so_file is None:
|
||||
so_file = find_loaded_library("libcudart")
|
||||
if so_file is None:
|
||||
# libcudart is not loaded in the current process, try hip
|
||||
so_file = find_loaded_library("libamdhip64")
|
||||
# should be safe to assume now that we are using ROCm
|
||||
# as the following assertion should error out if the
|
||||
# libhiprtc library is also not loaded
|
||||
if so_file is None:
|
||||
so_file = envs.VLLM_CUDART_SO_PATH # fallback to env var
|
||||
assert so_file is not None, (
|
||||
"libcudart is not loaded in the current process, "
|
||||
"try setting VLLM_CUDART_SO_PATH"
|
||||
)
|
||||
if so_file not in CudaRTLibrary.path_to_library_cache:
|
||||
lib = ctypes.CDLL(so_file)
|
||||
CudaRTLibrary.path_to_library_cache[so_file] = lib
|
||||
self.lib = CudaRTLibrary.path_to_library_cache[so_file]
|
||||
|
||||
if so_file not in CudaRTLibrary.path_to_dict_mapping:
|
||||
_funcs = {}
|
||||
for func in CudaRTLibrary.exported_functions:
|
||||
f = getattr(
|
||||
self.lib,
|
||||
CudaRTLibrary.cuda_to_hip_mapping[func.name]
|
||||
if current_platform.is_rocm()
|
||||
else func.name,
|
||||
)
|
||||
f.restype = func.restype
|
||||
f.argtypes = func.argtypes
|
||||
_funcs[func.name] = f
|
||||
CudaRTLibrary.path_to_dict_mapping[so_file] = _funcs
|
||||
self.funcs = CudaRTLibrary.path_to_dict_mapping[so_file]
|
||||
|
||||
def CUDART_CHECK(self, result: cudaError_t) -> None:
|
||||
if result != 0:
|
||||
error_str = self.cudaGetErrorString(result)
|
||||
raise RuntimeError(f"CUDART error: {error_str}")
|
||||
|
||||
def cudaGetErrorString(self, error: cudaError_t) -> str:
|
||||
return self.funcs["cudaGetErrorString"](error).decode("utf-8")
|
||||
|
||||
def cudaSetDevice(self, device: int) -> None:
|
||||
self.CUDART_CHECK(self.funcs["cudaSetDevice"](device))
|
||||
|
||||
def cudaDeviceSynchronize(self) -> None:
|
||||
self.CUDART_CHECK(self.funcs["cudaDeviceSynchronize"]())
|
||||
|
||||
def cudaDeviceReset(self) -> None:
|
||||
self.CUDART_CHECK(self.funcs["cudaDeviceReset"]())
|
||||
|
||||
def cudaMalloc(self, size: int) -> ctypes.c_void_p:
|
||||
devPtr = ctypes.c_void_p()
|
||||
self.CUDART_CHECK(self.funcs["cudaMalloc"](ctypes.byref(devPtr), size))
|
||||
return devPtr
|
||||
|
||||
def cudaFree(self, devPtr: ctypes.c_void_p) -> None:
|
||||
self.CUDART_CHECK(self.funcs["cudaFree"](devPtr))
|
||||
|
||||
def cudaMemset(self, devPtr: ctypes.c_void_p, value: int, count: int) -> None:
|
||||
self.CUDART_CHECK(self.funcs["cudaMemset"](devPtr, value, count))
|
||||
|
||||
def cudaMemcpy(
|
||||
self, dst: ctypes.c_void_p, src: ctypes.c_void_p, count: int
|
||||
) -> None:
|
||||
cudaMemcpyDefault = 4
|
||||
kind = cudaMemcpyDefault
|
||||
self.CUDART_CHECK(self.funcs["cudaMemcpy"](dst, src, count, kind))
|
||||
|
||||
def cudaIpcGetMemHandle(self, devPtr: ctypes.c_void_p) -> cudaIpcMemHandle_t:
|
||||
handle = cudaIpcMemHandle_t()
|
||||
self.CUDART_CHECK(
|
||||
self.funcs["cudaIpcGetMemHandle"](ctypes.byref(handle), devPtr)
|
||||
)
|
||||
return handle
|
||||
|
||||
def cudaIpcOpenMemHandle(self, handle: cudaIpcMemHandle_t) -> ctypes.c_void_p:
|
||||
cudaIpcMemLazyEnablePeerAccess = 1
|
||||
devPtr = ctypes.c_void_p()
|
||||
self.CUDART_CHECK(
|
||||
self.funcs["cudaIpcOpenMemHandle"](
|
||||
ctypes.byref(devPtr), handle, cudaIpcMemLazyEnablePeerAccess
|
||||
)
|
||||
)
|
||||
return devPtr
|
||||
326
vllm/distributed/device_communicators/custom_all_reduce.py
Normal file
326
vllm/distributed/device_communicators/custom_all_reduce.py
Normal file
@@ -0,0 +1,326 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import cast
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.distributed.device_communicators.all_reduce_utils import (
|
||||
CUSTOM_ALL_REDUCE_MAX_SIZES,
|
||||
gpu_p2p_access_check,
|
||||
)
|
||||
from vllm.distributed.parallel_state import in_the_same_node_as
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import cuda_device_count_stateless
|
||||
|
||||
try:
|
||||
ops.meta_size()
|
||||
custom_ar = True
|
||||
except Exception:
|
||||
# For CPUs
|
||||
custom_ar = False
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def _can_p2p(rank: int, world_size: int) -> bool:
|
||||
for i in range(world_size):
|
||||
if i == rank:
|
||||
continue
|
||||
if envs.VLLM_SKIP_P2P_CHECK:
|
||||
logger.debug("Skipping P2P check and trusting the driver's P2P report.")
|
||||
return torch.cuda.can_device_access_peer(rank, i)
|
||||
if not gpu_p2p_access_check(rank, i):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def is_weak_contiguous(inp: torch.Tensor):
|
||||
return inp.is_contiguous() or (
|
||||
inp.storage().nbytes() - inp.storage_offset() * inp.element_size()
|
||||
== inp.numel() * inp.element_size()
|
||||
)
|
||||
|
||||
|
||||
class CustomAllreduce:
|
||||
_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
|
||||
|
||||
# max_size: max supported allreduce size
|
||||
def __init__(
|
||||
self,
|
||||
group: ProcessGroup,
|
||||
device: int | str | torch.device,
|
||||
max_size=8192 * 1024,
|
||||
symm_mem_enabled=False,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
group: the process group to work on. If None, it will use the
|
||||
default process group.
|
||||
device: the device to bind the CustomAllreduce to. If None,
|
||||
it will be bound to f"cuda:{local_rank}".
|
||||
It is the caller's responsibility to make sure each communicator
|
||||
is bind to a unique device, and all communicators in this group
|
||||
are in the same node.
|
||||
"""
|
||||
self._IS_CAPTURING = False
|
||||
self.disabled = True
|
||||
|
||||
if not custom_ar:
|
||||
# disable because of missing custom allreduce library
|
||||
# e.g. in a non-GPU environment
|
||||
logger.info(
|
||||
"Custom allreduce is disabled because "
|
||||
"of missing custom allreduce library"
|
||||
)
|
||||
return
|
||||
|
||||
self.group = group
|
||||
|
||||
assert dist.get_backend(group) != dist.Backend.NCCL, (
|
||||
"CustomAllreduce should be attached to a non-NCCL group."
|
||||
)
|
||||
|
||||
if not all(in_the_same_node_as(group, source_rank=0)):
|
||||
# No need to initialize custom allreduce for multi-node case.
|
||||
logger.warning(
|
||||
"Custom allreduce is disabled because this process group"
|
||||
" spans across nodes."
|
||||
)
|
||||
return
|
||||
|
||||
rank = dist.get_rank(group=self.group)
|
||||
self.rank = rank
|
||||
world_size = dist.get_world_size(group=self.group)
|
||||
if world_size == 1:
|
||||
# No need to initialize custom allreduce for single GPU case.
|
||||
return
|
||||
|
||||
if world_size not in CustomAllreduce._SUPPORTED_WORLD_SIZES:
|
||||
logger.warning(
|
||||
"Custom allreduce is disabled due to an unsupported world"
|
||||
" size: %d. Supported world sizes: %s. To silence this "
|
||||
"warning, specify disable_custom_all_reduce=True explicitly.",
|
||||
world_size,
|
||||
str(CustomAllreduce._SUPPORTED_WORLD_SIZES),
|
||||
)
|
||||
return
|
||||
|
||||
if isinstance(device, int):
|
||||
device = torch.device(f"cuda:{device}")
|
||||
elif isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
# now `device` is a `torch.device` object
|
||||
assert isinstance(device, torch.device)
|
||||
self.device = device
|
||||
device_capability = current_platform.get_device_capability()
|
||||
if (
|
||||
current_platform.is_cuda()
|
||||
and symm_mem_enabled
|
||||
and device_capability is not None
|
||||
):
|
||||
device_capability_str = device_capability.as_version_str()
|
||||
if device_capability_str in CUSTOM_ALL_REDUCE_MAX_SIZES:
|
||||
max_size = min(
|
||||
CUSTOM_ALL_REDUCE_MAX_SIZES[device_capability_str][world_size],
|
||||
max_size,
|
||||
)
|
||||
cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
|
||||
if cuda_visible_devices:
|
||||
device_ids = list(map(int, cuda_visible_devices.split(",")))
|
||||
else:
|
||||
device_ids = list(range(cuda_device_count_stateless()))
|
||||
|
||||
physical_device_id = device_ids[device.index]
|
||||
tensor = torch.tensor([physical_device_id], dtype=torch.int, device="cpu")
|
||||
gather_list = [
|
||||
torch.tensor([0], dtype=torch.int, device="cpu") for _ in range(world_size)
|
||||
]
|
||||
dist.all_gather(gather_list, tensor, group=self.group)
|
||||
physical_device_ids = [t.item() for t in gather_list]
|
||||
|
||||
# test nvlink first, this will filter out most of the cases
|
||||
# where custom allreduce is not supported
|
||||
# this checks hardware and driver support for NVLink
|
||||
assert current_platform.is_cuda_alike()
|
||||
fully_connected = current_platform.is_fully_connected(physical_device_ids)
|
||||
if world_size > 2 and not fully_connected:
|
||||
logger.warning(
|
||||
"Custom allreduce is disabled because it's not supported on"
|
||||
" more than two PCIe-only GPUs. To silence this warning, "
|
||||
"specify disable_custom_all_reduce=True explicitly."
|
||||
)
|
||||
return
|
||||
# test P2P capability, this checks software/cudaruntime support
|
||||
# this is expensive to compute at the first time
|
||||
# then we cache the result
|
||||
# On AMD GPU, p2p is always enabled between XGMI connected GPUs
|
||||
if not current_platform.is_rocm() and not _can_p2p(rank, world_size):
|
||||
logger.warning(
|
||||
"Custom allreduce is disabled because your platform lacks "
|
||||
"GPU P2P capability or P2P test failed. To silence this "
|
||||
"warning, specify disable_custom_all_reduce=True explicitly."
|
||||
)
|
||||
return
|
||||
|
||||
self.disabled = False
|
||||
# Buffers memory are owned by this Python class and passed to C++.
|
||||
# Metadata composes of two parts: metadata for synchronization and a
|
||||
# temporary buffer for storing intermediate allreduce results.
|
||||
self.meta_ptrs = self.create_shared_buffer(
|
||||
ops.meta_size() + max_size, group=group, uncached=True
|
||||
)
|
||||
# This is a pre-registered IPC buffer. In eager mode, input tensors
|
||||
# are first copied into this buffer before allreduce is performed
|
||||
self.buffer_ptrs = self.create_shared_buffer(max_size, group=group)
|
||||
# This is a buffer for storing the tuples of pointers pointing to
|
||||
# IPC buffers from all ranks. Each registered tuple has size of
|
||||
# 8*world_size bytes where world_size is at most 8. Allocating 8MB
|
||||
# is enough for 131072 such tuples. The largest model I've seen only
|
||||
# needs less than 10000 of registered tuples.
|
||||
self.rank_data = torch.empty(
|
||||
8 * 1024 * 1024, dtype=torch.uint8, device=self.device
|
||||
)
|
||||
self.max_size = max_size
|
||||
self.rank = rank
|
||||
self.world_size = world_size
|
||||
self.fully_connected = fully_connected
|
||||
self._ptr = ops.init_custom_ar(
|
||||
self.meta_ptrs, self.rank_data, rank, self.fully_connected
|
||||
)
|
||||
ops.register_buffer(self._ptr, self.buffer_ptrs)
|
||||
|
||||
@contextmanager
|
||||
def capture(self):
|
||||
"""
|
||||
The main responsibility of this context manager is the
|
||||
`register_graph_buffers` call at the end of the context.
|
||||
It records all the buffer addresses used in the CUDA graph.
|
||||
"""
|
||||
try:
|
||||
self._IS_CAPTURING = True
|
||||
yield
|
||||
finally:
|
||||
self._IS_CAPTURING = False
|
||||
if not self.disabled:
|
||||
self.register_graph_buffers()
|
||||
|
||||
def register_graph_buffers(self):
|
||||
handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr)
|
||||
logger.info("Registering %d cuda graph addresses", len(offset))
|
||||
# We cannot directly use `dist.all_gather_object` here
|
||||
# because it is incompatible with `gloo` backend under inference mode.
|
||||
# see https://github.com/pytorch/pytorch/issues/126032 for details.
|
||||
all_data: list[list[list[int] | None]]
|
||||
all_data = [[None, None] for _ in range(dist.get_world_size(group=self.group))]
|
||||
all_data[self.rank] = [handle, offset]
|
||||
ranks = sorted(dist.get_process_group_ranks(group=self.group))
|
||||
for i, rank in enumerate(ranks):
|
||||
dist.broadcast_object_list(
|
||||
all_data[i], src=rank, group=self.group, device="cpu"
|
||||
)
|
||||
# Unpack list of tuples to tuple of lists.
|
||||
handles = cast(list[list[int]], [d[0] for d in all_data])
|
||||
offsets = cast(list[list[int]], [d[1] for d in all_data])
|
||||
ops.register_graph_buffers(self._ptr, handles, offsets)
|
||||
|
||||
def should_custom_ar(self, inp: torch.Tensor):
|
||||
if self.disabled:
|
||||
return False
|
||||
inp_size = inp.numel() * inp.element_size()
|
||||
# custom allreduce requires input byte size to be multiples of 16
|
||||
if inp_size % 16 != 0:
|
||||
return False
|
||||
if not is_weak_contiguous(inp):
|
||||
return False
|
||||
# for 4 or more non NVLink-capable GPUs, custom allreduce provides
|
||||
# little performance improvement over NCCL.
|
||||
if self.world_size == 2 or self.fully_connected:
|
||||
return inp_size < self.max_size
|
||||
return False
|
||||
|
||||
def all_reduce(
|
||||
self, inp: torch.Tensor, *, out: torch.Tensor = None, registered: bool = False
|
||||
):
|
||||
"""Performs an out-of-place all reduce.
|
||||
|
||||
If registered is True, this assumes inp's pointer is already
|
||||
IPC-registered. Otherwise, inp is first copied into a pre-registered
|
||||
buffer.
|
||||
"""
|
||||
if out is None:
|
||||
out = torch.empty_like(inp)
|
||||
if registered:
|
||||
ops.all_reduce(self._ptr, inp, out, 0, 0)
|
||||
else:
|
||||
ops.all_reduce(
|
||||
self._ptr, inp, out, self.buffer_ptrs[self.rank], self.max_size
|
||||
)
|
||||
return out
|
||||
|
||||
def custom_all_reduce(self, input: torch.Tensor) -> torch.Tensor | None:
|
||||
"""The main allreduce API that provides support for cuda graph."""
|
||||
# When custom allreduce is disabled, this will be None.
|
||||
if self.disabled or not self.should_custom_ar(input):
|
||||
return None
|
||||
if self._IS_CAPTURING:
|
||||
if torch.cuda.is_current_stream_capturing():
|
||||
return self.all_reduce(input, registered=True)
|
||||
else:
|
||||
# If warm up, mimic the allocation pattern since custom
|
||||
# allreduce is out-of-place.
|
||||
return torch.empty_like(input)
|
||||
else:
|
||||
# Note: outside of cuda graph context, custom allreduce incurs a
|
||||
# cost of cudaMemcpy, which should be small (<=1% of overall
|
||||
# latency) compared to the performance gain of using custom kernels
|
||||
return self.all_reduce(input, registered=False)
|
||||
|
||||
def close(self):
|
||||
if not self.disabled and self._ptr:
|
||||
if ops is not None:
|
||||
ops.dispose(self._ptr)
|
||||
self._ptr = 0
|
||||
self.free_shared_buffer(self.meta_ptrs, rank=self.rank)
|
||||
self.free_shared_buffer(self.buffer_ptrs, rank=self.rank)
|
||||
|
||||
def __del__(self):
|
||||
self.close()
|
||||
|
||||
@staticmethod
|
||||
def create_shared_buffer(
|
||||
size_in_bytes: int,
|
||||
group: ProcessGroup | None = None,
|
||||
uncached: bool | None = False,
|
||||
) -> list[int]:
|
||||
pointer, handle = ops.allocate_shared_buffer_and_handle(size_in_bytes)
|
||||
|
||||
world_size = dist.get_world_size(group=group)
|
||||
rank = dist.get_rank(group=group)
|
||||
handles = [None] * world_size
|
||||
dist.all_gather_object(handles, handle, group=group)
|
||||
|
||||
pointers: list[int] = []
|
||||
for i, h in enumerate(handles):
|
||||
if i == rank:
|
||||
pointers.append(pointer) # type: ignore
|
||||
else:
|
||||
pointers.append(ops.open_mem_handle(h))
|
||||
return pointers
|
||||
|
||||
@staticmethod
|
||||
def free_shared_buffer(
|
||||
pointers: list[int],
|
||||
group: ProcessGroup | None = None,
|
||||
rank: int | None = None,
|
||||
) -> None:
|
||||
if rank is None:
|
||||
rank = dist.get_rank(group=group)
|
||||
if ops is not None:
|
||||
ops.free_shared_buffer(pointers[rank])
|
||||
252
vllm/distributed/device_communicators/flashinfer_all_reduce.py
Normal file
252
vllm/distributed/device_communicators/flashinfer_all_reduce.py
Normal file
@@ -0,0 +1,252 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config.compilation import PassConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
fi_ar_available = False
|
||||
try:
|
||||
import flashinfer.comm as flashinfer_comm # type: ignore[no-redef]
|
||||
from flashinfer.comm.mnnvl import (
|
||||
TorchDistBackend, # type: ignore[import-not-found, no-redef]
|
||||
)
|
||||
|
||||
fi_ar_available = hasattr(flashinfer_comm, "allreduce_fusion")
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# Global workspace for standalone allreduce and non-quant ar+rms fusion
|
||||
_fi_ar_workspace = None
|
||||
# Extra workspace for quant fusion patterns (only supported by trtllm backend)
|
||||
# Only created if primary workspace is not already trtllm
|
||||
_fi_ar_quant_workspace = None
|
||||
|
||||
|
||||
def get_fi_ar_workspace():
|
||||
return _fi_ar_workspace
|
||||
|
||||
|
||||
def get_fi_ar_quant_workspace():
|
||||
return _fi_ar_quant_workspace
|
||||
|
||||
|
||||
def initialize_fi_ar_workspace(
|
||||
world_size: int,
|
||||
rank: int,
|
||||
max_token_num: int,
|
||||
hidden_dim: int,
|
||||
dtype: torch.dtype,
|
||||
group: ProcessGroup,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the workspace if not already initialized.
|
||||
|
||||
Currently, this function is called by either the AllReduceFusionPass
|
||||
or the FlashInferAllReduce backend for standalone allreduce.
|
||||
If the fusion pass is enabled via
|
||||
--compilation-config.pass_config.fuse_allreduce_rms=true,
|
||||
it will create the workspace first, and the standalone backend
|
||||
will reuse the workspace. Otherwise, the standalone backend will
|
||||
create the workspace.
|
||||
"""
|
||||
global _fi_ar_workspace
|
||||
if _fi_ar_workspace is not None:
|
||||
return
|
||||
|
||||
backend = envs.VLLM_FLASHINFER_ALLREDUCE_BACKEND
|
||||
comm_backend = TorchDistBackend(group=group)
|
||||
_fi_ar_workspace = flashinfer_comm.create_allreduce_fusion_workspace(
|
||||
backend=backend,
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
max_token_num=max_token_num,
|
||||
hidden_dim=hidden_dim,
|
||||
dtype=dtype,
|
||||
comm_backend=comm_backend,
|
||||
)
|
||||
assert _fi_ar_workspace is not None
|
||||
logger.debug(
|
||||
"Initialized FlashInfer All Reduce workspace: backend=%s, "
|
||||
"world_size=%d, rank=%d, max_token_num=%d, hidden_dim=%d, dtype=%s",
|
||||
backend,
|
||||
world_size,
|
||||
rank,
|
||||
max_token_num,
|
||||
hidden_dim,
|
||||
dtype,
|
||||
)
|
||||
|
||||
|
||||
def initialize_fi_ar_quant_workspace(
|
||||
world_size: int,
|
||||
rank: int,
|
||||
max_token_num: int,
|
||||
hidden_dim: int,
|
||||
dtype: torch.dtype,
|
||||
group: ProcessGroup,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the workspace used by quantization fusion patterns.
|
||||
|
||||
Currently this always creates a workspace for trtllm backend as only it
|
||||
supports quantization fusion (FP8/FP4). If the primary workspace
|
||||
is already trtllm, the quant workspace aliases to it.
|
||||
"""
|
||||
global _fi_ar_quant_workspace
|
||||
if _fi_ar_quant_workspace is not None:
|
||||
return
|
||||
|
||||
# If primary workspace is already trtllm, reuse it
|
||||
if _fi_ar_workspace is not None and _fi_ar_workspace.backend == "trtllm":
|
||||
_fi_ar_quant_workspace = _fi_ar_workspace
|
||||
return
|
||||
|
||||
comm_backend = TorchDistBackend(group=group)
|
||||
_fi_ar_quant_workspace = flashinfer_comm.create_allreduce_fusion_workspace(
|
||||
backend="trtllm",
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
max_token_num=max_token_num,
|
||||
hidden_dim=hidden_dim,
|
||||
dtype=dtype,
|
||||
comm_backend=comm_backend,
|
||||
)
|
||||
assert _fi_ar_quant_workspace is not None
|
||||
logger.debug(
|
||||
"Initialized FlashInfer All Reduce workspace: backend=trtllm, "
|
||||
"world_size=%d, rank=%d, max_token_num=%d, hidden_dim=%d, dtype=%s",
|
||||
world_size,
|
||||
rank,
|
||||
max_token_num,
|
||||
hidden_dim,
|
||||
dtype,
|
||||
)
|
||||
|
||||
|
||||
def destroy_fi_ar_workspace():
|
||||
global _fi_ar_workspace
|
||||
global _fi_ar_quant_workspace
|
||||
if (
|
||||
_fi_ar_quant_workspace is not None
|
||||
and _fi_ar_quant_workspace is not _fi_ar_workspace
|
||||
):
|
||||
_fi_ar_quant_workspace.destroy()
|
||||
_fi_ar_quant_workspace = None
|
||||
if _fi_ar_workspace is not None:
|
||||
_fi_ar_workspace.destroy()
|
||||
_fi_ar_workspace = None
|
||||
|
||||
|
||||
class FlashInferAllReduce:
|
||||
def __init__(
|
||||
self,
|
||||
group: ProcessGroup,
|
||||
device: int | str | torch.device,
|
||||
):
|
||||
self.disabled = True
|
||||
|
||||
if not fi_ar_available:
|
||||
logger.info(
|
||||
"FlashInfer All Reduce is disabled because flashinfer is not available"
|
||||
)
|
||||
return
|
||||
|
||||
if not current_platform.is_cuda():
|
||||
logger.info(
|
||||
"FlashInfer All Reduce is disabled because it requires CUDA platform"
|
||||
)
|
||||
return
|
||||
|
||||
self.group = group
|
||||
self.world_size = dist.get_world_size(self.group)
|
||||
self.rank = dist.get_rank(self.group)
|
||||
self.device = device
|
||||
if self.world_size == 1:
|
||||
return
|
||||
|
||||
# Use the same threshold as the allreduce-rms fusion pass
|
||||
# TODO: tune the threshold
|
||||
MiB = 1024 * 1024
|
||||
max_workspace_size = PassConfig.default_fi_allreduce_fusion_max_size_mb().get(
|
||||
self.world_size, None
|
||||
)
|
||||
if not max_workspace_size:
|
||||
logger.warning(
|
||||
"FlashInfer All Reduce is disabled because it "
|
||||
"is not supported for world_size=%d.",
|
||||
self.world_size,
|
||||
)
|
||||
return
|
||||
self.max_workspace_size = max_workspace_size * MiB
|
||||
self.max_num_tokens = 0
|
||||
self.disabled = False
|
||||
|
||||
def _ensure_workspace(self, hidden_dim: int, dtype: torch.dtype) -> bool:
|
||||
"""Ensure the all reduce workspace is initialized."""
|
||||
if get_fi_ar_workspace() is not None:
|
||||
return True
|
||||
if self.max_num_tokens == 0:
|
||||
element_size = torch.tensor([], dtype=dtype, device="cpu").element_size()
|
||||
self.max_num_tokens = self.max_workspace_size // (hidden_dim * element_size)
|
||||
try:
|
||||
initialize_fi_ar_workspace(
|
||||
world_size=self.world_size,
|
||||
rank=self.rank,
|
||||
max_token_num=self.max_num_tokens,
|
||||
hidden_dim=hidden_dim,
|
||||
dtype=dtype,
|
||||
group=self.group,
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to initialize FlashInfer All Reduce workspace: %s. "
|
||||
"FlashInfer All Reduce will be disabled.",
|
||||
e,
|
||||
)
|
||||
self.disabled = True
|
||||
return False
|
||||
|
||||
def should_use_fi_ar(self, input_tensor: torch.Tensor) -> bool:
|
||||
if self.disabled:
|
||||
return False
|
||||
|
||||
if not input_tensor.is_cuda:
|
||||
return False
|
||||
|
||||
if not input_tensor.is_contiguous():
|
||||
return False
|
||||
|
||||
if len(input_tensor.shape) != 2:
|
||||
return False
|
||||
|
||||
num_tokens, hidden_dim = input_tensor.shape
|
||||
if not self.max_num_tokens:
|
||||
element_size = torch.tensor([], dtype=input_tensor.dtype).element_size()
|
||||
self.max_num_tokens = self.max_workspace_size // (hidden_dim * element_size)
|
||||
|
||||
if num_tokens > self.max_num_tokens:
|
||||
return False
|
||||
|
||||
return self._ensure_workspace(hidden_dim, input_tensor.dtype)
|
||||
|
||||
def all_reduce(self, input_tensor: torch.Tensor) -> torch.Tensor:
|
||||
workspace = get_fi_ar_workspace()
|
||||
return flashinfer_comm.allreduce_fusion(
|
||||
input=input_tensor,
|
||||
workspace=workspace,
|
||||
pattern=flashinfer_comm.AllReduceFusionPattern.kAllReduce,
|
||||
)
|
||||
|
||||
def destroy(self):
|
||||
if not self.disabled:
|
||||
destroy_fi_ar_workspace()
|
||||
38
vllm/distributed/device_communicators/mnnvl_compat.py
Normal file
38
vllm/distributed/device_communicators/mnnvl_compat.py
Normal file
@@ -0,0 +1,38 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Any
|
||||
|
||||
import torch.distributed as dist
|
||||
from flashinfer.comm.mnnvl import CommBackend as CommBackend
|
||||
|
||||
from vllm.utils.flashinfer import has_flashinfer_all2all
|
||||
|
||||
assert has_flashinfer_all2all(), "Flashinfer alltoallv module cannot be found"
|
||||
|
||||
|
||||
class CustomCommunicator(CommBackend):
|
||||
def __init__(self, group):
|
||||
self._group = group
|
||||
|
||||
def Get_rank(self) -> int:
|
||||
return self._group.rank()
|
||||
|
||||
def Get_size(self) -> int:
|
||||
return self._group.size()
|
||||
|
||||
def allgather(self, data: int):
|
||||
gathered = [None] * self.Get_size()
|
||||
dist.all_gather_object(gathered, data, group=self._group)
|
||||
return gathered
|
||||
|
||||
# NOTE(rob): CommBackend is an abstract class, and bcast/barrier
|
||||
# are unimplemented on vLLM side. If we need to utilize these
|
||||
# methods in the future, can create a concrete implementation.
|
||||
def bcast(self, data: Any, root: int) -> Any:
|
||||
raise NotImplementedError
|
||||
|
||||
def barrier(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def Split(self, color: int, key: int) -> "CustomCommunicator":
|
||||
return self
|
||||
386
vllm/distributed/device_communicators/pynccl.py
Normal file
386
vllm/distributed/device_communicators/pynccl.py
Normal file
@@ -0,0 +1,386 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
# ===================== import region =====================
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup, ReduceOp
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.distributed.device_communicators.pynccl_wrapper import (
|
||||
NCCLLibrary,
|
||||
buffer_type,
|
||||
cudaStream_t,
|
||||
ncclComm_t,
|
||||
ncclDataTypeEnum,
|
||||
ncclRedOpTypeEnum,
|
||||
ncclUniqueId,
|
||||
)
|
||||
from vllm.distributed.utils import StatelessProcessGroup
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.torch_utils import current_stream
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_NCCL_SYMM_OPS_REGISTERED = False
|
||||
|
||||
|
||||
def register_nccl_symmetric_ops(pynccl_comm):
|
||||
from vllm.distributed.device_communicators.pynccl_allocator import (
|
||||
nccl_symm_mem_context,
|
||||
)
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
global _NCCL_SYMM_OPS_REGISTERED
|
||||
if _NCCL_SYMM_OPS_REGISTERED:
|
||||
return
|
||||
_NCCL_SYMM_OPS_REGISTERED = True
|
||||
|
||||
def all_reduce_symmetric_with_copy_impl(input_tensor: torch.Tensor) -> torch.Tensor:
|
||||
with nccl_symm_mem_context(pynccl_comm):
|
||||
symm_input = torch.empty_like(input_tensor)
|
||||
symm_output = torch.empty_like(input_tensor)
|
||||
symm_input.copy_(input_tensor)
|
||||
symm_output = pynccl_comm.all_reduce(symm_input, symm_output)
|
||||
return symm_output
|
||||
|
||||
def all_reduce_symmetric_with_copy_fake(input_tensor: torch.Tensor) -> torch.Tensor:
|
||||
return torch.empty_like(input_tensor)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="all_reduce_symmetric_with_copy",
|
||||
op_func=all_reduce_symmetric_with_copy_impl,
|
||||
fake_impl=all_reduce_symmetric_with_copy_fake,
|
||||
)
|
||||
|
||||
|
||||
class PyNcclCommunicator:
|
||||
def __init__(
|
||||
self,
|
||||
group: ProcessGroup | StatelessProcessGroup,
|
||||
device: int | str | torch.device,
|
||||
library_path: str | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
group: the process group to work on. If None, it will use the
|
||||
default process group.
|
||||
device: the device to bind the PyNcclCommunicator to. If None,
|
||||
it will be bound to f"cuda:{local_rank}".
|
||||
library_path: the path to the NCCL library. If None, it will
|
||||
use the default library path.
|
||||
It is the caller's responsibility to make sure each communicator
|
||||
is bind to a unique device.
|
||||
"""
|
||||
if not isinstance(group, StatelessProcessGroup):
|
||||
assert dist.is_initialized()
|
||||
assert dist.get_backend(group) != dist.Backend.NCCL, (
|
||||
"PyNcclCommunicator should be attached to a non-NCCL group."
|
||||
)
|
||||
# note: this rank is the rank in the group
|
||||
self.rank = dist.get_rank(group)
|
||||
self.world_size = dist.get_world_size(group)
|
||||
else:
|
||||
self.rank = group.rank
|
||||
self.world_size = group.world_size
|
||||
|
||||
self.group = group
|
||||
|
||||
# if world_size == 1, no need to create communicator
|
||||
if self.world_size == 1 or envs.VLLM_DISABLE_PYNCCL:
|
||||
self.available = False
|
||||
self.disabled = True
|
||||
return
|
||||
try:
|
||||
self.nccl = NCCLLibrary(library_path)
|
||||
except Exception:
|
||||
# disable because of missing NCCL library
|
||||
# e.g. in a non-GPU environment
|
||||
self.available = False
|
||||
self.disabled = True
|
||||
return
|
||||
|
||||
self.available = True
|
||||
self.disabled = False
|
||||
|
||||
self.nccl_version = self.nccl.ncclGetRawVersion()
|
||||
if self.rank == 0:
|
||||
# get the unique id from NCCL
|
||||
self.unique_id = self.nccl.ncclGetUniqueId()
|
||||
logger.info_once(
|
||||
"vLLM is using nccl==%s", self.nccl.ncclGetVersion(), scope="local"
|
||||
)
|
||||
else:
|
||||
# construct an empty unique id
|
||||
self.unique_id = ncclUniqueId()
|
||||
|
||||
if not isinstance(group, StatelessProcessGroup):
|
||||
tensor = torch.ByteTensor(list(self.unique_id.internal))
|
||||
ranks = dist.get_process_group_ranks(group)
|
||||
# arg `src` in `broadcast` is the global rank
|
||||
dist.broadcast(tensor, src=ranks[0], group=group)
|
||||
byte_list = tensor.tolist()
|
||||
for i, byte in enumerate(byte_list):
|
||||
self.unique_id.internal[i] = byte
|
||||
else:
|
||||
self.unique_id = group.broadcast_obj(self.unique_id, src=0)
|
||||
if isinstance(device, int):
|
||||
device = torch.device(f"cuda:{device}")
|
||||
elif isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
# now `device` is a `torch.device` object
|
||||
assert isinstance(device, torch.device)
|
||||
self.device = device
|
||||
# nccl communicator and stream will use this device
|
||||
# `torch.cuda.device` is a context manager that changes the
|
||||
# current cuda device to the specified one
|
||||
with torch.cuda.device(device):
|
||||
self.comm: ncclComm_t = self.nccl.ncclCommInitRank(
|
||||
self.world_size, self.unique_id, self.rank
|
||||
)
|
||||
|
||||
stream = current_stream()
|
||||
# A small all_reduce for warmup.
|
||||
data = torch.zeros(1, device=device)
|
||||
self.all_reduce(data)
|
||||
stream.synchronize()
|
||||
del data
|
||||
|
||||
def all_reduce(
|
||||
self,
|
||||
in_tensor: torch.Tensor,
|
||||
out_tensor: torch.Tensor = None,
|
||||
op: ReduceOp = ReduceOp.SUM,
|
||||
stream=None,
|
||||
) -> torch.Tensor:
|
||||
if self.disabled:
|
||||
return None
|
||||
# nccl communicator created on a specific device
|
||||
# will only work on tensors on the same device
|
||||
# otherwise it will cause "illegal memory access"
|
||||
assert in_tensor.device == self.device, (
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {in_tensor.device}"
|
||||
)
|
||||
|
||||
if out_tensor is None:
|
||||
out_tensor = torch.empty_like(in_tensor)
|
||||
|
||||
if stream is None:
|
||||
stream = current_stream()
|
||||
self.nccl.ncclAllReduce(
|
||||
buffer_type(in_tensor.data_ptr()),
|
||||
buffer_type(out_tensor.data_ptr()),
|
||||
in_tensor.numel(),
|
||||
ncclDataTypeEnum.from_torch(in_tensor.dtype),
|
||||
ncclRedOpTypeEnum.from_torch(op),
|
||||
self.comm,
|
||||
cudaStream_t(stream.cuda_stream),
|
||||
)
|
||||
return out_tensor
|
||||
|
||||
def all_gather(
|
||||
self, output_tensor: torch.Tensor, input_tensor: torch.Tensor, stream=None
|
||||
):
|
||||
if self.disabled:
|
||||
return
|
||||
# nccl communicator created on a specific device
|
||||
# will only work on tensors on the same device
|
||||
# otherwise it will cause "illegal memory access"
|
||||
assert input_tensor.device == self.device, (
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {input_tensor.device}"
|
||||
)
|
||||
if stream is None:
|
||||
stream = current_stream()
|
||||
self.nccl.ncclAllGather(
|
||||
buffer_type(input_tensor.data_ptr()),
|
||||
buffer_type(output_tensor.data_ptr()),
|
||||
input_tensor.numel(),
|
||||
ncclDataTypeEnum.from_torch(input_tensor.dtype),
|
||||
self.comm,
|
||||
cudaStream_t(stream.cuda_stream),
|
||||
)
|
||||
|
||||
def all_gatherv(
|
||||
self,
|
||||
output_tensor: torch.Tensor,
|
||||
input_tensor: torch.Tensor,
|
||||
sizes: list[int],
|
||||
stream=None,
|
||||
):
|
||||
if self.disabled:
|
||||
return
|
||||
# nccl communicator created on a specific device
|
||||
# will only work on tensors on the same device
|
||||
# otherwise it will cause "illegal memory access"
|
||||
assert input_tensor.device == self.device, (
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {input_tensor.device}"
|
||||
)
|
||||
if stream is None:
|
||||
stream = current_stream()
|
||||
assert output_tensor.shape[0] == sum(sizes)
|
||||
split_offset = 0
|
||||
self.nccl.ncclGroupStart()
|
||||
for root, split_size in enumerate(sizes):
|
||||
dst_slice = output_tensor[split_offset : split_offset + split_size]
|
||||
self.nccl.ncclBroadcast(
|
||||
buffer_type(input_tensor.data_ptr()),
|
||||
buffer_type(dst_slice.data_ptr()),
|
||||
dst_slice.numel(),
|
||||
ncclDataTypeEnum.from_torch(input_tensor.dtype),
|
||||
root,
|
||||
self.comm,
|
||||
cudaStream_t(stream.cuda_stream),
|
||||
)
|
||||
split_offset += split_size
|
||||
self.nccl.ncclGroupEnd()
|
||||
|
||||
def reduce_scatter(
|
||||
self,
|
||||
output_tensor: torch.Tensor,
|
||||
input_tensor: torch.Tensor,
|
||||
op: ReduceOp = ReduceOp.SUM,
|
||||
stream=None,
|
||||
):
|
||||
if self.disabled:
|
||||
return
|
||||
# nccl communicator created on a specific device
|
||||
# will only work on tensors on the same device
|
||||
# otherwise it will cause "illegal memory access"
|
||||
assert input_tensor.device == self.device, (
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {input_tensor.device}"
|
||||
)
|
||||
if stream is None:
|
||||
stream = current_stream()
|
||||
self.nccl.ncclReduceScatter(
|
||||
buffer_type(input_tensor.data_ptr()),
|
||||
buffer_type(output_tensor.data_ptr()),
|
||||
output_tensor.numel(),
|
||||
ncclDataTypeEnum.from_torch(input_tensor.dtype),
|
||||
ncclRedOpTypeEnum.from_torch(op),
|
||||
self.comm,
|
||||
cudaStream_t(stream.cuda_stream),
|
||||
)
|
||||
|
||||
def reduce_scatterv(
|
||||
self,
|
||||
output_tensor: torch.Tensor,
|
||||
input_tensor: torch.Tensor,
|
||||
sizes: list[int],
|
||||
op: ReduceOp = ReduceOp.SUM,
|
||||
stream=None,
|
||||
):
|
||||
if self.disabled:
|
||||
return
|
||||
# nccl communicator created on a specific device
|
||||
# will only work on tensors on the same device
|
||||
# otherwise it will cause "illegal memory access"
|
||||
assert input_tensor.device == self.device, (
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {input_tensor.device}"
|
||||
)
|
||||
if stream is None:
|
||||
stream = current_stream()
|
||||
|
||||
split_offset = 0
|
||||
self.nccl.ncclGroupStart()
|
||||
for root, split_size in enumerate(sizes):
|
||||
chunk = input_tensor[split_offset : split_offset + split_size, ...]
|
||||
self.nccl.ncclReduce(
|
||||
buffer_type(chunk.data_ptr()),
|
||||
buffer_type(output_tensor.data_ptr()),
|
||||
chunk.numel(),
|
||||
ncclDataTypeEnum.from_torch(input_tensor.dtype),
|
||||
ncclRedOpTypeEnum.from_torch(op),
|
||||
root,
|
||||
self.comm,
|
||||
cudaStream_t(stream.cuda_stream),
|
||||
)
|
||||
split_offset += split_size
|
||||
self.nccl.ncclGroupEnd()
|
||||
|
||||
def send(self, tensor: torch.Tensor, dst: int, stream=None):
|
||||
if self.disabled:
|
||||
return
|
||||
assert tensor.device == self.device, (
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {tensor.device}"
|
||||
)
|
||||
if stream is None:
|
||||
stream = current_stream()
|
||||
self.nccl.ncclSend(
|
||||
buffer_type(tensor.data_ptr()),
|
||||
tensor.numel(),
|
||||
ncclDataTypeEnum.from_torch(tensor.dtype),
|
||||
dst,
|
||||
self.comm,
|
||||
cudaStream_t(stream.cuda_stream),
|
||||
)
|
||||
|
||||
def recv(self, tensor: torch.Tensor, src: int, stream=None):
|
||||
if self.disabled:
|
||||
return
|
||||
assert tensor.device == self.device, (
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {tensor.device}"
|
||||
)
|
||||
if stream is None:
|
||||
stream = current_stream()
|
||||
self.nccl.ncclRecv(
|
||||
buffer_type(tensor.data_ptr()),
|
||||
tensor.numel(),
|
||||
ncclDataTypeEnum.from_torch(tensor.dtype),
|
||||
src,
|
||||
self.comm,
|
||||
cudaStream_t(stream.cuda_stream),
|
||||
)
|
||||
|
||||
def broadcast(self, tensor: torch.Tensor, src: int, stream=None):
|
||||
if self.disabled:
|
||||
return
|
||||
assert tensor.device == self.device, (
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {tensor.device}"
|
||||
)
|
||||
if stream is None:
|
||||
stream = current_stream()
|
||||
if src == self.rank:
|
||||
sendbuff = buffer_type(tensor.data_ptr())
|
||||
# NCCL requires the sender also to have a receive buffer
|
||||
recvbuff = buffer_type(tensor.data_ptr())
|
||||
else:
|
||||
sendbuff = buffer_type()
|
||||
recvbuff = buffer_type(tensor.data_ptr())
|
||||
self.nccl.ncclBroadcast(
|
||||
sendbuff,
|
||||
recvbuff,
|
||||
tensor.numel(),
|
||||
ncclDataTypeEnum.from_torch(tensor.dtype),
|
||||
src,
|
||||
self.comm,
|
||||
cudaStream_t(stream.cuda_stream),
|
||||
)
|
||||
|
||||
def group_start(self):
|
||||
self.nccl.ncclGroupStart()
|
||||
|
||||
def group_end(self):
|
||||
self.nccl.ncclGroupEnd()
|
||||
|
||||
def register_comm_window(self, tensor: torch.Tensor):
|
||||
return self.nccl.ncclCommWindowRegister(
|
||||
self.comm,
|
||||
buffer_type(tensor.data_ptr()),
|
||||
tensor.numel() * tensor.element_size(),
|
||||
1,
|
||||
)
|
||||
|
||||
def register_comm_window_raw(self, ptr: int, size: int):
|
||||
return self.nccl.ncclCommWindowRegister(self.comm, buffer_type(ptr), size, 1)
|
||||
|
||||
def deregister_comm_window(self, window):
|
||||
return self.nccl.ncclCommWindowDeregister(self.comm, window)
|
||||
191
vllm/distributed/device_communicators/pynccl_allocator.py
Normal file
191
vllm/distributed/device_communicators/pynccl_allocator.py
Normal file
@@ -0,0 +1,191 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import atexit
|
||||
import contextlib
|
||||
import tempfile
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from packaging import version
|
||||
from torch.cuda.memory import CUDAPluggableAllocator
|
||||
from torch.utils.cpp_extension import load_inline
|
||||
|
||||
from vllm import envs
|
||||
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.nccl import find_nccl_include_paths
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
nccl_allocator_source = """
|
||||
#include <nccl.h>
|
||||
extern "C" {
|
||||
|
||||
void* nccl_alloc_plug(size_t size, int device, void* stream) {
|
||||
void* ptr;
|
||||
ncclResult_t err = ncclMemAlloc(&ptr, size);
|
||||
return ptr;
|
||||
|
||||
}
|
||||
|
||||
void nccl_free_plug(void* ptr, size_t size, int device, void* stream) {
|
||||
ncclResult_t err = ncclMemFree(ptr);
|
||||
}
|
||||
|
||||
}
|
||||
"""
|
||||
|
||||
_allocator = None
|
||||
_allocator_wrapper = None
|
||||
_mem_pool = None
|
||||
_registered_base_addrs = set()
|
||||
_graph_pool_id = None
|
||||
_nccl_allocator_failed_to_compile = False
|
||||
_cached_pool_snapshot = None
|
||||
|
||||
|
||||
def is_symmetric_memory_enabled():
|
||||
global _nccl_allocator_failed_to_compile
|
||||
return envs.VLLM_USE_NCCL_SYMM_MEM and not _nccl_allocator_failed_to_compile
|
||||
|
||||
|
||||
def is_symmetric_memory_tensor(tensor: torch.Tensor):
|
||||
if not is_symmetric_memory_enabled() or _cached_pool_snapshot is None:
|
||||
return False
|
||||
for segment in _cached_pool_snapshot:
|
||||
for block in segment["blocks"]:
|
||||
if block["address"] == tensor.untyped_storage().data_ptr():
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def set_graph_pool_id(graph_pool_id: Any) -> None:
|
||||
global _graph_pool_id
|
||||
_graph_pool_id = graph_pool_id
|
||||
|
||||
|
||||
def compile_nccl_allocator():
|
||||
global _allocator, _allocator_wrapper, _nccl_allocator_failed_to_compile
|
||||
if not current_platform.is_cuda():
|
||||
_nccl_allocator_failed_to_compile = True
|
||||
return
|
||||
try:
|
||||
out_dir = tempfile.gettempdir()
|
||||
nccl_allocator_libname = "nccl_allocator"
|
||||
nccl_include_paths = find_nccl_include_paths()
|
||||
load_inline(
|
||||
name=nccl_allocator_libname,
|
||||
cpp_sources=nccl_allocator_source,
|
||||
with_cuda=True,
|
||||
extra_ldflags=["-lnccl"],
|
||||
verbose=envs.VLLM_LOGGING_LEVEL == "DEBUG",
|
||||
is_python_module=False,
|
||||
build_directory=out_dir,
|
||||
extra_include_paths=nccl_include_paths,
|
||||
)
|
||||
_allocator_wrapper = CUDAPluggableAllocator(
|
||||
f"{out_dir}/{nccl_allocator_libname}.so",
|
||||
"nccl_alloc_plug",
|
||||
"nccl_free_plug",
|
||||
)
|
||||
_allocator = _allocator_wrapper.allocator()
|
||||
except Exception as e:
|
||||
_nccl_allocator_failed_to_compile = True
|
||||
logger.warning(
|
||||
"Failed to compile NCCL memory allocator. "
|
||||
"Symmetric memory will be disabled. "
|
||||
"This is expected if NCCL headers are not available. "
|
||||
"optionally set VLLM_NCCL_INCLUDE_PATH to point to a directory "
|
||||
"containing the NCCL header. "
|
||||
"Error: %s",
|
||||
str(e),
|
||||
)
|
||||
|
||||
|
||||
def get_nccl_mem_pool():
|
||||
global _mem_pool, _nccl_allocator_failed_to_compile
|
||||
if _mem_pool is None and not _nccl_allocator_failed_to_compile:
|
||||
compile_nccl_allocator()
|
||||
if _allocator is not None:
|
||||
_mem_pool = torch.cuda.MemPool(_allocator)
|
||||
return _mem_pool
|
||||
|
||||
|
||||
def _cleanup_nccl_mem_pool():
|
||||
global _mem_pool
|
||||
_mem_pool = None
|
||||
|
||||
|
||||
def _cleanup_nccl_allocator_wrapper():
|
||||
global _allocator_wrapper
|
||||
_allocator_wrapper = None
|
||||
|
||||
|
||||
atexit.register(_cleanup_nccl_mem_pool)
|
||||
atexit.register(_cleanup_nccl_allocator_wrapper)
|
||||
|
||||
|
||||
class nccl_symm_mem_context:
|
||||
def __init__(
|
||||
self,
|
||||
pynccl_comm: PyNcclCommunicator,
|
||||
disabled: bool = False,
|
||||
):
|
||||
self.disabled = (
|
||||
disabled
|
||||
or not is_symmetric_memory_enabled()
|
||||
or pynccl_comm.world_size == 1
|
||||
or not current_platform.is_cuda()
|
||||
or get_nccl_mem_pool() is None
|
||||
or version.parse(torch.__version__) < version.parse("2.8.0.a0")
|
||||
)
|
||||
if self.disabled:
|
||||
self.pynccl_comm: PyNcclCommunicator | None = None
|
||||
self._mem_pool_ctx: contextlib.AbstractContextManager[Any] = (
|
||||
contextlib.nullcontext()
|
||||
)
|
||||
self.is_graph_capture = None
|
||||
self.device = None
|
||||
else:
|
||||
self.pynccl_comm = pynccl_comm
|
||||
self._mem_pool_ctx = torch.cuda.use_mem_pool(get_nccl_mem_pool())
|
||||
self.is_graph_capture = torch.cuda.is_current_stream_capturing()
|
||||
self.device = torch.cuda.current_device()
|
||||
|
||||
def __enter__(self):
|
||||
if self.disabled:
|
||||
return self
|
||||
assert self.pynccl_comm is not None, (
|
||||
"Symmetric memory requires pynccl to be initialized"
|
||||
)
|
||||
assert self.pynccl_comm.nccl_version >= 22703, (
|
||||
"NCCL version 2.27.3 or higher is required for NCCL symmetric memory"
|
||||
)
|
||||
if self.is_graph_capture:
|
||||
assert _graph_pool_id is not None, (
|
||||
"graph_pool_id is not set under graph capture"
|
||||
)
|
||||
# Pause graph memory pool to use symmetric memory with cuda graph
|
||||
torch._C._cuda_endAllocateToPool(self.device, _graph_pool_id)
|
||||
self._mem_pool_ctx.__enter__()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
if self.disabled:
|
||||
return
|
||||
global _cached_pool_snapshot
|
||||
global _registered_base_addrs
|
||||
self._mem_pool_ctx.__exit__(exc_type, exc_val, exc_tb)
|
||||
_pool = get_nccl_mem_pool()
|
||||
assert _pool is not None
|
||||
_cached_pool_snapshot = _pool.snapshot()
|
||||
assert self.pynccl_comm is not None
|
||||
for segment in _cached_pool_snapshot:
|
||||
if segment["address"] not in _registered_base_addrs:
|
||||
self.pynccl_comm.register_comm_window_raw(
|
||||
segment["address"], segment["total_size"]
|
||||
)
|
||||
_registered_base_addrs.add(segment["address"])
|
||||
if self.is_graph_capture:
|
||||
torch._C._cuda_beginAllocateCurrentThreadToPool(self.device, _graph_pool_id)
|
||||
571
vllm/distributed/device_communicators/pynccl_wrapper.py
Normal file
571
vllm/distributed/device_communicators/pynccl_wrapper.py
Normal file
@@ -0,0 +1,571 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# This file is a pure Python wrapper for the NCCL library.
|
||||
# The main purpose is to use NCCL combined with CUDA graph.
|
||||
# Before writing this script, we tried the following approach:
|
||||
# 1. We tried to use `cupy`, it calls NCCL correctly, but `cupy` itself
|
||||
# often gets stuck when initializing the NCCL communicator.
|
||||
# 2. We tried to use `torch.distributed`, but `torch.distributed.all_reduce`
|
||||
# contains many other potential cuda APIs, that are not allowed during
|
||||
# capturing the CUDA graph. For further details, please check
|
||||
# https://discuss.pytorch.org/t/pytorch-cudagraph-with-nccl-operation-failed/ .
|
||||
#
|
||||
# Another rejected idea is to write a C/C++ binding for NCCL. It is usually
|
||||
# doable, but we often encounter issues related with nccl versions, and need
|
||||
# to switch between different versions of NCCL. See
|
||||
# https://github.com/NVIDIA/nccl/issues/1234 for more details.
|
||||
# A C/C++ binding is not flexible enough to handle this. It requires
|
||||
# recompilation of the code every time we want to switch between different
|
||||
# versions. This current implementation, with a **pure** Python wrapper, is
|
||||
# more flexible. We can easily switch between different versions of NCCL by
|
||||
# changing the environment variable `VLLM_NCCL_SO_PATH`, or the `so_file`
|
||||
# variable in the code.
|
||||
|
||||
import ctypes
|
||||
import platform
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch.distributed import ReduceOp
|
||||
|
||||
from vllm import envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.nccl import find_nccl_library
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# === export types and functions from nccl to Python ===
|
||||
# for the original nccl definition, please check
|
||||
# https://github.com/NVIDIA/nccl/blob/master/src/nccl.h.in
|
||||
|
||||
ncclResult_t = ctypes.c_int
|
||||
ncclComm_t = ctypes.c_void_p
|
||||
ncclWindow_t = ctypes.c_void_p
|
||||
|
||||
|
||||
class ncclUniqueId(ctypes.Structure):
|
||||
_fields_ = [("internal", ctypes.c_byte * 128)]
|
||||
|
||||
|
||||
cudaStream_t = ctypes.c_void_p
|
||||
buffer_type = ctypes.c_void_p
|
||||
|
||||
ncclDataType_t = ctypes.c_int
|
||||
|
||||
|
||||
class ncclDataTypeEnum:
|
||||
ncclInt8 = 0
|
||||
ncclChar = 0
|
||||
ncclUint8 = 1
|
||||
ncclInt32 = 2
|
||||
ncclInt = 2
|
||||
ncclUint32 = 3
|
||||
ncclInt64 = 4
|
||||
ncclUint64 = 5
|
||||
ncclFloat16 = 6
|
||||
ncclHalf = 6
|
||||
ncclFloat32 = 7
|
||||
ncclFloat = 7
|
||||
ncclFloat64 = 8
|
||||
ncclDouble = 8
|
||||
ncclBfloat16 = 9
|
||||
ncclFloat8e4m3 = 10
|
||||
ncclNumTypes = 11
|
||||
|
||||
@classmethod
|
||||
def from_torch(cls, dtype: torch.dtype) -> int:
|
||||
if dtype == torch.int8:
|
||||
return cls.ncclInt8
|
||||
if dtype == torch.uint8:
|
||||
return cls.ncclUint8
|
||||
if dtype == torch.int32:
|
||||
return cls.ncclInt32
|
||||
if dtype == torch.int64:
|
||||
return cls.ncclInt64
|
||||
if dtype == torch.float16:
|
||||
return cls.ncclFloat16
|
||||
if dtype == torch.float32:
|
||||
return cls.ncclFloat32
|
||||
if dtype == torch.float64:
|
||||
return cls.ncclFloat64
|
||||
if dtype == torch.bfloat16:
|
||||
return cls.ncclBfloat16
|
||||
if dtype == current_platform.fp8_dtype():
|
||||
return cls.ncclFloat8e4m3
|
||||
raise ValueError(
|
||||
f"Unsupported dtype {dtype}: should be one of "
|
||||
f"int8, uint8, int32, int64, float16, float32, float64, bfloat16,"
|
||||
" float8e4m3."
|
||||
)
|
||||
|
||||
|
||||
ncclRedOp_t = ctypes.c_int
|
||||
|
||||
|
||||
class ncclRedOpTypeEnum:
|
||||
ncclSum = 0
|
||||
ncclProd = 1
|
||||
ncclMax = 2
|
||||
ncclMin = 3
|
||||
ncclAvg = 4
|
||||
ncclNumOps = 5
|
||||
|
||||
@classmethod
|
||||
def from_torch(cls, op: ReduceOp) -> int:
|
||||
if op == ReduceOp.SUM:
|
||||
return cls.ncclSum
|
||||
if op == ReduceOp.PRODUCT:
|
||||
return cls.ncclProd
|
||||
if op == ReduceOp.MAX:
|
||||
return cls.ncclMax
|
||||
if op == ReduceOp.MIN:
|
||||
return cls.ncclMin
|
||||
if op == ReduceOp.AVG:
|
||||
return cls.ncclAvg
|
||||
raise ValueError(f"Unsupported op: {op}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class Function:
|
||||
name: str
|
||||
restype: Any
|
||||
argtypes: list[Any]
|
||||
|
||||
|
||||
class NCCLLibrary:
|
||||
exported_functions = [
|
||||
# const char* ncclGetErrorString(ncclResult_t result)
|
||||
Function("ncclGetErrorString", ctypes.c_char_p, [ncclResult_t]),
|
||||
# ncclResult_t ncclGetVersion(int *version);
|
||||
Function("ncclGetVersion", ncclResult_t, [ctypes.POINTER(ctypes.c_int)]),
|
||||
# ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId);
|
||||
Function("ncclGetUniqueId", ncclResult_t, [ctypes.POINTER(ncclUniqueId)]),
|
||||
# ncclResult_t ncclCommInitRank(
|
||||
# ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank);
|
||||
# note that ncclComm_t is a pointer type, so the first argument
|
||||
# is a pointer to a pointer
|
||||
Function(
|
||||
"ncclCommInitRank",
|
||||
ncclResult_t,
|
||||
[ctypes.POINTER(ncclComm_t), ctypes.c_int, ncclUniqueId, ctypes.c_int],
|
||||
),
|
||||
# ncclResult_t ncclAllReduce(
|
||||
# const void* sendbuff, void* recvbuff, size_t count,
|
||||
# ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,
|
||||
# cudaStream_t stream);
|
||||
# note that cudaStream_t is a pointer type, so the last argument
|
||||
# is a pointer
|
||||
Function(
|
||||
"ncclAllReduce",
|
||||
ncclResult_t,
|
||||
[
|
||||
buffer_type,
|
||||
buffer_type,
|
||||
ctypes.c_size_t,
|
||||
ncclDataType_t,
|
||||
ncclRedOp_t,
|
||||
ncclComm_t,
|
||||
cudaStream_t,
|
||||
],
|
||||
),
|
||||
# ncclResult_t ncclReduce(
|
||||
# const void* sendbuff, void* recvbuff, size_t count,
|
||||
# ncclDataType_t datatype, ncclRedOp_t op, int root,
|
||||
# ncclComm_t comm, cudaStream_t stream);
|
||||
# note that cudaStream_t is a pointer type, so the last argument
|
||||
# is a pointer
|
||||
Function(
|
||||
"ncclReduce",
|
||||
ncclResult_t,
|
||||
[
|
||||
buffer_type,
|
||||
buffer_type,
|
||||
ctypes.c_size_t,
|
||||
ncclDataType_t,
|
||||
ncclRedOp_t,
|
||||
ctypes.c_int,
|
||||
ncclComm_t,
|
||||
cudaStream_t,
|
||||
],
|
||||
),
|
||||
# ncclResult_t ncclAllGather(
|
||||
# const void* sendbuff, void* recvbuff, size_t count,
|
||||
# ncclDataType_t datatype, ncclComm_t comm,
|
||||
# cudaStream_t stream);
|
||||
# note that cudaStream_t is a pointer type, so the last argument
|
||||
# is a pointer
|
||||
Function(
|
||||
"ncclAllGather",
|
||||
ncclResult_t,
|
||||
[
|
||||
buffer_type,
|
||||
buffer_type,
|
||||
ctypes.c_size_t,
|
||||
ncclDataType_t,
|
||||
ncclComm_t,
|
||||
cudaStream_t,
|
||||
],
|
||||
),
|
||||
# ncclResult_t ncclReduceScatter(
|
||||
# const void* sendbuff, void* recvbuff, size_t count,
|
||||
# ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,
|
||||
# cudaStream_t stream);
|
||||
# note that cudaStream_t is a pointer type, so the last argument
|
||||
# is a pointer
|
||||
Function(
|
||||
"ncclReduceScatter",
|
||||
ncclResult_t,
|
||||
[
|
||||
buffer_type,
|
||||
buffer_type,
|
||||
ctypes.c_size_t,
|
||||
ncclDataType_t,
|
||||
ncclRedOp_t,
|
||||
ncclComm_t,
|
||||
cudaStream_t,
|
||||
],
|
||||
),
|
||||
# ncclResult_t ncclSend(
|
||||
# const void* sendbuff, size_t count, ncclDataType_t datatype,
|
||||
# int dest, ncclComm_t comm, cudaStream_t stream);
|
||||
Function(
|
||||
"ncclSend",
|
||||
ncclResult_t,
|
||||
[
|
||||
buffer_type,
|
||||
ctypes.c_size_t,
|
||||
ncclDataType_t,
|
||||
ctypes.c_int,
|
||||
ncclComm_t,
|
||||
cudaStream_t,
|
||||
],
|
||||
),
|
||||
# ncclResult_t ncclRecv(
|
||||
# void* recvbuff, size_t count, ncclDataType_t datatype,
|
||||
# int src, ncclComm_t comm, cudaStream_t stream);
|
||||
Function(
|
||||
"ncclRecv",
|
||||
ncclResult_t,
|
||||
[
|
||||
buffer_type,
|
||||
ctypes.c_size_t,
|
||||
ncclDataType_t,
|
||||
ctypes.c_int,
|
||||
ncclComm_t,
|
||||
cudaStream_t,
|
||||
],
|
||||
),
|
||||
# ncclResult_t ncclBroadcast(
|
||||
# const void* sendbuff, void* recvbuff, size_t count,
|
||||
# ncclDataType_t datatype, int root, ncclComm_t comm,
|
||||
# cudaStream_t stream);
|
||||
Function(
|
||||
"ncclBroadcast",
|
||||
ncclResult_t,
|
||||
[
|
||||
buffer_type,
|
||||
buffer_type,
|
||||
ctypes.c_size_t,
|
||||
ncclDataType_t,
|
||||
ctypes.c_int,
|
||||
ncclComm_t,
|
||||
cudaStream_t,
|
||||
],
|
||||
),
|
||||
# be cautious! this is a collective call, it will block until all
|
||||
# processes in the communicator have called this function.
|
||||
# because Python object destruction can happen in random order,
|
||||
# it is better not to call it at all.
|
||||
# ncclResult_t ncclCommDestroy(ncclComm_t comm);
|
||||
Function("ncclCommDestroy", ncclResult_t, [ncclComm_t]),
|
||||
# ncclResult_t ncclGroupStart();
|
||||
Function("ncclGroupStart", ncclResult_t, []),
|
||||
# ncclResult_t ncclGroupEnd();
|
||||
Function("ncclGroupEnd", ncclResult_t, []),
|
||||
# ncclResult_t ncclCommWindowRegister(
|
||||
# ncclComm_t comm, void* buff, size_t size,
|
||||
# ncclWindow_t* win, int winFlags);
|
||||
Function(
|
||||
"ncclCommWindowRegister",
|
||||
ncclResult_t,
|
||||
[
|
||||
ncclComm_t,
|
||||
buffer_type,
|
||||
ctypes.c_size_t,
|
||||
ctypes.POINTER(ncclWindow_t),
|
||||
ctypes.c_int,
|
||||
],
|
||||
),
|
||||
# ncclResult_t ncclCommWindowDeregister(
|
||||
# ncclComm_t comm, ncclWindow_t win);
|
||||
Function("ncclCommWindowDeregister", ncclResult_t, [ncclComm_t, ncclWindow_t]),
|
||||
]
|
||||
|
||||
# class attribute to store the mapping from the path to the library
|
||||
# to avoid loading the same library multiple times
|
||||
path_to_library_cache: dict[str, Any] = {}
|
||||
|
||||
# class attribute to store the mapping from library path
|
||||
# to the corresponding dictionary
|
||||
path_to_dict_mapping: dict[str, dict[str, Any]] = {}
|
||||
|
||||
def __init__(self, so_file: str | None = None):
|
||||
so_file = so_file or find_nccl_library()
|
||||
|
||||
try:
|
||||
if so_file not in NCCLLibrary.path_to_dict_mapping:
|
||||
lib = ctypes.CDLL(so_file)
|
||||
NCCLLibrary.path_to_library_cache[so_file] = lib
|
||||
self.lib = NCCLLibrary.path_to_library_cache[so_file]
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to load NCCL library from %s. "
|
||||
"It is expected if you are not running on NVIDIA/AMD GPUs."
|
||||
"Otherwise, the nccl library might not exist, be corrupted "
|
||||
"or it does not support the current platform %s. "
|
||||
"If you already have the library, please set the "
|
||||
"environment variable VLLM_NCCL_SO_PATH"
|
||||
" to point to the correct nccl library path.",
|
||||
so_file,
|
||||
platform.platform(),
|
||||
)
|
||||
raise e
|
||||
|
||||
if so_file not in NCCLLibrary.path_to_dict_mapping:
|
||||
_funcs: dict[str, Any] = {}
|
||||
for func in NCCLLibrary.exported_functions:
|
||||
try:
|
||||
f = getattr(self.lib, func.name)
|
||||
f.restype = func.restype
|
||||
f.argtypes = func.argtypes
|
||||
_funcs[func.name] = f
|
||||
except AttributeError:
|
||||
if func.name in [
|
||||
"ncclCommWindowRegister",
|
||||
"ncclCommWindowDeregister",
|
||||
]:
|
||||
if envs.VLLM_USE_NCCL_SYMM_MEM:
|
||||
logger.warning_once(
|
||||
"The symbol %s is not found in the NCCL "
|
||||
"library %s. To enable VLLM_USE_NCCL_SYMM_MEM "
|
||||
" please update your NCCL version to >= "
|
||||
"2.27.03.",
|
||||
func.name,
|
||||
so_file,
|
||||
)
|
||||
if current_platform.is_rocm():
|
||||
# Having an exception here on ROCm platform is
|
||||
# not allowed during graph capturing
|
||||
continue
|
||||
raise
|
||||
NCCLLibrary.path_to_dict_mapping[so_file] = _funcs
|
||||
self._funcs = NCCLLibrary.path_to_dict_mapping[so_file]
|
||||
|
||||
def ncclGetErrorString(self, result: ncclResult_t) -> str:
|
||||
return self._funcs["ncclGetErrorString"](result).decode("utf-8")
|
||||
|
||||
def NCCL_CHECK(self, result: ncclResult_t) -> None:
|
||||
if result != 0:
|
||||
error_str = self.ncclGetErrorString(result)
|
||||
raise RuntimeError(f"NCCL error: {error_str}")
|
||||
|
||||
def ncclGetRawVersion(self) -> int:
|
||||
version = ctypes.c_int()
|
||||
self.NCCL_CHECK(self._funcs["ncclGetVersion"](ctypes.byref(version)))
|
||||
# something like 21903
|
||||
return version.value
|
||||
|
||||
def ncclGetVersion(self) -> str:
|
||||
version_str = str(self.ncclGetRawVersion())
|
||||
# something like 21903 --> "2.19.3"
|
||||
major = version_str[0].lstrip("0")
|
||||
minor = version_str[1:3].lstrip("0")
|
||||
patch = version_str[3:].lstrip("0")
|
||||
return f"{major}.{minor}.{patch}"
|
||||
|
||||
def ncclGetUniqueId(self) -> ncclUniqueId:
|
||||
unique_id = ncclUniqueId()
|
||||
self.NCCL_CHECK(self._funcs["ncclGetUniqueId"](ctypes.byref(unique_id)))
|
||||
return unique_id
|
||||
|
||||
def unique_id_from_bytes(self, data: bytes) -> ncclUniqueId:
|
||||
if len(data) != 128:
|
||||
raise ValueError(
|
||||
f"Expected 128 bytes for ncclUniqueId, got {len(data)} bytes"
|
||||
)
|
||||
unique_id = ncclUniqueId()
|
||||
ctypes.memmove(ctypes.addressof(unique_id.internal), data, 128)
|
||||
return unique_id
|
||||
|
||||
def ncclCommInitRank(
|
||||
self, world_size: int, unique_id: ncclUniqueId, rank: int
|
||||
) -> ncclComm_t:
|
||||
comm = ncclComm_t()
|
||||
self.NCCL_CHECK(
|
||||
self._funcs["ncclCommInitRank"](
|
||||
ctypes.byref(comm), world_size, unique_id, rank
|
||||
)
|
||||
)
|
||||
return comm
|
||||
|
||||
def ncclAllReduce(
|
||||
self,
|
||||
sendbuff: buffer_type,
|
||||
recvbuff: buffer_type,
|
||||
count: int,
|
||||
datatype: int,
|
||||
op: int,
|
||||
comm: ncclComm_t,
|
||||
stream: cudaStream_t,
|
||||
) -> None:
|
||||
# `datatype` actually should be `ncclDataType_t`
|
||||
# and `op` should be `ncclRedOp_t`
|
||||
# both are aliases of `ctypes.c_int`
|
||||
# when we pass int to a function, it will be converted to `ctypes.c_int`
|
||||
# by ctypes automatically
|
||||
self.NCCL_CHECK(
|
||||
self._funcs["ncclAllReduce"](
|
||||
sendbuff, recvbuff, count, datatype, op, comm, stream
|
||||
)
|
||||
)
|
||||
|
||||
def ncclReduce(
|
||||
self,
|
||||
sendbuff: buffer_type,
|
||||
recvbuff: buffer_type,
|
||||
count: int,
|
||||
datatype: int,
|
||||
op: int,
|
||||
root: int,
|
||||
comm: ncclComm_t,
|
||||
stream: cudaStream_t,
|
||||
) -> None:
|
||||
# `datatype` actually should be `ncclDataType_t`
|
||||
# and `op` should be `ncclRedOp_t`
|
||||
# both are aliases of `ctypes.c_int`
|
||||
# when we pass int to a function, it will be converted to `ctypes.c_int`
|
||||
# by ctypes automatically
|
||||
self.NCCL_CHECK(
|
||||
self._funcs["ncclReduce"](
|
||||
sendbuff, recvbuff, count, datatype, op, root, comm, stream
|
||||
)
|
||||
)
|
||||
|
||||
def ncclReduceScatter(
|
||||
self,
|
||||
sendbuff: buffer_type,
|
||||
recvbuff: buffer_type,
|
||||
count: int,
|
||||
datatype: int,
|
||||
op: int,
|
||||
comm: ncclComm_t,
|
||||
stream: cudaStream_t,
|
||||
) -> None:
|
||||
# `datatype` actually should be `ncclDataType_t`
|
||||
# and `op` should be `ncclRedOp_t`
|
||||
# both are aliases of `ctypes.c_int`
|
||||
# when we pass int to a function, it will be converted to `ctypes.c_int`
|
||||
# by ctypes automatically
|
||||
self.NCCL_CHECK(
|
||||
self._funcs["ncclReduceScatter"](
|
||||
sendbuff, recvbuff, count, datatype, op, comm, stream
|
||||
)
|
||||
)
|
||||
|
||||
def ncclAllGather(
|
||||
self,
|
||||
sendbuff: buffer_type,
|
||||
recvbuff: buffer_type,
|
||||
count: int,
|
||||
datatype: int,
|
||||
comm: ncclComm_t,
|
||||
stream: cudaStream_t,
|
||||
) -> None:
|
||||
# `datatype` actually should be `ncclDataType_t`
|
||||
# which is an aliases of `ctypes.c_int`
|
||||
# when we pass int to a function, it will be converted to `ctypes.c_int`
|
||||
# by ctypes automatically
|
||||
self.NCCL_CHECK(
|
||||
self._funcs["ncclAllGather"](
|
||||
sendbuff, recvbuff, count, datatype, comm, stream
|
||||
)
|
||||
)
|
||||
|
||||
def ncclSend(
|
||||
self,
|
||||
sendbuff: buffer_type,
|
||||
count: int,
|
||||
datatype: int,
|
||||
dest: int,
|
||||
comm: ncclComm_t,
|
||||
stream: cudaStream_t,
|
||||
) -> None:
|
||||
self.NCCL_CHECK(
|
||||
self._funcs["ncclSend"](sendbuff, count, datatype, dest, comm, stream)
|
||||
)
|
||||
|
||||
def ncclRecv(
|
||||
self,
|
||||
recvbuff: buffer_type,
|
||||
count: int,
|
||||
datatype: int,
|
||||
src: int,
|
||||
comm: ncclComm_t,
|
||||
stream: cudaStream_t,
|
||||
) -> None:
|
||||
self.NCCL_CHECK(
|
||||
self._funcs["ncclRecv"](recvbuff, count, datatype, src, comm, stream)
|
||||
)
|
||||
|
||||
def ncclBroadcast(
|
||||
self,
|
||||
sendbuff: buffer_type,
|
||||
recvbuff: buffer_type,
|
||||
count: int,
|
||||
datatype: int,
|
||||
root: int,
|
||||
comm: ncclComm_t,
|
||||
stream: cudaStream_t,
|
||||
) -> None:
|
||||
self.NCCL_CHECK(
|
||||
self._funcs["ncclBroadcast"](
|
||||
sendbuff, recvbuff, count, datatype, root, comm, stream
|
||||
)
|
||||
)
|
||||
|
||||
def ncclCommDestroy(self, comm: ncclComm_t) -> None:
|
||||
self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm))
|
||||
|
||||
def ncclGroupStart(self) -> None:
|
||||
self.NCCL_CHECK(self._funcs["ncclGroupStart"]())
|
||||
|
||||
def ncclGroupEnd(self) -> None:
|
||||
self.NCCL_CHECK(self._funcs["ncclGroupEnd"]())
|
||||
|
||||
def ncclCommWindowRegister(
|
||||
self, comm: ncclComm_t, buff: buffer_type, size: int, win_flags: int
|
||||
) -> ncclWindow_t:
|
||||
window = ncclWindow_t()
|
||||
self.NCCL_CHECK(
|
||||
self._funcs["ncclCommWindowRegister"](
|
||||
comm, buff, size, ctypes.byref(window), win_flags
|
||||
)
|
||||
)
|
||||
return window
|
||||
|
||||
def ncclCommWindowDeregister(self, comm: ncclComm_t, window: ncclWindow_t) -> None:
|
||||
self.NCCL_CHECK(self._funcs["ncclCommWindowDeregister"](comm, window))
|
||||
|
||||
|
||||
__all__ = [
|
||||
"NCCLLibrary",
|
||||
"ncclDataTypeEnum",
|
||||
"ncclRedOpTypeEnum",
|
||||
"ncclUniqueId",
|
||||
"ncclComm_t",
|
||||
"cudaStream_t",
|
||||
"buffer_type",
|
||||
]
|
||||
290
vllm/distributed/device_communicators/quick_all_reduce.py
Normal file
290
vllm/distributed/device_communicators/quick_all_reduce.py
Normal file
@@ -0,0 +1,290 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from enum import Enum
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.config import get_current_vllm_config_or_none
|
||||
from vllm.distributed.parallel_state import in_the_same_node_as
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import cuda_device_count_stateless
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
try:
|
||||
ops.qr_max_size()
|
||||
quick_ar = True
|
||||
except Exception:
|
||||
# For CPUs and CUDA
|
||||
quick_ar = False
|
||||
|
||||
|
||||
def is_weak_contiguous(inp: torch.Tensor):
|
||||
return inp.is_contiguous() or (
|
||||
inp.storage().nbytes() - inp.storage_offset() * inp.element_size()
|
||||
== inp.numel() * inp.element_size()
|
||||
)
|
||||
|
||||
|
||||
class QuickReduceRegime(Enum):
|
||||
FP = 0
|
||||
INT8 = 1
|
||||
INT6 = 2
|
||||
INT4 = 3
|
||||
NONE = 4
|
||||
|
||||
|
||||
MB = 1024 * 1024
|
||||
|
||||
|
||||
class QuickAllReduce:
|
||||
_SUPPORTED_WORLD_SIZES = [2, 4, 8]
|
||||
_SUPPORTED_DTYPES = [torch.float16, torch.bfloat16]
|
||||
# The following data is based on kernel tests.
|
||||
# In this order [FP, INT8, INT6, INT4].
|
||||
_QR_MIN_SIZE = {
|
||||
(torch.float16, 2): [1 * MB, 2 * MB, 2 * MB, 1 * MB],
|
||||
(torch.float16, 4): [1 * MB, 16 * MB, 4 * MB, 2 * MB],
|
||||
(torch.float16, 8): [16 * MB, 4 * MB, 4 * MB, 2 * MB],
|
||||
(torch.bfloat16, 2): [2 * MB, 8 * MB, 8 * MB, 8 * MB],
|
||||
(torch.bfloat16, 4): [8 * MB, 64 * MB, 64 * MB, 16 * MB],
|
||||
(torch.bfloat16, 8): [16 * MB, 2048 * MB, 2048 * MB, 2048 * MB],
|
||||
}
|
||||
|
||||
def __init__(self, group: ProcessGroup, device: int | str | torch.device) -> None:
|
||||
"""
|
||||
Custom allreduce provides non-destructive acceleration and is
|
||||
available for CUDA and ROCm MI300 series.
|
||||
|
||||
Custom quick allreduce leverages quantization for further
|
||||
acceleration on ROCm. It currently supports Q8, Q6, and Q4
|
||||
quantization formats and FP(float16, bfloat16).
|
||||
|
||||
Quick allreduce is designed as a complement to custom allreduce.
|
||||
Its initialization requires even stricter conditions.
|
||||
|
||||
Only the ROCm MI300 series is supported for quick allreduce at
|
||||
this time.
|
||||
|
||||
Args:
|
||||
group: the process group to work on. If None, it will use the
|
||||
default process group.
|
||||
device: the device to bind the CustomAllreduce to. If None,
|
||||
it will be bound to f"cuda:{local_rank}".
|
||||
It is the caller's responsibility to make sure each communicator
|
||||
is bind to a unique device, and all communicators in this group
|
||||
are in the same node.
|
||||
"""
|
||||
self.disabled = True
|
||||
if not self._rocm_arch_available():
|
||||
logger.debug(
|
||||
"Custom quick allreduce is only supported on ROCm MI300 series."
|
||||
)
|
||||
return
|
||||
|
||||
if not quick_ar:
|
||||
# disable because of missing quick reduce library
|
||||
# e.g. in a cuda environment
|
||||
logger.info(
|
||||
"Custom quick allreduce is disabled because "
|
||||
"of missing custom quick allreduce library"
|
||||
)
|
||||
return
|
||||
|
||||
self.group = group
|
||||
assert dist.get_backend(group) != dist.Backend.NCCL, (
|
||||
"Custom quick allreduce should be attached to a non-NCCL group."
|
||||
)
|
||||
if not all(in_the_same_node_as(group, source_rank=0)):
|
||||
# No need to initialize custom quick allreduce for
|
||||
# multi-node case.
|
||||
logger.warning(
|
||||
"Custom quick allreduce is disabled because this "
|
||||
"process group spans across nodes."
|
||||
)
|
||||
return
|
||||
rank = dist.get_rank(group=self.group)
|
||||
world_size = dist.get_world_size(group=self.group)
|
||||
self.rank = rank
|
||||
self.world_size = world_size
|
||||
if world_size == 1:
|
||||
# No need to initialize QuickReduce for single GPU case.
|
||||
return
|
||||
|
||||
if world_size not in QuickAllReduce._SUPPORTED_WORLD_SIZES:
|
||||
logger.warning(
|
||||
"Custom quick allreduce is disabled due to an "
|
||||
"unsupported world size: %d. Supported world sizes: %s.",
|
||||
world_size,
|
||||
str(QuickAllReduce._SUPPORTED_WORLD_SIZES),
|
||||
)
|
||||
return
|
||||
|
||||
if isinstance(device, int):
|
||||
device = torch.device(f"cuda:{device}")
|
||||
elif isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
assert isinstance(device, torch.device)
|
||||
self.device = device
|
||||
|
||||
cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
|
||||
if cuda_visible_devices:
|
||||
device_ids = list(map(int, cuda_visible_devices.split(",")))
|
||||
else:
|
||||
device_ids = list(range(cuda_device_count_stateless()))
|
||||
physical_device_id = device_ids[device.index]
|
||||
tensor = torch.tensor([physical_device_id], dtype=torch.int, device="cpu")
|
||||
gather_list = [
|
||||
torch.tensor([0], dtype=torch.int, device="cpu")
|
||||
for _ in range(self.world_size)
|
||||
]
|
||||
dist.all_gather(gather_list, tensor, group=self.group)
|
||||
physical_device_ids = [t.item() for t in gather_list]
|
||||
|
||||
# test nvlink first, this will filter out most of the cases
|
||||
# where custom quick allreduce is not supported
|
||||
# this checks hardware and driver support for NVLink
|
||||
assert current_platform.is_cuda_alike()
|
||||
self.fully_connected = current_platform.is_fully_connected(physical_device_ids)
|
||||
if self.world_size > 2 and not self.fully_connected:
|
||||
logger.debug(
|
||||
"Custom quick allreduce is disabled because it's not supported "
|
||||
"on more than two PCIe-only GPUs. "
|
||||
)
|
||||
return
|
||||
|
||||
self.init_quick_all_reduce()
|
||||
|
||||
def init_quick_all_reduce(self):
|
||||
# On RocM, bfloat16 kernels are slower than fp16
|
||||
# due to slower match operations
|
||||
# If environment variable is set to 1, we convert input to fp16
|
||||
self.use_fp16_kernels = envs.VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16
|
||||
regime_str = envs.VLLM_ROCM_QUICK_REDUCE_QUANTIZATION
|
||||
if regime_str not in QuickReduceRegime.__members__:
|
||||
logger.warning(
|
||||
"Custom quick allreduce:",
|
||||
f"Invalid quantization level: {regime_str}. "
|
||||
"Supported levels: "
|
||||
f"{list(QuickReduceRegime.__members__.keys())}",
|
||||
)
|
||||
return
|
||||
|
||||
if regime_str == "NONE":
|
||||
logger.debug(
|
||||
"Custom quick allreduce is disabled based "
|
||||
"on env variable "
|
||||
"VLLM_ROCM_QUICK_REDUCE_QUANTIZATION='NONE'"
|
||||
)
|
||||
return
|
||||
self.qr_quant_level = QuickReduceRegime[regime_str]
|
||||
vllm_config = get_current_vllm_config_or_none()
|
||||
if (
|
||||
vllm_config is not None
|
||||
and hasattr(vllm_config, "model_config")
|
||||
and hasattr(vllm_config.model_config, "dtype")
|
||||
):
|
||||
dtype = vllm_config.model_config.dtype
|
||||
if dtype not in [torch.float16, torch.bfloat16]:
|
||||
logger.debug(
|
||||
"Custom quick allreduce disabled: only supports "
|
||||
"float16 and float16, but get %s.",
|
||||
dtype,
|
||||
)
|
||||
return
|
||||
|
||||
if dtype == torch.bfloat16 and self.use_fp16_kernels:
|
||||
logger.info(
|
||||
"Custom quick allreduce: BF16 inputs will be converted "
|
||||
"to FP16 to improve performance. set "
|
||||
"envs.VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16=0 "
|
||||
"to turn off."
|
||||
)
|
||||
|
||||
# VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB is specified in MB
|
||||
qr_max_size = envs.VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB
|
||||
if qr_max_size is not None:
|
||||
if qr_max_size < 1:
|
||||
logger.info(
|
||||
"You should not set a max_size smaller than 1MB, which can "
|
||||
"lead to error or degradation to custom allreduce or rccl."
|
||||
)
|
||||
qr_max_size = qr_max_size * MB
|
||||
self._ptr = ops.init_custom_qr(self.rank, self.world_size, qr_max_size)
|
||||
self.qr_max_size = qr_max_size if qr_max_size is not None else ops.qr_max_size()
|
||||
self.create_shared_buffer()
|
||||
self.disabled = False
|
||||
|
||||
def _rocm_arch_available(self):
|
||||
if not current_platform.is_rocm():
|
||||
return False
|
||||
try:
|
||||
props = torch.cuda.get_device_properties(0)
|
||||
gcn_arch = getattr(props, "gcnArchName", "")
|
||||
supported_archs = ["gfx94", "gfx95"]
|
||||
return any(gfx in gcn_arch for gfx in supported_archs)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to determine ROCm for quick allreduce: %s", e)
|
||||
return False
|
||||
|
||||
def create_shared_buffer(self):
|
||||
"""
|
||||
Creates a shared buffer for quickreduce.
|
||||
Has to be called after init_custom_qr
|
||||
"""
|
||||
handle = ops.qr_get_handle(self._ptr)
|
||||
world_size = dist.get_world_size(group=self.group)
|
||||
handles = [None] * world_size
|
||||
dist.all_gather_object(handles, handle, group=self.group)
|
||||
ops.qr_open_handles(self._ptr, handles)
|
||||
|
||||
def should_quick_allreduce(self, inp: torch.Tensor):
|
||||
"""
|
||||
Check if quickreduce is available
|
||||
"""
|
||||
if self.disabled:
|
||||
return False
|
||||
if inp.dtype not in self._SUPPORTED_DTYPES:
|
||||
return False
|
||||
inp_size = inp.numel() * inp.element_size()
|
||||
# custom quick allreduce requires input byte size to be
|
||||
# multiples of 16
|
||||
if inp_size % 16 != 0:
|
||||
return False
|
||||
if not is_weak_contiguous(inp):
|
||||
return False
|
||||
dtype = inp.dtype
|
||||
if self.use_fp16_kernels:
|
||||
dtype = torch.float16
|
||||
return (
|
||||
inp_size <= self.qr_max_size
|
||||
and inp_size
|
||||
>= self._QR_MIN_SIZE[(dtype, self.world_size)][self.qr_quant_level.value]
|
||||
)
|
||||
|
||||
def quick_all_reduce(self, inp: torch.Tensor, *, out: torch.Tensor = None):
|
||||
"""Performs an out-of-place custom quick all reduce."""
|
||||
# quick allreduce doesn't require a separate graph mode,
|
||||
# as QR uses static IPC buffer.
|
||||
if out is None:
|
||||
out = torch.empty_like(inp)
|
||||
ops.qr_all_reduce(
|
||||
self._ptr, inp, out, self.qr_quant_level.value, self.use_fp16_kernels
|
||||
)
|
||||
return out
|
||||
|
||||
def close(self):
|
||||
if not self.disabled and getattr(self, "_ptr", None):
|
||||
if ops is not None:
|
||||
ops.qr_destroy(self._ptr)
|
||||
self._ptr = 0
|
||||
self.disabled = True
|
||||
|
||||
def __del__(self):
|
||||
self.close()
|
||||
259
vllm/distributed/device_communicators/ray_communicator.py
Normal file
259
vllm/distributed/device_communicators/ray_communicator.py
Normal file
@@ -0,0 +1,259 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
import ray
|
||||
import torch
|
||||
from ray.exceptions import RayChannelError
|
||||
from ray.experimental.channel.communicator import Communicator, TorchTensorAllocator
|
||||
from torch.distributed import ReduceOp
|
||||
|
||||
from vllm.distributed.device_communicators.base_device_communicator import (
|
||||
DeviceCommunicatorBase,
|
||||
)
|
||||
from vllm.distributed.parallel_state import get_pp_group
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.torch_utils import current_stream
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class RayPPCommunicator(Communicator):
|
||||
"""
|
||||
Communicator to be used for pipeline parallelism in Ray Compiled Graph.
|
||||
This is wraps around the vLLM _PP GroupCoordinator.
|
||||
|
||||
This class is not thread-safe.
|
||||
"""
|
||||
|
||||
_comm: DeviceCommunicatorBase | None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
world_size: int,
|
||||
comm_id: Any,
|
||||
rank: int | None,
|
||||
actor_handles: list["ray.actor.ActorHandle"],
|
||||
cuda_stream: torch.cuda.Stream | None,
|
||||
use_communication_streams: bool = False,
|
||||
):
|
||||
"""
|
||||
Initialize a RayPPCommunicator that can be used to communicate with
|
||||
other Ray Compiled Graph actors for pipeline parallelism.
|
||||
|
||||
Args:
|
||||
world_size: The number of participating actors.
|
||||
comm_id: A unique communicator ID. This is just to conform with
|
||||
the Ray Communicator API and is not used.
|
||||
rank: The rank of this actor. If None, then the caller is not a
|
||||
participant of the RayPPCommunicator group (e.g., the Ray
|
||||
driver).
|
||||
actor_handles: A list of actor handles.
|
||||
cuda_stream: A CUDA stream to dispatch communication ops to. This
|
||||
is not supported.
|
||||
use_communication_streams: Whether to use communication streams.
|
||||
This is not supported.
|
||||
"""
|
||||
self._world_size = world_size
|
||||
self._rank: int | None = None
|
||||
self._actor_handles = actor_handles
|
||||
if use_communication_streams:
|
||||
raise NotImplementedError("use_communication_streams is not supported")
|
||||
if cuda_stream is not None and cuda_stream != current_stream():
|
||||
raise ValueError(
|
||||
"cuda_stream other than the current stream is not supported"
|
||||
)
|
||||
|
||||
if rank is not None:
|
||||
# Rank is not None, this is Ray worker
|
||||
assert ray.get_gpu_ids(), "RayPPCommunicator has no GPUs assigned"
|
||||
|
||||
self._comm = get_pp_group().device_communicator
|
||||
assert self._comm is not None
|
||||
|
||||
# Since we wrap around the vLLM _PP communicator, we use
|
||||
# the rank from the vLLM communicator, and ignore the rank
|
||||
# passed in from Ray.
|
||||
# TODO(rui): refactor the Ray Communicator API so that
|
||||
# it also supports no rank passed in.
|
||||
self._rank = self._comm.rank_in_group
|
||||
|
||||
self._build_actor_rank_mapping()
|
||||
else:
|
||||
# Rank is None, this is Ray driver
|
||||
self._comm = None
|
||||
|
||||
self._closed = False
|
||||
|
||||
def _build_actor_rank_mapping(self):
|
||||
"""
|
||||
Use collective communication to build a mapping from actor IDs to ranks.
|
||||
This should be called once during initialization.
|
||||
"""
|
||||
if self._comm is None:
|
||||
return {}
|
||||
|
||||
current_actor = ray.get_runtime_context().current_actor
|
||||
actor_id_str = current_actor._actor_id.hex()
|
||||
|
||||
# Ray actor IDs are 32-character hex strings (128 bits)
|
||||
ACTOR_ID_LEN = 32
|
||||
actor_id_bytes = bytearray(actor_id_str.encode("utf-8"))
|
||||
assert len(actor_id_bytes) == ACTOR_ID_LEN, (
|
||||
f"Unexpected actor ID length: {len(actor_id_bytes)}"
|
||||
)
|
||||
|
||||
actor_id_tensor = torch.frombuffer(actor_id_bytes, dtype=torch.uint8).to(
|
||||
self._comm.device
|
||||
)
|
||||
|
||||
# All-gather full actor IDs from all actors
|
||||
gathered_ids = self._comm.all_gather(actor_id_tensor, dim=0)
|
||||
|
||||
# Build mapping: actor_id -> device_comm_rank
|
||||
self._actor_id_to_rank = {}
|
||||
for rank in range(self._world_size):
|
||||
start_idx = rank * ACTOR_ID_LEN
|
||||
end_idx = (rank + 1) * ACTOR_ID_LEN
|
||||
actor_bytes = gathered_ids[start_idx:end_idx].cpu().numpy().tobytes()
|
||||
actor_id = actor_bytes.decode("utf-8")
|
||||
self._actor_id_to_rank[actor_id] = rank
|
||||
|
||||
def initialize(self, rank: int) -> None:
|
||||
# No additional initialization is needed.
|
||||
pass
|
||||
|
||||
def get_actor_handles(self) -> list["ray.actor.ActorHandle"]:
|
||||
return self._actor_handles
|
||||
|
||||
def get_rank(self, actor: ray.actor.ActorHandle) -> int:
|
||||
"""
|
||||
Return the given actor's rank using device communicator collective ops.
|
||||
"""
|
||||
assert hasattr(self, "_actor_id_to_rank"), (
|
||||
"Actor rank mapping not built. "
|
||||
"This should have been done during initialization."
|
||||
)
|
||||
|
||||
actor_id_str = actor._actor_id.hex()
|
||||
|
||||
if actor_id_str in self._actor_id_to_rank:
|
||||
return self._actor_id_to_rank[actor_id_str] # type: ignore
|
||||
else:
|
||||
raise ValueError(f"Actor {actor} not found in communicator group")
|
||||
|
||||
def get_self_rank(self) -> int | None:
|
||||
"""
|
||||
Return this actor's rank.
|
||||
"""
|
||||
return self._rank
|
||||
|
||||
def get_world_size(self) -> int:
|
||||
"""
|
||||
Return the number of ranks in the RayPPCommunicator group.
|
||||
"""
|
||||
return self._world_size
|
||||
|
||||
def send(self, buf: "torch.Tensor", peer_rank: int) -> None:
|
||||
"""
|
||||
Send a torch.Tensor to a peer.
|
||||
|
||||
This returns when the send kernel has been queued, but the kernel may
|
||||
not have completed. Therefore, the caller should ensure that there are
|
||||
no concurrent writes to the sent `buf` until the send has finished.
|
||||
That is, either all writes should be submitted on the current stream
|
||||
(self._cuda_stream) or, if on a different stream, that stream should
|
||||
synchronize with the current stream.
|
||||
|
||||
Args:
|
||||
buf: The torch.Tensor to send. It should already be on this
|
||||
actor's default device.
|
||||
peer_rank: The rank of the actor to send to.
|
||||
"""
|
||||
if self._closed:
|
||||
raise RayChannelError("RayPPCommunicator has been destroyed.")
|
||||
|
||||
assert self._comm is not None
|
||||
self._comm.send(buf, peer_rank)
|
||||
|
||||
def recv(
|
||||
self,
|
||||
shape: tuple[int, ...],
|
||||
dtype: "torch.dtype",
|
||||
peer_rank: int,
|
||||
allocator: TorchTensorAllocator,
|
||||
) -> "torch.Tensor":
|
||||
"""
|
||||
Receive a torch.Tensor from a peer and synchronize the current stream.
|
||||
|
||||
After this call returns, the receive buffer is safe to read from
|
||||
any stream. An RayChannelError will be raised if an error occurred
|
||||
(e.g., remote actor died), and the buffer is not safe to read.
|
||||
|
||||
Args:
|
||||
shape: The shape of the tensor to receive.
|
||||
dtype: The dtype of the tensor to receive.
|
||||
peer_rank: The rank of the actor to receive from.
|
||||
allocator: The allocator to use to create the received tensor.
|
||||
This is ignored for this implementation.
|
||||
"""
|
||||
if self._closed:
|
||||
raise RayChannelError("RayPPCommunicator has been destroyed.")
|
||||
|
||||
assert self._comm is not None
|
||||
size = torch.Size(shape)
|
||||
buf = self._comm.recv(size, dtype, src=peer_rank)
|
||||
|
||||
# Buffer values are undefined if NCCL ops are aborted. Therefore, we
|
||||
# need to synchronize here and check that the channel is still
|
||||
# open to ensure that the receive buffer is valid.
|
||||
# TODO(swang): Avoid CUDA synchronization.
|
||||
current_stream().synchronize()
|
||||
|
||||
if self._closed:
|
||||
raise RayChannelError("RayPPCommunicator has been destroyed.")
|
||||
return buf
|
||||
|
||||
def allgather(
|
||||
self,
|
||||
send_buf: "torch.Tensor",
|
||||
recv_buf: "torch.Tensor",
|
||||
):
|
||||
raise NotImplementedError("allgather is not supported")
|
||||
|
||||
def allreduce(
|
||||
self,
|
||||
send_buf: "torch.Tensor",
|
||||
recv_buf: "torch.Tensor",
|
||||
op: ReduceOp = ReduceOp.SUM,
|
||||
):
|
||||
raise NotImplementedError("allreduce is not supported")
|
||||
|
||||
def reducescatter(
|
||||
self,
|
||||
send_buf: "torch.Tensor",
|
||||
recv_buf: "torch.Tensor",
|
||||
op: ReduceOp = ReduceOp.SUM,
|
||||
):
|
||||
raise NotImplementedError("reducescatter is not supported")
|
||||
|
||||
@property
|
||||
def recv_stream(self):
|
||||
return torch.cuda.StreamContext(current_stream())
|
||||
|
||||
@property
|
||||
def send_stream(self):
|
||||
return torch.cuda.StreamContext(current_stream())
|
||||
|
||||
def destroy(self) -> None:
|
||||
# Just sets a flag, vLLM manages the lifecycle of the underlying
|
||||
# _PP GroupCoordinator.
|
||||
self._closed = True
|
||||
|
||||
def get_transport_name(self) -> str:
|
||||
return "nccl"
|
||||
|
||||
@classmethod
|
||||
def generate_communicator_id(cls) -> Any:
|
||||
return uuid.uuid4()
|
||||
784
vllm/distributed/device_communicators/shm_broadcast.py
Normal file
784
vllm/distributed/device_communicators/shm_broadcast.py
Normal file
@@ -0,0 +1,784 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import functools
|
||||
import pickle
|
||||
import threading
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from multiprocessing import shared_memory
|
||||
from pickle import PickleBuffer
|
||||
from threading import Event
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import zmq
|
||||
from torch.distributed import ProcessGroup
|
||||
from zmq import ( # type: ignore
|
||||
IPV6, # type: ignore
|
||||
SUB,
|
||||
SUBSCRIBE,
|
||||
XPUB,
|
||||
XPUB_VERBOSE,
|
||||
Context,
|
||||
)
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.distributed.utils import StatelessProcessGroup, sched_yield
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.network_utils import (
|
||||
get_ip,
|
||||
get_open_port,
|
||||
get_open_zmq_ipc_path,
|
||||
is_valid_ipv6_address,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from _typeshed import SizedBuffer
|
||||
|
||||
VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL
|
||||
|
||||
from_bytes_big = functools.partial(int.from_bytes, byteorder="big")
|
||||
|
||||
|
||||
# Memory fence for cross-process shared memory visibility.
|
||||
# Required for correct producer-consumer synchronization when using
|
||||
# shared memory without locks.
|
||||
_memory_fence_lock = threading.Lock()
|
||||
|
||||
|
||||
def memory_fence():
|
||||
"""
|
||||
Full memory barrier for shared memory synchronization.
|
||||
|
||||
Ensures all prior memory writes are visible to other processes before
|
||||
any subsequent reads. This is critical for lock-free producer-consumer
|
||||
patterns using shared memory.
|
||||
|
||||
Implementation acquires and immediately releases a lock. Python's
|
||||
threading.Lock provides sequentially consistent memory barrier semantics
|
||||
across all major platforms (POSIX, Windows). This is a lightweight
|
||||
operation (~20ns) that guarantees:
|
||||
- All stores before the barrier are visible to other threads/processes
|
||||
- All loads after the barrier see the latest values
|
||||
"""
|
||||
# Lock acquire/release provides full memory barrier semantics.
|
||||
# Using context manager ensures lock release even on exceptions.
|
||||
with _memory_fence_lock:
|
||||
pass
|
||||
|
||||
|
||||
def to_bytes_big(value: int, size: int) -> bytes:
|
||||
return value.to_bytes(size, byteorder="big")
|
||||
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def long_wait_time_msg(threshold: int) -> str:
|
||||
return (
|
||||
"No available shared memory broadcast block found "
|
||||
f"in {threshold} seconds. This typically happens "
|
||||
"when some processes are hanging or doing some "
|
||||
"time-consuming work (e.g. compilation, "
|
||||
"weight/kv cache quantization)."
|
||||
)
|
||||
|
||||
|
||||
class SpinTimer:
|
||||
def record_activity(self):
|
||||
pass
|
||||
|
||||
def spin(self):
|
||||
sched_yield()
|
||||
|
||||
|
||||
class SpinSleepTimer(SpinTimer):
|
||||
"""
|
||||
In setups which have long inactivity periods it is desirable to reduce
|
||||
system power consumption when vllm does nothing. This would lead to more
|
||||
CPU thermal headroom when a request eventually comes, especially when
|
||||
multiple GPUs are connected as each GPU would otherwise pin one thread at
|
||||
100% CPU usage.
|
||||
|
||||
The simplest solution is to reduce polling frequency when there is no
|
||||
activity for a certain period of time.
|
||||
"""
|
||||
|
||||
def __init__(self, busy_loop_s: float = 3.0, wait_sleep_s: float = 0.1):
|
||||
self.last_activity = time.monotonic()
|
||||
self.busy_loop_s = busy_loop_s
|
||||
self.wait_sleep_s = wait_sleep_s
|
||||
|
||||
def record_activity(self):
|
||||
self.last_activity = time.monotonic()
|
||||
|
||||
def spin(self):
|
||||
curr_time = time.monotonic()
|
||||
if curr_time >= self.last_activity + self.busy_loop_s:
|
||||
time.sleep(self.wait_sleep_s)
|
||||
else:
|
||||
sched_yield()
|
||||
|
||||
|
||||
class ShmRingBuffer:
|
||||
def __init__(
|
||||
self,
|
||||
n_reader: int,
|
||||
max_chunk_bytes: int,
|
||||
max_chunks: int,
|
||||
name: str | None = None,
|
||||
):
|
||||
"""
|
||||
A shared memory ring buffer implementation for broadcast communication.
|
||||
Essentially, it is a queue where only one will `enqueue` and multiple
|
||||
will `dequeue`. The max size of each item, together with the max number
|
||||
of items that can be stored in the buffer are known in advance.
|
||||
In this case, we don't need to synchronize the access to
|
||||
the buffer.
|
||||
|
||||
Buffer memory layout:
|
||||
data metadata
|
||||
| |
|
||||
| (current_idx) | (current_idx)
|
||||
v v
|
||||
+-------------------------------+----------------------------------------+
|
||||
| chunk0 | chunk1 | ... | chunk | metadata0 | metadata1 | ... | metadata |
|
||||
+-------------------------------+----------------------------------------+
|
||||
| max_chunks x max_chunk_bytes | max_chunks x (1 + n_reader) bytes |
|
||||
|
||||
metadata memory layout: each byte is a flag, the first byte is the written
|
||||
flag, and the rest are reader flags. The flags are set to 0 by default.
|
||||
+--------------+--------------+--------------+-----+--------------+
|
||||
| written_flag | reader0_flag | reader1_flag | ... | readerN_flag |
|
||||
+--------------+--------------+--------------+-----+--------------+
|
||||
|
||||
The state of metadata is as follows:
|
||||
|
||||
(case 1) 0???...???: the block is not written yet, cannot read, can write
|
||||
(case 2) 1000...000: the block is just written, can read, cannot write
|
||||
(case 3) 1???...???: the block is written and read by some readers, can read if not read, cannot write
|
||||
(case 4) 1111...111: the block is written and read by all readers, cannot read, can write
|
||||
|
||||
State transition for readers:
|
||||
|
||||
When a reader finds a block that it can read (case 2 or 3), it can yield the block for caller to read.
|
||||
Only after the caller finishes reading the block, the reader can mark the block as read.
|
||||
Readers only mark the block as read (from 0 to 1), the writer marks the block as ready to read (from 1 to 0).
|
||||
|
||||
State transition for writer:
|
||||
|
||||
When the writer writes to a block (case 1 or 4), it first resets the written flag to 0, converting either case
|
||||
to case 1. Then it can yield the block for caller to write. After the caller finishes writing the block, the writer
|
||||
can reset the reader flags to 0, and mark the block as written (from 0 to 1).
|
||||
NOTE: the order is important here, first reset the reader flags (so that we are still in case 1), then mark the block as written. The state transition is atomic. If we do it in the reverse order, it will go through case 3 and then back to case 2, and readers might read the intermediate case 3, which is not correct.
|
||||
|
||||
During creation, `name` is None and the buffer is created. We can pass the
|
||||
created object to other processes by pickling it. The other processes will
|
||||
get the name of the shared memory and open it, so that they can access the
|
||||
same shared memory buffer.
|
||||
""" # noqa
|
||||
self.n_reader = n_reader
|
||||
self.metadata_size = 1 + n_reader
|
||||
self.max_chunk_bytes = max_chunk_bytes
|
||||
self.max_chunks = max_chunks
|
||||
self.total_bytes_of_buffer = (
|
||||
self.max_chunk_bytes + self.metadata_size
|
||||
) * self.max_chunks
|
||||
self.data_offset = 0
|
||||
self.metadata_offset = self.max_chunk_bytes * self.max_chunks
|
||||
|
||||
if name is None:
|
||||
# we are creating a buffer
|
||||
self.is_creator = True
|
||||
self.shared_memory = shared_memory.SharedMemory(
|
||||
create=True, size=self.total_bytes_of_buffer
|
||||
)
|
||||
# initialize the metadata section to 0
|
||||
with self.shared_memory.buf[self.metadata_offset :] as metadata_buffer:
|
||||
torch.frombuffer(metadata_buffer, dtype=torch.uint8).fill_(0)
|
||||
else:
|
||||
# we are opening an existing buffer
|
||||
self.is_creator = False
|
||||
# fix to https://stackoverflow.com/q/62748654/9191338
|
||||
# Python incorrectly tracks shared memory even if it is not
|
||||
# created by the process. The following patch is a workaround.
|
||||
with patch(
|
||||
"multiprocessing.resource_tracker.register",
|
||||
lambda *args, **kwargs: None,
|
||||
):
|
||||
try:
|
||||
self.shared_memory = shared_memory.SharedMemory(name=name)
|
||||
# See https://docs.python.org/3/library/multiprocessing.shared_memory.html # noqa
|
||||
# Some platforms allocate memory based on page size,
|
||||
# so the shared memory block size may be larger or equal
|
||||
# to the requested size. The size parameter is ignored
|
||||
# when attaching to an existing block.
|
||||
assert self.shared_memory.size >= self.total_bytes_of_buffer
|
||||
except FileNotFoundError:
|
||||
# we might deserialize the object in a different node
|
||||
# in this case, this object is not used,
|
||||
# and we should suppress the error
|
||||
pass
|
||||
|
||||
def handle(self):
|
||||
return (
|
||||
self.n_reader,
|
||||
self.max_chunk_bytes,
|
||||
self.max_chunks,
|
||||
self.shared_memory.name,
|
||||
)
|
||||
|
||||
def __reduce__(self):
|
||||
return (
|
||||
self.__class__,
|
||||
self.handle(),
|
||||
)
|
||||
|
||||
def __del__(self):
|
||||
if hasattr(self, "shared_memory"):
|
||||
self.shared_memory.close()
|
||||
if self.is_creator:
|
||||
self.shared_memory.unlink()
|
||||
|
||||
@contextmanager
|
||||
def get_data(self, current_idx: int):
|
||||
start = self.data_offset + current_idx * self.max_chunk_bytes
|
||||
end = start + self.max_chunk_bytes
|
||||
with self.shared_memory.buf[start:end] as buf:
|
||||
yield buf
|
||||
|
||||
@contextmanager
|
||||
def get_metadata(self, current_idx: int):
|
||||
start = self.metadata_offset + current_idx * self.metadata_size
|
||||
end = start + self.metadata_size
|
||||
with self.shared_memory.buf[start:end] as buf:
|
||||
yield buf
|
||||
|
||||
|
||||
@dataclass
|
||||
class Handle:
|
||||
local_reader_ranks: list[int] = field(default_factory=list)
|
||||
|
||||
buffer_handle: tuple[int, int, int, str] | None = None
|
||||
local_subscribe_addr: str | None = None
|
||||
remote_subscribe_addr: str | None = None
|
||||
remote_addr_ipv6: bool = False
|
||||
|
||||
|
||||
class MessageQueue:
|
||||
def __init__(
|
||||
self,
|
||||
n_reader, # number of all readers
|
||||
n_local_reader, # number of local readers through shared memory
|
||||
local_reader_ranks: list[int] | None = None,
|
||||
# Default of 24MiB chosen to be large enough to accommodate grammar
|
||||
# bitmask tensors for large batches (1024 requests).
|
||||
max_chunk_bytes: int = 1024 * 1024 * 24,
|
||||
max_chunks: int = 10,
|
||||
connect_ip: str | None = None,
|
||||
):
|
||||
if local_reader_ranks is None:
|
||||
local_reader_ranks = list(range(n_local_reader))
|
||||
else:
|
||||
assert len(local_reader_ranks) == n_local_reader
|
||||
self.n_local_reader = n_local_reader
|
||||
n_remote_reader = n_reader - n_local_reader
|
||||
self.n_remote_reader = n_remote_reader
|
||||
|
||||
context = Context()
|
||||
|
||||
if n_local_reader > 0:
|
||||
# for local readers, we will:
|
||||
# 1. create a shared memory ring buffer to communicate small data
|
||||
# 2. create a publish-subscribe socket to communicate large data
|
||||
self.buffer = ShmRingBuffer(n_local_reader, max_chunk_bytes, max_chunks)
|
||||
|
||||
# XPUB is very similar to PUB,
|
||||
# except that it can receive subscription messages
|
||||
# to confirm the number of subscribers
|
||||
self.local_socket = context.socket(XPUB)
|
||||
# set the verbose option so that we can receive every subscription
|
||||
# message. otherwise, we will only receive the first subscription
|
||||
# see http://api.zeromq.org/3-3:zmq-setsockopt for more details
|
||||
self.local_socket.setsockopt(XPUB_VERBOSE, True)
|
||||
local_subscribe_addr = get_open_zmq_ipc_path()
|
||||
logger.debug("Binding to %s", local_subscribe_addr)
|
||||
self.local_socket.bind(local_subscribe_addr)
|
||||
|
||||
self.current_idx = 0
|
||||
else:
|
||||
self.buffer = None # type: ignore
|
||||
local_subscribe_addr = None
|
||||
self.local_socket = None
|
||||
self.current_idx = -1
|
||||
|
||||
remote_addr_ipv6 = False
|
||||
if n_remote_reader > 0:
|
||||
# for remote readers, we will:
|
||||
# create a publish-subscribe socket to communicate large data
|
||||
if not connect_ip:
|
||||
connect_ip = get_ip()
|
||||
self.remote_socket = context.socket(XPUB)
|
||||
self.remote_socket.setsockopt(XPUB_VERBOSE, True)
|
||||
remote_subscribe_port = get_open_port()
|
||||
if is_valid_ipv6_address(connect_ip):
|
||||
self.remote_socket.setsockopt(IPV6, 1)
|
||||
remote_addr_ipv6 = True
|
||||
connect_ip = f"[{connect_ip}]"
|
||||
socket_addr = f"tcp://{connect_ip}:{remote_subscribe_port}"
|
||||
self.remote_socket.bind(socket_addr)
|
||||
remote_subscribe_addr = f"tcp://{connect_ip}:{remote_subscribe_port}"
|
||||
else:
|
||||
remote_subscribe_addr = None
|
||||
self.remote_socket = None
|
||||
|
||||
self._is_writer = True
|
||||
self._is_local_reader = False
|
||||
self.local_reader_rank = -1
|
||||
# rank does not matter for remote readers
|
||||
self._is_remote_reader = False
|
||||
self._read_spin_timer = SpinTimer()
|
||||
|
||||
self.handle = Handle(
|
||||
local_reader_ranks=local_reader_ranks,
|
||||
buffer_handle=self.buffer.handle() if self.buffer is not None else None,
|
||||
local_subscribe_addr=local_subscribe_addr,
|
||||
remote_subscribe_addr=remote_subscribe_addr,
|
||||
remote_addr_ipv6=remote_addr_ipv6,
|
||||
)
|
||||
|
||||
logger.debug("vLLM message queue communication handle: %s", self.handle)
|
||||
|
||||
def export_handle(self) -> Handle:
|
||||
return self.handle
|
||||
|
||||
@staticmethod
|
||||
def create_from_handle(handle: Handle, rank) -> "MessageQueue":
|
||||
self = MessageQueue.__new__(MessageQueue)
|
||||
self.handle = handle
|
||||
self._is_writer = False
|
||||
|
||||
context = Context()
|
||||
|
||||
if rank in handle.local_reader_ranks:
|
||||
assert handle.buffer_handle is not None
|
||||
self.buffer = ShmRingBuffer(*handle.buffer_handle)
|
||||
self.current_idx = 0
|
||||
self.local_reader_rank = handle.local_reader_ranks.index(rank)
|
||||
self._is_local_reader = True
|
||||
self._is_remote_reader = False
|
||||
|
||||
self.local_socket = context.socket(SUB)
|
||||
self.local_socket.setsockopt_string(SUBSCRIBE, "")
|
||||
socket_addr = handle.local_subscribe_addr
|
||||
logger.debug("Connecting to %s", socket_addr)
|
||||
self.local_socket.connect(socket_addr)
|
||||
|
||||
self.remote_socket = None
|
||||
|
||||
self._read_spin_timer = (
|
||||
SpinSleepTimer() if envs.VLLM_SLEEP_WHEN_IDLE else SpinTimer()
|
||||
)
|
||||
else:
|
||||
self.buffer = None # type: ignore
|
||||
self.current_idx = -1
|
||||
self.local_reader_rank = -1
|
||||
self._is_local_reader = False
|
||||
self._is_remote_reader = True
|
||||
|
||||
self.local_socket = None
|
||||
|
||||
self.remote_socket = context.socket(SUB)
|
||||
self.remote_socket.setsockopt_string(SUBSCRIBE, "")
|
||||
if handle.remote_addr_ipv6:
|
||||
self.remote_socket.setsockopt(IPV6, 1)
|
||||
socket_addr = handle.remote_subscribe_addr
|
||||
logger.debug("Connecting to %s", socket_addr)
|
||||
self.remote_socket.connect(socket_addr)
|
||||
|
||||
return self
|
||||
|
||||
def wait_until_ready(self):
|
||||
"""This is a collective operation. All processes (including the
|
||||
readers and the writer) should call this function.
|
||||
"""
|
||||
if self._is_writer:
|
||||
# wait for all readers to connect
|
||||
|
||||
# local readers
|
||||
for i in range(self.n_local_reader):
|
||||
# wait for subscription messages from all local readers
|
||||
self.local_socket.recv()
|
||||
if self.n_local_reader > 0:
|
||||
# send a message to all local readers
|
||||
# to make sure the publish channel is working
|
||||
self.local_socket.send(b"READY")
|
||||
|
||||
# remote readers
|
||||
for i in range(self.n_remote_reader):
|
||||
# wait for subscription messages from all remote readers
|
||||
self.remote_socket.recv()
|
||||
if self.n_remote_reader > 0:
|
||||
# send a message to all remote readers
|
||||
# to make sure the publish channel is working
|
||||
self.remote_socket.send(b"READY")
|
||||
elif self._is_local_reader:
|
||||
# wait for the writer to send a message
|
||||
recv = self.local_socket.recv()
|
||||
assert recv == b"READY"
|
||||
elif self._is_remote_reader:
|
||||
# wait for the writer to send a message
|
||||
recv = self.remote_socket.recv()
|
||||
assert recv == b"READY"
|
||||
|
||||
@contextmanager
|
||||
def acquire_write(self, timeout: float | None = None):
|
||||
assert self._is_writer, "Only writers can acquire write"
|
||||
start_time = time.monotonic()
|
||||
n_warning = 1
|
||||
while True:
|
||||
with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
|
||||
# Memory fence ensures we see the latest read flags from readers.
|
||||
# Without this, we may read stale flags from our CPU cache and
|
||||
# spin indefinitely even though readers have completed.
|
||||
memory_fence()
|
||||
read_count = sum(metadata_buffer[1:])
|
||||
written_flag = metadata_buffer[0]
|
||||
if written_flag and read_count != self.buffer.n_reader:
|
||||
# this block is written and not read by all readers
|
||||
# for writers, `self.current_idx` is the next block to write
|
||||
# if this block is not ready to write,
|
||||
# we need to wait until it is read by all readers
|
||||
|
||||
# Release the processor to other threads
|
||||
sched_yield()
|
||||
|
||||
# if we time out, raise an exception
|
||||
elapsed = time.monotonic() - start_time
|
||||
if timeout is not None and elapsed > timeout:
|
||||
raise TimeoutError
|
||||
|
||||
# if we wait for a long time, log a message
|
||||
if elapsed > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning:
|
||||
logger.info(
|
||||
long_wait_time_msg(VLLM_RINGBUFFER_WARNING_INTERVAL)
|
||||
)
|
||||
n_warning += 1
|
||||
|
||||
continue
|
||||
# found a block that is either
|
||||
# (1) not written
|
||||
# (2) read by all readers
|
||||
|
||||
# mark the block as not written
|
||||
metadata_buffer[0] = 0
|
||||
# let caller write to the buffer
|
||||
with self.buffer.get_data(self.current_idx) as buf:
|
||||
yield buf
|
||||
|
||||
# caller has written to the buffer
|
||||
# NOTE: order is important here
|
||||
# first set the read flags to 0
|
||||
# then set the written flag to 1
|
||||
# otherwise, the readers may think they already read the block
|
||||
for i in range(1, self.buffer.n_reader + 1):
|
||||
# set read flag to 0, meaning it is not read yet
|
||||
metadata_buffer[i] = 0
|
||||
# Memory fence here ensures the order of the buffer and flag
|
||||
# writes. This guarantees that when `metadata_buffer[0] = 1` is
|
||||
# visible to readers, `buf` can be completely ready. Without
|
||||
# this, some CPU architectures with weak ordering may incur
|
||||
# memory inconsistency.
|
||||
memory_fence()
|
||||
# mark the block as written
|
||||
metadata_buffer[0] = 1
|
||||
# Memory fence ensures the write is visible to readers on other cores
|
||||
# before we proceed. Without this, readers may spin indefinitely
|
||||
# waiting for a write that's stuck in our CPU's store buffer.
|
||||
memory_fence()
|
||||
self.current_idx = (self.current_idx + 1) % self.buffer.max_chunks
|
||||
break
|
||||
|
||||
@contextmanager
|
||||
def acquire_read(
|
||||
self,
|
||||
timeout: float | None = None,
|
||||
cancel: Event | None = None,
|
||||
indefinite: bool = False,
|
||||
):
|
||||
assert self._is_local_reader, "Only readers can acquire read"
|
||||
start_time = time.monotonic()
|
||||
n_warning = 1
|
||||
with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
|
||||
while True:
|
||||
# Memory fence ensures we see the latest writes from the writer.
|
||||
# Without this, we may read stale flags from our CPU cache
|
||||
# and spin indefinitely even though writer has updated them.
|
||||
memory_fence()
|
||||
read_flag = metadata_buffer[self.local_reader_rank + 1]
|
||||
written_flag = metadata_buffer[0]
|
||||
if not written_flag or read_flag:
|
||||
# this block is either
|
||||
# (1) not written
|
||||
# (2) already read by this reader
|
||||
|
||||
# for readers, `self.current_idx` is the next block to read
|
||||
# if this block is not ready,
|
||||
# we need to wait until it is written
|
||||
|
||||
# Release the processor to other threads
|
||||
self._read_spin_timer.spin()
|
||||
|
||||
if cancel is not None and cancel.is_set():
|
||||
raise RuntimeError("cancelled")
|
||||
|
||||
# if we time out, raise an exception
|
||||
elapsed = time.monotonic() - start_time
|
||||
if timeout is not None and elapsed > timeout:
|
||||
raise TimeoutError
|
||||
|
||||
# if we wait for a long time, log a message
|
||||
if not indefinite and (
|
||||
elapsed > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning
|
||||
):
|
||||
logger.info(
|
||||
long_wait_time_msg(VLLM_RINGBUFFER_WARNING_INTERVAL)
|
||||
)
|
||||
n_warning += 1
|
||||
|
||||
continue
|
||||
# found a block that is not read by this reader
|
||||
# let caller read from the buffer
|
||||
with self.buffer.get_data(self.current_idx) as buf:
|
||||
yield buf
|
||||
|
||||
# caller has read from the buffer
|
||||
# set the read flag
|
||||
metadata_buffer[self.local_reader_rank + 1] = 1
|
||||
# Memory fence ensures the read flag is visible to the writer.
|
||||
# Without this, writer may not see our read completion and
|
||||
# could wait indefinitely for all readers to finish.
|
||||
memory_fence()
|
||||
self.current_idx = (self.current_idx + 1) % self.buffer.max_chunks
|
||||
|
||||
self._read_spin_timer.record_activity()
|
||||
break
|
||||
|
||||
def enqueue(self, obj, timeout: float | None = None):
|
||||
"""Write to message queue with optional timeout (in seconds)"""
|
||||
assert self._is_writer, "Only writers can enqueue"
|
||||
all_buffers: list[SizedBuffer] = [b""]
|
||||
total_bytes = 6 # 2 bytes for oob buffer count, 4 for main buffer size
|
||||
|
||||
def oob_callback(buf: PickleBuffer) -> bool:
|
||||
raw_buf = buf.raw()
|
||||
if len(raw_buf) < 1024 * 1024:
|
||||
# In-line buffers smaller than 1MiB.
|
||||
return True
|
||||
all_buffers.append(raw_buf)
|
||||
nonlocal total_bytes
|
||||
total_bytes += len(raw_buf) + 4
|
||||
return False
|
||||
|
||||
all_buffers[0] = pickle.dumps(
|
||||
obj, protocol=pickle.HIGHEST_PROTOCOL, buffer_callback=oob_callback
|
||||
)
|
||||
if self.n_local_reader > 0:
|
||||
if total_bytes + len(all_buffers[0]) >= self.buffer.max_chunk_bytes:
|
||||
with self.acquire_write(timeout) as buf:
|
||||
buf[0] = 1 # overflow
|
||||
self.local_socket.send_multipart(all_buffers, copy=False)
|
||||
else:
|
||||
# Byte 0: 0
|
||||
# Bytes 1-2: Count of buffers
|
||||
# Then each buffer follows, preceded by 4 bytes containing its length:
|
||||
# [4 byte int L][L bytes of buffer content] ...
|
||||
with self.acquire_write(timeout) as buf:
|
||||
buf[0] = 0 # not overflow
|
||||
offset = 3
|
||||
buf[1:offset] = to_bytes_big(len(all_buffers), 2) # oob buf count
|
||||
for buffer in all_buffers:
|
||||
buf_len = len(buffer)
|
||||
# prepend each buffer with 4 bytes containing its size.
|
||||
buf_offset = offset + 4
|
||||
buf[offset:buf_offset] = to_bytes_big(buf_len, 4)
|
||||
buf[buf_offset : (offset := buf_offset + buf_len)] = buffer
|
||||
|
||||
if self.n_remote_reader > 0:
|
||||
self.remote_socket.send_multipart(all_buffers, copy=False)
|
||||
|
||||
def dequeue(
|
||||
self,
|
||||
timeout: float | None = None,
|
||||
cancel: Event | None = None,
|
||||
indefinite: bool = False,
|
||||
):
|
||||
"""Read from message queue with optional timeout (in seconds)"""
|
||||
if self._is_local_reader:
|
||||
with self.acquire_read(timeout, cancel, indefinite) as buf:
|
||||
overflow = buf[0] == 1
|
||||
if not overflow:
|
||||
offset = 3
|
||||
buf_count = from_bytes_big(buf[1:offset])
|
||||
all_buffers = []
|
||||
for i in range(buf_count):
|
||||
buf_offset = offset + 4
|
||||
buf_len = from_bytes_big(buf[offset:buf_offset])
|
||||
offset = buf_offset + buf_len
|
||||
all_buffers.append(buf[buf_offset:offset])
|
||||
obj = pickle.loads(all_buffers[0], buffers=all_buffers[1:])
|
||||
if overflow:
|
||||
obj = MessageQueue.recv(self.local_socket, timeout)
|
||||
elif self._is_remote_reader:
|
||||
obj = MessageQueue.recv(self.remote_socket, timeout)
|
||||
else:
|
||||
raise RuntimeError("Only readers can dequeue")
|
||||
return obj
|
||||
|
||||
@staticmethod
|
||||
def recv(socket: zmq.Socket, timeout: float | None) -> Any:
|
||||
timeout_ms = None if timeout is None else int(timeout * 1000)
|
||||
if not socket.poll(timeout=timeout_ms):
|
||||
raise TimeoutError
|
||||
recv, *recv_oob = socket.recv_multipart(copy=False)
|
||||
return pickle.loads(recv, buffers=recv_oob)
|
||||
|
||||
def broadcast_object(self, obj=None):
|
||||
if self._is_writer:
|
||||
self.enqueue(obj)
|
||||
return obj
|
||||
return self.dequeue()
|
||||
|
||||
@staticmethod
|
||||
def create_from_process_group_single_reader(
|
||||
pg: ProcessGroup,
|
||||
max_chunk_bytes,
|
||||
max_chunks,
|
||||
reader_rank: int = 0,
|
||||
blocking: bool = False,
|
||||
) -> tuple["MessageQueue", list[Handle]]:
|
||||
"""
|
||||
Creates a MessageQueue for a process group with a single reader.
|
||||
|
||||
This method is designed for scenarios where only one process (the reader)
|
||||
will consume messages, and all other processes are writers. It sets up
|
||||
the shared memory buffer and communication handles accordingly, and
|
||||
gathers the handles from all processes to the reader.
|
||||
|
||||
Args:
|
||||
pg (ProcessGroup): The torch distributed process group.
|
||||
max_chunk_bytes (int): Maximum size in bytes for each chunk in the buffer.
|
||||
max_chunks (int): Maximum number of chunks in the buffer.
|
||||
reader_rank (int, optional): The global rank that will act as the reader.
|
||||
Defaults to 0.
|
||||
blocking (bool, optional): If True, blocks until all processes are ready.
|
||||
Defaults to False.
|
||||
|
||||
Returns:
|
||||
tuple[MessageQueue, list[Handle]]:
|
||||
The MessageQueue instance for the calling process,
|
||||
and a list of handles (only non-empty for the reader process).
|
||||
"""
|
||||
local_size = current_platform.device_count()
|
||||
rank = dist.get_rank()
|
||||
same_node = rank // local_size == reader_rank // local_size
|
||||
buffer_io = MessageQueue(
|
||||
n_reader=1,
|
||||
n_local_reader=1 if same_node else 0,
|
||||
max_chunk_bytes=max_chunk_bytes,
|
||||
max_chunks=max_chunks,
|
||||
)
|
||||
handle = buffer_io.export_handle()
|
||||
handles = [None] * dist.get_world_size(pg) if rank == reader_rank else None
|
||||
dist.gather_object(handle, handles, dst=reader_rank, group=pg)
|
||||
if blocking:
|
||||
buffer_io.wait_until_ready()
|
||||
return buffer_io, cast(list[Handle], handles or [])
|
||||
|
||||
@staticmethod
|
||||
def create_from_process_group(
|
||||
pg: ProcessGroup | StatelessProcessGroup,
|
||||
max_chunk_bytes,
|
||||
max_chunks,
|
||||
writer_rank: int = 0,
|
||||
external_writer_handle=None,
|
||||
blocking: bool = True,
|
||||
) -> "MessageQueue":
|
||||
"""
|
||||
Creates a MessageQueue for a distributed process group with one writer and
|
||||
multiple readers.
|
||||
|
||||
This method is designed for scenarios where one process (the writer) sends
|
||||
messages, and all other processes (the readers) receive messages. It sets up
|
||||
the shared memory buffer and socket communication handles accordingly, and
|
||||
broadcasts the handle from the writer to all readers.
|
||||
|
||||
Args:
|
||||
pg (ProcessGroup | StatelessProcessGroup): The torch distributed process
|
||||
group.
|
||||
max_chunk_bytes (int): Maximum size in bytes for each chunk in the buffer.
|
||||
max_chunks (int): Maximum number of chunks in the buffer.
|
||||
writer_rank (int, optional): The global rank that will act as the writer.
|
||||
Defaults to 0.
|
||||
external_writer_handle (Handle, optional): Used when there is a handle
|
||||
from an external Message Queue. If provided, use this handle to init
|
||||
PG writer message queue instead of creating a new one. Defaults to None.
|
||||
blocking (bool, optional): If True, blocks until all processes are ready.
|
||||
Defaults to True.
|
||||
|
||||
Returns:
|
||||
MessageQueue: The MessageQueue instance for the calling process.
|
||||
|
||||
"""
|
||||
if isinstance(pg, ProcessGroup):
|
||||
group_rank = dist.get_rank(pg)
|
||||
group_world_size = dist.get_world_size(pg)
|
||||
global_ranks = dist.get_process_group_ranks(pg)
|
||||
else:
|
||||
group_rank = pg.rank
|
||||
group_world_size = pg.world_size
|
||||
global_ranks = list(range(pg.world_size))
|
||||
from vllm.distributed.parallel_state import in_the_same_node_as
|
||||
|
||||
status = in_the_same_node_as(pg, source_rank=writer_rank)
|
||||
if group_rank == writer_rank:
|
||||
if external_writer_handle is not None:
|
||||
buffer_io = MessageQueue.create_from_handle(
|
||||
external_writer_handle, group_rank
|
||||
)
|
||||
else:
|
||||
same_node_ranks = [i for i, s in enumerate(status) if s]
|
||||
n_reader = group_world_size - 1
|
||||
n_local_reader = len(same_node_ranks) - 1
|
||||
local_reader_ranks = [i for i in same_node_ranks if i != writer_rank]
|
||||
buffer_io = MessageQueue(
|
||||
n_reader=n_reader,
|
||||
n_local_reader=n_local_reader,
|
||||
local_reader_ranks=local_reader_ranks,
|
||||
max_chunk_bytes=max_chunk_bytes,
|
||||
max_chunks=max_chunks,
|
||||
)
|
||||
handle = buffer_io.export_handle()
|
||||
if isinstance(pg, ProcessGroup):
|
||||
dist.broadcast_object_list(
|
||||
[handle], src=global_ranks[writer_rank], group=pg
|
||||
)
|
||||
else:
|
||||
pg.broadcast_obj(handle, writer_rank)
|
||||
else:
|
||||
if isinstance(pg, ProcessGroup):
|
||||
recv = [None]
|
||||
dist.broadcast_object_list(
|
||||
recv, src=global_ranks[writer_rank], group=pg
|
||||
)
|
||||
handle = recv[0] # type: ignore
|
||||
else:
|
||||
handle = pg.broadcast_obj(None, writer_rank)
|
||||
buffer_io = MessageQueue.create_from_handle(handle, group_rank)
|
||||
if blocking:
|
||||
buffer_io.wait_until_ready()
|
||||
return buffer_io
|
||||
707
vllm/distributed/device_communicators/shm_object_storage.py
Normal file
707
vllm/distributed/device_communicators/shm_object_storage.py
Normal file
@@ -0,0 +1,707 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pickle
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable, Iterable
|
||||
from contextlib import contextmanager, suppress
|
||||
from dataclasses import dataclass
|
||||
from itertools import chain
|
||||
from multiprocessing import shared_memory
|
||||
from multiprocessing.synchronize import Lock as LockType
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class SingleWriterShmRingBuffer:
|
||||
"""
|
||||
A single-writer, multiple-reader ring buffer implementation using shared
|
||||
memory. This class provides a thread-safe ring buffer where one process
|
||||
can write data while multiple processes/threads can read from it.
|
||||
|
||||
Architecture:
|
||||
- Uses shared memory for cross-process communication
|
||||
- Maintains metadata for each allocated buffer chunk in the writer process
|
||||
- Supports custom "is_free_fn" functions to determine when buffers can be
|
||||
reused
|
||||
- Each buffer chunk contains: `[4-byte id][4-byte size][actual_data]`
|
||||
|
||||
Key Concepts:
|
||||
- monotonic_id_start/end: Track the range of active buffer IDs
|
||||
- data_buffer_start/end: Track the physical memory range in use
|
||||
- Automatic wraparound when reaching buffer end
|
||||
- Lazy garbage collection based on is_free_fn checks
|
||||
|
||||
Example Usage Scenarios:
|
||||
|
||||
Scenario 1: Simple Linear Allocation
|
||||
```
|
||||
Buffer size: 100 bytes
|
||||
Initial state: [................................................. ]
|
||||
^start=end(0)
|
||||
|
||||
After allocating 20 bytes (id=0):
|
||||
[id:0|size:20|data........][...................................]
|
||||
^start(0) ^end(28)
|
||||
|
||||
After allocating 30 bytes (id=1):
|
||||
[id:0|size:20|data........][id:1|size:30|data..............][..]
|
||||
^start(0) ^end(66)
|
||||
```
|
||||
|
||||
Scenario 2: Memory Reclamation
|
||||
```
|
||||
Before freeing (both buffers still in use):
|
||||
[id:0|size:20|data........][id:1|size:30|data..............][..]
|
||||
^start(0) ^end(66)
|
||||
|
||||
After id:0 is marked free by readers:
|
||||
[FREED.................... ][id:1|size:30|data..............][..]
|
||||
^start(28) ^end(66)
|
||||
|
||||
After both are freed:
|
||||
[FREED..............................................][..]
|
||||
^start=end(66)
|
||||
```
|
||||
|
||||
Scenario 3: Wraparound Allocation (continuing from Scenario 2)
|
||||
```
|
||||
Starting from after memory reclamation in Scenario 2:
|
||||
[FREED..............................................][..]
|
||||
^start=end(66)
|
||||
|
||||
Allocate 40 bytes (id=2) - only 34 bytes available at end, so wraparound:
|
||||
[id:2|size:40|data........................][FREED.............][..]
|
||||
^end(148) ^start(66)
|
||||
```
|
||||
|
||||
Scenario 4: Error Handling - Out of Space
|
||||
```
|
||||
Starting from after wraparound allocation in Scenario 3:
|
||||
[id:2|size:40|data........................][FREED.............][..]
|
||||
^end(148) ^start(66)
|
||||
|
||||
Trying to allocate 20 more bytes:
|
||||
occupied_size_new = end + size - start = 148 + 28 - 66 > buffer_size(100)
|
||||
-> Raises MemoryError: "Not enough space in the data buffer"
|
||||
```
|
||||
|
||||
Thread Safety:
|
||||
- Single writer: Only one process/thread should write (allocate_buf)
|
||||
- Multiple readers: Multiple processes/threads can read (access_buf)
|
||||
- Reader synchronization handled by is_free_fn callback
|
||||
- Writer handles garbage collection (free_buf) based on reader feedback
|
||||
|
||||
Memory Layout per Buffer Chunk:
|
||||
`[4-byte monotonic_id][4-byte chunk_size][actual_data...]`
|
||||
^metadata_start ^data_start
|
||||
|
||||
The monotonic_id ensures data integrity - readers can verify they're
|
||||
accessing the correct data even after buffer wraparound or reuse.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_buffer_size: int,
|
||||
name: str | None = None,
|
||||
create: bool = False,
|
||||
):
|
||||
self.data_buffer_size = data_buffer_size
|
||||
self.is_writer = create
|
||||
|
||||
self.ID_NBYTES = 4
|
||||
self.ID_MAX = 2**31 # exclusive, so 2**31 - 1 is the max value
|
||||
self.SIZE_NBYTES = 4
|
||||
# 4 bytes for id, 4 bytes for buffer size
|
||||
self.MD_SIZE = self.ID_NBYTES + self.SIZE_NBYTES
|
||||
self.monotonic_id_end = 0
|
||||
self.monotonic_id_start = 0
|
||||
self.data_buffer_start = 0
|
||||
self.data_buffer_end = 0
|
||||
|
||||
if create:
|
||||
logger.debug("Creating new shared memory buffer: %s", name)
|
||||
# we are creating a buffer
|
||||
self.metadata: dict[int, int] = {} # monotonic_id -> start address
|
||||
self.shared_memory = shared_memory.SharedMemory(
|
||||
create=True, size=self.data_buffer_size, name=name
|
||||
)
|
||||
else:
|
||||
# we are opening an existing buffer
|
||||
# fix to https://stackoverflow.com/q/62748654/9191338
|
||||
# Python incorrectly tracks shared memory even if it is not
|
||||
# created by the process. The following patch is a workaround.
|
||||
with patch(
|
||||
"multiprocessing.resource_tracker.register",
|
||||
lambda *args, **kwargs: None,
|
||||
):
|
||||
self.shared_memory = shared_memory.SharedMemory(name=name)
|
||||
# See https://docs.python.org/3/library/multiprocessing.shared_memory.html # noqa
|
||||
# Some platforms allocate memory based on page size,
|
||||
# so the shared memory block size may be larger or equal
|
||||
# to the requested size. The size parameter is ignored
|
||||
# when attaching to an existing block.
|
||||
assert self.shared_memory.size >= self.data_buffer_size
|
||||
|
||||
logger.debug(
|
||||
"Shared memory created/opened with name: %s, size: %d",
|
||||
self.shared_memory.name,
|
||||
self.data_buffer_size,
|
||||
)
|
||||
|
||||
def handle(self):
|
||||
return (
|
||||
self.data_buffer_size,
|
||||
self.shared_memory.name,
|
||||
)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear the ring buffer."""
|
||||
assert self.is_writer, "Only the writer can clear the buffer."
|
||||
self.metadata.clear()
|
||||
self.monotonic_id_end = 0
|
||||
self.monotonic_id_start = 0
|
||||
self.data_buffer_start = 0
|
||||
self.data_buffer_end = 0
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the shared memory."""
|
||||
if hasattr(self, "shared_memory"):
|
||||
self.shared_memory.close()
|
||||
if self.is_writer:
|
||||
with suppress(FileNotFoundError):
|
||||
self.shared_memory.unlink()
|
||||
|
||||
def __del__(self):
|
||||
self.close()
|
||||
|
||||
def int2byte(self, integer: int) -> bytes:
|
||||
"""Convert an integer to bytes."""
|
||||
return integer.to_bytes(self.ID_NBYTES, "little", signed=True)
|
||||
|
||||
def byte2int(self, byte_data: bytes) -> int:
|
||||
"""Convert bytes back to an integer."""
|
||||
return int.from_bytes(byte_data, "little", signed=True)
|
||||
|
||||
def allocate_buf(self, size: int) -> tuple[int, int]:
|
||||
"""
|
||||
Allocate a buffer `MD_SIZE` + `size` bytes in the shared memory.
|
||||
Memory layout:
|
||||
`[4-byte monotonic_id][4-byte size][buffer data...]`
|
||||
"""
|
||||
assert self.is_writer, "Only the writer can allocate buffers."
|
||||
assert size > 0, "Size must be greater than 0"
|
||||
size += self.MD_SIZE # add metadata size to the buffer size
|
||||
# reset to beginning if the buffer does have enough contiguous space
|
||||
buffer_end_reset = self.data_buffer_end % self.data_buffer_size
|
||||
if buffer_end_reset + size > self.data_buffer_size:
|
||||
buffer_end_reset = (
|
||||
self.data_buffer_end // self.data_buffer_size + 1
|
||||
) * self.data_buffer_size
|
||||
else: # no reset needed
|
||||
buffer_end_reset = self.data_buffer_end
|
||||
|
||||
# check if we have enough space in the data buffer
|
||||
# i.e. if the new end (self.data_buffer_end + size)
|
||||
# exceeds the start of the data buffer
|
||||
occupied_size_new = buffer_end_reset + size - self.data_buffer_start
|
||||
if occupied_size_new > self.data_buffer_size:
|
||||
raise MemoryError(
|
||||
"Not enough space in the data buffer, "
|
||||
"try calling free_buf() to free up space"
|
||||
)
|
||||
self.data_buffer_end = buffer_end_reset
|
||||
|
||||
# first 4 bytes as the monotonic id
|
||||
buf_idx = self.data_buffer_end % self.data_buffer_size
|
||||
self.shared_memory.buf[buf_idx : buf_idx + self.ID_NBYTES] = self.int2byte(
|
||||
self.monotonic_id_end
|
||||
)
|
||||
# next 4 bytes as the size of the data buffer
|
||||
self.shared_memory.buf[buf_idx + self.ID_NBYTES : buf_idx + self.MD_SIZE] = (
|
||||
self.int2byte(size)
|
||||
)
|
||||
|
||||
# record metadata
|
||||
self.metadata[self.monotonic_id_end % self.ID_MAX] = self.data_buffer_end
|
||||
# update buffer and monotonic id indices
|
||||
current_buffer_end = self.data_buffer_end
|
||||
current_id_end = self.monotonic_id_end
|
||||
self.data_buffer_end += size
|
||||
self.monotonic_id_end = (self.monotonic_id_end + 1) % self.ID_MAX
|
||||
return current_buffer_end, current_id_end
|
||||
|
||||
@contextmanager
|
||||
def access_buf(self, address: int):
|
||||
buf_idx = address % self.data_buffer_size
|
||||
|
||||
# read metadata
|
||||
metadata_buff = self.shared_memory.buf[buf_idx : buf_idx + self.MD_SIZE]
|
||||
id = self.byte2int(metadata_buff[: self.ID_NBYTES])
|
||||
size = self.byte2int(metadata_buff[self.ID_NBYTES : self.MD_SIZE])
|
||||
|
||||
# yield the data buffer and metadata
|
||||
data_buff = self.shared_memory.buf[buf_idx + self.MD_SIZE : buf_idx + size]
|
||||
with (
|
||||
memoryview(data_buff) as data_view,
|
||||
):
|
||||
yield data_view, (id, size)
|
||||
|
||||
def free_buf(
|
||||
self,
|
||||
is_free_fn: Callable[[int, memoryview], bool],
|
||||
nbytes: int | None = None,
|
||||
) -> Iterable[int]:
|
||||
"""
|
||||
Free a buffer of the given size. This is a no-op in shared memory,
|
||||
but we need to keep track of the metadata.
|
||||
|
||||
If freed memory spreads across the end and start of the ring buffer,
|
||||
the actual freed memory will be in two segments. In this case there
|
||||
still might not be a contiguous space of `nbytes` available.
|
||||
|
||||
Args:
|
||||
nbytes (int, optional): The size of the buffer to free. If None,
|
||||
frees the maximum size of the ring buffer.
|
||||
"""
|
||||
|
||||
assert self.is_writer, "Only the writer can free buffers."
|
||||
logger.debug(
|
||||
"Freeing up space in the ring buffer, "
|
||||
"monotonic_id_start: %d, monotonic_id_end: %d",
|
||||
self.monotonic_id_start,
|
||||
self.monotonic_id_end,
|
||||
)
|
||||
monotonic_id_before = self.monotonic_id_start
|
||||
# if nbytes is None, free up the maximum size of the ring buffer
|
||||
if nbytes is None:
|
||||
nbytes = self.data_buffer_size
|
||||
freed_bytes = 0
|
||||
while self.monotonic_id_start in self.metadata and freed_bytes < nbytes:
|
||||
address = self.metadata[self.monotonic_id_start]
|
||||
with self.access_buf(address) as (data_buff, metadata):
|
||||
if is_free_fn(self.monotonic_id_start, data_buff):
|
||||
# check passed, we can free the buffer
|
||||
del self.metadata[self.monotonic_id_start]
|
||||
self.monotonic_id_start = (
|
||||
self.monotonic_id_start + 1
|
||||
) % self.ID_MAX
|
||||
if self.monotonic_id_start in self.metadata:
|
||||
# pointing to the start addr of next allocation
|
||||
self.data_buffer_start += (
|
||||
self.metadata[self.monotonic_id_start]
|
||||
- self.data_buffer_start
|
||||
) % self.data_buffer_size
|
||||
else:
|
||||
# no remaining allocation, reset to zero
|
||||
self.data_buffer_start = self.data_buffer_end = 0
|
||||
freed_bytes += metadata[1]
|
||||
else:
|
||||
# there are still readers, we cannot free the buffer
|
||||
break
|
||||
|
||||
logger.debug(
|
||||
"Freed %d bytes from the ring buffer, "
|
||||
"monotonic_id_start: %d, monotonic_id_end: %d",
|
||||
freed_bytes,
|
||||
self.monotonic_id_start,
|
||||
self.monotonic_id_end,
|
||||
)
|
||||
|
||||
# buffer wrap around
|
||||
if self.data_buffer_start >= self.data_buffer_size:
|
||||
self.data_buffer_start -= self.data_buffer_size
|
||||
self.data_buffer_end -= self.data_buffer_size
|
||||
|
||||
monotonic_id_after = self.monotonic_id_start
|
||||
# id wrap around
|
||||
if monotonic_id_after >= monotonic_id_before:
|
||||
return range(monotonic_id_before, monotonic_id_after)
|
||||
else:
|
||||
return chain(
|
||||
range(monotonic_id_before, self.ID_MAX), range(0, monotonic_id_after)
|
||||
)
|
||||
|
||||
|
||||
class ObjectSerde(ABC):
|
||||
@abstractmethod
|
||||
def serialize(self, value: Any) -> tuple[Any, int, bytes, int]:
|
||||
"""Serialize an object to bytes."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def deserialize(self, data: memoryview) -> Any:
|
||||
"""Deserialize bytes back to an object."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MsgpackSerde(ObjectSerde):
|
||||
def __init__(self):
|
||||
# Delayed import to avoid circular dependency
|
||||
from vllm.multimodal.inputs import MultiModalKwargsItem
|
||||
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
|
||||
|
||||
self.encoder = MsgpackEncoder()
|
||||
self.tensor_decoder = MsgpackDecoder(torch.Tensor, share_mem=False)
|
||||
self.mm_decoder = MsgpackDecoder(MultiModalKwargsItem, share_mem=False)
|
||||
self._mm_kwargs_item_cls = MultiModalKwargsItem
|
||||
|
||||
def serialize(self, value: Any) -> tuple[bytes | list[bytes], int, bytes, int]:
|
||||
len_arr = None
|
||||
if isinstance(value, (torch.Tensor, self._mm_kwargs_item_cls)):
|
||||
type_name = type(value).__name__
|
||||
value = self.encoder.encode(value)
|
||||
len_arr = [len(s) for s in value]
|
||||
nbytes = sum(len_arr)
|
||||
else:
|
||||
value = pickle.dumps(value, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
type_name = type(value).__name__
|
||||
nbytes = len(value)
|
||||
|
||||
object_metadata = (type_name, nbytes, len_arr)
|
||||
serialized_metadata = pickle.dumps(
|
||||
object_metadata, protocol=pickle.HIGHEST_PROTOCOL
|
||||
)
|
||||
return value, nbytes, serialized_metadata, len(serialized_metadata)
|
||||
|
||||
def deserialize(self, data_view: memoryview) -> Any:
|
||||
# pickle.loads do not read past the end of a pickled object
|
||||
# within a large buffer, so we can skip storing the metadata size
|
||||
type_name, nbytes, len_arr = pickle.loads(data_view)
|
||||
serialized_data = data_view[-nbytes:]
|
||||
|
||||
if type_name == torch.Tensor.__name__:
|
||||
obj = []
|
||||
start_idx = 0
|
||||
for length in len_arr:
|
||||
item_bytes = serialized_data[start_idx : start_idx + length]
|
||||
obj.append(item_bytes)
|
||||
start_idx += length
|
||||
obj = self.tensor_decoder.decode(obj)
|
||||
elif type_name == self._mm_kwargs_item_cls.__name__:
|
||||
obj = []
|
||||
start_idx = 0
|
||||
for length in len_arr:
|
||||
item_bytes = serialized_data[start_idx : start_idx + length]
|
||||
obj.append(item_bytes)
|
||||
start_idx += length
|
||||
obj = self.mm_decoder.decode(obj)
|
||||
elif type_name == bytes.__name__:
|
||||
obj = pickle.loads(serialized_data)
|
||||
else:
|
||||
raise ValueError(f"Unsupported object type '{type_name}' in metadata")
|
||||
|
||||
return obj
|
||||
|
||||
|
||||
@dataclass
|
||||
class ShmObjectStorageHandle:
|
||||
max_object_size: int
|
||||
n_readers: int
|
||||
ring_buffer_handle: tuple[int, str]
|
||||
serde_class: type[ObjectSerde]
|
||||
reader_lock: LockType | None
|
||||
|
||||
|
||||
class SingleWriterShmObjectStorage:
|
||||
"""
|
||||
A single-writer, multiple-reader object storage system built on top of a
|
||||
shared memory ring buffer. Provides key-value storage with automatic memory
|
||||
management and cross-process serialization support.
|
||||
|
||||
This storage system follows a FIFO (First-In-First-Out) eviction policy
|
||||
where the oldest objects are automatically freed when memory runs low.
|
||||
Memory is reclaimed based on reader reference counting - objects are only
|
||||
freed when all readers have finished accessing them.
|
||||
|
||||
Architecture:
|
||||
- Single writer process can put(key, value) objects
|
||||
- Multiple reader processes can get(address, monotonic_id) objects
|
||||
- Built on SingleWriterShmRingBuffer for efficient shared memory management
|
||||
- Thread-safe operations with reader synchronization via locks
|
||||
|
||||
Key Features:
|
||||
- FIFO Eviction: Oldest objects are evicted first when memory is full
|
||||
- Reference Counting: Objects are only freed when no readers are
|
||||
accessing them
|
||||
- Duplicate Key Handling: Existing keys are not overwritten, just
|
||||
re-referenced
|
||||
- Customized Serialization: By default uses Msgpack for efficient
|
||||
serialization of Python objects, but can be extended for custom types
|
||||
- Cross-Process Safety: Uses shared memory with proper synchronization
|
||||
- Automatic Cleanup: Garbage collection happens transparently during
|
||||
allocation
|
||||
|
||||
Memory Layout per Object:
|
||||
`[4-byte reference_count][metadata_size][serialized_object_data]`
|
||||
|
||||
Thread Safety:
|
||||
- Writer operations (put, clear) are single-threaded by design
|
||||
- Reader operations (get) are thread-safe with lock-based reference
|
||||
counting
|
||||
- Memory reclamation is handled exclusively by the writer process
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_object_size: int,
|
||||
n_readers: int,
|
||||
ring_buffer: SingleWriterShmRingBuffer,
|
||||
serde_class: type[ObjectSerde] = MsgpackSerde,
|
||||
reader_lock: LockType | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize the object storage.
|
||||
|
||||
Args:
|
||||
max_object_size: Maximum size for a single object in bytes.
|
||||
n_readers: Number of reader processes that can access the storage.
|
||||
ring_buffer: The shared memory ring buffer for storing objects.
|
||||
serde_class: Serializer/deserializer for objects.
|
||||
reader_lock: Optional lock for synchronizing reader access.
|
||||
Raises:
|
||||
ValueError: If reader_lock is None for readers.
|
||||
"""
|
||||
|
||||
self.max_object_size = max_object_size
|
||||
self.n_readers = n_readers
|
||||
self.serde_class = serde_class
|
||||
self.ser_de = serde_class()
|
||||
self.ring_buffer = ring_buffer
|
||||
self.is_writer = self.ring_buffer.is_writer
|
||||
|
||||
self.flag_bytes = 4 # for in-use flag
|
||||
|
||||
if self.is_writer:
|
||||
# Key-value mapping: key -> (address, monotonic_id)
|
||||
self.key_index: dict[str, tuple[int, int]] = {}
|
||||
# Reverse mapping: monotonic_id -> key
|
||||
self.id_index: dict[int, str] = {}
|
||||
# Writer flag to track in-use status: monotonic_id -> count
|
||||
self.writer_flag: dict[int, int] = {}
|
||||
else:
|
||||
if reader_lock is None:
|
||||
raise ValueError("Lock must be provided for readers.")
|
||||
|
||||
self._reader_lock = reader_lock
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear the object storage."""
|
||||
if self.is_writer:
|
||||
self.ring_buffer.clear()
|
||||
self.key_index.clear()
|
||||
self.id_index.clear()
|
||||
self.writer_flag.clear()
|
||||
logger.debug("Object storage cleared and reinitialized.")
|
||||
|
||||
def copy_to_buffer(
|
||||
self,
|
||||
data: bytes | list[bytes],
|
||||
data_bytes: int,
|
||||
metadata: bytes,
|
||||
md_bytes: int,
|
||||
data_view: memoryview,
|
||||
) -> None:
|
||||
data_view[self.flag_bytes : self.flag_bytes + md_bytes] = metadata
|
||||
if isinstance(data, bytes):
|
||||
data_view[-data_bytes:] = data
|
||||
elif isinstance(data, list):
|
||||
start_idx = self.flag_bytes + md_bytes
|
||||
for item_bytes in data:
|
||||
item_size = len(item_bytes)
|
||||
data_view[start_idx : start_idx + item_size] = item_bytes
|
||||
start_idx += item_size
|
||||
else:
|
||||
raise ValueError(f"Unsupported data type for serialization: {type(data)}")
|
||||
|
||||
def increment_writer_flag(self, id: int) -> None:
|
||||
"""Set the in-use flag for the writer."""
|
||||
self.writer_flag[id] = self.writer_flag.get(id, 0) + 1
|
||||
|
||||
def increment_reader_flag(self, data_view: memoryview) -> None:
|
||||
"""Set the in-use flag for the reader."""
|
||||
# >0 for in-use flag
|
||||
reader_count = self.ring_buffer.byte2int(data_view)
|
||||
data_view[:] = self.ring_buffer.int2byte(reader_count + 1)
|
||||
|
||||
def free_unused(self) -> None:
|
||||
"""Free unused buffers in the ring buffer."""
|
||||
# try to free up 2*max_object_size bytes of space in the ring buffer,
|
||||
# since the buffer might be fragmented
|
||||
freed_ids = self.ring_buffer.free_buf(
|
||||
self.default_is_free_check, 2 * self.max_object_size
|
||||
)
|
||||
# update the metadata after freeing up space
|
||||
for freed_id in freed_ids:
|
||||
key_to_free = self.id_index[freed_id]
|
||||
del self.key_index[key_to_free]
|
||||
del self.id_index[freed_id]
|
||||
del self.writer_flag[freed_id]
|
||||
|
||||
def is_cached(self, key: str) -> bool:
|
||||
"""
|
||||
Check if the object with the given key is cached.
|
||||
"""
|
||||
return key in self.key_index
|
||||
|
||||
def get_cached(self, key: str) -> tuple[int, int]:
|
||||
"""
|
||||
Get the cached object by key if it exists.
|
||||
"""
|
||||
address, monotonic_id = self.key_index[key]
|
||||
self.increment_writer_flag(monotonic_id)
|
||||
return address, monotonic_id
|
||||
|
||||
def put(self, key: str, value: Any) -> tuple[int, int]:
|
||||
"""
|
||||
Store a key-value pair in the object storage.
|
||||
Attempts to free max_object_size bytes using FIFO order
|
||||
when the ring buffer runs out of space during a put() operation.
|
||||
|
||||
Args:
|
||||
key: String key to identify the object
|
||||
value: Any serializable Python object
|
||||
|
||||
Raises:
|
||||
MemoryError: If there's not enough space in the buffer
|
||||
ValueError: If the serialized object is too large
|
||||
ValueError: If the key already exists in the storage
|
||||
"""
|
||||
if key in self.key_index:
|
||||
raise ValueError(f"Key '{key}' already exists in the storage.")
|
||||
|
||||
object_data, data_bytes, object_metadata, md_bytes = self.ser_de.serialize(
|
||||
value
|
||||
)
|
||||
buffer_size = self.flag_bytes + data_bytes + md_bytes
|
||||
# Sanity checks
|
||||
if buffer_size > self.max_object_size:
|
||||
raise ValueError(
|
||||
f"Serialized object size ({buffer_size} bytes) exceeds "
|
||||
f"max object size ({self.max_object_size} bytes)"
|
||||
)
|
||||
|
||||
# Allocate new buffer
|
||||
try:
|
||||
address, monotonic_id = self.ring_buffer.allocate_buf(buffer_size)
|
||||
except MemoryError:
|
||||
self.free_unused()
|
||||
# try again after freeing up space
|
||||
address, monotonic_id = self.ring_buffer.allocate_buf(buffer_size)
|
||||
|
||||
# Write data to buffer
|
||||
with self.ring_buffer.access_buf(address) as (data_view, metadata):
|
||||
data_view[: self.flag_bytes] = self.ring_buffer.int2byte(0)
|
||||
self.copy_to_buffer(
|
||||
object_data, data_bytes, object_metadata, md_bytes, data_view
|
||||
)
|
||||
self.increment_writer_flag(monotonic_id)
|
||||
|
||||
# Update key index
|
||||
self.key_index[key] = (address, monotonic_id)
|
||||
self.id_index[monotonic_id] = key
|
||||
return address, monotonic_id
|
||||
|
||||
def get(self, address: int, monotonic_id: int) -> Any:
|
||||
# Read data from buffer
|
||||
with self.ring_buffer.access_buf(address) as (data_view, buf_metadata):
|
||||
# check id from metadata
|
||||
if buf_metadata[0] != monotonic_id:
|
||||
raise ValueError(
|
||||
f"Data for address:id '{address}:{monotonic_id}'"
|
||||
" has been modified or is invalid."
|
||||
)
|
||||
|
||||
obj = self.ser_de.deserialize(data_view[self.flag_bytes :])
|
||||
|
||||
# decrease the in-use flag for reader reads
|
||||
if self._reader_lock is not None:
|
||||
with self._reader_lock:
|
||||
self.increment_reader_flag(data_view[: self.flag_bytes])
|
||||
else:
|
||||
# if self._reader_lock is None, it means we are the writer
|
||||
# in this case, we do not need to decrease the reader count
|
||||
assert self.is_writer
|
||||
|
||||
return obj
|
||||
|
||||
def touch(
|
||||
self,
|
||||
key: str,
|
||||
address: int = 0,
|
||||
monotonic_id: int = 0,
|
||||
) -> None:
|
||||
"""
|
||||
Touch an existing cached item to update its eviction status.
|
||||
|
||||
For writers (ShmObjectStoreSenderCache): Increment writer_flag
|
||||
For readers (ShmObjectStoreReceiverCache): Increment reader_count
|
||||
|
||||
Args:
|
||||
key: String key of the object to touch
|
||||
address: Address of the object (only for readers)
|
||||
monotonic_id: Monotonic ID of the object (only for readers)
|
||||
|
||||
"""
|
||||
if self._reader_lock is None:
|
||||
if key not in self.key_index:
|
||||
return None
|
||||
address, monotonic_id = self.key_index[key]
|
||||
# Writer side: increment writer_flag to raise eviction threshold
|
||||
self.increment_writer_flag(monotonic_id)
|
||||
else:
|
||||
with (
|
||||
self._reader_lock,
|
||||
self.ring_buffer.access_buf(address) as (data_view, _),
|
||||
):
|
||||
reader_count = self.ring_buffer.byte2int(data_view[: self.flag_bytes])
|
||||
|
||||
# NOTE(Long):
|
||||
# Avoid increasing flag on newly added item (sync with sender)
|
||||
# Since when a new item is added
|
||||
# pre-touch has no effect on writer side
|
||||
if reader_count >= self.n_readers:
|
||||
self.increment_reader_flag(data_view[: self.flag_bytes])
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the shared memory."""
|
||||
self.ring_buffer.close()
|
||||
|
||||
def handle(self):
|
||||
"""Get handle for sharing across processes."""
|
||||
return ShmObjectStorageHandle(
|
||||
max_object_size=self.max_object_size,
|
||||
n_readers=self.n_readers,
|
||||
ring_buffer_handle=self.ring_buffer.handle(),
|
||||
serde_class=self.serde_class,
|
||||
reader_lock=self._reader_lock,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def create_from_handle(
|
||||
handle: ShmObjectStorageHandle,
|
||||
) -> "SingleWriterShmObjectStorage":
|
||||
logger.debug("Creating storage from handle: %s", handle)
|
||||
ring_buffer = SingleWriterShmRingBuffer(*handle.ring_buffer_handle)
|
||||
return SingleWriterShmObjectStorage(
|
||||
max_object_size=handle.max_object_size,
|
||||
n_readers=handle.n_readers,
|
||||
ring_buffer=ring_buffer,
|
||||
serde_class=handle.serde_class,
|
||||
reader_lock=handle.reader_lock,
|
||||
)
|
||||
|
||||
def default_is_free_check(self, id: int, buf: memoryview) -> bool:
|
||||
"""
|
||||
Default is_free function that checks if the first 4 bytes are zero.
|
||||
This indicates that the buffer is free.
|
||||
"""
|
||||
reader_count = int.from_bytes(buf[0:4], "little", signed=True)
|
||||
writer_count = self.writer_flag[id]
|
||||
return reader_count >= writer_count * self.n_readers
|
||||
156
vllm/distributed/device_communicators/symm_mem.py
Normal file
156
vllm/distributed/device_communicators/symm_mem.py
Normal file
@@ -0,0 +1,156 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from vllm.distributed.device_communicators.all_reduce_utils import (
|
||||
SYMM_MEM_ALL_REDUCE_MAX_SIZES,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_is_batch_invariant,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
try:
|
||||
import torch.distributed._symmetric_memory as torch_symm_mem
|
||||
|
||||
symm_mem_available = True
|
||||
except ImportError:
|
||||
symm_mem_available = False
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class SymmMemCommunicator:
|
||||
_WORLD_SIZES_MULTIMEM = {
|
||||
"9.0": [4, 6, 8],
|
||||
"10.0": [6, 8],
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
group: ProcessGroup,
|
||||
device: int | str | torch.device,
|
||||
# add options for testing
|
||||
force_multimem: bool | None = None,
|
||||
max_size_override: int | None = None,
|
||||
):
|
||||
self.disabled = True
|
||||
|
||||
if not symm_mem_available:
|
||||
return
|
||||
|
||||
if not current_platform.is_cuda():
|
||||
logger.warning("SymmMemCommunicator: symmetric memory is not available.")
|
||||
return
|
||||
if isinstance(device, int):
|
||||
device = torch.device(f"cuda:{device}")
|
||||
elif isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
torch.cuda.set_device(device)
|
||||
self.dtype = torch.bfloat16
|
||||
self.device = device
|
||||
self.group = group
|
||||
self.world_size = dist.get_world_size(self.group)
|
||||
capability = current_platform.get_device_capability()
|
||||
if capability is None:
|
||||
logger.warning(
|
||||
"SymmMemCommunicator: device capability is unknown, "
|
||||
"communicator is not available."
|
||||
)
|
||||
return
|
||||
self.device_capability = capability.as_version_str()
|
||||
if self.device_capability not in SYMM_MEM_ALL_REDUCE_MAX_SIZES:
|
||||
logger.warning(
|
||||
"SymmMemCommunicator: Device capability %s not supported, "
|
||||
"communicator is not available.",
|
||||
self.device_capability,
|
||||
)
|
||||
return
|
||||
if self.world_size not in SYMM_MEM_ALL_REDUCE_MAX_SIZES[self.device_capability]:
|
||||
logger.warning(
|
||||
"SymmMemCommunicator: World size %d not supported, "
|
||||
"communicator is not available.",
|
||||
self.world_size,
|
||||
)
|
||||
return
|
||||
# Use override max_size if provided, otherwise use default
|
||||
if max_size_override is not None:
|
||||
self.max_size = max_size_override
|
||||
logger.info(
|
||||
"SymmMemCommunicator: Using override max_size: %s bytes",
|
||||
self.max_size,
|
||||
)
|
||||
else:
|
||||
self.max_size = SYMM_MEM_ALL_REDUCE_MAX_SIZES[self.device_capability][
|
||||
self.world_size
|
||||
]
|
||||
try:
|
||||
self.buffer = torch_symm_mem.empty(
|
||||
self.max_size // self.dtype.itemsize,
|
||||
device=self.device,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
handle = torch_symm_mem.rendezvous(self.buffer, self.group.group_name)
|
||||
except RuntimeError as e:
|
||||
logger.warning_once(
|
||||
"SymmMemCommunicator: symmetric memory initialization failed: %s "
|
||||
"Communicator is not available. To suppress this warning set "
|
||||
"VLLM_ALLREDUCE_USE_SYMM_MEM=0",
|
||||
str(e),
|
||||
)
|
||||
return
|
||||
if handle.multicast_ptr == 0:
|
||||
logger.warning(
|
||||
"SymmMemCommunicator: symmetric memory "
|
||||
"multicast operations are not supported."
|
||||
)
|
||||
return
|
||||
self.force_multimem = force_multimem
|
||||
self.disabled = False
|
||||
if vllm_is_batch_invariant():
|
||||
self.disabled = True
|
||||
|
||||
def should_use_symm_mem(self, inp: torch.Tensor):
|
||||
if self.disabled:
|
||||
return False
|
||||
if inp.dtype != self.dtype:
|
||||
return False
|
||||
inp_size = inp.numel() * inp.element_size()
|
||||
if inp_size % 4 != 0:
|
||||
return False
|
||||
return inp_size < self.max_size
|
||||
|
||||
def all_reduce(
|
||||
self, inp: torch.Tensor, *, out: torch.Tensor | None = None
|
||||
) -> torch.Tensor | None:
|
||||
if not self.should_use_symm_mem(inp):
|
||||
return None
|
||||
if out is None:
|
||||
out = torch.empty_like(inp)
|
||||
self.buffer[: inp.numel()].copy_(inp.view(-1))
|
||||
|
||||
# Determine which algorithm to use
|
||||
use_multimem = False
|
||||
if self.force_multimem is not None:
|
||||
# Test override: use forced setting
|
||||
use_multimem = self.force_multimem
|
||||
else:
|
||||
# Normal logic: use multimem for supported world sizes
|
||||
use_multimem = (
|
||||
self.world_size in self._WORLD_SIZES_MULTIMEM[self.device_capability]
|
||||
)
|
||||
|
||||
if use_multimem:
|
||||
torch.ops.symm_mem.multimem_all_reduce_(
|
||||
self.buffer[: inp.numel()], "sum", self.group.group_name
|
||||
)
|
||||
else:
|
||||
torch.ops.symm_mem.two_shot_all_reduce_(
|
||||
self.buffer[: inp.numel()], "sum", self.group.group_name
|
||||
)
|
||||
out.copy_(self.buffer[: inp.numel()].view(out.shape))
|
||||
return out
|
||||
257
vllm/distributed/device_communicators/xpu_communicator.py
Normal file
257
vllm/distributed/device_communicators/xpu_communicator.py
Normal file
@@ -0,0 +1,257 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .base_device_communicator import DeviceCommunicatorBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class XpuCommunicator(DeviceCommunicatorBase):
|
||||
def __init__(
|
||||
self,
|
||||
cpu_group: ProcessGroup,
|
||||
device: torch.device | None = None,
|
||||
device_group: ProcessGroup | None = None,
|
||||
unique_name: str = "",
|
||||
):
|
||||
super().__init__(cpu_group, device, device_group, unique_name)
|
||||
if self.use_all2all:
|
||||
if self.all2all_backend == "naive":
|
||||
from .all2all import NaiveAll2AllManager
|
||||
|
||||
self.all2all_manager = NaiveAll2AllManager(self.cpu_group)
|
||||
logger.info("Using naive all2all manager.")
|
||||
|
||||
elif self.all2all_backend == "allgather_reducescatter":
|
||||
from .all2all import AgRsAll2AllManager
|
||||
|
||||
self.all2all_manager = AgRsAll2AllManager(self.cpu_group)
|
||||
logger.info("Using AgRs manager on XPU device.")
|
||||
|
||||
else: # type: ignore[has-type]
|
||||
logger.warning(
|
||||
"`%s` all2all manager is not supported on XPU. "
|
||||
"Falling back to AgRs manager for XPU, "
|
||||
"which is the Default backend",
|
||||
self.all2all_backend, # type: ignore[has-type]
|
||||
)
|
||||
from .all2all import AgRsAll2AllManager
|
||||
|
||||
self.all2all_manager = AgRsAll2AllManager(self.cpu_group)
|
||||
logger.info("Using AgRs manager on XPU device.")
|
||||
|
||||
def all_reduce(self, input_) -> torch.Tensor:
|
||||
dist.all_reduce(input_, group=self.device_group)
|
||||
return input_
|
||||
|
||||
def reduce_scatter(self, input_: torch.Tensor, dim: int = -1):
|
||||
world_size = self.world_size
|
||||
|
||||
if dim < 0:
|
||||
# Convert negative dim to positive.
|
||||
dim += input_.dim()
|
||||
|
||||
# Note: This will produce an incorrect answer if we don't make
|
||||
# the input_tensor contiguous. Possible bug in reduce_scatter_tensor?
|
||||
input_tensor = input_.movedim(0, dim).contiguous()
|
||||
|
||||
assert input_tensor.shape[0] % world_size == 0
|
||||
chunk_size = input_tensor.shape[0] // world_size
|
||||
output_shape = (chunk_size,) + input_tensor.shape[1:]
|
||||
|
||||
output = torch.empty(
|
||||
output_shape, dtype=input_tensor.dtype, device=input_tensor.device
|
||||
)
|
||||
|
||||
dist.reduce_scatter_tensor(output, input_tensor)
|
||||
|
||||
# Reshape before returning
|
||||
return output.movedim(0, dim).contiguous()
|
||||
|
||||
def reduce_scatterv(
|
||||
self, input_: torch.Tensor, dim: int = -1, sizes: list[int] | None = None
|
||||
):
|
||||
world_size = self.world_size
|
||||
|
||||
if dim < 0:
|
||||
# Convert negative dim to positive.
|
||||
dim += input_.dim()
|
||||
|
||||
# Note: This will produce an incorrect answer if we don't make
|
||||
# the input_tensor contiguous. Possible bug in reduce_scatter_tensor?
|
||||
input_tensor = input_.movedim(0, dim).contiguous()
|
||||
|
||||
if sizes is not None:
|
||||
assert len(sizes) == world_size
|
||||
assert input_tensor.shape[0] == sum(sizes)
|
||||
chunk_size = sizes[self.rank_in_group]
|
||||
else:
|
||||
assert input_tensor.shape[0] % world_size == 0
|
||||
chunk_size = input_tensor.shape[0] // world_size
|
||||
output_shape = (chunk_size,) + input_tensor.shape[1:]
|
||||
|
||||
output = torch.empty(
|
||||
output_shape, dtype=input_tensor.dtype, device=input_tensor.device
|
||||
)
|
||||
if sizes is not None and sizes.count(sizes[0]) != len(sizes):
|
||||
# if inputs shape in different ranks is not the same using reduce_scatter
|
||||
input_splits = list(input_tensor.split(sizes, dim=0))
|
||||
dist.reduce_scatter(output, input_splits)
|
||||
else:
|
||||
dist.reduce_scatter_tensor(output, input_tensor)
|
||||
# Reshape before returning
|
||||
return output.movedim(0, dim).contiguous()
|
||||
|
||||
def all_gatherv(
|
||||
self,
|
||||
input_: torch.Tensor | list[torch.Tensor],
|
||||
dim: int = 0,
|
||||
sizes: list[int] | None = None,
|
||||
):
|
||||
if dim != 0:
|
||||
raise NotImplementedError("only dim 0 all-gatherv is supported")
|
||||
world_size = self.world_size
|
||||
|
||||
# 'sizes' is not needed if all inputs in the same group have the same
|
||||
# shape
|
||||
if sizes is not None and all(s == sizes[0] for s in sizes):
|
||||
sizes = None
|
||||
|
||||
def _all_gather_single(input_: torch.Tensor, sizes: list[int] | None = None):
|
||||
input_size = input_.size()
|
||||
if sizes is not None:
|
||||
assert len(sizes) == world_size
|
||||
assert input_.shape[dim] == sizes[self.rank_in_group], (
|
||||
f"{input_.shape[dim]} != {sizes[self.rank_in_group]}"
|
||||
)
|
||||
output_size = (sum(sizes),) + input_size[1:]
|
||||
else:
|
||||
output_size = (input_size[0] * world_size,) + input_size[1:]
|
||||
# Allocate output tensor.
|
||||
output_tensor = torch.empty(
|
||||
output_size, dtype=input_.dtype, device=input_.device
|
||||
)
|
||||
|
||||
if sizes is not None:
|
||||
all_gather_list = []
|
||||
for size in sizes:
|
||||
all_gather_list.append(
|
||||
torch.empty(
|
||||
(size,) + input_.shape[1:],
|
||||
dtype=input_.dtype,
|
||||
device=input_.device,
|
||||
)
|
||||
)
|
||||
dist.all_gather(all_gather_list, input_)
|
||||
output_tensor = torch.cat(all_gather_list, dim=0)
|
||||
else:
|
||||
dist.all_gather([output_tensor], input_)
|
||||
return output_tensor
|
||||
|
||||
if isinstance(input_, torch.Tensor):
|
||||
return _all_gather_single(input_, sizes)
|
||||
|
||||
output_list = []
|
||||
for inp in input_:
|
||||
output_list.append(_all_gather_single(inp, sizes=sizes))
|
||||
return output_list
|
||||
|
||||
def gather(
|
||||
self, input_: torch.Tensor, dst: int = 0, dim: int = -1
|
||||
) -> torch.Tensor | None:
|
||||
assert -input_.dim() <= dim < input_.dim(), (
|
||||
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}"
|
||||
)
|
||||
if dim < 0:
|
||||
# Convert negative dim to positive.
|
||||
dim += input_.dim()
|
||||
# For xpu path, gather doesn't work properly together with ray
|
||||
# cluster so we use all_gather instead for now.
|
||||
input_size = input_.size()
|
||||
# Allocate output tensor.
|
||||
output_tensor = torch.empty(
|
||||
(self.world_size,) + input_size, dtype=input_.dtype, device=input_.device
|
||||
)
|
||||
# All-gather.
|
||||
dist.all_gather_into_tensor(output_tensor, input_, group=self.device_group)
|
||||
if self.rank_in_group == dst:
|
||||
# Reshape
|
||||
output_tensor = output_tensor.movedim(0, dim)
|
||||
output_tensor = output_tensor.reshape(
|
||||
input_size[:dim]
|
||||
+ (self.world_size * input_size[dim],)
|
||||
+ input_size[dim + 1 :]
|
||||
)
|
||||
else:
|
||||
output_tensor = None
|
||||
return output_tensor
|
||||
|
||||
def broadcast(self, input_: torch.Tensor, src: int = 0) -> None:
|
||||
dist.broadcast(input_, src=src, group=self.device_group)
|
||||
|
||||
def dispatch_router_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_sequence_parallel: bool = False,
|
||||
extra_tensors: list[torch.Tensor] | None = None,
|
||||
) -> (
|
||||
tuple[torch.Tensor, torch.Tensor]
|
||||
| tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
|
||||
):
|
||||
"""
|
||||
Dispatch the hidden states and router logits to the appropriate device.
|
||||
This is a no-op in the base class.
|
||||
"""
|
||||
|
||||
assert self.all2all_manager is not None
|
||||
return self.all2all_manager.dispatch_router_logits(
|
||||
hidden_states,
|
||||
router_logits,
|
||||
is_sequence_parallel,
|
||||
extra_tensors,
|
||||
)
|
||||
|
||||
def dispatch(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
is_sequence_parallel: bool = False,
|
||||
extra_tensors: list[torch.Tensor] | None = None,
|
||||
) -> (
|
||||
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
|
||||
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
|
||||
):
|
||||
"""
|
||||
Dispatch the hidden states and topk weights/ids to the appropriate device.
|
||||
This is a no-op in the base class.
|
||||
"""
|
||||
assert self.all2all_manager is not None
|
||||
return self.all2all_manager.dispatch(
|
||||
hidden_states,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
is_sequence_parallel,
|
||||
extra_tensors=extra_tensors,
|
||||
)
|
||||
|
||||
def combine(
|
||||
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Combine the hidden states and router logits from the appropriate device.
|
||||
This is a no-op in the base class.
|
||||
"""
|
||||
assert self.all2all_manager is not None
|
||||
return self.all2all_manager.combine(
|
||||
hidden_states,
|
||||
is_sequence_parallel,
|
||||
)
|
||||
14
vllm/distributed/ec_transfer/__init__.py
Normal file
14
vllm/distributed/ec_transfer/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm.distributed.ec_transfer.ec_transfer_state import (
|
||||
ensure_ec_transfer_initialized,
|
||||
get_ec_transfer,
|
||||
has_ec_transfer,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"get_ec_transfer",
|
||||
"ensure_ec_transfer_initialized",
|
||||
"has_ec_transfer",
|
||||
]
|
||||
252
vllm/distributed/ec_transfer/ec_connector/base.py
Normal file
252
vllm/distributed/ec_transfer/ec_connector/base.py
Normal file
@@ -0,0 +1,252 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
ECConnectorBase Class for Distributed Encoder Cache &
|
||||
P2P Encoder cache communication in V1
|
||||
|
||||
The class provides the following primitives:
|
||||
Scheduler-side: runs in the scheduler, binds metadata, which
|
||||
is used by the worker-side to load/save Encoder cache.
|
||||
check_caches_exist() - Check whether Encoder cache of requests exist
|
||||
update_state_after_alloc() - update ECConnector state after
|
||||
allocate. This will decide to load the cache or not
|
||||
request_finished() - called when a request is finished,
|
||||
free the cache with the requests
|
||||
|
||||
Worker-side: runs in each worker, loads/saves Encoder Cache to/from
|
||||
the Connector based on the metadata.
|
||||
start_load_ec() - starts loading all ECs (maybe async)
|
||||
wait_for_save() - blocks until all saves are done
|
||||
|
||||
get_finished() - called with ids of finished requests, returns
|
||||
ids of requests that have completed async sending/recving.
|
||||
"""
|
||||
|
||||
import enum
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.outputs import ECConnectorOutput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.v1.request import Request
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class ECConnectorRole(enum.Enum):
|
||||
# Connector running in the scheduler process
|
||||
SCHEDULER = 0
|
||||
|
||||
# Connector running in the worker process
|
||||
WORKER = 1
|
||||
|
||||
|
||||
class ECConnectorMetadata(ABC): # noqa: B024
|
||||
"""
|
||||
Abstract Metadata used to communicate between the
|
||||
Scheduler ECConnector and Worker ECConnector.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ECConnectorBase(ABC):
|
||||
def __init__(self, vllm_config: "VllmConfig", role: ECConnectorRole):
|
||||
self._connector_metadata: ECConnectorMetadata | None = None
|
||||
self._vllm_config = vllm_config
|
||||
self._role = role
|
||||
if vllm_config.ec_transfer_config is not None:
|
||||
self._is_producer = vllm_config.ec_transfer_config.is_ec_producer
|
||||
self._is_consumer = vllm_config.ec_transfer_config.is_ec_consumer
|
||||
else:
|
||||
raise ValueError("ec_transfer_config must be set for ECConnectorBase")
|
||||
|
||||
@property
|
||||
def role(self) -> ECConnectorRole:
|
||||
return self._role
|
||||
|
||||
@property
|
||||
def is_producer(self) -> bool:
|
||||
return self._is_producer
|
||||
|
||||
@property
|
||||
def is_consumer(self) -> bool:
|
||||
return self._is_consumer
|
||||
|
||||
# ==============================
|
||||
# Worker-side methods
|
||||
# ==============================
|
||||
|
||||
def bind_connector_metadata(self, connector_metadata: ECConnectorMetadata) -> None:
|
||||
"""Set the connector metadata from the scheduler.
|
||||
|
||||
This function should be called by the model runner every time
|
||||
before the model execution. The metadata will be used for runtime
|
||||
EC cache loading.
|
||||
|
||||
Args:
|
||||
connector_metadata (dict): the connector metadata.
|
||||
"""
|
||||
self._connector_metadata = connector_metadata
|
||||
|
||||
def clear_connector_metadata(self) -> None:
|
||||
"""Clear the connector metadata.
|
||||
|
||||
This function should be called by the model runner every time
|
||||
after the model execution.
|
||||
"""
|
||||
self._connector_metadata = None
|
||||
|
||||
def _get_connector_metadata(self) -> ECConnectorMetadata:
|
||||
"""Get the connector metadata.
|
||||
|
||||
This function should only be called inside the connector.
|
||||
|
||||
Returns:
|
||||
ConnectorMetadata: the connector metadata.
|
||||
"""
|
||||
|
||||
# Should only be called while set to valid metadata.
|
||||
assert self._connector_metadata is not None
|
||||
return self._connector_metadata
|
||||
|
||||
def register_caches(
|
||||
self,
|
||||
ec_caches: dict[str, torch.Tensor],
|
||||
):
|
||||
"""
|
||||
Initialize with the EC caches.
|
||||
Args:
|
||||
ec_caches: dictionary of encoder cache
|
||||
"""
|
||||
# TODO: Implement this later for P2P feature
|
||||
return
|
||||
|
||||
@abstractmethod
|
||||
def start_load_caches(
|
||||
self, encoder_cache: dict[str, torch.Tensor], **kwargs
|
||||
) -> None:
|
||||
"""
|
||||
Start loading the cache from the connector into vLLM's encoder cache.
|
||||
|
||||
This method loads the encoder cache based on metadata provided by the scheduler.
|
||||
It is called before `_gather_mm_embeddings` for the EC Connector. For EC,
|
||||
the `encoder_cache` and `mm_hash` are stored in `kwargs`.
|
||||
|
||||
Args:
|
||||
encoder_cache (dict[str, torch.Tensor]): A dictionary mapping multimodal
|
||||
data hashes (`mm_hash`) to encoder cache tensors.
|
||||
kwargs (dict): Additional keyword arguments for the connector.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save_caches(
|
||||
self, encoder_cache: dict[str, torch.Tensor], mm_hash: str, **kwargs
|
||||
) -> None:
|
||||
"""
|
||||
Save the encoder cache to the connector.
|
||||
|
||||
This method saves the encoder cache from the worker's local storage
|
||||
to shared storage or another external connector.
|
||||
|
||||
Args:
|
||||
encoder_cache (dict[str, torch.Tensor]): A dictionary mapping multimodal
|
||||
data hashes (`mm_hash`) to encoder cache tensors.
|
||||
mm_hash (str): The hash of the multimodal data whose cache is being saved.
|
||||
kwargs (dict): Additional keyword arguments for the connector.
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_finished(
|
||||
self, finished_req_ids: set[str]
|
||||
) -> tuple[set[str] | None, set[str] | None]:
|
||||
"""
|
||||
Notifies worker-side connector ids of requests that have
|
||||
finished generating tokens on the worker.
|
||||
The scheduler process (via the Executors) will use this output
|
||||
to track which workers are done.
|
||||
|
||||
Returns:
|
||||
ids of requests that have finished asynchronous transfer
|
||||
(requests that previously returned True from request_finished()),
|
||||
tuple of (sending/saving ids, recving/loading ids).
|
||||
The finished saves/sends req ids must belong to a set provided in a
|
||||
call to this method (this call or a prior one).
|
||||
"""
|
||||
return None, None
|
||||
|
||||
# ==============================
|
||||
# Scheduler-side methods
|
||||
# ==============================
|
||||
|
||||
@abstractmethod
|
||||
def has_cache_item(
|
||||
self,
|
||||
identifier: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a single encoder cache exists
|
||||
|
||||
Args:
|
||||
identifier (str): the identifier of the media.
|
||||
|
||||
Returns:
|
||||
A bool where value is True if cache exist for
|
||||
the media
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_state_after_alloc(self, request: "Request", index: int):
|
||||
"""
|
||||
Update ECConnector state to decide allocate cache for requests
|
||||
|
||||
Args:
|
||||
request (Request): the request object.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def build_connector_meta(
|
||||
self, scheduler_output: SchedulerOutput
|
||||
) -> ECConnectorMetadata:
|
||||
"""
|
||||
Build the connector metadata for this step.
|
||||
|
||||
This function should NOT modify fields in the scheduler_output.
|
||||
Also, calling this function will reset the state of the connector.
|
||||
|
||||
Args:
|
||||
scheduler_output (SchedulerOutput): the scheduler output object.
|
||||
"""
|
||||
pass
|
||||
|
||||
def update_connector_output(self, connector_output: ECConnectorOutput):
|
||||
"""
|
||||
Update ECConnector state from worker-side connectors output.
|
||||
|
||||
Args:
|
||||
connector_output (ECConnectorOutput): the worker-side
|
||||
connectors output.
|
||||
"""
|
||||
return
|
||||
|
||||
def request_finished(
|
||||
self, request: "Request"
|
||||
) -> tuple[bool, dict[str, Any] | None]:
|
||||
"""
|
||||
Called when a request has finished, before its encoder cache is freed.
|
||||
|
||||
Returns:
|
||||
True if the request is being saved/sent asynchronously and cached
|
||||
should not be freed until the request_id is returned from
|
||||
get_finished().
|
||||
"""
|
||||
return False, None
|
||||
198
vllm/distributed/ec_transfer/ec_connector/example_connector.py
Normal file
198
vllm/distributed/ec_transfer/ec_connector/example_connector.py
Normal file
@@ -0,0 +1,198 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import safetensors
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.ec_transfer.ec_connector.base import (
|
||||
ECConnectorBase,
|
||||
ECConnectorMetadata,
|
||||
ECConnectorRole,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.request import Request
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MMMeta:
|
||||
mm_hash: str
|
||||
num_token: int
|
||||
|
||||
@staticmethod
|
||||
def make_meta(mm_hash, num_token) -> "MMMeta":
|
||||
return MMMeta(mm_hash=mm_hash, num_token=num_token)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ECExampleConnectorMetadata(ECConnectorMetadata):
|
||||
mm_datas: list[MMMeta]
|
||||
|
||||
def __init__(self):
|
||||
self.mm_datas = []
|
||||
|
||||
def add_mm_data(self, mm_data: MMMeta):
|
||||
self.mm_datas.append(mm_data)
|
||||
|
||||
|
||||
class ECExampleConnector(ECConnectorBase):
|
||||
# NOTE: This is Simple debug implementation of the EC connector.
|
||||
# It save / load the EC cache to / from the disk.
|
||||
|
||||
def __init__(self, vllm_config: "VllmConfig", role: ECConnectorRole):
|
||||
super().__init__(vllm_config=vllm_config, role=role)
|
||||
# req_id -> index
|
||||
self._mm_datas_need_loads: dict[str, int] = {}
|
||||
transfer_config = vllm_config.ec_transfer_config
|
||||
if transfer_config is not None:
|
||||
self._storage_path = transfer_config.get_from_extra_config(
|
||||
"shared_storage_path", "/tmp"
|
||||
)
|
||||
logger.debug(transfer_config)
|
||||
logger.debug("Shared storage path is %s", self._storage_path)
|
||||
else:
|
||||
raise ValueError("ec_transfer_config must be set for ECConnectorBase")
|
||||
|
||||
def start_load_caches(self, encoder_cache, **kwargs) -> None:
|
||||
"""
|
||||
Start loading the cache from the connector into vLLM's encoder cache.
|
||||
|
||||
This method loads the encoder cache based on metadata provided by the scheduler.
|
||||
It is called before `_gather_mm_embeddings` for the EC Connector. For EC,
|
||||
the `encoder_cache` and `mm_hash` are stored in `kwargs`.
|
||||
|
||||
Args:
|
||||
encoder_cache (dict[str, torch.Tensor]): A dictionary mapping multimodal
|
||||
data hashes (`mm_hash`) to encoder cache tensors.
|
||||
kwargs (dict): Additional keyword arguments for the connector.
|
||||
"""
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
# Get the metadata
|
||||
metadata: ECConnectorMetadata = self._get_connector_metadata()
|
||||
assert isinstance(metadata, ECExampleConnectorMetadata)
|
||||
assert encoder_cache is not None
|
||||
if metadata is None:
|
||||
logger.warning(
|
||||
"In connector.start_load_caches, but the connector metadata is None"
|
||||
)
|
||||
return
|
||||
# Load the EC for each mm data
|
||||
for mm_data in metadata.mm_datas:
|
||||
if mm_data.mm_hash in encoder_cache:
|
||||
continue
|
||||
filename = self._generate_filename_debug(mm_data.mm_hash)
|
||||
ec_cache = safetensors.torch.load_file(
|
||||
filename, device=current_platform.device_type
|
||||
)["ec_cache"]
|
||||
encoder_cache[mm_data.mm_hash] = ec_cache
|
||||
logger.debug("Success load encoder cache for hash %s", mm_data.mm_hash)
|
||||
|
||||
def save_caches(self, encoder_cache, mm_hash, **kwargs) -> None:
|
||||
"""
|
||||
Save the encoder cache to the connector.
|
||||
|
||||
This method saves the encoder cache from the worker's local storage
|
||||
to shared storage or another external connector.
|
||||
|
||||
Args:
|
||||
encoder_cache (dict[str, torch.Tensor]): A dictionary mapping multimodal
|
||||
data hashes (`mm_hash`) to encoder cache tensors.
|
||||
mm_hash (str): The hash of the multimodal data whose cache is being saved.
|
||||
kwargs (dict): Additional keyword arguments for the connector.
|
||||
"""
|
||||
# Return if it is PD Instance
|
||||
if not self.is_producer:
|
||||
return
|
||||
filename = self._generate_filename_debug(mm_hash)
|
||||
ec_cache = encoder_cache[mm_hash]
|
||||
tensors = {"ec_cache": ec_cache.detach().cpu()}
|
||||
safetensors.torch.save_file(tensors, filename)
|
||||
logger.debug("Save cache successful for mm_hash %s", mm_hash)
|
||||
|
||||
def has_cache_item(
|
||||
self,
|
||||
identifier: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if cache exist externally for the media
|
||||
|
||||
Args:
|
||||
identifier (str): the identifier of the media.
|
||||
|
||||
Returns:
|
||||
Bool indicate that media exists in cache or not
|
||||
"""
|
||||
return self._found_match_for_mm_data(identifier)
|
||||
|
||||
def update_state_after_alloc(
|
||||
self,
|
||||
request: "Request",
|
||||
index: int,
|
||||
) -> None:
|
||||
"""
|
||||
Update ECConnector state after encoder cache allocation.
|
||||
"""
|
||||
mm_hash = request.mm_features[index].identifier
|
||||
num_encoder_token = request.get_num_encoder_embeds(index)
|
||||
# Insert mm_hash only if this block has not been recorded yet.
|
||||
self._mm_datas_need_loads[mm_hash] = num_encoder_token
|
||||
|
||||
def build_connector_meta(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
) -> ECConnectorMetadata:
|
||||
"""Build the connector metadata for this step.
|
||||
|
||||
This function should NOT modify any fields in the scheduler_output.
|
||||
Also, calling this function will reset the state of the connector.
|
||||
This only build for load mm_data only
|
||||
Args:
|
||||
scheduler_output (SchedulerOutput): the scheduler output object.
|
||||
"""
|
||||
meta = ECExampleConnectorMetadata()
|
||||
for mm_hash, num_encoder_token in self._mm_datas_need_loads.items():
|
||||
meta.add_mm_data(MMMeta.make_meta(mm_hash, num_encoder_token))
|
||||
self._mm_datas_need_loads.clear()
|
||||
return meta
|
||||
|
||||
# ==============================
|
||||
# Helper functions
|
||||
# ==============================
|
||||
|
||||
def _found_match_for_mm_data(self, mm_hash) -> bool:
|
||||
"""Check if the cache is hit for the request."""
|
||||
filename = self._generate_filename_debug(mm_hash)
|
||||
return os.path.exists(filename)
|
||||
|
||||
def _generate_foldername_debug(
|
||||
self,
|
||||
mm_hash: str,
|
||||
create_folder: bool = True, # <- now defaults to True
|
||||
) -> str:
|
||||
"""
|
||||
Return the folder in which the cache for this mm_hash lives.
|
||||
If `create_folder` is True (default) the directory is created
|
||||
recursively the first time it is needed.
|
||||
"""
|
||||
foldername = os.path.join(self._storage_path, mm_hash)
|
||||
if create_folder:
|
||||
os.makedirs(foldername, exist_ok=True)
|
||||
return foldername
|
||||
|
||||
def _generate_filename_debug(self, mm_hash: str) -> str:
|
||||
"""
|
||||
Return the full path of the safetensors file for this mm_hash.
|
||||
Ensures the parent directory exists because
|
||||
`_generate_foldername_debug` is called with its default
|
||||
(`create_folder=True`).
|
||||
"""
|
||||
foldername = self._generate_foldername_debug(mm_hash) # <- folder auto-created
|
||||
return os.path.join(foldername, "encoder_cache.safetensors")
|
||||
85
vllm/distributed/ec_transfer/ec_connector/factory.py
Normal file
85
vllm/distributed/ec_transfer/ec_connector/factory.py
Normal file
@@ -0,0 +1,85 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import importlib
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from vllm.distributed.ec_transfer.ec_connector.base import (
|
||||
ECConnectorBase,
|
||||
ECConnectorRole,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ECTransferConfig, VllmConfig
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class ECConnectorFactory:
|
||||
_registry: dict[str, Callable[[], type[ECConnectorBase]]] = {}
|
||||
|
||||
@classmethod
|
||||
def register_connector(cls, name: str, module_path: str, class_name: str) -> None:
|
||||
"""Register a connector with a lazy-loading module and class name."""
|
||||
if name in cls._registry:
|
||||
raise ValueError(f"Connector '{name}' is already registered.")
|
||||
|
||||
def loader() -> type[ECConnectorBase]:
|
||||
module = importlib.import_module(module_path)
|
||||
return getattr(module, class_name)
|
||||
|
||||
cls._registry[name] = loader
|
||||
|
||||
@classmethod
|
||||
def create_connector(
|
||||
cls,
|
||||
config: "VllmConfig",
|
||||
role: ECConnectorRole,
|
||||
) -> ECConnectorBase:
|
||||
ec_transfer_config = config.ec_transfer_config
|
||||
if ec_transfer_config is None:
|
||||
raise ValueError("ec_transfer_config must be set to create a connector")
|
||||
connector_cls = cls.get_connector_class(ec_transfer_config)
|
||||
logger.info(
|
||||
"Creating connector with name: %s and engine_id: %s",
|
||||
connector_cls.__name__,
|
||||
ec_transfer_config.engine_id,
|
||||
)
|
||||
# Connector is explicitly separated into two roles.
|
||||
# Scheduler connector:
|
||||
# - Co-locate with scheduler process
|
||||
# - Should only be used inside the Scheduler class
|
||||
# Worker connector:
|
||||
# - Co-locate with worker process
|
||||
return connector_cls(config, role)
|
||||
|
||||
@classmethod
|
||||
def get_connector_class(
|
||||
cls, ec_transfer_config: "ECTransferConfig"
|
||||
) -> type[ECConnectorBase]:
|
||||
"""Get the connector class by name."""
|
||||
connector_name = ec_transfer_config.ec_connector
|
||||
if connector_name is None:
|
||||
raise ValueError("EC connect must not be None")
|
||||
elif connector_name in cls._registry:
|
||||
connector_cls = cls._registry[connector_name]()
|
||||
else:
|
||||
connector_module_path = ec_transfer_config.ec_connector_module_path
|
||||
if connector_module_path is None:
|
||||
raise ValueError(f"Unsupported connector type: {connector_name}")
|
||||
connector_module = importlib.import_module(connector_module_path)
|
||||
connector_cls = getattr(connector_module, connector_name)
|
||||
return connector_cls
|
||||
|
||||
|
||||
# Register various connectors here.
|
||||
# The registration should not be done in each individual file, as we want to
|
||||
# only load the files corresponding to the current connector.
|
||||
|
||||
ECConnectorFactory.register_connector(
|
||||
"ECExampleConnector",
|
||||
"vllm.distributed.ec_transfer.ec_connector.example_connector",
|
||||
"ECExampleConnector",
|
||||
)
|
||||
42
vllm/distributed/ec_transfer/ec_transfer_state.py
Normal file
42
vllm/distributed/ec_transfer/ec_transfer_state.py
Normal file
@@ -0,0 +1,42 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from vllm.distributed.ec_transfer.ec_connector.base import (
|
||||
ECConnectorBase,
|
||||
ECConnectorRole,
|
||||
)
|
||||
from vllm.distributed.ec_transfer.ec_connector.factory import ECConnectorFactory
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
|
||||
_EC_CONNECTOR_AGENT: ECConnectorBase | None = None
|
||||
|
||||
|
||||
def get_ec_transfer() -> ECConnectorBase:
|
||||
assert _EC_CONNECTOR_AGENT is not None, "disaggregated EC cache is not initialized"
|
||||
return _EC_CONNECTOR_AGENT
|
||||
|
||||
|
||||
def has_ec_transfer() -> bool:
|
||||
return _EC_CONNECTOR_AGENT is not None
|
||||
|
||||
|
||||
def ensure_ec_transfer_initialized(vllm_config: "VllmConfig") -> None:
|
||||
"""
|
||||
Initialize EC cache connector.
|
||||
"""
|
||||
|
||||
global _EC_CONNECTOR_AGENT
|
||||
|
||||
if vllm_config.ec_transfer_config is None:
|
||||
return
|
||||
|
||||
if (
|
||||
vllm_config.ec_transfer_config.is_ec_transfer_instance
|
||||
and _EC_CONNECTOR_AGENT is None
|
||||
):
|
||||
_EC_CONNECTOR_AGENT = ECConnectorFactory.create_connector(
|
||||
config=vllm_config, role=ECConnectorRole.WORKER
|
||||
)
|
||||
3
vllm/distributed/eplb/__init__.py
Normal file
3
vllm/distributed/eplb/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Expert parallelism load balancer (EPLB)."""
|
||||
192
vllm/distributed/eplb/async_worker.py
Normal file
192
vllm/distributed/eplb/async_worker.py
Normal file
@@ -0,0 +1,192 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
The async worker that transfers experts in the background.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from vllm.distributed.parallel_state import get_eplb_group
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .rebalance_execute import transfer_layer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .eplb_state import EplbModelState, EplbState
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def start_async_worker(
|
||||
state: "EplbState",
|
||||
rank_mapping: dict[int, int] | None = None,
|
||||
is_profile: bool = False,
|
||||
) -> threading.Thread:
|
||||
eplb_group = get_eplb_group().device_group
|
||||
rank = eplb_group.rank()
|
||||
device_index = state.cuda_device_index
|
||||
assert state.is_async
|
||||
|
||||
def thread_target() -> None:
|
||||
assert device_index is not None
|
||||
torch.cuda.set_device(device_index)
|
||||
cuda_stream = torch.cuda.Stream(device=device_index)
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
transfer_run_periodically(
|
||||
state=state,
|
||||
eplb_group=eplb_group,
|
||||
cuda_stream=cuda_stream,
|
||||
is_profile=is_profile,
|
||||
rank_mapping=rank_mapping,
|
||||
)
|
||||
)
|
||||
except Exception as exc: # pragma: no cover - diagnostic path
|
||||
logger.exception("async loop error (Rank %d): %s", rank, str(exc))
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
thread = threading.Thread(target=thread_target, daemon=True)
|
||||
thread.start()
|
||||
return thread
|
||||
|
||||
|
||||
def run_rebalance_experts(
|
||||
model_state: "EplbModelState",
|
||||
eplb_state: "EplbState",
|
||||
physical_to_logical_map_cpu: torch.Tensor,
|
||||
) -> None:
|
||||
assert model_state.eplb_stats is not None
|
||||
eplb_stats = model_state.eplb_stats
|
||||
|
||||
# Wait for the main thread's all-reduce and clone to complete before
|
||||
# accessing the global_expert_load_window tensor.
|
||||
assert model_state.window_ready_event is not None
|
||||
model_state.window_ready_event.wait()
|
||||
model_state.window_ready_event = None
|
||||
|
||||
# Move the global expert load window to CPU for computation.
|
||||
global_expert_load_window = eplb_stats.global_expert_load_window.cpu()
|
||||
# Compute new expert mappings for the model
|
||||
(
|
||||
new_physical_to_logical_map,
|
||||
new_logical_to_physical_map,
|
||||
new_logical_replica_count,
|
||||
) = eplb_state.policy.rebalance_experts(
|
||||
global_expert_load_window,
|
||||
eplb_stats.num_replicas,
|
||||
eplb_stats.num_groups,
|
||||
eplb_stats.num_nodes,
|
||||
eplb_stats.num_gpus,
|
||||
physical_to_logical_map_cpu,
|
||||
)
|
||||
assert new_physical_to_logical_map.device == torch.device("cpu")
|
||||
|
||||
model_state.new_physical_to_logical_map = new_physical_to_logical_map
|
||||
|
||||
max_slots = model_state.logical_to_physical_map.shape[-1]
|
||||
padded_logical = torch.nn.functional.pad(
|
||||
new_logical_to_physical_map,
|
||||
(0, max(0, max_slots - new_logical_to_physical_map.shape[-1])),
|
||||
value=-1,
|
||||
).to(model_state.logical_to_physical_map.device)
|
||||
new_replica = new_logical_replica_count.to(model_state.logical_replica_count.device)
|
||||
model_state.new_logical_to_physical_map = padded_logical
|
||||
model_state.new_logical_replica_count = new_replica
|
||||
|
||||
|
||||
async def transfer_run_periodically(
|
||||
state: "EplbState",
|
||||
eplb_group: ProcessGroup,
|
||||
cuda_stream: torch.cuda.Stream,
|
||||
is_profile: bool = False,
|
||||
rank_mapping: dict[int, int] | None = None,
|
||||
) -> None:
|
||||
while True:
|
||||
await asyncio.to_thread(state.rearrange_event.wait)
|
||||
logger.info("async worker woke up for EPLB transfer")
|
||||
|
||||
assert state.is_async
|
||||
for model_state in state.model_states.values():
|
||||
rebalancing_algorithm_executed = False
|
||||
physical_to_logical_map_cpu = None
|
||||
current_num_layers = model_state.model.num_moe_layers
|
||||
while (
|
||||
model_state.rebalanced
|
||||
and model_state.layer_to_transfer < current_num_layers
|
||||
):
|
||||
if not model_state.ep_buffer_ready and model_state.rebalanced:
|
||||
# Polling the lock directly in the async thread avoids
|
||||
# the thread switch overhead of asyncio.to_thread.
|
||||
# This is typically faster than offloading to a worker thread.
|
||||
while not model_state.buffer_lock.acquire(blocking=False):
|
||||
await asyncio.sleep(0)
|
||||
try:
|
||||
if model_state.layer_to_transfer >= current_num_layers:
|
||||
break
|
||||
if (
|
||||
not rebalancing_algorithm_executed
|
||||
or model_state.new_physical_to_logical_map is None
|
||||
):
|
||||
# Move the physical_to_logical_map to CPU
|
||||
# for rebalancing and transfer_layer.
|
||||
physical_to_logical_map_cpu = (
|
||||
model_state.physical_to_logical_map.cpu()
|
||||
)
|
||||
run_rebalance_experts(
|
||||
model_state, state, physical_to_logical_map_cpu
|
||||
)
|
||||
rebalancing_algorithm_executed = True
|
||||
logger.info(
|
||||
"Async worker computed new indices for model %s",
|
||||
model_state.model_name,
|
||||
)
|
||||
|
||||
assert model_state.new_physical_to_logical_map is not None
|
||||
assert physical_to_logical_map_cpu is not None
|
||||
|
||||
layer_idx = model_state.layer_to_transfer
|
||||
old_layer_indices = physical_to_logical_map_cpu[layer_idx]
|
||||
new_layer_indices = model_state.new_physical_to_logical_map[
|
||||
layer_idx
|
||||
]
|
||||
|
||||
# Wait for the main thread to finish consuming the buffer
|
||||
# before initiating an EPLB transfer on another layer.
|
||||
if model_state.buffer_consumed_event is not None:
|
||||
cuda_stream.wait_event(model_state.buffer_consumed_event)
|
||||
model_state.buffer_consumed_event = None
|
||||
|
||||
(
|
||||
model_state.is_unchanged,
|
||||
model_state.is_received_locally,
|
||||
model_state.recv_metadata,
|
||||
) = await transfer_layer(
|
||||
old_layer_indices=old_layer_indices,
|
||||
new_layer_indices=new_layer_indices,
|
||||
expert_weights=model_state.model.expert_weights[layer_idx],
|
||||
expert_weights_buffer=model_state.expert_buffer,
|
||||
ep_group=eplb_group,
|
||||
is_profile=is_profile,
|
||||
cuda_stream=cuda_stream,
|
||||
rank_mapping=rank_mapping,
|
||||
)
|
||||
event = torch.cuda.Event(blocking=False)
|
||||
cuda_stream.record_event(event)
|
||||
model_state.buffer_ready_event = event
|
||||
model_state.ep_buffer_ready = 1
|
||||
finally:
|
||||
model_state.buffer_lock.release()
|
||||
else:
|
||||
if not model_state.rebalanced:
|
||||
break
|
||||
await asyncio.sleep(0.001)
|
||||
|
||||
state.rearrange_event.clear()
|
||||
1250
vllm/distributed/eplb/eplb_state.py
Normal file
1250
vllm/distributed/eplb/eplb_state.py
Normal file
File diff suppressed because it is too large
Load Diff
54
vllm/distributed/eplb/eplb_utils.py
Normal file
54
vllm/distributed/eplb/eplb_utils.py
Normal file
@@ -0,0 +1,54 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Utility functions for EPLB (Expert Parallel Load Balancing)."""
|
||||
|
||||
import os
|
||||
|
||||
from vllm.config import ParallelConfig
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def override_envs_for_eplb(parallel_config: ParallelConfig) -> None:
|
||||
"""
|
||||
Override environment variables for EPLB when specific conditions are met.
|
||||
|
||||
Args:
|
||||
parallel_config: The parallel configuration object.
|
||||
"""
|
||||
is_data_parallel = parallel_config.data_parallel_size > 1
|
||||
is_eplb_enabled = parallel_config.enable_eplb
|
||||
async_eplb = parallel_config.eplb_config.use_async
|
||||
is_deepep_ll = parallel_config.all2all_backend == "deepep_low_latency"
|
||||
|
||||
# Override NCCL_MAX_CTAS to avoid hangs when using async EPLB with the
|
||||
# DeepEP low-latency backend.
|
||||
#
|
||||
# The hang happens when two ranks interleave kernel launches differently
|
||||
# between NCCL collectives (used by async EPLB weight exchange) and DeepEP
|
||||
# low-latency (LL) kernels. DeepEP LL uses a cooperative launch and tries
|
||||
# to reserve a large fraction of the GPU's SMs; if those SMs are currently
|
||||
# occupied by NCCL, the DeepEP LL launch blocks until enough SMs are
|
||||
# freed.
|
||||
#
|
||||
# If rank A enters DeepEP LL in main thread while rank B is still executing
|
||||
# NCCL in async thread, rank A can block waiting for SMs, while rank B can
|
||||
# block inside NCCL waiting for rank A to participate in the collective.
|
||||
# This circular wait causes a deadlock.
|
||||
# Limiting NCCL occupancy via NCCL_MAX_CTAS leaves space for the DeepEP
|
||||
# cooperative kernel to launch and complete, breaking the deadlock.
|
||||
# See: https://github.com/deepseek-ai/DeepEP/issues/496
|
||||
if is_data_parallel and is_eplb_enabled and is_deepep_ll and async_eplb:
|
||||
current_value_str = os.getenv("NCCL_MAX_CTAS")
|
||||
|
||||
if current_value_str and current_value_str.isdigit():
|
||||
return
|
||||
|
||||
override_value = 8
|
||||
os.environ["NCCL_MAX_CTAS"] = str(override_value)
|
||||
logger.info_once(
|
||||
f"EPLB: Setting NCCL_MAX_CTAS={override_value} "
|
||||
"for expert parallel with EPLB and deepep_low_latency backend",
|
||||
scope="global",
|
||||
)
|
||||
19
vllm/distributed/eplb/policy/__init__.py
Normal file
19
vllm/distributed/eplb/policy/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import get_args
|
||||
|
||||
from vllm.config.parallel import EPLBPolicyOption
|
||||
|
||||
from .abstract import AbstractEplbPolicy
|
||||
from .default import DefaultEplbPolicy
|
||||
|
||||
EPLB_POLICIES = {"default": DefaultEplbPolicy}
|
||||
|
||||
# Ensure that the EPLB_POLICIES keys match the EPLBPolicyOption values
|
||||
assert set(EPLB_POLICIES.keys()) == set(get_args(EPLBPolicyOption))
|
||||
|
||||
__all__ = [
|
||||
"AbstractEplbPolicy",
|
||||
"DefaultEplbPolicy",
|
||||
"EPLB_POLICIES",
|
||||
]
|
||||
43
vllm/distributed/eplb/policy/abstract.py
Normal file
43
vllm/distributed/eplb/policy/abstract.py
Normal file
@@ -0,0 +1,43 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class AbstractEplbPolicy(ABC):
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def rebalance_experts(
|
||||
cls,
|
||||
weight: torch.Tensor,
|
||||
num_replicas: int,
|
||||
num_groups: int,
|
||||
num_nodes: int,
|
||||
num_ranks: int,
|
||||
old_global_expert_indices: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Entry point for expert-parallelism load balancer.
|
||||
|
||||
Parameters:
|
||||
weight: [layers, num_logical_experts], the load statistics
|
||||
for all logical experts
|
||||
num_replicas: number of physical experts, must be a multiple of
|
||||
`num_ranks`
|
||||
num_groups: number of expert groups
|
||||
num_nodes: number of server nodes
|
||||
num_ranks: number of ranks, must be a multiple of `num_nodes`
|
||||
old_global_expert_indices: [layers, num_logical_experts], the old global
|
||||
expert indices. Used to avoid unnecessary weight copying
|
||||
for experts moving within one rank.
|
||||
Returns:
|
||||
physical_to_logical_map: [layers, num_replicas], the expert
|
||||
index of each replica
|
||||
logical_to_physical_map: [layers, num_logical_experts, X],
|
||||
the replica indices for each expert
|
||||
expert_count: [layers, num_logical_experts], number of
|
||||
physical replicas for each logical expert
|
||||
"""
|
||||
raise NotImplementedError
|
||||
376
vllm/distributed/eplb/policy/default.py
Normal file
376
vllm/distributed/eplb/policy/default.py
Normal file
@@ -0,0 +1,376 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Expert parallelism load balancer (EPLB) for vLLM.
|
||||
|
||||
This module implements the core rearrangement algorithm.
|
||||
|
||||
The rearrangement algorithm is adapted from
|
||||
[DeepSeek EPLB](https://github.com/deepseek-ai/eplb).
|
||||
|
||||
Please find at [#12](https://github.com/deepseek-ai/EPLB/issues/12) an example
|
||||
on how the EPLB algorithm works.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from .abstract import AbstractEplbPolicy
|
||||
|
||||
|
||||
class DefaultEplbPolicy(AbstractEplbPolicy):
|
||||
@classmethod
|
||||
def balanced_packing(
|
||||
cls, weight: np.ndarray, num_packs: int
|
||||
) -> tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Pack n weighted objects to m packs, such that each bin contains exactly
|
||||
n/m objects and the weights of all packs are as balanced as possible.
|
||||
|
||||
Parameters:
|
||||
weight: [X, n], the weight of each item
|
||||
num_packs: number of packs
|
||||
|
||||
Returns:
|
||||
pack_index: [X, n], the pack index of each item
|
||||
rank_in_pack: [X, n], the rank of the item in the pack
|
||||
"""
|
||||
num_layers, num_groups = weight.shape
|
||||
assert num_groups % num_packs == 0
|
||||
groups_per_pack = num_groups // num_packs
|
||||
|
||||
if groups_per_pack == 1:
|
||||
pack_index = np.tile(np.arange(num_groups, dtype=np.int64), (num_layers, 1))
|
||||
rank_in_pack = np.zeros_like(pack_index, dtype=np.int64)
|
||||
return pack_index, rank_in_pack
|
||||
|
||||
# Sort and get indices in decending order
|
||||
indices = np.argsort(-weight, axis=-1)
|
||||
|
||||
pack_index = np.full((num_layers, num_groups), -1, dtype=np.int64)
|
||||
rank_in_pack = np.full((num_layers, num_groups), -1, dtype=np.int64)
|
||||
|
||||
pack_weights = np.zeros((num_layers, num_packs), dtype=np.float64)
|
||||
pack_items = np.zeros((num_layers, num_packs), dtype=np.int64)
|
||||
|
||||
# Run the packing algorithm
|
||||
for layer_idx in range(num_layers):
|
||||
weights_row = pack_weights[layer_idx]
|
||||
items_row = pack_items[layer_idx]
|
||||
|
||||
for group in indices[layer_idx]:
|
||||
# Pick the lightest pack; full packs are masked out by inf.
|
||||
pack = int(np.argmin(weights_row))
|
||||
|
||||
pack_index[layer_idx, group] = pack
|
||||
rank_in_pack[layer_idx, group] = items_row[pack]
|
||||
weights_row[pack] += weight[layer_idx, group]
|
||||
items_row[pack] += 1
|
||||
if items_row[pack] == groups_per_pack:
|
||||
# Mark as unavailable for future selections.
|
||||
weights_row[pack] = np.inf
|
||||
|
||||
return pack_index, rank_in_pack
|
||||
|
||||
@classmethod
|
||||
def replicate_experts(
|
||||
cls, weight: np.ndarray, num_phy: int
|
||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Replicate `num_log` experts to `num_phy` replicas, such that the maximum
|
||||
load of all replicas is minimized.
|
||||
|
||||
Parameters:
|
||||
weight: [X, num_log]
|
||||
num_phy: total number of experts after replication
|
||||
|
||||
Returns:
|
||||
phy2log: [X, num_phy], logical expert id of each physical expert
|
||||
replica_idx: [X, num_phy], the index of the replica for each logical expert
|
||||
logcnt: [X, num_log], number of replicas for each logical expert
|
||||
"""
|
||||
n, num_log = weight.shape
|
||||
num_redundant = num_phy - num_log
|
||||
assert num_redundant >= 0
|
||||
phy2log = np.tile(np.arange(num_phy, dtype=np.int64), (n, 1))
|
||||
replica_idx = np.zeros((n, num_phy), dtype=np.int64)
|
||||
logcnt = np.ones((n, num_log), dtype=np.int64)
|
||||
arangen = np.arange(n, dtype=np.int64)
|
||||
for i in range(num_log, num_phy):
|
||||
redundant_indices = np.argmax(weight / logcnt, axis=-1)
|
||||
phy2log[:, i] = redundant_indices
|
||||
replica_idx[:, i] = logcnt[arangen, redundant_indices]
|
||||
logcnt[arangen, redundant_indices] += 1
|
||||
return phy2log, replica_idx, logcnt
|
||||
|
||||
@classmethod
|
||||
def rebalance_experts_hierarchical(
|
||||
cls,
|
||||
weight: np.ndarray,
|
||||
num_physical_experts: int,
|
||||
num_groups: int,
|
||||
num_nodes: int,
|
||||
num_gpus: int,
|
||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Parameters:
|
||||
weight: [num_moe_layers, num_logical_experts]
|
||||
num_physical_experts: number of physical experts after replication
|
||||
num_groups: number of expert groups
|
||||
num_nodes: number of server nodes, where the intra-node network
|
||||
(e.g, NVLink) is faster
|
||||
num_gpus: number of GPUs, must be a multiple of `num_nodes`
|
||||
|
||||
Returns:
|
||||
phy2log: [layers, num_replicas], the expert
|
||||
index of each replica
|
||||
pphy_replicas_idx: [layers, num_logical_experts, X],
|
||||
the replica indices for each expert
|
||||
logcnt: [layers, num_logical_experts], number of
|
||||
physical replicas for each logical expert
|
||||
"""
|
||||
num_layers, num_logical_experts = weight.shape
|
||||
assert num_logical_experts % num_groups == 0
|
||||
group_size = num_logical_experts // num_groups
|
||||
assert num_groups % num_nodes == 0
|
||||
groups_per_node = num_groups // num_nodes
|
||||
assert num_gpus % num_nodes == 0
|
||||
assert num_physical_experts % num_gpus == 0
|
||||
phy_experts_per_gpu = num_physical_experts // num_gpus
|
||||
|
||||
def inverse(perm: np.ndarray) -> np.ndarray:
|
||||
inv = np.empty_like(perm)
|
||||
row_idx = np.arange(perm.shape[0])[:, None]
|
||||
col_idx = np.arange(perm.shape[1], dtype=np.int64)
|
||||
inv[row_idx, perm] = col_idx
|
||||
return inv
|
||||
|
||||
# Step 1: pack groups to nodes
|
||||
tokens_per_group = weight.reshape(num_layers, num_groups, group_size).sum(
|
||||
axis=-1
|
||||
)
|
||||
group_pack_index, group_rank_in_pack = cls.balanced_packing(
|
||||
tokens_per_group, num_nodes
|
||||
)
|
||||
# Map each logical expert into a node-local ordering based on packed groups.
|
||||
log2mlog = (
|
||||
(
|
||||
(group_pack_index * groups_per_node + group_rank_in_pack)[..., None]
|
||||
* group_size
|
||||
)
|
||||
+ np.arange(group_size, dtype=np.int64)
|
||||
).reshape(num_layers, num_logical_experts)
|
||||
mlog2log = inverse(log2mlog)
|
||||
|
||||
# Step 2: construct redundant experts within nodes
|
||||
# Reorder weights into the node-local layout so replication is done per node.
|
||||
tokens_per_mlog = np.take_along_axis(weight, mlog2log, axis=1).reshape(
|
||||
-1, num_logical_experts // num_nodes
|
||||
)
|
||||
phy2mlog, replicas_idx, mlogcnt = cls.replicate_experts(
|
||||
tokens_per_mlog, num_physical_experts // num_nodes
|
||||
)
|
||||
|
||||
# Step 3: pack physical_experts to GPUs
|
||||
# Effective per-physical load = logical load divided by replica count.
|
||||
tokens_per_phy = np.take_along_axis(tokens_per_mlog / mlogcnt, phy2mlog, axis=1)
|
||||
pack_index, rank_in_pack = cls.balanced_packing(
|
||||
tokens_per_phy, num_gpus // num_nodes
|
||||
)
|
||||
phy2pphy = pack_index * phy_experts_per_gpu + rank_in_pack
|
||||
pphy2phy = inverse(phy2pphy)
|
||||
|
||||
# Reorder node-local logical indices into the post-packing physical order.
|
||||
pphy2mlog = np.take_along_axis(phy2mlog, pphy2phy, axis=1)
|
||||
pphy2mlog = (
|
||||
pphy2mlog.reshape(num_layers, num_nodes, -1)
|
||||
+ np.arange(
|
||||
0,
|
||||
num_logical_experts,
|
||||
num_logical_experts // num_nodes,
|
||||
dtype=np.int64,
|
||||
)[None, :, None]
|
||||
).reshape(num_layers, -1)
|
||||
# Map node-local logical indices back to global logical expert ids.
|
||||
pphy2log = np.take_along_axis(mlog2log, pphy2mlog, axis=1)
|
||||
# Reorder replica ranks to the post-packing physical ordering.
|
||||
pphy_replicas_idx = np.take_along_axis(replicas_idx, pphy2phy, axis=1).reshape(
|
||||
num_layers, -1
|
||||
)
|
||||
# Convert replica counts back to the original logical ordering.
|
||||
logcnt = np.take_along_axis(mlogcnt.reshape(num_layers, -1), log2mlog, axis=1)
|
||||
return pphy2log, pphy_replicas_idx, logcnt
|
||||
|
||||
@classmethod
|
||||
def preserve_intragpu_slots(
|
||||
cls,
|
||||
phy2log: np.ndarray,
|
||||
phy_replicas_idx: np.ndarray,
|
||||
num_ranks: int,
|
||||
old_phy2log: np.ndarray,
|
||||
) -> tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Reorder the new mapping per GPU so that experts that remain on the same GPU
|
||||
keep their previous slot positions when possible. Incoming experts to that GPU
|
||||
fill any remaining available slots. This is applied only when the number of GPUs
|
||||
is unchanged and the slots per GPU remain the same between
|
||||
the old and new mappings.
|
||||
"""
|
||||
num_phy_experts = phy2log.shape[1]
|
||||
if num_ranks <= 0 or num_phy_experts % num_ranks != 0:
|
||||
return phy2log, phy_replicas_idx
|
||||
|
||||
# Move to CPU and convert to NumPy for processing
|
||||
slots_per_gpu = num_phy_experts // num_ranks
|
||||
num_layers = phy2log.shape[0]
|
||||
|
||||
post_phy2log = phy2log.copy()
|
||||
post_phy_replicas_idx = phy_replicas_idx.copy()
|
||||
|
||||
for gpu_idx in range(num_ranks):
|
||||
start = gpu_idx * slots_per_gpu
|
||||
end = start + slots_per_gpu
|
||||
# Experts across all layers for this GPU
|
||||
old_local = old_phy2log[:, start:end] # [layers, slots]
|
||||
new_local = phy2log[:, start:end] # [layers, slots]
|
||||
new_ridx = phy_replicas_idx[:, start:end] # [layers, slots]
|
||||
|
||||
used_new_indices = np.zeros((num_layers, slots_per_gpu), dtype=bool)
|
||||
preserved_positions = np.zeros((num_layers, slots_per_gpu), dtype=bool)
|
||||
|
||||
# First pass: preserve same-logical experts in their previous slots
|
||||
for slot_idx in range(slots_per_gpu):
|
||||
# matches: [layers, slots], True where new local experts have
|
||||
# the same logical value as the old from 'slot_idx' and not checked yet
|
||||
matches = (new_local == old_local[:, slot_idx][:, None]) & (
|
||||
~used_new_indices
|
||||
)
|
||||
has_any = matches.any(axis=1)
|
||||
if np.any(has_any):
|
||||
first_idx = np.argmax(matches, axis=1)
|
||||
layer_indices = np.nonzero(has_any)[0]
|
||||
matched_new_positions = first_idx[layer_indices]
|
||||
post_phy2log[layer_indices, start + slot_idx] = new_local[
|
||||
layer_indices, matched_new_positions
|
||||
]
|
||||
post_phy_replicas_idx[layer_indices, start + slot_idx] = new_ridx[
|
||||
layer_indices, matched_new_positions
|
||||
]
|
||||
used_new_indices[layer_indices, matched_new_positions] = True
|
||||
preserved_positions[layer_indices, slot_idx] = True
|
||||
|
||||
# Second pass: fill remaining slots with remaining new experts
|
||||
remaining_mask = ~used_new_indices # [layers, slots]
|
||||
fill_mask = ~preserved_positions # [layers, slots]
|
||||
if remaining_mask.any() and fill_mask.any():
|
||||
idx_base = np.tile(np.arange(slots_per_gpu), (num_layers, 1))
|
||||
# Sentinel value for unavailable positions.
|
||||
large = slots_per_gpu + 1
|
||||
# Priorities: keep original index for available spots, set sentinel
|
||||
# for unavailable; lower is earlier.
|
||||
remaining_priority = np.where(remaining_mask, idx_base, large)
|
||||
fill_priority = np.where(fill_mask, idx_base, large)
|
||||
# Sort to get ordered indices of available src/dst positions per layer.
|
||||
remaining_indices = np.argsort(remaining_priority, axis=1)
|
||||
fill_indices = np.argsort(fill_priority, axis=1)
|
||||
# Fill count per layer (cannot exceed either side).
|
||||
remaining_counts = remaining_mask.sum(axis=1)
|
||||
fill_counts = fill_mask.sum(axis=1)
|
||||
take_counts = np.minimum(remaining_counts, fill_counts)
|
||||
# Assign remaining new experts to remaining slots per layer.
|
||||
for layer_idx in range(num_layers):
|
||||
k = int(take_counts[layer_idx])
|
||||
if k <= 0:
|
||||
continue
|
||||
src_pos = remaining_indices[layer_idx, :k]
|
||||
dst_pos = fill_indices[layer_idx, :k]
|
||||
post_phy2log[layer_idx, start + dst_pos] = new_local[
|
||||
layer_idx, src_pos
|
||||
]
|
||||
post_phy_replicas_idx[layer_idx, start + dst_pos] = new_ridx[
|
||||
layer_idx, src_pos
|
||||
]
|
||||
|
||||
return post_phy2log, post_phy_replicas_idx
|
||||
|
||||
@classmethod
|
||||
def rebalance_experts(
|
||||
cls,
|
||||
weight: torch.Tensor,
|
||||
num_replicas: int,
|
||||
num_groups: int,
|
||||
num_nodes: int,
|
||||
num_ranks: int,
|
||||
old_global_expert_indices: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Entry point for expert-parallelism load balancer.
|
||||
|
||||
Parameters:
|
||||
weight: [layers, num_logical_experts], the load statistics for all
|
||||
logical experts
|
||||
num_replicas: number of physical experts, must be a multiple of
|
||||
`num_gpus`
|
||||
num_groups: number of expert groups
|
||||
num_nodes: number of server nodes, where the intra-node network
|
||||
(e.g, NVLink) is faster
|
||||
num_ranks: number of ranks, must be a multiple of `num_nodes`
|
||||
old_global_expert_indices: [layers, num_logical_experts], the old global
|
||||
expert indices. Used to avoid unnecessary weight copying
|
||||
for experts moving within one rank.
|
||||
Returns:
|
||||
phy2log: [layers, num_replicas], the expert
|
||||
index of each replica
|
||||
log2phy: [layers, num_logical_experts, X],
|
||||
the replica indices for each expert
|
||||
logcnt: [layers, num_logical_experts], number of
|
||||
physical replicas for each logical expert
|
||||
"""
|
||||
device = weight.device
|
||||
num_layers, num_logical_experts = weight.shape
|
||||
weight_np = weight.float().cpu().numpy()
|
||||
old_phy2log_np = (
|
||||
old_global_expert_indices.cpu().numpy()
|
||||
if old_global_expert_indices is not None
|
||||
else None
|
||||
)
|
||||
|
||||
if num_groups % num_nodes == 0:
|
||||
# use hierarchical load-balance policy
|
||||
phy2log_np, phy_replicas_idx_np, logcnt_np = (
|
||||
cls.rebalance_experts_hierarchical(
|
||||
weight_np, num_replicas, num_groups, num_nodes, num_ranks
|
||||
)
|
||||
)
|
||||
else:
|
||||
# use global load-balance policy
|
||||
phy2log_np, phy_replicas_idx_np, logcnt_np = (
|
||||
cls.rebalance_experts_hierarchical(
|
||||
weight_np, num_replicas, 1, 1, num_ranks
|
||||
)
|
||||
)
|
||||
|
||||
# Optional postprocessing to preserve slots for experts moving
|
||||
# within the same GPU
|
||||
# Only apply when the number of GPUs and slots per GPU remain unchanged.
|
||||
# Helps to avoid unnecessary weight copying when experts move
|
||||
# within the same GPU.
|
||||
if old_global_expert_indices is not None:
|
||||
phy2log_np, phy_replicas_idx_np = cls.preserve_intragpu_slots(
|
||||
phy2log_np, phy_replicas_idx_np, num_ranks, old_phy2log_np
|
||||
)
|
||||
num_redundant_experts = num_replicas - num_logical_experts
|
||||
maxlogcnt = num_redundant_experts + 1
|
||||
log2phy_np = np.full(
|
||||
(num_layers, num_logical_experts, maxlogcnt), -1, dtype=np.int64
|
||||
)
|
||||
layer_indices = np.arange(num_layers)[:, None]
|
||||
replica_indices = np.tile(
|
||||
np.arange(num_replicas, dtype=np.int64), (num_layers, 1)
|
||||
)
|
||||
log2phy_np[layer_indices, phy2log_np, phy_replicas_idx_np] = replica_indices
|
||||
|
||||
phy2log = torch.from_numpy(phy2log_np).to(device)
|
||||
log2phy = torch.from_numpy(log2phy_np).to(device)
|
||||
logcnt = torch.from_numpy(logcnt_np).to(device)
|
||||
return phy2log, log2phy, logcnt
|
||||
708
vllm/distributed/eplb/rebalance_execute.py
Normal file
708
vllm/distributed/eplb/rebalance_execute.py
Normal file
@@ -0,0 +1,708 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
The actual execution of the rearrangement.
|
||||
|
||||
This involves the exchange of expert weights between GPUs.
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.distributed import (
|
||||
P2POp,
|
||||
ProcessGroup,
|
||||
all_gather,
|
||||
batch_isend_irecv,
|
||||
get_global_rank,
|
||||
)
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RecvMetadata:
|
||||
"""Metadata describing remote receives during EPLB rebalancing."""
|
||||
|
||||
recv_primary_mask: np.ndarray
|
||||
"""Mask of (num_local_experts,) indicating primary experts received."""
|
||||
recv_count: int
|
||||
"""Number of received experts for the layer."""
|
||||
recv_expert_ids: np.ndarray
|
||||
"""Expert ids (num_local_experts,) of remote primary experts."""
|
||||
recv_dst_rows: np.ndarray
|
||||
"""Target expert indices (num_local_experts,) in local tensors to send."""
|
||||
|
||||
|
||||
# Type alias for the result of move_to_buffer or transfer_layer
|
||||
MoveToBufferResult = tuple[np.ndarray, np.ndarray, RecvMetadata]
|
||||
|
||||
|
||||
def get_ep_ranks_with_experts_batch(
|
||||
expert_ids: np.ndarray,
|
||||
num_local_experts: int,
|
||||
old_indices: np.ndarray,
|
||||
new_indices: np.ndarray,
|
||||
) -> tuple[dict[int, list[int]], dict[int, list[int]]]:
|
||||
"""
|
||||
Get the ranks of the experts that need to be exchanged.
|
||||
|
||||
Args:
|
||||
expert_ids: 1D array of expert indices to query.
|
||||
num_local_experts: The number of local experts.
|
||||
old_indices: The old indices of the experts.
|
||||
new_indices: The new indices of the experts.
|
||||
|
||||
Returns:
|
||||
A tuple of two dictionaries mapping expert_id to:
|
||||
- ranks_to_send: The ranks that have this expert and need to send.
|
||||
- ranks_to_recv: The ranks that need to receive this expert.
|
||||
"""
|
||||
ranks_to_send_map: dict[int, list[int]] = {}
|
||||
ranks_to_recv_map: dict[int, list[int]] = {}
|
||||
|
||||
# Fast path: if no experts, return empty dicts
|
||||
if expert_ids.size == 0:
|
||||
return ranks_to_send_map, ranks_to_recv_map
|
||||
|
||||
unique_experts = np.unique(expert_ids)
|
||||
num_positions = len(old_indices)
|
||||
position_indices = np.arange(num_positions, dtype=np.int32)
|
||||
|
||||
# Vectorized approach: find all positions matching any query expert in one pass
|
||||
# Use np.isin to get boolean masks for all relevant positions at once
|
||||
old_relevant_mask = np.isin(old_indices, unique_experts)
|
||||
new_relevant_mask = np.isin(new_indices, unique_experts)
|
||||
|
||||
# Process old_indices (send ranks)
|
||||
if np.any(old_relevant_mask):
|
||||
old_relevant_positions = position_indices[old_relevant_mask]
|
||||
old_relevant_experts = old_indices[old_relevant_mask]
|
||||
old_relevant_ranks = old_relevant_positions // num_local_experts
|
||||
|
||||
# Sort by expert first, then by position (to maintain first-appearance order)
|
||||
sort_order = np.lexsort((old_relevant_positions, old_relevant_experts))
|
||||
sorted_experts = old_relevant_experts[sort_order]
|
||||
sorted_ranks = old_relevant_ranks[sort_order]
|
||||
|
||||
# Find boundaries where expert changes
|
||||
expert_boundaries = np.concatenate(
|
||||
[[0], np.where(np.diff(sorted_experts) != 0)[0] + 1, [len(sorted_experts)]]
|
||||
)
|
||||
|
||||
# For each expert, extract unique ranks in order of first appearance
|
||||
for i in range(len(expert_boundaries) - 1):
|
||||
start, end = expert_boundaries[i], expert_boundaries[i + 1]
|
||||
expert = int(sorted_experts[start])
|
||||
expert_ranks = sorted_ranks[start:end]
|
||||
|
||||
# Get unique ranks preserving order
|
||||
_, unique_idx = np.unique(expert_ranks, return_index=True)
|
||||
unique_ranks = expert_ranks[np.sort(unique_idx)]
|
||||
ranks_to_send_map[expert] = unique_ranks.tolist()
|
||||
|
||||
# Process new_indices (recv ranks)
|
||||
if np.any(new_relevant_mask):
|
||||
new_relevant_positions = position_indices[new_relevant_mask]
|
||||
new_relevant_experts = new_indices[new_relevant_mask]
|
||||
new_relevant_ranks = new_relevant_positions // num_local_experts
|
||||
|
||||
# Sort by expert first, then by position
|
||||
sort_order = np.lexsort((new_relevant_positions, new_relevant_experts))
|
||||
sorted_experts = new_relevant_experts[sort_order]
|
||||
sorted_ranks = new_relevant_ranks[sort_order]
|
||||
|
||||
# Find boundaries where expert changes
|
||||
expert_boundaries = np.concatenate(
|
||||
[[0], np.where(np.diff(sorted_experts) != 0)[0] + 1, [len(sorted_experts)]]
|
||||
)
|
||||
|
||||
# For each expert, extract unique ranks and exclude local copies
|
||||
for i in range(len(expert_boundaries) - 1):
|
||||
start, end = expert_boundaries[i], expert_boundaries[i + 1]
|
||||
expert = int(sorted_experts[start])
|
||||
expert_ranks = sorted_ranks[start:end]
|
||||
|
||||
# Get unique ranks preserving order
|
||||
_, unique_idx = np.unique(expert_ranks, return_index=True)
|
||||
unique_ranks = expert_ranks[np.sort(unique_idx)]
|
||||
|
||||
# Remove ranks that have local copies (in send map)
|
||||
send_ranks_set = set(ranks_to_send_map.get(expert, []))
|
||||
recv_ranks_actual = [
|
||||
int(r) for r in unique_ranks if r not in send_ranks_set
|
||||
]
|
||||
ranks_to_recv_map[expert] = recv_ranks_actual
|
||||
|
||||
# Handle experts that only appear in old (send only) or new (recv only)
|
||||
for expert in unique_experts:
|
||||
expert = int(expert)
|
||||
if expert not in ranks_to_send_map:
|
||||
ranks_to_send_map[expert] = []
|
||||
if expert not in ranks_to_recv_map:
|
||||
ranks_to_recv_map[expert] = []
|
||||
|
||||
return ranks_to_send_map, ranks_to_recv_map
|
||||
|
||||
|
||||
def move_to_buffer(
|
||||
num_local_experts: int,
|
||||
old_indices: np.ndarray,
|
||||
new_indices: np.ndarray,
|
||||
expert_weights: Sequence[torch.Tensor],
|
||||
expert_weights_buffers: Sequence[torch.Tensor],
|
||||
cuda_stream: torch.cuda.Stream | None,
|
||||
ep_group: ProcessGroup,
|
||||
) -> MoveToBufferResult:
|
||||
"""
|
||||
Rearranges expert weights during EPLB rebalancing.
|
||||
|
||||
Args:
|
||||
num_local_experts: Number of local experts.
|
||||
old_indices: (num_experts_total,) ndarray of current (old)
|
||||
global-to-local expert assignments.
|
||||
new_indices: (num_experts_total,) ndarray of desired (new)
|
||||
global-to-local assignments after rebalance.
|
||||
expert_weights: Original expert weights for the layer.
|
||||
expert_weights_buffers: Intermediate buffers (one per tensor).
|
||||
cuda_stream: CUDA stream for async copies (can be None for sync mode).
|
||||
ep_group: Distributed process group for expert parallel comms.
|
||||
|
||||
Returns:
|
||||
is_unchanged (np.ndarray): (num_local_experts,), True where an expert row
|
||||
is unchanged after rebalance.
|
||||
is_received_locally (np.ndarray): (num_local_experts,), True where a row
|
||||
can be updated from local data.
|
||||
RecvMetadata: Metadata needed for completing remote weight transfers.
|
||||
"""
|
||||
assert old_indices.shape == new_indices.shape
|
||||
ep_rank = ep_group.rank()
|
||||
|
||||
recv_primary_mask = np.zeros((num_local_experts,), dtype=np.bool_)
|
||||
send_expert_ids = np.full((num_local_experts,), -1, dtype=np.int64)
|
||||
send_src_rows = np.full((num_local_experts,), -1, dtype=np.int32)
|
||||
recv_expert_ids = np.full((num_local_experts,), -1, dtype=np.int64)
|
||||
recv_dst_rows = np.full((num_local_experts,), -1, dtype=np.int32)
|
||||
|
||||
base = ep_rank * num_local_experts
|
||||
local_rows = np.arange(num_local_experts, dtype=np.int32)
|
||||
local_global = base + local_rows
|
||||
|
||||
old_local_expert_ids = old_indices[local_global]
|
||||
new_local_expert_ids = new_indices[local_global]
|
||||
|
||||
# Unchanged mask
|
||||
is_unchanged = old_local_expert_ids == new_local_expert_ids
|
||||
|
||||
# Local receive eligibility
|
||||
new_valid = new_local_expert_ids != -1
|
||||
can_recv_local = np.isin(
|
||||
new_local_expert_ids, old_local_expert_ids, assume_unique=False
|
||||
)
|
||||
is_received_locally = np.logical_or(
|
||||
is_unchanged, np.logical_and(new_valid, can_recv_local)
|
||||
)
|
||||
|
||||
# Send map: first src row per unique expert present locally in old mapping
|
||||
send_count = 0
|
||||
valid_old = old_local_expert_ids != -1
|
||||
if np.any(valid_old):
|
||||
uniq_experts, first_idx = np.unique(
|
||||
old_local_expert_ids[valid_old], return_index=True
|
||||
)
|
||||
filtered_rows = local_rows[valid_old]
|
||||
src_rows = filtered_rows[first_idx]
|
||||
send_count = int(uniq_experts.shape[0])
|
||||
send_expert_ids[:send_count] = uniq_experts
|
||||
send_src_rows[:send_count] = src_rows
|
||||
|
||||
# Recv map: primary dst per unique expert needed remotely
|
||||
recv_count = 0
|
||||
need_recv_mask = np.logical_and(~is_received_locally, new_valid)
|
||||
if np.any(need_recv_mask):
|
||||
desired_experts = new_local_expert_ids[need_recv_mask]
|
||||
desired_dsts = local_rows[need_recv_mask]
|
||||
uniq_recv_experts, uniq_indices = np.unique(desired_experts, return_index=True)
|
||||
dst_rows = desired_dsts[uniq_indices]
|
||||
recv_count = int(uniq_recv_experts.shape[0])
|
||||
recv_expert_ids[:recv_count] = uniq_recv_experts
|
||||
recv_dst_rows[:recv_count] = dst_rows
|
||||
recv_primary_mask[dst_rows] = True
|
||||
|
||||
eligible_local_buffer_mask = np.logical_and(~is_unchanged, is_received_locally)
|
||||
|
||||
# 1. Local moves into tmp buffers
|
||||
if bool(eligible_local_buffer_mask.any()) and send_count > 0:
|
||||
dest_indices = np.nonzero(eligible_local_buffer_mask)[0].tolist()
|
||||
expert_to_src_map = dict(
|
||||
zip(send_expert_ids[:send_count], send_src_rows[:send_count])
|
||||
)
|
||||
for dst in dest_indices:
|
||||
expert = new_local_expert_ids[dst]
|
||||
src_local = expert_to_src_map.get(expert, -1)
|
||||
if src_local != -1:
|
||||
for w, b in zip(expert_weights, expert_weights_buffers):
|
||||
b[dst].copy_(w[src_local], non_blocking=True)
|
||||
|
||||
p2p_ops: list[P2POp] = []
|
||||
|
||||
# Pre-compute global ranks mapping
|
||||
ep_size = ep_group.size()
|
||||
rank_to_global = {rank: get_global_rank(ep_group, rank) for rank in range(ep_size)}
|
||||
|
||||
# 2. Post sends
|
||||
if send_count > 0:
|
||||
experts = send_expert_ids[:send_count]
|
||||
srcs = send_src_rows[:send_count]
|
||||
order = np.argsort(experts, kind="stable")
|
||||
experts = experts[order]
|
||||
srcs = srcs[order]
|
||||
|
||||
send_map, recv_map = get_ep_ranks_with_experts_batch(
|
||||
experts,
|
||||
num_local_experts,
|
||||
old_indices,
|
||||
new_indices,
|
||||
)
|
||||
|
||||
for expert, src in zip(experts.tolist(), srcs.tolist()):
|
||||
ranks_to_send = send_map[expert]
|
||||
ranks_to_recv = recv_map[expert]
|
||||
if not ranks_to_send or not ranks_to_recv:
|
||||
continue
|
||||
num_dst_per_sender = len(ranks_to_recv) // len(ranks_to_send)
|
||||
sender_pos = ranks_to_send.index(ep_rank)
|
||||
recv_begin = sender_pos * num_dst_per_sender
|
||||
recv_end = recv_begin + num_dst_per_sender
|
||||
recv_ranks = ranks_to_recv[recv_begin:recv_end]
|
||||
remainder_start = len(ranks_to_send) * num_dst_per_sender
|
||||
recver_pos = remainder_start + sender_pos
|
||||
if recver_pos < len(ranks_to_recv):
|
||||
recv_ranks.append(ranks_to_recv[recver_pos])
|
||||
for dst in recv_ranks:
|
||||
dst_global = rank_to_global[dst]
|
||||
p2p_ops += [
|
||||
P2POp(
|
||||
torch.distributed.isend,
|
||||
w[src],
|
||||
dst_global,
|
||||
)
|
||||
for w in expert_weights
|
||||
]
|
||||
|
||||
# 3. Post recvs
|
||||
if recv_count > 0:
|
||||
experts = recv_expert_ids[:recv_count]
|
||||
dsts = recv_dst_rows[:recv_count]
|
||||
order = np.argsort(experts, kind="stable")
|
||||
experts = experts[order]
|
||||
dsts = dsts[order]
|
||||
|
||||
send_map, recv_map = get_ep_ranks_with_experts_batch(
|
||||
experts,
|
||||
num_local_experts,
|
||||
old_indices,
|
||||
new_indices,
|
||||
)
|
||||
|
||||
for expert, dst in zip(experts.tolist(), dsts.tolist()):
|
||||
ranks_to_send = send_map[expert]
|
||||
ranks_to_recv = recv_map[expert]
|
||||
if not ranks_to_send or not ranks_to_recv:
|
||||
continue
|
||||
num_dst_per_sender = len(ranks_to_recv) // len(ranks_to_send)
|
||||
recver_pos = ranks_to_recv.index(ep_rank)
|
||||
remainder_start = len(ranks_to_send) * num_dst_per_sender
|
||||
if recver_pos < remainder_start:
|
||||
src = ranks_to_send[recver_pos // num_dst_per_sender]
|
||||
else:
|
||||
src = ranks_to_send[recver_pos - remainder_start]
|
||||
src_global = rank_to_global[src]
|
||||
p2p_ops += [
|
||||
P2POp(
|
||||
torch.distributed.irecv,
|
||||
b[dst],
|
||||
src_global,
|
||||
)
|
||||
for b in expert_weights_buffers
|
||||
]
|
||||
|
||||
# 4. Execute the P2P operations. The real communication happens here.
|
||||
if p2p_ops and cuda_stream is not None:
|
||||
with torch.cuda.stream(cuda_stream):
|
||||
reqs = batch_isend_irecv(p2p_ops)
|
||||
for req in reqs:
|
||||
req.wait()
|
||||
elif p2p_ops:
|
||||
reqs = batch_isend_irecv(p2p_ops)
|
||||
for req in reqs:
|
||||
req.wait()
|
||||
# wait for the communication to finish
|
||||
return (
|
||||
is_unchanged,
|
||||
is_received_locally,
|
||||
RecvMetadata(
|
||||
recv_primary_mask=recv_primary_mask,
|
||||
recv_count=recv_count,
|
||||
recv_expert_ids=recv_expert_ids,
|
||||
recv_dst_rows=recv_dst_rows,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def move_from_buffer(
|
||||
expert_weights: Sequence[torch.Tensor],
|
||||
expert_weights_buffers: list[torch.Tensor],
|
||||
is_unchanged: np.ndarray,
|
||||
is_received_locally: np.ndarray,
|
||||
recv_metadata: RecvMetadata,
|
||||
new_indices: np.ndarray,
|
||||
ep_rank: int,
|
||||
) -> None:
|
||||
"""
|
||||
Copies expert weights from communication buffers back to the target weight tensors
|
||||
after EPLB rebalancing.
|
||||
|
||||
Args:
|
||||
expert_weights: List of the actual MoE layer weights used in the execution.
|
||||
expert_weights_buffers: Intermediate buffers containing the experts weights
|
||||
after the transfer is completed.
|
||||
is_unchanged: (num_local_experts,), True where an expert row is unchanged.
|
||||
is_received_locally: (num_local_experts,), True where a row is updated locally.
|
||||
recv_metadata: RecvMetadata containing remote receive metadata.
|
||||
new_indices: (num_experts_total,) mapping from local rows to desired
|
||||
(possibly global) expert id, after rebalance.
|
||||
ep_rank: Rank of the process in the expert parallel group.
|
||||
"""
|
||||
recv_primary_mask = recv_metadata.recv_primary_mask
|
||||
recv_count = recv_metadata.recv_count
|
||||
recv_expert_ids = recv_metadata.recv_expert_ids
|
||||
recv_dst_rows = recv_metadata.recv_dst_rows
|
||||
num_local_experts = is_unchanged.shape[0]
|
||||
|
||||
# Mask for rows to copy back from buffers:
|
||||
# copy if locally received OR remote primary recv
|
||||
copy_mask = np.logical_or(is_received_locally, recv_primary_mask)
|
||||
dest_mask_np = np.logical_and(~is_unchanged, copy_mask)
|
||||
if bool(dest_mask_np.any()):
|
||||
dest_indices = np.nonzero(dest_mask_np)[0].tolist()
|
||||
for dst in dest_indices:
|
||||
for w, b in zip(expert_weights, expert_weights_buffers):
|
||||
w[dst].copy_(b[dst], non_blocking=True)
|
||||
|
||||
if recv_count == 0:
|
||||
return
|
||||
|
||||
# Duplicate remote received rows to non-primary duplicate dsts
|
||||
base = ep_rank * num_local_experts
|
||||
local_experts = new_indices[base + np.arange(num_local_experts, dtype=np.int32)]
|
||||
duplicate_mask = np.logical_and(
|
||||
np.logical_and(~is_unchanged, ~is_received_locally),
|
||||
np.logical_and(~recv_primary_mask, local_experts != -1),
|
||||
)
|
||||
# All received experts are unique in the destination, so no need to copy duplicates
|
||||
if not bool(duplicate_mask.any()):
|
||||
return
|
||||
|
||||
dup_dst_rows = np.nonzero(duplicate_mask)[0]
|
||||
dup_experts = local_experts[dup_dst_rows]
|
||||
|
||||
prim_experts = recv_expert_ids[:recv_count]
|
||||
prim_dsts = recv_dst_rows[:recv_count]
|
||||
order = np.argsort(prim_experts, kind="stable")
|
||||
prim_experts_sorted = prim_experts[order]
|
||||
prim_dsts_sorted = prim_dsts[order]
|
||||
pos = np.searchsorted(prim_experts_sorted, dup_experts)
|
||||
valid = np.logical_and(
|
||||
pos < prim_experts_sorted.shape[0],
|
||||
prim_experts_sorted[np.minimum(pos, prim_experts_sorted.shape[0] - 1)]
|
||||
== dup_experts,
|
||||
)
|
||||
if not bool(valid.any()):
|
||||
return
|
||||
|
||||
matched_dst_rows = dup_dst_rows[valid]
|
||||
matched_src_rows = prim_dsts_sorted[pos[valid]]
|
||||
|
||||
for dst, src in zip(matched_dst_rows.tolist(), matched_src_rows.tolist()):
|
||||
for w in expert_weights:
|
||||
w[dst].copy_(w[src], non_blocking=True)
|
||||
|
||||
|
||||
async def transfer_layer(
|
||||
old_layer_indices: torch.Tensor,
|
||||
new_layer_indices: torch.Tensor,
|
||||
expert_weights: Sequence[torch.Tensor],
|
||||
expert_weights_buffer: Sequence[torch.Tensor],
|
||||
ep_group: ProcessGroup,
|
||||
is_profile: bool = False,
|
||||
cuda_stream: torch.cuda.Stream | None = None,
|
||||
rank_mapping: dict[int, int] | None = None,
|
||||
) -> MoveToBufferResult:
|
||||
"""
|
||||
Rearranges the expert weights in place according to the new expert indices.
|
||||
|
||||
The value of the indices arguments are logical indices of the experts,
|
||||
while keys are physical.
|
||||
|
||||
Args:
|
||||
old_layer_indices: Shape (num_physical_experts,).
|
||||
new_layer_indices: Shape (num_physical_experts,).
|
||||
expert_weights: Iterable of weight tensors for this layer, each with shape
|
||||
(num_local_physical_experts, hidden_size_i).
|
||||
For example, a linear layer may have up and down projection.
|
||||
expert_weights_buffer: Intermediate buffers (one per weight tensor).
|
||||
ep_group: The device process group for expert parallelism.
|
||||
is_profile (bool): If `True`, do not perform any actual weight copy.
|
||||
This is used during profile run, where we only perform dummy
|
||||
communications to reserve enough memory for the buffers.
|
||||
cuda_stream: CUDA stream for async copies (can be None for sync mode).
|
||||
rank_mapping: Optional rank mapping for elastic expert parallelism.
|
||||
|
||||
Returns:
|
||||
is_unchanged (np.ndarray): (num_local_experts,), True where expert
|
||||
is left unchanged.
|
||||
is_received_locally (np.ndarray): (num_local_experts,), True where expert
|
||||
can be received locally.
|
||||
RecvMetadata: Metadata needed for completing remote weight transfers.
|
||||
"""
|
||||
ep_size = ep_group.size()
|
||||
if rank_mapping is not None:
|
||||
# Add a layer dimension for compatibility with mapping functions
|
||||
old_layer_indices_2d = old_layer_indices.unsqueeze(0)
|
||||
new_layer_indices_2d = new_layer_indices.unsqueeze(0)
|
||||
|
||||
if len(rank_mapping) == ep_group.size():
|
||||
# scale down
|
||||
new_layer_indices_2d = _map_new_expert_indices_with_rank_mapping(
|
||||
new_layer_indices_2d,
|
||||
rank_mapping,
|
||||
)
|
||||
else:
|
||||
# scale up
|
||||
old_layer_indices_2d = _map_old_expert_indices_with_rank_mapping(
|
||||
old_layer_indices_2d,
|
||||
rank_mapping,
|
||||
ep_group.size(),
|
||||
)
|
||||
|
||||
# Remove the layer dimension
|
||||
old_layer_indices = old_layer_indices_2d.squeeze(0)
|
||||
new_layer_indices = new_layer_indices_2d.squeeze(0)
|
||||
|
||||
assert old_layer_indices.shape == new_layer_indices.shape
|
||||
num_physical_experts = old_layer_indices.shape[0]
|
||||
assert len(expert_weights[0]) >= 1
|
||||
num_local_physical_experts = expert_weights[0].shape[0]
|
||||
assert num_physical_experts == ep_size * num_local_physical_experts
|
||||
|
||||
old_layer_indices_np = old_layer_indices.cpu().numpy()
|
||||
new_layer_indices_np = new_layer_indices.cpu().numpy()
|
||||
|
||||
is_unchanged, is_received_locally, recv_metadata = move_to_buffer(
|
||||
num_local_experts=num_local_physical_experts,
|
||||
old_indices=old_layer_indices_np,
|
||||
new_indices=new_layer_indices_np,
|
||||
expert_weights=expert_weights,
|
||||
expert_weights_buffers=expert_weights_buffer,
|
||||
cuda_stream=cuda_stream,
|
||||
ep_group=ep_group,
|
||||
)
|
||||
return is_unchanged, is_received_locally, recv_metadata
|
||||
|
||||
|
||||
def rearrange_expert_weights_inplace(
|
||||
old_global_expert_indices: torch.Tensor,
|
||||
new_global_expert_indices: torch.Tensor,
|
||||
expert_weights: Sequence[Sequence[torch.Tensor]],
|
||||
ep_group: ProcessGroup,
|
||||
is_profile: bool = False,
|
||||
rank_mapping: dict[int, int] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Rearranges the expert weights in place according to the new expert indices.
|
||||
|
||||
The value of the indices arguments are logical indices of the experts,
|
||||
while keys are physical.
|
||||
|
||||
Args:
|
||||
old_global_expert_indices: Shape (num_moe_layers, num_physical_experts).
|
||||
new_global_expert_indices: Shape (num_moe_layers, num_physical_experts).
|
||||
expert_weights: A sequence of shape (num_moe_layers)(weight_count)
|
||||
of tensors of shape (num_local_physical_experts, hidden_size_i).
|
||||
For example, a linear layer may have up and down projection,
|
||||
so weight_count = 2. Each weight's hidden size can be different.
|
||||
ep_group: The device process group for expert parallelism.
|
||||
is_profile (bool): If `True`, do not perform any actual weight copy.
|
||||
This is used during profile run, where we only perform dummy
|
||||
communications to reserve enough memory for the buffers.
|
||||
rank_mapping: A dictionary mapping old rank to new rank.
|
||||
"""
|
||||
if rank_mapping is not None:
|
||||
if len(rank_mapping) == ep_group.size():
|
||||
# scale down
|
||||
new_global_expert_indices = _map_new_expert_indices_with_rank_mapping(
|
||||
new_global_expert_indices,
|
||||
rank_mapping,
|
||||
)
|
||||
else:
|
||||
# scale up
|
||||
old_global_expert_indices = _map_old_expert_indices_with_rank_mapping(
|
||||
old_global_expert_indices,
|
||||
rank_mapping,
|
||||
ep_group.size(),
|
||||
)
|
||||
|
||||
assert old_global_expert_indices.shape[1] == new_global_expert_indices.shape[1]
|
||||
|
||||
num_moe_layers, num_physical_experts = old_global_expert_indices.shape
|
||||
assert len(expert_weights) == num_moe_layers
|
||||
assert len(expert_weights[0]) >= 1
|
||||
|
||||
num_local_physical_experts = expert_weights[0][0].shape[0]
|
||||
assert new_global_expert_indices.shape == (num_moe_layers, num_physical_experts)
|
||||
|
||||
ep_size = ep_group.size()
|
||||
assert num_physical_experts == ep_size * num_local_physical_experts
|
||||
|
||||
first_layer_weights = list(expert_weights[0])
|
||||
# Buffers to hold the expert weights during the exchange.
|
||||
# NOTE: Currently we assume the same weights across different layers
|
||||
# have the same shape.
|
||||
weights_buffer: list[torch.Tensor] = [
|
||||
torch.empty_like(w) for w in first_layer_weights
|
||||
]
|
||||
if is_profile:
|
||||
# Reserve communication buffers via a minimal dummy all_gather on first layer
|
||||
for weight, buffer in zip(expert_weights[0], weights_buffer):
|
||||
dummy_recv_buffer = [buffer for _ in range(ep_size)]
|
||||
torch.distributed.barrier()
|
||||
all_gather(
|
||||
dummy_recv_buffer,
|
||||
weight,
|
||||
group=ep_group,
|
||||
)
|
||||
return
|
||||
|
||||
# NOTE(bowen): We need this synchronize to run, but I don't know why.
|
||||
# If you figure out the reason, please let me know -- thank you!
|
||||
torch.cuda.synchronize()
|
||||
|
||||
old_global_expert_indices_cpu = old_global_expert_indices.cpu().numpy()
|
||||
new_global_expert_indices_cpu = new_global_expert_indices.cpu().numpy()
|
||||
|
||||
for layer_idx in range(num_moe_layers):
|
||||
is_unchanged, is_received_locally, recv_metadata = move_to_buffer(
|
||||
num_local_experts=num_local_physical_experts,
|
||||
old_indices=old_global_expert_indices_cpu[layer_idx],
|
||||
new_indices=new_global_expert_indices_cpu[layer_idx],
|
||||
expert_weights=expert_weights[layer_idx],
|
||||
expert_weights_buffers=weights_buffer,
|
||||
cuda_stream=None,
|
||||
ep_group=ep_group,
|
||||
)
|
||||
|
||||
move_from_buffer(
|
||||
expert_weights=expert_weights[layer_idx],
|
||||
expert_weights_buffers=weights_buffer,
|
||||
is_unchanged=is_unchanged,
|
||||
is_received_locally=is_received_locally,
|
||||
recv_metadata=recv_metadata,
|
||||
new_indices=new_global_expert_indices_cpu[layer_idx],
|
||||
ep_rank=ep_group.rank(),
|
||||
)
|
||||
|
||||
|
||||
def _map_old_expert_indices_with_rank_mapping(
|
||||
old_global_expert_indices: torch.Tensor,
|
||||
rank_mapping: dict[int, int],
|
||||
new_ep_size: int,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Map the old global expert indices to the new global expert indices.
|
||||
|
||||
Args:
|
||||
old_global_expert_indices:
|
||||
Shape (num_layers, old_ep_size * num_local_physical_experts).
|
||||
rank_mapping: Mapping from old rank to new rank.
|
||||
new_ep_size: New expert parallelism size.
|
||||
|
||||
Returns:
|
||||
Mapped expert indices with shape
|
||||
(num_layers, new_ep_size * num_local_physical_experts).
|
||||
"""
|
||||
num_layers, old_num_physical_experts = old_global_expert_indices.shape
|
||||
assert rank_mapping, "Rank mapping is required"
|
||||
|
||||
# Get sizes from parameters and rank_mapping
|
||||
old_ep_size = len(rank_mapping)
|
||||
num_local_physical_experts = old_num_physical_experts // old_ep_size
|
||||
new_num_physical_experts = new_ep_size * num_local_physical_experts
|
||||
|
||||
# Create mapped tensor with new shape, initialized to -1
|
||||
mapped_expert_indices = torch.full(
|
||||
(num_layers, new_num_physical_experts),
|
||||
fill_value=-1,
|
||||
dtype=old_global_expert_indices.dtype,
|
||||
device=old_global_expert_indices.device,
|
||||
)
|
||||
|
||||
# Handle rank mapping (scale up/down with rank changes)
|
||||
for old_rank in range(old_ep_size):
|
||||
new_rank = rank_mapping.get(old_rank)
|
||||
if new_rank is not None and new_rank >= 0 and new_rank < new_ep_size:
|
||||
# This old rank exists in the new configuration
|
||||
old_start_idx = old_rank * num_local_physical_experts
|
||||
old_end_idx = (old_rank + 1) * num_local_physical_experts
|
||||
new_start_idx = new_rank * num_local_physical_experts
|
||||
new_end_idx = (new_rank + 1) * num_local_physical_experts
|
||||
|
||||
mapped_expert_indices[:, new_start_idx:new_end_idx] = (
|
||||
old_global_expert_indices[:, old_start_idx:old_end_idx]
|
||||
)
|
||||
# If new_rank is None or >= new_ep_size, the experts remain -1
|
||||
# (scale down case)
|
||||
|
||||
return mapped_expert_indices
|
||||
|
||||
|
||||
def _map_new_expert_indices_with_rank_mapping(
|
||||
new_global_expert_indices: torch.Tensor,
|
||||
rank_mapping: dict[int, int],
|
||||
) -> torch.Tensor:
|
||||
num_layers, new_num_physical_experts = new_global_expert_indices.shape
|
||||
assert rank_mapping, "Rank mapping is required"
|
||||
|
||||
# Get sizes from parameters and rank_mapping
|
||||
old_ep_size = len(rank_mapping)
|
||||
new_ep_size = sum(new_rank != -1 for new_rank in rank_mapping.values())
|
||||
num_local_physical_experts = new_num_physical_experts // new_ep_size
|
||||
old_num_physical_experts = old_ep_size * num_local_physical_experts
|
||||
|
||||
mapped_expert_indices = torch.full(
|
||||
(num_layers, old_num_physical_experts),
|
||||
fill_value=-1,
|
||||
dtype=new_global_expert_indices.dtype,
|
||||
device=new_global_expert_indices.device,
|
||||
)
|
||||
|
||||
for old_rank in range(old_ep_size):
|
||||
new_rank = rank_mapping[old_rank]
|
||||
if new_rank >= 0 and new_rank < new_ep_size:
|
||||
old_start_idx = old_rank * num_local_physical_experts
|
||||
old_end_idx = (old_rank + 1) * num_local_physical_experts
|
||||
new_start_idx = new_rank * num_local_physical_experts
|
||||
new_end_idx = (new_rank + 1) * num_local_physical_experts
|
||||
|
||||
mapped_expert_indices[:, old_start_idx:old_end_idx] = (
|
||||
new_global_expert_indices[:, new_start_idx:new_end_idx]
|
||||
)
|
||||
|
||||
return mapped_expert_indices
|
||||
|
||||
|
||||
__all__ = ["transfer_layer", "move_from_buffer", "RecvMetadata"]
|
||||
513
vllm/distributed/kv_events.py
Normal file
513
vllm/distributed/kv_events.py
Normal file
@@ -0,0 +1,513 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import queue
|
||||
import threading
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import Counter, deque
|
||||
from collections.abc import Callable
|
||||
from dataclasses import asdict
|
||||
from itertools import count
|
||||
from queue import Queue
|
||||
from typing import Any
|
||||
|
||||
import msgspec
|
||||
import zmq
|
||||
|
||||
from vllm.config.kv_events import KVEventsConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.core.kv_cache_utils import ExternalBlockHash
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class EventBatch(
|
||||
msgspec.Struct,
|
||||
array_like=True, # type: ignore[call-arg]
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
gc=False, # type: ignore[call-arg]
|
||||
):
|
||||
ts: float
|
||||
events: list[Any]
|
||||
data_parallel_rank: int | None = None
|
||||
|
||||
|
||||
class KVCacheEvent(
|
||||
msgspec.Struct,
|
||||
array_like=True, # type: ignore[call-arg]
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
gc=False, # type: ignore[call-arg]
|
||||
tag=True,
|
||||
):
|
||||
"""Base class for all KV cache-related events"""
|
||||
|
||||
|
||||
MEDIUM_GPU = "GPU"
|
||||
|
||||
|
||||
class BlockStored(KVCacheEvent):
|
||||
block_hashes: list[ExternalBlockHash]
|
||||
parent_block_hash: ExternalBlockHash | None
|
||||
token_ids: list[int]
|
||||
block_size: int
|
||||
|
||||
lora_id: int | None
|
||||
"""Deprecated: use `lora_name` for KV block key hash.
|
||||
Retained for backward compatibility.
|
||||
"""
|
||||
|
||||
medium: str | None
|
||||
lora_name: str | None
|
||||
|
||||
extra_keys: list[tuple[Any, ...] | None] | None = None
|
||||
"""Extra keys used in block hash computation, one entry per block in
|
||||
block_hashes. Each entry contains MM identifiers, LoRA name, cache_salt,
|
||||
prompt embedding hashes, etc. for that specific block. Exposed for external
|
||||
KV cache consumers to reconstruct block hashes.
|
||||
"""
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(
|
||||
(
|
||||
tuple(self.block_hashes),
|
||||
self.parent_block_hash,
|
||||
tuple(self.token_ids),
|
||||
self.block_size,
|
||||
self.lora_id,
|
||||
self.medium,
|
||||
tuple(self.extra_keys) if self.extra_keys else None,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class BlockRemoved(KVCacheEvent):
|
||||
block_hashes: list[ExternalBlockHash]
|
||||
medium: str | None
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash((tuple(self.block_hashes), self.medium))
|
||||
|
||||
|
||||
class AllBlocksCleared(KVCacheEvent):
|
||||
pass
|
||||
|
||||
|
||||
class KVEventBatch(EventBatch):
|
||||
events: list[BlockStored | BlockRemoved | AllBlocksCleared]
|
||||
|
||||
|
||||
class KVEventAggregator:
|
||||
"""
|
||||
Aggregates KV events across multiple workers.
|
||||
Tracks how many times each event appears and returns only those
|
||||
that were emitted by all workers.
|
||||
"""
|
||||
|
||||
__slots__ = ("_event_counter", "_num_workers")
|
||||
|
||||
def __init__(self, num_workers: int) -> None:
|
||||
if num_workers <= 0:
|
||||
raise ValueError("num_workers must be greater than zero.")
|
||||
self._event_counter: Counter[KVCacheEvent] = Counter()
|
||||
self._num_workers: int = num_workers
|
||||
|
||||
def add_events(self, events: list[KVCacheEvent]) -> None:
|
||||
"""
|
||||
Add events from a worker batch.
|
||||
|
||||
:param events: List of KVCacheEvent objects.
|
||||
"""
|
||||
if not isinstance(events, list):
|
||||
raise TypeError("events must be a list of KVCacheEvent.")
|
||||
self._event_counter.update(events)
|
||||
|
||||
def get_common_events(self) -> list[KVCacheEvent]:
|
||||
"""
|
||||
Return events that appeared in all workers.
|
||||
|
||||
:return: List of events present in all workers.
|
||||
"""
|
||||
return [
|
||||
event
|
||||
for event, count in self._event_counter.items()
|
||||
if count == self._num_workers
|
||||
]
|
||||
|
||||
def get_all_events(self) -> list[KVCacheEvent]:
|
||||
"""
|
||||
Return all events for all workers.
|
||||
|
||||
:return: List of events for all workers.
|
||||
"""
|
||||
return list(self._event_counter.elements())
|
||||
|
||||
def clear_events(self) -> None:
|
||||
"""
|
||||
Clear all tracked events.
|
||||
"""
|
||||
self._event_counter.clear()
|
||||
|
||||
def increment_workers(self, count: int = 1) -> None:
|
||||
"""
|
||||
Increment the number of workers contributing events.
|
||||
|
||||
:param count: Number to increment the workers by.
|
||||
"""
|
||||
if count <= 0:
|
||||
raise ValueError("count must be positive.")
|
||||
self._num_workers += count
|
||||
|
||||
def reset_workers(self) -> None:
|
||||
"""
|
||||
Reset the number of workers to 1.
|
||||
"""
|
||||
self._num_workers = 1
|
||||
|
||||
def get_number_of_workers(self) -> int:
|
||||
"""
|
||||
Return the number of workers.
|
||||
|
||||
:return: int number of workers.
|
||||
"""
|
||||
return self._num_workers
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"<KVEventAggregator workers={self._num_workers}, "
|
||||
f"events={len(self._event_counter)}>"
|
||||
)
|
||||
|
||||
|
||||
class KVConnectorKVEvents(ABC):
|
||||
"""
|
||||
Abstract base class for KV events.
|
||||
Acts as a container for KV events from the connector.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def add_events(self, events: list[KVCacheEvent]) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def aggregate(self) -> "KVConnectorKVEvents":
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def increment_workers(self, count: int = 1) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_all_events(self) -> list[KVCacheEvent]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_number_of_workers(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def clear_events(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class EventPublisher(ABC):
|
||||
"""Lightweight publisher for EventBatch batches with data parallelism
|
||||
support.
|
||||
|
||||
In data parallel setups, each DP rank runs its own EventPublisher instance
|
||||
to avoid duplicate events and ensure proper event attribution:
|
||||
|
||||
- Each DP rank creates a separate publisher
|
||||
- Publishers automatically annotate events with their data_parallel_rank
|
||||
- This allows consumers to distinguish events from different DP ranks
|
||||
|
||||
The publisher is responsible for adding DP metadata since the scheduler
|
||||
operates independently of DP topology and shouldn't need DP awareness.
|
||||
"""
|
||||
|
||||
def __init__(self, data_parallel_rank: int = 0) -> None:
|
||||
self._data_parallel_rank = data_parallel_rank
|
||||
|
||||
@abstractmethod
|
||||
def publish(self, events: EventBatch) -> None:
|
||||
"""Emit events in order.
|
||||
|
||||
Implementations should guarantee at-least-once delivery and
|
||||
monotonic ordering (e.g., via sequence numbers).
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def shutdown(self) -> None:
|
||||
"""Shutdown the publisher."""
|
||||
|
||||
|
||||
class NullEventPublisher(EventPublisher):
|
||||
"""No-op implementation (default when disabled)."""
|
||||
|
||||
def publish(self, events) -> None:
|
||||
return
|
||||
|
||||
def shutdown(self) -> None:
|
||||
return
|
||||
|
||||
|
||||
class ZmqEventPublisher(EventPublisher):
|
||||
"""Reliable PUB/ROUTER publisher with an in-memory replay buffer.
|
||||
|
||||
Spawns a separate thread to handle publishing from a queue.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
endpoint:
|
||||
PUB address. Use `tcp://*:5557` to bind or `tcp://host:5557` to
|
||||
connect.
|
||||
replay_endpoint:
|
||||
Optional ROUTER address for replay requests. When given, subscribers can
|
||||
request missed batches by sending the starting sequence number as an
|
||||
8-byte big-endian integer.
|
||||
buffer_steps:
|
||||
Number of past batches to keep for replay.
|
||||
hwm:
|
||||
ZeroMQ high-water-mark for PUB socket.
|
||||
max_queue_size:
|
||||
Maximum number of events to buffer in memory.
|
||||
topic:
|
||||
Topic to publish events to.
|
||||
"""
|
||||
|
||||
SHUTDOWN_TIMEOUT: float = 1.0
|
||||
END_SEQ = (-1).to_bytes(8, "big", signed=True)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_parallel_rank: int,
|
||||
endpoint: str = "tcp://*:5557",
|
||||
replay_endpoint: str | None = None,
|
||||
buffer_steps: int = 10_000,
|
||||
hwm: int = 100_000,
|
||||
max_queue_size: int = 100_000,
|
||||
topic: str = "",
|
||||
) -> None:
|
||||
# Storage
|
||||
super().__init__(data_parallel_rank)
|
||||
self._event_queue = Queue[EventBatch | None](maxsize=max_queue_size)
|
||||
self._buffer = deque[tuple[int, bytes]](maxlen=buffer_steps)
|
||||
|
||||
# ZMQ sockets
|
||||
self._ctx = zmq.Context.instance()
|
||||
self._pub: zmq.Socket | None = None
|
||||
self._replay: zmq.Socket | None = None
|
||||
self._dp_rank = data_parallel_rank
|
||||
|
||||
self._endpoint = self.offset_endpoint_port(endpoint, self._dp_rank)
|
||||
self._replay_endpoint = self.offset_endpoint_port(
|
||||
replay_endpoint, self._dp_rank
|
||||
)
|
||||
self._hwm = hwm
|
||||
self._socket_setup()
|
||||
|
||||
# Payload
|
||||
self._seq_gen = count()
|
||||
self._topic_bytes = topic.encode("utf-8")
|
||||
|
||||
# Thread
|
||||
self._running = True
|
||||
logger.info("Starting ZMQ publisher thread")
|
||||
|
||||
self._thread = threading.Thread(
|
||||
target=self._publisher_thread, daemon=True, name="zmq-publisher"
|
||||
)
|
||||
self._thread.start()
|
||||
|
||||
def publish(self, events: EventBatch) -> None:
|
||||
if not self._running:
|
||||
raise RuntimeError("Publisher is closed")
|
||||
if events.data_parallel_rank is None:
|
||||
events.data_parallel_rank = self._data_parallel_rank
|
||||
self._event_queue.put(events)
|
||||
|
||||
def shutdown(self) -> None:
|
||||
"""Stop the publisher thread and clean up resources."""
|
||||
self._running = False
|
||||
self._event_queue.put_nowait(None)
|
||||
|
||||
start = time.time()
|
||||
pending_items = True
|
||||
while pending_items and (time.time() - start < self.SHUTDOWN_TIMEOUT):
|
||||
pending_items = not self._event_queue.empty()
|
||||
if pending_items:
|
||||
time.sleep(0.1)
|
||||
|
||||
if pending_items:
|
||||
logger.warning(
|
||||
"Warning: Queue still has %s items after %s seconds timeout",
|
||||
self._event_queue.qsize(),
|
||||
self.SHUTDOWN_TIMEOUT,
|
||||
)
|
||||
|
||||
if self._thread.is_alive():
|
||||
self._thread.join(timeout=self.SHUTDOWN_TIMEOUT)
|
||||
|
||||
# Clean up ZMQ resources
|
||||
try:
|
||||
if self._pub is not None:
|
||||
self._pub.close(linger=0)
|
||||
if self._replay is not None:
|
||||
self._replay.close(linger=0)
|
||||
finally:
|
||||
pass # Do not terminate context; other sockets may use it
|
||||
|
||||
def _socket_setup(self) -> None:
|
||||
"""Initialize sockets
|
||||
https://pyzmq.readthedocs.io/en/v19.0.0/morethanbindings.html#thread-safety
|
||||
"""
|
||||
if self._pub is None:
|
||||
self._pub = self._ctx.socket(zmq.PUB)
|
||||
self._pub.set_hwm(self._hwm)
|
||||
# Heuristic: bind if wildcard / * present, else connect.
|
||||
# bind stable, connect volatile convention
|
||||
if self._endpoint is not None and (
|
||||
"*" in self._endpoint
|
||||
or "::" in self._endpoint
|
||||
or self._endpoint.startswith("ipc://")
|
||||
or self._endpoint.startswith("inproc://")
|
||||
):
|
||||
self._pub.bind(self._endpoint)
|
||||
elif self._endpoint is not None:
|
||||
self._pub.connect(self._endpoint)
|
||||
|
||||
# Set up replay socket: use ROUTER
|
||||
# 1) handles multiple REQ clients (identities)
|
||||
# 2) lets us send back one request → many replies (streamed events)
|
||||
# 3) works in our non‑blocking poll loop alongside PUB
|
||||
if self._replay_endpoint is not None:
|
||||
self._replay = self._ctx.socket(zmq.ROUTER)
|
||||
self._replay.bind(self._replay_endpoint)
|
||||
|
||||
def _publisher_thread(self) -> None:
|
||||
"""Background thread that processes the event queue."""
|
||||
self._pack = msgspec.msgpack.Encoder()
|
||||
|
||||
assert self._pub is not None # narrows type for mypy
|
||||
|
||||
while self._running or self._event_queue.qsize() > 0:
|
||||
# --- replay (non-critical) ---------------------------------
|
||||
if self._replay is not None and self._replay.poll(0):
|
||||
try:
|
||||
self._service_replay()
|
||||
except Exception as e:
|
||||
logger.exception("Error in replay: %s", e)
|
||||
|
||||
# --- main queue (critical) ---------------------------------
|
||||
try:
|
||||
event = self._event_queue.get(timeout=0.1)
|
||||
if event is None:
|
||||
break # Sentinel received, exit thread
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
try:
|
||||
seq = next(self._seq_gen)
|
||||
|
||||
payload = self._pack.encode(event)
|
||||
seq_bytes = seq.to_bytes(8, "big")
|
||||
self._pub.send_multipart((self._topic_bytes, seq_bytes, payload))
|
||||
|
||||
self._buffer.append((seq, payload))
|
||||
self._event_queue.task_done()
|
||||
|
||||
except Exception as e:
|
||||
# Publishing failed; back-off a bit to avoid a tight error loop
|
||||
logger.exception("Error in publisher thread: %s", e)
|
||||
time.sleep(0.1)
|
||||
|
||||
def _service_replay(self) -> None:
|
||||
"""If a replay request is waiting, send buffered batches."""
|
||||
assert self._replay is not None # narrows type for mypy
|
||||
|
||||
frame = self._replay.recv_multipart()
|
||||
if len(frame) != 3:
|
||||
logger.warning("Invalid replay request: %s", frame)
|
||||
return
|
||||
client_id, _, start_seq_bytes = frame
|
||||
start_seq = int.from_bytes(start_seq_bytes, "big")
|
||||
|
||||
for seq, buf in self._buffer:
|
||||
if seq >= start_seq:
|
||||
# [identity, empty_delim, seq_bytes, payload]
|
||||
# (identity, empty_delim) are stripped off by the router
|
||||
# receiving payload is (seq_bytes, payload)
|
||||
self._replay.send_multipart(
|
||||
(client_id, b"", seq.to_bytes(8, "big"), buf)
|
||||
)
|
||||
# Send end of sequence marker
|
||||
# receiving payload is (-1, b""")
|
||||
self._replay.send_multipart((client_id, b"", self.END_SEQ, b""))
|
||||
|
||||
@staticmethod
|
||||
def offset_endpoint_port(
|
||||
endpoint: str | None, data_parallel_rank: int
|
||||
) -> str | None:
|
||||
"""Helper function to offset the port in an endpoint by
|
||||
the data parallel rank.
|
||||
|
||||
Args:
|
||||
endpoint: The endpoint string
|
||||
(e.g., "tcp://*:5557" or "inproc://cache")
|
||||
data_parallel_rank: The data parallel rank to offset by
|
||||
|
||||
Returns:
|
||||
The endpoint with the port offset by data_parallel_rank
|
||||
or suffix appended
|
||||
"""
|
||||
# Do nothing if input is None or data_parallel_rank is 0
|
||||
if not endpoint or data_parallel_rank == 0:
|
||||
return endpoint
|
||||
|
||||
if "inproc" in endpoint:
|
||||
return f"{endpoint}_dp{data_parallel_rank}"
|
||||
if "tcp" in endpoint:
|
||||
if endpoint and ":" in endpoint:
|
||||
# Get everything after the last colon (the port)
|
||||
last_colon_idx = endpoint.rfind(":")
|
||||
base_addr = endpoint[:last_colon_idx]
|
||||
base_port = int(endpoint[last_colon_idx + 1 :])
|
||||
new_port = base_port + data_parallel_rank
|
||||
return f"{base_addr}:{new_port}"
|
||||
return endpoint
|
||||
raise ValueError("Invalid endpoint: must contain 'inproc' or 'tcp'")
|
||||
|
||||
|
||||
class EventPublisherFactory:
|
||||
_registry: dict[str, Callable[..., EventPublisher]] = {
|
||||
"null": NullEventPublisher,
|
||||
"zmq": ZmqEventPublisher,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def register_publisher(cls, name: str, ctor: Callable[..., EventPublisher]) -> None:
|
||||
if name in cls._registry:
|
||||
raise KeyError(f"publisher '{name}' already registered")
|
||||
cls._registry[name] = ctor
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls, config: KVEventsConfig | None, data_parallel_rank: int = 0
|
||||
) -> EventPublisher:
|
||||
"""Create publisher from a config mapping."""
|
||||
if (
|
||||
config is None
|
||||
or not config.enable_kv_cache_events
|
||||
or config.publisher == "null"
|
||||
):
|
||||
return NullEventPublisher()
|
||||
|
||||
config_dict = asdict(config)
|
||||
|
||||
kind = config_dict.pop("publisher")
|
||||
config_dict.pop("enable_kv_cache_events")
|
||||
try:
|
||||
constructor = cls._registry[kind]
|
||||
except KeyError as exc:
|
||||
raise ValueError(f"Unknown event publisher '{kind}'") from exc
|
||||
return constructor(data_parallel_rank=data_parallel_rank, **config_dict)
|
||||
29
vllm/distributed/kv_transfer/README.md
Normal file
29
vllm/distributed/kv_transfer/README.md
Normal file
@@ -0,0 +1,29 @@
|
||||
|
||||
# Distributed KV cache transfer
|
||||
|
||||
This folder implements distributed KV cache transfer across vLLM instances.
|
||||
Currently the main use case is for disaggregated prefilling.
|
||||
|
||||
## Abstractions
|
||||
|
||||
The KV cache transfer contains three layer of abstractions:
|
||||
|
||||
- KV pipe: a FIFO pipe for torch.tensor transmission. Key APIs: `send_tensor` and `recv_tensor`.
|
||||
- KV lookup buffer: a lookup buffer for KV caches. Key: the tokens, value: the KV caches (and/or hidden states). Key APIs: `insert` and `drop_select` (similar to SQL semantics).
|
||||
- KV connector: a connector that connects the KV pipe and KV lookup buffer to vLLM. Key APIs: `send_kv_caches_and_hidden_states` and `recv_kv_caches_and_hidden_states`.
|
||||
|
||||
Why we need KV lookup buffer: FIFO pipe itself is not enough as prefill vLLM worker may process requests in a different order compared to decode vLLM worker. Say the QPS is really high, prefill worker may handle requests in order A -> B -> C, but the decode worker may process request C first. This is not the case that can be naturally handled by FIFO pipe, so we provide KV lookup buffer to help translate a FIFO pipe to a lookup buffer.
|
||||
|
||||
NOTE: KV pipe layer is bypassable: you can skip this layer if your distributed
|
||||
communication service already supports key-value-based lookup (like redis or
|
||||
RDMA database).
|
||||
|
||||
NOTE: If you want to not only transfer KV caches, but adjust the model execution flow of vLLM as well (for example, allow vLLM to receive KV caches on some tokens and do prefill on the remaining tokens), you can bypass both KV pipe layer and KV lookup buffer layer, and directly implement on KV connector layer. Bear in mind that as vLLM's model input is constantly changing, this implementation will likely be broken when vLLM has new updates.
|
||||
|
||||
## Disaggregated prefilling
|
||||
|
||||
The example usage is in [this file](../../../examples/online_serving/disaggregated_prefill.sh).
|
||||
|
||||
Here is the diagram of how we run disaggregated prefilling.
|
||||
|
||||

|
||||
20
vllm/distributed/kv_transfer/__init__.py
Normal file
20
vllm/distributed/kv_transfer/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm.distributed.kv_transfer.kv_transfer_state import (
|
||||
KVConnectorBaseType,
|
||||
ensure_kv_transfer_initialized,
|
||||
ensure_kv_transfer_shutdown,
|
||||
get_kv_transfer_group,
|
||||
has_kv_transfer_group,
|
||||
is_v1_kv_transfer_group,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"get_kv_transfer_group",
|
||||
"has_kv_transfer_group",
|
||||
"is_v1_kv_transfer_group",
|
||||
"ensure_kv_transfer_initialized",
|
||||
"ensure_kv_transfer_shutdown",
|
||||
"KVConnectorBaseType",
|
||||
]
|
||||
BIN
vllm/distributed/kv_transfer/disagg_prefill_workflow.jpg
Normal file
BIN
vllm/distributed/kv_transfer/disagg_prefill_workflow.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 139 KiB |
10
vllm/distributed/kv_transfer/kv_connector/base.py
Normal file
10
vllm/distributed/kv_transfer/kv_connector/base.py
Normal file
@@ -0,0 +1,10 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Defines the base type for KV cache connectors."""
|
||||
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
|
||||
|
||||
KVConnectorBase = KVConnectorBase_V1
|
||||
KVConnectorBaseType = KVConnectorBase_V1
|
||||
|
||||
__all__ = ["KVConnectorBase", "KVConnectorBaseType"]
|
||||
203
vllm/distributed/kv_transfer/kv_connector/factory.py
Normal file
203
vllm/distributed/kv_transfer/kv_connector/factory.py
Normal file
@@ -0,0 +1,203 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import importlib
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING, cast
|
||||
|
||||
from vllm.distributed.kv_transfer.kv_connector.base import (
|
||||
KVConnectorBase,
|
||||
KVConnectorBaseType,
|
||||
)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1 import (
|
||||
KVConnectorRole,
|
||||
supports_hma,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.func_utils import supports_kw
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.kv_transfer import KVTransferConfig
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class KVConnectorFactory:
|
||||
_registry: dict[str, Callable[[], type[KVConnectorBase]]] = {}
|
||||
|
||||
@classmethod
|
||||
def register_connector(cls, name: str, module_path: str, class_name: str) -> None:
|
||||
"""Register a connector with a lazy-loading module and class name."""
|
||||
if name in cls._registry:
|
||||
raise ValueError(f"Connector '{name}' is already registered.")
|
||||
|
||||
def loader() -> type[KVConnectorBase]:
|
||||
module = importlib.import_module(module_path)
|
||||
return getattr(module, class_name)
|
||||
|
||||
cls._registry[name] = loader
|
||||
|
||||
@classmethod
|
||||
def create_connector(
|
||||
cls,
|
||||
config: "VllmConfig",
|
||||
role: KVConnectorRole,
|
||||
kv_cache_config: "KVCacheConfig | None" = None,
|
||||
) -> KVConnectorBase:
|
||||
kv_transfer_config = config.kv_transfer_config
|
||||
if kv_transfer_config is None:
|
||||
raise ValueError("kv_transfer_config must be set to create a connector")
|
||||
connector_cls, compat_sig = cls._get_connector_class_with_compat(
|
||||
kv_transfer_config
|
||||
)
|
||||
|
||||
# check if the connector supports HMA
|
||||
hma_enabled = not config.scheduler_config.disable_hybrid_kv_cache_manager
|
||||
if hma_enabled and not supports_hma(connector_cls):
|
||||
raise ValueError(
|
||||
f"Connector {connector_cls.__name__} does not support HMA but "
|
||||
f"HMA is enabled. Please set `--disable-hybrid-kv-cache-manager`."
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Creating v1 connector with name: %s and engine_id: %s",
|
||||
connector_cls.__name__,
|
||||
kv_transfer_config.engine_id,
|
||||
)
|
||||
# NOTE(Kuntai): v1 connector is explicitly separated into two roles.
|
||||
# Scheduler connector:
|
||||
# - Co-locate with scheduler process
|
||||
# - Should only be used inside the Scheduler class
|
||||
# Worker connector:
|
||||
# - Co-locate with worker process
|
||||
# - Should only be used inside the forward context & attention layer
|
||||
# We build separately to enforce strict separation
|
||||
if compat_sig:
|
||||
# Old signature: __init__(self, vllm_config, role)
|
||||
return connector_cls(config, role)
|
||||
else:
|
||||
# New signature: __init__(self, vllm_config, role, kv_cache_config)
|
||||
return connector_cls(config, role, kv_cache_config)
|
||||
|
||||
@classmethod
|
||||
def get_connector_class_by_name(
|
||||
cls, connector_name: str
|
||||
) -> type[KVConnectorBaseType]:
|
||||
"""Get a registered connector class by name.
|
||||
|
||||
Raises ValueError if the connector is not registered.
|
||||
|
||||
Args:
|
||||
connector_name: Name of the registered connector.
|
||||
|
||||
Returns:
|
||||
The connector class.
|
||||
"""
|
||||
if connector_name not in cls._registry:
|
||||
raise ValueError(f"Connector '{connector_name}' is not registered.")
|
||||
return cls._registry[connector_name]()
|
||||
|
||||
@classmethod
|
||||
def _get_connector_class_with_compat(
|
||||
cls, kv_transfer_config: "KVTransferConfig"
|
||||
) -> tuple[type[KVConnectorBaseType], bool]:
|
||||
connector_name = kv_transfer_config.kv_connector
|
||||
if connector_name is None:
|
||||
raise ValueError("Connector name is not set in KVTransferConfig")
|
||||
compat_sig = False
|
||||
if connector_name in cls._registry:
|
||||
connector_cls = cls._registry[connector_name]()
|
||||
else:
|
||||
connector_module_path = kv_transfer_config.kv_connector_module_path
|
||||
if connector_module_path is None:
|
||||
raise ValueError(f"Unsupported connector type: {connector_name}")
|
||||
connector_module = importlib.import_module(connector_module_path)
|
||||
try:
|
||||
connector_cls = getattr(connector_module, connector_name)
|
||||
except AttributeError as e:
|
||||
raise AttributeError(
|
||||
f"Class {connector_name} not found in {connector_module_path}"
|
||||
) from e
|
||||
connector_cls = cast(type[KVConnectorBaseType], connector_cls)
|
||||
if not supports_kw(connector_cls, "kv_cache_config"):
|
||||
compat_sig = True
|
||||
logger.warning(
|
||||
"Connector %s uses deprecated signature with 2 required arguments. "
|
||||
"Please update to include kv_cache_config as the second argument.",
|
||||
connector_cls.__name__,
|
||||
)
|
||||
return connector_cls, compat_sig
|
||||
|
||||
@classmethod
|
||||
def get_connector_class(
|
||||
cls, kv_transfer_config: "KVTransferConfig"
|
||||
) -> type[KVConnectorBaseType]:
|
||||
"""Get the connector class by name."""
|
||||
connector_cls, _ = cls._get_connector_class_with_compat(kv_transfer_config)
|
||||
return connector_cls
|
||||
|
||||
|
||||
# Register various connectors here.
|
||||
# The registration should not be done in each individual file, as we want to
|
||||
# only load the files corresponding to the current connector.
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"ExampleConnector",
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.example_connector",
|
||||
"ExampleConnector",
|
||||
)
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"P2pNcclConnector",
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.p2p.p2p_nccl_connector",
|
||||
"P2pNcclConnector",
|
||||
)
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"LMCacheConnectorV1",
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.lmcache_connector",
|
||||
"LMCacheConnectorV1",
|
||||
)
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"LMCacheMPConnector",
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.lmcache_mp_connector",
|
||||
"LMCacheMPConnector",
|
||||
)
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"NixlConnector",
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector",
|
||||
"NixlConnector",
|
||||
)
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"MultiConnector",
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.multi_connector",
|
||||
"MultiConnector",
|
||||
)
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"MoRIIOConnector",
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector",
|
||||
"MoRIIOConnector",
|
||||
)
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"OffloadingConnector",
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.offloading_connector",
|
||||
"OffloadingConnector",
|
||||
)
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"DecodeBenchConnector",
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.decode_bench_connector",
|
||||
"DecodeBenchConnector",
|
||||
)
|
||||
KVConnectorFactory.register_connector(
|
||||
"MooncakeConnector",
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.mooncake.mooncake_connector",
|
||||
"MooncakeConnector",
|
||||
)
|
||||
502
vllm/distributed/kv_transfer/kv_connector/utils.py
Normal file
502
vllm/distributed/kv_transfer/kv_connector/utils.py
Normal file
@@ -0,0 +1,502 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
KV cache helper for store.
|
||||
"""
|
||||
|
||||
from collections.abc import Iterator
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig, get_current_vllm_config, get_layers_from_vllm_config
|
||||
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.attention.backend import AttentionBackend
|
||||
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
EngineId = str
|
||||
|
||||
|
||||
def get_kv_connector_cache_layout():
|
||||
# NOTE (NickLucche) When running disaggregated PD with NIXL, HND layout is
|
||||
# used for faster transfer.
|
||||
vllm_config = get_current_vllm_config()
|
||||
kv_config = vllm_config.kv_transfer_config
|
||||
if kv_config is not None:
|
||||
connector_cls = KVConnectorFactory.get_connector_class(kv_config)
|
||||
required_kvcache_layout = connector_cls.get_required_kvcache_layout(vllm_config)
|
||||
if required_kvcache_layout is not None:
|
||||
return required_kvcache_layout
|
||||
logger.info_once(
|
||||
"Connectors do not specify a kv cache layout, defaulting to NHD."
|
||||
)
|
||||
return "NHD"
|
||||
|
||||
|
||||
class KVOutputAggregator:
|
||||
"""Utility class to aggregate the output of all workers into a single
|
||||
output corresponding to Rank 0 for scheduler."""
|
||||
|
||||
def __init__(self, expected_finished_count: int):
|
||||
# Complete transfer tracker. Used to track finished requests
|
||||
# [req_id -> n_remaining_workers]
|
||||
self._recv_remaining_count = dict[str, int]()
|
||||
self._send_remaining_count = dict[str, int]()
|
||||
self._expected_finished_count = expected_finished_count
|
||||
|
||||
@classmethod
|
||||
def from_connector(cls, connector: "KVConnectorBase", world_size: int):
|
||||
return cls(connector.get_finished_count() or world_size)
|
||||
|
||||
def aggregate(
|
||||
self, outputs: list[ModelRunnerOutput | None], output_rank: int = 0
|
||||
) -> ModelRunnerOutput | None:
|
||||
if not outputs[output_rank]:
|
||||
return None
|
||||
|
||||
# Aggregate kv_connector_output from all workers
|
||||
|
||||
def update_finished_set(
|
||||
req_ids: set[str] | None,
|
||||
remaining_count_dict: dict[str, int],
|
||||
finished_set: set[str],
|
||||
) -> None:
|
||||
for req_id in req_ids or ():
|
||||
remaining_count = remaining_count_dict.get(
|
||||
req_id, self._expected_finished_count
|
||||
)
|
||||
remaining_count_dict[req_id] = remaining_count - 1
|
||||
if remaining_count_dict[req_id] == 0:
|
||||
finished_set.add(req_id)
|
||||
del remaining_count_dict[req_id]
|
||||
|
||||
finished_sending = set[str]()
|
||||
finished_recving = set[str]()
|
||||
aggregated_kv_connector_stats = None
|
||||
combined_kv_cache_events = None
|
||||
invalid_block_ids = set[int]()
|
||||
for model_runner_output in outputs:
|
||||
assert model_runner_output is not None
|
||||
kv_output = model_runner_output.kv_connector_output
|
||||
if not kv_output:
|
||||
continue
|
||||
# Allow the worker to dynamically update the expected number of
|
||||
# finished sending/recving for new requests.
|
||||
if (
|
||||
kv_output.expected_finished_count > 0
|
||||
and kv_output.expected_finished_count != self._expected_finished_count
|
||||
):
|
||||
logger.debug(
|
||||
"Expected finished requests updated from %d to %d",
|
||||
self._expected_finished_count,
|
||||
kv_output.expected_finished_count,
|
||||
)
|
||||
self._expected_finished_count = kv_output.expected_finished_count
|
||||
|
||||
update_finished_set(
|
||||
kv_output.finished_sending, self._send_remaining_count, finished_sending
|
||||
)
|
||||
update_finished_set(
|
||||
kv_output.finished_recving, self._recv_remaining_count, finished_recving
|
||||
)
|
||||
|
||||
# Aggregate kv_connector_stats from all workers.
|
||||
if aggregated_kv_connector_stats is None:
|
||||
# Use the first worker's kv_connector_stats as accumulator.
|
||||
aggregated_kv_connector_stats = kv_output.kv_connector_stats
|
||||
elif kv_connector_stats := kv_output.kv_connector_stats:
|
||||
if aggregated_kv_connector_stats is None:
|
||||
aggregated_kv_connector_stats = kv_connector_stats
|
||||
else:
|
||||
assert isinstance(
|
||||
aggregated_kv_connector_stats, type(kv_connector_stats)
|
||||
)
|
||||
aggregated_kv_connector_stats = (
|
||||
aggregated_kv_connector_stats.aggregate(kv_connector_stats)
|
||||
)
|
||||
|
||||
# Combine kv_cache_events from all workers.
|
||||
if combined_kv_cache_events is None:
|
||||
# Use the first worker's kv_cache events as start event list.
|
||||
combined_kv_cache_events = kv_output.kv_cache_events
|
||||
elif kv_cache_events := kv_output.kv_cache_events:
|
||||
assert isinstance(
|
||||
combined_kv_cache_events,
|
||||
type(kv_cache_events),
|
||||
)
|
||||
worker_kv_cache_events = kv_cache_events.get_all_events()
|
||||
combined_kv_cache_events.add_events(worker_kv_cache_events)
|
||||
combined_kv_cache_events.increment_workers(1)
|
||||
|
||||
invalid_block_ids |= kv_output.invalid_block_ids
|
||||
|
||||
# select output of the worker specified by output_rank
|
||||
output = outputs[output_rank]
|
||||
|
||||
assert output is not None
|
||||
output.kv_connector_output = KVConnectorOutput(
|
||||
finished_sending=finished_sending or None,
|
||||
finished_recving=finished_recving or None,
|
||||
kv_connector_stats=aggregated_kv_connector_stats or None,
|
||||
kv_cache_events=combined_kv_cache_events or None,
|
||||
invalid_block_ids=invalid_block_ids,
|
||||
expected_finished_count=self._expected_finished_count,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def _make_src_and_dst_indices(
|
||||
src_block_ids: list[int],
|
||||
dst_block_ids: list[int],
|
||||
src_device: torch.device | str,
|
||||
dst_device: torch.device | str,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
src_indices = torch.tensor(src_block_ids, device=src_device, dtype=torch.int64)
|
||||
dst_indices = torch.tensor(dst_block_ids, device=dst_device, dtype=torch.int64)
|
||||
return src_indices, dst_indices
|
||||
|
||||
|
||||
def copy_kv_blocks(
|
||||
src_kv_caches: dict[str, torch.Tensor],
|
||||
dst_kv_caches: dict[str, torch.Tensor],
|
||||
src_block_ids: list[int],
|
||||
dst_block_ids: list[int],
|
||||
direction: Literal["h2d", "d2h"],
|
||||
) -> None:
|
||||
"""Copy kv blocks between different buffers."""
|
||||
if (
|
||||
not src_kv_caches
|
||||
or not dst_kv_caches
|
||||
or not src_block_ids
|
||||
or not dst_block_ids
|
||||
or len(src_block_ids) != len(dst_block_ids)
|
||||
):
|
||||
return
|
||||
|
||||
src_device = next(iter(src_kv_caches.values())).device
|
||||
dst_device = next(iter(dst_kv_caches.values())).device
|
||||
|
||||
src_indices, dst_indices = _make_src_and_dst_indices(
|
||||
src_block_ids=src_block_ids,
|
||||
dst_block_ids=dst_block_ids,
|
||||
src_device=src_device,
|
||||
dst_device=dst_device,
|
||||
)
|
||||
|
||||
if direction == "h2d":
|
||||
copy_fn = current_platform.insert_blocks_to_device
|
||||
else:
|
||||
copy_fn = current_platform.swap_out_blocks_to_host
|
||||
for layer_name in src_kv_caches:
|
||||
src_tensor = src_kv_caches[layer_name]
|
||||
dst_tensor = dst_kv_caches[layer_name]
|
||||
copy_fn(src_tensor, dst_tensor, src_indices, dst_indices)
|
||||
|
||||
|
||||
def kv_postprocess_blksize_on_receive(cache, indices, block_size_ratio):
|
||||
"""
|
||||
Transforms the layout of received KV cache blocks to the local block_size.
|
||||
(Only works for local blocksize > remote blocksize)
|
||||
|
||||
example:
|
||||
local blocksize = 16 tokens, remote blocksize = 4 tokens
|
||||
local block[0] = remote block[0, 1, 2, 3]
|
||||
remote is |h0-b0|h1-b0|h2-b0|h3-b0|h0-b1|h1-b1|h2-b1|h3-b1|...
|
||||
local is |h0-b0..................|h1-b0..................|...
|
||||
permute is to:
|
||||
1. view => view remote as n_blocks * remote_shape(H,remoteN,D)
|
||||
2. permute => (H, nblocks, remoteN, D)
|
||||
3. flatten => (H, localN, D)
|
||||
"""
|
||||
blocks_to_update = cache.index_select(0, indices)
|
||||
# use physical order
|
||||
blocks_to_update = blocks_to_update.permute(0, 2, 1, 3)
|
||||
n_kv_heads, block_size, head_size = blocks_to_update.shape[1:]
|
||||
remote_block_size = block_size // block_size_ratio
|
||||
n_blocks = block_size_ratio
|
||||
|
||||
permuted_blocks = (
|
||||
blocks_to_update.reshape(-1, n_blocks, n_kv_heads, remote_block_size, head_size)
|
||||
.permute(0, 2, 1, 3, 4)
|
||||
.flatten(2, 3)
|
||||
)
|
||||
permuted_blocks = permuted_blocks.permute(0, 2, 1, 3)
|
||||
cache.index_copy_(0, indices, permuted_blocks)
|
||||
|
||||
|
||||
def kv_postprocess_layout_on_receive(cache, indices):
|
||||
"""Transforms the layout of received KV cache blocks to the local format.
|
||||
|
||||
This method corrects layout mismatches from direct memory copies by
|
||||
permuting the tensor dimensions.
|
||||
|
||||
- **Source Layout:** `[num_blocks, n_kv_head, block_size, head_dim]`
|
||||
- **Target Layout:** `[num_blocks, block_size, n_kv_head, head_dim]`
|
||||
|
||||
Implementation:
|
||||
- x = blocks_to_update.reshape(src_shape) # view local kv with sender layout
|
||||
- permuted_blocks = x.permute(*inv_order) # transpose n_kv_heads, block_size
|
||||
- cache.index_copy_(0, indices, permuted_blocks) # copy permuted kv back
|
||||
|
||||
"""
|
||||
blocks_to_update = cache.index_select(0, indices)
|
||||
target_shape = list(blocks_to_update.shape)
|
||||
target_shape[0] = -1
|
||||
inv_order = [0, 2, 1, 3]
|
||||
src_shape = tuple(target_shape[i] for i in inv_order)
|
||||
blocks_to_update = cache.index_select(0, indices)
|
||||
permuted_blocks = blocks_to_update.reshape(src_shape).permute(*inv_order)
|
||||
cache.index_copy_(0, indices, permuted_blocks)
|
||||
|
||||
|
||||
def kv_postprocess_blksize_and_layout_on_receive(cache, indices, block_size_ratio):
|
||||
"""
|
||||
Transforms the layout of received KV cache to the local block_size and HND.
|
||||
(Only works for local blocksize > remote blocksize)
|
||||
|
||||
prefill is HND, smaller block_size
|
||||
decode(local) is NHD, larger block_size
|
||||
"""
|
||||
blocks_to_update = cache.index_select(0, indices)
|
||||
|
||||
block_size, n_kv_heads, head_size = blocks_to_update.shape[1:]
|
||||
remote_block_size = block_size // block_size_ratio
|
||||
n_blocks = block_size_ratio
|
||||
|
||||
permuted_blocks = (
|
||||
blocks_to_update.reshape(-1, n_blocks, n_kv_heads, remote_block_size, head_size)
|
||||
.permute(0, 1, 3, 2, 4)
|
||||
.flatten(1, 2)
|
||||
)
|
||||
cache.index_copy_(0, indices, permuted_blocks)
|
||||
|
||||
|
||||
def yield_req_data(
|
||||
scheduler_output,
|
||||
) -> Iterator[tuple[str, tuple[list[int], ...], bool]]:
|
||||
"""
|
||||
Yields:
|
||||
(req_id, new_block_id_groups, preempted)
|
||||
"""
|
||||
# new requests
|
||||
for req_data in scheduler_output.scheduled_new_reqs:
|
||||
yield req_data.req_id, req_data.block_ids, False
|
||||
|
||||
# cached requests
|
||||
cached_reqs = scheduler_output.scheduled_cached_reqs
|
||||
yield from zip(
|
||||
cached_reqs.req_ids,
|
||||
cached_reqs.new_block_ids,
|
||||
(req_id in cached_reqs.resumed_req_ids for req_id in cached_reqs.req_ids),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TpKVTopology:
|
||||
"""
|
||||
Helper class for tensor parallel and KV topology information for
|
||||
mapping between local and remote TP workers.
|
||||
"""
|
||||
|
||||
tp_rank: int
|
||||
remote_tp_size: dict[EngineId, int]
|
||||
is_mla: bool
|
||||
total_num_kv_heads: int
|
||||
attn_backend: type[AttentionBackend]
|
||||
engine_id: EngineId
|
||||
remote_block_size: dict[EngineId, int]
|
||||
tensor_shape: torch.Size | None = None
|
||||
|
||||
def __post_init__(self):
|
||||
# Figure out whether the first dimension of the cache is K/V
|
||||
# or num_blocks. This is used to register the memory regions correctly.
|
||||
_MOCK_BLOCK_SIZE = 16
|
||||
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
|
||||
num_blocks=1, block_size=_MOCK_BLOCK_SIZE, num_kv_heads=1, head_size=1
|
||||
)
|
||||
logger.debug("Test kv_cache_shape: %s", kv_cache_shape)
|
||||
# Non-MLA backends caches have 5 dims [2, num_blocks, H,N,D],
|
||||
# we just mock num_blocks to 1 for the dimension check below.
|
||||
self._is_kv_layout_blocks_first = (
|
||||
len(kv_cache_shape) == 5 and kv_cache_shape[0] == 1
|
||||
)
|
||||
|
||||
self._cross_layers_blocks = False
|
||||
if self.tensor_shape is not None:
|
||||
self._cross_layers_blocks = (
|
||||
len(self.tensor_shape) == len(kv_cache_shape) + 1
|
||||
)
|
||||
|
||||
if self._cross_layers_blocks:
|
||||
logger.debug("Using cross-layer KV cache")
|
||||
# prepend layers dimension
|
||||
_MOCK_NUM_LAYERS = 80
|
||||
kv_cache_shape = (_MOCK_NUM_LAYERS,) + kv_cache_shape
|
||||
try:
|
||||
kv_cache_stride_order = self.attn_backend.get_kv_cache_stride_order(
|
||||
include_num_layers_dimension=self._cross_layers_blocks
|
||||
)
|
||||
except (AttributeError, NotImplementedError):
|
||||
kv_cache_stride_order = tuple(range(len(self.tensor_shape)))
|
||||
|
||||
# In case of cross layers permute kv_cache_shape according to
|
||||
# stride_order to retrieve physical position of block_size
|
||||
kv_cache_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order)
|
||||
|
||||
# In the default non-cross layers layout the block_size position
|
||||
# is logical while in the cross layers case it is the physical
|
||||
# position. This matches the shape of the actual kv cache tensors
|
||||
# passed at register_kv_caches()/register_cross_layers_kv_cache()
|
||||
block_size_position = kv_cache_shape.index(_MOCK_BLOCK_SIZE)
|
||||
|
||||
assert block_size_position is not None
|
||||
self._block_size_position = -(len(kv_cache_shape) - block_size_position)
|
||||
|
||||
@property
|
||||
def is_kv_layout_blocks_first(self) -> bool:
|
||||
return self._is_kv_layout_blocks_first
|
||||
|
||||
@property
|
||||
def split_k_and_v(self) -> bool:
|
||||
# Whether to register regions for K and V separately (when present).
|
||||
return not (
|
||||
self._cross_layers_blocks or self.is_mla or self.is_kv_layout_blocks_first
|
||||
)
|
||||
|
||||
@property
|
||||
def tp_size(self) -> int:
|
||||
return self.remote_tp_size[self.engine_id]
|
||||
|
||||
@property
|
||||
def block_size(self) -> int:
|
||||
return self.remote_block_size[self.engine_id]
|
||||
|
||||
@property
|
||||
def cross_layers_blocks(self) -> bool:
|
||||
return self._cross_layers_blocks
|
||||
|
||||
@property
|
||||
def block_size_position(self) -> int:
|
||||
return self._block_size_position
|
||||
|
||||
def tp_ratio(
|
||||
self,
|
||||
remote_tp_size: int,
|
||||
) -> int:
|
||||
"""
|
||||
Calculate the tensor parallel ratio between local and remote TP.
|
||||
We can think of it as the number of local TP workers-per-remote TP
|
||||
workers. Local workers will read from the same remote TP worker in
|
||||
groups of size `tp_ratio`.If remote tp_size > local tp_size, the
|
||||
ratio is flipped (remote_size/local_size) and the returned value is
|
||||
negative.
|
||||
"""
|
||||
if self.tp_size >= remote_tp_size:
|
||||
assert self.tp_size % remote_tp_size == 0, (
|
||||
f"Local tensor parallel size {self.tp_size} is not divisible "
|
||||
f"by remote tensor parallel size {remote_tp_size}."
|
||||
)
|
||||
return self.tp_size // remote_tp_size
|
||||
|
||||
assert remote_tp_size % self.tp_size == 0, (
|
||||
f"Remote tensor parallel size {remote_tp_size} is not divisible "
|
||||
f"by local tensor parallel size {self.tp_size}."
|
||||
)
|
||||
# P TP > D TP case, return the ratio as negative
|
||||
return -remote_tp_size // self.tp_size
|
||||
|
||||
def block_size_ratio(
|
||||
self,
|
||||
remote_block_size: int,
|
||||
) -> int:
|
||||
"""
|
||||
Calculate the block size ratio between local and remote TP.
|
||||
"""
|
||||
assert self.block_size % remote_block_size == 0, (
|
||||
f"Local block size {self.block_size} is not divisible "
|
||||
f"by remote block size {remote_block_size} or vice versa."
|
||||
)
|
||||
return self.block_size // remote_block_size
|
||||
|
||||
def tp_ratio_from_engine_id(
|
||||
self,
|
||||
remote_engine_id: EngineId,
|
||||
) -> int:
|
||||
remote_tp_size = self.remote_tp_size[remote_engine_id]
|
||||
return self.tp_ratio(remote_tp_size)
|
||||
|
||||
def block_size_ratio_from_engine_id(
|
||||
self,
|
||||
remote_engine_id: EngineId,
|
||||
) -> int:
|
||||
remote_block_size = self.remote_block_size[remote_engine_id]
|
||||
return self.block_size_ratio(remote_block_size)
|
||||
|
||||
def is_kv_replicated(self, engine_id: EngineId) -> bool:
|
||||
"""
|
||||
Whether the KV cache is replicated across TP workers due to the
|
||||
number of TP workers being greater than the number of KV heads.
|
||||
"""
|
||||
tp_size = self.remote_tp_size[engine_id]
|
||||
return tp_size // self.total_num_kv_heads >= 1
|
||||
|
||||
def replicates_kv_cache(self, remote_engine_id: EngineId) -> bool:
|
||||
# MLA is always replicated as the hidden dim can't be split.
|
||||
return self.is_mla or self.is_kv_replicated(remote_engine_id)
|
||||
|
||||
def get_target_remote_ranks(
|
||||
self,
|
||||
remote_tp_size: int,
|
||||
) -> list[int]:
|
||||
"""
|
||||
Get the remote TP rank (on P) that the current local TP rank
|
||||
(on D) will read from. When remote tp_size > local tp_size, we
|
||||
read from multiple remote ranks.
|
||||
"""
|
||||
tp_ratio = self.tp_ratio(remote_tp_size)
|
||||
if tp_ratio > 0:
|
||||
return [self.tp_rank // tp_ratio]
|
||||
|
||||
# P TP > D TP case, D reads from |tp_ratio| remote workers.
|
||||
tp_ratio = -tp_ratio
|
||||
return [self.tp_rank * tp_ratio + i for i in range(tp_ratio)]
|
||||
|
||||
def get_target_remote_ranks_from_engine_id(
|
||||
self,
|
||||
remote_engine_id: EngineId,
|
||||
) -> list[int]:
|
||||
remote_tp_size = self.remote_tp_size[remote_engine_id]
|
||||
return self.get_target_remote_ranks(remote_tp_size)
|
||||
|
||||
|
||||
def get_current_attn_backend(vllm_config: VllmConfig):
|
||||
layer_type = cast(type[Any], AttentionLayerBase)
|
||||
layers = get_layers_from_vllm_config(vllm_config, layer_type, None)
|
||||
if layers:
|
||||
backend = next(iter(layers.values())).get_attn_backend()
|
||||
else:
|
||||
# Fallback for tests, when static_forward_context is empty.
|
||||
logger.debug(
|
||||
"No layers found in the vLLM config. "
|
||||
"Falling back to default attention backend."
|
||||
)
|
||||
from vllm.v1.attention.selector import get_attn_backend
|
||||
|
||||
backend = get_attn_backend(
|
||||
head_size=vllm_config.model_config.get_head_size(),
|
||||
dtype=vllm_config.model_config.dtype,
|
||||
kv_cache_dtype=vllm_config.cache_config.cache_dtype,
|
||||
block_size=vllm_config.cache_config.block_size,
|
||||
use_mla=vllm_config.model_config.use_mla,
|
||||
)
|
||||
return backend
|
||||
19
vllm/distributed/kv_transfer/kv_connector/v1/__init__.py
Normal file
19
vllm/distributed/kv_transfer/kv_connector/v1/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1,
|
||||
KVConnectorRole,
|
||||
SupportsHMA,
|
||||
supports_hma,
|
||||
)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.decode_bench_connector import ( # noqa: E501
|
||||
DecodeBenchConnector,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"KVConnectorRole",
|
||||
"KVConnectorBase_V1",
|
||||
"supports_hma",
|
||||
"SupportsHMA",
|
||||
"DecodeBenchConnector",
|
||||
]
|
||||
607
vllm/distributed/kv_transfer/kv_connector/v1/base.py
Normal file
607
vllm/distributed/kv_transfer/kv_connector/v1/base.py
Normal file
@@ -0,0 +1,607 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
KVConnectorBase_V1 Class for Distributed KV Cache & Hidden State
|
||||
communication in vLLM v1
|
||||
|
||||
The class provides the following primitives:
|
||||
Scheduler-side: runs in the scheduler, binds metadata, which
|
||||
is used by the worker-side to load/save KV cache.
|
||||
get_num_new_matched_tokens() - get number of new tokens
|
||||
that exist in the remote KV cache. Might be called multiple
|
||||
times for a given request and should be side-effect free.
|
||||
update_state_after_alloc() - update KVConnector state after
|
||||
temporary buffer alloc by the CacheManager.
|
||||
update_connector_output() - update KVConnector state after
|
||||
output is received from worker-side connectors.
|
||||
request_finished() - called once when a request is finished,
|
||||
with the computed kv cache blocks for the request.
|
||||
Returns whether KV cache should be freed now or if the
|
||||
connector now assumes responsibility for freeing the
|
||||
the blocks asynchronously. Also optionally returns KV
|
||||
transfer params.
|
||||
take_events() - returns new KV events that were collected
|
||||
by the connector since the last call.
|
||||
|
||||
Worker-side: runs in each worker, loads/saves KV cache to/from
|
||||
the Connector based on the metadata.
|
||||
handle_preemptions() - called if there are preempted requests,
|
||||
before their blocks are overwritten
|
||||
|
||||
start_load_kv() - starts loading all KVs (maybe async)
|
||||
wait_for_layer_load() - blocks until layer i load is done
|
||||
|
||||
save_kv_layer() - starts saving KV for layer i (maybe async)
|
||||
wait_for_save() - blocks until all saves are done
|
||||
|
||||
get_finished() - called with ids of finished requests, returns
|
||||
ids of requests that have completed async sending/recving.
|
||||
"""
|
||||
|
||||
import enum
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable, Iterable
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backend import AttentionBackend, AttentionMetadata
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.outputs import KVConnectorOutput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_events import KVCacheEvent, KVConnectorKVEvents
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
|
||||
KVConnectorPromMetrics,
|
||||
KVConnectorStats,
|
||||
PromMetric,
|
||||
PromMetricT,
|
||||
)
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.request import Request
|
||||
|
||||
# s_tensor_list, d_tensor_list, s_indices, d_indices, direction
|
||||
CopyBlocksOp = Callable[
|
||||
[
|
||||
dict[str, torch.Tensor],
|
||||
dict[str, torch.Tensor],
|
||||
list[int],
|
||||
list[int],
|
||||
Literal["h2d", "d2h"],
|
||||
],
|
||||
None,
|
||||
]
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class SupportsHMA(ABC):
|
||||
"""
|
||||
The class that indicates the corresponding connector supports hybrid memory
|
||||
allocator (HMA).
|
||||
This is required to use the connector together with hybrid memory allocator.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def request_finished_all_groups(
|
||||
self,
|
||||
request: "Request",
|
||||
block_ids: tuple[list[int], ...],
|
||||
) -> tuple[bool, dict[str, Any] | None]:
|
||||
"""
|
||||
Called exactly once when a request has finished for all kv cache groups,
|
||||
before its blocks are freed for each group.
|
||||
|
||||
NOTE(Kuntai): This function is only supported by connectors that support HMA.
|
||||
|
||||
The connector may assumes responsibility for freeing the blocks
|
||||
asynchronously by returning True.
|
||||
|
||||
Returns:
|
||||
True if the request is being saved/sent asynchronously and blocks
|
||||
should not be freed until the request_id is returned from
|
||||
get_finished().
|
||||
Optional KVTransferParams to be included in the request outputs
|
||||
returned by the engine.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def supports_hma(connector: Any) -> bool:
|
||||
if isinstance(connector, type):
|
||||
return issubclass(connector, SupportsHMA)
|
||||
else:
|
||||
return isinstance(connector, SupportsHMA)
|
||||
|
||||
|
||||
class KVConnectorRole(enum.Enum):
|
||||
# Connector running in the scheduler process
|
||||
SCHEDULER = 0
|
||||
|
||||
# Connector running in the worker process
|
||||
WORKER = 1
|
||||
|
||||
|
||||
class KVConnectorHandshakeMetadata(ABC): # noqa: B024
|
||||
"""
|
||||
Metadata used for out of band connector handshake between
|
||||
P/D workers. This needs to serializeable.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class KVConnectorMetadata(ABC): # noqa: B024
|
||||
"""
|
||||
Abstract Metadata used to communicate between the
|
||||
Scheduler KVConnector and Worker KVConnector.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class KVConnectorBase_V1(ABC):
|
||||
"""
|
||||
Base class for KV connectors.
|
||||
"""
|
||||
|
||||
@property
|
||||
def prefer_cross_layer_blocks(self) -> bool:
|
||||
"""
|
||||
Indicates whether this connector prefers KV blocks that hold KV data for all
|
||||
layers, which can speed up KV data transfers. Defaults to False.
|
||||
"""
|
||||
return False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: "VllmConfig",
|
||||
role: KVConnectorRole,
|
||||
kv_cache_config: "KVCacheConfig | None" = None,
|
||||
):
|
||||
logger.warning(
|
||||
"Initializing KVConnectorBase_V1. This API is experimental and "
|
||||
"subject to change in the future as we iterate the design."
|
||||
)
|
||||
self._connector_metadata: KVConnectorMetadata | None = None
|
||||
self._vllm_config = vllm_config
|
||||
if vllm_config.kv_transfer_config is not None:
|
||||
self._kv_transfer_config = vllm_config.kv_transfer_config
|
||||
else:
|
||||
raise ValueError("kv_transfer_config must be set for KVConnectorBase_V1")
|
||||
self._kv_cache_config = kv_cache_config
|
||||
if self._kv_cache_config is None:
|
||||
logger.warning(
|
||||
"KVConnectorBase_V1 initialized without kv_cache_config. "
|
||||
"This is deprecated - please update your connector to accept "
|
||||
"kv_cache_config as the third constructor argument and pass it "
|
||||
"to super().__init__()."
|
||||
)
|
||||
self._role = role
|
||||
|
||||
@property
|
||||
def role(self) -> KVConnectorRole:
|
||||
return self._role
|
||||
|
||||
# ==============================
|
||||
# Worker-side methods
|
||||
# ==============================
|
||||
|
||||
def bind_connector_metadata(self, connector_metadata: KVConnectorMetadata) -> None:
|
||||
"""Set the connector metadata from the scheduler.
|
||||
|
||||
This function should be called by the model runner every time
|
||||
before the model execution. The metadata will be used for runtime
|
||||
KV cache loading and saving.
|
||||
|
||||
Args:
|
||||
connector_metadata (dict): the connector metadata.
|
||||
"""
|
||||
self._connector_metadata = connector_metadata
|
||||
|
||||
def clear_connector_metadata(self) -> None:
|
||||
"""Clear the connector metadata.
|
||||
|
||||
This function should be called by the model runner every time
|
||||
after the model execution.
|
||||
"""
|
||||
self._connector_metadata = None
|
||||
|
||||
def _get_connector_metadata(self) -> KVConnectorMetadata:
|
||||
"""Get the connector metadata.
|
||||
|
||||
This function should only be called inside the connector.
|
||||
|
||||
Returns:
|
||||
ConnectorMetadata: the connector metadata.
|
||||
"""
|
||||
# Should only be called while set to valid metadata.
|
||||
assert self._connector_metadata is not None
|
||||
return self._connector_metadata
|
||||
|
||||
def has_connector_metadata(self) -> bool:
|
||||
"""Check whether the connector metadata is currently set.
|
||||
|
||||
Returns:
|
||||
bool: True if connector metadata exists, False otherwise.
|
||||
"""
|
||||
return self._connector_metadata is not None
|
||||
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
"""
|
||||
Initialize with the KV caches. Useful for pre-registering the
|
||||
KV Caches in the KVConnector (e.g. for NIXL).
|
||||
|
||||
Args:
|
||||
kv_caches: dictionary of layer names, kv cache
|
||||
"""
|
||||
return
|
||||
|
||||
def register_cross_layers_kv_cache(
|
||||
self, kv_cache: torch.Tensor, attn_backend: type["AttentionBackend"]
|
||||
):
|
||||
"""
|
||||
Initialize with a single KV cache tensor used by all layers.
|
||||
The first dimension should be num_layers.
|
||||
This function will only be called for models with uniform layers,
|
||||
and only if the prefers_cross_layer_blocks is set to True.
|
||||
Only one of the functions
|
||||
{register_kv_caches, register_cross_layers_kv_cache} will be called.
|
||||
|
||||
Args:
|
||||
kv_cache: a cross-layers kv cache tensor
|
||||
attn_backend: The attention backend that corresponds to all layers
|
||||
"""
|
||||
return
|
||||
|
||||
def set_host_xfer_buffer_ops(self, copy_operation: CopyBlocksOp):
|
||||
"""
|
||||
Set the xPU-specific ops for copying KV between host and device.
|
||||
Needed when host buffer is used for kv transfer (e.g., in NixlConnector)
|
||||
"""
|
||||
return
|
||||
|
||||
def handle_preemptions(self, preempted_req_ids: set[str]):
|
||||
"""
|
||||
Handle preempted requests BEFORE their blocks are overwritten.
|
||||
Needed for connectors which use async saves (e.g., OffloadingConnector)
|
||||
"""
|
||||
return
|
||||
|
||||
@abstractmethod
|
||||
def start_load_kv(self, forward_context: "ForwardContext", **kwargs: Any) -> None:
|
||||
"""
|
||||
Start loading the KV cache from the connector to vLLM's paged
|
||||
KV buffer. This is called from the forward context before the
|
||||
forward pass to enable async loading during model execution.
|
||||
|
||||
Args:
|
||||
forward_context (ForwardContext): the forward context.
|
||||
**kwargs: additional arguments for the load operation
|
||||
|
||||
Note:
|
||||
The number of elements in kv_caches and layer_names should be
|
||||
the same.
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||
"""
|
||||
Block until the KV for a specific layer is loaded into vLLM's
|
||||
paged buffer. This is called from within attention layer to ensure
|
||||
async copying from start_load_kv is complete.
|
||||
|
||||
This interface will be useful for layer-by-layer pipelining.
|
||||
|
||||
Args:
|
||||
layer_name: the name of that layer
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save_kv_layer(
|
||||
self,
|
||||
layer_name: str,
|
||||
kv_layer: torch.Tensor,
|
||||
attn_metadata: "AttentionMetadata",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Start saving a layer of KV cache from vLLM's paged buffer
|
||||
to the connector. This is called from within attention layer to
|
||||
enable async copying during execution.
|
||||
|
||||
Args:
|
||||
layer_name (str): the name of the layer.
|
||||
kv_layer (torch.Tensor): the paged KV buffer of the current
|
||||
layer in vLLM.
|
||||
attn_metadata (AttentionMetadata): the attention metadata.
|
||||
**kwargs: additional arguments for the save operation.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def wait_for_save(self):
|
||||
"""
|
||||
Block until all the save operations is done. This is called
|
||||
as the forward context exits to ensure that the async saving
|
||||
from save_kv_layer is complete before finishing the forward.
|
||||
|
||||
This prevents overwrites of paged KV buffer before saving done.
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_finished(
|
||||
self, finished_req_ids: set[str]
|
||||
) -> tuple[set[str] | None, set[str] | None]:
|
||||
"""
|
||||
Notifies worker-side connector ids of requests that have
|
||||
finished generating tokens on the worker.
|
||||
The scheduler process (via the Executors) will use this output
|
||||
to track which workers are done.
|
||||
|
||||
Returns:
|
||||
ids of requests that have finished asynchronous transfer
|
||||
(requests that previously returned True from request_finished()),
|
||||
tuple of (sending/saving ids, recving/loading ids).
|
||||
The finished saves/sends req ids must belong to a set provided in a
|
||||
call to this method (this call or a prior one).
|
||||
"""
|
||||
return None, None
|
||||
|
||||
def get_block_ids_with_load_errors(self) -> set[int]:
|
||||
"""
|
||||
Get the set of block IDs that failed to load.
|
||||
|
||||
Returns:
|
||||
Set of block IDs that encountered load errors.
|
||||
Empty set if no load errors occurred.
|
||||
|
||||
Notes:
|
||||
- Applies to both sync- and async-loading requests.
|
||||
- Async loading: failed blocks may be reported in any forward pass
|
||||
up to and including the pass where the request ID is returned by
|
||||
`get_finished()`. Even if failures occur, the request must still
|
||||
be reported via `get_finished()`, and the failed block IDs must
|
||||
appear here no later than that same pass.
|
||||
- Sync loading: failed blocks should be reported in the forward
|
||||
pass in which they are detected.
|
||||
"""
|
||||
return set()
|
||||
|
||||
def shutdown(self):
|
||||
"""
|
||||
Shutdown the connector. This is called when the worker process
|
||||
is shutting down to ensure that all the async operations are
|
||||
completed and the connector is cleaned up properly.
|
||||
"""
|
||||
return None
|
||||
|
||||
def get_kv_connector_stats(self) -> "KVConnectorStats | None":
|
||||
"""
|
||||
Get the KV connector stats collected during the last interval.
|
||||
"""
|
||||
return None
|
||||
|
||||
def get_kv_connector_kv_cache_events(self) -> "KVConnectorKVEvents | None":
|
||||
"""
|
||||
Get the KV connector kv cache events collected during the last interval.
|
||||
This function should be called by the model runner every time after the
|
||||
model execution and before cleanup.
|
||||
"""
|
||||
return None
|
||||
|
||||
def get_handshake_metadata(self) -> KVConnectorHandshakeMetadata | None:
|
||||
"""
|
||||
Get the KVConnector handshake metadata for this connector.
|
||||
This metadata is used for out-of-band connector handshake
|
||||
between P/D workers.
|
||||
|
||||
Returns:
|
||||
KVConnectorHandshakeMetadata: the handshake metadata.
|
||||
None if no handshake metadata is available.
|
||||
"""
|
||||
return None
|
||||
|
||||
# ==============================
|
||||
# Scheduler-side methods
|
||||
# ==============================
|
||||
|
||||
@abstractmethod
|
||||
def get_num_new_matched_tokens(
|
||||
self,
|
||||
request: "Request",
|
||||
num_computed_tokens: int,
|
||||
) -> tuple[int | None, bool]:
|
||||
"""
|
||||
Get number of new tokens that can be loaded from the
|
||||
external KV cache beyond the num_computed_tokens.
|
||||
|
||||
Args:
|
||||
request (Request): the request object.
|
||||
num_computed_tokens (int): the number of locally
|
||||
computed tokens for this request
|
||||
|
||||
Returns:
|
||||
A tuple with the following elements:
|
||||
- An optional number of tokens that can be loaded from the
|
||||
external KV cache beyond what is already computed.
|
||||
If None, it means that the connector needs more time to
|
||||
determine the number of matched tokens, and the scheduler
|
||||
should query for this request again later.
|
||||
- `True` if external KV cache tokens will be loaded
|
||||
asynchronously (between scheduler steps). Must be
|
||||
'False' if the first element is 0.
|
||||
|
||||
Notes:
|
||||
The connector should only consider the largest prefix of prompt-
|
||||
tokens for which KV cache is actually available at the time of the
|
||||
call. If the cache cannot be loaded for some tokens (e.g., due to
|
||||
connectivity issues or eviction), those tokens must not be taken
|
||||
into account.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_state_after_alloc(
|
||||
self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
|
||||
):
|
||||
"""
|
||||
Update KVConnector state after block allocation.
|
||||
|
||||
If get_num_new_matched_tokens previously returned True for a
|
||||
request, this function may be called twice for that same request -
|
||||
first when blocks are allocated for the connector tokens to be
|
||||
asynchronously loaded into, and second when any additional blocks
|
||||
are allocated, after the load/transfer is complete.
|
||||
|
||||
Args:
|
||||
request (Request): the request object.
|
||||
blocks (KVCacheBlocks): the blocks allocated for the request.
|
||||
num_external_tokens (int): the number of tokens that will be
|
||||
loaded from the external KV cache.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def build_connector_meta(
|
||||
self, scheduler_output: SchedulerOutput
|
||||
) -> KVConnectorMetadata:
|
||||
"""
|
||||
Build the connector metadata for this step.
|
||||
|
||||
This function should NOT modify fields in the scheduler_output.
|
||||
Also, calling this function will reset the state of the connector.
|
||||
|
||||
Args:
|
||||
scheduler_output (SchedulerOutput): the scheduler output object.
|
||||
"""
|
||||
pass
|
||||
|
||||
def update_connector_output(self, connector_output: KVConnectorOutput):
|
||||
"""
|
||||
Update KVConnector state from worker-side connectors output.
|
||||
|
||||
Args:
|
||||
connector_output (KVConnectorOutput): the worker-side
|
||||
connectors output.
|
||||
"""
|
||||
return
|
||||
|
||||
def request_finished(
|
||||
self,
|
||||
request: "Request",
|
||||
block_ids: list[int],
|
||||
) -> tuple[bool, dict[str, Any] | None]:
|
||||
"""
|
||||
Called exactly once when a request has finished, before its blocks are
|
||||
freed.
|
||||
|
||||
The connector may assumes responsibility for freeing the blocks
|
||||
asynchronously by returning True.
|
||||
|
||||
Returns:
|
||||
True if the request is being saved/sent asynchronously and blocks
|
||||
should not be freed until the request_id is returned from
|
||||
get_finished().
|
||||
Optional KVTransferParams to be included in the request outputs
|
||||
returned by the engine.
|
||||
"""
|
||||
return False, None
|
||||
|
||||
def take_events(self) -> Iterable["KVCacheEvent"]:
|
||||
"""
|
||||
Take the KV cache events from the connector.
|
||||
|
||||
Yields:
|
||||
New KV cache events since the last call.
|
||||
"""
|
||||
return ()
|
||||
|
||||
@classmethod
|
||||
def get_required_kvcache_layout(cls, vllm_config: "VllmConfig") -> str | None:
|
||||
"""
|
||||
Get the required KV cache layout for this connector.
|
||||
Args:
|
||||
vllm_config (VllmConfig): the vllm config.
|
||||
|
||||
Returns:
|
||||
str: the required KV cache layout. e.g. HND, or NHD.
|
||||
None if the connector does not require a specific layout.
|
||||
"""
|
||||
|
||||
if cls is KVConnectorBase_V1:
|
||||
raise TypeError(
|
||||
"get_required_kvcache_layout should not be called "
|
||||
"on the abstract base class"
|
||||
)
|
||||
return None
|
||||
|
||||
def get_finished_count(self) -> int | None:
|
||||
"""
|
||||
Get the count of requests expected to complete send/receive operations
|
||||
via this connector. This method is used to initialize the
|
||||
KVOutputAggregator, overwriting the default world_size.
|
||||
|
||||
Returns:
|
||||
int: expected sending or receiving completion count.
|
||||
"""
|
||||
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def build_kv_connector_stats(
|
||||
cls, data: dict[str, Any] | None = None
|
||||
) -> "KVConnectorStats | None":
|
||||
"""
|
||||
KVConnectorStats resolution method. This method allows dynamically
|
||||
registered connectors to return their own KVConnectorStats object,
|
||||
which can implement custom aggregation logic on the data dict.
|
||||
"""
|
||||
return None
|
||||
|
||||
def set_xfer_handshake_metadata(
|
||||
self, metadata: dict[int, KVConnectorHandshakeMetadata]
|
||||
) -> None:
|
||||
"""
|
||||
Set the KV connector handshake metadata for this connector.
|
||||
|
||||
Args:
|
||||
metadata (KVConnectorHandshakeMetadata): the handshake metadata to set.
|
||||
"""
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def build_prom_metrics(
|
||||
cls,
|
||||
vllm_config: "VllmConfig",
|
||||
metric_types: dict[type["PromMetric"], type["PromMetricT"]],
|
||||
labelnames: list[str],
|
||||
per_engine_labelvalues: dict[int, list[object]],
|
||||
) -> "KVConnectorPromMetrics | None":
|
||||
"""
|
||||
Create a KVConnectorPromMetrics subclass which should register
|
||||
per-connector Prometheus metrics and implement observe() to
|
||||
expose connector transfer stats via Prometheus.
|
||||
"""
|
||||
return None
|
||||
|
||||
def reset_cache(self) -> bool | None:
|
||||
"""
|
||||
Reset the connector's internal cache.
|
||||
|
||||
Returns:
|
||||
bool: True if the cache was successfully reset, False otherwise.
|
||||
"""
|
||||
logger.debug(
|
||||
"Connector cache reset requested, but %s does not implement reset_cache().",
|
||||
type(self).__name__,
|
||||
)
|
||||
|
||||
return None
|
||||
@@ -0,0 +1,419 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
DecodeBenchConnector: A KV Connector for decode instance performance testing.
|
||||
|
||||
This connector emulates a prefill-decode disaggregated setting by filling
|
||||
the KV cache with dummy values, allowing measurement of decoder performance
|
||||
under larger input sequence lengths (ISL) in resource-limited environments.
|
||||
|
||||
Usage:
|
||||
To use this connector for benchmarking, configure it in the kv_transfer_config:
|
||||
|
||||
Example:
|
||||
vllm serve <model> --kv-transfer-config '{
|
||||
"kv_connector": "DecodeBenchConnector",
|
||||
"kv_role": "kv_both",
|
||||
"kv_connector_extra_config": {
|
||||
"fill_mean": 0.015,
|
||||
"fill_std": 0.0
|
||||
}
|
||||
}'
|
||||
|
||||
Then run your benchmark with desired input/output lengths:
|
||||
vllm bench serve --base-url http://127.0.0.1:8000 --model <model> \\
|
||||
--dataset-name random --random-input-len 40000 \\
|
||||
--random-output-len 100 --max-concurrency 10
|
||||
|
||||
Configuration options (via kv_connector_extra_config):
|
||||
- fill_mean (float): Mean value for random normal fill (default: 0.015)
|
||||
- fill_std (float): Standard deviation for random fill (default: 0.0)
|
||||
Set to 0 for constant values, >0 for random sampling
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1 import (
|
||||
KVConnectorBase_V1,
|
||||
KVConnectorRole,
|
||||
)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.attention.backend import AttentionMetadata
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.request import Request
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DecodeBenchConnectorMetadata(KVConnectorMetadata):
|
||||
"""Metadata for DecodeBenchConnector.
|
||||
|
||||
Contains information about which requests need their KV cache filled
|
||||
with dummy values for benchmarking purposes.
|
||||
"""
|
||||
|
||||
# request_id -> (block_ids_per_group, num_tokens_to_fill)
|
||||
# block_ids_per_group is a tuple of lists, one per KV cache group
|
||||
# For standard attention: single group, e.g., ([1, 2, 3],)
|
||||
# For MLA: multiple groups, e.g., ([1, 2], [1, 2])
|
||||
reqs_to_fill: dict[str, tuple[tuple[list[int], ...], int]]
|
||||
|
||||
|
||||
class DecodeBenchConnector(KVConnectorBase_V1):
|
||||
"""
|
||||
A KV Connector for decode instance performance testing.
|
||||
|
||||
This connector fills the KV cache with dummy (non-zero) values to
|
||||
emulate a prefill-decode disaggregated setting, enabling performance
|
||||
testing of the decoder with larger input sequence lengths.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: "VllmConfig",
|
||||
role: KVConnectorRole,
|
||||
kv_cache_config: "KVCacheConfig | None" = None,
|
||||
):
|
||||
super().__init__(vllm_config, role, kv_cache_config)
|
||||
|
||||
self.connector_scheduler: DecodeBenchConnectorScheduler | None = None
|
||||
self.connector_worker: DecodeBenchConnectorWorker | None = None
|
||||
|
||||
if role == KVConnectorRole.SCHEDULER:
|
||||
self.connector_scheduler = DecodeBenchConnectorScheduler(vllm_config)
|
||||
elif role == KVConnectorRole.WORKER:
|
||||
self.connector_worker = DecodeBenchConnectorWorker(vllm_config)
|
||||
|
||||
# ==============================
|
||||
# Worker-side methods
|
||||
# ==============================
|
||||
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
assert self.connector_worker is not None
|
||||
self.connector_worker.register_kv_caches(kv_caches)
|
||||
|
||||
def start_load_kv(self, forward_context: "ForwardContext", **kwargs: Any) -> None:
|
||||
assert self.connector_worker is not None
|
||||
assert isinstance(self._connector_metadata, DecodeBenchConnectorMetadata)
|
||||
self.connector_worker.start_fill_kv(self._connector_metadata)
|
||||
|
||||
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||
# All operations are synchronous, so nothing to wait for
|
||||
pass
|
||||
|
||||
def save_kv_layer(
|
||||
self,
|
||||
layer_name: str,
|
||||
kv_layer: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
# This connector doesn't save KV cache (benchmarking only)
|
||||
pass
|
||||
|
||||
def wait_for_save(self):
|
||||
# This connector doesn't save KV cache (benchmarking only)
|
||||
pass
|
||||
|
||||
# ==============================
|
||||
# Scheduler-side methods
|
||||
# ==============================
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self,
|
||||
request: "Request",
|
||||
num_computed_tokens: int,
|
||||
) -> tuple[int | None, bool]:
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.get_num_new_matched_tokens(
|
||||
request, num_computed_tokens
|
||||
)
|
||||
|
||||
def update_state_after_alloc(
|
||||
self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
|
||||
):
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.update_state_after_alloc(
|
||||
request, blocks, num_external_tokens
|
||||
)
|
||||
|
||||
def build_connector_meta(
|
||||
self, scheduler_output: "SchedulerOutput"
|
||||
) -> KVConnectorMetadata:
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.build_connector_meta(scheduler_output)
|
||||
|
||||
def request_finished(
|
||||
self,
|
||||
request: "Request",
|
||||
block_ids: list[int],
|
||||
) -> tuple[bool, dict[str, Any] | None]:
|
||||
assert self.connector_scheduler is not None
|
||||
self.connector_scheduler.request_finished(request)
|
||||
return False, None
|
||||
|
||||
|
||||
class DecodeBenchConnectorScheduler:
|
||||
"""Scheduler-side implementation for DecodeBenchConnector."""
|
||||
|
||||
def __init__(self, vllm_config: "VllmConfig"):
|
||||
self.vllm_config = vllm_config
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
|
||||
# Track which requests have already been filled
|
||||
self._filled_requests: set[str] = set()
|
||||
|
||||
# Track pending fills for the current scheduler step
|
||||
# request_id -> (block_ids_per_group, num_tokens_to_fill)
|
||||
# Note: _pending_fills doesn't need explicit cleanup - it's cleared
|
||||
# after build_connector_meta() is called in the same scheduler step
|
||||
self._pending_fills: dict[str, tuple[tuple[list[int], ...], int]] = {}
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self,
|
||||
request: "Request",
|
||||
num_computed_tokens: int,
|
||||
) -> tuple[int, bool]:
|
||||
"""
|
||||
For new requests, return the number of tokens that should be filled
|
||||
with dummy KV cache values.
|
||||
|
||||
Returns:
|
||||
(num_tokens_to_fill, is_async)
|
||||
- num_tokens_to_fill: number of uncomputed tokens minus 1
|
||||
(we fill everything except the last token for decode)
|
||||
- is_async: False (synchronous filling)
|
||||
"""
|
||||
req_id = request.request_id
|
||||
|
||||
# Only fill once per request on first scheduling
|
||||
if req_id in self._filled_requests:
|
||||
return 0, False
|
||||
|
||||
# Calculate how many tokens we need to fill
|
||||
# Fill all uncomputed tokens except the last one (which will be decoded)
|
||||
# This simulates having processed a long prefill
|
||||
num_uncomputed_tokens = request.num_tokens - num_computed_tokens
|
||||
num_tokens_to_fill = max(0, num_uncomputed_tokens - 1)
|
||||
|
||||
if num_tokens_to_fill == 0:
|
||||
return 0, False
|
||||
|
||||
# Return False for synchronous operation - the fill is fast enough
|
||||
# that async overhead isn't worth it
|
||||
return num_tokens_to_fill, False
|
||||
|
||||
def update_state_after_alloc(
|
||||
self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
|
||||
):
|
||||
"""
|
||||
Called after blocks are allocated. Store the block IDs so we can
|
||||
fill them with dummy values.
|
||||
|
||||
Supports both standard attention (single KV cache group) and MLA
|
||||
(multiple KV cache groups).
|
||||
"""
|
||||
req_id = request.request_id
|
||||
|
||||
if num_external_tokens == 0:
|
||||
return
|
||||
|
||||
# Get the block IDs that were allocated
|
||||
# block_groups is a tuple of lists, one per KV cache group
|
||||
# For standard attention: 1 group
|
||||
# For MLA: multiple groups (one per attention type)
|
||||
block_groups = blocks.get_block_ids()
|
||||
|
||||
# Calculate how many blocks we need to fill
|
||||
# num_external_tokens are the tokens we said we'd provide
|
||||
num_blocks_to_fill = cdiv(num_external_tokens, self.block_size)
|
||||
|
||||
# Extract the first num_blocks_to_fill blocks from each group
|
||||
# All groups should have the same block IDs for the same request
|
||||
block_ids_per_group = tuple(
|
||||
group_blocks[:num_blocks_to_fill] for group_blocks in block_groups
|
||||
)
|
||||
|
||||
# Store the blocks to fill for all group. _pending_fills doesn't need cleanup
|
||||
# as it's cleared after build_connector_meta
|
||||
self._pending_fills[req_id] = (
|
||||
block_ids_per_group,
|
||||
num_external_tokens,
|
||||
)
|
||||
self._filled_requests.add(req_id)
|
||||
|
||||
logger.debug(
|
||||
"DecodeBenchConnector: Allocated %d blocks across %d KV cache groups "
|
||||
"for request %s",
|
||||
num_blocks_to_fill,
|
||||
len(block_groups),
|
||||
req_id,
|
||||
)
|
||||
|
||||
def build_connector_meta(
|
||||
self, scheduler_output: "SchedulerOutput"
|
||||
) -> KVConnectorMetadata:
|
||||
"""
|
||||
Build metadata containing information about which blocks to fill
|
||||
with dummy KV values.
|
||||
"""
|
||||
meta = DecodeBenchConnectorMetadata(reqs_to_fill=self._pending_fills.copy())
|
||||
|
||||
# Clear pending fills after building metadata
|
||||
self._pending_fills.clear()
|
||||
|
||||
return meta
|
||||
|
||||
def request_finished(self, request: "Request"):
|
||||
"""
|
||||
Called when a request has finished. Clean up any state.
|
||||
"""
|
||||
self._filled_requests.discard(request.request_id)
|
||||
|
||||
|
||||
class DecodeBenchConnectorWorker:
|
||||
"""Worker-side implementation for DecodeBenchConnector."""
|
||||
|
||||
def __init__(self, vllm_config: "VllmConfig"):
|
||||
self.vllm_config = vllm_config
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
|
||||
# Get fill parameters from extra config
|
||||
kv_transfer_config = vllm_config.kv_transfer_config
|
||||
assert kv_transfer_config is not None
|
||||
self.fill_mean = kv_transfer_config.get_from_extra_config("fill_mean", 0.015)
|
||||
self.fill_std = kv_transfer_config.get_from_extra_config("fill_std", 0.0)
|
||||
|
||||
# Will be populated via register_kv_caches
|
||||
self.kv_caches: dict[str, torch.Tensor] | None = None
|
||||
|
||||
# Mapping from KV cache group index to list of layer names in that group
|
||||
self.group_to_layers: dict[int, list[str]] | None = None
|
||||
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
"""Store references to the KV cache tensors and build group mapping."""
|
||||
self.kv_caches = kv_caches
|
||||
|
||||
# For simplicity, assume all layers belong to group 0 (standard attention)
|
||||
# For MLA models with multiple groups, the metadata will handle the mapping
|
||||
# We just need to fill the blocks specified in the metadata
|
||||
self.group_to_layers = {0: list(kv_caches.keys())}
|
||||
|
||||
logger.debug(
|
||||
"DecodeBenchConnector: Registered %d KV cache layers",
|
||||
len(kv_caches),
|
||||
)
|
||||
|
||||
def start_fill_kv(self, metadata: DecodeBenchConnectorMetadata):
|
||||
"""
|
||||
Fill the allocated KV cache blocks with dummy (non-zero) values.
|
||||
|
||||
This simulates having a populated KV cache from a prefill phase,
|
||||
allowing decode performance testing with larger context sizes.
|
||||
|
||||
Supports both standard attention (single group) and MLA (multiple groups).
|
||||
"""
|
||||
if not metadata.reqs_to_fill:
|
||||
return
|
||||
|
||||
assert self.kv_caches is not None, "KV caches must be registered before filling"
|
||||
assert self.group_to_layers is not None, "Group mapping must be initialized"
|
||||
|
||||
for req_id, (block_ids_per_group, num_tokens) in metadata.reqs_to_fill.items():
|
||||
# Fill blocks for each KV cache group
|
||||
for group_idx, block_ids in enumerate(block_ids_per_group):
|
||||
self._fill_blocks(group_idx, block_ids, num_tokens)
|
||||
|
||||
logger.debug(
|
||||
"DecodeBenchConnector: Filled %d blocks (%d tokens) across %d groups "
|
||||
"for request %s",
|
||||
len(block_ids_per_group[0]) if block_ids_per_group else 0,
|
||||
num_tokens,
|
||||
len(block_ids_per_group),
|
||||
req_id,
|
||||
)
|
||||
|
||||
def _fill_blocks(self, group_idx: int, block_ids: list[int], num_tokens: int):
|
||||
"""
|
||||
Fill specified blocks with dummy non-zero values for a specific KV cache group.
|
||||
|
||||
Args:
|
||||
group_idx: The KV cache group index to fill
|
||||
block_ids: List of block IDs to fill in this group
|
||||
num_tokens: Total number of tokens to fill across these blocks
|
||||
"""
|
||||
if not block_ids:
|
||||
return
|
||||
|
||||
assert self.kv_caches is not None
|
||||
assert self.group_to_layers is not None
|
||||
|
||||
# Get the layers that belong to this group
|
||||
layer_names = self.group_to_layers.get(group_idx, [])
|
||||
|
||||
# Fill only the layers in this group
|
||||
for layer_name in layer_names:
|
||||
if layer_name not in self.kv_caches:
|
||||
logger.warning(
|
||||
"DecodeBenchConnector: Layer %s not found in KV caches", layer_name
|
||||
)
|
||||
continue
|
||||
|
||||
kv_cache = self.kv_caches[layer_name]
|
||||
|
||||
# Convert block_ids to tensor on device
|
||||
block_ids_tensor = torch.tensor(
|
||||
block_ids, dtype=torch.long, device=kv_cache.device
|
||||
)
|
||||
|
||||
# Filter invalid block IDs
|
||||
valid_mask = block_ids_tensor < kv_cache.shape[0]
|
||||
valid_block_ids = block_ids_tensor[valid_mask]
|
||||
|
||||
if len(valid_block_ids) == 0:
|
||||
continue
|
||||
|
||||
# Create fill values - either constant or random
|
||||
block_shape = kv_cache.shape[1:]
|
||||
if self.fill_std > 0:
|
||||
# Random normal sampling
|
||||
fill_values = torch.normal(
|
||||
mean=self.fill_mean,
|
||||
std=self.fill_std,
|
||||
size=(len(valid_block_ids),) + block_shape,
|
||||
dtype=kv_cache.dtype,
|
||||
device=kv_cache.device,
|
||||
)
|
||||
else:
|
||||
# Constant fill value
|
||||
fill_values = torch.full(
|
||||
(len(valid_block_ids),) + block_shape,
|
||||
self.fill_mean,
|
||||
dtype=kv_cache.dtype,
|
||||
device=kv_cache.device,
|
||||
)
|
||||
|
||||
# Batch fill operation
|
||||
kv_cache[valid_block_ids] = fill_values
|
||||
|
||||
logger.debug(
|
||||
"DecodeBenchConnector: Filled %d blocks in group %d with %s values "
|
||||
"(mean=%.3f, std=%.3f)",
|
||||
len(block_ids),
|
||||
group_idx,
|
||||
"random" if self.fill_std > 0 else "constant",
|
||||
self.fill_mean,
|
||||
self.fill_std,
|
||||
)
|
||||
@@ -0,0 +1,442 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import safetensors
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1,
|
||||
KVConnectorMetadata,
|
||||
KVConnectorRole,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention.mla_attention import MLACommonMetadata
|
||||
from vllm.utils.hashing import safe_hash
|
||||
from vllm.v1.attention.backend import AttentionMetadata
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.request import Request
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReqMeta:
|
||||
# Request tokens
|
||||
token_ids: torch.Tensor
|
||||
# Slot mappings, should have the same length as token_ids
|
||||
slot_mapping: torch.Tensor
|
||||
# Is store or load
|
||||
is_store: bool
|
||||
mm_hashes: list[str]
|
||||
|
||||
@staticmethod
|
||||
def make_meta(
|
||||
token_ids: list[int],
|
||||
block_ids: list[int],
|
||||
block_size: int,
|
||||
is_store: bool,
|
||||
mm_hashes: list[str],
|
||||
) -> "ReqMeta":
|
||||
valid_num_tokens = align_to_block_size(len(token_ids), block_size)
|
||||
token_ids_tensor = torch.tensor(token_ids)[:valid_num_tokens]
|
||||
block_ids_tensor = torch.tensor(block_ids)
|
||||
num_blocks = block_ids_tensor.shape[0]
|
||||
block_offsets = torch.arange(0, block_size)
|
||||
slot_mapping = (
|
||||
block_offsets.reshape((1, block_size))
|
||||
+ block_ids_tensor.reshape((num_blocks, 1)) * block_size
|
||||
)
|
||||
slot_mapping = slot_mapping.flatten()[:valid_num_tokens]
|
||||
return ReqMeta(
|
||||
token_ids=token_ids_tensor,
|
||||
slot_mapping=slot_mapping,
|
||||
is_store=is_store,
|
||||
mm_hashes=mm_hashes,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExampleConnectorMetadata(KVConnectorMetadata):
|
||||
requests: list[ReqMeta] = field(default_factory=list)
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
token_ids: list[int],
|
||||
block_ids: list[int],
|
||||
block_size: int,
|
||||
is_store: bool,
|
||||
mm_hashes: list[str],
|
||||
) -> None:
|
||||
self.requests.append(
|
||||
ReqMeta.make_meta(token_ids, block_ids, block_size, is_store, mm_hashes)
|
||||
)
|
||||
|
||||
|
||||
class ExampleConnector(KVConnectorBase_V1):
|
||||
# NOTE: This is Simple debug implementation of the KV connector.
|
||||
# It save / load the KV cache to / from the disk.
|
||||
# It does extra work which will overwrite the existing prefix-cache in GPU
|
||||
# - to remove the overhead, need to add some "mask" in the ReqMeta class
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: "VllmConfig",
|
||||
role: KVConnectorRole,
|
||||
kv_cache_config: "KVCacheConfig | None" = None,
|
||||
):
|
||||
super().__init__(
|
||||
vllm_config=vllm_config,
|
||||
role=role,
|
||||
kv_cache_config=kv_cache_config,
|
||||
)
|
||||
self._block_size = vllm_config.cache_config.block_size
|
||||
self._requests_need_load: dict[str, Request] = {}
|
||||
self._storage_path = self._kv_transfer_config.get_from_extra_config(
|
||||
"shared_storage_path", "/tmp"
|
||||
)
|
||||
logger.info(self._kv_transfer_config)
|
||||
logger.info("Shared storage path is %s", self._storage_path)
|
||||
|
||||
def start_load_kv(self, forward_context: "ForwardContext", **kwargs: Any) -> None:
|
||||
"""Start loading the KV cache from the connector buffer to vLLM's
|
||||
paged KV buffer.
|
||||
|
||||
Args:
|
||||
forward_context (ForwardContext): the forward context.
|
||||
**kwargs: additional arguments for the load operation
|
||||
|
||||
Note:
|
||||
The number of elements in kv_caches and layer_names should be
|
||||
the same.
|
||||
"""
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
|
||||
def inject_kv_into_layer(
|
||||
dst_kv_cache_layer: torch.Tensor,
|
||||
src_kv_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
) -> None:
|
||||
"""Inject the KV cache into the layer.
|
||||
|
||||
Args:
|
||||
dst_kv_cache_layer (torch.Tensor): the destination KV cache
|
||||
layer. In shape [2, num_pages, page_size, xxx] if not
|
||||
using MLA, [num_pages, page_size, xxx] otherwise.
|
||||
src_kv_cache (torch.Tensor): the source KV cache. In shape
|
||||
[2, num_tokens, xxx] if not using MLA, [num_tokens, xxx]
|
||||
otherwise.
|
||||
slot_mapping (torch.Tensor): the slot mapping. In shape
|
||||
[num_tokens].
|
||||
"""
|
||||
dst_kv_cache_layer_shape = dst_kv_cache_layer.shape
|
||||
if isinstance(attn_metadata, MLACommonMetadata):
|
||||
num_pages = dst_kv_cache_layer_shape[0]
|
||||
page_size = dst_kv_cache_layer_shape[1]
|
||||
dst_kv_cache_layer = dst_kv_cache_layer.reshape(
|
||||
num_pages * page_size, -1
|
||||
)
|
||||
dst_kv_cache_layer[slot_mapping, ...] = src_kv_cache
|
||||
else:
|
||||
num_pages = dst_kv_cache_layer_shape[1]
|
||||
page_size = dst_kv_cache_layer_shape[2]
|
||||
dst_kv_cache_layer = dst_kv_cache_layer.reshape(
|
||||
2, num_pages * page_size, -1
|
||||
)
|
||||
dst_kv_cache_layer[:, slot_mapping, ...] = src_kv_cache
|
||||
|
||||
# Get the metadata
|
||||
metadata: KVConnectorMetadata = self._get_connector_metadata()
|
||||
assert isinstance(metadata, ExampleConnectorMetadata)
|
||||
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
if attn_metadata is None:
|
||||
logger.warning("In connector.start_load_kv, but the attn_metadata is None")
|
||||
return
|
||||
|
||||
# Load the KV for each request each layer
|
||||
for request in metadata.requests:
|
||||
if request.is_store:
|
||||
continue
|
||||
logger.info(
|
||||
"Inject KV cache of %d tokens to the paged memory",
|
||||
len(request.slot_mapping),
|
||||
)
|
||||
for layer_name in forward_context.no_compile_layers:
|
||||
layer = forward_context.no_compile_layers[layer_name]
|
||||
|
||||
# Only process layers that have kv_cache
|
||||
# attribute (attention layers) Skip non-attention
|
||||
# layers like FusedMoE/MLP etc.
|
||||
kv_cache_attr = getattr(layer, "kv_cache", None)
|
||||
if kv_cache_attr is None:
|
||||
continue
|
||||
|
||||
kv_cache_layer = kv_cache_attr[forward_context.virtual_engine]
|
||||
|
||||
filename = self._generate_filename_debug(
|
||||
layer_name, request.token_ids, request.mm_hashes
|
||||
)
|
||||
kv_cache = safetensors.torch.load_file(filename)["kv_cache"].cuda()
|
||||
inject_kv_into_layer(kv_cache_layer, kv_cache, request.slot_mapping)
|
||||
|
||||
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||
"""Blocking until the KV for a specific layer is loaded into vLLM's
|
||||
paged buffer.
|
||||
|
||||
This interface will be useful for layer-by-layer pipelining.
|
||||
|
||||
Args:
|
||||
layer_name: the name of that layer
|
||||
"""
|
||||
return
|
||||
|
||||
def save_kv_layer(
|
||||
self,
|
||||
layer_name: str,
|
||||
kv_layer: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Start saving the KV cache of the layer from vLLM's paged buffer
|
||||
to the connector.
|
||||
|
||||
Args:
|
||||
layer_name (str): the name of the layer.
|
||||
kv_layer (torch.Tensor): the paged KV buffer of the current
|
||||
layer in vLLM.
|
||||
attn_metadata (AttentionMetadata): the attention metadata.
|
||||
**kwargs: additional arguments for the save operation.
|
||||
"""
|
||||
|
||||
def extract_kv_from_layer(
|
||||
layer: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Extract the KV cache from the layer.
|
||||
|
||||
Assume the shape of the layer is (2, num_pages, page_size, xxx)
|
||||
if MLA is not used, and (num_pages, page_size, xxx) otherwise.
|
||||
"""
|
||||
if isinstance(attn_metadata, MLACommonMetadata):
|
||||
num_pages, page_size = layer.shape[0], layer.shape[1]
|
||||
return layer.reshape(num_pages * page_size, -1)[slot_mapping, ...]
|
||||
num_pages, page_size = layer.shape[1], layer.shape[2]
|
||||
return layer.reshape(2, num_pages * page_size, -1)[:, slot_mapping, ...]
|
||||
|
||||
connector_metadata = self._get_connector_metadata()
|
||||
assert isinstance(connector_metadata, ExampleConnectorMetadata)
|
||||
for request in connector_metadata.requests:
|
||||
if request.is_store:
|
||||
filename = self._generate_filename_debug(
|
||||
layer_name, request.token_ids, request.mm_hashes
|
||||
)
|
||||
kv_cache = extract_kv_from_layer(kv_layer, request.slot_mapping)
|
||||
tensors = {"kv_cache": kv_cache.detach().cpu()}
|
||||
safetensors.torch.save_file(tensors, filename)
|
||||
|
||||
def wait_for_save(self):
|
||||
return
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self,
|
||||
request: "Request",
|
||||
num_computed_tokens: int,
|
||||
) -> tuple[int | None, bool]:
|
||||
"""
|
||||
Get number of new tokens that can be loaded from the
|
||||
external KV cache beyond the num_computed_tokens.
|
||||
|
||||
Args:
|
||||
request (Request): the request object.
|
||||
num_computed_tokens (int): the number of locally
|
||||
computed tokens for this request
|
||||
|
||||
Returns:
|
||||
the number of tokens that can be loaded from the
|
||||
external KV cache beyond what is already computed.
|
||||
"""
|
||||
# NOTE: in this debug implementation, we assume that the prompt is
|
||||
# cached_prompt + newly_generated_single_token
|
||||
# Therefore, we use prompt_token_ids[:-1] to determine the folder name
|
||||
|
||||
# NOTE: in current v1 scheduler, the num_computed_tokens is aligned
|
||||
# with the block granularity. And it expects the returned blocks and
|
||||
# num_computed_tokens to also be aligned with the block granularity.
|
||||
if not self._found_match_for_request(request):
|
||||
return 0, False
|
||||
|
||||
logger.info("External Cache Hit!")
|
||||
|
||||
# Now, first num_tokens_to_check tokens are hit, we need to prepare
|
||||
# the metadata for the worker connector to correctly load the KV
|
||||
token_ids = request.prompt_token_ids or []
|
||||
num_tokens_to_check = align_to_block_size(len(token_ids) - 1, self._block_size)
|
||||
|
||||
return num_tokens_to_check - num_computed_tokens, False
|
||||
|
||||
def update_state_after_alloc(
|
||||
self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
|
||||
):
|
||||
"""
|
||||
Update KVConnector state after block allocation.
|
||||
|
||||
If blocks were allocated, add to _requests_need_load,
|
||||
such that we load the KVs in the next forward pass.
|
||||
"""
|
||||
if num_external_tokens > 0:
|
||||
self._requests_need_load[request.request_id] = request
|
||||
|
||||
def build_connector_meta(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
) -> KVConnectorMetadata:
|
||||
"""Build the connector metadata for this step.
|
||||
|
||||
This function should NOT modify any fields in the scheduler_output.
|
||||
Also, calling this function will reset the state of the connector.
|
||||
|
||||
Args:
|
||||
scheduler_output (SchedulerOutput): the scheduler output object.
|
||||
"""
|
||||
meta = ExampleConnectorMetadata()
|
||||
|
||||
total_need_load = 0
|
||||
for new_req in scheduler_output.scheduled_new_reqs:
|
||||
token_ids = new_req.prompt_token_ids or []
|
||||
mm_hashes = [f.identifier for f in new_req.mm_features]
|
||||
if new_req.req_id in self._requests_need_load:
|
||||
meta.add_request(
|
||||
token_ids=token_ids,
|
||||
block_ids=new_req.block_ids[0],
|
||||
block_size=self._block_size,
|
||||
is_store=False,
|
||||
mm_hashes=mm_hashes,
|
||||
)
|
||||
total_need_load += 1
|
||||
else:
|
||||
# NOTE: here, we set the store and load being exclusive,
|
||||
# but a single request can have both store and load.
|
||||
# NOTE(rob): for this debug implementation, we only cache
|
||||
# the original prompt tokens.
|
||||
if not self._found_match_for_prompt(token_ids, mm_hashes):
|
||||
meta.add_request(
|
||||
token_ids=token_ids,
|
||||
block_ids=new_req.block_ids[0],
|
||||
block_size=self._block_size,
|
||||
is_store=True,
|
||||
mm_hashes=mm_hashes,
|
||||
)
|
||||
|
||||
cached_reqs = scheduler_output.scheduled_cached_reqs
|
||||
for i, req_id in enumerate(cached_reqs.req_ids):
|
||||
resumed_from_preemption = req_id in cached_reqs.resumed_req_ids
|
||||
if not resumed_from_preemption or req_id not in self._requests_need_load:
|
||||
continue
|
||||
|
||||
num_computed_tokens = cached_reqs.num_computed_tokens[i]
|
||||
num_new_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
||||
new_block_ids = cached_reqs.new_block_ids[i]
|
||||
|
||||
# NOTE(rob): cached_req_data does not have the full
|
||||
# list of token ids (only new tokens). So we look it
|
||||
# up in the actual request object.
|
||||
request = self._requests_need_load[req_id]
|
||||
total_tokens = num_computed_tokens + num_new_tokens
|
||||
token_ids = request.all_token_ids[:total_tokens]
|
||||
|
||||
# NOTE(rob): For resumed req, new_block_ids is all
|
||||
# of the block_ids for the request.
|
||||
assert new_block_ids is not None
|
||||
block_ids = new_block_ids[0]
|
||||
|
||||
meta.add_request(
|
||||
token_ids=token_ids,
|
||||
block_ids=block_ids,
|
||||
block_size=self._block_size,
|
||||
is_store=False,
|
||||
mm_hashes=[f.identifier for f in request.mm_features],
|
||||
)
|
||||
total_need_load += 1
|
||||
|
||||
assert total_need_load == len(self._requests_need_load)
|
||||
self._requests_need_load.clear()
|
||||
return meta
|
||||
|
||||
# ==============================
|
||||
# Helper functions
|
||||
# ==============================
|
||||
|
||||
def _found_match_for_request(
|
||||
self,
|
||||
request: "Request",
|
||||
) -> bool:
|
||||
"""Check if the cache is hit for the request."""
|
||||
return self._found_match_for_prompt(
|
||||
list(request.prompt_token_ids or []),
|
||||
[f.identifier for f in request.mm_features],
|
||||
)
|
||||
|
||||
def _found_match_for_prompt(
|
||||
self,
|
||||
prompt_token_ids: list[int],
|
||||
mm_hashes: list[str],
|
||||
) -> bool:
|
||||
num_tokens_to_check = align_to_block_size(
|
||||
len(prompt_token_ids) - 1, self._block_size
|
||||
)
|
||||
foldername = self._generate_foldername_debug(
|
||||
torch.tensor(prompt_token_ids)[:num_tokens_to_check],
|
||||
mm_hashes,
|
||||
create_folder=False,
|
||||
)
|
||||
return os.path.exists(foldername)
|
||||
|
||||
def _generate_foldername_debug(
|
||||
self,
|
||||
token_ids: torch.Tensor,
|
||||
mm_hashes: list[str],
|
||||
create_folder=False,
|
||||
) -> str:
|
||||
"""Generate a folder name based on the hash of the bytes of the input
|
||||
ids.
|
||||
"""
|
||||
token_bytes = token_ids.numpy().tobytes()
|
||||
# Add mm_hashes to the bytes being hashed to avoid path traversal and
|
||||
# to create a canonical key.
|
||||
if mm_hashes:
|
||||
mm_str = "-".join(mm_hashes)
|
||||
token_bytes += mm_str.encode("utf-8")
|
||||
input_ids_hash = safe_hash(token_bytes, usedforsecurity=False).hexdigest()
|
||||
|
||||
foldername = os.path.join(self._storage_path, input_ids_hash)
|
||||
if create_folder:
|
||||
os.makedirs(foldername, exist_ok=True)
|
||||
return foldername
|
||||
|
||||
def _generate_filename_debug(
|
||||
self,
|
||||
layer_name: str,
|
||||
token_ids: torch.Tensor,
|
||||
mm_hashes: list[str],
|
||||
) -> str:
|
||||
"""Generate a file name based on the layer name and the hash
|
||||
of the bytes of the input ids.
|
||||
"""
|
||||
foldername = self._generate_foldername_debug(
|
||||
token_ids, mm_hashes=mm_hashes, create_folder=True
|
||||
)
|
||||
return os.path.join(foldername, f"{layer_name}.safetensors")
|
||||
|
||||
|
||||
def align_to_block_size(num_tokens: int, block_size) -> int:
|
||||
"""Align the number of tokens to the block size."""
|
||||
return (num_tokens - 1) // block_size * block_size
|
||||
@@ -0,0 +1,344 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Iterable
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_events import (
|
||||
BlockStored,
|
||||
KVCacheEvent,
|
||||
KVConnectorKVEvents,
|
||||
KVEventAggregator,
|
||||
)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1,
|
||||
KVConnectorMetadata,
|
||||
KVConnectorRole,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backend import AttentionMetadata
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.outputs import KVConnectorOutput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.request import Request
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class LMCacheKVEvents(KVConnectorKVEvents):
|
||||
"""
|
||||
Concrete implementation of KVConnectorKVEvents using KVEventAggregator.
|
||||
"""
|
||||
|
||||
def __init__(self, num_workers: int) -> None:
|
||||
self._aggregator = KVEventAggregator(num_workers)
|
||||
|
||||
def add_events(self, events: list[KVCacheEvent]) -> None:
|
||||
self._aggregator.add_events(events)
|
||||
|
||||
def aggregate(self) -> "LMCacheKVEvents":
|
||||
"""
|
||||
Aggregate KV events and retain only common events.
|
||||
"""
|
||||
common_events = self._aggregator.get_common_events()
|
||||
self._aggregator.clear_events()
|
||||
self._aggregator.add_events(common_events)
|
||||
self._aggregator.reset_workers()
|
||||
return self
|
||||
|
||||
def increment_workers(self, count: int = 1) -> None:
|
||||
self._aggregator.increment_workers(count)
|
||||
|
||||
def get_all_events(self) -> list[KVCacheEvent]:
|
||||
return self._aggregator.get_all_events()
|
||||
|
||||
def get_number_of_workers(self) -> int:
|
||||
return self._aggregator.get_number_of_workers()
|
||||
|
||||
def clear_events(self) -> None:
|
||||
self._aggregator.clear_events()
|
||||
self._aggregator.reset_workers()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<LMCacheKVEvents events={self.get_all_events()}>"
|
||||
|
||||
|
||||
class LMCacheConnectorV1(KVConnectorBase_V1):
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: "VllmConfig",
|
||||
role: KVConnectorRole,
|
||||
kv_cache_config: "KVCacheConfig",
|
||||
):
|
||||
super().__init__(
|
||||
vllm_config=vllm_config, role=role, kv_cache_config=kv_cache_config
|
||||
)
|
||||
assert vllm_config.kv_transfer_config is not None
|
||||
use_native = vllm_config.kv_transfer_config.get_from_extra_config(
|
||||
"use_native", False
|
||||
)
|
||||
if use_native:
|
||||
logger.info("Initializing native LMCache connector")
|
||||
# lazy import
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1 import lmcache_integration
|
||||
|
||||
_adapter = lmcache_integration.vllm_v1_adapter
|
||||
|
||||
cls = _adapter.LMCacheConnectorV1Impl
|
||||
else:
|
||||
logger.info("Initializing latest dev LMCache connector")
|
||||
# lazy import
|
||||
from lmcache.integration.vllm.vllm_v1_adapter import (
|
||||
LMCacheConnectorV1Impl as LMCacheConnectorLatestImpl,
|
||||
)
|
||||
|
||||
cls = LMCacheConnectorLatestImpl
|
||||
|
||||
self._lmcache_engine = cls(vllm_config, role, self)
|
||||
|
||||
self._kv_cache_events: LMCacheKVEvents | None = None
|
||||
|
||||
# ==============================
|
||||
# Worker-side methods
|
||||
# ==============================
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
"""
|
||||
Initialize with the KV caches. Useful for pre-registering the
|
||||
KV Caches in the KVConnector (e.g. for NIXL).
|
||||
|
||||
Args:
|
||||
kv_caches: dictionary of layer names, kv cache
|
||||
"""
|
||||
if hasattr(self._lmcache_engine, "register_kv_caches"):
|
||||
self._lmcache_engine.register_kv_caches(kv_caches)
|
||||
else:
|
||||
logger.warning(
|
||||
"LMCache engine does not support register_kv_caches, "
|
||||
"please check and use the latest version"
|
||||
)
|
||||
|
||||
def start_load_kv(self, forward_context: "ForwardContext", **kwargs: Any) -> None:
|
||||
"""
|
||||
Start loading the KV cache from the connector to vLLM's paged
|
||||
KV buffer. This is called from the forward context before the
|
||||
forward pass to enable async loading during model execution.
|
||||
|
||||
Args:
|
||||
forward_context (ForwardContext): the forward context.
|
||||
**kwargs: additional arguments for the load operation
|
||||
|
||||
Note:
|
||||
The number of elements in kv_caches and layer_names should be
|
||||
the same.
|
||||
|
||||
"""
|
||||
self._lmcache_engine.start_load_kv(forward_context, **kwargs)
|
||||
|
||||
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||
"""
|
||||
Block until the KV for a specific layer is loaded into vLLM's
|
||||
paged buffer. This is called from within attention layer to ensure
|
||||
async copying from start_load_kv is complete.
|
||||
|
||||
This interface will be useful for layer-by-layer pipelining.
|
||||
|
||||
Args:
|
||||
layer_name: the name of that layer
|
||||
"""
|
||||
self._lmcache_engine.wait_for_layer_load(layer_name)
|
||||
|
||||
def save_kv_layer(
|
||||
self,
|
||||
layer_name: str,
|
||||
kv_layer: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Start saving the a layer of KV cache from vLLM's paged buffer
|
||||
to the connector. This is called from within attention layer to
|
||||
enable async copying during execution.
|
||||
|
||||
Args:
|
||||
layer_name (str): the name of the layer.
|
||||
kv_layer (torch.Tensor): the paged KV buffer of the current
|
||||
layer in vLLM.
|
||||
attn_metadata (AttentionMetadata): the attention metadata.
|
||||
**kwargs: additional arguments for the save operation.
|
||||
"""
|
||||
self._lmcache_engine.save_kv_layer(
|
||||
layer_name, kv_layer, attn_metadata, **kwargs
|
||||
)
|
||||
|
||||
def wait_for_save(self):
|
||||
"""
|
||||
Block until all the save operations is done. This is called
|
||||
as the forward context exits to ensure that the async saving
|
||||
from save_kv_layer is complete before finishing the forward.
|
||||
|
||||
This prevents overwrites of paged KV buffer before saving done.
|
||||
"""
|
||||
self._lmcache_engine.wait_for_save()
|
||||
|
||||
def get_finished(
|
||||
self, finished_req_ids: set[str]
|
||||
) -> tuple[set[str] | None, set[str] | None]:
|
||||
"""
|
||||
Notifies worker-side connector ids of requests that have
|
||||
finished generating tokens.
|
||||
|
||||
Returns:
|
||||
ids of requests that have finished asynchronous transfer
|
||||
(requests that previously returned True from request_finished()),
|
||||
tuple of (sending/saving ids, recving/loading ids).
|
||||
The finished saves/sends req ids must belong to a set provided in a
|
||||
call to this method (this call or a prior one).
|
||||
"""
|
||||
return self._lmcache_engine.get_finished(finished_req_ids)
|
||||
|
||||
def get_block_ids_with_load_errors(self) -> set[int]:
|
||||
"""
|
||||
Get the set of block IDs that failed to load.
|
||||
|
||||
Returns:
|
||||
Set of block IDs that encountered load errors.
|
||||
Empty set if no load errors occurred.
|
||||
"""
|
||||
method = getattr(self._lmcache_engine, "get_block_ids_with_load_errors", None)
|
||||
if callable(method):
|
||||
return method()
|
||||
|
||||
# Fallback for older versions that don't support this method
|
||||
return set()
|
||||
|
||||
def get_kv_connector_kv_cache_events(self) -> LMCacheKVEvents | None:
|
||||
"""
|
||||
Get the KV connector kv cache events collected during the last interval.
|
||||
"""
|
||||
|
||||
events = self._lmcache_engine.get_kv_events() # type: ignore [attr-defined]
|
||||
if not events:
|
||||
return None
|
||||
|
||||
blocks: list[BlockStored] = [
|
||||
BlockStored(
|
||||
block_hashes=e.block_hashes,
|
||||
parent_block_hash=e.parent_block_hash,
|
||||
token_ids=e.token_ids,
|
||||
lora_id=e.lora_id,
|
||||
block_size=e.block_size,
|
||||
medium=e.medium,
|
||||
lora_name=getattr(e, "lora_name", None),
|
||||
)
|
||||
for e in events
|
||||
]
|
||||
|
||||
lmcache_kv_events = LMCacheKVEvents(num_workers=1)
|
||||
lmcache_kv_events.add_events(blocks)
|
||||
return lmcache_kv_events
|
||||
|
||||
# ==============================
|
||||
# Scheduler-side methods
|
||||
# ==============================
|
||||
def get_num_new_matched_tokens(
|
||||
self,
|
||||
request: "Request",
|
||||
num_computed_tokens: int,
|
||||
) -> tuple[int | None, bool]:
|
||||
"""
|
||||
Get number of new tokens that can be loaded from the
|
||||
external KV cache beyond the num_computed_tokens.
|
||||
|
||||
Args:
|
||||
request (Request): the request object.
|
||||
num_computed_tokens (int): the number of locally
|
||||
computed tokens for this request
|
||||
|
||||
Returns:
|
||||
the number of tokens that can be loaded from the
|
||||
external KV cache beyond what is already computed.
|
||||
"""
|
||||
return self._lmcache_engine.get_num_new_matched_tokens(
|
||||
request, num_computed_tokens
|
||||
), False
|
||||
|
||||
def update_state_after_alloc(
|
||||
self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
|
||||
):
|
||||
"""
|
||||
Update KVConnector state after block allocation.
|
||||
"""
|
||||
self._lmcache_engine.update_state_after_alloc(request, num_external_tokens)
|
||||
|
||||
def build_connector_meta(
|
||||
self, scheduler_output: SchedulerOutput
|
||||
) -> KVConnectorMetadata:
|
||||
"""
|
||||
Build the connector metadata for this step.
|
||||
|
||||
This function should NOT modify fields in the scheduler_output.
|
||||
Also, calling this function will reset the state of the connector.
|
||||
|
||||
Args:
|
||||
scheduler_output (SchedulerOutput): the scheduler output object.
|
||||
"""
|
||||
return self._lmcache_engine.build_connector_meta(scheduler_output)
|
||||
|
||||
def update_connector_output(self, connector_output: KVConnectorOutput):
|
||||
"""
|
||||
Update KVConnector state from worker-side connectors output.
|
||||
|
||||
Args:
|
||||
connector_output (KVConnectorOutput): the worker-side
|
||||
connectors output.
|
||||
"""
|
||||
# Get the KV events
|
||||
kv_cache_events = connector_output.kv_cache_events
|
||||
if not kv_cache_events or not isinstance(kv_cache_events, LMCacheKVEvents):
|
||||
return
|
||||
|
||||
if self._kv_cache_events is None:
|
||||
self._kv_cache_events = kv_cache_events
|
||||
else:
|
||||
self._kv_cache_events.add_events(kv_cache_events.get_all_events())
|
||||
self._kv_cache_events.increment_workers(
|
||||
kv_cache_events.get_number_of_workers()
|
||||
)
|
||||
return
|
||||
|
||||
def request_finished(
|
||||
self,
|
||||
request: "Request",
|
||||
block_ids: list[int],
|
||||
) -> tuple[bool, dict[str, Any] | None]:
|
||||
"""
|
||||
Called when a request has finished, before its blocks are freed.
|
||||
|
||||
Returns:
|
||||
True if the request is being saved/sent asynchronously and blocks
|
||||
should not be freed until the request_id is returned from
|
||||
get_finished().
|
||||
Optional KVTransferParams to be included in the request outputs
|
||||
returned by the engine.
|
||||
"""
|
||||
return self._lmcache_engine.request_finished(request, block_ids)
|
||||
|
||||
def take_events(self) -> Iterable["KVCacheEvent"]:
|
||||
"""
|
||||
Take the KV cache events from the connector.
|
||||
|
||||
Yields:
|
||||
New KV cache events since the last call.
|
||||
"""
|
||||
if self._kv_cache_events is not None:
|
||||
self._kv_cache_events.aggregate()
|
||||
kv_cache_events = self._kv_cache_events.get_all_events()
|
||||
yield from kv_cache_events
|
||||
self._kv_cache_events.clear_events()
|
||||
self._kv_cache_events = None
|
||||
@@ -0,0 +1,18 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
from . import multi_process_adapter, vllm_v1_adapter
|
||||
from .multi_process_adapter import (
|
||||
LMCacheMPSchedulerAdapter,
|
||||
LMCacheMPWorkerAdapter,
|
||||
LoadStoreOp,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"vllm_v1_adapter",
|
||||
"multi_process_adapter",
|
||||
"LMCacheMPSchedulerAdapter",
|
||||
"LMCacheMPWorkerAdapter",
|
||||
"LoadStoreOp",
|
||||
]
|
||||
@@ -0,0 +1,666 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import os
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from itertools import islice
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import zmq
|
||||
from lmcache.utils import _lmcache_nvtx_annotate, init_logger
|
||||
from lmcache.v1.multiprocess.custom_types import (
|
||||
CudaIPCWrapper,
|
||||
IPCCacheEngineKey,
|
||||
KVCache,
|
||||
)
|
||||
from lmcache.v1.multiprocess.mq import MessageQueueClient, MessagingFuture
|
||||
from lmcache.v1.multiprocess.protocol import RequestType, get_response_class
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def wrap_kv_caches(kv_caches: dict[str, torch.Tensor]) -> KVCache:
|
||||
logger.info("KV caches keys are %s", list(kv_caches.keys()))
|
||||
return [CudaIPCWrapper(tensor) for tensor in kv_caches.values()]
|
||||
|
||||
|
||||
def striding_block_hashes(
|
||||
block_hashes: list[bytes], blocks_in_chunk: int
|
||||
) -> Iterable[bytes]:
|
||||
"""Extract chunk-level hashes from block hashes by striding.
|
||||
|
||||
In hash-based vLLM, each vLLM block has its own hash. LMCache chunks
|
||||
span ``blocks_in_chunk`` consecutive blocks. The representative hash
|
||||
for a chunk is the hash of the **last** block in that chunk (because
|
||||
each block hash already encodes its prefix). So we start at index
|
||||
``blocks_in_chunk - 1`` and stride by ``blocks_in_chunk``.
|
||||
"""
|
||||
return islice(block_hashes, blocks_in_chunk - 1, None, blocks_in_chunk)
|
||||
|
||||
|
||||
def send_lmcache_request(
|
||||
mq_client: MessageQueueClient,
|
||||
request_type: RequestType,
|
||||
payloads: list[Any],
|
||||
) -> MessagingFuture[Any]:
|
||||
"""
|
||||
Helper function to send the request to the LMCache multiprocess server
|
||||
|
||||
Args:
|
||||
mq_client: The LMCache multiprocess mode message queue client
|
||||
request_type: The request type
|
||||
payloads: The request payloads
|
||||
|
||||
Returns:
|
||||
A messaging future for the request
|
||||
"""
|
||||
|
||||
future = mq_client.submit_request(
|
||||
request_type, payloads, get_response_class(request_type)
|
||||
)
|
||||
return future
|
||||
|
||||
|
||||
def get_lmcache_chunk_size(
|
||||
mq_client: MessageQueueClient,
|
||||
) -> int:
|
||||
"""
|
||||
Helper function to get the LMCache chunk size from the server
|
||||
|
||||
Args:
|
||||
mq_client: The LMCache multiprocess mode message queue client
|
||||
|
||||
Returns:
|
||||
An integer representing the LMCache chunk size
|
||||
"""
|
||||
future = send_lmcache_request(mq_client, RequestType.GET_CHUNK_SIZE, [])
|
||||
chunk_size = future.result()
|
||||
return chunk_size
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoadStoreOp:
|
||||
block_ids: list[int]
|
||||
"""Block ids for the load/store operation"""
|
||||
|
||||
token_ids: list[int] | None = None
|
||||
"""Token IDs for the load/store operation (token mode)"""
|
||||
|
||||
block_hashes: list[bytes] | None = None
|
||||
"""Block hashes for the load/store operation (hash mode)"""
|
||||
|
||||
start: int = 0
|
||||
"""Start token index (token mode only)"""
|
||||
|
||||
end: int = 0
|
||||
"""End token index (token mode only)"""
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.block_ids)
|
||||
|
||||
|
||||
StoreResult = bool
|
||||
RetrieveResult = list[bool]
|
||||
LookupResult = int
|
||||
|
||||
|
||||
class LMCacheMPSchedulerAdapter:
|
||||
def __init__(
|
||||
self,
|
||||
server_url: str,
|
||||
context: zmq.Context,
|
||||
model_name: str,
|
||||
world_size: int,
|
||||
kv_rank: int,
|
||||
vllm_block_size: int,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
server_url: The server URL for the LMCache message queue
|
||||
context: The ZMQ context
|
||||
|
||||
model_name: The model name used for LMCache keys
|
||||
world_size: The world size used for LMCache keys
|
||||
kv_rank: The kv rank used for LMCache keys
|
||||
vllm_block_size: The block size used in vLLM
|
||||
"""
|
||||
self.mq_client = MessageQueueClient(server_url, context)
|
||||
|
||||
# Request futures
|
||||
self.lookup_futures: dict[str, MessagingFuture[LookupResult]] = {}
|
||||
|
||||
self.model_name = model_name
|
||||
self.world_size = world_size
|
||||
self.worker_id = kv_rank
|
||||
|
||||
# Read chunk size from lmcache
|
||||
self.chunk_size = get_lmcache_chunk_size(self.mq_client)
|
||||
assert self.chunk_size % vllm_block_size == 0, (
|
||||
"LMCache chunk size should be a multiple of vLLM block size"
|
||||
)
|
||||
self.blocks_in_chunk = self.chunk_size // vllm_block_size
|
||||
|
||||
@_lmcache_nvtx_annotate
|
||||
def maybe_submit_lookup_request(
|
||||
self,
|
||||
request_id: str,
|
||||
block_hashes: list[bytes] | None = None,
|
||||
token_ids: list[int] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Submit a new lookup request to LMCache if there is no ongoing request.
|
||||
|
||||
Supports both token-based and hash-based vLLM:
|
||||
- token_ids: token IDs (token-based vLLM) -> single token-mode key
|
||||
- block_hashes: block hashes (hash-based vLLM) -> strided hash-mode keys
|
||||
|
||||
Exactly one of block_hashes or token_ids must be provided.
|
||||
|
||||
Args:
|
||||
request_id: The ID of the lookup request. The same ID indicates it's
|
||||
from the same request
|
||||
block_hashes: Block hashes to lookup from LMCache (hash mode)
|
||||
token_ids: Token IDs to lookup from LMCache (token mode)
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
Notes:
|
||||
This function will have a side-effect: submitting a look up request to
|
||||
LMCache, which will essentially 'lock' the KV cache chunks in the LMCache
|
||||
for later retrieve operations.
|
||||
In the meantime, this function will record the lookup request, and the
|
||||
status of the look up request can be checked by `check_lookup_result`.
|
||||
"""
|
||||
if request_id in self.lookup_futures:
|
||||
# Skip if there is already a lookup request
|
||||
return
|
||||
|
||||
assert (block_hashes is None) != (token_ids is None), (
|
||||
"Exactly one of block_hashes or token_ids must be provided"
|
||||
)
|
||||
|
||||
if block_hashes is not None:
|
||||
# Hash mode: stride block hashes -> N hash-mode keys
|
||||
chunk_hashes = list(
|
||||
striding_block_hashes(block_hashes, self.blocks_in_chunk)
|
||||
)
|
||||
keys = [
|
||||
self._create_hash_key(ch, request_id=request_id) for ch in chunk_hashes
|
||||
]
|
||||
else:
|
||||
# Token mode: truncate to chunk-aligned length
|
||||
assert token_ids is not None
|
||||
aligned_end = (len(token_ids) // self.chunk_size) * self.chunk_size
|
||||
if aligned_end == 0:
|
||||
return
|
||||
keys = [
|
||||
self._create_key(
|
||||
token_ids,
|
||||
start=0,
|
||||
end=aligned_end,
|
||||
request_id=request_id,
|
||||
).no_worker_id_version()
|
||||
]
|
||||
|
||||
future = send_lmcache_request(
|
||||
self.mq_client,
|
||||
RequestType.LOOKUP,
|
||||
[keys],
|
||||
)
|
||||
self.lookup_futures[request_id] = future
|
||||
|
||||
@_lmcache_nvtx_annotate
|
||||
def check_lookup_result(self, request_id: str) -> int | None:
|
||||
"""
|
||||
Check the result of a previously submitted lookup request.
|
||||
|
||||
Args:
|
||||
request_id: The ID of the lookup request submitted in
|
||||
`maybe_submit_lookup_request`
|
||||
|
||||
Returns:
|
||||
An integer representing the total number of tokens matched
|
||||
in LMCache (prefix matching), or
|
||||
None if the lookup request is not finished yet.
|
||||
"""
|
||||
assert request_id in self.lookup_futures, (
|
||||
f"Lookup request for request_id={request_id} has not been submitted"
|
||||
)
|
||||
|
||||
future = self.lookup_futures[request_id]
|
||||
if not future.query():
|
||||
return None
|
||||
|
||||
result = future.result()
|
||||
num_chunks = result
|
||||
return num_chunks * self.chunk_size
|
||||
|
||||
def num_blocks_per_chunk(self) -> int:
|
||||
"""
|
||||
Returns:
|
||||
The number of vllm blocks in a LMCache data chunk
|
||||
"""
|
||||
return self.blocks_in_chunk
|
||||
|
||||
def cleanup_lookup_result(self, request_id: str) -> None:
|
||||
"""
|
||||
Clean up lookup future for a finished request to prevent memory leak.
|
||||
Args:
|
||||
request_id: The ID of the finished request.
|
||||
"""
|
||||
self.lookup_futures.pop(request_id, None)
|
||||
|
||||
def end_session(self, request_id: str) -> None:
|
||||
"""
|
||||
Notify LMCache server to remove the session for a finished request.
|
||||
Args:
|
||||
request_id: The ID of the finished request.
|
||||
"""
|
||||
send_lmcache_request(
|
||||
self.mq_client,
|
||||
RequestType.END_SESSION,
|
||||
[request_id],
|
||||
)
|
||||
|
||||
# Helper functions
|
||||
def _create_key(
|
||||
self,
|
||||
token_ids: list[int],
|
||||
start: int = 0,
|
||||
end: int = 0,
|
||||
request_id: str | None = None,
|
||||
) -> IPCCacheEngineKey:
|
||||
"""Convert token IDs to an IPC cache engine key"""
|
||||
return IPCCacheEngineKey(
|
||||
model_name=self.model_name,
|
||||
world_size=self.world_size,
|
||||
worker_id=self.worker_id,
|
||||
token_ids=tuple(token_ids),
|
||||
start=start,
|
||||
end=end,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
def _create_hash_key(
|
||||
self, chunk_hash: bytes, request_id: str | None = None
|
||||
) -> IPCCacheEngineKey:
|
||||
"""Create a hash-mode IPC cache engine key"""
|
||||
return IPCCacheEngineKey(
|
||||
model_name=self.model_name,
|
||||
world_size=self.world_size,
|
||||
worker_id=None,
|
||||
chunk_hash=chunk_hash,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
|
||||
class LMCacheMPWorkerAdapter:
|
||||
def __init__(
|
||||
self,
|
||||
server_url: str,
|
||||
context: zmq.Context,
|
||||
model_name: str,
|
||||
world_size: int,
|
||||
kv_rank: int,
|
||||
vllm_block_size: int,
|
||||
):
|
||||
self.mq_client = MessageQueueClient(server_url, context)
|
||||
|
||||
# Instance id for GPU worker
|
||||
self.instance_id = os.getpid()
|
||||
|
||||
# Registered kv caches from vLLM
|
||||
self.kv_caches: dict[str, torch.Tensor] = {}
|
||||
|
||||
# Request futures
|
||||
# request_id -> (future, other merged requests)
|
||||
self.store_futures: dict[
|
||||
str, tuple[MessagingFuture[StoreResult], list[str]]
|
||||
] = {}
|
||||
self.retrieve_futures: dict[
|
||||
str, tuple[MessagingFuture[RetrieveResult], list[str]]
|
||||
] = {}
|
||||
|
||||
# The store requests that have finished execution in LMCache
|
||||
self.finished_stores: set[str] = set()
|
||||
# The finished request ids that are passed via vLLM and also
|
||||
# have corresponding store requests submitted to LMCache before
|
||||
self.previously_finished: set[str] = set()
|
||||
|
||||
self.model_name = model_name
|
||||
self.world_size = world_size
|
||||
self.worker_id = kv_rank
|
||||
|
||||
# Read chunk size from lmcache
|
||||
chunk_size = get_lmcache_chunk_size(self.mq_client)
|
||||
assert chunk_size % vllm_block_size == 0, (
|
||||
"LMCache chunk size should be a multiple of vLLM block size"
|
||||
)
|
||||
self.blocks_in_chunk = chunk_size // vllm_block_size
|
||||
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
"""
|
||||
Register the kv caches with LMCache server
|
||||
|
||||
Args:
|
||||
kv_caches: A dict of kv caches to register. The keys are the
|
||||
layer names and the values are the corresponding tensors.
|
||||
"""
|
||||
# Register kv cache and send the request
|
||||
self.kv_caches = kv_caches
|
||||
logger.info("Registering kv caches")
|
||||
future = send_lmcache_request(
|
||||
self.mq_client,
|
||||
RequestType.REGISTER_KV_CACHE,
|
||||
[self.instance_id, wrap_kv_caches(kv_caches)],
|
||||
)
|
||||
future.result()
|
||||
|
||||
@_lmcache_nvtx_annotate
|
||||
def submit_store_request(
|
||||
self, request_id: str, op: LoadStoreOp, event: torch.cuda.Event
|
||||
):
|
||||
"""
|
||||
Submit a KV cache store request to LMCache
|
||||
|
||||
Args:
|
||||
request_id: The ID of the request
|
||||
op: The LoadStoreOp describing the store operation.
|
||||
event: The CUDA event that is recorded after the current
|
||||
model inference step
|
||||
"""
|
||||
if op.block_hashes is not None:
|
||||
# Hash mode
|
||||
chunk_hashes = list(
|
||||
striding_block_hashes(op.block_hashes, self.blocks_in_chunk)
|
||||
)
|
||||
keys = [
|
||||
self._create_hash_key(ch, request_id=request_id) for ch in chunk_hashes
|
||||
]
|
||||
else:
|
||||
# Token mode
|
||||
assert op.token_ids is not None
|
||||
keys = [
|
||||
self._create_key(op.token_ids, op.start, op.end, request_id=request_id)
|
||||
]
|
||||
future = send_lmcache_request(
|
||||
self.mq_client,
|
||||
RequestType.STORE,
|
||||
[keys, self.instance_id, op.block_ids, event.ipc_handle()],
|
||||
).to_cuda_future()
|
||||
self.store_futures[request_id] = (future, [])
|
||||
|
||||
@_lmcache_nvtx_annotate
|
||||
def submit_retrieve_request(
|
||||
self, request_id: str, op: LoadStoreOp, event: torch.cuda.Event
|
||||
):
|
||||
"""
|
||||
Submit a KV cache retrieve request to LMCache
|
||||
|
||||
Args:
|
||||
request_id: The ID of the request
|
||||
op: The LoadStoreOp describing the retrieve operation.
|
||||
event: The CUDA event that is recorded after the current
|
||||
model inference step
|
||||
"""
|
||||
if op.block_hashes is not None:
|
||||
# Hash mode
|
||||
chunk_hashes = list(
|
||||
striding_block_hashes(op.block_hashes, self.blocks_in_chunk)
|
||||
)
|
||||
keys = [
|
||||
self._create_hash_key(ch, request_id=request_id) for ch in chunk_hashes
|
||||
]
|
||||
else:
|
||||
# Token mode
|
||||
assert op.token_ids is not None
|
||||
keys = [
|
||||
self._create_key(op.token_ids, op.start, op.end, request_id=request_id)
|
||||
]
|
||||
future = send_lmcache_request(
|
||||
self.mq_client,
|
||||
RequestType.RETRIEVE,
|
||||
[keys, self.instance_id, op.block_ids, event.ipc_handle()],
|
||||
).to_cuda_future()
|
||||
self.retrieve_futures[request_id] = (future, [])
|
||||
|
||||
@_lmcache_nvtx_annotate
|
||||
def batched_submit_store_requests(
|
||||
self,
|
||||
request_ids: list[str],
|
||||
ops: list[LoadStoreOp],
|
||||
event: torch.cuda.Event,
|
||||
):
|
||||
"""
|
||||
Submit a batched store request to LMCache
|
||||
|
||||
Args:
|
||||
request_ids: The IDs of the requests
|
||||
ops: The LoadStoreOps describing the store operations. Should have
|
||||
the same length as request_ids
|
||||
event: The CUDA event that is recorded after the current
|
||||
model inference step
|
||||
"""
|
||||
all_keys: list[IPCCacheEngineKey] = []
|
||||
block_ids: list[int] = []
|
||||
for request_id, op in zip(request_ids, ops, strict=False):
|
||||
if op.block_hashes is not None:
|
||||
chunk_hashes = list(
|
||||
striding_block_hashes(op.block_hashes, self.blocks_in_chunk)
|
||||
)
|
||||
keys = [
|
||||
self._create_hash_key(ch, request_id=request_id)
|
||||
for ch in chunk_hashes
|
||||
]
|
||||
all_keys.extend(keys)
|
||||
else:
|
||||
assert op.token_ids is not None
|
||||
all_keys.append(
|
||||
self._create_key(
|
||||
op.token_ids, op.start, op.end, request_id=request_id
|
||||
)
|
||||
)
|
||||
block_ids.extend(op.block_ids)
|
||||
future = send_lmcache_request(
|
||||
self.mq_client,
|
||||
RequestType.STORE,
|
||||
[
|
||||
all_keys,
|
||||
self.instance_id,
|
||||
block_ids,
|
||||
event.ipc_handle(),
|
||||
],
|
||||
).to_cuda_future()
|
||||
self.store_futures[request_ids[0]] = (future, list(request_ids[1:]))
|
||||
|
||||
@_lmcache_nvtx_annotate
|
||||
def batched_submit_retrieve_requests(
|
||||
self,
|
||||
request_ids: list[str],
|
||||
ops: list[LoadStoreOp],
|
||||
event: torch.cuda.Event,
|
||||
):
|
||||
"""
|
||||
Submit a batched retrieve request to LMCache
|
||||
|
||||
Args:
|
||||
request_ids: The IDs of the requests
|
||||
ops: The LoadStoreOps describing the retrieve operations. Should have
|
||||
the same length as request_ids
|
||||
event: The CUDA event that is recorded after the current
|
||||
model inference step
|
||||
"""
|
||||
all_keys: list[IPCCacheEngineKey] = []
|
||||
block_ids: list[int] = []
|
||||
for request_id, op in zip(request_ids, ops, strict=False):
|
||||
if op.block_hashes is not None:
|
||||
chunk_hashes = list(
|
||||
striding_block_hashes(op.block_hashes, self.blocks_in_chunk)
|
||||
)
|
||||
keys = [
|
||||
self._create_hash_key(ch, request_id=request_id)
|
||||
for ch in chunk_hashes
|
||||
]
|
||||
all_keys.extend(keys)
|
||||
else:
|
||||
assert op.token_ids is not None
|
||||
all_keys.append(
|
||||
self._create_key(
|
||||
op.token_ids, op.start, op.end, request_id=request_id
|
||||
)
|
||||
)
|
||||
block_ids.extend(op.block_ids)
|
||||
future = send_lmcache_request(
|
||||
self.mq_client,
|
||||
RequestType.RETRIEVE,
|
||||
[
|
||||
all_keys,
|
||||
self.instance_id,
|
||||
block_ids,
|
||||
event.ipc_handle(),
|
||||
],
|
||||
).to_cuda_future()
|
||||
self.retrieve_futures[request_ids[0]] = (future, list(request_ids[1:]))
|
||||
|
||||
@_lmcache_nvtx_annotate
|
||||
def get_finished(
|
||||
self, finished_req_ids_from_engine: set[str]
|
||||
) -> tuple[set[str] | None, set[str] | None]:
|
||||
"""
|
||||
Check and get the finished store and retrieve requests.
|
||||
|
||||
Args:
|
||||
finished_req_ids_from_engine: the set of request ids that are
|
||||
reported as finished from the vLLM engine side.
|
||||
|
||||
Returns:
|
||||
A tuple of two sets:
|
||||
- The first set contains the finished store request ids. The returned
|
||||
store request ids MUST be seen before in the
|
||||
`finished_req_ids_from_engine`.
|
||||
- The second set contains the finished retrieve request ids.
|
||||
|
||||
Notes:
|
||||
When enabling async scheduling in vLLM, the same request ID may appear
|
||||
multiple times in `finished_req_ids_from_engine`. The adapter should
|
||||
take care of deduplicating the request IDs and only return the request
|
||||
IDs that have not been returned before.
|
||||
"""
|
||||
finished_stores = set()
|
||||
finished_retrieves = set()
|
||||
for request_id, (s_future, other_reqs) in self.store_futures.items():
|
||||
if not s_future.query():
|
||||
continue
|
||||
|
||||
s_result = s_future.result()
|
||||
finished_stores.add(request_id)
|
||||
finished_stores.update(other_reqs)
|
||||
|
||||
if not s_result:
|
||||
# TODO: add error handling here
|
||||
logger.error(
|
||||
"Something went wrong when processing the "
|
||||
"store request for request_id=%s",
|
||||
request_id,
|
||||
)
|
||||
|
||||
for request_id, (r_future, other_reqs) in self.retrieve_futures.items():
|
||||
if not r_future.query():
|
||||
continue
|
||||
|
||||
r_result = r_future.result()
|
||||
finished_retrieves.add(request_id)
|
||||
finished_retrieves.update(other_reqs)
|
||||
|
||||
if not all(r_result):
|
||||
# TODO: add error handing here
|
||||
logger.error(
|
||||
"Something went wrong when processing the "
|
||||
"retrieve request for request_id=%s, result=%s",
|
||||
request_id,
|
||||
r_result,
|
||||
)
|
||||
|
||||
# Remove the finished requests from the tracking dicts
|
||||
for request_id in finished_stores:
|
||||
self.store_futures.pop(request_id, None)
|
||||
for request_id in finished_retrieves:
|
||||
self.retrieve_futures.pop(request_id, None)
|
||||
|
||||
# Update the internal states
|
||||
self.finished_stores.update(finished_stores)
|
||||
|
||||
ret_stores = set()
|
||||
for req_id in finished_req_ids_from_engine:
|
||||
if req_id in self.finished_stores or req_id in self.store_futures:
|
||||
self.previously_finished.add(req_id)
|
||||
else:
|
||||
ret_stores.add(req_id)
|
||||
|
||||
# Calculate the final finished stores
|
||||
ret_stores.update(self._update_and_get_finished_store())
|
||||
|
||||
return ret_stores, finished_retrieves
|
||||
|
||||
def num_blocks_per_chunk(self) -> int:
|
||||
"""
|
||||
Returns:
|
||||
The number of vllm blocks in a LMCache data chunk
|
||||
"""
|
||||
return self.blocks_in_chunk
|
||||
|
||||
def shutdown(self):
|
||||
"""
|
||||
Shutdown the LMCache MP worker adapter
|
||||
"""
|
||||
logger.info("Unregistering kv caches")
|
||||
send_lmcache_request(
|
||||
self.mq_client, RequestType.UNREGISTER_KV_CACHE, [self.instance_id]
|
||||
).result()
|
||||
|
||||
self.mq_client.close()
|
||||
|
||||
# Helper functions
|
||||
def _update_and_get_finished_store(
|
||||
self,
|
||||
) -> set[str]:
|
||||
"""Converge the internal states about finished stores
|
||||
and returns the 'safe finished store request ids' back
|
||||
"""
|
||||
safe_finished_s = self.finished_stores.intersection(self.previously_finished)
|
||||
self.finished_stores.difference_update(self.previously_finished)
|
||||
self.previously_finished.difference_update(safe_finished_s)
|
||||
|
||||
return safe_finished_s
|
||||
|
||||
def _create_key(
|
||||
self,
|
||||
token_ids: list[int],
|
||||
start: int = 0,
|
||||
end: int = 0,
|
||||
request_id: str | None = None,
|
||||
) -> IPCCacheEngineKey:
|
||||
"""Convert token IDs to an IPC cache engine key"""
|
||||
return IPCCacheEngineKey(
|
||||
model_name=self.model_name,
|
||||
world_size=self.world_size,
|
||||
worker_id=self.worker_id,
|
||||
token_ids=tuple(token_ids),
|
||||
start=start,
|
||||
end=end,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
def _create_hash_key(
|
||||
self, chunk_hash: bytes, request_id: str | None = None
|
||||
) -> IPCCacheEngineKey:
|
||||
"""Create a hash-mode IPC cache engine key"""
|
||||
return IPCCacheEngineKey(
|
||||
model_name=self.model_name,
|
||||
world_size=self.world_size,
|
||||
worker_id=self.worker_id,
|
||||
chunk_hash=chunk_hash,
|
||||
request_id=request_id,
|
||||
)
|
||||
@@ -0,0 +1,211 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# Standard
|
||||
import os
|
||||
import threading
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
from lmcache.logging import init_logger
|
||||
from lmcache.v1.config import LMCacheEngineConfig as V1Config
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.multimodal.inputs import PlaceholderRange
|
||||
from vllm.v1.core.sched.output import NewRequestData
|
||||
from vllm.v1.request import Request
|
||||
|
||||
logger = init_logger(__name__)
|
||||
ENGINE_NAME = "vllm-instance"
|
||||
|
||||
# Thread-safe singleton storage
|
||||
_config_instance: V1Config | None = None
|
||||
_config_lock = threading.Lock()
|
||||
|
||||
|
||||
def is_false(value: str) -> bool:
|
||||
"""Check if the given string value is equivalent to 'false'."""
|
||||
return value.lower() in ("false", "0", "no", "n", "off")
|
||||
|
||||
|
||||
def lmcache_get_or_create_config() -> V1Config:
|
||||
"""Get the LMCache configuration from the environment variable
|
||||
`LMCACHE_CONFIG_FILE`. If the environment variable is not set, this
|
||||
function will return the default configuration.
|
||||
|
||||
This function is thread-safe and implements singleton pattern,
|
||||
ensuring the configuration is loaded only once.
|
||||
"""
|
||||
global _config_instance
|
||||
|
||||
# Double-checked locking for thread-safe singleton
|
||||
if _config_instance is None:
|
||||
with _config_lock:
|
||||
if _config_instance is None: # Check again within lock
|
||||
LMCacheEngineConfig = V1Config # type: ignore[assignment]
|
||||
|
||||
if "LMCACHE_CONFIG_FILE" not in os.environ:
|
||||
logger.warning(
|
||||
"No LMCache configuration file is set. Trying to read"
|
||||
" configurations from the environment variables."
|
||||
)
|
||||
logger.warning(
|
||||
"You can set the configuration file through "
|
||||
"the environment variable: LMCACHE_CONFIG_FILE"
|
||||
)
|
||||
_config_instance = LMCacheEngineConfig.from_env()
|
||||
else:
|
||||
config_file = os.environ["LMCACHE_CONFIG_FILE"]
|
||||
logger.info("Loading LMCache config file %s", config_file)
|
||||
_config_instance = LMCacheEngineConfig.from_file(config_file)
|
||||
# Update config from environment variables
|
||||
_config_instance.update_config_from_env()
|
||||
return _config_instance
|
||||
|
||||
|
||||
def hex_hash_to_int16(s: str) -> int:
|
||||
"""
|
||||
Convert a hex hash string to a 16-bit integer.
|
||||
"""
|
||||
return int(s, 16) & 0xFFFF
|
||||
|
||||
|
||||
def apply_mm_hashes_to_token_ids(
|
||||
token_ids: torch.Tensor,
|
||||
mm_hashes: list[str],
|
||||
mm_positions: list["PlaceholderRange"],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Overwrite token_ids in-place for multimodal placeholders using
|
||||
efficient slice assignments.
|
||||
"""
|
||||
n = token_ids.size(0)
|
||||
for hash_str, placeholder in zip(mm_hashes, mm_positions):
|
||||
start, length = placeholder.offset, placeholder.length
|
||||
if start >= n:
|
||||
continue
|
||||
end = min(start + length, n)
|
||||
token_ids[start:end] = hex_hash_to_int16(hash_str)
|
||||
return token_ids
|
||||
|
||||
|
||||
def mla_enabled(model_config: "ModelConfig") -> bool:
|
||||
return (
|
||||
hasattr(model_config, "use_mla")
|
||||
and isinstance(model_config.use_mla, bool)
|
||||
and model_config.use_mla
|
||||
)
|
||||
|
||||
|
||||
def create_lmcache_metadata(
|
||||
vllm_config=None, model_config=None, parallel_config=None, cache_config=None
|
||||
):
|
||||
"""
|
||||
Create LMCacheEngineMetadata from vLLM configuration.
|
||||
|
||||
This function extracts common metadata creation logic that was duplicated
|
||||
across multiple files.
|
||||
|
||||
Args:
|
||||
vllm_config (VllmConfig): vLLM configuration object containing model,
|
||||
parallel, and cache configs (alternative to
|
||||
individual config parameters)
|
||||
model_config (ModelConfig): Model configuration (alternative to
|
||||
vllm_config)
|
||||
parallel_config (ParallelConfig): Parallel configuration (alternative
|
||||
to vllm_config)
|
||||
cache_config (CacheConfig): Cache configuration (alternative to
|
||||
vllm_config)
|
||||
"""
|
||||
# Third Party
|
||||
# First Party
|
||||
from lmcache.config import LMCacheEngineMetadata
|
||||
|
||||
from vllm.utils.torch_utils import get_kv_cache_torch_dtype
|
||||
|
||||
config = lmcache_get_or_create_config()
|
||||
# Support both vllm_config object and individual config parameters
|
||||
if vllm_config is not None:
|
||||
model_cfg = vllm_config.model_config
|
||||
parallel_cfg = vllm_config.parallel_config
|
||||
cache_cfg = vllm_config.cache_config
|
||||
else:
|
||||
if model_config is None or parallel_config is None or cache_config is None:
|
||||
raise ValueError(
|
||||
"Either vllm_config must be provided, or all of "
|
||||
"model_config, parallel_config, and cache_config must be provided."
|
||||
)
|
||||
model_cfg = model_config
|
||||
parallel_cfg = parallel_config
|
||||
cache_cfg = cache_config
|
||||
|
||||
# Get KV cache dtype
|
||||
kv_dtype = get_kv_cache_torch_dtype(cache_cfg.cache_dtype, model_cfg.dtype)
|
||||
|
||||
# Check if MLA is enabled
|
||||
use_mla = mla_enabled(model_cfg)
|
||||
|
||||
# Construct KV shape (for memory pool)
|
||||
num_layer = model_cfg.get_num_layers(parallel_cfg)
|
||||
chunk_size = config.chunk_size
|
||||
num_kv_head = model_cfg.get_num_kv_heads(parallel_cfg)
|
||||
head_size = model_cfg.get_head_size()
|
||||
kv_shape = (num_layer, 1 if use_mla else 2, chunk_size, num_kv_head, head_size)
|
||||
|
||||
# Create metadata
|
||||
metadata = LMCacheEngineMetadata(
|
||||
model_cfg.model,
|
||||
parallel_cfg.world_size,
|
||||
parallel_cfg.rank,
|
||||
"vllm",
|
||||
kv_dtype,
|
||||
kv_shape,
|
||||
use_mla,
|
||||
)
|
||||
|
||||
return metadata, config
|
||||
|
||||
|
||||
def extract_mm_features(
|
||||
request: Union["Request", "NewRequestData"], modify: bool = False
|
||||
) -> tuple[list[str], list["PlaceholderRange"]]:
|
||||
"""
|
||||
Normalize multimodal information from a Request into parallel lists.
|
||||
|
||||
This helper reads either:
|
||||
1) `request.mm_features` (objects each exposing `.identifier` and
|
||||
`.mm_position`), or
|
||||
2) legacy fields `request.mm_hashes` and `request.mm_positions`.
|
||||
|
||||
It returns two equally sized lists: the multimodal hash identifiers and
|
||||
their corresponding positions. If the request contains no multimodal info,
|
||||
it returns `([], [])`.
|
||||
|
||||
Args:
|
||||
request (Request): The source object.
|
||||
modify (bool):
|
||||
Controls copy semantics for the legacy-path return values.
|
||||
- If True and legacy fields are used, shallow-copies are returned so
|
||||
the caller can mutate the lists without affecting `request`.
|
||||
- If False, the original legacy sequences are returned as-is
|
||||
(zero-copy); treat them as read-only.
|
||||
|
||||
Returns:
|
||||
tuple[list[str], list[PlaceholderRange]]: (`mm_hashes`, `mm_positions`).
|
||||
May be `([], [])` when no multimodal data is present.
|
||||
"""
|
||||
if getattr(request, "mm_features", None):
|
||||
mm_hashes, mm_positions = zip(
|
||||
*((f.identifier, f.mm_position) for f in request.mm_features)
|
||||
)
|
||||
return (list(mm_hashes), list(mm_positions))
|
||||
elif getattr(request, "mm_hashes", None):
|
||||
if modify:
|
||||
return (
|
||||
request.mm_hashes.copy(), # type: ignore
|
||||
request.mm_positions.copy(), # type: ignore
|
||||
)
|
||||
else:
|
||||
return (request.mm_hashes, request.mm_positions) # type: ignore
|
||||
else:
|
||||
return ([], [])
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,955 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import enum
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
import torch
|
||||
import zmq
|
||||
from lmcache.integration.vllm.utils import mla_enabled
|
||||
from lmcache.utils import init_logger as lmcache_init_logger
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1,
|
||||
KVConnectorMetadata,
|
||||
KVConnectorRole,
|
||||
)
|
||||
from vllm.v1.attention.backend import AttentionMetadata
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.outputs import KVConnectorOutput
|
||||
from vllm.v1.request import RequestStatus
|
||||
from vllm.v1.utils import ConstantList
|
||||
|
||||
try:
|
||||
from lmcache.integration.vllm.vllm_multi_process_adapter import (
|
||||
LMCacheMPSchedulerAdapter,
|
||||
LMCacheMPWorkerAdapter,
|
||||
LoadStoreOp,
|
||||
)
|
||||
except ImportError:
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.lmcache_integration import (
|
||||
LMCacheMPSchedulerAdapter,
|
||||
LMCacheMPWorkerAdapter,
|
||||
LoadStoreOp,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.distributed.kv_events import KVCacheEvent
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
|
||||
KVConnectorPromMetrics,
|
||||
KVConnectorStats,
|
||||
PromMetric,
|
||||
PromMetricT,
|
||||
)
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.core.kv_cache_utils import BlockHash
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.request import Request
|
||||
|
||||
logger = lmcache_init_logger(__name__)
|
||||
|
||||
|
||||
# Helper functions
|
||||
def reformat_block_ids(block_ids: tuple[list[int], ...] | None) -> list[int]:
|
||||
if block_ids is None:
|
||||
return []
|
||||
assert isinstance(block_ids, tuple), (
|
||||
f"Expected block_ids to be a tuple of lists, but got {type(block_ids)}"
|
||||
)
|
||||
|
||||
if len(block_ids) > 1:
|
||||
raise RuntimeError(
|
||||
"LMCacheMPConnector only works without hybrid kv cache manager. "
|
||||
"Please pass --disable-hybrid-kv-cache-manager when starting vllm"
|
||||
)
|
||||
|
||||
return block_ids[0]
|
||||
|
||||
|
||||
def extract_world_size_and_kv_rank(
|
||||
world_size: int,
|
||||
rank: int,
|
||||
vllm_config: VllmConfig,
|
||||
) -> tuple[int, int]:
|
||||
"""
|
||||
Convert the rank for the MLA.
|
||||
"""
|
||||
use_mla = mla_enabled(vllm_config.model_config)
|
||||
if not use_mla:
|
||||
return world_size, rank
|
||||
else:
|
||||
# Tensor parallel does not change the KV caches for MLA models.
|
||||
# So we need to "exclude" the effect of TP on rank and world size
|
||||
tp_size = vllm_config.parallel_config.tensor_parallel_size
|
||||
# vLLM constructs TP groups first, and then construct other
|
||||
# parallel groups on top of TP groups.
|
||||
# for example, TP=4, PP=2,
|
||||
# TP group: [0, 1, 2, 3], [4, 5, 6, 7]
|
||||
# PP group: [0, 4], [1, 5], [2, 6], [3, 7]
|
||||
# So we can "exclude" the effect of TP by rank // tp_size.
|
||||
return world_size // tp_size, rank // tp_size
|
||||
|
||||
|
||||
def create_scheduler_adapter(
|
||||
server_url: str, zmq_context: zmq.Context, vllm_config: VllmConfig
|
||||
) -> LMCacheMPSchedulerAdapter:
|
||||
world_size, kv_rank = extract_world_size_and_kv_rank(
|
||||
vllm_config.parallel_config.world_size,
|
||||
vllm_config.parallel_config.rank,
|
||||
vllm_config,
|
||||
)
|
||||
return LMCacheMPSchedulerAdapter(
|
||||
server_url,
|
||||
zmq_context,
|
||||
vllm_config.model_config.model,
|
||||
world_size,
|
||||
kv_rank,
|
||||
vllm_config.cache_config.block_size,
|
||||
)
|
||||
|
||||
|
||||
def create_worker_adapter(
|
||||
server_url: str, zmq_context: zmq.Context, vllm_config: VllmConfig
|
||||
) -> LMCacheMPWorkerAdapter:
|
||||
world_size, kv_rank = extract_world_size_and_kv_rank(
|
||||
vllm_config.parallel_config.world_size,
|
||||
vllm_config.parallel_config.rank,
|
||||
vllm_config,
|
||||
)
|
||||
return LMCacheMPWorkerAdapter(
|
||||
server_url,
|
||||
zmq_context,
|
||||
vllm_config.model_config.model,
|
||||
world_size,
|
||||
kv_rank,
|
||||
vllm_config.cache_config.block_size,
|
||||
)
|
||||
|
||||
|
||||
class LMCacheMPRequestState(enum.Enum):
|
||||
"""
|
||||
State machine:
|
||||
PREFETCHING -- update_state_after_alloc --> WAITING_FOR_LOAD
|
||||
WAITING_FOR_LOAD -- process_loading_requests --> READY
|
||||
"""
|
||||
|
||||
PREFETCHING = enum.auto()
|
||||
WAITING_FOR_LOAD = enum.auto()
|
||||
READY = enum.auto()
|
||||
|
||||
|
||||
@dataclass
|
||||
class LMCacheMPRequestTracker:
|
||||
# NOTE: this class used vLLM data structures, should be part of
|
||||
# vLLM integration code
|
||||
|
||||
request_id: str
|
||||
|
||||
# Read-only lists to track the token ids and block hashes
|
||||
all_token_ids: ConstantList[int]
|
||||
block_hashes: ConstantList["BlockHash"]
|
||||
|
||||
# Block ids and hashes will be updated at update_states_after_alloc and
|
||||
# during the generation
|
||||
allocated_block_ids: list[int] = field(default_factory=list)
|
||||
|
||||
# Number of scheduled tokens in this request. We keep tracking this to
|
||||
# avoid saving half-full blocks.
|
||||
num_scheduled_tokens: int = 0
|
||||
|
||||
# Number of blocks stored will be initialized when lookup the external
|
||||
# hit tokens and will be updated when processing new requests and cached
|
||||
# requests.
|
||||
num_stored_blocks: int = 0
|
||||
|
||||
# Staging load operation -- save vllm and lmcache hit tokens during lookup
|
||||
num_vllm_hit_blocks: int = 0
|
||||
num_lmcache_hit_blocks: int = 0
|
||||
|
||||
# Main state
|
||||
state: LMCacheMPRequestState = LMCacheMPRequestState.PREFETCHING
|
||||
|
||||
def __init__(self, request: "Request"):
|
||||
self.request_id = request.request_id
|
||||
self.all_token_ids = request.all_token_ids
|
||||
self.block_hashes = ConstantList(request.block_hashes)
|
||||
self.allocated_block_ids = []
|
||||
self.num_stored_blocks = 0
|
||||
self.num_vllm_hit_blocks = 0
|
||||
self.num_lmcache_hit_blocks = 0
|
||||
self.state = LMCacheMPRequestState.PREFETCHING
|
||||
|
||||
####
|
||||
# Check the state of the request
|
||||
####
|
||||
def needs_retrieve(self) -> bool:
|
||||
"""Check whether the current request needs retrieve, will be used
|
||||
update_stage_after_alloc"""
|
||||
return (
|
||||
self.num_lmcache_hit_blocks > self.num_vllm_hit_blocks
|
||||
and self.state != LMCacheMPRequestState.READY
|
||||
)
|
||||
|
||||
def is_ready_for_retrieving(self) -> bool:
|
||||
"""Check whether the current request is ready for retrieving,
|
||||
will be used in process_loading_requests"""
|
||||
return (
|
||||
self.state == LMCacheMPRequestState.WAITING_FOR_LOAD
|
||||
and self.needs_retrieve()
|
||||
)
|
||||
|
||||
####
|
||||
# Update internal states
|
||||
####
|
||||
def increase_num_scheduled_tokens(self, num_new_tokens: int):
|
||||
self.num_scheduled_tokens += num_new_tokens
|
||||
|
||||
def increase_num_stored_blocks(self, num_new_blocks: int):
|
||||
"""Increase the number of stored blocks for the current request
|
||||
This function will be called when processing the cached requests.
|
||||
"""
|
||||
self.num_stored_blocks += num_new_blocks
|
||||
|
||||
def append_block_ids(
|
||||
self,
|
||||
new_block_ids: list[int],
|
||||
):
|
||||
"""Update the block ids for the current request
|
||||
This function will be called when processing the cached requests.
|
||||
"""
|
||||
self.allocated_block_ids.extend(new_block_ids)
|
||||
|
||||
####
|
||||
# For debugging
|
||||
####
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"LMCacheMPRequestTracker(request_id={self.request_id}, "
|
||||
f"num_tokens={len(self.all_token_ids)}, "
|
||||
f"num_block_hashes={len(self.block_hashes)}, "
|
||||
f"num_allocated_blocks={len(self.allocated_block_ids)}, "
|
||||
f"num_stored_blocks={self.num_stored_blocks}, "
|
||||
f"vllm_hit_blocks={self.num_vllm_hit_blocks}, "
|
||||
f"lmcache_hit_blocks={self.num_lmcache_hit_blocks}, "
|
||||
f"state={self.state})"
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.__repr__()
|
||||
|
||||
|
||||
@dataclass
|
||||
class LMCacheMPRequestMetadata:
|
||||
request_id: str
|
||||
direction: Literal["STORE", "RETRIEVE"]
|
||||
op: LoadStoreOp
|
||||
|
||||
@staticmethod
|
||||
def GetStoreMetadata(
|
||||
tracker: LMCacheMPRequestTracker,
|
||||
blocks_in_chunk: int,
|
||||
vllm_block_size: int,
|
||||
) -> "LMCacheMPRequestMetadata | None":
|
||||
"""
|
||||
Generate the store metadata for the current request tracker.
|
||||
|
||||
Args:
|
||||
tracker: The request tracker to generate the metadata from.
|
||||
blocks_in_chunk: the number of blocks in a LMCache data chunk
|
||||
vllm_block_size: the block size used in vLLM
|
||||
"""
|
||||
# Store the blocks that has block hashes
|
||||
# NOTE: the invariant here is that `num_stored_blocks` should
|
||||
# always be a multiple of `blocks_in_chunk`
|
||||
# TODO: This should be checked everytime we update the num_stored_blocks
|
||||
min_available_blocks = min(
|
||||
len(tracker.block_hashes),
|
||||
len(tracker.allocated_block_ids),
|
||||
tracker.num_scheduled_tokens // vllm_block_size,
|
||||
)
|
||||
num_staging_blocks = min_available_blocks - tracker.num_stored_blocks
|
||||
num_chunks = num_staging_blocks // blocks_in_chunk
|
||||
|
||||
if num_chunks >= 1:
|
||||
start = tracker.num_stored_blocks
|
||||
end = start + num_chunks * blocks_in_chunk
|
||||
block_ids = tracker.allocated_block_ids[start:end]
|
||||
start_token_idx = start * vllm_block_size
|
||||
end_token_idx = end * vllm_block_size
|
||||
token_ids = list(tracker.all_token_ids)
|
||||
op = LoadStoreOp(
|
||||
token_ids=token_ids,
|
||||
block_ids=block_ids,
|
||||
start=start_token_idx,
|
||||
end=end_token_idx,
|
||||
)
|
||||
|
||||
ret = LMCacheMPRequestMetadata(
|
||||
request_id=tracker.request_id,
|
||||
direction="STORE",
|
||||
op=op,
|
||||
)
|
||||
|
||||
# Update the request tracker
|
||||
tracker.increase_num_stored_blocks(end - start)
|
||||
return ret
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def GetRetrieveMetadata(
|
||||
tracker: LMCacheMPRequestTracker,
|
||||
blocks_in_chunk: int,
|
||||
vllm_block_size: int,
|
||||
) -> "LMCacheMPRequestMetadata | None":
|
||||
"""
|
||||
Generate the retrieve metadata for the current request tracker.
|
||||
|
||||
Args:
|
||||
tracker: The request tracker to generate the metadata from.
|
||||
blocks_in_chunk: the number of blocks in a LMCache data chunk
|
||||
vllm_block_size: the block size used in vLLM
|
||||
"""
|
||||
if not tracker.is_ready_for_retrieving():
|
||||
return None
|
||||
|
||||
# |---------------------|-----------------|----------------|
|
||||
# | num_vllm_hit_blocks |
|
||||
# | lmcache chunk 1 | lmcache chunk 2 |
|
||||
# | need to retrieve |
|
||||
|
||||
start = tracker.num_vllm_hit_blocks // blocks_in_chunk * blocks_in_chunk
|
||||
end = tracker.num_lmcache_hit_blocks
|
||||
assert end % blocks_in_chunk == 0, (
|
||||
"The number of LMCache hit blocks should be a multiple of the "
|
||||
"number of blocks in a lmcache chunk. "
|
||||
)
|
||||
assert len(tracker.block_hashes) >= end, (
|
||||
"The number of block hashes should be greater than or equal to the "
|
||||
"number of LMCache hit blocks. "
|
||||
)
|
||||
if end > start:
|
||||
block_ids = tracker.allocated_block_ids[start:end]
|
||||
start_token_idx = start * vllm_block_size
|
||||
end_token_idx = end * vllm_block_size
|
||||
token_ids = list(tracker.all_token_ids)
|
||||
op = LoadStoreOp(
|
||||
token_ids=token_ids,
|
||||
block_ids=block_ids,
|
||||
start=start_token_idx,
|
||||
end=end_token_idx,
|
||||
)
|
||||
|
||||
ret = LMCacheMPRequestMetadata(
|
||||
request_id=tracker.request_id,
|
||||
direction="RETRIEVE",
|
||||
op=op,
|
||||
)
|
||||
return ret
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class LMCacheMPConnectorMetadata(KVConnectorMetadata):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.requests: list[LMCacheMPRequestMetadata] = []
|
||||
|
||||
def add_request_metadata(self, request_metadata: LMCacheMPRequestMetadata):
|
||||
self.requests.append(request_metadata)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.requests)
|
||||
|
||||
# For debugging
|
||||
def __str__(self):
|
||||
request_strs = []
|
||||
for req_meta in self.requests:
|
||||
request_strs.append(
|
||||
f"RequestMetadata(request_id={req_meta.request_id}, "
|
||||
f"direction={req_meta.direction}, "
|
||||
f"num_blocks={len(req_meta.op)}, "
|
||||
f"block_ids={req_meta.op.block_ids})"
|
||||
)
|
||||
return "[" + "\n".join(request_strs) + "]"
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
|
||||
class LMCacheMPConnector(KVConnectorBase_V1):
|
||||
"""
|
||||
The connector for LMCache multi-process mode.
|
||||
|
||||
Extra configs (kv_transfer_config.extra_config):
|
||||
- lmcache.mp.host: the host of the LMCache server.
|
||||
- lmcache.mp.port: the port of the LMCache server.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: "VllmConfig",
|
||||
role: KVConnectorRole,
|
||||
kv_cache_config: "KVCacheConfig | None" = None,
|
||||
):
|
||||
super().__init__(vllm_config, role, kv_cache_config)
|
||||
|
||||
assert vllm_config.kv_transfer_config is not None
|
||||
server_host = vllm_config.kv_transfer_config.get_from_extra_config(
|
||||
"lmcache.mp.host", "tcp://localhost"
|
||||
)
|
||||
server_port = vllm_config.kv_transfer_config.get_from_extra_config(
|
||||
"lmcache.mp.port", 5555
|
||||
)
|
||||
|
||||
server_url = f"{server_host}:{server_port}"
|
||||
zmq_context = zmq.Context.instance()
|
||||
if self.role == KVConnectorRole.SCHEDULER:
|
||||
self.scheduler_adapter = create_scheduler_adapter(
|
||||
server_url, zmq_context, vllm_config
|
||||
)
|
||||
self.request_trackers: dict[str, LMCacheMPRequestTracker] = {}
|
||||
elif self.role == KVConnectorRole.WORKER:
|
||||
self.worker_adapter = create_worker_adapter(
|
||||
server_url, zmq_context, vllm_config
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown KVConnectorRole: {self.role}")
|
||||
|
||||
self.vllm_block_size = vllm_config.cache_config.block_size
|
||||
|
||||
@property
|
||||
def role(self) -> KVConnectorRole:
|
||||
return self._role
|
||||
|
||||
# ==============================
|
||||
# Worker-side methods
|
||||
# ==============================
|
||||
|
||||
def _get_connector_metadata(self) -> KVConnectorMetadata:
|
||||
"""Get the connector metadata.
|
||||
|
||||
This function should only be called inside the connector.
|
||||
|
||||
Returns:
|
||||
ConnectorMetadata: the connector metadata.
|
||||
"""
|
||||
|
||||
# Should only be called while set to valid metadata.
|
||||
assert self._connector_metadata is not None
|
||||
return self._connector_metadata
|
||||
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
"""
|
||||
Initialize with the KV caches. Useful for pre-registering the
|
||||
KV Caches in the KVConnector (e.g. for NIXL).
|
||||
|
||||
Args:
|
||||
kv_caches: dictionary of layer names, kv cache
|
||||
"""
|
||||
logger.info("Registering kv caches!")
|
||||
self.worker_adapter.register_kv_caches(kv_caches)
|
||||
return
|
||||
|
||||
def start_load_kv(self, forward_context: "ForwardContext", **kwargs: Any) -> None:
|
||||
"""
|
||||
Start loading the KV cache from the connector to vLLM's paged
|
||||
KV buffer. This is called from the forward context before the
|
||||
forward pass to enable async loading during model execution.
|
||||
|
||||
Args:
|
||||
forward_context (ForwardContext): the forward context.
|
||||
**kwargs: additional arguments for the load operation
|
||||
|
||||
Note:
|
||||
The number of elements in kv_caches and layer_names should be
|
||||
the same.
|
||||
|
||||
"""
|
||||
metadata = self._get_connector_metadata()
|
||||
assert isinstance(metadata, LMCacheMPConnectorMetadata)
|
||||
|
||||
request_ids = []
|
||||
ops = []
|
||||
|
||||
for meta in metadata.requests:
|
||||
if meta.direction != "RETRIEVE":
|
||||
continue
|
||||
request_ids.append(meta.request_id)
|
||||
ops.append(meta.op)
|
||||
|
||||
if len(request_ids) == 0:
|
||||
return
|
||||
|
||||
with torch.cuda.stream(torch.cuda.current_stream()):
|
||||
event = torch.cuda.Event(interprocess=True)
|
||||
event.record()
|
||||
|
||||
self.worker_adapter.batched_submit_retrieve_requests(request_ids, ops, event)
|
||||
|
||||
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||
"""
|
||||
Block until the KV for a specific layer is loaded into vLLM's
|
||||
paged buffer. This is called from within attention layer to ensure
|
||||
async copying from start_load_kv is complete.
|
||||
|
||||
This interface will be useful for layer-by-layer pipelining.
|
||||
|
||||
Args:
|
||||
layer_name: the name of that layer
|
||||
"""
|
||||
return
|
||||
|
||||
def save_kv_layer(
|
||||
self,
|
||||
layer_name: str,
|
||||
kv_layer: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Start saving a layer of KV cache from vLLM's paged buffer
|
||||
to the connector. This is called from within attention layer to
|
||||
enable async copying during execution.
|
||||
|
||||
Args:
|
||||
layer_name (str): the name of the layer.
|
||||
kv_layer (torch.Tensor): the paged KV buffer of the current
|
||||
layer in vLLM.
|
||||
attn_metadata (AttentionMetadata): the attention metadata.
|
||||
**kwargs: additional arguments for the save operation.
|
||||
"""
|
||||
return
|
||||
|
||||
def wait_for_save(self):
|
||||
"""
|
||||
Block until all the save operations is done. This is called
|
||||
as the forward context exits to ensure that the async saving
|
||||
from save_kv_layer is complete before finishing the forward.
|
||||
|
||||
This prevents overwrites of paged KV buffer before saving done.
|
||||
"""
|
||||
metadata = self._get_connector_metadata()
|
||||
assert isinstance(metadata, LMCacheMPConnectorMetadata)
|
||||
|
||||
request_ids = []
|
||||
ops = []
|
||||
for meta in metadata.requests:
|
||||
if meta.direction != "STORE":
|
||||
continue
|
||||
request_ids.append(meta.request_id)
|
||||
ops.append(meta.op)
|
||||
|
||||
if len(request_ids) == 0:
|
||||
return
|
||||
|
||||
with torch.cuda.stream(torch.cuda.current_stream()):
|
||||
event = torch.cuda.Event(interprocess=True)
|
||||
event.record()
|
||||
|
||||
self.worker_adapter.batched_submit_store_requests(request_ids, ops, event)
|
||||
|
||||
def get_finished(
|
||||
self, finished_req_ids: set[str]
|
||||
) -> tuple[set[str] | None, set[str] | None]:
|
||||
"""
|
||||
Notifies worker-side connector ids of requests that have
|
||||
finished generating tokens on the worker.
|
||||
The scheduler process (via the Executors) will use this output
|
||||
to track which workers are done.
|
||||
|
||||
Returns:
|
||||
ids of requests that have finished asynchronous transfer
|
||||
(requests that previously returned True from request_finished()),
|
||||
tuple of (sending/saving ids, recving/loading ids).
|
||||
The finished saves/sends req ids must belong to a set provided in a
|
||||
call to this method (this call or a prior one).
|
||||
"""
|
||||
val = self.worker_adapter.get_finished(finished_req_ids)
|
||||
# logger.error("Finished req ids: %s, %s", val[0], val[1])
|
||||
return val
|
||||
|
||||
def get_block_ids_with_load_errors(self) -> set[int]:
|
||||
"""
|
||||
Get the set of block IDs that failed to load.
|
||||
|
||||
Returns:
|
||||
Set of block IDs that encountered load errors.
|
||||
Empty set if no load errors occurred.
|
||||
|
||||
Notes:
|
||||
- Applies to both sync- and async-loading requests.
|
||||
- Async loading: failed blocks may be reported in any forward pass
|
||||
up to and including the pass where the request ID is returned by
|
||||
`get_finished()`. Even if failures occur, the request must still
|
||||
be reported via `get_finished()`, and the failed block IDs must
|
||||
appear here no later than that same pass.
|
||||
- Sync loading: failed blocks should be reported in the forward
|
||||
pass in which they are detected.
|
||||
"""
|
||||
# TODO: add error tracking
|
||||
return set()
|
||||
|
||||
def shutdown(self):
|
||||
"""
|
||||
Shutdown the connector. This is called when the worker process
|
||||
is shutting down to ensure that all the async operations are
|
||||
completed and the connector is cleaned up properly.
|
||||
"""
|
||||
if hasattr(self, "worker_adapter"):
|
||||
self.worker_adapter.shutdown()
|
||||
return None
|
||||
|
||||
def get_kv_connector_stats(self) -> "KVConnectorStats | None":
|
||||
"""
|
||||
Get the KV connector stats collected during the last interval.
|
||||
"""
|
||||
return None
|
||||
|
||||
# ==============================
|
||||
# Scheduler-side methods
|
||||
# ==============================
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self,
|
||||
request: "Request",
|
||||
num_computed_tokens: int,
|
||||
) -> tuple[int | None, bool]:
|
||||
"""
|
||||
Get number of new tokens that can be loaded from the
|
||||
external KV cache beyond the num_computed_tokens.
|
||||
|
||||
Args:
|
||||
request (Request): the request object.
|
||||
num_computed_tokens (int): the number of locally
|
||||
computed tokens for this request
|
||||
|
||||
Returns:
|
||||
A tuple with the following elements:
|
||||
- An optional number of tokens that can be loaded from the
|
||||
external KV cache beyond what is already computed.
|
||||
If None, it means that the connector needs more time to
|
||||
determine the number of matched tokens, and the scheduler
|
||||
should query for this request again later.
|
||||
- `True` if external KV cache tokens will be loaded
|
||||
asynchronously (between scheduler steps). Must be
|
||||
'False' if the first element is 0.
|
||||
|
||||
Notes:
|
||||
The connector should only consider the largest prefix of prompt-
|
||||
tokens for which KV cache is actually available at the time of the
|
||||
call. If the cache cannot be loaded for some tokens (e.g., due to
|
||||
connectivity issues or eviction), those tokens must not be taken
|
||||
into account.
|
||||
"""
|
||||
tracker = self._get_or_create_request_tracker(request)
|
||||
# TODO: support loading KV for preempted requests in the future
|
||||
if request.status == RequestStatus.PREEMPTED:
|
||||
return 0, False
|
||||
|
||||
self.scheduler_adapter.maybe_submit_lookup_request(
|
||||
request.request_id,
|
||||
token_ids=list(request.all_token_ids),
|
||||
)
|
||||
|
||||
ret = self.scheduler_adapter.check_lookup_result(request.request_id)
|
||||
if ret is None:
|
||||
return None, True
|
||||
|
||||
if ret == 0:
|
||||
return 0, False
|
||||
|
||||
assert (
|
||||
ret % (self.scheduler_adapter.num_blocks_per_chunk() * self.vllm_block_size)
|
||||
== 0
|
||||
)
|
||||
|
||||
# Update num stored blocks for the tracker
|
||||
num_vllm_blocks = num_computed_tokens // self.vllm_block_size
|
||||
num_lmcache_blocks = ret // self.vllm_block_size
|
||||
tracker.increase_num_stored_blocks(num_lmcache_blocks)
|
||||
|
||||
# Save the vllm and lmcache hit tokens
|
||||
tracker.num_vllm_hit_blocks = num_vllm_blocks
|
||||
tracker.num_lmcache_hit_blocks = num_lmcache_blocks
|
||||
|
||||
need_to_load = max(0, ret - num_computed_tokens)
|
||||
logger.debug(
|
||||
"vLLM hit is: %d, Need to load is %d", num_computed_tokens, need_to_load
|
||||
)
|
||||
return need_to_load, need_to_load > 0
|
||||
|
||||
def update_state_after_alloc(
|
||||
self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
|
||||
):
|
||||
"""
|
||||
Update KVConnector state after block allocation.
|
||||
|
||||
If get_num_new_matched_tokens previously returned True for a
|
||||
request, this function may be called twice for that same request -
|
||||
first when blocks are allocated for the connector tokens to be
|
||||
asynchronously loaded into, and second when any additional blocks
|
||||
are allocated, after the load/transfer is complete.
|
||||
|
||||
Args:
|
||||
request (Request): the request object.
|
||||
blocks (KVCacheBlocks): the blocks allocated for the request.
|
||||
num_external_tokens (int): the number of tokens that will be
|
||||
loaded from the external KV cache.
|
||||
"""
|
||||
# NOTE: the `blocks` are NEW BLOCKS allocated for this request.
|
||||
tracker = self._get_request_tracker(request.request_id)
|
||||
block_ids = reformat_block_ids(blocks.get_block_ids())
|
||||
|
||||
# No matter we need to retrieve or not, we need to update
|
||||
# the block ids into the tracker
|
||||
tracker.append_block_ids(block_ids)
|
||||
|
||||
# Update the state of the tracker
|
||||
condition = tracker.needs_retrieve()
|
||||
if tracker.state == LMCacheMPRequestState.PREFETCHING:
|
||||
# If need to retrieve, change to WAITING_FOR_LOAD
|
||||
# Otherwise, change to READY
|
||||
tracker.state = (
|
||||
LMCacheMPRequestState.WAITING_FOR_LOAD
|
||||
if condition
|
||||
else LMCacheMPRequestState.READY
|
||||
)
|
||||
# Clean up lookup future in scheduler adapter
|
||||
self.scheduler_adapter.cleanup_lookup_result(request.request_id)
|
||||
|
||||
def build_connector_meta(
|
||||
self, scheduler_output: SchedulerOutput
|
||||
) -> KVConnectorMetadata:
|
||||
"""
|
||||
Build the connector metadata for this step.
|
||||
|
||||
This function should NOT modify fields in the scheduler_output.
|
||||
Also, calling this function will reset the state of the connector.
|
||||
|
||||
Args:
|
||||
scheduler_output (SchedulerOutput): the scheduler output object.
|
||||
"""
|
||||
metadata = LMCacheMPConnectorMetadata()
|
||||
|
||||
self._process_retrieve_requests(metadata)
|
||||
self._process_new_requests(scheduler_output, metadata)
|
||||
self._process_cached_requests(scheduler_output, metadata)
|
||||
|
||||
if len(metadata) > 0:
|
||||
logger.debug("Final connector metadata: %s", metadata)
|
||||
|
||||
return metadata
|
||||
|
||||
def update_connector_output(self, connector_output: KVConnectorOutput):
|
||||
"""
|
||||
Update KVConnector state from worker-side connectors output.
|
||||
|
||||
Args:
|
||||
connector_output (KVConnectorOutput): the worker-side
|
||||
connectors output.
|
||||
"""
|
||||
return
|
||||
|
||||
def request_finished(
|
||||
self,
|
||||
request: "Request",
|
||||
block_ids: list[int],
|
||||
) -> tuple[bool, dict[str, Any] | None]:
|
||||
"""
|
||||
Called exactly once when a request has finished, before its blocks are
|
||||
freed.
|
||||
|
||||
The connector may assumes responsibility for freeing the blocks
|
||||
asynchronously by returning True.
|
||||
|
||||
Returns:
|
||||
True if the request is being saved/sent asynchronously and blocks
|
||||
should not be freed until the request_id is returned from
|
||||
get_finished().
|
||||
Optional KVTransferParams to be included in the request outputs
|
||||
returned by the engine.
|
||||
"""
|
||||
# Clean up request tracker to prevent memory leak
|
||||
self._cleanup_request_tracker(request.request_id)
|
||||
# Notify LMCache to end the session for this request
|
||||
self.scheduler_adapter.end_session(request.request_id)
|
||||
|
||||
return True, None
|
||||
|
||||
def take_events(self) -> Iterable["KVCacheEvent"]:
|
||||
"""
|
||||
Take the KV cache events from the connector.
|
||||
|
||||
Yields:
|
||||
New KV cache events since the last call.
|
||||
"""
|
||||
return ()
|
||||
|
||||
@classmethod
|
||||
def get_required_kvcache_layout(cls, vllm_config: "VllmConfig") -> str | None:
|
||||
"""
|
||||
Get the required KV cache layout for this connector.
|
||||
Args:
|
||||
vllm_config (VllmConfig): the vllm config.
|
||||
|
||||
Returns:
|
||||
str: the required KV cache layout. e.g. HND, or NHD.
|
||||
None if the connector does not require a specific layout.
|
||||
"""
|
||||
|
||||
if cls is KVConnectorBase_V1:
|
||||
raise TypeError(
|
||||
"get_required_kvcache_layout should not be called "
|
||||
"on the abstract base class"
|
||||
)
|
||||
return None
|
||||
|
||||
def get_finished_count(self) -> int | None:
|
||||
"""
|
||||
Get the count of requests expected to complete send/receive operations
|
||||
via this connector. This method is used to initialize the
|
||||
KVOutputAggregator, overwriting the default world_size.
|
||||
|
||||
Returns:
|
||||
int: expected sending or receiving completion count.
|
||||
"""
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def build_kv_connector_stats(
|
||||
cls, data: dict[str, Any] | None = None
|
||||
) -> "KVConnectorStats | None":
|
||||
"""
|
||||
KVConnectorStats resolution method. This method allows dynamically
|
||||
registered connectors to return their own KVConnectorStats object,
|
||||
which can implement custom aggregation logic on the data dict.
|
||||
"""
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def build_prom_metrics(
|
||||
cls,
|
||||
vllm_config: "VllmConfig",
|
||||
metric_types: dict[type["PromMetric"], type["PromMetricT"]],
|
||||
labelnames: list[str],
|
||||
per_engine_labelvalues: dict[int, list[object]],
|
||||
) -> "KVConnectorPromMetrics | None":
|
||||
"""
|
||||
Create a KVConnectorPromMetrics subclass which should register
|
||||
per-connector Prometheus metrics and implement observe() to
|
||||
expose connector transfer stats via Prometheus.
|
||||
"""
|
||||
return None
|
||||
|
||||
##############################
|
||||
# Helper functions
|
||||
##############################
|
||||
def _process_retrieve_requests(
|
||||
self,
|
||||
metadata: LMCacheMPConnectorMetadata,
|
||||
) -> None:
|
||||
blocks_per_chunk = self.scheduler_adapter.num_blocks_per_chunk()
|
||||
|
||||
for request_tracker in self.request_trackers.values():
|
||||
if request_tracker.state != LMCacheMPRequestState.WAITING_FOR_LOAD:
|
||||
continue
|
||||
r_metadata = LMCacheMPRequestMetadata.GetRetrieveMetadata(
|
||||
request_tracker,
|
||||
blocks_per_chunk,
|
||||
vllm_block_size=self.vllm_block_size,
|
||||
)
|
||||
if r_metadata is not None:
|
||||
metadata.add_request_metadata(r_metadata)
|
||||
request_tracker.state = LMCacheMPRequestState.READY
|
||||
|
||||
def _process_new_requests(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
metadata: LMCacheMPConnectorMetadata,
|
||||
) -> None:
|
||||
blocks_per_chunk = self.scheduler_adapter.num_blocks_per_chunk()
|
||||
|
||||
for new_request in scheduler_output.scheduled_new_reqs:
|
||||
request_tracker = self._get_request_tracker(new_request.req_id)
|
||||
|
||||
num_new_tokens = scheduler_output.num_scheduled_tokens[new_request.req_id]
|
||||
request_tracker.increase_num_scheduled_tokens(num_new_tokens)
|
||||
|
||||
r_meta = LMCacheMPRequestMetadata.GetStoreMetadata(
|
||||
request_tracker, blocks_per_chunk, self.vllm_block_size
|
||||
)
|
||||
if r_meta is not None:
|
||||
metadata.add_request_metadata(r_meta)
|
||||
|
||||
def _process_cached_requests(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
metadata: LMCacheMPConnectorMetadata,
|
||||
) -> None:
|
||||
blocks_per_chunk = self.scheduler_adapter.num_blocks_per_chunk()
|
||||
|
||||
cached_reqs = scheduler_output.scheduled_cached_reqs
|
||||
for idx, request_id in enumerate(cached_reqs.req_ids):
|
||||
request_tracker = self._get_request_tracker(request_id)
|
||||
|
||||
# Update block ids
|
||||
new_block_ids = reformat_block_ids(cached_reqs.new_block_ids[idx])
|
||||
if request_id not in cached_reqs.resumed_req_ids:
|
||||
request_tracker.append_block_ids(new_block_ids)
|
||||
|
||||
# Update new scheduled tokens
|
||||
num_new_tokens = cached_reqs.num_computed_tokens[idx]
|
||||
request_tracker.increase_num_scheduled_tokens(num_new_tokens)
|
||||
|
||||
r_meta = LMCacheMPRequestMetadata.GetStoreMetadata(
|
||||
request_tracker, blocks_per_chunk, self.vllm_block_size
|
||||
)
|
||||
|
||||
if r_meta is not None:
|
||||
metadata.add_request_metadata(r_meta)
|
||||
|
||||
def _get_request_tracker(self, request_id: str) -> LMCacheMPRequestTracker:
|
||||
assert request_id in self.request_trackers, (
|
||||
f"Request tracker for request_id {request_id} not found. "
|
||||
)
|
||||
return self.request_trackers[request_id]
|
||||
|
||||
def _get_or_create_request_tracker(
|
||||
self, request: "Request"
|
||||
) -> LMCacheMPRequestTracker:
|
||||
request_id = request.request_id
|
||||
# Remove the old trackers that is created before the preemption
|
||||
if (
|
||||
request.status == RequestStatus.PREEMPTED
|
||||
and request_id in self.request_trackers
|
||||
):
|
||||
tracker = self.request_trackers[request_id]
|
||||
|
||||
# NOTE: since this function may be called multiple times
|
||||
# for a single request (because get_num_new_matched_tokens
|
||||
# may be called multiple times) for the same request, we
|
||||
# will only do the remove if the tracker is not in the "fresh"
|
||||
# state, i.e., PREFETCHING
|
||||
if tracker.state != LMCacheMPRequestState.PREFETCHING:
|
||||
self.request_trackers.pop(request_id)
|
||||
|
||||
if request_id not in self.request_trackers:
|
||||
new_tracker = LMCacheMPRequestTracker(request)
|
||||
self.request_trackers[request_id] = new_tracker
|
||||
return self.request_trackers[request_id]
|
||||
|
||||
def _cleanup_request_tracker(self, request_id: str) -> None:
|
||||
"""
|
||||
Clean up request tracker and associated lookup future for a request.
|
||||
This should be called when a request is finished to prevent memory leak.
|
||||
"""
|
||||
# Clean up request tracker
|
||||
if self.request_trackers.pop(request_id, None):
|
||||
logger.debug(
|
||||
"[KVConnector] Cleaned up request_tracker for request %s",
|
||||
request_id,
|
||||
)
|
||||
186
vllm/distributed/kv_transfer/kv_connector/v1/metrics.py
Normal file
186
vllm/distributed/kv_transfer/kv_connector/v1/metrics.py
Normal file
@@ -0,0 +1,186 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, TypeAlias, TypeVar
|
||||
|
||||
from prometheus_client import Counter, Gauge, Histogram
|
||||
|
||||
from vllm.config import KVTransferConfig, VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
|
||||
from vllm.logger import init_logger
|
||||
|
||||
PromMetric: TypeAlias = Gauge | Counter | Histogram
|
||||
PromMetricT = TypeVar("PromMetricT", bound=PromMetric)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class KVConnectorStats:
|
||||
"""
|
||||
Base class for KV Connector Stats, a container for transfer performance
|
||||
metrics or otherwise important telemetry from the connector.
|
||||
All sub-classes need to be serializable as stats are sent from worker to
|
||||
logger process.
|
||||
"""
|
||||
|
||||
data: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def reset(self):
|
||||
"""Reset the stats, clear the state."""
|
||||
raise NotImplementedError
|
||||
|
||||
def aggregate(self, other: "KVConnectorStats") -> "KVConnectorStats":
|
||||
"""
|
||||
Aggregate stats with another `KVConnectorStats` object.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def reduce(self) -> dict[str, int | float]:
|
||||
"""
|
||||
Reduce the observations collected during a time interval to one or
|
||||
more representative values (eg avg/median/sum of the series).
|
||||
This is meant to be called by the logger to produce a summary of the
|
||||
stats for the last time interval.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
"""Return True if the stats are empty."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class KVConnectorLogging:
|
||||
def __init__(self, kv_transfer_config: KVTransferConfig | None):
|
||||
# Instantiate the connector's stats class.
|
||||
if kv_transfer_config and kv_transfer_config.kv_connector:
|
||||
self.connector_cls = KVConnectorFactory.get_connector_class(
|
||||
kv_transfer_config
|
||||
)
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.transfer_stats_accumulator: KVConnectorStats | None = None
|
||||
|
||||
def observe(self, transfer_stats_data: dict[str, Any]):
|
||||
# Should not be called when a KVConnector is not configured.
|
||||
assert self.connector_cls is not None
|
||||
# Called periodically when connector syncs with the scheduler.
|
||||
# Note that this is not the same as the logging interval.
|
||||
# We expect transfer_stats_data to be aggregated across all workers and
|
||||
# consist of observations from a single connector or a MultiConnector.
|
||||
transfer_stats = self.connector_cls.build_kv_connector_stats(
|
||||
transfer_stats_data
|
||||
)
|
||||
if transfer_stats is None:
|
||||
logger.warning_once(
|
||||
"The connector %s is collecting stats but "
|
||||
"does not implement the "
|
||||
"`build_kv_connector_stats` method. "
|
||||
"Stats will not be logged.",
|
||||
self.connector_cls,
|
||||
)
|
||||
return
|
||||
|
||||
if self.transfer_stats_accumulator is None:
|
||||
self.transfer_stats_accumulator = transfer_stats
|
||||
else:
|
||||
# Accumulate last interval stats.
|
||||
self.transfer_stats_accumulator = self.transfer_stats_accumulator.aggregate(
|
||||
transfer_stats
|
||||
)
|
||||
|
||||
def log(self, log_fn=logger.info):
|
||||
"""Log transfer metrics periodically, similar to throughput logging"""
|
||||
if (
|
||||
self.transfer_stats_accumulator
|
||||
and not self.transfer_stats_accumulator.is_empty()
|
||||
):
|
||||
# Produce a single cumulative stats object for the last time
|
||||
# interval from the recorded observations.
|
||||
xfer_metrics = self.transfer_stats_accumulator.reduce()
|
||||
xfer_metrics_str = ", ".join(f"{k}={v}" for k, v in xfer_metrics.items())
|
||||
log_fn("KV Transfer metrics: %s", xfer_metrics_str)
|
||||
|
||||
# Reset metrics for next interval
|
||||
self.reset()
|
||||
|
||||
|
||||
class KVConnectorPromMetrics:
|
||||
"""
|
||||
A base class for per-connector Prometheus metric registration
|
||||
and recording.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
metric_types: dict[type[PromMetric], type[PromMetricT]],
|
||||
labelnames: list[str],
|
||||
per_engine_labelvalues: dict[int, list[object]],
|
||||
):
|
||||
self._kv_transfer_config = vllm_config.kv_transfer_config
|
||||
self._gauge_cls = metric_types[Gauge]
|
||||
self._counter_cls = metric_types[Counter]
|
||||
self._histogram_cls = metric_types[Histogram]
|
||||
self._labelnames = labelnames
|
||||
self.per_engine_labelvalues = per_engine_labelvalues
|
||||
|
||||
def make_per_engine(self, metric: PromMetric) -> dict[int, PromMetric]:
|
||||
"""
|
||||
Create a per-engine child of a prometheus_client.Metric with
|
||||
the appropriate labels set. The parent metric must be created
|
||||
using the labelnames list.
|
||||
"""
|
||||
return {
|
||||
idx: metric.labels(*labelvalues)
|
||||
for idx, labelvalues in self.per_engine_labelvalues.items()
|
||||
}
|
||||
|
||||
def observe(self, transfer_stats_data: dict[str, Any], engine_idx: int = 0):
|
||||
"""
|
||||
Record the supplied transfer statistics to Prometheus metrics. These
|
||||
statistics are engine-specific, and should be recorded to a metric
|
||||
with the appropriate 'engine' label. These metric instances can be
|
||||
created using the make_per_engine() helper method.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class KVConnectorPrometheus:
|
||||
"""
|
||||
Support for registering per-connector Prometheus metrics, and
|
||||
recording transfer statistics to those metrics. Uses
|
||||
KVConnectorBase.build_prom_metrics().
|
||||
"""
|
||||
|
||||
_gauge_cls = Gauge
|
||||
_counter_cls = Counter
|
||||
_histogram_cls = Histogram
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
labelnames: list[str],
|
||||
per_engine_labelvalues: dict[int, list[object]],
|
||||
):
|
||||
self.prom_metrics: KVConnectorPromMetrics | None = None
|
||||
kv_transfer_config = vllm_config.kv_transfer_config
|
||||
if kv_transfer_config and kv_transfer_config.kv_connector:
|
||||
connector_cls = KVConnectorFactory.get_connector_class(kv_transfer_config)
|
||||
metric_types = {
|
||||
Gauge: self._gauge_cls,
|
||||
Counter: self._counter_cls,
|
||||
Histogram: self._histogram_cls,
|
||||
}
|
||||
self.prom_metrics = connector_cls.build_prom_metrics(
|
||||
vllm_config,
|
||||
metric_types,
|
||||
labelnames,
|
||||
per_engine_labelvalues,
|
||||
)
|
||||
|
||||
def observe(self, transfer_stats_data: dict[str, Any], engine_idx: int = 0):
|
||||
if self.prom_metrics is None:
|
||||
return
|
||||
self.prom_metrics.observe(transfer_stats_data, engine_idx)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,127 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.utils import EngineId
|
||||
from vllm.logger import init_logger
|
||||
|
||||
WorkerAddr = str
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class RegisterWorkerPayload(BaseModel):
|
||||
engine_id: EngineId
|
||||
dp_rank: int
|
||||
tp_rank: int
|
||||
pp_rank: int
|
||||
addr: WorkerAddr
|
||||
|
||||
|
||||
@dataclass
|
||||
class EngineEntry:
|
||||
engine_id: EngineId
|
||||
# {tp_rank: {pp_rank: worker_addr}}
|
||||
worker_addr: dict[int, dict[int, WorkerAddr]]
|
||||
|
||||
|
||||
class MooncakeBootstrapServer:
|
||||
"""
|
||||
A centralized server running on the global rank 0 prefiller worker.
|
||||
Prefiller workers register their connection info (IP, port, ranks) here.
|
||||
"""
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, host: str, port: int):
|
||||
self.workers: dict[int, EngineEntry] = {}
|
||||
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.app = FastAPI()
|
||||
self._register_routes()
|
||||
self.server_thread: threading.Thread | None = None
|
||||
self.server: uvicorn.Server | None = None
|
||||
|
||||
def __del__(self):
|
||||
self.shutdown()
|
||||
|
||||
def _register_routes(self):
|
||||
# All methods are async. No need to use lock to protect data.
|
||||
self.app.post("/register")(self.register_worker)
|
||||
self.app.get("/query", response_model=dict[int, EngineEntry])(self.query)
|
||||
|
||||
def start(self):
|
||||
if self.server_thread:
|
||||
return
|
||||
|
||||
config = uvicorn.Config(app=self.app, host=self.host, port=self.port)
|
||||
self.server = uvicorn.Server(config=config)
|
||||
self.server_thread = threading.Thread(
|
||||
target=self.server.run, name="mooncake_bootstrap_server", daemon=True
|
||||
)
|
||||
self.server_thread.start()
|
||||
while not self.server.started:
|
||||
time.sleep(0.1) # Wait for the server to start
|
||||
logger.info("Mooncake Bootstrap Server started at %s:%d", self.host, self.port)
|
||||
|
||||
def shutdown(self):
|
||||
if self.server_thread is None or self.server is None or not self.server.started:
|
||||
return
|
||||
|
||||
self.server.should_exit = True
|
||||
self.server_thread.join()
|
||||
logger.info("Mooncake Bootstrap Server stopped.")
|
||||
|
||||
async def register_worker(self, payload: RegisterWorkerPayload):
|
||||
"""Handles registration of a prefiller worker."""
|
||||
if payload.dp_rank not in self.workers:
|
||||
self.workers[payload.dp_rank] = EngineEntry(
|
||||
engine_id=payload.engine_id,
|
||||
worker_addr={},
|
||||
)
|
||||
|
||||
dp_entry = self.workers[payload.dp_rank]
|
||||
if dp_entry.engine_id != payload.engine_id:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=(
|
||||
f"Engine ID mismatch for dp_rank={payload.dp_rank}: "
|
||||
f"expected {dp_entry.engine_id}, got {payload.engine_id}"
|
||||
),
|
||||
)
|
||||
if payload.tp_rank not in dp_entry.worker_addr:
|
||||
dp_entry.worker_addr[payload.tp_rank] = {}
|
||||
|
||||
tp_entry = dp_entry.worker_addr[payload.tp_rank]
|
||||
if payload.pp_rank in tp_entry:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=(
|
||||
f"Worker with dp_rank={payload.dp_rank}, "
|
||||
f"tp_rank={payload.tp_rank}, pp_rank={payload.pp_rank} "
|
||||
f"is already registered at "
|
||||
f"{tp_entry[payload.pp_rank]}, "
|
||||
f"but still want to register at {payload.addr}"
|
||||
),
|
||||
)
|
||||
|
||||
tp_entry[payload.pp_rank] = payload.addr
|
||||
logger.debug(
|
||||
"Registered worker: engine_id=%s, dp_rank=%d, tp_rank=%d, pp_rank=%d at %s",
|
||||
payload.engine_id,
|
||||
payload.dp_rank,
|
||||
payload.tp_rank,
|
||||
payload.pp_rank,
|
||||
payload.addr,
|
||||
)
|
||||
|
||||
return {"status": "ok"}
|
||||
|
||||
async def query(self) -> dict[int, EngineEntry]:
|
||||
return self.workers
|
||||
@@ -0,0 +1,321 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import contextlib
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Iterator
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import msgspec
|
||||
import torch
|
||||
import zmq
|
||||
|
||||
from vllm import envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorMetadata,
|
||||
)
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.network_utils import (
|
||||
get_ip,
|
||||
get_open_port,
|
||||
make_zmq_socket,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
from dataclasses import field
|
||||
from enum import Enum
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
Transfer = tuple[int, float]
|
||||
EngineId = str
|
||||
ReqId = str
|
||||
|
||||
|
||||
@dataclass
|
||||
class WriteTask:
|
||||
request_id: str
|
||||
dst_engine_id: str
|
||||
local_block_ids: list[int]
|
||||
remote_block_ids_hint: list[int] | None
|
||||
layer_name: str
|
||||
event: torch.cuda.Event
|
||||
remote_notify_port: int
|
||||
remote_ip: str
|
||||
enqueue_time: float = field(default_factory=time.perf_counter)
|
||||
retried: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class LayerTransferPlan:
|
||||
"""Plan for transferring a single layer."""
|
||||
|
||||
request_id: str
|
||||
layer_name: str
|
||||
sess_idx: int
|
||||
transfer_local_offsets: list[int]
|
||||
transfer_remote_offsets: list[int]
|
||||
transfer_sizes: list[int]
|
||||
use_batch: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class RemoteAllocInfo:
|
||||
"""Information about remote block allocation."""
|
||||
|
||||
block_ids: list[int]
|
||||
writes_done: int = 0
|
||||
decode_dp_rank: int = 0
|
||||
transfer_offset: tuple[list[int], list[int], list[int]] | None = None
|
||||
|
||||
|
||||
class ROLE(Enum):
|
||||
PRODUCER = "producer"
|
||||
CONSUMER = "consumer"
|
||||
NOTINIT = "notinit"
|
||||
|
||||
|
||||
class MoRIIOAgentMetadata(
|
||||
msgspec.Struct,
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
# required for @cached_property.d
|
||||
dict=True,
|
||||
):
|
||||
engine_id: str
|
||||
agent_metadata: bytes
|
||||
kv_caches_base_addr: list[int]
|
||||
num_blocks: int
|
||||
block_len: int
|
||||
attn_backend_name: str
|
||||
|
||||
|
||||
class RoleManager:
|
||||
"""Manages role state across the connector."""
|
||||
|
||||
_instance: "RoleManager | None" = None
|
||||
_lock = threading.Lock()
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._role: ROLE = ROLE.NOTINIT
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> "RoleManager":
|
||||
if cls._instance is None:
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = cls()
|
||||
return cls._instance
|
||||
|
||||
def set_role(self, role: ROLE) -> None:
|
||||
"""Set the current role."""
|
||||
with self._lock:
|
||||
self._role = role
|
||||
|
||||
def get_role(self) -> ROLE:
|
||||
"""Get the current role."""
|
||||
return self._role
|
||||
|
||||
|
||||
def set_role(role: ROLE):
|
||||
"""Set the global role."""
|
||||
RoleManager.get_instance().set_role(role)
|
||||
|
||||
|
||||
def get_role() -> ROLE:
|
||||
"""Get the global role."""
|
||||
return RoleManager.get_instance().get_role()
|
||||
|
||||
|
||||
class MoRIIOMode(Enum):
|
||||
READ = "read"
|
||||
WRITE = "write"
|
||||
|
||||
|
||||
class MoRIIOError(Exception):
|
||||
"""Base exception for MoRIIO operations."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class HandshakeError(MoRIIOError):
|
||||
"""Exception raised when handshake fails."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TransferError(MoRIIOError):
|
||||
"""Exception raised when transfer fails."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def get_moriio_mode() -> MoRIIOMode:
|
||||
read_mode = envs.VLLM_MORIIO_CONNECTOR_READ_MODE
|
||||
logger.debug("MoRIIO Connector read_mode: %s", read_mode)
|
||||
if read_mode:
|
||||
return MoRIIOMode.READ
|
||||
else:
|
||||
return MoRIIOMode.WRITE
|
||||
|
||||
|
||||
def get_port_offset(dp_rank: int, tp_rank: int, tp_size: int = 1) -> int:
|
||||
return (dp_rank) * tp_size + tp_rank
|
||||
|
||||
|
||||
@dataclass
|
||||
class MoRIIOConfig:
|
||||
local_ip: str
|
||||
local_kv_port: int
|
||||
proxy_ip: str
|
||||
local_ping_port: int
|
||||
proxy_ping_port: int
|
||||
http_port: int
|
||||
handshake_port: int
|
||||
notify_port: int
|
||||
tp_rank: int
|
||||
dp_rank: int
|
||||
dp_size: int
|
||||
tp_size: int
|
||||
|
||||
@classmethod
|
||||
def from_vllm_config(cls, vllm_config: VllmConfig) -> "MoRIIOConfig":
|
||||
# Port Configuration:
|
||||
# local_ping_port -> Outgoing heartbeat to proxy
|
||||
# proxy_ping_port -> Remote proxy's heartbeat ingress port
|
||||
# http_port -> Instance's HTTP service endpoint
|
||||
# local_kv_port -> service port for mori engine
|
||||
# notify_port -> For synchronizing stages between prefill and decode
|
||||
# handshake_port -> For initial handshake between mori engine
|
||||
|
||||
# TODO : merge notify_port and handshake_port to simplify port management
|
||||
# supports non-contiguous ports
|
||||
assert vllm_config.kv_transfer_config is not None, (
|
||||
"kv_transfer_config must be set for MoRIIOConnector"
|
||||
)
|
||||
kv_transfer_config = vllm_config.kv_transfer_config
|
||||
extra_config = kv_transfer_config.kv_connector_extra_config
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
dp_rank = vllm_config.parallel_config.data_parallel_rank
|
||||
base_notify_port = int(extra_config["notify_port"])
|
||||
dp_size = vllm_config.parallel_config.data_parallel_size
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
port_offset = get_port_offset(dp_rank, tp_rank)
|
||||
|
||||
return cls(
|
||||
local_ip=get_ip(),
|
||||
local_kv_port=get_open_port(),
|
||||
proxy_ip=extra_config["proxy_ip"],
|
||||
local_ping_port=get_open_port(),
|
||||
proxy_ping_port=int(extra_config["proxy_ping_port"]),
|
||||
http_port=int(extra_config["http_port"]),
|
||||
handshake_port=int(extra_config["handshake_port"]),
|
||||
notify_port=base_notify_port + port_offset,
|
||||
tp_rank=tp_rank,
|
||||
dp_rank=dp_rank,
|
||||
dp_size=dp_size,
|
||||
tp_size=tp_size,
|
||||
)
|
||||
|
||||
|
||||
class MoRIIOConstants:
|
||||
"""Constants for MoRIIO connector."""
|
||||
|
||||
# ZMQ message types
|
||||
GET_META_MSG = b"get_meta_msg"
|
||||
POP_DONE_RECV = b"pop_done_recv"
|
||||
OVER = b"OVER"
|
||||
COMPLETION_PREFIX = "cmpl"
|
||||
|
||||
PING_INTERVAL = 5
|
||||
MAX_PING_RETRIES = 100
|
||||
DEFAULT_HANDSHAKE_PORT = "6301"
|
||||
DEFAULT_NOTIFY_PORT = "61005"
|
||||
|
||||
VLLM_MORI_READ_ABORT_REQUEST_TIMEOUT = 3600
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReqMeta:
|
||||
"""Metadata for a single request."""
|
||||
|
||||
local_block_ids: list[int]
|
||||
remote_block_ids: list[int]
|
||||
remote_host: str
|
||||
remote_port: int
|
||||
remote_handshake_port: int
|
||||
remote_notify_port: int
|
||||
remote_engine_id: str
|
||||
tp_size: int
|
||||
remote_dp_size: int
|
||||
|
||||
|
||||
class MoRIIOConnectorMetadata(KVConnectorMetadata):
|
||||
def __init__(self):
|
||||
self.reqs_to_recv: dict[ReqId, ReqMeta] = {}
|
||||
self.reqs_to_save: dict[ReqId, ReqMeta] = {}
|
||||
self.reqs_to_send: dict[ReqId, float] = {}
|
||||
|
||||
def __repr__(self):
|
||||
return_str = ""
|
||||
for req_id, req_meta in self.reqs_to_recv.items():
|
||||
return_str += (
|
||||
f"{req_id = },{req_meta.local_block_ids = },"
|
||||
f"{req_meta.remote_host = },{req_meta.remote_port = }"
|
||||
f"{req_meta.remote_engine_id = },{req_meta.tp_size = }"
|
||||
)
|
||||
return_str = f"MoRIIOConnectorMetadata:reqs_to_recv:{return_str},"
|
||||
|
||||
for req_id, expiry in self.reqs_to_send.items():
|
||||
return_str += f"{req_id = },{expiry = }"
|
||||
return_str = f"MoRIIOConnectorMetadata:reqs_to_send:{return_str},"
|
||||
return return_str
|
||||
|
||||
def add_new_req(
|
||||
self,
|
||||
request_id: ReqId,
|
||||
local_block_ids: list[int],
|
||||
kv_transfer_params: dict[str, Any],
|
||||
write_mode=False,
|
||||
):
|
||||
_req = ReqMeta(
|
||||
local_block_ids=local_block_ids,
|
||||
remote_block_ids=kv_transfer_params["remote_block_ids"],
|
||||
remote_engine_id=kv_transfer_params["remote_engine_id"],
|
||||
remote_host=kv_transfer_params["remote_host"],
|
||||
remote_port=kv_transfer_params["remote_port"],
|
||||
remote_handshake_port=kv_transfer_params["remote_handshake_port"],
|
||||
remote_notify_port=kv_transfer_params["remote_notify_port"],
|
||||
tp_size=kv_transfer_params.get("tp_size", 1),
|
||||
remote_dp_size=kv_transfer_params.get("remote_dp_size", 1),
|
||||
)
|
||||
if write_mode:
|
||||
self.reqs_to_save[request_id] = _req
|
||||
else:
|
||||
self.reqs_to_recv[request_id] = _req
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def zmq_ctx(socket_type: Any, addr: str) -> Iterator[zmq.Socket]:
|
||||
"""Context manager for a ZMQ socket"""
|
||||
|
||||
if socket_type not in (zmq.ROUTER, zmq.REQ, zmq.DEALER):
|
||||
raise ValueError(f"Unexpected socket type: {socket_type}")
|
||||
|
||||
ctx: zmq.Context | None = None
|
||||
try:
|
||||
ctx = zmq.Context() # type: ignore[attr-defined]
|
||||
yield make_zmq_socket(
|
||||
ctx=ctx, path=addr, socket_type=socket_type, bind=socket_type == zmq.ROUTER
|
||||
)
|
||||
finally:
|
||||
if ctx is not None:
|
||||
ctx.destroy(linger=0)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,609 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import threading
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from weakref import ref as weakref_ref
|
||||
|
||||
import msgpack
|
||||
import torch
|
||||
import zmq
|
||||
|
||||
from vllm import envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.network_utils import (
|
||||
make_zmq_path,
|
||||
make_zmq_socket,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
from queue import Empty, Queue
|
||||
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common import (
|
||||
ROLE,
|
||||
HandshakeError,
|
||||
LayerTransferPlan,
|
||||
MoRIIOAgentMetadata,
|
||||
MoRIIOConstants,
|
||||
MoRIIOError,
|
||||
RemoteAllocInfo,
|
||||
TransferError,
|
||||
WriteTask,
|
||||
get_port_offset,
|
||||
get_role,
|
||||
zmq_ctx,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector import (
|
||||
MoRIIOConnectorWorker,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
try:
|
||||
from mori.io import (
|
||||
EngineDesc,
|
||||
IOEngine,
|
||||
MemoryDesc,
|
||||
PollCqMode,
|
||||
RdmaBackendConfig,
|
||||
)
|
||||
|
||||
logger.info("MoRIIO is available")
|
||||
except ImportError:
|
||||
logger.error("MoRIIO is not available")
|
||||
|
||||
|
||||
"""Write task execution logic for MoRIIO connector."""
|
||||
|
||||
|
||||
class MoRIIOWriter:
|
||||
"""Handles write operations for KV cache transfers.
|
||||
Implements distributed KV cache transfer using the MoRIIO library
|
||||
for RDMA-based communication between prefill and decode instances."""
|
||||
|
||||
def __init__(self, worker: "MoRIIOConnectorWorker"):
|
||||
"""Initialize the writer.
|
||||
|
||||
Args:
|
||||
worker: Reference to the parent worker
|
||||
"""
|
||||
self._worker_ref: weakref_ref[MoRIIOConnectorWorker] = weakref_ref(worker)
|
||||
self._write_task_q: Queue[WriteTask] = Queue()
|
||||
self._write_worker_started = False
|
||||
self._write_worker_lock = threading.Lock()
|
||||
self._deferred_tasks: list[WriteTask] = []
|
||||
|
||||
@property
|
||||
def worker(self) -> "MoRIIOConnectorWorker":
|
||||
"""Get the worker instance.
|
||||
|
||||
Returns:
|
||||
The parent worker instance
|
||||
|
||||
Raises:
|
||||
RuntimeError: If worker has been garbage collected
|
||||
"""
|
||||
worker = self._worker_ref()
|
||||
if worker is None:
|
||||
raise RuntimeError("Parent worker has been garbage collected")
|
||||
return worker
|
||||
|
||||
def ensure_worker_started(self) -> None:
|
||||
"""Ensure the background write worker is running."""
|
||||
if self._write_worker_started:
|
||||
return
|
||||
self._write_worker_started = True
|
||||
with self._write_worker_lock:
|
||||
thread = threading.Thread(
|
||||
target=self._write_worker_loop, daemon=True, name="moriio-write-worker"
|
||||
)
|
||||
thread.start()
|
||||
logger.info("Started MoRIIO write worker thread")
|
||||
|
||||
def schedule_write(self, task: WriteTask) -> None:
|
||||
"""Schedule a write task.
|
||||
|
||||
Args:
|
||||
task: The write task to schedule
|
||||
"""
|
||||
self.ensure_worker_started()
|
||||
self._write_task_q.put(task)
|
||||
|
||||
def _write_worker_loop(self) -> None:
|
||||
"""Main loop for the write worker thread."""
|
||||
|
||||
while True:
|
||||
# Process deferred tasks first
|
||||
self._process_deferred_tasks()
|
||||
|
||||
# Get new task
|
||||
try:
|
||||
task = self._write_task_q.get(timeout=0.01)
|
||||
except Empty:
|
||||
continue
|
||||
|
||||
# Check if remote blocks are ready
|
||||
if not self._is_remote_ready(task):
|
||||
# task.retry_count += 1
|
||||
self._deferred_tasks.append(task)
|
||||
# logger.debug(
|
||||
# "Deferred task for request %s (retry %d)",
|
||||
# task.request_id, task.retry_count
|
||||
# )
|
||||
continue
|
||||
|
||||
# Execute the task
|
||||
|
||||
self._execute_write_task(task)
|
||||
|
||||
def _process_deferred_tasks(self) -> None:
|
||||
"""Process tasks that were previously deferred."""
|
||||
if not self._deferred_tasks:
|
||||
return
|
||||
|
||||
still_deferred: list[WriteTask] = []
|
||||
for task in self._deferred_tasks:
|
||||
if self._is_remote_ready(task):
|
||||
self._execute_write_task(task)
|
||||
else:
|
||||
still_deferred.append(task)
|
||||
|
||||
self._deferred_tasks = still_deferred
|
||||
|
||||
def _is_remote_ready(self, task: WriteTask) -> bool:
|
||||
"""Check if remote blocks are allocated for this task.
|
||||
|
||||
Args:
|
||||
task: The write task
|
||||
|
||||
Returns:
|
||||
True if remote blocks are ready
|
||||
"""
|
||||
return (
|
||||
task.request_id in self.worker.moriio_wrapper.done_remote_allocate_req_dict
|
||||
)
|
||||
|
||||
def _get_remote_alloc_info(self, request_id: str) -> RemoteAllocInfo:
|
||||
"""Get remote allocation info for a request.
|
||||
|
||||
Args:
|
||||
request_id: The request ID
|
||||
|
||||
Returns:
|
||||
Remote allocation information
|
||||
|
||||
Raises:
|
||||
KeyError: If allocation info is missing
|
||||
"""
|
||||
try:
|
||||
return self.worker.moriio_wrapper.done_remote_allocate_req_dict[request_id]
|
||||
except KeyError as e:
|
||||
raise KeyError(
|
||||
f"Remote allocation info missing for request {request_id}"
|
||||
) from e
|
||||
|
||||
def _execute_write_task(self, task: WriteTask) -> None:
|
||||
"""Execute a single write task.
|
||||
|
||||
Args:
|
||||
task: The write task to execute
|
||||
|
||||
"""
|
||||
# Get remote allocation info
|
||||
request_info = self._get_remote_alloc_info(task.request_id)
|
||||
|
||||
if request_info.block_ids is None:
|
||||
logger.debug("Request %s remote block IDs not ready", task.request_id)
|
||||
return
|
||||
|
||||
# Wait for CUDA event
|
||||
# The attention computation of the current layer cannot
|
||||
# overlap with the kv transfer task,
|
||||
# otherwise it will cause precision issues.
|
||||
# This event is used to synchronize the kv transfer and computation tasks.
|
||||
task.event.synchronize()
|
||||
|
||||
# Update engine ID with DP rank
|
||||
task.dst_engine_id = self.worker.get_engine_name_with_dp(
|
||||
task.dst_engine_id, request_info.decode_dp_rank
|
||||
)
|
||||
|
||||
# Get or create sessions
|
||||
sessions, remote_moriio_meta = self.worker._get_built_session(
|
||||
task.dst_engine_id
|
||||
)
|
||||
|
||||
# Prepare transfer plan
|
||||
plan = self._prepare_transfer_plan(task, request_info, remote_moriio_meta)
|
||||
|
||||
# Execute transfer
|
||||
self._do_layer_write(plan, sessions)
|
||||
|
||||
# Finalize if all layers complete
|
||||
self._finalize_if_complete(task, request_info)
|
||||
|
||||
def _prepare_transfer_plan(
|
||||
self,
|
||||
task: WriteTask,
|
||||
request_info: RemoteAllocInfo,
|
||||
remote_moriio_meta: MoRIIOAgentMetadata,
|
||||
) -> LayerTransferPlan:
|
||||
"""Prepare the transfer plan for a layer.
|
||||
|
||||
Args:
|
||||
task: The write task
|
||||
request_info: Remote allocation information
|
||||
|
||||
Returns:
|
||||
The transfer plan
|
||||
"""
|
||||
# Compute offsets if not cached
|
||||
if request_info.transfer_offset is None:
|
||||
offsets = self.worker._compute_block_transfer_offsets(
|
||||
task.layer_name,
|
||||
task.local_block_ids,
|
||||
request_info.block_ids,
|
||||
remote_moriio_meta,
|
||||
)
|
||||
request_info.transfer_offset = offsets
|
||||
|
||||
# Get session index
|
||||
layer_names = list(self.worker.layer_name_to_local_kv_cache_metadata.keys())
|
||||
sess_idx = layer_names.index(task.layer_name)
|
||||
|
||||
local_off, remote_off, sizes = request_info.transfer_offset
|
||||
|
||||
return LayerTransferPlan(
|
||||
request_id=task.request_id,
|
||||
layer_name=task.layer_name,
|
||||
sess_idx=sess_idx,
|
||||
transfer_local_offsets=local_off,
|
||||
transfer_remote_offsets=remote_off,
|
||||
transfer_sizes=sizes,
|
||||
use_batch=True,
|
||||
)
|
||||
|
||||
def _do_layer_write(self, plan: LayerTransferPlan, sessions: list) -> None:
|
||||
"""Perform the actual layer write.
|
||||
|
||||
Args:
|
||||
plan: The transfer plan
|
||||
sessions: List of transfer sessions
|
||||
"""
|
||||
if plan.use_batch:
|
||||
self.worker.moriio_wrapper.write_remote_data(
|
||||
plan.transfer_sizes,
|
||||
plan.transfer_local_offsets,
|
||||
plan.transfer_remote_offsets,
|
||||
sessions[plan.sess_idx],
|
||||
)
|
||||
else:
|
||||
for i in range(len(plan.transfer_local_offsets)):
|
||||
self.worker.moriio_wrapper.write_remote_data_single(
|
||||
plan.transfer_sizes[i],
|
||||
plan.transfer_local_offsets[i],
|
||||
plan.transfer_remote_offsets[i],
|
||||
plan.sess_idx,
|
||||
)
|
||||
|
||||
def _finalize_if_complete(
|
||||
self, task: WriteTask, request_info: RemoteAllocInfo
|
||||
) -> None:
|
||||
"""Finalize transfer if all layers are complete.
|
||||
|
||||
Args:
|
||||
task: The write task
|
||||
request_info: Remote allocation information
|
||||
"""
|
||||
request_info.writes_done += 1
|
||||
|
||||
if request_info.writes_done >= self.worker.num_layers:
|
||||
# Wait for transfer to complete
|
||||
self.worker.moriio_wrapper.waiting_for_transfer_complete()
|
||||
|
||||
remote_port = task.remote_notify_port + get_port_offset(
|
||||
request_info.decode_dp_rank, self.worker.tp_rank
|
||||
)
|
||||
# Consider using RDMA immediate data in decode side
|
||||
# to eliminate the need for this notification.
|
||||
# Consider including the first gen token from prefill in the notification
|
||||
|
||||
# Send completion notification
|
||||
self.worker.moriio_wrapper.send_notify(
|
||||
task.request_id, task.remote_ip, remote_port
|
||||
)
|
||||
# mark request as done, then we can free the blocks
|
||||
with self.worker.moriio_wrapper.lock:
|
||||
self.worker.moriio_wrapper.done_req_ids.append(task.request_id)
|
||||
del self.worker.moriio_wrapper.done_remote_allocate_req_dict[
|
||||
task.request_id
|
||||
]
|
||||
logger.debug(
|
||||
"Completed transfer for request %s, notified port %d",
|
||||
task.request_id,
|
||||
remote_port,
|
||||
)
|
||||
|
||||
|
||||
class MoRIIOWrapper:
|
||||
"""Wrapper for MoRIIO engine operations.
|
||||
|
||||
Handles both producer and consumer roles for KV cache transfers.
|
||||
|
||||
Args:
|
||||
moriio_engine: MoRIIO engine instance
|
||||
tp_rank: Tensor parallel rank
|
||||
dp_rank: Data parallel rank
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
moriio_engine: "IOEngine | None" = None,
|
||||
tp_rank: int = 0,
|
||||
dp_rank: int = 0,
|
||||
):
|
||||
self.tp_rank = tp_rank
|
||||
self.dp_rank = dp_rank
|
||||
self.moriio_engine = moriio_engine
|
||||
self.remote_memory_metadata = None
|
||||
self.local_memory_registered = False
|
||||
self.local_memory_metadata = None
|
||||
self.transfer_status: list[Any] = []
|
||||
self.remote_engine_ip: str | None = None
|
||||
self.notify_port: int | None = None
|
||||
self.lock = threading.Lock()
|
||||
self.done_req_ids: list[str] = []
|
||||
self.done_remote_allocate_req_dict: dict[str, RemoteAllocInfo] = {}
|
||||
self.done_write_cache_req_ids: list[str] = []
|
||||
self.notify_thread: threading.Thread | None = None
|
||||
self.sessions: list[IOEngine.Session] = []
|
||||
self.paths: dict[str, zmq.Socket] = {}
|
||||
|
||||
def set_moriio_engine(self, moriio_engine):
|
||||
assert moriio_engine is not None, (
|
||||
"You Cannot pass None engine to MoRIIOWrapper!"
|
||||
)
|
||||
self.moriio_engine = moriio_engine
|
||||
|
||||
def set_backend_type(self, backend_type):
|
||||
assert self.moriio_engine is not None, "MoRIIO engine must be set first"
|
||||
qp_per_transfer = envs.VLLM_MORIIO_QP_PER_TRANSFER
|
||||
post_batch_size = envs.VLLM_MORIIO_POST_BATCH_SIZE
|
||||
num_worker_threads = envs.VLLM_MORIIO_NUM_WORKERS
|
||||
poll_mode = PollCqMode.POLLING
|
||||
rdma_cfg = RdmaBackendConfig(
|
||||
qp_per_transfer,
|
||||
post_batch_size,
|
||||
num_worker_threads,
|
||||
poll_mode,
|
||||
)
|
||||
self.moriio_engine.create_backend(backend_type, rdma_cfg)
|
||||
|
||||
def get_agent_metadata(self):
|
||||
assert self.moriio_engine is not None, "MoRIIO engine must be set first"
|
||||
engine_metadata = self.moriio_engine.get_engine_desc()
|
||||
engine_metadata_packed = engine_metadata.pack()
|
||||
return engine_metadata_packed
|
||||
|
||||
def register_remote_engine(self, remote_packed_engine_metadata):
|
||||
assert self.moriio_engine is not None, "MoRIIO engine must be set first"
|
||||
consumer_engine_metadata = EngineDesc.unpack(remote_packed_engine_metadata)
|
||||
self.moriio_engine.register_remote_engine(consumer_engine_metadata)
|
||||
return consumer_engine_metadata.key
|
||||
|
||||
def register_local_tensor(self, tensor: torch.Tensor):
|
||||
assert self.moriio_engine is not None, "MoRIIO engine must be set first"
|
||||
try:
|
||||
self.local_memory_metadata = self.moriio_engine.register_torch_tensor(
|
||||
tensor
|
||||
)
|
||||
assert self.local_memory_metadata is not None, (
|
||||
"register_torch_tensor returned None"
|
||||
)
|
||||
local_memory_metadata_packed = self.local_memory_metadata.pack()
|
||||
except Exception as e:
|
||||
raise MoRIIOError(f"Failed to register local memory: {e}") from e
|
||||
self.local_memory_registered = True
|
||||
return local_memory_metadata_packed
|
||||
|
||||
def get_unpack_memory_metadata(self, packed_memory_metadata):
|
||||
return MemoryDesc.unpack(packed_memory_metadata)
|
||||
|
||||
def build_session(self, local_memory_metadata, remote_memory_metadata):
|
||||
assert self.moriio_engine is not None, "MoRIIO engine must be set first"
|
||||
return self.moriio_engine.create_session(
|
||||
local_memory_metadata, remote_memory_metadata
|
||||
)
|
||||
|
||||
def read_remote_data(
|
||||
self, transfer_size_byte, local_offset=0, remote_offset=0, session=None
|
||||
):
|
||||
assert self.local_memory_registered, "You have not register local memory data!"
|
||||
assert self.moriio_engine is not None, "MoRIIO engine must be set first"
|
||||
transfer_status = session.batch_read(
|
||||
local_offset,
|
||||
remote_offset,
|
||||
transfer_size_byte,
|
||||
self.moriio_engine.allocate_transfer_uid(),
|
||||
)
|
||||
|
||||
return transfer_status
|
||||
|
||||
def write_remote_data(
|
||||
self, transfer_size_byte, local_offset=0, remote_offset=0, session=None
|
||||
):
|
||||
assert self.local_memory_registered, "You have not register local memory data!"
|
||||
assert self.moriio_engine is not None, "MoRIIO engine must be set first"
|
||||
write_uid = self.moriio_engine.allocate_transfer_uid()
|
||||
|
||||
transfer_status = session.batch_write(
|
||||
local_offset, remote_offset, transfer_size_byte, write_uid
|
||||
)
|
||||
with self.lock:
|
||||
self.transfer_status.append(transfer_status)
|
||||
|
||||
def write_remote_data_single(
|
||||
self, transfer_size_byte, local_offset=0, remote_offset=0, sess_idx=0
|
||||
):
|
||||
assert self.local_memory_registered, "You have not register local memory data!"
|
||||
assert self.moriio_engine is not None, "MoRIIO engine must be set first"
|
||||
transfer_status = self.sessions[sess_idx].write(
|
||||
local_offset,
|
||||
remote_offset,
|
||||
transfer_size_byte,
|
||||
self.moriio_engine.allocate_transfer_uid(),
|
||||
)
|
||||
with self.lock:
|
||||
self.transfer_status.append(transfer_status)
|
||||
|
||||
def waiting_for_transfer_complete(self):
|
||||
if not self.transfer_status:
|
||||
return
|
||||
|
||||
transfers_to_wait = []
|
||||
with self.lock:
|
||||
transfers_to_wait = self.transfer_status[:]
|
||||
self.transfer_status.clear()
|
||||
|
||||
for status in transfers_to_wait:
|
||||
try:
|
||||
status.Wait()
|
||||
if not status.Succeeded():
|
||||
logger.error(
|
||||
"Transfer failed: %s, Code: %s", status.Message(), status.Code()
|
||||
)
|
||||
raise TransferError("MoRIIO transfer failed!")
|
||||
except Exception as e:
|
||||
logger.error("Transfer %s failed: %s", status, e)
|
||||
raise
|
||||
|
||||
def async_wait_reqid(self):
|
||||
assert self.notify_port is not None, "Notify port cannot be None"
|
||||
|
||||
if self.notify_thread is not None:
|
||||
return
|
||||
|
||||
def _async_wait():
|
||||
host = "*"
|
||||
path = make_zmq_path("tcp", host, self.notify_port)
|
||||
logger.info("Node starting to listen notify from path = %s", path)
|
||||
|
||||
with zmq_ctx(zmq.ROUTER, path) as sock:
|
||||
while True:
|
||||
try:
|
||||
identity, msg = sock.recv_multipart()
|
||||
self._handle_message(msg)
|
||||
except Exception as e:
|
||||
logger.error("Error processing message: %s", e)
|
||||
raise HandshakeError(f"Error processing message: {e}") from e
|
||||
|
||||
self.notify_thread = threading.Thread(
|
||||
target=_async_wait, daemon=True, name="moriio-notify-listener"
|
||||
)
|
||||
self.notify_thread.start()
|
||||
|
||||
def _handle_message(self, msg: bytes):
|
||||
"""Handles incoming messages from remote nodes."""
|
||||
# Handles incoming remote messages:
|
||||
# Prefill Role:
|
||||
# [write] mode: receives block information (allocation)
|
||||
# [read] mode: receives block release messages from decode side
|
||||
# Decode Role:
|
||||
# [write] mode: receives KV cache write completion notifications
|
||||
handled = False
|
||||
try:
|
||||
data = msgpack.loads(msg)
|
||||
if isinstance(data, dict) and "req_id" in data:
|
||||
self._handle_structured_message(data)
|
||||
|
||||
return
|
||||
except (msgpack.exceptions.ExtraData, msgpack.exceptions.UnpackException):
|
||||
logger.debug("Failed to decode msgpack message, will try as string")
|
||||
pass
|
||||
|
||||
try:
|
||||
msg_str = msg.decode("UTF-8")
|
||||
if msg_str.startswith(MoRIIOConstants.COMPLETION_PREFIX):
|
||||
self._handle_completion_message(msg_str)
|
||||
handled = True
|
||||
except UnicodeDecodeError:
|
||||
logger.warning("Received non-UTF8 message: %s", msg_str)
|
||||
if not handled:
|
||||
raise MoRIIOError(f"Unhandled message format: {msg_str}")
|
||||
|
||||
def _handle_structured_message(self, data: dict):
|
||||
assert get_role() == ROLE.PRODUCER, "Only prefill can get block messages"
|
||||
req_id = data["req_id"]
|
||||
block_notify_list = data.get("block_notify_list", [])
|
||||
decode_dp_rank = data.get("decode_rank", 0)
|
||||
assert len(block_notify_list) > 0, (
|
||||
"block_notify_list cannot be empty in remote allocate message"
|
||||
)
|
||||
|
||||
with self.lock:
|
||||
self.done_remote_allocate_req_dict[req_id] = RemoteAllocInfo(
|
||||
block_ids=block_notify_list, decode_dp_rank=decode_dp_rank
|
||||
)
|
||||
|
||||
def _handle_completion_message(self, msg: str):
|
||||
with self.lock:
|
||||
if get_role() == ROLE.PRODUCER:
|
||||
self.done_req_ids.append(msg)
|
||||
else:
|
||||
self.done_write_cache_req_ids.append(msg)
|
||||
|
||||
def send_notify(self, req_ids, remote_ip, remote_port):
|
||||
if not remote_ip or not remote_port:
|
||||
logger.warning("Missing remote_ip or remote_port for notification")
|
||||
return
|
||||
|
||||
path = make_zmq_path("tcp", remote_ip, remote_port)
|
||||
|
||||
if path not in self.paths:
|
||||
ctx = zmq.Context.instance()
|
||||
sock = make_zmq_socket(
|
||||
ctx=ctx, path=path, socket_type=zmq.DEALER, bind=False
|
||||
)
|
||||
self.paths[path] = sock
|
||||
|
||||
req_list = req_ids if isinstance(req_ids, list) else [req_ids]
|
||||
|
||||
sock = self.paths[path]
|
||||
try:
|
||||
for req_id in req_list:
|
||||
if not isinstance(req_id, str):
|
||||
logger.warning(
|
||||
"Invalid req_id type: %s, expected str", type(req_id)
|
||||
)
|
||||
continue
|
||||
sock.send(req_id.encode("utf-8"))
|
||||
except Exception as e:
|
||||
logger.error("Failed to send notification to %s: %s", path, e)
|
||||
self.paths.pop(path, None)
|
||||
raise
|
||||
|
||||
def pop_finished_req_ids(self):
|
||||
# producer invocation: get the set of completed requests at the decode
|
||||
with self.lock:
|
||||
done_send = set(self.done_req_ids)
|
||||
self.done_req_ids = []
|
||||
return done_send
|
||||
|
||||
def pop_finished_write_req_ids(self):
|
||||
# Call the consumer in write mode to get the collection after write completion
|
||||
with self.lock:
|
||||
done_write_cache = set(self.done_write_cache_req_ids)
|
||||
self.done_write_cache_req_ids = []
|
||||
return done_write_cache
|
||||
|
||||
def shutdown(self):
|
||||
logger.debug("Closing MoRIIOWrapper and cleaning up ZMQ sockets")
|
||||
for path, sock in self.paths.items():
|
||||
try:
|
||||
sock.close(linger=0)
|
||||
logger.debug("Closed ZMQ socket for path: %s", path)
|
||||
except Exception as e:
|
||||
logger.warning("Error closing ZMQ socket for path %s: %s", path, e)
|
||||
self.paths.clear()
|
||||
515
vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
Normal file
515
vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
Normal file
@@ -0,0 +1,515 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import copy
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.kv_transfer import KVTransferConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType
|
||||
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
CopyBlocksOp,
|
||||
KVConnectorBase_V1,
|
||||
KVConnectorHandshakeMetadata,
|
||||
KVConnectorMetadata,
|
||||
KVConnectorRole,
|
||||
)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
|
||||
KVConnectorPromMetrics,
|
||||
KVConnectorStats,
|
||||
PromMetric,
|
||||
PromMetricT,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backend import AttentionBackend, AttentionMetadata
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.outputs import KVConnectorOutput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.distributed.kv_events import KVCacheEvent
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.request import Request
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MultiKVConnectorMetadata(KVConnectorMetadata):
|
||||
metadata: tuple[KVConnectorMetadata, ...]
|
||||
extra_async_saves: dict[str, int] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class MultiKVConnectorStats(KVConnectorStats):
|
||||
"""
|
||||
Maintain a dict of KVConnectorStats objects, one for each connector.
|
||||
This is used to aggregate the stats from all connectors separately.
|
||||
"""
|
||||
|
||||
def aggregate(self, other: KVConnectorStats) -> KVConnectorStats:
|
||||
for connector_id, stats in other.data.items():
|
||||
if connector_id not in self.data:
|
||||
self[connector_id] = stats
|
||||
else:
|
||||
assert isinstance(stats, type(self.data[connector_id]))
|
||||
self[connector_id] = self[connector_id].aggregate(stats)
|
||||
return self
|
||||
|
||||
def reset(self):
|
||||
for stats in self.data.values():
|
||||
stats.reset()
|
||||
|
||||
def reduce(self) -> dict[str, Any]:
|
||||
# TODO (NickLucche) Adjust for logging on separate lines
|
||||
return {
|
||||
connector_id: stats.reduce() for connector_id, stats in self.data.items()
|
||||
}
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
return all(stats.is_empty() for stats in self.data.values())
|
||||
|
||||
def __getitem__(self, connector_id: str) -> KVConnectorStats:
|
||||
return self.data[connector_id]
|
||||
|
||||
def __setitem__(self, connector_id: str, stats: KVConnectorStats):
|
||||
self.data[connector_id] = stats
|
||||
|
||||
|
||||
class MultiKVConnectorPromMetrics(KVConnectorPromMetrics):
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: "VllmConfig",
|
||||
metric_types: dict[type[PromMetric], type[PromMetricT]],
|
||||
labelnames: list[str],
|
||||
per_engine_labelvalues: dict[int, list[object]],
|
||||
prom_metrics: dict[str, KVConnectorPromMetrics],
|
||||
):
|
||||
super().__init__(vllm_config, metric_types, labelnames, per_engine_labelvalues)
|
||||
self._prom_metrics = prom_metrics
|
||||
|
||||
def observe(self, transfer_stats_data: dict[str, Any], engine_idx: int = 0):
|
||||
for connector_id, stats_data in transfer_stats_data.items():
|
||||
assert connector_id in self._prom_metrics, (
|
||||
f"{connector_id} is not contained in the list of registered connectors "
|
||||
f"with Prometheus metrics support: {self._prom_metrics.keys()}"
|
||||
)
|
||||
self._prom_metrics[connector_id].observe(stats_data["data"], engine_idx)
|
||||
|
||||
|
||||
class MultiConnector(KVConnectorBase_V1):
|
||||
"""
|
||||
A wrapper for using multiple KVConnectors at the same time.
|
||||
|
||||
The current logic is:
|
||||
- Load KV from the first connector that advertises available tokens from
|
||||
get_num_new_matched_tokens(), based on the order in the config.
|
||||
- Save to all connectors.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: "VllmConfig",
|
||||
role: KVConnectorRole,
|
||||
kv_cache_config: "KVCacheConfig",
|
||||
):
|
||||
super().__init__(
|
||||
vllm_config=vllm_config, role=role, kv_cache_config=kv_cache_config
|
||||
)
|
||||
|
||||
self._connectors: list[KVConnectorBase_V1] = []
|
||||
self._ktc_kv_transfer_config = []
|
||||
for connector_cls, temp_config in self._get_connector_classes_and_configs(
|
||||
vllm_config
|
||||
):
|
||||
self._connectors.append(connector_cls(temp_config, role, kv_cache_config))
|
||||
self._ktc_kv_transfer_config.append(temp_config.kv_transfer_config)
|
||||
|
||||
# A mapping from request id to the index of the connector chosen to
|
||||
# load the request from (if any).
|
||||
self._requests_to_connector: dict[str, int] = {}
|
||||
|
||||
# Keeps track of *additional* remaining async saves (beyond 1) to be
|
||||
# finished per request. Not needed for async loads since we only allow
|
||||
# a single connector to load.
|
||||
# Propagated from scheduler to worker side via the connector metadata.
|
||||
self._extra_async_saves: dict[str, int] = {}
|
||||
|
||||
@property
|
||||
def prefer_cross_layer_blocks(self) -> bool:
|
||||
if not self._connectors:
|
||||
return False
|
||||
return all(c.prefer_cross_layer_blocks for c in self._connectors)
|
||||
|
||||
@classmethod
|
||||
def _get_connector_classes_and_configs(
|
||||
cls, vllm_config: "VllmConfig"
|
||||
) -> list[tuple[type[KVConnectorBaseType], "VllmConfig"]]:
|
||||
assert vllm_config.kv_transfer_config is not None
|
||||
ktcs = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||
"connectors"
|
||||
)
|
||||
assert ktcs is not None
|
||||
ret: list[tuple[type[KVConnectorBaseType], VllmConfig]] = []
|
||||
for ktc in ktcs:
|
||||
temp_config = copy.copy(vllm_config)
|
||||
engine_id = ktc.get("engine_id", vllm_config.kv_transfer_config.engine_id)
|
||||
temp_config.kv_transfer_config = KVTransferConfig(
|
||||
**ktc, engine_id=engine_id
|
||||
)
|
||||
ret.append(
|
||||
(
|
||||
KVConnectorFactory.get_connector_class(
|
||||
temp_config.kv_transfer_config
|
||||
),
|
||||
temp_config,
|
||||
)
|
||||
)
|
||||
return ret
|
||||
|
||||
def register_cross_layers_kv_cache(
|
||||
self, kv_cache: torch.Tensor, attn_backend: type[AttentionBackend]
|
||||
):
|
||||
# Register on all connectors
|
||||
for c in self._connectors:
|
||||
c.register_cross_layers_kv_cache(kv_cache, attn_backend)
|
||||
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
for c in self._connectors:
|
||||
c.register_kv_caches(kv_caches)
|
||||
|
||||
# We must override the base class method here because we need to bind
|
||||
# the metadata to each connector in the order of the connectors in the
|
||||
# MultiKVConnectorMetadata.
|
||||
#
|
||||
# Note: Call the base class method to ensure metadata is also set on the
|
||||
# MultiConnector instance itself; otherwise, `has_connector_metadata()` will
|
||||
# always return False.
|
||||
def bind_connector_metadata(self, connector_metadata: KVConnectorMetadata) -> None:
|
||||
assert isinstance(connector_metadata, MultiKVConnectorMetadata)
|
||||
if connector_metadata.extra_async_saves:
|
||||
self._extra_async_saves.update(connector_metadata.extra_async_saves)
|
||||
for c, cm in zip(self._connectors, connector_metadata.metadata):
|
||||
c.bind_connector_metadata(cm)
|
||||
super().bind_connector_metadata(connector_metadata)
|
||||
|
||||
def clear_connector_metadata(self) -> None:
|
||||
for c in self._connectors:
|
||||
c.clear_connector_metadata()
|
||||
super().clear_connector_metadata()
|
||||
|
||||
def shutdown(self):
|
||||
exception: Exception | None = None
|
||||
for c in self._connectors:
|
||||
try:
|
||||
c.shutdown()
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Exception during connector %s shutdown.", c.__class__.__name__
|
||||
)
|
||||
exception = e
|
||||
if exception:
|
||||
raise exception
|
||||
|
||||
# ==============================
|
||||
# Worker-side methods
|
||||
# ==============================
|
||||
def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
|
||||
for c in self._connectors:
|
||||
c.start_load_kv(forward_context, **kwargs)
|
||||
|
||||
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||
for c in self._connectors:
|
||||
c.wait_for_layer_load(layer_name)
|
||||
|
||||
def save_kv_layer(
|
||||
self,
|
||||
layer_name: str,
|
||||
kv_layer: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
for c in self._connectors:
|
||||
c.save_kv_layer(layer_name, kv_layer, attn_metadata, **kwargs)
|
||||
|
||||
def wait_for_save(self):
|
||||
for c in self._connectors:
|
||||
c.wait_for_save()
|
||||
|
||||
def get_finished(
|
||||
self, finished_req_ids: set[str]
|
||||
) -> tuple[set[str] | None, set[str] | None]:
|
||||
finished_sending: set[str] = set()
|
||||
finished_recving: set[str] = set()
|
||||
for c in self._connectors:
|
||||
sending, recving = c.get_finished(finished_req_ids)
|
||||
if not recving and not sending:
|
||||
continue
|
||||
# Aggregate finished recving request ids.
|
||||
finished_recving.update(recving or ())
|
||||
# Aggregate finished sending request ids - only include
|
||||
# once we've drained the "extra" count (for cases where
|
||||
# more than one connector is async-saving the same request).
|
||||
for req_id in sending or ():
|
||||
extra_pending = self._extra_async_saves.get(req_id)
|
||||
if extra_pending is None:
|
||||
finished_sending.add(req_id)
|
||||
continue
|
||||
assert extra_pending > 0
|
||||
if extra_pending == 1:
|
||||
del self._extra_async_saves[req_id]
|
||||
else:
|
||||
self._extra_async_saves[req_id] = extra_pending - 1
|
||||
|
||||
return finished_sending or None, finished_recving or None
|
||||
|
||||
def get_block_ids_with_load_errors(self) -> set[int]:
|
||||
agg_block_ids: set[int] = set()
|
||||
for c in self._connectors:
|
||||
agg_block_ids |= c.get_block_ids_with_load_errors()
|
||||
return agg_block_ids
|
||||
|
||||
def set_host_xfer_buffer_ops(self, copy_operation: CopyBlocksOp):
|
||||
"""Set xPU-specific copy ops for all sub-connectors."""
|
||||
for c in self._connectors:
|
||||
c.set_host_xfer_buffer_ops(copy_operation)
|
||||
|
||||
def handle_preemptions(self, preempted_req_ids: set[str]):
|
||||
"""Handle preempted requests for all sub-connectors."""
|
||||
for c in self._connectors:
|
||||
c.handle_preemptions(preempted_req_ids)
|
||||
|
||||
def get_finished_count(self) -> int | None:
|
||||
# TODO(https://github.com/vllm-project/vllm/issues/33400)
|
||||
# Currently no connectors return non-None
|
||||
return None
|
||||
|
||||
# TODO: Add a generic implementation of 'get_kv_connector_kv_cache_events'
|
||||
# method for the MultiConnector. It should be able to get events from
|
||||
# multiple connectors, handling the case where only a subset of the
|
||||
# requested connectors implements the 'get_kv_connector_kv_cache_events'
|
||||
# WIP: https://github.com/vllm-project/vllm/pull/31811
|
||||
|
||||
# ==============================
|
||||
# Scheduler-side methods
|
||||
# ==============================
|
||||
def get_num_new_matched_tokens(
|
||||
self,
|
||||
request: "Request",
|
||||
num_computed_tokens: int,
|
||||
) -> tuple[int | None, bool]:
|
||||
to_return = (0, False)
|
||||
for i, c in enumerate(self._connectors):
|
||||
toks, load_async = c.get_num_new_matched_tokens(
|
||||
request, num_computed_tokens
|
||||
)
|
||||
# If there is a connector still looking up the matches,
|
||||
# we return None to indicate that we are not done yet.
|
||||
if toks is None:
|
||||
return (None, False)
|
||||
# The first connector that has new matched tokens will be assigned
|
||||
# to this request.
|
||||
if to_return[0] == 0 and toks > 0:
|
||||
self._requests_to_connector[request.request_id] = i
|
||||
to_return = (toks, load_async)
|
||||
return to_return
|
||||
|
||||
def update_state_after_alloc(
|
||||
self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
|
||||
):
|
||||
chosen_connector = self._requests_to_connector.get(request.request_id, -1)
|
||||
empty_blocks = blocks.new_empty()
|
||||
for i, c in enumerate(self._connectors):
|
||||
if i == chosen_connector:
|
||||
# Forward call to the chosen connector (if any).
|
||||
c.update_state_after_alloc(request, blocks, num_external_tokens)
|
||||
else:
|
||||
# Call with empty blocks for other connectors.
|
||||
c.update_state_after_alloc(request, empty_blocks, 0)
|
||||
|
||||
def build_connector_meta(
|
||||
self, scheduler_output: SchedulerOutput
|
||||
) -> MultiKVConnectorMetadata:
|
||||
metadata = MultiKVConnectorMetadata(
|
||||
metadata=tuple(
|
||||
c.build_connector_meta(scheduler_output) for c in self._connectors
|
||||
)
|
||||
)
|
||||
if self._extra_async_saves:
|
||||
metadata.extra_async_saves = self._extra_async_saves
|
||||
self._extra_async_saves = {}
|
||||
return metadata
|
||||
|
||||
def update_connector_output(self, connector_output: KVConnectorOutput):
|
||||
for c in self._connectors:
|
||||
c.update_connector_output(connector_output)
|
||||
|
||||
def get_handshake_metadata(self) -> KVConnectorHandshakeMetadata | None:
|
||||
"""
|
||||
Get the KVConnector handshake metadata from sub-connectors.
|
||||
Returns the first non-None metadata from sub-connectors.
|
||||
"""
|
||||
for c in self._connectors:
|
||||
metadata = c.get_handshake_metadata()
|
||||
if metadata is not None:
|
||||
return metadata
|
||||
return None
|
||||
|
||||
def set_xfer_handshake_metadata(
|
||||
self, metadata: dict[int, KVConnectorHandshakeMetadata]
|
||||
) -> None:
|
||||
"""
|
||||
Set the KV connector handshake metadata for all sub-connectors.
|
||||
This is needed to start the NIXL listener thread for NixlConnector.
|
||||
"""
|
||||
for c in self._connectors:
|
||||
c.set_xfer_handshake_metadata(metadata)
|
||||
|
||||
def request_finished(
|
||||
self,
|
||||
request: "Request",
|
||||
blocks: list[int],
|
||||
) -> tuple[bool, dict[str, Any] | None]:
|
||||
async_saves = 0
|
||||
kv_txfer_params = None
|
||||
for c in self._connectors:
|
||||
async_save, txfer_params = c.request_finished(request, blocks)
|
||||
if async_save:
|
||||
async_saves += 1
|
||||
if txfer_params is not None:
|
||||
if kv_txfer_params is not None:
|
||||
# TODO we can probably change this to merge the dicts here,
|
||||
# checking for key clashes.
|
||||
raise RuntimeError(
|
||||
"Only one connector can produce KV transfer params"
|
||||
)
|
||||
kv_txfer_params = txfer_params
|
||||
if async_saves > 1:
|
||||
self._extra_async_saves[request.request_id] = async_saves - 1
|
||||
|
||||
# Clean up other state for this request.
|
||||
self._requests_to_connector.pop(request.request_id, None)
|
||||
|
||||
return async_saves > 0, kv_txfer_params
|
||||
|
||||
def take_events(self) -> Iterable["KVCacheEvent"]:
|
||||
for c in self._connectors:
|
||||
yield from c.take_events()
|
||||
|
||||
@classmethod
|
||||
def get_required_kvcache_layout(cls, vllm_config: "VllmConfig") -> str | None:
|
||||
"""
|
||||
Get the required KV cache layout for this connector.
|
||||
Args:
|
||||
vllm_config (VllmConfig): the vllm config.
|
||||
|
||||
Returns:
|
||||
str: the required KV cache layout. e.g. HND, or NHD.
|
||||
None if the connector does not require a specific layout.
|
||||
"""
|
||||
assert vllm_config.kv_transfer_config is not None
|
||||
layouts: set[str] = set()
|
||||
for connector_cls, temp_config in cls._get_connector_classes_and_configs(
|
||||
vllm_config
|
||||
):
|
||||
required_kvcache_layout = connector_cls.get_required_kvcache_layout(
|
||||
temp_config
|
||||
)
|
||||
if required_kvcache_layout is not None:
|
||||
layouts.add(required_kvcache_layout)
|
||||
|
||||
if len(layouts) > 1:
|
||||
raise ValueError(
|
||||
f"KV cache layout mismatch: "
|
||||
f"found {len(layouts)} different layouts "
|
||||
f"({', '.join(layouts)})."
|
||||
f"All connectors must use the same layout."
|
||||
)
|
||||
return next(iter(layouts), None)
|
||||
|
||||
@classmethod
|
||||
def build_kv_connector_stats(
|
||||
cls, data: dict[str, Any] | None = None
|
||||
) -> KVConnectorStats | None:
|
||||
if data is None:
|
||||
return MultiKVConnectorStats()
|
||||
|
||||
# data is a dict mapping connector name to their stats data.
|
||||
# The stats data can be either:
|
||||
# 1. Already-instantiated KVConnectorStats objects (same process)
|
||||
# 2. Serialized dicts (cross-process after serialization)
|
||||
# We need to reconstruct proper KVConnectorStats objects from dicts
|
||||
reconstructed_data = {}
|
||||
for connector_name, stats_value in data.items():
|
||||
# If already a KVConnectorStats object, use it directly
|
||||
if isinstance(stats_value, KVConnectorStats):
|
||||
reconstructed_data[connector_name] = stats_value
|
||||
continue
|
||||
|
||||
# Otherwise, reconstruct from serialized dict
|
||||
# Get the connector class to reconstruct its stats
|
||||
connector_cls = KVConnectorFactory.get_connector_class_by_name(
|
||||
connector_name
|
||||
)
|
||||
|
||||
# stats_value is the serialized dataclass which contains {'data': {...}}
|
||||
# We need to extract the inner 'data' field to avoid double-nesting
|
||||
assert isinstance(stats_value, dict) and "data" in stats_value, (
|
||||
f"Expected a dict with a 'data' field, got {stats_value}"
|
||||
)
|
||||
inner_data = stats_value["data"]
|
||||
|
||||
# Use the connector's build_kv_connector_stats to reconstruct
|
||||
if reconstructed_stats := connector_cls.build_kv_connector_stats(
|
||||
data=inner_data
|
||||
):
|
||||
reconstructed_data[connector_name] = reconstructed_stats
|
||||
|
||||
return MultiKVConnectorStats(data=reconstructed_data)
|
||||
|
||||
def get_kv_connector_stats(self) -> MultiKVConnectorStats | None:
|
||||
# Group connector stats by connector type.
|
||||
stats_by_connector: MultiKVConnectorStats | None = None
|
||||
for c in self._connectors:
|
||||
stats = c.get_kv_connector_stats()
|
||||
if stats is None:
|
||||
continue
|
||||
if stats_by_connector is None:
|
||||
# Lazy init to allow optional return value.
|
||||
stats_by_connector = MultiKVConnectorStats()
|
||||
stats_by_connector[c.__class__.__name__] = stats
|
||||
return stats_by_connector
|
||||
|
||||
@classmethod
|
||||
def build_prom_metrics(
|
||||
cls,
|
||||
vllm_config: "VllmConfig",
|
||||
metric_types: dict[type["PromMetric"], type["PromMetricT"]],
|
||||
labelnames: list[str],
|
||||
per_engine_labelvalues: dict[int, list[object]],
|
||||
) -> KVConnectorPromMetrics:
|
||||
prom_metrics: dict[str, KVConnectorPromMetrics] = {}
|
||||
for connector_cls, temp_config in cls._get_connector_classes_and_configs(
|
||||
vllm_config
|
||||
):
|
||||
connector_prom = connector_cls.build_prom_metrics(
|
||||
temp_config, metric_types, labelnames, per_engine_labelvalues
|
||||
)
|
||||
if connector_prom is not None:
|
||||
prom_metrics[connector_cls.__name__] = connector_prom
|
||||
return MultiKVConnectorPromMetrics(
|
||||
vllm_config,
|
||||
metric_types,
|
||||
labelnames,
|
||||
per_engine_labelvalues,
|
||||
prom_metrics,
|
||||
)
|
||||
|
||||
def reset_cache(self) -> bool:
|
||||
results = [c.reset_cache() is not False for c in self._connectors]
|
||||
return all(results)
|
||||
2790
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
Normal file
2790
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,800 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from itertools import islice
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.distributed.kv_events import BlockRemoved, BlockStored, KVCacheEvent
|
||||
from vllm.distributed.kv_transfer.kv_connector.utils import yield_req_data
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1 import (
|
||||
KVConnectorBase_V1,
|
||||
KVConnectorRole,
|
||||
)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
|
||||
KVConnectorPromMetrics,
|
||||
KVConnectorStats,
|
||||
PromMetric,
|
||||
PromMetricT,
|
||||
)
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention import Attention
|
||||
from vllm.v1.attention.backend import AttentionBackend, AttentionMetadata
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.core.kv_cache_utils import BlockHash
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.kv_offload.abstract import OffloadingManager
|
||||
from vllm.v1.kv_offload.factory import OffloadingSpecFactory
|
||||
from vllm.v1.kv_offload.mediums import GPULoadStoreSpec
|
||||
from vllm.v1.kv_offload.spec import OffloadingSpec
|
||||
from vllm.v1.kv_offload.worker.worker import (
|
||||
OffloadingWorker,
|
||||
TransferSpec,
|
||||
TransferType,
|
||||
)
|
||||
from vllm.v1.outputs import KVConnectorOutput
|
||||
from vllm.v1.request import Request
|
||||
|
||||
ReqId = str
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class OffloadingOperationMetrics:
|
||||
op_size: int
|
||||
op_time: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class OffloadingConnectorStats(KVConnectorStats):
|
||||
def __post_init__(self):
|
||||
if not self.data:
|
||||
# Empty container init, no data is passed in.
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.data: dict[str, list[OffloadingOperationMetrics]] = {}
|
||||
|
||||
def aggregate(self, other: KVConnectorStats) -> KVConnectorStats:
|
||||
if not other.is_empty():
|
||||
for k, v in other.data.items():
|
||||
if k not in self.data:
|
||||
self.data[k] = v
|
||||
else:
|
||||
accumulator = self.data[k]
|
||||
assert isinstance(accumulator, list)
|
||||
accumulator.extend(v)
|
||||
return self
|
||||
|
||||
def reduce(self) -> dict[str, int | float]:
|
||||
"""
|
||||
Reduce the observations collected during a time interval to one or
|
||||
more representative values (eg avg/median/sum of the series).
|
||||
This is meant to be called by the logger to produce a summary of the
|
||||
stats for the last time interval.
|
||||
"""
|
||||
return_dict: dict[str, int | float] = {}
|
||||
for transfer_type, ops_list in self.data.items():
|
||||
assert isinstance(ops_list, list)
|
||||
total_bytes = 0
|
||||
total_time = 0.0
|
||||
for op in ops_list:
|
||||
assert isinstance(op, dict)
|
||||
total_bytes += op["op_size"]
|
||||
total_time += op["op_time"]
|
||||
return_dict[f"{transfer_type}_total_bytes"] = total_bytes
|
||||
return_dict[f"{transfer_type}_total_time"] = total_time
|
||||
return return_dict
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
return not self.data
|
||||
|
||||
def record_transfer(self, num_bytes: int, time: float, transfer_type: TransferType):
|
||||
src, dst = transfer_type
|
||||
transfer_type_key = src + "_to_" + dst
|
||||
op = OffloadingOperationMetrics(num_bytes, time)
|
||||
if transfer_type_key in self.data:
|
||||
self.data[transfer_type_key].append(op)
|
||||
else:
|
||||
self.data[transfer_type_key] = [op]
|
||||
|
||||
|
||||
@dataclass
|
||||
class OffloadingConnectorMetadata(KVConnectorMetadata):
|
||||
reqs_to_load: dict[ReqId, TransferSpec]
|
||||
reqs_to_store: dict[ReqId, TransferSpec]
|
||||
|
||||
|
||||
class OffloadingConnector(KVConnectorBase_V1):
|
||||
@property
|
||||
def prefer_cross_layer_blocks(self) -> bool:
|
||||
return True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
role: KVConnectorRole,
|
||||
kv_cache_config: KVCacheConfig | None = None,
|
||||
):
|
||||
super().__init__(vllm_config, role, kv_cache_config)
|
||||
|
||||
spec = OffloadingSpecFactory.create_spec(vllm_config, kv_cache_config)
|
||||
|
||||
self.connector_scheduler: OffloadingConnectorScheduler | None = None
|
||||
self.connector_worker: OffloadingConnectorWorker | None = None
|
||||
if role == KVConnectorRole.SCHEDULER:
|
||||
self.connector_scheduler = OffloadingConnectorScheduler(spec)
|
||||
elif role == KVConnectorRole.WORKER:
|
||||
self.connector_worker = OffloadingConnectorWorker(spec)
|
||||
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
assert self.connector_worker is not None
|
||||
self.connector_worker.register_kv_caches(kv_caches)
|
||||
|
||||
def register_cross_layers_kv_cache(
|
||||
self, kv_cache: torch.Tensor, attn_backend: type[AttentionBackend]
|
||||
):
|
||||
assert self.connector_worker is not None
|
||||
self.connector_worker.register_cross_layers_kv_cache(kv_cache, attn_backend)
|
||||
|
||||
def handle_preemptions(self, preempted_req_ids: set[str]):
|
||||
assert self.connector_worker is not None
|
||||
self.connector_worker.handle_preemptions(preempted_req_ids)
|
||||
|
||||
def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
|
||||
assert self.connector_worker is not None
|
||||
assert isinstance(self._connector_metadata, OffloadingConnectorMetadata)
|
||||
self.connector_worker.start_kv_transfers(self._connector_metadata)
|
||||
|
||||
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||
pass
|
||||
|
||||
def save_kv_layer(
|
||||
self,
|
||||
layer_name: str,
|
||||
kv_layer: torch.Tensor,
|
||||
attn_metadata: "AttentionMetadata",
|
||||
**kwargs,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def wait_for_save(self):
|
||||
assert self.connector_worker is not None
|
||||
assert isinstance(self._connector_metadata, OffloadingConnectorMetadata)
|
||||
self.connector_worker.prepare_store_kv(self._connector_metadata)
|
||||
|
||||
def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
|
||||
assert self.connector_worker is not None
|
||||
return self.connector_worker.get_finished(finished_req_ids)
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self, request: "Request", num_computed_tokens: int
|
||||
) -> tuple[int | None, bool]:
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.get_num_new_matched_tokens(
|
||||
request, num_computed_tokens
|
||||
)
|
||||
|
||||
def update_state_after_alloc(
|
||||
self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
|
||||
):
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.update_state_after_alloc(
|
||||
request, blocks, num_external_tokens
|
||||
)
|
||||
|
||||
def build_connector_meta(
|
||||
self, scheduler_output: SchedulerOutput
|
||||
) -> KVConnectorMetadata:
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.build_connector_meta(scheduler_output)
|
||||
|
||||
def update_connector_output(self, connector_output: KVConnectorOutput):
|
||||
assert self.connector_scheduler is not None
|
||||
self.connector_scheduler.update_connector_output(connector_output)
|
||||
|
||||
def request_finished(
|
||||
self,
|
||||
request: "Request",
|
||||
block_ids: list[int],
|
||||
) -> tuple[bool, dict[str, Any] | None]:
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.request_finished(request, block_ids)
|
||||
|
||||
def take_events(self) -> Iterable[KVCacheEvent]:
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.take_events()
|
||||
|
||||
def get_kv_connector_stats(self) -> KVConnectorStats | None:
|
||||
if self.connector_worker is None:
|
||||
return None # We only emit stats from the worker-side
|
||||
return self.connector_worker.get_kv_connector_stats()
|
||||
|
||||
@classmethod
|
||||
def build_kv_connector_stats(
|
||||
cls, data: dict[str, Any] | None = None
|
||||
) -> KVConnectorStats | None:
|
||||
return (
|
||||
OffloadingConnectorStats(data=data)
|
||||
if data is not None
|
||||
else OffloadingConnectorStats()
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def build_prom_metrics(
|
||||
cls,
|
||||
vllm_config: VllmConfig,
|
||||
metric_types: dict[type[PromMetric], type[PromMetricT]],
|
||||
labelnames: list[str],
|
||||
per_engine_labelvalues: dict[int, list[object]],
|
||||
) -> KVConnectorPromMetrics:
|
||||
return OffloadPromMetrics(
|
||||
vllm_config, metric_types, labelnames, per_engine_labelvalues
|
||||
)
|
||||
|
||||
|
||||
class OffloadingConnectorScheduler:
|
||||
"""Implementation of Scheduler side methods"""
|
||||
|
||||
def __init__(self, spec: OffloadingSpec):
|
||||
self.gpu_block_size = spec.gpu_block_size
|
||||
self.offloaded_block_size = spec.offloaded_block_size
|
||||
self.block_size_factor = self.offloaded_block_size // self.gpu_block_size
|
||||
self.manager: OffloadingManager = spec.get_manager()
|
||||
|
||||
self._requests: dict[ReqId, Request] = {}
|
||||
# list of GPU block IDs per request
|
||||
self._request_block_ids: dict[ReqId, list[int]] = {}
|
||||
# requests to load for the current scheduler step
|
||||
self._reqs_to_load: dict[ReqId, TransferSpec] = {}
|
||||
# request blocks are stored in order
|
||||
# index of next block (of size offloaded_block_size) to offload
|
||||
self._next_stored_block_idx: dict[ReqId, int] = {}
|
||||
# if GPU prefix caching is enabled,
|
||||
# track loaded blocks to avoid redundant loads
|
||||
self._blocks_being_loaded: set[BlockHash] | None = (
|
||||
set() if spec.vllm_config.cache_config.enable_prefix_caching else None
|
||||
)
|
||||
|
||||
# request ID -> set(block hashes being stored/load)
|
||||
self._reqs_being_stored = defaultdict[ReqId, set[BlockHash]](set)
|
||||
self._reqs_being_loaded = defaultdict[ReqId, set[BlockHash]](set)
|
||||
|
||||
def _get_block_hashes(
|
||||
self,
|
||||
req: Request,
|
||||
start_idx: int = 0,
|
||||
end_idx: int | None = None,
|
||||
) -> Iterable[BlockHash]:
|
||||
return islice(
|
||||
req.block_hashes,
|
||||
self.block_size_factor * start_idx + self.block_size_factor - 1,
|
||||
self.block_size_factor * end_idx if end_idx else None,
|
||||
self.block_size_factor,
|
||||
)
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self, request: Request, num_computed_tokens: int
|
||||
) -> tuple[int | None, bool]:
|
||||
"""
|
||||
Get number of new tokens that can be loaded beyond the
|
||||
num_computed_tokens.
|
||||
|
||||
Args:
|
||||
request (Request): the request object.
|
||||
num_computed_tokens (int): the number of locally
|
||||
computed tokens for this request
|
||||
|
||||
Returns:
|
||||
A tuple with the following elements:
|
||||
- The number of tokens that can be loaded beyond what is
|
||||
already computed.
|
||||
If None, it means that the connector needs more time to
|
||||
determine the number of matched tokens, and the scheduler
|
||||
should query for this request again later.
|
||||
- `True` if tokens will be loaded asynchronously
|
||||
(between scheduler steps).
|
||||
"""
|
||||
num_blocks = request.num_tokens // self.offloaded_block_size
|
||||
|
||||
assert len(request.block_hashes) // self.block_size_factor == num_blocks
|
||||
block_hashes = self._get_block_hashes(request)
|
||||
|
||||
self.manager.touch(block_hashes)
|
||||
|
||||
full_block_tokens = self.offloaded_block_size * num_blocks
|
||||
if full_block_tokens - num_computed_tokens < self.offloaded_block_size:
|
||||
# we can load less than a block, skip
|
||||
return 0, False
|
||||
|
||||
start_block_idx = num_computed_tokens // self.offloaded_block_size
|
||||
hits = self.manager.lookup(
|
||||
self._get_block_hashes(request, start_idx=start_block_idx)
|
||||
)
|
||||
if hits is None:
|
||||
# indicates a lookup that should be tried later
|
||||
return None, False
|
||||
if hits == 0:
|
||||
return 0, False
|
||||
|
||||
num_hit_tokens = (
|
||||
self.offloaded_block_size * (start_block_idx + hits) - num_computed_tokens
|
||||
)
|
||||
logger.debug(
|
||||
"Request %s hit %s offloaded tokens after %s GPU hit tokens",
|
||||
request.request_id,
|
||||
num_hit_tokens,
|
||||
num_computed_tokens,
|
||||
)
|
||||
if num_hit_tokens < self.offloaded_block_size:
|
||||
return 0, False
|
||||
|
||||
if self._blocks_being_loaded:
|
||||
block_hashes = self._get_block_hashes(
|
||||
request, start_idx=start_block_idx, end_idx=start_block_idx + hits
|
||||
)
|
||||
|
||||
if any(
|
||||
block_hash in self._blocks_being_loaded for block_hash in block_hashes
|
||||
):
|
||||
# hit blocks are being loaded, delay request
|
||||
logger.debug(
|
||||
"Delaying request %s since some of its blocks are already"
|
||||
" being loaded",
|
||||
request.request_id,
|
||||
)
|
||||
return None, False
|
||||
|
||||
return num_hit_tokens, True
|
||||
|
||||
def update_state_after_alloc(
|
||||
self, request: Request, blocks: KVCacheBlocks, num_external_tokens: int
|
||||
):
|
||||
self._requests[request.request_id] = request
|
||||
# the block ids are updated in _get_reqs_to_store
|
||||
self._request_block_ids[request.request_id] = []
|
||||
|
||||
if num_external_tokens == 0:
|
||||
return
|
||||
|
||||
block_groups = blocks.get_block_ids()
|
||||
block_ids = block_groups[0]
|
||||
|
||||
num_computed_gpu_blocks = sum(
|
||||
block.block_hash is not None for block in blocks.blocks[0]
|
||||
)
|
||||
num_computed_tokens = num_computed_gpu_blocks * self.gpu_block_size
|
||||
full_block_tokens = num_computed_tokens + num_external_tokens
|
||||
assert full_block_tokens % self.offloaded_block_size == 0
|
||||
|
||||
num_pending_gpu_blocks = len(block_ids) - num_computed_gpu_blocks
|
||||
assert num_external_tokens == num_pending_gpu_blocks * self.gpu_block_size
|
||||
|
||||
start_block_idx = num_computed_tokens // self.offloaded_block_size
|
||||
num_blocks = full_block_tokens // self.offloaded_block_size
|
||||
|
||||
assert len(request.block_hashes) // self.block_size_factor >= num_blocks
|
||||
block_hashes = self._get_block_hashes(
|
||||
request, start_idx=start_block_idx, end_idx=num_blocks
|
||||
)
|
||||
|
||||
src_spec = self.manager.prepare_load(block_hashes)
|
||||
dst_spec = GPULoadStoreSpec(block_ids[num_computed_gpu_blocks:])
|
||||
|
||||
block_hashes = self._get_block_hashes(
|
||||
request, start_idx=start_block_idx, end_idx=num_blocks
|
||||
)
|
||||
|
||||
self._reqs_to_load[request.request_id] = (src_spec, dst_spec)
|
||||
req_blocks_being_loaded = self._reqs_being_loaded[request.request_id]
|
||||
req_blocks_being_loaded.update(block_hashes)
|
||||
self._next_stored_block_idx[request.request_id] = num_blocks
|
||||
|
||||
if self._blocks_being_loaded is not None:
|
||||
self._blocks_being_loaded.update(req_blocks_being_loaded)
|
||||
|
||||
def _get_reqs_to_store(self, scheduler_output: SchedulerOutput):
|
||||
reqs_to_store: dict[ReqId, TransferSpec] = {}
|
||||
# iterate over both new and cached requests
|
||||
for req_id, new_block_id_groups, preempted in yield_req_data(scheduler_output):
|
||||
if preempted:
|
||||
self._request_block_ids[req_id] = []
|
||||
|
||||
if new_block_id_groups:
|
||||
new_block_ids = new_block_id_groups[0]
|
||||
self._request_block_ids[req_id] += new_block_ids
|
||||
|
||||
block_ids = self._request_block_ids[req_id]
|
||||
|
||||
req = self._requests[req_id]
|
||||
new_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
||||
total_tokens = req.num_computed_tokens + new_tokens
|
||||
num_blocks = total_tokens // self.offloaded_block_size
|
||||
start_block_idx = self._next_stored_block_idx.get(req_id, 0)
|
||||
num_new_blocks = num_blocks - start_block_idx
|
||||
|
||||
if num_new_blocks <= 0:
|
||||
continue
|
||||
|
||||
# NOTE: In async scheduling, placeholders may temporarily make
|
||||
# len(req.block_hashes) < num_blocks * self.block_size_factor.
|
||||
|
||||
new_block_hashes = self._get_block_hashes(
|
||||
req, start_idx=start_block_idx, end_idx=num_blocks
|
||||
)
|
||||
store_output = self.manager.prepare_store(new_block_hashes)
|
||||
if store_output is None:
|
||||
logger.warning(
|
||||
"Request %s: cannot store %s blocks", req_id, num_new_blocks
|
||||
)
|
||||
continue
|
||||
|
||||
self._next_stored_block_idx[req_id] = num_blocks
|
||||
|
||||
if not store_output.block_hashes_to_store:
|
||||
continue
|
||||
block_hashes_to_store = set(store_output.block_hashes_to_store)
|
||||
|
||||
block_hashes = self._get_block_hashes(req, end_idx=num_blocks)
|
||||
self.manager.touch(block_hashes)
|
||||
|
||||
new_block_hashes = self._get_block_hashes(
|
||||
req, start_idx=start_block_idx, end_idx=num_blocks
|
||||
)
|
||||
dst_spec = store_output.store_spec
|
||||
src_block_ids: list[int] = []
|
||||
for idx, blk_hash in enumerate(new_block_hashes):
|
||||
if blk_hash not in block_hashes_to_store:
|
||||
continue
|
||||
offloaded_block_idx = start_block_idx + idx
|
||||
gpu_block_idx = offloaded_block_idx * self.block_size_factor
|
||||
for i in range(self.block_size_factor):
|
||||
src_block_ids.append(block_ids[gpu_block_idx + i])
|
||||
src_spec = GPULoadStoreSpec(src_block_ids)
|
||||
|
||||
reqs_to_store[req_id] = (src_spec, dst_spec)
|
||||
self._reqs_being_stored[req_id] |= block_hashes_to_store
|
||||
|
||||
logger.debug(
|
||||
"Request %s offloading %s blocks starting from block #%d",
|
||||
req_id,
|
||||
len(block_hashes_to_store),
|
||||
start_block_idx,
|
||||
)
|
||||
|
||||
return reqs_to_store
|
||||
|
||||
def build_connector_meta(
|
||||
self, scheduler_output: SchedulerOutput
|
||||
) -> KVConnectorMetadata:
|
||||
meta = OffloadingConnectorMetadata(
|
||||
reqs_to_load=self._reqs_to_load,
|
||||
reqs_to_store=self._get_reqs_to_store(scheduler_output),
|
||||
)
|
||||
self._reqs_to_load = {}
|
||||
|
||||
# NOTE (orozery): we should move this logic to update_connector_output
|
||||
# once KVConnectorOutput allows us to report completed transfers
|
||||
for req_id in scheduler_output.preempted_req_ids or ():
|
||||
block_hashes = self._reqs_being_stored.get(req_id)
|
||||
if block_hashes:
|
||||
self.manager.complete_store(block_hashes)
|
||||
block_hashes.clear()
|
||||
|
||||
return meta
|
||||
|
||||
def update_connector_output(self, connector_output: KVConnectorOutput):
|
||||
"""
|
||||
Update KVConnector state from worker-side connectors output.
|
||||
|
||||
Args:
|
||||
connector_output (KVConnectorOutput): the worker-side
|
||||
connectors output.
|
||||
"""
|
||||
for req_id in connector_output.finished_sending or []:
|
||||
block_hashes = self._reqs_being_stored.pop(req_id, None)
|
||||
if block_hashes:
|
||||
self.manager.complete_store(block_hashes)
|
||||
|
||||
for req_id in connector_output.finished_recving or []:
|
||||
block_hashes = self._reqs_being_loaded.pop(req_id, None)
|
||||
if block_hashes:
|
||||
if self._blocks_being_loaded:
|
||||
self._blocks_being_loaded.difference_update(block_hashes)
|
||||
self.manager.complete_load(block_hashes)
|
||||
|
||||
def request_finished(
|
||||
self,
|
||||
request: Request,
|
||||
block_ids: list[int],
|
||||
) -> tuple[bool, dict[str, Any] | None]:
|
||||
"""
|
||||
Called when a request has finished, before its blocks are freed.
|
||||
|
||||
Returns:
|
||||
True if the request is being saved/sent asynchronously and blocks
|
||||
should not be freed until the request_id is returned from
|
||||
get_finished().
|
||||
Optional KVTransferParams to be included in the request outputs
|
||||
returned by the engine.
|
||||
"""
|
||||
req_id = request.request_id
|
||||
self._requests.pop(req_id, None)
|
||||
self._request_block_ids.pop(req_id, None)
|
||||
self._next_stored_block_idx.pop(req_id, None)
|
||||
|
||||
request_being_stored = req_id in self._reqs_being_stored
|
||||
return request_being_stored, None
|
||||
|
||||
def take_events(self) -> Iterable[KVCacheEvent]:
|
||||
"""Take the KV cache events from the connector.
|
||||
|
||||
Returns:
|
||||
A list of KV cache events.
|
||||
"""
|
||||
for event in self.manager.take_events():
|
||||
if event.removed:
|
||||
yield BlockRemoved(block_hashes=event.block_hashes, medium=event.medium)
|
||||
else:
|
||||
yield BlockStored(
|
||||
block_hashes=event.block_hashes,
|
||||
parent_block_hash=None,
|
||||
token_ids=[],
|
||||
lora_id=None,
|
||||
block_size=event.block_size,
|
||||
medium=event.medium,
|
||||
lora_name=None,
|
||||
)
|
||||
|
||||
|
||||
class OffloadingConnectorWorker:
|
||||
"""Implementation of Worker side methods"""
|
||||
|
||||
def __init__(self, spec: OffloadingSpec):
|
||||
self.spec = spec
|
||||
self.worker = OffloadingWorker()
|
||||
|
||||
self._job_counter = 0
|
||||
|
||||
self.kv_connector_stats = OffloadingConnectorStats()
|
||||
# req_id -> (job_id, store)
|
||||
self._jobs: dict[int, tuple[ReqId, bool]] = {}
|
||||
# req_id -> active job IDs
|
||||
self._load_job: dict[ReqId, int] = {}
|
||||
# req_id -> set(active job IDs)
|
||||
self._store_jobs = defaultdict[ReqId, set[int]](set)
|
||||
# list of store jobs pending submission (job_id, transfer_spec)
|
||||
self._unsubmitted_store_jobs: list[tuple[int, TransferSpec]] = []
|
||||
|
||||
self._finished_reqs_waiting_for_store: set[ReqId] = set()
|
||||
|
||||
def _generate_job_id(self) -> int:
|
||||
job_id = self._job_counter
|
||||
self._job_counter = job_id + 1
|
||||
return job_id
|
||||
|
||||
def _register_handlers(
|
||||
self,
|
||||
kv_caches: dict[str, torch.Tensor],
|
||||
attn_backends: dict[str, type[AttentionBackend]],
|
||||
):
|
||||
for src_cls, dst_cls, handler in self.spec.get_handlers(
|
||||
kv_caches, attn_backends
|
||||
):
|
||||
self.worker.register_handler(src_cls, dst_cls, handler)
|
||||
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
layer_names = list(kv_caches.keys())
|
||||
layers = get_layers_from_vllm_config(
|
||||
self.spec.vllm_config, Attention, layer_names
|
||||
)
|
||||
attn_backends = {
|
||||
layer_name: layers[layer_name].get_attn_backend()
|
||||
for layer_name in layer_names
|
||||
}
|
||||
self._register_handlers(kv_caches, attn_backends)
|
||||
|
||||
def register_cross_layers_kv_cache(
|
||||
self, kv_cache: torch.Tensor, attn_backend: type[AttentionBackend]
|
||||
):
|
||||
cross_layer_name = "ALL_LAYERS"
|
||||
kv_caches = {cross_layer_name: kv_cache}
|
||||
attn_backends = {cross_layer_name: attn_backend}
|
||||
self._register_handlers(kv_caches, attn_backends)
|
||||
|
||||
def handle_preemptions(self, preempted_req_ids: set[str]):
|
||||
for job_id, transfer_spec in self._unsubmitted_store_jobs:
|
||||
success = self.worker.transfer_async(job_id, transfer_spec)
|
||||
assert success
|
||||
self._unsubmitted_store_jobs.clear()
|
||||
|
||||
for req_id in preempted_req_ids:
|
||||
job_ids = self._store_jobs.get(req_id)
|
||||
if job_ids:
|
||||
self.worker.wait(job_ids)
|
||||
|
||||
def start_kv_transfers(self, metadata: OffloadingConnectorMetadata):
|
||||
for job_id, transfer_spec in self._unsubmitted_store_jobs:
|
||||
success = self.worker.transfer_async(job_id, transfer_spec)
|
||||
assert success
|
||||
self._unsubmitted_store_jobs.clear()
|
||||
|
||||
for req_id, transfer_spec in metadata.reqs_to_load.items():
|
||||
job_id = self._generate_job_id()
|
||||
self._jobs[job_id] = (req_id, False)
|
||||
assert req_id not in self._load_job
|
||||
self._load_job[req_id] = job_id
|
||||
success = self.worker.transfer_async(job_id, transfer_spec)
|
||||
assert success
|
||||
|
||||
def prepare_store_kv(self, metadata: OffloadingConnectorMetadata):
|
||||
for req_id, transfer_spec in metadata.reqs_to_store.items():
|
||||
job_id = self._generate_job_id()
|
||||
self._jobs[job_id] = (req_id, True)
|
||||
self._store_jobs[req_id].add(job_id)
|
||||
# NOTE(orozery): defer the store to the beginning of the next engine step,
|
||||
# so that offloading starts AFTER transfers related to token sampling,
|
||||
# thereby avoiding delays to token generation due to offloading.
|
||||
self._unsubmitted_store_jobs.append((job_id, transfer_spec))
|
||||
|
||||
def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
|
||||
"""
|
||||
Notifies worker-side connector ids of requests that have
|
||||
finished generating tokens.
|
||||
Returns a list of request IDs that finished loading or storing.
|
||||
|
||||
Returns:
|
||||
ids of requests that have finished asynchronous transfer
|
||||
tuple of (sending/saving ids, recving/loading ids).
|
||||
"""
|
||||
finished_sending = set()
|
||||
finished_recving = set()
|
||||
for transfer_result in self.worker.get_finished():
|
||||
# we currently do not support job failures
|
||||
job_id = transfer_result.job_id
|
||||
assert transfer_result.success
|
||||
req_id, store = self._jobs.pop(job_id)
|
||||
if (
|
||||
transfer_result.transfer_time
|
||||
and transfer_result.transfer_size is not None
|
||||
and transfer_result.transfer_type is not None
|
||||
):
|
||||
self.kv_connector_stats.record_transfer(
|
||||
num_bytes=transfer_result.transfer_size,
|
||||
time=transfer_result.transfer_time,
|
||||
transfer_type=transfer_result.transfer_type,
|
||||
)
|
||||
if store:
|
||||
req_jobs = self._store_jobs[req_id]
|
||||
req_jobs.remove(job_id)
|
||||
if req_jobs:
|
||||
continue
|
||||
|
||||
if req_id in self._finished_reqs_waiting_for_store:
|
||||
self._finished_reqs_waiting_for_store.remove(req_id)
|
||||
finished_sending.add(req_id)
|
||||
del self._store_jobs[req_id]
|
||||
else:
|
||||
req_job = self._load_job[req_id]
|
||||
assert job_id == req_job
|
||||
del self._load_job[req_id]
|
||||
finished_recving.add(req_id)
|
||||
|
||||
for req_id in finished_req_ids:
|
||||
pending_req_jobs = self._store_jobs.get(req_id)
|
||||
if pending_req_jobs:
|
||||
self._finished_reqs_waiting_for_store.add(req_id)
|
||||
elif pending_req_jobs is not None:
|
||||
finished_sending.add(req_id)
|
||||
del self._store_jobs[req_id]
|
||||
|
||||
return finished_sending, finished_recving
|
||||
|
||||
def get_kv_connector_stats(self) -> KVConnectorStats | None:
|
||||
"""
|
||||
Get the KV transfer stats for the connector.
|
||||
"""
|
||||
|
||||
if self.kv_connector_stats.is_empty():
|
||||
return None
|
||||
# Clear stats for next iteration
|
||||
kv_connector_stats = self.kv_connector_stats
|
||||
self.kv_connector_stats = OffloadingConnectorStats()
|
||||
return kv_connector_stats
|
||||
|
||||
|
||||
class OffloadPromMetrics(KVConnectorPromMetrics):
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
metric_types: dict[type[PromMetric], type[PromMetricT]],
|
||||
labelnames: list[str],
|
||||
per_engine_labelvalues: dict[int, list[object]],
|
||||
):
|
||||
super().__init__(vllm_config, metric_types, labelnames, per_engine_labelvalues)
|
||||
# (engine_idx, transfer_tupe) -> (metric with bounded labels)
|
||||
self.histogram_transfer_size: dict[tuple[int, str], PromMetricT] = {}
|
||||
self.counter_kv_bytes: dict[tuple[int, str], PromMetricT] = {}
|
||||
self.counter_kv_transfer_time: dict[tuple[int, str], PromMetricT] = {}
|
||||
buckets = [ # In bytes
|
||||
1e6,
|
||||
5e6,
|
||||
10e6,
|
||||
20e6,
|
||||
40e6,
|
||||
60e6,
|
||||
80e6,
|
||||
100e6,
|
||||
150e6,
|
||||
200e6,
|
||||
]
|
||||
|
||||
self._counter_kv_bytes = self._counter_cls(
|
||||
name="vllm:kv_offload_total_bytes",
|
||||
documentation="Number of bytes offloaded by KV connector",
|
||||
labelnames=labelnames + ["transfer_type"],
|
||||
)
|
||||
|
||||
self._counter_kv_transfer_time = self._counter_cls(
|
||||
name="vllm:kv_offload_total_time",
|
||||
documentation="Total time measured by all KV offloading operations",
|
||||
labelnames=labelnames + ["transfer_type"],
|
||||
)
|
||||
|
||||
self._histogram_transfer_size = self._histogram_cls(
|
||||
name="vllm:kv_offload_size",
|
||||
documentation="Histogram of KV offload transfer size, in bytes.",
|
||||
buckets=buckets[:],
|
||||
labelnames=labelnames + ["transfer_type"],
|
||||
)
|
||||
|
||||
def observe(self, transfer_stats_data: dict[str, Any], engine_idx: int = 0):
|
||||
"""
|
||||
Observe transfer statistics from the new data structure.
|
||||
transfer_stats_data is expected to be a dict where:
|
||||
- keys are transfer type strings (e.g., "cpu_to_gpu", "gpu_to_cpu")
|
||||
- values are lists of OffloadingOperationMetrics objects
|
||||
"""
|
||||
|
||||
for transfer_type, ops in transfer_stats_data.items():
|
||||
# Cache:
|
||||
if (engine_idx, transfer_type) not in self.histogram_transfer_size:
|
||||
self.histogram_transfer_size[(engine_idx, transfer_type)] = (
|
||||
self._histogram_transfer_size.labels(
|
||||
*(self.per_engine_labelvalues[engine_idx] + [transfer_type])
|
||||
)
|
||||
)
|
||||
self.counter_kv_bytes[(engine_idx, transfer_type)] = (
|
||||
self._counter_kv_bytes.labels(
|
||||
*(self.per_engine_labelvalues[engine_idx] + [transfer_type])
|
||||
)
|
||||
)
|
||||
self.counter_kv_transfer_time[(engine_idx, transfer_type)] = (
|
||||
self._counter_kv_transfer_time.labels(
|
||||
*(self.per_engine_labelvalues[engine_idx] + [transfer_type])
|
||||
)
|
||||
)
|
||||
|
||||
# Process ops:
|
||||
assert isinstance(ops, list)
|
||||
for op in ops: # ops is a list of serialized OffloadingOperationMetrics
|
||||
assert isinstance(op, dict)
|
||||
# Observe size histogram
|
||||
self.histogram_transfer_size[(engine_idx, transfer_type)].observe(
|
||||
op["op_size"]
|
||||
)
|
||||
|
||||
# Increment byte and time counters
|
||||
self.counter_kv_bytes[(engine_idx, transfer_type)].inc(op["op_size"])
|
||||
|
||||
self.counter_kv_transfer_time[(engine_idx, transfer_type)].inc(
|
||||
op["op_time"]
|
||||
)
|
||||
@@ -0,0 +1,531 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import regex as re
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1,
|
||||
KVConnectorMetadata,
|
||||
KVConnectorRole,
|
||||
)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.p2p.p2p_nccl_engine import (
|
||||
P2pNcclEngine,
|
||||
)
|
||||
from vllm.distributed.parallel_state import get_world_group
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention.mla_attention import MLACommonMetadata
|
||||
from vllm.v1.attention.backend import AttentionMetadata
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.request import Request
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReqMeta:
|
||||
# Request Id
|
||||
request_id: str
|
||||
# Request block ids
|
||||
block_ids: torch.Tensor
|
||||
# Request num tokens
|
||||
num_tokens: int
|
||||
|
||||
@staticmethod
|
||||
def make_meta(
|
||||
request_id: str, token_ids: list[int], block_ids: list[int], block_size: int
|
||||
) -> "ReqMeta":
|
||||
block_ids_tensor = torch.tensor(block_ids)
|
||||
return ReqMeta(
|
||||
request_id=request_id,
|
||||
block_ids=block_ids_tensor,
|
||||
num_tokens=len(token_ids),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class P2pNcclConnectorMetadata(KVConnectorMetadata):
|
||||
requests: list[ReqMeta]
|
||||
|
||||
def __init__(self):
|
||||
self.requests = []
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
request_id: str,
|
||||
token_ids: list[int],
|
||||
block_ids: list[int],
|
||||
block_size: int,
|
||||
) -> None:
|
||||
self.requests.append(
|
||||
ReqMeta.make_meta(request_id, token_ids, block_ids, block_size)
|
||||
)
|
||||
|
||||
|
||||
class P2pNcclConnector(KVConnectorBase_V1):
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: "VllmConfig",
|
||||
role: KVConnectorRole,
|
||||
kv_cache_config: "KVCacheConfig | None" = None,
|
||||
):
|
||||
super().__init__(
|
||||
vllm_config=vllm_config,
|
||||
role=role,
|
||||
kv_cache_config=kv_cache_config,
|
||||
)
|
||||
self._block_size = vllm_config.cache_config.block_size
|
||||
self._requests_need_load: dict[str, Any] = {}
|
||||
self.is_producer = self._kv_transfer_config.is_kv_producer
|
||||
self.chunked_prefill: dict[str, tuple[list[int], list[int] | None]] = {}
|
||||
|
||||
self._rank = get_world_group().rank if role == KVConnectorRole.WORKER else 0
|
||||
self._local_rank = (
|
||||
get_world_group().local_rank if role == KVConnectorRole.WORKER else 0
|
||||
)
|
||||
|
||||
self.p2p_nccl_engine = (
|
||||
P2pNcclEngine(
|
||||
local_rank=self._local_rank,
|
||||
config=self._kv_transfer_config,
|
||||
hostname="",
|
||||
port_offset=self._rank,
|
||||
)
|
||||
if role == KVConnectorRole.WORKER
|
||||
else None
|
||||
)
|
||||
|
||||
# ==============================
|
||||
# Worker-side methods
|
||||
# ==============================
|
||||
|
||||
def start_load_kv(self, forward_context: "ForwardContext", **kwargs: Any) -> None:
|
||||
"""Start loading the KV cache from the connector buffer to vLLM's
|
||||
paged KV buffer.
|
||||
|
||||
Args:
|
||||
forward_context (ForwardContext): the forward context.
|
||||
**kwargs: additional arguments for the load operation
|
||||
|
||||
Note:
|
||||
The number of elements in kv_caches and layer_names should be
|
||||
the same.
|
||||
"""
|
||||
|
||||
# Only consumer/decode loads KV Cache
|
||||
if self.is_producer:
|
||||
return
|
||||
|
||||
assert self.p2p_nccl_engine is not None
|
||||
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
if attn_metadata is None:
|
||||
return
|
||||
|
||||
def inject_kv_into_layer(
|
||||
layer: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
block_ids: torch.Tensor,
|
||||
request_id: str,
|
||||
) -> None:
|
||||
"""
|
||||
Inject KV cache data into a given attention layer tensor.
|
||||
|
||||
This function updates `layer` in-place with values from `kv_cache`,
|
||||
handling different backend layouts:
|
||||
- MLA (Multi-Linear Attention) or FlashInfer: KV tensors are
|
||||
indexed along the first dimension.
|
||||
- FlashAttention: KV tensors are indexed along the second
|
||||
dimension.
|
||||
|
||||
If the number of provided block IDs does not match the number of KV
|
||||
blocks, only the overlapping portion is updated, and a warning is
|
||||
logged.
|
||||
|
||||
Args:
|
||||
layer (torch.Tensor): The attention layer KV tensor to update.
|
||||
kv_cache (torch.Tensor): The KV cache tensor to inject.
|
||||
block_ids (torch.Tensor): Indices of the blocks to update.
|
||||
request_id (str): Request identifier used for logging.
|
||||
|
||||
Returns:
|
||||
None. The function modifies `layer` in-place.
|
||||
"""
|
||||
if (
|
||||
isinstance(attn_metadata, MLACommonMetadata) or layer.shape[1] == 2
|
||||
): # MLA or FlashInfer
|
||||
num_block = kv_cache.shape[0]
|
||||
self.check_tensors_except_dim(layer, kv_cache, 0)
|
||||
if len(block_ids) == num_block:
|
||||
layer[block_ids, ...] = kv_cache
|
||||
else:
|
||||
layer[block_ids[:num_block], ...] = kv_cache
|
||||
logger.warning(
|
||||
"🚧kv_cache does not match, block_ids:%d, "
|
||||
"num_block:%d, request_id:%s",
|
||||
len(block_ids),
|
||||
num_block,
|
||||
request_id,
|
||||
)
|
||||
|
||||
elif layer.shape[0] == 2: # FlashAttention
|
||||
num_block = kv_cache.shape[1]
|
||||
self.check_tensors_except_dim(layer, kv_cache, 1)
|
||||
if len(block_ids) == num_block:
|
||||
layer[:, block_ids, ...] = kv_cache
|
||||
else:
|
||||
layer[:, block_ids[:num_block], ...] = kv_cache
|
||||
logger.warning(
|
||||
"🚧kv_cache does not match, block_ids:%d, "
|
||||
"num_block:%d, request_id:%s",
|
||||
len(block_ids),
|
||||
num_block,
|
||||
request_id,
|
||||
)
|
||||
|
||||
# Get the metadata
|
||||
metadata: KVConnectorMetadata = self._get_connector_metadata()
|
||||
assert isinstance(metadata, P2pNcclConnectorMetadata)
|
||||
|
||||
if metadata is None:
|
||||
return
|
||||
|
||||
# Load the KV for each request each layer
|
||||
for request in metadata.requests:
|
||||
request_id = request.request_id
|
||||
ip, port = self.parse_request_id(request_id, False)
|
||||
remote_address = ip + ":" + str(port + self._rank)
|
||||
for layer_name in forward_context.no_compile_layers:
|
||||
layer = forward_context.no_compile_layers[layer_name]
|
||||
|
||||
# Only process layers that have kv_cache
|
||||
# attribute (attention layers) Skip non-attention
|
||||
# layers like FusedMoE
|
||||
kv_cache = getattr(layer, "kv_cache", None)
|
||||
if kv_cache is None:
|
||||
continue
|
||||
|
||||
layer = kv_cache[forward_context.virtual_engine]
|
||||
|
||||
kv_cache = self.p2p_nccl_engine.recv_tensor(
|
||||
request.request_id + "#" + layer_name, remote_address
|
||||
)
|
||||
|
||||
if kv_cache is None:
|
||||
logger.warning("🚧kv_cache is None, %s", request.request_id)
|
||||
continue
|
||||
|
||||
inject_kv_into_layer(
|
||||
layer, kv_cache, request.block_ids, request.request_id
|
||||
)
|
||||
|
||||
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||
"""Blocking until the KV for a specific layer is loaded into vLLM's
|
||||
paged buffer.
|
||||
|
||||
This interface will be useful for layer-by-layer pipelining.
|
||||
|
||||
Args:
|
||||
layer_name: the name of that layer
|
||||
"""
|
||||
return
|
||||
|
||||
def save_kv_layer(
|
||||
self,
|
||||
layer_name: str,
|
||||
kv_layer: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Start saving the KV cache of the layer from vLLM's paged buffer
|
||||
to the connector.
|
||||
|
||||
Args:
|
||||
layer_name (str): the name of the layer.
|
||||
kv_layer (torch.Tensor): the paged KV buffer of the current
|
||||
layer in vLLM.
|
||||
attn_metadata (AttentionMetadata): the attention metadata.
|
||||
**kwargs: additional arguments for the save operation.
|
||||
"""
|
||||
|
||||
# Only producer/prefill saves KV Cache
|
||||
if not self.is_producer:
|
||||
return
|
||||
|
||||
assert self.p2p_nccl_engine is not None
|
||||
|
||||
def extract_kv_from_layer(
|
||||
layer: torch.Tensor,
|
||||
block_ids: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Extract KV cache slices from a given attention layer tensor.
|
||||
|
||||
This function handles multiple backend layouts:
|
||||
- MLA (Multi-Linear Attention) or FlashInfer: KV tensors are
|
||||
indexed along the first dimension.
|
||||
- FlashAttention: KV tensors are indexed along the second
|
||||
dimension.
|
||||
|
||||
Args:
|
||||
layer (torch.Tensor): The KV cache from the attention layer.
|
||||
block_ids (torch.Tensor): Indices of blocks to extract.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: A tensor containing the extracted KV slices.
|
||||
Returns None if the layout is unsupported.
|
||||
"""
|
||||
if (
|
||||
isinstance(attn_metadata, MLACommonMetadata) or layer.shape[1] == 2
|
||||
): # MLA or FlashInfer
|
||||
return layer[block_ids, ...]
|
||||
|
||||
if layer.shape[0] == 2: # FlashAttention
|
||||
return layer[:, block_ids, ...]
|
||||
|
||||
return None
|
||||
|
||||
connector_metadata = self._get_connector_metadata()
|
||||
assert isinstance(connector_metadata, P2pNcclConnectorMetadata)
|
||||
for request in connector_metadata.requests:
|
||||
request_id = request.request_id
|
||||
ip, port = self.parse_request_id(request_id, True)
|
||||
remote_address = ip + ":" + str(port + self._rank)
|
||||
|
||||
kv_cache = extract_kv_from_layer(kv_layer, request.block_ids)
|
||||
self.p2p_nccl_engine.send_tensor(
|
||||
request_id + "#" + layer_name, kv_cache, remote_address
|
||||
)
|
||||
|
||||
def wait_for_save(self):
|
||||
if self.is_producer:
|
||||
assert self.p2p_nccl_engine is not None
|
||||
self.p2p_nccl_engine.wait_for_sent()
|
||||
|
||||
def get_finished(
|
||||
self, finished_req_ids: set[str], **kwargs: Any
|
||||
) -> tuple[set[str] | None, set[str] | None]:
|
||||
"""
|
||||
Notifies worker-side connector ids of requests that have
|
||||
finished generating tokens.
|
||||
|
||||
Returns:
|
||||
ids of requests that have finished asynchronous transfer,
|
||||
tuple of (sending/saving ids, recving/loading ids).
|
||||
The finished saves/sends req ids must belong to a set provided in a
|
||||
call to this method (this call or a prior one).
|
||||
"""
|
||||
|
||||
assert self.p2p_nccl_engine is not None
|
||||
|
||||
no_compile_layers = self._vllm_config.compilation_config.static_forward_context
|
||||
return self.p2p_nccl_engine.get_finished(finished_req_ids, no_compile_layers)
|
||||
|
||||
# ==============================
|
||||
# Scheduler-side methods
|
||||
# ==============================
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self,
|
||||
request: "Request",
|
||||
num_computed_tokens: int,
|
||||
) -> tuple[int, bool]:
|
||||
"""
|
||||
Get number of new tokens that can be loaded from the
|
||||
external KV cache beyond the num_computed_tokens.
|
||||
|
||||
Args:
|
||||
request (Request): the request object.
|
||||
num_computed_tokens (int): the number of locally
|
||||
computed tokens for this request
|
||||
|
||||
Returns:
|
||||
the number of tokens that can be loaded from the
|
||||
external KV cache beyond what is already computed.
|
||||
"""
|
||||
if self.is_producer:
|
||||
return 0, False
|
||||
|
||||
prompt_token_ids = request.prompt_token_ids or []
|
||||
num_external_tokens = len(prompt_token_ids) - 1 - num_computed_tokens
|
||||
|
||||
if num_external_tokens < 0:
|
||||
num_external_tokens = 0
|
||||
|
||||
return num_external_tokens, False
|
||||
|
||||
def update_state_after_alloc(
|
||||
self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
|
||||
):
|
||||
"""
|
||||
Update KVConnector state after block allocation.
|
||||
"""
|
||||
if not self.is_producer and num_external_tokens > 0:
|
||||
self._requests_need_load[request.request_id] = (
|
||||
request,
|
||||
blocks.get_block_ids()[0],
|
||||
)
|
||||
|
||||
def build_connector_meta(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
) -> KVConnectorMetadata:
|
||||
"""Build the connector metadata for this step.
|
||||
|
||||
This function should NOT modify any fields in the scheduler_output.
|
||||
Also, calling this function will reset the state of the connector.
|
||||
|
||||
Args:
|
||||
scheduler_output (SchedulerOutput): the scheduler output object.
|
||||
"""
|
||||
|
||||
meta = P2pNcclConnectorMetadata()
|
||||
|
||||
for new_req in scheduler_output.scheduled_new_reqs:
|
||||
if self.is_producer:
|
||||
num_scheduled_tokens = (scheduler_output.num_scheduled_tokens)[
|
||||
new_req.req_id
|
||||
]
|
||||
num_tokens = num_scheduled_tokens + new_req.num_computed_tokens
|
||||
# the request's prompt is chunked prefill
|
||||
if num_tokens < len(new_req.prompt_token_ids or []):
|
||||
# 'CachedRequestData' has no attribute 'prompt_token_ids'
|
||||
self.chunked_prefill[new_req.req_id] = (
|
||||
new_req.block_ids[0],
|
||||
new_req.prompt_token_ids,
|
||||
)
|
||||
continue
|
||||
# the request's prompt is not chunked prefill
|
||||
meta.add_request(
|
||||
request_id=new_req.req_id,
|
||||
token_ids=new_req.prompt_token_ids or [],
|
||||
block_ids=new_req.block_ids[0],
|
||||
block_size=self._block_size,
|
||||
)
|
||||
continue
|
||||
if new_req.req_id in self._requests_need_load:
|
||||
meta.add_request(
|
||||
request_id=new_req.req_id,
|
||||
token_ids=new_req.prompt_token_ids or [],
|
||||
block_ids=new_req.block_ids[0],
|
||||
block_size=self._block_size,
|
||||
)
|
||||
self._requests_need_load.pop(new_req.req_id)
|
||||
|
||||
cached_reqs = scheduler_output.scheduled_cached_reqs
|
||||
for i, req_id in enumerate(cached_reqs.req_ids):
|
||||
num_computed_tokens = cached_reqs.num_computed_tokens[i]
|
||||
new_block_ids = cached_reqs.new_block_ids[i]
|
||||
resumed_from_preemption = req_id in cached_reqs.resumed_req_ids
|
||||
|
||||
if self.is_producer:
|
||||
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
||||
num_tokens = num_scheduled_tokens + num_computed_tokens
|
||||
assert req_id in self.chunked_prefill
|
||||
assert new_block_ids is not None
|
||||
block_ids = new_block_ids[0]
|
||||
if not resumed_from_preemption:
|
||||
block_ids = self.chunked_prefill[req_id][0] + block_ids
|
||||
prompt_token_ids = self.chunked_prefill[req_id][1]
|
||||
assert prompt_token_ids is not None
|
||||
# the request's prompt is chunked prefill again
|
||||
if num_tokens < len(prompt_token_ids):
|
||||
self.chunked_prefill[req_id] = (block_ids, prompt_token_ids)
|
||||
continue
|
||||
# the request's prompt is all prefilled finally
|
||||
meta.add_request(
|
||||
request_id=req_id,
|
||||
token_ids=prompt_token_ids,
|
||||
block_ids=block_ids,
|
||||
block_size=self._block_size,
|
||||
)
|
||||
self.chunked_prefill.pop(req_id, None)
|
||||
continue
|
||||
|
||||
# NOTE(rob): here we rely on the resumed requests being
|
||||
# the first N requests in the list scheduled_cache_reqs.
|
||||
if not resumed_from_preemption:
|
||||
break
|
||||
if req_id in self._requests_need_load:
|
||||
request, _ = self._requests_need_load.pop(req_id)
|
||||
total_tokens = num_computed_tokens + 1
|
||||
token_ids = request.all_token_ids[:total_tokens]
|
||||
|
||||
# NOTE(rob): For resumed req, new_block_ids is all
|
||||
# of the block_ids for the request.
|
||||
assert new_block_ids is not None
|
||||
block_ids = new_block_ids[0]
|
||||
|
||||
meta.add_request(
|
||||
request_id=req_id,
|
||||
token_ids=token_ids,
|
||||
block_ids=block_ids,
|
||||
block_size=self._block_size,
|
||||
)
|
||||
|
||||
self._requests_need_load.clear()
|
||||
return meta
|
||||
|
||||
def request_finished(
|
||||
self,
|
||||
request: "Request",
|
||||
block_ids: list[int],
|
||||
) -> tuple[bool, dict[str, Any] | None]:
|
||||
"""
|
||||
Called when a request has finished, before its blocks are freed.
|
||||
|
||||
Returns:
|
||||
True if the request is being saved/sent asynchronously and blocks
|
||||
should not be freed until the request_id is returned from
|
||||
get_finished().
|
||||
Optional KVTransferParams to be included in the request outputs
|
||||
returned by the engine.
|
||||
"""
|
||||
|
||||
self.chunked_prefill.pop(request.request_id, None)
|
||||
|
||||
return False, None
|
||||
|
||||
# ==============================
|
||||
# Static methods
|
||||
# ==============================
|
||||
|
||||
@staticmethod
|
||||
def parse_request_id(request_id: str, is_prefill=True) -> tuple[str, int]:
|
||||
# Regular expression to match the string hostname and integer port
|
||||
if is_prefill:
|
||||
pattern = r"___decode_addr_(.*):(\d+)"
|
||||
else:
|
||||
pattern = r"___prefill_addr_(.*):(\d+)___"
|
||||
|
||||
# Use re.search to find the pattern in the request_id
|
||||
match = re.search(pattern, request_id)
|
||||
if match:
|
||||
# Extract the ranks
|
||||
ip = match.group(1)
|
||||
port = int(match.group(2))
|
||||
|
||||
return ip, port
|
||||
raise ValueError(f"Request id {request_id} does not contain hostname and port")
|
||||
|
||||
@staticmethod
|
||||
def check_tensors_except_dim(tensor1, tensor2, dim):
|
||||
shape1 = tensor1.size()
|
||||
shape2 = tensor2.size()
|
||||
|
||||
if len(shape1) != len(shape2) or not all(
|
||||
s1 == s2 for i, (s1, s2) in enumerate(zip(shape1, shape2)) if i != dim
|
||||
):
|
||||
raise NotImplementedError(
|
||||
"Currently, only symmetric TP is supported. Asymmetric TP, PP,"
|
||||
"and others will be supported in future PRs."
|
||||
)
|
||||
@@ -0,0 +1,632 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from collections import deque
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import msgpack
|
||||
import torch
|
||||
import zmq
|
||||
|
||||
from vllm.config.kv_transfer import KVTransferConfig
|
||||
from vllm.distributed.device_communicators.pynccl_wrapper import (
|
||||
NCCLLibrary,
|
||||
buffer_type,
|
||||
cudaStream_t,
|
||||
ncclComm_t,
|
||||
ncclDataTypeEnum,
|
||||
)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.p2p.tensor_memory_pool import ( # noqa: E501
|
||||
TensorMemoryPool,
|
||||
)
|
||||
from vllm.utils.network_utils import get_ip
|
||||
from vllm.utils.torch_utils import current_stream
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_MEM_POOL_SIZE_GB = 32
|
||||
|
||||
|
||||
@contextmanager
|
||||
def set_p2p_nccl_context(num_channels: str):
|
||||
original_values: dict[str, Any] = {}
|
||||
env_vars = [
|
||||
"NCCL_MAX_NCHANNELS",
|
||||
"NCCL_MIN_NCHANNELS",
|
||||
"NCCL_CUMEM_ENABLE",
|
||||
"NCCL_BUFFSIZE",
|
||||
"NCCL_PROTO", # LL,LL128,SIMPLE
|
||||
"NCCL_ALGO", # RING,TREE
|
||||
]
|
||||
|
||||
for var in env_vars:
|
||||
original_values[var] = os.environ.get(var)
|
||||
|
||||
logger.info("set_p2p_nccl_context, original_values: %s", original_values)
|
||||
|
||||
try:
|
||||
os.environ["NCCL_MAX_NCHANNELS"] = num_channels
|
||||
os.environ["NCCL_MIN_NCHANNELS"] = num_channels
|
||||
os.environ["NCCL_CUMEM_ENABLE"] = "1"
|
||||
yield
|
||||
finally:
|
||||
for var in env_vars:
|
||||
if original_values[var] is not None:
|
||||
os.environ[var] = original_values[var]
|
||||
else:
|
||||
os.environ.pop(var, None)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SendQueueItem:
|
||||
tensor_id: str
|
||||
remote_address: str
|
||||
tensor: torch.Tensor
|
||||
|
||||
|
||||
class P2pNcclEngine:
|
||||
def __init__(
|
||||
self,
|
||||
local_rank: int,
|
||||
config: KVTransferConfig,
|
||||
hostname: str = "",
|
||||
port_offset: int = 0,
|
||||
library_path: str | None = None,
|
||||
) -> None:
|
||||
self.config = config
|
||||
self.rank = port_offset
|
||||
self.local_rank = local_rank
|
||||
self.device = torch.device(f"cuda:{self.local_rank}")
|
||||
self.nccl = NCCLLibrary(library_path)
|
||||
|
||||
if not hostname:
|
||||
hostname = get_ip()
|
||||
port = int(self.config.kv_port) + port_offset
|
||||
if port == 0:
|
||||
raise ValueError("Port cannot be 0")
|
||||
self._hostname = hostname
|
||||
self._port = port
|
||||
|
||||
# Each card corresponds to a ZMQ address.
|
||||
self.zmq_address = f"{self._hostname}:{self._port}"
|
||||
|
||||
# If `proxy_ip` or `proxy_port` is `""`,
|
||||
# then the ping thread will not be enabled.
|
||||
proxy_ip = self.config.get_from_extra_config("proxy_ip", "")
|
||||
proxy_port = self.config.get_from_extra_config("proxy_port", "")
|
||||
if proxy_ip == "" or proxy_port == "":
|
||||
self.proxy_address = ""
|
||||
self.http_address = ""
|
||||
else:
|
||||
self.proxy_address = proxy_ip + ":" + proxy_port
|
||||
# the `http_port` must be consistent with the port of OpenAI.
|
||||
http_port = self.config.get_from_extra_config("http_port", None)
|
||||
if http_port is None:
|
||||
example_cfg = {
|
||||
"kv_connector": "P2pNcclConnector",
|
||||
"kv_connector_extra_config": {"http_port": 8000},
|
||||
}
|
||||
example = (
|
||||
f"--port=8000 --kv-transfer-config='{json.dumps(example_cfg)}'"
|
||||
)
|
||||
raise ValueError(
|
||||
"kv_connector_extra_config.http_port is required. "
|
||||
f"Example: {example}"
|
||||
)
|
||||
self.http_address = f"{self._hostname}:{http_port}"
|
||||
|
||||
self.context = zmq.Context()
|
||||
self.router_socket = self.context.socket(zmq.ROUTER)
|
||||
self.router_socket.bind(f"tcp://{self.zmq_address}")
|
||||
|
||||
self.poller = zmq.Poller()
|
||||
self.poller.register(self.router_socket, zmq.POLLIN)
|
||||
|
||||
self.send_store_cv = threading.Condition()
|
||||
self.send_queue_cv = threading.Condition()
|
||||
self.recv_store_cv = threading.Condition()
|
||||
|
||||
self.send_stream = torch.cuda.Stream()
|
||||
self.recv_stream = torch.cuda.Stream()
|
||||
|
||||
mem_pool_size_gb = float(
|
||||
self.config.get_from_extra_config(
|
||||
"mem_pool_size_gb", DEFAULT_MEM_POOL_SIZE_GB
|
||||
)
|
||||
)
|
||||
self.pool = TensorMemoryPool(
|
||||
max_block_size=int(mem_pool_size_gb * 1024**3)
|
||||
) # GB
|
||||
|
||||
# The sending type includes tree mutually exclusive options:
|
||||
# PUT, GET, PUT_ASYNC.
|
||||
self.send_type = self.config.get_from_extra_config("send_type", "PUT_ASYNC")
|
||||
if self.send_type == "GET":
|
||||
# tensor_id: torch.Tensor
|
||||
self.send_store: dict[str, torch.Tensor] = {}
|
||||
else:
|
||||
# PUT or PUT_ASYNC
|
||||
# tensor_id: torch.Tensor
|
||||
self.send_queue: deque[SendQueueItem] = deque()
|
||||
if self.send_type == "PUT_ASYNC":
|
||||
self._send_thread = threading.Thread(
|
||||
target=self.send_async, daemon=True
|
||||
)
|
||||
self._send_thread.start()
|
||||
|
||||
# tensor_id: torch.Tensor/(addr, dtype, shape)
|
||||
self.recv_store: dict[str, Any] = {}
|
||||
self.recv_request_id_to_tensor_ids: dict[str, set[str]] = {}
|
||||
self.send_request_id_to_tensor_ids: dict[str, set[str]] = {}
|
||||
self.socks: dict[str, Any] = {} # remote_address: client socket
|
||||
self.comms: dict[str, Any] = {} # remote_address: (ncclComm_t, rank)
|
||||
|
||||
self.buffer_size = 0
|
||||
self.buffer_size_threshold = float(self.config.kv_buffer_size)
|
||||
|
||||
self.nccl_num_channels = self.config.get_from_extra_config(
|
||||
"nccl_num_channels", "8"
|
||||
)
|
||||
|
||||
self._listener_thread = threading.Thread(
|
||||
target=self.listen_for_requests, daemon=True
|
||||
)
|
||||
self._listener_thread.start()
|
||||
|
||||
self._ping_thread = None
|
||||
if port_offset == 0 and self.proxy_address != "":
|
||||
self._ping_thread = threading.Thread(target=self.ping, daemon=True)
|
||||
self._ping_thread.start()
|
||||
|
||||
logger.info(
|
||||
"💯P2pNcclEngine init, rank:%d, local_rank:%d, http_address:%s, "
|
||||
"zmq_address:%s, proxy_address:%s, send_type:%s, buffer_size_"
|
||||
"threshold:%.2f, nccl_num_channels:%s",
|
||||
self.rank,
|
||||
self.local_rank,
|
||||
self.http_address,
|
||||
self.zmq_address,
|
||||
self.proxy_address,
|
||||
self.send_type,
|
||||
self.buffer_size_threshold,
|
||||
self.nccl_num_channels,
|
||||
)
|
||||
|
||||
def create_connect(self, remote_address: str | None = None):
|
||||
assert remote_address is not None
|
||||
if remote_address not in self.socks:
|
||||
sock = self.context.socket(zmq.DEALER)
|
||||
sock.setsockopt_string(zmq.IDENTITY, self.zmq_address)
|
||||
sock.connect(f"tcp://{remote_address}")
|
||||
self.socks[remote_address] = sock
|
||||
if remote_address in self.comms:
|
||||
logger.info(
|
||||
"👋comm exists, remote_address:%s, comms:%s",
|
||||
remote_address,
|
||||
self.comms,
|
||||
)
|
||||
return sock, self.comms[remote_address]
|
||||
|
||||
unique_id = self.nccl.ncclGetUniqueId()
|
||||
data = {"cmd": "NEW", "unique_id": bytes(unique_id.internal)}
|
||||
sock.send(msgpack.dumps(data))
|
||||
|
||||
with torch.cuda.device(self.device):
|
||||
rank = 0
|
||||
with set_p2p_nccl_context(self.nccl_num_channels):
|
||||
comm: ncclComm_t = self.nccl.ncclCommInitRank(2, unique_id, rank)
|
||||
self.comms[remote_address] = (comm, rank)
|
||||
logger.info(
|
||||
"🤝ncclCommInitRank Success, %s👉%s, MyRank:%s",
|
||||
self.zmq_address,
|
||||
remote_address,
|
||||
rank,
|
||||
)
|
||||
|
||||
return self.socks[remote_address], self.comms[remote_address]
|
||||
|
||||
def send_tensor(
|
||||
self,
|
||||
tensor_id: str,
|
||||
tensor: torch.Tensor,
|
||||
remote_address: str | None = None,
|
||||
) -> bool:
|
||||
if remote_address is None:
|
||||
with self.recv_store_cv:
|
||||
self.recv_store[tensor_id] = tensor
|
||||
self.recv_store_cv.notify()
|
||||
return True
|
||||
|
||||
item = SendQueueItem(
|
||||
tensor_id=tensor_id, remote_address=remote_address, tensor=tensor
|
||||
)
|
||||
|
||||
if self.send_type == "PUT":
|
||||
return self.send_sync(item)
|
||||
|
||||
if self.send_type == "PUT_ASYNC":
|
||||
with self.send_queue_cv:
|
||||
self.send_queue.append(item)
|
||||
self.send_queue_cv.notify()
|
||||
return True
|
||||
|
||||
# GET
|
||||
with self.send_store_cv:
|
||||
tensor_size = tensor.element_size() * tensor.numel()
|
||||
if tensor_size > self.buffer_size_threshold:
|
||||
logger.warning(
|
||||
"❗[GET]tensor_id:%s, tensor_size:%d, is greater than"
|
||||
"buffer size threshold :%d, skip send to %s, rank:%d",
|
||||
tensor_id,
|
||||
tensor_size,
|
||||
self.buffer_size_threshold,
|
||||
remote_address,
|
||||
self.rank,
|
||||
)
|
||||
return False
|
||||
while self.buffer_size + tensor_size > self.buffer_size_threshold:
|
||||
assert len(self.send_store) > 0
|
||||
oldest_tensor_id = next(iter(self.send_store))
|
||||
oldest_tensor = self.send_store.pop(oldest_tensor_id)
|
||||
oldest_tensor_size = (
|
||||
oldest_tensor.element_size() * oldest_tensor.numel()
|
||||
)
|
||||
self.buffer_size -= oldest_tensor_size
|
||||
logger.debug(
|
||||
"⛔[GET]Send to %s, tensor_id:%s, tensor_size:%d,"
|
||||
" buffer_size:%d, oldest_tensor_size:%d, rank:%d",
|
||||
remote_address,
|
||||
tensor_id,
|
||||
tensor_size,
|
||||
self.buffer_size,
|
||||
oldest_tensor_size,
|
||||
self.rank,
|
||||
)
|
||||
|
||||
self.send_store[tensor_id] = tensor
|
||||
self.buffer_size += tensor_size
|
||||
logger.debug(
|
||||
"🔵[GET]Send to %s, tensor_id:%s, tensor_size:%d, "
|
||||
"shape:%s, rank:%d, buffer_size:%d(%.2f%%)",
|
||||
remote_address,
|
||||
tensor_id,
|
||||
tensor_size,
|
||||
tensor.shape,
|
||||
self.rank,
|
||||
self.buffer_size,
|
||||
self.buffer_size / self.buffer_size_threshold * 100,
|
||||
)
|
||||
return True
|
||||
|
||||
def recv_tensor(
|
||||
self,
|
||||
tensor_id: str,
|
||||
remote_address: str | None = None,
|
||||
) -> torch.Tensor:
|
||||
if self.send_type == "PUT" or self.send_type == "PUT_ASYNC":
|
||||
start_time = time.time()
|
||||
with self.recv_store_cv:
|
||||
while tensor_id not in self.recv_store:
|
||||
self.recv_store_cv.wait()
|
||||
tensor = self.recv_store[tensor_id]
|
||||
|
||||
if tensor is not None:
|
||||
if isinstance(tensor, tuple):
|
||||
addr, dtype, shape = tensor
|
||||
tensor = self.pool.load_tensor(addr, dtype, shape, self.device)
|
||||
else:
|
||||
self.buffer_size -= tensor.element_size() * tensor.numel()
|
||||
else:
|
||||
duration = time.time() - start_time
|
||||
logger.warning(
|
||||
"🔴[PUT]Recv From %s, tensor_id:%s, duration:%.3fms, rank:%d",
|
||||
remote_address,
|
||||
tensor_id,
|
||||
duration * 1000,
|
||||
self.rank,
|
||||
)
|
||||
return tensor
|
||||
|
||||
# GET
|
||||
if remote_address is None:
|
||||
return None
|
||||
|
||||
if remote_address not in self.socks:
|
||||
self.create_connect(remote_address)
|
||||
|
||||
sock = self.socks[remote_address]
|
||||
comm, rank = self.comms[remote_address]
|
||||
|
||||
data = {"cmd": "GET", "tensor_id": tensor_id}
|
||||
sock.send(msgpack.dumps(data))
|
||||
|
||||
message = sock.recv()
|
||||
data = msgpack.loads(message)
|
||||
if data["ret"] != 0:
|
||||
logger.warning(
|
||||
"🔴[GET]Recv From %s, tensor_id: %s, ret: %d",
|
||||
remote_address,
|
||||
tensor_id,
|
||||
data["ret"],
|
||||
)
|
||||
return None
|
||||
|
||||
with torch.cuda.stream(self.recv_stream):
|
||||
tensor = torch.empty(
|
||||
data["shape"], dtype=getattr(torch, data["dtype"]), device=self.device
|
||||
)
|
||||
|
||||
self.recv(comm, tensor, rank ^ 1, self.recv_stream)
|
||||
|
||||
return tensor
|
||||
|
||||
def listen_for_requests(self):
|
||||
while True:
|
||||
socks = dict(self.poller.poll())
|
||||
if self.router_socket not in socks:
|
||||
continue
|
||||
|
||||
remote_address, message = self.router_socket.recv_multipart()
|
||||
data = msgpack.loads(message)
|
||||
if data["cmd"] == "NEW":
|
||||
unique_id = self.nccl.unique_id_from_bytes(bytes(data["unique_id"]))
|
||||
with torch.cuda.device(self.device):
|
||||
rank = 1
|
||||
with set_p2p_nccl_context(self.nccl_num_channels):
|
||||
comm: ncclComm_t = self.nccl.ncclCommInitRank(
|
||||
2, unique_id, rank
|
||||
)
|
||||
self.comms[remote_address.decode()] = (comm, rank)
|
||||
logger.info(
|
||||
"🤝ncclCommInitRank Success, %s👈%s, MyRank:%s",
|
||||
self.zmq_address,
|
||||
remote_address.decode(),
|
||||
rank,
|
||||
)
|
||||
elif data["cmd"] == "PUT":
|
||||
tensor_id = data["tensor_id"]
|
||||
try:
|
||||
with torch.cuda.stream(self.recv_stream):
|
||||
tensor = torch.empty(
|
||||
data["shape"],
|
||||
dtype=getattr(torch, data["dtype"]),
|
||||
device=self.device,
|
||||
)
|
||||
self.router_socket.send_multipart([remote_address, b"0"])
|
||||
comm, rank = self.comms[remote_address.decode()]
|
||||
self.recv(comm, tensor, rank ^ 1, self.recv_stream)
|
||||
tensor_size = tensor.element_size() * tensor.numel()
|
||||
if self.buffer_size + tensor_size > self.buffer_size_threshold:
|
||||
# Store Tensor in memory pool
|
||||
addr = self.pool.store_tensor(tensor)
|
||||
tensor = (addr, tensor.dtype, tensor.shape)
|
||||
logger.warning(
|
||||
"🔴[PUT]Recv Tensor, Out Of Threshold, "
|
||||
"%s👈%s, data:%s, addr:%d",
|
||||
self.zmq_address,
|
||||
remote_address.decode(),
|
||||
data,
|
||||
addr,
|
||||
)
|
||||
else:
|
||||
self.buffer_size += tensor_size
|
||||
|
||||
except torch.cuda.OutOfMemoryError:
|
||||
self.router_socket.send_multipart([remote_address, b"1"])
|
||||
tensor = None
|
||||
logger.warning(
|
||||
"🔴[PUT]Recv Tensor, Out Of Memory, %s👈%s, data:%s",
|
||||
self.zmq_address,
|
||||
remote_address.decode(),
|
||||
data,
|
||||
)
|
||||
|
||||
with self.recv_store_cv:
|
||||
self.recv_store[tensor_id] = tensor
|
||||
self.have_received_tensor_id(tensor_id)
|
||||
self.recv_store_cv.notify()
|
||||
|
||||
elif data["cmd"] == "GET":
|
||||
tensor_id = data["tensor_id"]
|
||||
with self.send_store_cv:
|
||||
tensor = self.send_store.pop(tensor_id, None)
|
||||
if tensor is not None:
|
||||
data = {
|
||||
"ret": 0,
|
||||
"shape": tensor.shape,
|
||||
"dtype": str(tensor.dtype).replace("torch.", ""),
|
||||
}
|
||||
# LRU
|
||||
self.send_store[tensor_id] = tensor
|
||||
self.have_sent_tensor_id(tensor_id)
|
||||
else:
|
||||
data = {"ret": 1}
|
||||
|
||||
self.router_socket.send_multipart([remote_address, msgpack.dumps(data)])
|
||||
|
||||
if data["ret"] == 0:
|
||||
comm, rank = self.comms[remote_address.decode()]
|
||||
self.send(comm, tensor.to(self.device), rank ^ 1, self.send_stream)
|
||||
else:
|
||||
logger.warning(
|
||||
"🚧Unexpected, Received message from %s, data:%s",
|
||||
remote_address,
|
||||
data,
|
||||
)
|
||||
|
||||
def have_sent_tensor_id(self, tensor_id: str):
|
||||
request_id = tensor_id.split("#")[0]
|
||||
if request_id not in self.send_request_id_to_tensor_ids:
|
||||
self.send_request_id_to_tensor_ids[request_id] = set()
|
||||
self.send_request_id_to_tensor_ids[request_id].add(tensor_id)
|
||||
|
||||
def have_received_tensor_id(self, tensor_id: str):
|
||||
request_id = tensor_id.split("#")[0]
|
||||
if request_id not in self.recv_request_id_to_tensor_ids:
|
||||
self.recv_request_id_to_tensor_ids[request_id] = set()
|
||||
self.recv_request_id_to_tensor_ids[request_id].add(tensor_id)
|
||||
|
||||
def send_async(self):
|
||||
while True:
|
||||
with self.send_queue_cv:
|
||||
while not self.send_queue:
|
||||
self.send_queue_cv.wait()
|
||||
item = self.send_queue.popleft()
|
||||
if not self.send_queue:
|
||||
self.send_queue_cv.notify()
|
||||
self.send_sync(item)
|
||||
|
||||
def wait_for_sent(self):
|
||||
if self.send_type == "PUT_ASYNC":
|
||||
start_time = time.time()
|
||||
with self.send_queue_cv:
|
||||
while self.send_queue:
|
||||
self.send_queue_cv.wait()
|
||||
duration = time.time() - start_time
|
||||
logger.debug(
|
||||
"🚧[PUT_ASYNC]It took %.3fms to wait for the send_queue"
|
||||
" to be empty, rank:%d",
|
||||
duration * 1000,
|
||||
self.rank,
|
||||
)
|
||||
|
||||
def send_sync(self, item: SendQueueItem) -> bool:
|
||||
if item.remote_address is None:
|
||||
return False
|
||||
if item.remote_address not in self.socks:
|
||||
self.create_connect(item.remote_address)
|
||||
|
||||
tensor = item.tensor
|
||||
|
||||
sock = self.socks[item.remote_address]
|
||||
comm, rank = self.comms[item.remote_address]
|
||||
data = {
|
||||
"cmd": "PUT",
|
||||
"tensor_id": item.tensor_id,
|
||||
"shape": tensor.shape,
|
||||
"dtype": str(tensor.dtype).replace("torch.", ""),
|
||||
}
|
||||
sock.send(msgpack.dumps(data))
|
||||
|
||||
response = sock.recv()
|
||||
if response != b"0":
|
||||
logger.error(
|
||||
"🔴Send Tensor, Peer Out Of Memory/Threshold, %s 👉 %s, "
|
||||
"MyRank:%s, data:%s, tensor:%s, size:%fGB, response:%s",
|
||||
self.zmq_address,
|
||||
item.remote_address,
|
||||
rank,
|
||||
data,
|
||||
tensor.shape,
|
||||
tensor.element_size() * tensor.numel() / 1024**3,
|
||||
response.decode(),
|
||||
)
|
||||
return False
|
||||
|
||||
self.send(comm, tensor.to(self.device), rank ^ 1, self.send_stream)
|
||||
|
||||
if self.send_type == "PUT_ASYNC":
|
||||
self.have_sent_tensor_id(item.tensor_id)
|
||||
|
||||
return True
|
||||
|
||||
def get_finished(
|
||||
self, finished_req_ids: set[str], no_compile_layers
|
||||
) -> tuple[set[str] | None, set[str] | None]:
|
||||
"""
|
||||
Notifies worker-side connector ids of requests that have
|
||||
finished generating tokens.
|
||||
|
||||
Returns:
|
||||
ids of requests that have finished asynchronous transfer,
|
||||
tuple of (sending/saving ids, recving/loading ids).
|
||||
The finished saves/sends req ids must belong to a set provided in a
|
||||
call to this method (this call or a prior one).
|
||||
"""
|
||||
|
||||
# Clear the buffer upon request completion.
|
||||
for request_id in finished_req_ids:
|
||||
for layer_name in no_compile_layers:
|
||||
tensor_id = request_id + "#" + layer_name
|
||||
if tensor_id in self.recv_store:
|
||||
with self.recv_store_cv:
|
||||
tensor = self.recv_store.pop(tensor_id, None)
|
||||
self.send_request_id_to_tensor_ids.pop(request_id, None)
|
||||
self.recv_request_id_to_tensor_ids.pop(request_id, None)
|
||||
if isinstance(tensor, tuple):
|
||||
addr, _, _ = tensor
|
||||
self.pool.free(addr)
|
||||
|
||||
# TODO:Retrieve requests that have already sent the KV cache.
|
||||
finished_sending: set[str] = set()
|
||||
|
||||
# TODO:Retrieve requests that have already received the KV cache.
|
||||
finished_recving: set[str] = set()
|
||||
|
||||
return finished_sending or None, finished_recving or None
|
||||
|
||||
def ping(self):
|
||||
sock = self.context.socket(zmq.DEALER)
|
||||
sock.setsockopt_string(zmq.IDENTITY, self.zmq_address)
|
||||
logger.debug("ping start, zmq_address:%s", self.zmq_address)
|
||||
sock.connect(f"tcp://{self.proxy_address}")
|
||||
data = {
|
||||
"type": "P" if self.config.is_kv_producer else "D",
|
||||
"http_address": self.http_address,
|
||||
"zmq_address": self.zmq_address,
|
||||
}
|
||||
while True:
|
||||
sock.send(msgpack.dumps(data))
|
||||
time.sleep(3)
|
||||
|
||||
def send(self, comm, tensor: torch.Tensor, dst: int, stream=None):
|
||||
assert tensor.device == self.device, (
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {tensor.device}"
|
||||
)
|
||||
if stream is None:
|
||||
stream = current_stream()
|
||||
|
||||
with torch.cuda.stream(stream):
|
||||
self.nccl.ncclSend(
|
||||
buffer_type(tensor.data_ptr()),
|
||||
tensor.numel(),
|
||||
ncclDataTypeEnum.from_torch(tensor.dtype),
|
||||
dst,
|
||||
comm,
|
||||
cudaStream_t(stream.cuda_stream),
|
||||
)
|
||||
stream.synchronize()
|
||||
|
||||
def recv(self, comm, tensor: torch.Tensor, src: int, stream=None):
|
||||
assert tensor.device == self.device, (
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {tensor.device}"
|
||||
)
|
||||
if stream is None:
|
||||
stream = current_stream()
|
||||
|
||||
with torch.cuda.stream(stream):
|
||||
self.nccl.ncclRecv(
|
||||
buffer_type(tensor.data_ptr()),
|
||||
tensor.numel(),
|
||||
ncclDataTypeEnum.from_torch(tensor.dtype),
|
||||
src,
|
||||
comm,
|
||||
cudaStream_t(stream.cuda_stream),
|
||||
)
|
||||
stream.synchronize()
|
||||
|
||||
def close(self) -> None:
|
||||
self._listener_thread.join()
|
||||
if self.send_type == "PUT_ASYNC":
|
||||
self._send_thread.join()
|
||||
if self._ping_thread is not None:
|
||||
self._ping_thread.join()
|
||||
@@ -0,0 +1,273 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import atexit
|
||||
import ctypes
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryBlock:
|
||||
size: int
|
||||
addr: int
|
||||
|
||||
|
||||
"""A memory pool for managing pinned host memory allocations for tensors.
|
||||
|
||||
This class implements a buddy allocation system to efficiently manage pinned
|
||||
host memory for tensor storage. It supports allocation, deallocation, and
|
||||
tensor storage/retrieval operations.
|
||||
|
||||
Key Features:
|
||||
- Uses power-of-two block sizes for efficient buddy allocation
|
||||
- Supports splitting and merging of memory blocks
|
||||
- Provides methods to store CUDA tensors in pinned host memory
|
||||
- Allows loading tensors from pinned memory back to device
|
||||
- Automatically cleans up memory on destruction
|
||||
|
||||
Attributes:
|
||||
max_block_size (int): Maximum block size (rounded to nearest power of two)
|
||||
min_block_size (int): Minimum block size (rounded to nearest power of two)
|
||||
free_lists (dict): Dictionary of free memory blocks by size
|
||||
allocated_blocks (dict): Dictionary of currently allocated blocks
|
||||
base_tensor (torch.Tensor): Base pinned memory tensor
|
||||
base_address (int): Base memory address of the pinned memory region
|
||||
|
||||
Example:
|
||||
>>> pool = TensorMemoryPool(max_block_size=1024*1024)
|
||||
>>> tensor = torch.randn(100, device='cuda')
|
||||
>>> addr = pool.store_tensor(tensor)
|
||||
>>> loaded_tensor = pool.load_tensor(addr, tensor.dtype,
|
||||
... tensor.shape, 'cuda')
|
||||
>>> pool.free(addr)
|
||||
"""
|
||||
|
||||
|
||||
class TensorMemoryPool:
|
||||
"""Initializes the memory pool with given size constraints.
|
||||
|
||||
Args:
|
||||
max_block_size (int): Maximum size of memory blocks to manage
|
||||
min_block_size (int, optional): Minimum size of memory blocks
|
||||
to manage. Defaults to 512.
|
||||
|
||||
Raises:
|
||||
ValueError: If block sizes are invalid or max_block_size is less
|
||||
than min_block_size
|
||||
"""
|
||||
|
||||
def __init__(self, max_block_size: int, min_block_size: int = 512):
|
||||
if max_block_size <= 0 or min_block_size <= 0:
|
||||
raise ValueError("Block sizes must be positive")
|
||||
if max_block_size < min_block_size:
|
||||
raise ValueError("Max block size must be greater than min block size")
|
||||
|
||||
self.max_block_size = self._round_to_power_of_two(max_block_size)
|
||||
self.min_block_size = self._round_to_power_of_two(min_block_size)
|
||||
|
||||
self.free_lists: dict[int, dict[int, MemoryBlock]] = {}
|
||||
self.allocated_blocks: dict[int, MemoryBlock] = {}
|
||||
|
||||
self._initialize_free_lists()
|
||||
self._allocate_pinned_memory()
|
||||
|
||||
atexit.register(self.cleanup)
|
||||
|
||||
def _round_to_power_of_two(self, size: int) -> int:
|
||||
return 1 << (size - 1).bit_length()
|
||||
|
||||
def _initialize_free_lists(self):
|
||||
size = self.max_block_size
|
||||
while size >= self.min_block_size:
|
||||
self.free_lists[size] = {}
|
||||
size //= 2
|
||||
|
||||
def _allocate_pinned_memory(self):
|
||||
self.base_tensor = torch.empty(
|
||||
self.max_block_size // 4, dtype=torch.float32, pin_memory=True
|
||||
)
|
||||
self.base_address = self.base_tensor.data_ptr()
|
||||
initial_block = MemoryBlock(size=self.max_block_size, addr=self.base_address)
|
||||
self.free_lists[self.max_block_size][initial_block.addr] = initial_block
|
||||
|
||||
logger.debug(
|
||||
"TensorMemoryPool, base_address:%d, max_block_size:%d",
|
||||
self.base_address,
|
||||
self.max_block_size,
|
||||
)
|
||||
|
||||
def allocate(self, size: int) -> int:
|
||||
"""Allocates a memory block of at least the requested size.
|
||||
|
||||
Args:
|
||||
size (int): Minimum size of memory to allocate
|
||||
|
||||
Returns:
|
||||
int: Address of the allocated memory block
|
||||
|
||||
Raises:
|
||||
ValueError: If size is invalid or insufficient memory is available
|
||||
"""
|
||||
if size <= 0:
|
||||
raise ValueError("Allocation size must be positive")
|
||||
|
||||
required_size = self._round_to_power_of_two(max(size, self.min_block_size))
|
||||
if required_size > self.max_block_size:
|
||||
raise ValueError("Requested size exceeds maximum block size")
|
||||
|
||||
current_size = required_size
|
||||
while current_size <= self.max_block_size:
|
||||
if self.free_lists[current_size]:
|
||||
_, block = self.free_lists[current_size].popitem()
|
||||
self._split_block(block, required_size)
|
||||
self.allocated_blocks[block.addr] = block
|
||||
return block.addr
|
||||
current_size *= 2
|
||||
|
||||
raise ValueError("Insufficient memory")
|
||||
|
||||
def _split_block(self, block: MemoryBlock, required_size: int):
|
||||
while block.size > required_size and block.size // 2 >= self.min_block_size:
|
||||
buddy_size = block.size // 2
|
||||
buddy_addr = block.addr + buddy_size
|
||||
|
||||
buddy = MemoryBlock(size=buddy_size, addr=buddy_addr)
|
||||
block.size = buddy_size
|
||||
|
||||
self.free_lists[buddy_size][buddy.addr] = buddy
|
||||
|
||||
def free(self, addr: int):
|
||||
"""Frees an allocated memory block.
|
||||
|
||||
Args:
|
||||
addr (int): Address of the block to free
|
||||
|
||||
Raises:
|
||||
ValueError: If address is invalid or not allocated
|
||||
"""
|
||||
if addr not in self.allocated_blocks:
|
||||
raise ValueError("Invalid address to free")
|
||||
|
||||
block = self.allocated_blocks.pop(addr)
|
||||
self._merge_buddies(block)
|
||||
|
||||
def _merge_buddies(self, block: MemoryBlock):
|
||||
MAX_MERGE_DEPTH = 30
|
||||
depth = 0
|
||||
|
||||
while depth < MAX_MERGE_DEPTH:
|
||||
buddy_offset = (
|
||||
block.size
|
||||
if (block.addr - self.base_address) % (2 * block.size) == 0
|
||||
else -block.size
|
||||
)
|
||||
buddy_addr = block.addr + buddy_offset
|
||||
buddy = self.free_lists[block.size].get(buddy_addr)
|
||||
if buddy:
|
||||
del self.free_lists[buddy.size][buddy.addr]
|
||||
merged_addr = min(block.addr, buddy.addr)
|
||||
merged_size = block.size * 2
|
||||
block = MemoryBlock(size=merged_size, addr=merged_addr)
|
||||
depth += 1
|
||||
else:
|
||||
break
|
||||
self.free_lists[block.size][block.addr] = block
|
||||
|
||||
def store_tensor(self, tensor: torch.Tensor) -> int:
|
||||
"""Stores a CUDA tensor in pinned host memory.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): CUDA tensor to store
|
||||
|
||||
Returns:
|
||||
int: Address where the tensor is stored
|
||||
|
||||
Raises:
|
||||
ValueError: If tensor is not on CUDA or allocation fails
|
||||
"""
|
||||
if not tensor.is_cuda:
|
||||
raise ValueError("Only CUDA tensors can be stored")
|
||||
|
||||
size = tensor.element_size() * tensor.numel()
|
||||
addr = self.allocate(size)
|
||||
block = self.allocated_blocks[addr]
|
||||
|
||||
if block.size < size:
|
||||
self.free(addr)
|
||||
raise ValueError(
|
||||
f"Allocated block size {block.size} is smaller than "
|
||||
f"required size {size}"
|
||||
)
|
||||
|
||||
try:
|
||||
buffer = (ctypes.c_byte * block.size).from_address(block.addr)
|
||||
cpu_tensor = torch.frombuffer(
|
||||
buffer, dtype=tensor.dtype, count=tensor.numel()
|
||||
).reshape(tensor.shape)
|
||||
except ValueError as err:
|
||||
self.free(addr)
|
||||
raise ValueError(f"Failed to create tensor view: {err}") from err
|
||||
|
||||
cpu_tensor.copy_(tensor)
|
||||
|
||||
return addr
|
||||
|
||||
def load_tensor(
|
||||
self,
|
||||
addr: int,
|
||||
dtype: torch.dtype,
|
||||
shape: tuple[int, ...],
|
||||
device: torch.device,
|
||||
) -> torch.Tensor:
|
||||
"""Loads a tensor from pinned host memory to the specified device.
|
||||
|
||||
Args:
|
||||
addr (int): Address where tensor is stored
|
||||
dtype (torch.dtype): Data type of the tensor
|
||||
shape (tuple[int, ...]): Shape of the tensor
|
||||
device: Target device for the loaded tensor
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The loaded tensor on the specified device
|
||||
|
||||
Raises:
|
||||
ValueError: If address is invalid or sizes don't match
|
||||
"""
|
||||
if addr not in self.allocated_blocks:
|
||||
raise ValueError("Invalid address to load")
|
||||
|
||||
block = self.allocated_blocks[addr]
|
||||
num_elements = math.prod(shape)
|
||||
dtype_size = torch.tensor([], dtype=dtype).element_size()
|
||||
required_size = num_elements * dtype_size
|
||||
|
||||
if required_size > block.size:
|
||||
raise ValueError("Requested tensor size exceeds block size")
|
||||
|
||||
buffer = (ctypes.c_byte * block.size).from_address(block.addr)
|
||||
cpu_tensor = torch.frombuffer(buffer, dtype=dtype, count=num_elements).reshape(
|
||||
shape
|
||||
)
|
||||
|
||||
cuda_tensor = torch.empty(shape, dtype=dtype, device=device)
|
||||
|
||||
cuda_tensor.copy_(cpu_tensor)
|
||||
|
||||
return cuda_tensor
|
||||
|
||||
def cleanup(self):
|
||||
"""Cleans up all memory resources and resets the pool state."""
|
||||
self.free_lists.clear()
|
||||
self.allocated_blocks.clear()
|
||||
if hasattr(self, "base_tensor"):
|
||||
del self.base_tensor
|
||||
|
||||
def __del__(self):
|
||||
self.cleanup()
|
||||
78
vllm/distributed/kv_transfer/kv_transfer_state.py
Normal file
78
vllm/distributed/kv_transfer/kv_transfer_state.py
Normal file
@@ -0,0 +1,78 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType
|
||||
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1 import (
|
||||
KVConnectorBase_V1,
|
||||
KVConnectorRole,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
|
||||
_KV_CONNECTOR_AGENT: KVConnectorBaseType | None = None
|
||||
|
||||
|
||||
def get_kv_transfer_group() -> KVConnectorBaseType:
|
||||
assert _KV_CONNECTOR_AGENT is not None, (
|
||||
"disaggregated KV cache transfer parallel group is not initialized"
|
||||
)
|
||||
return _KV_CONNECTOR_AGENT
|
||||
|
||||
|
||||
def has_kv_transfer_group() -> bool:
|
||||
return _KV_CONNECTOR_AGENT is not None
|
||||
|
||||
|
||||
def is_v1_kv_transfer_group(connector: KVConnectorBaseType | None = None) -> bool:
|
||||
"""Check if the KV connector is the v1 connector.
|
||||
If the argument is None, it will check the global KV connector
|
||||
|
||||
Args:
|
||||
connector: The KV connector to check. If None, it will check the
|
||||
global KV connector.
|
||||
|
||||
Note:
|
||||
This function will no-longer be needed after the v1 KV connector
|
||||
becomes the default.
|
||||
"""
|
||||
if connector is None:
|
||||
connector = _KV_CONNECTOR_AGENT
|
||||
|
||||
if connector is None:
|
||||
return False
|
||||
|
||||
return isinstance(connector, KVConnectorBase_V1)
|
||||
|
||||
|
||||
def ensure_kv_transfer_initialized(
|
||||
vllm_config: "VllmConfig", kv_cache_config: "KVCacheConfig | None" = None
|
||||
) -> None:
|
||||
"""
|
||||
Initialize KV cache transfer parallel group.
|
||||
"""
|
||||
|
||||
global _KV_CONNECTOR_AGENT
|
||||
|
||||
if vllm_config.kv_transfer_config is None:
|
||||
return
|
||||
|
||||
if (
|
||||
vllm_config.kv_transfer_config.is_kv_transfer_instance
|
||||
and _KV_CONNECTOR_AGENT is None
|
||||
):
|
||||
_KV_CONNECTOR_AGENT = KVConnectorFactory.create_connector(
|
||||
config=vllm_config,
|
||||
role=KVConnectorRole.WORKER,
|
||||
kv_cache_config=kv_cache_config,
|
||||
)
|
||||
|
||||
|
||||
def ensure_kv_transfer_shutdown() -> None:
|
||||
global _KV_CONNECTOR_AGENT
|
||||
if _KV_CONNECTOR_AGENT is not None:
|
||||
_KV_CONNECTOR_AGENT.shutdown()
|
||||
_KV_CONNECTOR_AGENT = None
|
||||
1959
vllm/distributed/parallel_state.py
Normal file
1959
vllm/distributed/parallel_state.py
Normal file
File diff suppressed because it is too large
Load Diff
566
vllm/distributed/utils.py
Normal file
566
vllm/distributed/utils.py
Normal file
@@ -0,0 +1,566 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Adapted from
|
||||
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
|
||||
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
||||
import dataclasses
|
||||
import os
|
||||
import pickle
|
||||
import socket
|
||||
import sys
|
||||
import time
|
||||
import uuid
|
||||
from collections import deque
|
||||
from collections.abc import Sequence
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch.distributed import ProcessGroup, TCPStore
|
||||
from torch.distributed.distributed_c10d import (
|
||||
Backend,
|
||||
PrefixStore,
|
||||
_get_default_timeout,
|
||||
_unregister_process_group,
|
||||
)
|
||||
from torch.distributed.rendezvous import rendezvous
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.network_utils import get_tcp_uri
|
||||
from vllm.utils.system_utils import suppress_stdout
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# We prefer to use os.sched_yield as it results in tighter polling loops,
|
||||
# measured to be around 3e-7 seconds. However on earlier versions of Python
|
||||
# os.sched_yield() does not release the GIL, so we fall back to time.sleep(0)
|
||||
USE_SCHED_YIELD = (sys.version_info[:3] >= (3, 11, 1)) or (
|
||||
sys.version_info[:2] == (3, 10) and sys.version_info[2] >= 8
|
||||
)
|
||||
|
||||
|
||||
def sched_yield():
|
||||
if USE_SCHED_YIELD:
|
||||
os.sched_yield()
|
||||
else:
|
||||
time.sleep(0)
|
||||
|
||||
|
||||
def ensure_divisibility(numerator, denominator):
|
||||
"""Ensure that numerator is divisible by the denominator."""
|
||||
assert numerator % denominator == 0, "{} is not divisible by {}".format(
|
||||
numerator, denominator
|
||||
)
|
||||
|
||||
|
||||
def divide(numerator, denominator):
|
||||
"""Ensure that numerator is divisible by the denominator and return
|
||||
the division value."""
|
||||
ensure_divisibility(numerator, denominator)
|
||||
return numerator // denominator
|
||||
|
||||
|
||||
def split_tensor_along_last_dim(
|
||||
tensor: torch.Tensor,
|
||||
num_partitions: int,
|
||||
contiguous_split_chunks: bool = False,
|
||||
) -> Sequence[torch.Tensor]:
|
||||
"""Split a tensor along its last dimension.
|
||||
|
||||
Arguments:
|
||||
tensor: input tensor.
|
||||
num_partitions: number of partitions to split the tensor
|
||||
contiguous_split_chunks: If True, make each chunk contiguous
|
||||
in memory.
|
||||
|
||||
Returns:
|
||||
A list of Tensors
|
||||
"""
|
||||
# Get the size and dimension.
|
||||
last_dim = tensor.dim() - 1
|
||||
last_dim_size = divide(tensor.size()[last_dim], num_partitions)
|
||||
# Split.
|
||||
tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
|
||||
# NOTE: torch.split does not create contiguous tensors by default.
|
||||
if contiguous_split_chunks:
|
||||
return tuple(chunk.contiguous() for chunk in tensor_list)
|
||||
|
||||
return tensor_list
|
||||
|
||||
|
||||
def get_pp_indices(
|
||||
num_hidden_layers: int, pp_rank: int, pp_size: int
|
||||
) -> tuple[int, int]:
|
||||
"""Try to evenly distribute layers across partitions.
|
||||
|
||||
If the number of layers is not divisible by the number of partitions,
|
||||
the remaining layers are evenly distributed across all but the last
|
||||
partition. The last partition is excluded because it often contains an
|
||||
additional norm layer and we are attempting to balance compute.
|
||||
|
||||
If `pp_size > 2` and the number of remaining layers is
|
||||
`0 < x <= pp_size - 2` then the remaining layers are evenly distributed
|
||||
across the middle partitions. The first and last partitions are excluded
|
||||
because they contain the input and output embeddings respectively and we
|
||||
are attempting to reduce maximum memory consumption across partitions.
|
||||
"""
|
||||
partition_list_str = envs.VLLM_PP_LAYER_PARTITION
|
||||
if partition_list_str is not None:
|
||||
try:
|
||||
partitions = [int(layer) for layer in partition_list_str.split(",")]
|
||||
except ValueError as err:
|
||||
raise ValueError(
|
||||
"Invalid partition string: {}".format(partition_list_str)
|
||||
) from err
|
||||
if len(partitions) != pp_size:
|
||||
raise ValueError(f"{len(partitions)=} does not match {pp_size=}.")
|
||||
if sum(partitions) != num_hidden_layers:
|
||||
raise ValueError(f"{sum(partitions)=} does not match {num_hidden_layers=}.")
|
||||
else:
|
||||
layers_per_partition = num_hidden_layers // pp_size
|
||||
partitions = [layers_per_partition for _ in range(pp_size)]
|
||||
|
||||
if remaining_layers := num_hidden_layers % pp_size:
|
||||
for i in range(2, remaining_layers + 2):
|
||||
partitions[-i] += 1
|
||||
logger.info(
|
||||
"Hidden layers were unevenly partitioned: [%s]. "
|
||||
"This can be manually overridden using the "
|
||||
"VLLM_PP_LAYER_PARTITION environment variable",
|
||||
",".join(str(p) for p in partitions),
|
||||
)
|
||||
|
||||
start_layer = sum(partitions[:pp_rank])
|
||||
end_layer = start_layer + partitions[pp_rank]
|
||||
|
||||
return (start_layer, end_layer)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class StatelessProcessGroup:
|
||||
"""A dataclass to hold a metadata store, and the rank, world_size of the
|
||||
group. Only use it to communicate metadata between processes.
|
||||
For data-plane communication, create NCCL-related objects.
|
||||
"""
|
||||
|
||||
rank: int
|
||||
world_size: int
|
||||
store: torch._C._distributed_c10d.Store
|
||||
|
||||
# stores a reference to the socket so that the file descriptor stays alive
|
||||
socket: socket.socket | None
|
||||
|
||||
data_expiration_seconds: int = 3600 # 1 hour
|
||||
|
||||
# dst rank -> counter
|
||||
send_dst_counter: dict[int, int] = dataclasses.field(default_factory=dict)
|
||||
# src rank -> counter
|
||||
recv_src_counter: dict[int, int] = dataclasses.field(default_factory=dict)
|
||||
broadcast_send_counter: int = 0
|
||||
broadcast_recv_src_counter: dict[int, int] = dataclasses.field(default_factory=dict)
|
||||
|
||||
# A deque to store the data entries, with key and timestamp.
|
||||
entries: deque[tuple[str, float]] = dataclasses.field(default_factory=deque)
|
||||
|
||||
def __post_init__(self):
|
||||
assert self.rank < self.world_size
|
||||
self.send_dst_counter = {i: 0 for i in range(self.world_size)}
|
||||
self.recv_src_counter = {i: 0 for i in range(self.world_size)}
|
||||
self.broadcast_recv_src_counter = {i: 0 for i in range(self.world_size)}
|
||||
|
||||
def send_obj(self, obj: Any, dst: int):
|
||||
"""Send an object to a destination rank."""
|
||||
self.expire_data()
|
||||
key = f"send_to/{dst}/{self.send_dst_counter[dst]}"
|
||||
self.store.set(key, pickle.dumps(obj))
|
||||
self.send_dst_counter[dst] += 1
|
||||
self.entries.append((key, time.time()))
|
||||
|
||||
def expire_data(self):
|
||||
"""Expire data that is older than `data_expiration_seconds` seconds."""
|
||||
while self.entries:
|
||||
# check the oldest entry
|
||||
key, timestamp = self.entries[0]
|
||||
if time.time() - timestamp > self.data_expiration_seconds:
|
||||
self.store.delete_key(key)
|
||||
self.entries.popleft()
|
||||
else:
|
||||
break
|
||||
|
||||
def recv_obj(self, src: int) -> Any:
|
||||
"""Receive an object from a source rank."""
|
||||
obj = pickle.loads(
|
||||
self.store.get(f"send_to/{self.rank}/{self.recv_src_counter[src]}")
|
||||
)
|
||||
self.recv_src_counter[src] += 1
|
||||
return obj
|
||||
|
||||
def broadcast_obj(self, obj: Any | None, src: int) -> Any:
|
||||
"""Broadcast an object from a source rank to all other ranks.
|
||||
It does not clean up after all ranks have received the object.
|
||||
Use it for limited times, e.g., for initialization.
|
||||
"""
|
||||
if self.rank == src:
|
||||
self.expire_data()
|
||||
key = f"broadcast_from/{src}/{self.broadcast_send_counter}"
|
||||
self.store.set(key, pickle.dumps(obj))
|
||||
self.broadcast_send_counter += 1
|
||||
self.entries.append((key, time.time()))
|
||||
return obj
|
||||
else:
|
||||
key = f"broadcast_from/{src}/{self.broadcast_recv_src_counter[src]}"
|
||||
recv_obj = pickle.loads(self.store.get(key))
|
||||
self.broadcast_recv_src_counter[src] += 1
|
||||
return recv_obj
|
||||
|
||||
def all_gather_obj(self, obj: Any) -> list[Any]:
|
||||
"""All gather an object from all ranks."""
|
||||
gathered_objs = []
|
||||
for i in range(self.world_size):
|
||||
if i == self.rank:
|
||||
gathered_objs.append(obj)
|
||||
self.broadcast_obj(obj, src=self.rank)
|
||||
else:
|
||||
recv_obj = self.broadcast_obj(None, src=i)
|
||||
gathered_objs.append(recv_obj)
|
||||
return gathered_objs
|
||||
|
||||
def barrier(self, timeout: float = 30.0):
|
||||
"""A robust barrier to synchronize all ranks.
|
||||
|
||||
|
||||
Uses a multi-phase approach to ensure all processes reach the barrier
|
||||
before proceeding:
|
||||
|
||||
1. Each process signals it has reached the barrier
|
||||
|
||||
2. Each process signals that it has confirmed the arrival of all other
|
||||
ranks.
|
||||
|
||||
3. Rank 0 waits for all other ranks to signal their departure to ensure
|
||||
that all ranks have departed the barrier first.
|
||||
|
||||
Args:
|
||||
timeout: Maximum time in seconds to wait for each phase (in seconds)
|
||||
|
||||
|
||||
Raises:
|
||||
RuntimeError: If coordination fails or times out
|
||||
"""
|
||||
# Generate a barrier ID that is globally unique
|
||||
try:
|
||||
if self.rank == 0:
|
||||
barrier_id = f"barrier_{uuid.uuid4()}"
|
||||
self.broadcast_obj(barrier_id, src=0)
|
||||
else:
|
||||
barrier_id = self.broadcast_obj(None, src=0)
|
||||
except Exception as e:
|
||||
raise RuntimeError("Failed to broadcast barrier_id") from e
|
||||
|
||||
# Phase 1: Signal arrival at barrier
|
||||
# Wait for all processes to arrive
|
||||
# We need all ranks to confirm the arrival of all other ranks.
|
||||
# This is the key synchronization point.
|
||||
arrival_key = f"arrival_{barrier_id}_{self.rank}"
|
||||
try:
|
||||
self.store.set(arrival_key, b"1")
|
||||
except Exception as e:
|
||||
raise RuntimeError("Failed to signal barrier arrival") from e
|
||||
|
||||
start_time = time.time()
|
||||
processes_arrived: set[int] = set()
|
||||
|
||||
while len(processes_arrived) < self.world_size:
|
||||
# Check for timeout
|
||||
cur_time = time.time()
|
||||
if cur_time - start_time > timeout:
|
||||
raise RuntimeError(f"Barrier timed out after {timeout:.2f} seconds")
|
||||
|
||||
# Check for each process
|
||||
for i in range(self.world_size):
|
||||
if i in processes_arrived:
|
||||
continue
|
||||
|
||||
key = f"arrival_{barrier_id}_{i}"
|
||||
try:
|
||||
# Try to get the key - if it exists, we'll get a value
|
||||
# If it doesn't exist, it will throw an exception
|
||||
self.store.get(key)
|
||||
processes_arrived.add(i)
|
||||
except KeyError:
|
||||
# Key doesn't exist yet
|
||||
pass
|
||||
except Exception as check_e:
|
||||
logger.debug("Error checking key existence: %s", check_e)
|
||||
sched_yield()
|
||||
|
||||
# Short sleep to avoid tight polling
|
||||
if len(processes_arrived) < self.world_size:
|
||||
sched_yield()
|
||||
|
||||
# Phase 2: Signal departure from barrier
|
||||
# We only care to block at this stage in rank 0, which runs the
|
||||
# server side of the TCPStore. We want to make sure that all
|
||||
# clients have departed the barrier before rank 0 in case the
|
||||
# next thing after the barrier is a shutdown, including tearing
|
||||
# down the TCPStore. Other ranks can exit the barrier immediately
|
||||
# after signaling their departure.
|
||||
departure_key = f"departure_{barrier_id}_{self.rank}"
|
||||
try:
|
||||
self.store.set(departure_key, b"1")
|
||||
except Exception as e:
|
||||
raise RuntimeError("Failed to signal barrier departure") from e
|
||||
|
||||
if self.rank != 0:
|
||||
return
|
||||
|
||||
# Make rank 0 wait for all processes to signal departure
|
||||
start_time = time.time()
|
||||
processes_departed: set[int] = set()
|
||||
|
||||
while len(processes_departed) < self.world_size:
|
||||
# Check for timeout
|
||||
if time.time() - start_time > timeout:
|
||||
raise RuntimeError(
|
||||
f"Barrier departure timed out after {timeout:.2f} seconds"
|
||||
)
|
||||
|
||||
# Check for each process
|
||||
for i in range(self.world_size):
|
||||
if i in processes_departed:
|
||||
continue
|
||||
|
||||
key = f"departure_{barrier_id}_{i}"
|
||||
try:
|
||||
# Try to get the key - if it exists, we'll get a value
|
||||
# If it doesn't exist, it will throw an exception
|
||||
self.store.get(key)
|
||||
processes_departed.add(i)
|
||||
except KeyError:
|
||||
# Key doesn't exist yet
|
||||
pass
|
||||
except Exception as check_e:
|
||||
logger.debug("Error checking key existence: %s", check_e)
|
||||
sched_yield()
|
||||
|
||||
# Short sleep to avoid tight polling
|
||||
if len(processes_departed) < self.world_size:
|
||||
sched_yield()
|
||||
|
||||
# Clean up keys to avoid leaking memory in the store
|
||||
for i in range(self.world_size):
|
||||
try:
|
||||
self.store.delete_key(f"arrival_{barrier_id}_{i}")
|
||||
except Exception:
|
||||
logger.debug("Error deleting key: %s", f"arrival_{barrier_id}_{i}")
|
||||
|
||||
try:
|
||||
self.store.delete_key(f"departure_{barrier_id}_{i}")
|
||||
except Exception:
|
||||
logger.debug("Error deleting key: %s", f"departure_{barrier_id}_{i}")
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
host: str,
|
||||
port: int,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
data_expiration_seconds: int = 3600,
|
||||
store_timeout: int = 300,
|
||||
) -> "StatelessProcessGroup":
|
||||
"""A replacement for `torch.distributed.init_process_group` that does not
|
||||
pollute the global state.
|
||||
|
||||
If we have process A and process B called `torch.distributed.init_process_group`
|
||||
to form a group, and then we want to form another group with process A, B, C,
|
||||
D, it is not possible in PyTorch, because process A and process B have already
|
||||
formed a group, and process C and process D cannot join that group. This
|
||||
function is a workaround for this issue.
|
||||
|
||||
`torch.distributed.init_process_group` is a global call, while this function
|
||||
is a stateless call. It will return a `StatelessProcessGroup` object that can be
|
||||
used for exchanging metadata. With this function, process A and process B
|
||||
can call `StatelessProcessGroup.create` to form a group, and then process A, B,
|
||||
C, and D can call `StatelessProcessGroup.create` to form another group.
|
||||
""" # noqa
|
||||
launch_server = rank == 0
|
||||
if launch_server:
|
||||
# listen on the specified interface (instead of 0.0.0.0)
|
||||
listen_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
listen_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
listen_socket.bind((host, port))
|
||||
listen_socket.listen()
|
||||
listen_fd = listen_socket.fileno()
|
||||
else:
|
||||
listen_socket = None
|
||||
listen_fd = None
|
||||
|
||||
store = TCPStore(
|
||||
host_name=host,
|
||||
port=port,
|
||||
world_size=world_size,
|
||||
is_master=launch_server,
|
||||
timeout=timedelta(seconds=store_timeout),
|
||||
use_libuv=False, # for now: github.com/pytorch/pytorch/pull/150215
|
||||
master_listen_fd=listen_fd,
|
||||
)
|
||||
|
||||
return StatelessProcessGroup(
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
store=store,
|
||||
socket=listen_socket,
|
||||
data_expiration_seconds=data_expiration_seconds,
|
||||
)
|
||||
|
||||
|
||||
def init_gloo_process_group(
|
||||
prefix_store: PrefixStore,
|
||||
group_rank: int,
|
||||
group_size: int,
|
||||
timeout: timedelta,
|
||||
) -> ProcessGroup:
|
||||
"""
|
||||
Stateless init ProcessGroup with gloo backend compatible with
|
||||
different torch versions.
|
||||
"""
|
||||
with suppress_stdout():
|
||||
pg = ProcessGroup(
|
||||
prefix_store,
|
||||
group_rank,
|
||||
group_size,
|
||||
)
|
||||
from torch.distributed.distributed_c10d import ProcessGroupGloo
|
||||
|
||||
backend_class = ProcessGroupGloo(
|
||||
prefix_store, group_rank, group_size, timeout=timeout
|
||||
)
|
||||
backend_type = ProcessGroup.BackendType.GLOO
|
||||
device = torch.device("cpu")
|
||||
pg._set_default_backend(backend_type)
|
||||
backend_class._set_sequence_number_for_group()
|
||||
|
||||
pg._register_backend(device, backend_type, backend_class)
|
||||
return pg
|
||||
|
||||
|
||||
def stateless_init_torch_distributed_process_group(
|
||||
host: str, port: int, rank: int, world_size: int, backend: str
|
||||
) -> ProcessGroup:
|
||||
"""
|
||||
A replacement for `torch.distributed.init_process_group` that does not
|
||||
pollute the global state. The created ProcessGroup object can be used for
|
||||
some operations such as `allreduce`, because it does not depend on the
|
||||
global rank. However, some operations such as `broadcast` cannot be used
|
||||
because it depends on the global rank.
|
||||
|
||||
# TODO: ask for help from PyTorch team if we need the `broadcast` operation.
|
||||
|
||||
This function is useful when we are not sure about the total number of
|
||||
processes in the process group. For example, we may have process
|
||||
1, 2, ..., 8 who want to communicate, and process 9 might be the same
|
||||
process as process 1, or it might be a different process; process 10
|
||||
might be the same process as process 5, or it might be a different process.
|
||||
In this case, how can we reliably form a communication channel within
|
||||
process 9 and 10, without affecting the communication channel within
|
||||
process 1, 2, ..., 8?
|
||||
|
||||
One possible solution is to figure out if process 9 and 10 are the same
|
||||
as process 1 and 5 beforehand, and then form a communication channel
|
||||
based on the information, adjusting the ranks and world_size etc. However,
|
||||
figuring out the information is not always easy, and it will interfere
|
||||
with the main communication channel.
|
||||
|
||||
Our solution is to always form a communication channel with process 1, 2,
|
||||
..., 8, and then use this function to form another communication channel
|
||||
with process 9 and 10. This way, regardless of whether process 9 and 10
|
||||
are the same as process 1 and 5, the main communication channel is
|
||||
always formed with process 1, 2, ..., 8, and the additional communication
|
||||
channel is formed with process 9 and 10.
|
||||
"""
|
||||
init_method = get_tcp_uri(host, port)
|
||||
backend = Backend(backend) # it is basically string
|
||||
timeout = _get_default_timeout(backend)
|
||||
|
||||
store, rank, world_size = next(
|
||||
rendezvous(init_method, rank, world_size, timeout=timeout)
|
||||
)
|
||||
store.set_timeout(timeout)
|
||||
|
||||
group_rank = rank
|
||||
group_size = world_size
|
||||
|
||||
# Use a PrefixStore to avoid accidental overrides of keys used by
|
||||
# different systems (e.g. RPC) in case the store is multi-tenant.
|
||||
prefix_store = PrefixStore(init_method, store)
|
||||
try:
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
return current_platform.stateless_init_device_torch_dist_pg(
|
||||
backend=backend,
|
||||
prefix_store=prefix_store,
|
||||
group_rank=group_rank,
|
||||
group_size=group_size,
|
||||
timeout=timeout,
|
||||
)
|
||||
except NotImplementedError:
|
||||
# If platform doesn't implement stateless_init_device_torch_dist_pg, it
|
||||
# will raise a NotImplementedError. In this case, we fall back to gloo.
|
||||
return init_gloo_process_group(
|
||||
prefix_store=prefix_store,
|
||||
group_rank=group_rank,
|
||||
group_size=group_size,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
|
||||
def stateless_destroy_torch_distributed_process_group(pg: ProcessGroup) -> None:
|
||||
"""
|
||||
Destroy ProcessGroup returned by
|
||||
stateless_init_torch_distributed_process_group().
|
||||
"""
|
||||
pg.shutdown()
|
||||
_unregister_process_group(pg.group_name)
|
||||
|
||||
|
||||
def get_worker_rank_suffix(global_rank: int | None = None) -> str:
|
||||
"""Generate a descriptive rank suffix for worker identification.
|
||||
|
||||
Returns a string like 'dp0_pp0_tp0_dcp0_ep0_rank0' including all
|
||||
parallel dimensions: DP, PP, TP, DCP, EP.
|
||||
|
||||
Args:
|
||||
global_rank: Optional global rank to append. If not provided,
|
||||
only parallel dimension ranks are included.
|
||||
|
||||
Returns:
|
||||
A string suffix identifying the worker's position in the
|
||||
distributed topology.
|
||||
"""
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_dcp_group,
|
||||
get_dp_group,
|
||||
get_ep_group,
|
||||
get_pp_group,
|
||||
get_tp_group,
|
||||
)
|
||||
|
||||
try:
|
||||
dp_rank = get_dp_group().rank_in_group
|
||||
pp_rank = get_pp_group().rank_in_group
|
||||
tp_rank = get_tp_group().rank_in_group
|
||||
dcp_rank = get_dcp_group().rank_in_group
|
||||
ep_rank = get_ep_group().rank_in_group
|
||||
|
||||
suffix = f"dp{dp_rank}_pp{pp_rank}_tp{tp_rank}_dcp{dcp_rank}_ep{ep_rank}"
|
||||
if global_rank is not None:
|
||||
suffix = f"{suffix}_rank{global_rank}"
|
||||
return suffix
|
||||
except Exception:
|
||||
# Fallback if parallel state not initialized
|
||||
if global_rank is not None:
|
||||
return f"rank{global_rank}"
|
||||
return ""
|
||||
12
vllm/distributed/weight_transfer/__init__.py
Normal file
12
vllm/distributed/weight_transfer/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Weight transfer engines for syncing model weights from trainers
|
||||
to inference workers.
|
||||
"""
|
||||
|
||||
from vllm.distributed.weight_transfer.factory import WeightTransferEngineFactory
|
||||
|
||||
__all__ = [
|
||||
"WeightTransferEngineFactory",
|
||||
]
|
||||
158
vllm/distributed/weight_transfer/base.py
Normal file
158
vllm/distributed/weight_transfer/base.py
Normal file
@@ -0,0 +1,158 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Base class for weight transfer engines."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from dataclasses import KW_ONLY, dataclass, field
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config.parallel import ParallelConfig
|
||||
from vllm.config.weight_transfer import WeightTransferConfig
|
||||
|
||||
TInitInfo = TypeVar("TInitInfo", bound="WeightTransferInitInfo")
|
||||
TUpdateInfo = TypeVar("TUpdateInfo", bound="WeightTransferUpdateInfo")
|
||||
|
||||
|
||||
# Base protocols for backend-specific dataclasses
|
||||
@dataclass
|
||||
class WeightTransferInitInfo(ABC): # noqa: B024
|
||||
"""Base class for backend-specific initialization info."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class WeightTransferUpdateInfo(ABC): # noqa: B024
|
||||
"""Base class for backend-specific weight update info."""
|
||||
|
||||
_: KW_ONLY
|
||||
is_checkpoint_format: bool = True
|
||||
"""Set to True if weights are in checkpoint/original model format and need
|
||||
layerwise processing. Set to False if weights have already been processed
|
||||
into kernel format (repacking, renaming, etc.)."""
|
||||
|
||||
|
||||
# API-level request classes (accept dicts for backend-agnostic serialization)
|
||||
@dataclass
|
||||
class WeightTransferInitRequest:
|
||||
"""API-level weight transfer initialization request."""
|
||||
|
||||
init_info: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class WeightTransferUpdateRequest:
|
||||
"""API-level weight update request."""
|
||||
|
||||
update_info: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
class WeightTransferEngine(ABC, Generic[TInitInfo, TUpdateInfo]):
|
||||
"""
|
||||
Base class for weight transfer engines that handle transport of model weights
|
||||
from a trainer to inference workers.
|
||||
|
||||
This abstraction separates weight transfer transport logic from the worker
|
||||
implementation, allowing different backends (NCCL, CUDA IPC[TODO], RDMA[TODO]) to be
|
||||
plugged in.
|
||||
|
||||
Subclasses should define:
|
||||
init_info_cls: Type of backend-specific initialization info
|
||||
update_info_cls: Type of backend-specific update info
|
||||
"""
|
||||
|
||||
# Subclasses should override these class attributes
|
||||
init_info_cls: type[TInitInfo]
|
||||
update_info_cls: type[TUpdateInfo]
|
||||
|
||||
def __init__(
|
||||
self, config: WeightTransferConfig, parallel_config: ParallelConfig
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the weight transfer engine.
|
||||
|
||||
Args:
|
||||
config: The configuration for the weight transfer engine
|
||||
parallel_config: The configuration for the parallel setup
|
||||
"""
|
||||
self.config = config
|
||||
self.parallel_config = parallel_config
|
||||
|
||||
def parse_init_info(self, init_dict: dict[str, Any]) -> TInitInfo:
|
||||
"""
|
||||
Construct typed init info from dict with validation.
|
||||
|
||||
Args:
|
||||
init_dict: Dictionary containing backend-specific initialization parameters
|
||||
|
||||
Returns:
|
||||
Typed backend-specific init info dataclass
|
||||
|
||||
Raises:
|
||||
ValueError: If init_dict is invalid for this backend
|
||||
"""
|
||||
try:
|
||||
return self.init_info_cls(**init_dict)
|
||||
except TypeError as e:
|
||||
raise ValueError(
|
||||
f"Invalid init_info for {self.__class__.__name__}: {e}"
|
||||
) from e
|
||||
|
||||
def parse_update_info(self, update_dict: dict[str, Any]) -> TUpdateInfo:
|
||||
"""
|
||||
Construct typed update info from dict with validation.
|
||||
|
||||
Args:
|
||||
update_dict: Dictionary containing backend-specific update parameters
|
||||
|
||||
Returns:
|
||||
Typed backend-specific update info dataclass
|
||||
|
||||
Raises:
|
||||
ValueError: If update_dict is invalid for this backend
|
||||
"""
|
||||
try:
|
||||
return self.update_info_cls(**update_dict)
|
||||
except TypeError as e:
|
||||
raise ValueError(
|
||||
f"Invalid update_info for {self.__class__.__name__}: {e}"
|
||||
) from e
|
||||
|
||||
@abstractmethod
|
||||
def init_transfer_engine(self, init_info: TInitInfo) -> None:
|
||||
"""
|
||||
Initialize the weight transfer mechanism.
|
||||
This is called once at the beginning of training.
|
||||
|
||||
Args:
|
||||
init_info: Backend-specific initialization info
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def receive_weights(
|
||||
self,
|
||||
update_info: TUpdateInfo,
|
||||
load_weights: Callable[[list[tuple[str, torch.Tensor]]], None],
|
||||
) -> None:
|
||||
"""
|
||||
Receive weights from the trainer and load them incrementally.
|
||||
|
||||
Args:
|
||||
update_info: Backend-specific update info containing parameter metadata
|
||||
and any backend-specific data
|
||||
load_weights: Callable that loads weights into the model. Called
|
||||
incrementally for each weight to avoid OOM.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def shutdown(self) -> None:
|
||||
"""
|
||||
Shutdown the weight transfer engine.
|
||||
This should be called when the worker is shutting down.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
116
vllm/distributed/weight_transfer/factory.py
Normal file
116
vllm/distributed/weight_transfer/factory.py
Normal file
@@ -0,0 +1,116 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Factory for weight transfer engines with lazy loading."""
|
||||
|
||||
import importlib
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from vllm.distributed.weight_transfer.base import WeightTransferEngine
|
||||
from vllm.logger import init_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config.parallel import ParallelConfig
|
||||
from vllm.config.weight_transfer import WeightTransferConfig
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class WeightTransferEngineFactory:
|
||||
"""Factory for creating weight transfer engines with lazy loading.
|
||||
|
||||
This factory implements a registry pattern that supports:
|
||||
- Lazy loading: Engine modules are only imported when actually needed
|
||||
- Extensibility: Custom engines can be registered at runtime
|
||||
- Centralized registration: All built-in engines registered in one place
|
||||
"""
|
||||
|
||||
_registry: dict[str, Callable[[], type[WeightTransferEngine]]] = {}
|
||||
|
||||
@classmethod
|
||||
def register_engine(
|
||||
cls,
|
||||
name: str,
|
||||
module_path_or_cls: str | type[WeightTransferEngine],
|
||||
class_name: str | None = None,
|
||||
) -> None:
|
||||
"""Register an engine with lazy-loading or direct class reference.
|
||||
|
||||
Supports two calling conventions:
|
||||
1. Lazy loading: register_engine(name, module_path, class_name)
|
||||
2. Direct class: register_engine(name, engine_cls)
|
||||
|
||||
Args:
|
||||
name: The name to register the engine under (e.g., "nccl")
|
||||
module_path_or_cls: Either a module path string for lazy loading,
|
||||
or the engine class directly
|
||||
class_name: Name of the engine class (required if module_path is string)
|
||||
|
||||
Raises:
|
||||
ValueError: If an engine with the same name is already registered
|
||||
"""
|
||||
if name in cls._registry:
|
||||
raise ValueError(f"Weight transfer engine '{name}' is already registered.")
|
||||
|
||||
if isinstance(module_path_or_cls, str):
|
||||
# Lazy loading path
|
||||
module_path = module_path_or_cls
|
||||
if class_name is None:
|
||||
raise ValueError(
|
||||
"class_name is required when registering with module path"
|
||||
)
|
||||
|
||||
def loader() -> type[WeightTransferEngine]:
|
||||
module = importlib.import_module(module_path)
|
||||
return getattr(module, class_name)
|
||||
|
||||
cls._registry[name] = loader
|
||||
else:
|
||||
# Direct class registration
|
||||
engine_cls = module_path_or_cls
|
||||
cls._registry[name] = lambda: engine_cls
|
||||
|
||||
@classmethod
|
||||
def create_engine(
|
||||
cls,
|
||||
config: "WeightTransferConfig",
|
||||
parallel_config: "ParallelConfig",
|
||||
) -> WeightTransferEngine:
|
||||
"""Create a weight transfer engine instance.
|
||||
|
||||
Args:
|
||||
config: Weight transfer configuration containing the backend name
|
||||
parallel_config: Parallel configuration for the engine
|
||||
|
||||
Returns:
|
||||
An initialized weight transfer engine instance
|
||||
|
||||
Raises:
|
||||
ValueError: If the backend is not registered
|
||||
"""
|
||||
backend = config.backend
|
||||
if backend not in cls._registry:
|
||||
available = list(cls._registry.keys())
|
||||
raise ValueError(
|
||||
f"Invalid weight transfer backend: {backend}. "
|
||||
f"Available engines: {available}"
|
||||
)
|
||||
engine_cls = cls._registry[backend]()
|
||||
|
||||
logger.info(
|
||||
"Creating weight transfer engine: %s",
|
||||
engine_cls.__name__,
|
||||
)
|
||||
|
||||
return engine_cls(config, parallel_config)
|
||||
|
||||
|
||||
# Register built-in weight transfer engines here.
|
||||
# Registration should be centralized to ensure lazy loading -
|
||||
# engine modules are only imported when actually used.
|
||||
|
||||
WeightTransferEngineFactory.register_engine(
|
||||
"nccl",
|
||||
"vllm.distributed.weight_transfer.nccl_engine",
|
||||
"NCCLWeightTransferEngine",
|
||||
)
|
||||
315
vllm/distributed/weight_transfer/nccl_engine.py
Normal file
315
vllm/distributed/weight_transfer/nccl_engine.py
Normal file
@@ -0,0 +1,315 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""NCCL-based weight transfer engine."""
|
||||
|
||||
from collections.abc import Callable, Iterator
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
|
||||
|
||||
from vllm.config.parallel import ParallelConfig
|
||||
from vllm.config.weight_transfer import WeightTransferConfig
|
||||
from vllm.distributed.weight_transfer.base import (
|
||||
WeightTransferEngine,
|
||||
WeightTransferInitInfo,
|
||||
WeightTransferUpdateInfo,
|
||||
)
|
||||
from vllm.distributed.weight_transfer.packed_tensor import (
|
||||
DEFAULT_PACKED_BUFFER_SIZE_BYTES,
|
||||
DEFAULT_PACKED_NUM_BUFFERS,
|
||||
packed_broadcast_consumer,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class NCCLWeightTransferInitInfo(WeightTransferInitInfo):
|
||||
"""Initialization info for NCCL weight transfer backend."""
|
||||
|
||||
master_address: str
|
||||
master_port: int
|
||||
rank_offset: int
|
||||
world_size: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class NCCLWeightTransferUpdateInfo(WeightTransferUpdateInfo):
|
||||
"""Update info for NCCL weight transfer backend."""
|
||||
|
||||
names: list[str]
|
||||
dtype_names: list[str]
|
||||
shapes: list[list[int]]
|
||||
packed: bool = False
|
||||
"""Whether to use packed tensor broadcasting for efficiency.
|
||||
When True, multiple tensors are batched together before broadcasting
|
||||
to reduce NCCL communication overhead."""
|
||||
packed_buffer_size_bytes: int = DEFAULT_PACKED_BUFFER_SIZE_BYTES
|
||||
"""Size in bytes for each packed tensor buffer. Default is 1GB.
|
||||
Both producer and consumer must use the same value."""
|
||||
packed_num_buffers: int = DEFAULT_PACKED_NUM_BUFFERS
|
||||
"""Number of buffers for double/triple buffering during packed transfer.
|
||||
Both producer and consumer must use the same value."""
|
||||
|
||||
def __post_init__(self):
|
||||
"""Validate that all lists have the same length."""
|
||||
num_params = len(self.names)
|
||||
if len(self.dtype_names) != num_params:
|
||||
raise ValueError(
|
||||
f"`dtype_names` should be of the same size as `names`: "
|
||||
f"got {len(self.dtype_names)} and {len(self.names)}"
|
||||
)
|
||||
if len(self.shapes) != num_params:
|
||||
raise ValueError(
|
||||
f"`shapes` should be of the same size as `names`: "
|
||||
f"got {len(self.shapes)} and {len(self.names)}"
|
||||
)
|
||||
|
||||
|
||||
class NCCLWeightTransferEngine(
|
||||
WeightTransferEngine[NCCLWeightTransferInitInfo, NCCLWeightTransferUpdateInfo]
|
||||
):
|
||||
"""
|
||||
Weight transfer engine using NCCL for communication between trainer and workers.
|
||||
|
||||
This implementation uses NCCL broadcast operations to transfer weights from
|
||||
the trainer (rank 0) to all inference workers in a process group.
|
||||
"""
|
||||
|
||||
# Define backend-specific dataclass types
|
||||
init_info_cls = NCCLWeightTransferInitInfo
|
||||
update_info_cls = NCCLWeightTransferUpdateInfo
|
||||
|
||||
def __init__(
|
||||
self, config: WeightTransferConfig, parallel_config: ParallelConfig
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the NCCL weight transfer engine.
|
||||
|
||||
Args:
|
||||
config: The configuration for the weight transfer engine
|
||||
parallel_config: The configuration for the parallel setup
|
||||
"""
|
||||
super().__init__(config, parallel_config)
|
||||
self.model_update_group: PyNcclCommunicator | None = None
|
||||
|
||||
def init_transfer_engine(self, init_info: NCCLWeightTransferInitInfo) -> None:
|
||||
"""
|
||||
Initialize NCCL process group with the trainer.
|
||||
|
||||
Args:
|
||||
init_info: NCCL initialization info containing master address, port,
|
||||
rank offset, and world size
|
||||
"""
|
||||
|
||||
# Calculate the global rank in the trainer-worker process group
|
||||
# Must account for data parallel to get unique ranks across all workers
|
||||
dp_rank = self.parallel_config.data_parallel_rank
|
||||
world_size_per_dp = self.parallel_config.world_size # TP * PP
|
||||
rank_within_dp = self.parallel_config.rank
|
||||
|
||||
# Unique rank across all DP groups
|
||||
worker_rank = dp_rank * world_size_per_dp + rank_within_dp
|
||||
rank = worker_rank + init_info.rank_offset
|
||||
# Create stateless process group
|
||||
self.model_update_group = (
|
||||
NCCLWeightTransferEngine._stateless_init_process_group(
|
||||
init_info.master_address,
|
||||
init_info.master_port,
|
||||
rank,
|
||||
init_info.world_size,
|
||||
torch.cuda.current_device(),
|
||||
)
|
||||
)
|
||||
|
||||
def receive_weights(
|
||||
self,
|
||||
update_info: NCCLWeightTransferUpdateInfo,
|
||||
load_weights: Callable[[list[tuple[str, torch.Tensor]]], None],
|
||||
) -> None:
|
||||
"""
|
||||
Receive weights from trainer via NCCL broadcast and load them incrementally.
|
||||
|
||||
If update_info.packed is True, uses packed tensor broadcasting for
|
||||
efficient transfer of multiple weights in batches. Otherwise, uses simple
|
||||
one-by-one broadcasting.
|
||||
|
||||
Args:
|
||||
update_info: NCCL update info containing parameter names, dtypes, shapes,
|
||||
and packed flag
|
||||
load_weights: Callable that loads weights into the model. Called
|
||||
incrementally for each batch of weights to avoid OOM.
|
||||
"""
|
||||
if self.model_update_group is None:
|
||||
raise RuntimeError(
|
||||
"NCCL weight transfer not initialized. "
|
||||
"Call init_transfer_engine() first."
|
||||
)
|
||||
|
||||
if update_info.packed:
|
||||
# Build iterator of (name, (shape, dtype)) from update_info
|
||||
def state_dict_info_iterator():
|
||||
for name, dtype_name, shape in zip(
|
||||
update_info.names, update_info.dtype_names, update_info.shapes
|
||||
):
|
||||
dtype = getattr(torch, dtype_name)
|
||||
yield (name, (shape, dtype))
|
||||
|
||||
packed_broadcast_consumer(
|
||||
iterator=state_dict_info_iterator(),
|
||||
group=self.model_update_group,
|
||||
src=0,
|
||||
post_unpack_func=load_weights,
|
||||
buffer_size_bytes=update_info.packed_buffer_size_bytes,
|
||||
num_buffers=update_info.packed_num_buffers,
|
||||
)
|
||||
else:
|
||||
# Use simple one-by-one broadcasting
|
||||
for name, dtype_name, shape in zip(
|
||||
update_info.names, update_info.dtype_names, update_info.shapes
|
||||
):
|
||||
dtype = getattr(torch, dtype_name)
|
||||
weight = torch.empty(shape, dtype=dtype, device="cuda")
|
||||
self.model_update_group.broadcast(
|
||||
weight, src=0, stream=torch.cuda.current_stream()
|
||||
)
|
||||
load_weights([(name, weight)])
|
||||
del weight
|
||||
|
||||
def shutdown(self) -> None:
|
||||
if self.model_update_group is not None:
|
||||
# Clean up the communicator by removing the reference
|
||||
self.model_update_group = None
|
||||
|
||||
@staticmethod
|
||||
def trainer_send_weights(
|
||||
iterator: Iterator[tuple[str, torch.Tensor]],
|
||||
group: Any,
|
||||
src: int = 0,
|
||||
post_iter_func: Callable[[tuple[str, torch.Tensor]], torch.Tensor]
|
||||
| None = None,
|
||||
packed: bool = False,
|
||||
stream: torch.cuda.Stream | None = None,
|
||||
packed_buffer_size_bytes: int = DEFAULT_PACKED_BUFFER_SIZE_BYTES,
|
||||
packed_num_buffers: int = DEFAULT_PACKED_NUM_BUFFERS,
|
||||
) -> None:
|
||||
"""Broadcast weights from trainer to vLLM workers.
|
||||
|
||||
Args:
|
||||
iterator: Iterator of model parameters. Returns (name, tensor) tuples
|
||||
group: Process group (PyNcclCommunicator)
|
||||
src: Source rank (default 0, trainer is typically rank 0)
|
||||
post_iter_func: Optional function to apply to each (name, tensor) pair
|
||||
before broadcasting. If None, extracts just the tensor.
|
||||
packed: Whether to use packed tensor broadcasting for efficiency.
|
||||
When True, multiple tensors are batched together before
|
||||
broadcasting to reduce NCCL communication overhead.
|
||||
stream: CUDA stream to use for broadcasting if packed is False.
|
||||
If packed is True, new streams will be created for each buffer.
|
||||
packed_buffer_size_bytes: Size in bytes for each packed tensor buffer.
|
||||
Must match the value used in NCCLWeightTransferUpdateInfo.
|
||||
packed_num_buffers: Number of buffers for double/triple buffering.
|
||||
Must match the value used in NCCLWeightTransferUpdateInfo.
|
||||
|
||||
Example:
|
||||
>>> from vllm.distributed.weight_transfer.nccl_engine import (
|
||||
... NCCLWeightTransferEngine,
|
||||
... )
|
||||
>>> param_iter = ((n, p) for n, p in model.named_parameters())
|
||||
>>> NCCLWeightTransferEngine.trainer_send_weights(
|
||||
... param_iter, group, packed=True
|
||||
... )
|
||||
"""
|
||||
if post_iter_func is None:
|
||||
# Default: extract just the tensor from (name, tensor) tuple
|
||||
post_iter_func = lambda x: x[1]
|
||||
|
||||
if packed:
|
||||
# Use packed tensor broadcasting for efficiency
|
||||
from vllm.distributed.weight_transfer.packed_tensor import (
|
||||
packed_broadcast_producer,
|
||||
)
|
||||
|
||||
packed_broadcast_producer(
|
||||
iterator=iterator,
|
||||
group=group,
|
||||
src=src,
|
||||
post_iter_func=post_iter_func,
|
||||
buffer_size_bytes=packed_buffer_size_bytes,
|
||||
num_buffers=packed_num_buffers,
|
||||
)
|
||||
else:
|
||||
# Use simple one-by-one broadcasting
|
||||
for item in iterator:
|
||||
tensor = post_iter_func(item)
|
||||
group.broadcast(
|
||||
tensor, src=src, stream=stream or torch.cuda.current_stream()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def trainer_init(
|
||||
init_info: NCCLWeightTransferInitInfo | dict,
|
||||
) -> "PyNcclCommunicator":
|
||||
"""
|
||||
Initialize NCCL process group for trainer-side weight transfer.
|
||||
|
||||
The trainer is always rank 0 in the process group. Uses the current
|
||||
CUDA device (torch.cuda.current_device()).
|
||||
|
||||
Args:
|
||||
init_info: Either an NCCLWeightTransferInitInfo object or a dict with keys:
|
||||
- master_address: str
|
||||
- master_port: int
|
||||
- world_size: int
|
||||
|
||||
Returns:
|
||||
PyNcclCommunicator for weight transfer.
|
||||
|
||||
Example:
|
||||
>>> from vllm.distributed.weight_transfer.nccl_engine import (
|
||||
... NCCLWeightTransferEngine,
|
||||
... )
|
||||
>>> group = NCCLWeightTransferEngine.trainer_init(
|
||||
... dict(
|
||||
... master_address=master_address,
|
||||
... master_port=master_port,
|
||||
... world_size=world_size,
|
||||
... ),
|
||||
... )
|
||||
"""
|
||||
if isinstance(init_info, dict):
|
||||
master_address = init_info["master_address"]
|
||||
master_port = init_info["master_port"]
|
||||
world_size = init_info["world_size"]
|
||||
else:
|
||||
# NCCLWeightTransferInitInfo object
|
||||
master_address = init_info.master_address
|
||||
master_port = init_info.master_port
|
||||
world_size = init_info.world_size
|
||||
|
||||
# Trainer is always rank 0
|
||||
return NCCLWeightTransferEngine._stateless_init_process_group(
|
||||
master_address, master_port, 0, world_size, torch.cuda.current_device()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _stateless_init_process_group(
|
||||
master_address, master_port, rank, world_size, device
|
||||
):
|
||||
"""
|
||||
vLLM provides `StatelessProcessGroup` to create a process group
|
||||
without considering the global process group in torch.distributed.
|
||||
It is recommended to create `StatelessProcessGroup`, and then initialize
|
||||
the data-plane communication (NCCL) between external (train processes)
|
||||
and vLLM workers.
|
||||
"""
|
||||
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
|
||||
from vllm.distributed.utils import StatelessProcessGroup
|
||||
|
||||
pg = StatelessProcessGroup.create(
|
||||
host=master_address, port=master_port, rank=rank, world_size=world_size
|
||||
)
|
||||
pynccl = PyNcclCommunicator(pg, device=device)
|
||||
return pynccl
|
||||
216
vllm/distributed/weight_transfer/packed_tensor.py
Normal file
216
vllm/distributed/weight_transfer/packed_tensor.py
Normal file
@@ -0,0 +1,216 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Packed tensor utilities for efficient weight transfer."""
|
||||
|
||||
import math
|
||||
from collections.abc import Callable, Iterator
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
# Default values for packed tensor configuration.
|
||||
# These are imported by NCCLWeightTransferUpdateInfo and trainer_send_weights.
|
||||
DEFAULT_PACKED_BUFFER_SIZE_BYTES = 1024 * 1024 * 1024 # 1GB
|
||||
DEFAULT_PACKED_NUM_BUFFERS = 2
|
||||
|
||||
|
||||
def packed_broadcast_producer(
|
||||
iterator: Iterator[tuple[str, torch.Tensor]],
|
||||
group: Any,
|
||||
src: int,
|
||||
post_iter_func: Callable[[tuple[str, torch.Tensor]], torch.Tensor],
|
||||
buffer_size_bytes: int = DEFAULT_PACKED_BUFFER_SIZE_BYTES,
|
||||
num_buffers: int = DEFAULT_PACKED_NUM_BUFFERS,
|
||||
) -> None:
|
||||
"""Broadcast tensors in a packed manner from trainer to workers.
|
||||
|
||||
Args:
|
||||
iterator: Iterator of model parameters. Returns a tuple of (name, tensor)
|
||||
group: Process group (PyNcclCommunicator)
|
||||
src: Source rank (0 in current implementation)
|
||||
post_iter_func: Function to apply to each (name, tensor) pair before
|
||||
packing, should return a tensor
|
||||
buffer_size_bytes: Size in bytes for each packed tensor buffer.
|
||||
Both producer and consumer must use the same value.
|
||||
num_buffers: Number of buffers for double/triple buffering.
|
||||
Both producer and consumer must use the same value.
|
||||
|
||||
"""
|
||||
target_packed_tensor_size = buffer_size_bytes
|
||||
|
||||
streams = [torch.cuda.Stream() for _ in range(num_buffers)]
|
||||
buffer_idx = 0
|
||||
|
||||
packing_tensor_list: list[list[torch.Tensor]] = [[] for _ in range(num_buffers)]
|
||||
packing_tensor_sizes: list[int] = [0 for _ in range(num_buffers)]
|
||||
packed_tensors: list[torch.Tensor] = [
|
||||
torch.empty(0, dtype=torch.uint8, device="cuda") for _ in range(num_buffers)
|
||||
]
|
||||
|
||||
while True:
|
||||
# Synchronize the current stream
|
||||
streams[buffer_idx].synchronize()
|
||||
# Start tasks for the new buffer in a new stream
|
||||
with torch.cuda.stream(streams[buffer_idx]):
|
||||
try:
|
||||
# Initialize the packing tensor list and sizes
|
||||
packing_tensor_list[buffer_idx] = []
|
||||
packing_tensor_sizes[buffer_idx] = 0
|
||||
# Pack the tensors
|
||||
while True:
|
||||
# Apply post processing and convert to linearized uint8 tensor
|
||||
tensor = (
|
||||
post_iter_func(next(iterator))
|
||||
.contiguous()
|
||||
.view(torch.uint8)
|
||||
.view(-1)
|
||||
)
|
||||
packing_tensor_list[buffer_idx].append(tensor)
|
||||
packing_tensor_sizes[buffer_idx] += tensor.numel()
|
||||
if packing_tensor_sizes[buffer_idx] > target_packed_tensor_size:
|
||||
break
|
||||
# Pack the tensors and call broadcast collective
|
||||
packed_tensors[buffer_idx] = torch.cat(
|
||||
packing_tensor_list[buffer_idx], dim=0
|
||||
)
|
||||
group.broadcast(packed_tensors[buffer_idx], src=src)
|
||||
# Move to the next buffer
|
||||
buffer_idx = (buffer_idx + 1) % num_buffers
|
||||
except StopIteration:
|
||||
# Do the last broadcast if there are remaining tensors
|
||||
if len(packing_tensor_list[buffer_idx]) > 0:
|
||||
packed_tensors[buffer_idx] = torch.cat(
|
||||
packing_tensor_list[buffer_idx], dim=0
|
||||
)
|
||||
group.broadcast(packed_tensors[buffer_idx], src=src)
|
||||
break
|
||||
|
||||
|
||||
def packed_broadcast_consumer(
|
||||
iterator: Iterator[tuple[str, tuple[list[int], torch.dtype]]],
|
||||
group: Any,
|
||||
src: int,
|
||||
post_unpack_func: Callable[[list[tuple[str, torch.Tensor]]], None],
|
||||
buffer_size_bytes: int = DEFAULT_PACKED_BUFFER_SIZE_BYTES,
|
||||
num_buffers: int = DEFAULT_PACKED_NUM_BUFFERS,
|
||||
) -> None:
|
||||
"""Consume packed tensors and unpack them into a list of tensors.
|
||||
|
||||
Args:
|
||||
iterator: Iterator of parameter metadata. Returns (name, (shape, dtype))
|
||||
group: Process group (PyNcclCommunicator)
|
||||
src: Source rank (0 in current implementation)
|
||||
post_unpack_func: Function to apply to each list of (name, tensor) after
|
||||
unpacking
|
||||
buffer_size_bytes: Size in bytes for each packed tensor buffer.
|
||||
Both producer and consumer must use the same value.
|
||||
num_buffers: Number of buffers for double/triple buffering.
|
||||
Both producer and consumer must use the same value.
|
||||
|
||||
"""
|
||||
|
||||
def unpack_tensor(
|
||||
packed_tensor: torch.Tensor,
|
||||
names: list[str],
|
||||
shapes: list[list[int]],
|
||||
dtypes: list[torch.dtype],
|
||||
tensor_sizes: list[int],
|
||||
) -> list[tuple[str, torch.Tensor]]:
|
||||
"""Unpack a single tensor into a list of tensors.
|
||||
|
||||
Args:
|
||||
packed_tensor: The packed torch.uint8 tensor to unpack
|
||||
names: List of tensor names
|
||||
shapes: List of tensor shapes
|
||||
dtypes: List of tensor dtypes
|
||||
tensor_sizes: List of tensor sizes in bytes
|
||||
|
||||
Returns:
|
||||
unpacked List[(name, tensor)]
|
||||
"""
|
||||
unpacked_tensors = packed_tensor.split(tensor_sizes)
|
||||
|
||||
unpacked_list = [
|
||||
(name, tensor.contiguous().view(dtype).view(*shape))
|
||||
for name, shape, dtype, tensor in zip(
|
||||
names, shapes, dtypes, unpacked_tensors
|
||||
)
|
||||
]
|
||||
|
||||
return unpacked_list
|
||||
|
||||
target_packed_tensor_size = buffer_size_bytes
|
||||
|
||||
streams = [torch.cuda.Stream() for _ in range(num_buffers)]
|
||||
buffer_idx = 0
|
||||
|
||||
packing_tensor_meta_data: list[list[tuple[str, list[int], torch.dtype, int]]] = [
|
||||
[] for _ in range(num_buffers)
|
||||
]
|
||||
packing_tensor_sizes: list[int] = [0 for _ in range(num_buffers)]
|
||||
packed_tensors: list[torch.Tensor] = [
|
||||
torch.empty(0, dtype=torch.uint8, device="cuda") for _ in range(num_buffers)
|
||||
]
|
||||
|
||||
while True:
|
||||
# Synchronize the current stream
|
||||
streams[buffer_idx].synchronize()
|
||||
with torch.cuda.stream(streams[buffer_idx]):
|
||||
# Initialize the packing tensor meta data
|
||||
packing_tensor_meta_data[buffer_idx] = []
|
||||
packing_tensor_sizes[buffer_idx] = 0
|
||||
try:
|
||||
# Form a packed tensor
|
||||
while True:
|
||||
name, (shape, dtype) = next(iterator)
|
||||
tensor_size = math.prod(shape) * dtype.itemsize
|
||||
packing_tensor_meta_data[buffer_idx].append(
|
||||
(name, shape, dtype, tensor_size)
|
||||
)
|
||||
packing_tensor_sizes[buffer_idx] += tensor_size
|
||||
if packing_tensor_sizes[buffer_idx] > target_packed_tensor_size:
|
||||
break
|
||||
# Create a packed tensor and broadcast it
|
||||
packed_tensors[buffer_idx] = torch.empty(
|
||||
packing_tensor_sizes[buffer_idx], dtype=torch.uint8, device="cuda"
|
||||
)
|
||||
group.broadcast(packed_tensors[buffer_idx], src=src)
|
||||
# Load the packed tensor into the model
|
||||
names, shapes, dtypes, tensor_sizes = zip(
|
||||
*packing_tensor_meta_data[buffer_idx]
|
||||
)
|
||||
post_unpack_func(
|
||||
unpack_tensor(
|
||||
packed_tensors[buffer_idx],
|
||||
list(names),
|
||||
list(shapes),
|
||||
list(dtypes),
|
||||
list(tensor_sizes),
|
||||
)
|
||||
)
|
||||
# Move to the next buffer
|
||||
buffer_idx = (buffer_idx + 1) % num_buffers
|
||||
except StopIteration:
|
||||
# Do the last broadcast if there are remaining tensors
|
||||
if len(packing_tensor_meta_data[buffer_idx]) > 0:
|
||||
# Create a packed tensor and broadcast it
|
||||
packed_tensors[buffer_idx] = torch.empty(
|
||||
packing_tensor_sizes[buffer_idx],
|
||||
dtype=torch.uint8,
|
||||
device="cuda",
|
||||
)
|
||||
group.broadcast(packed_tensors[buffer_idx], src=src)
|
||||
# Load the packed tensor into the model
|
||||
names, shapes, dtypes, tensor_sizes = zip(
|
||||
*packing_tensor_meta_data[buffer_idx]
|
||||
)
|
||||
post_unpack_func(
|
||||
unpack_tensor(
|
||||
packed_tensors[buffer_idx],
|
||||
list(names),
|
||||
list(shapes),
|
||||
list(dtypes),
|
||||
list(tensor_sizes),
|
||||
)
|
||||
)
|
||||
break
|
||||
Reference in New Issue
Block a user