Add minimal vLLM 0.16.1 build repo for BI-V150

This commit is contained in:
2026-04-18 10:56:22 +08:00
commit d69657327e
1895 changed files with 615301 additions and 0 deletions

View File

@@ -0,0 +1,38 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from .hasher import MultiModalHasher
from .inputs import (
BatchedTensorInputs,
ModalityData,
MultiModalDataBuiltins,
MultiModalDataDict,
MultiModalKwargsItems,
MultiModalPlaceholderDict,
MultiModalUUIDDict,
NestedTensors,
)
from .registry import MultiModalRegistry
MULTIMODAL_REGISTRY = MultiModalRegistry()
"""
The global [`MultiModalRegistry`][vllm.multimodal.registry.MultiModalRegistry]
is used by model runners to dispatch data processing according to the target
model.
Info:
[mm_processing](../../../design/mm_processing.md)
"""
__all__ = [
"BatchedTensorInputs",
"ModalityData",
"MultiModalDataBuiltins",
"MultiModalDataDict",
"MultiModalHasher",
"MultiModalKwargsItems",
"MultiModalPlaceholderDict",
"MultiModalUUIDDict",
"NestedTensors",
"MULTIMODAL_REGISTRY",
"MultiModalRegistry",
]

336
vllm/multimodal/audio.py Normal file
View File

