add qwen3

This commit is contained in:
Chranos
2026-02-04 17:22:39 +08:00
parent d1c0f68ab4
commit 8511fe8530
1932 changed files with 300426 additions and 0 deletions

View File

@@ -0,0 +1,44 @@
from .base import MultiModalPlaceholderMap, MultiModalPlugin
from .inputs import (BatchedTensorInputs, MultiModalData,
MultiModalDataBuiltins, MultiModalDataDict,
MultiModalKwargs, MultiModalPlaceholderDict,
NestedTensors)
from .registry import MultiModalRegistry
MULTIMODAL_REGISTRY = MultiModalRegistry()
"""
The global :class:`~MultiModalRegistry` is used by model runners to
dispatch data processing according to its modality and the target model.
See also:
:ref:`input_processing_pipeline`
"""
__all__ = [
"BatchedTensorInputs",
"MultiModalData",
"MultiModalDataBuiltins",
"MultiModalDataDict",
"MultiModalKwargs",
"MultiModalPlaceholderDict",
"MultiModalPlaceholderMap",
"MultiModalPlugin",
"NestedTensors",
"MULTIMODAL_REGISTRY",
"MultiModalRegistry",
]
def __getattr__(name: str):
import warnings
if name == "MultiModalInputs":
msg = ("MultiModalInputs has been renamed to MultiModalKwargs. "
"The original name will take another meaning in an upcoming "
"version.")
warnings.warn(DeprecationWarning(msg), stacklevel=2)
return MultiModalKwargs
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

View File

@@ -0,0 +1,23 @@
from vllm.inputs.registry import InputContext
from .base import MultiModalPlugin
from .inputs import AudioItem, MultiModalData, MultiModalKwargs
class AudioPlugin(MultiModalPlugin):
"""Plugin for audio data."""
def get_data_key(self) -> str:
return "audio"
def _default_input_mapper(
self,
ctx: InputContext,
data: MultiModalData[AudioItem],
**mm_processor_kwargs,
) -> MultiModalKwargs:
raise NotImplementedError("There is no default audio input mapper")
def _default_max_multimodal_tokens(self, ctx: InputContext) -> int:
raise NotImplementedError(
"There is no default maximum multimodal tokens")

View File

