vlm: support video as an input modality (#5888)

This commit is contained in:
Mick
2025-07-10 14:48:35 +08:00
committed by GitHub
parent 4ed57807c2
commit b5e3d6031c
42 changed files with 887 additions and 524 deletions

View File

@@ -65,6 +65,8 @@ class GenerateReqInput:
] = None
# The audio input. Like image data, it can be a file name, a url, or base64 encoded string.
audio_data: Optional[Union[List[AudioDataItem], AudioDataItem]] = None
# The video input. Like image data, it can be a file name, a url, or base64 encoded string.
video_data: Optional[Union[List[List[str]], List[str], str]] = None
# The sampling_params. See descriptions below.
sampling_params: Optional[Union[List[Dict], Dict]] = None
# The request id.
@@ -110,7 +112,11 @@ class GenerateReqInput:
data_parallel_rank: Optional[int] = None
def contains_mm_input(self) -> bool:
return has_valid_data(self.image_data) or has_valid_data(self.audio_data)
return (
has_valid_data(self.image_data)
or has_valid_data(self.video_data)
or has_valid_data(self.audio_data)
)
def normalize_batch_and_arguments(self):
"""
@@ -232,6 +238,7 @@ class GenerateReqInput:
self._normalize_rid(num)
self._normalize_lora_paths(num)
self._normalize_image_data(num)
self._normalize_video_data(num)
self._normalize_audio_data(num)
self._normalize_sampling_params(num)
self._normalize_logprob_params(num)
@@ -300,6 +307,15 @@ class GenerateReqInput:
self.image_data = wrapped_images * self.parallel_sample_num
self.modalities = ["image"] * num
def _normalize_video_data(self, num):
"""Normalize video data for batch processing."""
if self.video_data is None:
self.video_data = [None] * num
elif not isinstance(self.video_data, list):
self.video_data = [self.video_data] * num
elif isinstance(self.video_data, list):
self.video_data = self.video_data * self.parallel_sample_num
def _normalize_audio_data(self, num):
"""Normalize audio data for batch processing."""
if self.audio_data is None:
@@ -408,6 +424,7 @@ class GenerateReqInput:
self.input_embeds[i] if self.input_embeds is not None else None
),
image_data=self.image_data[i],
video_data=self.video_data[i],
audio_data=self.audio_data[i],
sampling_params=self.sampling_params[i],
rid=self.rid[i],
@@ -507,6 +524,8 @@ class EmbeddingReqInput:
image_data: Optional[
Union[List[List[Union[Image, str]]], List[Union[Image, str]], Union[Image, str]]
] = None
# The video input. Like image data, it can be a file name, a url, or base64 encoded string.
video_data: Optional[Union[List[str], str]] = None
# The audio input. Like image data, it can be a file name, a url, or base64 encoded string.
audio_data: Optional[Union[List[str], str]] = None
# The token ids for text; one can either specify text or input_ids.
@@ -578,7 +597,11 @@ class EmbeddingReqInput:
return self.rid
def contains_mm_input(self) -> bool:
return has_valid_data(self.image_data) or has_valid_data(self.audio_data)
return (
has_valid_data(self.image_data)
or has_valid_data(self.video_data)
or has_valid_data(self.audio_data)
)
def __getitem__(self, i):
if self.is_cross_encoder_request:

View File

@@ -4,7 +4,7 @@ Multi-modality utils
import hashlib
from abc import abstractmethod
from typing import Callable, List, Optional, Tuple
from typing import Callable, Dict, List, Optional, Tuple
import numpy as np
import torch
@@ -76,6 +76,7 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
This function will replace the data-tokens in between with pad_values accordingly
"""
pad_values = [item.pad_value for item in mm_inputs.mm_items]
print(f"{mm_inputs.mm_items=}")
data_token_pairs = self.data_token_id_pairs
mm_inputs.data_offsets = []
if data_token_pairs is None:
@@ -159,10 +160,10 @@ class MultiModalityDataPaddingPatternMultimodalTokens(MultiModalityDataPaddingPa
return ret_input_ids
embedding_cache = None
embedding_cache: Optional[MultiModalCache] = None
def init_embedding_cache(max_size: int):
def init_embedding_cache(max_size: int = 0):
global embedding_cache
embedding_cache = MultiModalCache(max_size)
@@ -255,6 +256,7 @@ def _get_chunked_prefill_embedding(
continue
embedding_items_per_req = embedding_items[items_size[i] : items_size[i + 1]]
items_offset = items_offset_list[i]
assert items_offset is not None, items_offset
embedding_items_hash = get_embedding_hash(embedding_items_per_req)
# if all items has been prefixed, we do not need to calculate embedding
if all([offset_end < prefix_length[i] for _, offset_end in items_offset]):
@@ -380,11 +382,9 @@ def embed_mm_inputs(
extend_seq_lens: List[int],
input_ids: torch.Tensor,
input_embedding: nn.Embedding,
image_data_embedding_func: Callable[
[List[MultimodalDataItem]], torch.Tensor
] = None,
audio_data_embedding_func: Callable[
[List[MultimodalDataItem]], torch.Tensor
multimodal_model: nn.Module = None,
data_embedding_func_mapping: Dict[
Modality, Callable[[List[MultimodalDataItem]], torch.Tensor]
] = None,
placeholder_tokens: dict[Modality, List[int]] = None,
) -> Optional[torch.Tensor]:
@@ -397,8 +397,6 @@ def embed_mm_inputs(
extend_seq_lens: Sequence lengths for each request
input_ids: Input token IDs tensor
input_embedding: Embedding layer for text tokens
image_data_embedding_func: Function to embed image data
audio_data_embedding_func: Function to embed audio data
placeholder_tokens: Token IDs for multimodal placeholders (uses pad_values if None)
Returns:
@@ -415,88 +413,53 @@ def embed_mm_inputs(
item_flatten_list += [item for item in mm_inputs.mm_items if item is not None]
embeddings, masks = [], []
# 2. Get multimodal embedding separately
# TODO: make this more generic
# Try get image embedding if any
if (
any(True for item in item_flatten_list if item.is_image())
and image_data_embedding_func
):
items = [item for item in item_flatten_list if item.is_image()]
placeholder_tensor = torch.tensor(
[item.pad_value for item in items],
device=input_ids.device,
# Try get mm embedding if any
for modality in Modality.all():
items = [
item for item in item_flatten_list if item.is_modality(modality=modality)
]
embedder = (
None
if data_embedding_func_mapping is None
else data_embedding_func_mapping.get(modality, None)
)
# calculate per request items length offset
items_size = torch.zeros(len(mm_inputs_list) + 1, dtype=int)
items_offsets = []
for i, mm_inputs in enumerate(mm_inputs_list):
image_items = [item for item in mm_inputs.mm_items if item.is_image()]
items_size[i + 1] = len(image_items)
items_offsets.append(
flatten_nested_list(
[
item.image_offsets
for item in mm_inputs.mm_items
if item.is_image()
]
)
if embedder is None:
# "image", "video", etc
modality_id = modality.name.lower()
embedder = getattr(multimodal_model, f"get_{modality_id}_feature", None)
if len(items) != 0 and embedder is not None:
placeholder_tensor = torch.tensor(
[item.pad_value for item in items],
device=input_ids.device,
)
items_size = torch.cumsum(items_size, dim=0).tolist()
embedding, mask = get_embedding_and_mask(
data_embedding_func=image_data_embedding_func,
embedding_items=items,
placeholder_tensor=placeholder_tensor,
input_ids=input_ids,
items_size=items_size,
prefix_length=extend_prefix_lens,
extend_length=extend_seq_lens,
items_offset_list=items_offsets,
)
embeddings += [embedding]
masks += [mask]
# Try get audio embedding if any
if (
any(True for item in item_flatten_list if item.is_audio())
and audio_data_embedding_func
):
items = [item for item in item_flatten_list if item.is_audio()]
placeholder_tensor = torch.tensor(
[item.pad_value for item in items],
device=input_ids.device,
)
items_offsets = []
# calculate per request items length offset
items_size = torch.zeros(len(mm_inputs_list) + 1, dtype=int)
for i, mm_inputs in enumerate(mm_inputs_list):
audio_items = [item for item in mm_inputs.mm_items if item.is_audio()]
items_size[i + 1] = len(audio_items)
items_offsets.append(
flatten_nested_list(
[
item.audio_offsets
for item in mm_inputs.mm_items
if item.is_audio()
]
# calculate per request items length offset
items_size = torch.zeros(len(mm_inputs_list) + 1, dtype=int)
items_offsets = []
for i, mm_inputs in enumerate(mm_inputs_list):
mm_items = [
item
for item in mm_inputs.mm_items
if item.is_modality(modality=modality)
]
items_size[i + 1] = len(mm_items)
items_offsets.append(
flatten_nested_list([item.offsets for item in mm_inputs.mm_items])
)
)
items_size = torch.cumsum(items_size, dim=0)
items_size = torch.cumsum(items_size, dim=0).tolist()
embedding, mask = get_embedding_and_mask(
data_embedding_func=audio_data_embedding_func,
embedding_items=items,
placeholder_tensor=placeholder_tensor,
input_ids=input_ids,
items_size=items_size,
prefix_length=extend_prefix_lens,
extend_length=extend_seq_lens,
items_offset_list=items_offsets,
)
embeddings += [embedding]
masks += [mask]
embedding, mask = get_embedding_and_mask(
data_embedding_func=embedder,
embedding_items=items,
placeholder_tensor=placeholder_tensor,
input_ids=input_ids,
items_size=items_size,
prefix_length=extend_prefix_lens,
extend_length=extend_seq_lens,
items_offset_list=items_offsets,
)
embeddings += [embedding]
masks += [mask]
# 3. Get input embeddings
vocab_size = input_embedding.num_embeddings
@@ -523,11 +486,9 @@ def general_mm_embed_routine(
input_ids: torch.Tensor,
forward_batch: ForwardBatch,
language_model: nn.Module,
image_data_embedding_func: Optional[
Callable[[List[MultimodalDataItem]], torch.Tensor]
] = None,
audio_data_embedding_func: Optional[
Callable[[List[MultimodalDataItem]], torch.Tensor]
multimodal_model: Optional[nn.Module] = None,
data_embedding_funcs: Dict[
Modality, Callable[[List[MultimodalDataItem]], torch.Tensor]
] = None,
placeholder_tokens: Optional[dict[Modality, List[int]]] = None,
**kwargs,
@@ -572,8 +533,8 @@ def general_mm_embed_routine(
extend_seq_lens=extend_seq_lens,
input_ids=input_ids,
input_embedding=embed_tokens,
image_data_embedding_func=image_data_embedding_func,
audio_data_embedding_func=audio_data_embedding_func,
multimodal_model=multimodal_model,
data_embedding_func_mapping=data_embedding_funcs,
placeholder_tokens=placeholder_tokens,
)
# once used, mm_inputs is useless, considering chunked-prefill is disabled for multimodal models

View File

@@ -185,6 +185,10 @@ class Modality(Enum):
f"Invalid modality string: {modality_str}. Valid modalities are: {[m.name for m in Modality]}"
)
@staticmethod
def all():
return [Modality.IMAGE, Modality.VIDEO, Modality.AUDIO]
@dataclasses.dataclass
class MultimodalDataItem:
@@ -200,7 +204,7 @@ class MultimodalDataItem:
hash: int = None
pad_value: int = None
image_sizes: Tuple[int, int] = None
image_offsets: Optional[list] = None
offsets: Optional[list] = None
# the real data, pixel_values or audio_features
# data: Union[List[torch.Tensor], List[np.ndarray]]
@@ -253,12 +257,17 @@ class MultimodalDataItem:
self.hash = hash_feature(self.audio_features)
elif self.input_features is not None:
self.hash = hash_feature(self.input_features)
elif self.is_video():
self.hash = hash_feature(self.pixel_values_videos)
else:
self.hash = hash_feature(self.pixel_values)
assert self.hash is not None
self.pad_value = self.hash % (1 << 30)
def is_modality(self, modality: Modality) -> bool:
return self.modality == modality
def is_audio(self):
return (self.modality == Modality.AUDIO) and (
self.precomputed_features is not None
@@ -268,7 +277,7 @@ class MultimodalDataItem:
def is_image(self):
return (
self.modality == Modality.IMAGE or self.modality == Modality.MULTI_IMAGES
self.is_modality(Modality.IMAGE) or self.is_modality(Modality.MULTI_IMAGES)
) and (
self.precomputed_features is not None
or not MultimodalDataItem.is_empty_list(self.pixel_values)
@@ -277,7 +286,7 @@ class MultimodalDataItem:
def is_video(self):
return (self.modality == Modality.VIDEO) and (
self.precomputed_features is not None
or not MultimodalDataItem.is_empty_list(self.pixel_values)
or not MultimodalDataItem.is_empty_list(self.pixel_values_videos)
)
def is_valid(self) -> bool:
@@ -351,6 +360,7 @@ class MultimodalInputs:
"im_token_id",
"im_start_id",
"im_end_id",
"video_token_id",
"slice_start_id",
"slice_end_id",
"audio_start_id",
@@ -364,11 +374,12 @@ class MultimodalInputs:
return ret
def contains_image_inputs(self) -> bool:
""" """
return any(item.is_image() for item in self.mm_items)
def contains_video_inputs(self) -> bool:
return any(item.is_video() for item in self.mm_items)
def contains_audio_inputs(self) -> bool:
""" """
return any(item.is_audio() for item in self.mm_items)
def contains_mm_input(self) -> bool: