Files
sglang/python/sglang/srt/mm_utils.py

341 lines
12 KiB
Python

"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
# Source: https://github.com/LLaVA-VL/LLaVA-NeXT/blob/main/llava/mm_utils.py
"""
Utilities for multi-modal models.
This python file mainly contains utilities that were used in the
image processing logic of llava-next including operations such as
anyres and anyres_max
Currently supports the anyres and anyres_max operation for CLIP and
SigLip. For more information, you may refer to the paper or the blog
LLaVA-NeXT : https://llava-vl.github.io/blog/2024-01-30-llava-next/
LLaVA-Onevision : https://arxiv.org/pdf/2408.03326
"""
import ast
import base64
import math
import re
from io import BytesIO
import numpy as np
from PIL import Image
def select_best_resolution(original_size, possible_resolutions):
"""
Selects the best resolution from a list of possible resolutions based on the original size.
Args:
original_size (tuple): The original size of the image in the format (width, height).
possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
Returns:
tuple: The best fit resolution in the format (width, height).
"""
original_width, original_height = original_size
best_fit = None
max_effective_resolution = 0
min_wasted_resolution = float("inf")
for width, height in possible_resolutions:
# Calculate the downscaled size to keep the aspect ratio
scale = min(width / original_width, height / original_height)
downscaled_width, downscaled_height = int(original_width * scale), int(
original_height * scale
)
# Calculate effective and wasted resolutions
effective_resolution = min(
downscaled_width * downscaled_height, original_width * original_height
)
wasted_resolution = (width * height) - effective_resolution
if effective_resolution > max_effective_resolution or (
effective_resolution == max_effective_resolution
and wasted_resolution < min_wasted_resolution
):
max_effective_resolution = effective_resolution
min_wasted_resolution = wasted_resolution
best_fit = (width, height)
return best_fit
def resize_and_pad_image(image, target_resolution):
"""
Resize and pad an image to a target resolution while maintaining aspect ratio.
Args:
image (PIL.Image.Image): The input image.
target_resolution (tuple): The target resolution (width, height) of the image.
Returns:
PIL.Image.Image: The resized and padded image.
"""
original_width, original_height = image.size
target_width, target_height = target_resolution
scale_w = target_width / original_width
scale_h = target_height / original_height
if scale_w < scale_h:
new_width = target_width
new_height = min(math.ceil(original_height * scale_w), target_height)
else:
new_height = target_height
new_width = min(math.ceil(original_width * scale_h), target_width)
# Resize the image
resized_image = image.resize((new_width, new_height))
new_image = Image.new("RGB", (target_width, target_height), (0, 0, 0))
paste_x = (target_width - new_width) // 2
paste_y = (target_height - new_height) // 2
new_image.paste(resized_image, (paste_x, paste_y))
return new_image
def divide_to_patches(image, patch_size):
"""
Divides an image into patches of a specified size.
Args:
image (PIL.Image.Image): The input image.
patch_size (int): The size of each patch.
Returns:
list: A list of PIL.Image.Image objects representing the patches.
"""
patches = []
width, height = image.size
for i in range(0, height, patch_size):
for j in range(0, width, patch_size):
box = (j, i, j + patch_size, i + patch_size)
patch = image.crop(box)
patches.append(patch)
return patches
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
"""
Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
Args:
image_size (tuple): The size of the input image in the format (width, height).
grid_pinpoints (str): A string representation of a list of possible resolutions.
patch_size (int): The size of each image patch.
Returns:
tuple: The shape of the image patch grid in the format (width, height).
"""
if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints:
assert patch_size in [
224,
336,
384,
448,
512,
], "patch_size should be in [224, 336, 384, 448, 512]"
# Use regex to extract the range from the input string
matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints)
range_start = tuple(map(int, matches[0]))
range_end = tuple(map(int, matches[-1]))
# Generate a matrix of tuples from (range_start[0], range_start[1]) to (range_end[0], range_end[1])
grid_pinpoints = [
(i, j)
for i in range(range_start[0], range_end[0] + 1)
for j in range(range_start[1], range_end[1] + 1)
]
# Multiply all elements by patch_size
grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints]
if type(grid_pinpoints) is list:
possible_resolutions = grid_pinpoints
else:
possible_resolutions = ast.literal_eval(grid_pinpoints)
width, height = select_best_resolution(image_size, possible_resolutions)
return width // patch_size, height // patch_size
def process_anyres_image(image, processor, grid_pinpoints):
"""
Process an image with variable resolutions.
Args:
image (PIL.Image.Image): The input image to be processed.
processor: The image processor object.
grid_pinpoints (str): A string representation of a list of possible resolutions.
Returns:
np.array: An np array containing the processed image patches.
"""
if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints:
try:
patch_size = processor.size[0]
except Exception as e:
patch_size = processor.size["shortest_edge"]
assert patch_size in [
224,
336,
384,
448,
512,
], "patch_size should be in [224, 336, 384, 448, 512]"
# Use regex to extract the range from the input string
matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints)
range_start = tuple(map(int, matches[0]))
range_end = tuple(map(int, matches[-1]))
# Generate a matrix of tuples from (range_start[0], range_start[1]) to (range_end[0], range_end[1])
grid_pinpoints = [
(i, j)
for i in range(range_start[0], range_end[0] + 1)
for j in range(range_start[1], range_end[1] + 1)
]
# Multiply all elements by patch_size
grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints]
if type(grid_pinpoints) is list:
possible_resolutions = grid_pinpoints
else:
possible_resolutions = ast.literal_eval(grid_pinpoints)
best_resolution = select_best_resolution(image.size, possible_resolutions)
image_padded = resize_and_pad_image(image, best_resolution)
# For Siglip processor, only have size but no crop size
crop_size = (
processor.crop_size["height"]
if "crop_size" in processor.__dict__
else processor.size["height"]
)
shortest_edge = (
processor.size["shortest_edge"]
if "shortest_edge" in processor.size
else processor.size["height"]
)
patches = divide_to_patches(image_padded, crop_size)
image_original_resize = image.resize((shortest_edge, shortest_edge))
image_patches = [image_original_resize] + patches
image_patches = [
processor.preprocess(image_patch.convert("RGB"))["pixel_values"][0]
for image_patch in image_patches
]
return np.stack(image_patches, axis=0)
def load_image_from_base64(image):
return Image.open(BytesIO(base64.b64decode(image)))
def expand2square(pil_img, background_color):
width, height = pil_img.size
if width == height:
return pil_img
if pil_img.mode == "L":
pil_img = pil_img.convert("RGB")
if width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
def unpad_image(tensor, original_size):
"""
Unpads a PyTorch tensor of a padded and resized image.
Args:
tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format.
original_size (tuple): The original size of the image (height, width).
Returns:
torch.Tensor: The unpadded image tensor.
"""
original_width, original_height = original_size
current_height, current_width = tensor.shape[1:]
original_aspect_ratio = original_width / original_height
current_aspect_ratio = current_width / current_height
if original_aspect_ratio > current_aspect_ratio:
scale_factor = current_width / original_width
new_height = int(original_height * scale_factor)
padding = (current_height - new_height) // 2
unpadded_tensor = tensor[:, padding : current_height - padding, :]
else:
scale_factor = current_height / original_height
new_width = int(original_width * scale_factor)
padding = (current_width - new_width) // 2
unpadded_tensor = tensor[:, :, padding : current_width - padding]
return unpadded_tensor
def unpad_image_shape(current_height, current_width, original_size):
"""
Unpads a PyTorch tensor of a padded and resized image
and returns the new shape.
"""
original_width, original_height = original_size
original_aspect_ratio = original_width / original_height
current_aspect_ratio = current_width / current_height
if original_aspect_ratio > current_aspect_ratio:
scale_factor = current_width / original_width
new_height = int(original_height * scale_factor)
padding = (current_height - new_height) // 2
new_shape = (current_height - 2 * padding, current_width)
else:
scale_factor = current_height / original_height
new_width = int(original_width * scale_factor)
padding = (current_width - new_width) // 2
new_shape = (current_height, current_width - 2 * padding)
return new_shape
def process_images(images, image_processor, model_cfg):
image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
new_images = []
if image_aspect_ratio == "pad":
for image in images:
image = expand2square(
image, tuple(int(x * 255) for x in image_processor.image_mean)
)
image = image_processor.preprocess(image)["pixel_values"][0]
new_images.append(image)
elif "anyres" in image_aspect_ratio:
for image in images:
image = process_anyres_image(
image, image_processor, model_cfg.image_grid_pinpoints
)
new_images.append(image)
else:
return image_processor(images)["pixel_values"]
if all(x.shape == new_images[0].shape for x in new_images):
new_images = np.stack(new_images, axis=0)
return new_images