@@ -0,0 +1,450 @@
from abc import ABC, abstractmethod
from collections import defaultdict
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple,
Optional, Sequence, Tuple, Type, TypeVar, Union)
from torch import nn
from vllm.inputs import InputContext
from vllm.logger import init_logger
from vllm.utils import (get_allowed_kwarg_only_overrides,
resolve_mm_processor_kwargs)
if TYPE_CHECKING:
from vllm.config import ModelConfig
from vllm.sequence import SequenceGroupMetadata
from .inputs import (MultiModalData, MultiModalDataDict, MultiModalKwargs,
PlaceholderRange)
logger = init_logger(__name__)
MultiModalInputMapper = Callable[[InputContext, MultiModalData[object]],
MultiModalKwargs]
"""
Return a dictionary to be passed as keyword arguments to
:meth:`~torch.nn.Module.forward`. This is similar in concept to tokenizers
and processors in HuggingFace Transformers.
If the data is not supported, throw :exc:`TypeError`.
"""
MultiModalTokensCalc = Union[int, Callable[[InputContext], int]]
"""
Calculate the maximum number of multimodal tokens input to the language
model. This does not include tokens that correspond to the input text.
"""
_T = TypeVar("_T")
N = TypeVar("N", bound=Type[nn.Module])
class MultiModalPlugin(ABC):
"""
Base class that defines data processing logic for a specific modality.
In particular, we adopt a registry pattern to dispatch data processing
according to the model being used (considering that different models may
process the same data differently). This registry is in turn used by
:class:`~MultiModalRegistry` which acts at a higher level
(i.e., the modality of the data).
See also:
:ref:`adding_multimodal_plugin`
"""
def __init__(self) -> None:
self._input_mappers: Dict[Type[nn.Module], MultiModalInputMapper] = {}
self._max_mm_tokens: Dict[Type[nn.Module], MultiModalTokensCalc] = {}
@abstractmethod
def get_data_key(self) -> str:
"""
Get the data key corresponding to the modality.
"""
raise NotImplementedError
@abstractmethod
def _default_input_mapper(
self,
ctx: InputContext,
data: MultiModalData[Any],
**mm_processor_kwargs,
) -> MultiModalKwargs:
"""
Return a dictionary to be passed as keyword arguments to
:meth:`~torch.nn.Module.forward`. This is similar in concept to
tokenizers and processors in HuggingFace Transformers.
If the data is not supported, throw :exc:`TypeError`.
"""
raise NotImplementedError
def register_input_mapper(
self,
mapper: Optional[MultiModalInputMapper] = None,
):
"""
Register an input mapper to a model class.
When the model receives input data that matches the modality served by
this plugin (see :meth:`get_data_key`), the provided function is
invoked to transform the data into a dictionary of model inputs.
If `None` is provided, then the default input mapper is used instead.
See also:
- :ref:`input_processing_pipeline`
- :ref:`enabling_multimodal_inputs`
"""
def wrapper(model_cls: N) -> N:
if model_cls in self._input_mappers:
logger.warning(
"Model class %s already has an input mapper "
"registered to %s. It is overwritten by the new one.",
model_cls,
self,
)
self._input_mappers[model_cls] = (mapper
or self._default_input_mapper)
return model_cls
return wrapper
def map_input(
self,
model_config: "ModelConfig",
data: MultiModalData[Any],
mm_processor_kwargs: Optional[Dict[str, Any]],
) -> MultiModalKwargs:
"""
Transform the data into a dictionary of model inputs using the
input mapper registered for that model.
The model is identified by ``model_config``.
Raises:
TypeError: If the data type is not supported.
See also:
- :ref:`input_processing_pipeline`
- :ref:`enabling_multimodal_inputs`
"""
# Avoid circular import
from vllm.model_executor.model_loader import get_model_architecture
model_cls, _ = get_model_architecture(model_config)
mapper = self._input_mappers.get(model_cls)
if mapper is None:
raise KeyError(f"No input mapper in {self} is registered for "
f"model class {model_cls.__name__}.")
if mm_processor_kwargs is None:
mm_processor_kwargs = {}
# In the case of the default mapper, we have to get resource
# processor through its HuggingFace autoclass; since this goes
# through **kwargs, we can't inspect it the same way, so we allow
# drop mm_processor_kwargs based on signature inspection
# if we're using the default mapper.
#
# This should be safe in general due to the sanitation, since the
# transformers resource should filter unused kwargs anyway.
uses_default_mapper = mapper == self._default_input_mapper
mm_processor_kwargs = resolve_mm_processor_kwargs(
model_config.mm_processor_kwargs,
mm_processor_kwargs,
callable=mapper,
allow_var_kwargs=uses_default_mapper,
)
return mapper(InputContext(model_config), data, **mm_processor_kwargs)
@abstractmethod
def _default_max_multimodal_tokens(self, ctx: InputContext) -> int:
"""
Calculate the maximum number of tokens, corresponding to a single
instance of multimodal data, that are passed to the language model.
"""
raise NotImplementedError
def _validate_max_multimodal_tokens(self, max_mm_tokens: int):
if max_mm_tokens < 1:
raise ValueError("You should set the number of tokens to a "
f"positive integer. Found: {max_mm_tokens}")
def register_max_multimodal_tokens(
self,
max_mm_tokens: Optional[MultiModalTokensCalc] = None,
):
"""
Register the maximum number of tokens, corresponding to a single
instance of multimodal data, that are passed to the language model
for a model class.
If `None` is provided, then the default calculation is used instead.
See also:
:ref:`enabling_multimodal_inputs`
"""
def wrapper(model_cls: N) -> N:
if model_cls in self._max_mm_tokens:
logger.warning(
"Model class %s already calculates maximum number of "
"tokens in %s. It is overwritten by the new one.",
model_cls,
self,
)
if isinstance(max_mm_tokens, int):
self._validate_max_multimodal_tokens(max_mm_tokens)
self._max_mm_tokens[model_cls] = (
max_mm_tokens or self._default_max_multimodal_tokens)
return model_cls
return wrapper
def get_max_multimodal_tokens(self, model_config: "ModelConfig") -> int:
"""
Get the maximum number of multi-modal tokens
for profiling the memory usage of a model.
If this registry is not applicable to the model, `0` is returned.
The model is identified by ``model_config``.
See also:
:ref:`enabling_multimodal_inputs`
"""
# Avoid circular import
from vllm.model_executor.model_loader import get_model_architecture
model_cls, _ = get_model_architecture(model_config)
if model_cls not in self._input_mappers:
return 0
max_mm_tokens = self._max_mm_tokens.get(model_cls)
if max_mm_tokens is None:
raise KeyError(f"No maximum number of multi-modal tokens is given "
f"for model class {model_cls.__name__} in {self}.")
if callable(max_mm_tokens):
mm_processor_kwargs = get_allowed_kwarg_only_overrides(
max_mm_tokens, overrides=model_config.mm_processor_kwargs)
max_mm_tokens = max_mm_tokens(InputContext(model_config),
**mm_processor_kwargs)
self._validate_max_multimodal_tokens(max_mm_tokens)
return max_mm_tokens
class MultiModalPlaceholderMap:
"""
Relates multi-modal embeddings to their corresponding placeholders.
"""
class IndexMap(NamedTuple):
src: List[int]
dest: List[int]
src_ranges: List[range]
"""
The indices of the multi-modal embeddings that will replace the
corresponding placeholder embeddings pointed to by ``dest_ranges``.
"""
src_len: int
"""
The total number of flattened multi-modal embeddings.
"""
dest_ranges: List[range]
"""
The indices of the placeholder embeddings that will be replaced by the
multimodal embeddings.
"""
dest_len: int
"""
The total number of embeddings in the destination tensor.
"""
def __init__(self):
self.src_ranges = []
self.src_len = 0
self.dest_ranges = []
self.dest_len = 0
@classmethod
def from_seq_group(
cls, seq_group: "SequenceGroupMetadata", positions: range
) -> Tuple[Optional[MultiModalDataDict], Dict[str,
"MultiModalPlaceholderMap"]]:
"""
Returns the multi-modal items that intersect with the portion of a
prompt (``seq_group``) represented by ``positions``, as well as a
``MultiModalPlaceholderMap`` that relates the multi-modal embedding
vectors to their corresponding placeholders.
Consider the following scenarios:
Prompt: |AAAA BBBB What's in these images?|
Positions: |.................................|
images = [A, B]
src_ranges = [(0, 4), (4, 8)]
dest_ranges = [(0, 4), (5, 9)]
Prompt: |AAAA BBBB What's in these images?|
Positions: | ..... |
images = [A, B]
src_ranges = [(2, 4), (4, 6)]
dest_ranges = [(0, 2), (3, 5)]
Prompt: |AAAA BBBB What's in these images?|
Positions: | ......... |
images = [B]
src_ranges = [(0, 4)]
dest_ranges = [(0, 4)]
Prompt: |AAAA BBBB What's in these images?|
Positions: | .......................|
images = []
src_ranges = []
dest_ranges = []
"""
if (not seq_group.multi_modal_data
or not seq_group.multi_modal_placeholders):
return seq_group.multi_modal_data, {}
mm_data = {**seq_group.multi_modal_data}
placeholder_maps: Dict[str, MultiModalPlaceholderMap] = defaultdict(
MultiModalPlaceholderMap)
for (
modality,
placeholders,
) in seq_group.multi_modal_placeholders.items():
mm_items = mm_data.pop(modality)
if not isinstance(mm_items, list):
mm_items = [mm_items]
if positions:
intersecting_items = placeholder_maps[
modality].append_items_from_seq_group(
positions, mm_items, placeholders)
if intersecting_items:
mm_data[modality] = intersecting_items
return mm_data, placeholder_maps
def append_items_from_seq_group(
self,
positions: range,
multi_modal_items: List[_T],
multi_modal_placeholders: Sequence[PlaceholderRange],
) -> List[_T]:
"""
Adds the multi-modal items that intersect ```positions`` to this
placeholder map and returns the intersecting items.
"""
intersecting_items = []
if len(multi_modal_items) != len(multi_modal_placeholders):
raise ValueError(
"Multi-modal placeholders and items must have the same length."
)
for placeholder_dict, mm_item in zip(multi_modal_placeholders,
multi_modal_items):
placeholder = range(
placeholder_dict["offset"],
placeholder_dict["offset"] + placeholder_dict["length"],
)
intersection = range(
max(positions.start, placeholder.start),
min(positions.stop, placeholder.stop),
)
if not intersection:
# Skip this multi-modal item.
continue
token_embedding_range = range(
intersection.start - positions.start,
intersection.stop - positions.start,
)
multimodal_embedding_range = range(
intersection.start - placeholder.start + self.src_len,
intersection.stop - placeholder.start + self.src_len,
)
intersecting_items.append(mm_item)
self.dest_ranges.append(token_embedding_range)
self.src_ranges.append(multimodal_embedding_range)
self.src_len += len(placeholder)
self.dest_len += len(positions)
return intersecting_items
def extend(self, other: "MultiModalPlaceholderMap"):
"""
Adds the placeholders from another ``MultiModalPlaceholderMap`` to this
instance based on the source and destination tensors being
concatenated.
"""
self.src_ranges.extend(
range(self.src_len + r.start, self.src_len + r.stop)
for r in other.src_ranges)
self.src_len += other.src_len
self.dest_ranges.extend(
range(self.dest_len + r.start, self.dest_len + r.stop)
for r in other.dest_ranges)
self.dest_len += other.dest_len
def index_map(self) -> "IndexMap":
"""
Finalizes the placeholder map into lists of indices that can be used to
index the source and destination tensors.
"""
src_indices = [i for r in self.src_ranges for i in r]
dest_indices = [i for r in self.dest_ranges for i in r]
if len(src_indices) != len(dest_indices):
raise ValueError(
f"The number of source ({len(src_indices)}) and destination "
f"indices ({len(dest_indices)}) must be the same.")
return MultiModalPlaceholderMap.IndexMap(src=src_indices,
dest=dest_indices)
def __getattr__(name: str):
import warnings
if name == "MultiModalInputs":
msg = ("MultiModalInputs has been renamed to MultiModalKwargs. "
"The original name will take another meaning in an upcoming "
"version.")
warnings.warn(DeprecationWarning(msg), stacklevel=2)
return MultiModalKwargs
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