@@ -0,0 +1,336 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from dataclasses import dataclass
from enum import Enum
from typing import Literal
import numpy as np
import numpy.typing as npt
import torch
from vllm.utils.import_utils import PlaceholderModule
try:
import librosa
except ImportError:
librosa = PlaceholderModule("librosa") # type: ignore[assignment]
try:
import scipy.signal as scipy_signal
except ImportError:
scipy_signal = PlaceholderModule("scipy").placeholder_attr("signal") # type: ignore[assignment]
# ============================================================
class ChannelReduction(str, Enum):
"""Method to reduce multi-channel audio to target channels."""
MEAN = "mean" # Average across channels (default, preserves energy balance)
FIRST = "first" # Take first channel only
MAX = "max" # Take max value across channels
SUM = "sum" # Sum across channels
@dataclass
class AudioSpec:
"""Specification for target audio format.
This dataclass defines the expected audio format for a model's feature
extractor. It is used to normalize audio data before processing.
Attributes:
target_channels: Number of output channels. None means passthrough
(no normalization). 1 = mono, 2 = stereo, etc.
channel_reduction: Method to reduce channels when input has more
channels than target. Only used when reducing channels.
"""
target_channels: int | None = 1
channel_reduction: ChannelReduction = ChannelReduction.MEAN
@property
def needs_normalization(self) -> bool:
"""Whether audio normalization is needed."""
return self.target_channels is not None
def __repr__(self) -> str:
if self.target_channels is None:
return "AudioSpec(passthrough)"
return (
f"AudioSpec(channels={self.target_channels}, "
f"reduction={self.channel_reduction.value})"
)
# Pre-defined specs for common use cases
MONO_AUDIO_SPEC = AudioSpec(target_channels=1, channel_reduction=ChannelReduction.MEAN)
PASSTHROUGH_AUDIO_SPEC = AudioSpec(target_channels=None)
def normalize_audio(
audio: npt.NDArray[np.floating] | torch.Tensor,
spec: AudioSpec,
) -> npt.NDArray[np.floating] | torch.Tensor:
"""Normalize audio to the specified format.
This function handles channel reduction for multi-channel audio,
supporting both numpy arrays and torch tensors.
Args:
audio: Input audio data. Can be:
- 1D array/tensor: (time,) - already mono
- 2D array/tensor: (channels, time) - standard format from torchaudio
- 2D array/tensor: (time, channels) - format from soundfile
(will be auto-detected and transposed if time > channels)
spec: AudioSpec defining the target format.
Returns:
Normalized audio in the same type as input (numpy or torch).
For mono output (target_channels=1), returns 1D array/tensor.
Raises:
ValueError: If audio has unsupported dimensions or channel expansion
is requested (e.g., mono to stereo).
"""
if not spec.needs_normalization:
return audio
# Handle 1D audio (already mono)
if audio.ndim == 1:
if spec.target_channels == 1:
return audio
raise ValueError(f"Cannot expand mono audio to {spec.target_channels} channels")
# Handle 2D audio
if audio.ndim != 2:
raise ValueError(f"Unsupported audio shape: {audio.shape}. Expected 1D or 2D.")
# Auto-detect format: if shape[0] > shape[1], assume (time, channels)
# This handles soundfile format where time dimension is typically much larger
if audio.shape[0] > audio.shape[1]:
# Transpose from (time, channels) to (channels, time)
audio = audio.T if isinstance(audio, np.ndarray) else audio.T
num_channels = audio.shape[0]
# No reduction needed if already at target
if num_channels == spec.target_channels:
return audio
# Cannot expand channels
if num_channels < spec.target_channels:
raise ValueError(
f"Cannot expand {num_channels} channels to {spec.target_channels}"
)
# Reduce channels
is_numpy = isinstance(audio, np.ndarray)
if spec.target_channels == 1:
# Reduce to mono
if spec.channel_reduction == ChannelReduction.MEAN:
result = np.mean(audio, axis=0) if is_numpy else audio.mean(dim=0)
elif spec.channel_reduction == ChannelReduction.FIRST:
result = audio[0]
elif spec.channel_reduction == ChannelReduction.MAX:
result = np.max(audio, axis=0) if is_numpy else audio.max(dim=0).values
elif spec.channel_reduction == ChannelReduction.SUM:
result = np.sum(audio, axis=0) if is_numpy else audio.sum(dim=0)
else:
raise ValueError(f"Unknown reduction method: {spec.channel_reduction}")
return result
else:
# Reduce to N channels (take first N and apply reduction if needed)
# For now, just take first N channels
return audio[: spec.target_channels]
# ============================================================
# Audio Resampling
# ============================================================
def resample_audio_librosa(
audio: npt.NDArray[np.floating],
*,
orig_sr: float,
target_sr: float,
) -> npt.NDArray[np.floating]:
return librosa.resample(audio, orig_sr=orig_sr, target_sr=target_sr)
def resample_audio_scipy(
audio: npt.NDArray[np.floating],
*,
orig_sr: float,
target_sr: float,
):
if orig_sr > target_sr:
return scipy_signal.resample_poly(audio, 1, orig_sr // target_sr)
elif orig_sr < target_sr:
return scipy_signal.resample_poly(audio, target_sr // orig_sr, 1)
return audio
class AudioResampler:
"""Resample audio data to a target sample rate."""
def __init__(
self,
target_sr: float | None = None,
method: Literal["librosa", "scipy"] = "librosa",
):
self.target_sr = target_sr
self.method = method
def resample(
self,
audio: npt.NDArray[np.floating],
*,
orig_sr: float,
) -> npt.NDArray[np.floating]:
if self.target_sr is None:
raise RuntimeError(
"Audio resampling is not supported when `target_sr` is not provided"
)
if math.isclose(
float(orig_sr),
float(self.target_sr),
rel_tol=0.0,
abs_tol=1e-6,
):
return audio
if self.method == "librosa":
return resample_audio_librosa(
audio, orig_sr=orig_sr, target_sr=self.target_sr
)
elif self.method == "scipy":
return resample_audio_scipy(
audio, orig_sr=orig_sr, target_sr=self.target_sr
)
else:
raise ValueError(
f"Invalid resampling method: {self.method}. "
"Supported methods are 'librosa' and 'scipy'."
)
# ============================================================
# Audio Chunking / Splitting
# ============================================================
def split_audio(
audio_data: np.ndarray,
sample_rate: int,
max_clip_duration_s: float,
overlap_duration_s: float,
min_energy_window_size: int,
) -> list[np.ndarray]:
"""Split audio into chunks with intelligent split points.
Splits long audio into smaller chunks at low-energy regions to minimize
cutting through speech. Uses overlapping windows to find quiet moments
for splitting.
Args:
audio_data: Audio array to split. Can be 1D (mono) or multi-dimensional.
Splits along the last dimension (time axis).
sample_rate: Sample rate of the audio in Hz.
max_clip_duration_s: Maximum duration of each chunk in seconds.
overlap_duration_s: Overlap duration in seconds between consecutive chunks.
Used to search for optimal split points.
min_energy_window_size: Window size in samples for finding low-energy regions.
Returns:
List of audio chunks. Each chunk is a numpy array with the same shape
as the input except for the last (time) dimension.
Example:
>>> audio = np.random.randn(1040000) # 65 seconds at 16kHz
>>> chunks = split_audio(
... audio_data=audio,
... sample_rate=16000,
... max_clip_duration_s=30.0,
... overlap_duration_s=1.0,
... min_energy_window_size=1600,
... )
>>> len(chunks)
3
"""
chunk_size = int(sample_rate * max_clip_duration_s)
overlap_size = int(sample_rate * overlap_duration_s)
chunks = []
i = 0
while i < audio_data.shape[-1]:
if i + chunk_size >= audio_data.shape[-1]:
# Handle last chunk - take everything remaining
chunks.append(audio_data[..., i:])
break
# Find the best split point in the overlap region
search_start = i + chunk_size - overlap_size
search_end = min(i + chunk_size, audio_data.shape[-1])
split_point = find_split_point(
audio_data, search_start, search_end, min_energy_window_size
)
# Extract chunk up to the split point
chunks.append(audio_data[..., i:split_point])
i = split_point
return chunks
def find_split_point(
wav: np.ndarray,
start_idx: int,
end_idx: int,
min_energy_window: int,
) -> int:
"""Find the best point to split audio by looking for silence or low amplitude.
Searches for the quietest region within a specified range by calculating
RMS energy in sliding windows.
Args:
wav: Audio array. Can be 1D or multi-dimensional.
start_idx: Start index of search region (inclusive).
end_idx: End index of search region (exclusive).
min_energy_window: Window size in samples for energy calculation.
Returns:
Index of the quietest point within the search region. This is the
recommended split point to minimize audio artifacts.
Example:
>>> audio = np.random.randn(32000)
>>> # Insert quiet region
>>> audio[16000:17600] = 0.01
>>> split_idx = find_split_point(
... wav=audio,
... start_idx=0,
... end_idx=32000,
... min_energy_window=1600,
... )
>>> 16000 <= split_idx <= 17600
True
"""
segment = wav[start_idx:end_idx]
# Calculate RMS energy in small windows
min_energy = math.inf
quietest_idx = 0
for i in range(0, len(segment) - min_energy_window, min_energy_window):
window = segment[i : i + min_energy_window]
energy = (window**2).mean() ** 0.5
if energy < min_energy:
quietest_idx = i + start_idx
min_energy = energy
return quietest_idx

725
vllm/multimodal/cache.py Normal file
View File

@@ -0,0 +1,725 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import operator
import sys
from abc import ABC, abstractmethod
from collections.abc import Mapping, Sequence
from multiprocessing.synchronize import Lock as LockType
from typing import TYPE_CHECKING, Generic, TypeAlias, TypeVar, cast
import torch
from typing_extensions import override
import vllm.envs as envs
from vllm.distributed.device_communicators.shm_object_storage import (
MsgpackSerde,
SingleWriterShmObjectStorage,
SingleWriterShmRingBuffer,
)
from vllm.logger import init_logger
from vllm.utils.cache import CacheInfo, LRUCache
from vllm.utils.jsontree import json_count_leaves, json_map_leaves, json_reduce_leaves
from vllm.utils.mem_constants import GiB_bytes, MiB_bytes
from vllm.utils.mem_utils import format_gib
from .inputs import (
MultiModalBatchedField,
MultiModalFeatureSpec,
MultiModalFieldElem,
MultiModalKwargsItem,
MultiModalKwargsItems,
NestedTensors,
)
if TYPE_CHECKING:
from vllm.config import ModelConfig, VllmConfig
from .processing.processor import ResolvedPromptUpdate
logger = init_logger(__name__)
class MultiModalProcessorCacheItem:
"""
The data to store inside `MultiModalProcessorOnlyCache`.
Args:
item: The processed tensor data corresponding to a multi-modal item.
prompt_updates: The prompt updates corresponding to `item`.
"""
def __init__(
self,
item: MultiModalKwargsItem,
prompt_updates: Sequence["ResolvedPromptUpdate"],
) -> None:
super().__init__()
self.item = item
self.prompt_updates = prompt_updates
class MultiModalProcessorCacheItemMetadata:
"""
The metadata to store inside `MultiModalProcessorSenderCache`.
Args:
item: The processed tensor data corresponding to a multi-modal item.
Since P1 already stores the tensor data, we only store its size
metadata in P0 to reduce memory usage. The size metadata is still
needed to keep the same cache eviction policy as P0.
prompt_updates: The prompt updates corresponding to `item`.
This needs to stay on P0 because for some models, they are
dependent on the processed tensor data (cached on P1).
"""
def __init__(
self,
item: MultiModalKwargsItem,
prompt_updates: Sequence["ResolvedPromptUpdate"],
) -> None:
super().__init__()
self.item_size = MultiModalCache.get_item_size(item)
self.prompt_updates = prompt_updates
MultiModalCacheValue: TypeAlias = (
MultiModalProcessorCacheItem
| MultiModalProcessorCacheItemMetadata
| MultiModalKwargsItems
| MultiModalKwargsItem
| Mapping[str, NestedTensors]
)
_V = TypeVar("_V", bound=MultiModalCacheValue)
class MultiModalCache:
@classmethod
def get_leaf_size(cls, leaf: object) -> int:
if isinstance(leaf, MultiModalProcessorCacheItem):
return cls.get_leaf_size(leaf.item)
if isinstance(leaf, MultiModalProcessorCacheItemMetadata):
return leaf.item_size
# These are not subclasses of dict
if isinstance(
leaf,
(MultiModalKwargsItems, MultiModalKwargsItem, MultiModalFieldElem),
):
return cls.get_item_size(leaf.data) # type: ignore
# sys.getsizeof doesn't work for tensors
if isinstance(leaf, torch.Tensor):
return leaf.nbytes
return sys.getsizeof(leaf)
@classmethod
def get_item_size(
cls,
value: MultiModalCacheValue,
*,
debug: bool = False,
) -> int:
size = json_reduce_leaves(
operator.add, json_map_leaves(cls.get_leaf_size, value)
)
if debug:
leaf_count = json_count_leaves(value)
logger.debug(
"Calculated size of %s to be %s GiB (%d leaves)",
type(value),
format_gib(size),
leaf_count,
)
return size
@classmethod
def get_item_complexity(cls, value: MultiModalCacheValue) -> int:
"""
Get the number of leaf elements in a multi-modal cache value.
This provides a measure of structural complexity that can be useful
for debugging cache performance and understanding data patterns.
Args:
value: The multi-modal cache value to analyze.
Returns:
The number of leaf elements in the nested structure.
"""
return json_count_leaves(value)
@classmethod
def get_lru_cache(
cls,
capacity_gb: float,
value_type: type[_V],
*,
debug: bool = False,
) -> LRUCache[str, _V]:
return LRUCache(
GiB_bytes * capacity_gb,
getsizeof=lambda x: cls.get_item_size(x, debug=debug),
)
_I = TypeVar("_I", contravariant=True)
_O = TypeVar("_O", covariant=True)
class BaseMultiModalCache(ABC, Generic[_I, _O]):
"""
Abstract base class to read/write multi-modal items from cache.
The idea of multi-modal caching is based on having a client and server
where the client executes in the frontend process (=P0) and
the server in the core process (=P1). The data flow is as follows:
```
is_cached() x N get_and_update()
P0: From API -----------------> -----------------> To P1
get_and_update()
P1: From P0 -----------------> To model
```
`is_cached()` can be called any number of times in P0. However,
`get_and_update()` must be called in P0 and P1 one after another
so that their cache eviction order remains the same.
This ensures that the keys in P0 and P1 caches are mirrored,
allowing us to determine whether a key is cached in P1 by looking
up the P0 cache, without having to communicate with P1.
"""
@abstractmethod
def get_and_update_item(
self,
mm_item: _I,
mm_hash: str,
) -> _O:
"""
Possibly update a multi-modal item based on whether it is
in the underlying cache.
This update is done out-of-place and updates the cache eviction order.
Args:
mm_item: The multi-modal item to update.
mm_hash: The hash of `mm_item`.
Returns:
The update multi-modal item.
"""
raise NotImplementedError
def get_and_update(
self,
mm_items: Sequence[_I],
mm_hashes: list[str],
) -> list[_O]:
"""
Possibly update a sequence of multi-modal items based on whether they
are in the underlying cache.
This update is done out-of-place and updates the cache eviction order.
Args:
mm_items: The multi-modal items to update.
mm_hashes: The hash of each item in `mm_items`.
Returns:
A new list of updated multi-modal items.
"""
assert len(mm_items) == len(mm_hashes)
return [
self.get_and_update_item(mm_item, mm_hash)
for mm_item, mm_hash in zip(mm_items, mm_hashes)
]
@abstractmethod
def clear_cache(self) -> None:
"""Clear the underlying cache."""
raise NotImplementedError
MultiModalProcessorCacheInItem: TypeAlias = (
tuple[MultiModalKwargsItem, Sequence["ResolvedPromptUpdate"]] | None
)
MultiModalProcessorCacheOutItem: TypeAlias = tuple[
MultiModalKwargsItem | None, Sequence["ResolvedPromptUpdate"]
]
class BaseMultiModalProcessorCache(
BaseMultiModalCache[MultiModalProcessorCacheInItem, MultiModalProcessorCacheOutItem]
):
"""The required interface for caches on P0."""
@abstractmethod
def is_cached_item(self, mm_hash: str) -> bool:
"""
Check whether a multi-modal item is
in the underlying cache.
This **DOES NOT** update the cache eviction order.
Args:
mm_hash: The hash of the item to check.
Returns:
`True` if the item is cached, otherwise `False`.
"""
raise NotImplementedError
def is_cached(self, mm_hashes: list[str]) -> list[bool]:
"""
Check whether a sequence of multi-modal items are
in the underlying cache.
This **DOES NOT** update the cache eviction order.
Args:
mm_hashes: The hash of each item to check.
Returns:
For each item, `True` if the item is cached, otherwise `False`.
"""
return [self.is_cached_item(mm_hash) for mm_hash in mm_hashes]
def close(self) -> None:
"""Close the underlying cache, if needed."""
pass
@abstractmethod
def touch_sender_cache_item(self, mm_hash: str) -> None:
"""
Update the cache eviction order for a multi-modal item.
This is used to touch the item in the cache without changing
its value.
Args:
mm_hash: The hash of the multi-modal item.
"""
raise NotImplementedError
@abstractmethod
def make_stats(self, *, delta: bool = False) -> CacheInfo:
"""
Get (and reset) the multi-modal cache stats.
Returns:
The current multi-modal caching stats.
"""
raise NotImplementedError
class MultiModalProcessorOnlyCache(BaseMultiModalProcessorCache):
"""
The cache which is used on P0 when IPC caching is disabled.
How to update each item:
- If the item is in the cache, replace the input with the cached item.
- If the item is not in the cache, store that item (which includes
tensor data and metadata) into the cache, and return the input.
"""
def __init__(self, model_config: "ModelConfig") -> None:
super().__init__()
mm_config = model_config.get_multimodal_config()
self._cache = MultiModalCache.get_lru_cache(
mm_config.mm_processor_cache_gb,
MultiModalProcessorCacheItem,
)
@override
def is_cached_item(self, mm_hash: str) -> bool:
return mm_hash in self._cache
@override
def get_and_update_item(
self,
mm_item: MultiModalProcessorCacheInItem,
mm_hash: str,
) -> MultiModalProcessorCacheOutItem:
if (cached_item := self._cache.get(mm_hash)) is not None:
return cached_item.item, cached_item.prompt_updates
assert mm_item is not None, f"Expected a cached item for {mm_hash=}"
self._cache[mm_hash] = MultiModalProcessorCacheItem(*mm_item)
return mm_item
@override
def touch_sender_cache_item(self, mm_hash: str) -> None:
self._cache.touch(mm_hash)
@override
def clear_cache(self) -> None:
self._cache.clear()
@override
def make_stats(self, *, delta: bool = False) -> CacheInfo:
return self._cache.stat(delta=delta)
class MultiModalProcessorSenderCache(BaseMultiModalProcessorCache):
"""
The cache which is used on P0 when IPC caching is enabled.
How to update each item:
- If the item is already in the cache, clear the input to avoid
unnecessary IPC.
- If the item is not in the cache, store the metadata of that item so
that the eviction policy remains the same as the cache on P1,
and return the input.
By only storing the metadata, we avoid keeping the data itself in
memory inside P0.
"""
def __init__(self, model_config: "ModelConfig") -> None:
super().__init__()
mm_config = model_config.get_multimodal_config()
self._cache = MultiModalCache.get_lru_cache(
mm_config.mm_processor_cache_gb,
MultiModalProcessorCacheItemMetadata,
)
@override
def is_cached_item(self, mm_hash: str) -> bool:
return mm_hash in self._cache
@override
def get_and_update_item(
self,
mm_item: MultiModalProcessorCacheInItem,
mm_hash: str,
) -> MultiModalProcessorCacheOutItem:
if (cached_item := self._cache.get(mm_hash)) is not None:
return None, cached_item.prompt_updates
assert mm_item is not None, f"Expected a cached item for {mm_hash=}"
self._cache[mm_hash] = MultiModalProcessorCacheItemMetadata(*mm_item)
return mm_item
@override
def touch_sender_cache_item(self, mm_hash: str) -> None:
self._cache.touch(mm_hash)
@override
def clear_cache(self) -> None:
self._cache.clear()
@override
def make_stats(self, *, delta: bool = False) -> CacheInfo:
return self._cache.stat(delta=delta)
class ShmObjectStoreSenderCache(BaseMultiModalProcessorCache):
"""
The cache which is used on P0 when IPC caching is enabled.
How to update each item:
- If the item is already in the cache, clear the input to avoid
unnecessary IPC.
- If the item is not in the cache, store the data in shared memory.
"""
def __init__(self, vllm_config: "VllmConfig") -> None:
super().__init__()
self.world_size = vllm_config.parallel_config.world_size
mm_config = vllm_config.model_config.get_multimodal_config()
ring_buffer = SingleWriterShmRingBuffer(
data_buffer_size=int(mm_config.mm_processor_cache_gb * GiB_bytes),
name=envs.VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME,
create=True, # sender is the writer
)
self._shm_cache = SingleWriterShmObjectStorage(
max_object_size=mm_config.mm_shm_cache_max_object_size_mb * MiB_bytes,
n_readers=self.world_size,
ring_buffer=ring_buffer,
serde_class=MsgpackSerde,
)
# cache prompt_updates for P0 only
self._p0_cache: dict[str, Sequence[ResolvedPromptUpdate]] = {}
self._hits = 0
self._total = 0
self._last_info = CacheInfo(hits=0, total=0)
def _stat(self, *, delta: bool = False) -> CacheInfo:
info = CacheInfo(hits=self._hits, total=self._total)
if delta:
info_delta = info - self._last_info
self._last_info = info
info = info_delta
return info
@override
def is_cached_item(self, mm_hash: str) -> bool:
return self._shm_cache.is_cached(mm_hash)
@override
def get_and_update_item(
self,
mm_item: MultiModalProcessorCacheInItem,
mm_hash: str,
) -> MultiModalProcessorCacheOutItem:
if self._shm_cache.is_cached(mm_hash):
self._hits += 1
self._total += 1
address, monotonic_id = self._shm_cache.get_cached(mm_hash)
prompt_updates = self._p0_cache[mm_hash]
return self.address_as_item(address, monotonic_id), prompt_updates
assert mm_item is not None, f"Expected a cached item for {mm_hash=}"
item, prompt_updates = mm_item
self._total += 1
try:
address, monotonic_id = self._shm_cache.put(mm_hash, item)
# Try to remove dangling items if p0 cache is too large.
if len(self._p0_cache) >= 2 * len(self._shm_cache.key_index):
self.remove_dangling_items()
self._p0_cache[mm_hash] = prompt_updates
return self.address_as_item(address, monotonic_id), prompt_updates
except (ValueError, MemoryError) as e:
# put may fail if the object is too large or
# the cache is full.
# In this case we log the error and keep the original mm_input.
logger.debug("Failed to cache mm_input with hash %s: %s", mm_hash, e)
return mm_item
@override
def touch_sender_cache_item(self, mm_hash: str) -> None:
"""Touch the item in shared memory cache to prevent eviction.
Increments writer_flag on sender side."""
self._shm_cache.touch(mm_hash)
@override
def clear_cache(self) -> None:
self._shm_cache.clear()
self._p0_cache.clear()
self._hits = 0
self._total = 0
self._last_info = CacheInfo(hits=0, total=0)
@override
def make_stats(self, *, delta: bool = False) -> CacheInfo:
return self._stat(delta=delta)
@override
def close(self) -> None:
self._shm_cache.close()
def remove_dangling_items(self) -> None:
"""Remove items that are no longer in the shared memory cache."""
cached_hashes = self._shm_cache.key_index.keys()
dangling_hashes = set(self._p0_cache.keys()) - cached_hashes
for mm_hash in dangling_hashes:
del self._p0_cache[mm_hash]
def address_as_item(
self,
address: int,
monotonic_id: int,
) -> MultiModalKwargsItem:
addr_elem = MultiModalFieldElem(
data=address,
field=MultiModalBatchedField(),
)
id_elem = MultiModalFieldElem(
data=monotonic_id,
field=MultiModalBatchedField(),
)
return MultiModalKwargsItem({"address": addr_elem, "monotonic_id": id_elem})
class BaseMultiModalReceiverCache(
BaseMultiModalCache[MultiModalKwargsItem | None, MultiModalKwargsItem]
):
"""The required interface for caches on P1."""
def get_and_update_features(
self,
mm_features: list["MultiModalFeatureSpec"],
) -> list["MultiModalFeatureSpec"]:
"""
Update multimodal features with cached encoder outputs.
Touch all identifier at first before update to avoid
item in updated list evict during update.
Uses mm_hash for cache key to share across LoRAs (falls back to
identifier for backward compatibility).
"""
for feature in mm_features:
cache_key = feature.mm_hash or feature.identifier
self.touch_receiver_cache_item(cache_key, feature.data)
for feature in mm_features:
cache_key = feature.mm_hash or feature.identifier
feature.data = self.get_and_update_item(feature.data, cache_key)
return mm_features
@abstractmethod
def touch_receiver_cache_item(
self,
mm_hash: str,
mm_item: MultiModalKwargsItem | None = None,
) -> None:
"""
Update the cache eviction order for a multi-modal item.
This is used to touch the item in the cache without changing
its value.
Args:
mm_hash: The hash of the multi-modal item.
mm_item: The multi-modal item itself. This is optional and
may not be needed by some cache implementations.
"""
raise NotImplementedError
class MultiModalReceiverCache(BaseMultiModalReceiverCache):
"""
The cache which is used on P1 when IPC caching is enabled.
How to update each item:
- If the item is in the cache, replace the input with the cached item.
- If the item is not in the cache, store that item (which includes tensor
data) into the cache, and return the input.
"""
def __init__(self, model_config: "ModelConfig") -> None:
super().__init__()
mm_config = model_config.get_multimodal_config()
self._cache = MultiModalCache.get_lru_cache(
mm_config.mm_processor_cache_gb,
MultiModalKwargsItem,
)
@override
def get_and_update_item(
self,
mm_item: MultiModalKwargsItem | None,
mm_hash: str,
) -> MultiModalKwargsItem:
if (cached_item := self._cache.get(mm_hash)) is not None:
return cached_item
assert mm_item is not None, f"Expected a cached item for {mm_hash=}"
self._cache[mm_hash] = mm_item
return mm_item
@override
def touch_receiver_cache_item(
self,
mm_hash: str,
mm_item: MultiModalKwargsItem | None = None,
) -> None:
self._cache.touch(mm_hash)
@override
def clear_cache(self) -> None:
self._cache.clear()
class ShmObjectStoreReceiverCache(BaseMultiModalReceiverCache):
"""
The cache which is used on P1 Worker Process when IPC caching is enabled.
How to update each item:
- If the item has an address, replace the input with the cached item.
- If not, return the input.
"""
def __init__(
self,
vllm_config: "VllmConfig",
shared_worker_lock: LockType,
) -> None:
super().__init__()
self.world_size = vllm_config.parallel_config.world_size
mm_config = vllm_config.model_config.get_multimodal_config()
ring_buffer = SingleWriterShmRingBuffer(
data_buffer_size=int(mm_config.mm_processor_cache_gb * GiB_bytes),
name=envs.VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME,
create=False, # Server is a reader
)
self._shm_cache = SingleWriterShmObjectStorage(
max_object_size=mm_config.mm_shm_cache_max_object_size_mb * MiB_bytes,
n_readers=self.world_size,
ring_buffer=ring_buffer,
serde_class=MsgpackSerde,
reader_lock=shared_worker_lock,
)
@override
def get_and_update_item(
self,
mm_item: MultiModalKwargsItem | None,
mm_hash: str,
) -> MultiModalKwargsItem:
assert mm_item is not None, f"Expected an address item for {mm_hash=}"
if "address" in mm_item:
address = cast(int, mm_item["address"].data)
monotonic_id = cast(int, mm_item["monotonic_id"].data)
return self._shm_cache.get(address, monotonic_id)
return mm_item
@override
def touch_receiver_cache_item(
self,
mm_hash: str,
mm_item: MultiModalKwargsItem | None = None,
) -> None:
"""Touch the item in shared memory cache to prevent eviction.
Increments reader_count on receiver side."""
assert mm_item is not None
if "address" in mm_item:
address = cast(int, mm_item["address"].data)
monotonic_id = cast(int, mm_item["monotonic_id"].data)
self._shm_cache.touch(mm_hash, address=address, monotonic_id=monotonic_id)
@override
def clear_cache(self) -> None:
self._shm_cache.clear()

View File

@@ -0,0 +1,193 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Mapping
from vllm.config import ModelConfig, VllmConfig
from vllm.logger import init_logger
from vllm.multimodal.processing import BaseMultiModalProcessor
from vllm.multimodal.registry import MultiModalRegistry
from vllm.utils.torch_utils import set_default_torch_num_threads
from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget
logger = init_logger(__name__)
def get_mm_max_toks_per_item(
model_config: ModelConfig,
mm_registry: MultiModalRegistry,
processor: BaseMultiModalProcessor,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
"""
Get the maximum number of tokens per data item from each modality based
on underlying model configuration.
"""
max_tokens_per_item = processor.info.get_mm_max_tokens_per_item(
seq_len=model_config.max_model_len,
mm_counts=mm_counts,
)
if max_tokens_per_item is not None:
return max_tokens_per_item
mm_inputs = mm_registry.get_dummy_mm_inputs(
model_config,
mm_counts=mm_counts,
processor=processor,
)
return {
modality: sum(item.get_num_embeds() for item in placeholders)
for modality, placeholders in mm_inputs["mm_placeholders"].items()
}
class MultiModalBudget:
"""Helper class to calculate budget information for multi-modal models."""
def __init__(
self,
vllm_config: VllmConfig,
mm_registry: MultiModalRegistry,
) -> None:
super().__init__()
self.model_config = model_config = vllm_config.model_config
self.scheduler_config = scheduler_config = vllm_config.scheduler_config
self.max_model_len = model_config.max_model_len
self.max_num_reqs = scheduler_config.max_num_seqs
with set_default_torch_num_threads(): # Avoid hang during startup
cache = mm_registry.processor_only_cache_from_config(vllm_config)
processor = mm_registry.create_processor(model_config, cache=cache)
self.cache = cache
self.processor = processor
mm_config = model_config.get_multimodal_config()
enable_mm_embeds = mm_config is not None and mm_config.enable_mm_embeds
supported_mm_limits = processor.info.supported_mm_limits
self.mm_limits = mm_limits = processor.info.allowed_mm_limits
# Modalities that pass through the MM encoder tower
tower_modalities = {
modality
for modality in supported_mm_limits
if mm_limits.get(modality, 0) > 0
}
# Modalities that bypass the tower (pre-computed embeddings only)
embed_only_modalities = {
modality
for modality in supported_mm_limits
if enable_mm_embeds and mm_limits.get(modality, 0) == 0
}
active_modalities = tower_modalities | embed_only_modalities
all_mm_max_toks_per_item = get_mm_max_toks_per_item(
model_config,
mm_registry,
processor,
mm_counts=dict.fromkeys(active_modalities, 1),
)
if embed_only_modalities:
logger.info_once(
"enable_mm_embeds is True; modalities handled as embedding-only: %s",
tuple(embed_only_modalities),
)
# Some models (e.g., Qwen3Omni with use_audio_in_video=True) share
# placeholders between modalities, so not all active modalities will
# have their own entry in the returned dict. We filter to only include
# modalities that have independent placeholder tokens.
active_mm_max_toks_per_item = {
modality: all_mm_max_toks_per_item[modality]
for modality in active_modalities
if modality in all_mm_max_toks_per_item
}
tower_mm_max_toks_per_item = {
modality: active_mm_max_toks_per_item[modality]
for modality in tower_modalities
if modality in active_mm_max_toks_per_item
}
# Encoder budget is computed from all active modalities (including
# embedding-only ones that need encoder cache space).
encoder_compute_budget, encoder_cache_size = compute_mm_encoder_budget(
scheduler_config,
active_mm_max_toks_per_item,
)
self.encoder_compute_budget = encoder_compute_budget
self.encoder_cache_size = encoder_cache_size
mm_max_items_per_prompt = dict[str, int]()
mm_max_items_per_batch = dict[str, int]()
# Per-prompt/per-batch limits are only relevant for tower modalities
# (embedding-only modalities don't go through the encoder tower).
for modality, max_toks_per_item in tower_mm_max_toks_per_item.items():
(
mm_max_items_per_prompt[modality],
mm_max_items_per_batch[modality],
) = self._get_max_items(modality, max_toks_per_item)
self.mm_max_toks_per_item = tower_mm_max_toks_per_item
self.mm_max_items_per_prompt: Mapping[str, int] = mm_max_items_per_prompt
self.mm_max_items_per_batch: Mapping[str, int] = mm_max_items_per_batch
def _get_max_items(
self,
modality: str,
max_tokens_per_item: int,
) -> tuple[int, int]:
if max_tokens_per_item == 0:
return 0, 0
# Check how many items of this modality can be supported by
# the encoder budget.
if (encoder_budget := self.get_encoder_budget()) == 0:
return 0, 0
max_encoder_items_per_batch = encoder_budget // max_tokens_per_item
# Check how many items of this modality can be supported by
# the decoder budget.
mm_limit = self.mm_limits[modality]
max_items_per_prompt = max(
1,
min(mm_limit, self.max_model_len // max_tokens_per_item),
)
scheduler_config = self.scheduler_config
max_num_reqs = self.max_num_reqs
if not scheduler_config.enable_chunked_prefill:
max_num_reqs = min(
max_num_reqs,
scheduler_config.max_num_batched_tokens // max_tokens_per_item,
)
max_decoder_items_per_batch = max_num_reqs * max_items_per_prompt
max_items_per_batch = max(
1,
min(max_encoder_items_per_batch, max_decoder_items_per_batch),
)
return max_items_per_prompt, max_items_per_batch
def get_modality_with_max_tokens(self) -> str:
mm_max_toks_per_item = self.mm_max_toks_per_item
modality, _ = max(mm_max_toks_per_item.items(), key=lambda x: (x[1], x[0]))
return modality
def get_encoder_budget(self) -> int:
return min(self.encoder_compute_budget, self.encoder_cache_size)
def reset_cache(self) -> None:
if self.cache is not None:
self.cache.clear_cache()

294
vllm/multimodal/evs.py Normal file
View File

@@ -0,0 +1,294 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
import typing
import torch
def compute_retained_tokens_count(
tokens_per_frame: int, num_frames: int, q: float
) -> int:
"""
Compute the number of retained tokens for a given video.
Method ensures that we retain all the tokens from the first frame
regardless of the pruning rate.
Args:
tokens_per_frame: The number of tokens per frame.
num_frames: The total number of frames.
q: The pruning rate.
Returns:
The number of retained tokens.
"""
total_tokens = tokens_per_frame * num_frames
evs_num_tokens = int(total_tokens * (1 - q))
min_num_tokens = tokens_per_frame
return max(min_num_tokens, evs_num_tokens)
def compute_retention_mask(
video_embeds: torch.Tensor,
video_size_thw: torch.LongTensor | tuple[int, int, int],
spatial_merge_size: int,
q: float,
) -> torch.Tensor:
"""
Computes the retention mask for input video embeddings.
Args:
video_embeds (`torch.Tensor`): The input video embeddings
of shape `(T * H * W // spatial_merge_size ^ 2, hidden_size)`
video_size_thw (`torch.LongTensor` of shape `(3)`):
The temporal, height and width of video.
spatial_merge_size: Size reduction for rows & cols dimensions.
q: (`float`): Pruning rate factor [0,1)
Returns:
`torch.Tensor`: The retention mask for the video embeddings of
`(T * H * W // spatial_merge_size ^ 2)` shape.
"""
T, H, W = map(int, video_size_thw)
# Use reshape instead of einops to avoid graph breaks
video_embeds = video_embeds.reshape(
T,
H // spatial_merge_size,
W // spatial_merge_size,
video_embeds.size(-1),
)
tokens_per_frame = (H // spatial_merge_size) * (W // spatial_merge_size)
# Core EVS
similarity = torch.nn.functional.cosine_similarity(
video_embeds[1:, ...], video_embeds[:-1, ...], dim=-1
)
dissimilarity = 1 - similarity
# Always ensure we include all tokens from the first frame
dissimilarity = torch.cat(
[255 * torch.ones_like(video_embeds[:1, :, :, 0]), dissimilarity], dim=0
)
dissimilarity_flat = dissimilarity.view(-1)
order = torch.argsort(dissimilarity_flat, dim=-1, descending=True, stable=True)
retain_num_tokens = compute_retained_tokens_count(
tokens_per_frame=tokens_per_frame, num_frames=T, q=q
)
topk_indices = order[:retain_num_tokens]
retention_mask = torch.zeros_like(dissimilarity_flat, dtype=torch.bool)
retention_mask[topk_indices] = True
retention_mask = retention_mask.reshape(dissimilarity.size())
mask = retention_mask.view(-1) # "T H W -> (T H W)"
return mask
def compute_mrope_for_media(
video_size_thw: torch.LongTensor,
spatial_merge_size: int,
tokens_per_second: float = 1.0,
video_second_per_grid: float = 1.0,
) -> torch.Tensor:
"""
Computes the mrope for video embeddings based on the grid dimensions.
Computed mrope positions match original qwen 2.5 implementation,
but positions are built for media being the first element in sequence.
Args:
video_size_thw: Media size (num frames, rows, cols)
spatial_merge_size: Size reduction for rows & cols dimensions.
tokens_per_second: Number of tokens per second.
video_second_per_grid: Number of seconds per video.
Returns:
Tensor of shape `(T * H * W, 4)` where last dimension
represents mrope positions [0:3), while the last channel
contains value of llm_grid_w repeated for all positions.
"""
llm_grid_t = video_size_thw[0]
llm_grid_h = video_size_thw[1] // spatial_merge_size
llm_grid_w = video_size_thw[2] // spatial_merge_size
t_index = (
(
torch.arange(llm_grid_t)
.view(-1, 1)
.expand(-1, llm_grid_h * llm_grid_w)
.mul(tokens_per_second * video_second_per_grid)
)
.long()
.flatten()
)
h_index = (
torch.arange(llm_grid_h)
.view(1, -1, 1)
.expand(llm_grid_t, -1, llm_grid_w)
.flatten()
)
w_index = (
torch.arange(llm_grid_w)
.view(1, 1, -1)
.expand(llm_grid_t, llm_grid_h, -1)
.flatten()
)
llm_grid_w = (
torch.tensor([llm_grid_w])
.view(1, 1, 1)
.expand(llm_grid_t, llm_grid_h, llm_grid_w)
.flatten()
)
positions = torch.stack([t_index, h_index, w_index, llm_grid_w], dim=1)
return positions
def recompute_mrope_positions(
input_ids: torch.LongTensor,
multimodal_positions: list[torch.Tensor],
mrope_positions: torch.LongTensor,
num_computed_tokens: int,
vision_start_token_id: int,
image_token_id: int,
video_token_id: int,
) -> tuple[torch.LongTensor, int]:
"""
Update part of input mrope positions.
Original mrope_positions are computed incorrectly, so once we prune media
tokens we should reflect this in the mrope positions for the LLM.
This method supports chunked prefill approach where
multimodal_embeddings are passed to LLM in chunks, so input
multimodal_embeddings may contain zero, some or even some part of all
multimodal_embeddings for a given prompt.
Each multimodal_positions has 4 extra channels
(First 3 channels corresponds to original 3 mrope positions, last channel
is the maximum width of the media repeated). Provided multimodal_positions
do not reflect location of media position in sequence - they are computed
like the media is in the 0-th position in the sequence.
Method works as follows: it recomputes mrope_positions starting from the
`num_computed_tokens` for `total_len_of_multimodal_embeddings` and then
shifts all text tokens that goes after total_len_of_multimodal_embeddings.
It also handles case when multimodal_embeddings is partial
(e.g. one media is split into two prefill stages)
Args:
input_ids: (N,) All input tokens of the prompt (entire sequence).
multimodal_positions: List of mrope positions for each media.
mrope_positions: Existing mrope positions (4, N) for entire sequence.
num_computed_tokens: A number of computed tokens so far.
vision_start_token_id: Token indicating start of vision media.
image_token_id: Image token id
video_token_id: Video token id
Returns:
Tuple of (mrope_positions, mrope_position_delta).
"""
# Tensors
positions: torch.LongTensor = typing.cast(
torch.LongTensor, mrope_positions.clone()
) # (3, N)
N = input_ids.numel()
image_mask = input_ids.eq(image_token_id)
video_mask = input_ids.eq(video_token_id)
media_mask = image_mask | video_mask
text_mask = ~media_mask
# Early exit: no media in this chunk
if len(multimodal_positions) == 0:
delta = int((positions.max().item() + 1) - N) if positions.numel() else -N
return positions, delta
total_mm_tokens = torch.count_nonzero(media_mask)
seen_mm_tokens = torch.count_nonzero(media_mask[:num_computed_tokens])
# Early exit: we've updated positions for all media tokens
# (and consequently - for all remaining text tokens)
if seen_mm_tokens == total_mm_tokens:
delta = int((positions.max().item() + 1) - N) if positions.numel() else -N
return positions, delta
vision_start_indices = (input_ids == vision_start_token_id).nonzero(as_tuple=True)[
0
]
for mm_pos in multimodal_positions:
# Each mm_pos can be a complete embedding for single media
# or it can be a part of a single media (due to chunked prefill)
# Cases to cover
# - Current prefill chunk has no vision start indexes at all
# - Vision start token appeared in previous prefill round
# - Regular case
seen_vision_start_indices = vision_start_indices[
vision_start_indices < num_computed_tokens
]
if len(seen_vision_start_indices):
# If we have encountered some vision start indexes,
# then we should check the condition:
# | --- prefill 1 ------| ---- prefill 2 ----- |
# | TTTTTTTTTSVVVVVVVVVV|VVVVVVTTTTTTTTTTTTTTTT|
last_vision_start_token = seen_vision_start_indices[-1]
seem_mm_tokens_before_last_vision_start = torch.count_nonzero(
media_mask[:last_vision_start_token]
)
in_the_middle_of_media = (
seen_mm_tokens > seem_mm_tokens_before_last_vision_start
)
if in_the_middle_of_media:
mm_embeddings_seen = (
seen_mm_tokens - seem_mm_tokens_before_last_vision_start
)
global_mm_start = last_vision_start_token
else:
# We have completed previous mm_embedding part and
# ready to start a new one
next_vision_start_token = vision_start_indices[
vision_start_indices >= num_computed_tokens
][0]
mm_embeddings_seen = 0
global_mm_start = next_vision_start_token
else:
# If there were no vision start indexes so far,
# let's find first vision start index
next_vision_start_token = vision_start_indices[
vision_start_indices >= num_computed_tokens
][0]
mm_embeddings_seen = 0
global_mm_start = next_vision_start_token
# Offset right after vision_start_token
base = positions[-1, global_mm_start] + 1
local_start = global_mm_start + 1 + mm_embeddings_seen
local_end = local_start + mm_pos.shape[1]
positions[:, local_start:local_end] = mm_pos[0:3] + base
# mm_pos[3, 0] is the max width of the media
offset = mm_pos[3, 0] + base
text_pos_sum = torch.cumsum(text_mask[local_end:].long(), dim=0)
positions[:, local_end:N] = text_pos_sum + offset - 1
# Include distance to the next vision start token
num_computed_tokens += mm_pos.shape[1]
mrope_positions_delta = (positions.max() + 1 - N).item()
return positions, mrope_positions_delta

162
vllm/multimodal/hasher.py Normal file
View File

@@ -0,0 +1,162 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
import hashlib
import pickle
import uuid
from collections.abc import Callable, Iterable
import numpy as np
import torch
from PIL import Image
import vllm.envs as envs
from vllm.logger import init_logger
from .media import MediaWithBytes
logger = init_logger(__name__)
@functools.lru_cache(maxsize=3)
def _get_hasher_factory(algorithm: str) -> Callable[[], "hashlib._Hash"]:
"""
Get the hasher factory based on the configured algorithm.
Args:
algorithm: Hash algorithm name (blake3, sha256, or sha512)
Returns a callable that creates a new hasher instance.
Supports blake3 (default), sha256, and sha512 for FIPS compliance.
See: https://github.com/vllm-project/vllm/issues/18334
"""
algorithm = algorithm.lower()
if algorithm == "blake3":
from blake3 import blake3
return blake3
elif algorithm == "sha256":
return hashlib.sha256
elif algorithm == "sha512":
return hashlib.sha512
else:
# This should never happen due to env_with_choices validation
raise ValueError(f"Unsupported hash algorithm: {algorithm}")
class MultiModalHasher:
@classmethod
def serialize_item(cls, obj: object) -> Iterable[bytes | memoryview]:
# Simple cases
if isinstance(obj, (bytes, memoryview)):
return (obj,)
if isinstance(obj, str):
return (obj.encode("utf-8"),)
if isinstance(obj, (int, float)):
return (np.array(obj).tobytes(),)
if isinstance(obj, Image.Image):
exif = obj.getexif()
if Image.ExifTags.Base.ImageID in exif and isinstance(
exif[Image.ExifTags.Base.ImageID], uuid.UUID
):
return (exif[Image.ExifTags.Base.ImageID].bytes,)
data = {"mode": obj.mode, "data": np.asarray(obj)}
palette = obj.palette
if palette is not None:
data["palette"] = palette.palette
if palette.rawmode is not None:
data["palette_rawmode"] = palette.rawmode
return cls.iter_item_to_bytes("image", data)
if isinstance(obj, MediaWithBytes) and isinstance(obj.media, Image.Image):
exif = obj.media.getexif()
if Image.ExifTags.Base.ImageID in exif and isinstance(
exif[Image.ExifTags.Base.ImageID], uuid.UUID
):
return (exif[Image.ExifTags.Base.ImageID].bytes,)
return cls.iter_item_to_bytes("image", obj.original_bytes)
if isinstance(obj, torch.Tensor):
tensor_obj: torch.Tensor = obj.cpu()
tensor_dtype = tensor_obj.dtype
tensor_shape = tensor_obj.shape
# NumPy does not support bfloat16.
# Workaround: View the tensor as a contiguous 1D array of bytes
if tensor_dtype == torch.bfloat16:
tensor_obj = tensor_obj.contiguous()
tensor_obj = tensor_obj.view((tensor_obj.numel(),)).view(torch.uint8)
return cls.iter_item_to_bytes(
"tensor",
{
"original_dtype": str(tensor_dtype),
"original_shape": tuple(tensor_shape),
"data": tensor_obj.numpy(),
},
)
return cls.iter_item_to_bytes("tensor", tensor_obj.numpy())
if isinstance(obj, np.ndarray):
if obj.ndim == 0:
arr_data = obj.item()
elif obj.flags.c_contiguous:
# Not valid for 0-D arrays
arr_data = obj.view(np.uint8).data
else:
# If the array is non-contiguous, we need to copy it first
arr_data = obj.tobytes()
return cls.iter_item_to_bytes(
"ndarray",
{
"dtype": obj.dtype.str,
"shape": obj.shape,
"data": arr_data,
},
)
logger.warning(
"No serialization method found for %s. Falling back to pickle.", type(obj)
)
return (pickle.dumps(obj),)
@classmethod
def iter_item_to_bytes(
cls,
key: str,
obj: object,
) -> Iterable[bytes | memoryview]:
if obj is None:
yield key.encode("utf-8")
return
# Recursive cases
if isinstance(obj, (list, tuple)):
for i, elem in enumerate(obj):
yield from cls.iter_item_to_bytes(f"{key}.{i}", elem)
elif isinstance(obj, dict):
for k, v in obj.items():
yield from cls.iter_item_to_bytes(f"{key}.{k}", v)
else:
yield key.encode("utf-8")
yield from cls.serialize_item(obj)
@classmethod
def hash_kwargs(cls, **kwargs: object) -> str:
hasher_factory = _get_hasher_factory(envs.VLLM_MM_HASHER_ALGORITHM)
hasher = hasher_factory()
for k, v in sorted(kwargs.items(), key=lambda kv: kv[0]):
for bytes_ in cls.iter_item_to_bytes(k, v):
hasher.update(bytes_)
return hasher.hexdigest()

36
vllm/multimodal/image.py Normal file
View File

@@ -0,0 +1,36 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from PIL import Image
def rescale_image_size(
image: Image.Image, size_factor: float, transpose: int = -1
) -> Image.Image:
"""Rescale the dimensions of an image by a constant factor."""
new_width = int(image.width * size_factor)
new_height = int(image.height * size_factor)
image = image.resize((new_width, new_height))
if transpose >= 0:
image = image.transpose(Image.Transpose(transpose))
return image
def rgba_to_rgb(
image: Image.Image,
background_color: tuple[int, int, int] | list[int] = (255, 255, 255),
) -> Image.Image:
"""Convert an RGBA image to RGB with filled background color."""
assert image.mode == "RGBA"
converted = Image.new("RGB", image.size, background_color)
converted.paste(image, mask=image.split()[3]) # 3 is the alpha channel
return converted
def convert_image_mode(image: Image.Image, to_mode: str):
if image.mode == to_mode:
return image
elif image.mode == "RGBA" and to_mode == "RGB":
return rgba_to_rgb(image)
else:
return image.convert(to_mode)

1159
vllm/multimodal/inputs.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,20 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from .audio import AudioEmbeddingMediaIO, AudioMediaIO
from .base import MediaIO, MediaWithBytes
from .connector import MEDIA_CONNECTOR_REGISTRY, MediaConnector
from .image import ImageEmbeddingMediaIO, ImageMediaIO
from .video import VIDEO_LOADER_REGISTRY, VideoMediaIO
__all__ = [
"MediaIO",
"MediaWithBytes",
"AudioEmbeddingMediaIO",
"AudioMediaIO",
"ImageEmbeddingMediaIO",
"ImageMediaIO",
"VIDEO_LOADER_REGISTRY",
"VideoMediaIO",
"MEDIA_CONNECTOR_REGISTRY",
"MediaConnector",
]

View File

@@ -0,0 +1,89 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import base64
from io import BytesIO
from pathlib import Path
import numpy.typing as npt
import pybase64
import torch
from vllm.utils.import_utils import PlaceholderModule
from vllm.utils.serial_utils import tensor2base64
from .base import MediaIO
try:
import librosa
except ImportError:
librosa = PlaceholderModule("librosa") # type: ignore[assignment]
try:
import soundfile
except ImportError:
soundfile = PlaceholderModule("soundfile") # type: ignore[assignment]
class AudioMediaIO(MediaIO[tuple[npt.NDArray, float]]):
def __init__(self, **kwargs) -> None:
super().__init__()
# `kwargs` contains custom arguments from
# --media-io-kwargs for this modality.
# They can be passed to the underlying
# media loaders (e.g. custom implementations)
# for flexible control.
self.kwargs = kwargs
def load_bytes(self, data: bytes) -> tuple[npt.NDArray, float]:
return librosa.load(BytesIO(data), sr=None)
def load_base64(
self,
media_type: str,
data: str,
) -> tuple[npt.NDArray, float]:
return self.load_bytes(base64.b64decode(data))
def load_file(self, filepath: Path) -> tuple[npt.NDArray, float]:
return librosa.load(filepath, sr=None)
def encode_base64(
self,
media: tuple[npt.NDArray, int],
*,
audio_format: str = "WAV",
) -> str:
audio, sr = media
with BytesIO() as buffer:
soundfile.write(buffer, audio, sr, format=audio_format)
data = buffer.getvalue()
return base64.b64encode(data).decode("utf-8")
class AudioEmbeddingMediaIO(MediaIO[torch.Tensor]):
def __init__(self) -> None:
super().__init__()
def load_bytes(self, data: bytes) -> torch.Tensor:
buffer = BytesIO(data)
# Enable sparse tensor integrity checks to prevent out-of-bounds
# writes from maliciously crafted tensors
with torch.sparse.check_sparse_tensor_invariants():
tensor = torch.load(buffer, weights_only=True)
return tensor.to_dense()
def load_base64(self, media_type: str, data: str) -> torch.Tensor:
return self.load_bytes(pybase64.b64decode(data, validate=True))
def load_file(self, filepath: Path) -> torch.Tensor:
# Enable sparse tensor integrity checks to prevent out-of-bounds
# writes from maliciously crafted tensors
with torch.sparse.check_sparse_tensor_invariants():
tensor = torch.load(filepath, weights_only=True)
return tensor.to_dense()
def encode_base64(self, media: torch.Tensor) -> str:
return tensor2base64(media)

View File

@@ -0,0 +1,61 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Generic, TypeVar
import numpy as np
_T = TypeVar("_T")
@dataclass
class MediaWithBytes(Generic[_T]):
"""
Wrapper that couples a media object with its original encoded bytes.
This ensures the raw bytes and media object remain synchronized,
preventing cache corruption from in-place modifications.
The wrapper delegates attribute access to the underlying media object,
making it behave transparently like the wrapped type (e.g., PIL.Image).
NOTE: Currently, this wrapper is used only for the image modality.
"""
media: _T
original_bytes: bytes = field(repr=False)
def __array__(self, *args, **kwargs) -> np.ndarray:
"""Allow np.array(obj) to return np.array(obj.media)."""
return np.array(self.media, *args, **kwargs)
def __getstate__(self):
return self.__dict__.copy()
def __setstate__(self, state: dict[str, Any]):
self.__dict__.update(state)
def __getattr__(self, name: str):
"""Delegate attribute access to the underlying media object."""
return getattr(self.media, name)
class MediaIO(ABC, Generic[_T]):
@abstractmethod
def load_bytes(self, data: bytes) -> _T:
raise NotImplementedError
@abstractmethod
def load_base64(self, media_type: str, data: str) -> _T:
"""
List of media types:
https://www.iana.org/assignments/media-types/media-types.xhtml
"""
raise NotImplementedError
@abstractmethod
def load_file(self, filepath: Path) -> _T:
raise NotImplementedError

View File

@@ -0,0 +1,343 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import atexit
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import Any, TypeVar
from urllib.request import url2pathname
import numpy as np
import numpy.typing as npt
import torch
from PIL import Image, UnidentifiedImageError
from urllib3.util import Url, parse_url
import vllm.envs as envs
from vllm.connections import HTTPConnection, global_http_connection
from vllm.utils.registry import ExtensionManager
from .audio import AudioEmbeddingMediaIO, AudioMediaIO
from .base import MediaIO
from .image import ImageEmbeddingMediaIO, ImageMediaIO
from .video import VideoMediaIO
_M = TypeVar("_M")
global_thread_pool = ThreadPoolExecutor(
max_workers=envs.VLLM_MEDIA_LOADING_THREAD_COUNT
)
atexit.register(global_thread_pool.shutdown)
MEDIA_CONNECTOR_REGISTRY = ExtensionManager()
@MEDIA_CONNECTOR_REGISTRY.register("http")
class MediaConnector:
def __init__(
self,
media_io_kwargs: dict[str, dict[str, Any]] | None = None,
connection: HTTPConnection = global_http_connection,
*,
allowed_local_media_path: str = "",
allowed_media_domains: list[str] | None = None,
) -> None:
"""
Args:
media_io_kwargs: Additional args passed to process media
inputs, keyed by modalities. For example,
to set num_frames for video, set
`--media-io-kwargs '{"video":{"num_frames":40}}'`
connection: HTTP connection client to download media contents.
allowed_local_media_path: A local directory to load media files from.
allowed_media_domains: If set, only media URLs that belong to this
domain can be used for multi-modal inputs.
"""
super().__init__()
self.media_io_kwargs: dict[str, dict[str, Any]] = (
media_io_kwargs if media_io_kwargs else {}
)
self.connection = connection
if allowed_local_media_path:
allowed_local_media_path_ = Path(allowed_local_media_path)
if not allowed_local_media_path_.exists():
raise ValueError(
"Invalid `--allowed-local-media-path`: The path "
f"{allowed_local_media_path_} does not exist."
)
if not allowed_local_media_path_.is_dir():
raise ValueError(
"Invalid `--allowed-local-media-path`: The path "
f"{allowed_local_media_path_} must be a directory."
)
else:
allowed_local_media_path_ = None
self.allowed_local_media_path = allowed_local_media_path_
if allowed_media_domains is None:
allowed_media_domains = []
self.allowed_media_domains = allowed_media_domains
def _load_data_url(
self,
url_spec: Url,
media_io: MediaIO[_M],
) -> _M: # type: ignore[type-var]
url_spec_path = url_spec.path or ""
data_spec, data = url_spec_path.split(",", 1)
media_type, data_type = data_spec.split(";", 1)
# media_type starts with a leading "/" (e.g., "/video/jpeg")
media_type = media_type.lstrip("/")
if data_type != "base64":
msg = "Only base64 data URLs are supported for now."
raise NotImplementedError(msg)
return media_io.load_base64(media_type, data)
def _load_file_url(
self,
url_spec: Url,
media_io: MediaIO[_M],
) -> _M: # type: ignore[type-var]
allowed_local_media_path = self.allowed_local_media_path
if allowed_local_media_path is None:
raise RuntimeError(
"Cannot load local files without `--allowed-local-media-path`."
)
url_spec_path = url_spec.path or ""
url_spec_netloc = url_spec.netloc or ""
filepath = Path(url2pathname(url_spec_netloc + url_spec_path))
if allowed_local_media_path not in filepath.resolve().parents:
raise ValueError(
f"The file path {filepath} must be a subpath "
f"of `--allowed-local-media-path {allowed_local_media_path}`."
)
return media_io.load_file(filepath)
def _assert_url_in_allowed_media_domains(self, url_spec: Url) -> None:
if (
self.allowed_media_domains
and url_spec.hostname not in self.allowed_media_domains
):
raise ValueError(
f"The URL must be from one of the allowed domains: "
f"{self.allowed_media_domains}. Input URL domain: "
f"{url_spec.hostname}"
)
def load_from_url(
self,
url: str,
media_io: MediaIO[_M],
*,
fetch_timeout: int | None = None,
) -> _M: # type: ignore[type-var]
url_spec = parse_url(url)
if url_spec.scheme and url_spec.scheme.startswith("http"):
self._assert_url_in_allowed_media_domains(url_spec)
connection = self.connection
data = connection.get_bytes(
url_spec.url,
timeout=fetch_timeout,
allow_redirects=envs.VLLM_MEDIA_URL_ALLOW_REDIRECTS,
)
return media_io.load_bytes(data)
if url_spec.scheme == "data":
return self._load_data_url(url_spec, media_io)
if url_spec.scheme == "file":
return self._load_file_url(url_spec, media_io)
msg = "The URL must be either a HTTP, data or file URL."
raise ValueError(msg)
async def load_from_url_async(
self,
url: str,
media_io: MediaIO[_M],
*,
fetch_timeout: int | None = None,
) -> _M:
url_spec = parse_url(url)
loop = asyncio.get_running_loop()
if url_spec.scheme and url_spec.scheme.startswith("http"):
self._assert_url_in_allowed_media_domains(url_spec)
connection = self.connection
data = await connection.async_get_bytes(
url_spec.url,
timeout=fetch_timeout,
allow_redirects=envs.VLLM_MEDIA_URL_ALLOW_REDIRECTS,
)
future = loop.run_in_executor(global_thread_pool, media_io.load_bytes, data)
return await future
if url_spec.scheme == "data":
future = loop.run_in_executor(
global_thread_pool, self._load_data_url, url_spec, media_io
)
return await future
if url_spec.scheme == "file":
future = loop.run_in_executor(
global_thread_pool, self._load_file_url, url_spec, media_io
)
return await future
msg = "The URL must be either a HTTP, data or file URL."
raise ValueError(msg)
def fetch_audio(
self,
audio_url: str,
) -> tuple[np.ndarray, int | float]:
"""
Load audio from a URL.
"""
audio_io = AudioMediaIO(**self.media_io_kwargs.get("audio", {}))
return self.load_from_url(
audio_url,
audio_io,
fetch_timeout=envs.VLLM_AUDIO_FETCH_TIMEOUT,
)
async def fetch_audio_async(
self,
audio_url: str,
) -> tuple[np.ndarray, int | float]:
"""
Asynchronously fetch audio from a URL.
"""
audio_io = AudioMediaIO(**self.media_io_kwargs.get("audio", {}))
return await self.load_from_url_async(
audio_url,
audio_io,
fetch_timeout=envs.VLLM_AUDIO_FETCH_TIMEOUT,
)
def fetch_image(
self,
image_url: str,
*,
image_mode: str = "RGB",
) -> Image.Image:
"""
Load a PIL image from an HTTP or base64 data URL.
By default, the image is converted into RGB format.
"""
image_io = ImageMediaIO(
image_mode=image_mode, **self.media_io_kwargs.get("image", {})
)
try:
return self.load_from_url(
image_url,
image_io,
fetch_timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT,
)
except UnidentifiedImageError as e:
# convert to ValueError to be properly caught upstream
raise ValueError(str(e)) from e
async def fetch_image_async(
self,
image_url: str,
*,
image_mode: str = "RGB",
) -> Image.Image:
"""
Asynchronously load a PIL image from an HTTP or base64 data URL.
By default, the image is converted into RGB format.
"""
image_io = ImageMediaIO(
image_mode=image_mode, **self.media_io_kwargs.get("image", {})
)
try:
return await self.load_from_url_async(
image_url,
image_io,
fetch_timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT,
)
except UnidentifiedImageError as e:
# convert to ValueError to be properly caught upstream
raise ValueError(str(e)) from e
def fetch_video(
self,
video_url: str,
*,
image_mode: str = "RGB",
) -> tuple[npt.NDArray, dict[str, Any]]:
"""
Load video from an HTTP or base64 data URL.
"""
image_io = ImageMediaIO(
image_mode=image_mode, **self.media_io_kwargs.get("image", {})
)
video_io = VideoMediaIO(image_io, **self.media_io_kwargs.get("video", {}))
return self.load_from_url(
video_url,
video_io,
fetch_timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT,
)
async def fetch_video_async(
self,
video_url: str,
*,
image_mode: str = "RGB",
) -> tuple[npt.NDArray, dict[str, Any]]:
"""
Asynchronously load video from an HTTP or base64 data URL.
By default, the image is converted into RGB format.
"""
image_io = ImageMediaIO(
image_mode=image_mode, **self.media_io_kwargs.get("image", {})
)
video_io = VideoMediaIO(image_io, **self.media_io_kwargs.get("video", {}))
return await self.load_from_url_async(
video_url,
video_io,
fetch_timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT,
)
def fetch_image_embedding(
self,
data: str,
) -> torch.Tensor:
"""
Load image embedding from a URL.
"""
image_embedding_io = ImageEmbeddingMediaIO()
return image_embedding_io.load_base64("", data)
def fetch_audio_embedding(
self,
data: str,
) -> torch.Tensor:
"""
Load audio embedding from a URL.
"""
audio_embedding_io = AudioEmbeddingMediaIO()
return audio_embedding_io.load_base64("", data)

View File

@@ -0,0 +1,113 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from io import BytesIO
from pathlib import Path
import pybase64
import torch
from PIL import Image
from vllm.utils.serial_utils import tensor2base64
from ..image import convert_image_mode, rgba_to_rgb
from .base import MediaIO, MediaWithBytes
class ImageMediaIO(MediaIO[Image.Image]):
def __init__(self, image_mode: str = "RGB", **kwargs) -> None:
super().__init__()
self.image_mode = image_mode
# `kwargs` contains custom arguments from
# --media-io-kwargs for this modality.
# They can be passed to the underlying
# media loaders (e.g. custom implementations)
# for flexible control.
self.kwargs = kwargs
# Extract RGBA background color from kwargs if provided
# Default to white background for backward compatibility
rgba_bg = kwargs.get("rgba_background_color", (255, 255, 255))
# Convert list to tuple for consistency
if isinstance(rgba_bg, list):
rgba_bg = tuple(rgba_bg)
# Validate rgba_background_color format
if not (
isinstance(rgba_bg, tuple)
and len(rgba_bg) == 3
and all(isinstance(c, int) and 0 <= c <= 255 for c in rgba_bg)
):
raise ValueError(
"rgba_background_color must be a list or tuple of 3 integers "
"in the range [0, 255]."
)
self.rgba_background_color = rgba_bg
def _convert_image_mode(
self, image: Image.Image | MediaWithBytes[Image.Image]
) -> Image.Image:
"""Convert image mode with custom background color."""
if isinstance(image, MediaWithBytes):
image = image.media
if image.mode == self.image_mode:
return image
elif image.mode == "RGBA" and self.image_mode == "RGB":
return rgba_to_rgb(image, self.rgba_background_color)
else:
return convert_image_mode(image, self.image_mode)
def load_bytes(self, data: bytes) -> MediaWithBytes[Image.Image]:
image = Image.open(BytesIO(data))
return MediaWithBytes(self._convert_image_mode(image), data)
def load_base64(self, media_type: str, data: str) -> MediaWithBytes[Image.Image]:
return self.load_bytes(pybase64.b64decode(data, validate=True))
def load_file(self, filepath: Path) -> MediaWithBytes[Image.Image]:
with open(filepath, "rb") as f:
data = f.read()
image = Image.open(BytesIO(data))
return MediaWithBytes(self._convert_image_mode(image), data)
def encode_base64(
self,
media: Image.Image,
*,
image_format: str = "PNG",
) -> str:
image = media
with BytesIO() as buffer:
image = self._convert_image_mode(image)
image.save(buffer, image_format)
data = buffer.getvalue()
return pybase64.b64encode(data).decode("utf-8")
class ImageEmbeddingMediaIO(MediaIO[torch.Tensor]):
def __init__(self) -> None:
super().__init__()
def load_bytes(self, data: bytes) -> torch.Tensor:
buffer = BytesIO(data)
# Enable sparse tensor integrity checks to prevent out-of-bounds
# writes from maliciously crafted tensors
with torch.sparse.check_sparse_tensor_invariants():
tensor = torch.load(buffer, weights_only=True)
return tensor.to_dense()
def load_base64(self, media_type: str, data: str) -> torch.Tensor:
return self.load_bytes(pybase64.b64decode(data, validate=True))
def load_file(self, filepath: Path) -> torch.Tensor:
# Enable sparse tensor integrity checks to prevent out-of-bounds
# writes from maliciously crafted tensors
with torch.sparse.check_sparse_tensor_invariants():
tensor = torch.load(filepath, weights_only=True)
return tensor.to_dense()
def encode_base64(self, media: torch.Tensor) -> str:
return tensor2base64(media)

View File

@@ -0,0 +1,89 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import base64
from functools import partial
from pathlib import Path
from typing import Any
import numpy as np
import numpy.typing as npt
from PIL import Image
from vllm import envs
from ..video import VIDEO_LOADER_REGISTRY
from .base import MediaIO
from .image import ImageMediaIO
class VideoMediaIO(MediaIO[tuple[npt.NDArray, dict[str, Any]]]):
def __init__(
self,
image_io: ImageMediaIO,
num_frames: int = 32,
**kwargs,
) -> None:
super().__init__()
self.image_io = image_io
self.num_frames = num_frames
# `kwargs` contains custom arguments from
# --media-io-kwargs for this modality.
# They can be passed to the underlying
# media loaders (e.g. custom implementations)
# for flexible control.
# Allow per-request override of video backend via kwargs.
# This enables users to specify a different backend than the
# global VLLM_VIDEO_LOADER_BACKEND env var, e.g.:
# --media-io-kwargs '{"video": {"video_backend": "torchcodec"}}'
video_loader_backend = (
kwargs.pop("video_backend", None) or envs.VLLM_VIDEO_LOADER_BACKEND
)
self.kwargs = kwargs
self.video_loader = VIDEO_LOADER_REGISTRY.load(video_loader_backend)
def load_bytes(self, data: bytes) -> tuple[npt.NDArray, dict[str, Any]]:
return self.video_loader.load_bytes(
data, num_frames=self.num_frames, **self.kwargs
)
def load_base64(
self, media_type: str, data: str
) -> tuple[npt.NDArray, dict[str, Any]]:
if media_type.lower() == "video/jpeg":
load_frame = partial(
self.image_io.load_base64,
"image/jpeg",
)
return np.stack(
[np.asarray(load_frame(frame_data)) for frame_data in data.split(",")]
), {}
return self.load_bytes(base64.b64decode(data))
def load_file(self, filepath: Path) -> tuple[npt.NDArray, dict[str, Any]]:
with filepath.open("rb") as f:
data = f.read()
return self.load_bytes(data)
def encode_base64(
self,
media: npt.NDArray,
*,
video_format: str = "JPEG",
) -> str:
video = media
if video_format == "JPEG":
encode_frame = partial(
self.image_io.encode_base64,
image_format=video_format,
)
return ",".join(encode_frame(Image.fromarray(frame)) for frame in video)
msg = "Only JPEG format is supported for now."
raise NotImplementedError(msg)

710
vllm/multimodal/parse.py Normal file
View File

@@ -0,0 +1,710 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from collections import UserDict
from collections.abc import Callable, Iterator, Mapping, Sequence, Set
from typing import (
TYPE_CHECKING,
Any,
Generic,
Literal,
NamedTuple,
TypeAlias,
TypeGuard,
TypeVar,
)
import numpy as np
import torch
from typing_extensions import assert_never
from vllm.utils.collection_utils import is_list_of
from vllm.utils.import_utils import LazyLoader
from .audio import AudioResampler, AudioSpec, normalize_audio
from .inputs import (
AudioItem,
HfAudioItem,
HfImageItem,
HfVideoItem,
ImageItem,
ModalityData,
MultiModalDataDict,
MultiModalFieldConfig,
MultiModalKwargsItems,
MultiModalUUIDDict,
VideoItem,
)
from .media import MediaWithBytes
_T = TypeVar("_T")
_I = TypeVar("_I")
if TYPE_CHECKING:
import PIL.Image as PILImage
else:
PILImage = LazyLoader("PILImage", globals(), "PIL.Image")
class ModalityDataItems(ABC, Generic[_T, _I]):
"""
Represents data items for a modality in
[`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems].
"""
def __init__(self, data: _T, modality: str) -> None:
super().__init__()
self.data: _T = data
self.modality = modality
def __repr__(self) -> str:
return f"{type(self).__name__}(modality={self.modality!r}, len={len(self)})"
def __len__(self) -> int:
return self.get_count()
def __getitem__(self, index: int) -> _I:
return self.get(index)
if TYPE_CHECKING:
# Auto-generated
def __iter__(self) -> Iterator[_I]: ...
@abstractmethod
def get_count(self) -> int:
"""Get the number of data items."""
raise NotImplementedError
@abstractmethod
def get(self, index: int) -> _I:
"""Get a data item by its index."""
raise NotImplementedError
def get_all(self) -> list[_I]:
"""Get all data items."""
return [self.get(idx) for idx in range(self.get_count())]
def get_item_for_hash(self, index: int) -> object:
return self.get(index)
def get_all_items_for_hash(self) -> list[object]:
return [self.get_item_for_hash(idx) for idx in range(self.get_count())]
@abstractmethod
def get_processor_data(self) -> Mapping[str, object]:
"""Get the data to pass to the HF processor."""
raise NotImplementedError
@abstractmethod
def get_passthrough_data(self) -> Mapping[str, object]:
"""Get the data to pass directly to the model."""
raise NotImplementedError
class ProcessorBatchItems(ModalityDataItems[Sequence[_T], _T]):
"""Base class for data items that are arranged in a list."""
def _unwrap(self, item: _T | MediaWithBytes[_T]) -> _T:
"""Extract media from wrapper if present."""
return item.media if isinstance(item, MediaWithBytes) else item
def get_count(self) -> int:
return len(self.data)
def get(self, index: int) -> _T:
return self._unwrap(self.data[index])
def get_item_for_hash(self, index: int) -> _T | MediaWithBytes[_T]:
# Return raw item for hashing (preserves original_bytes if present)
return self.data[index]
def get_processor_data(self) -> Mapping[str, object]:
return {f"{self.modality}s": self.get_all()}
def get_passthrough_data(self) -> Mapping[str, object]:
return {}
def validate_embedding_ndim(
tensor: torch.Tensor,
modality: str,
index: int | None = None,
) -> None:
"""Validate tensor ndim for multimodal embeddings.
Single embeddings should be 2D (seq_len, hidden_size).
Batched embeddings should be 3D (batch, seq_len, hidden_size).
Args:
tensor: The tensor to validate.
modality: The modality name for error messages (e.g., "image", "audio").
index: Optional index for list items, included in error messages.
"""
if tensor.ndim < 2 or tensor.ndim > 3:
idx_str = f" [{index}]" if index is not None else ""
raise ValueError(
f"{modality.capitalize()} embedding{idx_str} must be 2D "
f"(seq_len, hidden_size) or 3D (batch, seq_len, hidden_size), "
f"got {tensor.ndim}D tensor with shape {tuple(tensor.shape)}"
)
class EmbeddingItems(
ModalityDataItems[torch.Tensor | list[torch.Tensor], torch.Tensor]
):
"""
Base class for data items that are expressed as a batched embedding tensor,
or a list of embedding tensors (one per item).
"""
def __init__(
self,
data: torch.Tensor | list[torch.Tensor],
modality: str,
expected_hidden_size: int | None = None,
) -> None:
super().__init__(data, modality)
# Validate ndim first (before hidden_size which depends on correct ndim)
self._validate_ndim()
# Validate hidden dimension if expected size is provided
if expected_hidden_size is not None:
self._validate_hidden_size(expected_hidden_size)
def _validate_ndim(self) -> None:
"""Validate that embedding tensors have correct ndim (2D or 3D)."""
if isinstance(self.data, torch.Tensor):
validate_embedding_ndim(self.data, self.modality)
else:
# List of tensors: each should be 2D (seq_len, hidden_size)
for idx, tensor in enumerate(self.data):
if tensor.ndim != 2:
raise ValueError(
f"{self.modality.capitalize()} embedding [{idx}] must be "
f"2D (seq_len, hidden_size), got {tensor.ndim}D tensor "
f"with shape {tuple(tensor.shape)}"
)
def _validate_hidden_size(self, expected_hidden_size: int) -> None:
"""Validate that embedding hidden dimension matches expected size.
This validates hidden dimensions to prevent vulnerabilities: Embeddings
with correct ndim but wrong hidden dimension could bypass initial
checks and cause crashes during model inference when dimensions don't match.
"""
if isinstance(self.data, torch.Tensor):
# Batched tensor: shape is (batch, seq_len, hidden_size)
actual_hidden_size = self.data.shape[-1]
if actual_hidden_size != expected_hidden_size:
raise ValueError(
f"{self.modality.capitalize()} embedding hidden dimension "
f"mismatch: got {actual_hidden_size}, but model expects "
f"{expected_hidden_size}. Embedding shape: {tuple(self.data.shape)}"
)
else:
# List of tensors: each has shape (seq_len, hidden_size)
for idx, tensor in enumerate(self.data):
actual_hidden_size = tensor.shape[-1]
if actual_hidden_size != expected_hidden_size:
raise ValueError(
f"{self.modality.capitalize()} embedding [{idx}] hidden "
f"dimension mismatch: got {actual_hidden_size}, but model "
f"expects {expected_hidden_size}. "
f"Embedding shape: {tuple(tensor.shape)}"
)
def _unwrap(
self, item: torch.Tensor | MediaWithBytes[torch.Tensor]
) -> torch.Tensor:
"""Extract media from wrapper if present."""
return item.media if isinstance(item, MediaWithBytes) else item
def get_count(self) -> int:
return len(self.data)
def get(self, index: int) -> torch.Tensor:
return self._unwrap(self.data[index])
def get_processor_data(self) -> Mapping[str, object]:
return {}
def get_passthrough_data(self) -> Mapping[str, object]:
return {f"{self.modality}_embeds": self.data}
def get_feature_size(self, item_idx: int) -> int:
return len(self.get(item_idx))
class DictEmbeddingItems(
ModalityDataItems[Mapping[str, torch.Tensor], Mapping[str, torch.Tensor]]
):
"""
Base class for data items that are expressed as a dictionary of tensors.
Usually, the dictionary keys correspond to the outputs of HF processor.
"""
def __init__(
self,
data: Mapping[str, torch.Tensor],
modality: str,
required_fields: set[str],
fields_factory: Callable[
[Mapping[str, torch.Tensor]],
Mapping[str, MultiModalFieldConfig],
],
) -> None:
from transformers.feature_extraction_utils import BatchFeature
super().__init__(data, modality)
missing_required_data_keys = required_fields - data.keys()
if missing_required_data_keys:
data_keys = set(data.keys())
msg = (
f"The data should contain the fields: {required_fields}, "
f"but only found the following keys: {data_keys}"
)
raise ValueError(msg)
fields_config = fields_factory(data)
missing_required_fields = required_fields - fields_config.keys()
if missing_required_fields:
fields = set(fields_config.keys())
msg = f"{required_fields=} should be a subset of {fields=}"
raise ValueError(msg)
self.fields_config = fields_config
self.required_fields = required_fields
self._kwargs = MultiModalKwargsItems.from_hf_inputs(
BatchFeature(dict(data)),
fields_config,
)
def get_count(self) -> int:
return len(self._kwargs[self.modality])
def get(self, index: int) -> Mapping[str, torch.Tensor]:
return self._kwargs[self.modality][index].get_data()
def get_processor_data(self) -> Mapping[str, object]:
return {}
def get_passthrough_data(self) -> Mapping[str, object]:
return self.data
class AudioProcessorItems(ProcessorBatchItems[HfAudioItem | None]):
def __init__(self, data: Sequence[HfAudioItem | None]) -> None:
super().__init__(data, "audio")
def get_audio_length(self, item_idx: int) -> int:
audio = self.get(item_idx)
if audio is None:
raise ValueError(f"Cannot get length of cached audio at {item_idx}")
return len(audio)
class AudioEmbeddingItems(EmbeddingItems):
def __init__(
self,
data: torch.Tensor | list[torch.Tensor],
expected_hidden_size: int | None = None,
) -> None:
super().__init__(data, "audio", expected_hidden_size)
class ImageSize(NamedTuple):
width: int
height: int
class ImageProcessorItems(ProcessorBatchItems[HfImageItem | None]):
def __init__(self, data: Sequence[HfImageItem | None]) -> None:
super().__init__(data, "image")
def get_image_size(self, item_idx: int) -> ImageSize:
image = self.get(item_idx)
if image is None:
raise ValueError(f"Cannot get size of cached image at {item_idx}")
if isinstance(image, PILImage.Image):
return ImageSize(*image.size)
if isinstance(image, (np.ndarray, torch.Tensor)):
_, h, w = image.shape
return ImageSize(w, h)
assert_never(image)
class ImageEmbeddingItems(EmbeddingItems):
def __init__(
self,
data: torch.Tensor | list[torch.Tensor],
expected_hidden_size: int | None = None,
) -> None:
super().__init__(data, "image", expected_hidden_size)
class VideoProcessorItems(ProcessorBatchItems[HfVideoItem | None]):
def __init__(
self,
data: Sequence[HfVideoItem | None],
metadata: dict[str, Any] | list[dict[str, Any] | None] | None = None,
) -> None:
super().__init__(data, "video")
self.metadata = metadata
def get_num_frames(self, item_idx: int) -> int:
video = self.get(item_idx)
if video is None:
raise ValueError(f"Cannot get length of cached video at {item_idx}")
return len(video)
def get_frame_size(self, item_idx: int) -> ImageSize:
video = self.get(item_idx)
if video is None:
raise ValueError(f"Cannot get size of cached video at {item_idx}")
if len(video) == 0:
raise ValueError(f"Cannot get size of empty video at {item_idx}")
image = video[0]
if isinstance(image, PILImage.Image):
return ImageSize(*image.size)
if isinstance(image, (np.ndarray, torch.Tensor)):
_, h, w = image.shape
return ImageSize(w, h)
assert_never(image)
class VideoEmbeddingItems(EmbeddingItems):
def __init__(
self,
data: torch.Tensor | list[torch.Tensor],
expected_hidden_size: int | None = None,
) -> None:
super().__init__(data, "video", expected_hidden_size)
class VisionChunkProcessorItems(ProcessorBatchItems[Any]):
"""Processor items for vision chunks (unified image and video chunks)."""
def __init__(self, data: Sequence[Any]) -> None:
super().__init__(data, "vision_chunk")
_D = TypeVar("_D", bound=ModalityDataItems[Any, Any])
class MultiModalDataItems(UserDict[str, ModalityDataItems[Any, Any]]):
"""
As [`MultiModalDataDict`][vllm.multimodal.inputs.MultiModalDataDict], but
normalized such that each entry corresponds to a list.
"""
def select(self, modalities: Set[str]):
"""
Construct a new `MultiModalDataItems` instance containing only the
selected modalities.
"""
return MultiModalDataItems(
{modality: self[modality] for modality in modalities}
)
def get_count(self, modality: str, *, strict: bool = True) -> int:
"""
Get the number of data items belonging to a modality.
If `strict=False`, return `0` instead of raising [`KeyError`][]
even if the modality is not found.
"""
if modality not in self:
if strict:
available_modalities = set(self.keys())
raise KeyError(
f"Modality {modality!r} not found. "
f"Available modalities: {available_modalities}"
)
return 0
return self[modality].get_count()
def get_all_counts(self) -> Mapping[str, int]:
"""Get the number of items belonging to each modality."""
return {m: items.get_count() for m, items in self.items()}
def get_items(
self,
modality: str,
typ: type[_D] | tuple[type[_D], ...],
) -> _D:
"""
Get the data items belonging to a modality,
requiring that they belong to a certain type.
"""
if modality not in self:
available_modalities = set(self.keys())
raise KeyError(
f"Modality {modality!r} not found. "
f"Available modalities: {available_modalities}"
)
items = self[modality]
if not isinstance(items, typ):
raise TypeError(
f"Invalid type of data items for {modality=}. "
f"Expected type: {typ}, but "
f"found type: {type(items)}"
)
return items # type: ignore[return-value]
ModalityDataParser: TypeAlias = Callable[
[ModalityData[Any]], ModalityDataItems[Any, Any] | None
]
class MultiModalDataParser:
"""
Parses [`MultiModalDataDict`][vllm.multimodal.inputs.MultiModalDataDict]
into [`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems].
Args:
target_sr (float, optional): Enables automatic resampling of audio
items to the model's expected sampling rate.
target_channels (int, optional): Target number of audio channels.
If provided, normalizes audio to this many channels (e.g., 1 for mono).
If None, audio channels are passed through unchanged.
expected_hidden_size (int, optional): Expected hidden dimension for
embedding inputs. If provided, validates that user-supplied
embeddings have the correct hidden size to prevent crashes
during model inference.
"""
def __init__(
self,
*,
target_sr: float | None = None,
target_channels: int | None = None,
audio_resample_method: Literal["librosa", "scipy"] = "librosa",
video_needs_metadata: bool = False,
expected_hidden_size: int | None = None,
) -> None:
super().__init__()
self.audio_resampler = AudioResampler(
target_sr=target_sr,
method=audio_resample_method,
)
self.target_channels = target_channels
self.video_needs_metadata = video_needs_metadata
self.expected_hidden_size = expected_hidden_size
@classmethod
def is_embeddings(
cls, data: object
) -> TypeGuard[torch.Tensor | list[torch.Tensor]]:
if isinstance(data, torch.Tensor):
return data.ndim == 3
if is_list_of(data, torch.Tensor) and len(data) > 0:
return data[0].ndim == 2 # type: ignore[index]
return False
def _get_audio_with_sr(
self,
audio: AudioItem,
) -> tuple[np.ndarray, float | None]:
if isinstance(audio, tuple):
return audio
if isinstance(audio, list):
return np.array(audio), None
if isinstance(audio, np.ndarray):
return audio, None
if isinstance(audio, torch.Tensor):
return audio.numpy(), None
assert_never(audio)
def _get_video_with_metadata(
self,
video: VideoItem,
) -> tuple[np.ndarray, dict[str, Any] | None]:
if isinstance(video, tuple):
return video
if isinstance(video, list):
return np.array(video), None
if isinstance(video, np.ndarray):
return video, None
if isinstance(video, torch.Tensor):
return video.numpy(), None
assert_never(video)
def _parse_audio_data(
self,
data: ModalityData[AudioItem],
) -> ModalityDataItems[Any, Any] | None:
if data is None:
return None
if self.is_embeddings(data):
return AudioEmbeddingItems(data, self.expected_hidden_size)
data_items: list[AudioItem]
if (
(is_list_of(data, float) and len(data) > 0)
or (isinstance(data, (np.ndarray, torch.Tensor)) and data.ndim == 1)
or isinstance(data, tuple)
):
data_items = [data]
elif isinstance(data, (np.ndarray, torch.Tensor)):
data_items = [elem for elem in data]
else:
data_items = data # type: ignore[assignment]
new_audios = list[np.ndarray]()
for data_item in data_items:
audio, orig_sr = self._get_audio_with_sr(data_item)
if orig_sr is None:
new_audio = audio
else:
new_audio = self.audio_resampler.resample(audio, orig_sr=orig_sr)
# Apply channel normalization if target_channels is set
if self.target_channels is not None:
spec = AudioSpec(target_channels=self.target_channels)
new_audio = normalize_audio(new_audio, spec)
new_audios.append(new_audio)
return AudioProcessorItems(new_audios)
def _parse_image_data(
self,
data: ModalityData[ImageItem],
) -> ModalityDataItems[Any, Any] | None:
if data is None:
return None
if self.is_embeddings(data):
return ImageEmbeddingItems(data, self.expected_hidden_size)
if isinstance(data, (PILImage.Image, MediaWithBytes)) or (
isinstance(data, (np.ndarray, torch.Tensor)) and data.ndim == 3
):
data_items = [data]
elif isinstance(data, (np.ndarray, torch.Tensor)):
data_items = [elem for elem in data]
else:
data_items = data
return ImageProcessorItems(data_items)
def _parse_video_data(
self,
data: ModalityData[VideoItem],
) -> ModalityDataItems[Any, Any] | None:
if data is None:
return None
if self.is_embeddings(data):
return VideoEmbeddingItems(data, self.expected_hidden_size)
data_items: list[VideoItem]
if (is_list_of(data, PILImage.Image) and len(data) > 0) or (
isinstance(data, (np.ndarray, torch.Tensor)) and data.ndim == 4
):
data_items = [data]
elif isinstance(data, (np.ndarray, torch.Tensor)):
data_items = [elem for elem in data]
elif isinstance(data, tuple) and len(data) == 2:
data_items = [data]
else:
data_items = data # type: ignore[assignment]
new_videos = list[tuple[np.ndarray, dict[str, Any] | None]]()
metadata_lst: list[dict[str, Any] | None] = []
for data_item in data_items:
video, metadata = self._get_video_with_metadata(data_item)
if self.video_needs_metadata:
if metadata is None:
raise ValueError(
"Video metadata is required but not found in mm input. "
"Please check your video input in `multi_modal_data`"
)
new_videos.append((video, metadata))
metadata_lst.append(metadata)
else:
new_videos.append(video)
if not self.video_needs_metadata:
metadata = None
return VideoProcessorItems(new_videos, metadata=metadata_lst)
def _parse_vision_chunk_data(
self,
data: ModalityData[Any],
) -> ModalityDataItems[Any, Any] | None:
"""Parse vision chunk data (unified image and video chunks)."""
if data is None:
return None
if self.is_embeddings(data):
raise ValueError("Do not support embedding data for vision_chunk right now")
if isinstance(data, dict):
data = [data]
return VisionChunkProcessorItems(data)
def _get_subparsers(self) -> Mapping[str, ModalityDataParser]:
return {
"audio": self._parse_audio_data,
"image": self._parse_image_data,
"video": self._parse_video_data,
"vision_chunk": self._parse_vision_chunk_data,
}
def parse_mm_data(self, mm_data: MultiModalDataDict) -> MultiModalDataItems:
subparsers = self._get_subparsers()
mm_items = MultiModalDataItems()
for k, v in mm_data.items():
if k not in subparsers:
raise ValueError(f"Unsupported modality: {k}")
# ignore empty embedding data
if (parsed_data := subparsers[k](v)) is not None:
mm_items[k] = parsed_data
return mm_items
MultiModalUUIDItems: TypeAlias = dict[str, Sequence[str | None]]
"""
As [`MultiModalUUIDDict`][vllm.multimodal.inputs.MultiModalUUIDDict], but
normalized such that each entry corresponds to a list.
"""
def parse_mm_uuids(mm_uuids: MultiModalUUIDDict | None) -> MultiModalUUIDItems:
if mm_uuids is None:
return {}
return {
modality: [uuids] if isinstance(uuids, str) else uuids
for modality, uuids in mm_uuids.items()
}

View File

@@ -0,0 +1,29 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from .context import BaseProcessingInfo, InputProcessingContext, TimingContext
from .dummy_inputs import BaseDummyInputsBuilder
from .inputs import ProcessorInputs
from .processor import (
BaseMultiModalProcessor,
EncDecMultiModalProcessor,
PromptIndexTargets,
PromptInsertion,
PromptReplacement,
PromptUpdate,
PromptUpdateDetails,
)
__all__ = [
"BaseProcessingInfo",
"InputProcessingContext",
"TimingContext",
"BaseDummyInputsBuilder",
"ProcessorInputs",
"BaseMultiModalProcessor",
"EncDecMultiModalProcessor",
"PromptUpdate",
"PromptIndexTargets",
"PromptUpdateDetails",
"PromptInsertion",
"PromptReplacement",
]

View File

@@ -0,0 +1,507 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time
from abc import abstractmethod
from collections.abc import Mapping
from contextlib import contextmanager
from dataclasses import dataclass, field
from functools import cached_property
from typing import TYPE_CHECKING, Any, overload
import torch
from typing_extensions import TypeVar
from vllm.logger import init_logger
from vllm.multimodal.inputs import MultiModalDataDict
from vllm.multimodal.parse import (
DictEmbeddingItems,
EmbeddingItems,
MultiModalDataItems,
MultiModalDataParser,
)
from vllm.renderers import TokenizeParams
from vllm.tokenizers import TokenizerLike
from vllm.transformers_utils.processor import cached_processor_from_config
from vllm.utils.func_utils import get_allowed_kwarg_only_overrides
from vllm.utils.jsontree import JSONTree, json_map_leaves
from vllm.utils.mistral import is_mistral_tokenizer
if TYPE_CHECKING:
from transformers.configuration_utils import PretrainedConfig
from transformers.feature_extraction_utils import BatchFeature
from transformers.processing_utils import ProcessorMixin
from vllm.config import ModelConfig
else:
PretrainedConfig = object
BatchFeature = object
ProcessorMixin = object
ModelConfig = object
logger = init_logger(__name__)
@dataclass
class TimingContext:
"""Helper class to record execution times during multi-modal processing."""
enabled: bool = True
"""If disabled, `TimingContext.record` becomes a no-op."""
stage_secs: dict[str, float] = field(default_factory=dict)
"""The execution time (in seconds) for each processing stage."""
@property
def total_secs(self) -> float:
return sum(self.stage_secs.values())
@contextmanager
def record(self, stage: str):
"""Record the execution time for a processing stage."""
if not self.enabled:
yield
return
start_time = time.perf_counter()
try:
yield
finally:
elapsed = time.perf_counter() - start_time
self.stage_secs.setdefault(stage, 0.0)
self.stage_secs[stage] += elapsed
def get_stats_dict(self):
stats_dict = {
f"{stage}_secs": time_s for stage, time_s in self.stage_secs.items()
}
stats_dict["preprocessor_total_secs"] = self.total_secs
return stats_dict
_T = TypeVar("_T")
_C = TypeVar("_C", bound=PretrainedConfig, default=PretrainedConfig)
_P = TypeVar("_P", bound=ProcessorMixin, default=ProcessorMixin)
@dataclass(frozen=True)
class InputProcessingContext:
"""
Contains information about the model which may be used to
modify the inputs.
"""
model_config: ModelConfig
"""The configuration of the model."""
tokenizer: TokenizerLike | None
"""The tokenizer used to tokenize the inputs."""
def get_tokenizer(self) -> TokenizerLike:
if self.tokenizer is None:
raise ValueError(
"You cannot pass text prompts when `skip_tokenizer_init=True`"
)
return self.tokenizer
@overload
def get_hf_config(self, /) -> PretrainedConfig: ...
@overload
def get_hf_config(
self,
typ: type[_C] | tuple[type[_C], ...],
/,
) -> _C: ...
def get_hf_config(
self,
typ: type[Any] | tuple[type[Any], ...] | None = None,
/,
) -> Any:
"""
Get the HuggingFace configuration
(`transformers.PretrainedConfig`) of the model,
additionally checking its type.
Raises:
TypeError: If the configuration is not of the specified type.
"""
if typ is None:
from transformers.configuration_utils import PretrainedConfig
typ = PretrainedConfig
hf_config = self.model_config.hf_config
if not isinstance(hf_config, typ):
raise TypeError(
"Invalid type of HuggingFace config. "
f"Expected type: {typ}, but "
f"found type: {type(hf_config)}"
)
return hf_config
def get_hf_image_processor_config(self) -> dict[str, Any]:
"""
Get the HuggingFace image processor configuration of the model.
"""
return self.model_config.hf_image_processor_config
def get_mm_config(self):
"""
Get the multimodal config of the model.
Raises:
RuntimeError: If the model is not a multimodal model.
"""
mm_config = self.model_config.multimodal_config
if mm_config is None:
raise RuntimeError("Not a multimodal model")
return mm_config
@overload
def get_hf_processor(self, /, **kwargs: object) -> ProcessorMixin: ...
@overload
def get_hf_processor(
self,
typ: type[_P] | tuple[type[_P], ...],
/,
**kwargs: object,
) -> _P: ...
def get_hf_processor(
self,
typ: type[Any] | tuple[type[Any], ...] | None = None,
/,
**kwargs: object,
) -> Any:
"""
Get the HuggingFace processor
(`transformers.ProcessorMixin`) of the model,
additionally checking its type.
Raises:
TypeError: If the processor is not of the specified type.
"""
if typ is None:
from transformers.processing_utils import ProcessorMixin
typ = ProcessorMixin
tokenizer = self.tokenizer
if is_mistral_tokenizer(tokenizer):
tokenizer = tokenizer.transformers_tokenizer
merged_kwargs = self.get_merged_mm_kwargs(kwargs)
merged_kwargs.pop("tokenizer", None)
return cached_processor_from_config(
self.model_config,
processor_cls=typ,
tokenizer=tokenizer,
**merged_kwargs,
)
def init_processor(
self,
typ: type[_T],
/,
**kwargs: object,
) -> _T:
"""
Initialize a HuggingFace-like processor class, merging the
keyword arguments with those in the model's configuration.
"""
merged_kwargs = self.get_merged_mm_kwargs(kwargs)
return typ(**merged_kwargs)
def _postprocess_output(
self,
output: JSONTree,
) -> JSONTree:
def _postprocess_one(x: object):
if isinstance(x, torch.Tensor): # noqa: SIM102
# This mimics the behavior of transformers.BatchFeature
if x.is_floating_point():
x = x.to(dtype=self.model_config.dtype)
return x
return json_map_leaves(_postprocess_one, output)
def get_merged_mm_kwargs(self, kwargs: Mapping[str, object]):
mm_config = self.model_config.get_multimodal_config()
return mm_config.merge_mm_processor_kwargs(kwargs)
def call_hf_processor(
self,
hf_processor: ProcessorMixin,
data: Mapping[str, object],
kwargs: Mapping[str, object] = {},
*,
num_tries: int = 1,
max_tries: int = 5,
) -> BatchFeature | JSONTree:
"""
Call `hf_processor` on the prompt `data`
(text, image, audio...) with configurable options `kwargs`.
"""
assert callable(hf_processor)
merged_kwargs = self.get_merged_mm_kwargs(kwargs)
allowed_kwargs = get_allowed_kwarg_only_overrides(
hf_processor,
merged_kwargs,
requires_kw_only=False,
allow_var_kwargs=True,
)
try:
output = hf_processor(**data, **allowed_kwargs, return_tensors="pt")
except Exception as exc:
# See https://github.com/huggingface/tokenizers/issues/537
if (
isinstance(exc, RuntimeError)
and exc
and exc.args[0] == "Already borrowed"
and num_tries < max_tries
):
logger.warning(
"Failed to acquire tokenizer in current thread. "
"Retrying (%d/%d)...",
num_tries,
max_tries,
)
time.sleep(0.5)
return self.call_hf_processor(
hf_processor,
data,
kwargs,
num_tries=num_tries + 1,
max_tries=max_tries,
)
msg = (
f"Failed to apply {type(hf_processor).__name__} "
f"on data={data} with kwargs={allowed_kwargs}"
)
raise ValueError(msg) from exc
# this emulates output.to(dtype=self.model_config.dtype)
from transformers.feature_extraction_utils import BatchFeature
if isinstance(output, BatchFeature):
output_ = self._postprocess_output(output.data)
return BatchFeature(output_)
logger.warning_once(
"%s did not return `BatchFeature`. "
"Make sure to match the behaviour of `ProcessorMixin` when "
"implementing custom processors.",
type(hf_processor).__name__,
)
return self._postprocess_output(output)
class BaseProcessingInfo:
"""Base class to provide the information necessary for data processing."""
def __init__(self, ctx: InputProcessingContext) -> None:
super().__init__()
self.ctx = ctx
@property
def model_id(self) -> str:
return self.ctx.model_config.model
def get_tokenizer(self) -> TokenizerLike:
return self.ctx.get_tokenizer()
def get_hf_config(self) -> PretrainedConfig:
return self.ctx.get_hf_config()
def get_hf_processor(self, **kwargs: object) -> ProcessorMixin:
"""
Subclasses can override this method to handle
specific kwargs from model config or user inputs.
"""
return self.ctx.get_hf_processor(**kwargs)
def get_default_tok_params(self) -> TokenizeParams:
"""Construct the default parameters for tokenization."""
model_config = self.ctx.model_config
encoder_config = model_config.encoder_config or {}
return TokenizeParams(
max_total_tokens=model_config.max_model_len,
do_lower_case=encoder_config.get("do_lower_case", False),
add_special_tokens=True,
)
@cached_property
def default_tok_params(self) -> TokenizeParams:
return self.get_default_tok_params()
def _get_expected_hidden_size(self) -> int | None:
"""
Get expected hidden size for embedding validation if `mm_embeds` are enabled.
This validates hidden dimensions to prevent a vulnerability where embeddings
with correct `ndim` but wrong `shape` could cause crashes at inference time.
"""
model_config = self.ctx.model_config
mm_config = model_config.get_multimodal_config()
if mm_config.enable_mm_embeds:
return model_config.get_inputs_embeds_size()
return None
def get_data_parser(self) -> MultiModalDataParser:
"""
Constructs a parser to preprocess multi-modal data items
before passing them to
[`_get_hf_mm_data`][vllm.multimodal.processing.BaseMultiModalProcessor._get_hf_mm_data].
You can support additional modalities by creating a subclass
of [`MultiModalDataParser`][vllm.multimodal.parse.MultiModalDataParser]
that has additional subparsers.
"""
return MultiModalDataParser(
expected_hidden_size=self._get_expected_hidden_size(),
)
@cached_property
def data_parser(self) -> MultiModalDataParser:
return self.get_data_parser()
@property
def skip_prompt_length_check(self) -> bool:
return False
@abstractmethod
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
"""
Return the maximum supported number of items for each modality.
A value of `None` means unlimited number of items.
Omitting a modality from the returned dictionary means that
it is not supported at all.
"""
raise NotImplementedError
@cached_property
def supported_mm_limits(self) -> Mapping[str, int | None]:
"""The maximum supported number of items for each modality."""
return self.get_supported_mm_limits()
@cached_property
def allowed_mm_limits(self) -> Mapping[str, int]:
"""The maximum allowed number of items for each modality."""
mm_config = self.ctx.get_mm_config()
allowed_limits = dict[str, int]()
for modality, supported_limit in self.supported_mm_limits.items():
user_limit = mm_config.get_limit_per_prompt(modality)
allowed_limits[modality] = (
user_limit
if supported_limit is None
else min(user_limit, supported_limit)
)
return allowed_limits
def validate_num_items(self, modality: str, num_items: int) -> None:
"""
Raise `ValueError` if the number of input items for the given modality
is invalid.
"""
supported_limit = self.supported_mm_limits.get(modality, 0)
allowed_limit = self.allowed_mm_limits.get(modality, 0)
if supported_limit is None:
supported_limit = allowed_limit
limit = min(supported_limit, allowed_limit)
if num_items > limit:
msg = f"At most {limit} {modality}(s) may be provided in one prompt."
if num_items <= supported_limit:
msg += " Set `--limit-mm-per-prompt` to increase this limit."
raise ValueError(msg)
def parse_mm_data(
self,
mm_data: MultiModalDataDict,
*,
validate: bool = True,
) -> MultiModalDataItems:
"""
Normalize
[`MultiModalDataDict`][vllm.multimodal.inputs.MultiModalDataDict]
to [`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems]
before passing them to
[`_get_hf_mm_data`][vllm.multimodal.processing.BaseMultiModalProcessor._get_hf_mm_data].
"""
mm_items = self.data_parser.parse_mm_data(mm_data)
if validate:
mm_config = self.ctx.get_mm_config()
for modality, items in mm_items.items():
if isinstance(items, (EmbeddingItems, DictEmbeddingItems)):
if not mm_config.enable_mm_embeds:
raise ValueError(
f"You must set `--enable-mm-embeds` to input "
f"`{modality}_embeds`"
)
if mm_config.get_limit_per_prompt(modality) == 0:
logger.debug(
"Skipping count validation for modality "
"'%s' (embeddings with limit=0)",
modality,
)
continue
self.validate_num_items(modality, len(items))
return mm_items
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int] | None:
"""
Return the maximum number of tokens per item of for each modality.
When `None` (the default) is returned, vLLM will generate dummy inputs
(images/videos) at maximum possible sizes and process them to determine
the maximum token count per modality.
This approach works but can be very slow for certain models (e.g.,
Qwen2.5-VL), leading to very long startup time. For better performance,
each model can override this method to return pre-computed maximum token
counts, avoiding the need for dummy input generation and processing.
Note:
The maximum number of tokens per item of each modality returned
from this function should respect the model's maximum sequence
length and the maximum number of items of each modality allowed,
and agree with dummy inputs (images/videos) at maximum possible
sizes.
"""
return None

View File

@@ -0,0 +1,187 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from collections.abc import Mapping
from typing import Generic, TypeVar
import numpy as np
import numpy.typing as npt
from PIL import Image
from vllm.config.multimodal import (
AudioDummyOptions,
BaseDummyOptions,
ImageDummyOptions,
VideoDummyOptions,
)
from vllm.logger import init_logger
from ..inputs import MultiModalDataDict
from .context import BaseProcessingInfo
from .inputs import ProcessorInputs
_I = TypeVar("_I", bound=BaseProcessingInfo)
logger = init_logger(__name__)
class BaseDummyInputsBuilder(ABC, Generic[_I]):
"""
Abstract base class that constructs the dummy data to profile
multi-modal models.
"""
def __init__(self, info: _I) -> None:
super().__init__()
self.info = info
@abstractmethod
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
"""
Build the text input corresponding to `mm_counts`.
"""
raise NotImplementedError
@abstractmethod
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions],
) -> MultiModalDataDict:
"""
Build the multimodal input which, after processing, results in
the maximum possible number of placeholder tokens.
Args:
seq_len: Sequence length
mm_counts: Count of items per modality
mm_options: Configurable options per modality (optional).
If None, use model defaults for backward compatibility.
If provided, models can use these to customize dummy
data generation.
"""
raise NotImplementedError
def get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions],
) -> ProcessorInputs:
"""
Build the input which, after processing, results in
the maximum possible number of placeholder tokens.
Args:
seq_len: Sequence length
mm_counts: Count of items per modality
mm_options: Configurable options per modality (optional)
"""
dummy_text = self.get_dummy_text(mm_counts)
dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts, mm_options)
dummy_mm_items = self.info.parse_mm_data(dummy_mm_data, validate=False)
tokenization_kwargs = {"truncation": False}
return ProcessorInputs(
prompt=dummy_text,
mm_data_items=dummy_mm_items,
tokenization_kwargs=tokenization_kwargs,
)
def _get_dummy_audios(
self,
*,
length: int,
num_audios: int,
overrides: AudioDummyOptions | None = None,
) -> list[npt.NDArray]:
if num_audios == 0:
return []
if overrides and overrides.length:
if overrides.length > length:
logger.warning(
"audio.length override (%d) exceeds model's "
"maximum length (%d), will be ignored",
overrides.length,
length,
)
length = min(length, overrides.length)
audio = np.zeros((length,))
return [audio] * num_audios
def _get_dummy_images(
self,
*,
width: int,
height: int,
num_images: int,
overrides: ImageDummyOptions | None = None,
) -> list[Image.Image]:
if num_images == 0:
return []
if overrides:
if overrides.width:
if overrides.width > width:
logger.warning(
"image.width override (%d) exceeds model's "
"maximum width (%d), will be ignored",
overrides.width,
width,
)
width = min(width, overrides.width)
if overrides.height:
if overrides.height > height:
logger.warning(
"image.height override (%d) exceeds model's "
"maximum height (%d), will be ignored",
overrides.height,
height,
)
height = min(height, overrides.height)
image = Image.new("RGB", (width, height), color=255)
return [image] * num_images
def _get_dummy_videos(
self,
*,
width: int,
height: int,
num_frames: int,
num_videos: int,
overrides: VideoDummyOptions | None = None,
) -> list[npt.NDArray]:
if num_videos == 0:
return []
if overrides:
if overrides.num_frames:
if overrides.num_frames > num_frames:
logger.warning(
"video.num_frames override (%d) exceeds model's "
"maximum number of frames (%d), will be ignored",
overrides.num_frames,
num_frames,
)
num_frames = min(num_frames, overrides.num_frames)
if overrides.width:
if overrides.width > width:
logger.warning(
"video.width override (%d) exceeds model's "
"maximum width (%d), will be ignored",
overrides.width,
width,
)
width = min(width, overrides.width)
if overrides.height:
if overrides.height > height:
logger.warning(
"video.height override (%d) exceeds model's "
"maximum height (%d), will be ignored",
overrides.height,
height,
)
height = min(height, overrides.height)
video = np.full((num_frames, width, height, 3), 255, dtype=np.uint8)
return [video] * num_videos

