This commit is contained in:
root
2026-04-09 11:19:36 +08:00
parent 809cecae09
commit 8082d5f4b2
2579 changed files with 3675 additions and 0 deletions

View File

@@ -1,40 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from .hasher import MultiModalHasher
from .inputs import (
BatchedTensorInputs,
ModalityData,
MultiModalDataBuiltins,
MultiModalDataDict,
MultiModalKwargs,
MultiModalKwargsItems,
MultiModalPlaceholderDict,
MultiModalUUIDDict,
NestedTensors,
)
from .registry import MultiModalRegistry
MULTIMODAL_REGISTRY = MultiModalRegistry()
"""
The global [`MultiModalRegistry`][vllm.multimodal.registry.MultiModalRegistry]
is used by model runners to dispatch data processing according to the target
model.
Info:
[mm_processing](../../../design/mm_processing.md)
"""
__all__ = [
"BatchedTensorInputs",
"ModalityData",
"MultiModalDataBuiltins",
"MultiModalDataDict",
"MultiModalHasher",
"MultiModalKwargs",
"MultiModalKwargsItems",
"MultiModalPlaceholderDict",
"MultiModalUUIDDict",
"NestedTensors",
"MULTIMODAL_REGISTRY",
"MultiModalRegistry",
]

View File

@@ -1,118 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import base64
from io import BytesIO
from pathlib import Path
from typing import Literal
import numpy as np
import numpy.typing as npt
from vllm.utils.import_utils import PlaceholderModule
from .base import MediaIO
try:
import librosa
except ImportError:
librosa = PlaceholderModule("librosa") # type: ignore[assignment]
try:
import soundfile
except ImportError:
soundfile = PlaceholderModule("soundfile") # type: ignore[assignment]
def resample_audio_librosa(
audio: npt.NDArray[np.floating],
*,
orig_sr: float,
target_sr: float,
) -> npt.NDArray[np.floating]:
return librosa.resample(audio, orig_sr=orig_sr, target_sr=target_sr)
def resample_audio_scipy(
audio: npt.NDArray[np.floating],
*,
orig_sr: float,
target_sr: float,
):
# lazy import scipy.signal, otherwise it will crash doc build.
import scipy.signal
if orig_sr > target_sr:
return scipy.signal.resample_poly(audio, 1, orig_sr // target_sr)
elif orig_sr < target_sr:
return scipy.signal.resample_poly(audio, target_sr // orig_sr, 1)
return audio
class AudioResampler:
"""Resample audio data to a target sample rate."""
def __init__(
self,
target_sr: float | None = None,
method: Literal["librosa", "scipy"] = "librosa",
):
self.target_sr = target_sr
self.method = method
def resample(
self,
audio: npt.NDArray[np.floating],
*,
orig_sr: float,
) -> npt.NDArray[np.floating]:
if self.target_sr is None:
raise RuntimeError(
"Audio resampling is not supported when `target_sr` is not provided"
)
if self.method == "librosa":
return resample_audio_librosa(
audio, orig_sr=orig_sr, target_sr=self.target_sr
)
elif self.method == "scipy":
return resample_audio_scipy(
audio, orig_sr=orig_sr, target_sr=self.target_sr
)
else:
raise ValueError(
f"Invalid resampling method: {self.method}. "
"Supported methods are 'librosa' and 'scipy'."
)
class AudioMediaIO(MediaIO[tuple[npt.NDArray, float]]):
def __init__(self, **kwargs) -> None:
super().__init__()
# `kwargs` contains custom arguments from
# --media-io-kwargs for this modality.
# They can be passed to the underlying
# media loaders (e.g. custom implementations)
# for flexible control.
self.kwargs = kwargs
def load_bytes(self, data: bytes) -> tuple[npt.NDArray, float]:
return librosa.load(BytesIO(data), sr=None)
def load_base64(
self,
media_type: str,
data: str,
) -> tuple[npt.NDArray, float]:
return self.load_bytes(base64.b64decode(data))
def load_file(self, filepath: Path) -> tuple[npt.NDArray, float]:
return librosa.load(filepath, sr=None)
def encode_base64(self, media: tuple[npt.NDArray, int]) -> str:
audio, sr = media
with BytesIO() as buffer:
soundfile.write(buffer, audio, sr, format="WAV")
data = buffer.getvalue()
return base64.b64encode(data).decode("utf-8")

View File

@@ -1,26 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Generic, TypeVar
_T = TypeVar("_T")
class MediaIO(ABC, Generic[_T]):
@abstractmethod
def load_bytes(self, data: bytes) -> _T:
raise NotImplementedError
@abstractmethod
def load_base64(self, media_type: str, data: str) -> _T:
"""
List of media types:
https://www.iana.org/assignments/media-types/media-types.xhtml
"""
raise NotImplementedError
@abstractmethod
def load_file(self, filepath: Path) -> _T:
raise NotImplementedError

View File

@@ -1,755 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import operator
import sys
from abc import ABC, abstractmethod
from collections.abc import Mapping, Sequence
from multiprocessing.synchronize import Lock as LockType
from typing import TYPE_CHECKING, Generic, TypeAlias, TypeVar, cast
import torch
from typing_extensions import override
import vllm.envs as envs
from vllm.distributed.device_communicators.shm_object_storage import (
MsgpackSerde,
SingleWriterShmObjectStorage,
SingleWriterShmRingBuffer,
)
from vllm.logger import init_logger
from vllm.utils.cache import CacheInfo, LRUCache
from vllm.utils.jsontree import json_count_leaves, json_map_leaves, json_reduce_leaves
from vllm.utils.mem_constants import GiB_bytes, MiB_bytes
from .inputs import (
MultiModalBatchedField,
MultiModalFeatureSpec,
MultiModalFieldElem,
MultiModalKwargs,
MultiModalKwargsItem,
MultiModalKwargsItems,
NestedTensors,
)
if TYPE_CHECKING:
from vllm.config import ModelConfig, VllmConfig
from .processing import ResolvedPromptUpdate
from .registry import MultiModalRegistry
logger = init_logger(__name__)
class MultiModalProcessorCacheItem:
"""
The data to store inside `MultiModalProcessorOnlyCache`.
Args:
item: The processed tensor data corresponding to a multi-modal item.
prompt_updates: The prompt updates corresponding to `item`.
"""
def __init__(
self,
item: MultiModalKwargsItem,
prompt_updates: Sequence["ResolvedPromptUpdate"],
) -> None:
super().__init__()
self.item = item
self.prompt_updates = prompt_updates
class MultiModalProcessorCacheItemMetadata:
"""
The metadata to store inside `MultiModalProcessorSenderCache`.
Args:
item: The processed tensor data corresponding to a multi-modal item.
Since P1 already stores the tensor data, we only store its size
metadata in P0 to reduce memory usage. The size metadata is still
needed to keep the same cache eviction policy as P0.
prompt_updates: The prompt updates corresponding to `item`.
This needs to stay on P0 because for some models, they are
dependent on the processed tensor data (cached on P1).
"""
def __init__(
self,
item: MultiModalKwargsItem,
prompt_updates: Sequence["ResolvedPromptUpdate"],
) -> None:
super().__init__()
self.item_size = MultiModalCache.get_item_size(item)
self.prompt_updates = prompt_updates
MultiModalCacheValue: TypeAlias = (
MultiModalProcessorCacheItem
| MultiModalProcessorCacheItemMetadata
| MultiModalKwargsItems
| MultiModalKwargsItem
| MultiModalKwargs
| Mapping[str, NestedTensors]
)
_V = TypeVar("_V", bound=MultiModalCacheValue)
class MultiModalCache:
@classmethod
def get_leaf_size(cls, leaf: object) -> int:
if isinstance(leaf, MultiModalProcessorCacheItem):
return cls.get_leaf_size(leaf.item)
if isinstance(leaf, MultiModalProcessorCacheItemMetadata):
return leaf.item_size
# These are not subclasses of dict
if isinstance(
leaf,
(
MultiModalKwargs,
MultiModalKwargsItems,
MultiModalKwargsItem,
MultiModalFieldElem,
),
):
return cls.get_item_size(leaf.data) # type: ignore
# sys.getsizeof doesn't work for tensors
if isinstance(leaf, torch.Tensor):
return leaf.nbytes
return sys.getsizeof(leaf)
@classmethod
def get_item_size(
cls,
value: MultiModalCacheValue,
*,
debug: bool = False,
) -> int:
size = json_reduce_leaves(
operator.add, json_map_leaves(cls.get_leaf_size, value)
)
if debug:
leaf_count = json_count_leaves(value)
logger.debug(
"Calculated size of %s to be %.2f GiB (%d leaves)",
type(value),
size / GiB_bytes,
leaf_count,
)
return size
@classmethod
def get_item_complexity(cls, value: MultiModalCacheValue) -> int:
"""
Get the number of leaf elements in a multi-modal cache value.
This provides a measure of structural complexity that can be useful
for debugging cache performance and understanding data patterns.
Args:
value: The multi-modal cache value to analyze.
Returns:
The number of leaf elements in the nested structure.
"""
return json_count_leaves(value)
@classmethod
def get_lru_cache(
cls,
capacity_gb: float,
value_type: type[_V],
*,
debug: bool = False,
) -> LRUCache[str, _V]:
return LRUCache(
GiB_bytes * capacity_gb,
getsizeof=lambda x: cls.get_item_size(x, debug=debug),
)
_I = TypeVar("_I", contravariant=True)
_O = TypeVar("_O", covariant=True)
class BaseMultiModalCache(ABC, Generic[_I, _O]):
"""
Abstract base class to read/write multi-modal items from cache.
The idea of multi-modal caching is based on having a client and server
where the client executes in the frontend process (=P0) and
the server in the core process (=P1). The data flow is as follows:
```
is_cached() x N get_and_update()
P0: From API -----------------> -----------------> To P1
get_and_update()
P1: From P0 -----------------> To model
```
`is_cached()` can be called any number of times in P0. However,
`get_and_update()` must be called in P0 and P1 one after another
so that their cache eviction order remains the same.
This ensures that the keys in P0 and P1 caches are mirrored,
allowing us to determine whether a key is cached in P1 by looking
up the P0 cache, without having to communicate with P1.
"""
@abstractmethod
def get_and_update_item(
self,
mm_item: _I,
mm_hash: str,
) -> _O:
"""
Possibly update a multi-modal item based on whether it is
in the underlying cache.
This update is done out-of-place and updates the cache eviction order.
Args:
mm_item: The multi-modal item to update.
mm_hash: The hash of `mm_item`.
Returns:
The update multi-modal item.
"""
raise NotImplementedError
def get_and_update(
self,
mm_items: Sequence[_I],
mm_hashes: list[str],
) -> list[_O]:
"""
Possibly update a sequence of multi-modal items based on whether they
are in the underlying cache.
This update is done out-of-place and updates the cache eviction order.
Args:
mm_items: The multi-modal items to update.
mm_hashes: The hash of each item in `mm_items`.
Returns:
A new list of updated multi-modal items.
"""
assert len(mm_items) == len(mm_hashes)
return [
self.get_and_update_item(mm_item, mm_hash)
for mm_item, mm_hash in zip(mm_items, mm_hashes)
]
@abstractmethod
def clear_cache(self) -> None:
"""Clear the underlying cache."""
raise NotImplementedError
MultiModalProcessorCacheInItem: TypeAlias = (
tuple[MultiModalKwargsItem, Sequence["ResolvedPromptUpdate"]] | None
)
MultiModalProcessorCacheOutItem: TypeAlias = tuple[
MultiModalKwargsItem | None, Sequence["ResolvedPromptUpdate"]
]
class BaseMultiModalProcessorCache(
BaseMultiModalCache[MultiModalProcessorCacheInItem, MultiModalProcessorCacheOutItem]
):
"""The required interface for caches on P0."""
@abstractmethod
def is_cached_item(self, mm_hash: str) -> bool:
"""
Check whether a multi-modal item is
in the underlying cache.
This **DOES NOT** update the cache eviction order.
Args:
mm_hash: The hash of the item to check.
Returns:
`True` if the item is cached, otherwise `False`.
"""
raise NotImplementedError
def is_cached(self, mm_hashes: list[str]) -> list[bool]:
"""
Check whether a sequence of multi-modal items are
in the underlying cache.
This **DOES NOT** update the cache eviction order.
Args:
mm_hashes: The hash of each item to check.
Returns:
For each item, `True` if the item is cached, otherwise `False`.
"""
return [self.is_cached_item(mm_hash) for mm_hash in mm_hashes]
@abstractmethod
def make_stats(self, *, delta: bool = False) -> CacheInfo:
"""
Get (and reset) the multi-modal cache stats.
Returns:
The current multi-modal caching stats.
"""
raise NotImplementedError
class MultiModalProcessorOnlyCache(BaseMultiModalProcessorCache):
"""
The cache which is used on P0 when IPC caching is disabled.
How to update each item:
- If the item is in the cache, replace the input with the cached item.
- If the item is not in the cache, store that item (which includes
tensor data and metadata) into the cache, and return the input.
"""
def __init__(self, model_config: "ModelConfig") -> None:
super().__init__()
mm_config = model_config.get_multimodal_config()
self._cache = MultiModalCache.get_lru_cache(
mm_config.mm_processor_cache_gb,
MultiModalProcessorCacheItem,
)
@override
def is_cached_item(self, mm_hash: str) -> bool:
return mm_hash in self._cache
@override
def get_and_update_item(
self,
mm_item: MultiModalProcessorCacheInItem,
mm_hash: str,
) -> MultiModalProcessorCacheOutItem:
if (cached_item := self._cache.get(mm_hash)) is not None:
return cached_item.item, cached_item.prompt_updates
assert mm_item is not None, f"Expected a cached item for {mm_hash=}"
self._cache[mm_hash] = MultiModalProcessorCacheItem(*mm_item)
return mm_item
@override
def clear_cache(self) -> None:
self._cache.clear()
@override
def make_stats(self, *, delta: bool = False) -> CacheInfo:
return self._cache.stat(delta=delta)
class MultiModalProcessorSenderCache(BaseMultiModalProcessorCache):
"""
The cache which is used on P0 when IPC caching is enabled.
How to update each item:
- If the item is already in the cache, clear the input to avoid
unnecessary IPC.
- If the item is not in the cache, store the metadata of that item so
that the eviction policy remains the same as the cache on P1,
and return the input.
By only storing the metadata, we avoid keeping the data itself in
memory inside P0.
"""
def __init__(self, model_config: "ModelConfig") -> None:
super().__init__()
mm_config = model_config.get_multimodal_config()
self._cache = MultiModalCache.get_lru_cache(
mm_config.mm_processor_cache_gb,
MultiModalProcessorCacheItemMetadata,
)
@override
def is_cached_item(self, mm_hash: str) -> bool:
return mm_hash in self._cache
@override
def get_and_update_item(
self,
mm_item: MultiModalProcessorCacheInItem,
mm_hash: str,
) -> MultiModalProcessorCacheOutItem:
if (cached_item := self._cache.get(mm_hash)) is not None:
return None, cached_item.prompt_updates
assert mm_item is not None, f"Expected a cached item for {mm_hash=}"
self._cache[mm_hash] = MultiModalProcessorCacheItemMetadata(*mm_item)
return mm_item
@override
def clear_cache(self) -> None:
self._cache.clear()
@override
def make_stats(self, *, delta: bool = False) -> CacheInfo:
return self._cache.stat(delta=delta)
class ShmObjectStoreSenderCache(BaseMultiModalProcessorCache):
"""
The cache which is used on P0 when IPC caching is enabled.
How to update each item:
- If the item is already in the cache, clear the input to avoid
unnecessary IPC.
- If the item is not in the cache, store the data in shared memory.
"""
def __init__(self, vllm_config: "VllmConfig") -> None:
super().__init__()
self.world_size = vllm_config.parallel_config.world_size
mm_config = vllm_config.model_config.get_multimodal_config()
ring_buffer = SingleWriterShmRingBuffer(
data_buffer_size=int(mm_config.mm_processor_cache_gb * GiB_bytes),
name=envs.VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME,
create=True, # sender is the writer
)
self._shm_cache = SingleWriterShmObjectStorage(
max_object_size=mm_config.mm_shm_cache_max_object_size_mb * MiB_bytes,
n_readers=self.world_size,
ring_buffer=ring_buffer,
serde_class=MsgpackSerde,
)
# cache (prompt_updates, modality) for P0 only
self._p0_cache: dict[str, tuple[Sequence[ResolvedPromptUpdate], str]] = {}
self._hits = 0
self._total = 0
self._last_info = CacheInfo(hits=0, total=0)
def _stat(self, *, delta: bool = False) -> CacheInfo:
info = CacheInfo(hits=self._hits, total=self._total)
if delta:
info_delta = info - self._last_info
self._last_info = info
info = info_delta
return info
@override
def is_cached_item(self, mm_hash: str) -> bool:
return self._shm_cache.is_cached(mm_hash)
@override
def get_and_update_item(
self,
mm_item: MultiModalProcessorCacheInItem,
mm_hash: str,
) -> MultiModalProcessorCacheOutItem:
if self._shm_cache.is_cached(mm_hash):
self._hits += 1
self._total += 1
address, monotonic_id = self._shm_cache.get_cached(mm_hash)
prompt_updates, modality = self._p0_cache[mm_hash]
return self.address_as_item(address, monotonic_id, modality), prompt_updates
assert mm_item is not None, f"Expected a cached item for {mm_hash=}"
self._total += 1
try:
address, monotonic_id = self._shm_cache.put(mm_hash, mm_item[0])
# Try to remove dangling items if p0 cache is too large.
if len(self._p0_cache) >= 2 * len(self._shm_cache.key_index):
self.remove_dangling_items()
self._p0_cache[mm_hash] = mm_item[1], mm_item[0].modality
address_item = self.address_as_item(
address, monotonic_id, mm_item[0].modality
)
return address_item, mm_item[1]
except (ValueError, MemoryError) as e:
# put may fail if the object is too large or
# the cache is full.
# In this case we log the error and keep the original mm_input.
logger.debug("Failed to cache mm_input with hash %s: %s", mm_hash, e)
return mm_item
@override
def clear_cache(self) -> None:
self._shm_cache.clear()
self._p0_cache.clear()
self._hits = 0
self._total = 0
self._last_info = CacheInfo(hits=0, total=0)
@override
def make_stats(self, *, delta: bool = False) -> CacheInfo:
return self._stat(delta=delta)
def remove_dangling_items(self) -> None:
"""Remove items that are no longer in the shared memory cache."""
cached_hashes = self._shm_cache.key_index.keys()
dangling_hashes = set(self._p0_cache.keys()) - cached_hashes
for mm_hash in dangling_hashes:
del self._p0_cache[mm_hash]
def address_as_item(
self, address: int, monotonic_id: int, modality: str
) -> MultiModalKwargsItem:
addr_elem = MultiModalFieldElem(
modality=modality,
key="address",
data=address,
field=MultiModalBatchedField(),
)
id_elem = MultiModalFieldElem(
modality=modality,
key="monotonic_id",
data=monotonic_id,
field=MultiModalBatchedField(),
)
mm_item = MultiModalKwargsItem.from_elems([addr_elem, id_elem])
return mm_item
def _enable_processor_cache(
model_config: "ModelConfig",
mm_registry: "MultiModalRegistry",
) -> bool:
if not mm_registry.supports_multimodal_inputs(model_config):
return False
mm_config = model_config.get_multimodal_config()
return mm_config.mm_processor_cache_gb > 0
def _enable_ipc_cache(vllm_config: "VllmConfig") -> bool:
parallel_config = vllm_config.parallel_config
supports_ipc_cache = (
parallel_config._api_process_count == 1
and parallel_config.data_parallel_size == 1
) or parallel_config.data_parallel_external_lb
return supports_ipc_cache
def _enable_mm_input_shm_cache(vllm_config: "VllmConfig") -> bool:
"""Whether the shared memory based cache should be enabled."""
if not _enable_ipc_cache(vllm_config):
return False
mm_config = vllm_config.model_config.get_multimodal_config()
return mm_config.mm_processor_cache_type == "shm"
def processor_cache_from_config(
vllm_config: "VllmConfig",
mm_registry: "MultiModalRegistry",
) -> BaseMultiModalProcessorCache | None:
"""Return a `BaseMultiModalProcessorCache`, if enabled."""
model_config = vllm_config.model_config
if not _enable_processor_cache(model_config, mm_registry):
return None
if not _enable_ipc_cache(vllm_config):
return MultiModalProcessorOnlyCache(model_config)
if not _enable_mm_input_shm_cache(vllm_config):
return MultiModalProcessorSenderCache(model_config)
return ShmObjectStoreSenderCache(vllm_config)
def processor_only_cache_from_config(
model_config: "ModelConfig",
mm_registry: "MultiModalRegistry",
):
"""Return a `MultiModalProcessorOnlyCache`, if enabled."""
if not _enable_processor_cache(model_config, mm_registry):
return None
return MultiModalProcessorOnlyCache(model_config)
class BaseMultiModalReceiverCache(
BaseMultiModalCache[MultiModalKwargsItem | None, MultiModalKwargsItem]
):
"""The required interface for caches on P1."""
def get_and_update_features(
self,
mm_features: list["MultiModalFeatureSpec"],
) -> list["MultiModalFeatureSpec"]:
"""Update multimodal features with cached encoder outputs."""
for feature in mm_features:
feature.data = self.get_and_update_item(feature.data, feature.identifier)
return mm_features
class MultiModalReceiverCache(BaseMultiModalReceiverCache):
"""
The cache which is used on P1 when IPC caching is enabled.
How to update each item:
- If the item is in the cache, replace the input with the cached item.
- If the item is not in the cache, store that item (which includes tensor
data) into the cache, and return the input.
"""
def __init__(self, model_config: "ModelConfig") -> None:
super().__init__()
mm_config = model_config.get_multimodal_config()
self._cache = MultiModalCache.get_lru_cache(
mm_config.mm_processor_cache_gb,
MultiModalKwargsItem,
)
@override
def get_and_update_item(
self,
mm_item: MultiModalKwargsItem | None,
mm_hash: str,
) -> MultiModalKwargsItem:
if (cached_item := self._cache.get(mm_hash)) is not None:
return cached_item
assert mm_item is not None, f"Expected a cached item for {mm_hash=}"
self._cache[mm_hash] = mm_item
return mm_item
@override
def clear_cache(self) -> None:
self._cache.clear()
class ShmObjectStoreReceiverCache(BaseMultiModalReceiverCache):
"""
The cache which is used on P1 Worker Process when IPC caching is enabled.
How to update each item:
- If the item has an address, replace the input with the cached item.
- If not, return the input.
"""
def __init__(
self,
vllm_config: "VllmConfig",
shared_worker_lock: LockType,
) -> None:
super().__init__()
self.world_size = vllm_config.parallel_config.world_size
mm_config = vllm_config.model_config.get_multimodal_config()
ring_buffer = SingleWriterShmRingBuffer(
data_buffer_size=int(mm_config.mm_processor_cache_gb * GiB_bytes),
name=envs.VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME,
create=False, # Server is a reader
)
self._shm_cache = SingleWriterShmObjectStorage(
max_object_size=mm_config.mm_shm_cache_max_object_size_mb * MiB_bytes,
n_readers=self.world_size,
ring_buffer=ring_buffer,
serde_class=MsgpackSerde,
reader_lock=shared_worker_lock,
)
@override
def get_and_update_item(
self,
mm_item: MultiModalKwargsItem | None,
mm_hash: str,
) -> MultiModalKwargsItem:
assert mm_item is not None, f"Expected an address item for {mm_hash=}"
if "address" in mm_item:
address = cast(int, mm_item["address"].data)
monotonic_id = cast(int, mm_item["monotonic_id"].data)
return self._shm_cache.get(address, monotonic_id)
return mm_item
@override
def clear_cache(self) -> None:
self._shm_cache.clear()
def engine_receiver_cache_from_config(
vllm_config: "VllmConfig",
mm_registry: "MultiModalRegistry",
) -> BaseMultiModalReceiverCache | None:
"""
This is used in the engine process.
Return a `BaseMultiModalReceiverCache` only when IPC caching is enabled and
mm_processor_cache_type=="lru".
"""
model_config = vllm_config.model_config
if not _enable_processor_cache(model_config, mm_registry):
return None
if not _enable_ipc_cache(vllm_config):
return None
if not _enable_mm_input_shm_cache(vllm_config):
return MultiModalReceiverCache(model_config)
return None
def worker_receiver_cache_from_config(
vllm_config: "VllmConfig",
mm_registry: "MultiModalRegistry",
shared_worker_lock: LockType,
) -> BaseMultiModalReceiverCache | None:
"""
This is used in the worker process.
Return a `BaseMultiModalReceiverCache` only when IPC caching is enabled and
mm_processor_cache_type=="shm".
"""
model_config = vllm_config.model_config
if not _enable_processor_cache(model_config, mm_registry):
return None
if not _enable_ipc_cache(vllm_config):
return None
if not _enable_mm_input_shm_cache(vllm_config):
return None
return ShmObjectStoreReceiverCache(vllm_config, shared_worker_lock)

View File

@@ -1,294 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
import typing
import torch
def compute_retained_tokens_count(
tokens_per_frame: int, num_frames: int, q: float
) -> int:
"""
Compute the number of retained tokens for a given video.
Method ensures that we retain all the tokens from the first frame
regardless of the pruning rate.
Args:
tokens_per_frame: The number of tokens per frame.
num_frames: The total number of frames.
q: The pruning rate.
Returns:
The number of retained tokens.
"""
total_tokens = tokens_per_frame * num_frames
evs_num_tokens = int(total_tokens * (1 - q))
min_num_tokens = tokens_per_frame
return max(min_num_tokens, evs_num_tokens)
def compute_retention_mask(
video_embeds: torch.Tensor,
video_size_thw: torch.LongTensor | tuple[int, int, int],
spatial_merge_size: int,
q: float,
) -> torch.Tensor:
"""
Computes the retention mask for input video embeddings.
Args:
video_embeds (`torch.Tensor`): The input video embeddings
of shape `(T * H * W // spatial_merge_size ^ 2, hidden_size)`
video_size_thw (`torch.LongTensor` of shape `(3)`):
The temporal, height and width of video.
spatial_merge_size: Size reduction for rows & cols dimensions.
q: (`float`): Pruning rate factor [0,1)
Returns:
`torch.Tensor`: The retention mask for the video embeddings of
`(T * H * W // spatial_merge_size ^ 2)` shape.
"""
T, H, W = map(int, video_size_thw)
# Use reshape instead of einops to avoid graph breaks
video_embeds = video_embeds.reshape(
T,
H // spatial_merge_size,
W // spatial_merge_size,
video_embeds.size(-1),
)
tokens_per_frame = (H // spatial_merge_size) * (W // spatial_merge_size)
# Core EVS
similarity = torch.nn.functional.cosine_similarity(
video_embeds[1:, ...], video_embeds[:-1, ...], dim=-1
)
dissimilarity = 1 - similarity
# Always ensure we include all tokens from the first frame
dissimilarity = torch.cat(
[255 * torch.ones_like(video_embeds[:1, :, :, 0]), dissimilarity], dim=0
)
dissimilarity_flat = dissimilarity.view(-1)
order = torch.argsort(dissimilarity_flat, dim=-1, descending=True, stable=True)
retain_num_tokens = compute_retained_tokens_count(
tokens_per_frame=tokens_per_frame, num_frames=T, q=q
)
topk_indices = order[:retain_num_tokens]
retention_mask = torch.zeros_like(dissimilarity_flat, dtype=torch.bool)
retention_mask[topk_indices] = True
retention_mask = retention_mask.reshape(dissimilarity.size())
mask = retention_mask.view(-1) # "T H W -> (T H W)"
return mask
def compute_mrope_for_media(
video_size_thw: torch.LongTensor,
spatial_merge_size: int,
tokens_per_second: float = 1.0,
video_second_per_grid: float = 1.0,
) -> torch.Tensor:
"""
Computes the mrope for video embeddings based on the grid dimensions.
Computed mrope positions match original qwen 2.5 implementation,
but positions are built for media being the first element in sequence.
Args:
video_size_thw: Media size (num frames, rows, cols)
spatial_merge_size: Size reduction for rows & cols dimensions.
tokens_per_second: Number of tokens per second.
video_second_per_grid: Number of seconds per video.
Returns:
Tensor of shape `(T * H * W, 4)` where last dimension
represents mrope positions [0:3), while the last channel
contains value of llm_grid_w repeated for all positions.
"""
llm_grid_t = video_size_thw[0]
llm_grid_h = video_size_thw[1] // spatial_merge_size
llm_grid_w = video_size_thw[2] // spatial_merge_size
t_index = (
(
torch.arange(llm_grid_t)
.view(-1, 1)
.expand(-1, llm_grid_h * llm_grid_w)
.mul(tokens_per_second * video_second_per_grid)
)
.long()
.flatten()
)
h_index = (
torch.arange(llm_grid_h)
.view(1, -1, 1)
.expand(llm_grid_t, -1, llm_grid_w)
.flatten()
)
w_index = (
torch.arange(llm_grid_w)
.view(1, 1, -1)
.expand(llm_grid_t, llm_grid_h, -1)
.flatten()
)
llm_grid_w = (
torch.tensor([llm_grid_w])
.view(1, 1, 1)
.expand(llm_grid_t, llm_grid_h, llm_grid_w)
.flatten()
)
positions = torch.stack([t_index, h_index, w_index, llm_grid_w], dim=1)
return positions
def recompute_mrope_positions(
input_ids: torch.LongTensor,
multimodal_positions: list[torch.Tensor],
mrope_positions: torch.LongTensor,
num_computed_tokens: int,
vision_start_token_id: int,
image_token_id: int,
video_token_id: int,
) -> tuple[torch.LongTensor, int]:
"""
Update part of input mrope positions.
Original mrope_positions are computed incorrectly, so once we prune media
tokens we should reflect this in the mrope positions for the LLM.
This method supports chunked prefill approach where
multimodal_embeddings are passed to LLM in chunks, so input
multimodal_embeddings may contain zero, some or even some part of all
multimodal_embeddings for a given prompt.
Each multimodal_positions has 4 extra channels
(First 3 channels corresponds to original 3 mrope positions, last channel
is the maximum width of the media repeated). Provided multimodal_positions
do not reflect location of media position in sequence - they are computed
like the media is in the 0-th position in the sequence.
Method works as follows: it recomputes mrope_positions starting from the
`num_computed_tokens` for `total_len_of_multimodal_embeddings` and then
shifts all text tokens that goes after total_len_of_multimodal_embeddings.
It also handles case when multimodal_embeddings is partial
(e.g. one media is split into two prefill stages)
Args:
input_ids: (N,) All input tokens of the prompt (entire sequence).
multimodal_positions: List of mrope positsions for each media.
mrope_positions: Existing mrope positions (4, N) for entire sequence.
num_computed_tokens: A number of computed tokens so far.
vision_start_token_id: Token indicating start of vision media.
image_token_id: Image token id
video_token_id: Video token id
Returns:
Tuple of (mrope_positions, mrope_position_delta).
"""
# Tensors
positions: torch.LongTensor = typing.cast(
torch.LongTensor, mrope_positions.clone()
) # (3, N)
N = input_ids.numel()
image_mask = input_ids.eq(image_token_id)
video_mask = input_ids.eq(video_token_id)
media_mask = image_mask | video_mask
text_mask = ~media_mask
# Early exit: no media in this chunk
if len(multimodal_positions) == 0:
delta = int((positions.max().item() + 1) - N) if positions.numel() else -N
return positions, delta
total_mm_tokens = torch.count_nonzero(media_mask)
seen_mm_tokens = torch.count_nonzero(media_mask[:num_computed_tokens])
# Early exit: we've updated positions for all media tokens
# (and consequently - for all remaining text tokens)
if seen_mm_tokens == total_mm_tokens:
delta = int((positions.max().item() + 1) - N) if positions.numel() else -N
return positions, delta
vision_start_indices = (input_ids == vision_start_token_id).nonzero(as_tuple=True)[
0
]
for mm_pos in multimodal_positions:
# Each mm_pos can be a complete embedding for single media
# or it can be a part of a single media (due to chunked prefill)
# Cases to cover
# - Current prefill chunk has no vision start indexes at all
# - Vision start token appeared in previous prefill round
# - Regular case
seen_vision_start_indices = vision_start_indices[
vision_start_indices < num_computed_tokens
]
if len(seen_vision_start_indices):
# If we have encountered some vision start indexes,
# then we should check the condition:
# | --- prefill 1 ------| ---- prefill 2 ----- |
# | TTTTTTTTTSVVVVVVVVVV|VVVVVVTTTTTTTTTTTTTTTT|
last_vision_start_token = seen_vision_start_indices[-1]
seem_mm_tokens_before_last_vision_start = torch.count_nonzero(
media_mask[:last_vision_start_token]
)
in_the_middle_of_media = (
seen_mm_tokens > seem_mm_tokens_before_last_vision_start
)
if in_the_middle_of_media:
mm_embeddings_seen = (
seen_mm_tokens - seem_mm_tokens_before_last_vision_start
)
global_mm_start = last_vision_start_token
else:
# We have completed previous mm_embedding part and
# ready to start a new one
next_vision_start_token = vision_start_indices[
vision_start_indices >= num_computed_tokens
][0]
mm_embeddings_seen = 0
global_mm_start = next_vision_start_token
else:
# If there were no vision start indexes so far,
# let's find first vision start index
next_vision_start_token = vision_start_indices[
vision_start_indices >= num_computed_tokens
][0]
mm_embeddings_seen = 0
global_mm_start = next_vision_start_token
# Offset right after vision_start_token
base = positions[-1, global_mm_start] + 1
local_start = global_mm_start + 1 + mm_embeddings_seen
local_end = local_start + mm_pos.shape[1]
positions[:, local_start:local_end] = mm_pos[0:3] + base
# mm_pos[3, 0] is the max width of the media
offset = mm_pos[3, 0] + base
text_pos_sum = torch.cumsum(text_mask[local_end:].long(), dim=0)
positions[:, local_end:N] = text_pos_sum + offset - 1
# Include distance to the next vision start token
num_computed_tokens += mm_pos.shape[1]
mrope_positions_delta = (positions.max() + 1 - N).item()
return positions, mrope_positions_delta

View File

@@ -1,106 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pickle
import uuid
from collections.abc import Iterable
import numpy as np
import torch
from blake3 import blake3
from PIL import Image
from vllm.logger import init_logger
logger = init_logger(__name__)
class MultiModalHasher:
@classmethod
def serialize_item(cls, obj: object) -> Iterable[bytes | memoryview]:
# Simple cases
if isinstance(obj, (bytes, memoryview)):
return (obj,)
if isinstance(obj, str):
return (obj.encode("utf-8"),)
if isinstance(obj, (int, float)):
return (np.array(obj).tobytes(),)
if isinstance(obj, Image.Image):
exif = obj.getexif()
if Image.ExifTags.Base.ImageID in exif and isinstance(
exif[Image.ExifTags.Base.ImageID], uuid.UUID
):
# If the image has exif ImageID tag, use that
return (exif[Image.ExifTags.Base.ImageID].bytes,)
data = {"mode": obj.mode, "data": np.asarray(obj)}
if obj.palette is not None:
data["palette"] = obj.palette.palette
if obj.palette.rawmode is not None:
data["palette_rawmode"] = obj.palette.rawmode
return cls.iter_item_to_bytes("image", data)
if isinstance(obj, torch.Tensor):
tensor_obj: torch.Tensor = obj.cpu()
tensor_dtype = tensor_obj.dtype
tensor_shape = tensor_obj.shape
# NumPy does not support bfloat16.
# Workaround: View the tensor as a contiguous 1D array of bytes
if tensor_dtype == torch.bfloat16:
tensor_obj = tensor_obj.contiguous()
tensor_obj = tensor_obj.view((tensor_obj.numel(),)).view(torch.uint8)
return cls.iter_item_to_bytes(
"tensor",
{
"original_dtype": str(tensor_dtype),
"original_shape": tuple(tensor_shape),
"data": tensor_obj.numpy(),
},
)
return cls.iter_item_to_bytes("tensor", tensor_obj.numpy())
if isinstance(obj, np.ndarray):
# If the array is non-contiguous, we need to copy it first
arr_data = (
obj.view(np.uint8).data if obj.flags.c_contiguous else obj.tobytes()
)
return cls.iter_item_to_bytes(
"ndarray",
{
"dtype": obj.dtype.str,
"shape": obj.shape,
"data": arr_data,
},
)
logger.warning(
"No serialization method found for %s. Falling back to pickle.", type(obj)
)
return (pickle.dumps(obj),)
@classmethod
def iter_item_to_bytes(
cls,
key: str,
obj: object,
) -> Iterable[bytes | memoryview]:
# Recursive cases
if isinstance(obj, (list, tuple)):
for i, elem in enumerate(obj):
yield from cls.iter_item_to_bytes(f"{key}.{i}", elem)
elif isinstance(obj, dict):
for k, v in obj.items():
yield from cls.iter_item_to_bytes(f"{key}.{k}", v)
else:
yield key.encode("utf-8")
yield from cls.serialize_item(obj)
@classmethod
def hash_kwargs(cls, **kwargs: object) -> str:
hasher = blake3()
for k, v in kwargs.items():
for bytes_ in cls.iter_item_to_bytes(k, v):
hasher.update(bytes_)
return hasher.hexdigest()

View File

@@ -1,130 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from io import BytesIO
from pathlib import Path
import pybase64
import torch
from PIL import Image
from .base import MediaIO
def rescale_image_size(
image: Image.Image, size_factor: float, transpose: int = -1
) -> Image.Image:
"""Rescale the dimensions of an image by a constant factor."""
new_width = int(image.width * size_factor)
new_height = int(image.height * size_factor)
image = image.resize((new_width, new_height))
if transpose >= 0:
image = image.transpose(Image.Transpose(transpose))
return image
def rgba_to_rgb(
image: Image.Image,
background_color: tuple[int, int, int] | list[int] = (255, 255, 255),
) -> Image.Image:
"""Convert an RGBA image to RGB with filled background color."""
assert image.mode == "RGBA"
converted = Image.new("RGB", image.size, background_color)
converted.paste(image, mask=image.split()[3]) # 3 is the alpha channel
return converted
def convert_image_mode(image: Image.Image, to_mode: str):
if image.mode == to_mode:
return image
elif image.mode == "RGBA" and to_mode == "RGB":
return rgba_to_rgb(image)
else:
return image.convert(to_mode)
class ImageMediaIO(MediaIO[Image.Image]):
def __init__(self, image_mode: str = "RGB", **kwargs) -> None:
super().__init__()
self.image_mode = image_mode
# `kwargs` contains custom arguments from
# --media-io-kwargs for this modality.
# They can be passed to the underlying
# media loaders (e.g. custom implementations)
# for flexible control.
self.kwargs = kwargs
# Extract RGBA background color from kwargs if provided
# Default to white background for backward compatibility
rgba_bg = kwargs.get("rgba_background_color", (255, 255, 255))
# Convert list to tuple for consistency
if isinstance(rgba_bg, list):
rgba_bg = tuple(rgba_bg)
# Validate rgba_background_color format
if not (
isinstance(rgba_bg, tuple)
and len(rgba_bg) == 3
and all(isinstance(c, int) and 0 <= c <= 255 for c in rgba_bg)
):
raise ValueError(
"rgba_background_color must be a list or tuple of 3 integers "
"in the range [0, 255]."
)
self.rgba_background_color = rgba_bg
def _convert_image_mode(self, image: Image.Image) -> Image.Image:
"""Convert image mode with custom background color."""
if image.mode == self.image_mode:
return image
elif image.mode == "RGBA" and self.image_mode == "RGB":
return rgba_to_rgb(image, self.rgba_background_color)
else:
return convert_image_mode(image, self.image_mode)
def load_bytes(self, data: bytes) -> Image.Image:
image = Image.open(BytesIO(data))
image.load()
return self._convert_image_mode(image)
def load_base64(self, media_type: str, data: str) -> Image.Image:
return self.load_bytes(pybase64.b64decode(data, validate=True))
def load_file(self, filepath: Path) -> Image.Image:
image = Image.open(filepath)
image.load()
return self._convert_image_mode(image)
def encode_base64(
self,
media: Image.Image,
*,
image_format: str = "JPEG",
) -> str:
image = media
with BytesIO() as buffer:
image = self._convert_image_mode(image)
image.save(buffer, image_format)
data = buffer.getvalue()
return pybase64.b64encode(data).decode("utf-8")
class ImageEmbeddingMediaIO(MediaIO[torch.Tensor]):
def __init__(self) -> None:
super().__init__()
def load_bytes(self, data: bytes) -> torch.Tensor:
buffer = BytesIO(data)
return torch.load(buffer, weights_only=True)
def load_base64(self, media_type: str, data: str) -> torch.Tensor:
return self.load_bytes(pybase64.b64decode(data, validate=True))
def load_file(self, filepath: Path) -> torch.Tensor:
return torch.load(filepath, weights_only=True)
def encode_base64(self, media: torch.Tensor) -> str:
return pybase64.b64encode(media.numpy()).decode("utf-8")

File diff suppressed because it is too large Load Diff

View File

@@ -1,544 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from collections import UserDict
from collections.abc import Callable, Iterator, Mapping, Sequence
from typing import (
TYPE_CHECKING,
Any,
Generic,
Literal,
NamedTuple,
TypeAlias,
TypeGuard,
TypeVar,
)
import numpy as np
import torch
from typing_extensions import assert_never
from vllm.utils.collection_utils import is_list_of
from vllm.utils.import_utils import LazyLoader
from .audio import AudioResampler
from .inputs import (
AudioItem,
HfAudioItem,
HfImageItem,
HfVideoItem,
ImageItem,
ModalityData,
MultiModalDataDict,
MultiModalFieldConfig,
MultiModalKwargsItems,
VideoItem,
)
_T = TypeVar("_T")
_I = TypeVar("_I")
if TYPE_CHECKING:
import PIL.Image as PILImage
else:
PILImage = LazyLoader("PILImage", globals(), "PIL.Image")
class ModalityDataItems(ABC, Generic[_T, _I]):
"""
Represents data items for a modality in
[`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems].
"""
def __init__(self, data: _T, modality: str) -> None:
super().__init__()
self.data: _T = data
self.modality = modality
def __repr__(self) -> str:
return f"{type(self).__name__}(modality={self.modality!r}, len={len(self)})"
def __len__(self) -> int:
return self.get_count()
def __getitem__(self, index: int) -> _I:
return self.get(index)
if TYPE_CHECKING:
# Auto-generated
def __iter__(self) -> Iterator[_I]: ...
@abstractmethod
def get_count(self) -> int:
"""Get the number of data items."""
raise NotImplementedError
@abstractmethod
def get(self, index: int) -> _I:
"""Get a data item by its index."""
raise NotImplementedError
def get_all(self) -> list[_I]:
"""Get all data items."""
return [self.get(idx) for idx in range(self.get_count())]
@abstractmethod
def get_processor_data(self) -> Mapping[str, object]:
"""Get the data to pass to the HF processor."""
raise NotImplementedError
@abstractmethod
def get_passthrough_data(self) -> Mapping[str, object]:
"""Get the data to pass directly to the model."""
raise NotImplementedError
class ProcessorBatchItems(ModalityDataItems[Sequence[_T], _T]):
"""Base class for data items that are arranged in a list."""
def get_count(self) -> int:
return len(self.data)
def get(self, index: int) -> _T:
return self.data[index]
def get_processor_data(self) -> Mapping[str, object]:
return {f"{self.modality}s": self.data}
def get_passthrough_data(self) -> Mapping[str, object]:
return {}
class EmbeddingItems(
ModalityDataItems[torch.Tensor | list[torch.Tensor], torch.Tensor]
):
"""
Base class for data items that are expressed as a batched embedding tensor,
or a list of embedding tensors (one per item).
"""
def get_count(self) -> int:
return len(self.data)
def get(self, index: int) -> torch.Tensor:
return self.data[index]
def get_processor_data(self) -> Mapping[str, object]:
return {}
def get_passthrough_data(self) -> Mapping[str, object]:
return {f"{self.modality}_embeds": self.data}
def get_feature_size(self, item_idx: int) -> int:
return len(self.get(item_idx))
class DictEmbeddingItems(
ModalityDataItems[Mapping[str, torch.Tensor], Mapping[str, torch.Tensor]]
):
"""
Base class for data items that are expressed as a dictionary of tensors.
Usually, the dictionary keys correspond to the outputs of HF processor.
"""
def __init__(
self,
data: Mapping[str, torch.Tensor],
modality: str,
required_fields: set[str],
fields_factory: Callable[
[Mapping[str, torch.Tensor]],
Mapping[str, MultiModalFieldConfig],
],
) -> None:
from transformers.feature_extraction_utils import BatchFeature
super().__init__(data, modality)
missing_required_data_keys = required_fields - data.keys()
if missing_required_data_keys:
data_keys = set(data.keys())
msg = (
f"The data should contain the fields: {required_fields}, "
f"but only found the following keys: {data_keys}"
)
raise ValueError(msg)
fields_config = fields_factory(data)
missing_required_fields = required_fields - fields_config.keys()
if missing_required_fields:
fields = set(fields_config.keys())
msg = f"{required_fields=} should be a subset of {fields=}"
raise ValueError(msg)
self.fields_config = fields_config
self.required_fields = required_fields
self._kwargs = MultiModalKwargsItems.from_hf_inputs(
BatchFeature(dict(data)),
fields_config,
)
def get_count(self) -> int:
return len(self._kwargs[self.modality])
def get(self, index: int) -> Mapping[str, torch.Tensor]:
return self._kwargs[self.modality][index].get_data()
def get_processor_data(self) -> Mapping[str, object]:
return {}
def get_passthrough_data(self) -> Mapping[str, object]:
return self.data
class AudioProcessorItems(ProcessorBatchItems[HfAudioItem]):
def __init__(self, data: Sequence[HfAudioItem] | None) -> None:
if data is None:
data = [None]
super().__init__(data, "audio")
def get_audio_length(self, item_idx: int) -> int:
audio = self.get(item_idx)
return len(audio)
class AudioEmbeddingItems(EmbeddingItems):
def __init__(self, data: torch.Tensor | list[torch.Tensor]) -> None:
super().__init__(data, "audio")
class ImageSize(NamedTuple):
width: int
height: int
class ImageProcessorItems(ProcessorBatchItems[HfImageItem]):
def __init__(self, data: Sequence[HfImageItem] | None) -> None:
if data is None:
data = [None]
super().__init__(data, "image")
def get_image_size(self, item_idx: int) -> ImageSize:
image = self.get(item_idx)
if isinstance(image, PILImage.Image):
return ImageSize(*image.size)
if isinstance(image, (np.ndarray, torch.Tensor)):
_, h, w = image.shape
return ImageSize(w, h)
assert_never(image)
class ImageEmbeddingItems(EmbeddingItems):
def __init__(self, data: torch.Tensor | list[torch.Tensor]) -> None:
super().__init__(data, "image")
class VideoProcessorItems(ProcessorBatchItems[HfVideoItem]):
def __init__(
self,
data: Sequence[HfVideoItem] | None,
metadata: dict[str, Any] | list[dict[str, Any] | None] | None = None,
) -> None:
if data is None:
data = [None]
super().__init__(data, "video")
self.metadata = metadata
def get_num_frames(self, item_idx: int) -> int:
return len(self.get(item_idx))
def get_frame_size(self, item_idx: int) -> ImageSize:
image = self.get(item_idx)[0] # Assume that the video isn't empty
if isinstance(image, PILImage.Image):
return ImageSize(*image.size)
if isinstance(image, (np.ndarray, torch.Tensor)):
_, h, w = image.shape
return ImageSize(w, h)
assert_never(image)
class VideoEmbeddingItems(EmbeddingItems):
def __init__(self, data: torch.Tensor | list[torch.Tensor]) -> None:
super().__init__(data, "video")
_D = TypeVar("_D", bound=ModalityDataItems[Any, Any])
class MultiModalDataItems(UserDict[str, ModalityDataItems[Any, Any]]):
"""
As [`MultiModalDataDict`][vllm.multimodal.inputs.MultiModalDataDict], but
normalized such that each entry corresponds to a list.
"""
def get_count(self, modality: str, *, strict: bool = True) -> int:
"""
Get the number of data items belonging to a modality.
If `strict=False`, return `0` instead of raising [`KeyError`][]
even if the modality is not found.
"""
if modality not in self:
if strict:
available_modalities = set(self.keys())
raise KeyError(
f"Modality {modality!r} not found. "
f"Available modalities: {available_modalities}"
)
return 0
return self[modality].get_count()
def get_all_counts(self) -> Mapping[str, int]:
"""Get the number of items belonging to each modality."""
return {m: items.get_count() for m, items in self.items()}
def get_items(
self,
modality: str,
typ: type[_D] | tuple[type[_D], ...],
) -> _D:
"""
Get the data items belonging to a modality,
requiring that they belong to a certain type.
"""
if modality not in self:
available_modalities = set(self.keys())
raise KeyError(
f"Modality {modality!r} not found. "
f"Available modalities: {available_modalities}"
)
items = self[modality]
if not isinstance(items, typ):
raise TypeError(
f"Invalid type of data items for {modality=}. "
f"Expected type: {typ}, but "
f"found type: {type(items)}"
)
return items # type: ignore[return-value]
ModalityDataParser: TypeAlias = Callable[
[ModalityData[Any]], ModalityDataItems[Any, Any] | None
]
class MultiModalDataParser:
"""
Parses [`MultiModalDataDict`][vllm.multimodal.inputs.MultiModalDataDict]
into [`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems].
Args:
target_sr (float, optional): Enables automatic resampling of audio
items to the model's expected sampling rate.
"""
def __init__(
self,
*,
target_sr: float | None = None,
audio_resample_method: Literal["librosa", "scipy"] = "librosa",
video_needs_metadata: bool = False,
) -> None:
super().__init__()
self.audio_resampler = AudioResampler(
target_sr=target_sr,
method=audio_resample_method,
)
self.video_needs_metadata = video_needs_metadata
@classmethod
def is_embeddings(
cls, data: object
) -> TypeGuard[torch.Tensor | list[torch.Tensor]]:
if isinstance(data, torch.Tensor):
return data.ndim == 3
if is_list_of(data, torch.Tensor):
return data[0].ndim == 2 # type: ignore[index]
return False
def _is_empty(self, data: object) -> TypeGuard[None]:
if isinstance(data, list):
return len(data) == 0
if isinstance(data, (np.ndarray, torch.Tensor)):
return data.size == 0
return False
def _get_audio_with_sr(
self,
audio: AudioItem,
) -> tuple[np.ndarray, float | None]:
if isinstance(audio, tuple):
return audio
if isinstance(audio, list):
return np.array(audio), None
if isinstance(audio, np.ndarray):
return audio, None
if isinstance(audio, torch.Tensor):
return audio.numpy(), None
assert_never(audio)
def _get_video_with_metadata(
self,
video: VideoItem,
) -> tuple[np.ndarray, dict[str, Any] | None]:
if isinstance(video, tuple):
return video
if isinstance(video, list):
return np.array(video), None
if isinstance(video, np.ndarray):
return video, None
if isinstance(video, torch.Tensor):
return video.numpy(), None
assert_never(video)
def _parse_audio_data(
self,
data: ModalityData[AudioItem],
) -> ModalityDataItems[Any, Any] | None:
if data is None:
return AudioProcessorItems(None)
# also check single audio item with sampling rate
if self._is_empty(data) or (
isinstance(data, tuple) and self._is_empty(data[0])
):
return None
if self.is_embeddings(data):
return AudioEmbeddingItems(data)
data_items: list[AudioItem]
if (
is_list_of(data, float)
or isinstance(data, (np.ndarray, torch.Tensor))
and data.ndim == 1
or isinstance(data, tuple)
):
data_items = [data]
elif isinstance(data, (np.ndarray, torch.Tensor)):
data_items = [elem for elem in data]
else:
data_items = data # type: ignore[assignment]
new_audios = list[np.ndarray]()
for data_item in data_items:
audio, orig_sr = self._get_audio_with_sr(data_item)
if orig_sr is None:
new_audio = audio
else:
new_audio = self.audio_resampler.resample(audio, orig_sr=orig_sr)
new_audios.append(new_audio)
return AudioProcessorItems(new_audios)
def _parse_image_data(
self,
data: ModalityData[ImageItem],
) -> ModalityDataItems[Any, Any] | None:
if data is None:
return ImageProcessorItems(None)
if self._is_empty(data):
return None
if self.is_embeddings(data):
return ImageEmbeddingItems(data)
if (
isinstance(data, PILImage.Image)
or isinstance(data, (np.ndarray, torch.Tensor))
and data.ndim == 3
):
data_items = [data]
elif isinstance(data, (np.ndarray, torch.Tensor)):
data_items = [elem for elem in data]
else:
data_items = data
return ImageProcessorItems(data_items)
def _parse_video_data(
self,
data: ModalityData[VideoItem],
) -> ModalityDataItems[Any, Any] | None:
if data is None:
return VideoProcessorItems(None)
if self._is_empty(data):
return None
if self.is_embeddings(data):
return VideoEmbeddingItems(data)
data_items: list[VideoItem]
if (
is_list_of(data, PILImage.Image)
or isinstance(data, (np.ndarray, torch.Tensor))
and data.ndim == 4
):
data_items = [data]
elif isinstance(data, (np.ndarray, torch.Tensor)):
data_items = [elem for elem in data]
elif isinstance(data, tuple) and len(data) == 2:
data_items = [data]
else:
data_items = data # type: ignore[assignment]
new_videos = list[tuple[np.ndarray, dict[str, Any] | None]]()
metadata_lst: list[dict[str, Any] | None] = []
for data_item in data_items:
video, metadata = self._get_video_with_metadata(data_item)
if self.video_needs_metadata:
if metadata is None:
raise ValueError(
"Video metadata is required but not found in mm input. "
"Please check your video input in `multi_modal_data`"
)
new_videos.append((video, metadata))
metadata_lst.append(metadata)
else:
new_videos.append(video)
if not self.video_needs_metadata:
metadata = None
return VideoProcessorItems(new_videos, metadata=metadata_lst)
def _get_subparsers(self) -> Mapping[str, ModalityDataParser]:
return {
"audio": self._parse_audio_data,
"image": self._parse_image_data,
"video": self._parse_video_data,
}
def parse_mm_data(self, mm_data: MultiModalDataDict) -> MultiModalDataItems:
subparsers = self._get_subparsers()
mm_items = MultiModalDataItems()
for k, v in mm_data.items():
if k not in subparsers:
raise ValueError(f"Unsupported modality: {k}")
# ignore empty embedding data
if (parsed_data := subparsers[k](v)) is not None:
mm_items[k] = parsed_data
return mm_items

File diff suppressed because it is too large Load Diff

View File

@@ -1,369 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from collections.abc import Mapping
from dataclasses import dataclass, field
from typing import Generic, NamedTuple, TypeVar, cast
import numpy as np
import numpy.typing as npt
from PIL import Image
from vllm.config.multimodal import (
AudioDummyOptions,
BaseDummyOptions,
ImageDummyOptions,
VideoDummyOptions,
)
from vllm.logger import init_logger
from .inputs import (
MultiModalDataDict,
MultiModalEncDecInputs,
MultiModalInputs,
MultiModalKwargsItems,
MultiModalPlaceholderDict,
)
from .processing import (
BaseMultiModalProcessor,
BaseProcessingInfo,
EncDecMultiModalProcessor,
)
logger = init_logger(__name__)
@dataclass
class ProcessorInputs:
"""
Represents the keyword arguments to
[`vllm.multimodal.processing.BaseMultiModalProcessor.apply`][].
"""
prompt: str | list[int]
mm_data: MultiModalDataDict
hf_processor_mm_kwargs: Mapping[str, object] = field(default_factory=dict)
tokenization_kwargs: Mapping[str, object] = field(default_factory=dict)
class DummyEncoderData(NamedTuple):
"""Dummy data used for profiling."""
prompt_token_ids: list[int]
class DummyDecoderData(NamedTuple):
"""Dummy data used for profiling."""
prompt_token_ids: list[int]
multi_modal_data: MultiModalKwargsItems
multi_modal_placeholders: MultiModalPlaceholderDict
_I = TypeVar("_I", bound=BaseProcessingInfo)
class BaseDummyInputsBuilder(ABC, Generic[_I]):
"""
Abstract base class that constructs the dummy data to profile
multi-modal models.
"""
def __init__(self, info: _I) -> None:
super().__init__()
self.info = info
@abstractmethod
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
"""
Build the text input corresponding to `mm_counts`.
"""
raise NotImplementedError
@abstractmethod
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions] | None = None,
) -> MultiModalDataDict:
"""
Build the multimodal input which, after processing, results in
the maximum possible number of placeholder tokens.
Args:
seq_len: Sequence length
mm_counts: Count of items per modality
mm_options: Configurable options per modality (optional).
If None, use model defaults for backward compatibility.
If provided, models can use these to customize dummy
data generation.
"""
raise NotImplementedError
def get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions] | None = None,
) -> ProcessorInputs:
"""
Build the input which, after processing, results in
the maximum possible number of placeholder tokens.
Args:
seq_len: Sequence length
mm_counts: Count of items per modality
mm_options: Configurable options per modality (optional)
"""
dummy_text = self.get_dummy_text(mm_counts)
# Use the unified function for both legacy and configurable cases
dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts, mm_options)
tokenization_kwargs = {"truncation": False}
return ProcessorInputs(
prompt=dummy_text,
mm_data=dummy_mm_data,
tokenization_kwargs=tokenization_kwargs,
)
def _get_dummy_audios(
self,
*,
length: int,
num_audios: int,
overrides: AudioDummyOptions | None = None,
) -> list[npt.NDArray]:
if num_audios == 0:
return []
if overrides and overrides.length:
if overrides.length > length:
logger.warning(
"audio.length override (%d) exceeds model's "
"maximum length (%d), will be ignored",
overrides.length,
length,
)
length = min(length, overrides.length)
audio = np.zeros((length,))
return [audio] * num_audios
def _get_dummy_images(
self,
*,
width: int,
height: int,
num_images: int,
overrides: ImageDummyOptions | None = None,
) -> list[Image.Image]:
if num_images == 0:
return []
if overrides:
if overrides.width:
if overrides.width > width:
logger.warning(
"image.width override (%d) exceeds model's "
"maximum width (%d), will be ignored",
overrides.width,
width,
)
width = min(width, overrides.width)
if overrides.height:
if overrides.height > height:
logger.warning(
"image.height override (%d) exceeds model's "
"maximum height (%d), will be ignored",
overrides.height,
height,
)
height = min(height, overrides.height)
image = Image.new("RGB", (width, height), color=255)
return [image] * num_images
def _get_dummy_videos(
self,
*,
width: int,
height: int,
num_frames: int,
num_videos: int,
overrides: VideoDummyOptions | None = None,
) -> list[npt.NDArray]:
if num_videos == 0:
return []
if overrides:
if overrides.num_frames:
if overrides.num_frames > num_frames:
logger.warning(
"video.num_frames override (%d) exceeds model's "
"maximum number of frames (%d), will be ignored",
overrides.num_frames,
num_frames,
)
num_frames = min(num_frames, overrides.num_frames)
if overrides.width:
if overrides.width > width:
logger.warning(
"video.width override (%d) exceeds model's "
"maximum width (%d), will be ignored",
overrides.width,
width,
)
width = min(width, overrides.width)
if overrides.height:
if overrides.height > height:
logger.warning(
"video.height override (%d) exceeds model's "
"maximum height (%d), will be ignored",
overrides.height,
height,
)
height = min(height, overrides.height)
video = np.full((num_frames, width, height, 3), 255, dtype=np.uint8)
return [video] * num_videos
class MultiModalProfiler(Generic[_I]):
"""
Contains code for running memory profiling for multi-modal models.
"""
def __init__(
self,
processor: BaseMultiModalProcessor[_I],
) -> None:
super().__init__()
self.processor = processor
@property
def processing_info(self) -> BaseProcessingInfo:
return self.processor.info
@property
def dummy_inputs(self) -> BaseDummyInputsBuilder[_I]:
return self.processor.dummy_inputs
def get_mm_limits(self) -> Mapping[str, int]:
return self.processor.allowed_mm_limits
def _get_dummy_mm_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int] | None = None,
mm_options: Mapping[str, BaseDummyOptions] | None = None,
) -> MultiModalInputs:
if mm_counts is None:
mm_counts = self.get_mm_limits()
factory = self.dummy_inputs
processor_inputs = factory.get_dummy_processor_inputs(
seq_len, mm_counts, mm_options
)
return self.processor.apply(
prompt=processor_inputs.prompt,
mm_data=processor_inputs.mm_data,
hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
tokenization_kwargs=processor_inputs.tokenization_kwargs,
)
def _get_mm_num_tokens(
self,
mm_inputs: MultiModalInputs,
mm_embeddings_only: bool = True,
) -> Mapping[str, int]:
placeholders_by_modality = mm_inputs["mm_placeholders"]
return {
modality: sum(
item.get_num_embeds() if mm_embeddings_only else item.length
for item in placeholders
)
for modality, placeholders in placeholders_by_modality.items()
}
def get_encoder_dummy_data(
self,
seq_len: int,
mm_counts: Mapping[str, int] | None = None,
mm_options: Mapping[str, BaseDummyOptions] | None = None,
) -> DummyEncoderData:
mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts, mm_options)
mm_inputs = cast(MultiModalEncDecInputs, mm_inputs)
# For encoder-decoder models, use encoder prompt token ids instead of
# decoder prompt to construct dummy seq_data for encoder profiling.
encoder_prompt_token_ids = mm_inputs["encoder_prompt_token_ids"]
total_len = len(encoder_prompt_token_ids)
processor = cast(EncDecMultiModalProcessor, self.processor)
if processor.pad_dummy_encoder_prompt:
num_tokens_to_pad = max(total_len, seq_len) - total_len
encoder_prompt_token_ids.extend([0] * num_tokens_to_pad)
return DummyEncoderData(encoder_prompt_token_ids)
def get_decoder_dummy_data(
self,
seq_len: int,
mm_counts: Mapping[str, int] | None = None,
mm_options: Mapping[str, BaseDummyOptions] | None = None,
) -> DummyDecoderData:
mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts, mm_options)
prompt_token_ids = mm_inputs["prompt_token_ids"]
total_len = len(prompt_token_ids)
if total_len < seq_len:
prompt_token_ids.extend([0] * (seq_len - total_len))
return DummyDecoderData(
prompt_token_ids=prompt_token_ids,
multi_modal_data=mm_inputs["mm_kwargs"].require_data(),
multi_modal_placeholders=mm_inputs["mm_placeholders"],
)
def _get_mm_max_tokens(
self,
seq_len: int,
mm_counts: Mapping[str, int] | None = None,
mm_embeddings_only: bool = True,
) -> Mapping[str, int]:
if mm_counts is None:
mm_counts = self.get_mm_limits()
max_tokens_per_item = self.processing_info.get_mm_max_tokens_per_item(
seq_len=seq_len,
mm_counts=mm_counts,
)
if max_tokens_per_item is not None:
return {
modality: max_tokens
for modality, max_tokens in max_tokens_per_item.items()
if mm_counts.get(modality, 0) > 0
}
mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
return self._get_mm_num_tokens(mm_inputs, mm_embeddings_only=mm_embeddings_only)
def get_mm_max_contiguous_tokens(
self,
seq_len: int,
mm_counts: Mapping[str, int] | None = None,
) -> Mapping[str, int]:
"""
Returns the maximum length of the multimodal (image placeholders+text)
tokens, including any break/text tokens in-between image embeddings.
`<im_start> [IMG] [IMG] [IMG] <row_break> [IMG] [IMG] [IMG] <im_end>`
Returns 9, even when the number of image embeddings is 6.
This is important to take into account when profiling and
initializing the encoder cache size.
"""
return self._get_mm_max_tokens(seq_len, mm_counts, mm_embeddings_only=False)