View File

@@ -0,0 +1,86 @@
from functools import lru_cache
from typing import TYPE_CHECKING, Any, Dict, Optional
import torch
from PIL import Image
from vllm.inputs.registry import InputContext
from vllm.logger import init_logger
from vllm.transformers_utils.processor import get_image_processor
from vllm.utils import is_list_of
from .base import MultiModalPlugin
from .inputs import ImageItem, MultiModalData, MultiModalKwargs
if TYPE_CHECKING:
from vllm.config import ModelConfig
logger = init_logger(__name__)
cached_get_image_processor = lru_cache(get_image_processor)
class ImagePlugin(MultiModalPlugin):
"""Plugin for image data."""
def get_data_key(self) -> str:
return "image"
def _get_hf_image_processor(
self,
model_config: "ModelConfig",
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
):
if mm_processor_kwargs is None:
mm_processor_kwargs = {}
return cached_get_image_processor(
model_config.model,
trust_remote_code=model_config.trust_remote_code,
**mm_processor_kwargs)
def _default_input_mapper(
self,
ctx: InputContext,
data: MultiModalData[ImageItem],
**mm_processor_kwargs,
) -> MultiModalKwargs:
model_config = ctx.model_config
# PIL image
if isinstance(data, Image.Image) or is_list_of(data, Image.Image):
image_processor = self._get_hf_image_processor(
model_config,
mm_processor_kwargs,
)
if image_processor is None:
raise RuntimeError("No HuggingFace processor is available "
"to process the image object")
try:
# NOTE: It may make sense to forward the mm_processor_kwargs
# here too. For now, to keep it simple, we only allow it be
# used for the initialization call though, just in case the
# signatures of the preprocessor initializer don't match
# preprocess()
batch_data = image_processor \
.preprocess(data, return_tensors="pt") \
.data
except Exception:
logger.error(
"Failed to process image (%s) with the default mapper. "
"This is most likely an edge-case with this model's image "
"processor in transformers (type: %s), and not vLLM.",
data,
type(image_processor).__name__)
raise
return MultiModalKwargs(batch_data)
# Image embedding
elif isinstance(data, torch.Tensor) or is_list_of(data, torch.Tensor):
return MultiModalKwargs({"image_embeds": data})
raise TypeError(f"Invalid image type: {type(data)}")
def _default_max_multimodal_tokens(self, ctx: InputContext) -> int:
return 3000

View File

@@ -0,0 +1,225 @@
from collections import UserDict, defaultdict
from typing import (Any, Dict, List, Literal, Mapping, Sequence, Tuple,
TypedDict, TypeVar, Union, cast, final)
import numpy as np
import torch
import torch.types
from PIL.Image import Image
from typing_extensions import TypeAlias
from vllm.utils import JSONTree, is_list_of, json_map_leaves
_T = TypeVar("_T")
# yapf: disable
ImageItem: TypeAlias = Union[Image, np.ndarray, torch.Tensor]
"""
A :class:`transformers.image_utils.ImageInput` representing a single image,
which can be passed to a HuggingFace :code:`ImageProcessor`.
"""
VideoItem: TypeAlias = Union[
List[Image],
np.ndarray,
torch.Tensor,
List[np.ndarray],
List[torch.Tensor],
]
"""
A :class:`transformers.image_utils.VideoInput` representing a single video,
which can be passed to a HuggingFace :code:`VideoProcessor`.
"""
AudioItem: TypeAlias = Union[
np.ndarray,
List[float],
Tuple[np.ndarray, float], # DEPRECATED: Use mm_processor_kwargs instead
]
"""
Represents a single audio that can be inputted to a HuggingFace
:code:`AudioProcessor`.
"""
# yapf: enable
MultiModalData: TypeAlias = Union[_T, List[_T]]
"""
Either a single data item, or a list of data items.
The number of data items allowed per modality is restricted by
:code:`--limit-mm-per-prompt`.
"""
@final
class MultiModalDataBuiltins(TypedDict, total=False):
"""Type annotations for modality types predefined by vLLM."""
image: MultiModalData[ImageItem]
"""The input image(s)."""
video: MultiModalData[VideoItem]
"""The input video(s)."""
audio: MultiModalData[AudioItem]
"""The input audio(s)."""
MultiModalDataDict: TypeAlias = Mapping[str, MultiModalData[Any]]
"""
A dictionary containing an entry for each modality type to input.
Note:
This dictionary also accepts modality keys defined outside
:class:`MultiModalDataBuiltins` as long as a customized plugin
is registered through the :class:`~vllm.multimodal.MULTIMODAL_REGISTRY`.
Read more on that :ref:`here <adding_multimodal_plugin>`.
"""
class PlaceholderRange(TypedDict):
"""
Placeholder location information for multi-modal data.
For example:
Prompt: AAAA BBBB What is in these images?
Images A and B will have:
A: { "offset": 0, "length": 4 }
B: { "offset": 5, "length": 4 }
"""
offset: int
"""The start index of the placeholder in the prompt."""
length: int
"""The length of the placeholder."""
NestedTensors = Union[List["NestedTensors"], List[torch.Tensor], torch.Tensor]
"""
Uses a list instead of a tensor if the dimensions of each element do not match.
"""
BatchedTensorInputs: TypeAlias = Dict[str, NestedTensors]
"""
A dictionary containing nested tensors which have been batched via
:meth:`MultiModalKwargs.batch`.
"""
class MultiModalKwargs(UserDict[str, NestedTensors]):
"""
A dictionary that represents the keyword arguments to
:meth:`~torch.nn.Module.forward`.
"""
@staticmethod
def _try_stack(nested_tensors: NestedTensors) -> NestedTensors:
"""
Stack the inner dimensions that have the same shape in
a nested list of tensors.
Thus, a dimension represented by a list means that the inner
dimensions are different for each element along that dimension.
"""
if isinstance(nested_tensors, torch.Tensor):
return nested_tensors
# TODO: Remove these once all models have been migrated
if isinstance(nested_tensors, np.ndarray):
return torch.from_numpy(nested_tensors)
if isinstance(nested_tensors, (int, float)):
return torch.tensor(nested_tensors)
stacked = [MultiModalKwargs._try_stack(t) for t in nested_tensors]
if not is_list_of(stacked, torch.Tensor, check="all"):
# Only tensors (not lists) can be stacked.
return stacked
tensors_ = cast(List[torch.Tensor], stacked)
if any(t.shape != tensors_[0].shape for t in tensors_):
# The tensors have incompatible shapes and can't be stacked.
return tensors_
return torch.stack(tensors_)
@staticmethod
def batch(inputs_list: List["MultiModalKwargs"]) -> BatchedTensorInputs:
"""
Batch multiple inputs together into a dictionary.
The resulting dictionary has the same keys as the inputs.
If the corresponding value from each input is a tensor and they all
share the same shape, the output value is a single batched tensor;
otherwise, the output value is a list containing the original value
from each input.
"""
if len(inputs_list) == 0:
return {}
# We need to consider the case where each item in the batch
# contains different modalities (i.e. different keys).
item_lists: Dict[str, List[NestedTensors]] = defaultdict(list)
for inputs in inputs_list:
for k, v in inputs.items():
item_lists[k].append(v)
return {
k: MultiModalKwargs._try_stack(item_list)
for k, item_list in item_lists.items()
}
@staticmethod
def as_kwargs(
batched_inputs: BatchedTensorInputs,
*,
device: torch.types.Device,
) -> BatchedTensorInputs:
json_inputs = cast(JSONTree[torch.Tensor], batched_inputs)
json_mapped = json_map_leaves(
lambda x: x.to(device, non_blocking=True),
json_inputs,
)
return cast(BatchedTensorInputs, json_mapped)
MultiModalPlaceholderDict = Mapping[str, Sequence[PlaceholderRange]]
"""
A dictionary containing placeholder ranges.
"""
class MultiModalInputsV2(TypedDict):
"""
Represents the outputs of :class:`vllm.multimodal.MultiModalProcessor`,
ready to be passed to vLLM internals.
"""
type: Literal["multimodal"]
"""The type of inputs."""
prompt: str
"""
The original, unprocessed prompt text.
Note:
Since prompt text is not required by vLLM internals, we leave this
unprocessed to save CPU computation. You can still call
:code:`tokenizer.decode(prompt_token_ids)` to get the processed text.
"""
prompt_token_ids: List[int]
"""The processed token IDs which includes placeholder tokens."""
mm_kwargs: MultiModalKwargs
"""Keyword arguments to be directly passed to the model after batching."""
mm_placeholders: MultiModalPlaceholderDict
"""
For each modality, information about the placeholder tokens in
:code:`prompt_token_ids`.
"""

