[Refactor] Multimodal data processing for VLM (#6659)
Signed-off-by: Xinyuan Tong <justinning0323@outlook.com>
This commit is contained in:
@@ -5,7 +5,8 @@ import multiprocessing as mp
|
||||
import os
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -16,16 +17,24 @@ from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
||||
from sglang.srt.utils import encode_video, load_audio, load_image
|
||||
|
||||
|
||||
class MultimodalInputFormat(Enum):
|
||||
"""Enum for different multimodal input formats."""
|
||||
|
||||
RAW_IMAGES = "raw_images"
|
||||
PRECOMPUTED_FEATURES = "precomputed_features"
|
||||
PIXEL_VALUES = "pixel_values"
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class BaseMultiModalProcessorOutput:
|
||||
# input_text, with each frame of video/image represented with a image_token
|
||||
input_text: str
|
||||
|
||||
# frames loaded from image and video, in given order
|
||||
images: Optional[list[Union[Image.Image, MultimodalDataItem]]] = None
|
||||
images: Optional[list[Union[Image.Image, dict]]] = None
|
||||
|
||||
# audios
|
||||
audios: Optional[list[Union[np.ndarray, MultimodalDataItem]]] = None
|
||||
audios: Optional[list[Union[np.ndarray, dict]]] = None
|
||||
|
||||
def normalize(self):
|
||||
for field_name in ["images", "audios"]:
|
||||
@@ -170,8 +179,6 @@ class BaseMultimodalProcessor(ABC):
|
||||
):
|
||||
"""Static method that can be pickled for multiprocessing"""
|
||||
if isinstance(data, dict):
|
||||
return MultimodalDataItem.from_dict(data)
|
||||
if isinstance(data, MultimodalDataItem):
|
||||
return data
|
||||
try:
|
||||
if is_audio:
|
||||
@@ -370,29 +377,180 @@ class BaseMultimodalProcessor(ABC):
|
||||
|
||||
return list(zip(indices_start.tolist(), indices_end.tolist()))
|
||||
|
||||
def mm_inputs_are_preprocessed(self, mm_inputs: Optional[list]):
|
||||
"""Returns true if all images are preprocessed, false if all are not, and error otherwise."""
|
||||
if not mm_inputs:
|
||||
return True
|
||||
ret = any(isinstance(mm_input, MultimodalDataItem) for mm_input in mm_inputs)
|
||||
if ret and not all(
|
||||
isinstance(mm_input, MultimodalDataItem) for mm_input in mm_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
"Unsupported: mixture of multimodal inputs where some but not all are preprocessed."
|
||||
)
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def _extract_processor_features(
|
||||
items: List[Any], attr_name: str
|
||||
items: List[dict], attr_name: str
|
||||
) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
Helper function to concat extracted attributes from processor output.
|
||||
"""
|
||||
values = [
|
||||
getattr(item, attr_name)
|
||||
for item in items
|
||||
if getattr(item, attr_name) is not None
|
||||
]
|
||||
return torch.concat(values) if values else None
|
||||
values = [value for item in items if (value := item.get(attr_name)) is not None]
|
||||
return torch.cat(values) if values else None
|
||||
|
||||
# When we assume that all the items have the same attributes
|
||||
def _extract_processor_features_from_all_attributes(
|
||||
self, items: List[dict]
|
||||
) -> dict:
|
||||
values = {}
|
||||
# Verify all items have the same keys
|
||||
first_keys = set(items[0].keys())
|
||||
for item in items[1:]:
|
||||
if set(item.keys()) != first_keys:
|
||||
raise ValueError(
|
||||
f"All items must have the same attributes. "
|
||||
f"First item has {first_keys}, but found {set(item.keys())}"
|
||||
)
|
||||
|
||||
# Process each attribute
|
||||
for k, v in items[0].items():
|
||||
if isinstance(v, list):
|
||||
values[k] = self._extract_processor_features(items, k)
|
||||
else:
|
||||
# Verify all items have the same value for non-list attributes
|
||||
for item in items[1:]:
|
||||
if item[k] != v:
|
||||
raise ValueError(
|
||||
f"All items must have the same value for attribute {k}. "
|
||||
f"First item has {v}, but found {item[k]}"
|
||||
)
|
||||
values[k] = v
|
||||
return values
|
||||
|
||||
def process_and_combine_mm_data(
|
||||
self, base_output: BaseMultiModalProcessorOutput
|
||||
) -> Tuple[Optional[MultimodalDataItem], torch.Tensor]:
|
||||
"""
|
||||
Process multimodal data and return the combined multimodal item and input_ids.
|
||||
Handles all three input formats at the same abstraction level.
|
||||
|
||||
Returns:
|
||||
Tuple of (combined_mm_item, input_ids)
|
||||
"""
|
||||
|
||||
def tokenize_text(input_text: str) -> torch.Tensor:
|
||||
"""Tokenize input text."""
|
||||
return self._processor.tokenizer(
|
||||
input_text,
|
||||
return_tensors="pt",
|
||||
add_special_tokens=True,
|
||||
).input_ids.flatten()
|
||||
|
||||
def categorize_mm_inputs(mm_inputs: List) -> MultimodalInputFormat:
|
||||
"""Categorize multimodal inputs and validate consistency."""
|
||||
try:
|
||||
has_image = False
|
||||
has_pixel_values = False
|
||||
has_precomputed_features = False
|
||||
|
||||
for mm_input in mm_inputs:
|
||||
if isinstance(mm_input, Image.Image):
|
||||
has_image = True
|
||||
elif isinstance(mm_input, dict):
|
||||
if mm_input.get("precomputed_features", None) is not None:
|
||||
has_precomputed_features = True
|
||||
elif mm_input.get("pixel_values", None) is not None:
|
||||
has_pixel_values = True
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid multimodal input: {mm_input}, expected dict with pixel_values or precomputed_features"
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid multimodal input: {mm_input}, expected Image.Image or dict"
|
||||
)
|
||||
|
||||
# Validate format consistency
|
||||
format_count = sum(
|
||||
[has_image, has_pixel_values, has_precomputed_features]
|
||||
)
|
||||
if format_count > 1:
|
||||
raise ValueError(
|
||||
"Unsupported: mixture of multimodal input formats. "
|
||||
f"Found formats: image={has_image}, pixel_values={has_pixel_values}, "
|
||||
f"precomputed_features={has_precomputed_features}"
|
||||
)
|
||||
|
||||
if has_image:
|
||||
return MultimodalInputFormat.RAW_IMAGES
|
||||
elif has_precomputed_features:
|
||||
return MultimodalInputFormat.PRECOMPUTED_FEATURES
|
||||
elif has_pixel_values:
|
||||
return MultimodalInputFormat.PIXEL_VALUES
|
||||
else:
|
||||
raise ValueError("No valid multimodal input format found")
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to categorize inputs: {e}")
|
||||
|
||||
def process_raw_images(
|
||||
base_output: BaseMultiModalProcessorOutput,
|
||||
) -> Tuple[MultimodalDataItem, torch.Tensor]:
|
||||
"""Process raw Image.Image objects using transformers processor."""
|
||||
ret = self.process_mm_data(
|
||||
input_text=base_output.input_text,
|
||||
images=base_output.images,
|
||||
)
|
||||
combined_mm_item = MultimodalDataItem(modality=Modality.IMAGE)
|
||||
|
||||
# Copy all fields from processor output except input_ids
|
||||
for key, value in ret.items():
|
||||
if key != "input_ids" and hasattr(combined_mm_item, key):
|
||||
setattr(combined_mm_item, key, value)
|
||||
|
||||
input_ids = ret["input_ids"].flatten()
|
||||
return combined_mm_item, input_ids
|
||||
|
||||
def process_precomputed_features(
|
||||
base_output: BaseMultiModalProcessorOutput,
|
||||
) -> Tuple[MultimodalDataItem, torch.Tensor]:
|
||||
"""Process inputs with precomputed features."""
|
||||
combined_mm_item = MultimodalDataItem(modality=Modality.IMAGE)
|
||||
combined_mm_item.precomputed_features = self._extract_processor_features(
|
||||
base_output.images, "precomputed_features"
|
||||
)
|
||||
input_ids = tokenize_text(base_output.input_text)
|
||||
return combined_mm_item, input_ids
|
||||
|
||||
def process_pixel_values(
|
||||
base_output: BaseMultiModalProcessorOutput,
|
||||
) -> Tuple[MultimodalDataItem, torch.Tensor]:
|
||||
"""Process inputs with pixel values."""
|
||||
values = self._extract_processor_features_from_all_attributes(
|
||||
base_output.images
|
||||
)
|
||||
combined_mm_item = MultimodalDataItem.from_dict(values)
|
||||
input_ids = tokenize_text(base_output.input_text)
|
||||
return combined_mm_item, input_ids
|
||||
|
||||
def finalize_mm_item(
|
||||
combined_mm_item: MultimodalDataItem, input_ids: torch.Tensor
|
||||
) -> MultimodalDataItem:
|
||||
"""Apply common post-processing to the multimodal item."""
|
||||
combined_mm_item.image_offsets = self.get_mm_items_offset(
|
||||
input_ids=input_ids,
|
||||
mm_token_id=self.IM_TOKEN_ID,
|
||||
)
|
||||
return combined_mm_item
|
||||
|
||||
# Main logic
|
||||
mm_inputs = base_output.images
|
||||
if not mm_inputs:
|
||||
# Return text-only case
|
||||
input_ids = tokenize_text(base_output.input_text)
|
||||
return None, input_ids
|
||||
|
||||
# Categorize input formats
|
||||
input_format = categorize_mm_inputs(mm_inputs)
|
||||
|
||||
# Process based on format
|
||||
if input_format == MultimodalInputFormat.RAW_IMAGES:
|
||||
combined_mm_item, input_ids = process_raw_images(base_output)
|
||||
elif input_format == MultimodalInputFormat.PRECOMPUTED_FEATURES:
|
||||
combined_mm_item, input_ids = process_precomputed_features(base_output)
|
||||
elif input_format == MultimodalInputFormat.PIXEL_VALUES:
|
||||
combined_mm_item, input_ids = process_pixel_values(base_output)
|
||||
else:
|
||||
raise ValueError(f"Unknown input format: {input_format}")
|
||||
|
||||
# Finalize with common processing
|
||||
combined_mm_item = finalize_mm_item(combined_mm_item, input_ids)
|
||||
return combined_mm_item, input_ids
|
||||
|
||||
@@ -27,6 +27,7 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
|
||||
)
|
||||
self.IM_START_TOKEN_ID = hf_config.boi_token_index
|
||||
self.IM_END_TOKEN_ID = hf_config.eoi_token_index
|
||||
self.IM_TOKEN_ID = hf_config.image_token_index
|
||||
|
||||
async def process_mm_data_async(
|
||||
self,
|
||||
@@ -42,49 +43,21 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
|
||||
if isinstance(image_data, str):
|
||||
image_data = [image_data]
|
||||
|
||||
image_token = self.IMAGE_TOKEN
|
||||
image_token_regex = self.IMAGE_TOKEN_REGEX
|
||||
base_output = self.load_mm_data(
|
||||
prompt=input_text,
|
||||
image_data=image_data,
|
||||
multimodal_tokens=MultimodalSpecialTokens(
|
||||
image_token=image_token, image_token_regex=image_token_regex
|
||||
image_token=self.IMAGE_TOKEN, image_token_regex=self.IMAGE_TOKEN_REGEX
|
||||
),
|
||||
max_req_input_len=max_req_input_len,
|
||||
discard_alpha_channel=True,
|
||||
)
|
||||
|
||||
images_are_preprocessed = self.mm_inputs_are_preprocessed(base_output.images)
|
||||
ret = self.process_mm_data(
|
||||
input_text=base_output.input_text,
|
||||
images=None if images_are_preprocessed else base_output.images,
|
||||
)
|
||||
|
||||
items = []
|
||||
input_ids = ret["input_ids"].flatten()
|
||||
image_offsets = self.get_mm_items_offset(
|
||||
input_ids=input_ids,
|
||||
mm_token_id=self.hf_config.image_token_index,
|
||||
)
|
||||
for i, image in enumerate(base_output.images):
|
||||
if images_are_preprocessed:
|
||||
pixel_values = image.pixel_values
|
||||
precomputed_features = image.precomputed_features
|
||||
else:
|
||||
pixel_values = ret["pixel_values"][i]
|
||||
precomputed_features = None
|
||||
|
||||
item = MultimodalDataItem(
|
||||
pixel_values=pixel_values,
|
||||
precomputed_features=precomputed_features,
|
||||
modality=Modality.IMAGE,
|
||||
image_offsets=image_offsets[i],
|
||||
)
|
||||
items += [item]
|
||||
combined_mm_item, input_ids = self.process_and_combine_mm_data(base_output)
|
||||
|
||||
return {
|
||||
"mm_items": items,
|
||||
"input_ids": input_ids.tolist(),
|
||||
"mm_items": [combined_mm_item] if combined_mm_item is not None else [],
|
||||
"im_start_id": self.IM_START_TOKEN_ID,
|
||||
"im_end_id": self.IM_END_TOKEN_ID,
|
||||
}
|
||||
|
||||
@@ -21,7 +21,7 @@ class KimiVLImageProcessor(SGLangBaseProcessor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
self.IMAGE_TOKEN = "<|media_pad|>"
|
||||
self.IMAGE_TOKEN_REGEX = re.compile(r"(?:<\|media_pad\|>)+")
|
||||
self.im_token_id = _processor.tokenizer.convert_tokens_to_ids(self.IMAGE_TOKEN)
|
||||
self.IM_TOKEN_ID = _processor.tokenizer.convert_tokens_to_ids(self.IMAGE_TOKEN)
|
||||
|
||||
async def process_mm_data_async(
|
||||
self,
|
||||
@@ -46,48 +46,10 @@ class KimiVLImageProcessor(SGLangBaseProcessor):
|
||||
max_req_input_len=max_req_input_len,
|
||||
)
|
||||
|
||||
images_are_preprocessed = self.mm_inputs_are_preprocessed(base_output.images)
|
||||
if not images_are_preprocessed:
|
||||
ret = self.process_mm_data(
|
||||
input_text=base_output.input_text,
|
||||
images=base_output.images,
|
||||
)
|
||||
input_ids = ret["input_ids"].flatten()
|
||||
image_grid_thws = ret["image_grid_hws"]
|
||||
pixel_values = ret["pixel_values"]
|
||||
precomputed_features = None
|
||||
else:
|
||||
input_ids = self._processor.tokenizer(
|
||||
base_output.input_text,
|
||||
return_tensors="pt",
|
||||
add_special_tokens=True,
|
||||
).input_ids.flatten()
|
||||
|
||||
image_grid_thws = self._extract_processor_features(
|
||||
base_output.images, "image_grid_thws"
|
||||
)
|
||||
precomputed_features = self._extract_processor_features(
|
||||
base_output.images, "precomputed_features"
|
||||
)
|
||||
pixel_values = self._extract_processor_features(
|
||||
base_output.images, "pixel_values"
|
||||
)
|
||||
|
||||
image_offsets = self.get_mm_items_offset(
|
||||
input_ids=input_ids,
|
||||
mm_token_id=self.im_token_id,
|
||||
)
|
||||
combined_mm_item, input_ids = self.process_and_combine_mm_data(base_output)
|
||||
|
||||
return {
|
||||
"input_ids": input_ids.tolist(),
|
||||
"mm_items": [
|
||||
MultimodalDataItem(
|
||||
pixel_values=pixel_values,
|
||||
image_grid_thws=image_grid_thws,
|
||||
precomputed_features=precomputed_features,
|
||||
modality=Modality.IMAGE,
|
||||
image_offsets=image_offsets,
|
||||
)
|
||||
],
|
||||
"im_token_id": self.im_token_id,
|
||||
"mm_items": [combined_mm_item] if combined_mm_item is not None else [],
|
||||
"im_token_id": self.IM_TOKEN_ID,
|
||||
}
|
||||
|
||||
@@ -32,8 +32,8 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
|
||||
)
|
||||
self.IM_START_TOKEN_ID = hf_config.vision_start_token_id
|
||||
self.IM_END_TOKEN_ID = hf_config.vision_end_token_id
|
||||
self.image_token_id = hf_config.image_token_id
|
||||
self.video_token_id = hf_config.video_token_id
|
||||
self.IM_TOKEN_ID = hf_config.image_token_id
|
||||
self.VIDEO_TOKEN_ID = hf_config.video_token_id
|
||||
self.vision_start_token_id = hf_config.vision_start_token_id
|
||||
self.vision_end_token_id = hf_config.vision_end_token_id
|
||||
self.NUM_TOKEN_PER_FRAME = 770
|
||||
@@ -125,72 +125,45 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
|
||||
async def resize_image_async(image):
|
||||
return resize_image(image)
|
||||
|
||||
images_are_preprocessed = self.mm_inputs_are_preprocessed(base_output.images)
|
||||
if base_output.images and not images_are_preprocessed:
|
||||
# Qwen-specific: resize images if they are raw Image objects
|
||||
if base_output.images and isinstance(base_output.images[0], Image.Image):
|
||||
resize_tasks = [resize_image_async(image) for image in base_output.images]
|
||||
base_output.images = await asyncio.gather(*resize_tasks)
|
||||
|
||||
ret = self.process_mm_data(
|
||||
input_text=base_output.input_text,
|
||||
images=None if images_are_preprocessed else base_output.images,
|
||||
)
|
||||
input_ids = ret["input_ids"].flatten().tolist()
|
||||
image_offsets = self.get_mm_items_offset(
|
||||
input_ids=ret["input_ids"].flatten(), mm_token_id=self.image_token_id
|
||||
)
|
||||
image_grid_thw = None
|
||||
video_grid_thw = None # TODO
|
||||
items = []
|
||||
|
||||
if base_output.images:
|
||||
if images_are_preprocessed:
|
||||
image_grid_thw = self._extract_processor_features(
|
||||
base_output.images, "image_grid_thws"
|
||||
)
|
||||
precomputed_features = self._extract_processor_features(
|
||||
base_output.images, "precomputed_features"
|
||||
)
|
||||
pixel_values = self._extract_processor_features(
|
||||
base_output.images, "pixel_values"
|
||||
)
|
||||
else:
|
||||
image_grid_thw = ret["image_grid_thw"]
|
||||
pixel_values = ret["pixel_values"]
|
||||
precomputed_features = None
|
||||
items += [
|
||||
MultimodalDataItem(
|
||||
pixel_values=pixel_values,
|
||||
image_grid_thws=image_grid_thw,
|
||||
video_grid_thws=video_grid_thw,
|
||||
precomputed_features=precomputed_features,
|
||||
image_offsets=image_offsets,
|
||||
modality=Modality.IMAGE,
|
||||
)
|
||||
]
|
||||
combined_mm_item, input_ids = self.process_and_combine_mm_data(base_output)
|
||||
|
||||
if combined_mm_item is None:
|
||||
# Note(Xinyuan): This is the case where image loading fails.
|
||||
return None
|
||||
|
||||
video_grid_thw = None # TODO
|
||||
second_per_grid_ts = getattr(combined_mm_item, "second_per_grid_ts", None)
|
||||
|
||||
mrope_positions, mrope_position_delta = MRotaryEmbedding.get_rope_index(
|
||||
spatial_merge_size=self.hf_config.vision_config.spatial_merge_size,
|
||||
image_token_id=self.image_token_id,
|
||||
video_token_id=self.video_token_id,
|
||||
image_token_id=self.IM_TOKEN_ID,
|
||||
video_token_id=self.VIDEO_TOKEN_ID,
|
||||
vision_start_token_id=self.vision_start_token_id,
|
||||
model_type=self.hf_config.model_type,
|
||||
tokens_per_second=getattr(
|
||||
self.hf_config.vision_config, "tokens_per_second", None
|
||||
),
|
||||
input_ids=torch.tensor(input_ids).unsqueeze(0),
|
||||
image_grid_thw=image_grid_thw,
|
||||
input_ids=input_ids.unsqueeze(0),
|
||||
image_grid_thw=combined_mm_item.image_grid_thw,
|
||||
video_grid_thw=video_grid_thw,
|
||||
second_per_grid_ts=ret.get("second_per_grid_ts", None),
|
||||
second_per_grid_ts=second_per_grid_ts,
|
||||
)
|
||||
mrope_positions = mrope_positions.squeeze(1)
|
||||
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"mm_items": items,
|
||||
"input_ids": input_ids.tolist(),
|
||||
"mm_items": [combined_mm_item],
|
||||
"im_start_id": self.IM_START_TOKEN_ID,
|
||||
"im_end_id": self.IM_END_TOKEN_ID,
|
||||
"im_token_id": self.image_token_id,
|
||||
"video_token_id": self.video_token_id,
|
||||
"im_token_id": self.IM_TOKEN_ID,
|
||||
"video_token_id": self.VIDEO_TOKEN_ID,
|
||||
"mrope_positions": mrope_positions,
|
||||
"mrope_position_delta": mrope_position_delta,
|
||||
}
|
||||
|
||||
@@ -188,7 +188,7 @@ class MultimodalDataItem:
|
||||
# the real data, pixel_values or audio_features
|
||||
# data: Union[List[torch.Tensor], List[np.ndarray]]
|
||||
pixel_values: Union[torch.Tensor, np.ndarray] = None
|
||||
image_grid_thws: Union[torch.Tensor, np.ndarray] = None
|
||||
image_grid_thw: Union[torch.Tensor, np.ndarray] = None
|
||||
video_grid_thws: Union[torch.Tensor, np.ndarray] = None
|
||||
|
||||
image_emb_mask: Optional[torch.Tensor] = None
|
||||
@@ -198,6 +198,9 @@ class MultimodalDataItem:
|
||||
# [num_images, (n, w, h)]
|
||||
tgt_size: Tuple[int, int] = None
|
||||
|
||||
# kimi-vl related
|
||||
image_grid_hws: Optional[List[torch.Tensor]] = None
|
||||
|
||||
audio_features: Union[torch.Tensor, np.ndarray] = None
|
||||
audio_feature_lens: Optional[List[torch.Tensor]] = None
|
||||
audio_offsets: Optional[List[Tuple[int, int]]] = None
|
||||
|
||||
@@ -286,14 +286,26 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
|
||||
all_pixel_values = flatten_nested_list([item.pixel_values for item in items])
|
||||
vision_outputs_list = []
|
||||
|
||||
for pixel_value in all_pixel_values:
|
||||
# Add batch dimension for single image processing
|
||||
pixel_value_batch = pixel_value.unsqueeze(0)
|
||||
pixel_value_batch = pixel_value_batch.to(device=self.vision_tower.device)
|
||||
pixel_value_batch = pixel_value_batch.to(dtype=self.language_model.dtype())
|
||||
for pixel_values_batch in all_pixel_values:
|
||||
# Normalize input shape to [batch_size, channels, height, width]
|
||||
if pixel_values_batch.dim() == 5:
|
||||
pixel_values_batch = pixel_values_batch.squeeze(0)
|
||||
elif pixel_values_batch.dim() == 3:
|
||||
pixel_values_batch = pixel_values_batch.unsqueeze(0)
|
||||
elif pixel_values_batch.dim() != 4:
|
||||
raise ValueError(
|
||||
f"Unexpected pixel_values shape: {pixel_values_batch.shape}"
|
||||
)
|
||||
|
||||
vision_output = self.vision_tower(pixel_values=pixel_value_batch)
|
||||
vision_outputs_list.append(vision_output)
|
||||
# Process each image in the batch
|
||||
batch_size = pixel_values_batch.shape[0]
|
||||
for i in range(batch_size):
|
||||
pixel_value = pixel_values_batch[i : i + 1] # Keep batch dimension as 1
|
||||
pixel_value = pixel_value.to(
|
||||
device=self.vision_tower.device, dtype=self.language_model.dtype()
|
||||
)
|
||||
vision_output = self.vision_tower(pixel_values=pixel_value)
|
||||
vision_outputs_list.append(vision_output)
|
||||
|
||||
# Concatenate all vision outputs
|
||||
vision_outputs = torch.cat(vision_outputs_list, dim=0)
|
||||
|
||||
@@ -144,10 +144,10 @@ class KimiVLForConditionalGeneration(nn.Module):
|
||||
.type(self.vision_tower.dtype)
|
||||
.to(self.vision_tower.device)
|
||||
)
|
||||
image_grid_thws = torch.concat(
|
||||
[item.image_grid_thws for item in items], dim=0
|
||||
).to(self.vision_tower.device)
|
||||
image_features = self.vision_tower(pixel_values, image_grid_thws)
|
||||
image_grid_hws = torch.cat([item.image_grid_hws for item in items], dim=0).to(
|
||||
self.vision_tower.device
|
||||
)
|
||||
image_features = self.vision_tower(pixel_values, image_grid_hws)
|
||||
assert isinstance(image_features, list)
|
||||
# lengths = [x.shape[0] for x in image_features]
|
||||
res = self.multi_modal_projector(torch.cat(image_features)) # .split(lengths)
|
||||
|
||||
@@ -503,10 +503,10 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
||||
pixel_values = torch.cat([item.pixel_values for item in items], dim=0).type(
|
||||
self.visual.dtype
|
||||
)
|
||||
image_grid_thws = torch.concat([item.image_grid_thws for item in items], dim=0)
|
||||
image_grid_thw = torch.concat([item.image_grid_thw for item in items], dim=0)
|
||||
assert pixel_values.dim() == 2, pixel_values.dim()
|
||||
assert image_grid_thws.dim() == 2, image_grid_thws.dim()
|
||||
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thws)
|
||||
assert image_grid_thw.dim() == 2, image_grid_thw.dim()
|
||||
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
|
||||
return image_embeds
|
||||
|
||||
def _process_video_input(self, video_input: Qwen2VLVideoInputs) -> torch.Tensor:
|
||||
|
||||
@@ -490,10 +490,10 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||
pixel_values = torch.cat([item.pixel_values for item in items], dim=0).type(
|
||||
self.visual.dtype
|
||||
)
|
||||
image_grid_thws = torch.concat([item.image_grid_thws for item in items], dim=0)
|
||||
image_grid_thw = torch.concat([item.image_grid_thw for item in items], dim=0)
|
||||
assert pixel_values.dim() == 2, pixel_values.dim()
|
||||
assert image_grid_thws.dim() == 2, image_grid_thws.dim()
|
||||
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thws)
|
||||
assert image_grid_thw.dim() == 2, image_grid_thw.dim()
|
||||
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
|
||||
return image_embeds
|
||||
|
||||
def _process_video_input(self, video_input: Qwen2VLVideoInputs) -> torch.Tensor:
|
||||
|
||||
Reference in New Issue
Block a user