Llava-hd Support (#92)
Co-authored-by: Haotian Liu <liuhaotian.cn@gmail.com>
This commit is contained in:
251
python/sglang/srt/mm_utils.py
Normal file
251
python/sglang/srt/mm_utils.py
Normal file
@@ -0,0 +1,251 @@
|
||||
# Source: https://github.com/haotian-liu/LLaVA/blob/main/llava/mm_utils.py
|
||||
import ast
|
||||
import base64
|
||||
import math
|
||||
from io import BytesIO
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def select_best_resolution(original_size, possible_resolutions):
|
||||
"""
|
||||
Selects the best resolution from a list of possible resolutions based on the original size.
|
||||
|
||||
Args:
|
||||
original_size (tuple): The original size of the image in the format (width, height).
|
||||
possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
|
||||
|
||||
Returns:
|
||||
tuple: The best fit resolution in the format (width, height).
|
||||
"""
|
||||
original_width, original_height = original_size
|
||||
best_fit = None
|
||||
max_effective_resolution = 0
|
||||
min_wasted_resolution = float("inf")
|
||||
|
||||
for width, height in possible_resolutions:
|
||||
scale = min(width / original_width, height / original_height)
|
||||
downscaled_width, downscaled_height = int(original_width * scale), int(
|
||||
original_height * scale
|
||||
)
|
||||
effective_resolution = min(
|
||||
downscaled_width * downscaled_height, original_width * original_height
|
||||
)
|
||||
wasted_resolution = (width * height) - effective_resolution
|
||||
|
||||
if effective_resolution > max_effective_resolution or (
|
||||
effective_resolution == max_effective_resolution
|
||||
and wasted_resolution < min_wasted_resolution
|
||||
):
|
||||
max_effective_resolution = effective_resolution
|
||||
min_wasted_resolution = wasted_resolution
|
||||
best_fit = (width, height)
|
||||
|
||||
return best_fit
|
||||
|
||||
|
||||
def resize_and_pad_image(image, target_resolution):
|
||||
"""
|
||||
Resize and pad an image to a target resolution while maintaining aspect ratio.
|
||||
|
||||
Args:
|
||||
image (PIL.Image.Image): The input image.
|
||||
target_resolution (tuple): The target resolution (width, height) of the image.
|
||||
|
||||
Returns:
|
||||
PIL.Image.Image: The resized and padded image.
|
||||
"""
|
||||
original_width, original_height = image.size
|
||||
target_width, target_height = target_resolution
|
||||
|
||||
scale_w = target_width / original_width
|
||||
scale_h = target_height / original_height
|
||||
|
||||
if scale_w < scale_h:
|
||||
new_width = target_width
|
||||
new_height = min(math.ceil(original_height * scale_w), target_height)
|
||||
else:
|
||||
new_height = target_height
|
||||
new_width = min(math.ceil(original_width * scale_h), target_width)
|
||||
|
||||
# Resize the image
|
||||
resized_image = image.resize((new_width, new_height))
|
||||
|
||||
new_image = Image.new("RGB", (target_width, target_height), (0, 0, 0))
|
||||
paste_x = (target_width - new_width) // 2
|
||||
paste_y = (target_height - new_height) // 2
|
||||
new_image.paste(resized_image, (paste_x, paste_y))
|
||||
|
||||
return new_image
|
||||
|
||||
|
||||
def divide_to_patches(image, patch_size):
|
||||
"""
|
||||
Divides an image into patches of a specified size.
|
||||
|
||||
Args:
|
||||
image (PIL.Image.Image): The input image.
|
||||
patch_size (int): The size of each patch.
|
||||
|
||||
Returns:
|
||||
list: A list of PIL.Image.Image objects representing the patches.
|
||||
"""
|
||||
patches = []
|
||||
width, height = image.size
|
||||
for i in range(0, height, patch_size):
|
||||
for j in range(0, width, patch_size):
|
||||
box = (j, i, j + patch_size, i + patch_size)
|
||||
patch = image.crop(box)
|
||||
patches.append(patch)
|
||||
|
||||
return patches
|
||||
|
||||
|
||||
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
||||
"""
|
||||
Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
|
||||
|
||||
Args:
|
||||
image_size (tuple): The size of the input image in the format (width, height).
|
||||
grid_pinpoints (str): A string representation of a list of possible resolutions.
|
||||
patch_size (int): The size of each image patch.
|
||||
|
||||
Returns:
|
||||
tuple: The shape of the image patch grid in the format (width, height).
|
||||
"""
|
||||
if type(grid_pinpoints) is list:
|
||||
possible_resolutions = grid_pinpoints
|
||||
else:
|
||||
possible_resolutions = ast.literal_eval(grid_pinpoints)
|
||||
width, height = select_best_resolution(image_size, possible_resolutions)
|
||||
return width // patch_size, height // patch_size
|
||||
|
||||
|
||||
def process_anyres_image(image, processor, grid_pinpoints):
|
||||
"""
|
||||
Process an image with variable resolutions.
|
||||
|
||||
Args:
|
||||
image (PIL.Image.Image): The input image to be processed.
|
||||
processor: The image processor object.
|
||||
grid_pinpoints (str): A string representation of a list of possible resolutions.
|
||||
|
||||
Returns:
|
||||
np.array: An np array containing the processed image patches.
|
||||
"""
|
||||
if type(grid_pinpoints) is list:
|
||||
possible_resolutions = grid_pinpoints
|
||||
else:
|
||||
possible_resolutions = ast.literal_eval(grid_pinpoints)
|
||||
best_resolution = select_best_resolution(image.size, possible_resolutions)
|
||||
image_padded = resize_and_pad_image(image, best_resolution)
|
||||
|
||||
patches = divide_to_patches(image_padded, processor.crop_size["height"])
|
||||
|
||||
image_original_resize = image.resize(
|
||||
(processor.size["shortest_edge"], processor.size["shortest_edge"])
|
||||
)
|
||||
|
||||
image_patches = [image_original_resize] + patches
|
||||
image_patches = [
|
||||
processor.preprocess(image_patch)["pixel_values"][0]
|
||||
for image_patch in image_patches
|
||||
]
|
||||
return np.stack(image_patches, axis=0)
|
||||
|
||||
|
||||
def load_image_from_base64(image):
|
||||
return Image.open(BytesIO(base64.b64decode(image)))
|
||||
|
||||
|
||||
def expand2square(pil_img, background_color):
|
||||
width, height = pil_img.size
|
||||
if width == height:
|
||||
return pil_img
|
||||
elif width > height:
|
||||
result = Image.new(pil_img.mode, (width, width), background_color)
|
||||
result.paste(pil_img, (0, (width - height) // 2))
|
||||
return result
|
||||
else:
|
||||
result = Image.new(pil_img.mode, (height, height), background_color)
|
||||
result.paste(pil_img, ((height - width) // 2, 0))
|
||||
return result
|
||||
|
||||
|
||||
def unpad_image(tensor, original_size):
|
||||
"""
|
||||
Unpads a PyTorch tensor of a padded and resized image.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format.
|
||||
original_size (tuple): The original size of the image (height, width).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The unpadded image tensor.
|
||||
"""
|
||||
original_width, original_height = original_size
|
||||
current_height, current_width = tensor.shape[1:]
|
||||
|
||||
original_aspect_ratio = original_width / original_height
|
||||
current_aspect_ratio = current_width / current_height
|
||||
|
||||
if original_aspect_ratio > current_aspect_ratio:
|
||||
scale_factor = current_width / original_width
|
||||
new_height = int(original_height * scale_factor)
|
||||
padding = (current_height - new_height) // 2
|
||||
unpadded_tensor = tensor[:, padding : current_height - padding, :]
|
||||
else:
|
||||
scale_factor = current_height / original_height
|
||||
new_width = int(original_width * scale_factor)
|
||||
padding = (current_width - new_width) // 2
|
||||
unpadded_tensor = tensor[:, :, padding : current_width - padding]
|
||||
|
||||
return unpadded_tensor
|
||||
|
||||
|
||||
def unpad_image_shape(current_height, current_width, original_size):
|
||||
"""
|
||||
Unpads a PyTorch tensor of a padded and resized image
|
||||
and returns the new shape.
|
||||
"""
|
||||
original_width, original_height = original_size
|
||||
|
||||
original_aspect_ratio = original_width / original_height
|
||||
current_aspect_ratio = current_width / current_height
|
||||
|
||||
if original_aspect_ratio > current_aspect_ratio:
|
||||
scale_factor = current_width / original_width
|
||||
new_height = int(original_height * scale_factor)
|
||||
padding = (current_height - new_height) // 2
|
||||
new_shape = (current_height - 2 * padding, current_width)
|
||||
else:
|
||||
scale_factor = current_height / original_height
|
||||
new_width = int(original_width * scale_factor)
|
||||
padding = (current_width - new_width) // 2
|
||||
new_shape = (current_height, current_width - 2 * padding)
|
||||
|
||||
return new_shape
|
||||
|
||||
|
||||
def process_images(images, image_processor, model_cfg):
|
||||
image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
|
||||
new_images = []
|
||||
if image_aspect_ratio == "pad":
|
||||
for image in images:
|
||||
image = expand2square(
|
||||
image, tuple(int(x * 255) for x in image_processor.image_mean)
|
||||
)
|
||||
image = image_processor.preprocess(image)["pixel_values"][0]
|
||||
new_images.append(image)
|
||||
elif image_aspect_ratio == "anyres":
|
||||
for image in images:
|
||||
image = process_anyres_image(
|
||||
image, image_processor, model_cfg.image_grid_pinpoints
|
||||
)
|
||||
new_images.append(image)
|
||||
else:
|
||||
return image_processor(images)["pixel_values"]
|
||||
if all(x.shape == new_images[0].shape for x in new_images):
|
||||
new_images = np.stack(new_images, axis=0)
|
||||
return new_images
|
||||
Reference in New Issue
Block a user