Sync from v0.13
This commit is contained in:
40
vllm/multimodal/__init__.py
Normal file
40
vllm/multimodal/__init__.py
Normal file
@@ -0,0 +1,40 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from .hasher import MultiModalHasher
|
||||
from .inputs import (
|
||||
BatchedTensorInputs,
|
||||
ModalityData,
|
||||
MultiModalDataBuiltins,
|
||||
MultiModalDataDict,
|
||||
MultiModalKwargs,
|
||||
MultiModalKwargsItems,
|
||||
MultiModalPlaceholderDict,
|
||||
MultiModalUUIDDict,
|
||||
NestedTensors,
|
||||
)
|
||||
from .registry import MultiModalRegistry
|
||||
|
||||
MULTIMODAL_REGISTRY = MultiModalRegistry()
|
||||
"""
|
||||
The global [`MultiModalRegistry`][vllm.multimodal.registry.MultiModalRegistry]
|
||||
is used by model runners to dispatch data processing according to the target
|
||||
model.
|
||||
|
||||
Info:
|
||||
[mm_processing](../../../design/mm_processing.md)
|
||||
"""
|
||||
|
||||
__all__ = [
|
||||
"BatchedTensorInputs",
|
||||
"ModalityData",
|
||||
"MultiModalDataBuiltins",
|
||||
"MultiModalDataDict",
|
||||
"MultiModalHasher",
|
||||
"MultiModalKwargs",
|
||||
"MultiModalKwargsItems",
|
||||
"MultiModalPlaceholderDict",
|
||||
"MultiModalUUIDDict",
|
||||
"NestedTensors",
|
||||
"MULTIMODAL_REGISTRY",
|
||||
"MultiModalRegistry",
|
||||
]
|
||||
147
vllm/multimodal/audio.py
Normal file
147
vllm/multimodal/audio.py
Normal file
@@ -0,0 +1,147 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
import pybase64
|
||||
import torch
|
||||
|
||||
from vllm.utils.import_utils import PlaceholderModule
|
||||
from vllm.utils.serial_utils import tensor2base64
|
||||
|
||||
from .base import MediaIO
|
||||
|
||||
try:
|
||||
import librosa
|
||||
except ImportError:
|
||||
librosa = PlaceholderModule("librosa") # type: ignore[assignment]
|
||||
|
||||
try:
|
||||
import soundfile
|
||||
except ImportError:
|
||||
soundfile = PlaceholderModule("soundfile") # type: ignore[assignment]
|
||||
|
||||
|
||||
def resample_audio_librosa(
|
||||
audio: npt.NDArray[np.floating],
|
||||
*,
|
||||
orig_sr: float,
|
||||
target_sr: float,
|
||||
) -> npt.NDArray[np.floating]:
|
||||
return librosa.resample(audio, orig_sr=orig_sr, target_sr=target_sr)
|
||||
|
||||
|
||||
def resample_audio_scipy(
|
||||
audio: npt.NDArray[np.floating],
|
||||
*,
|
||||
orig_sr: float,
|
||||
target_sr: float,
|
||||
):
|
||||
# lazy import scipy.signal, otherwise it will crash doc build.
|
||||
import scipy.signal
|
||||
|
||||
if orig_sr > target_sr:
|
||||
return scipy.signal.resample_poly(audio, 1, orig_sr // target_sr)
|
||||
elif orig_sr < target_sr:
|
||||
return scipy.signal.resample_poly(audio, target_sr // orig_sr, 1)
|
||||
return audio
|
||||
|
||||
|
||||
class AudioResampler:
|
||||
"""Resample audio data to a target sample rate."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
target_sr: float | None = None,
|
||||
method: Literal["librosa", "scipy"] = "librosa",
|
||||
):
|
||||
self.target_sr = target_sr
|
||||
self.method = method
|
||||
|
||||
def resample(
|
||||
self,
|
||||
audio: npt.NDArray[np.floating],
|
||||
*,
|
||||
orig_sr: float,
|
||||
) -> npt.NDArray[np.floating]:
|
||||
if self.target_sr is None:
|
||||
raise RuntimeError(
|
||||
"Audio resampling is not supported when `target_sr` is not provided"
|
||||
)
|
||||
if self.method == "librosa":
|
||||
return resample_audio_librosa(
|
||||
audio, orig_sr=orig_sr, target_sr=self.target_sr
|
||||
)
|
||||
elif self.method == "scipy":
|
||||
return resample_audio_scipy(
|
||||
audio, orig_sr=orig_sr, target_sr=self.target_sr
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid resampling method: {self.method}. "
|
||||
"Supported methods are 'librosa' and 'scipy'."
|
||||
)
|
||||
|
||||
|
||||
class AudioMediaIO(MediaIO[tuple[npt.NDArray, float]]):
|
||||
def __init__(self, **kwargs) -> None:
|
||||
super().__init__()
|
||||
|
||||
# `kwargs` contains custom arguments from
|
||||
# --media-io-kwargs for this modality.
|
||||
# They can be passed to the underlying
|
||||
# media loaders (e.g. custom implementations)
|
||||
# for flexible control.
|
||||
self.kwargs = kwargs
|
||||
|
||||
def load_bytes(self, data: bytes) -> tuple[npt.NDArray, float]:
|
||||
return librosa.load(BytesIO(data), sr=None)
|
||||
|
||||
def load_base64(
|
||||
self,
|
||||
media_type: str,
|
||||
data: str,
|
||||
) -> tuple[npt.NDArray, float]:
|
||||
return self.load_bytes(base64.b64decode(data))
|
||||
|
||||
def load_file(self, filepath: Path) -> tuple[npt.NDArray, float]:
|
||||
return librosa.load(filepath, sr=None)
|
||||
|
||||
def encode_base64(self, media: tuple[npt.NDArray, int]) -> str:
|
||||
audio, sr = media
|
||||
|
||||
with BytesIO() as buffer:
|
||||
soundfile.write(buffer, audio, sr, format="WAV")
|
||||
data = buffer.getvalue()
|
||||
|
||||
return base64.b64encode(data).decode("utf-8")
|
||||
|
||||
|
||||
class AudioEmbeddingMediaIO(MediaIO[torch.Tensor]):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def load_bytes(self, data: bytes) -> torch.Tensor:
|
||||
buffer = BytesIO(data)
|
||||
# Enable sparse tensor integrity checks to prevent out-of-bounds
|
||||
# writes from maliciously crafted tensors
|
||||
with torch.sparse.check_sparse_tensor_invariants():
|
||||
tensor = torch.load(buffer, weights_only=True)
|
||||
return tensor.to_dense()
|
||||
|
||||
def load_base64(self, media_type: str, data: str) -> torch.Tensor:
|
||||
return self.load_bytes(pybase64.b64decode(data, validate=True))
|
||||
|
||||
def load_file(self, filepath: Path) -> torch.Tensor:
|
||||
# Enable sparse tensor integrity checks to prevent out-of-bounds
|
||||
# writes from maliciously crafted tensors
|
||||
with torch.sparse.check_sparse_tensor_invariants():
|
||||
tensor = torch.load(filepath, weights_only=True)
|
||||
return tensor.to_dense()
|
||||
|
||||
def encode_base64(self, media: torch.Tensor) -> str:
|
||||
return tensor2base64(media)
|
||||
56
vllm/multimodal/base.py
Normal file
56
vllm/multimodal/base.py
Normal file
@@ -0,0 +1,56 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
import numpy as np
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
@dataclass
|
||||
class MediaWithBytes(Generic[_T]):
|
||||
"""
|
||||
Wrapper that couples a media object with its original encoded bytes.
|
||||
|
||||
This ensures the raw bytes and media object remain synchronized,
|
||||
preventing cache corruption from in-place modifications.
|
||||
|
||||
The wrapper delegates attribute access to the underlying media object,
|
||||
making it behave transparently like the wrapped type (e.g., PIL.Image).
|
||||
|
||||
NOTE: Currently, this wrapper is used only for the image modality.
|
||||
"""
|
||||
|
||||
media: _T
|
||||
original_bytes: bytes
|
||||
|
||||
def __array__(self, *args, **kwargs) -> np.ndarray:
|
||||
"""Allow np.array(obj) to return np.array(obj.media)."""
|
||||
return np.array(self.media, *args, **kwargs)
|
||||
|
||||
def __getattr__(self, name: str):
|
||||
"""Delegate attribute access to the underlying media object."""
|
||||
# This is only called when the attribute is not found on self
|
||||
return getattr(self.media, name)
|
||||
|
||||
|
||||
class MediaIO(ABC, Generic[_T]):
|
||||
@abstractmethod
|
||||
def load_bytes(self, data: bytes) -> _T:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def load_base64(self, media_type: str, data: str) -> _T:
|
||||
"""
|
||||
List of media types:
|
||||
https://www.iana.org/assignments/media-types/media-types.xhtml
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def load_file(self, filepath: Path) -> _T:
|
||||
raise NotImplementedError
|
||||
823
vllm/multimodal/cache.py
Normal file
823
vllm/multimodal/cache.py
Normal file
@@ -0,0 +1,823 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import operator
|
||||
import sys
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Mapping, Sequence
|
||||
from multiprocessing.synchronize import Lock as LockType
|
||||
from typing import TYPE_CHECKING, Generic, TypeAlias, TypeVar, cast
|
||||
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.distributed.device_communicators.shm_object_storage import (
|
||||
MsgpackSerde,
|
||||
SingleWriterShmObjectStorage,
|
||||
SingleWriterShmRingBuffer,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.cache import CacheInfo, LRUCache
|
||||
from vllm.utils.jsontree import json_count_leaves, json_map_leaves, json_reduce_leaves
|
||||
from vllm.utils.mem_constants import GiB_bytes, MiB_bytes
|
||||
|
||||
from .inputs import (
|
||||
MultiModalBatchedField,
|
||||
MultiModalFeatureSpec,
|
||||
MultiModalFieldElem,
|
||||
MultiModalKwargsItem,
|
||||
MultiModalKwargsItems,
|
||||
NestedTensors,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
|
||||
from .processing import ResolvedPromptUpdate
|
||||
from .registry import MultiModalRegistry
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class MultiModalProcessorCacheItem:
|
||||
"""
|
||||
The data to store inside `MultiModalProcessorOnlyCache`.
|
||||
|
||||
Args:
|
||||
item: The processed tensor data corresponding to a multi-modal item.
|
||||
prompt_updates: The prompt updates corresponding to `item`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
item: MultiModalKwargsItem,
|
||||
prompt_updates: Sequence["ResolvedPromptUpdate"],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.item = item
|
||||
self.prompt_updates = prompt_updates
|
||||
|
||||
|
||||
class MultiModalProcessorCacheItemMetadata:
|
||||
"""
|
||||
The metadata to store inside `MultiModalProcessorSenderCache`.
|
||||
|
||||
Args:
|
||||
item: The processed tensor data corresponding to a multi-modal item.
|
||||
Since P1 already stores the tensor data, we only store its size
|
||||
metadata in P0 to reduce memory usage. The size metadata is still
|
||||
needed to keep the same cache eviction policy as P0.
|
||||
prompt_updates: The prompt updates corresponding to `item`.
|
||||
This needs to stay on P0 because for some models, they are
|
||||
dependent on the processed tensor data (cached on P1).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
item: MultiModalKwargsItem,
|
||||
prompt_updates: Sequence["ResolvedPromptUpdate"],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.item_size = MultiModalCache.get_item_size(item)
|
||||
self.prompt_updates = prompt_updates
|
||||
|
||||
|
||||
MultiModalCacheValue: TypeAlias = (
|
||||
MultiModalProcessorCacheItem
|
||||
| MultiModalProcessorCacheItemMetadata
|
||||
| MultiModalKwargsItems
|
||||
| MultiModalKwargsItem
|
||||
| Mapping[str, NestedTensors]
|
||||
)
|
||||
|
||||
_V = TypeVar("_V", bound=MultiModalCacheValue)
|
||||
|
||||
|
||||
class MultiModalCache:
|
||||
@classmethod
|
||||
def get_leaf_size(cls, leaf: object) -> int:
|
||||
if isinstance(leaf, MultiModalProcessorCacheItem):
|
||||
return cls.get_leaf_size(leaf.item)
|
||||
if isinstance(leaf, MultiModalProcessorCacheItemMetadata):
|
||||
return leaf.item_size
|
||||
|
||||
# These are not subclasses of dict
|
||||
if isinstance(
|
||||
leaf,
|
||||
(MultiModalKwargsItems, MultiModalKwargsItem, MultiModalFieldElem),
|
||||
):
|
||||
return cls.get_item_size(leaf.data) # type: ignore
|
||||
|
||||
# sys.getsizeof doesn't work for tensors
|
||||
if isinstance(leaf, torch.Tensor):
|
||||
return leaf.nbytes
|
||||
|
||||
return sys.getsizeof(leaf)
|
||||
|
||||
@classmethod
|
||||
def get_item_size(
|
||||
cls,
|
||||
value: MultiModalCacheValue,
|
||||
*,
|
||||
debug: bool = False,
|
||||
) -> int:
|
||||
size = json_reduce_leaves(
|
||||
operator.add, json_map_leaves(cls.get_leaf_size, value)
|
||||
)
|
||||
|
||||
if debug:
|
||||
leaf_count = json_count_leaves(value)
|
||||
logger.debug(
|
||||
"Calculated size of %s to be %.2f GiB (%d leaves)",
|
||||
type(value),
|
||||
size / GiB_bytes,
|
||||
leaf_count,
|
||||
)
|
||||
|
||||
return size
|
||||
|
||||
@classmethod
|
||||
def get_item_complexity(cls, value: MultiModalCacheValue) -> int:
|
||||
"""
|
||||
Get the number of leaf elements in a multi-modal cache value.
|
||||
|
||||
This provides a measure of structural complexity that can be useful
|
||||
for debugging cache performance and understanding data patterns.
|
||||
|
||||
Args:
|
||||
value: The multi-modal cache value to analyze.
|
||||
|
||||
Returns:
|
||||
The number of leaf elements in the nested structure.
|
||||
"""
|
||||
return json_count_leaves(value)
|
||||
|
||||
@classmethod
|
||||
def get_lru_cache(
|
||||
cls,
|
||||
capacity_gb: float,
|
||||
value_type: type[_V],
|
||||
*,
|
||||
debug: bool = False,
|
||||
) -> LRUCache[str, _V]:
|
||||
return LRUCache(
|
||||
GiB_bytes * capacity_gb,
|
||||
getsizeof=lambda x: cls.get_item_size(x, debug=debug),
|
||||
)
|
||||
|
||||
|
||||
_I = TypeVar("_I", contravariant=True)
|
||||
_O = TypeVar("_O", covariant=True)
|
||||
|
||||
|
||||
class BaseMultiModalCache(ABC, Generic[_I, _O]):
|
||||
"""
|
||||
Abstract base class to read/write multi-modal items from cache.
|
||||
|
||||
The idea of multi-modal caching is based on having a client and server
|
||||
where the client executes in the frontend process (=P0) and
|
||||
the server in the core process (=P1). The data flow is as follows:
|
||||
|
||||
```
|
||||
is_cached() x N get_and_update()
|
||||
P0: From API -----------------> -----------------> To P1
|
||||
|
||||
get_and_update()
|
||||
P1: From P0 -----------------> To model
|
||||
```
|
||||
|
||||
`is_cached()` can be called any number of times in P0. However,
|
||||
`get_and_update()` must be called in P0 and P1 one after another
|
||||
so that their cache eviction order remains the same.
|
||||
|
||||
This ensures that the keys in P0 and P1 caches are mirrored,
|
||||
allowing us to determine whether a key is cached in P1 by looking
|
||||
up the P0 cache, without having to communicate with P1.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_and_update_item(
|
||||
self,
|
||||
mm_item: _I,
|
||||
mm_hash: str,
|
||||
) -> _O:
|
||||
"""
|
||||
Possibly update a multi-modal item based on whether it is
|
||||
in the underlying cache.
|
||||
|
||||
This update is done out-of-place and updates the cache eviction order.
|
||||
|
||||
Args:
|
||||
mm_item: The multi-modal item to update.
|
||||
mm_hash: The hash of `mm_item`.
|
||||
|
||||
Returns:
|
||||
The update multi-modal item.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_and_update(
|
||||
self,
|
||||
mm_items: Sequence[_I],
|
||||
mm_hashes: list[str],
|
||||
) -> list[_O]:
|
||||
"""
|
||||
Possibly update a sequence of multi-modal items based on whether they
|
||||
are in the underlying cache.
|
||||
|
||||
This update is done out-of-place and updates the cache eviction order.
|
||||
|
||||
Args:
|
||||
mm_items: The multi-modal items to update.
|
||||
mm_hashes: The hash of each item in `mm_items`.
|
||||
|
||||
Returns:
|
||||
A new list of updated multi-modal items.
|
||||
"""
|
||||
assert len(mm_items) == len(mm_hashes)
|
||||
|
||||
return [
|
||||
self.get_and_update_item(mm_item, mm_hash)
|
||||
for mm_item, mm_hash in zip(mm_items, mm_hashes)
|
||||
]
|
||||
|
||||
@abstractmethod
|
||||
def clear_cache(self) -> None:
|
||||
"""Clear the underlying cache."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
MultiModalProcessorCacheInItem: TypeAlias = (
|
||||
tuple[MultiModalKwargsItem, Sequence["ResolvedPromptUpdate"]] | None
|
||||
)
|
||||
|
||||
|
||||
MultiModalProcessorCacheOutItem: TypeAlias = tuple[
|
||||
MultiModalKwargsItem | None, Sequence["ResolvedPromptUpdate"]
|
||||
]
|
||||
|
||||
|
||||
class BaseMultiModalProcessorCache(
|
||||
BaseMultiModalCache[MultiModalProcessorCacheInItem, MultiModalProcessorCacheOutItem]
|
||||
):
|
||||
"""The required interface for caches on P0."""
|
||||
|
||||
@abstractmethod
|
||||
def is_cached_item(self, mm_hash: str) -> bool:
|
||||
"""
|
||||
Check whether a multi-modal item is
|
||||
in the underlying cache.
|
||||
|
||||
This **DOES NOT** update the cache eviction order.
|
||||
|
||||
Args:
|
||||
mm_hash: The hash of the item to check.
|
||||
|
||||
Returns:
|
||||
`True` if the item is cached, otherwise `False`.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def is_cached(self, mm_hashes: list[str]) -> list[bool]:
|
||||
"""
|
||||
Check whether a sequence of multi-modal items are
|
||||
in the underlying cache.
|
||||
|
||||
This **DOES NOT** update the cache eviction order.
|
||||
|
||||
Args:
|
||||
mm_hashes: The hash of each item to check.
|
||||
|
||||
Returns:
|
||||
For each item, `True` if the item is cached, otherwise `False`.
|
||||
"""
|
||||
return [self.is_cached_item(mm_hash) for mm_hash in mm_hashes]
|
||||
|
||||
@abstractmethod
|
||||
def touch_sender_cache_item(self, mm_hash: str) -> None:
|
||||
"""
|
||||
Update the cache eviction order for a multi-modal item.
|
||||
|
||||
This is used to touch the item in the cache without changing
|
||||
its value.
|
||||
|
||||
Args:
|
||||
mm_hash: The hash of the multi-modal item.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def make_stats(self, *, delta: bool = False) -> CacheInfo:
|
||||
"""
|
||||
Get (and reset) the multi-modal cache stats.
|
||||
|
||||
Returns:
|
||||
The current multi-modal caching stats.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MultiModalProcessorOnlyCache(BaseMultiModalProcessorCache):
|
||||
"""
|
||||
The cache which is used on P0 when IPC caching is disabled.
|
||||
|
||||
How to update each item:
|
||||
|
||||
- If the item is in the cache, replace the input with the cached item.
|
||||
- If the item is not in the cache, store that item (which includes
|
||||
tensor data and metadata) into the cache, and return the input.
|
||||
"""
|
||||
|
||||
def __init__(self, model_config: "ModelConfig") -> None:
|
||||
super().__init__()
|
||||
|
||||
mm_config = model_config.get_multimodal_config()
|
||||
|
||||
self._cache = MultiModalCache.get_lru_cache(
|
||||
mm_config.mm_processor_cache_gb,
|
||||
MultiModalProcessorCacheItem,
|
||||
)
|
||||
|
||||
@override
|
||||
def is_cached_item(self, mm_hash: str) -> bool:
|
||||
return mm_hash in self._cache
|
||||
|
||||
@override
|
||||
def get_and_update_item(
|
||||
self,
|
||||
mm_item: MultiModalProcessorCacheInItem,
|
||||
mm_hash: str,
|
||||
) -> MultiModalProcessorCacheOutItem:
|
||||
if (cached_item := self._cache.get(mm_hash)) is not None:
|
||||
return cached_item.item, cached_item.prompt_updates
|
||||
|
||||
assert mm_item is not None, f"Expected a cached item for {mm_hash=}"
|
||||
|
||||
self._cache[mm_hash] = MultiModalProcessorCacheItem(*mm_item)
|
||||
|
||||
return mm_item
|
||||
|
||||
@override
|
||||
def touch_sender_cache_item(self, mm_hash: str) -> None:
|
||||
self._cache.touch(mm_hash)
|
||||
|
||||
@override
|
||||
def clear_cache(self) -> None:
|
||||
self._cache.clear()
|
||||
|
||||
@override
|
||||
def make_stats(self, *, delta: bool = False) -> CacheInfo:
|
||||
return self._cache.stat(delta=delta)
|
||||
|
||||
|
||||
class MultiModalProcessorSenderCache(BaseMultiModalProcessorCache):
|
||||
"""
|
||||
The cache which is used on P0 when IPC caching is enabled.
|
||||
|
||||
How to update each item:
|
||||
|
||||
- If the item is already in the cache, clear the input to avoid
|
||||
unnecessary IPC.
|
||||
|
||||
- If the item is not in the cache, store the metadata of that item so
|
||||
that the eviction policy remains the same as the cache on P1,
|
||||
and return the input.
|
||||
By only storing the metadata, we avoid keeping the data itself in
|
||||
memory inside P0.
|
||||
"""
|
||||
|
||||
def __init__(self, model_config: "ModelConfig") -> None:
|
||||
super().__init__()
|
||||
|
||||
mm_config = model_config.get_multimodal_config()
|
||||
|
||||
self._cache = MultiModalCache.get_lru_cache(
|
||||
mm_config.mm_processor_cache_gb,
|
||||
MultiModalProcessorCacheItemMetadata,
|
||||
)
|
||||
|
||||
@override
|
||||
def is_cached_item(self, mm_hash: str) -> bool:
|
||||
return mm_hash in self._cache
|
||||
|
||||
@override
|
||||
def get_and_update_item(
|
||||
self,
|
||||
mm_item: MultiModalProcessorCacheInItem,
|
||||
mm_hash: str,
|
||||
) -> MultiModalProcessorCacheOutItem:
|
||||
if (cached_item := self._cache.get(mm_hash)) is not None:
|
||||
return None, cached_item.prompt_updates
|
||||
|
||||
assert mm_item is not None, f"Expected a cached item for {mm_hash=}"
|
||||
|
||||
self._cache[mm_hash] = MultiModalProcessorCacheItemMetadata(*mm_item)
|
||||
|
||||
return mm_item
|
||||
|
||||
@override
|
||||
def touch_sender_cache_item(self, mm_hash: str) -> None:
|
||||
self._cache.touch(mm_hash)
|
||||
|
||||
@override
|
||||
def clear_cache(self) -> None:
|
||||
self._cache.clear()
|
||||
|
||||
@override
|
||||
def make_stats(self, *, delta: bool = False) -> CacheInfo:
|
||||
return self._cache.stat(delta=delta)
|
||||
|
||||
|
||||
class ShmObjectStoreSenderCache(BaseMultiModalProcessorCache):
|
||||
"""
|
||||
The cache which is used on P0 when IPC caching is enabled.
|
||||
|
||||
How to update each item:
|
||||
|
||||
- If the item is already in the cache, clear the input to avoid
|
||||
unnecessary IPC.
|
||||
|
||||
- If the item is not in the cache, store the data in shared memory.
|
||||
"""
|
||||
|
||||
def __init__(self, vllm_config: "VllmConfig") -> None:
|
||||
super().__init__()
|
||||
|
||||
self.world_size = vllm_config.parallel_config.world_size
|
||||
mm_config = vllm_config.model_config.get_multimodal_config()
|
||||
|
||||
ring_buffer = SingleWriterShmRingBuffer(
|
||||
data_buffer_size=int(mm_config.mm_processor_cache_gb * GiB_bytes),
|
||||
name=envs.VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME,
|
||||
create=True, # sender is the writer
|
||||
)
|
||||
self._shm_cache = SingleWriterShmObjectStorage(
|
||||
max_object_size=mm_config.mm_shm_cache_max_object_size_mb * MiB_bytes,
|
||||
n_readers=self.world_size,
|
||||
ring_buffer=ring_buffer,
|
||||
serde_class=MsgpackSerde,
|
||||
)
|
||||
# cache (prompt_updates, modality) for P0 only
|
||||
self._p0_cache: dict[str, tuple[Sequence[ResolvedPromptUpdate], str]] = {}
|
||||
|
||||
self._hits = 0
|
||||
self._total = 0
|
||||
self._last_info = CacheInfo(hits=0, total=0)
|
||||
|
||||
def _stat(self, *, delta: bool = False) -> CacheInfo:
|
||||
info = CacheInfo(hits=self._hits, total=self._total)
|
||||
|
||||
if delta:
|
||||
info_delta = info - self._last_info
|
||||
self._last_info = info
|
||||
info = info_delta
|
||||
|
||||
return info
|
||||
|
||||
@override
|
||||
def is_cached_item(self, mm_hash: str) -> bool:
|
||||
return self._shm_cache.is_cached(mm_hash)
|
||||
|
||||
@override
|
||||
def get_and_update_item(
|
||||
self,
|
||||
mm_item: MultiModalProcessorCacheInItem,
|
||||
mm_hash: str,
|
||||
) -> MultiModalProcessorCacheOutItem:
|
||||
if self._shm_cache.is_cached(mm_hash):
|
||||
self._hits += 1
|
||||
self._total += 1
|
||||
|
||||
address, monotonic_id = self._shm_cache.get_cached(mm_hash)
|
||||
prompt_updates, modality = self._p0_cache[mm_hash]
|
||||
return self.address_as_item(address, monotonic_id, modality), prompt_updates
|
||||
|
||||
assert mm_item is not None, f"Expected a cached item for {mm_hash=}"
|
||||
|
||||
self._total += 1
|
||||
|
||||
try:
|
||||
address, monotonic_id = self._shm_cache.put(mm_hash, mm_item[0])
|
||||
# Try to remove dangling items if p0 cache is too large.
|
||||
if len(self._p0_cache) >= 2 * len(self._shm_cache.key_index):
|
||||
self.remove_dangling_items()
|
||||
self._p0_cache[mm_hash] = mm_item[1], mm_item[0].modality
|
||||
address_item = self.address_as_item(
|
||||
address, monotonic_id, mm_item[0].modality
|
||||
)
|
||||
return address_item, mm_item[1]
|
||||
except (ValueError, MemoryError) as e:
|
||||
# put may fail if the object is too large or
|
||||
# the cache is full.
|
||||
# In this case we log the error and keep the original mm_input.
|
||||
logger.debug("Failed to cache mm_input with hash %s: %s", mm_hash, e)
|
||||
return mm_item
|
||||
|
||||
@override
|
||||
def touch_sender_cache_item(self, mm_hash: str) -> None:
|
||||
"""Touch the item in shared memory cache to prevent eviction.
|
||||
Increments writer_flag on sender side."""
|
||||
self._shm_cache.touch(mm_hash)
|
||||
|
||||
@override
|
||||
def clear_cache(self) -> None:
|
||||
self._shm_cache.clear()
|
||||
self._p0_cache.clear()
|
||||
|
||||
self._hits = 0
|
||||
self._total = 0
|
||||
self._last_info = CacheInfo(hits=0, total=0)
|
||||
|
||||
@override
|
||||
def make_stats(self, *, delta: bool = False) -> CacheInfo:
|
||||
return self._stat(delta=delta)
|
||||
|
||||
def remove_dangling_items(self) -> None:
|
||||
"""Remove items that are no longer in the shared memory cache."""
|
||||
cached_hashes = self._shm_cache.key_index.keys()
|
||||
dangling_hashes = set(self._p0_cache.keys()) - cached_hashes
|
||||
for mm_hash in dangling_hashes:
|
||||
del self._p0_cache[mm_hash]
|
||||
|
||||
def address_as_item(
|
||||
self, address: int, monotonic_id: int, modality: str
|
||||
) -> MultiModalKwargsItem:
|
||||
addr_elem = MultiModalFieldElem(
|
||||
modality=modality,
|
||||
key="address",
|
||||
data=address,
|
||||
field=MultiModalBatchedField(),
|
||||
)
|
||||
id_elem = MultiModalFieldElem(
|
||||
modality=modality,
|
||||
key="monotonic_id",
|
||||
data=monotonic_id,
|
||||
field=MultiModalBatchedField(),
|
||||
)
|
||||
mm_item = MultiModalKwargsItem.from_elems([addr_elem, id_elem])
|
||||
return mm_item
|
||||
|
||||
|
||||
def _enable_processor_cache(
|
||||
model_config: "ModelConfig",
|
||||
mm_registry: "MultiModalRegistry",
|
||||
) -> bool:
|
||||
if not mm_registry.supports_multimodal_inputs(model_config):
|
||||
return False
|
||||
|
||||
mm_config = model_config.get_multimodal_config()
|
||||
return mm_config.mm_processor_cache_gb > 0
|
||||
|
||||
|
||||
def _enable_ipc_cache(vllm_config: "VllmConfig") -> bool:
|
||||
parallel_config = vllm_config.parallel_config
|
||||
supports_ipc_cache = (
|
||||
parallel_config._api_process_count == 1
|
||||
and parallel_config.data_parallel_size == 1
|
||||
) or parallel_config.data_parallel_external_lb
|
||||
|
||||
return supports_ipc_cache
|
||||
|
||||
|
||||
def _enable_mm_input_shm_cache(vllm_config: "VllmConfig") -> bool:
|
||||
"""Whether the shared memory based cache should be enabled."""
|
||||
|
||||
if not _enable_ipc_cache(vllm_config):
|
||||
return False
|
||||
|
||||
mm_config = vllm_config.model_config.get_multimodal_config()
|
||||
|
||||
return mm_config.mm_processor_cache_type == "shm"
|
||||
|
||||
|
||||
def processor_cache_from_config(
|
||||
vllm_config: "VllmConfig",
|
||||
mm_registry: "MultiModalRegistry",
|
||||
) -> BaseMultiModalProcessorCache | None:
|
||||
"""Return a `BaseMultiModalProcessorCache`, if enabled."""
|
||||
model_config = vllm_config.model_config
|
||||
|
||||
if not _enable_processor_cache(model_config, mm_registry):
|
||||
return None
|
||||
|
||||
if not _enable_ipc_cache(vllm_config):
|
||||
return MultiModalProcessorOnlyCache(model_config)
|
||||
|
||||
if not _enable_mm_input_shm_cache(vllm_config):
|
||||
return MultiModalProcessorSenderCache(model_config)
|
||||
return ShmObjectStoreSenderCache(vllm_config)
|
||||
|
||||
|
||||
def processor_only_cache_from_config(
|
||||
model_config: "ModelConfig",
|
||||
mm_registry: "MultiModalRegistry",
|
||||
):
|
||||
"""Return a `MultiModalProcessorOnlyCache`, if enabled."""
|
||||
if not _enable_processor_cache(model_config, mm_registry):
|
||||
return None
|
||||
|
||||
return MultiModalProcessorOnlyCache(model_config)
|
||||
|
||||
|
||||
class BaseMultiModalReceiverCache(
|
||||
BaseMultiModalCache[MultiModalKwargsItem | None, MultiModalKwargsItem]
|
||||
):
|
||||
"""The required interface for caches on P1."""
|
||||
|
||||
def get_and_update_features(
|
||||
self,
|
||||
mm_features: list["MultiModalFeatureSpec"],
|
||||
) -> list["MultiModalFeatureSpec"]:
|
||||
"""
|
||||
Update multimodal features with cached encoder outputs.
|
||||
Touch all identifier at first before update to avoid
|
||||
item in updated list evict during update.
|
||||
"""
|
||||
for feature in mm_features:
|
||||
self.touch_receiver_cache_item(feature.identifier, feature.data)
|
||||
|
||||
for feature in mm_features:
|
||||
feature.data = self.get_and_update_item(feature.data, feature.identifier)
|
||||
return mm_features
|
||||
|
||||
@abstractmethod
|
||||
def touch_receiver_cache_item(
|
||||
self,
|
||||
mm_hash: str,
|
||||
mm_item: MultiModalKwargsItem | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Update the cache eviction order for a multi-modal item.
|
||||
|
||||
This is used to touch the item in the cache without changing
|
||||
its value.
|
||||
|
||||
Args:
|
||||
mm_hash: The hash of the multi-modal item.
|
||||
mm_item: The multi-modal item itself. This is optional and
|
||||
may not be needed by some cache implementations.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MultiModalReceiverCache(BaseMultiModalReceiverCache):
|
||||
"""
|
||||
The cache which is used on P1 when IPC caching is enabled.
|
||||
|
||||
How to update each item:
|
||||
|
||||
- If the item is in the cache, replace the input with the cached item.
|
||||
- If the item is not in the cache, store that item (which includes tensor
|
||||
data) into the cache, and return the input.
|
||||
"""
|
||||
|
||||
def __init__(self, model_config: "ModelConfig") -> None:
|
||||
super().__init__()
|
||||
|
||||
mm_config = model_config.get_multimodal_config()
|
||||
|
||||
self._cache = MultiModalCache.get_lru_cache(
|
||||
mm_config.mm_processor_cache_gb,
|
||||
MultiModalKwargsItem,
|
||||
)
|
||||
|
||||
@override
|
||||
def get_and_update_item(
|
||||
self,
|
||||
mm_item: MultiModalKwargsItem | None,
|
||||
mm_hash: str,
|
||||
) -> MultiModalKwargsItem:
|
||||
if (cached_item := self._cache.get(mm_hash)) is not None:
|
||||
return cached_item
|
||||
|
||||
assert mm_item is not None, f"Expected a cached item for {mm_hash=}"
|
||||
|
||||
self._cache[mm_hash] = mm_item
|
||||
return mm_item
|
||||
|
||||
@override
|
||||
def touch_receiver_cache_item(
|
||||
self,
|
||||
mm_hash: str,
|
||||
mm_item: MultiModalKwargsItem | None = None,
|
||||
) -> None:
|
||||
self._cache.touch(mm_hash)
|
||||
|
||||
@override
|
||||
def clear_cache(self) -> None:
|
||||
self._cache.clear()
|
||||
|
||||
|
||||
class ShmObjectStoreReceiverCache(BaseMultiModalReceiverCache):
|
||||
"""
|
||||
The cache which is used on P1 Worker Process when IPC caching is enabled.
|
||||
|
||||
How to update each item:
|
||||
|
||||
- If the item has an address, replace the input with the cached item.
|
||||
- If not, return the input.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: "VllmConfig",
|
||||
shared_worker_lock: LockType,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.world_size = vllm_config.parallel_config.world_size
|
||||
mm_config = vllm_config.model_config.get_multimodal_config()
|
||||
|
||||
ring_buffer = SingleWriterShmRingBuffer(
|
||||
data_buffer_size=int(mm_config.mm_processor_cache_gb * GiB_bytes),
|
||||
name=envs.VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME,
|
||||
create=False, # Server is a reader
|
||||
)
|
||||
self._shm_cache = SingleWriterShmObjectStorage(
|
||||
max_object_size=mm_config.mm_shm_cache_max_object_size_mb * MiB_bytes,
|
||||
n_readers=self.world_size,
|
||||
ring_buffer=ring_buffer,
|
||||
serde_class=MsgpackSerde,
|
||||
reader_lock=shared_worker_lock,
|
||||
)
|
||||
|
||||
@override
|
||||
def get_and_update_item(
|
||||
self,
|
||||
mm_item: MultiModalKwargsItem | None,
|
||||
mm_hash: str,
|
||||
) -> MultiModalKwargsItem:
|
||||
assert mm_item is not None, f"Expected an address item for {mm_hash=}"
|
||||
if "address" in mm_item:
|
||||
address = cast(int, mm_item["address"].data)
|
||||
monotonic_id = cast(int, mm_item["monotonic_id"].data)
|
||||
return self._shm_cache.get(address, monotonic_id)
|
||||
|
||||
return mm_item
|
||||
|
||||
@override
|
||||
def touch_receiver_cache_item(
|
||||
self,
|
||||
mm_hash: str,
|
||||
mm_item: MultiModalKwargsItem | None = None,
|
||||
) -> None:
|
||||
"""Touch the item in shared memory cache to prevent eviction.
|
||||
Increments reader_count on receiver side."""
|
||||
assert mm_item is not None
|
||||
if "address" in mm_item:
|
||||
address = cast(int, mm_item["address"].data)
|
||||
monotonic_id = cast(int, mm_item["monotonic_id"].data)
|
||||
self._shm_cache.touch(mm_hash, address=address, monotonic_id=monotonic_id)
|
||||
|
||||
@override
|
||||
def clear_cache(self) -> None:
|
||||
self._shm_cache.clear()
|
||||
|
||||
|
||||
def engine_receiver_cache_from_config(
|
||||
vllm_config: "VllmConfig",
|
||||
mm_registry: "MultiModalRegistry",
|
||||
) -> BaseMultiModalReceiverCache | None:
|
||||
"""
|
||||
This is used in the engine process.
|
||||
Return a `BaseMultiModalReceiverCache` only when IPC caching is enabled and
|
||||
mm_processor_cache_type=="lru".
|
||||
"""
|
||||
model_config = vllm_config.model_config
|
||||
|
||||
if not _enable_processor_cache(model_config, mm_registry):
|
||||
return None
|
||||
|
||||
if not _enable_ipc_cache(vllm_config):
|
||||
return None
|
||||
|
||||
if not _enable_mm_input_shm_cache(vllm_config):
|
||||
return MultiModalReceiverCache(model_config)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def worker_receiver_cache_from_config(
|
||||
vllm_config: "VllmConfig",
|
||||
mm_registry: "MultiModalRegistry",
|
||||
shared_worker_lock: LockType,
|
||||
) -> BaseMultiModalReceiverCache | None:
|
||||
"""
|
||||
This is used in the worker process.
|
||||
Return a `BaseMultiModalReceiverCache` only when IPC caching is enabled and
|
||||
mm_processor_cache_type=="shm".
|
||||
"""
|
||||
model_config = vllm_config.model_config
|
||||
|
||||
if not _enable_processor_cache(model_config, mm_registry):
|
||||
return None
|
||||
|
||||
if not _enable_ipc_cache(vllm_config):
|
||||
return None
|
||||
|
||||
if not _enable_mm_input_shm_cache(vllm_config):
|
||||
return None
|
||||
|
||||
return ShmObjectStoreReceiverCache(vllm_config, shared_worker_lock)
|
||||
294
vllm/multimodal/evs.py
Normal file
294
vllm/multimodal/evs.py
Normal file
@@ -0,0 +1,294 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
# and proprietary rights in and to this software, related documentation
|
||||
# and any modifications thereto. Any use, reproduction, disclosure or
|
||||
# distribution of this software and related documentation without an express
|
||||
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
import typing
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def compute_retained_tokens_count(
|
||||
tokens_per_frame: int, num_frames: int, q: float
|
||||
) -> int:
|
||||
"""
|
||||
Compute the number of retained tokens for a given video.
|
||||
Method ensures that we retain all the tokens from the first frame
|
||||
regardless of the pruning rate.
|
||||
|
||||
Args:
|
||||
tokens_per_frame: The number of tokens per frame.
|
||||
num_frames: The total number of frames.
|
||||
q: The pruning rate.
|
||||
|
||||
Returns:
|
||||
The number of retained tokens.
|
||||
"""
|
||||
total_tokens = tokens_per_frame * num_frames
|
||||
evs_num_tokens = int(total_tokens * (1 - q))
|
||||
min_num_tokens = tokens_per_frame
|
||||
return max(min_num_tokens, evs_num_tokens)
|
||||
|
||||
|
||||
def compute_retention_mask(
|
||||
video_embeds: torch.Tensor,
|
||||
video_size_thw: torch.LongTensor | tuple[int, int, int],
|
||||
spatial_merge_size: int,
|
||||
q: float,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Computes the retention mask for input video embeddings.
|
||||
|
||||
Args:
|
||||
video_embeds (`torch.Tensor`): The input video embeddings
|
||||
of shape `(T * H * W // spatial_merge_size ^ 2, hidden_size)`
|
||||
video_size_thw (`torch.LongTensor` of shape `(3)`):
|
||||
The temporal, height and width of video.
|
||||
spatial_merge_size: Size reduction for rows & cols dimensions.
|
||||
q: (`float`): Pruning rate factor [0,1)
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: The retention mask for the video embeddings of
|
||||
`(T * H * W // spatial_merge_size ^ 2)` shape.
|
||||
"""
|
||||
T, H, W = map(int, video_size_thw)
|
||||
|
||||
# Use reshape instead of einops to avoid graph breaks
|
||||
video_embeds = video_embeds.reshape(
|
||||
T,
|
||||
H // spatial_merge_size,
|
||||
W // spatial_merge_size,
|
||||
video_embeds.size(-1),
|
||||
)
|
||||
tokens_per_frame = (H // spatial_merge_size) * (W // spatial_merge_size)
|
||||
# Core EVS
|
||||
similarity = torch.nn.functional.cosine_similarity(
|
||||
video_embeds[1:, ...], video_embeds[:-1, ...], dim=-1
|
||||
)
|
||||
dissimilarity = 1 - similarity
|
||||
|
||||
# Always ensure we include all tokens from the first frame
|
||||
dissimilarity = torch.cat(
|
||||
[255 * torch.ones_like(video_embeds[:1, :, :, 0]), dissimilarity], dim=0
|
||||
)
|
||||
|
||||
dissimilarity_flat = dissimilarity.view(-1)
|
||||
order = torch.argsort(dissimilarity_flat, dim=-1, descending=True, stable=True)
|
||||
retain_num_tokens = compute_retained_tokens_count(
|
||||
tokens_per_frame=tokens_per_frame, num_frames=T, q=q
|
||||
)
|
||||
topk_indices = order[:retain_num_tokens]
|
||||
|
||||
retention_mask = torch.zeros_like(dissimilarity_flat, dtype=torch.bool)
|
||||
retention_mask[topk_indices] = True
|
||||
retention_mask = retention_mask.reshape(dissimilarity.size())
|
||||
|
||||
mask = retention_mask.view(-1) # "T H W -> (T H W)"
|
||||
return mask
|
||||
|
||||
|
||||
def compute_mrope_for_media(
|
||||
video_size_thw: torch.LongTensor,
|
||||
spatial_merge_size: int,
|
||||
tokens_per_second: float = 1.0,
|
||||
video_second_per_grid: float = 1.0,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Computes the mrope for video embeddings based on the grid dimensions.
|
||||
Computed mrope positions match original qwen 2.5 implementation,
|
||||
but positions are built for media being the first element in sequence.
|
||||
|
||||
Args:
|
||||
video_size_thw: Media size (num frames, rows, cols)
|
||||
spatial_merge_size: Size reduction for rows & cols dimensions.
|
||||
tokens_per_second: Number of tokens per second.
|
||||
video_second_per_grid: Number of seconds per video.
|
||||
|
||||
Returns:
|
||||
Tensor of shape `(T * H * W, 4)` where last dimension
|
||||
represents mrope positions [0:3), while the last channel
|
||||
contains value of llm_grid_w repeated for all positions.
|
||||
"""
|
||||
llm_grid_t = video_size_thw[0]
|
||||
llm_grid_h = video_size_thw[1] // spatial_merge_size
|
||||
llm_grid_w = video_size_thw[2] // spatial_merge_size
|
||||
|
||||
t_index = (
|
||||
(
|
||||
torch.arange(llm_grid_t)
|
||||
.view(-1, 1)
|
||||
.expand(-1, llm_grid_h * llm_grid_w)
|
||||
.mul(tokens_per_second * video_second_per_grid)
|
||||
)
|
||||
.long()
|
||||
.flatten()
|
||||
)
|
||||
h_index = (
|
||||
torch.arange(llm_grid_h)
|
||||
.view(1, -1, 1)
|
||||
.expand(llm_grid_t, -1, llm_grid_w)
|
||||
.flatten()
|
||||
)
|
||||
w_index = (
|
||||
torch.arange(llm_grid_w)
|
||||
.view(1, 1, -1)
|
||||
.expand(llm_grid_t, llm_grid_h, -1)
|
||||
.flatten()
|
||||
)
|
||||
llm_grid_w = (
|
||||
torch.tensor([llm_grid_w])
|
||||
.view(1, 1, 1)
|
||||
.expand(llm_grid_t, llm_grid_h, llm_grid_w)
|
||||
.flatten()
|
||||
)
|
||||
|
||||
positions = torch.stack([t_index, h_index, w_index, llm_grid_w], dim=1)
|
||||
return positions
|
||||
|
||||
|
||||
def recompute_mrope_positions(
|
||||
input_ids: torch.LongTensor,
|
||||
multimodal_positions: list[torch.Tensor],
|
||||
mrope_positions: torch.LongTensor,
|
||||
num_computed_tokens: int,
|
||||
vision_start_token_id: int,
|
||||
image_token_id: int,
|
||||
video_token_id: int,
|
||||
) -> tuple[torch.LongTensor, int]:
|
||||
"""
|
||||
Update part of input mrope positions.
|
||||
Original mrope_positions are computed incorrectly, so once we prune media
|
||||
tokens we should reflect this in the mrope positions for the LLM.
|
||||
|
||||
This method supports chunked prefill approach where
|
||||
multimodal_embeddings are passed to LLM in chunks, so input
|
||||
multimodal_embeddings may contain zero, some or even some part of all
|
||||
multimodal_embeddings for a given prompt.
|
||||
|
||||
Each multimodal_positions has 4 extra channels
|
||||
(First 3 channels corresponds to original 3 mrope positions, last channel
|
||||
is the maximum width of the media repeated). Provided multimodal_positions
|
||||
do not reflect location of media position in sequence - they are computed
|
||||
like the media is in the 0-th position in the sequence.
|
||||
|
||||
Method works as follows: it recomputes mrope_positions starting from the
|
||||
`num_computed_tokens` for `total_len_of_multimodal_embeddings` and then
|
||||
shifts all text tokens that goes after total_len_of_multimodal_embeddings.
|
||||
|
||||
It also handles case when multimodal_embeddings is partial
|
||||
(e.g. one media is split into two prefill stages)
|
||||
|
||||
Args:
|
||||
input_ids: (N,) All input tokens of the prompt (entire sequence).
|
||||
multimodal_positions: List of mrope positions for each media.
|
||||
mrope_positions: Existing mrope positions (4, N) for entire sequence.
|
||||
num_computed_tokens: A number of computed tokens so far.
|
||||
vision_start_token_id: Token indicating start of vision media.
|
||||
image_token_id: Image token id
|
||||
video_token_id: Video token id
|
||||
|
||||
Returns:
|
||||
Tuple of (mrope_positions, mrope_position_delta).
|
||||
"""
|
||||
|
||||
# Tensors
|
||||
positions: torch.LongTensor = typing.cast(
|
||||
torch.LongTensor, mrope_positions.clone()
|
||||
) # (3, N)
|
||||
N = input_ids.numel()
|
||||
|
||||
image_mask = input_ids.eq(image_token_id)
|
||||
video_mask = input_ids.eq(video_token_id)
|
||||
media_mask = image_mask | video_mask
|
||||
text_mask = ~media_mask
|
||||
|
||||
# Early exit: no media in this chunk
|
||||
if len(multimodal_positions) == 0:
|
||||
delta = int((positions.max().item() + 1) - N) if positions.numel() else -N
|
||||
return positions, delta
|
||||
|
||||
total_mm_tokens = torch.count_nonzero(media_mask)
|
||||
seen_mm_tokens = torch.count_nonzero(media_mask[:num_computed_tokens])
|
||||
|
||||
# Early exit: we've updated positions for all media tokens
|
||||
# (and consequently - for all remaining text tokens)
|
||||
if seen_mm_tokens == total_mm_tokens:
|
||||
delta = int((positions.max().item() + 1) - N) if positions.numel() else -N
|
||||
return positions, delta
|
||||
|
||||
vision_start_indices = (input_ids == vision_start_token_id).nonzero(as_tuple=True)[
|
||||
0
|
||||
]
|
||||
|
||||
for mm_pos in multimodal_positions:
|
||||
# Each mm_pos can be a complete embedding for single media
|
||||
# or it can be a part of a single media (due to chunked prefill)
|
||||
|
||||
# Cases to cover
|
||||
# - Current prefill chunk has no vision start indexes at all
|
||||
# - Vision start token appeared in previous prefill round
|
||||
# - Regular case
|
||||
seen_vision_start_indices = vision_start_indices[
|
||||
vision_start_indices < num_computed_tokens
|
||||
]
|
||||
|
||||
if len(seen_vision_start_indices):
|
||||
# If we have encountered some vision start indexes,
|
||||
# then we should check the condition:
|
||||
# | --- prefill 1 ------| ---- prefill 2 ----- |
|
||||
# | TTTTTTTTTSVVVVVVVVVV|VVVVVVTTTTTTTTTTTTTTTT|
|
||||
last_vision_start_token = seen_vision_start_indices[-1]
|
||||
seem_mm_tokens_before_last_vision_start = torch.count_nonzero(
|
||||
media_mask[:last_vision_start_token]
|
||||
)
|
||||
in_the_middle_of_media = (
|
||||
seen_mm_tokens > seem_mm_tokens_before_last_vision_start
|
||||
)
|
||||
|
||||
if in_the_middle_of_media:
|
||||
mm_embeddings_seen = (
|
||||
seen_mm_tokens - seem_mm_tokens_before_last_vision_start
|
||||
)
|
||||
global_mm_start = last_vision_start_token
|
||||
else:
|
||||
# We have completed previous mm_embedding part and
|
||||
# ready to start a new one
|
||||
next_vision_start_token = vision_start_indices[
|
||||
vision_start_indices >= num_computed_tokens
|
||||
][0]
|
||||
mm_embeddings_seen = 0
|
||||
global_mm_start = next_vision_start_token
|
||||
|
||||
else:
|
||||
# If there were no vision start indexes so far,
|
||||
# let's find first vision start index
|
||||
next_vision_start_token = vision_start_indices[
|
||||
vision_start_indices >= num_computed_tokens
|
||||
][0]
|
||||
|
||||
mm_embeddings_seen = 0
|
||||
global_mm_start = next_vision_start_token
|
||||
|
||||
# Offset right after vision_start_token
|
||||
base = positions[-1, global_mm_start] + 1
|
||||
local_start = global_mm_start + 1 + mm_embeddings_seen
|
||||
local_end = local_start + mm_pos.shape[1]
|
||||
positions[:, local_start:local_end] = mm_pos[0:3] + base
|
||||
|
||||
# mm_pos[3, 0] is the max width of the media
|
||||
offset = mm_pos[3, 0] + base
|
||||
|
||||
text_pos_sum = torch.cumsum(text_mask[local_end:].long(), dim=0)
|
||||
|
||||
positions[:, local_end:N] = text_pos_sum + offset - 1
|
||||
|
||||
# Include distance to the next vision start token
|
||||
num_computed_tokens += mm_pos.shape[1]
|
||||
|
||||
mrope_positions_delta = (positions.max() + 1 - N).item()
|
||||
return positions, mrope_positions_delta
|
||||
120
vllm/multimodal/hasher.py
Normal file
120
vllm/multimodal/hasher.py
Normal file
@@ -0,0 +1,120 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pickle
|
||||
import uuid
|
||||
from collections.abc import Iterable
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from blake3 import blake3
|
||||
from PIL import Image
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .base import MediaWithBytes
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class MultiModalHasher:
|
||||
@classmethod
|
||||
def serialize_item(cls, obj: object) -> Iterable[bytes | memoryview]:
|
||||
# Simple cases
|
||||
if isinstance(obj, (bytes, memoryview)):
|
||||
return (obj,)
|
||||
if isinstance(obj, str):
|
||||
return (obj.encode("utf-8"),)
|
||||
if isinstance(obj, (int, float)):
|
||||
return (np.array(obj).tobytes(),)
|
||||
|
||||
if isinstance(obj, Image.Image):
|
||||
exif = obj.getexif()
|
||||
if Image.ExifTags.Base.ImageID in exif and isinstance(
|
||||
exif[Image.ExifTags.Base.ImageID], uuid.UUID
|
||||
):
|
||||
return (exif[Image.ExifTags.Base.ImageID].bytes,)
|
||||
|
||||
data = {"mode": obj.mode, "data": np.asarray(obj)}
|
||||
palette = obj.palette
|
||||
if palette is not None:
|
||||
data["palette"] = palette.palette
|
||||
if palette.rawmode is not None:
|
||||
data["palette_rawmode"] = palette.rawmode
|
||||
|
||||
return cls.iter_item_to_bytes("image", data)
|
||||
|
||||
if isinstance(obj, MediaWithBytes) and isinstance(obj.media, Image.Image):
|
||||
exif = obj.media.getexif()
|
||||
if Image.ExifTags.Base.ImageID in exif and isinstance(
|
||||
exif[Image.ExifTags.Base.ImageID], uuid.UUID
|
||||
):
|
||||
return (exif[Image.ExifTags.Base.ImageID].bytes,)
|
||||
|
||||
return cls.iter_item_to_bytes("image", obj.original_bytes)
|
||||
|
||||
if isinstance(obj, torch.Tensor):
|
||||
tensor_obj: torch.Tensor = obj.cpu()
|
||||
tensor_dtype = tensor_obj.dtype
|
||||
tensor_shape = tensor_obj.shape
|
||||
|
||||
# NumPy does not support bfloat16.
|
||||
# Workaround: View the tensor as a contiguous 1D array of bytes
|
||||
if tensor_dtype == torch.bfloat16:
|
||||
tensor_obj = tensor_obj.contiguous()
|
||||
tensor_obj = tensor_obj.view((tensor_obj.numel(),)).view(torch.uint8)
|
||||
|
||||
return cls.iter_item_to_bytes(
|
||||
"tensor",
|
||||
{
|
||||
"original_dtype": str(tensor_dtype),
|
||||
"original_shape": tuple(tensor_shape),
|
||||
"data": tensor_obj.numpy(),
|
||||
},
|
||||
)
|
||||
return cls.iter_item_to_bytes("tensor", tensor_obj.numpy())
|
||||
if isinstance(obj, np.ndarray):
|
||||
# If the array is non-contiguous, we need to copy it first
|
||||
arr_data = (
|
||||
obj.view(np.uint8).data if obj.flags.c_contiguous else obj.tobytes()
|
||||
)
|
||||
return cls.iter_item_to_bytes(
|
||||
"ndarray",
|
||||
{
|
||||
"dtype": obj.dtype.str,
|
||||
"shape": obj.shape,
|
||||
"data": arr_data,
|
||||
},
|
||||
)
|
||||
logger.warning(
|
||||
"No serialization method found for %s. Falling back to pickle.", type(obj)
|
||||
)
|
||||
|
||||
return (pickle.dumps(obj),)
|
||||
|
||||
@classmethod
|
||||
def iter_item_to_bytes(
|
||||
cls,
|
||||
key: str,
|
||||
obj: object,
|
||||
) -> Iterable[bytes | memoryview]:
|
||||
# Recursive cases
|
||||
if isinstance(obj, (list, tuple)):
|
||||
for i, elem in enumerate(obj):
|
||||
yield from cls.iter_item_to_bytes(f"{key}.{i}", elem)
|
||||
elif isinstance(obj, dict):
|
||||
for k, v in obj.items():
|
||||
yield from cls.iter_item_to_bytes(f"{key}.{k}", v)
|
||||
else:
|
||||
yield key.encode("utf-8")
|
||||
yield from cls.serialize_item(obj)
|
||||
|
||||
@classmethod
|
||||
def hash_kwargs(cls, **kwargs: object) -> str:
|
||||
hasher = blake3()
|
||||
|
||||
for k, v in kwargs.items():
|
||||
for bytes_ in cls.iter_item_to_bytes(k, v):
|
||||
hasher.update(bytes_)
|
||||
|
||||
return hasher.hexdigest()
|
||||
142
vllm/multimodal/image.py
Normal file
142
vllm/multimodal/image.py
Normal file
@@ -0,0 +1,142 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
|
||||
import pybase64
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from .base import MediaIO, MediaWithBytes
|
||||
|
||||
|
||||
def rescale_image_size(
|
||||
image: Image.Image, size_factor: float, transpose: int = -1
|
||||
) -> Image.Image:
|
||||
"""Rescale the dimensions of an image by a constant factor."""
|
||||
new_width = int(image.width * size_factor)
|
||||
new_height = int(image.height * size_factor)
|
||||
image = image.resize((new_width, new_height))
|
||||
if transpose >= 0:
|
||||
image = image.transpose(Image.Transpose(transpose))
|
||||
return image
|
||||
|
||||
|
||||
def rgba_to_rgb(
|
||||
image: Image.Image,
|
||||
background_color: tuple[int, int, int] | list[int] = (255, 255, 255),
|
||||
) -> Image.Image:
|
||||
"""Convert an RGBA image to RGB with filled background color."""
|
||||
assert image.mode == "RGBA"
|
||||
converted = Image.new("RGB", image.size, background_color)
|
||||
converted.paste(image, mask=image.split()[3]) # 3 is the alpha channel
|
||||
return converted
|
||||
|
||||
|
||||
def convert_image_mode(image: Image.Image, to_mode: str):
|
||||
if image.mode == to_mode:
|
||||
return image
|
||||
elif image.mode == "RGBA" and to_mode == "RGB":
|
||||
return rgba_to_rgb(image)
|
||||
else:
|
||||
return image.convert(to_mode)
|
||||
|
||||
|
||||
class ImageMediaIO(MediaIO[Image.Image]):
|
||||
def __init__(self, image_mode: str = "RGB", **kwargs) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.image_mode = image_mode
|
||||
# `kwargs` contains custom arguments from
|
||||
# --media-io-kwargs for this modality.
|
||||
# They can be passed to the underlying
|
||||
# media loaders (e.g. custom implementations)
|
||||
# for flexible control.
|
||||
self.kwargs = kwargs
|
||||
|
||||
# Extract RGBA background color from kwargs if provided
|
||||
# Default to white background for backward compatibility
|
||||
rgba_bg = kwargs.get("rgba_background_color", (255, 255, 255))
|
||||
# Convert list to tuple for consistency
|
||||
if isinstance(rgba_bg, list):
|
||||
rgba_bg = tuple(rgba_bg)
|
||||
|
||||
# Validate rgba_background_color format
|
||||
if not (
|
||||
isinstance(rgba_bg, tuple)
|
||||
and len(rgba_bg) == 3
|
||||
and all(isinstance(c, int) and 0 <= c <= 255 for c in rgba_bg)
|
||||
):
|
||||
raise ValueError(
|
||||
"rgba_background_color must be a list or tuple of 3 integers "
|
||||
"in the range [0, 255]."
|
||||
)
|
||||
self.rgba_background_color = rgba_bg
|
||||
|
||||
def _convert_image_mode(
|
||||
self, image: Image.Image | MediaWithBytes[Image.Image]
|
||||
) -> Image.Image:
|
||||
"""Convert image mode with custom background color."""
|
||||
if isinstance(image, MediaWithBytes):
|
||||
image = image.media
|
||||
if image.mode == self.image_mode:
|
||||
return image
|
||||
elif image.mode == "RGBA" and self.image_mode == "RGB":
|
||||
return rgba_to_rgb(image, self.rgba_background_color)
|
||||
else:
|
||||
return convert_image_mode(image, self.image_mode)
|
||||
|
||||
def load_bytes(self, data: bytes) -> MediaWithBytes[Image.Image]:
|
||||
image = Image.open(BytesIO(data))
|
||||
return MediaWithBytes(self._convert_image_mode(image), data)
|
||||
|
||||
def load_base64(self, media_type: str, data: str) -> MediaWithBytes[Image.Image]:
|
||||
return self.load_bytes(pybase64.b64decode(data, validate=True))
|
||||
|
||||
def load_file(self, filepath: Path) -> MediaWithBytes[Image.Image]:
|
||||
with open(filepath, "rb") as f:
|
||||
data = f.read()
|
||||
image = Image.open(BytesIO(data))
|
||||
return MediaWithBytes(self._convert_image_mode(image), data)
|
||||
|
||||
def encode_base64(
|
||||
self,
|
||||
media: Image.Image,
|
||||
*,
|
||||
image_format: str = "JPEG",
|
||||
) -> str:
|
||||
image = media
|
||||
|
||||
with BytesIO() as buffer:
|
||||
image = self._convert_image_mode(image)
|
||||
image.save(buffer, image_format)
|
||||
data = buffer.getvalue()
|
||||
|
||||
return pybase64.b64encode(data).decode("utf-8")
|
||||
|
||||
|
||||
class ImageEmbeddingMediaIO(MediaIO[torch.Tensor]):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def load_bytes(self, data: bytes) -> torch.Tensor:
|
||||
buffer = BytesIO(data)
|
||||
# Enable sparse tensor integrity checks to prevent out-of-bounds
|
||||
# writes from maliciously crafted tensors
|
||||
with torch.sparse.check_sparse_tensor_invariants():
|
||||
tensor = torch.load(buffer, weights_only=True)
|
||||
return tensor.to_dense()
|
||||
|
||||
def load_base64(self, media_type: str, data: str) -> torch.Tensor:
|
||||
return self.load_bytes(pybase64.b64decode(data, validate=True))
|
||||
|
||||
def load_file(self, filepath: Path) -> torch.Tensor:
|
||||
# Enable sparse tensor integrity checks to prevent out-of-bounds
|
||||
# writes from maliciously crafted tensors
|
||||
with torch.sparse.check_sparse_tensor_invariants():
|
||||
tensor = torch.load(filepath, weights_only=True)
|
||||
return tensor.to_dense()
|
||||
|
||||
def encode_base64(self, media: torch.Tensor) -> str:
|
||||
return pybase64.b64encode(media.numpy()).decode("utf-8")
|
||||
1089
vllm/multimodal/inputs.py
Normal file
1089
vllm/multimodal/inputs.py
Normal file
File diff suppressed because it is too large
Load Diff
565
vllm/multimodal/parse.py
Normal file
565
vllm/multimodal/parse.py
Normal file
@@ -0,0 +1,565 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import UserDict
|
||||
from collections.abc import Callable, Iterator, Mapping, Sequence
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Generic,
|
||||
Literal,
|
||||
NamedTuple,
|
||||
TypeAlias,
|
||||
TypeGuard,
|
||||
TypeVar,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from vllm.utils.collection_utils import is_list_of
|
||||
from vllm.utils.import_utils import LazyLoader
|
||||
|
||||
from .audio import AudioResampler
|
||||
from .base import MediaWithBytes
|
||||
from .inputs import (
|
||||
AudioItem,
|
||||
HfAudioItem,
|
||||
HfImageItem,
|
||||
HfVideoItem,
|
||||
ImageItem,
|
||||
ModalityData,
|
||||
MultiModalDataDict,
|
||||
MultiModalFieldConfig,
|
||||
MultiModalKwargsItems,
|
||||
VideoItem,
|
||||
)
|
||||
|
||||
_T = TypeVar("_T")
|
||||
_I = TypeVar("_I")
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import PIL.Image as PILImage
|
||||
else:
|
||||
PILImage = LazyLoader("PILImage", globals(), "PIL.Image")
|
||||
|
||||
|
||||
class ModalityDataItems(ABC, Generic[_T, _I]):
|
||||
"""
|
||||
Represents data items for a modality in
|
||||
[`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems].
|
||||
"""
|
||||
|
||||
def __init__(self, data: _T, modality: str) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.data: _T = data
|
||||
self.modality = modality
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{type(self).__name__}(modality={self.modality!r}, len={len(self)})"
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.get_count()
|
||||
|
||||
def __getitem__(self, index: int) -> _I:
|
||||
return self.get(index)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# Auto-generated
|
||||
def __iter__(self) -> Iterator[_I]: ...
|
||||
|
||||
@abstractmethod
|
||||
def get_count(self) -> int:
|
||||
"""Get the number of data items."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get(self, index: int) -> _I:
|
||||
"""Get a data item by its index."""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_all(self) -> list[_I]:
|
||||
"""Get all data items."""
|
||||
return [self.get(idx) for idx in range(self.get_count())]
|
||||
|
||||
def get_item_for_hash(self, index: int) -> object:
|
||||
return self.get(index)
|
||||
|
||||
def get_all_items_for_hash(self) -> list[object]:
|
||||
return [self.get_item_for_hash(idx) for idx in range(self.get_count())]
|
||||
|
||||
@abstractmethod
|
||||
def get_processor_data(self) -> Mapping[str, object]:
|
||||
"""Get the data to pass to the HF processor."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_passthrough_data(self) -> Mapping[str, object]:
|
||||
"""Get the data to pass directly to the model."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ProcessorBatchItems(ModalityDataItems[Sequence[_T], _T]):
|
||||
"""Base class for data items that are arranged in a list."""
|
||||
|
||||
def _unwrap(self, item: _T | MediaWithBytes[_T]) -> _T:
|
||||
"""Extract media from wrapper if present."""
|
||||
return item.media if isinstance(item, MediaWithBytes) else item
|
||||
|
||||
def get_count(self) -> int:
|
||||
return len(self.data)
|
||||
|
||||
def get(self, index: int) -> _T:
|
||||
return self._unwrap(self.data[index])
|
||||
|
||||
def get_item_for_hash(self, index: int) -> _T | MediaWithBytes[_T]:
|
||||
# Return raw item for hashing (preserves original_bytes if present)
|
||||
return self.data[index]
|
||||
|
||||
def get_processor_data(self) -> Mapping[str, object]:
|
||||
return {f"{self.modality}s": self.get_all()}
|
||||
|
||||
def get_passthrough_data(self) -> Mapping[str, object]:
|
||||
return {}
|
||||
|
||||
|
||||
class EmbeddingItems(
|
||||
ModalityDataItems[torch.Tensor | list[torch.Tensor], torch.Tensor]
|
||||
):
|
||||
"""
|
||||
Base class for data items that are expressed as a batched embedding tensor,
|
||||
or a list of embedding tensors (one per item).
|
||||
"""
|
||||
|
||||
def _unwrap(
|
||||
self, item: torch.Tensor | MediaWithBytes[torch.Tensor]
|
||||
) -> torch.Tensor:
|
||||
"""Extract media from wrapper if present."""
|
||||
return item.media if isinstance(item, MediaWithBytes) else item
|
||||
|
||||
def get_count(self) -> int:
|
||||
return len(self.data)
|
||||
|
||||
def get(self, index: int) -> torch.Tensor:
|
||||
return self._unwrap(self.data[index])
|
||||
|
||||
def get_processor_data(self) -> Mapping[str, object]:
|
||||
return {}
|
||||
|
||||
def get_passthrough_data(self) -> Mapping[str, object]:
|
||||
return {f"{self.modality}_embeds": self.data}
|
||||
|
||||
def get_feature_size(self, item_idx: int) -> int:
|
||||
return len(self.get(item_idx))
|
||||
|
||||
|
||||
class DictEmbeddingItems(
|
||||
ModalityDataItems[Mapping[str, torch.Tensor], Mapping[str, torch.Tensor]]
|
||||
):
|
||||
"""
|
||||
Base class for data items that are expressed as a dictionary of tensors.
|
||||
|
||||
Usually, the dictionary keys correspond to the outputs of HF processor.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data: Mapping[str, torch.Tensor],
|
||||
modality: str,
|
||||
required_fields: set[str],
|
||||
fields_factory: Callable[
|
||||
[Mapping[str, torch.Tensor]],
|
||||
Mapping[str, MultiModalFieldConfig],
|
||||
],
|
||||
) -> None:
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
|
||||
super().__init__(data, modality)
|
||||
|
||||
missing_required_data_keys = required_fields - data.keys()
|
||||
if missing_required_data_keys:
|
||||
data_keys = set(data.keys())
|
||||
msg = (
|
||||
f"The data should contain the fields: {required_fields}, "
|
||||
f"but only found the following keys: {data_keys}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
fields_config = fields_factory(data)
|
||||
missing_required_fields = required_fields - fields_config.keys()
|
||||
if missing_required_fields:
|
||||
fields = set(fields_config.keys())
|
||||
msg = f"{required_fields=} should be a subset of {fields=}"
|
||||
raise ValueError(msg)
|
||||
|
||||
self.fields_config = fields_config
|
||||
self.required_fields = required_fields
|
||||
|
||||
self._kwargs = MultiModalKwargsItems.from_hf_inputs(
|
||||
BatchFeature(dict(data)),
|
||||
fields_config,
|
||||
)
|
||||
|
||||
def get_count(self) -> int:
|
||||
return len(self._kwargs[self.modality])
|
||||
|
||||
def get(self, index: int) -> Mapping[str, torch.Tensor]:
|
||||
return self._kwargs[self.modality][index].get_data()
|
||||
|
||||
def get_processor_data(self) -> Mapping[str, object]:
|
||||
return {}
|
||||
|
||||
def get_passthrough_data(self) -> Mapping[str, object]:
|
||||
return self.data
|
||||
|
||||
|
||||
class AudioProcessorItems(ProcessorBatchItems[HfAudioItem]):
|
||||
def __init__(self, data: Sequence[HfAudioItem] | None) -> None:
|
||||
if data is None:
|
||||
data = [None]
|
||||
super().__init__(data, "audio")
|
||||
|
||||
def get_audio_length(self, item_idx: int) -> int:
|
||||
audio = self.get(item_idx)
|
||||
return len(audio)
|
||||
|
||||
|
||||
class AudioEmbeddingItems(EmbeddingItems):
|
||||
def __init__(self, data: torch.Tensor | list[torch.Tensor]) -> None:
|
||||
super().__init__(data, "audio")
|
||||
|
||||
|
||||
class ImageSize(NamedTuple):
|
||||
width: int
|
||||
height: int
|
||||
|
||||
|
||||
class ImageProcessorItems(ProcessorBatchItems[HfImageItem]):
|
||||
def __init__(self, data: Sequence[HfImageItem] | None) -> None:
|
||||
if data is None:
|
||||
data = [None]
|
||||
super().__init__(data, "image")
|
||||
|
||||
def get_image_size(self, item_idx: int) -> ImageSize:
|
||||
image = self.get(item_idx)
|
||||
|
||||
if isinstance(image, PILImage.Image):
|
||||
return ImageSize(*image.size)
|
||||
if isinstance(image, (np.ndarray, torch.Tensor)):
|
||||
_, h, w = image.shape
|
||||
return ImageSize(w, h)
|
||||
|
||||
assert_never(image)
|
||||
|
||||
|
||||
class ImageEmbeddingItems(EmbeddingItems):
|
||||
def __init__(self, data: torch.Tensor | list[torch.Tensor]) -> None:
|
||||
super().__init__(data, "image")
|
||||
|
||||
|
||||
class VideoProcessorItems(ProcessorBatchItems[HfVideoItem]):
|
||||
def __init__(
|
||||
self,
|
||||
data: Sequence[HfVideoItem] | None,
|
||||
metadata: dict[str, Any] | list[dict[str, Any] | None] | None = None,
|
||||
) -> None:
|
||||
if data is None:
|
||||
data = [None]
|
||||
super().__init__(data, "video")
|
||||
self.metadata = metadata
|
||||
|
||||
def get_num_frames(self, item_idx: int) -> int:
|
||||
return len(self.get(item_idx))
|
||||
|
||||
def get_frame_size(self, item_idx: int) -> ImageSize:
|
||||
image = self.get(item_idx)[0] # Assume that the video isn't empty
|
||||
|
||||
if isinstance(image, PILImage.Image):
|
||||
return ImageSize(*image.size)
|
||||
if isinstance(image, (np.ndarray, torch.Tensor)):
|
||||
_, h, w = image.shape
|
||||
return ImageSize(w, h)
|
||||
|
||||
assert_never(image)
|
||||
|
||||
|
||||
class VideoEmbeddingItems(EmbeddingItems):
|
||||
def __init__(self, data: torch.Tensor | list[torch.Tensor]) -> None:
|
||||
super().__init__(data, "video")
|
||||
|
||||
|
||||
_D = TypeVar("_D", bound=ModalityDataItems[Any, Any])
|
||||
|
||||
|
||||
class MultiModalDataItems(UserDict[str, ModalityDataItems[Any, Any]]):
|
||||
"""
|
||||
As [`MultiModalDataDict`][vllm.multimodal.inputs.MultiModalDataDict], but
|
||||
normalized such that each entry corresponds to a list.
|
||||
"""
|
||||
|
||||
def get_count(self, modality: str, *, strict: bool = True) -> int:
|
||||
"""
|
||||
Get the number of data items belonging to a modality.
|
||||
|
||||
If `strict=False`, return `0` instead of raising [`KeyError`][]
|
||||
even if the modality is not found.
|
||||
"""
|
||||
if modality not in self:
|
||||
if strict:
|
||||
available_modalities = set(self.keys())
|
||||
raise KeyError(
|
||||
f"Modality {modality!r} not found. "
|
||||
f"Available modalities: {available_modalities}"
|
||||
)
|
||||
|
||||
return 0
|
||||
|
||||
return self[modality].get_count()
|
||||
|
||||
def get_all_counts(self) -> Mapping[str, int]:
|
||||
"""Get the number of items belonging to each modality."""
|
||||
return {m: items.get_count() for m, items in self.items()}
|
||||
|
||||
def get_items(
|
||||
self,
|
||||
modality: str,
|
||||
typ: type[_D] | tuple[type[_D], ...],
|
||||
) -> _D:
|
||||
"""
|
||||
Get the data items belonging to a modality,
|
||||
requiring that they belong to a certain type.
|
||||
"""
|
||||
if modality not in self:
|
||||
available_modalities = set(self.keys())
|
||||
raise KeyError(
|
||||
f"Modality {modality!r} not found. "
|
||||
f"Available modalities: {available_modalities}"
|
||||
)
|
||||
|
||||
items = self[modality]
|
||||
if not isinstance(items, typ):
|
||||
raise TypeError(
|
||||
f"Invalid type of data items for {modality=}. "
|
||||
f"Expected type: {typ}, but "
|
||||
f"found type: {type(items)}"
|
||||
)
|
||||
|
||||
return items # type: ignore[return-value]
|
||||
|
||||
|
||||
ModalityDataParser: TypeAlias = Callable[
|
||||
[ModalityData[Any]], ModalityDataItems[Any, Any] | None
|
||||
]
|
||||
|
||||
|
||||
class MultiModalDataParser:
|
||||
"""
|
||||
Parses [`MultiModalDataDict`][vllm.multimodal.inputs.MultiModalDataDict]
|
||||
into [`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems].
|
||||
|
||||
Args:
|
||||
target_sr (float, optional): Enables automatic resampling of audio
|
||||
items to the model's expected sampling rate.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
target_sr: float | None = None,
|
||||
audio_resample_method: Literal["librosa", "scipy"] = "librosa",
|
||||
video_needs_metadata: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.audio_resampler = AudioResampler(
|
||||
target_sr=target_sr,
|
||||
method=audio_resample_method,
|
||||
)
|
||||
self.video_needs_metadata = video_needs_metadata
|
||||
|
||||
@classmethod
|
||||
def is_embeddings(
|
||||
cls, data: object
|
||||
) -> TypeGuard[torch.Tensor | list[torch.Tensor]]:
|
||||
if isinstance(data, torch.Tensor):
|
||||
return data.ndim == 3
|
||||
if is_list_of(data, torch.Tensor):
|
||||
return data[0].ndim == 2 # type: ignore[index]
|
||||
|
||||
return False
|
||||
|
||||
def _is_empty(self, data: object) -> TypeGuard[None]:
|
||||
if isinstance(data, list):
|
||||
return len(data) == 0
|
||||
if isinstance(data, (np.ndarray, torch.Tensor)):
|
||||
return data.size == 0
|
||||
|
||||
return False
|
||||
|
||||
def _get_audio_with_sr(
|
||||
self,
|
||||
audio: AudioItem,
|
||||
) -> tuple[np.ndarray, float | None]:
|
||||
if isinstance(audio, tuple):
|
||||
return audio
|
||||
if isinstance(audio, list):
|
||||
return np.array(audio), None
|
||||
if isinstance(audio, np.ndarray):
|
||||
return audio, None
|
||||
if isinstance(audio, torch.Tensor):
|
||||
return audio.numpy(), None
|
||||
|
||||
assert_never(audio)
|
||||
|
||||
def _get_video_with_metadata(
|
||||
self,
|
||||
video: VideoItem,
|
||||
) -> tuple[np.ndarray, dict[str, Any] | None]:
|
||||
if isinstance(video, tuple):
|
||||
return video
|
||||
if isinstance(video, list):
|
||||
return np.array(video), None
|
||||
if isinstance(video, np.ndarray):
|
||||
return video, None
|
||||
if isinstance(video, torch.Tensor):
|
||||
return video.numpy(), None
|
||||
|
||||
assert_never(video)
|
||||
|
||||
def _parse_audio_data(
|
||||
self,
|
||||
data: ModalityData[AudioItem],
|
||||
) -> ModalityDataItems[Any, Any] | None:
|
||||
if data is None:
|
||||
return AudioProcessorItems(None)
|
||||
|
||||
# also check single audio item with sampling rate
|
||||
if self._is_empty(data) or (
|
||||
isinstance(data, tuple) and self._is_empty(data[0])
|
||||
):
|
||||
return None
|
||||
|
||||
if self.is_embeddings(data):
|
||||
return AudioEmbeddingItems(data)
|
||||
|
||||
data_items: list[AudioItem]
|
||||
if (
|
||||
is_list_of(data, float)
|
||||
or isinstance(data, (np.ndarray, torch.Tensor))
|
||||
and data.ndim == 1
|
||||
or isinstance(data, tuple)
|
||||
):
|
||||
data_items = [data]
|
||||
elif isinstance(data, (np.ndarray, torch.Tensor)):
|
||||
data_items = [elem for elem in data]
|
||||
else:
|
||||
data_items = data # type: ignore[assignment]
|
||||
|
||||
new_audios = list[np.ndarray]()
|
||||
for data_item in data_items:
|
||||
audio, orig_sr = self._get_audio_with_sr(data_item)
|
||||
if orig_sr is None:
|
||||
new_audio = audio
|
||||
else:
|
||||
new_audio = self.audio_resampler.resample(audio, orig_sr=orig_sr)
|
||||
|
||||
new_audios.append(new_audio)
|
||||
|
||||
return AudioProcessorItems(new_audios)
|
||||
|
||||
def _parse_image_data(
|
||||
self,
|
||||
data: ModalityData[ImageItem],
|
||||
) -> ModalityDataItems[Any, Any] | None:
|
||||
if data is None:
|
||||
return ImageProcessorItems(None)
|
||||
|
||||
if self._is_empty(data):
|
||||
return None
|
||||
|
||||
if self.is_embeddings(data):
|
||||
return ImageEmbeddingItems(data)
|
||||
|
||||
if (
|
||||
isinstance(data, (PILImage.Image, MediaWithBytes))
|
||||
or isinstance(data, (np.ndarray, torch.Tensor))
|
||||
and data.ndim == 3
|
||||
):
|
||||
data_items = [data]
|
||||
elif isinstance(data, (np.ndarray, torch.Tensor)):
|
||||
data_items = [elem for elem in data]
|
||||
else:
|
||||
data_items = data
|
||||
|
||||
return ImageProcessorItems(data_items)
|
||||
|
||||
def _parse_video_data(
|
||||
self,
|
||||
data: ModalityData[VideoItem],
|
||||
) -> ModalityDataItems[Any, Any] | None:
|
||||
if data is None:
|
||||
return VideoProcessorItems(None)
|
||||
|
||||
if self._is_empty(data):
|
||||
return None
|
||||
|
||||
if self.is_embeddings(data):
|
||||
return VideoEmbeddingItems(data)
|
||||
|
||||
data_items: list[VideoItem]
|
||||
if (
|
||||
is_list_of(data, PILImage.Image)
|
||||
or isinstance(data, (np.ndarray, torch.Tensor))
|
||||
and data.ndim == 4
|
||||
):
|
||||
data_items = [data]
|
||||
elif isinstance(data, (np.ndarray, torch.Tensor)):
|
||||
data_items = [elem for elem in data]
|
||||
elif isinstance(data, tuple) and len(data) == 2:
|
||||
data_items = [data]
|
||||
else:
|
||||
data_items = data # type: ignore[assignment]
|
||||
|
||||
new_videos = list[tuple[np.ndarray, dict[str, Any] | None]]()
|
||||
metadata_lst: list[dict[str, Any] | None] = []
|
||||
for data_item in data_items:
|
||||
video, metadata = self._get_video_with_metadata(data_item)
|
||||
if self.video_needs_metadata:
|
||||
if metadata is None:
|
||||
raise ValueError(
|
||||
"Video metadata is required but not found in mm input. "
|
||||
"Please check your video input in `multi_modal_data`"
|
||||
)
|
||||
new_videos.append((video, metadata))
|
||||
metadata_lst.append(metadata)
|
||||
else:
|
||||
new_videos.append(video)
|
||||
|
||||
if not self.video_needs_metadata:
|
||||
metadata = None
|
||||
|
||||
return VideoProcessorItems(new_videos, metadata=metadata_lst)
|
||||
|
||||
def _get_subparsers(self) -> Mapping[str, ModalityDataParser]:
|
||||
return {
|
||||
"audio": self._parse_audio_data,
|
||||
"image": self._parse_image_data,
|
||||
"video": self._parse_video_data,
|
||||
}
|
||||
|
||||
def parse_mm_data(self, mm_data: MultiModalDataDict) -> MultiModalDataItems:
|
||||
subparsers = self._get_subparsers()
|
||||
|
||||
mm_items = MultiModalDataItems()
|
||||
for k, v in mm_data.items():
|
||||
if k not in subparsers:
|
||||
raise ValueError(f"Unsupported modality: {k}")
|
||||
|
||||
# ignore empty embedding data
|
||||
if (parsed_data := subparsers[k](v)) is not None:
|
||||
mm_items[k] = parsed_data
|
||||
|
||||
return mm_items
|
||||
2240
vllm/multimodal/processing.py
Normal file
2240
vllm/multimodal/processing.py
Normal file
File diff suppressed because it is too large
Load Diff
351
vllm/multimodal/profiling.py
Normal file
351
vllm/multimodal/profiling.py
Normal file
@@ -0,0 +1,351 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Generic, NamedTuple, TypeVar, cast
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
from PIL import Image
|
||||
|
||||
from vllm.config.multimodal import (
|
||||
AudioDummyOptions,
|
||||
BaseDummyOptions,
|
||||
ImageDummyOptions,
|
||||
VideoDummyOptions,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .inputs import (
|
||||
MultiModalDataDict,
|
||||
MultiModalEncDecInputs,
|
||||
MultiModalInputs,
|
||||
MultiModalKwargsItems,
|
||||
MultiModalPlaceholderDict,
|
||||
)
|
||||
from .processing import (
|
||||
BaseMultiModalProcessor,
|
||||
BaseProcessingInfo,
|
||||
EncDecMultiModalProcessor,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProcessorInputs:
|
||||
"""
|
||||
Represents the keyword arguments to
|
||||
[`vllm.multimodal.processing.BaseMultiModalProcessor.apply`][].
|
||||
"""
|
||||
|
||||
prompt: str | list[int]
|
||||
mm_data: MultiModalDataDict
|
||||
hf_processor_mm_kwargs: Mapping[str, object] = field(default_factory=dict)
|
||||
tokenization_kwargs: Mapping[str, object] = field(default_factory=dict)
|
||||
|
||||
|
||||
class DummyEncoderData(NamedTuple):
|
||||
"""Dummy data used for profiling."""
|
||||
|
||||
prompt_token_ids: list[int]
|
||||
|
||||
|
||||
class DummyDecoderData(NamedTuple):
|
||||
"""Dummy data used for profiling."""
|
||||
|
||||
prompt_token_ids: list[int]
|
||||
multi_modal_data: MultiModalKwargsItems
|
||||
multi_modal_placeholders: MultiModalPlaceholderDict
|
||||
|
||||
|
||||
_I = TypeVar("_I", bound=BaseProcessingInfo)
|
||||
|
||||
|
||||
class BaseDummyInputsBuilder(ABC, Generic[_I]):
|
||||
"""
|
||||
Abstract base class that constructs the dummy data to profile
|
||||
multi-modal models.
|
||||
"""
|
||||
|
||||
def __init__(self, info: _I) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.info = info
|
||||
|
||||
@abstractmethod
|
||||
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
||||
"""
|
||||
Build the text input corresponding to `mm_counts`.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_dummy_mm_data(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Mapping[str, BaseDummyOptions] | None = None,
|
||||
) -> MultiModalDataDict:
|
||||
"""
|
||||
Build the multimodal input which, after processing, results in
|
||||
the maximum possible number of placeholder tokens.
|
||||
|
||||
Args:
|
||||
seq_len: Sequence length
|
||||
mm_counts: Count of items per modality
|
||||
mm_options: Configurable options per modality (optional).
|
||||
If None, use model defaults for backward compatibility.
|
||||
If provided, models can use these to customize dummy
|
||||
data generation.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_dummy_processor_inputs(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Mapping[str, BaseDummyOptions] | None = None,
|
||||
) -> ProcessorInputs:
|
||||
"""
|
||||
Build the input which, after processing, results in
|
||||
the maximum possible number of placeholder tokens.
|
||||
|
||||
Args:
|
||||
seq_len: Sequence length
|
||||
mm_counts: Count of items per modality
|
||||
mm_options: Configurable options per modality (optional)
|
||||
"""
|
||||
dummy_text = self.get_dummy_text(mm_counts)
|
||||
|
||||
# Use the unified function for both legacy and configurable cases
|
||||
dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts, mm_options)
|
||||
|
||||
tokenization_kwargs = {"truncation": False}
|
||||
|
||||
return ProcessorInputs(
|
||||
prompt=dummy_text,
|
||||
mm_data=dummy_mm_data,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
|
||||
def _get_dummy_audios(
|
||||
self,
|
||||
*,
|
||||
length: int,
|
||||
num_audios: int,
|
||||
overrides: AudioDummyOptions | None = None,
|
||||
) -> list[npt.NDArray]:
|
||||
if num_audios == 0:
|
||||
return []
|
||||
if overrides and overrides.length:
|
||||
if overrides.length > length:
|
||||
logger.warning(
|
||||
"audio.length override (%d) exceeds model's "
|
||||
"maximum length (%d), will be ignored",
|
||||
overrides.length,
|
||||
length,
|
||||
)
|
||||
length = min(length, overrides.length)
|
||||
audio = np.zeros((length,))
|
||||
return [audio] * num_audios
|
||||
|
||||
def _get_dummy_images(
|
||||
self,
|
||||
*,
|
||||
width: int,
|
||||
height: int,
|
||||
num_images: int,
|
||||
overrides: ImageDummyOptions | None = None,
|
||||
) -> list[Image.Image]:
|
||||
if num_images == 0:
|
||||
return []
|
||||
if overrides:
|
||||
if overrides.width:
|
||||
if overrides.width > width:
|
||||
logger.warning(
|
||||
"image.width override (%d) exceeds model's "
|
||||
"maximum width (%d), will be ignored",
|
||||
overrides.width,
|
||||
width,
|
||||
)
|
||||
width = min(width, overrides.width)
|
||||
if overrides.height:
|
||||
if overrides.height > height:
|
||||
logger.warning(
|
||||
"image.height override (%d) exceeds model's "
|
||||
"maximum height (%d), will be ignored",
|
||||
overrides.height,
|
||||
height,
|
||||
)
|
||||
height = min(height, overrides.height)
|
||||
image = Image.new("RGB", (width, height), color=255)
|
||||
return [image] * num_images
|
||||
|
||||
def _get_dummy_videos(
|
||||
self,
|
||||
*,
|
||||
width: int,
|
||||
height: int,
|
||||
num_frames: int,
|
||||
num_videos: int,
|
||||
overrides: VideoDummyOptions | None = None,
|
||||
) -> list[npt.NDArray]:
|
||||
if num_videos == 0:
|
||||
return []
|
||||
if overrides:
|
||||
if overrides.num_frames:
|
||||
if overrides.num_frames > num_frames:
|
||||
logger.warning(
|
||||
"video.num_frames override (%d) exceeds model's "
|
||||
"maximum number of frames (%d), will be ignored",
|
||||
overrides.num_frames,
|
||||
num_frames,
|
||||
)
|
||||
num_frames = min(num_frames, overrides.num_frames)
|
||||
if overrides.width:
|
||||
if overrides.width > width:
|
||||
logger.warning(
|
||||
"video.width override (%d) exceeds model's "
|
||||
"maximum width (%d), will be ignored",
|
||||
overrides.width,
|
||||
width,
|
||||
)
|
||||
width = min(width, overrides.width)
|
||||
if overrides.height:
|
||||
if overrides.height > height:
|
||||
logger.warning(
|
||||
"video.height override (%d) exceeds model's "
|
||||
"maximum height (%d), will be ignored",
|
||||
overrides.height,
|
||||
height,
|
||||
)
|
||||
height = min(height, overrides.height)
|
||||
video = np.full((num_frames, width, height, 3), 255, dtype=np.uint8)
|
||||
return [video] * num_videos
|
||||
|
||||
|
||||
class MultiModalProfiler(Generic[_I]):
|
||||
"""
|
||||
Contains code for running memory profiling for multi-modal models.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
processor: BaseMultiModalProcessor[_I],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.processor = processor
|
||||
|
||||
@property
|
||||
def processing_info(self) -> BaseProcessingInfo:
|
||||
return self.processor.info
|
||||
|
||||
@property
|
||||
def dummy_inputs(self) -> BaseDummyInputsBuilder[_I]:
|
||||
return self.processor.dummy_inputs
|
||||
|
||||
def get_mm_limits(self) -> Mapping[str, int]:
|
||||
return self.processor.allowed_mm_limits
|
||||
|
||||
def _get_dummy_mm_inputs(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int] | None = None,
|
||||
mm_options: Mapping[str, BaseDummyOptions] | None = None,
|
||||
) -> MultiModalInputs:
|
||||
if mm_counts is None:
|
||||
mm_counts = self.get_mm_limits()
|
||||
|
||||
factory = self.dummy_inputs
|
||||
processor_inputs = factory.get_dummy_processor_inputs(
|
||||
seq_len, mm_counts, mm_options
|
||||
)
|
||||
|
||||
return self.processor.apply(
|
||||
prompt=processor_inputs.prompt,
|
||||
mm_data=processor_inputs.mm_data,
|
||||
hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
|
||||
tokenization_kwargs=processor_inputs.tokenization_kwargs,
|
||||
)
|
||||
|
||||
def _get_mm_num_tokens(
|
||||
self,
|
||||
mm_inputs: MultiModalInputs,
|
||||
) -> Mapping[str, int]:
|
||||
placeholders_by_modality = mm_inputs["mm_placeholders"]
|
||||
|
||||
return {
|
||||
modality: sum(item.get_num_embeds for item in placeholders)
|
||||
for modality, placeholders in placeholders_by_modality.items()
|
||||
}
|
||||
|
||||
def get_encoder_dummy_data(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int] | None = None,
|
||||
mm_options: Mapping[str, BaseDummyOptions] | None = None,
|
||||
) -> DummyEncoderData:
|
||||
mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts, mm_options)
|
||||
mm_inputs = cast(MultiModalEncDecInputs, mm_inputs)
|
||||
|
||||
# For encoder-decoder models, use encoder prompt token ids instead of
|
||||
# decoder prompt to construct dummy seq_data for encoder profiling.
|
||||
encoder_prompt_token_ids = mm_inputs["encoder_prompt_token_ids"]
|
||||
|
||||
total_len = len(encoder_prompt_token_ids)
|
||||
|
||||
processor = cast(EncDecMultiModalProcessor, self.processor)
|
||||
if processor.pad_dummy_encoder_prompt:
|
||||
num_tokens_to_pad = max(total_len, seq_len) - total_len
|
||||
encoder_prompt_token_ids.extend([0] * num_tokens_to_pad)
|
||||
|
||||
return DummyEncoderData(encoder_prompt_token_ids)
|
||||
|
||||
def get_decoder_dummy_data(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int] | None = None,
|
||||
mm_options: Mapping[str, BaseDummyOptions] | None = None,
|
||||
) -> DummyDecoderData:
|
||||
mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts, mm_options)
|
||||
|
||||
prompt_token_ids = mm_inputs["prompt_token_ids"]
|
||||
total_len = len(prompt_token_ids)
|
||||
|
||||
if total_len < seq_len:
|
||||
prompt_token_ids.extend([0] * (seq_len - total_len))
|
||||
|
||||
return DummyDecoderData(
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
multi_modal_data=mm_inputs["mm_kwargs"].require_data(),
|
||||
multi_modal_placeholders=mm_inputs["mm_placeholders"],
|
||||
)
|
||||
|
||||
def get_mm_max_tokens(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int] | None = None,
|
||||
) -> Mapping[str, int]:
|
||||
"""
|
||||
Returns the maximum number of embeddings per item of each modality, excluding
|
||||
any break/text tokens in-between multimodal embeddings/encoder outputs.
|
||||
"""
|
||||
if mm_counts is None:
|
||||
mm_counts = self.get_mm_limits()
|
||||
|
||||
max_tokens_per_item = self.processing_info.get_mm_max_tokens_per_item(
|
||||
seq_len=seq_len,
|
||||
mm_counts=mm_counts,
|
||||
)
|
||||
if max_tokens_per_item is not None:
|
||||
return {
|
||||
modality: max_tokens
|
||||
for modality, max_tokens in max_tokens_per_item.items()
|
||||
if mm_counts.get(modality, 0) > 0
|
||||
}
|
||||
|
||||
mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
|
||||
return self._get_mm_num_tokens(mm_inputs)
|
||||
357
vllm/multimodal/registry.py
Normal file
357
vllm/multimodal/registry.py
Normal file
@@ -0,0 +1,357 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Generic, Protocol, TypeVar, cast
|
||||
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config
|
||||
|
||||
from .cache import BaseMultiModalProcessorCache
|
||||
from .processing import (
|
||||
BaseMultiModalProcessor,
|
||||
BaseProcessingInfo,
|
||||
InputProcessingContext,
|
||||
)
|
||||
from .profiling import (
|
||||
BaseDummyInputsBuilder,
|
||||
DummyDecoderData,
|
||||
DummyEncoderData,
|
||||
MultiModalProfiler,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.model_executor.models.interfaces import SupportsMultiModal
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
N = TypeVar("N", bound=type["SupportsMultiModal"])
|
||||
_I = TypeVar("_I", bound=BaseProcessingInfo)
|
||||
_I_co = TypeVar("_I_co", bound=BaseProcessingInfo, covariant=True)
|
||||
|
||||
|
||||
class ProcessingInfoFactory(Protocol[_I_co]):
|
||||
"""
|
||||
Constructs a
|
||||
[`BaseMultiModalProcessor`][vllm.multimodal.processing.BaseMultiModalProcessor]
|
||||
instance from the context.
|
||||
"""
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
ctx: InputProcessingContext,
|
||||
) -> _I_co: ...
|
||||
|
||||
|
||||
class DummyInputsBuilderFactory(Protocol[_I]): # type: ignore[misc]
|
||||
"""
|
||||
Constructs a
|
||||
[`BaseDummyInputsBuilder`][vllm.multimodal.profiling.BaseDummyInputsBuilder]
|
||||
instance from the context.
|
||||
"""
|
||||
|
||||
def __call__(self, info: _I) -> BaseDummyInputsBuilder[_I]: ...
|
||||
|
||||
|
||||
class MultiModalProcessorFactory(Protocol[_I]): # type: ignore[misc]
|
||||
"""
|
||||
Constructs a
|
||||
[`BaseMultiModalProcessor`][vllm.multimodal.processing.BaseMultiModalProcessor]
|
||||
instance from the context.
|
||||
"""
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
info: _I,
|
||||
dummy_inputs: BaseDummyInputsBuilder[_I],
|
||||
*,
|
||||
cache: BaseMultiModalProcessorCache | None = None,
|
||||
) -> BaseMultiModalProcessor[_I]: ...
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _ProcessorFactories(Generic[_I]):
|
||||
info: ProcessingInfoFactory[_I]
|
||||
processor: MultiModalProcessorFactory[_I]
|
||||
dummy_inputs: DummyInputsBuilderFactory[_I]
|
||||
|
||||
def build_processor(
|
||||
self,
|
||||
ctx: InputProcessingContext,
|
||||
*,
|
||||
cache: BaseMultiModalProcessorCache | None = None,
|
||||
):
|
||||
info = self.info(ctx)
|
||||
dummy_inputs_builder = self.dummy_inputs(info)
|
||||
return self.processor(info, dummy_inputs_builder, cache=cache)
|
||||
|
||||
|
||||
class MultiModalRegistry:
|
||||
"""
|
||||
A registry that dispatches data processing according to the model.
|
||||
"""
|
||||
|
||||
def _extract_mm_options(
|
||||
self,
|
||||
model_config: "ModelConfig",
|
||||
) -> Mapping[str, BaseDummyOptions] | None:
|
||||
"""
|
||||
Extract multimodal dummy options from model config.
|
||||
|
||||
Returns None if no configurable options are found, otherwise returns
|
||||
a mapping of modality names to their dummy options.
|
||||
"""
|
||||
if not model_config.multimodal_config:
|
||||
return None
|
||||
|
||||
mm_options = {
|
||||
m: opt
|
||||
for m in model_config.multimodal_config.limit_per_prompt
|
||||
if (opt := model_config.multimodal_config.get_dummy_options(m)) is not None
|
||||
}
|
||||
|
||||
return mm_options if len(mm_options) > 0 else None
|
||||
|
||||
def supports_multimodal_inputs(self, model_config: "ModelConfig") -> bool:
|
||||
"""
|
||||
Checks if the model supports multimodal inputs.
|
||||
Returns True if the model is multimodal with any non-zero supported
|
||||
modalities, otherwise returns False, effectively running in
|
||||
text-only mode.
|
||||
"""
|
||||
if not model_config.is_multimodal_model:
|
||||
return False
|
||||
|
||||
info = self._create_processing_info(model_config, tokenizer=None)
|
||||
supported_modalities = info.get_supported_mm_limits()
|
||||
|
||||
mm_config = model_config.get_multimodal_config()
|
||||
|
||||
# Check if all supported modalities have limit == 0
|
||||
if all(
|
||||
mm_config.get_limit_per_prompt(modality) == 0
|
||||
for modality in supported_modalities
|
||||
):
|
||||
logger.info_once(
|
||||
"All limits of multimodal modalities supported by the model "
|
||||
"are set to 0, running in text-only mode."
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def get_max_tokens_per_item_by_modality(
|
||||
self,
|
||||
model_config: "ModelConfig",
|
||||
*,
|
||||
cache: BaseMultiModalProcessorCache | None = None,
|
||||
profiler_limits: Mapping[str, int] | None = None,
|
||||
) -> Mapping[str, int]:
|
||||
"""
|
||||
Get the maximum number of tokens per data item from each modality based
|
||||
on underlying model configuration.
|
||||
"""
|
||||
if not model_config.is_multimodal_model:
|
||||
return {}
|
||||
|
||||
processor = self.create_processor(model_config, cache=cache)
|
||||
profiler: MultiModalProfiler = MultiModalProfiler(processor)
|
||||
|
||||
seq_len = model_config.max_model_len
|
||||
profiler_limits = (
|
||||
profiler.get_mm_limits() if profiler_limits is None else profiler_limits
|
||||
)
|
||||
|
||||
return profiler.get_mm_max_tokens(
|
||||
seq_len,
|
||||
{modality: 1 for modality, limit in profiler_limits.items() if limit > 0},
|
||||
)
|
||||
|
||||
def get_mm_limits_per_prompt(
|
||||
self,
|
||||
model_config: "ModelConfig",
|
||||
*,
|
||||
cache: BaseMultiModalProcessorCache | None = None,
|
||||
) -> Mapping[str, int]:
|
||||
"""
|
||||
Get the maximum number of multi-modal input instances for each modality
|
||||
that are allowed per prompt for a model class.
|
||||
"""
|
||||
if not model_config.is_multimodal_model:
|
||||
return {}
|
||||
|
||||
processor = self.create_processor(model_config, cache=cache)
|
||||
profiler: MultiModalProfiler = MultiModalProfiler(processor)
|
||||
return profiler.get_mm_limits()
|
||||
|
||||
def register_processor(
|
||||
self,
|
||||
processor: MultiModalProcessorFactory[_I],
|
||||
*,
|
||||
info: ProcessingInfoFactory[_I],
|
||||
dummy_inputs: DummyInputsBuilderFactory[_I],
|
||||
):
|
||||
"""
|
||||
Register a multi-modal processor to a model class. The processor
|
||||
is constructed lazily, hence a factory method should be passed.
|
||||
|
||||
When the model receives multi-modal data, the provided function is
|
||||
invoked to transform the data into a dictionary of model inputs.
|
||||
"""
|
||||
|
||||
def wrapper(model_cls: N) -> N:
|
||||
if "_processor_factory" in model_cls.__dict__:
|
||||
logger.warning(
|
||||
"Model class %s already has a multi-modal processor "
|
||||
"registered to %s. It is overwritten by the new one.",
|
||||
model_cls,
|
||||
self,
|
||||
)
|
||||
|
||||
model_cls._processor_factory = _ProcessorFactories(
|
||||
info=info,
|
||||
dummy_inputs=dummy_inputs,
|
||||
processor=processor,
|
||||
)
|
||||
|
||||
return model_cls
|
||||
|
||||
return wrapper
|
||||
|
||||
def _get_model_cls(self, model_config: "ModelConfig") -> "SupportsMultiModal":
|
||||
# Avoid circular import
|
||||
from vllm.model_executor.model_loader import get_model_architecture
|
||||
|
||||
model_cls, _ = get_model_architecture(model_config)
|
||||
assert hasattr(model_cls, "_processor_factory")
|
||||
return cast("SupportsMultiModal", model_cls)
|
||||
|
||||
def _create_processing_ctx(
|
||||
self,
|
||||
model_config: "ModelConfig",
|
||||
tokenizer: TokenizerLike | None = None,
|
||||
) -> InputProcessingContext:
|
||||
if tokenizer is None and not model_config.skip_tokenizer_init:
|
||||
tokenizer = cached_tokenizer_from_config(model_config)
|
||||
|
||||
return InputProcessingContext(model_config, tokenizer)
|
||||
|
||||
def _create_processing_info(
|
||||
self,
|
||||
model_config: "ModelConfig",
|
||||
*,
|
||||
tokenizer: TokenizerLike | None = None,
|
||||
) -> BaseProcessingInfo:
|
||||
model_cls = self._get_model_cls(model_config)
|
||||
factories = model_cls._processor_factory
|
||||
ctx = self._create_processing_ctx(model_config, tokenizer)
|
||||
return factories.info(ctx)
|
||||
|
||||
def create_processor(
|
||||
self,
|
||||
model_config: "ModelConfig",
|
||||
*,
|
||||
tokenizer: TokenizerLike | None = None,
|
||||
cache: BaseMultiModalProcessorCache | None = None,
|
||||
) -> BaseMultiModalProcessor[BaseProcessingInfo]:
|
||||
"""
|
||||
Create a multi-modal processor for a specific model and tokenizer.
|
||||
"""
|
||||
if not model_config.is_multimodal_model:
|
||||
raise ValueError(f"{model_config.model} is not a multimodal model")
|
||||
|
||||
model_cls = self._get_model_cls(model_config)
|
||||
factories = model_cls._processor_factory
|
||||
|
||||
ctx = self._create_processing_ctx(model_config, tokenizer)
|
||||
|
||||
return factories.build_processor(ctx, cache=cache)
|
||||
|
||||
def get_decoder_dummy_data(
|
||||
self,
|
||||
model_config: "ModelConfig",
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int] | None = None,
|
||||
*,
|
||||
cache: BaseMultiModalProcessorCache | None = None,
|
||||
) -> DummyDecoderData:
|
||||
"""
|
||||
Create dummy data for profiling the memory usage of a model.
|
||||
|
||||
The model is identified by `model_config`.
|
||||
"""
|
||||
processor = self.create_processor(model_config, cache=cache)
|
||||
profiler: MultiModalProfiler = MultiModalProfiler(processor)
|
||||
|
||||
# Extract configurable options from multimodal config.
|
||||
# Only include modalities that use advanced option types so legacy
|
||||
# count-only behavior remains unchanged.
|
||||
mm_options = self._extract_mm_options(model_config)
|
||||
|
||||
dummy_data = profiler.get_decoder_dummy_data(seq_len, mm_counts, mm_options)
|
||||
|
||||
# Having more tokens is over-conservative but otherwise fine
|
||||
token_ids = dummy_data.prompt_token_ids
|
||||
if len(token_ids) < seq_len:
|
||||
raise AssertionError(
|
||||
f"Expected at least {seq_len} dummy tokens for profiling, "
|
||||
f"but found {len(token_ids)} tokens instead."
|
||||
)
|
||||
|
||||
return dummy_data
|
||||
|
||||
def get_encoder_dummy_data(
|
||||
self,
|
||||
model_config: "ModelConfig",
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int] | None = None,
|
||||
*,
|
||||
cache: BaseMultiModalProcessorCache | None = None,
|
||||
) -> DummyEncoderData:
|
||||
"""
|
||||
Create dummy data for profiling the memory usage of a model.
|
||||
|
||||
The model is identified by `model_config`.
|
||||
"""
|
||||
processor = self.create_processor(model_config, cache=cache)
|
||||
profiler: MultiModalProfiler = MultiModalProfiler(processor)
|
||||
|
||||
# Extract configurable options from multimodal config.
|
||||
# Only include modalities that use advanced option types so legacy
|
||||
# count-only behavior remains unchanged.
|
||||
mm_options = self._extract_mm_options(model_config)
|
||||
|
||||
dummy_data = profiler.get_encoder_dummy_data(seq_len, mm_counts, mm_options)
|
||||
|
||||
# Having more tokens is over-conservative but otherwise fine
|
||||
token_ids = dummy_data.prompt_token_ids
|
||||
if len(token_ids) < seq_len:
|
||||
logger.warning_once(
|
||||
"Expected at least %d dummy encoder tokens for profiling, but found %d tokens instead.", # noqa: E501
|
||||
seq_len,
|
||||
len(token_ids),
|
||||
)
|
||||
|
||||
return dummy_data
|
||||
|
||||
def get_encdec_max_encoder_len(self, model_config: "ModelConfig") -> int:
|
||||
"""
|
||||
Get the maximum length of the encoder input for encoder-decoder models.
|
||||
"""
|
||||
if not model_config.is_encoder_decoder:
|
||||
return 0
|
||||
max_tokens = self.get_max_tokens_per_item_by_modality(model_config)
|
||||
if not max_tokens:
|
||||
# TODO - this function assumes encoder-decoder models are
|
||||
# multimodal. This will need to change when adding support for more
|
||||
# than whisper.
|
||||
return 0
|
||||
assert len(max_tokens) == 1, (
|
||||
"Encoder-decoder models are expected \
|
||||
to implement the multimodal interface with at most one modality."
|
||||
)
|
||||
|
||||
first_modality = next(iter(max_tokens))
|
||||
return max_tokens[first_modality]
|
||||
513
vllm/multimodal/utils.py
Normal file
513
vllm/multimodal/utils.py
Normal file
@@ -0,0 +1,513 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
import atexit
|
||||
from collections.abc import Generator, Set
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from itertools import groupby
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, TypeVar
|
||||
from urllib.parse import ParseResult, urlparse
|
||||
from urllib.request import url2pathname
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
import torch
|
||||
from PIL import Image, UnidentifiedImageError
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.connections import HTTPConnection, global_http_connection
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.registry import ExtensionManager
|
||||
|
||||
from .audio import AudioEmbeddingMediaIO, AudioMediaIO
|
||||
from .base import MediaIO
|
||||
from .image import ImageEmbeddingMediaIO, ImageMediaIO
|
||||
from .video import VideoMediaIO
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .inputs import (
|
||||
BatchedTensorInputs,
|
||||
MultiModalKwargsItem,
|
||||
MultiModalPlaceholderDict,
|
||||
)
|
||||
else:
|
||||
BatchedTensorInputs = Any
|
||||
MultiModalKwargsItem = Any
|
||||
MultiModalPlaceholderDict = Any
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
global_thread_pool = ThreadPoolExecutor(
|
||||
max_workers=envs.VLLM_MEDIA_LOADING_THREAD_COUNT
|
||||
)
|
||||
atexit.register(global_thread_pool.shutdown)
|
||||
|
||||
_M = TypeVar("_M")
|
||||
|
||||
MEDIA_CONNECTOR_REGISTRY = ExtensionManager()
|
||||
|
||||
|
||||
@MEDIA_CONNECTOR_REGISTRY.register("http")
|
||||
class MediaConnector:
|
||||
def __init__(
|
||||
self,
|
||||
media_io_kwargs: dict[str, dict[str, Any]] | None = None,
|
||||
connection: HTTPConnection = global_http_connection,
|
||||
*,
|
||||
allowed_local_media_path: str = "",
|
||||
allowed_media_domains: list[str] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
media_io_kwargs: Additional args passed to process media
|
||||
inputs, keyed by modalities. For example,
|
||||
to set num_frames for video, set
|
||||
`--media-io-kwargs '{"video":{"num_frames":40}}'`
|
||||
connection: HTTP connection client to download media contents.
|
||||
allowed_local_media_path: A local directory to load media files from.
|
||||
allowed_media_domains: If set, only media URLs that belong to this
|
||||
domain can be used for multi-modal inputs.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.media_io_kwargs: dict[str, dict[str, Any]] = (
|
||||
media_io_kwargs if media_io_kwargs else {}
|
||||
)
|
||||
self.connection = connection
|
||||
|
||||
if allowed_local_media_path:
|
||||
allowed_local_media_path_ = Path(allowed_local_media_path)
|
||||
|
||||
if not allowed_local_media_path_.exists():
|
||||
raise ValueError(
|
||||
"Invalid `--allowed-local-media-path`: The path "
|
||||
f"{allowed_local_media_path_} does not exist."
|
||||
)
|
||||
if not allowed_local_media_path_.is_dir():
|
||||
raise ValueError(
|
||||
"Invalid `--allowed-local-media-path`: The path "
|
||||
f"{allowed_local_media_path_} must be a directory."
|
||||
)
|
||||
else:
|
||||
allowed_local_media_path_ = None
|
||||
|
||||
self.allowed_local_media_path = allowed_local_media_path_
|
||||
if allowed_media_domains is None:
|
||||
allowed_media_domains = []
|
||||
self.allowed_media_domains = allowed_media_domains
|
||||
|
||||
def _load_data_url(
|
||||
self,
|
||||
url_spec: ParseResult,
|
||||
media_io: MediaIO[_M],
|
||||
) -> _M: # type: ignore[type-var]
|
||||
data_spec, data = url_spec.path.split(",", 1)
|
||||
media_type, data_type = data_spec.split(";", 1)
|
||||
|
||||
if data_type != "base64":
|
||||
msg = "Only base64 data URLs are supported for now."
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
return media_io.load_base64(media_type, data)
|
||||
|
||||
def _load_file_url(
|
||||
self,
|
||||
url_spec: ParseResult,
|
||||
media_io: MediaIO[_M],
|
||||
) -> _M: # type: ignore[type-var]
|
||||
allowed_local_media_path = self.allowed_local_media_path
|
||||
if allowed_local_media_path is None:
|
||||
raise RuntimeError(
|
||||
"Cannot load local files without `--allowed-local-media-path`."
|
||||
)
|
||||
|
||||
filepath = Path(url2pathname(url_spec.netloc + url_spec.path))
|
||||
if allowed_local_media_path not in filepath.resolve().parents:
|
||||
raise ValueError(
|
||||
f"The file path {filepath} must be a subpath "
|
||||
f"of `--allowed-local-media-path {allowed_local_media_path}`."
|
||||
)
|
||||
|
||||
return media_io.load_file(filepath)
|
||||
|
||||
def _assert_url_in_allowed_media_domains(self, url_spec: ParseResult) -> None:
|
||||
if (
|
||||
self.allowed_media_domains
|
||||
and url_spec.hostname not in self.allowed_media_domains
|
||||
):
|
||||
raise ValueError(
|
||||
f"The URL must be from one of the allowed domains: "
|
||||
f"{self.allowed_media_domains}. Input URL domain: "
|
||||
f"{url_spec.hostname}"
|
||||
)
|
||||
|
||||
def load_from_url(
|
||||
self,
|
||||
url: str,
|
||||
media_io: MediaIO[_M],
|
||||
*,
|
||||
fetch_timeout: int | None = None,
|
||||
) -> _M: # type: ignore[type-var]
|
||||
url_spec = urlparse(url)
|
||||
|
||||
if url_spec.scheme.startswith("http"):
|
||||
self._assert_url_in_allowed_media_domains(url_spec)
|
||||
|
||||
connection = self.connection
|
||||
data = connection.get_bytes(
|
||||
url,
|
||||
timeout=fetch_timeout,
|
||||
allow_redirects=envs.VLLM_MEDIA_URL_ALLOW_REDIRECTS,
|
||||
)
|
||||
|
||||
return media_io.load_bytes(data)
|
||||
|
||||
if url_spec.scheme == "data":
|
||||
return self._load_data_url(url_spec, media_io)
|
||||
|
||||
if url_spec.scheme == "file":
|
||||
return self._load_file_url(url_spec, media_io)
|
||||
|
||||
msg = "The URL must be either a HTTP, data or file URL."
|
||||
raise ValueError(msg)
|
||||
|
||||
async def load_from_url_async(
|
||||
self,
|
||||
url: str,
|
||||
media_io: MediaIO[_M],
|
||||
*,
|
||||
fetch_timeout: int | None = None,
|
||||
) -> _M:
|
||||
url_spec = urlparse(url)
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
if url_spec.scheme.startswith("http"):
|
||||
self._assert_url_in_allowed_media_domains(url_spec)
|
||||
|
||||
connection = self.connection
|
||||
data = await connection.async_get_bytes(
|
||||
url,
|
||||
timeout=fetch_timeout,
|
||||
allow_redirects=envs.VLLM_MEDIA_URL_ALLOW_REDIRECTS,
|
||||
)
|
||||
future = loop.run_in_executor(global_thread_pool, media_io.load_bytes, data)
|
||||
return await future
|
||||
|
||||
if url_spec.scheme == "data":
|
||||
future = loop.run_in_executor(
|
||||
global_thread_pool, self._load_data_url, url_spec, media_io
|
||||
)
|
||||
return await future
|
||||
|
||||
if url_spec.scheme == "file":
|
||||
future = loop.run_in_executor(
|
||||
global_thread_pool, self._load_file_url, url_spec, media_io
|
||||
)
|
||||
return await future
|
||||
msg = "The URL must be either a HTTP, data or file URL."
|
||||
raise ValueError(msg)
|
||||
|
||||
def fetch_audio(
|
||||
self,
|
||||
audio_url: str,
|
||||
) -> tuple[np.ndarray, int | float]:
|
||||
"""
|
||||
Load audio from a URL.
|
||||
"""
|
||||
audio_io = AudioMediaIO(**self.media_io_kwargs.get("audio", {}))
|
||||
|
||||
return self.load_from_url(
|
||||
audio_url,
|
||||
audio_io,
|
||||
fetch_timeout=envs.VLLM_AUDIO_FETCH_TIMEOUT,
|
||||
)
|
||||
|
||||
async def fetch_audio_async(
|
||||
self,
|
||||
audio_url: str,
|
||||
) -> tuple[np.ndarray, int | float]:
|
||||
"""
|
||||
Asynchronously fetch audio from a URL.
|
||||
"""
|
||||
audio_io = AudioMediaIO(**self.media_io_kwargs.get("audio", {}))
|
||||
|
||||
return await self.load_from_url_async(
|
||||
audio_url,
|
||||
audio_io,
|
||||
fetch_timeout=envs.VLLM_AUDIO_FETCH_TIMEOUT,
|
||||
)
|
||||
|
||||
def fetch_image(
|
||||
self,
|
||||
image_url: str,
|
||||
*,
|
||||
image_mode: str = "RGB",
|
||||
) -> Image.Image:
|
||||
"""
|
||||
Load a PIL image from an HTTP or base64 data URL.
|
||||
|
||||
By default, the image is converted into RGB format.
|
||||
"""
|
||||
image_io = ImageMediaIO(
|
||||
image_mode=image_mode, **self.media_io_kwargs.get("image", {})
|
||||
)
|
||||
|
||||
try:
|
||||
return self.load_from_url(
|
||||
image_url,
|
||||
image_io,
|
||||
fetch_timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT,
|
||||
)
|
||||
except UnidentifiedImageError as e:
|
||||
# convert to ValueError to be properly caught upstream
|
||||
raise ValueError(str(e)) from e
|
||||
|
||||
async def fetch_image_async(
|
||||
self,
|
||||
image_url: str,
|
||||
*,
|
||||
image_mode: str = "RGB",
|
||||
) -> Image.Image:
|
||||
"""
|
||||
Asynchronously load a PIL image from an HTTP or base64 data URL.
|
||||
|
||||
By default, the image is converted into RGB format.
|
||||
"""
|
||||
image_io = ImageMediaIO(
|
||||
image_mode=image_mode, **self.media_io_kwargs.get("image", {})
|
||||
)
|
||||
|
||||
try:
|
||||
return await self.load_from_url_async(
|
||||
image_url,
|
||||
image_io,
|
||||
fetch_timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT,
|
||||
)
|
||||
except UnidentifiedImageError as e:
|
||||
# convert to ValueError to be properly caught upstream
|
||||
raise ValueError(str(e)) from e
|
||||
|
||||
def fetch_video(
|
||||
self,
|
||||
video_url: str,
|
||||
*,
|
||||
image_mode: str = "RGB",
|
||||
) -> tuple[npt.NDArray, dict[str, Any]]:
|
||||
"""
|
||||
Load video from an HTTP or base64 data URL.
|
||||
"""
|
||||
image_io = ImageMediaIO(
|
||||
image_mode=image_mode, **self.media_io_kwargs.get("image", {})
|
||||
)
|
||||
video_io = VideoMediaIO(image_io, **self.media_io_kwargs.get("video", {}))
|
||||
|
||||
return self.load_from_url(
|
||||
video_url,
|
||||
video_io,
|
||||
fetch_timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT,
|
||||
)
|
||||
|
||||
async def fetch_video_async(
|
||||
self,
|
||||
video_url: str,
|
||||
*,
|
||||
image_mode: str = "RGB",
|
||||
) -> tuple[npt.NDArray, dict[str, Any]]:
|
||||
"""
|
||||
Asynchronously load video from an HTTP or base64 data URL.
|
||||
|
||||
By default, the image is converted into RGB format.
|
||||
"""
|
||||
image_io = ImageMediaIO(
|
||||
image_mode=image_mode, **self.media_io_kwargs.get("image", {})
|
||||
)
|
||||
video_io = VideoMediaIO(image_io, **self.media_io_kwargs.get("video", {}))
|
||||
|
||||
return await self.load_from_url_async(
|
||||
video_url,
|
||||
video_io,
|
||||
fetch_timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT,
|
||||
)
|
||||
|
||||
def fetch_image_embedding(
|
||||
self,
|
||||
data: str,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Load image embedding from a URL.
|
||||
"""
|
||||
image_embedding_io = ImageEmbeddingMediaIO()
|
||||
|
||||
return image_embedding_io.load_base64("", data)
|
||||
|
||||
def fetch_audio_embedding(
|
||||
self,
|
||||
data: str,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Load audio embedding from a URL.
|
||||
"""
|
||||
audio_embedding_io = AudioEmbeddingMediaIO()
|
||||
|
||||
return audio_embedding_io.load_base64("", data)
|
||||
|
||||
|
||||
def encode_audio_base64(
|
||||
audio: np.ndarray,
|
||||
sampling_rate: int,
|
||||
) -> str:
|
||||
"""Encode audio as base64."""
|
||||
audio_io = AudioMediaIO()
|
||||
return audio_io.encode_base64((audio, sampling_rate))
|
||||
|
||||
|
||||
def encode_image_base64(
|
||||
image: Image.Image,
|
||||
*,
|
||||
image_mode: str = "RGB",
|
||||
format: str = "JPEG",
|
||||
) -> str:
|
||||
"""
|
||||
Encode a pillow image to base64 format.
|
||||
|
||||
By default, the image is converted into RGB format before being encoded.
|
||||
"""
|
||||
image_io = ImageMediaIO(image_mode=image_mode)
|
||||
return image_io.encode_base64(image, image_format=format)
|
||||
|
||||
|
||||
def encode_video_base64(frames: npt.NDArray) -> str:
|
||||
image_io = ImageMediaIO()
|
||||
video_io = VideoMediaIO(image_io)
|
||||
return video_io.encode_base64(frames)
|
||||
|
||||
|
||||
def argsort_mm_positions(
|
||||
mm_positions: MultiModalPlaceholderDict,
|
||||
) -> list[tuple[str, int]]:
|
||||
"""
|
||||
Given a `MultiModalPlaceholderDict`, output a sequence of keys to
|
||||
sort the dictionary by `offset` (starting index in the input sequence)
|
||||
in ascending order.
|
||||
|
||||
Returns:
|
||||
A list of `(modality, idx)`, which can be used to access an item
|
||||
by `mm_positions[modality][idx]`.
|
||||
"""
|
||||
flat_items = (
|
||||
(modality, idx, item)
|
||||
for modality, items in mm_positions.items()
|
||||
for idx, item in enumerate(items)
|
||||
)
|
||||
|
||||
sorted_flat_items = sorted(flat_items, key=lambda x: x[2].offset)
|
||||
|
||||
return [(modality, idx) for modality, idx, _ in sorted_flat_items]
|
||||
|
||||
|
||||
def group_mm_kwargs_by_modality(
|
||||
mm_kwargs: list[MultiModalKwargsItem],
|
||||
*,
|
||||
device: torch.types.Device = None,
|
||||
pin_memory: bool = False,
|
||||
merge_by_field_config: bool | None = None,
|
||||
multimodal_cpu_fields: Set[str] | None = None,
|
||||
) -> Generator[tuple[str, int, BatchedTensorInputs], None, None]:
|
||||
"""Group consecutive `MultiModalKwargsItem`s from `mm_kwargs` with the same
|
||||
modality together into the same `MultiModalKwargs` instance.
|
||||
|
||||
Args:
|
||||
mm_kwargs: List of `MultiModalKwargsItem`.
|
||||
device: The device to place the grouped tensors on.
|
||||
pin_memory: Whether to pin memory for faster host-to-device transfer.
|
||||
|
||||
Yields:
|
||||
A tuple `(modality, num_items, grouped_kwargs)`.
|
||||
"""
|
||||
if merge_by_field_config is not None:
|
||||
logger.warning_once(
|
||||
"The `merge_by_field_config` argument of `group_mm_kwargs_by_modality` "
|
||||
"is deprecated and will be removed in v0.14."
|
||||
)
|
||||
if multimodal_cpu_fields is not None:
|
||||
logger.warning_once(
|
||||
"The `multimodal_cpu_fields` argument of `group_mm_kwargs_by_modality` "
|
||||
"is deprecated and will be removed in v0.14."
|
||||
)
|
||||
|
||||
from vllm.multimodal.inputs import MultiModalKwargsItems
|
||||
|
||||
for modality, items in groupby(mm_kwargs, key=lambda item: item.modality):
|
||||
items_lst = list(items)
|
||||
mm_kwargs_items = MultiModalKwargsItems.from_seq(items_lst)
|
||||
mm_kwargs_data = mm_kwargs_items.get_data(
|
||||
device=device,
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
|
||||
yield modality, len(items_lst), mm_kwargs_data
|
||||
|
||||
|
||||
def fetch_audio(
|
||||
audio_url: str,
|
||||
audio_io_kwargs: dict[str, Any] | None = None,
|
||||
) -> tuple[np.ndarray, int | float]:
|
||||
"""
|
||||
Args:
|
||||
audio_url: URL of the audio file to fetch.
|
||||
audio_io_kwargs: Additional kwargs passed to handle audio IO.
|
||||
|
||||
Warning:
|
||||
This method has direct access to local files and is only intended
|
||||
to be called by user code. Never call this from the online server!
|
||||
"""
|
||||
media_io_kwargs = None if not audio_io_kwargs else {"audio": audio_io_kwargs}
|
||||
media_connector = MediaConnector(
|
||||
media_io_kwargs=media_io_kwargs,
|
||||
allowed_local_media_path="/",
|
||||
)
|
||||
return media_connector.fetch_audio(audio_url)
|
||||
|
||||
|
||||
def fetch_image(
|
||||
image_url: str,
|
||||
image_io_kwargs: dict[str, Any] | None = None,
|
||||
) -> Image.Image:
|
||||
"""
|
||||
Args:
|
||||
image_url: URL of the image file to fetch.
|
||||
image_io_kwargs: Additional kwargs passed to handle image IO.
|
||||
|
||||
Warning:
|
||||
This method has direct access to local files and is only intended
|
||||
to be called by user code. Never call this from the online server!
|
||||
"""
|
||||
media_io_kwargs = None if not image_io_kwargs else {"image": image_io_kwargs}
|
||||
media_connector = MediaConnector(
|
||||
media_io_kwargs=media_io_kwargs,
|
||||
allowed_local_media_path="/",
|
||||
)
|
||||
return media_connector.fetch_image(image_url)
|
||||
|
||||
|
||||
def fetch_video(
|
||||
video_url: str,
|
||||
video_io_kwargs: dict[str, Any] | None = None,
|
||||
) -> tuple[npt.NDArray, dict[str, Any]]:
|
||||
"""
|
||||
Args:
|
||||
video_url: URL of the video file to fetch.
|
||||
video_io_kwargs: Additional kwargs passed to handle video IO.
|
||||
|
||||
Warning:
|
||||
This method has direct access to local files and is only intended
|
||||
to be called by user code. Never call this from the online server!
|
||||
"""
|
||||
media_io_kwargs = None if not video_io_kwargs else {"video": video_io_kwargs}
|
||||
media_connector = MediaConnector(
|
||||
media_io_kwargs=media_io_kwargs,
|
||||
allowed_local_media_path="/",
|
||||
)
|
||||
return media_connector.fetch_video(video_url)
|
||||
340
vllm/multimodal/video.py
Normal file
340
vllm/multimodal/video.py
Normal file
@@ -0,0 +1,340 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import base64
|
||||
import math
|
||||
from abc import abstractmethod
|
||||
from functools import partial
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
from PIL import Image
|
||||
|
||||
from vllm import envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.registry import ExtensionManager
|
||||
|
||||
from .base import MediaIO
|
||||
from .image import ImageMediaIO
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def resize_video(frames: npt.NDArray, size: tuple[int, int]) -> npt.NDArray:
|
||||
num_frames, _, _, channels = frames.shape
|
||||
new_height, new_width = size
|
||||
resized_frames = np.empty(
|
||||
(num_frames, new_height, new_width, channels), dtype=frames.dtype
|
||||
)
|
||||
# lazy import cv2 to avoid bothering users who only use text models
|
||||
import cv2
|
||||
|
||||
for i, frame in enumerate(frames):
|
||||
resized_frame = cv2.resize(frame, (new_width, new_height))
|
||||
resized_frames[i] = resized_frame
|
||||
return resized_frames
|
||||
|
||||
|
||||
def rescale_video_size(frames: npt.NDArray, size_factor: float) -> npt.NDArray:
|
||||
_, height, width, _ = frames.shape
|
||||
new_height = int(height * size_factor)
|
||||
new_width = int(width * size_factor)
|
||||
|
||||
return resize_video(frames, (new_height, new_width))
|
||||
|
||||
|
||||
def sample_frames_from_video(frames: npt.NDArray, num_frames: int) -> npt.NDArray:
|
||||
total_frames = frames.shape[0]
|
||||
if num_frames == -1:
|
||||
return frames
|
||||
|
||||
frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
|
||||
sampled_frames = frames[frame_indices, ...]
|
||||
return sampled_frames
|
||||
|
||||
|
||||
class VideoLoader:
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def load_bytes(
|
||||
cls, data: bytes, num_frames: int = -1, **kwargs
|
||||
) -> tuple[npt.NDArray, dict[str, Any]]:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def _read_frames(
|
||||
cap,
|
||||
frame_indices: set[int],
|
||||
num_expected_frames: int,
|
||||
max_frame_idx: int,
|
||||
) -> tuple[npt.NDArray, int, list[int]]:
|
||||
import cv2
|
||||
|
||||
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
frames = np.empty((num_expected_frames, height, width, 3), dtype=np.uint8)
|
||||
|
||||
i = 0
|
||||
valid_frame_indices = []
|
||||
for idx in range(max_frame_idx + 1):
|
||||
ok = cap.grab()
|
||||
if not ok:
|
||||
# Frame is broken/unreadable, log warning
|
||||
if idx in frame_indices:
|
||||
logger.warning(
|
||||
"Failed to grab frame %d during video loading. "
|
||||
"This frame will be skipped.",
|
||||
idx,
|
||||
)
|
||||
continue
|
||||
if idx in frame_indices:
|
||||
ret, frame = cap.retrieve()
|
||||
if ret:
|
||||
frames[i] = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
valid_frame_indices.append(idx)
|
||||
i += 1
|
||||
else:
|
||||
# retrieve() failed even though grab() succeeded
|
||||
logger.warning(
|
||||
"Failed to retrieve frame %d during video loading. "
|
||||
"This frame will be skipped.",
|
||||
idx,
|
||||
)
|
||||
|
||||
valid_num_frames = len(valid_frame_indices)
|
||||
if valid_num_frames < num_expected_frames:
|
||||
logger.warning(
|
||||
"Video loading completed with %d broken/unreadable frames. "
|
||||
"Expected %d frames but only loaded %d frames.",
|
||||
num_expected_frames - valid_num_frames,
|
||||
num_expected_frames,
|
||||
valid_num_frames,
|
||||
)
|
||||
|
||||
return frames[:valid_num_frames], valid_num_frames, valid_frame_indices
|
||||
|
||||
|
||||
VIDEO_LOADER_REGISTRY = ExtensionManager()
|
||||
|
||||
|
||||
@VIDEO_LOADER_REGISTRY.register("opencv")
|
||||
class OpenCVVideoBackend(VideoLoader):
|
||||
def get_cv2_video_api(self):
|
||||
import cv2.videoio_registry as vr
|
||||
|
||||
api_pref = None
|
||||
for backend in vr.getStreamBufferedBackends():
|
||||
if not vr.hasBackend(backend):
|
||||
continue
|
||||
if not vr.isBackendBuiltIn(backend):
|
||||
_, abi, api = vr.getStreamBufferedBackendPluginVersion(backend)
|
||||
if abi < 1 or (abi == 1 and api < 2):
|
||||
continue
|
||||
api_pref = backend
|
||||
break
|
||||
return api_pref
|
||||
|
||||
@classmethod
|
||||
def load_bytes(
|
||||
cls,
|
||||
data: bytes,
|
||||
num_frames: int = -1,
|
||||
fps: int = -1,
|
||||
**kwargs,
|
||||
) -> tuple[npt.NDArray, dict[str, Any]]:
|
||||
import cv2
|
||||
|
||||
backend = cls().get_cv2_video_api()
|
||||
cap = cv2.VideoCapture(BytesIO(data), backend, [])
|
||||
if not cap.isOpened():
|
||||
raise ValueError("Could not open video stream")
|
||||
|
||||
total_frames_num = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
original_fps = cap.get(cv2.CAP_PROP_FPS)
|
||||
duration = total_frames_num / original_fps if original_fps > 0 else 0
|
||||
|
||||
# resample video to target num_frames and fps
|
||||
# - the minimum of the two will be used
|
||||
num_frames_to_sample = total_frames_num
|
||||
if num_frames > 0:
|
||||
num_frames_to_sample = min(num_frames, total_frames_num)
|
||||
if fps > 0:
|
||||
num_frames_to_sample = min(num_frames_to_sample, math.floor(duration * fps))
|
||||
num_frames_to_sample = max(1, num_frames_to_sample) # at least one sample
|
||||
|
||||
if num_frames_to_sample == total_frames_num:
|
||||
frame_idx = list(range(0, num_frames_to_sample))
|
||||
else:
|
||||
uniform_sampled_frames = np.linspace(
|
||||
0, total_frames_num - 1, num_frames_to_sample, dtype=int
|
||||
)
|
||||
frame_idx = uniform_sampled_frames.tolist()
|
||||
|
||||
# Convert to set for O(1) lookup performance
|
||||
frame_idx_set = set(frame_idx)
|
||||
frames, valid_num_frames, valid_frame_indices = cls._read_frames(
|
||||
cap, frame_idx_set, num_frames_to_sample, max(frame_idx)
|
||||
)
|
||||
|
||||
# Use transformers transformers.video_utils.VideoMetadata format
|
||||
# NOTE(Isotr0py): For models like Qwen3-VL/GLM4.5V, this metadata
|
||||
# can cause incorrect timestamp calculation without num_frames=-1.
|
||||
metadata = {
|
||||
"total_num_frames": total_frames_num,
|
||||
"fps": original_fps,
|
||||
"duration": duration,
|
||||
"video_backend": "opencv",
|
||||
"frames_indices": valid_frame_indices,
|
||||
# extra field used to control hf processor's video
|
||||
# sampling behavior
|
||||
"do_sample_frames": valid_num_frames == total_frames_num,
|
||||
}
|
||||
|
||||
return frames, metadata
|
||||
|
||||
|
||||
@VIDEO_LOADER_REGISTRY.register("opencv_dynamic")
|
||||
class OpenCVDynamicVideoBackend(OpenCVVideoBackend):
|
||||
@classmethod
|
||||
def load_bytes(
|
||||
cls,
|
||||
data: bytes,
|
||||
num_frames: int = -1,
|
||||
fps: int = 2,
|
||||
max_duration: int = 300,
|
||||
**kwargs,
|
||||
) -> tuple[npt.NDArray, dict[str, Any]]:
|
||||
import cv2
|
||||
|
||||
backend = cls().get_cv2_video_api()
|
||||
cap = cv2.VideoCapture(BytesIO(data), backend, [])
|
||||
if not cap.isOpened():
|
||||
raise ValueError("Could not open video stream")
|
||||
|
||||
total_frames_num = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
original_fps = cap.get(cv2.CAP_PROP_FPS)
|
||||
duration = total_frames_num / original_fps if original_fps > 0 else 0
|
||||
|
||||
# resample video to target num_frames
|
||||
max_frame_idx = total_frames_num - 1
|
||||
duration = duration or round(max_frame_idx / original_fps) + 1
|
||||
|
||||
# Refer to:
|
||||
# https://github.com/huggingface/transformers/blob/v4.55.4/src/transformers/models/glm4v/video_processing_glm4v.py#L103-L140
|
||||
frame_indices_list: list[int]
|
||||
if duration <= max_duration:
|
||||
n = int(math.floor(duration * fps))
|
||||
frame_indices_list = sorted(
|
||||
{
|
||||
min(max_frame_idx, int(math.ceil(i * original_fps / fps)))
|
||||
for i in range(n)
|
||||
}
|
||||
)
|
||||
else:
|
||||
num_samples = int(max_duration * fps)
|
||||
if num_samples >= total_frames_num:
|
||||
frame_indices_list = list(range(total_frames_num))
|
||||
else:
|
||||
target_seconds = np.linspace(0, duration, num_samples, endpoint=True)
|
||||
frame_indices_list = sorted(
|
||||
{
|
||||
min(max_frame_idx, int(math.ceil(t * original_fps)))
|
||||
for t in target_seconds
|
||||
}
|
||||
)
|
||||
|
||||
# Convert to set for O(1) lookup performance
|
||||
frame_indices_set = set(frame_indices_list)
|
||||
frames, valid_num_frames, valid_frame_indices = cls._read_frames(
|
||||
cap,
|
||||
frame_indices_set,
|
||||
len(frame_indices_list),
|
||||
total_frames_num - 1,
|
||||
)
|
||||
|
||||
# Use transformers transformers.video_utils.VideoMetadata format
|
||||
metadata = {
|
||||
"total_num_frames": total_frames_num,
|
||||
"fps": original_fps,
|
||||
"duration": duration,
|
||||
"video_backend": "opencv_dynamic",
|
||||
"frames_indices": valid_frame_indices,
|
||||
"do_sample_frames": False,
|
||||
}
|
||||
|
||||
return frames, metadata
|
||||
|
||||
|
||||
class VideoMediaIO(MediaIO[tuple[npt.NDArray, dict[str, Any]]]):
|
||||
def __init__(
|
||||
self,
|
||||
image_io: ImageMediaIO,
|
||||
num_frames: int = 32,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.image_io = image_io
|
||||
self.num_frames = num_frames
|
||||
# `kwargs` contains custom arguments from
|
||||
# --media-io-kwargs for this modality.
|
||||
# They can be passed to the underlying
|
||||
# media loaders (e.g. custom implementations)
|
||||
# for flexible control.
|
||||
|
||||
# Allow per-request override of video backend via kwargs.
|
||||
# This enables users to specify a different backend than the
|
||||
# global VLLM_VIDEO_LOADER_BACKEND env var, e.g.:
|
||||
# --media-io-kwargs '{"video": {"video_backend": "torchcodec"}}'
|
||||
video_loader_backend = (
|
||||
kwargs.pop("video_backend", None) or envs.VLLM_VIDEO_LOADER_BACKEND
|
||||
)
|
||||
self.kwargs = kwargs
|
||||
self.video_loader = VIDEO_LOADER_REGISTRY.load(video_loader_backend)
|
||||
|
||||
def load_bytes(self, data: bytes) -> tuple[npt.NDArray, dict[str, Any]]:
|
||||
return self.video_loader.load_bytes(
|
||||
data, num_frames=self.num_frames, **self.kwargs
|
||||
)
|
||||
|
||||
def load_base64(
|
||||
self, media_type: str, data: str
|
||||
) -> tuple[npt.NDArray, dict[str, Any]]:
|
||||
if media_type.lower() == "video/jpeg":
|
||||
load_frame = partial(
|
||||
self.image_io.load_base64,
|
||||
"image/jpeg",
|
||||
)
|
||||
|
||||
return np.stack(
|
||||
[np.asarray(load_frame(frame_data)) for frame_data in data.split(",")]
|
||||
), {}
|
||||
|
||||
return self.load_bytes(base64.b64decode(data))
|
||||
|
||||
def load_file(self, filepath: Path) -> tuple[npt.NDArray, dict[str, Any]]:
|
||||
with filepath.open("rb") as f:
|
||||
data = f.read()
|
||||
|
||||
return self.load_bytes(data)
|
||||
|
||||
def encode_base64(
|
||||
self,
|
||||
media: npt.NDArray,
|
||||
*,
|
||||
video_format: str = "JPEG",
|
||||
) -> str:
|
||||
video = media
|
||||
|
||||
if video_format == "JPEG":
|
||||
encode_frame = partial(
|
||||
self.image_io.encode_base64,
|
||||
image_format=video_format,
|
||||
)
|
||||
|
||||
return ",".join(encode_frame(Image.fromarray(frame)) for frame in video)
|
||||
|
||||
msg = "Only JPEG format is supported for now."
|
||||
raise NotImplementedError(msg)
|
||||
Reference in New Issue
Block a user