View File

@@ -0,0 +1,273 @@
from dataclasses import dataclass
from functools import lru_cache, partial
from typing import (Any, Callable, Collection, Generic, List, Mapping,
Optional, TypedDict, TypeVar, final)
from transformers import BatchFeature
from typing_extensions import TypeAlias
from vllm.inputs import InputProcessingContext
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import is_list_of
from .inputs import (AudioItem, ImageItem, MultiModalDataDict,
MultiModalInputsV2, MultiModalKwargs, PlaceholderRange,
VideoItem)
_T = TypeVar("_T")
ReplacementFunc: TypeAlias = Callable[[_T, BatchFeature, int], List[int]]
"""
Given the original data item, HF-processed data, and index of the processed
item, output the replacement token IDs to be allocated in vLLM.
"""
@dataclass
class ModalityProcessingMetadata(Generic[_T]):
placeholder_replacements: Mapping[str, ReplacementFunc]
"""
A dictionary where each item represents the original placeholder in the
prompt text and the corresponding replacement.
"""
class MultiModalProcessingMetadataBuiltins(TypedDict, total=False):
"""Type annotations for modality types predefined by vLLM."""
image: ModalityProcessingMetadata[ImageItem]
video: ModalityProcessingMetadata[VideoItem]
audio: ModalityProcessingMetadata[AudioItem]
MultiModalProcessingMetadata: TypeAlias = \
Mapping[str, ModalityProcessingMetadata[Any]]
"""
A dictionary containing an entry for each modality type to process.
Note:
This dictionary also accepts modality keys defined outside
:class:`MultiModalProcessingMetadataBuiltins` as long as a customized plugin
is registered through the :class:`~vllm.multimodal.MULTIMODAL_REGISTRY`.
Read more on that :ref:`here <adding_multimodal_plugin>`.
"""
MultiModalMultiData: TypeAlias = List[_T]
"""
A list of data items, where the number of data items allowed
per modality is restricted by :code:`--limit-mm-per-prompt`.
"""
@final
class MultiModalMultiDataBuiltins(TypedDict, total=False):
"""Type annotations for modality types predefined by vLLM."""
image: MultiModalMultiData[ImageItem]
"""The input images."""
video: MultiModalMultiData[VideoItem]
"""The input videos."""
audio: MultiModalMultiData[AudioItem]
"""The input audios."""
MultiModalMultiDataDict: TypeAlias = Mapping[str, MultiModalMultiData[Any]]
"""
A dictionary containing an entry for each modality type to input.
Note:
This dictionary also accepts modality keys defined outside
:class:`MultiModalMultiDataBuiltins` as long as a customized plugin
is registered through the :class:`~vllm.multimodal.MULTIMODAL_REGISTRY`.
Read more on that :ref:`here <adding_multimodal_plugin>`.
"""
def to_multi_format(data: MultiModalDataDict) -> MultiModalMultiDataDict:
"""
Convert a :class:`MultiModalDataDict` containing single data items
to a :class:`MultiModalMultiDataDict` containing multiple data items
per entry.
"""
multi_data: Mapping[str, MultiModalMultiData[Any]] = {}
for k, v in data.items():
# yapf: disable
if k == "video":
# Special case since even a single item can be a list
multi_data[k] = v if is_list_of(v, list) else [v] # type: ignore[index]
elif k in ("image", "audio"):
multi_data[k] = v if isinstance(v, list) else [v] # type: ignore[index]
else:
multi_data[k] = v if isinstance(v, list) else [v] # type: ignore[index]
# yapf: enable
return multi_data
def encode_no_special_tokens(
tokenizer: AnyTokenizer,
text: str,
) -> List[int]:
"""
Backend-agnostic equivalent of HF's
:code:`tokenizer.encode(text, add_special_tokens=False)`.
"""
if isinstance(tokenizer, MistralTokenizer):
return tokenizer.tokenizer.encode(text, bos=False, eos=False)
return tokenizer.encode(text, add_special_tokens=False)
@lru_cache
def candidate_placeholders(
tokenizer: AnyTokenizer,
placeholder_text: str,
) -> Collection[List[int]]:
"""Generate token ID sequences that may represent a placeholder text."""
# When the placeholder text is not mapped to a special token ID,
# it may be tokenized differently based on whether it is at the start/end
# of the string. So, we go through each combination of whether the text
# is at the start and end boundaries of the string
# Matches the placeholder when it is in the middle of the string
start_id, = encode_no_special_tokens(tokenizer, "a")
end_id, = encode_no_special_tokens(tokenizer, "b")
candidate_basic = encode_no_special_tokens(tokenizer, placeholder_text)
start_id_, *candidate_a = encode_no_special_tokens(
tokenizer,
f"a{placeholder_text}",
)
assert start_id == start_id_
start_id_, *candidate_ab, end_id_ = encode_no_special_tokens(
tokenizer,
f"a{placeholder_text}b",
)
assert start_id == start_id_ and end_id == end_id_
*candidate_b, end_id_ = encode_no_special_tokens(
tokenizer,
f"{placeholder_text}b",
)
assert end_id == end_id_
# Remove duplicates (need to convert to tuple to be hashable)
unique_candidates = {
tuple(c)
for c in [candidate_basic, candidate_a, candidate_ab, candidate_b]
}
# Convert back to list
return [list(c) for c in unique_candidates]
def apply_placeholders(
token_ids: List[int],
placeholder_ids: List[int],
get_replacement_ids: Callable[[], List[int]],
) -> Optional[PlaceholderRange]:
"""
Find the first occurrence of :code:`placeholder_ids`,
and replace it with the output of :code:`get_replacement_ids`.
This function updates :code:`token_ids` in place.
"""
placeholder_length = len(placeholder_ids)
for start_idx in range(len(token_ids) - placeholder_length + 1):
if token_ids[start_idx:placeholder_length] == placeholder_ids:
token_ids[start_idx:placeholder_length] = get_replacement_ids()
return PlaceholderRange(offset=start_idx,
length=placeholder_length)
return None
class MultiModalProcessor:
"""
Helper class to process multi-modal inputs to be used in vLLM.
"""
def __init__(
self,
ctx: InputProcessingContext,
metadata: MultiModalProcessingMetadata,
) -> None:
super().__init__()
self.ctx = ctx
self.metadata = metadata
def __call__(
self,
prompt: str,
mm_data: MultiModalDataDict,
mm_processor_kwargs: Mapping[str, object],
) -> MultiModalInputsV2:
return self.apply(prompt, mm_data, mm_processor_kwargs)
def apply(
self,
prompt: str,
mm_data: MultiModalDataDict,
mm_processor_kwargs: Mapping[str, object],
) -> MultiModalInputsV2:
tokenizer = self.ctx.tokenizer
hf_processor = self.ctx.get_hf_processor()
processed_inputs = hf_processor(
text=prompt, # type: ignore
**mm_data,
**mm_processor_kwargs,
)
new_token_ids, = processed_inputs.pop("input_ids").tolist()
mm_kwargs = MultiModalKwargs(processed_inputs)
mm_placeholders: Mapping[str, List[PlaceholderRange]] = {}
for modality, orig_inputs in to_multi_format(mm_data).items():
assert isinstance(orig_inputs, list)
metadata = self.metadata[modality]
placeholder_replacements = metadata.placeholder_replacements
modality_placeholders: List[PlaceholderRange] = []
for item_idx, orig_item in enumerate(orig_inputs):
for match_text, replace_fn in placeholder_replacements.items():
candidates = candidate_placeholders(tokenizer, match_text)
get_replacement_ids = partial(
replace_fn,
orig_item,
processed_inputs,
item_idx,
)
for match_ids in candidates:
# TODO(youkaichao): Don't update new_token_ids
placeholders = apply_placeholders(
new_token_ids,
match_ids,
get_replacement_ids,
)
if placeholders is not None:
modality_placeholders.append(placeholders)
# yapf: disable
mm_placeholders[modality] = modality_placeholders # type: ignore[index]
# yapf: enable
return MultiModalInputsV2(
type="multimodal",
prompt=prompt,
prompt_token_ids=new_token_ids,
mm_kwargs=mm_kwargs,
mm_placeholders=mm_placeholders,
)

View File

@@ -0,0 +1,321 @@
import functools
from collections import UserDict
from typing import (TYPE_CHECKING, Any, Callable, Dict, Mapping, Optional,
Sequence, Type, TypeVar)
import torch.nn as nn
from typing_extensions import TypeAlias
from vllm.inputs import InputProcessingContext
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from .audio import AudioPlugin
from .base import MultiModalInputMapper, MultiModalPlugin, MultiModalTokensCalc
from .image import ImagePlugin
from .inputs import MultiModalDataDict, MultiModalKwargs, NestedTensors
from .processing import MultiModalProcessor
from .video import VideoPlugin
if TYPE_CHECKING:
from vllm.config import ModelConfig
logger = init_logger(__name__)
N = TypeVar("N", bound=Type[nn.Module])
MultiModalProcessorFactory: TypeAlias = Callable[[InputProcessingContext],
MultiModalProcessor]
"""
Constructs a :class:`MultiModalProcessor` instance from the context.
The processing metadata should be derived from the context.
"""
class _MultiModalLimits(UserDict["ModelConfig", Dict[str, int]]):
"""
Wraps `_limits_by_model` for a more informative error message
when attempting to access a model that does not exist.
"""
def __getitem__(self, key: "ModelConfig") -> Dict[str, int]:
try:
return super().__getitem__(key)
except KeyError as exc:
msg = (f"Cannot find `mm_limits` for model={key.model}. Did you "
"forget to call `init_mm_limits_per_prompt`?")
raise KeyError(msg) from exc
class MultiModalRegistry:
"""
A registry that dispatches data processing to the
:class:`~vllm.multimodal.MultiModalPlugin` for each modality.
"""
DEFAULT_PLUGINS = (ImagePlugin(), AudioPlugin(), VideoPlugin())
def __init__(
self,
*,
plugins: Sequence[MultiModalPlugin] = DEFAULT_PLUGINS) -> None:
self._plugins = {p.get_data_key(): p for p in plugins}
self._processor_factories: Dict[Type[nn.Module],
MultiModalProcessorFactory] = {}
# This is used for non-multimodal models
self._disabled_limits_per_plugin = {k: 0 for k in self._plugins}
self._limits_by_model = _MultiModalLimits()
def register_plugin(self, plugin: MultiModalPlugin) -> None:
"""
Register a multi-modal plugin so it can be recognized by vLLM.
See also:
:ref:`adding_multimodal_plugin`
"""
data_type_key = plugin.get_data_key()
if data_type_key in self._plugins:
logger.warning(
"A plugin is already registered for data type %s, "
"and will be overwritten by the new plugin %s.", data_type_key,
plugin)
self._plugins[data_type_key] = plugin
def _get_plugin(self, data_type_key: str):
plugin = self._plugins.get(data_type_key)
if plugin is not None:
return plugin
msg = f"Unknown multi-modal data type: {data_type_key}"
raise NotImplementedError(msg)
def register_input_mapper(
self,
data_type_key: str,
mapper: Optional[MultiModalInputMapper] = None,
):
"""
Register an input mapper for a specific modality to a model class.
See :meth:`MultiModalPlugin.register_input_mapper` for more details.
"""
return self._get_plugin(data_type_key).register_input_mapper(mapper)
def register_image_input_mapper(
self,
mapper: Optional[MultiModalInputMapper] = None,
):
"""
Register an input mapper for image data to a model class.
See :meth:`MultiModalPlugin.register_input_mapper` for more details.
"""
return self.register_input_mapper("image", mapper)
def map_input(
self,
model_config: "ModelConfig",
data: MultiModalDataDict,
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
) -> MultiModalKwargs:
"""
Apply an input mapper to the data passed to the model.
The data belonging to each modality is passed to the corresponding
plugin which in turn converts the data into into keyword arguments
via the input mapper registered for that model.
See :meth:`MultiModalPlugin.map_input` for more details.
Note:
This should be called after :meth:`init_mm_limits_per_prompt`.
"""
merged_dict: Dict[str, NestedTensors] = {}
for data_key, data_value in data.items():
plugin = self._get_plugin(data_key)
num_items = len(data_value) if isinstance(data_value, list) else 1
max_items = self._limits_by_model[model_config][data_key]
if num_items > max_items:
raise ValueError(
f"You set {data_key}={max_items} (or defaulted to 1) in "
f"`--limit-mm-per-prompt`, but found {num_items} items "
"in the same prompt.")
input_dict = plugin.map_input(model_config, data_value,
mm_processor_kwargs)
for input_key, input_tensor in input_dict.items():
if input_key in merged_dict:
raise ValueError(f"The input mappers (keys={set(data)}) "
f"resulted in a conflicting keyword "
f"argument to `forward()`: {input_key}")
merged_dict[input_key] = input_tensor
return MultiModalKwargs(merged_dict)
def create_input_mapper(self, model_config: "ModelConfig"):
"""
Create an input mapper (see :meth:`map_input`) for a specific model.
"""
# NOTE - we currently make the assumption that if a model has multiple
# supported modalities, they take the same kwargs. For the default,
# this could be an issue in the future if it falls back to two HF
# resources and we can't inspect the signature easily since it's
# getting initialized through the autoclass.
#
# If this is a problem in the future, we should revisit it, but since
# it potentially introduces a lot of complexity for a currently
# uncommon case, we do not for simplicity of both use & implementation
return functools.partial(self.map_input, model_config)
def register_max_multimodal_tokens(
self,
data_type_key: str,
max_mm_tokens: Optional[MultiModalTokensCalc] = None,
):
"""
Register the maximum number of tokens, corresponding to a single
instance of multimodal data belonging to a specific modality, that are
passed to the language model for a model class.
"""
return self._get_plugin(data_type_key) \
.register_max_multimodal_tokens(max_mm_tokens)
def register_max_image_tokens(
self,
max_mm_tokens: Optional[MultiModalTokensCalc] = None,
):
"""
Register the maximum number of image tokens, corresponding to a single
image, that are passed to the language model for a model class.
"""
return self.register_max_multimodal_tokens("image", max_mm_tokens)
def get_max_multimodal_tokens(self, model_config: "ModelConfig") -> int:
"""
Get the maximum number of multi-modal tokens
for profiling the memory usage of a model.
See :meth:`MultiModalPlugin.get_max_multimodal_tokens` for more details.
Note:
This should be called after :meth:`init_mm_limits_per_prompt`.
"""
limits_per_plugin = self._limits_by_model[model_config]
return sum((limits_per_plugin[key] *
plugin.get_max_multimodal_tokens(model_config))
for key, plugin in self._plugins.items())
def init_mm_limits_per_prompt(
self,
model_config: "ModelConfig",
) -> None:
"""
Initialize the maximum number of multi-modal input instances for each
modality that are allowed per prompt for a model class.
"""
if model_config in self._limits_by_model:
logger.warning(
"`mm_limits` has already been set for model=%s, and will "
"be overwritten by the new values.", model_config.model)
multimodal_config = model_config.multimodal_config
if multimodal_config is None:
limits_per_plugin = self._disabled_limits_per_plugin
else:
config_limits_per_plugin = multimodal_config.limit_per_prompt
extra_keys = config_limits_per_plugin.keys() - self._plugins.keys()
if extra_keys:
logger.warning(
"Detected extra keys in `--limit-mm-per-prompt` which "
"are not registered as multi-modal plugins: %s. "
"They will be ignored.", extra_keys)
# NOTE: Currently the default is set to 1 for each plugin
# TODO: Automatically determine the limits based on budget
# once more models support multi-image inputs
limits_per_plugin = {
key: config_limits_per_plugin.get(key, 1)
for key in self._plugins
}
self._limits_by_model[model_config] = limits_per_plugin
def get_mm_limits_per_prompt(
self,
model_config: "ModelConfig",
) -> Mapping[str, int]:
"""
Get the maximum number of multi-modal input instances for each modality
that are allowed per prompt for a model class.
Note:
This should be called after :meth:`init_mm_limits_per_prompt`.
"""
return self._limits_by_model[model_config]
def register_processor(
self,
factory: MultiModalProcessorFactory,
):
"""
Register a multi-modal processor to a model class.
When the model receives multi-modal data, the provided function is
invoked to transform the data into a dictionary of model inputs.
See also:
- :ref:`input_processing_pipeline`
- :ref:`enabling_multimodal_inputs`
"""
def wrapper(model_cls: N) -> N:
if model_cls in self._processor_factories:
logger.warning(
"Model class %s already has an input mapper "
"registered to %s. It is overwritten by the new one.",
model_cls, self)
self._processor_factories[model_cls] = factory
return model_cls
return wrapper
def has_processor(self, model_config: "ModelConfig") -> bool:
"""
Test whether a multi-modal processor is defined for a specific model.
"""
# Avoid circular import
from vllm.model_executor.model_loader import get_model_architecture
model_cls, _ = get_model_architecture(model_config)
return model_cls in self._processor_factories
def create_processor(
self,
model_config: "ModelConfig",
tokenizer: AnyTokenizer,
) -> MultiModalProcessor:
"""
Create a multi-modal processor for a specific model and tokenizer.
"""
# Avoid circular import
from vllm.model_executor.model_loader import get_model_architecture
model_cls, _ = get_model_architecture(model_config)
processor_factory = self._processor_factories[model_cls]
ctx = InputProcessingContext(model_config, tokenizer)
return processor_factory(ctx)