View File

@@ -0,0 +1,70 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Mapping
from dataclasses import dataclass, field
from ..hasher import MultiModalHasher
from ..inputs import MultiModalHashes
from ..parse import MultiModalDataItems, MultiModalUUIDItems
@dataclass
class ProcessorInputs:
"""
Represents the keyword arguments to
[`vllm.multimodal.processing.BaseMultiModalProcessor.apply`][].
"""
prompt: str | list[int]
mm_data_items: MultiModalDataItems
mm_uuid_items: MultiModalUUIDItems | None = None
hf_processor_mm_kwargs: Mapping[str, object] = field(default_factory=dict)
tokenization_kwargs: Mapping[str, object] = field(default_factory=dict)
def get_mm_hashes(self, model_id: str) -> MultiModalHashes:
mm_data_items = self.mm_data_items
mm_uuid_items = self.mm_uuid_items or {}
hf_processor_mm_kwargs = self.hf_processor_mm_kwargs
mm_hashes: MultiModalHashes = {}
hasher = MultiModalHasher
for modality, data_items in mm_data_items.items():
if modality in mm_uuid_items:
uuid_items = mm_uuid_items[modality]
# For None entries, compute a hash; otherwise, use provided ID.
hashes: list[str] = []
for i, item in enumerate(data_items.get_all_items_for_hash()):
uuid_item = uuid_items[i]
# NOTE: Even if a uuid_item is provided, we still compute a hash
# if `hf_processor_mm_kwargs` is provided.
# This is because the processed multimodal inputs can be different
# depending on the processor kwargs.
if uuid_item is None or hf_processor_mm_kwargs:
# NOTE: use provided hash string to hash with kwargs
# if available for better performance.
item = uuid_item if uuid_item is not None else item
hashes.append(
hasher.hash_kwargs(
model_id=model_id,
**{modality: item},
**hf_processor_mm_kwargs,
)
)
else:
hashes.append(uuid_item)
mm_hashes[modality] = hashes
else:
mm_hashes[modality] = [
hasher.hash_kwargs(
model_id=model_id,
**{modality: item},
**hf_processor_mm_kwargs,
)
for item in data_items
]
return mm_hashes

