Files
enginex-mlu370-vllm/vllm-v0.6.2/vllm/multimodal/processing.py
2026-02-04 17:22:39 +08:00

274 lines
8.8 KiB
Python

from dataclasses import dataclass
from functools import lru_cache, partial
from typing import (Any, Callable, Collection, Generic, List, Mapping,
Optional, TypedDict, TypeVar, final)
from transformers import BatchFeature
from typing_extensions import TypeAlias
from vllm.inputs import InputProcessingContext
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import is_list_of
from .inputs import (AudioItem, ImageItem, MultiModalDataDict,
MultiModalInputsV2, MultiModalKwargs, PlaceholderRange,
VideoItem)
_T = TypeVar("_T")
ReplacementFunc: TypeAlias = Callable[[_T, BatchFeature, int], List[int]]
"""
Given the original data item, HF-processed data, and index of the processed
item, output the replacement token IDs to be allocated in vLLM.
"""
@dataclass
class ModalityProcessingMetadata(Generic[_T]):
placeholder_replacements: Mapping[str, ReplacementFunc]
"""
A dictionary where each item represents the original placeholder in the
prompt text and the corresponding replacement.
"""
class MultiModalProcessingMetadataBuiltins(TypedDict, total=False):
"""Type annotations for modality types predefined by vLLM."""
image: ModalityProcessingMetadata[ImageItem]
video: ModalityProcessingMetadata[VideoItem]
audio: ModalityProcessingMetadata[AudioItem]
MultiModalProcessingMetadata: TypeAlias = \
Mapping[str, ModalityProcessingMetadata[Any]]
"""
A dictionary containing an entry for each modality type to process.
Note:
This dictionary also accepts modality keys defined outside
:class:`MultiModalProcessingMetadataBuiltins` as long as a customized plugin
is registered through the :class:`~vllm.multimodal.MULTIMODAL_REGISTRY`.
Read more on that :ref:`here <adding_multimodal_plugin>`.
"""
MultiModalMultiData: TypeAlias = List[_T]
"""
A list of data items, where the number of data items allowed
per modality is restricted by :code:`--limit-mm-per-prompt`.
"""
@final
class MultiModalMultiDataBuiltins(TypedDict, total=False):
"""Type annotations for modality types predefined by vLLM."""
image: MultiModalMultiData[ImageItem]
"""The input images."""
video: MultiModalMultiData[VideoItem]
"""The input videos."""
audio: MultiModalMultiData[AudioItem]
"""The input audios."""
MultiModalMultiDataDict: TypeAlias = Mapping[str, MultiModalMultiData[Any]]
"""
A dictionary containing an entry for each modality type to input.
Note:
This dictionary also accepts modality keys defined outside
:class:`MultiModalMultiDataBuiltins` as long as a customized plugin
is registered through the :class:`~vllm.multimodal.MULTIMODAL_REGISTRY`.
Read more on that :ref:`here <adding_multimodal_plugin>`.
"""
def to_multi_format(data: MultiModalDataDict) -> MultiModalMultiDataDict:
"""
Convert a :class:`MultiModalDataDict` containing single data items
to a :class:`MultiModalMultiDataDict` containing multiple data items
per entry.
"""
multi_data: Mapping[str, MultiModalMultiData[Any]] = {}
for k, v in data.items():
# yapf: disable
if k == "video":
# Special case since even a single item can be a list
multi_data[k] = v if is_list_of(v, list) else [v] # type: ignore[index]
elif k in ("image", "audio"):
multi_data[k] = v if isinstance(v, list) else [v] # type: ignore[index]
else:
multi_data[k] = v if isinstance(v, list) else [v] # type: ignore[index]
# yapf: enable
return multi_data
def encode_no_special_tokens(
tokenizer: AnyTokenizer,
text: str,
) -> List[int]:
"""
Backend-agnostic equivalent of HF's
:code:`tokenizer.encode(text, add_special_tokens=False)`.
"""
if isinstance(tokenizer, MistralTokenizer):
return tokenizer.tokenizer.encode(text, bos=False, eos=False)
return tokenizer.encode(text, add_special_tokens=False)
@lru_cache
def candidate_placeholders(
tokenizer: AnyTokenizer,
placeholder_text: str,
) -> Collection[List[int]]:
"""Generate token ID sequences that may represent a placeholder text."""
# When the placeholder text is not mapped to a special token ID,
# it may be tokenized differently based on whether it is at the start/end
# of the string. So, we go through each combination of whether the text
# is at the start and end boundaries of the string
# Matches the placeholder when it is in the middle of the string
start_id, = encode_no_special_tokens(tokenizer, "a")
end_id, = encode_no_special_tokens(tokenizer, "b")
candidate_basic = encode_no_special_tokens(tokenizer, placeholder_text)
start_id_, *candidate_a = encode_no_special_tokens(
tokenizer,
f"a{placeholder_text}",
)
assert start_id == start_id_
start_id_, *candidate_ab, end_id_ = encode_no_special_tokens(
tokenizer,
f"a{placeholder_text}b",
)
assert start_id == start_id_ and end_id == end_id_
*candidate_b, end_id_ = encode_no_special_tokens(
tokenizer,
f"{placeholder_text}b",
)
assert end_id == end_id_
# Remove duplicates (need to convert to tuple to be hashable)
unique_candidates = {
tuple(c)
for c in [candidate_basic, candidate_a, candidate_ab, candidate_b]
}
# Convert back to list
return [list(c) for c in unique_candidates]
def apply_placeholders(
token_ids: List[int],
placeholder_ids: List[int],
get_replacement_ids: Callable[[], List[int]],
) -> Optional[PlaceholderRange]:
"""
Find the first occurrence of :code:`placeholder_ids`,
and replace it with the output of :code:`get_replacement_ids`.
This function updates :code:`token_ids` in place.
"""
placeholder_length = len(placeholder_ids)
for start_idx in range(len(token_ids) - placeholder_length + 1):
if token_ids[start_idx:placeholder_length] == placeholder_ids:
token_ids[start_idx:placeholder_length] = get_replacement_ids()
return PlaceholderRange(offset=start_idx,
length=placeholder_length)
return None
class MultiModalProcessor:
"""
Helper class to process multi-modal inputs to be used in vLLM.
"""
def __init__(
self,
ctx: InputProcessingContext,
metadata: MultiModalProcessingMetadata,
) -> None:
super().__init__()
self.ctx = ctx
self.metadata = metadata
def __call__(
self,
prompt: str,
mm_data: MultiModalDataDict,
mm_processor_kwargs: Mapping[str, object],
) -> MultiModalInputsV2:
return self.apply(prompt, mm_data, mm_processor_kwargs)
def apply(
self,
prompt: str,
mm_data: MultiModalDataDict,
mm_processor_kwargs: Mapping[str, object],
) -> MultiModalInputsV2:
tokenizer = self.ctx.tokenizer
hf_processor = self.ctx.get_hf_processor()
processed_inputs = hf_processor(
text=prompt, # type: ignore
**mm_data,
**mm_processor_kwargs,
)
new_token_ids, = processed_inputs.pop("input_ids").tolist()
mm_kwargs = MultiModalKwargs(processed_inputs)
mm_placeholders: Mapping[str, List[PlaceholderRange]] = {}
for modality, orig_inputs in to_multi_format(mm_data).items():
assert isinstance(orig_inputs, list)
metadata = self.metadata[modality]
placeholder_replacements = metadata.placeholder_replacements
modality_placeholders: List[PlaceholderRange] = []
for item_idx, orig_item in enumerate(orig_inputs):
for match_text, replace_fn in placeholder_replacements.items():
candidates = candidate_placeholders(tokenizer, match_text)
get_replacement_ids = partial(
replace_fn,
orig_item,
processed_inputs,
item_idx,
)
for match_ids in candidates:
# TODO(youkaichao): Don't update new_token_ids
placeholders = apply_placeholders(
new_token_ids,
match_ids,
get_replacement_ids,
)
if placeholders is not None:
modality_placeholders.append(placeholders)
# yapf: disable
mm_placeholders[modality] = modality_placeholders # type: ignore[index]
# yapf: enable
return MultiModalInputsV2(
type="multimodal",
prompt=prompt,
prompt_token_ids=new_token_ids,
mm_kwargs=mm_kwargs,
mm_placeholders=mm_placeholders,
)