View File

@@ -0,0 +1,501 @@
import base64
import os
from functools import lru_cache
from io import BytesIO
from typing import Any, List, Optional, Tuple, TypeVar, Union
import numpy as np
import numpy.typing as npt
from PIL import Image
import vllm.envs as envs
from vllm.connections import global_http_connection
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
from .inputs import MultiModalDataDict, PlaceholderRange
logger = init_logger(__name__)
cached_get_tokenizer = lru_cache(get_tokenizer)
def _load_image_from_bytes(b: bytes) -> Image.Image:
image = Image.open(BytesIO(b))
image.load()
return image
def _is_subpath(image_path: str, allowed_local_media_path: str) -> bool:
# Get the common path
common_path = os.path.commonpath([
os.path.abspath(image_path),
os.path.abspath(allowed_local_media_path)
])
# Check if the common path is the same as allowed_local_media_path
return common_path == os.path.abspath(allowed_local_media_path)
def _load_image_from_file(image_url: str,
allowed_local_media_path: str) -> Image.Image:
if not allowed_local_media_path:
raise ValueError("Invalid 'image_url': Cannot load local files without"
"'--allowed-local-media-path'.")
if allowed_local_media_path:
if not os.path.exists(allowed_local_media_path):
raise ValueError(
"Invalid '--allowed-local-media-path': "
f"The path {allowed_local_media_path} does not exist.")
if not os.path.isdir(allowed_local_media_path):
raise ValueError(
"Invalid '--allowed-local-media-path': "
f"The path {allowed_local_media_path} must be a directory.")
# Only split once and assume the second part is the image path
_, image_path = image_url.split("file://", 1)
if not _is_subpath(image_path, allowed_local_media_path):
raise ValueError(
f"Invalid 'image_url': The file path {image_path} must"
" be a subpath of '--allowed-local-media-path'"
f" '{allowed_local_media_path}'.")
image = Image.open(image_path)
image.load()
return image
def _load_image_from_data_url(image_url: str) -> Image.Image:
# Only split once and assume the second part is the base64 encoded image
_, image_base64 = image_url.split(",", 1)
return load_image_from_base64(image_base64)
def fetch_image(image_url: str,
*,
image_mode: str = "RGB",
allowed_local_media_path: str = "") -> Image.Image:
"""
Load a PIL image from a HTTP or base64 data URL.
By default, the image is converted into RGB format.
"""
if image_url.startswith('http'):
image_raw = global_http_connection.get_bytes(
image_url,
timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT,
)
image = _load_image_from_bytes(image_raw)
elif image_url.startswith('data:image'):
image = _load_image_from_data_url(image_url)
elif image_url.startswith('file://'):
image = _load_image_from_file(image_url, allowed_local_media_path)
else:
raise ValueError("Invalid 'image_url': A valid 'image_url' must start "
"with either 'data:image', 'file://' or 'http'.")
return image.convert(image_mode)
async def async_fetch_image(image_url: str,
*,
image_mode: str = "RGB",
allowed_local_media_path: str = "") -> Image.Image:
"""
Asynchronously load a PIL image from a HTTP or base64 data URL.
By default, the image is converted into RGB format.
"""
if image_url.startswith('http'):
image_raw = await global_http_connection.async_get_bytes(
image_url,
timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT,
)
image = _load_image_from_bytes(image_raw)
elif image_url.startswith('data:image'):
image = _load_image_from_data_url(image_url)
elif image_url.startswith('file://'):
image = _load_image_from_file(image_url, allowed_local_media_path)
else:
raise ValueError("Invalid 'image_url': A valid 'image_url' must start "
"with either 'data:image', 'file://' or 'http'.")
return image.convert(image_mode)
def _load_video_frames_from_bytes(b: bytes):
frame = Image.open(BytesIO(b))
return np.array(frame)
def load_video_frames_from_base64(frame: Union[bytes, str]):
"""Load frame from base64 format."""
return _load_video_frames_from_bytes(base64.b64decode(frame))
def _load_video_from_bytes(b: bytes, num_frames: int = 32):
_, decord = try_import_video_packages()
video_path = BytesIO(b)
vr = decord.VideoReader(video_path, num_threads=1)
total_frame_num = len(vr)
if total_frame_num > num_frames:
uniform_sampled_frames = np.linspace(0,
total_frame_num - 1,
num_frames,
dtype=int)
frame_idx = uniform_sampled_frames.tolist()
else:
frame_idx = [i for i in range(0, total_frame_num)]
frames = vr.get_batch(frame_idx).asnumpy()
return frames
def _load_video_from_data_url(video_url: str):
# Only split once and assume the second part is the base64 encoded image
frames_base64 = video_url.split(",")[1:]
return np.stack([
load_video_frames_from_base64(frame_base64)
for frame_base64 in frames_base64
])
def fetch_video(video_url: str, *, num_frames: int = 32) -> npt.NDArray:
"""
Load video from a HTTP or base64 data URL.
"""
if video_url.startswith('http') or video_url.startswith('https'):
video_raw = global_http_connection.get_bytes(
video_url,
timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT,
)
video = _load_video_from_bytes(video_raw, num_frames)
elif video_url.startswith('data:video'):
video = _load_video_from_data_url(video_url)
else:
raise ValueError("Invalid 'video_url': A valid 'video_url' must start "
"with either 'data:video' or 'http'.")
return video
async def async_fetch_video(video_url: str,
*,
num_frames: int = 32) -> npt.NDArray:
"""
Asynchronously load video from a HTTP or base64 data URL.
By default, the image is converted into RGB format.
"""
if video_url.startswith('http') or video_url.startswith('https'):
video_raw = await global_http_connection.async_get_bytes(
video_url,
timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT,
)
video = _load_video_from_bytes(video_raw, num_frames)
elif video_url.startswith('data:video'):
video = _load_video_from_data_url(video_url)
else:
raise ValueError("Invalid 'video_url': A valid 'video_url' must start "
"with either 'data:video' or 'http'.")
return video
def try_import_audio_packages() -> Tuple[Any, Any]:
try:
import librosa
import soundfile
except ImportError as exc:
raise ImportError(
"Please install vllm[audio] for audio support.") from exc
return librosa, soundfile
def fetch_audio(audio_url: str) -> Tuple[np.ndarray, Union[int, float]]:
"""
Load audio from a URL.
"""
librosa, _ = try_import_audio_packages()
if audio_url.startswith("http"):
audio_bytes = global_http_connection.get_bytes(
audio_url,
timeout=envs.VLLM_AUDIO_FETCH_TIMEOUT,
)
elif audio_url.startswith("data:audio"):
_, audio_base64 = audio_url.split(",", 1)
audio_bytes = base64.b64decode(audio_base64)
else:
raise ValueError("Invalid 'audio_url': A valid 'audio_url' must start "
"with either 'data:audio' or 'http'.")
return librosa.load(BytesIO(audio_bytes), sr=None)
async def async_fetch_audio(
audio_url: str) -> Tuple[np.ndarray, Union[int, float]]:
"""
Asynchronously fetch audio from a URL.
"""
librosa, _ = try_import_audio_packages()
if audio_url.startswith("http"):
audio_bytes = await global_http_connection.async_get_bytes(
audio_url,
timeout=envs.VLLM_AUDIO_FETCH_TIMEOUT,
)
elif audio_url.startswith("data:audio"):
_, audio_base64 = audio_url.split(",", 1)
audio_bytes = base64.b64decode(audio_base64)
else:
raise ValueError("Invalid 'audio_url': A valid 'audio_url' must start "
"with either 'data:audio' or 'http'.")
return librosa.load(BytesIO(audio_bytes), sr=None)
def get_and_parse_audio(audio_url: str) -> MultiModalDataDict:
audio, sr = fetch_audio(audio_url)
return {"audio": (audio, sr)}
def get_and_parse_image(
image_url: str,
*,
allowed_local_media_path: str = "") -> MultiModalDataDict:
image = fetch_image(image_url,
allowed_local_media_path=allowed_local_media_path)
return {"image": image}
def get_and_parse_video(video_url: str) -> MultiModalDataDict:
video = fetch_video(video_url)
return {"video": video}
async def async_get_and_parse_audio(audio_url: str) -> MultiModalDataDict:
audio, sr = await async_fetch_audio(audio_url)
return {"audio": (audio, sr)}
async def async_get_and_parse_image(
image_url: str,
*,
allowed_local_media_path: str = "") -> MultiModalDataDict:
image = await async_fetch_image(
image_url, allowed_local_media_path=allowed_local_media_path)
return {"image": image}
async def async_get_and_parse_video(video_url: str) -> MultiModalDataDict:
video = await async_fetch_video(video_url)
return {"video": video}
def encode_audio_base64(
audio: np.ndarray,
sampling_rate: int,
) -> str:
"""Encode audio as base64."""
_, soundfile = try_import_audio_packages()
buffered = BytesIO()
soundfile.write(buffered, audio, sampling_rate, format="WAV")
return base64.b64encode(buffered.getvalue()).decode('utf-8')
def encode_image_base64(
image: Image.Image,
*,
image_mode: str = "RGB",
format: str = "JPEG",
) -> str:
"""
Encode a pillow image to base64 format.
By default, the image is converted into RGB format before being encoded.
"""
buffered = BytesIO()
image = image.convert(image_mode)
image.save(buffered, format)
return base64.b64encode(buffered.getvalue()).decode('utf-8')
def load_image_from_base64(image: Union[bytes, str]) -> Image.Image:
"""Load image from base64 format."""
return _load_image_from_bytes(base64.b64decode(image))
def rescale_image_size(image: Image.Image,
size_factor: float,
transpose: int = -1) -> Image.Image:
"""Rescale the dimensions of an image by a constant factor."""
new_width = int(image.width * size_factor)
new_height = int(image.height * size_factor)
image = image.resize((new_width, new_height))
if transpose >= 0:
image = image.transpose(Image.Transpose(transpose))
return image
def try_import_video_packages() -> Any:
try:
import cv2
import decord
except ImportError as exc:
raise ImportError(
"Please install vllm[video] for video support.") from exc
return cv2, decord
def resize_video(frames: npt.NDArray, size: Tuple[int, int]) -> npt.NDArray:
cv2, _ = try_import_video_packages()
num_frames, _, _, channels = frames.shape
new_height, new_width = size
resized_frames = np.empty((num_frames, new_height, new_width, channels),
dtype=frames.dtype)
for i, frame in enumerate(frames):
resized_frame = cv2.resize(frame, (new_width, new_height))
resized_frames[i] = resized_frame
return resized_frames
def rescale_video_size(frames: npt.NDArray, size_factor: float) -> npt.NDArray:
_, height, width, _ = frames.shape
new_height = int(height * size_factor)
new_width = int(width * size_factor)
return resize_video(frames, (new_height, new_width))
def sample_frames_from_video(frames: npt.NDArray,
num_frames: int) -> npt.NDArray:
total_frames = frames.shape[0]
if num_frames == -1:
return frames
else:
frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
sampled_frames = frames[frame_indices, ...]
return sampled_frames
def encode_video_base64(frames: npt.NDArray):
base64_frames = []
frames_list = [frames[i] for i in range(frames.shape[0])]
for frame in frames_list:
img_base64 = encode_image_base64(Image.fromarray(frame))
base64_frames.append(img_base64)
return ",".join(base64_frames)
# Utilities for input processors
_T = TypeVar("_T", str, int)
def repeat_and_pad_token(
token: _T,
*,
repeat_count: int = 1,
pad_token_left: Optional[_T] = None,
pad_token_right: Optional[_T] = None,
) -> List[_T]:
replacement = [token] * repeat_count
if pad_token_left is not None:
replacement = [pad_token_left] + replacement
if pad_token_right is not None:
replacement = replacement + [pad_token_right]
return replacement
def repeat_and_pad_placeholder_tokens(
tokenizer: AnyTokenizer,
prompt: Optional[str],
prompt_token_ids: List[int],
*,
placeholder_token_id: int,
repeat_count: Union[int, List[int]],
pad_token_left: Optional[int] = None,
pad_token_right: Optional[int] = None,
) -> Tuple[Optional[str], List[int], List[PlaceholderRange]]:
if isinstance(repeat_count, int):
repeat_count = [repeat_count]
if prompt is None:
new_prompt = None
else:
placeholder_token_str = tokenizer.decode(placeholder_token_id)
pad_token_str_left = (None if pad_token_left is None else
tokenizer.decode(pad_token_left))
pad_token_str_right = (None if pad_token_right is None else
tokenizer.decode(pad_token_right))
placeholder_token_count = prompt.count(placeholder_token_str)
# This is an arbitrary number to distinguish between the two cases
if placeholder_token_count > 16:
logger.warning(
"Please follow the prompt format that is "
"documented on HuggingFace which does not involve "
"repeating %s tokens.", placeholder_token_str)
if placeholder_token_count < len(repeat_count):
logger.warning(
"The number of multi-modal placeholder tokens in the prompt "
"is less than the number of multi-modal inputs. Extra "
"placeholder tokens will be treated as plain text")
repeat_count = repeat_count[:placeholder_token_count]
prompt_parts = prompt.split(placeholder_token_str,
maxsplit=len(repeat_count))
new_prompt = ""
for i, repeat_count_item in enumerate(repeat_count):
replacement_str = "".join(
repeat_and_pad_token(
placeholder_token_str,
repeat_count=repeat_count_item,
pad_token_left=pad_token_str_left,
pad_token_right=pad_token_str_right,
))
# The image tokens are removed to be consistent with HuggingFace
new_prompt += prompt_parts[i] + replacement_str
new_prompt += prompt_parts[-1]
new_token_ids: List[int] = []
placeholder_ranges: List[PlaceholderRange] = []
placeholder_token_idx = 0
for i, token in enumerate(prompt_token_ids):
if token == placeholder_token_id:
replacement_ids = repeat_and_pad_token(
placeholder_token_id,
repeat_count=repeat_count[placeholder_token_idx],
pad_token_left=pad_token_left,
pad_token_right=pad_token_right,
)
placeholder_ranges.append({
"offset": len(new_token_ids),
"length": len(replacement_ids)
})
new_token_ids.extend(replacement_ids)
placeholder_token_idx += 1
# No need to further scan the list since we replaced all tokens
if placeholder_token_idx >= len(repeat_count):
new_token_ids.extend(prompt_token_ids[i + 1:])
break
else:
new_token_ids.append(token)
return new_prompt, new_token_ids, placeholder_ranges
def consecutive_placeholder_ranges(num_items: int,
item_size: int) -> List[PlaceholderRange]:
"""Returns a list of consecutive PlaceholderRanges of a fixed size"""
return [
PlaceholderRange(offset=i * item_size, length=item_size)
for i in range(num_items)
]

