vlm: support video as an input modality (#5888)
This commit is contained in:
@@ -65,6 +65,8 @@ class GenerateReqInput:
|
||||
] = None
|
||||
# The audio input. Like image data, it can be a file name, a url, or base64 encoded string.
|
||||
audio_data: Optional[Union[List[AudioDataItem], AudioDataItem]] = None
|
||||
# The video input. Like image data, it can be a file name, a url, or base64 encoded string.
|
||||
video_data: Optional[Union[List[List[str]], List[str], str]] = None
|
||||
# The sampling_params. See descriptions below.
|
||||
sampling_params: Optional[Union[List[Dict], Dict]] = None
|
||||
# The request id.
|
||||
@@ -110,7 +112,11 @@ class GenerateReqInput:
|
||||
data_parallel_rank: Optional[int] = None
|
||||
|
||||
def contains_mm_input(self) -> bool:
|
||||
return has_valid_data(self.image_data) or has_valid_data(self.audio_data)
|
||||
return (
|
||||
has_valid_data(self.image_data)
|
||||
or has_valid_data(self.video_data)
|
||||
or has_valid_data(self.audio_data)
|
||||
)
|
||||
|
||||
def normalize_batch_and_arguments(self):
|
||||
"""
|
||||
@@ -232,6 +238,7 @@ class GenerateReqInput:
|
||||
self._normalize_rid(num)
|
||||
self._normalize_lora_paths(num)
|
||||
self._normalize_image_data(num)
|
||||
self._normalize_video_data(num)
|
||||
self._normalize_audio_data(num)
|
||||
self._normalize_sampling_params(num)
|
||||
self._normalize_logprob_params(num)
|
||||
@@ -300,6 +307,15 @@ class GenerateReqInput:
|
||||
self.image_data = wrapped_images * self.parallel_sample_num
|
||||
self.modalities = ["image"] * num
|
||||
|
||||
def _normalize_video_data(self, num):
|
||||
"""Normalize video data for batch processing."""
|
||||
if self.video_data is None:
|
||||
self.video_data = [None] * num
|
||||
elif not isinstance(self.video_data, list):
|
||||
self.video_data = [self.video_data] * num
|
||||
elif isinstance(self.video_data, list):
|
||||
self.video_data = self.video_data * self.parallel_sample_num
|
||||
|
||||
def _normalize_audio_data(self, num):
|
||||
"""Normalize audio data for batch processing."""
|
||||
if self.audio_data is None:
|
||||
@@ -408,6 +424,7 @@ class GenerateReqInput:
|
||||
self.input_embeds[i] if self.input_embeds is not None else None
|
||||
),
|
||||
image_data=self.image_data[i],
|
||||
video_data=self.video_data[i],
|
||||
audio_data=self.audio_data[i],
|
||||
sampling_params=self.sampling_params[i],
|
||||
rid=self.rid[i],
|
||||
@@ -507,6 +524,8 @@ class EmbeddingReqInput:
|
||||
image_data: Optional[
|
||||
Union[List[List[Union[Image, str]]], List[Union[Image, str]], Union[Image, str]]
|
||||
] = None
|
||||
# The video input. Like image data, it can be a file name, a url, or base64 encoded string.
|
||||
video_data: Optional[Union[List[str], str]] = None
|
||||
# The audio input. Like image data, it can be a file name, a url, or base64 encoded string.
|
||||
audio_data: Optional[Union[List[str], str]] = None
|
||||
# The token ids for text; one can either specify text or input_ids.
|
||||
@@ -578,7 +597,11 @@ class EmbeddingReqInput:
|
||||
return self.rid
|
||||
|
||||
def contains_mm_input(self) -> bool:
|
||||
return has_valid_data(self.image_data) or has_valid_data(self.audio_data)
|
||||
return (
|
||||
has_valid_data(self.image_data)
|
||||
or has_valid_data(self.video_data)
|
||||
or has_valid_data(self.audio_data)
|
||||
)
|
||||
|
||||
def __getitem__(self, i):
|
||||
if self.is_cross_encoder_request:
|
||||
|
||||
@@ -4,7 +4,7 @@ Multi-modality utils
|
||||
|
||||
import hashlib
|
||||
from abc import abstractmethod
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
from typing import Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -76,6 +76,7 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
|
||||
This function will replace the data-tokens in between with pad_values accordingly
|
||||
"""
|
||||
pad_values = [item.pad_value for item in mm_inputs.mm_items]
|
||||
print(f"{mm_inputs.mm_items=}")
|
||||
data_token_pairs = self.data_token_id_pairs
|
||||
mm_inputs.data_offsets = []
|
||||
if data_token_pairs is None:
|
||||
@@ -159,10 +160,10 @@ class MultiModalityDataPaddingPatternMultimodalTokens(MultiModalityDataPaddingPa
|
||||
return ret_input_ids
|
||||
|
||||
|
||||
embedding_cache = None
|
||||
embedding_cache: Optional[MultiModalCache] = None
|
||||
|
||||
|
||||
def init_embedding_cache(max_size: int):
|
||||
def init_embedding_cache(max_size: int = 0):
|
||||
global embedding_cache
|
||||
embedding_cache = MultiModalCache(max_size)
|
||||
|
||||
@@ -255,6 +256,7 @@ def _get_chunked_prefill_embedding(
|
||||
continue
|
||||
embedding_items_per_req = embedding_items[items_size[i] : items_size[i + 1]]
|
||||
items_offset = items_offset_list[i]
|
||||
assert items_offset is not None, items_offset
|
||||
embedding_items_hash = get_embedding_hash(embedding_items_per_req)
|
||||
# if all items has been prefixed, we do not need to calculate embedding
|
||||
if all([offset_end < prefix_length[i] for _, offset_end in items_offset]):
|
||||
@@ -380,11 +382,9 @@ def embed_mm_inputs(
|
||||
extend_seq_lens: List[int],
|
||||
input_ids: torch.Tensor,
|
||||
input_embedding: nn.Embedding,
|
||||
image_data_embedding_func: Callable[
|
||||
[List[MultimodalDataItem]], torch.Tensor
|
||||
] = None,
|
||||
audio_data_embedding_func: Callable[
|
||||
[List[MultimodalDataItem]], torch.Tensor
|
||||
multimodal_model: nn.Module = None,
|
||||
data_embedding_func_mapping: Dict[
|
||||
Modality, Callable[[List[MultimodalDataItem]], torch.Tensor]
|
||||
] = None,
|
||||
placeholder_tokens: dict[Modality, List[int]] = None,
|
||||
) -> Optional[torch.Tensor]:
|
||||
@@ -397,8 +397,6 @@ def embed_mm_inputs(
|
||||
extend_seq_lens: Sequence lengths for each request
|
||||
input_ids: Input token IDs tensor
|
||||
input_embedding: Embedding layer for text tokens
|
||||
image_data_embedding_func: Function to embed image data
|
||||
audio_data_embedding_func: Function to embed audio data
|
||||
placeholder_tokens: Token IDs for multimodal placeholders (uses pad_values if None)
|
||||
|
||||
Returns:
|
||||
@@ -415,88 +413,53 @@ def embed_mm_inputs(
|
||||
item_flatten_list += [item for item in mm_inputs.mm_items if item is not None]
|
||||
|
||||
embeddings, masks = [], []
|
||||
|
||||
# 2. Get multimodal embedding separately
|
||||
# TODO: make this more generic
|
||||
# Try get image embedding if any
|
||||
if (
|
||||
any(True for item in item_flatten_list if item.is_image())
|
||||
and image_data_embedding_func
|
||||
):
|
||||
items = [item for item in item_flatten_list if item.is_image()]
|
||||
placeholder_tensor = torch.tensor(
|
||||
[item.pad_value for item in items],
|
||||
device=input_ids.device,
|
||||
# Try get mm embedding if any
|
||||
for modality in Modality.all():
|
||||
items = [
|
||||
item for item in item_flatten_list if item.is_modality(modality=modality)
|
||||
]
|
||||
embedder = (
|
||||
None
|
||||
if data_embedding_func_mapping is None
|
||||
else data_embedding_func_mapping.get(modality, None)
|
||||
)
|
||||
# calculate per request items length offset
|
||||
items_size = torch.zeros(len(mm_inputs_list) + 1, dtype=int)
|
||||
items_offsets = []
|
||||
for i, mm_inputs in enumerate(mm_inputs_list):
|
||||
image_items = [item for item in mm_inputs.mm_items if item.is_image()]
|
||||
items_size[i + 1] = len(image_items)
|
||||
items_offsets.append(
|
||||
flatten_nested_list(
|
||||
[
|
||||
item.image_offsets
|
||||
for item in mm_inputs.mm_items
|
||||
if item.is_image()
|
||||
]
|
||||
)
|
||||
if embedder is None:
|
||||
# "image", "video", etc
|
||||
modality_id = modality.name.lower()
|
||||
embedder = getattr(multimodal_model, f"get_{modality_id}_feature", None)
|
||||
if len(items) != 0 and embedder is not None:
|
||||
placeholder_tensor = torch.tensor(
|
||||
[item.pad_value for item in items],
|
||||
device=input_ids.device,
|
||||
)
|
||||
items_size = torch.cumsum(items_size, dim=0).tolist()
|
||||
|
||||
embedding, mask = get_embedding_and_mask(
|
||||
data_embedding_func=image_data_embedding_func,
|
||||
embedding_items=items,
|
||||
placeholder_tensor=placeholder_tensor,
|
||||
input_ids=input_ids,
|
||||
items_size=items_size,
|
||||
prefix_length=extend_prefix_lens,
|
||||
extend_length=extend_seq_lens,
|
||||
items_offset_list=items_offsets,
|
||||
)
|
||||
embeddings += [embedding]
|
||||
masks += [mask]
|
||||
|
||||
# Try get audio embedding if any
|
||||
if (
|
||||
any(True for item in item_flatten_list if item.is_audio())
|
||||
and audio_data_embedding_func
|
||||
):
|
||||
items = [item for item in item_flatten_list if item.is_audio()]
|
||||
placeholder_tensor = torch.tensor(
|
||||
[item.pad_value for item in items],
|
||||
device=input_ids.device,
|
||||
)
|
||||
items_offsets = []
|
||||
# calculate per request items length offset
|
||||
items_size = torch.zeros(len(mm_inputs_list) + 1, dtype=int)
|
||||
for i, mm_inputs in enumerate(mm_inputs_list):
|
||||
audio_items = [item for item in mm_inputs.mm_items if item.is_audio()]
|
||||
items_size[i + 1] = len(audio_items)
|
||||
items_offsets.append(
|
||||
flatten_nested_list(
|
||||
[
|
||||
item.audio_offsets
|
||||
for item in mm_inputs.mm_items
|
||||
if item.is_audio()
|
||||
]
|
||||
# calculate per request items length offset
|
||||
items_size = torch.zeros(len(mm_inputs_list) + 1, dtype=int)
|
||||
items_offsets = []
|
||||
for i, mm_inputs in enumerate(mm_inputs_list):
|
||||
mm_items = [
|
||||
item
|
||||
for item in mm_inputs.mm_items
|
||||
if item.is_modality(modality=modality)
|
||||
]
|
||||
items_size[i + 1] = len(mm_items)
|
||||
items_offsets.append(
|
||||
flatten_nested_list([item.offsets for item in mm_inputs.mm_items])
|
||||
)
|
||||
)
|
||||
items_size = torch.cumsum(items_size, dim=0)
|
||||
items_size = torch.cumsum(items_size, dim=0).tolist()
|
||||
|
||||
embedding, mask = get_embedding_and_mask(
|
||||
data_embedding_func=audio_data_embedding_func,
|
||||
embedding_items=items,
|
||||
placeholder_tensor=placeholder_tensor,
|
||||
input_ids=input_ids,
|
||||
items_size=items_size,
|
||||
prefix_length=extend_prefix_lens,
|
||||
extend_length=extend_seq_lens,
|
||||
items_offset_list=items_offsets,
|
||||
)
|
||||
embeddings += [embedding]
|
||||
masks += [mask]
|
||||
embedding, mask = get_embedding_and_mask(
|
||||
data_embedding_func=embedder,
|
||||
embedding_items=items,
|
||||
placeholder_tensor=placeholder_tensor,
|
||||
input_ids=input_ids,
|
||||
items_size=items_size,
|
||||
prefix_length=extend_prefix_lens,
|
||||
extend_length=extend_seq_lens,
|
||||
items_offset_list=items_offsets,
|
||||
)
|
||||
embeddings += [embedding]
|
||||
masks += [mask]
|
||||
|
||||
# 3. Get input embeddings
|
||||
vocab_size = input_embedding.num_embeddings
|
||||
@@ -523,11 +486,9 @@ def general_mm_embed_routine(
|
||||
input_ids: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
language_model: nn.Module,
|
||||
image_data_embedding_func: Optional[
|
||||
Callable[[List[MultimodalDataItem]], torch.Tensor]
|
||||
] = None,
|
||||
audio_data_embedding_func: Optional[
|
||||
Callable[[List[MultimodalDataItem]], torch.Tensor]
|
||||
multimodal_model: Optional[nn.Module] = None,
|
||||
data_embedding_funcs: Dict[
|
||||
Modality, Callable[[List[MultimodalDataItem]], torch.Tensor]
|
||||
] = None,
|
||||
placeholder_tokens: Optional[dict[Modality, List[int]]] = None,
|
||||
**kwargs,
|
||||
@@ -572,8 +533,8 @@ def general_mm_embed_routine(
|
||||
extend_seq_lens=extend_seq_lens,
|
||||
input_ids=input_ids,
|
||||
input_embedding=embed_tokens,
|
||||
image_data_embedding_func=image_data_embedding_func,
|
||||
audio_data_embedding_func=audio_data_embedding_func,
|
||||
multimodal_model=multimodal_model,
|
||||
data_embedding_func_mapping=data_embedding_funcs,
|
||||
placeholder_tokens=placeholder_tokens,
|
||||
)
|
||||
# once used, mm_inputs is useless, considering chunked-prefill is disabled for multimodal models
|
||||
|
||||
@@ -185,6 +185,10 @@ class Modality(Enum):
|
||||
f"Invalid modality string: {modality_str}. Valid modalities are: {[m.name for m in Modality]}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def all():
|
||||
return [Modality.IMAGE, Modality.VIDEO, Modality.AUDIO]
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MultimodalDataItem:
|
||||
@@ -200,7 +204,7 @@ class MultimodalDataItem:
|
||||
hash: int = None
|
||||
pad_value: int = None
|
||||
image_sizes: Tuple[int, int] = None
|
||||
image_offsets: Optional[list] = None
|
||||
offsets: Optional[list] = None
|
||||
|
||||
# the real data, pixel_values or audio_features
|
||||
# data: Union[List[torch.Tensor], List[np.ndarray]]
|
||||
@@ -253,12 +257,17 @@ class MultimodalDataItem:
|
||||
self.hash = hash_feature(self.audio_features)
|
||||
elif self.input_features is not None:
|
||||
self.hash = hash_feature(self.input_features)
|
||||
elif self.is_video():
|
||||
self.hash = hash_feature(self.pixel_values_videos)
|
||||
else:
|
||||
self.hash = hash_feature(self.pixel_values)
|
||||
|
||||
assert self.hash is not None
|
||||
self.pad_value = self.hash % (1 << 30)
|
||||
|
||||
def is_modality(self, modality: Modality) -> bool:
|
||||
return self.modality == modality
|
||||
|
||||
def is_audio(self):
|
||||
return (self.modality == Modality.AUDIO) and (
|
||||
self.precomputed_features is not None
|
||||
@@ -268,7 +277,7 @@ class MultimodalDataItem:
|
||||
|
||||
def is_image(self):
|
||||
return (
|
||||
self.modality == Modality.IMAGE or self.modality == Modality.MULTI_IMAGES
|
||||
self.is_modality(Modality.IMAGE) or self.is_modality(Modality.MULTI_IMAGES)
|
||||
) and (
|
||||
self.precomputed_features is not None
|
||||
or not MultimodalDataItem.is_empty_list(self.pixel_values)
|
||||
@@ -277,7 +286,7 @@ class MultimodalDataItem:
|
||||
def is_video(self):
|
||||
return (self.modality == Modality.VIDEO) and (
|
||||
self.precomputed_features is not None
|
||||
or not MultimodalDataItem.is_empty_list(self.pixel_values)
|
||||
or not MultimodalDataItem.is_empty_list(self.pixel_values_videos)
|
||||
)
|
||||
|
||||
def is_valid(self) -> bool:
|
||||
@@ -351,6 +360,7 @@ class MultimodalInputs:
|
||||
"im_token_id",
|
||||
"im_start_id",
|
||||
"im_end_id",
|
||||
"video_token_id",
|
||||
"slice_start_id",
|
||||
"slice_end_id",
|
||||
"audio_start_id",
|
||||
@@ -364,11 +374,12 @@ class MultimodalInputs:
|
||||
return ret
|
||||
|
||||
def contains_image_inputs(self) -> bool:
|
||||
""" """
|
||||
return any(item.is_image() for item in self.mm_items)
|
||||
|
||||
def contains_video_inputs(self) -> bool:
|
||||
return any(item.is_video() for item in self.mm_items)
|
||||
|
||||
def contains_audio_inputs(self) -> bool:
|
||||
""" """
|
||||
return any(item.is_audio() for item in self.mm_items)
|
||||
|
||||
def contains_mm_input(self) -> bool:
|
||||
|
||||
Reference in New Issue
Block a user