File diff suppressed because it is too large Load Diff

362
vllm/multimodal/registry.py Normal file
View File

@@ -0,0 +1,362 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import threading
from collections import defaultdict
from collections.abc import Mapping
from dataclasses import dataclass
from multiprocessing.synchronize import Lock as LockType
from typing import TYPE_CHECKING, Generic, Literal, Protocol, TypeVar, cast
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config
from .cache import (
BaseMultiModalProcessorCache,
BaseMultiModalReceiverCache,
MultiModalProcessorOnlyCache,
MultiModalProcessorSenderCache,
MultiModalReceiverCache,
ShmObjectStoreReceiverCache,
ShmObjectStoreSenderCache,
)
from .inputs import MultiModalInputs
from .processing import (
BaseDummyInputsBuilder,
BaseMultiModalProcessor,
BaseProcessingInfo,
InputProcessingContext,
TimingContext,
)
if TYPE_CHECKING:
from vllm.config import ModelConfig, ObservabilityConfig, VllmConfig
from vllm.model_executor.models.interfaces import SupportsMultiModal
logger = init_logger(__name__)
N = TypeVar("N", bound=type["SupportsMultiModal"])
_I = TypeVar("_I", bound=BaseProcessingInfo)
_I_co = TypeVar("_I_co", bound=BaseProcessingInfo, covariant=True)
class ProcessingInfoFactory(Protocol[_I_co]):
"""
Constructs a
[`BaseMultiModalProcessor`][vllm.multimodal.processing.BaseMultiModalProcessor]
instance from the context.
"""
def __call__(
self,
ctx: InputProcessingContext,
) -> _I_co: ...
class DummyInputsBuilderFactory(Protocol[_I]): # type: ignore[misc]
"""
Constructs a
[`BaseDummyInputsBuilder`][vllm.multimodal.processing.BaseDummyInputsBuilder]
instance from the context.
"""
def __call__(self, info: _I) -> BaseDummyInputsBuilder[_I]: ...
class MultiModalProcessorFactory(Protocol[_I]): # type: ignore[misc]
"""
Constructs a
[`BaseMultiModalProcessor`][vllm.multimodal.processing.BaseMultiModalProcessor]
instance from the context.
"""
def __call__(
self,
info: _I,
dummy_inputs: BaseDummyInputsBuilder[_I],
*,
cache: BaseMultiModalProcessorCache | None = None,
) -> BaseMultiModalProcessor[_I]: ...
@dataclass(frozen=True)
class _ProcessorFactories(Generic[_I]):
info: ProcessingInfoFactory[_I]
processor: MultiModalProcessorFactory[_I]
dummy_inputs: DummyInputsBuilderFactory[_I]
def build_processor(
self,
ctx: InputProcessingContext,
*,
cache: BaseMultiModalProcessorCache | None = None,
):
info = self.info(ctx)
dummy_inputs_builder = self.dummy_inputs(info)
return self.processor(info, dummy_inputs_builder, cache=cache)
class MultiModalRegistry:
"""
A registry that dispatches data processing according to the model.
"""
def supports_multimodal_inputs(self, model_config: "ModelConfig") -> bool:
"""
Checks if the model supports multimodal inputs.
Returns True if the model is multimodal with any non-zero supported
modalities, otherwise returns False, effectively running in
text-only mode.
"""
if not model_config.is_multimodal_model:
return False
mm_config = model_config.get_multimodal_config()
info = self._create_processing_info(model_config, tokenizer=None)
# Check if all supported modalities have limit == 0
if all(
mm_config.get_limit_per_prompt(modality) == 0
for modality in info.supported_mm_limits
):
# If enable_mm_embeds is True, we still need MM infrastructure
# to process pre-computed embeddings even though encoder won't run
if mm_config.enable_mm_embeds:
return True
logger.info_once(
"All limits of multimodal modalities supported by the model "
"are set to 0, running in text-only mode."
)
return False
return True
def register_processor(
self,
processor: MultiModalProcessorFactory[_I],
*,
info: ProcessingInfoFactory[_I],
dummy_inputs: DummyInputsBuilderFactory[_I],
):
"""
Register a multi-modal processor to a model class. The processor
is constructed lazily, hence a factory method should be passed.
When the model receives multi-modal data, the provided function is
invoked to transform the data into a dictionary of model inputs.
"""
def wrapper(model_cls: N) -> N:
if "_processor_factory" in model_cls.__dict__:
logger.warning(
"Model class %s already has a multi-modal processor "
"registered to %s. It is overwritten by the new one.",
model_cls,
self,
)
model_cls._processor_factory = _ProcessorFactories(
info=info,
dummy_inputs=dummy_inputs,
processor=processor,
)
return model_cls
return wrapper
def _get_model_cls(self, model_config: "ModelConfig") -> "SupportsMultiModal":
# Avoid circular import
from vllm.model_executor.model_loader import get_model_architecture
model_cls, _ = get_model_architecture(model_config)
assert hasattr(model_cls, "_processor_factory")
return cast("SupportsMultiModal", model_cls)
def _create_processing_ctx(
self,
model_config: "ModelConfig",
tokenizer: TokenizerLike | None = None,
) -> InputProcessingContext:
if tokenizer is None:
tokenizer = cached_tokenizer_from_config(model_config)
return InputProcessingContext(model_config, tokenizer)
def _create_processing_info(
self,
model_config: "ModelConfig",
tokenizer: TokenizerLike | None = None,
) -> BaseProcessingInfo:
model_cls = self._get_model_cls(model_config)
factories = model_cls._processor_factory
ctx = self._create_processing_ctx(model_config, tokenizer)
return factories.info(ctx)
def create_processor(
self,
model_config: "ModelConfig",
*,
tokenizer: TokenizerLike | None = None,
cache: BaseMultiModalProcessorCache | None = None,
) -> BaseMultiModalProcessor[BaseProcessingInfo]:
"""
Create a multi-modal processor for a specific model and tokenizer.
"""
if not model_config.is_multimodal_model:
raise ValueError(f"{model_config.model} is not a multimodal model")
model_cls = self._get_model_cls(model_config)
factories = model_cls._processor_factory
ctx = self._create_processing_ctx(model_config, tokenizer)
return factories.build_processor(ctx, cache=cache)
def get_dummy_mm_inputs(
self,
model_config: "ModelConfig",
mm_counts: Mapping[str, int],
*,
cache: BaseMultiModalProcessorCache | None = None,
processor: BaseMultiModalProcessor | None = None,
) -> MultiModalInputs:
"""
Create dummy data for profiling the memory usage of a model.
The model is identified by `model_config`.
"""
seq_len = model_config.max_model_len
if processor is None:
processor = self.create_processor(model_config, cache=cache)
mm_config = model_config.get_multimodal_config()
processor_inputs = processor.dummy_inputs.get_dummy_processor_inputs(
seq_len=seq_len,
mm_counts=mm_counts,
mm_options=mm_config.limit_per_prompt,
)
mm_inputs = processor.apply(
processor_inputs,
timing_ctx=TimingContext(enabled=False),
)
prompt_token_ids = mm_inputs["prompt_token_ids"]
total_len = len(prompt_token_ids)
if total_len < seq_len:
prompt_token_ids.extend([0] * (seq_len - total_len))
return mm_inputs
def _get_cache_type(
self,
vllm_config: "VllmConfig",
) -> Literal[None, "processor_only", "lru", "shm"]:
model_config = vllm_config.model_config
if not self.supports_multimodal_inputs(model_config):
return None
# Check if the cache is disabled.
mm_config = model_config.get_multimodal_config()
if mm_config.mm_processor_cache_gb <= 0:
return None
# Check if IPC caching is supported.
parallel_config = vllm_config.parallel_config
is_ipc_supported = parallel_config._api_process_count == 1 and (
parallel_config.data_parallel_size == 1
or parallel_config.data_parallel_external_lb
)
if not is_ipc_supported:
return "processor_only"
mm_config = model_config.get_multimodal_config()
return mm_config.mm_processor_cache_type
def processor_cache_from_config(
self,
vllm_config: "VllmConfig",
) -> BaseMultiModalProcessorCache | None:
"""Return a `BaseMultiModalProcessorCache`, if enabled."""
cache_type = self._get_cache_type(vllm_config)
if cache_type is None:
return None
elif cache_type == "processor_only":
return MultiModalProcessorOnlyCache(vllm_config.model_config)
elif cache_type == "lru":
return MultiModalProcessorSenderCache(vllm_config.model_config)
elif cache_type == "shm":
return ShmObjectStoreSenderCache(vllm_config)
else:
raise ValueError(f"Unknown cache type: {cache_type!r}")
def processor_only_cache_from_config(
self,
vllm_config: "VllmConfig",
) -> MultiModalProcessorOnlyCache | None:
"""Return a `MultiModalProcessorOnlyCache`, if enabled."""
cache_type = self._get_cache_type(vllm_config)
if cache_type is None:
return None
return MultiModalProcessorOnlyCache(vllm_config.model_config)
def engine_receiver_cache_from_config(
self,
vllm_config: "VllmConfig",
) -> BaseMultiModalReceiverCache | None:
"""Return a `BaseMultiModalReceiverCache` for the engine process."""
cache_type = self._get_cache_type(vllm_config)
if cache_type in (None, "processor_only", "shm"):
return None
elif cache_type == "lru":
return MultiModalReceiverCache(vllm_config.model_config)
else:
raise ValueError(f"Unknown cache type: {cache_type!r}")
def worker_receiver_cache_from_config(
self,
vllm_config: "VllmConfig",
shared_worker_lock: LockType,
) -> BaseMultiModalReceiverCache | None:
"""Return a `BaseMultiModalReceiverCache` for the worker process."""
cache_type = self._get_cache_type(vllm_config)
if cache_type in (None, "processor_only", "lru"):
return None
elif cache_type == "shm":
return ShmObjectStoreReceiverCache(vllm_config, shared_worker_lock)
else:
raise ValueError(f"Unknown cache type: {cache_type!r}")
class MultiModalTimingRegistry:
def __init__(self, observability_config: "ObservabilityConfig | None") -> None:
super().__init__()
if observability_config and observability_config.enable_mm_processor_stats:
self._lock = threading.Lock()
self._ctx_by_request_id = defaultdict[str, TimingContext](TimingContext)
self._enabled = True
else:
self._enabled = False
def get(self, request_id: str) -> TimingContext:
if not self._enabled:
return TimingContext(enabled=False)
with self._lock:
return self._ctx_by_request_id[request_id]
def stat(self) -> dict[str, dict[str, float]]:
if not self._enabled:
return {}
with self._lock:
stats = {
req_id: ctx.get_stats_dict()
for req_id, ctx in self._ctx_by_request_id.items()
}
self._ctx_by_request_id.clear()
return stats