View File

@@ -1,360 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Mapping
from dataclasses import dataclass
from typing import TYPE_CHECKING, Generic, Protocol, TypeVar
import torch.nn as nn
from vllm.config.multimodal import BaseDummyOptions
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer, cached_tokenizer_from_config
from vllm.utils.collection_utils import ClassRegistry
from .cache import BaseMultiModalProcessorCache
from .processing import (
BaseMultiModalProcessor,
BaseProcessingInfo,
InputProcessingContext,
)
from .profiling import (
BaseDummyInputsBuilder,
DummyDecoderData,
DummyEncoderData,
MultiModalProfiler,
)
if TYPE_CHECKING:
from vllm.config import ModelConfig
logger = init_logger(__name__)
N = TypeVar("N", bound=type[nn.Module])
_I = TypeVar("_I", bound=BaseProcessingInfo)
_I_co = TypeVar("_I_co", bound=BaseProcessingInfo, covariant=True)
class ProcessingInfoFactory(Protocol[_I_co]):
"""
Constructs a
[`BaseMultiModalProcessor`][vllm.multimodal.processing.BaseMultiModalProcessor]
instance from the context.
"""
def __call__(
self,
ctx: InputProcessingContext,
) -> _I_co: ...
class DummyInputsBuilderFactory(Protocol[_I]): # type: ignore[misc]
"""
Constructs a
[`BaseDummyInputsBuilder`][vllm.multimodal.profiling.BaseDummyInputsBuilder]
instance from the context.
"""
def __call__(self, info: _I) -> BaseDummyInputsBuilder[_I]: ...
class MultiModalProcessorFactory(Protocol[_I]): # type: ignore[misc]
"""
Constructs a
[`BaseMultiModalProcessor`][vllm.multimodal.processing.BaseMultiModalProcessor]
instance from the context.
"""
def __call__(
self,
info: _I,
dummy_inputs: BaseDummyInputsBuilder[_I],
*,
cache: BaseMultiModalProcessorCache | None = None,
) -> BaseMultiModalProcessor[_I]: ...
@dataclass(frozen=True)
class _ProcessorFactories(Generic[_I]):
info: ProcessingInfoFactory[_I]
processor: MultiModalProcessorFactory[_I]
dummy_inputs: DummyInputsBuilderFactory[_I]
def build_processor(
self,
ctx: InputProcessingContext,
*,
cache: BaseMultiModalProcessorCache | None = None,
):
info = self.info(ctx)
dummy_inputs_builder = self.dummy_inputs(info)
return self.processor(info, dummy_inputs_builder, cache=cache)
class MultiModalRegistry:
"""
A registry that dispatches data processing according to the model.
"""
def __init__(self) -> None:
self._processor_factories = ClassRegistry[nn.Module, _ProcessorFactories]()
def _extract_mm_options(
self,
model_config: "ModelConfig",
) -> Mapping[str, BaseDummyOptions] | None:
"""
Extract multimodal dummy options from model config.
Returns None if no configurable options are found, otherwise returns
a mapping of modality names to their dummy options.
"""
if not model_config.multimodal_config:
return None
mm_options = {
m: opt
for m in model_config.multimodal_config.limit_per_prompt
if (opt := model_config.multimodal_config.get_dummy_options(m)) is not None
}
return mm_options if len(mm_options) > 0 else None
def supports_multimodal_inputs(self, model_config: "ModelConfig") -> bool:
"""
Checks if the model supports multimodal inputs.
Returns True if the model is multimodal with any non-zero supported
modalities, otherwise returns False, effectively running in
text-only mode.
"""
if not model_config.is_multimodal_model:
return False
info = self._create_processing_info(model_config, tokenizer=None)
supported_modalities = info.get_supported_mm_limits()
mm_config = model_config.get_multimodal_config()
# Check if all supported modalities have limit == 0
if all(
mm_config.get_limit_per_prompt(modality) == 0
for modality in supported_modalities
):
logger.info_once(
"All limits of multimodal modalities supported by the model "
"are set to 0, running in text-only mode."
)
return False
return True
def get_max_tokens_per_item_by_modality(
self,
model_config: "ModelConfig",
*,
cache: BaseMultiModalProcessorCache | None = None,
profiler_limits: Mapping[str, int] | None = None,
) -> Mapping[str, int]:
"""
Get the maximum number of tokens per data item from each modality based
on underlying model configuration.
"""
if not model_config.is_multimodal_model:
return {}
processor = self.create_processor(model_config, cache=cache)
profiler: MultiModalProfiler = MultiModalProfiler(processor)
seq_len = model_config.max_model_len
profiler_limits = (
profiler.get_mm_limits() if profiler_limits is None else profiler_limits
)
return profiler.get_mm_max_contiguous_tokens(
seq_len,
{modality: 1 for modality, limit in profiler_limits.items() if limit > 0},
)
def get_mm_limits_per_prompt(
self,
model_config: "ModelConfig",
*,
cache: BaseMultiModalProcessorCache | None = None,
) -> Mapping[str, int]:
"""
Get the maximum number of multi-modal input instances for each modality
that are allowed per prompt for a model class.
"""
if not model_config.is_multimodal_model:
return {}
processor = self.create_processor(model_config, cache=cache)
profiler: MultiModalProfiler = MultiModalProfiler(processor)
return profiler.get_mm_limits()
def register_processor(
self,
processor: MultiModalProcessorFactory[_I],
*,
info: ProcessingInfoFactory[_I],
dummy_inputs: DummyInputsBuilderFactory[_I],
):
"""
Register a multi-modal processor to a model class. The processor
is constructed lazily, hence a factory method should be passed.
When the model receives multi-modal data, the provided function is
invoked to transform the data into a dictionary of model inputs.
"""
def wrapper(model_cls: N) -> N:
if self._processor_factories.contains(model_cls, strict=True):
logger.warning(
"Model class %s already has a multi-modal processor "
"registered to %s. It is overwritten by the new one.",
model_cls,
self,
)
self._processor_factories[model_cls] = _ProcessorFactories(
info=info,
dummy_inputs=dummy_inputs,
processor=processor,
)
return model_cls
return wrapper
def _get_model_cls(self, model_config: "ModelConfig"):
# Avoid circular import
from vllm.model_executor.model_loader import get_model_architecture
model_cls, _ = get_model_architecture(model_config)
return model_cls
def _create_processing_ctx(
self,
model_config: "ModelConfig",
tokenizer: AnyTokenizer | None = None,
) -> InputProcessingContext:
if tokenizer is None and not model_config.skip_tokenizer_init:
tokenizer = cached_tokenizer_from_config(model_config)
return InputProcessingContext(model_config, tokenizer)
def _create_processing_info(
self,
model_config: "ModelConfig",
*,
tokenizer: AnyTokenizer | None = None,
) -> BaseProcessingInfo:
model_cls = self._get_model_cls(model_config)
factories = self._processor_factories[model_cls]
ctx = self._create_processing_ctx(model_config, tokenizer)
return factories.info(ctx)
def create_processor(
self,
model_config: "ModelConfig",
*,
tokenizer: AnyTokenizer | None = None,
cache: BaseMultiModalProcessorCache | None = None,
) -> BaseMultiModalProcessor[BaseProcessingInfo]:
"""
Create a multi-modal processor for a specific model and tokenizer.
"""
if not model_config.is_multimodal_model:
raise ValueError(f"{model_config.model} is not a multimodal model")
model_cls = self._get_model_cls(model_config)
factories = self._processor_factories[model_cls]
ctx = self._create_processing_ctx(model_config, tokenizer)
return factories.build_processor(ctx, cache=cache)
def get_decoder_dummy_data(
self,
model_config: "ModelConfig",
seq_len: int,
mm_counts: Mapping[str, int] | None = None,
*,
cache: BaseMultiModalProcessorCache | None = None,
) -> DummyDecoderData:
"""
Create dummy data for profiling the memory usage of a model.
The model is identified by `model_config`.
"""
processor = self.create_processor(model_config, cache=cache)
profiler: MultiModalProfiler = MultiModalProfiler(processor)
# Extract configurable options from multimodal config.
# Only include modalities that use advanced option types so legacy
# count-only behavior remains unchanged.
mm_options = self._extract_mm_options(model_config)
dummy_data = profiler.get_decoder_dummy_data(seq_len, mm_counts, mm_options)
# Having more tokens is over-conservative but otherwise fine
token_ids = dummy_data.prompt_token_ids
if len(token_ids) < seq_len:
raise AssertionError(
f"Expected at least {seq_len} dummy tokens for profiling, "
f"but found {len(token_ids)} tokens instead."
)
return dummy_data
def get_encoder_dummy_data(
self,
model_config: "ModelConfig",
seq_len: int,
mm_counts: Mapping[str, int] | None = None,
*,
cache: BaseMultiModalProcessorCache | None = None,
) -> DummyEncoderData:
"""
Create dummy data for profiling the memory usage of a model.
The model is identified by `model_config`.
"""
processor = self.create_processor(model_config, cache=cache)
profiler: MultiModalProfiler = MultiModalProfiler(processor)
# Extract configurable options from multimodal config.
# Only include modalities that use advanced option types so legacy
# count-only behavior remains unchanged.
mm_options = self._extract_mm_options(model_config)
dummy_data = profiler.get_encoder_dummy_data(seq_len, mm_counts, mm_options)
# Having more tokens is over-conservative but otherwise fine
token_ids = dummy_data.prompt_token_ids
if len(token_ids) < seq_len:
logger.warning_once(
"Expected at least %d dummy encoder tokens for profiling, but found %d tokens instead.", # noqa: E501
seq_len,
len(token_ids),
)
return dummy_data
def get_encdec_max_encoder_len(self, model_config: "ModelConfig") -> int:
"""
Get the maximum length of the encoder input for encoder-decoder models.
"""
if not model_config.is_encoder_decoder:
return 0
max_tokens = self.get_max_tokens_per_item_by_modality(model_config)
if not max_tokens:
# TODO - this function assumes encoder-decoder models are
# multimodal. This will need to change when adding support for more
# than whisper.
return 0
assert len(max_tokens) == 1, (
"Encoder-decoder models are expected \
to implement the multimodal interface with at most one modality."
)
first_modality = next(iter(max_tokens))
return max_tokens[first_modality]