View File

@@ -0,0 +1,77 @@
from functools import lru_cache
from typing import TYPE_CHECKING, Any, Dict, Optional
import numpy as np
from vllm.inputs.registry import InputContext
from vllm.logger import init_logger
from vllm.transformers_utils.processor import get_video_processor
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.utils import is_list_of
from .base import MultiModalData
from .image import ImagePlugin
from .inputs import MultiModalKwargs, VideoItem
if TYPE_CHECKING:
from vllm.config import ModelConfig
logger = init_logger(__name__)
cached_get_video_processor = lru_cache(get_video_processor)
cached_get_tokenizer = lru_cache(get_tokenizer)
class VideoPlugin(ImagePlugin):
"""Plugin for video data."""
def get_data_key(self) -> str:
return "video"
def _get_hf_video_processor(
self,
model_config: "ModelConfig",
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
):
if mm_processor_kwargs is None:
mm_processor_kwargs = {}
return cached_get_video_processor(
model_config.model,
trust_remote_code=model_config.trust_remote_code,
**mm_processor_kwargs)
def _default_input_mapper(
self,
ctx: InputContext,
data: MultiModalData[VideoItem],
**mm_processor_kwargs,
) -> MultiModalKwargs:
model_config = ctx.model_config
if isinstance(data, list) and len(data) == 1:
data = data[0] # type: ignore
if isinstance(data, np.ndarray) or is_list_of(data, np.ndarray):
video_processor = self._get_hf_video_processor(
model_config,
mm_processor_kwargs,
)
if video_processor is None:
raise RuntimeError("No HuggingFace processor is available "
"to process the video object")
try:
# NOTE: Similar to image; it may be a good idea to filter and
# pass mm_processor_kwargs here too, but for now we don't to
# avoid extra complexity if the initializer and preprocess
# signatures of the processor don't align
batch_data = video_processor(data, return_tensors="pt").data
except Exception:
logger.error("Failed to process video (%s)", data)
raise
return MultiModalKwargs(batch_data)
raise TypeError(f"Invalid video type: {type(data)}")
def _default_max_multimodal_tokens(self, ctx: InputContext) -> int:
return 4096