Sync from v0.13
This commit is contained in:
823
vllm/multimodal/cache.py
Normal file
823
vllm/multimodal/cache.py
Normal file
@@ -0,0 +1,823 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import operator
|
||||
import sys
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Mapping, Sequence
|
||||
from multiprocessing.synchronize import Lock as LockType
|
||||
from typing import TYPE_CHECKING, Generic, TypeAlias, TypeVar, cast
|
||||
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.distributed.device_communicators.shm_object_storage import (
|
||||
MsgpackSerde,
|
||||
SingleWriterShmObjectStorage,
|
||||
SingleWriterShmRingBuffer,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.cache import CacheInfo, LRUCache
|
||||
from vllm.utils.jsontree import json_count_leaves, json_map_leaves, json_reduce_leaves
|
||||
from vllm.utils.mem_constants import GiB_bytes, MiB_bytes
|
||||
|
||||
from .inputs import (
|
||||
MultiModalBatchedField,
|
||||
MultiModalFeatureSpec,
|
||||
MultiModalFieldElem,
|
||||
MultiModalKwargsItem,
|
||||
MultiModalKwargsItems,
|
||||
NestedTensors,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
|
||||
from .processing import ResolvedPromptUpdate
|
||||
from .registry import MultiModalRegistry
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class MultiModalProcessorCacheItem:
|
||||
"""
|
||||
The data to store inside `MultiModalProcessorOnlyCache`.
|
||||
|
||||
Args:
|
||||
item: The processed tensor data corresponding to a multi-modal item.
|
||||
prompt_updates: The prompt updates corresponding to `item`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
item: MultiModalKwargsItem,
|
||||
prompt_updates: Sequence["ResolvedPromptUpdate"],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.item = item
|
||||
self.prompt_updates = prompt_updates
|
||||
|
||||
|
||||
class MultiModalProcessorCacheItemMetadata:
|
||||
"""
|
||||
The metadata to store inside `MultiModalProcessorSenderCache`.
|
||||
|
||||
Args:
|
||||
item: The processed tensor data corresponding to a multi-modal item.
|
||||
Since P1 already stores the tensor data, we only store its size
|
||||
metadata in P0 to reduce memory usage. The size metadata is still
|
||||
needed to keep the same cache eviction policy as P0.
|
||||
prompt_updates: The prompt updates corresponding to `item`.
|
||||
This needs to stay on P0 because for some models, they are
|
||||
dependent on the processed tensor data (cached on P1).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
item: MultiModalKwargsItem,
|
||||
prompt_updates: Sequence["ResolvedPromptUpdate"],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.item_size = MultiModalCache.get_item_size(item)
|
||||
self.prompt_updates = prompt_updates
|
||||
|
||||
|
||||
MultiModalCacheValue: TypeAlias = (
|
||||
MultiModalProcessorCacheItem
|
||||
| MultiModalProcessorCacheItemMetadata
|
||||
| MultiModalKwargsItems
|
||||
| MultiModalKwargsItem
|
||||
| Mapping[str, NestedTensors]
|
||||
)
|
||||
|
||||
_V = TypeVar("_V", bound=MultiModalCacheValue)
|
||||
|
||||
|
||||
class MultiModalCache:
|
||||
@classmethod
|
||||
def get_leaf_size(cls, leaf: object) -> int:
|
||||
if isinstance(leaf, MultiModalProcessorCacheItem):
|
||||
return cls.get_leaf_size(leaf.item)
|
||||
if isinstance(leaf, MultiModalProcessorCacheItemMetadata):
|
||||
return leaf.item_size
|
||||
|
||||
# These are not subclasses of dict
|
||||
if isinstance(
|
||||
leaf,
|
||||
(MultiModalKwargsItems, MultiModalKwargsItem, MultiModalFieldElem),
|
||||
):
|
||||
return cls.get_item_size(leaf.data) # type: ignore
|
||||
|
||||
# sys.getsizeof doesn't work for tensors
|
||||
if isinstance(leaf, torch.Tensor):
|
||||
return leaf.nbytes
|
||||
|
||||
return sys.getsizeof(leaf)
|
||||
|
||||
@classmethod
|
||||
def get_item_size(
|
||||
cls,
|
||||
value: MultiModalCacheValue,
|
||||
*,
|
||||
debug: bool = False,
|
||||
) -> int:
|
||||
size = json_reduce_leaves(
|
||||
operator.add, json_map_leaves(cls.get_leaf_size, value)
|
||||
)
|
||||
|
||||
if debug:
|
||||
leaf_count = json_count_leaves(value)
|
||||
logger.debug(
|
||||
"Calculated size of %s to be %.2f GiB (%d leaves)",
|
||||
type(value),
|
||||
size / GiB_bytes,
|
||||
leaf_count,
|
||||
)
|
||||
|
||||
return size
|
||||
|
||||
@classmethod
|
||||
def get_item_complexity(cls, value: MultiModalCacheValue) -> int:
|
||||
"""
|
||||
Get the number of leaf elements in a multi-modal cache value.
|
||||
|
||||
This provides a measure of structural complexity that can be useful
|
||||
for debugging cache performance and understanding data patterns.
|
||||
|
||||
Args:
|
||||
value: The multi-modal cache value to analyze.
|
||||
|
||||
Returns:
|
||||
The number of leaf elements in the nested structure.
|
||||
"""
|
||||
return json_count_leaves(value)
|
||||
|
||||
@classmethod
|
||||
def get_lru_cache(
|
||||
cls,
|
||||
capacity_gb: float,
|
||||
value_type: type[_V],
|
||||
*,
|
||||
debug: bool = False,
|
||||
) -> LRUCache[str, _V]:
|
||||
return LRUCache(
|
||||
GiB_bytes * capacity_gb,
|
||||
getsizeof=lambda x: cls.get_item_size(x, debug=debug),
|
||||
)
|
||||
|
||||
|
||||
_I = TypeVar("_I", contravariant=True)
|
||||
_O = TypeVar("_O", covariant=True)
|
||||
|
||||
|
||||
class BaseMultiModalCache(ABC, Generic[_I, _O]):
|
||||
"""
|
||||
Abstract base class to read/write multi-modal items from cache.
|
||||
|
||||
The idea of multi-modal caching is based on having a client and server
|
||||
where the client executes in the frontend process (=P0) and
|
||||
the server in the core process (=P1). The data flow is as follows:
|
||||
|
||||
```
|
||||
is_cached() x N get_and_update()
|
||||
P0: From API -----------------> -----------------> To P1
|
||||
|
||||
get_and_update()
|
||||
P1: From P0 -----------------> To model
|
||||
```
|
||||
|
||||
`is_cached()` can be called any number of times in P0. However,
|
||||
`get_and_update()` must be called in P0 and P1 one after another
|
||||
so that their cache eviction order remains the same.
|
||||
|
||||
This ensures that the keys in P0 and P1 caches are mirrored,
|
||||
allowing us to determine whether a key is cached in P1 by looking
|
||||
up the P0 cache, without having to communicate with P1.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_and_update_item(
|
||||
self,
|
||||
mm_item: _I,
|
||||
mm_hash: str,
|
||||
) -> _O:
|
||||
"""
|
||||
Possibly update a multi-modal item based on whether it is
|
||||
in the underlying cache.
|
||||
|
||||
This update is done out-of-place and updates the cache eviction order.
|
||||
|
||||
Args:
|
||||
mm_item: The multi-modal item to update.
|
||||
mm_hash: The hash of `mm_item`.
|
||||
|
||||
Returns:
|
||||
The update multi-modal item.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_and_update(
|
||||
self,
|
||||
mm_items: Sequence[_I],
|
||||
mm_hashes: list[str],
|
||||
) -> list[_O]:
|
||||
"""
|
||||
Possibly update a sequence of multi-modal items based on whether they
|
||||
are in the underlying cache.
|
||||
|
||||
This update is done out-of-place and updates the cache eviction order.
|
||||
|
||||
Args:
|
||||
mm_items: The multi-modal items to update.
|
||||
mm_hashes: The hash of each item in `mm_items`.
|
||||
|
||||
Returns:
|
||||
A new list of updated multi-modal items.
|
||||
"""
|
||||
assert len(mm_items) == len(mm_hashes)
|
||||
|
||||
return [
|
||||
self.get_and_update_item(mm_item, mm_hash)
|
||||
for mm_item, mm_hash in zip(mm_items, mm_hashes)
|
||||
]
|
||||
|
||||
@abstractmethod
|
||||
def clear_cache(self) -> None:
|
||||
"""Clear the underlying cache."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
MultiModalProcessorCacheInItem: TypeAlias = (
|
||||
tuple[MultiModalKwargsItem, Sequence["ResolvedPromptUpdate"]] | None
|
||||
)
|
||||
|
||||
|
||||
MultiModalProcessorCacheOutItem: TypeAlias = tuple[
|
||||
MultiModalKwargsItem | None, Sequence["ResolvedPromptUpdate"]
|
||||
]
|
||||
|
||||
|
||||
class BaseMultiModalProcessorCache(
|
||||
BaseMultiModalCache[MultiModalProcessorCacheInItem, MultiModalProcessorCacheOutItem]
|
||||
):
|
||||
"""The required interface for caches on P0."""
|
||||
|
||||
@abstractmethod
|
||||
def is_cached_item(self, mm_hash: str) -> bool:
|
||||
"""
|
||||
Check whether a multi-modal item is
|
||||
in the underlying cache.
|
||||
|
||||
This **DOES NOT** update the cache eviction order.
|
||||
|
||||
Args:
|
||||
mm_hash: The hash of the item to check.
|
||||
|
||||
Returns:
|
||||
`True` if the item is cached, otherwise `False`.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def is_cached(self, mm_hashes: list[str]) -> list[bool]:
|
||||
"""
|
||||
Check whether a sequence of multi-modal items are
|
||||
in the underlying cache.
|
||||
|
||||
This **DOES NOT** update the cache eviction order.
|
||||
|
||||
Args:
|
||||
mm_hashes: The hash of each item to check.
|
||||
|
||||
Returns:
|
||||
For each item, `True` if the item is cached, otherwise `False`.
|
||||
"""
|
||||
return [self.is_cached_item(mm_hash) for mm_hash in mm_hashes]
|
||||
|
||||
@abstractmethod
|
||||
def touch_sender_cache_item(self, mm_hash: str) -> None:
|
||||
"""
|
||||
Update the cache eviction order for a multi-modal item.
|
||||
|
||||
This is used to touch the item in the cache without changing
|
||||
its value.
|
||||
|
||||
Args:
|
||||
mm_hash: The hash of the multi-modal item.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def make_stats(self, *, delta: bool = False) -> CacheInfo:
|
||||
"""
|
||||
Get (and reset) the multi-modal cache stats.
|
||||
|
||||
Returns:
|
||||
The current multi-modal caching stats.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MultiModalProcessorOnlyCache(BaseMultiModalProcessorCache):
|
||||
"""
|
||||
The cache which is used on P0 when IPC caching is disabled.
|
||||
|
||||
How to update each item:
|
||||
|
||||
- If the item is in the cache, replace the input with the cached item.
|
||||
- If the item is not in the cache, store that item (which includes
|
||||
tensor data and metadata) into the cache, and return the input.
|
||||
"""
|
||||
|
||||
def __init__(self, model_config: "ModelConfig") -> None:
|
||||
super().__init__()
|
||||
|
||||
mm_config = model_config.get_multimodal_config()
|
||||
|
||||
self._cache = MultiModalCache.get_lru_cache(
|
||||
mm_config.mm_processor_cache_gb,
|
||||
MultiModalProcessorCacheItem,
|
||||
)
|
||||
|
||||
@override
|
||||
def is_cached_item(self, mm_hash: str) -> bool:
|
||||
return mm_hash in self._cache
|
||||
|
||||
@override
|
||||
def get_and_update_item(
|
||||
self,
|
||||
mm_item: MultiModalProcessorCacheInItem,
|
||||
mm_hash: str,
|
||||
) -> MultiModalProcessorCacheOutItem:
|
||||
if (cached_item := self._cache.get(mm_hash)) is not None:
|
||||
return cached_item.item, cached_item.prompt_updates
|
||||
|
||||
assert mm_item is not None, f"Expected a cached item for {mm_hash=}"
|
||||
|
||||
self._cache[mm_hash] = MultiModalProcessorCacheItem(*mm_item)
|
||||
|
||||
return mm_item
|
||||
|
||||
@override
|
||||
def touch_sender_cache_item(self, mm_hash: str) -> None:
|
||||
self._cache.touch(mm_hash)
|
||||
|
||||
@override
|
||||
def clear_cache(self) -> None:
|
||||
self._cache.clear()
|
||||
|
||||
@override
|
||||
def make_stats(self, *, delta: bool = False) -> CacheInfo:
|
||||
return self._cache.stat(delta=delta)
|
||||
|
||||
|
||||
class MultiModalProcessorSenderCache(BaseMultiModalProcessorCache):
|
||||
"""
|
||||
The cache which is used on P0 when IPC caching is enabled.
|
||||
|
||||
How to update each item:
|
||||
|
||||
- If the item is already in the cache, clear the input to avoid
|
||||
unnecessary IPC.
|
||||
|
||||
- If the item is not in the cache, store the metadata of that item so
|
||||
that the eviction policy remains the same as the cache on P1,
|
||||
and return the input.
|
||||
By only storing the metadata, we avoid keeping the data itself in
|
||||
memory inside P0.
|
||||
"""
|
||||
|
||||
def __init__(self, model_config: "ModelConfig") -> None:
|
||||
super().__init__()
|
||||
|
||||
mm_config = model_config.get_multimodal_config()
|
||||
|
||||
self._cache = MultiModalCache.get_lru_cache(
|
||||
mm_config.mm_processor_cache_gb,
|
||||
MultiModalProcessorCacheItemMetadata,
|
||||
)
|
||||
|
||||
@override
|
||||
def is_cached_item(self, mm_hash: str) -> bool:
|
||||
return mm_hash in self._cache
|
||||
|
||||
@override
|
||||
def get_and_update_item(
|
||||
self,
|
||||
mm_item: MultiModalProcessorCacheInItem,
|
||||
mm_hash: str,
|
||||
) -> MultiModalProcessorCacheOutItem:
|
||||
if (cached_item := self._cache.get(mm_hash)) is not None:
|
||||
return None, cached_item.prompt_updates
|
||||
|
||||
assert mm_item is not None, f"Expected a cached item for {mm_hash=}"
|
||||
|
||||
self._cache[mm_hash] = MultiModalProcessorCacheItemMetadata(*mm_item)
|
||||
|
||||
return mm_item
|
||||
|
||||
@override
|
||||
def touch_sender_cache_item(self, mm_hash: str) -> None:
|
||||
self._cache.touch(mm_hash)
|
||||
|
||||
@override
|
||||
def clear_cache(self) -> None:
|
||||
self._cache.clear()
|
||||
|
||||
@override
|
||||
def make_stats(self, *, delta: bool = False) -> CacheInfo:
|
||||
return self._cache.stat(delta=delta)
|
||||
|
||||
|
||||
class ShmObjectStoreSenderCache(BaseMultiModalProcessorCache):
|
||||
"""
|
||||
The cache which is used on P0 when IPC caching is enabled.
|
||||
|
||||
How to update each item:
|
||||
|
||||
- If the item is already in the cache, clear the input to avoid
|
||||
unnecessary IPC.
|
||||
|
||||
- If the item is not in the cache, store the data in shared memory.
|
||||
"""
|
||||
|
||||
def __init__(self, vllm_config: "VllmConfig") -> None:
|
||||
super().__init__()
|
||||
|
||||
self.world_size = vllm_config.parallel_config.world_size
|
||||
mm_config = vllm_config.model_config.get_multimodal_config()
|
||||
|
||||
ring_buffer = SingleWriterShmRingBuffer(
|
||||
data_buffer_size=int(mm_config.mm_processor_cache_gb * GiB_bytes),
|
||||
name=envs.VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME,
|
||||
create=True, # sender is the writer
|
||||
)
|
||||
self._shm_cache = SingleWriterShmObjectStorage(
|
||||
max_object_size=mm_config.mm_shm_cache_max_object_size_mb * MiB_bytes,
|
||||
n_readers=self.world_size,
|
||||
ring_buffer=ring_buffer,
|
||||
serde_class=MsgpackSerde,
|
||||
)
|
||||
# cache (prompt_updates, modality) for P0 only
|
||||
self._p0_cache: dict[str, tuple[Sequence[ResolvedPromptUpdate], str]] = {}
|
||||
|
||||
self._hits = 0
|
||||
self._total = 0
|
||||
self._last_info = CacheInfo(hits=0, total=0)
|
||||
|
||||
def _stat(self, *, delta: bool = False) -> CacheInfo:
|
||||
info = CacheInfo(hits=self._hits, total=self._total)
|
||||
|
||||
if delta:
|
||||
info_delta = info - self._last_info
|
||||
self._last_info = info
|
||||
info = info_delta
|
||||
|
||||
return info
|
||||
|
||||
@override
|
||||
def is_cached_item(self, mm_hash: str) -> bool:
|
||||
return self._shm_cache.is_cached(mm_hash)
|
||||
|
||||
@override
|
||||
def get_and_update_item(
|
||||
self,
|
||||
mm_item: MultiModalProcessorCacheInItem,
|
||||
mm_hash: str,
|
||||
) -> MultiModalProcessorCacheOutItem:
|
||||
if self._shm_cache.is_cached(mm_hash):
|
||||
self._hits += 1
|
||||
self._total += 1
|
||||
|
||||
address, monotonic_id = self._shm_cache.get_cached(mm_hash)
|
||||
prompt_updates, modality = self._p0_cache[mm_hash]
|
||||
return self.address_as_item(address, monotonic_id, modality), prompt_updates
|
||||
|
||||
assert mm_item is not None, f"Expected a cached item for {mm_hash=}"
|
||||
|
||||
self._total += 1
|
||||
|
||||
try:
|
||||
address, monotonic_id = self._shm_cache.put(mm_hash, mm_item[0])
|
||||
# Try to remove dangling items if p0 cache is too large.
|
||||
if len(self._p0_cache) >= 2 * len(self._shm_cache.key_index):
|
||||
self.remove_dangling_items()
|
||||
self._p0_cache[mm_hash] = mm_item[1], mm_item[0].modality
|
||||
address_item = self.address_as_item(
|
||||
address, monotonic_id, mm_item[0].modality
|
||||
)
|
||||
return address_item, mm_item[1]
|
||||
except (ValueError, MemoryError) as e:
|
||||
# put may fail if the object is too large or
|
||||
# the cache is full.
|
||||
# In this case we log the error and keep the original mm_input.
|
||||
logger.debug("Failed to cache mm_input with hash %s: %s", mm_hash, e)
|
||||
return mm_item
|
||||
|
||||
@override
|
||||
def touch_sender_cache_item(self, mm_hash: str) -> None:
|
||||
"""Touch the item in shared memory cache to prevent eviction.
|
||||
Increments writer_flag on sender side."""
|
||||
self._shm_cache.touch(mm_hash)
|
||||
|
||||
@override
|
||||
def clear_cache(self) -> None:
|
||||
self._shm_cache.clear()
|
||||
self._p0_cache.clear()
|
||||
|
||||
self._hits = 0
|
||||
self._total = 0
|
||||
self._last_info = CacheInfo(hits=0, total=0)
|
||||
|
||||
@override
|
||||
def make_stats(self, *, delta: bool = False) -> CacheInfo:
|
||||
return self._stat(delta=delta)
|
||||
|
||||
def remove_dangling_items(self) -> None:
|
||||
"""Remove items that are no longer in the shared memory cache."""
|
||||
cached_hashes = self._shm_cache.key_index.keys()
|
||||
dangling_hashes = set(self._p0_cache.keys()) - cached_hashes
|
||||
for mm_hash in dangling_hashes:
|
||||
del self._p0_cache[mm_hash]
|
||||
|
||||
def address_as_item(
|
||||
self, address: int, monotonic_id: int, modality: str
|
||||
) -> MultiModalKwargsItem:
|
||||
addr_elem = MultiModalFieldElem(
|
||||
modality=modality,
|
||||
key="address",
|
||||
data=address,
|
||||
field=MultiModalBatchedField(),
|
||||
)
|
||||
id_elem = MultiModalFieldElem(
|
||||
modality=modality,
|
||||
key="monotonic_id",
|
||||
data=monotonic_id,
|
||||
field=MultiModalBatchedField(),
|
||||
)
|
||||
mm_item = MultiModalKwargsItem.from_elems([addr_elem, id_elem])
|
||||
return mm_item
|
||||
|
||||
|
||||
def _enable_processor_cache(
|
||||
model_config: "ModelConfig",
|
||||
mm_registry: "MultiModalRegistry",
|
||||
) -> bool:
|
||||
if not mm_registry.supports_multimodal_inputs(model_config):
|
||||
return False
|
||||
|
||||
mm_config = model_config.get_multimodal_config()
|
||||
return mm_config.mm_processor_cache_gb > 0
|
||||
|
||||
|
||||
def _enable_ipc_cache(vllm_config: "VllmConfig") -> bool:
|
||||
parallel_config = vllm_config.parallel_config
|
||||
supports_ipc_cache = (
|
||||
parallel_config._api_process_count == 1
|
||||
and parallel_config.data_parallel_size == 1
|
||||
) or parallel_config.data_parallel_external_lb
|
||||
|
||||
return supports_ipc_cache
|
||||
|
||||
|
||||
def _enable_mm_input_shm_cache(vllm_config: "VllmConfig") -> bool:
|
||||
"""Whether the shared memory based cache should be enabled."""
|
||||
|
||||
if not _enable_ipc_cache(vllm_config):
|
||||
return False
|
||||
|
||||
mm_config = vllm_config.model_config.get_multimodal_config()
|
||||
|
||||
return mm_config.mm_processor_cache_type == "shm"
|
||||
|
||||
|
||||
def processor_cache_from_config(
|
||||
vllm_config: "VllmConfig",
|
||||
mm_registry: "MultiModalRegistry",
|
||||
) -> BaseMultiModalProcessorCache | None:
|
||||
"""Return a `BaseMultiModalProcessorCache`, if enabled."""
|
||||
model_config = vllm_config.model_config
|
||||
|
||||
if not _enable_processor_cache(model_config, mm_registry):
|
||||
return None
|
||||
|
||||
if not _enable_ipc_cache(vllm_config):
|
||||
return MultiModalProcessorOnlyCache(model_config)
|
||||
|
||||
if not _enable_mm_input_shm_cache(vllm_config):
|
||||
return MultiModalProcessorSenderCache(model_config)
|
||||
return ShmObjectStoreSenderCache(vllm_config)
|
||||
|
||||
|
||||
def processor_only_cache_from_config(
|
||||
model_config: "ModelConfig",
|
||||
mm_registry: "MultiModalRegistry",
|
||||
):
|
||||
"""Return a `MultiModalProcessorOnlyCache`, if enabled."""
|
||||
if not _enable_processor_cache(model_config, mm_registry):
|
||||
return None
|
||||
|
||||
return MultiModalProcessorOnlyCache(model_config)
|
||||
|
||||
|
||||
class BaseMultiModalReceiverCache(
|
||||
BaseMultiModalCache[MultiModalKwargsItem | None, MultiModalKwargsItem]
|
||||
):
|
||||
"""The required interface for caches on P1."""
|
||||
|
||||
def get_and_update_features(
|
||||
self,
|
||||
mm_features: list["MultiModalFeatureSpec"],
|
||||
) -> list["MultiModalFeatureSpec"]:
|
||||
"""
|
||||
Update multimodal features with cached encoder outputs.
|
||||
Touch all identifier at first before update to avoid
|
||||
item in updated list evict during update.
|
||||
"""
|
||||
for feature in mm_features:
|
||||
self.touch_receiver_cache_item(feature.identifier, feature.data)
|
||||
|
||||
for feature in mm_features:
|
||||
feature.data = self.get_and_update_item(feature.data, feature.identifier)
|
||||
return mm_features
|
||||
|
||||
@abstractmethod
|
||||
def touch_receiver_cache_item(
|
||||
self,
|
||||
mm_hash: str,
|
||||
mm_item: MultiModalKwargsItem | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Update the cache eviction order for a multi-modal item.
|
||||
|
||||
This is used to touch the item in the cache without changing
|
||||
its value.
|
||||
|
||||
Args:
|
||||
mm_hash: The hash of the multi-modal item.
|
||||
mm_item: The multi-modal item itself. This is optional and
|
||||
may not be needed by some cache implementations.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MultiModalReceiverCache(BaseMultiModalReceiverCache):
|
||||
"""
|
||||
The cache which is used on P1 when IPC caching is enabled.
|
||||
|
||||
How to update each item:
|
||||
|
||||
- If the item is in the cache, replace the input with the cached item.
|
||||
- If the item is not in the cache, store that item (which includes tensor
|
||||
data) into the cache, and return the input.
|
||||
"""
|
||||
|
||||
def __init__(self, model_config: "ModelConfig") -> None:
|
||||
super().__init__()
|
||||
|
||||
mm_config = model_config.get_multimodal_config()
|
||||
|
||||
self._cache = MultiModalCache.get_lru_cache(
|
||||
mm_config.mm_processor_cache_gb,
|
||||
MultiModalKwargsItem,
|
||||
)
|
||||
|
||||
@override
|
||||
def get_and_update_item(
|
||||
self,
|
||||
mm_item: MultiModalKwargsItem | None,
|
||||
mm_hash: str,
|
||||
) -> MultiModalKwargsItem:
|
||||
if (cached_item := self._cache.get(mm_hash)) is not None:
|
||||
return cached_item
|
||||
|
||||
assert mm_item is not None, f"Expected a cached item for {mm_hash=}"
|
||||
|
||||
self._cache[mm_hash] = mm_item
|
||||
return mm_item
|
||||
|
||||
@override
|
||||
def touch_receiver_cache_item(
|
||||
self,
|
||||
mm_hash: str,
|
||||
mm_item: MultiModalKwargsItem | None = None,
|
||||
) -> None:
|
||||
self._cache.touch(mm_hash)
|
||||
|
||||
@override
|
||||
def clear_cache(self) -> None:
|
||||
self._cache.clear()
|
||||
|
||||
|
||||
class ShmObjectStoreReceiverCache(BaseMultiModalReceiverCache):
|
||||
"""
|
||||
The cache which is used on P1 Worker Process when IPC caching is enabled.
|
||||
|
||||
How to update each item:
|
||||
|
||||
- If the item has an address, replace the input with the cached item.
|
||||
- If not, return the input.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: "VllmConfig",
|
||||
shared_worker_lock: LockType,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.world_size = vllm_config.parallel_config.world_size
|
||||
mm_config = vllm_config.model_config.get_multimodal_config()
|
||||
|
||||
ring_buffer = SingleWriterShmRingBuffer(
|
||||
data_buffer_size=int(mm_config.mm_processor_cache_gb * GiB_bytes),
|
||||
name=envs.VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME,
|
||||
create=False, # Server is a reader
|
||||
)
|
||||
self._shm_cache = SingleWriterShmObjectStorage(
|
||||
max_object_size=mm_config.mm_shm_cache_max_object_size_mb * MiB_bytes,
|
||||
n_readers=self.world_size,
|
||||
ring_buffer=ring_buffer,
|
||||
serde_class=MsgpackSerde,
|
||||
reader_lock=shared_worker_lock,
|
||||
)
|
||||
|
||||
@override
|
||||
def get_and_update_item(
|
||||
self,
|
||||
mm_item: MultiModalKwargsItem | None,
|
||||
mm_hash: str,
|
||||
) -> MultiModalKwargsItem:
|
||||
assert mm_item is not None, f"Expected an address item for {mm_hash=}"
|
||||
if "address" in mm_item:
|
||||
address = cast(int, mm_item["address"].data)
|
||||
monotonic_id = cast(int, mm_item["monotonic_id"].data)
|
||||
return self._shm_cache.get(address, monotonic_id)
|
||||
|
||||
return mm_item
|
||||
|
||||
@override
|
||||
def touch_receiver_cache_item(
|
||||
self,
|
||||
mm_hash: str,
|
||||
mm_item: MultiModalKwargsItem | None = None,
|
||||
) -> None:
|
||||
"""Touch the item in shared memory cache to prevent eviction.
|
||||
Increments reader_count on receiver side."""
|
||||
assert mm_item is not None
|
||||
if "address" in mm_item:
|
||||
address = cast(int, mm_item["address"].data)
|
||||
monotonic_id = cast(int, mm_item["monotonic_id"].data)
|
||||
self._shm_cache.touch(mm_hash, address=address, monotonic_id=monotonic_id)
|
||||
|
||||
@override
|
||||
def clear_cache(self) -> None:
|
||||
self._shm_cache.clear()
|
||||
|
||||
|
||||
def engine_receiver_cache_from_config(
|
||||
vllm_config: "VllmConfig",
|
||||
mm_registry: "MultiModalRegistry",
|
||||
) -> BaseMultiModalReceiverCache | None:
|
||||
"""
|
||||
This is used in the engine process.
|
||||
Return a `BaseMultiModalReceiverCache` only when IPC caching is enabled and
|
||||
mm_processor_cache_type=="lru".
|
||||
"""
|
||||
model_config = vllm_config.model_config
|
||||
|
||||
if not _enable_processor_cache(model_config, mm_registry):
|
||||
return None
|
||||
|
||||
if not _enable_ipc_cache(vllm_config):
|
||||
return None
|
||||
|
||||
if not _enable_mm_input_shm_cache(vllm_config):
|
||||
return MultiModalReceiverCache(model_config)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def worker_receiver_cache_from_config(
|
||||
vllm_config: "VllmConfig",
|
||||
mm_registry: "MultiModalRegistry",
|
||||
shared_worker_lock: LockType,
|
||||
) -> BaseMultiModalReceiverCache | None:
|
||||
"""
|
||||
This is used in the worker process.
|
||||
Return a `BaseMultiModalReceiverCache` only when IPC caching is enabled and
|
||||
mm_processor_cache_type=="shm".
|
||||
"""
|
||||
model_config = vllm_config.model_config
|
||||
|
||||
if not _enable_processor_cache(model_config, mm_registry):
|
||||
return None
|
||||
|
||||
if not _enable_ipc_cache(vllm_config):
|
||||
return None
|
||||
|
||||
if not _enable_mm_input_shm_cache(vllm_config):
|
||||
return None
|
||||
|
||||
return ShmObjectStoreReceiverCache(vllm_config, shared_worker_lock)
|
||||
Reference in New Issue
Block a user