790 lines
33 KiB
Python
790 lines
33 KiB
Python
|
|
import copy
|
||
|
|
import math
|
||
|
|
import os
|
||
|
|
from typing import Dict, List, Optional, Union
|
||
|
|
|
||
|
|
import numpy as np
|
||
|
|
import torch
|
||
|
|
from PIL import Image
|
||
|
|
from transformers.feature_extraction_utils import BatchFeature
|
||
|
|
from transformers.image_processing_utils import (
|
||
|
|
BaseImageProcessor,
|
||
|
|
get_size_dict,
|
||
|
|
)
|
||
|
|
from transformers.image_transforms import (
|
||
|
|
convert_to_rgb,
|
||
|
|
get_resize_output_image_size,
|
||
|
|
resize,
|
||
|
|
to_channel_dimension_format,
|
||
|
|
)
|
||
|
|
from transformers.image_utils import (
|
||
|
|
OPENAI_CLIP_MEAN,
|
||
|
|
OPENAI_CLIP_STD,
|
||
|
|
ChannelDimension,
|
||
|
|
ImageInput,
|
||
|
|
PILImageResampling,
|
||
|
|
get_image_size,
|
||
|
|
infer_channel_dimension_format,
|
||
|
|
is_scaled_image,
|
||
|
|
make_list_of_images,
|
||
|
|
to_numpy_array,
|
||
|
|
valid_images,
|
||
|
|
)
|
||
|
|
from transformers.utils import TensorType, logging
|
||
|
|
|
||
|
|
logger = logging.get_logger(__name__)
|
||
|
|
|
||
|
|
|
||
|
|
class HCXImageProcessor(BaseImageProcessor):
|
||
|
|
r"""
|
||
|
|
Constructs a VLM image processor. Based on [`CLIPImageProcessor`] with incorporation of additional techniques for processing high resolution images.
|
||
|
|
Args:
|
||
|
|
anyres: (bool) anyres 기능을 사용할지 안할지
|
||
|
|
unpad: (bool) anyres 사용시, unpad 기능 (순수 pad 영역에 해당하는 visual tokens 은 LLM input 에서 제거) 을 사용할지 안할지
|
||
|
|
num_queries_vis_abstractor: (int) 각 grid 에 대해서 resampler 를 사용하는 경우, visual query 수
|
||
|
|
possible_resolutions: (List) anyres 기능 사용시, 가능한 resolution 조합, 예: [[336, 336], [336, 672], [672, 336]]
|
||
|
|
patch_size: (int) ViT patch size
|
||
|
|
pad_to_square: (bool) 정사각형으로 padding 을 수행할지, 안할지를 결정. False 이면 정사각형이 아니기 때문에 center crop 을 거쳐 ViT 의 입력으로 들어감
|
||
|
|
"""
|
||
|
|
|
||
|
|
model_input_names = ["pixel_values"]
|
||
|
|
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
do_resize: bool = True,
|
||
|
|
size: Dict[str, int] = None,
|
||
|
|
anyres: bool = False,
|
||
|
|
unpad: bool = False,
|
||
|
|
num_queries_vis_abstractor_image: int = 81,
|
||
|
|
num_queries_vis_abstractor_video_slow: int = 81,
|
||
|
|
num_queries_vis_abstractor_video_fast: int = 9,
|
||
|
|
first_last_frames_slow_video: bool = False,
|
||
|
|
possible_resolutions: List = [],
|
||
|
|
patch_size: int = 14,
|
||
|
|
pad_to_square: bool = True,
|
||
|
|
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
||
|
|
do_center_crop: bool = True,
|
||
|
|
crop_size: Dict[str, int] = None,
|
||
|
|
do_rescale: bool = True,
|
||
|
|
rescale_factor: Union[int, float] = 1 / 255,
|
||
|
|
do_normalize: bool = True,
|
||
|
|
image_mean: Optional[Union[float, List[float]]] = None,
|
||
|
|
image_std: Optional[Union[float, List[float]]] = None,
|
||
|
|
do_convert_rgb: bool = True,
|
||
|
|
**kwargs,
|
||
|
|
) -> None:
|
||
|
|
super().__init__(**kwargs)
|
||
|
|
size = size if size is not None else {"shortest_edge": 336}
|
||
|
|
size = get_size_dict(size, default_to_square=False)
|
||
|
|
crop_size = crop_size if crop_size is not None else {"height": 336, "width": 336}
|
||
|
|
crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size")
|
||
|
|
|
||
|
|
self.do_resize = do_resize
|
||
|
|
self.size = size
|
||
|
|
self.anyres = anyres
|
||
|
|
self.unpad = unpad
|
||
|
|
self.num_queries_vis_abstractor_image = num_queries_vis_abstractor_image
|
||
|
|
self.num_queries_vis_abstractor_video_slow = num_queries_vis_abstractor_video_slow
|
||
|
|
self.num_queries_vis_abstractor_video_fast = num_queries_vis_abstractor_video_fast
|
||
|
|
self.first_last_frames_slow_video = first_last_frames_slow_video
|
||
|
|
self.possible_resolutions = [_resolution for _resolution in possible_resolutions]
|
||
|
|
self.patch_size = patch_size
|
||
|
|
self.pad_to_square = pad_to_square
|
||
|
|
self.resample = resample
|
||
|
|
self.do_center_crop = do_center_crop
|
||
|
|
self.crop_size = crop_size
|
||
|
|
self.do_rescale = do_rescale
|
||
|
|
self.rescale_factor = rescale_factor
|
||
|
|
self.do_normalize = do_normalize
|
||
|
|
self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
|
||
|
|
self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
|
||
|
|
self.do_convert_rgb = do_convert_rgb
|
||
|
|
|
||
|
|
def resize(
|
||
|
|
self,
|
||
|
|
image: np.ndarray,
|
||
|
|
size: Dict[str, int],
|
||
|
|
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
||
|
|
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||
|
|
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||
|
|
**kwargs,
|
||
|
|
) -> np.ndarray:
|
||
|
|
default_to_square = True
|
||
|
|
if "shortest_edge" in size:
|
||
|
|
size = size["shortest_edge"]
|
||
|
|
default_to_square = False
|
||
|
|
elif "height" in size and "width" in size:
|
||
|
|
size = (size["height"], size["width"])
|
||
|
|
else:
|
||
|
|
raise ValueError("Size must contain either 'shortest_edge' or 'height' and 'width'.")
|
||
|
|
|
||
|
|
output_size = get_resize_output_image_size(
|
||
|
|
image,
|
||
|
|
size=size,
|
||
|
|
default_to_square=default_to_square,
|
||
|
|
input_data_format=input_data_format,
|
||
|
|
)
|
||
|
|
|
||
|
|
return resize(
|
||
|
|
image,
|
||
|
|
size=output_size,
|
||
|
|
resample=resample,
|
||
|
|
data_format=data_format,
|
||
|
|
input_data_format=input_data_format,
|
||
|
|
**kwargs,
|
||
|
|
)
|
||
|
|
|
||
|
|
def _preprocess(
|
||
|
|
self,
|
||
|
|
images: ImageInput,
|
||
|
|
do_resize: bool = None,
|
||
|
|
size: Dict[str, int] = None,
|
||
|
|
resample: PILImageResampling = None,
|
||
|
|
do_center_crop: bool = None,
|
||
|
|
crop_size: int = None,
|
||
|
|
do_rescale: bool = None,
|
||
|
|
rescale_factor: float = None,
|
||
|
|
do_normalize: bool = None,
|
||
|
|
image_mean: Optional[Union[float, List[float]]] = None,
|
||
|
|
image_std: Optional[Union[float, List[float]]] = None,
|
||
|
|
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
|
||
|
|
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||
|
|
) -> Image.Image:
|
||
|
|
images = make_list_of_images(images)
|
||
|
|
|
||
|
|
if do_resize:
|
||
|
|
images = [
|
||
|
|
self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
|
||
|
|
for image in images
|
||
|
|
]
|
||
|
|
|
||
|
|
if do_center_crop:
|
||
|
|
images = [
|
||
|
|
self.center_crop(image=image, size=crop_size, input_data_format=input_data_format) for image in images
|
||
|
|
]
|
||
|
|
|
||
|
|
if do_rescale:
|
||
|
|
images = [
|
||
|
|
self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) for image in images
|
||
|
|
]
|
||
|
|
|
||
|
|
if do_normalize:
|
||
|
|
images = [
|
||
|
|
self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
|
||
|
|
for image in images
|
||
|
|
]
|
||
|
|
|
||
|
|
images = [
|
||
|
|
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
|
||
|
|
]
|
||
|
|
|
||
|
|
return images
|
||
|
|
|
||
|
|
def _resize_for_local_grids(
|
||
|
|
self, image: np.array, target_resolution: tuple, resample, input_data_format: ChannelDimension
|
||
|
|
) -> np.array:
|
||
|
|
new_height, new_width = _get_local_grids_output_size(image, target_resolution, input_data_format)
|
||
|
|
|
||
|
|
# Resize the image
|
||
|
|
resized_image = resize(image, (new_height, new_width), resample=resample, input_data_format=input_data_format)
|
||
|
|
|
||
|
|
return resized_image
|
||
|
|
|
||
|
|
def _pad_for_patching(
|
||
|
|
self, image: np.array, target_resolution: tuple, input_data_format: ChannelDimension
|
||
|
|
) -> np.array:
|
||
|
|
"""
|
||
|
|
Pad an image to a target resolution while maintaining aspect ratio.
|
||
|
|
"""
|
||
|
|
target_height, target_width = target_resolution
|
||
|
|
|
||
|
|
background_color = tuple(int(x * 255) for x in self.image_mean)
|
||
|
|
padded_image = pad(
|
||
|
|
image,
|
||
|
|
target_size=(target_height, target_width),
|
||
|
|
background_color=background_color,
|
||
|
|
input_data_format=input_data_format,
|
||
|
|
)
|
||
|
|
|
||
|
|
return padded_image
|
||
|
|
|
||
|
|
def get_image_grids(
|
||
|
|
self,
|
||
|
|
image: np.array,
|
||
|
|
possible_resolutions,
|
||
|
|
grid_size: int,
|
||
|
|
resample: PILImageResampling,
|
||
|
|
data_format: ChannelDimension,
|
||
|
|
input_data_format: ChannelDimension,
|
||
|
|
) -> List[np.array]:
|
||
|
|
if not isinstance(possible_resolutions, list):
|
||
|
|
raise ValueError("possible_resolutions must be a list of possible resolutions.")
|
||
|
|
|
||
|
|
image_size = get_image_size(image, channel_dim=input_data_format)
|
||
|
|
best_resolution = select_best_resolution(image_size, possible_resolutions)
|
||
|
|
resized_image = self._resize_for_local_grids(
|
||
|
|
image, best_resolution, resample=resample, input_data_format=input_data_format
|
||
|
|
)
|
||
|
|
padded_image = self._pad_for_patching(resized_image, best_resolution, input_data_format=input_data_format)
|
||
|
|
local_grids = divide_to_grids(padded_image, grid_size=grid_size, input_data_format=input_data_format)
|
||
|
|
|
||
|
|
# make sure that all patches are in the input data format
|
||
|
|
local_grids = [
|
||
|
|
to_channel_dimension_format(grid, channel_dim=data_format, input_channel_dim=input_data_format)
|
||
|
|
for grid in local_grids
|
||
|
|
]
|
||
|
|
|
||
|
|
return local_grids
|
||
|
|
|
||
|
|
def preprocess(
|
||
|
|
self,
|
||
|
|
images: ImageInput,
|
||
|
|
do_resize: bool = None,
|
||
|
|
size: Dict[str, int] = None,
|
||
|
|
anyres: bool = None,
|
||
|
|
unpad: bool = None,
|
||
|
|
is_video: bool = False,
|
||
|
|
num_queries_vis_abstractor_image: int = None,
|
||
|
|
num_queries_vis_abstractor_video_slow: int = None,
|
||
|
|
num_queries_vis_abstractor_video_fast: int = None,
|
||
|
|
first_last_frames_slow_video: bool = None,
|
||
|
|
possible_resolutions: List = None,
|
||
|
|
patch_size: int = None,
|
||
|
|
pad_to_square: bool = None,
|
||
|
|
resample: PILImageResampling = None,
|
||
|
|
do_center_crop: bool = None,
|
||
|
|
crop_size: int = None,
|
||
|
|
do_rescale: bool = None,
|
||
|
|
rescale_factor: float = None,
|
||
|
|
do_normalize: bool = None,
|
||
|
|
image_mean: Optional[Union[float, List[float]]] = None,
|
||
|
|
image_std: Optional[Union[float, List[float]]] = None,
|
||
|
|
do_convert_rgb: bool = None,
|
||
|
|
return_tensors: Optional[Union[str, TensorType]] = None,
|
||
|
|
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
|
||
|
|
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||
|
|
return_dummy_image: bool = False,
|
||
|
|
first_last_frames_slow: bool = False,
|
||
|
|
is_first_or_last_frames: bool = False,
|
||
|
|
**kwargs,
|
||
|
|
):
|
||
|
|
"""
|
||
|
|
HCXVisionImageProcessor 로 image tensor, original image size (width, height), visual tokens
|
||
|
|
:return pixel_values: List of 4D tensor 로 image tensor
|
||
|
|
:return image_sizes: List of Dict 로 image width, height [{"width": image 1 의 width, "height": image 1 의 height}, {"width": image 2 의 width, "height": image 2 의 height}, ...]
|
||
|
|
:return vision_query_lengths: List of int 로 각 image 가 LLM 입력으로 전달될때 변환되는 visual token 수
|
||
|
|
"""
|
||
|
|
|
||
|
|
do_resize = do_resize if do_resize is not None else self.do_resize
|
||
|
|
size = size if size is not None else self.size
|
||
|
|
size = get_size_dict(size, param_name="size", default_to_square=False)
|
||
|
|
anyres = anyres if anyres is not None else self.anyres
|
||
|
|
unpad = unpad if unpad is not None else self.unpad
|
||
|
|
num_queries_vis_abstractor_image = (
|
||
|
|
num_queries_vis_abstractor_image
|
||
|
|
if num_queries_vis_abstractor_image is not None
|
||
|
|
else self.num_queries_vis_abstractor_image
|
||
|
|
)
|
||
|
|
num_queries_vis_abstractor_video_slow = (
|
||
|
|
num_queries_vis_abstractor_video_slow
|
||
|
|
if num_queries_vis_abstractor_video_slow is not None
|
||
|
|
else self.num_queries_vis_abstractor_video_slow
|
||
|
|
)
|
||
|
|
num_queries_vis_abstractor_video_fast = (
|
||
|
|
num_queries_vis_abstractor_video_fast
|
||
|
|
if num_queries_vis_abstractor_video_fast is not None
|
||
|
|
else self.num_queries_vis_abstractor_video_fast
|
||
|
|
)
|
||
|
|
first_last_frames_slow_video = (
|
||
|
|
first_last_frames_slow_video
|
||
|
|
if first_last_frames_slow_video is not None
|
||
|
|
else self.first_last_frames_slow_video
|
||
|
|
)
|
||
|
|
possible_resolutions = possible_resolutions if possible_resolutions is not None else self.possible_resolutions
|
||
|
|
patch_size = patch_size if patch_size is not None else self.patch_size
|
||
|
|
pad_to_square = pad_to_square if pad_to_square is not None else self.pad_to_square
|
||
|
|
resample = resample if resample is not None else self.resample
|
||
|
|
do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
|
||
|
|
crop_size = crop_size if crop_size is not None else self.crop_size
|
||
|
|
crop_size = get_size_dict(crop_size, param_name="crop_size", default_to_square=True)
|
||
|
|
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
||
|
|
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
||
|
|
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
||
|
|
image_mean = image_mean if image_mean is not None else self.image_mean
|
||
|
|
image_std = image_std if image_std is not None else self.image_std
|
||
|
|
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
|
||
|
|
|
||
|
|
if is_video:
|
||
|
|
num_queries_vis_abstractor = num_queries_vis_abstractor_video_fast
|
||
|
|
num_queries_vis_abstractor_slow = num_queries_vis_abstractor_video_slow
|
||
|
|
unpad = False
|
||
|
|
else:
|
||
|
|
num_queries_vis_abstractor = num_queries_vis_abstractor_image
|
||
|
|
num_queries_vis_abstractor_slow = 0
|
||
|
|
|
||
|
|
if return_dummy_image:
|
||
|
|
images = Image.new("RGB", (224, 224), (0, 0, 0))
|
||
|
|
|
||
|
|
images = make_list_of_images(images)
|
||
|
|
|
||
|
|
if not valid_images(images):
|
||
|
|
raise ValueError(
|
||
|
|
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
||
|
|
"torch.Tensor, tf.Tensor or jax.ndarray."
|
||
|
|
)
|
||
|
|
|
||
|
|
if do_convert_rgb:
|
||
|
|
images = [convert_to_rgb(image) for image in images]
|
||
|
|
|
||
|
|
# All transformations expect numpy arrays.
|
||
|
|
images = [to_numpy_array(image) for image in images]
|
||
|
|
|
||
|
|
if is_scaled_image(images[0]) and do_rescale:
|
||
|
|
logger.warning_once(
|
||
|
|
"It looks like you are trying to rescale already rescaled images. If the input"
|
||
|
|
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
|
||
|
|
)
|
||
|
|
|
||
|
|
if input_data_format is None:
|
||
|
|
# We assume that all images have the same channel dimension format.
|
||
|
|
input_data_format = infer_channel_dimension_format(images[0])
|
||
|
|
|
||
|
|
new_images = []
|
||
|
|
image_sizes = [get_image_size(image, channel_dim=input_data_format) for image in images]
|
||
|
|
vision_query_lengths = []
|
||
|
|
|
||
|
|
assert crop_size["height"] == crop_size["width"]
|
||
|
|
|
||
|
|
# global image 의 padding 연산은, image original width, height 가 클 때 bottleneck 이 될 수 있음
|
||
|
|
# 장축의 길이를 size["shortest_edge"] 로 resize 를 먼저 한 뒤에, padding
|
||
|
|
if anyres:
|
||
|
|
anyres_global_images = copy.deepcopy(images)
|
||
|
|
if pad_to_square:
|
||
|
|
background_color = tuple(int(x * 255) for x in self.image_mean)
|
||
|
|
anyres_global_images = [
|
||
|
|
resize_longside(copy.deepcopy(image), size["shortest_edge"], resample, input_data_format)
|
||
|
|
for image in anyres_global_images
|
||
|
|
]
|
||
|
|
anyres_global_images = [
|
||
|
|
expand2square(image, background_color=background_color, input_data_format=input_data_format)[0]
|
||
|
|
for image in anyres_global_images
|
||
|
|
]
|
||
|
|
else:
|
||
|
|
anyres_global_images = [
|
||
|
|
self.resize(
|
||
|
|
image=image,
|
||
|
|
size={"height": size["shortest_edge"], "width": size["shortest_edge"]},
|
||
|
|
resample=resample,
|
||
|
|
input_data_format=input_data_format,
|
||
|
|
)
|
||
|
|
for image in anyres_global_images
|
||
|
|
]
|
||
|
|
else:
|
||
|
|
anyres_global_images = [None for _ in range(len(images))]
|
||
|
|
if pad_to_square:
|
||
|
|
background_color = tuple(int(x * 255) for x in self.image_mean)
|
||
|
|
images = [
|
||
|
|
resize_longside(image, size["shortest_edge"], resample, input_data_format) for image in images
|
||
|
|
]
|
||
|
|
images = [
|
||
|
|
expand2square(image, background_color=background_color, input_data_format=input_data_format)[0]
|
||
|
|
for image in images
|
||
|
|
]
|
||
|
|
|
||
|
|
for image, anyres_global_image, image_size in zip(images, anyres_global_images, image_sizes):
|
||
|
|
if anyres:
|
||
|
|
# convert image into a list of grids
|
||
|
|
# we intentially use the same data format as the input data format
|
||
|
|
image_grids = self.get_image_grids(
|
||
|
|
image,
|
||
|
|
possible_resolutions,
|
||
|
|
grid_size=crop_size["height"],
|
||
|
|
resample=resample,
|
||
|
|
data_format=input_data_format,
|
||
|
|
input_data_format=input_data_format,
|
||
|
|
)
|
||
|
|
# video 에 대해서는 global image (thumbnail) 를 사용하지 않음
|
||
|
|
if not is_video:
|
||
|
|
image_grids = [anyres_global_image] + image_grids
|
||
|
|
else:
|
||
|
|
image_grids = [image]
|
||
|
|
|
||
|
|
pixel_values = self._preprocess(
|
||
|
|
image_grids,
|
||
|
|
do_resize=do_resize,
|
||
|
|
size=size,
|
||
|
|
resample=resample,
|
||
|
|
do_center_crop=do_center_crop,
|
||
|
|
crop_size=crop_size,
|
||
|
|
do_rescale=do_rescale,
|
||
|
|
rescale_factor=rescale_factor,
|
||
|
|
do_normalize=do_normalize,
|
||
|
|
image_mean=image_mean,
|
||
|
|
image_std=image_std,
|
||
|
|
data_format=data_format,
|
||
|
|
input_data_format=input_data_format,
|
||
|
|
)
|
||
|
|
|
||
|
|
pixel_values = np.array(pixel_values)
|
||
|
|
new_images.append(pixel_values)
|
||
|
|
|
||
|
|
vision_query_length = determine_anyres_num_vision_patches(
|
||
|
|
image_size=image_size,
|
||
|
|
grid_size=crop_size["height"],
|
||
|
|
patch_size=patch_size,
|
||
|
|
possible_resolutions=possible_resolutions,
|
||
|
|
anyres=anyres,
|
||
|
|
unpad=unpad,
|
||
|
|
num_queries_vis_abstractor=num_queries_vis_abstractor,
|
||
|
|
num_queries_vis_abstractor_slow=num_queries_vis_abstractor_slow,
|
||
|
|
is_video=is_video,
|
||
|
|
first_last_frames_slow=first_last_frames_slow,
|
||
|
|
is_first_or_last_frames=is_first_or_last_frames,
|
||
|
|
)
|
||
|
|
|
||
|
|
vision_query_lengths.append(vision_query_length)
|
||
|
|
|
||
|
|
if return_dummy_image:
|
||
|
|
vision_query_lengths = []
|
||
|
|
|
||
|
|
data = {
|
||
|
|
"pixel_values": [torch.tensor(new_image) for new_image in new_images],
|
||
|
|
"image_sizes": [{"width": image_size[1], "height": image_size[0]} for image_size in image_sizes],
|
||
|
|
"vision_query_lengths": vision_query_lengths,
|
||
|
|
}
|
||
|
|
|
||
|
|
return BatchFeature(data=data, tensor_type=return_tensors)
|
||
|
|
|
||
|
|
def save_pretrained(
|
||
|
|
self,
|
||
|
|
save_directory: Union[str, os.PathLike],
|
||
|
|
*args,
|
||
|
|
**kwargs,
|
||
|
|
):
|
||
|
|
self.register_for_auto_class()
|
||
|
|
super().save_pretrained(save_directory, *args, **kwargs)
|
||
|
|
|
||
|
|
|
||
|
|
def determine_anyres_num_vision_patches(
|
||
|
|
image_size,
|
||
|
|
grid_size,
|
||
|
|
patch_size,
|
||
|
|
possible_resolutions,
|
||
|
|
anyres=False,
|
||
|
|
unpad=True,
|
||
|
|
num_queries_vis_abstractor=0,
|
||
|
|
num_queries_vis_abstractor_slow=0,
|
||
|
|
is_video=False,
|
||
|
|
first_last_frames_slow=False, # sample-wise option
|
||
|
|
is_first_or_last_frames=False, # grid-wise option
|
||
|
|
):
|
||
|
|
"""
|
||
|
|
Computes the number of visual tokens (patches) based on image resolution, grid configuration, and patch size.
|
||
|
|
|
||
|
|
This function supports both fixed-size and any-resolution settings, as well as video-specific configurations
|
||
|
|
such as handling slow frames and frame position flags.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
num_grids (int): Number of grids per image (e.g., 1 for 1x1, 4 for 2x2, etc.).
|
||
|
|
image_size (tuple): The original image size as (height, width).
|
||
|
|
grid_size (int): Size of each grid in pixels (e.g., 336).
|
||
|
|
patch_size (int): Size of each vision patch (e.g., 14 for ViT models).
|
||
|
|
possible_resolutions (list): List of possible resolution tuples [(h1, w1), (h2, w2), ...].
|
||
|
|
anyres (bool, optional): Whether to use any-resolution mode. Defaults to False.
|
||
|
|
unpad (bool, optional): Whether to unpad the image before computing patches. Defaults to True.
|
||
|
|
num_queries_vis_abstractor (int, optional): Number of query tokens for vision abstractor (fast path).
|
||
|
|
num_queries_vis_abstractor_slow (int, optional): Number of query tokens for vision abstractor (slow path).
|
||
|
|
is_video (bool, optional): Whether the input is a video. Defaults to False.
|
||
|
|
first_last_frames_slow (bool, optional): Whether to treat first/last video frames as "slow". Defaults to False.
|
||
|
|
is_first_or_last_frames (bool, optional): Whether current grid corresponds to first/last frame. Defaults to False.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
int: Total number of visual tokens (patches) after processing.
|
||
|
|
"""
|
||
|
|
|
||
|
|
if not anyres:
|
||
|
|
return num_queries_vis_abstractor if num_queries_vis_abstractor > 0 else (grid_size // patch_size) ** 2
|
||
|
|
|
||
|
|
if num_queries_vis_abstractor > 0:
|
||
|
|
num_patch_per_grid = int(num_queries_vis_abstractor**0.5)
|
||
|
|
else:
|
||
|
|
num_patch_per_grid = grid_size // patch_size
|
||
|
|
|
||
|
|
num_global_per_grid = num_patch_per_grid
|
||
|
|
|
||
|
|
# In anyres mode, a global image is included, so there are always at least 2 grids.
|
||
|
|
# However, for video inputs, there is no global image, so it's possible to have only 1 grid.
|
||
|
|
# Therefore, the assertion below is commented out:
|
||
|
|
# assert num_grids > 1
|
||
|
|
|
||
|
|
# Compute the number of vision patches.
|
||
|
|
height, width = select_best_resolution(image_size, possible_resolutions)
|
||
|
|
|
||
|
|
num_patch_height = (height // grid_size) * num_patch_per_grid
|
||
|
|
num_patch_width = (width // grid_size) * num_patch_per_grid
|
||
|
|
|
||
|
|
# local images
|
||
|
|
if unpad:
|
||
|
|
original_height, original_width = image_size
|
||
|
|
|
||
|
|
original_aspect_ratio = original_width / original_height
|
||
|
|
current_aspect_ratio = num_patch_width / num_patch_height
|
||
|
|
|
||
|
|
if original_aspect_ratio > current_aspect_ratio:
|
||
|
|
scale_factor = num_patch_width / original_width
|
||
|
|
new_height = int(original_height * scale_factor)
|
||
|
|
padding = (num_patch_height - new_height) // 2
|
||
|
|
num_patch_height = num_patch_height - padding * 2
|
||
|
|
else:
|
||
|
|
scale_factor = num_patch_height / original_height
|
||
|
|
new_width = int(original_width * scale_factor)
|
||
|
|
padding = (num_patch_width - new_width) // 2
|
||
|
|
num_patch_width = num_patch_width - padding * 2
|
||
|
|
|
||
|
|
num_patches = num_patch_width * num_patch_height + num_patch_height
|
||
|
|
else:
|
||
|
|
num_patches = num_patch_width * num_patch_height
|
||
|
|
|
||
|
|
# In the "slow" strategy, when applying to first and last frames only, it is applied exclusively to those two frames.
|
||
|
|
if num_queries_vis_abstractor_slow > 0:
|
||
|
|
if first_last_frames_slow:
|
||
|
|
if is_first_or_last_frames:
|
||
|
|
num_patches += num_queries_vis_abstractor_slow - num_queries_vis_abstractor
|
||
|
|
else:
|
||
|
|
num_patches += num_queries_vis_abstractor_slow - num_queries_vis_abstractor
|
||
|
|
# The slowfast feature is only applicable when unpad is set to False.
|
||
|
|
assert unpad is False
|
||
|
|
|
||
|
|
# Global image is not included for video inputs.
|
||
|
|
if not is_video:
|
||
|
|
num_patches += num_global_per_grid**2
|
||
|
|
|
||
|
|
return num_patches
|
||
|
|
|
||
|
|
|
||
|
|
def divide_to_grids(image: np.array, grid_size: int, input_data_format=None) -> List[np.array]:
|
||
|
|
"""
|
||
|
|
Divides a local image into grids of size (grid_size x grid_size).
|
||
|
|
|
||
|
|
Args:
|
||
|
|
image (np.array): Input image as a NumPy array.
|
||
|
|
grid_size (int): The size (in pixels) of each square grid.
|
||
|
|
input_data_format (optional): Optional format specifier (e.g., "channels_first" or "channels_last").
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
List[np.array]: A list of image patches, each of size (grid_size x grid_size).
|
||
|
|
"""
|
||
|
|
grids = []
|
||
|
|
height, width = get_image_size(image, channel_dim=input_data_format)
|
||
|
|
for i in range(0, height, grid_size):
|
||
|
|
for j in range(0, width, grid_size):
|
||
|
|
if input_data_format == ChannelDimension.LAST:
|
||
|
|
grid = image[i : i + grid_size, j : j + grid_size]
|
||
|
|
else:
|
||
|
|
grid = image[:, i : i + grid_size, j : j + grid_size]
|
||
|
|
grids.append(grid)
|
||
|
|
|
||
|
|
return grids
|
||
|
|
|
||
|
|
|
||
|
|
def pad(
|
||
|
|
image: np.array,
|
||
|
|
target_size: tuple,
|
||
|
|
background_color=(127, 127, 127),
|
||
|
|
input_data_format=None,
|
||
|
|
) -> np.array:
|
||
|
|
"""
|
||
|
|
Pads the input image on the sides (top/bottom and left/right) to match the target height and width.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
image (np.array): Input image as a NumPy array.
|
||
|
|
target_size (tuple): Target size as (target_height, target_width).
|
||
|
|
background_color (tuple, optional): RGB color value used for padding. Defaults to (127, 127, 127).
|
||
|
|
input_data_format (optional): Optional format specifier (e.g., "channels_first" or "channels_last").
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
np.array: The padded image with the specified target size.
|
||
|
|
"""
|
||
|
|
target_height, target_width = target_size
|
||
|
|
height, width = get_image_size(image, channel_dim=input_data_format)
|
||
|
|
|
||
|
|
# result = np.ones((target_height, target_width, image.shape[2]), dtype=image.dtype) * background_color
|
||
|
|
result = np.empty((target_height, target_width, image.shape[2]), dtype=image.dtype)
|
||
|
|
for i in range(image.shape[2]):
|
||
|
|
result[..., i].fill(background_color[i])
|
||
|
|
|
||
|
|
paste_x = (target_width - width) // 2
|
||
|
|
paste_y = (target_height - height) // 2
|
||
|
|
|
||
|
|
result[paste_y : paste_y + height, paste_x : paste_x + width, :] = image
|
||
|
|
|
||
|
|
return result
|
||
|
|
|
||
|
|
|
||
|
|
def expand2square(
|
||
|
|
image: np.array,
|
||
|
|
bboxes_dict=None,
|
||
|
|
background_color=(127, 127, 127),
|
||
|
|
input_data_format=None,
|
||
|
|
) -> np.array:
|
||
|
|
"""
|
||
|
|
Expands the input image to a square shape by placing it at the center of a new square canvas,
|
||
|
|
with padding added to the shorter side (either top/bottom or left/right).
|
||
|
|
|
||
|
|
The image is always centered on the new canvas, and padding is applied symmetrically.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
image (np.array): Input image as a NumPy array.
|
||
|
|
bboxes_dict (dict, optional): A dictionary of bounding boxes, where each value is an NDArray of shape (N, 4, 2)
|
||
|
|
with box coordinates in the format [[xtl, ytl], [xtr, ytr], [xbr, ybr], [xbl, ybl]].
|
||
|
|
Supports multiple categories (e.g., "ocr", "html") simultaneously.
|
||
|
|
background_color (tuple, optional): RGB color to fill the padding area. Defaults to (127, 127, 127).
|
||
|
|
input_data_format (optional): Optional format specifier for image data (e.g., "channels_first" or "channels_last").
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
np.array: A square-shaped image with the original image centered and padded as needed.
|
||
|
|
|
||
|
|
Example:
|
||
|
|
>>> _img = np.ones((80, 100), dtype=np.uint8) * 100
|
||
|
|
>>> _bboxes_dict = {"words": np.array([[[10, 10], [20, 10], [20, 20], [10, 20]],
|
||
|
|
... [[30, 30], [40, 30], [40, 40], [30, 40]]])}
|
||
|
|
>>> _img, _bboxes_dict = expand2square(_img, _bboxes_dict, (255, 255, 255))
|
||
|
|
>>> _img.shape
|
||
|
|
(100, 100)
|
||
|
|
>>> guessed_ocr_bboxes = np.array([[[20, 10], [30, 10], [30, 20], [20, 20]],
|
||
|
|
... [[40, 30], [50, 30], [50, 40], [40, 40]]])
|
||
|
|
>>> np.testing.assert_array_almost_equal(_bboxes_dict["words"], guessed_ocr_bboxes) is None
|
||
|
|
True
|
||
|
|
"""
|
||
|
|
height, width = get_image_size(image, channel_dim=input_data_format)
|
||
|
|
if width == height:
|
||
|
|
return image, bboxes_dict
|
||
|
|
elif width > height:
|
||
|
|
# result = np.ones((width, width, image.shape[2]), dtype=image.dtype) * background_color
|
||
|
|
result = np.empty((width, width, image.shape[2]), dtype=image.dtype)
|
||
|
|
for i in range(image.shape[2]):
|
||
|
|
result[..., i].fill(background_color[i])
|
||
|
|
|
||
|
|
result[(width - height) // 2 : (width - height) // 2 + height, :] = image
|
||
|
|
if bboxes_dict is not None:
|
||
|
|
for key in bboxes_dict:
|
||
|
|
bboxes_dict[key][:, :, 1] += (width - height) // 2
|
||
|
|
return result, bboxes_dict
|
||
|
|
else:
|
||
|
|
# result = np.ones((height, height, image.shape[2]), dtype=image.dtype) * background_color
|
||
|
|
result = np.empty((height, height, image.shape[2]), dtype=image.dtype)
|
||
|
|
for i in range(image.shape[2]):
|
||
|
|
result[..., i].fill(background_color[i])
|
||
|
|
|
||
|
|
result[:, (height - width) // 2 : (height - width) // 2 + width] = image
|
||
|
|
if bboxes_dict is not None:
|
||
|
|
for key in bboxes_dict:
|
||
|
|
bboxes_dict[key][:, :, 0] += (height - width) // 2
|
||
|
|
return result, bboxes_dict
|
||
|
|
|
||
|
|
|
||
|
|
def resize_longside(
|
||
|
|
image: np.array,
|
||
|
|
size: int,
|
||
|
|
resample: PILImageResampling = PILImageResampling.BICUBIC, # type: ignore
|
||
|
|
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||
|
|
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||
|
|
):
|
||
|
|
"""
|
||
|
|
Resizes the image so that its longer side matches the specified size, maintaining the original aspect ratio.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
image (np.array): Input image as a NumPy array.
|
||
|
|
size (int): Target size for the longer side of the image.
|
||
|
|
resample (PILImageResampling, optional): Resampling method to use during resizing. Defaults to BICUBIC.
|
||
|
|
data_format (str or ChannelDimension, optional): Output data format (e.g., "channels_first" or "channels_last").
|
||
|
|
input_data_format (str or ChannelDimension, optional): Input data format of the image.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
np.array: The resized image with its aspect ratio preserved.
|
||
|
|
"""
|
||
|
|
height, width = get_image_size(image, channel_dim=input_data_format)
|
||
|
|
|
||
|
|
if width == height:
|
||
|
|
target_height, target_width = size, size
|
||
|
|
elif width > height:
|
||
|
|
target_width = size
|
||
|
|
target_height = math.ceil(height / width * size)
|
||
|
|
else:
|
||
|
|
target_width = math.ceil(width / height * size)
|
||
|
|
target_height = size
|
||
|
|
|
||
|
|
return resize(
|
||
|
|
image,
|
||
|
|
size=(target_height, target_width),
|
||
|
|
resample=resample,
|
||
|
|
data_format=data_format,
|
||
|
|
input_data_format=input_data_format,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def _get_local_grids_output_size(image: np.array, target_resolution: tuple, input_data_format=None):
|
||
|
|
"""
|
||
|
|
Computes the number of local grids (patches) along the height and width when resizing an image
|
||
|
|
to the target resolution.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
image (np.array): Input image as a NumPy array.
|
||
|
|
target_resolution (tuple): Target resolution in the format (target_height, target_width).
|
||
|
|
input_data_format (optional): Optional format specifier (e.g., "channels_first" or "channels_last").
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
tuple: A tuple (grid_h, grid_w) representing the number of grids along the height and width.
|
||
|
|
"""
|
||
|
|
original_height, original_width = get_image_size(image, channel_dim=input_data_format)
|
||
|
|
target_height, target_width = target_resolution
|
||
|
|
|
||
|
|
scale_w = target_width / original_width
|
||
|
|
scale_h = target_height / original_height
|
||
|
|
|
||
|
|
if scale_w < scale_h:
|
||
|
|
new_width = target_width
|
||
|
|
new_height = min(math.ceil(original_height * scale_w), target_height)
|
||
|
|
else:
|
||
|
|
new_height = target_height
|
||
|
|
new_width = min(math.ceil(original_width * scale_h), target_width)
|
||
|
|
|
||
|
|
return new_height, new_width
|
||
|
|
|
||
|
|
|
||
|
|
def select_best_resolution(original_size: tuple, possible_resolutions: list) -> tuple:
|
||
|
|
"""
|
||
|
|
Selects the best-fit resolution from a list of possible resolutions based on the original image size.
|
||
|
|
|
||
|
|
This function, adapted from LLaVA-Next
|
||
|
|
(https://github.com/huggingface/transformers/blob/v4.40.2/src/transformers/models/llava_next/image_processing_llava_next.py),
|
||
|
|
evaluates each resolution by computing its effective and wasted area compared to the original size.
|
||
|
|
The optimal resolution is the one that maximizes the effective area while minimizing unused (wasted) space.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
original_size (tuple): The original image size in the format (height, width).
|
||
|
|
possible_resolutions (list): A list of candidate resolutions in the format [(height1, width1), (height2, width2), ...].
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
tuple: The best-fit resolution in the format (height, width).
|
||
|
|
"""
|
||
|
|
original_height, original_width = original_size
|
||
|
|
best_fit = None
|
||
|
|
max_effective_resolution = 0
|
||
|
|
min_wasted_resolution = float("inf")
|
||
|
|
|
||
|
|
for height, width in possible_resolutions:
|
||
|
|
scale = min(width / original_width, height / original_height)
|
||
|
|
downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
|
||
|
|
effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
|
||
|
|
wasted_resolution = (width * height) - effective_resolution
|
||
|
|
|
||
|
|
if effective_resolution > max_effective_resolution or (
|
||
|
|
effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution
|
||
|
|
):
|
||
|
|
max_effective_resolution = effective_resolution
|
||
|
|
min_wasted_resolution = wasted_resolution
|
||
|
|
best_fit = (height, width)
|
||
|
|
|
||
|
|
return best_fit
|