update
This commit is contained in:
@@ -1,40 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from .hasher import MultiModalHasher
|
||||
from .inputs import (
|
||||
BatchedTensorInputs,
|
||||
ModalityData,
|
||||
MultiModalDataBuiltins,
|
||||
MultiModalDataDict,
|
||||
MultiModalKwargs,
|
||||
MultiModalKwargsItems,
|
||||
MultiModalPlaceholderDict,
|
||||
MultiModalUUIDDict,
|
||||
NestedTensors,
|
||||
)
|
||||
from .registry import MultiModalRegistry
|
||||
|
||||
MULTIMODAL_REGISTRY = MultiModalRegistry()
|
||||
"""
|
||||
The global [`MultiModalRegistry`][vllm.multimodal.registry.MultiModalRegistry]
|
||||
is used by model runners to dispatch data processing according to the target
|
||||
model.
|
||||
|
||||
Info:
|
||||
[mm_processing](../../../design/mm_processing.md)
|
||||
"""
|
||||
|
||||
__all__ = [
|
||||
"BatchedTensorInputs",
|
||||
"ModalityData",
|
||||
"MultiModalDataBuiltins",
|
||||
"MultiModalDataDict",
|
||||
"MultiModalHasher",
|
||||
"MultiModalKwargs",
|
||||
"MultiModalKwargsItems",
|
||||
"MultiModalPlaceholderDict",
|
||||
"MultiModalUUIDDict",
|
||||
"NestedTensors",
|
||||
"MULTIMODAL_REGISTRY",
|
||||
"MultiModalRegistry",
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,118 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
|
||||
from vllm.utils.import_utils import PlaceholderModule
|
||||
|
||||
from .base import MediaIO
|
||||
|
||||
try:
|
||||
import librosa
|
||||
except ImportError:
|
||||
librosa = PlaceholderModule("librosa") # type: ignore[assignment]
|
||||
|
||||
try:
|
||||
import soundfile
|
||||
except ImportError:
|
||||
soundfile = PlaceholderModule("soundfile") # type: ignore[assignment]
|
||||
|
||||
|
||||
def resample_audio_librosa(
|
||||
audio: npt.NDArray[np.floating],
|
||||
*,
|
||||
orig_sr: float,
|
||||
target_sr: float,
|
||||
) -> npt.NDArray[np.floating]:
|
||||
return librosa.resample(audio, orig_sr=orig_sr, target_sr=target_sr)
|
||||
|
||||
|
||||
def resample_audio_scipy(
|
||||
audio: npt.NDArray[np.floating],
|
||||
*,
|
||||
orig_sr: float,
|
||||
target_sr: float,
|
||||
):
|
||||
# lazy import scipy.signal, otherwise it will crash doc build.
|
||||
import scipy.signal
|
||||
|
||||
if orig_sr > target_sr:
|
||||
return scipy.signal.resample_poly(audio, 1, orig_sr // target_sr)
|
||||
elif orig_sr < target_sr:
|
||||
return scipy.signal.resample_poly(audio, target_sr // orig_sr, 1)
|
||||
return audio
|
||||
|
||||
|
||||
class AudioResampler:
|
||||
"""Resample audio data to a target sample rate."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
target_sr: float | None = None,
|
||||
method: Literal["librosa", "scipy"] = "librosa",
|
||||
):
|
||||
self.target_sr = target_sr
|
||||
self.method = method
|
||||
|
||||
def resample(
|
||||
self,
|
||||
audio: npt.NDArray[np.floating],
|
||||
*,
|
||||
orig_sr: float,
|
||||
) -> npt.NDArray[np.floating]:
|
||||
if self.target_sr is None:
|
||||
raise RuntimeError(
|
||||
"Audio resampling is not supported when `target_sr` is not provided"
|
||||
)
|
||||
if self.method == "librosa":
|
||||
return resample_audio_librosa(
|
||||
audio, orig_sr=orig_sr, target_sr=self.target_sr
|
||||
)
|
||||
elif self.method == "scipy":
|
||||
return resample_audio_scipy(
|
||||
audio, orig_sr=orig_sr, target_sr=self.target_sr
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid resampling method: {self.method}. "
|
||||
"Supported methods are 'librosa' and 'scipy'."
|
||||
)
|
||||
|
||||
|
||||
class AudioMediaIO(MediaIO[tuple[npt.NDArray, float]]):
|
||||
def __init__(self, **kwargs) -> None:
|
||||
super().__init__()
|
||||
|
||||
# `kwargs` contains custom arguments from
|
||||
# --media-io-kwargs for this modality.
|
||||
# They can be passed to the underlying
|
||||
# media loaders (e.g. custom implementations)
|
||||
# for flexible control.
|
||||
self.kwargs = kwargs
|
||||
|
||||
def load_bytes(self, data: bytes) -> tuple[npt.NDArray, float]:
|
||||
return librosa.load(BytesIO(data), sr=None)
|
||||
|
||||
def load_base64(
|
||||
self,
|
||||
media_type: str,
|
||||
data: str,
|
||||
) -> tuple[npt.NDArray, float]:
|
||||
return self.load_bytes(base64.b64decode(data))
|
||||
|
||||
def load_file(self, filepath: Path) -> tuple[npt.NDArray, float]:
|
||||
return librosa.load(filepath, sr=None)
|
||||
|
||||
def encode_base64(self, media: tuple[npt.NDArray, int]) -> str:
|
||||
audio, sr = media
|
||||
|
||||
with BytesIO() as buffer:
|
||||
soundfile.write(buffer, audio, sr, format="WAV")
|
||||
data = buffer.getvalue()
|
||||
|
||||
return base64.b64encode(data).decode("utf-8")
|
||||
@@ -1,26 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
class MediaIO(ABC, Generic[_T]):
|
||||
@abstractmethod
|
||||
def load_bytes(self, data: bytes) -> _T:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def load_base64(self, media_type: str, data: str) -> _T:
|
||||
"""
|
||||
List of media types:
|
||||
https://www.iana.org/assignments/media-types/media-types.xhtml
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def load_file(self, filepath: Path) -> _T:
|
||||
raise NotImplementedError
|
||||
@@ -1,755 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import operator
|
||||
import sys
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Mapping, Sequence
|
||||
from multiprocessing.synchronize import Lock as LockType
|
||||
from typing import TYPE_CHECKING, Generic, TypeAlias, TypeVar, cast
|
||||
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.distributed.device_communicators.shm_object_storage import (
|
||||
MsgpackSerde,
|
||||
SingleWriterShmObjectStorage,
|
||||
SingleWriterShmRingBuffer,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.cache import CacheInfo, LRUCache
|
||||
from vllm.utils.jsontree import json_count_leaves, json_map_leaves, json_reduce_leaves
|
||||
from vllm.utils.mem_constants import GiB_bytes, MiB_bytes
|
||||
|
||||
from .inputs import (
|
||||
MultiModalBatchedField,
|
||||
MultiModalFeatureSpec,
|
||||
MultiModalFieldElem,
|
||||
MultiModalKwargs,
|
||||
MultiModalKwargsItem,
|
||||
MultiModalKwargsItems,
|
||||
NestedTensors,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
|
||||
from .processing import ResolvedPromptUpdate
|
||||
from .registry import MultiModalRegistry
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class MultiModalProcessorCacheItem:
|
||||
"""
|
||||
The data to store inside `MultiModalProcessorOnlyCache`.
|
||||
|
||||
Args:
|
||||
item: The processed tensor data corresponding to a multi-modal item.
|
||||
prompt_updates: The prompt updates corresponding to `item`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
item: MultiModalKwargsItem,
|
||||
prompt_updates: Sequence["ResolvedPromptUpdate"],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.item = item
|
||||
self.prompt_updates = prompt_updates
|
||||
|
||||
|
||||
class MultiModalProcessorCacheItemMetadata:
|
||||
"""
|
||||
The metadata to store inside `MultiModalProcessorSenderCache`.
|
||||
|
||||
Args:
|
||||
item: The processed tensor data corresponding to a multi-modal item.
|
||||
Since P1 already stores the tensor data, we only store its size
|
||||
metadata in P0 to reduce memory usage. The size metadata is still
|
||||
needed to keep the same cache eviction policy as P0.
|
||||
prompt_updates: The prompt updates corresponding to `item`.
|
||||
This needs to stay on P0 because for some models, they are
|
||||
dependent on the processed tensor data (cached on P1).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
item: MultiModalKwargsItem,
|
||||
prompt_updates: Sequence["ResolvedPromptUpdate"],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.item_size = MultiModalCache.get_item_size(item)
|
||||
self.prompt_updates = prompt_updates
|
||||
|
||||
|
||||
MultiModalCacheValue: TypeAlias = (
|
||||
MultiModalProcessorCacheItem
|
||||
| MultiModalProcessorCacheItemMetadata
|
||||
| MultiModalKwargsItems
|
||||
| MultiModalKwargsItem
|
||||
| MultiModalKwargs
|
||||
| Mapping[str, NestedTensors]
|
||||
)
|
||||
|
||||
_V = TypeVar("_V", bound=MultiModalCacheValue)
|
||||
|
||||
|
||||
class MultiModalCache:
|
||||
@classmethod
|
||||
def get_leaf_size(cls, leaf: object) -> int:
|
||||
if isinstance(leaf, MultiModalProcessorCacheItem):
|
||||
return cls.get_leaf_size(leaf.item)
|
||||
if isinstance(leaf, MultiModalProcessorCacheItemMetadata):
|
||||
return leaf.item_size
|
||||
|
||||
# These are not subclasses of dict
|
||||
if isinstance(
|
||||
leaf,
|
||||
(
|
||||
MultiModalKwargs,
|
||||
MultiModalKwargsItems,
|
||||
MultiModalKwargsItem,
|
||||
MultiModalFieldElem,
|
||||
),
|
||||
):
|
||||
return cls.get_item_size(leaf.data) # type: ignore
|
||||
|
||||
# sys.getsizeof doesn't work for tensors
|
||||
if isinstance(leaf, torch.Tensor):
|
||||
return leaf.nbytes
|
||||
|
||||
return sys.getsizeof(leaf)
|
||||
|
||||
@classmethod
|
||||
def get_item_size(
|
||||
cls,
|
||||
value: MultiModalCacheValue,
|
||||
*,
|
||||
debug: bool = False,
|
||||
) -> int:
|
||||
size = json_reduce_leaves(
|
||||
operator.add, json_map_leaves(cls.get_leaf_size, value)
|
||||
)
|
||||
|
||||
if debug:
|
||||
leaf_count = json_count_leaves(value)
|
||||
logger.debug(
|
||||
"Calculated size of %s to be %.2f GiB (%d leaves)",
|
||||
type(value),
|
||||
size / GiB_bytes,
|
||||
leaf_count,
|
||||
)
|
||||
|
||||
return size
|
||||
|
||||
@classmethod
|
||||
def get_item_complexity(cls, value: MultiModalCacheValue) -> int:
|
||||
"""
|
||||
Get the number of leaf elements in a multi-modal cache value.
|
||||
|
||||
This provides a measure of structural complexity that can be useful
|
||||
for debugging cache performance and understanding data patterns.
|
||||
|
||||
Args:
|
||||
value: The multi-modal cache value to analyze.
|
||||
|
||||
Returns:
|
||||
The number of leaf elements in the nested structure.
|
||||
"""
|
||||
return json_count_leaves(value)
|
||||
|
||||
@classmethod
|
||||
def get_lru_cache(
|
||||
cls,
|
||||
capacity_gb: float,
|
||||
value_type: type[_V],
|
||||
*,
|
||||
debug: bool = False,
|
||||
) -> LRUCache[str, _V]:
|
||||
return LRUCache(
|
||||
GiB_bytes * capacity_gb,
|
||||
getsizeof=lambda x: cls.get_item_size(x, debug=debug),
|
||||
)
|
||||
|
||||
|
||||
_I = TypeVar("_I", contravariant=True)
|
||||
_O = TypeVar("_O", covariant=True)
|
||||
|
||||
|
||||
class BaseMultiModalCache(ABC, Generic[_I, _O]):
|
||||
"""
|
||||
Abstract base class to read/write multi-modal items from cache.
|
||||
|
||||
The idea of multi-modal caching is based on having a client and server
|
||||
where the client executes in the frontend process (=P0) and
|
||||
the server in the core process (=P1). The data flow is as follows:
|
||||
|
||||
```
|
||||
is_cached() x N get_and_update()
|
||||
P0: From API -----------------> -----------------> To P1
|
||||
|
||||
get_and_update()
|
||||
P1: From P0 -----------------> To model
|
||||
```
|
||||
|
||||
`is_cached()` can be called any number of times in P0. However,
|
||||
`get_and_update()` must be called in P0 and P1 one after another
|
||||
so that their cache eviction order remains the same.
|
||||
|
||||
This ensures that the keys in P0 and P1 caches are mirrored,
|
||||
allowing us to determine whether a key is cached in P1 by looking
|
||||
up the P0 cache, without having to communicate with P1.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_and_update_item(
|
||||
self,
|
||||
mm_item: _I,
|
||||
mm_hash: str,
|
||||
) -> _O:
|
||||
"""
|
||||
Possibly update a multi-modal item based on whether it is
|
||||
in the underlying cache.
|
||||
|
||||
This update is done out-of-place and updates the cache eviction order.
|
||||
|
||||
Args:
|
||||
mm_item: The multi-modal item to update.
|
||||
mm_hash: The hash of `mm_item`.
|
||||
|
||||
Returns:
|
||||
The update multi-modal item.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_and_update(
|
||||
self,
|
||||
mm_items: Sequence[_I],
|
||||
mm_hashes: list[str],
|
||||
) -> list[_O]:
|
||||
"""
|
||||
Possibly update a sequence of multi-modal items based on whether they
|
||||
are in the underlying cache.
|
||||
|
||||
This update is done out-of-place and updates the cache eviction order.
|
||||
|
||||
Args:
|
||||
mm_items: The multi-modal items to update.
|
||||
mm_hashes: The hash of each item in `mm_items`.
|
||||
|
||||
Returns:
|
||||
A new list of updated multi-modal items.
|
||||
"""
|
||||
assert len(mm_items) == len(mm_hashes)
|
||||
|
||||
return [
|
||||
self.get_and_update_item(mm_item, mm_hash)
|
||||
for mm_item, mm_hash in zip(mm_items, mm_hashes)
|
||||
]
|
||||
|
||||
@abstractmethod
|
||||
def clear_cache(self) -> None:
|
||||
"""Clear the underlying cache."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
MultiModalProcessorCacheInItem: TypeAlias = (
|
||||
tuple[MultiModalKwargsItem, Sequence["ResolvedPromptUpdate"]] | None
|
||||
)
|
||||
|
||||
|
||||
MultiModalProcessorCacheOutItem: TypeAlias = tuple[
|
||||
MultiModalKwargsItem | None, Sequence["ResolvedPromptUpdate"]
|
||||
]
|
||||
|
||||
|
||||
class BaseMultiModalProcessorCache(
|
||||
BaseMultiModalCache[MultiModalProcessorCacheInItem, MultiModalProcessorCacheOutItem]
|
||||
):
|
||||
"""The required interface for caches on P0."""
|
||||
|
||||
@abstractmethod
|
||||
def is_cached_item(self, mm_hash: str) -> bool:
|
||||
"""
|
||||
Check whether a multi-modal item is
|
||||
in the underlying cache.
|
||||
|
||||
This **DOES NOT** update the cache eviction order.
|
||||
|
||||
Args:
|
||||
mm_hash: The hash of the item to check.
|
||||
|
||||
Returns:
|
||||
`True` if the item is cached, otherwise `False`.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def is_cached(self, mm_hashes: list[str]) -> list[bool]:
|
||||
"""
|
||||
Check whether a sequence of multi-modal items are
|
||||
in the underlying cache.
|
||||
|
||||
This **DOES NOT** update the cache eviction order.
|
||||
|
||||
Args:
|
||||
mm_hashes: The hash of each item to check.
|
||||
|
||||
Returns:
|
||||
For each item, `True` if the item is cached, otherwise `False`.
|
||||
"""
|
||||
return [self.is_cached_item(mm_hash) for mm_hash in mm_hashes]
|
||||
|
||||
@abstractmethod
|
||||
def make_stats(self, *, delta: bool = False) -> CacheInfo:
|
||||
"""
|
||||
Get (and reset) the multi-modal cache stats.
|
||||
|
||||
Returns:
|
||||
The current multi-modal caching stats.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MultiModalProcessorOnlyCache(BaseMultiModalProcessorCache):
|
||||
"""
|
||||
The cache which is used on P0 when IPC caching is disabled.
|
||||
|
||||
How to update each item:
|
||||
|
||||
- If the item is in the cache, replace the input with the cached item.
|
||||
- If the item is not in the cache, store that item (which includes
|
||||
tensor data and metadata) into the cache, and return the input.
|
||||
"""
|
||||
|
||||
def __init__(self, model_config: "ModelConfig") -> None:
|
||||
super().__init__()
|
||||
|
||||
mm_config = model_config.get_multimodal_config()
|
||||
|
||||
self._cache = MultiModalCache.get_lru_cache(
|
||||
mm_config.mm_processor_cache_gb,
|
||||
MultiModalProcessorCacheItem,
|
||||
)
|
||||
|
||||
@override
|
||||
def is_cached_item(self, mm_hash: str) -> bool:
|
||||
return mm_hash in self._cache
|
||||
|
||||
@override
|
||||
def get_and_update_item(
|
||||
self,
|
||||
mm_item: MultiModalProcessorCacheInItem,
|
||||
mm_hash: str,
|
||||
) -> MultiModalProcessorCacheOutItem:
|
||||
if (cached_item := self._cache.get(mm_hash)) is not None:
|
||||
return cached_item.item, cached_item.prompt_updates
|
||||
|
||||
assert mm_item is not None, f"Expected a cached item for {mm_hash=}"
|
||||
|
||||
self._cache[mm_hash] = MultiModalProcessorCacheItem(*mm_item)
|
||||
|
||||
return mm_item
|
||||
|
||||
@override
|
||||
def clear_cache(self) -> None:
|
||||
self._cache.clear()
|
||||
|
||||
@override
|
||||
def make_stats(self, *, delta: bool = False) -> CacheInfo:
|
||||
return self._cache.stat(delta=delta)
|
||||
|
||||
|
||||
class MultiModalProcessorSenderCache(BaseMultiModalProcessorCache):
|
||||
"""
|
||||
The cache which is used on P0 when IPC caching is enabled.
|
||||
|
||||
How to update each item:
|
||||
|
||||
- If the item is already in the cache, clear the input to avoid
|
||||
unnecessary IPC.
|
||||
|
||||
- If the item is not in the cache, store the metadata of that item so
|
||||
that the eviction policy remains the same as the cache on P1,
|
||||
and return the input.
|
||||
By only storing the metadata, we avoid keeping the data itself in
|
||||
memory inside P0.
|
||||
"""
|
||||
|
||||
def __init__(self, model_config: "ModelConfig") -> None:
|
||||
super().__init__()
|
||||
|
||||
mm_config = model_config.get_multimodal_config()
|
||||
|
||||
self._cache = MultiModalCache.get_lru_cache(
|
||||
mm_config.mm_processor_cache_gb,
|
||||
MultiModalProcessorCacheItemMetadata,
|
||||
)
|
||||
|
||||
@override
|
||||
def is_cached_item(self, mm_hash: str) -> bool:
|
||||
return mm_hash in self._cache
|
||||
|
||||
@override
|
||||
def get_and_update_item(
|
||||
self,
|
||||
mm_item: MultiModalProcessorCacheInItem,
|
||||
mm_hash: str,
|
||||
) -> MultiModalProcessorCacheOutItem:
|
||||
if (cached_item := self._cache.get(mm_hash)) is not None:
|
||||
return None, cached_item.prompt_updates
|
||||
|
||||
assert mm_item is not None, f"Expected a cached item for {mm_hash=}"
|
||||
|
||||
self._cache[mm_hash] = MultiModalProcessorCacheItemMetadata(*mm_item)
|
||||
|
||||
return mm_item
|
||||
|
||||
@override
|
||||
def clear_cache(self) -> None:
|
||||
self._cache.clear()
|
||||
|
||||
@override
|
||||
def make_stats(self, *, delta: bool = False) -> CacheInfo:
|
||||
return self._cache.stat(delta=delta)
|
||||
|
||||
|
||||
class ShmObjectStoreSenderCache(BaseMultiModalProcessorCache):
|
||||
"""
|
||||
The cache which is used on P0 when IPC caching is enabled.
|
||||
|
||||
How to update each item:
|
||||
|
||||
- If the item is already in the cache, clear the input to avoid
|
||||
unnecessary IPC.
|
||||
|
||||
- If the item is not in the cache, store the data in shared memory.
|
||||
"""
|
||||
|
||||
def __init__(self, vllm_config: "VllmConfig") -> None:
|
||||
super().__init__()
|
||||
|
||||
self.world_size = vllm_config.parallel_config.world_size
|
||||
mm_config = vllm_config.model_config.get_multimodal_config()
|
||||
|
||||
ring_buffer = SingleWriterShmRingBuffer(
|
||||
data_buffer_size=int(mm_config.mm_processor_cache_gb * GiB_bytes),
|
||||
name=envs.VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME,
|
||||
create=True, # sender is the writer
|
||||
)
|
||||
self._shm_cache = SingleWriterShmObjectStorage(
|
||||
max_object_size=mm_config.mm_shm_cache_max_object_size_mb * MiB_bytes,
|
||||
n_readers=self.world_size,
|
||||
ring_buffer=ring_buffer,
|
||||
serde_class=MsgpackSerde,
|
||||
)
|
||||
# cache (prompt_updates, modality) for P0 only
|
||||
self._p0_cache: dict[str, tuple[Sequence[ResolvedPromptUpdate], str]] = {}
|
||||
|
||||
self._hits = 0
|
||||
self._total = 0
|
||||
self._last_info = CacheInfo(hits=0, total=0)
|
||||
|
||||
def _stat(self, *, delta: bool = False) -> CacheInfo:
|
||||
info = CacheInfo(hits=self._hits, total=self._total)
|
||||
|
||||
if delta:
|
||||
info_delta = info - self._last_info
|
||||
self._last_info = info
|
||||
info = info_delta
|
||||
|
||||
return info
|
||||
|
||||
@override
|
||||
def is_cached_item(self, mm_hash: str) -> bool:
|
||||
return self._shm_cache.is_cached(mm_hash)
|
||||
|
||||
@override
|
||||
def get_and_update_item(
|
||||
self,
|
||||
mm_item: MultiModalProcessorCacheInItem,
|
||||
mm_hash: str,
|
||||
) -> MultiModalProcessorCacheOutItem:
|
||||
if self._shm_cache.is_cached(mm_hash):
|
||||
self._hits += 1
|
||||
self._total += 1
|
||||
|
||||
address, monotonic_id = self._shm_cache.get_cached(mm_hash)
|
||||
prompt_updates, modality = self._p0_cache[mm_hash]
|
||||
return self.address_as_item(address, monotonic_id, modality), prompt_updates
|
||||
|
||||
assert mm_item is not None, f"Expected a cached item for {mm_hash=}"
|
||||
|
||||
self._total += 1
|
||||
|
||||
try:
|
||||
address, monotonic_id = self._shm_cache.put(mm_hash, mm_item[0])
|
||||
# Try to remove dangling items if p0 cache is too large.
|
||||
if len(self._p0_cache) >= 2 * len(self._shm_cache.key_index):
|
||||
self.remove_dangling_items()
|
||||
self._p0_cache[mm_hash] = mm_item[1], mm_item[0].modality
|
||||
address_item = self.address_as_item(
|
||||
address, monotonic_id, mm_item[0].modality
|
||||
)
|
||||
return address_item, mm_item[1]
|
||||
except (ValueError, MemoryError) as e:
|
||||
# put may fail if the object is too large or
|
||||
# the cache is full.
|
||||
# In this case we log the error and keep the original mm_input.
|
||||
logger.debug("Failed to cache mm_input with hash %s: %s", mm_hash, e)
|
||||
return mm_item
|
||||
|
||||
@override
|
||||
def clear_cache(self) -> None:
|
||||
self._shm_cache.clear()
|
||||
self._p0_cache.clear()
|
||||
|
||||
self._hits = 0
|
||||
self._total = 0
|
||||
self._last_info = CacheInfo(hits=0, total=0)
|
||||
|
||||
@override
|
||||
def make_stats(self, *, delta: bool = False) -> CacheInfo:
|
||||
return self._stat(delta=delta)
|
||||
|
||||
def remove_dangling_items(self) -> None:
|
||||
"""Remove items that are no longer in the shared memory cache."""
|
||||
cached_hashes = self._shm_cache.key_index.keys()
|
||||
dangling_hashes = set(self._p0_cache.keys()) - cached_hashes
|
||||
for mm_hash in dangling_hashes:
|
||||
del self._p0_cache[mm_hash]
|
||||
|
||||
def address_as_item(
|
||||
self, address: int, monotonic_id: int, modality: str
|
||||
) -> MultiModalKwargsItem:
|
||||
addr_elem = MultiModalFieldElem(
|
||||
modality=modality,
|
||||
key="address",
|
||||
data=address,
|
||||
field=MultiModalBatchedField(),
|
||||
)
|
||||
id_elem = MultiModalFieldElem(
|
||||
modality=modality,
|
||||
key="monotonic_id",
|
||||
data=monotonic_id,
|
||||
field=MultiModalBatchedField(),
|
||||
)
|
||||
mm_item = MultiModalKwargsItem.from_elems([addr_elem, id_elem])
|
||||
return mm_item
|
||||
|
||||
|
||||
def _enable_processor_cache(
|
||||
model_config: "ModelConfig",
|
||||
mm_registry: "MultiModalRegistry",
|
||||
) -> bool:
|
||||
if not mm_registry.supports_multimodal_inputs(model_config):
|
||||
return False
|
||||
|
||||
mm_config = model_config.get_multimodal_config()
|
||||
return mm_config.mm_processor_cache_gb > 0
|
||||
|
||||
|
||||
def _enable_ipc_cache(vllm_config: "VllmConfig") -> bool:
|
||||
parallel_config = vllm_config.parallel_config
|
||||
supports_ipc_cache = (
|
||||
parallel_config._api_process_count == 1
|
||||
and parallel_config.data_parallel_size == 1
|
||||
) or parallel_config.data_parallel_external_lb
|
||||
|
||||
return supports_ipc_cache
|
||||
|
||||
|
||||
def _enable_mm_input_shm_cache(vllm_config: "VllmConfig") -> bool:
|
||||
"""Whether the shared memory based cache should be enabled."""
|
||||
|
||||
if not _enable_ipc_cache(vllm_config):
|
||||
return False
|
||||
|
||||
mm_config = vllm_config.model_config.get_multimodal_config()
|
||||
|
||||
return mm_config.mm_processor_cache_type == "shm"
|
||||
|
||||
|
||||
def processor_cache_from_config(
|
||||
vllm_config: "VllmConfig",
|
||||
mm_registry: "MultiModalRegistry",
|
||||
) -> BaseMultiModalProcessorCache | None:
|
||||
"""Return a `BaseMultiModalProcessorCache`, if enabled."""
|
||||
model_config = vllm_config.model_config
|
||||
|
||||
if not _enable_processor_cache(model_config, mm_registry):
|
||||
return None
|
||||
|
||||
if not _enable_ipc_cache(vllm_config):
|
||||
return MultiModalProcessorOnlyCache(model_config)
|
||||
|
||||
if not _enable_mm_input_shm_cache(vllm_config):
|
||||
return MultiModalProcessorSenderCache(model_config)
|
||||
return ShmObjectStoreSenderCache(vllm_config)
|
||||
|
||||
|
||||
def processor_only_cache_from_config(
|
||||
model_config: "ModelConfig",
|
||||
mm_registry: "MultiModalRegistry",
|
||||
):
|
||||
"""Return a `MultiModalProcessorOnlyCache`, if enabled."""
|
||||
if not _enable_processor_cache(model_config, mm_registry):
|
||||
return None
|
||||
|
||||
return MultiModalProcessorOnlyCache(model_config)
|
||||
|
||||
|
||||
class BaseMultiModalReceiverCache(
|
||||
BaseMultiModalCache[MultiModalKwargsItem | None, MultiModalKwargsItem]
|
||||
):
|
||||
"""The required interface for caches on P1."""
|
||||
|
||||
def get_and_update_features(
|
||||
self,
|
||||
mm_features: list["MultiModalFeatureSpec"],
|
||||
) -> list["MultiModalFeatureSpec"]:
|
||||
"""Update multimodal features with cached encoder outputs."""
|
||||
for feature in mm_features:
|
||||
feature.data = self.get_and_update_item(feature.data, feature.identifier)
|
||||
return mm_features
|
||||
|
||||
|
||||
class MultiModalReceiverCache(BaseMultiModalReceiverCache):
|
||||
"""
|
||||
The cache which is used on P1 when IPC caching is enabled.
|
||||
|
||||
How to update each item:
|
||||
|
||||
- If the item is in the cache, replace the input with the cached item.
|
||||
- If the item is not in the cache, store that item (which includes tensor
|
||||
data) into the cache, and return the input.
|
||||
"""
|
||||
|
||||
def __init__(self, model_config: "ModelConfig") -> None:
|
||||
super().__init__()
|
||||
|
||||
mm_config = model_config.get_multimodal_config()
|
||||
|
||||
self._cache = MultiModalCache.get_lru_cache(
|
||||
mm_config.mm_processor_cache_gb,
|
||||
MultiModalKwargsItem,
|
||||
)
|
||||
|
||||
@override
|
||||
def get_and_update_item(
|
||||
self,
|
||||
mm_item: MultiModalKwargsItem | None,
|
||||
mm_hash: str,
|
||||
) -> MultiModalKwargsItem:
|
||||
if (cached_item := self._cache.get(mm_hash)) is not None:
|
||||
return cached_item
|
||||
|
||||
assert mm_item is not None, f"Expected a cached item for {mm_hash=}"
|
||||
|
||||
self._cache[mm_hash] = mm_item
|
||||
return mm_item
|
||||
|
||||
@override
|
||||
def clear_cache(self) -> None:
|
||||
self._cache.clear()
|
||||
|
||||
|
||||
class ShmObjectStoreReceiverCache(BaseMultiModalReceiverCache):
|
||||
"""
|
||||
The cache which is used on P1 Worker Process when IPC caching is enabled.
|
||||
|
||||
How to update each item:
|
||||
|
||||
- If the item has an address, replace the input with the cached item.
|
||||
- If not, return the input.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: "VllmConfig",
|
||||
shared_worker_lock: LockType,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.world_size = vllm_config.parallel_config.world_size
|
||||
mm_config = vllm_config.model_config.get_multimodal_config()
|
||||
|
||||
ring_buffer = SingleWriterShmRingBuffer(
|
||||
data_buffer_size=int(mm_config.mm_processor_cache_gb * GiB_bytes),
|
||||
name=envs.VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME,
|
||||
create=False, # Server is a reader
|
||||
)
|
||||
self._shm_cache = SingleWriterShmObjectStorage(
|
||||
max_object_size=mm_config.mm_shm_cache_max_object_size_mb * MiB_bytes,
|
||||
n_readers=self.world_size,
|
||||
ring_buffer=ring_buffer,
|
||||
serde_class=MsgpackSerde,
|
||||
reader_lock=shared_worker_lock,
|
||||
)
|
||||
|
||||
@override
|
||||
def get_and_update_item(
|
||||
self,
|
||||
mm_item: MultiModalKwargsItem | None,
|
||||
mm_hash: str,
|
||||
) -> MultiModalKwargsItem:
|
||||
assert mm_item is not None, f"Expected an address item for {mm_hash=}"
|
||||
if "address" in mm_item:
|
||||
address = cast(int, mm_item["address"].data)
|
||||
monotonic_id = cast(int, mm_item["monotonic_id"].data)
|
||||
return self._shm_cache.get(address, monotonic_id)
|
||||
|
||||
return mm_item
|
||||
|
||||
@override
|
||||
def clear_cache(self) -> None:
|
||||
self._shm_cache.clear()
|
||||
|
||||
|
||||
def engine_receiver_cache_from_config(
|
||||
vllm_config: "VllmConfig",
|
||||
mm_registry: "MultiModalRegistry",
|
||||
) -> BaseMultiModalReceiverCache | None:
|
||||
"""
|
||||
This is used in the engine process.
|
||||
Return a `BaseMultiModalReceiverCache` only when IPC caching is enabled and
|
||||
mm_processor_cache_type=="lru".
|
||||
"""
|
||||
model_config = vllm_config.model_config
|
||||
|
||||
if not _enable_processor_cache(model_config, mm_registry):
|
||||
return None
|
||||
|
||||
if not _enable_ipc_cache(vllm_config):
|
||||
return None
|
||||
|
||||
if not _enable_mm_input_shm_cache(vllm_config):
|
||||
return MultiModalReceiverCache(model_config)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def worker_receiver_cache_from_config(
|
||||
vllm_config: "VllmConfig",
|
||||
mm_registry: "MultiModalRegistry",
|
||||
shared_worker_lock: LockType,
|
||||
) -> BaseMultiModalReceiverCache | None:
|
||||
"""
|
||||
This is used in the worker process.
|
||||
Return a `BaseMultiModalReceiverCache` only when IPC caching is enabled and
|
||||
mm_processor_cache_type=="shm".
|
||||
"""
|
||||
model_config = vllm_config.model_config
|
||||
|
||||
if not _enable_processor_cache(model_config, mm_registry):
|
||||
return None
|
||||
|
||||
if not _enable_ipc_cache(vllm_config):
|
||||
return None
|
||||
|
||||
if not _enable_mm_input_shm_cache(vllm_config):
|
||||
return None
|
||||
|
||||
return ShmObjectStoreReceiverCache(vllm_config, shared_worker_lock)
|
||||
@@ -1,294 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
# and proprietary rights in and to this software, related documentation
|
||||
# and any modifications thereto. Any use, reproduction, disclosure or
|
||||
# distribution of this software and related documentation without an express
|
||||
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
import typing
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def compute_retained_tokens_count(
|
||||
tokens_per_frame: int, num_frames: int, q: float
|
||||
) -> int:
|
||||
"""
|
||||
Compute the number of retained tokens for a given video.
|
||||
Method ensures that we retain all the tokens from the first frame
|
||||
regardless of the pruning rate.
|
||||
|
||||
Args:
|
||||
tokens_per_frame: The number of tokens per frame.
|
||||
num_frames: The total number of frames.
|
||||
q: The pruning rate.
|
||||
|
||||
Returns:
|
||||
The number of retained tokens.
|
||||
"""
|
||||
total_tokens = tokens_per_frame * num_frames
|
||||
evs_num_tokens = int(total_tokens * (1 - q))
|
||||
min_num_tokens = tokens_per_frame
|
||||
return max(min_num_tokens, evs_num_tokens)
|
||||
|
||||
|
||||
def compute_retention_mask(
|
||||
video_embeds: torch.Tensor,
|
||||
video_size_thw: torch.LongTensor | tuple[int, int, int],
|
||||
spatial_merge_size: int,
|
||||
q: float,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Computes the retention mask for input video embeddings.
|
||||
|
||||
Args:
|
||||
video_embeds (`torch.Tensor`): The input video embeddings
|
||||
of shape `(T * H * W // spatial_merge_size ^ 2, hidden_size)`
|
||||
video_size_thw (`torch.LongTensor` of shape `(3)`):
|
||||
The temporal, height and width of video.
|
||||
spatial_merge_size: Size reduction for rows & cols dimensions.
|
||||
q: (`float`): Pruning rate factor [0,1)
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: The retention mask for the video embeddings of
|
||||
`(T * H * W // spatial_merge_size ^ 2)` shape.
|
||||
"""
|
||||
T, H, W = map(int, video_size_thw)
|
||||
|
||||
# Use reshape instead of einops to avoid graph breaks
|
||||
video_embeds = video_embeds.reshape(
|
||||
T,
|
||||
H // spatial_merge_size,
|
||||
W // spatial_merge_size,
|
||||
video_embeds.size(-1),
|
||||
)
|
||||
tokens_per_frame = (H // spatial_merge_size) * (W // spatial_merge_size)
|
||||
# Core EVS
|
||||
similarity = torch.nn.functional.cosine_similarity(
|
||||
video_embeds[1:, ...], video_embeds[:-1, ...], dim=-1
|
||||
)
|
||||
dissimilarity = 1 - similarity
|
||||
|
||||
# Always ensure we include all tokens from the first frame
|
||||
dissimilarity = torch.cat(
|
||||
[255 * torch.ones_like(video_embeds[:1, :, :, 0]), dissimilarity], dim=0
|
||||
)
|
||||
|
||||
dissimilarity_flat = dissimilarity.view(-1)
|
||||
order = torch.argsort(dissimilarity_flat, dim=-1, descending=True, stable=True)
|
||||
retain_num_tokens = compute_retained_tokens_count(
|
||||
tokens_per_frame=tokens_per_frame, num_frames=T, q=q
|
||||
)
|
||||
topk_indices = order[:retain_num_tokens]
|
||||
|
||||
retention_mask = torch.zeros_like(dissimilarity_flat, dtype=torch.bool)
|
||||
retention_mask[topk_indices] = True
|
||||
retention_mask = retention_mask.reshape(dissimilarity.size())
|
||||
|
||||
mask = retention_mask.view(-1) # "T H W -> (T H W)"
|
||||
return mask
|
||||
|
||||
|
||||
def compute_mrope_for_media(
|
||||
video_size_thw: torch.LongTensor,
|
||||
spatial_merge_size: int,
|
||||
tokens_per_second: float = 1.0,
|
||||
video_second_per_grid: float = 1.0,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Computes the mrope for video embeddings based on the grid dimensions.
|
||||
Computed mrope positions match original qwen 2.5 implementation,
|
||||
but positions are built for media being the first element in sequence.
|
||||
|
||||
Args:
|
||||
video_size_thw: Media size (num frames, rows, cols)
|
||||
spatial_merge_size: Size reduction for rows & cols dimensions.
|
||||
tokens_per_second: Number of tokens per second.
|
||||
video_second_per_grid: Number of seconds per video.
|
||||
|
||||
Returns:
|
||||
Tensor of shape `(T * H * W, 4)` where last dimension
|
||||
represents mrope positions [0:3), while the last channel
|
||||
contains value of llm_grid_w repeated for all positions.
|
||||
"""
|
||||
llm_grid_t = video_size_thw[0]
|
||||
llm_grid_h = video_size_thw[1] // spatial_merge_size
|
||||
llm_grid_w = video_size_thw[2] // spatial_merge_size
|
||||
|
||||
t_index = (
|
||||
(
|
||||
torch.arange(llm_grid_t)
|
||||
.view(-1, 1)
|
||||
.expand(-1, llm_grid_h * llm_grid_w)
|
||||
.mul(tokens_per_second * video_second_per_grid)
|
||||
)
|
||||
.long()
|
||||
.flatten()
|
||||
)
|
||||
h_index = (
|
||||
torch.arange(llm_grid_h)
|
||||
.view(1, -1, 1)
|
||||
.expand(llm_grid_t, -1, llm_grid_w)
|
||||
.flatten()
|
||||
)
|
||||
w_index = (
|
||||
torch.arange(llm_grid_w)
|
||||
.view(1, 1, -1)
|
||||
.expand(llm_grid_t, llm_grid_h, -1)
|
||||
.flatten()
|
||||
)
|
||||
llm_grid_w = (
|
||||
torch.tensor([llm_grid_w])
|
||||
.view(1, 1, 1)
|
||||
.expand(llm_grid_t, llm_grid_h, llm_grid_w)
|
||||
.flatten()
|
||||
)
|
||||
|
||||
positions = torch.stack([t_index, h_index, w_index, llm_grid_w], dim=1)
|
||||
return positions
|
||||
|
||||
|
||||
def recompute_mrope_positions(
|
||||
input_ids: torch.LongTensor,
|
||||
multimodal_positions: list[torch.Tensor],
|
||||
mrope_positions: torch.LongTensor,
|
||||
num_computed_tokens: int,
|
||||
vision_start_token_id: int,
|
||||
image_token_id: int,
|
||||
video_token_id: int,
|
||||
) -> tuple[torch.LongTensor, int]:
|
||||
"""
|
||||
Update part of input mrope positions.
|
||||
Original mrope_positions are computed incorrectly, so once we prune media
|
||||
tokens we should reflect this in the mrope positions for the LLM.
|
||||
|
||||
This method supports chunked prefill approach where
|
||||
multimodal_embeddings are passed to LLM in chunks, so input
|
||||
multimodal_embeddings may contain zero, some or even some part of all
|
||||
multimodal_embeddings for a given prompt.
|
||||
|
||||
Each multimodal_positions has 4 extra channels
|
||||
(First 3 channels corresponds to original 3 mrope positions, last channel
|
||||
is the maximum width of the media repeated). Provided multimodal_positions
|
||||
do not reflect location of media position in sequence - they are computed
|
||||
like the media is in the 0-th position in the sequence.
|
||||
|
||||
Method works as follows: it recomputes mrope_positions starting from the
|
||||
`num_computed_tokens` for `total_len_of_multimodal_embeddings` and then
|
||||
shifts all text tokens that goes after total_len_of_multimodal_embeddings.
|
||||
|
||||
It also handles case when multimodal_embeddings is partial
|
||||
(e.g. one media is split into two prefill stages)
|
||||
|
||||
Args:
|
||||
input_ids: (N,) All input tokens of the prompt (entire sequence).
|
||||
multimodal_positions: List of mrope positsions for each media.
|
||||
mrope_positions: Existing mrope positions (4, N) for entire sequence.
|
||||
num_computed_tokens: A number of computed tokens so far.
|
||||
vision_start_token_id: Token indicating start of vision media.
|
||||
image_token_id: Image token id
|
||||
video_token_id: Video token id
|
||||
|
||||
Returns:
|
||||
Tuple of (mrope_positions, mrope_position_delta).
|
||||
"""
|
||||
|
||||
# Tensors
|
||||
positions: torch.LongTensor = typing.cast(
|
||||
torch.LongTensor, mrope_positions.clone()
|
||||
) # (3, N)
|
||||
N = input_ids.numel()
|
||||
|
||||
image_mask = input_ids.eq(image_token_id)
|
||||
video_mask = input_ids.eq(video_token_id)
|
||||
media_mask = image_mask | video_mask
|
||||
text_mask = ~media_mask
|
||||
|
||||
# Early exit: no media in this chunk
|
||||
if len(multimodal_positions) == 0:
|
||||
delta = int((positions.max().item() + 1) - N) if positions.numel() else -N
|
||||
return positions, delta
|
||||
|
||||
total_mm_tokens = torch.count_nonzero(media_mask)
|
||||
seen_mm_tokens = torch.count_nonzero(media_mask[:num_computed_tokens])
|
||||
|
||||
# Early exit: we've updated positions for all media tokens
|
||||
# (and consequently - for all remaining text tokens)
|
||||
if seen_mm_tokens == total_mm_tokens:
|
||||
delta = int((positions.max().item() + 1) - N) if positions.numel() else -N
|
||||
return positions, delta
|
||||
|
||||
vision_start_indices = (input_ids == vision_start_token_id).nonzero(as_tuple=True)[
|
||||
0
|
||||
]
|
||||
|
||||
for mm_pos in multimodal_positions:
|
||||
# Each mm_pos can be a complete embedding for single media
|
||||
# or it can be a part of a single media (due to chunked prefill)
|
||||
|
||||
# Cases to cover
|
||||
# - Current prefill chunk has no vision start indexes at all
|
||||
# - Vision start token appeared in previous prefill round
|
||||
# - Regular case
|
||||
seen_vision_start_indices = vision_start_indices[
|
||||
vision_start_indices < num_computed_tokens
|
||||
]
|
||||
|
||||
if len(seen_vision_start_indices):
|
||||
# If we have encountered some vision start indexes,
|
||||
# then we should check the condition:
|
||||
# | --- prefill 1 ------| ---- prefill 2 ----- |
|
||||
# | TTTTTTTTTSVVVVVVVVVV|VVVVVVTTTTTTTTTTTTTTTT|
|
||||
last_vision_start_token = seen_vision_start_indices[-1]
|
||||
seem_mm_tokens_before_last_vision_start = torch.count_nonzero(
|
||||
media_mask[:last_vision_start_token]
|
||||
)
|
||||
in_the_middle_of_media = (
|
||||
seen_mm_tokens > seem_mm_tokens_before_last_vision_start
|
||||
)
|
||||
|
||||
if in_the_middle_of_media:
|
||||
mm_embeddings_seen = (
|
||||
seen_mm_tokens - seem_mm_tokens_before_last_vision_start
|
||||
)
|
||||
global_mm_start = last_vision_start_token
|
||||
else:
|
||||
# We have completed previous mm_embedding part and
|
||||
# ready to start a new one
|
||||
next_vision_start_token = vision_start_indices[
|
||||
vision_start_indices >= num_computed_tokens
|
||||
][0]
|
||||
mm_embeddings_seen = 0
|
||||
global_mm_start = next_vision_start_token
|
||||
|
||||
else:
|
||||
# If there were no vision start indexes so far,
|
||||
# let's find first vision start index
|
||||
next_vision_start_token = vision_start_indices[
|
||||
vision_start_indices >= num_computed_tokens
|
||||
][0]
|
||||
|
||||
mm_embeddings_seen = 0
|
||||
global_mm_start = next_vision_start_token
|
||||
|
||||
# Offset right after vision_start_token
|
||||
base = positions[-1, global_mm_start] + 1
|
||||
local_start = global_mm_start + 1 + mm_embeddings_seen
|
||||
local_end = local_start + mm_pos.shape[1]
|
||||
positions[:, local_start:local_end] = mm_pos[0:3] + base
|
||||
|
||||
# mm_pos[3, 0] is the max width of the media
|
||||
offset = mm_pos[3, 0] + base
|
||||
|
||||
text_pos_sum = torch.cumsum(text_mask[local_end:].long(), dim=0)
|
||||
|
||||
positions[:, local_end:N] = text_pos_sum + offset - 1
|
||||
|
||||
# Include distance to the next vision start token
|
||||
num_computed_tokens += mm_pos.shape[1]
|
||||
|
||||
mrope_positions_delta = (positions.max() + 1 - N).item()
|
||||
return positions, mrope_positions_delta
|
||||
@@ -1,106 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pickle
|
||||
import uuid
|
||||
from collections.abc import Iterable
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from blake3 import blake3
|
||||
from PIL import Image
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class MultiModalHasher:
|
||||
@classmethod
|
||||
def serialize_item(cls, obj: object) -> Iterable[bytes | memoryview]:
|
||||
# Simple cases
|
||||
if isinstance(obj, (bytes, memoryview)):
|
||||
return (obj,)
|
||||
if isinstance(obj, str):
|
||||
return (obj.encode("utf-8"),)
|
||||
if isinstance(obj, (int, float)):
|
||||
return (np.array(obj).tobytes(),)
|
||||
|
||||
if isinstance(obj, Image.Image):
|
||||
exif = obj.getexif()
|
||||
if Image.ExifTags.Base.ImageID in exif and isinstance(
|
||||
exif[Image.ExifTags.Base.ImageID], uuid.UUID
|
||||
):
|
||||
# If the image has exif ImageID tag, use that
|
||||
return (exif[Image.ExifTags.Base.ImageID].bytes,)
|
||||
data = {"mode": obj.mode, "data": np.asarray(obj)}
|
||||
if obj.palette is not None:
|
||||
data["palette"] = obj.palette.palette
|
||||
if obj.palette.rawmode is not None:
|
||||
data["palette_rawmode"] = obj.palette.rawmode
|
||||
return cls.iter_item_to_bytes("image", data)
|
||||
if isinstance(obj, torch.Tensor):
|
||||
tensor_obj: torch.Tensor = obj.cpu()
|
||||
tensor_dtype = tensor_obj.dtype
|
||||
tensor_shape = tensor_obj.shape
|
||||
|
||||
# NumPy does not support bfloat16.
|
||||
# Workaround: View the tensor as a contiguous 1D array of bytes
|
||||
if tensor_dtype == torch.bfloat16:
|
||||
tensor_obj = tensor_obj.contiguous()
|
||||
tensor_obj = tensor_obj.view((tensor_obj.numel(),)).view(torch.uint8)
|
||||
|
||||
return cls.iter_item_to_bytes(
|
||||
"tensor",
|
||||
{
|
||||
"original_dtype": str(tensor_dtype),
|
||||
"original_shape": tuple(tensor_shape),
|
||||
"data": tensor_obj.numpy(),
|
||||
},
|
||||
)
|
||||
return cls.iter_item_to_bytes("tensor", tensor_obj.numpy())
|
||||
if isinstance(obj, np.ndarray):
|
||||
# If the array is non-contiguous, we need to copy it first
|
||||
arr_data = (
|
||||
obj.view(np.uint8).data if obj.flags.c_contiguous else obj.tobytes()
|
||||
)
|
||||
return cls.iter_item_to_bytes(
|
||||
"ndarray",
|
||||
{
|
||||
"dtype": obj.dtype.str,
|
||||
"shape": obj.shape,
|
||||
"data": arr_data,
|
||||
},
|
||||
)
|
||||
logger.warning(
|
||||
"No serialization method found for %s. Falling back to pickle.", type(obj)
|
||||
)
|
||||
|
||||
return (pickle.dumps(obj),)
|
||||
|
||||
@classmethod
|
||||
def iter_item_to_bytes(
|
||||
cls,
|
||||
key: str,
|
||||
obj: object,
|
||||
) -> Iterable[bytes | memoryview]:
|
||||
# Recursive cases
|
||||
if isinstance(obj, (list, tuple)):
|
||||
for i, elem in enumerate(obj):
|
||||
yield from cls.iter_item_to_bytes(f"{key}.{i}", elem)
|
||||
elif isinstance(obj, dict):
|
||||
for k, v in obj.items():
|
||||
yield from cls.iter_item_to_bytes(f"{key}.{k}", v)
|
||||
else:
|
||||
yield key.encode("utf-8")
|
||||
yield from cls.serialize_item(obj)
|
||||
|
||||
@classmethod
|
||||
def hash_kwargs(cls, **kwargs: object) -> str:
|
||||
hasher = blake3()
|
||||
|
||||
for k, v in kwargs.items():
|
||||
for bytes_ in cls.iter_item_to_bytes(k, v):
|
||||
hasher.update(bytes_)
|
||||
|
||||
return hasher.hexdigest()
|
||||
@@ -1,130 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
|
||||
import pybase64
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from .base import MediaIO
|
||||
|
||||
|
||||
def rescale_image_size(
|
||||
image: Image.Image, size_factor: float, transpose: int = -1
|
||||
) -> Image.Image:
|
||||
"""Rescale the dimensions of an image by a constant factor."""
|
||||
new_width = int(image.width * size_factor)
|
||||
new_height = int(image.height * size_factor)
|
||||
image = image.resize((new_width, new_height))
|
||||
if transpose >= 0:
|
||||
image = image.transpose(Image.Transpose(transpose))
|
||||
return image
|
||||
|
||||
|
||||
def rgba_to_rgb(
|
||||
image: Image.Image,
|
||||
background_color: tuple[int, int, int] | list[int] = (255, 255, 255),
|
||||
) -> Image.Image:
|
||||
"""Convert an RGBA image to RGB with filled background color."""
|
||||
assert image.mode == "RGBA"
|
||||
converted = Image.new("RGB", image.size, background_color)
|
||||
converted.paste(image, mask=image.split()[3]) # 3 is the alpha channel
|
||||
return converted
|
||||
|
||||
|
||||
def convert_image_mode(image: Image.Image, to_mode: str):
|
||||
if image.mode == to_mode:
|
||||
return image
|
||||
elif image.mode == "RGBA" and to_mode == "RGB":
|
||||
return rgba_to_rgb(image)
|
||||
else:
|
||||
return image.convert(to_mode)
|
||||
|
||||
|
||||
class ImageMediaIO(MediaIO[Image.Image]):
|
||||
def __init__(self, image_mode: str = "RGB", **kwargs) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.image_mode = image_mode
|
||||
# `kwargs` contains custom arguments from
|
||||
# --media-io-kwargs for this modality.
|
||||
# They can be passed to the underlying
|
||||
# media loaders (e.g. custom implementations)
|
||||
# for flexible control.
|
||||
self.kwargs = kwargs
|
||||
|
||||
# Extract RGBA background color from kwargs if provided
|
||||
# Default to white background for backward compatibility
|
||||
rgba_bg = kwargs.get("rgba_background_color", (255, 255, 255))
|
||||
# Convert list to tuple for consistency
|
||||
if isinstance(rgba_bg, list):
|
||||
rgba_bg = tuple(rgba_bg)
|
||||
|
||||
# Validate rgba_background_color format
|
||||
if not (
|
||||
isinstance(rgba_bg, tuple)
|
||||
and len(rgba_bg) == 3
|
||||
and all(isinstance(c, int) and 0 <= c <= 255 for c in rgba_bg)
|
||||
):
|
||||
raise ValueError(
|
||||
"rgba_background_color must be a list or tuple of 3 integers "
|
||||
"in the range [0, 255]."
|
||||
)
|
||||
self.rgba_background_color = rgba_bg
|
||||
|
||||
def _convert_image_mode(self, image: Image.Image) -> Image.Image:
|
||||
"""Convert image mode with custom background color."""
|
||||
if image.mode == self.image_mode:
|
||||
return image
|
||||
elif image.mode == "RGBA" and self.image_mode == "RGB":
|
||||
return rgba_to_rgb(image, self.rgba_background_color)
|
||||
else:
|
||||
return convert_image_mode(image, self.image_mode)
|
||||
|
||||
def load_bytes(self, data: bytes) -> Image.Image:
|
||||
image = Image.open(BytesIO(data))
|
||||
image.load()
|
||||
return self._convert_image_mode(image)
|
||||
|
||||
def load_base64(self, media_type: str, data: str) -> Image.Image:
|
||||
return self.load_bytes(pybase64.b64decode(data, validate=True))
|
||||
|
||||
def load_file(self, filepath: Path) -> Image.Image:
|
||||
image = Image.open(filepath)
|
||||
image.load()
|
||||
return self._convert_image_mode(image)
|
||||
|
||||
def encode_base64(
|
||||
self,
|
||||
media: Image.Image,
|
||||
*,
|
||||
image_format: str = "JPEG",
|
||||
) -> str:
|
||||
image = media
|
||||
|
||||
with BytesIO() as buffer:
|
||||
image = self._convert_image_mode(image)
|
||||
image.save(buffer, image_format)
|
||||
data = buffer.getvalue()
|
||||
|
||||
return pybase64.b64encode(data).decode("utf-8")
|
||||
|
||||
|
||||
class ImageEmbeddingMediaIO(MediaIO[torch.Tensor]):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def load_bytes(self, data: bytes) -> torch.Tensor:
|
||||
buffer = BytesIO(data)
|
||||
return torch.load(buffer, weights_only=True)
|
||||
|
||||
def load_base64(self, media_type: str, data: str) -> torch.Tensor:
|
||||
return self.load_bytes(pybase64.b64decode(data, validate=True))
|
||||
|
||||
def load_file(self, filepath: Path) -> torch.Tensor:
|
||||
return torch.load(filepath, weights_only=True)
|
||||
|
||||
def encode_base64(self, media: torch.Tensor) -> str:
|
||||
return pybase64.b64encode(media.numpy()).decode("utf-8")
|
||||
1036
multimodal/inputs.py
1036
multimodal/inputs.py
File diff suppressed because it is too large
Load Diff
@@ -1,544 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import UserDict
|
||||
from collections.abc import Callable, Iterator, Mapping, Sequence
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Generic,
|
||||
Literal,
|
||||
NamedTuple,
|
||||
TypeAlias,
|
||||
TypeGuard,
|
||||
TypeVar,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from vllm.utils.collection_utils import is_list_of
|
||||
from vllm.utils.import_utils import LazyLoader
|
||||
|
||||
from .audio import AudioResampler
|
||||
from .inputs import (
|
||||
AudioItem,
|
||||
HfAudioItem,
|
||||
HfImageItem,
|
||||
HfVideoItem,
|
||||
ImageItem,
|
||||
ModalityData,
|
||||
MultiModalDataDict,
|
||||
MultiModalFieldConfig,
|
||||
MultiModalKwargsItems,
|
||||
VideoItem,
|
||||
)
|
||||
|
||||
_T = TypeVar("_T")
|
||||
_I = TypeVar("_I")
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import PIL.Image as PILImage
|
||||
else:
|
||||
PILImage = LazyLoader("PILImage", globals(), "PIL.Image")
|
||||
|
||||
|
||||
class ModalityDataItems(ABC, Generic[_T, _I]):
|
||||
"""
|
||||
Represents data items for a modality in
|
||||
[`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems].
|
||||
"""
|
||||
|
||||
def __init__(self, data: _T, modality: str) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.data: _T = data
|
||||
self.modality = modality
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{type(self).__name__}(modality={self.modality!r}, len={len(self)})"
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.get_count()
|
||||
|
||||
def __getitem__(self, index: int) -> _I:
|
||||
return self.get(index)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# Auto-generated
|
||||
def __iter__(self) -> Iterator[_I]: ...
|
||||
|
||||
@abstractmethod
|
||||
def get_count(self) -> int:
|
||||
"""Get the number of data items."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get(self, index: int) -> _I:
|
||||
"""Get a data item by its index."""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_all(self) -> list[_I]:
|
||||
"""Get all data items."""
|
||||
return [self.get(idx) for idx in range(self.get_count())]
|
||||
|
||||
@abstractmethod
|
||||
def get_processor_data(self) -> Mapping[str, object]:
|
||||
"""Get the data to pass to the HF processor."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_passthrough_data(self) -> Mapping[str, object]:
|
||||
"""Get the data to pass directly to the model."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ProcessorBatchItems(ModalityDataItems[Sequence[_T], _T]):
|
||||
"""Base class for data items that are arranged in a list."""
|
||||
|
||||
def get_count(self) -> int:
|
||||
return len(self.data)
|
||||
|
||||
def get(self, index: int) -> _T:
|
||||
return self.data[index]
|
||||
|
||||
def get_processor_data(self) -> Mapping[str, object]:
|
||||
return {f"{self.modality}s": self.data}
|
||||
|
||||
def get_passthrough_data(self) -> Mapping[str, object]:
|
||||
return {}
|
||||
|
||||
|
||||
class EmbeddingItems(
|
||||
ModalityDataItems[torch.Tensor | list[torch.Tensor], torch.Tensor]
|
||||
):
|
||||
"""
|
||||
Base class for data items that are expressed as a batched embedding tensor,
|
||||
or a list of embedding tensors (one per item).
|
||||
"""
|
||||
|
||||
def get_count(self) -> int:
|
||||
return len(self.data)
|
||||
|
||||
def get(self, index: int) -> torch.Tensor:
|
||||
return self.data[index]
|
||||
|
||||
def get_processor_data(self) -> Mapping[str, object]:
|
||||
return {}
|
||||
|
||||
def get_passthrough_data(self) -> Mapping[str, object]:
|
||||
return {f"{self.modality}_embeds": self.data}
|
||||
|
||||
def get_feature_size(self, item_idx: int) -> int:
|
||||
return len(self.get(item_idx))
|
||||
|
||||
|
||||
class DictEmbeddingItems(
|
||||
ModalityDataItems[Mapping[str, torch.Tensor], Mapping[str, torch.Tensor]]
|
||||
):
|
||||
"""
|
||||
Base class for data items that are expressed as a dictionary of tensors.
|
||||
|
||||
Usually, the dictionary keys correspond to the outputs of HF processor.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data: Mapping[str, torch.Tensor],
|
||||
modality: str,
|
||||
required_fields: set[str],
|
||||
fields_factory: Callable[
|
||||
[Mapping[str, torch.Tensor]],
|
||||
Mapping[str, MultiModalFieldConfig],
|
||||
],
|
||||
) -> None:
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
|
||||
super().__init__(data, modality)
|
||||
|
||||
missing_required_data_keys = required_fields - data.keys()
|
||||
if missing_required_data_keys:
|
||||
data_keys = set(data.keys())
|
||||
msg = (
|
||||
f"The data should contain the fields: {required_fields}, "
|
||||
f"but only found the following keys: {data_keys}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
fields_config = fields_factory(data)
|
||||
missing_required_fields = required_fields - fields_config.keys()
|
||||
if missing_required_fields:
|
||||
fields = set(fields_config.keys())
|
||||
msg = f"{required_fields=} should be a subset of {fields=}"
|
||||
raise ValueError(msg)
|
||||
|
||||
self.fields_config = fields_config
|
||||
self.required_fields = required_fields
|
||||
|
||||
self._kwargs = MultiModalKwargsItems.from_hf_inputs(
|
||||
BatchFeature(dict(data)),
|
||||
fields_config,
|
||||
)
|
||||
|
||||
def get_count(self) -> int:
|
||||
return len(self._kwargs[self.modality])
|
||||
|
||||
def get(self, index: int) -> Mapping[str, torch.Tensor]:
|
||||
return self._kwargs[self.modality][index].get_data()
|
||||
|
||||
def get_processor_data(self) -> Mapping[str, object]:
|
||||
return {}
|
||||
|
||||
def get_passthrough_data(self) -> Mapping[str, object]:
|
||||
return self.data
|
||||
|
||||
|
||||
class AudioProcessorItems(ProcessorBatchItems[HfAudioItem]):
|
||||
def __init__(self, data: Sequence[HfAudioItem] | None) -> None:
|
||||
if data is None:
|
||||
data = [None]
|
||||
super().__init__(data, "audio")
|
||||
|
||||
def get_audio_length(self, item_idx: int) -> int:
|
||||
audio = self.get(item_idx)
|
||||
return len(audio)
|
||||
|
||||
|
||||
class AudioEmbeddingItems(EmbeddingItems):
|
||||
def __init__(self, data: torch.Tensor | list[torch.Tensor]) -> None:
|
||||
super().__init__(data, "audio")
|
||||
|
||||
|
||||
class ImageSize(NamedTuple):
|
||||
width: int
|
||||
height: int
|
||||
|
||||
|
||||
class ImageProcessorItems(ProcessorBatchItems[HfImageItem]):
|
||||
def __init__(self, data: Sequence[HfImageItem] | None) -> None:
|
||||
if data is None:
|
||||
data = [None]
|
||||
super().__init__(data, "image")
|
||||
|
||||
def get_image_size(self, item_idx: int) -> ImageSize:
|
||||
image = self.get(item_idx)
|
||||
|
||||
if isinstance(image, PILImage.Image):
|
||||
return ImageSize(*image.size)
|
||||
if isinstance(image, (np.ndarray, torch.Tensor)):
|
||||
_, h, w = image.shape
|
||||
return ImageSize(w, h)
|
||||
|
||||
assert_never(image)
|
||||
|
||||
|
||||
class ImageEmbeddingItems(EmbeddingItems):
|
||||
def __init__(self, data: torch.Tensor | list[torch.Tensor]) -> None:
|
||||
super().__init__(data, "image")
|
||||
|
||||
|
||||
class VideoProcessorItems(ProcessorBatchItems[HfVideoItem]):
|
||||
def __init__(
|
||||
self,
|
||||
data: Sequence[HfVideoItem] | None,
|
||||
metadata: dict[str, Any] | list[dict[str, Any] | None] | None = None,
|
||||
) -> None:
|
||||
if data is None:
|
||||
data = [None]
|
||||
super().__init__(data, "video")
|
||||
self.metadata = metadata
|
||||
|
||||
def get_num_frames(self, item_idx: int) -> int:
|
||||
return len(self.get(item_idx))
|
||||
|
||||
def get_frame_size(self, item_idx: int) -> ImageSize:
|
||||
image = self.get(item_idx)[0] # Assume that the video isn't empty
|
||||
|
||||
if isinstance(image, PILImage.Image):
|
||||
return ImageSize(*image.size)
|
||||
if isinstance(image, (np.ndarray, torch.Tensor)):
|
||||
_, h, w = image.shape
|
||||
return ImageSize(w, h)
|
||||
|
||||
assert_never(image)
|
||||
|
||||
|
||||
class VideoEmbeddingItems(EmbeddingItems):
|
||||
def __init__(self, data: torch.Tensor | list[torch.Tensor]) -> None:
|
||||
super().__init__(data, "video")
|
||||
|
||||
|
||||
_D = TypeVar("_D", bound=ModalityDataItems[Any, Any])
|
||||
|
||||
|
||||
class MultiModalDataItems(UserDict[str, ModalityDataItems[Any, Any]]):
|
||||
"""
|
||||
As [`MultiModalDataDict`][vllm.multimodal.inputs.MultiModalDataDict], but
|
||||
normalized such that each entry corresponds to a list.
|
||||
"""
|
||||
|
||||
def get_count(self, modality: str, *, strict: bool = True) -> int:
|
||||
"""
|
||||
Get the number of data items belonging to a modality.
|
||||
|
||||
If `strict=False`, return `0` instead of raising [`KeyError`][]
|
||||
even if the modality is not found.
|
||||
"""
|
||||
if modality not in self:
|
||||
if strict:
|
||||
available_modalities = set(self.keys())
|
||||
raise KeyError(
|
||||
f"Modality {modality!r} not found. "
|
||||
f"Available modalities: {available_modalities}"
|
||||
)
|
||||
|
||||
return 0
|
||||
|
||||
return self[modality].get_count()
|
||||
|
||||
def get_all_counts(self) -> Mapping[str, int]:
|
||||
"""Get the number of items belonging to each modality."""
|
||||
return {m: items.get_count() for m, items in self.items()}
|
||||
|
||||
def get_items(
|
||||
self,
|
||||
modality: str,
|
||||
typ: type[_D] | tuple[type[_D], ...],
|
||||
) -> _D:
|
||||
"""
|
||||
Get the data items belonging to a modality,
|
||||
requiring that they belong to a certain type.
|
||||
"""
|
||||
if modality not in self:
|
||||
available_modalities = set(self.keys())
|
||||
raise KeyError(
|
||||
f"Modality {modality!r} not found. "
|
||||
f"Available modalities: {available_modalities}"
|
||||
)
|
||||
|
||||
items = self[modality]
|
||||
if not isinstance(items, typ):
|
||||
raise TypeError(
|
||||
f"Invalid type of data items for {modality=}. "
|
||||
f"Expected type: {typ}, but "
|
||||
f"found type: {type(items)}"
|
||||
)
|
||||
|
||||
return items # type: ignore[return-value]
|
||||
|
||||
|
||||
ModalityDataParser: TypeAlias = Callable[
|
||||
[ModalityData[Any]], ModalityDataItems[Any, Any] | None
|
||||
]
|
||||
|
||||
|
||||
class MultiModalDataParser:
|
||||
"""
|
||||
Parses [`MultiModalDataDict`][vllm.multimodal.inputs.MultiModalDataDict]
|
||||
into [`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems].
|
||||
|
||||
Args:
|
||||
target_sr (float, optional): Enables automatic resampling of audio
|
||||
items to the model's expected sampling rate.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
target_sr: float | None = None,
|
||||
audio_resample_method: Literal["librosa", "scipy"] = "librosa",
|
||||
video_needs_metadata: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.audio_resampler = AudioResampler(
|
||||
target_sr=target_sr,
|
||||
method=audio_resample_method,
|
||||
)
|
||||
self.video_needs_metadata = video_needs_metadata
|
||||
|
||||
@classmethod
|
||||
def is_embeddings(
|
||||
cls, data: object
|
||||
) -> TypeGuard[torch.Tensor | list[torch.Tensor]]:
|
||||
if isinstance(data, torch.Tensor):
|
||||
return data.ndim == 3
|
||||
if is_list_of(data, torch.Tensor):
|
||||
return data[0].ndim == 2 # type: ignore[index]
|
||||
|
||||
return False
|
||||
|
||||
def _is_empty(self, data: object) -> TypeGuard[None]:
|
||||
if isinstance(data, list):
|
||||
return len(data) == 0
|
||||
if isinstance(data, (np.ndarray, torch.Tensor)):
|
||||
return data.size == 0
|
||||
|
||||
return False
|
||||
|
||||
def _get_audio_with_sr(
|
||||
self,
|
||||
audio: AudioItem,
|
||||
) -> tuple[np.ndarray, float | None]:
|
||||
if isinstance(audio, tuple):
|
||||
return audio
|
||||
if isinstance(audio, list):
|
||||
return np.array(audio), None
|
||||
if isinstance(audio, np.ndarray):
|
||||
return audio, None
|
||||
if isinstance(audio, torch.Tensor):
|
||||
return audio.numpy(), None
|
||||
|
||||
assert_never(audio)
|
||||
|
||||
def _get_video_with_metadata(
|
||||
self,
|
||||
video: VideoItem,
|
||||
) -> tuple[np.ndarray, dict[str, Any] | None]:
|
||||
if isinstance(video, tuple):
|
||||
return video
|
||||
if isinstance(video, list):
|
||||
return np.array(video), None
|
||||
if isinstance(video, np.ndarray):
|
||||
return video, None
|
||||
if isinstance(video, torch.Tensor):
|
||||
return video.numpy(), None
|
||||
|
||||
assert_never(video)
|
||||
|
||||
def _parse_audio_data(
|
||||
self,
|
||||
data: ModalityData[AudioItem],
|
||||
) -> ModalityDataItems[Any, Any] | None:
|
||||
if data is None:
|
||||
return AudioProcessorItems(None)
|
||||
|
||||
# also check single audio item with sampling rate
|
||||
if self._is_empty(data) or (
|
||||
isinstance(data, tuple) and self._is_empty(data[0])
|
||||
):
|
||||
return None
|
||||
|
||||
if self.is_embeddings(data):
|
||||
return AudioEmbeddingItems(data)
|
||||
|
||||
data_items: list[AudioItem]
|
||||
if (
|
||||
is_list_of(data, float)
|
||||
or isinstance(data, (np.ndarray, torch.Tensor))
|
||||
and data.ndim == 1
|
||||
or isinstance(data, tuple)
|
||||
):
|
||||
data_items = [data]
|
||||
elif isinstance(data, (np.ndarray, torch.Tensor)):
|
||||
data_items = [elem for elem in data]
|
||||
else:
|
||||
data_items = data # type: ignore[assignment]
|
||||
|
||||
new_audios = list[np.ndarray]()
|
||||
for data_item in data_items:
|
||||
audio, orig_sr = self._get_audio_with_sr(data_item)
|
||||
if orig_sr is None:
|
||||
new_audio = audio
|
||||
else:
|
||||
new_audio = self.audio_resampler.resample(audio, orig_sr=orig_sr)
|
||||
|
||||
new_audios.append(new_audio)
|
||||
|
||||
return AudioProcessorItems(new_audios)
|
||||
|
||||
def _parse_image_data(
|
||||
self,
|
||||
data: ModalityData[ImageItem],
|
||||
) -> ModalityDataItems[Any, Any] | None:
|
||||
if data is None:
|
||||
return ImageProcessorItems(None)
|
||||
|
||||
if self._is_empty(data):
|
||||
return None
|
||||
|
||||
if self.is_embeddings(data):
|
||||
return ImageEmbeddingItems(data)
|
||||
|
||||
if (
|
||||
isinstance(data, PILImage.Image)
|
||||
or isinstance(data, (np.ndarray, torch.Tensor))
|
||||
and data.ndim == 3
|
||||
):
|
||||
data_items = [data]
|
||||
elif isinstance(data, (np.ndarray, torch.Tensor)):
|
||||
data_items = [elem for elem in data]
|
||||
else:
|
||||
data_items = data
|
||||
|
||||
return ImageProcessorItems(data_items)
|
||||
|
||||
def _parse_video_data(
|
||||
self,
|
||||
data: ModalityData[VideoItem],
|
||||
) -> ModalityDataItems[Any, Any] | None:
|
||||
if data is None:
|
||||
return VideoProcessorItems(None)
|
||||
|
||||
if self._is_empty(data):
|
||||
return None
|
||||
|
||||
if self.is_embeddings(data):
|
||||
return VideoEmbeddingItems(data)
|
||||
|
||||
data_items: list[VideoItem]
|
||||
if (
|
||||
is_list_of(data, PILImage.Image)
|
||||
or isinstance(data, (np.ndarray, torch.Tensor))
|
||||
and data.ndim == 4
|
||||
):
|
||||
data_items = [data]
|
||||
elif isinstance(data, (np.ndarray, torch.Tensor)):
|
||||
data_items = [elem for elem in data]
|
||||
elif isinstance(data, tuple) and len(data) == 2:
|
||||
data_items = [data]
|
||||
else:
|
||||
data_items = data # type: ignore[assignment]
|
||||
|
||||
new_videos = list[tuple[np.ndarray, dict[str, Any] | None]]()
|
||||
metadata_lst: list[dict[str, Any] | None] = []
|
||||
for data_item in data_items:
|
||||
video, metadata = self._get_video_with_metadata(data_item)
|
||||
if self.video_needs_metadata:
|
||||
if metadata is None:
|
||||
raise ValueError(
|
||||
"Video metadata is required but not found in mm input. "
|
||||
"Please check your video input in `multi_modal_data`"
|
||||
)
|
||||
new_videos.append((video, metadata))
|
||||
metadata_lst.append(metadata)
|
||||
else:
|
||||
new_videos.append(video)
|
||||
|
||||
if not self.video_needs_metadata:
|
||||
metadata = None
|
||||
|
||||
return VideoProcessorItems(new_videos, metadata=metadata_lst)
|
||||
|
||||
def _get_subparsers(self) -> Mapping[str, ModalityDataParser]:
|
||||
return {
|
||||
"audio": self._parse_audio_data,
|
||||
"image": self._parse_image_data,
|
||||
"video": self._parse_video_data,
|
||||
}
|
||||
|
||||
def parse_mm_data(self, mm_data: MultiModalDataDict) -> MultiModalDataItems:
|
||||
subparsers = self._get_subparsers()
|
||||
|
||||
mm_items = MultiModalDataItems()
|
||||
for k, v in mm_data.items():
|
||||
if k not in subparsers:
|
||||
raise ValueError(f"Unsupported modality: {k}")
|
||||
|
||||
# ignore empty embedding data
|
||||
if (parsed_data := subparsers[k](v)) is not None:
|
||||
mm_items[k] = parsed_data
|
||||
|
||||
return mm_items
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,369 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Generic, NamedTuple, TypeVar, cast
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
from PIL import Image
|
||||
|
||||
from vllm.config.multimodal import (
|
||||
AudioDummyOptions,
|
||||
BaseDummyOptions,
|
||||
ImageDummyOptions,
|
||||
VideoDummyOptions,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .inputs import (
|
||||
MultiModalDataDict,
|
||||
MultiModalEncDecInputs,
|
||||
MultiModalInputs,
|
||||
MultiModalKwargsItems,
|
||||
MultiModalPlaceholderDict,
|
||||
)
|
||||
from .processing import (
|
||||
BaseMultiModalProcessor,
|
||||
BaseProcessingInfo,
|
||||
EncDecMultiModalProcessor,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProcessorInputs:
|
||||
"""
|
||||
Represents the keyword arguments to
|
||||
[`vllm.multimodal.processing.BaseMultiModalProcessor.apply`][].
|
||||
"""
|
||||
|
||||
prompt: str | list[int]
|
||||
mm_data: MultiModalDataDict
|
||||
hf_processor_mm_kwargs: Mapping[str, object] = field(default_factory=dict)
|
||||
tokenization_kwargs: Mapping[str, object] = field(default_factory=dict)
|
||||
|
||||
|
||||
class DummyEncoderData(NamedTuple):
|
||||
"""Dummy data used for profiling."""
|
||||
|
||||
prompt_token_ids: list[int]
|
||||
|
||||
|
||||
class DummyDecoderData(NamedTuple):
|
||||
"""Dummy data used for profiling."""
|
||||
|
||||
prompt_token_ids: list[int]
|
||||
multi_modal_data: MultiModalKwargsItems
|
||||
multi_modal_placeholders: MultiModalPlaceholderDict
|
||||
|
||||
|
||||
_I = TypeVar("_I", bound=BaseProcessingInfo)
|
||||
|
||||
|
||||
class BaseDummyInputsBuilder(ABC, Generic[_I]):
|
||||
"""
|
||||
Abstract base class that constructs the dummy data to profile
|
||||
multi-modal models.
|
||||
"""
|
||||
|
||||
def __init__(self, info: _I) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.info = info
|
||||
|
||||
@abstractmethod
|
||||
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
||||
"""
|
||||
Build the text input corresponding to `mm_counts`.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_dummy_mm_data(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Mapping[str, BaseDummyOptions] | None = None,
|
||||
) -> MultiModalDataDict:
|
||||
"""
|
||||
Build the multimodal input which, after processing, results in
|
||||
the maximum possible number of placeholder tokens.
|
||||
|
||||
Args:
|
||||
seq_len: Sequence length
|
||||
mm_counts: Count of items per modality
|
||||
mm_options: Configurable options per modality (optional).
|
||||
If None, use model defaults for backward compatibility.
|
||||
If provided, models can use these to customize dummy
|
||||
data generation.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_dummy_processor_inputs(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Mapping[str, BaseDummyOptions] | None = None,
|
||||
) -> ProcessorInputs:
|
||||
"""
|
||||
Build the input which, after processing, results in
|
||||
the maximum possible number of placeholder tokens.
|
||||
|
||||
Args:
|
||||
seq_len: Sequence length
|
||||
mm_counts: Count of items per modality
|
||||
mm_options: Configurable options per modality (optional)
|
||||
"""
|
||||
dummy_text = self.get_dummy_text(mm_counts)
|
||||
|
||||
# Use the unified function for both legacy and configurable cases
|
||||
dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts, mm_options)
|
||||
|
||||
tokenization_kwargs = {"truncation": False}
|
||||
|
||||
return ProcessorInputs(
|
||||
prompt=dummy_text,
|
||||
mm_data=dummy_mm_data,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
|
||||
def _get_dummy_audios(
|
||||
self,
|
||||
*,
|
||||
length: int,
|
||||
num_audios: int,
|
||||
overrides: AudioDummyOptions | None = None,
|
||||
) -> list[npt.NDArray]:
|
||||
if num_audios == 0:
|
||||
return []
|
||||
if overrides and overrides.length:
|
||||
if overrides.length > length:
|
||||
logger.warning(
|
||||
"audio.length override (%d) exceeds model's "
|
||||
"maximum length (%d), will be ignored",
|
||||
overrides.length,
|
||||
length,
|
||||
)
|
||||
length = min(length, overrides.length)
|
||||
audio = np.zeros((length,))
|
||||
return [audio] * num_audios
|
||||
|
||||
def _get_dummy_images(
|
||||
self,
|
||||
*,
|
||||
width: int,
|
||||
height: int,
|
||||
num_images: int,
|
||||
overrides: ImageDummyOptions | None = None,
|
||||
) -> list[Image.Image]:
|
||||
if num_images == 0:
|
||||
return []
|
||||
if overrides:
|
||||
if overrides.width:
|
||||
if overrides.width > width:
|
||||
logger.warning(
|
||||
"image.width override (%d) exceeds model's "
|
||||
"maximum width (%d), will be ignored",
|
||||
overrides.width,
|
||||
width,
|
||||
)
|
||||
width = min(width, overrides.width)
|
||||
if overrides.height:
|
||||
if overrides.height > height:
|
||||
logger.warning(
|
||||
"image.height override (%d) exceeds model's "
|
||||
"maximum height (%d), will be ignored",
|
||||
overrides.height,
|
||||
height,
|
||||
)
|
||||
height = min(height, overrides.height)
|
||||
image = Image.new("RGB", (width, height), color=255)
|
||||
return [image] * num_images
|
||||
|
||||
def _get_dummy_videos(
|
||||
self,
|
||||
*,
|
||||
width: int,
|
||||
height: int,
|
||||
num_frames: int,
|
||||
num_videos: int,
|
||||
overrides: VideoDummyOptions | None = None,
|
||||
) -> list[npt.NDArray]:
|
||||
if num_videos == 0:
|
||||
return []
|
||||
if overrides:
|
||||
if overrides.num_frames:
|
||||
if overrides.num_frames > num_frames:
|
||||
logger.warning(
|
||||
"video.num_frames override (%d) exceeds model's "
|
||||
"maximum number of frames (%d), will be ignored",
|
||||
overrides.num_frames,
|
||||
num_frames,
|
||||
)
|
||||
num_frames = min(num_frames, overrides.num_frames)
|
||||
if overrides.width:
|
||||
if overrides.width > width:
|
||||
logger.warning(
|
||||
"video.width override (%d) exceeds model's "
|
||||
"maximum width (%d), will be ignored",
|
||||
overrides.width,
|
||||
width,
|
||||
)
|
||||
width = min(width, overrides.width)
|
||||
if overrides.height:
|
||||
if overrides.height > height:
|
||||
logger.warning(
|
||||
"video.height override (%d) exceeds model's "
|
||||
"maximum height (%d), will be ignored",
|
||||
overrides.height,
|
||||
height,
|
||||
)
|
||||
height = min(height, overrides.height)
|
||||
video = np.full((num_frames, width, height, 3), 255, dtype=np.uint8)
|
||||
return [video] * num_videos
|
||||
|
||||
|
||||
class MultiModalProfiler(Generic[_I]):
|
||||
"""
|
||||
Contains code for running memory profiling for multi-modal models.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
processor: BaseMultiModalProcessor[_I],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.processor = processor
|
||||
|
||||
@property
|
||||
def processing_info(self) -> BaseProcessingInfo:
|
||||
return self.processor.info
|
||||
|
||||
@property
|
||||
def dummy_inputs(self) -> BaseDummyInputsBuilder[_I]:
|
||||
return self.processor.dummy_inputs
|
||||
|
||||
def get_mm_limits(self) -> Mapping[str, int]:
|
||||
return self.processor.allowed_mm_limits
|
||||
|
||||
def _get_dummy_mm_inputs(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int] | None = None,
|
||||
mm_options: Mapping[str, BaseDummyOptions] | None = None,
|
||||
) -> MultiModalInputs:
|
||||
if mm_counts is None:
|
||||
mm_counts = self.get_mm_limits()
|
||||
|
||||
factory = self.dummy_inputs
|
||||
processor_inputs = factory.get_dummy_processor_inputs(
|
||||
seq_len, mm_counts, mm_options
|
||||
)
|
||||
|
||||
return self.processor.apply(
|
||||
prompt=processor_inputs.prompt,
|
||||
mm_data=processor_inputs.mm_data,
|
||||
hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
|
||||
tokenization_kwargs=processor_inputs.tokenization_kwargs,
|
||||
)
|
||||
|
||||
def _get_mm_num_tokens(
|
||||
self,
|
||||
mm_inputs: MultiModalInputs,
|
||||
mm_embeddings_only: bool = True,
|
||||
) -> Mapping[str, int]:
|
||||
placeholders_by_modality = mm_inputs["mm_placeholders"]
|
||||
|
||||
return {
|
||||
modality: sum(
|
||||
item.get_num_embeds() if mm_embeddings_only else item.length
|
||||
for item in placeholders
|
||||
)
|
||||
for modality, placeholders in placeholders_by_modality.items()
|
||||
}
|
||||
|
||||
def get_encoder_dummy_data(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int] | None = None,
|
||||
mm_options: Mapping[str, BaseDummyOptions] | None = None,
|
||||
) -> DummyEncoderData:
|
||||
mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts, mm_options)
|
||||
mm_inputs = cast(MultiModalEncDecInputs, mm_inputs)
|
||||
|
||||
# For encoder-decoder models, use encoder prompt token ids instead of
|
||||
# decoder prompt to construct dummy seq_data for encoder profiling.
|
||||
encoder_prompt_token_ids = mm_inputs["encoder_prompt_token_ids"]
|
||||
|
||||
total_len = len(encoder_prompt_token_ids)
|
||||
|
||||
processor = cast(EncDecMultiModalProcessor, self.processor)
|
||||
if processor.pad_dummy_encoder_prompt:
|
||||
num_tokens_to_pad = max(total_len, seq_len) - total_len
|
||||
encoder_prompt_token_ids.extend([0] * num_tokens_to_pad)
|
||||
|
||||
return DummyEncoderData(encoder_prompt_token_ids)
|
||||
|
||||
def get_decoder_dummy_data(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int] | None = None,
|
||||
mm_options: Mapping[str, BaseDummyOptions] | None = None,
|
||||
) -> DummyDecoderData:
|
||||
mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts, mm_options)
|
||||
|
||||
prompt_token_ids = mm_inputs["prompt_token_ids"]
|
||||
total_len = len(prompt_token_ids)
|
||||
|
||||
if total_len < seq_len:
|
||||
prompt_token_ids.extend([0] * (seq_len - total_len))
|
||||
|
||||
return DummyDecoderData(
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
multi_modal_data=mm_inputs["mm_kwargs"].require_data(),
|
||||
multi_modal_placeholders=mm_inputs["mm_placeholders"],
|
||||
)
|
||||
|
||||
def _get_mm_max_tokens(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int] | None = None,
|
||||
mm_embeddings_only: bool = True,
|
||||
) -> Mapping[str, int]:
|
||||
if mm_counts is None:
|
||||
mm_counts = self.get_mm_limits()
|
||||
|
||||
max_tokens_per_item = self.processing_info.get_mm_max_tokens_per_item(
|
||||
seq_len=seq_len,
|
||||
mm_counts=mm_counts,
|
||||
)
|
||||
if max_tokens_per_item is not None:
|
||||
return {
|
||||
modality: max_tokens
|
||||
for modality, max_tokens in max_tokens_per_item.items()
|
||||
if mm_counts.get(modality, 0) > 0
|
||||
}
|
||||
|
||||
mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
|
||||
return self._get_mm_num_tokens(mm_inputs, mm_embeddings_only=mm_embeddings_only)
|
||||
|
||||
def get_mm_max_contiguous_tokens(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int] | None = None,
|
||||
) -> Mapping[str, int]:
|
||||
"""
|
||||
Returns the maximum length of the multimodal (image placeholders+text)
|
||||
tokens, including any break/text tokens in-between image embeddings.
|
||||
|
||||
`<im_start> [IMG] [IMG] [IMG] <row_break> [IMG] [IMG] [IMG] <im_end>`
|
||||
Returns 9, even when the number of image embeddings is 6.
|
||||
|
||||
This is important to take into account when profiling and
|
||||
initializing the encoder cache size.
|
||||
"""
|
||||
return self._get_mm_max_tokens(seq_len, mm_counts, mm_embeddings_only=False)
|
||||
@@ -1,360 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Generic, Protocol, TypeVar
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, cached_tokenizer_from_config
|
||||
from vllm.utils.collection_utils import ClassRegistry
|
||||
|
||||
from .cache import BaseMultiModalProcessorCache
|
||||
from .processing import (
|
||||
BaseMultiModalProcessor,
|
||||
BaseProcessingInfo,
|
||||
InputProcessingContext,
|
||||
)
|
||||
from .profiling import (
|
||||
BaseDummyInputsBuilder,
|
||||
DummyDecoderData,
|
||||
DummyEncoderData,
|
||||
MultiModalProfiler,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
N = TypeVar("N", bound=type[nn.Module])
|
||||
_I = TypeVar("_I", bound=BaseProcessingInfo)
|
||||
_I_co = TypeVar("_I_co", bound=BaseProcessingInfo, covariant=True)
|
||||
|
||||
|
||||
class ProcessingInfoFactory(Protocol[_I_co]):
|
||||
"""
|
||||
Constructs a
|
||||
[`BaseMultiModalProcessor`][vllm.multimodal.processing.BaseMultiModalProcessor]
|
||||
instance from the context.
|
||||
"""
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
ctx: InputProcessingContext,
|
||||
) -> _I_co: ...
|
||||
|
||||
|
||||
class DummyInputsBuilderFactory(Protocol[_I]): # type: ignore[misc]
|
||||
"""
|
||||
Constructs a
|
||||
[`BaseDummyInputsBuilder`][vllm.multimodal.profiling.BaseDummyInputsBuilder]
|
||||
instance from the context.
|
||||
"""
|
||||
|
||||
def __call__(self, info: _I) -> BaseDummyInputsBuilder[_I]: ...
|
||||
|
||||
|
||||
class MultiModalProcessorFactory(Protocol[_I]): # type: ignore[misc]
|
||||
"""
|
||||
Constructs a
|
||||
[`BaseMultiModalProcessor`][vllm.multimodal.processing.BaseMultiModalProcessor]
|
||||
instance from the context.
|
||||
"""
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
info: _I,
|
||||
dummy_inputs: BaseDummyInputsBuilder[_I],
|
||||
*,
|
||||
cache: BaseMultiModalProcessorCache | None = None,
|
||||
) -> BaseMultiModalProcessor[_I]: ...
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _ProcessorFactories(Generic[_I]):
|
||||
info: ProcessingInfoFactory[_I]
|
||||
processor: MultiModalProcessorFactory[_I]
|
||||
dummy_inputs: DummyInputsBuilderFactory[_I]
|
||||
|
||||
def build_processor(
|
||||
self,
|
||||
ctx: InputProcessingContext,
|
||||
*,
|
||||
cache: BaseMultiModalProcessorCache | None = None,
|
||||
):
|
||||
info = self.info(ctx)
|
||||
dummy_inputs_builder = self.dummy_inputs(info)
|
||||
return self.processor(info, dummy_inputs_builder, cache=cache)
|
||||
|
||||
|
||||
class MultiModalRegistry:
|
||||
"""
|
||||
A registry that dispatches data processing according to the model.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._processor_factories = ClassRegistry[nn.Module, _ProcessorFactories]()
|
||||
|
||||
def _extract_mm_options(
|
||||
self,
|
||||
model_config: "ModelConfig",
|
||||
) -> Mapping[str, BaseDummyOptions] | None:
|
||||
"""
|
||||
Extract multimodal dummy options from model config.
|
||||
|
||||
Returns None if no configurable options are found, otherwise returns
|
||||
a mapping of modality names to their dummy options.
|
||||
"""
|
||||
if not model_config.multimodal_config:
|
||||
return None
|
||||
|
||||
mm_options = {
|
||||
m: opt
|
||||
for m in model_config.multimodal_config.limit_per_prompt
|
||||
if (opt := model_config.multimodal_config.get_dummy_options(m)) is not None
|
||||
}
|
||||
|
||||
return mm_options if len(mm_options) > 0 else None
|
||||
|
||||
def supports_multimodal_inputs(self, model_config: "ModelConfig") -> bool:
|
||||
"""
|
||||
Checks if the model supports multimodal inputs.
|
||||
Returns True if the model is multimodal with any non-zero supported
|
||||
modalities, otherwise returns False, effectively running in
|
||||
text-only mode.
|
||||
"""
|
||||
if not model_config.is_multimodal_model:
|
||||
return False
|
||||
|
||||
info = self._create_processing_info(model_config, tokenizer=None)
|
||||
supported_modalities = info.get_supported_mm_limits()
|
||||
|
||||
mm_config = model_config.get_multimodal_config()
|
||||
|
||||
# Check if all supported modalities have limit == 0
|
||||
if all(
|
||||
mm_config.get_limit_per_prompt(modality) == 0
|
||||
for modality in supported_modalities
|
||||
):
|
||||
logger.info_once(
|
||||
"All limits of multimodal modalities supported by the model "
|
||||
"are set to 0, running in text-only mode."
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def get_max_tokens_per_item_by_modality(
|
||||
self,
|
||||
model_config: "ModelConfig",
|
||||
*,
|
||||
cache: BaseMultiModalProcessorCache | None = None,
|
||||
profiler_limits: Mapping[str, int] | None = None,
|
||||
) -> Mapping[str, int]:
|
||||
"""
|
||||
Get the maximum number of tokens per data item from each modality based
|
||||
on underlying model configuration.
|
||||
"""
|
||||
if not model_config.is_multimodal_model:
|
||||
return {}
|
||||
|
||||
processor = self.create_processor(model_config, cache=cache)
|
||||
profiler: MultiModalProfiler = MultiModalProfiler(processor)
|
||||
|
||||
seq_len = model_config.max_model_len
|
||||
profiler_limits = (
|
||||
profiler.get_mm_limits() if profiler_limits is None else profiler_limits
|
||||
)
|
||||
|
||||
return profiler.get_mm_max_contiguous_tokens(
|
||||
seq_len,
|
||||
{modality: 1 for modality, limit in profiler_limits.items() if limit > 0},
|
||||
)
|
||||
|
||||
def get_mm_limits_per_prompt(
|
||||
self,
|
||||
model_config: "ModelConfig",
|
||||
*,
|
||||
cache: BaseMultiModalProcessorCache | None = None,
|
||||
) -> Mapping[str, int]:
|
||||
"""
|
||||
Get the maximum number of multi-modal input instances for each modality
|
||||
that are allowed per prompt for a model class.
|
||||
"""
|
||||
if not model_config.is_multimodal_model:
|
||||
return {}
|
||||
|
||||
processor = self.create_processor(model_config, cache=cache)
|
||||
profiler: MultiModalProfiler = MultiModalProfiler(processor)
|
||||
return profiler.get_mm_limits()
|
||||
|
||||
def register_processor(
|
||||
self,
|
||||
processor: MultiModalProcessorFactory[_I],
|
||||
*,
|
||||
info: ProcessingInfoFactory[_I],
|
||||
dummy_inputs: DummyInputsBuilderFactory[_I],
|
||||
):
|
||||
"""
|
||||
Register a multi-modal processor to a model class. The processor
|
||||
is constructed lazily, hence a factory method should be passed.
|
||||
|
||||
When the model receives multi-modal data, the provided function is
|
||||
invoked to transform the data into a dictionary of model inputs.
|
||||
"""
|
||||
|
||||
def wrapper(model_cls: N) -> N:
|
||||
if self._processor_factories.contains(model_cls, strict=True):
|
||||
logger.warning(
|
||||
"Model class %s already has a multi-modal processor "
|
||||
"registered to %s. It is overwritten by the new one.",
|
||||
model_cls,
|
||||
self,
|
||||
)
|
||||
|
||||
self._processor_factories[model_cls] = _ProcessorFactories(
|
||||
info=info,
|
||||
dummy_inputs=dummy_inputs,
|
||||
processor=processor,
|
||||
)
|
||||
|
||||
return model_cls
|
||||
|
||||
return wrapper
|
||||
|
||||
def _get_model_cls(self, model_config: "ModelConfig"):
|
||||
# Avoid circular import
|
||||
from vllm.model_executor.model_loader import get_model_architecture
|
||||
|
||||
model_cls, _ = get_model_architecture(model_config)
|
||||
return model_cls
|
||||
|
||||
def _create_processing_ctx(
|
||||
self,
|
||||
model_config: "ModelConfig",
|
||||
tokenizer: AnyTokenizer | None = None,
|
||||
) -> InputProcessingContext:
|
||||
if tokenizer is None and not model_config.skip_tokenizer_init:
|
||||
tokenizer = cached_tokenizer_from_config(model_config)
|
||||
return InputProcessingContext(model_config, tokenizer)
|
||||
|
||||
def _create_processing_info(
|
||||
self,
|
||||
model_config: "ModelConfig",
|
||||
*,
|
||||
tokenizer: AnyTokenizer | None = None,
|
||||
) -> BaseProcessingInfo:
|
||||
model_cls = self._get_model_cls(model_config)
|
||||
factories = self._processor_factories[model_cls]
|
||||
ctx = self._create_processing_ctx(model_config, tokenizer)
|
||||
return factories.info(ctx)
|
||||
|
||||
def create_processor(
|
||||
self,
|
||||
model_config: "ModelConfig",
|
||||
*,
|
||||
tokenizer: AnyTokenizer | None = None,
|
||||
cache: BaseMultiModalProcessorCache | None = None,
|
||||
) -> BaseMultiModalProcessor[BaseProcessingInfo]:
|
||||
"""
|
||||
Create a multi-modal processor for a specific model and tokenizer.
|
||||
"""
|
||||
if not model_config.is_multimodal_model:
|
||||
raise ValueError(f"{model_config.model} is not a multimodal model")
|
||||
|
||||
model_cls = self._get_model_cls(model_config)
|
||||
factories = self._processor_factories[model_cls]
|
||||
|
||||
ctx = self._create_processing_ctx(model_config, tokenizer)
|
||||
|
||||
return factories.build_processor(ctx, cache=cache)
|
||||
|
||||
def get_decoder_dummy_data(
|
||||
self,
|
||||
model_config: "ModelConfig",
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int] | None = None,
|
||||
*,
|
||||
cache: BaseMultiModalProcessorCache | None = None,
|
||||
) -> DummyDecoderData:
|
||||
"""
|
||||
Create dummy data for profiling the memory usage of a model.
|
||||
|
||||
The model is identified by `model_config`.
|
||||
"""
|
||||
processor = self.create_processor(model_config, cache=cache)
|
||||
profiler: MultiModalProfiler = MultiModalProfiler(processor)
|
||||
|
||||
# Extract configurable options from multimodal config.
|
||||
# Only include modalities that use advanced option types so legacy
|
||||
# count-only behavior remains unchanged.
|
||||
mm_options = self._extract_mm_options(model_config)
|
||||
|
||||
dummy_data = profiler.get_decoder_dummy_data(seq_len, mm_counts, mm_options)
|
||||
|
||||
# Having more tokens is over-conservative but otherwise fine
|
||||
token_ids = dummy_data.prompt_token_ids
|
||||
if len(token_ids) < seq_len:
|
||||
raise AssertionError(
|
||||
f"Expected at least {seq_len} dummy tokens for profiling, "
|
||||
f"but found {len(token_ids)} tokens instead."
|
||||
)
|
||||
|
||||
return dummy_data
|
||||
|
||||
def get_encoder_dummy_data(
|
||||
self,
|
||||
model_config: "ModelConfig",
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int] | None = None,
|
||||
*,
|
||||
cache: BaseMultiModalProcessorCache | None = None,
|
||||
) -> DummyEncoderData:
|
||||
"""
|
||||
Create dummy data for profiling the memory usage of a model.
|
||||
|
||||
The model is identified by `model_config`.
|
||||
"""
|
||||
processor = self.create_processor(model_config, cache=cache)
|
||||
profiler: MultiModalProfiler = MultiModalProfiler(processor)
|
||||
|
||||
# Extract configurable options from multimodal config.
|
||||
# Only include modalities that use advanced option types so legacy
|
||||
# count-only behavior remains unchanged.
|
||||
mm_options = self._extract_mm_options(model_config)
|
||||
|
||||
dummy_data = profiler.get_encoder_dummy_data(seq_len, mm_counts, mm_options)
|
||||
|
||||
# Having more tokens is over-conservative but otherwise fine
|
||||
token_ids = dummy_data.prompt_token_ids
|
||||
if len(token_ids) < seq_len:
|
||||
logger.warning_once(
|
||||
"Expected at least %d dummy encoder tokens for profiling, but found %d tokens instead.", # noqa: E501
|
||||
seq_len,
|
||||
len(token_ids),
|
||||
)
|
||||
|
||||
return dummy_data
|
||||
|
||||
def get_encdec_max_encoder_len(self, model_config: "ModelConfig") -> int:
|
||||
"""
|
||||
Get the maximum length of the encoder input for encoder-decoder models.
|
||||
"""
|
||||
if not model_config.is_encoder_decoder:
|
||||
return 0
|
||||
max_tokens = self.get_max_tokens_per_item_by_modality(model_config)
|
||||
if not max_tokens:
|
||||
# TODO - this function assumes encoder-decoder models are
|
||||
# multimodal. This will need to change when adding support for more
|
||||
# than whisper.
|
||||
return 0
|
||||
assert len(max_tokens) == 1, (
|
||||
"Encoder-decoder models are expected \
|
||||
to implement the multimodal interface with at most one modality."
|
||||
)
|
||||
|
||||
first_modality = next(iter(max_tokens))
|
||||
return max_tokens[first_modality]
|
||||
@@ -1,512 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
import atexit
|
||||
from collections.abc import Iterable, Set
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from itertools import groupby
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, TypeVar
|
||||
from urllib.parse import ParseResult, urlparse
|
||||
from urllib.request import url2pathname
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
import torch
|
||||
from PIL import Image, UnidentifiedImageError
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.connections import HTTPConnection, global_http_connection
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.jsontree import json_map_leaves
|
||||
from vllm.utils.registry import ExtensionManager
|
||||
|
||||
from .audio import AudioMediaIO
|
||||
from .base import MediaIO
|
||||
from .image import ImageEmbeddingMediaIO, ImageMediaIO
|
||||
from .video import VideoMediaIO
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .inputs import (
|
||||
BatchedTensorInputs,
|
||||
MultiModalKwargsItem,
|
||||
MultiModalPlaceholderDict,
|
||||
)
|
||||
else:
|
||||
BatchedTensorInputs = Any
|
||||
MultiModalKwargsItem = Any
|
||||
MultiModalPlaceholderDict = Any
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
global_thread_pool = ThreadPoolExecutor(
|
||||
max_workers=envs.VLLM_MEDIA_LOADING_THREAD_COUNT
|
||||
)
|
||||
atexit.register(global_thread_pool.shutdown)
|
||||
|
||||
_M = TypeVar("_M")
|
||||
|
||||
MEDIA_CONNECTOR_REGISTRY = ExtensionManager()
|
||||
|
||||
|
||||
@MEDIA_CONNECTOR_REGISTRY.register("http")
|
||||
class MediaConnector:
|
||||
def __init__(
|
||||
self,
|
||||
media_io_kwargs: dict[str, dict[str, Any]] | None = None,
|
||||
connection: HTTPConnection = global_http_connection,
|
||||
*,
|
||||
allowed_local_media_path: str = "",
|
||||
allowed_media_domains: list[str] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
media_io_kwargs: Additional args passed to process media
|
||||
inputs, keyed by modalities. For example,
|
||||
to set num_frames for video, set
|
||||
`--media-io-kwargs '{"video":{"num_frames":40}}'`
|
||||
connection: HTTP connection client to download media contents.
|
||||
allowed_local_media_path: A local directory to load media files
|
||||
from.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.media_io_kwargs: dict[str, dict[str, Any]] = (
|
||||
media_io_kwargs if media_io_kwargs else {}
|
||||
)
|
||||
self.connection = connection
|
||||
|
||||
if allowed_local_media_path:
|
||||
allowed_local_media_path_ = Path(allowed_local_media_path)
|
||||
|
||||
if not allowed_local_media_path_.exists():
|
||||
raise ValueError(
|
||||
"Invalid `--allowed-local-media-path`: The path "
|
||||
f"{allowed_local_media_path_} does not exist."
|
||||
)
|
||||
if not allowed_local_media_path_.is_dir():
|
||||
raise ValueError(
|
||||
"Invalid `--allowed-local-media-path`: The path "
|
||||
f"{allowed_local_media_path_} must be a directory."
|
||||
)
|
||||
else:
|
||||
allowed_local_media_path_ = None
|
||||
|
||||
self.allowed_local_media_path = allowed_local_media_path_
|
||||
if allowed_media_domains is None:
|
||||
allowed_media_domains = []
|
||||
self.allowed_media_domains = allowed_media_domains
|
||||
|
||||
def _load_data_url(
|
||||
self,
|
||||
url_spec: ParseResult,
|
||||
media_io: MediaIO[_M],
|
||||
) -> _M: # type: ignore[type-var]
|
||||
data_spec, data = url_spec.path.split(",", 1)
|
||||
media_type, data_type = data_spec.split(";", 1)
|
||||
|
||||
if data_type != "base64":
|
||||
msg = "Only base64 data URLs are supported for now."
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
return media_io.load_base64(media_type, data)
|
||||
|
||||
def _load_file_url(
|
||||
self,
|
||||
url_spec: ParseResult,
|
||||
media_io: MediaIO[_M],
|
||||
) -> _M: # type: ignore[type-var]
|
||||
allowed_local_media_path = self.allowed_local_media_path
|
||||
if allowed_local_media_path is None:
|
||||
raise RuntimeError(
|
||||
"Cannot load local files without `--allowed-local-media-path`."
|
||||
)
|
||||
|
||||
filepath = Path(url2pathname(url_spec.path))
|
||||
if allowed_local_media_path not in filepath.resolve().parents:
|
||||
raise ValueError(
|
||||
f"The file path {filepath} must be a subpath "
|
||||
f"of `--allowed-local-media-path` {allowed_local_media_path}."
|
||||
)
|
||||
|
||||
return media_io.load_file(filepath)
|
||||
|
||||
def _assert_url_in_allowed_media_domains(self, url_spec) -> None:
|
||||
if (
|
||||
self.allowed_media_domains
|
||||
and url_spec.hostname not in self.allowed_media_domains
|
||||
):
|
||||
raise ValueError(
|
||||
f"The URL must be from one of the allowed domains: "
|
||||
f"{self.allowed_media_domains}. Input URL domain: "
|
||||
f"{url_spec.hostname}"
|
||||
)
|
||||
|
||||
def load_from_url(
|
||||
self,
|
||||
url: str,
|
||||
media_io: MediaIO[_M],
|
||||
*,
|
||||
fetch_timeout: int | None = None,
|
||||
) -> _M: # type: ignore[type-var]
|
||||
url_spec = urlparse(url)
|
||||
|
||||
if url_spec.scheme.startswith("http"):
|
||||
self._assert_url_in_allowed_media_domains(url_spec)
|
||||
|
||||
connection = self.connection
|
||||
data = connection.get_bytes(
|
||||
url,
|
||||
timeout=fetch_timeout,
|
||||
allow_redirects=envs.VLLM_MEDIA_URL_ALLOW_REDIRECTS,
|
||||
)
|
||||
|
||||
return media_io.load_bytes(data)
|
||||
|
||||
if url_spec.scheme == "data":
|
||||
return self._load_data_url(url_spec, media_io)
|
||||
|
||||
if url_spec.scheme == "file":
|
||||
return self._load_file_url(url_spec, media_io)
|
||||
|
||||
msg = "The URL must be either a HTTP, data or file URL."
|
||||
raise ValueError(msg)
|
||||
|
||||
async def load_from_url_async(
|
||||
self,
|
||||
url: str,
|
||||
media_io: MediaIO[_M],
|
||||
*,
|
||||
fetch_timeout: int | None = None,
|
||||
) -> _M:
|
||||
url_spec = urlparse(url)
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
if url_spec.scheme.startswith("http"):
|
||||
self._assert_url_in_allowed_media_domains(url_spec)
|
||||
|
||||
connection = self.connection
|
||||
data = await connection.async_get_bytes(
|
||||
url,
|
||||
timeout=fetch_timeout,
|
||||
allow_redirects=envs.VLLM_MEDIA_URL_ALLOW_REDIRECTS,
|
||||
)
|
||||
future = loop.run_in_executor(global_thread_pool, media_io.load_bytes, data)
|
||||
return await future
|
||||
|
||||
if url_spec.scheme == "data":
|
||||
future = loop.run_in_executor(
|
||||
global_thread_pool, self._load_data_url, url_spec, media_io
|
||||
)
|
||||
return await future
|
||||
|
||||
if url_spec.scheme == "file":
|
||||
future = loop.run_in_executor(
|
||||
global_thread_pool, self._load_file_url, url_spec, media_io
|
||||
)
|
||||
return await future
|
||||
msg = "The URL must be either a HTTP, data or file URL."
|
||||
raise ValueError(msg)
|
||||
|
||||
def fetch_audio(
|
||||
self,
|
||||
audio_url: str,
|
||||
) -> tuple[np.ndarray, int | float]:
|
||||
"""
|
||||
Load audio from a URL.
|
||||
"""
|
||||
audio_io = AudioMediaIO(**self.media_io_kwargs.get("audio", {}))
|
||||
|
||||
return self.load_from_url(
|
||||
audio_url,
|
||||
audio_io,
|
||||
fetch_timeout=envs.VLLM_AUDIO_FETCH_TIMEOUT,
|
||||
)
|
||||
|
||||
async def fetch_audio_async(
|
||||
self,
|
||||
audio_url: str,
|
||||
) -> tuple[np.ndarray, int | float]:
|
||||
"""
|
||||
Asynchronously fetch audio from a URL.
|
||||
"""
|
||||
audio_io = AudioMediaIO(**self.media_io_kwargs.get("audio", {}))
|
||||
|
||||
return await self.load_from_url_async(
|
||||
audio_url,
|
||||
audio_io,
|
||||
fetch_timeout=envs.VLLM_AUDIO_FETCH_TIMEOUT,
|
||||
)
|
||||
|
||||
def fetch_image(
|
||||
self,
|
||||
image_url: str,
|
||||
*,
|
||||
image_mode: str = "RGB",
|
||||
) -> Image.Image:
|
||||
"""
|
||||
Load a PIL image from an HTTP or base64 data URL.
|
||||
|
||||
By default, the image is converted into RGB format.
|
||||
"""
|
||||
image_io = ImageMediaIO(
|
||||
image_mode=image_mode, **self.media_io_kwargs.get("image", {})
|
||||
)
|
||||
|
||||
try:
|
||||
return self.load_from_url(
|
||||
image_url,
|
||||
image_io,
|
||||
fetch_timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT,
|
||||
)
|
||||
except UnidentifiedImageError as e:
|
||||
# convert to ValueError to be properly caught upstream
|
||||
raise ValueError(str(e)) from e
|
||||
|
||||
async def fetch_image_async(
|
||||
self,
|
||||
image_url: str,
|
||||
*,
|
||||
image_mode: str = "RGB",
|
||||
) -> Image.Image:
|
||||
"""
|
||||
Asynchronously load a PIL image from an HTTP or base64 data URL.
|
||||
|
||||
By default, the image is converted into RGB format.
|
||||
"""
|
||||
image_io = ImageMediaIO(
|
||||
image_mode=image_mode, **self.media_io_kwargs.get("image", {})
|
||||
)
|
||||
|
||||
try:
|
||||
return await self.load_from_url_async(
|
||||
image_url,
|
||||
image_io,
|
||||
fetch_timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT,
|
||||
)
|
||||
except UnidentifiedImageError as e:
|
||||
# convert to ValueError to be properly caught upstream
|
||||
raise ValueError(str(e)) from e
|
||||
|
||||
def fetch_video(
|
||||
self,
|
||||
video_url: str,
|
||||
*,
|
||||
image_mode: str = "RGB",
|
||||
) -> tuple[npt.NDArray, dict[str, Any]]:
|
||||
"""
|
||||
Load video from an HTTP or base64 data URL.
|
||||
"""
|
||||
image_io = ImageMediaIO(
|
||||
image_mode=image_mode, **self.media_io_kwargs.get("image", {})
|
||||
)
|
||||
video_io = VideoMediaIO(image_io, **self.media_io_kwargs.get("video", {}))
|
||||
|
||||
return self.load_from_url(
|
||||
video_url,
|
||||
video_io,
|
||||
fetch_timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT,
|
||||
)
|
||||
|
||||
async def fetch_video_async(
|
||||
self,
|
||||
video_url: str,
|
||||
*,
|
||||
image_mode: str = "RGB",
|
||||
) -> tuple[npt.NDArray, dict[str, Any]]:
|
||||
"""
|
||||
Asynchronously load video from an HTTP or base64 data URL.
|
||||
|
||||
By default, the image is converted into RGB format.
|
||||
"""
|
||||
image_io = ImageMediaIO(
|
||||
image_mode=image_mode, **self.media_io_kwargs.get("image", {})
|
||||
)
|
||||
video_io = VideoMediaIO(image_io, **self.media_io_kwargs.get("video", {}))
|
||||
|
||||
return await self.load_from_url_async(
|
||||
video_url,
|
||||
video_io,
|
||||
fetch_timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT,
|
||||
)
|
||||
|
||||
def fetch_image_embedding(
|
||||
self,
|
||||
data: str,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Load image embedding from a URL.
|
||||
"""
|
||||
image_embedding_io = ImageEmbeddingMediaIO()
|
||||
|
||||
return image_embedding_io.load_base64("", data)
|
||||
|
||||
|
||||
def encode_audio_base64(
|
||||
audio: np.ndarray,
|
||||
sampling_rate: int,
|
||||
) -> str:
|
||||
"""Encode audio as base64."""
|
||||
audio_io = AudioMediaIO()
|
||||
return audio_io.encode_base64((audio, sampling_rate))
|
||||
|
||||
|
||||
def encode_image_base64(
|
||||
image: Image.Image,
|
||||
*,
|
||||
image_mode: str = "RGB",
|
||||
format: str = "JPEG",
|
||||
) -> str:
|
||||
"""
|
||||
Encode a pillow image to base64 format.
|
||||
|
||||
By default, the image is converted into RGB format before being encoded.
|
||||
"""
|
||||
image_io = ImageMediaIO(image_mode=image_mode)
|
||||
return image_io.encode_base64(image, image_format=format)
|
||||
|
||||
|
||||
def encode_video_base64(frames: npt.NDArray) -> str:
|
||||
image_io = ImageMediaIO()
|
||||
video_io = VideoMediaIO(image_io)
|
||||
return video_io.encode_base64(frames)
|
||||
|
||||
|
||||
def argsort_mm_positions(
|
||||
mm_positions: MultiModalPlaceholderDict,
|
||||
) -> list[tuple[str, int]]:
|
||||
"""
|
||||
Given a `MultiModalPlaceholderDict`, output a sequence of keys to
|
||||
sort the dictionary by `offset` (starting index in the input sequence)
|
||||
in ascending order.
|
||||
|
||||
Returns:
|
||||
A list of `(modality, idx)`, which can be used to access an item
|
||||
by `mm_positions[modality][idx]`.
|
||||
"""
|
||||
flat_items = (
|
||||
(modality, idx, item)
|
||||
for modality, items in mm_positions.items()
|
||||
for idx, item in enumerate(items)
|
||||
)
|
||||
|
||||
sorted_flat_items = sorted(flat_items, key=lambda x: x[2].offset)
|
||||
|
||||
return [(modality, idx) for modality, idx, _ in sorted_flat_items]
|
||||
|
||||
|
||||
def group_mm_kwargs_by_modality(
|
||||
mm_kwargs: list[MultiModalKwargsItem],
|
||||
*,
|
||||
device: torch.types.Device = None,
|
||||
pin_memory: bool = False,
|
||||
merge_by_field_config: bool | None = None,
|
||||
multimodal_cpu_fields: Set[str] = frozenset(),
|
||||
) -> Iterable[tuple[str, int, BatchedTensorInputs]]:
|
||||
"""Group consecutive `MultiModalKwargsItem`s from `mm_kwargs` with the same
|
||||
modality together into the same `MultiModalKwargs` instance.
|
||||
|
||||
Args:
|
||||
mm_kwargs: List of `MultiModalKwargsItem`.
|
||||
device: The device to place the grouped tensors on.
|
||||
pin_memory: Whether to pin memory for faster host-to-device transfer.
|
||||
|
||||
Yields:
|
||||
A tuple `(modality, num_items, grouped_kwargs)`.
|
||||
"""
|
||||
if merge_by_field_config is None:
|
||||
raise RuntimeError(
|
||||
"`group_mm_kwargs_by_modality` now requires "
|
||||
"`merge_by_field_config` arg, please update your model runner "
|
||||
"according to https://github.com/vllm-project/vllm/pull/25676."
|
||||
)
|
||||
if merge_by_field_config is False:
|
||||
logger.warning_once(
|
||||
"The legacy code for batching multi-modal kwargs is deprecated and "
|
||||
"will be removed in v0.12. Please update your model with "
|
||||
"`merge_by_field_config=True` to use the new code defined by "
|
||||
"`MultiModalFieldConfig`. You can refer to "
|
||||
"https://github.com/vllm-project/vllm/issues/26149 "
|
||||
"for some examples on how to do this."
|
||||
)
|
||||
|
||||
from vllm.multimodal.inputs import MultiModalKwargs, MultiModalKwargsItems
|
||||
|
||||
for modality, items in groupby(mm_kwargs, key=lambda item: item.modality):
|
||||
items_lst = list(items)
|
||||
|
||||
if merge_by_field_config:
|
||||
mm_kwargs_group: BatchedTensorInputs = dict(
|
||||
MultiModalKwargsItems.from_seq(items_lst).get_data(
|
||||
pin_memory=pin_memory
|
||||
)
|
||||
)
|
||||
|
||||
if device is not None:
|
||||
mm_kwargs_group = {
|
||||
k: json_map_leaves(
|
||||
lambda x: x.to(device=device, non_blocking=True)
|
||||
if isinstance(x, torch.Tensor)
|
||||
else x,
|
||||
v,
|
||||
)
|
||||
if k not in multimodal_cpu_fields
|
||||
else v
|
||||
for k, v in mm_kwargs_group.items()
|
||||
}
|
||||
else:
|
||||
mm_kwargs_group = MultiModalKwargs.as_kwargs(
|
||||
MultiModalKwargs.batch(
|
||||
[
|
||||
MultiModalKwargsItems.from_seq([item]).get_data()
|
||||
for item in items_lst
|
||||
],
|
||||
pin_memory=pin_memory,
|
||||
),
|
||||
device=device,
|
||||
)
|
||||
|
||||
yield modality, len(items_lst), mm_kwargs_group
|
||||
|
||||
|
||||
def fetch_audio(
|
||||
audio_url: str,
|
||||
audio_io_kwargs: dict[str, Any] | None = None,
|
||||
) -> tuple[np.ndarray, int | float]:
|
||||
"""
|
||||
Args:
|
||||
audio_url: URL of the audio file to fetch.
|
||||
audio_io_kwargs: Additional kwargs passed to handle audio IO.
|
||||
"""
|
||||
media_io_kwargs = None if not audio_io_kwargs else {"audio": audio_io_kwargs}
|
||||
media_connector = MediaConnector(media_io_kwargs=media_io_kwargs)
|
||||
return media_connector.fetch_audio(audio_url)
|
||||
|
||||
|
||||
def fetch_image(
|
||||
image_url: str,
|
||||
image_io_kwargs: dict[str, Any] | None = None,
|
||||
) -> Image.Image:
|
||||
"""
|
||||
Args:
|
||||
image_url: URL of the image file to fetch.
|
||||
image_io_kwargs: Additional kwargs passed to handle image IO.
|
||||
"""
|
||||
media_io_kwargs = None if not image_io_kwargs else {"image": image_io_kwargs}
|
||||
media_connector = MediaConnector(media_io_kwargs=media_io_kwargs)
|
||||
return media_connector.fetch_image(image_url)
|
||||
|
||||
|
||||
def fetch_video(
|
||||
video_url: str,
|
||||
video_io_kwargs: dict[str, Any] | None = None,
|
||||
) -> tuple[npt.NDArray, dict[str, Any]]:
|
||||
"""
|
||||
Args:
|
||||
video_url: URL of the video file to fetch.
|
||||
video_io_kwargs: Additional kwargs passed to handle video IO.
|
||||
"""
|
||||
media_io_kwargs = None if not video_io_kwargs else {"video": video_io_kwargs}
|
||||
media_connector = MediaConnector(media_io_kwargs=media_io_kwargs)
|
||||
return media_connector.fetch_video(video_url)
|
||||
@@ -1,306 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import base64
|
||||
import math
|
||||
from abc import abstractmethod
|
||||
from functools import partial
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
from PIL import Image
|
||||
|
||||
from vllm import envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.registry import ExtensionManager
|
||||
|
||||
from .base import MediaIO
|
||||
from .image import ImageMediaIO
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def resize_video(frames: npt.NDArray, size: tuple[int, int]) -> npt.NDArray:
|
||||
num_frames, _, _, channels = frames.shape
|
||||
new_height, new_width = size
|
||||
resized_frames = np.empty(
|
||||
(num_frames, new_height, new_width, channels), dtype=frames.dtype
|
||||
)
|
||||
# lazy import cv2 to avoid bothering users who only use text models
|
||||
import cv2
|
||||
|
||||
for i, frame in enumerate(frames):
|
||||
resized_frame = cv2.resize(frame, (new_width, new_height))
|
||||
resized_frames[i] = resized_frame
|
||||
return resized_frames
|
||||
|
||||
|
||||
def rescale_video_size(frames: npt.NDArray, size_factor: float) -> npt.NDArray:
|
||||
_, height, width, _ = frames.shape
|
||||
new_height = int(height * size_factor)
|
||||
new_width = int(width * size_factor)
|
||||
|
||||
return resize_video(frames, (new_height, new_width))
|
||||
|
||||
|
||||
def sample_frames_from_video(frames: npt.NDArray, num_frames: int) -> npt.NDArray:
|
||||
total_frames = frames.shape[0]
|
||||
if num_frames == -1:
|
||||
return frames
|
||||
|
||||
frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
|
||||
sampled_frames = frames[frame_indices, ...]
|
||||
return sampled_frames
|
||||
|
||||
|
||||
class VideoLoader:
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def load_bytes(
|
||||
cls, data: bytes, num_frames: int = -1, **kwargs
|
||||
) -> tuple[npt.NDArray, dict[str, Any]]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
VIDEO_LOADER_REGISTRY = ExtensionManager()
|
||||
|
||||
|
||||
@VIDEO_LOADER_REGISTRY.register("opencv")
|
||||
class OpenCVVideoBackend(VideoLoader):
|
||||
def get_cv2_video_api(self):
|
||||
import cv2.videoio_registry as vr
|
||||
|
||||
api_pref = None
|
||||
for backend in vr.getStreamBufferedBackends():
|
||||
if not vr.hasBackend(backend):
|
||||
continue
|
||||
if not vr.isBackendBuiltIn(backend):
|
||||
_, abi, api = vr.getStreamBufferedBackendPluginVersion(backend)
|
||||
if abi < 1 or (abi == 1 and api < 2):
|
||||
continue
|
||||
api_pref = backend
|
||||
break
|
||||
return api_pref
|
||||
|
||||
@classmethod
|
||||
def load_bytes(
|
||||
cls,
|
||||
data: bytes,
|
||||
num_frames: int = -1,
|
||||
fps: int = -1,
|
||||
**kwargs,
|
||||
) -> tuple[npt.NDArray, dict[str, Any]]:
|
||||
import cv2
|
||||
|
||||
backend = cls().get_cv2_video_api()
|
||||
cap = cv2.VideoCapture(BytesIO(data), backend, [])
|
||||
if not cap.isOpened():
|
||||
raise ValueError("Could not open video stream")
|
||||
|
||||
total_frames_num = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
original_fps = cap.get(cv2.CAP_PROP_FPS)
|
||||
duration = total_frames_num / original_fps if original_fps > 0 else 0
|
||||
|
||||
# resample video to target num_frames and fps
|
||||
# - the minimum of the two will be used
|
||||
num_frames_to_sample = total_frames_num
|
||||
if num_frames > 0:
|
||||
num_frames_to_sample = min(num_frames, total_frames_num)
|
||||
if fps > 0:
|
||||
num_frames_to_sample = min(num_frames_to_sample, math.floor(duration * fps))
|
||||
num_frames_to_sample = max(1, num_frames_to_sample) # at least one sample
|
||||
|
||||
if num_frames_to_sample == total_frames_num:
|
||||
frame_idx = list(range(0, num_frames_to_sample))
|
||||
else:
|
||||
uniform_sampled_frames = np.linspace(
|
||||
0, total_frames_num - 1, num_frames_to_sample, dtype=int
|
||||
)
|
||||
frame_idx = uniform_sampled_frames.tolist()
|
||||
|
||||
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
frames = np.empty((len(frame_idx), height, width, 3), dtype=np.uint8)
|
||||
|
||||
i = 0
|
||||
for idx in range(max(frame_idx) + 1):
|
||||
ok = cap.grab()
|
||||
if not ok:
|
||||
break
|
||||
if idx in frame_idx:
|
||||
ret, frame = cap.retrieve()
|
||||
if ret:
|
||||
frames[i] = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
i += 1
|
||||
|
||||
assert i == num_frames_to_sample, (
|
||||
f"Expected reading {num_frames_to_sample} frames, "
|
||||
f"but only loaded {i} frames from video."
|
||||
)
|
||||
|
||||
# Use transformers transformers.video_utils.VideoMetadata format
|
||||
# NOTE(Isotr0py): For models like Qwen3-VL/GLM4.5V, this metadata
|
||||
# can cause incorrect timestamp calculation without num_frames=-1.
|
||||
metadata = {
|
||||
"total_num_frames": total_frames_num,
|
||||
"fps": original_fps,
|
||||
"duration": duration,
|
||||
"video_backend": "opencv",
|
||||
"frames_indices": list(frame_idx),
|
||||
# extra field used to control hf processor's video
|
||||
# sampling behavior
|
||||
"do_sample_frames": num_frames_to_sample == total_frames_num,
|
||||
}
|
||||
|
||||
return frames, metadata
|
||||
|
||||
|
||||
@VIDEO_LOADER_REGISTRY.register("opencv_dynamic")
|
||||
class OpenCVDynamicVideoBackend(OpenCVVideoBackend):
|
||||
@classmethod
|
||||
def load_bytes(
|
||||
cls,
|
||||
data: bytes,
|
||||
num_frames: int = -1,
|
||||
fps: int = 2,
|
||||
max_duration: int = 300,
|
||||
**kwargs,
|
||||
) -> tuple[npt.NDArray, dict[str, Any]]:
|
||||
import cv2
|
||||
|
||||
backend = cls().get_cv2_video_api()
|
||||
cap = cv2.VideoCapture(BytesIO(data), backend, [])
|
||||
if not cap.isOpened():
|
||||
raise ValueError("Could not open video stream")
|
||||
|
||||
total_frames_num = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
original_fps = cap.get(cv2.CAP_PROP_FPS)
|
||||
duration = total_frames_num / original_fps if original_fps > 0 else 0
|
||||
|
||||
# resample video to target num_frames
|
||||
max_frame_idx = total_frames_num - 1
|
||||
duration = duration or round(max_frame_idx / original_fps) + 1
|
||||
|
||||
# Refer to:
|
||||
# https://github.com/huggingface/transformers/blob/v4.55.4/src/transformers/models/glm4v/video_processing_glm4v.py#L103-L140
|
||||
frame_indices: range | list[int]
|
||||
if duration <= max_duration:
|
||||
n = int(math.floor(duration * fps))
|
||||
frame_indices = sorted(
|
||||
{
|
||||
min(max_frame_idx, int(math.ceil(i * original_fps / fps)))
|
||||
for i in range(n)
|
||||
}
|
||||
)
|
||||
else:
|
||||
num_samples = int(max_duration * fps)
|
||||
if num_samples >= total_frames_num:
|
||||
frame_indices = range(total_frames_num)
|
||||
else:
|
||||
target_seconds = np.linspace(0, duration, num_samples, endpoint=True)
|
||||
frame_indices = sorted(
|
||||
{
|
||||
min(max_frame_idx, int(math.ceil(t * original_fps)))
|
||||
for t in target_seconds
|
||||
}
|
||||
)
|
||||
|
||||
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
frames = np.empty((len(frame_indices), height, width, 3), dtype=np.uint8)
|
||||
|
||||
i = 0
|
||||
for idx in range(total_frames_num):
|
||||
ok = cap.grab()
|
||||
if not ok:
|
||||
break
|
||||
if idx in frame_indices:
|
||||
ret, frame = cap.retrieve()
|
||||
if ret:
|
||||
frames[i] = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
i += 1
|
||||
|
||||
assert i == len(frame_indices), (
|
||||
f"Expected reading {len(frame_indices)} frames, "
|
||||
f"but only loaded {i} frames from video."
|
||||
)
|
||||
|
||||
# Use transformers transformers.video_utils.VideoMetadata format
|
||||
metadata = {
|
||||
"total_num_frames": total_frames_num,
|
||||
"fps": original_fps,
|
||||
"duration": duration,
|
||||
"video_backend": "opencv_dynamic",
|
||||
"frames_indices": list(frame_indices),
|
||||
"do_sample_frames": False,
|
||||
}
|
||||
|
||||
return frames, metadata
|
||||
|
||||
|
||||
class VideoMediaIO(MediaIO[npt.NDArray]):
|
||||
def __init__(
|
||||
self,
|
||||
image_io: ImageMediaIO,
|
||||
num_frames: int = 32,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.image_io = image_io
|
||||
self.num_frames = num_frames
|
||||
# `kwargs` contains custom arguments from
|
||||
# --media-io-kwargs for this modality.
|
||||
# They can be passed to the underlying
|
||||
# media loaders (e.g. custom implementations)
|
||||
# for flexible control.
|
||||
self.kwargs = kwargs
|
||||
video_loader_backend = envs.VLLM_VIDEO_LOADER_BACKEND
|
||||
self.video_loader = VIDEO_LOADER_REGISTRY.load(video_loader_backend)
|
||||
|
||||
def load_bytes(self, data: bytes) -> tuple[npt.NDArray, dict[str, Any]]:
|
||||
return self.video_loader.load_bytes(
|
||||
data, num_frames=self.num_frames, **self.kwargs
|
||||
)
|
||||
|
||||
def load_base64(
|
||||
self, media_type: str, data: str
|
||||
) -> tuple[npt.NDArray, dict[str, Any]]:
|
||||
if media_type.lower() == "video/jpeg":
|
||||
load_frame = partial(
|
||||
self.image_io.load_base64,
|
||||
"image/jpeg",
|
||||
)
|
||||
|
||||
return np.stack(
|
||||
[np.asarray(load_frame(frame_data)) for frame_data in data.split(",")]
|
||||
), {}
|
||||
|
||||
return self.load_bytes(base64.b64decode(data))
|
||||
|
||||
def load_file(self, filepath: Path) -> tuple[npt.NDArray, dict[str, Any]]:
|
||||
with filepath.open("rb") as f:
|
||||
data = f.read()
|
||||
|
||||
return self.load_bytes(data)
|
||||
|
||||
def encode_base64(
|
||||
self,
|
||||
media: npt.NDArray,
|
||||
*,
|
||||
video_format: str = "JPEG",
|
||||
) -> str:
|
||||
video = media
|
||||
|
||||
if video_format == "JPEG":
|
||||
encode_frame = partial(
|
||||
self.image_io.encode_base64,
|
||||
image_format=video_format,
|
||||
)
|
||||
|
||||
return ",".join(encode_frame(Image.fromarray(frame)) for frame in video)
|
||||
|
||||
msg = "Only JPEG format is supported for now."
|
||||
raise NotImplementedError(msg)
|
||||
Reference in New Issue
Block a user