[Refactor] Multimodal data processing for VLM (#6659)
Signed-off-by: Xinyuan Tong <justinning0323@outlook.com>
This commit is contained in:
@@ -132,7 +132,7 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"mm_item = dict(\n",
|
"mm_item = dict(\n",
|
||||||
" modality=\"IMAGE\",\n",
|
" modality=\"IMAGE\",\n",
|
||||||
" image_grid_thws=processed_prompt[\"image_grid_thw\"],\n",
|
" image_grid_thw=processed_prompt[\"image_grid_thw\"],\n",
|
||||||
" precomputed_features=precomputed_features,\n",
|
" precomputed_features=precomputed_features,\n",
|
||||||
")\n",
|
")\n",
|
||||||
"out = llm.generate(input_ids=input_ids, image_data=[mm_item])\n",
|
"out = llm.generate(input_ids=input_ids, image_data=[mm_item])\n",
|
||||||
|
|||||||
@@ -5,7 +5,8 @@ import multiprocessing as mp
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, List, Optional, Tuple, Union
|
from enum import Enum
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -16,16 +17,24 @@ from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
|||||||
from sglang.srt.utils import encode_video, load_audio, load_image
|
from sglang.srt.utils import encode_video, load_audio, load_image
|
||||||
|
|
||||||
|
|
||||||
|
class MultimodalInputFormat(Enum):
|
||||||
|
"""Enum for different multimodal input formats."""
|
||||||
|
|
||||||
|
RAW_IMAGES = "raw_images"
|
||||||
|
PRECOMPUTED_FEATURES = "precomputed_features"
|
||||||
|
PIXEL_VALUES = "pixel_values"
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class BaseMultiModalProcessorOutput:
|
class BaseMultiModalProcessorOutput:
|
||||||
# input_text, with each frame of video/image represented with a image_token
|
# input_text, with each frame of video/image represented with a image_token
|
||||||
input_text: str
|
input_text: str
|
||||||
|
|
||||||
# frames loaded from image and video, in given order
|
# frames loaded from image and video, in given order
|
||||||
images: Optional[list[Union[Image.Image, MultimodalDataItem]]] = None
|
images: Optional[list[Union[Image.Image, dict]]] = None
|
||||||
|
|
||||||
# audios
|
# audios
|
||||||
audios: Optional[list[Union[np.ndarray, MultimodalDataItem]]] = None
|
audios: Optional[list[Union[np.ndarray, dict]]] = None
|
||||||
|
|
||||||
def normalize(self):
|
def normalize(self):
|
||||||
for field_name in ["images", "audios"]:
|
for field_name in ["images", "audios"]:
|
||||||
@@ -170,8 +179,6 @@ class BaseMultimodalProcessor(ABC):
|
|||||||
):
|
):
|
||||||
"""Static method that can be pickled for multiprocessing"""
|
"""Static method that can be pickled for multiprocessing"""
|
||||||
if isinstance(data, dict):
|
if isinstance(data, dict):
|
||||||
return MultimodalDataItem.from_dict(data)
|
|
||||||
if isinstance(data, MultimodalDataItem):
|
|
||||||
return data
|
return data
|
||||||
try:
|
try:
|
||||||
if is_audio:
|
if is_audio:
|
||||||
@@ -370,29 +377,180 @@ class BaseMultimodalProcessor(ABC):
|
|||||||
|
|
||||||
return list(zip(indices_start.tolist(), indices_end.tolist()))
|
return list(zip(indices_start.tolist(), indices_end.tolist()))
|
||||||
|
|
||||||
def mm_inputs_are_preprocessed(self, mm_inputs: Optional[list]):
|
|
||||||
"""Returns true if all images are preprocessed, false if all are not, and error otherwise."""
|
|
||||||
if not mm_inputs:
|
|
||||||
return True
|
|
||||||
ret = any(isinstance(mm_input, MultimodalDataItem) for mm_input in mm_inputs)
|
|
||||||
if ret and not all(
|
|
||||||
isinstance(mm_input, MultimodalDataItem) for mm_input in mm_inputs
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
"Unsupported: mixture of multimodal inputs where some but not all are preprocessed."
|
|
||||||
)
|
|
||||||
return ret
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _extract_processor_features(
|
def _extract_processor_features(
|
||||||
items: List[Any], attr_name: str
|
items: List[dict], attr_name: str
|
||||||
) -> Optional[torch.Tensor]:
|
) -> Optional[torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Helper function to concat extracted attributes from processor output.
|
Helper function to concat extracted attributes from processor output.
|
||||||
"""
|
"""
|
||||||
values = [
|
values = [value for item in items if (value := item.get(attr_name)) is not None]
|
||||||
getattr(item, attr_name)
|
return torch.cat(values) if values else None
|
||||||
for item in items
|
|
||||||
if getattr(item, attr_name) is not None
|
# When we assume that all the items have the same attributes
|
||||||
]
|
def _extract_processor_features_from_all_attributes(
|
||||||
return torch.concat(values) if values else None
|
self, items: List[dict]
|
||||||
|
) -> dict:
|
||||||
|
values = {}
|
||||||
|
# Verify all items have the same keys
|
||||||
|
first_keys = set(items[0].keys())
|
||||||
|
for item in items[1:]:
|
||||||
|
if set(item.keys()) != first_keys:
|
||||||
|
raise ValueError(
|
||||||
|
f"All items must have the same attributes. "
|
||||||
|
f"First item has {first_keys}, but found {set(item.keys())}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process each attribute
|
||||||
|
for k, v in items[0].items():
|
||||||
|
if isinstance(v, list):
|
||||||
|
values[k] = self._extract_processor_features(items, k)
|
||||||
|
else:
|
||||||
|
# Verify all items have the same value for non-list attributes
|
||||||
|
for item in items[1:]:
|
||||||
|
if item[k] != v:
|
||||||
|
raise ValueError(
|
||||||
|
f"All items must have the same value for attribute {k}. "
|
||||||
|
f"First item has {v}, but found {item[k]}"
|
||||||
|
)
|
||||||
|
values[k] = v
|
||||||
|
return values
|
||||||
|
|
||||||
|
def process_and_combine_mm_data(
|
||||||
|
self, base_output: BaseMultiModalProcessorOutput
|
||||||
|
) -> Tuple[Optional[MultimodalDataItem], torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Process multimodal data and return the combined multimodal item and input_ids.
|
||||||
|
Handles all three input formats at the same abstraction level.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (combined_mm_item, input_ids)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def tokenize_text(input_text: str) -> torch.Tensor:
|
||||||
|
"""Tokenize input text."""
|
||||||
|
return self._processor.tokenizer(
|
||||||
|
input_text,
|
||||||
|
return_tensors="pt",
|
||||||
|
add_special_tokens=True,
|
||||||
|
).input_ids.flatten()
|
||||||
|
|
||||||
|
def categorize_mm_inputs(mm_inputs: List) -> MultimodalInputFormat:
|
||||||
|
"""Categorize multimodal inputs and validate consistency."""
|
||||||
|
try:
|
||||||
|
has_image = False
|
||||||
|
has_pixel_values = False
|
||||||
|
has_precomputed_features = False
|
||||||
|
|
||||||
|
for mm_input in mm_inputs:
|
||||||
|
if isinstance(mm_input, Image.Image):
|
||||||
|
has_image = True
|
||||||
|
elif isinstance(mm_input, dict):
|
||||||
|
if mm_input.get("precomputed_features", None) is not None:
|
||||||
|
has_precomputed_features = True
|
||||||
|
elif mm_input.get("pixel_values", None) is not None:
|
||||||
|
has_pixel_values = True
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid multimodal input: {mm_input}, expected dict with pixel_values or precomputed_features"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid multimodal input: {mm_input}, expected Image.Image or dict"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate format consistency
|
||||||
|
format_count = sum(
|
||||||
|
[has_image, has_pixel_values, has_precomputed_features]
|
||||||
|
)
|
||||||
|
if format_count > 1:
|
||||||
|
raise ValueError(
|
||||||
|
"Unsupported: mixture of multimodal input formats. "
|
||||||
|
f"Found formats: image={has_image}, pixel_values={has_pixel_values}, "
|
||||||
|
f"precomputed_features={has_precomputed_features}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if has_image:
|
||||||
|
return MultimodalInputFormat.RAW_IMAGES
|
||||||
|
elif has_precomputed_features:
|
||||||
|
return MultimodalInputFormat.PRECOMPUTED_FEATURES
|
||||||
|
elif has_pixel_values:
|
||||||
|
return MultimodalInputFormat.PIXEL_VALUES
|
||||||
|
else:
|
||||||
|
raise ValueError("No valid multimodal input format found")
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Failed to categorize inputs: {e}")
|
||||||
|
|
||||||
|
def process_raw_images(
|
||||||
|
base_output: BaseMultiModalProcessorOutput,
|
||||||
|
) -> Tuple[MultimodalDataItem, torch.Tensor]:
|
||||||
|
"""Process raw Image.Image objects using transformers processor."""
|
||||||
|
ret = self.process_mm_data(
|
||||||
|
input_text=base_output.input_text,
|
||||||
|
images=base_output.images,
|
||||||
|
)
|
||||||
|
combined_mm_item = MultimodalDataItem(modality=Modality.IMAGE)
|
||||||
|
|
||||||
|
# Copy all fields from processor output except input_ids
|
||||||
|
for key, value in ret.items():
|
||||||
|
if key != "input_ids" and hasattr(combined_mm_item, key):
|
||||||
|
setattr(combined_mm_item, key, value)
|
||||||
|
|
||||||
|
input_ids = ret["input_ids"].flatten()
|
||||||
|
return combined_mm_item, input_ids
|
||||||
|
|
||||||
|
def process_precomputed_features(
|
||||||
|
base_output: BaseMultiModalProcessorOutput,
|
||||||
|
) -> Tuple[MultimodalDataItem, torch.Tensor]:
|
||||||
|
"""Process inputs with precomputed features."""
|
||||||
|
combined_mm_item = MultimodalDataItem(modality=Modality.IMAGE)
|
||||||
|
combined_mm_item.precomputed_features = self._extract_processor_features(
|
||||||
|
base_output.images, "precomputed_features"
|
||||||
|
)
|
||||||
|
input_ids = tokenize_text(base_output.input_text)
|
||||||
|
return combined_mm_item, input_ids
|
||||||
|
|
||||||
|
def process_pixel_values(
|
||||||
|
base_output: BaseMultiModalProcessorOutput,
|
||||||
|
) -> Tuple[MultimodalDataItem, torch.Tensor]:
|
||||||
|
"""Process inputs with pixel values."""
|
||||||
|
values = self._extract_processor_features_from_all_attributes(
|
||||||
|
base_output.images
|
||||||
|
)
|
||||||
|
combined_mm_item = MultimodalDataItem.from_dict(values)
|
||||||
|
input_ids = tokenize_text(base_output.input_text)
|
||||||
|
return combined_mm_item, input_ids
|
||||||
|
|
||||||
|
def finalize_mm_item(
|
||||||
|
combined_mm_item: MultimodalDataItem, input_ids: torch.Tensor
|
||||||
|
) -> MultimodalDataItem:
|
||||||
|
"""Apply common post-processing to the multimodal item."""
|
||||||
|
combined_mm_item.image_offsets = self.get_mm_items_offset(
|
||||||
|
input_ids=input_ids,
|
||||||
|
mm_token_id=self.IM_TOKEN_ID,
|
||||||
|
)
|
||||||
|
return combined_mm_item
|
||||||
|
|
||||||
|
# Main logic
|
||||||
|
mm_inputs = base_output.images
|
||||||
|
if not mm_inputs:
|
||||||
|
# Return text-only case
|
||||||
|
input_ids = tokenize_text(base_output.input_text)
|
||||||
|
return None, input_ids
|
||||||
|
|
||||||
|
# Categorize input formats
|
||||||
|
input_format = categorize_mm_inputs(mm_inputs)
|
||||||
|
|
||||||
|
# Process based on format
|
||||||
|
if input_format == MultimodalInputFormat.RAW_IMAGES:
|
||||||
|
combined_mm_item, input_ids = process_raw_images(base_output)
|
||||||
|
elif input_format == MultimodalInputFormat.PRECOMPUTED_FEATURES:
|
||||||
|
combined_mm_item, input_ids = process_precomputed_features(base_output)
|
||||||
|
elif input_format == MultimodalInputFormat.PIXEL_VALUES:
|
||||||
|
combined_mm_item, input_ids = process_pixel_values(base_output)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown input format: {input_format}")
|
||||||
|
|
||||||
|
# Finalize with common processing
|
||||||
|
combined_mm_item = finalize_mm_item(combined_mm_item, input_ids)
|
||||||
|
return combined_mm_item, input_ids
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
|
|||||||
)
|
)
|
||||||
self.IM_START_TOKEN_ID = hf_config.boi_token_index
|
self.IM_START_TOKEN_ID = hf_config.boi_token_index
|
||||||
self.IM_END_TOKEN_ID = hf_config.eoi_token_index
|
self.IM_END_TOKEN_ID = hf_config.eoi_token_index
|
||||||
|
self.IM_TOKEN_ID = hf_config.image_token_index
|
||||||
|
|
||||||
async def process_mm_data_async(
|
async def process_mm_data_async(
|
||||||
self,
|
self,
|
||||||
@@ -42,49 +43,21 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
|
|||||||
if isinstance(image_data, str):
|
if isinstance(image_data, str):
|
||||||
image_data = [image_data]
|
image_data = [image_data]
|
||||||
|
|
||||||
image_token = self.IMAGE_TOKEN
|
|
||||||
image_token_regex = self.IMAGE_TOKEN_REGEX
|
|
||||||
base_output = self.load_mm_data(
|
base_output = self.load_mm_data(
|
||||||
prompt=input_text,
|
prompt=input_text,
|
||||||
image_data=image_data,
|
image_data=image_data,
|
||||||
multimodal_tokens=MultimodalSpecialTokens(
|
multimodal_tokens=MultimodalSpecialTokens(
|
||||||
image_token=image_token, image_token_regex=image_token_regex
|
image_token=self.IMAGE_TOKEN, image_token_regex=self.IMAGE_TOKEN_REGEX
|
||||||
),
|
),
|
||||||
max_req_input_len=max_req_input_len,
|
max_req_input_len=max_req_input_len,
|
||||||
discard_alpha_channel=True,
|
discard_alpha_channel=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
images_are_preprocessed = self.mm_inputs_are_preprocessed(base_output.images)
|
combined_mm_item, input_ids = self.process_and_combine_mm_data(base_output)
|
||||||
ret = self.process_mm_data(
|
|
||||||
input_text=base_output.input_text,
|
|
||||||
images=None if images_are_preprocessed else base_output.images,
|
|
||||||
)
|
|
||||||
|
|
||||||
items = []
|
|
||||||
input_ids = ret["input_ids"].flatten()
|
|
||||||
image_offsets = self.get_mm_items_offset(
|
|
||||||
input_ids=input_ids,
|
|
||||||
mm_token_id=self.hf_config.image_token_index,
|
|
||||||
)
|
|
||||||
for i, image in enumerate(base_output.images):
|
|
||||||
if images_are_preprocessed:
|
|
||||||
pixel_values = image.pixel_values
|
|
||||||
precomputed_features = image.precomputed_features
|
|
||||||
else:
|
|
||||||
pixel_values = ret["pixel_values"][i]
|
|
||||||
precomputed_features = None
|
|
||||||
|
|
||||||
item = MultimodalDataItem(
|
|
||||||
pixel_values=pixel_values,
|
|
||||||
precomputed_features=precomputed_features,
|
|
||||||
modality=Modality.IMAGE,
|
|
||||||
image_offsets=image_offsets[i],
|
|
||||||
)
|
|
||||||
items += [item]
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"mm_items": items,
|
|
||||||
"input_ids": input_ids.tolist(),
|
"input_ids": input_ids.tolist(),
|
||||||
|
"mm_items": [combined_mm_item] if combined_mm_item is not None else [],
|
||||||
"im_start_id": self.IM_START_TOKEN_ID,
|
"im_start_id": self.IM_START_TOKEN_ID,
|
||||||
"im_end_id": self.IM_END_TOKEN_ID,
|
"im_end_id": self.IM_END_TOKEN_ID,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ class KimiVLImageProcessor(SGLangBaseProcessor):
|
|||||||
super().__init__(hf_config, server_args, _processor)
|
super().__init__(hf_config, server_args, _processor)
|
||||||
self.IMAGE_TOKEN = "<|media_pad|>"
|
self.IMAGE_TOKEN = "<|media_pad|>"
|
||||||
self.IMAGE_TOKEN_REGEX = re.compile(r"(?:<\|media_pad\|>)+")
|
self.IMAGE_TOKEN_REGEX = re.compile(r"(?:<\|media_pad\|>)+")
|
||||||
self.im_token_id = _processor.tokenizer.convert_tokens_to_ids(self.IMAGE_TOKEN)
|
self.IM_TOKEN_ID = _processor.tokenizer.convert_tokens_to_ids(self.IMAGE_TOKEN)
|
||||||
|
|
||||||
async def process_mm_data_async(
|
async def process_mm_data_async(
|
||||||
self,
|
self,
|
||||||
@@ -46,48 +46,10 @@ class KimiVLImageProcessor(SGLangBaseProcessor):
|
|||||||
max_req_input_len=max_req_input_len,
|
max_req_input_len=max_req_input_len,
|
||||||
)
|
)
|
||||||
|
|
||||||
images_are_preprocessed = self.mm_inputs_are_preprocessed(base_output.images)
|
combined_mm_item, input_ids = self.process_and_combine_mm_data(base_output)
|
||||||
if not images_are_preprocessed:
|
|
||||||
ret = self.process_mm_data(
|
|
||||||
input_text=base_output.input_text,
|
|
||||||
images=base_output.images,
|
|
||||||
)
|
|
||||||
input_ids = ret["input_ids"].flatten()
|
|
||||||
image_grid_thws = ret["image_grid_hws"]
|
|
||||||
pixel_values = ret["pixel_values"]
|
|
||||||
precomputed_features = None
|
|
||||||
else:
|
|
||||||
input_ids = self._processor.tokenizer(
|
|
||||||
base_output.input_text,
|
|
||||||
return_tensors="pt",
|
|
||||||
add_special_tokens=True,
|
|
||||||
).input_ids.flatten()
|
|
||||||
|
|
||||||
image_grid_thws = self._extract_processor_features(
|
|
||||||
base_output.images, "image_grid_thws"
|
|
||||||
)
|
|
||||||
precomputed_features = self._extract_processor_features(
|
|
||||||
base_output.images, "precomputed_features"
|
|
||||||
)
|
|
||||||
pixel_values = self._extract_processor_features(
|
|
||||||
base_output.images, "pixel_values"
|
|
||||||
)
|
|
||||||
|
|
||||||
image_offsets = self.get_mm_items_offset(
|
|
||||||
input_ids=input_ids,
|
|
||||||
mm_token_id=self.im_token_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"input_ids": input_ids.tolist(),
|
"input_ids": input_ids.tolist(),
|
||||||
"mm_items": [
|
"mm_items": [combined_mm_item] if combined_mm_item is not None else [],
|
||||||
MultimodalDataItem(
|
"im_token_id": self.IM_TOKEN_ID,
|
||||||
pixel_values=pixel_values,
|
|
||||||
image_grid_thws=image_grid_thws,
|
|
||||||
precomputed_features=precomputed_features,
|
|
||||||
modality=Modality.IMAGE,
|
|
||||||
image_offsets=image_offsets,
|
|
||||||
)
|
|
||||||
],
|
|
||||||
"im_token_id": self.im_token_id,
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -32,8 +32,8 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
|
|||||||
)
|
)
|
||||||
self.IM_START_TOKEN_ID = hf_config.vision_start_token_id
|
self.IM_START_TOKEN_ID = hf_config.vision_start_token_id
|
||||||
self.IM_END_TOKEN_ID = hf_config.vision_end_token_id
|
self.IM_END_TOKEN_ID = hf_config.vision_end_token_id
|
||||||
self.image_token_id = hf_config.image_token_id
|
self.IM_TOKEN_ID = hf_config.image_token_id
|
||||||
self.video_token_id = hf_config.video_token_id
|
self.VIDEO_TOKEN_ID = hf_config.video_token_id
|
||||||
self.vision_start_token_id = hf_config.vision_start_token_id
|
self.vision_start_token_id = hf_config.vision_start_token_id
|
||||||
self.vision_end_token_id = hf_config.vision_end_token_id
|
self.vision_end_token_id = hf_config.vision_end_token_id
|
||||||
self.NUM_TOKEN_PER_FRAME = 770
|
self.NUM_TOKEN_PER_FRAME = 770
|
||||||
@@ -125,72 +125,45 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
|
|||||||
async def resize_image_async(image):
|
async def resize_image_async(image):
|
||||||
return resize_image(image)
|
return resize_image(image)
|
||||||
|
|
||||||
images_are_preprocessed = self.mm_inputs_are_preprocessed(base_output.images)
|
# Qwen-specific: resize images if they are raw Image objects
|
||||||
if base_output.images and not images_are_preprocessed:
|
if base_output.images and isinstance(base_output.images[0], Image.Image):
|
||||||
resize_tasks = [resize_image_async(image) for image in base_output.images]
|
resize_tasks = [resize_image_async(image) for image in base_output.images]
|
||||||
base_output.images = await asyncio.gather(*resize_tasks)
|
base_output.images = await asyncio.gather(*resize_tasks)
|
||||||
|
|
||||||
ret = self.process_mm_data(
|
|
||||||
input_text=base_output.input_text,
|
|
||||||
images=None if images_are_preprocessed else base_output.images,
|
|
||||||
)
|
|
||||||
input_ids = ret["input_ids"].flatten().tolist()
|
|
||||||
image_offsets = self.get_mm_items_offset(
|
|
||||||
input_ids=ret["input_ids"].flatten(), mm_token_id=self.image_token_id
|
|
||||||
)
|
|
||||||
image_grid_thw = None
|
|
||||||
video_grid_thw = None # TODO
|
video_grid_thw = None # TODO
|
||||||
items = []
|
|
||||||
|
|
||||||
if base_output.images:
|
combined_mm_item, input_ids = self.process_and_combine_mm_data(base_output)
|
||||||
if images_are_preprocessed:
|
|
||||||
image_grid_thw = self._extract_processor_features(
|
if combined_mm_item is None:
|
||||||
base_output.images, "image_grid_thws"
|
# Note(Xinyuan): This is the case where image loading fails.
|
||||||
)
|
return None
|
||||||
precomputed_features = self._extract_processor_features(
|
|
||||||
base_output.images, "precomputed_features"
|
video_grid_thw = None # TODO
|
||||||
)
|
second_per_grid_ts = getattr(combined_mm_item, "second_per_grid_ts", None)
|
||||||
pixel_values = self._extract_processor_features(
|
|
||||||
base_output.images, "pixel_values"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
image_grid_thw = ret["image_grid_thw"]
|
|
||||||
pixel_values = ret["pixel_values"]
|
|
||||||
precomputed_features = None
|
|
||||||
items += [
|
|
||||||
MultimodalDataItem(
|
|
||||||
pixel_values=pixel_values,
|
|
||||||
image_grid_thws=image_grid_thw,
|
|
||||||
video_grid_thws=video_grid_thw,
|
|
||||||
precomputed_features=precomputed_features,
|
|
||||||
image_offsets=image_offsets,
|
|
||||||
modality=Modality.IMAGE,
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
mrope_positions, mrope_position_delta = MRotaryEmbedding.get_rope_index(
|
mrope_positions, mrope_position_delta = MRotaryEmbedding.get_rope_index(
|
||||||
spatial_merge_size=self.hf_config.vision_config.spatial_merge_size,
|
spatial_merge_size=self.hf_config.vision_config.spatial_merge_size,
|
||||||
image_token_id=self.image_token_id,
|
image_token_id=self.IM_TOKEN_ID,
|
||||||
video_token_id=self.video_token_id,
|
video_token_id=self.VIDEO_TOKEN_ID,
|
||||||
vision_start_token_id=self.vision_start_token_id,
|
vision_start_token_id=self.vision_start_token_id,
|
||||||
model_type=self.hf_config.model_type,
|
model_type=self.hf_config.model_type,
|
||||||
tokens_per_second=getattr(
|
tokens_per_second=getattr(
|
||||||
self.hf_config.vision_config, "tokens_per_second", None
|
self.hf_config.vision_config, "tokens_per_second", None
|
||||||
),
|
),
|
||||||
input_ids=torch.tensor(input_ids).unsqueeze(0),
|
input_ids=input_ids.unsqueeze(0),
|
||||||
image_grid_thw=image_grid_thw,
|
image_grid_thw=combined_mm_item.image_grid_thw,
|
||||||
video_grid_thw=video_grid_thw,
|
video_grid_thw=video_grid_thw,
|
||||||
second_per_grid_ts=ret.get("second_per_grid_ts", None),
|
second_per_grid_ts=second_per_grid_ts,
|
||||||
)
|
)
|
||||||
mrope_positions = mrope_positions.squeeze(1)
|
mrope_positions = mrope_positions.squeeze(1)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"input_ids": input_ids,
|
"input_ids": input_ids.tolist(),
|
||||||
"mm_items": items,
|
"mm_items": [combined_mm_item],
|
||||||
"im_start_id": self.IM_START_TOKEN_ID,
|
"im_start_id": self.IM_START_TOKEN_ID,
|
||||||
"im_end_id": self.IM_END_TOKEN_ID,
|
"im_end_id": self.IM_END_TOKEN_ID,
|
||||||
"im_token_id": self.image_token_id,
|
"im_token_id": self.IM_TOKEN_ID,
|
||||||
"video_token_id": self.video_token_id,
|
"video_token_id": self.VIDEO_TOKEN_ID,
|
||||||
"mrope_positions": mrope_positions,
|
"mrope_positions": mrope_positions,
|
||||||
"mrope_position_delta": mrope_position_delta,
|
"mrope_position_delta": mrope_position_delta,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -188,7 +188,7 @@ class MultimodalDataItem:
|
|||||||
# the real data, pixel_values or audio_features
|
# the real data, pixel_values or audio_features
|
||||||
# data: Union[List[torch.Tensor], List[np.ndarray]]
|
# data: Union[List[torch.Tensor], List[np.ndarray]]
|
||||||
pixel_values: Union[torch.Tensor, np.ndarray] = None
|
pixel_values: Union[torch.Tensor, np.ndarray] = None
|
||||||
image_grid_thws: Union[torch.Tensor, np.ndarray] = None
|
image_grid_thw: Union[torch.Tensor, np.ndarray] = None
|
||||||
video_grid_thws: Union[torch.Tensor, np.ndarray] = None
|
video_grid_thws: Union[torch.Tensor, np.ndarray] = None
|
||||||
|
|
||||||
image_emb_mask: Optional[torch.Tensor] = None
|
image_emb_mask: Optional[torch.Tensor] = None
|
||||||
@@ -198,6 +198,9 @@ class MultimodalDataItem:
|
|||||||
# [num_images, (n, w, h)]
|
# [num_images, (n, w, h)]
|
||||||
tgt_size: Tuple[int, int] = None
|
tgt_size: Tuple[int, int] = None
|
||||||
|
|
||||||
|
# kimi-vl related
|
||||||
|
image_grid_hws: Optional[List[torch.Tensor]] = None
|
||||||
|
|
||||||
audio_features: Union[torch.Tensor, np.ndarray] = None
|
audio_features: Union[torch.Tensor, np.ndarray] = None
|
||||||
audio_feature_lens: Optional[List[torch.Tensor]] = None
|
audio_feature_lens: Optional[List[torch.Tensor]] = None
|
||||||
audio_offsets: Optional[List[Tuple[int, int]]] = None
|
audio_offsets: Optional[List[Tuple[int, int]]] = None
|
||||||
|
|||||||
@@ -286,14 +286,26 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
|
|||||||
all_pixel_values = flatten_nested_list([item.pixel_values for item in items])
|
all_pixel_values = flatten_nested_list([item.pixel_values for item in items])
|
||||||
vision_outputs_list = []
|
vision_outputs_list = []
|
||||||
|
|
||||||
for pixel_value in all_pixel_values:
|
for pixel_values_batch in all_pixel_values:
|
||||||
# Add batch dimension for single image processing
|
# Normalize input shape to [batch_size, channels, height, width]
|
||||||
pixel_value_batch = pixel_value.unsqueeze(0)
|
if pixel_values_batch.dim() == 5:
|
||||||
pixel_value_batch = pixel_value_batch.to(device=self.vision_tower.device)
|
pixel_values_batch = pixel_values_batch.squeeze(0)
|
||||||
pixel_value_batch = pixel_value_batch.to(dtype=self.language_model.dtype())
|
elif pixel_values_batch.dim() == 3:
|
||||||
|
pixel_values_batch = pixel_values_batch.unsqueeze(0)
|
||||||
|
elif pixel_values_batch.dim() != 4:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unexpected pixel_values shape: {pixel_values_batch.shape}"
|
||||||
|
)
|
||||||
|
|
||||||
vision_output = self.vision_tower(pixel_values=pixel_value_batch)
|
# Process each image in the batch
|
||||||
vision_outputs_list.append(vision_output)
|
batch_size = pixel_values_batch.shape[0]
|
||||||
|
for i in range(batch_size):
|
||||||
|
pixel_value = pixel_values_batch[i : i + 1] # Keep batch dimension as 1
|
||||||
|
pixel_value = pixel_value.to(
|
||||||
|
device=self.vision_tower.device, dtype=self.language_model.dtype()
|
||||||
|
)
|
||||||
|
vision_output = self.vision_tower(pixel_values=pixel_value)
|
||||||
|
vision_outputs_list.append(vision_output)
|
||||||
|
|
||||||
# Concatenate all vision outputs
|
# Concatenate all vision outputs
|
||||||
vision_outputs = torch.cat(vision_outputs_list, dim=0)
|
vision_outputs = torch.cat(vision_outputs_list, dim=0)
|
||||||
|
|||||||
@@ -144,10 +144,10 @@ class KimiVLForConditionalGeneration(nn.Module):
|
|||||||
.type(self.vision_tower.dtype)
|
.type(self.vision_tower.dtype)
|
||||||
.to(self.vision_tower.device)
|
.to(self.vision_tower.device)
|
||||||
)
|
)
|
||||||
image_grid_thws = torch.concat(
|
image_grid_hws = torch.cat([item.image_grid_hws for item in items], dim=0).to(
|
||||||
[item.image_grid_thws for item in items], dim=0
|
self.vision_tower.device
|
||||||
).to(self.vision_tower.device)
|
)
|
||||||
image_features = self.vision_tower(pixel_values, image_grid_thws)
|
image_features = self.vision_tower(pixel_values, image_grid_hws)
|
||||||
assert isinstance(image_features, list)
|
assert isinstance(image_features, list)
|
||||||
# lengths = [x.shape[0] for x in image_features]
|
# lengths = [x.shape[0] for x in image_features]
|
||||||
res = self.multi_modal_projector(torch.cat(image_features)) # .split(lengths)
|
res = self.multi_modal_projector(torch.cat(image_features)) # .split(lengths)
|
||||||
|
|||||||
@@ -503,10 +503,10 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
|||||||
pixel_values = torch.cat([item.pixel_values for item in items], dim=0).type(
|
pixel_values = torch.cat([item.pixel_values for item in items], dim=0).type(
|
||||||
self.visual.dtype
|
self.visual.dtype
|
||||||
)
|
)
|
||||||
image_grid_thws = torch.concat([item.image_grid_thws for item in items], dim=0)
|
image_grid_thw = torch.concat([item.image_grid_thw for item in items], dim=0)
|
||||||
assert pixel_values.dim() == 2, pixel_values.dim()
|
assert pixel_values.dim() == 2, pixel_values.dim()
|
||||||
assert image_grid_thws.dim() == 2, image_grid_thws.dim()
|
assert image_grid_thw.dim() == 2, image_grid_thw.dim()
|
||||||
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thws)
|
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
|
||||||
return image_embeds
|
return image_embeds
|
||||||
|
|
||||||
def _process_video_input(self, video_input: Qwen2VLVideoInputs) -> torch.Tensor:
|
def _process_video_input(self, video_input: Qwen2VLVideoInputs) -> torch.Tensor:
|
||||||
|
|||||||
@@ -490,10 +490,10 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|||||||
pixel_values = torch.cat([item.pixel_values for item in items], dim=0).type(
|
pixel_values = torch.cat([item.pixel_values for item in items], dim=0).type(
|
||||||
self.visual.dtype
|
self.visual.dtype
|
||||||
)
|
)
|
||||||
image_grid_thws = torch.concat([item.image_grid_thws for item in items], dim=0)
|
image_grid_thw = torch.concat([item.image_grid_thw for item in items], dim=0)
|
||||||
assert pixel_values.dim() == 2, pixel_values.dim()
|
assert pixel_values.dim() == 2, pixel_values.dim()
|
||||||
assert image_grid_thws.dim() == 2, image_grid_thws.dim()
|
assert image_grid_thw.dim() == 2, image_grid_thw.dim()
|
||||||
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thws)
|
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
|
||||||
return image_embeds
|
return image_embeds
|
||||||
|
|
||||||
def _process_video_input(self, video_input: Qwen2VLVideoInputs) -> torch.Tensor:
|
def _process_video_input(self, video_input: Qwen2VLVideoInputs) -> torch.Tensor:
|
||||||
|
|||||||
@@ -156,7 +156,7 @@ class TestQwenVLUnderstandsImage(VLMInputTestBase, unittest.IsolatedAsyncioTestC
|
|||||||
def _pixel_values_image_data(self, processor_output):
|
def _pixel_values_image_data(self, processor_output):
|
||||||
return dict(
|
return dict(
|
||||||
modality="IMAGE",
|
modality="IMAGE",
|
||||||
image_grid_thws=processor_output["image_grid_thw"],
|
image_grid_thw=processor_output["image_grid_thw"],
|
||||||
pixel_values=processor_output["pixel_values"],
|
pixel_values=processor_output["pixel_values"],
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -207,8 +207,8 @@ class TestKimiVLImageUnderstandsImage(
|
|||||||
def _pixel_values_image_data(self, processor_output):
|
def _pixel_values_image_data(self, processor_output):
|
||||||
return dict(
|
return dict(
|
||||||
modality="IMAGE",
|
modality="IMAGE",
|
||||||
image_grid_thws=processor_output["image_grid_hws"],
|
|
||||||
pixel_values=processor_output["pixel_values"],
|
pixel_values=processor_output["pixel_values"],
|
||||||
|
image_grid_hws=processor_output["image_grid_hws"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user