View File

@@ -1,512 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import atexit
from collections.abc import Iterable, Set
from concurrent.futures import ThreadPoolExecutor
from itertools import groupby
from pathlib import Path
from typing import TYPE_CHECKING, Any, TypeVar
from urllib.parse import ParseResult, urlparse
from urllib.request import url2pathname
import numpy as np
import numpy.typing as npt
import torch
from PIL import Image, UnidentifiedImageError
import vllm.envs as envs
from vllm.connections import HTTPConnection, global_http_connection
from vllm.logger import init_logger
from vllm.utils.jsontree import json_map_leaves
from vllm.utils.registry import ExtensionManager
from .audio import AudioMediaIO
from .base import MediaIO
from .image import ImageEmbeddingMediaIO, ImageMediaIO
from .video import VideoMediaIO
if TYPE_CHECKING:
from .inputs import (
BatchedTensorInputs,
MultiModalKwargsItem,
MultiModalPlaceholderDict,
)
else:
BatchedTensorInputs = Any
MultiModalKwargsItem = Any
MultiModalPlaceholderDict = Any
logger = init_logger(__name__)
global_thread_pool = ThreadPoolExecutor(
max_workers=envs.VLLM_MEDIA_LOADING_THREAD_COUNT
)
atexit.register(global_thread_pool.shutdown)
_M = TypeVar("_M")
MEDIA_CONNECTOR_REGISTRY = ExtensionManager()
@MEDIA_CONNECTOR_REGISTRY.register("http")
class MediaConnector:
def __init__(
self,
media_io_kwargs: dict[str, dict[str, Any]] | None = None,
connection: HTTPConnection = global_http_connection,
*,
allowed_local_media_path: str = "",
allowed_media_domains: list[str] | None = None,
) -> None:
"""
Args:
media_io_kwargs: Additional args passed to process media
inputs, keyed by modalities. For example,
to set num_frames for video, set
`--media-io-kwargs '{"video":{"num_frames":40}}'`
connection: HTTP connection client to download media contents.
allowed_local_media_path: A local directory to load media files
from.
"""
super().__init__()
self.media_io_kwargs: dict[str, dict[str, Any]] = (
media_io_kwargs if media_io_kwargs else {}
)
self.connection = connection
if allowed_local_media_path:
allowed_local_media_path_ = Path(allowed_local_media_path)
if not allowed_local_media_path_.exists():
raise ValueError(
"Invalid `--allowed-local-media-path`: The path "
f"{allowed_local_media_path_} does not exist."
)
if not allowed_local_media_path_.is_dir():
raise ValueError(
"Invalid `--allowed-local-media-path`: The path "
f"{allowed_local_media_path_} must be a directory."
)
else:
allowed_local_media_path_ = None
self.allowed_local_media_path = allowed_local_media_path_
if allowed_media_domains is None:
allowed_media_domains = []
self.allowed_media_domains = allowed_media_domains
def _load_data_url(
self,
url_spec: ParseResult,
media_io: MediaIO[_M],
) -> _M: # type: ignore[type-var]
data_spec, data = url_spec.path.split(",", 1)
media_type, data_type = data_spec.split(";", 1)
if data_type != "base64":
msg = "Only base64 data URLs are supported for now."
raise NotImplementedError(msg)
return media_io.load_base64(media_type, data)
def _load_file_url(
self,
url_spec: ParseResult,
media_io: MediaIO[_M],
) -> _M: # type: ignore[type-var]
allowed_local_media_path = self.allowed_local_media_path
if allowed_local_media_path is None:
raise RuntimeError(
"Cannot load local files without `--allowed-local-media-path`."
)
filepath = Path(url2pathname(url_spec.path))
if allowed_local_media_path not in filepath.resolve().parents:
raise ValueError(
f"The file path {filepath} must be a subpath "
f"of `--allowed-local-media-path` {allowed_local_media_path}."
)
return media_io.load_file(filepath)
def _assert_url_in_allowed_media_domains(self, url_spec) -> None:
if (
self.allowed_media_domains
and url_spec.hostname not in self.allowed_media_domains
):
raise ValueError(
f"The URL must be from one of the allowed domains: "
f"{self.allowed_media_domains}. Input URL domain: "
f"{url_spec.hostname}"
)
def load_from_url(
self,
url: str,
media_io: MediaIO[_M],
*,
fetch_timeout: int | None = None,
) -> _M: # type: ignore[type-var]
url_spec = urlparse(url)
if url_spec.scheme.startswith("http"):
self._assert_url_in_allowed_media_domains(url_spec)
connection = self.connection
data = connection.get_bytes(
url,
timeout=fetch_timeout,
allow_redirects=envs.VLLM_MEDIA_URL_ALLOW_REDIRECTS,
)
return media_io.load_bytes(data)
if url_spec.scheme == "data":
return self._load_data_url(url_spec, media_io)
if url_spec.scheme == "file":
return self._load_file_url(url_spec, media_io)
msg = "The URL must be either a HTTP, data or file URL."
raise ValueError(msg)
async def load_from_url_async(
self,
url: str,
media_io: MediaIO[_M],
*,
fetch_timeout: int | None = None,
) -> _M:
url_spec = urlparse(url)
loop = asyncio.get_running_loop()
if url_spec.scheme.startswith("http"):
self._assert_url_in_allowed_media_domains(url_spec)
connection = self.connection
data = await connection.async_get_bytes(
url,
timeout=fetch_timeout,
allow_redirects=envs.VLLM_MEDIA_URL_ALLOW_REDIRECTS,
)
future = loop.run_in_executor(global_thread_pool, media_io.load_bytes, data)
return await future
if url_spec.scheme == "data":
future = loop.run_in_executor(
global_thread_pool, self._load_data_url, url_spec, media_io
)
return await future
if url_spec.scheme == "file":
future = loop.run_in_executor(
global_thread_pool, self._load_file_url, url_spec, media_io
)
return await future
msg = "The URL must be either a HTTP, data or file URL."
raise ValueError(msg)
def fetch_audio(
self,
audio_url: str,
) -> tuple[np.ndarray, int | float]:
"""
Load audio from a URL.
"""
audio_io = AudioMediaIO(**self.media_io_kwargs.get("audio", {}))
return self.load_from_url(
audio_url,
audio_io,
fetch_timeout=envs.VLLM_AUDIO_FETCH_TIMEOUT,
)
async def fetch_audio_async(
self,
audio_url: str,
) -> tuple[np.ndarray, int | float]:
"""
Asynchronously fetch audio from a URL.
"""
audio_io = AudioMediaIO(**self.media_io_kwargs.get("audio", {}))
return await self.load_from_url_async(
audio_url,
audio_io,
fetch_timeout=envs.VLLM_AUDIO_FETCH_TIMEOUT,
)
def fetch_image(
self,
image_url: str,
*,
image_mode: str = "RGB",
) -> Image.Image:
"""
Load a PIL image from an HTTP or base64 data URL.
By default, the image is converted into RGB format.
"""
image_io = ImageMediaIO(
image_mode=image_mode, **self.media_io_kwargs.get("image", {})
)
try:
return self.load_from_url(
image_url,
image_io,
fetch_timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT,
)
except UnidentifiedImageError as e:
# convert to ValueError to be properly caught upstream
raise ValueError(str(e)) from e
async def fetch_image_async(
self,
image_url: str,
*,
image_mode: str = "RGB",
) -> Image.Image:
"""
Asynchronously load a PIL image from an HTTP or base64 data URL.
By default, the image is converted into RGB format.
"""
image_io = ImageMediaIO(
image_mode=image_mode, **self.media_io_kwargs.get("image", {})
)
try:
return await self.load_from_url_async(
image_url,
image_io,
fetch_timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT,
)
except UnidentifiedImageError as e:
# convert to ValueError to be properly caught upstream
raise ValueError(str(e)) from e
def fetch_video(
self,
video_url: str,
*,
image_mode: str = "RGB",
) -> tuple[npt.NDArray, dict[str, Any]]:
"""
Load video from an HTTP or base64 data URL.
"""
image_io = ImageMediaIO(
image_mode=image_mode, **self.media_io_kwargs.get("image", {})
)
video_io = VideoMediaIO(image_io, **self.media_io_kwargs.get("video", {}))
return self.load_from_url(
video_url,
video_io,
fetch_timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT,
)
async def fetch_video_async(
self,
video_url: str,
*,
image_mode: str = "RGB",
) -> tuple[npt.NDArray, dict[str, Any]]:
"""
Asynchronously load video from an HTTP or base64 data URL.
By default, the image is converted into RGB format.
"""
image_io = ImageMediaIO(
image_mode=image_mode, **self.media_io_kwargs.get("image", {})
)
video_io = VideoMediaIO(image_io, **self.media_io_kwargs.get("video", {}))
return await self.load_from_url_async(
video_url,
video_io,
fetch_timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT,
)
def fetch_image_embedding(
self,
data: str,
) -> torch.Tensor:
"""
Load image embedding from a URL.
"""
image_embedding_io = ImageEmbeddingMediaIO()
return image_embedding_io.load_base64("", data)
def encode_audio_base64(
audio: np.ndarray,
sampling_rate: int,
) -> str:
"""Encode audio as base64."""
audio_io = AudioMediaIO()
return audio_io.encode_base64((audio, sampling_rate))
def encode_image_base64(
image: Image.Image,
*,
image_mode: str = "RGB",
format: str = "JPEG",
) -> str:
"""
Encode a pillow image to base64 format.
By default, the image is converted into RGB format before being encoded.
"""
image_io = ImageMediaIO(image_mode=image_mode)
return image_io.encode_base64(image, image_format=format)
def encode_video_base64(frames: npt.NDArray) -> str:
image_io = ImageMediaIO()
video_io = VideoMediaIO(image_io)
return video_io.encode_base64(frames)
def argsort_mm_positions(
mm_positions: MultiModalPlaceholderDict,
) -> list[tuple[str, int]]:
"""
Given a `MultiModalPlaceholderDict`, output a sequence of keys to
sort the dictionary by `offset` (starting index in the input sequence)
in ascending order.
Returns:
A list of `(modality, idx)`, which can be used to access an item
by `mm_positions[modality][idx]`.
"""
flat_items = (
(modality, idx, item)
for modality, items in mm_positions.items()
for idx, item in enumerate(items)
)
sorted_flat_items = sorted(flat_items, key=lambda x: x[2].offset)
return [(modality, idx) for modality, idx, _ in sorted_flat_items]
def group_mm_kwargs_by_modality(
mm_kwargs: list[MultiModalKwargsItem],
*,
device: torch.types.Device = None,
pin_memory: bool = False,
merge_by_field_config: bool | None = None,
multimodal_cpu_fields: Set[str] = frozenset(),
) -> Iterable[tuple[str, int, BatchedTensorInputs]]:
"""Group consecutive `MultiModalKwargsItem`s from `mm_kwargs` with the same
modality together into the same `MultiModalKwargs` instance.
Args:
mm_kwargs: List of `MultiModalKwargsItem`.
device: The device to place the grouped tensors on.
pin_memory: Whether to pin memory for faster host-to-device transfer.
Yields:
A tuple `(modality, num_items, grouped_kwargs)`.
"""
if merge_by_field_config is None:
raise RuntimeError(
"`group_mm_kwargs_by_modality` now requires "
"`merge_by_field_config` arg, please update your model runner "
"according to https://github.com/vllm-project/vllm/pull/25676."
)
if merge_by_field_config is False:
logger.warning_once(
"The legacy code for batching multi-modal kwargs is deprecated and "
"will be removed in v0.12. Please update your model with "
"`merge_by_field_config=True` to use the new code defined by "
"`MultiModalFieldConfig`. You can refer to "
"https://github.com/vllm-project/vllm/issues/26149 "
"for some examples on how to do this."
)
from vllm.multimodal.inputs import MultiModalKwargs, MultiModalKwargsItems
for modality, items in groupby(mm_kwargs, key=lambda item: item.modality):
items_lst = list(items)
if merge_by_field_config:
mm_kwargs_group: BatchedTensorInputs = dict(
MultiModalKwargsItems.from_seq(items_lst).get_data(
pin_memory=pin_memory
)
)
if device is not None:
mm_kwargs_group = {
k: json_map_leaves(
lambda x: x.to(device=device, non_blocking=True)
if isinstance(x, torch.Tensor)
else x,
v,
)
if k not in multimodal_cpu_fields
else v
for k, v in mm_kwargs_group.items()
}
else:
mm_kwargs_group = MultiModalKwargs.as_kwargs(
MultiModalKwargs.batch(
[
MultiModalKwargsItems.from_seq([item]).get_data()
for item in items_lst
],
pin_memory=pin_memory,
),
device=device,
)
yield modality, len(items_lst), mm_kwargs_group
def fetch_audio(
audio_url: str,
audio_io_kwargs: dict[str, Any] | None = None,
) -> tuple[np.ndarray, int | float]:
"""
Args:
audio_url: URL of the audio file to fetch.
audio_io_kwargs: Additional kwargs passed to handle audio IO.
"""
media_io_kwargs = None if not audio_io_kwargs else {"audio": audio_io_kwargs}
media_connector = MediaConnector(media_io_kwargs=media_io_kwargs)
return media_connector.fetch_audio(audio_url)
def fetch_image(
image_url: str,
image_io_kwargs: dict[str, Any] | None = None,
) -> Image.Image:
"""
Args:
image_url: URL of the image file to fetch.
image_io_kwargs: Additional kwargs passed to handle image IO.
"""
media_io_kwargs = None if not image_io_kwargs else {"image": image_io_kwargs}
media_connector = MediaConnector(media_io_kwargs=media_io_kwargs)
return media_connector.fetch_image(image_url)
def fetch_video(
video_url: str,
video_io_kwargs: dict[str, Any] | None = None,
) -> tuple[npt.NDArray, dict[str, Any]]:
"""
Args:
video_url: URL of the video file to fetch.
video_io_kwargs: Additional kwargs passed to handle video IO.
"""
media_io_kwargs = None if not video_io_kwargs else {"video": video_io_kwargs}
media_connector = MediaConnector(media_io_kwargs=media_io_kwargs)
return media_connector.fetch_video(video_url)

View File

@@ -1,306 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import base64
import math
from abc import abstractmethod
from functools import partial
from io import BytesIO
from pathlib import Path
from typing import Any
import numpy as np
import numpy.typing as npt
from PIL import Image
from vllm import envs
from vllm.logger import init_logger
from vllm.utils.registry import ExtensionManager
from .base import MediaIO
from .image import ImageMediaIO
logger = init_logger(__name__)
def resize_video(frames: npt.NDArray, size: tuple[int, int]) -> npt.NDArray:
num_frames, _, _, channels = frames.shape
new_height, new_width = size
resized_frames = np.empty(
(num_frames, new_height, new_width, channels), dtype=frames.dtype
)
# lazy import cv2 to avoid bothering users who only use text models
import cv2
for i, frame in enumerate(frames):
resized_frame = cv2.resize(frame, (new_width, new_height))
resized_frames[i] = resized_frame
return resized_frames
def rescale_video_size(frames: npt.NDArray, size_factor: float) -> npt.NDArray:
_, height, width, _ = frames.shape
new_height = int(height * size_factor)
new_width = int(width * size_factor)
return resize_video(frames, (new_height, new_width))
def sample_frames_from_video(frames: npt.NDArray, num_frames: int) -> npt.NDArray:
total_frames = frames.shape[0]
if num_frames == -1:
return frames
frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
sampled_frames = frames[frame_indices, ...]
return sampled_frames
class VideoLoader:
@classmethod
@abstractmethod
def load_bytes(
cls, data: bytes, num_frames: int = -1, **kwargs
) -> tuple[npt.NDArray, dict[str, Any]]:
raise NotImplementedError
VIDEO_LOADER_REGISTRY = ExtensionManager()
@VIDEO_LOADER_REGISTRY.register("opencv")
class OpenCVVideoBackend(VideoLoader):
def get_cv2_video_api(self):
import cv2.videoio_registry as vr
api_pref = None
for backend in vr.getStreamBufferedBackends():
if not vr.hasBackend(backend):
continue
if not vr.isBackendBuiltIn(backend):
_, abi, api = vr.getStreamBufferedBackendPluginVersion(backend)
if abi < 1 or (abi == 1 and api < 2):
continue
api_pref = backend
break
return api_pref
@classmethod
def load_bytes(
cls,
data: bytes,
num_frames: int = -1,
fps: int = -1,
**kwargs,
) -> tuple[npt.NDArray, dict[str, Any]]:
import cv2
backend = cls().get_cv2_video_api()
cap = cv2.VideoCapture(BytesIO(data), backend, [])
if not cap.isOpened():
raise ValueError("Could not open video stream")
total_frames_num = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
original_fps = cap.get(cv2.CAP_PROP_FPS)
duration = total_frames_num / original_fps if original_fps > 0 else 0
# resample video to target num_frames and fps
# - the minimum of the two will be used
num_frames_to_sample = total_frames_num
if num_frames > 0:
num_frames_to_sample = min(num_frames, total_frames_num)
if fps > 0:
num_frames_to_sample = min(num_frames_to_sample, math.floor(duration * fps))
num_frames_to_sample = max(1, num_frames_to_sample) # at least one sample
if num_frames_to_sample == total_frames_num:
frame_idx = list(range(0, num_frames_to_sample))
else:
uniform_sampled_frames = np.linspace(
0, total_frames_num - 1, num_frames_to_sample, dtype=int
)
frame_idx = uniform_sampled_frames.tolist()
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
frames = np.empty((len(frame_idx), height, width, 3), dtype=np.uint8)
i = 0
for idx in range(max(frame_idx) + 1):
ok = cap.grab()
if not ok:
break
if idx in frame_idx:
ret, frame = cap.retrieve()
if ret:
frames[i] = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
i += 1
assert i == num_frames_to_sample, (
f"Expected reading {num_frames_to_sample} frames, "
f"but only loaded {i} frames from video."
)
# Use transformers transformers.video_utils.VideoMetadata format
# NOTE(Isotr0py): For models like Qwen3-VL/GLM4.5V, this metadata
# can cause incorrect timestamp calculation without num_frames=-1.
metadata = {
"total_num_frames": total_frames_num,
"fps": original_fps,
"duration": duration,
"video_backend": "opencv",
"frames_indices": list(frame_idx),
# extra field used to control hf processor's video
# sampling behavior
"do_sample_frames": num_frames_to_sample == total_frames_num,
}
return frames, metadata
@VIDEO_LOADER_REGISTRY.register("opencv_dynamic")
class OpenCVDynamicVideoBackend(OpenCVVideoBackend):
@classmethod
def load_bytes(
cls,
data: bytes,
num_frames: int = -1,
fps: int = 2,
max_duration: int = 300,
**kwargs,
) -> tuple[npt.NDArray, dict[str, Any]]:
import cv2
backend = cls().get_cv2_video_api()
cap = cv2.VideoCapture(BytesIO(data), backend, [])
if not cap.isOpened():
raise ValueError("Could not open video stream")
total_frames_num = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
original_fps = cap.get(cv2.CAP_PROP_FPS)
duration = total_frames_num / original_fps if original_fps > 0 else 0
# resample video to target num_frames
max_frame_idx = total_frames_num - 1
duration = duration or round(max_frame_idx / original_fps) + 1
# Refer to:
# https://github.com/huggingface/transformers/blob/v4.55.4/src/transformers/models/glm4v/video_processing_glm4v.py#L103-L140
frame_indices: range | list[int]
if duration <= max_duration:
n = int(math.floor(duration * fps))
frame_indices = sorted(
{
min(max_frame_idx, int(math.ceil(i * original_fps / fps)))
for i in range(n)
}
)
else:
num_samples = int(max_duration * fps)
if num_samples >= total_frames_num:
frame_indices = range(total_frames_num)
else:
target_seconds = np.linspace(0, duration, num_samples, endpoint=True)
frame_indices = sorted(
{
min(max_frame_idx, int(math.ceil(t * original_fps)))
for t in target_seconds
}
)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
frames = np.empty((len(frame_indices), height, width, 3), dtype=np.uint8)
i = 0
for idx in range(total_frames_num):
ok = cap.grab()
if not ok:
break
if idx in frame_indices:
ret, frame = cap.retrieve()
if ret:
frames[i] = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
i += 1
assert i == len(frame_indices), (
f"Expected reading {len(frame_indices)} frames, "
f"but only loaded {i} frames from video."
)
# Use transformers transformers.video_utils.VideoMetadata format
metadata = {
"total_num_frames": total_frames_num,
"fps": original_fps,
"duration": duration,
"video_backend": "opencv_dynamic",
"frames_indices": list(frame_indices),
"do_sample_frames": False,
}
return frames, metadata
class VideoMediaIO(MediaIO[npt.NDArray]):
def __init__(
self,
image_io: ImageMediaIO,
num_frames: int = 32,
**kwargs,
) -> None:
super().__init__()
self.image_io = image_io
self.num_frames = num_frames
# `kwargs` contains custom arguments from
# --media-io-kwargs for this modality.
# They can be passed to the underlying
# media loaders (e.g. custom implementations)
# for flexible control.
self.kwargs = kwargs
video_loader_backend = envs.VLLM_VIDEO_LOADER_BACKEND
self.video_loader = VIDEO_LOADER_REGISTRY.load(video_loader_backend)
def load_bytes(self, data: bytes) -> tuple[npt.NDArray, dict[str, Any]]:
return self.video_loader.load_bytes(
data, num_frames=self.num_frames, **self.kwargs
)
def load_base64(
self, media_type: str, data: str
) -> tuple[npt.NDArray, dict[str, Any]]:
if media_type.lower() == "video/jpeg":
load_frame = partial(
self.image_io.load_base64,
"image/jpeg",
)
return np.stack(
[np.asarray(load_frame(frame_data)) for frame_data in data.split(",")]
), {}
return self.load_bytes(base64.b64decode(data))
def load_file(self, filepath: Path) -> tuple[npt.NDArray, dict[str, Any]]:
with filepath.open("rb") as f:
data = f.read()
return self.load_bytes(data)
def encode_base64(
self,
media: npt.NDArray,
*,
video_format: str = "JPEG",
) -> str:
video = media
if video_format == "JPEG":
encode_frame = partial(
self.image_io.encode_base64,
image_format=video_format,
)
return ",".join(encode_frame(Image.fromarray(frame)) for frame in video)
msg = "Only JPEG format is supported for now."
raise NotImplementedError(msg)