327
vllm/multimodal/utils.py Normal file
View File

@@ -0,0 +1,327 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import mimetypes
import warnings
from collections import defaultdict
from collections.abc import Generator, Sequence
from itertools import groupby
from typing import TYPE_CHECKING, Any
import numpy as np
import numpy.typing as npt
from PIL import Image
from vllm.utils.import_utils import LazyLoader
from .hasher import MultiModalHasher
from .inputs import (
BatchedTensorInputs,
MultiModalFieldElem,
MultiModalKwargsItem,
MultiModalPlaceholderDict,
MultiModalSharedField,
)
from .media import AudioMediaIO, ImageMediaIO, MediaConnector, VideoMediaIO
if TYPE_CHECKING:
import torch.types
else:
torch = LazyLoader("torch", globals(), "torch")
def __getattr__(name: str):
if name == "MEDIA_CONNECTOR_REGISTRY":
from .media import MEDIA_CONNECTOR_REGISTRY
warnings.warn(
"`vllm.multimodal.utils.MEDIA_CONNECTOR_REGISTRY` "
"has been moved to `vllm.multimodal.media.MEDIA_CONNECTOR_REGISTRY`. "
"The old name will be removed in v0.17.",
DeprecationWarning,
stacklevel=2,
)
return MEDIA_CONNECTOR_REGISTRY
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
def encode_audio_base64(
audio: np.ndarray,
sampling_rate: int,
*,
format: str = "WAV",
) -> str:
"""Encode audio as base64."""
audio_io = AudioMediaIO()
return audio_io.encode_base64((audio, sampling_rate), audio_format=format)
def encode_audio_url(
audio: np.ndarray,
sampling_rate: int,
*,
format: str = "WAV",
) -> str:
"""Encode audio as a data URL."""
audio_b64 = encode_audio_base64(audio, sampling_rate, format=format)
mimetype = mimetypes.types_map.get("." + format.lower(), "audio")
return f"data:{mimetype};base64,{audio_b64}"
def encode_image_base64(
image: Image.Image,
*,
image_mode: str = "RGB",
format: str = "PNG",
) -> str:
"""
Encode a pillow image to base64 format.
By default, the image is converted into RGB format before being encoded.
"""
image_io = ImageMediaIO(image_mode=image_mode)
return image_io.encode_base64(image, image_format=format)
def encode_image_url(
image: Image.Image,
*,
image_mode: str = "RGB",
format: str = "PNG",
) -> str:
"""
Encode a pillow image as a data URL.
By default, the image is converted into RGB format before being encoded.
"""
image_b64 = encode_image_base64(image, image_mode=image_mode, format=format)
mimetype = mimetypes.types_map.get("." + format.lower(), "image")
return f"data:{mimetype};base64,{image_b64}"
def encode_video_base64(
frames: npt.NDArray,
*,
format: str = "JPEG",
) -> str:
image_io = ImageMediaIO()
video_io = VideoMediaIO(image_io)
return video_io.encode_base64(frames, video_format=format)
def encode_video_url(
frames: npt.NDArray,
*,
format: str = "JPEG",
) -> str:
video_b64 = encode_video_base64(frames, format=format)
if format.lower() == "jpeg":
mimetype = "video/jpeg"
else:
mimetype = mimetypes.types_map.get("." + format.lower(), "video")
return f"data:{mimetype};base64,{video_b64}"
def argsort_mm_positions(
mm_positions: MultiModalPlaceholderDict,
) -> list[tuple[str, int]]:
"""
Given a `MultiModalPlaceholderDict`, output a sequence of keys to
sort the dictionary by `offset` (starting index in the input sequence)
in ascending order.
Returns:
A list of `(modality, idx)`, which can be used to access an item
by `mm_positions[modality][idx]`.
"""
flat_items = (
(modality, idx, item)
for modality, items in mm_positions.items()
for idx, item in enumerate(items)
)
sorted_flat_items = sorted(flat_items, key=lambda x: x[2].offset)
return [(modality, idx) for modality, idx, _ in sorted_flat_items]
def _get_group_hash(elem: MultiModalFieldElem):
if not isinstance(elem.field, MultiModalSharedField):
return None
return MultiModalHasher.hash_kwargs(data=elem.data)
def _batch_mm_items(
items: Sequence[MultiModalKwargsItem],
*,
device: torch.types.Device = None,
pin_memory: bool = False,
):
elems = defaultdict[str, list[MultiModalFieldElem]](list)
for item in items:
for key, elem in item.items():
elems[key].append(elem)
return {
key: elems[0].field.reduce_data(
elems,
device=device,
pin_memory=pin_memory,
)
for key, elems in elems.items()
}
def group_and_batch_mm_items(
items: Sequence[MultiModalKwargsItem],
*,
device: torch.types.Device = None,
pin_memory: bool = False,
) -> Generator[tuple[int, BatchedTensorInputs]]:
"""
Group consecutive items (possibly from different requests) into batches.
Items must be split across groups if any of the following occurs,
as the batch would otherwise be invalid:
- They have different fields (e.g. mixed image and embedding inputs).
- They have different values in `MultiModalSharedField`.
Args:
items: List of `MultiModalKwargsItem`.
device: The device to place the grouped tensors on.
pin_memory: Whether to pin memory for faster host-to-device transfer.
Yields:
A tuple `(num_items, grouped_kwargs)`, where:
- `kwargs` is a dictionary of keyword arguments to pass to the model;
- `num_items` is the corresponding number of items.
"""
group_ids = [
tuple(
(key, _get_group_hash(elem))
for key, elem in sorted(item.items(), key=lambda kv: kv[0])
)
for item in items
]
group_sizes = [sum(1 for _ in group) for _, group in groupby(group_ids)]
start_idx = 0
for group_size in group_sizes:
group_data = _batch_mm_items(
items[start_idx : start_idx + group_size],
device=device,
pin_memory=pin_memory,
)
yield group_size, group_data
start_idx += group_size
assert start_idx == len(items)
def group_mm_kwargs_by_modality(
mm_kwargs: list[tuple[str, MultiModalKwargsItem]],
*,
device: torch.types.Device = None,
pin_memory: bool = False,
) -> Generator[tuple[str, int, BatchedTensorInputs], None, None]:
"""
Group consecutive items (possibly from different requests) into batches.
Items must be split across groups if any of the following occurs,
as the batch would otherwise be invalid:
- They have different fields (e.g. mixed image and embedding inputs).
- They have different values in `MultiModalSharedField`.
To simplify the implementation of `embed_multimodal`, we add another
restriction that the items in a batch must belong to the same modality.
Args:
mm_kwargs: List of `(modality, item)`.
device: The device to place the grouped tensors on.
pin_memory: Whether to pin memory for faster host-to-device transfer.
Yields:
A tuple `(modality, num_items, grouped_kwargs)`, where:
- `modality` is the modality of the batch;
- `kwargs` is a dictionary of keyword arguments to pass to the model;
- `num_items` is the corresponding number of items.
"""
for modality, group in groupby(mm_kwargs, key=lambda x: x[0]):
items_lst = [item for _, item in group]
for num_items, mm_kwargs_batch in group_and_batch_mm_items(
items_lst,
device=device,
pin_memory=pin_memory,
):
yield modality, num_items, mm_kwargs_batch
def fetch_audio(
audio_url: str,
audio_io_kwargs: dict[str, Any] | None = None,
) -> tuple[np.ndarray, int | float]:
"""
Args:
audio_url: URL of the audio file to fetch.
audio_io_kwargs: Additional kwargs passed to handle audio IO.
Warning:
This method has direct access to local files and is only intended
to be called by user code. Never call this from the online server!
"""
media_io_kwargs = None if not audio_io_kwargs else {"audio": audio_io_kwargs}
media_connector = MediaConnector(
media_io_kwargs=media_io_kwargs,
allowed_local_media_path="/",
)
return media_connector.fetch_audio(audio_url)
def fetch_image(
image_url: str,
image_io_kwargs: dict[str, Any] | None = None,
) -> Image.Image:
"""
Args:
image_url: URL of the image file to fetch.
image_io_kwargs: Additional kwargs passed to handle image IO.
Warning:
This method has direct access to local files and is only intended
to be called by user code. Never call this from the online server!
"""
media_io_kwargs = None if not image_io_kwargs else {"image": image_io_kwargs}
media_connector = MediaConnector(
media_io_kwargs=media_io_kwargs,
allowed_local_media_path="/",
)
return media_connector.fetch_image(image_url)
def fetch_video(
video_url: str,
video_io_kwargs: dict[str, Any] | None = None,
) -> tuple[npt.NDArray, dict[str, Any]]:
"""
Args:
video_url: URL of the video file to fetch.
video_io_kwargs: Additional kwargs passed to handle video IO.
Warning:
This method has direct access to local files and is only intended
to be called by user code. Never call this from the online server!
"""
media_io_kwargs = None if not video_io_kwargs else {"video": video_io_kwargs}
media_connector = MediaConnector(
media_io_kwargs=media_io_kwargs,
allowed_local_media_path="/",
)
return media_connector.fetch_video(video_url)

836
vllm/multimodal/video.py Normal file
View File

@@ -0,0 +1,836 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from abc import abstractmethod
from io import BytesIO
from typing import TYPE_CHECKING, Any, cast
import numpy as np
import numpy.typing as npt
if TYPE_CHECKING:
import cv2
from vllm.logger import init_logger
from vllm.utils.registry import ExtensionManager
logger = init_logger(__name__)
def resize_video(frames: npt.NDArray, size: tuple[int, int]) -> npt.NDArray:
num_frames, _, _, channels = frames.shape
new_height, new_width = size
resized_frames = np.empty(
(num_frames, new_height, new_width, channels), dtype=frames.dtype
)
# lazy import cv2 to avoid bothering users who only use text models
import cv2
for i, frame in enumerate(frames):
resized_frame = cv2.resize(frame, (new_width, new_height))
resized_frames[i] = resized_frame
return resized_frames
def rescale_video_size(frames: npt.NDArray, size_factor: float) -> npt.NDArray:
_, height, width, _ = frames.shape
new_height = int(height * size_factor)
new_width = int(width * size_factor)
return resize_video(frames, (new_height, new_width))
def sample_frames_from_video(frames: npt.NDArray, num_frames: int) -> npt.NDArray:
total_frames = frames.shape[0]
if num_frames == -1:
return frames
frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
sampled_frames = frames[frame_indices, ...]
return sampled_frames
class VideoLoader:
@classmethod
@abstractmethod
def load_bytes(
cls, data: bytes, num_frames: int = -1, **kwargs
) -> tuple[npt.NDArray, dict[str, Any]]:
raise NotImplementedError
@staticmethod
def _can_use_for_recovery(
idx: int,
failed_frames: list[int],
next_target_map: dict[int, int],
total_frames: int,
) -> bool:
"""Check if current frame can recover the oldest failed frame."""
if not failed_frames:
return False
oldest_failed = failed_frames[0]
limit = next_target_map.get(oldest_failed, total_frames)
return idx < limit
@staticmethod
def _read_frames_with_recovery(
cap: "cv2.VideoCapture",
frame_indices: list[int],
total_frames: int,
) -> tuple[npt.NDArray, list[int], dict[int, int]]:
"""
Read frames with dynamic window forward-scan recovery.
When a target frame fails to load, the next successfully grabbed
frame (before the next target frame) will be used to recover it.
Args:
cap: OpenCV VideoCapture object
frame_indices: Sorted list of target frame indices to load
total_frames: Total number of frames in the video
Returns:
Tuple of (frames_array, valid_frame_indices, recovered_map)
- frames_array: Array of loaded frames
- valid_frame_indices: List of frame indices that were loaded
- recovered_map: Dict mapping recovered_idx -> source_idx
"""
import cv2
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
assert width > 0 and height > 0, (
f"Invalid video frame size: width={width}, height={height}"
)
frame_idx_set = set(frame_indices)
max_frame_idx = frame_indices[-1] if frame_indices else 0
# Build map: target_idx -> next_target_idx (for recovery window)
next_target_map: dict[int, int] = {}
for k in range(len(frame_indices) - 1):
next_target_map[frame_indices[k]] = frame_indices[k + 1]
next_target_map[frame_indices[-1]] = total_frames
frames_list: list[npt.NDArray] = []
valid_frame_indices: list[int] = []
failed_frames_idx: list[int] = []
recovered_map: dict[int, int] = {}
i = 0
for idx in range(max_frame_idx + 1):
is_target_frame = idx in frame_idx_set
# Attempt to grab the current frame
ok = cap.grab()
if not ok:
if is_target_frame:
logger.warning(
"Failed to grab frame %d during video loading.",
idx,
)
failed_frames_idx.append(idx)
continue
# Check if we should retrieve: target frame OR can recover a failed one
can_recover = VideoLoader._can_use_for_recovery(
idx, failed_frames_idx, next_target_map, total_frames
)
if is_target_frame or can_recover:
ret, frame = cap.retrieve()
if ret and frame is not None and frame.size > 0:
rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frames_list.append(rgb_frame)
valid_frame_indices.append(idx)
i += 1
if can_recover:
recovered_idx = failed_frames_idx.pop(0)
recovered_map[recovered_idx] = idx
logger.info(
"Recovered frame %d using frame %d (delay: %d)",
recovered_idx,
idx,
idx - recovered_idx,
)
elif is_target_frame:
logger.warning(
"Failed to retrieve frame %d during video loading.",
idx,
)
failed_frames_idx.append(idx)
# Log any remaining failed frames
for failed_idx in failed_frames_idx:
logger.warning(
"Frame %d could not be recovered (end of video).",
failed_idx,
)
# Stack frames
if frames_list:
frames = np.stack(frames_list)
else:
frames = np.empty((0, height, width, 3), dtype=np.uint8)
return frames, valid_frame_indices, recovered_map
@staticmethod
def _read_frames(
cap,
frame_indices: set[int],
num_expected_frames: int,
max_frame_idx: int,
) -> tuple[npt.NDArray, int, list[int]]:
import cv2
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
frames = np.empty((num_expected_frames, height, width, 3), dtype=np.uint8)
i = 0
valid_frame_indices = []
for idx in range(max_frame_idx + 1):
ok = cap.grab()
if not ok:
# Frame is broken/unreadable, log warning
if idx in frame_indices:
logger.warning(
"Failed to grab frame %d during video loading. "
"This frame will be skipped.",
idx,
)
continue
if idx in frame_indices:
ret, frame = cap.retrieve()
if ret:
frames[i] = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
valid_frame_indices.append(idx)
i += 1
else:
# retrieve() failed even though grab() succeeded
logger.warning(
"Failed to retrieve frame %d during video loading. "
"This frame will be skipped.",
idx,
)
valid_num_frames = len(valid_frame_indices)
if valid_num_frames < num_expected_frames:
logger.warning(
"Video loading completed with %d broken/unreadable frames. "
"Expected %d frames but only loaded %d frames.",
num_expected_frames - valid_num_frames,
num_expected_frames,
valid_num_frames,
)
return frames[:valid_num_frames], valid_num_frames, valid_frame_indices
VIDEO_LOADER_REGISTRY = ExtensionManager()
@VIDEO_LOADER_REGISTRY.register("opencv")
class OpenCVVideoBackend(VideoLoader):
def get_cv2_video_api(self):
import cv2.videoio_registry as vr
api_pref = None
for backend in vr.getStreamBufferedBackends():
if not vr.hasBackend(backend):
continue
if not vr.isBackendBuiltIn(backend):
_, abi, api = vr.getStreamBufferedBackendPluginVersion(backend)
if abi < 1 or (abi == 1 and api < 2):
continue
api_pref = backend
break
return api_pref
@classmethod
def load_bytes(
cls,
data: bytes,
num_frames: int = -1,
fps: int = -1,
max_duration: int = 300,
frame_recovery: bool = False,
**kwargs,
) -> tuple[npt.NDArray, dict[str, Any]]:
"""
Load video frames from bytes.
Args:
data: Raw video bytes
num_frames: Target number of frames to sample (-1 for all)
fps: Target FPS for sampling (-1 for original)
max_duration: Maximum duration (unused in base backend)
frame_recovery: Enable forward-scan recovery for failed frames
Returns:
Tuple of (frames_array, metadata_dict)
"""
import cv2
backend = cls().get_cv2_video_api()
cap = cv2.VideoCapture(BytesIO(data), backend, [])
if not cap.isOpened():
raise ValueError("Could not open video stream")
total_frames_num = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
original_fps = cap.get(cv2.CAP_PROP_FPS)
duration = total_frames_num / original_fps if original_fps > 0 else 0
# resample video to target num_frames and fps
# - the minimum of the two will be used
num_frames_to_sample = total_frames_num
if num_frames > 0:
num_frames_to_sample = min(num_frames, total_frames_num)
if fps > 0:
num_frames_to_sample = min(num_frames_to_sample, math.floor(duration * fps))
num_frames_to_sample = max(1, num_frames_to_sample) # at least one sample
if num_frames_to_sample == total_frames_num:
frame_idx = list(range(0, num_frames_to_sample))
else:
uniform_sampled_frames = np.linspace(
0, total_frames_num - 1, num_frames_to_sample, dtype=int
)
frame_idx = uniform_sampled_frames.tolist()
if frame_recovery:
frames, valid_frame_indices, recovered_map = cls._read_frames_with_recovery(
cap, frame_idx, total_frames_num
)
valid_num_frames = len(valid_frame_indices)
if recovered_map:
logger.info(
"Frame recovery: %d frames recovered using forward scan.",
len(recovered_map),
)
else:
frame_idx_set = set(frame_idx)
frames, valid_num_frames, valid_frame_indices = cls._read_frames(
cap, frame_idx_set, num_frames_to_sample, max(frame_idx)
)
# Use transformers transformers.video_utils.VideoMetadata format
# NOTE(Isotr0py): For models like Qwen3-VL/GLM4.5V, this metadata
# can cause incorrect timestamp calculation without num_frames=-1.
metadata = {
"total_num_frames": total_frames_num,
"fps": original_fps,
"duration": duration,
"video_backend": "opencv",
"frames_indices": valid_frame_indices,
# extra field used to control hf processor's video
# sampling behavior
"do_sample_frames": valid_num_frames == total_frames_num,
}
return frames, metadata
@VIDEO_LOADER_REGISTRY.register("opencv_dynamic")
class OpenCVDynamicVideoBackend(OpenCVVideoBackend):
@classmethod
def load_bytes(
cls,
data: bytes,
num_frames: int = -1,
fps: int = 2,
max_duration: int = 300,
frame_recovery: bool = False,
**kwargs,
) -> tuple[npt.NDArray, dict[str, Any]]:
"""
Load video frames with dynamic sampling based on duration.
Args:
data: Raw video bytes
num_frames: Not used in dynamic backend
fps: Target FPS for sampling (default: 2)
max_duration: Maximum video duration to process (default: 300s)
frame_recovery: Enable forward-scan recovery for failed frames
Returns:
Tuple of (frames_array, metadata_dict)
"""
import cv2
backend = cls().get_cv2_video_api()
cap = cv2.VideoCapture(BytesIO(data), backend, [])
if not cap.isOpened():
raise ValueError("Could not open video stream")
total_frames_num = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
original_fps = cap.get(cv2.CAP_PROP_FPS)
duration = total_frames_num / original_fps if original_fps > 0 else 0
# resample video to target num_frames
max_frame_idx = total_frames_num - 1
duration = duration or round(max_frame_idx / original_fps) + 1
# Refer to:
# https://github.com/huggingface/transformers/blob/v4.55.4/src/transformers/models/glm4v/video_processing_glm4v.py#L103-L140
frame_indices_list: list[int]
if duration <= max_duration:
n = int(math.floor(duration * fps))
frame_indices_list = sorted(
{
min(max_frame_idx, int(math.ceil(i * original_fps / fps)))
for i in range(n)
}
)
else:
num_samples = int(max_duration * fps)
if num_samples >= total_frames_num:
frame_indices_list = list(range(total_frames_num))
else:
target_seconds = np.linspace(0, duration, num_samples, endpoint=True)
frame_indices_list = sorted(
{
min(max_frame_idx, int(math.ceil(t * original_fps)))
for t in target_seconds
}
)
if frame_recovery:
frames, valid_frame_indices, recovered_map = cls._read_frames_with_recovery(
cap, frame_indices_list, total_frames_num
)
valid_num_frames = len(valid_frame_indices)
if recovered_map:
logger.info(
"Frame recovery: %d frames recovered using forward scan.",
len(recovered_map),
)
else:
frame_indices_set = set(frame_indices_list)
frames, valid_num_frames, valid_frame_indices = cls._read_frames(
cap, frame_indices_set, len(frame_indices_list), total_frames_num - 1
)
# Use transformers transformers.video_utils.VideoMetadata format
metadata = {
"total_num_frames": total_frames_num,
"fps": original_fps,
"duration": duration,
"video_backend": "opencv_dynamic",
"frames_indices": valid_frame_indices,
"do_sample_frames": False,
}
return frames, metadata
@VIDEO_LOADER_REGISTRY.register("molmo2")
class Molmo2VideoBackend(VideoLoader):
def get_cv2_video_api(self):
import cv2.videoio_registry as vr
api_pref = None
for backend in vr.getStreamBufferedBackends():
if not vr.hasBackend(backend):
continue
if not vr.isBackendBuiltIn(backend):
_, abi, api = vr.getStreamBufferedBackendPluginVersion(backend)
if abi < 1 or (abi == 1 and api < 2):
continue
api_pref = backend
break
return api_pref
@classmethod
def get_candidate_target_fps(
cls,
video_fps: float,
sampling_fps: float,
max_fps: float = 8.0,
) -> list[float]:
"""
Return the subset of `video_fps` factors that remain multiples
of `sampling_fps`.
Examples:
>>> get_candidate_target_fps(video_fps=6, sampling_fps=2)
[2, 6]
>>> get_candidate_target_fps(video_fps=5, sampling_fps=1)
[1, 5]
>>> get_candidate_target_fps(video_fps=2, sampling_fps=2)
[2]
>>> get_candidate_target_fps(video_fps=5, sampling_fps=2)
Traceback (most recent call last):
...
ValueError: sampling_fps=2 must divide video_fps=5 to produce
consistent frame steps.
"""
video_fps = int(video_fps)
sampling_fps = int(sampling_fps)
max_fps = int(max_fps)
if sampling_fps is None:
raise ValueError("sampling_fps must be provided")
if video_fps <= 0 or sampling_fps <= 0:
raise ValueError(
"video_fps and sampling_fps must be positive "
f"(got {video_fps}, {sampling_fps})"
)
if video_fps % sampling_fps != 0:
raise ValueError(
f"sampling_fps={sampling_fps} must divide video_fps={video_fps}."
)
candidates = []
for candidate in range(sampling_fps, video_fps + 1, sampling_fps):
if candidate > max_fps:
break
if video_fps % candidate == 0:
candidates.append(float(candidate))
return candidates
@classmethod
def get_target_fps(
cls,
video_fps: float,
max_frames: int,
total_frames: int,
frame_sample_mode: str,
candidate_target_fps: list[float],
) -> float | None:
"""
Get the target fps that best spans the videoand has the most frames sampled
"""
num_frames_sampled = 0
selected_target_fps = None
for target_fps in candidate_target_fps:
step_size = max(int(video_fps / target_fps), 1)
num_frames_sampled_at_fps = int(total_frames / step_size)
if num_frames_sampled == 0:
if (
"uniform" in frame_sample_mode
and num_frames_sampled_at_fps > max_frames
):
break
selected_target_fps = target_fps
num_frames_sampled = num_frames_sampled_at_fps
else:
# the candidate sampling fps increases so frame count can't decrease
assert num_frames_sampled <= num_frames_sampled_at_fps
if num_frames_sampled_at_fps > max_frames:
# choose the sampling fps that spans the video
continue
elif num_frames_sampled_at_fps > num_frames_sampled:
# both are less than max_frames; choose the one with higher
# density of frames sampled
selected_target_fps = target_fps
num_frames_sampled = num_frames_sampled_at_fps
return selected_target_fps
@classmethod
def get_frame_times_and_chosen_fps(
cls,
selected_target_fps: float | None,
total_frames: int,
max_frames: int,
video_fps: float,
) -> tuple[float | None, npt.NDArray]:
if selected_target_fps is None:
frame_indices = np.linspace(
0, total_frames, max_frames, endpoint=False, dtype=int
)
else:
step_size = max(int(video_fps / selected_target_fps), 1)
frame_indices = np.arange(0, total_frames, step_size)
if len(frame_indices) > max_frames:
frame_indices = frame_indices[:max_frames]
return selected_target_fps, frame_indices
@classmethod
def sample_times(
cls,
duration: float,
max_frames: int,
frame_sample_mode: str,
max_fps: int | None,
candidate_target_fps: list[float] | None = None,
**kwargs,
) -> npt.NDArray:
if frame_sample_mode == "fps":
assert candidate_target_fps is not None
# Try larger and larger FPSs until we hit one that can't span the video
sampling_fps = candidate_target_fps[0]
for candidate_fps in candidate_target_fps[1:]:
if max_frames / candidate_fps < duration:
break
sampling_fps = candidate_fps
times = np.arange(0, max_frames) / sampling_fps
times = times[times < duration]
return times
elif frame_sample_mode == "uniform_last_frame":
if max_fps is not None:
max_duration = (
max_frames - 1
) / max_fps # -1 to include the last frame
if max_duration < duration:
times = np.linspace(
0, duration, num=max_frames, endpoint=True, dtype=np.float64
)
else:
times = np.arange(0.0, stop=duration, step=1 / max_fps)
times = np.concatenate([times, [duration]], axis=0)
assert len(times) <= max_frames
else:
times = np.linspace(
0, duration, num=max_frames, endpoint=True, dtype=np.float64
)
return times
else:
raise NotImplementedError(frame_sample_mode)
@classmethod
def _sample_frames(
cls,
total_num_frames: int,
video_fps: float,
duration: float,
frame_sample_mode: str,
num_frames: int,
max_fps: int,
sampling_fps: int,
) -> npt.NDArray:
if frame_sample_mode == "uniform_last_frame" and max_fps is not None:
if total_num_frames <= 2:
indices = np.arange(total_num_frames).astype(int)
elif duration > (num_frames - 1) / max_fps: # -1 to include the last frame
# uniform fallback
indices = np.linspace(
0,
total_num_frames - 1,
num=min(num_frames, total_num_frames),
endpoint=True,
).astype(int)
else:
float_indices = np.arange(
0.0,
stop=total_num_frames - 1,
step=float(video_fps / max_fps),
)
if np.round(float_indices[-1]) != total_num_frames - 1:
float_indices = np.concatenate(
[float_indices, [total_num_frames - 1]], axis=0
)
indices = np.round(float_indices).astype(int)
assert indices[-1] < total_num_frames
assert len(float_indices) <= num_frames
elif frame_sample_mode == "uniform_last_frame":
indices = np.linspace(
0,
total_num_frames - 1,
num=min(num_frames, total_num_frames),
endpoint=True,
).astype(int)
elif frame_sample_mode == "fps":
candidate_target_fps = cls.get_candidate_target_fps(video_fps, sampling_fps)
selected_target_fps = cls.get_target_fps(
video_fps,
num_frames,
total_num_frames,
frame_sample_mode,
candidate_target_fps,
)
_, indices = cls.get_frame_times_and_chosen_fps(
selected_target_fps,
total_num_frames,
num_frames,
video_fps,
)
else:
raise NotImplementedError(frame_sample_mode)
return indices
@classmethod
def load_bytes_opencv(
cls,
data: bytes,
frame_sample_mode: str | None = None,
num_frames: int = -1,
max_fps: int = 2,
sampling_fps: int = 2,
**kwargs,
) -> tuple[npt.NDArray, dict[str, Any]]:
import cv2
backend = cls().get_cv2_video_api()
cap = cv2.VideoCapture(BytesIO(data), backend, [])
if not cap.isOpened():
raise ValueError("Could not open video stream")
total_frames_num = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
original_fps = cap.get(cv2.CAP_PROP_FPS)
duration = total_frames_num / original_fps if original_fps > 0 else 0
if frame_sample_mode is None:
# Use transformers transformers.video_utils.VideoMetadata format
frame_idx = list(range(0, total_frames_num))
frame_idx_set = set(frame_idx)
frames, valid_num_frames, valid_frame_indices = cls._read_frames(
cap, frame_idx_set, total_frames_num, max(frame_idx)
)
do_sample_frames = valid_num_frames == total_frames_num
metadata = {
"total_num_frames": total_frames_num,
"fps": original_fps,
"duration": duration,
"video_backend": "opencv",
"do_sample_frames": do_sample_frames,
}
if not do_sample_frames:
metadata["frames_indices"] = valid_frame_indices
return frames, metadata
frame_idx = cls._sample_frames(
total_frames_num,
original_fps,
duration,
frame_sample_mode,
num_frames,
max_fps,
sampling_fps,
).tolist()
frames, valid_num_frames, valid_frame_indices = cls._read_frames(
cap,
set(frame_idx),
len(frame_idx),
total_frames_num - 1,
)
metadata = {
"total_num_frames": total_frames_num,
"fps": original_fps,
"duration": duration,
"video_backend": "opencv",
"frames_indices": valid_frame_indices,
"do_sample_frames": False,
}
return frames, metadata
@classmethod
def load_bytes(
cls,
data: bytes,
num_frames: int = -1,
**kwargs,
) -> tuple[npt.NDArray, dict[str, Any]]:
frame_sample_mode = cast(str | None, kwargs.pop("frame_sample_mode", None))
max_fps = cast(int, kwargs.pop("max_fps", 2))
sampling_fps = cast(int, kwargs.pop("sampling_fps", 2))
out = cls.load_bytes_opencv(
data,
frame_sample_mode,
num_frames,
max_fps,
sampling_fps,
**kwargs,
)
return out
@VIDEO_LOADER_REGISTRY.register("openpangu")
class OpenCVDynamicOpenPanguVideoBackend(OpenCVVideoBackend):
@classmethod
def load_bytes(
cls,
data: bytes,
num_frames: int = 32,
fps: int = 1,
max_duration: int = 300,
frame_recovery: bool = False,
**kwargs,
) -> tuple[npt.NDArray, dict[str, Any]]:
"""
Load video frames with dynamic sampling based on duration.
Assume that total_num_frames = 10 and fps = 1.
The timestamp of frame 0 is 0.0.
The timestamp of frame 1 is 1.0.…
The timestamp of frame 9 (the last frame) should be 9.0, that is,
(total_frames_num 1) / original_fps.
Args:
data: Raw video bytes
num_frames: Not used in dynamic backend
fps: Target FPS for sampling (default: 1)
Returns:
Tuple of (frames_array, metadata_dict)
"""
import cv2
backend = cls().get_cv2_video_api()
cap = cv2.VideoCapture(BytesIO(data), backend, [])
if not cap.isOpened():
raise ValueError("Could not open video stream")
total_frames_num = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
original_fps = float(cap.get(cv2.CAP_PROP_FPS))
# The timestamp of the rightmost frame, cannot be used to calculate frame 0.
if total_frames_num >= 1 and original_fps > 0:
total_duration = (total_frames_num - 1) / original_fps
else:
total_duration = 0
# `fps` is the FPS parameter passed in for sampling,
# -1 indicates that sampling can be performed directly without FPS limitation.
if fps > 0:
# Num_frames is the maximum number of frames to sample.
# If fewer frames are sampled at this sample_fps, the update duration will be longer. # noqa: E501
if num_frames >= int(total_duration * fps) + 1:
num_frames = int(total_duration * fps) + 1
# Under the new maximum frame rate, the video duration of the rightmost frame, # noqa: E501
# cannot be calculated for frame 0.
total_duration = min(total_duration, (num_frames - 1) / fps)
elif fps != -1:
raise ValueError(
f"requires dataset fps is -1 or greater than 0 but got {fps}"
)
sample_frame_timestamps = np.linspace(
0, total_duration, num_frames, dtype=float
)
frames_indices = [
min(total_frames_num - 1, round(t * original_fps))
for t in sample_frame_timestamps
]
frames, valid_frame_indices, recovered_map = cls._read_frames_with_recovery(
cap, frames_indices, total_frames_num
)
if recovered_map:
logger.info(
"Frame recovery: %d frames recovered using forward scan.",
len(recovered_map),
)
metadata = {
"total_num_frames": total_frames_num,
"fps": original_fps,
"duration": total_duration,
"video_backend": "opencv_dynamic_openpangu",
"frames_indices": valid_frame_indices,
"do_sample_frames": False,
}
return frames, metadata