170 lines
6.9 KiB
Python
170 lines
6.9 KiB
Python
import asyncio
|
|
import math
|
|
import re
|
|
from typing import Dict, List, Union
|
|
|
|
import torch
|
|
from PIL import Image
|
|
|
|
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
|
|
from sglang.srt.managers.multimodal_processors.base_processor import (
|
|
BaseMultimodalProcessor as SGLangBaseProcessor,
|
|
)
|
|
from sglang.srt.managers.multimodal_processors.base_processor import (
|
|
MultimodalSpecialTokens,
|
|
)
|
|
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
|
from sglang.srt.models.qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
|
|
from sglang.srt.models.qwen2_vl import Qwen2VLForConditionalGeneration
|
|
|
|
|
|
# Compatible with Qwen2VL and Qwen2_5VL
|
|
class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
|
|
models = [Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration]
|
|
|
|
def __init__(self, hf_config, server_args, _processor):
|
|
super().__init__(hf_config, server_args, _processor)
|
|
# The single, pre-expanded image token.
|
|
self.IMAGE_TOKEN = "<|vision_start|><|image_pad|><|vision_end|>"
|
|
# The regex that matches expanded image tokens.
|
|
self.IMAGE_TOKEN_REGEX = re.compile(
|
|
r"<\|vision_start\|>(?:<\|image_pad\|>)+<\|vision_end\|>"
|
|
)
|
|
self.IM_START_TOKEN_ID = hf_config.vision_start_token_id
|
|
self.IM_END_TOKEN_ID = hf_config.vision_end_token_id
|
|
self.IM_TOKEN_ID = hf_config.image_token_id
|
|
self.VIDEO_TOKEN_ID = hf_config.video_token_id
|
|
self.vision_start_token_id = hf_config.vision_start_token_id
|
|
self.vision_end_token_id = hf_config.vision_end_token_id
|
|
self.NUM_TOKEN_PER_FRAME = 770
|
|
self.IMAGE_FACTOR = 28
|
|
self.MIN_PIXELS = 4 * 28 * 28
|
|
self.MAX_PIXELS = 16384 * 28 * 28
|
|
self.MAX_RATIO = 200
|
|
|
|
async def process_mm_data_async(
|
|
self,
|
|
image_data: List[Union[str, bytes, Dict]],
|
|
input_text,
|
|
request_obj,
|
|
max_req_input_len,
|
|
*args,
|
|
**kwargs,
|
|
):
|
|
if isinstance(image_data, str):
|
|
image_data = [image_data]
|
|
|
|
base_output = self.load_mm_data(
|
|
prompt=input_text,
|
|
image_data=image_data,
|
|
multimodal_tokens=MultimodalSpecialTokens(
|
|
image_token=self.IMAGE_TOKEN,
|
|
image_token_regex=self.IMAGE_TOKEN_REGEX,
|
|
),
|
|
max_req_input_len=max_req_input_len,
|
|
)
|
|
|
|
def smart_resize(
|
|
height: int,
|
|
width: int,
|
|
factor: int = self.IMAGE_FACTOR,
|
|
min_pixels: int = self.MIN_PIXELS,
|
|
max_pixels: int = self.MAX_PIXELS,
|
|
) -> tuple[int, int]:
|
|
"""
|
|
Rescales the image so that the following conditions are met:
|
|
|
|
1. Both dimensions (height and width) are divisible by 'factor'.
|
|
|
|
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
|
|
|
|
3. The aspect ratio of the image is maintained as closely as possible.
|
|
"""
|
|
if max(height, width) / min(height, width) > self.MAX_RATIO:
|
|
raise ValueError(
|
|
f"absolute aspect ratio must be smaller than {self.MAX_RATIO}, got {max(height, width) / min(height, width)}"
|
|
)
|
|
h_bar = max(factor, round_by_factor(height, factor))
|
|
w_bar = max(factor, round_by_factor(width, factor))
|
|
if h_bar * w_bar > max_pixels:
|
|
beta = math.sqrt((height * width) / max_pixels)
|
|
h_bar = floor_by_factor(height / beta, factor)
|
|
w_bar = floor_by_factor(width / beta, factor)
|
|
elif h_bar * w_bar < min_pixels:
|
|
beta = math.sqrt(min_pixels / (height * width))
|
|
h_bar = ceil_by_factor(height * beta, factor)
|
|
w_bar = ceil_by_factor(width * beta, factor)
|
|
return h_bar, w_bar
|
|
|
|
def resize_image(image, size_factor: int = self.IMAGE_FACTOR) -> Image.Image:
|
|
width, height = image.size
|
|
min_pixels = self.MIN_PIXELS
|
|
max_pixels = self.MAX_PIXELS
|
|
resized_height, resized_width = smart_resize(
|
|
height,
|
|
width,
|
|
factor=size_factor,
|
|
min_pixels=min_pixels,
|
|
max_pixels=max_pixels,
|
|
)
|
|
image = image.resize((resized_width, resized_height))
|
|
return image
|
|
|
|
def round_by_factor(number: int, factor: int) -> int:
|
|
"""Returns the closest integer to 'number' that is divisible by 'factor'."""
|
|
return round(number / factor) * factor
|
|
|
|
def ceil_by_factor(number: int, factor: int) -> int:
|
|
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
|
|
return math.ceil(number / factor) * factor
|
|
|
|
def floor_by_factor(number: int, factor: int) -> int:
|
|
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
|
|
return math.floor(number / factor) * factor
|
|
|
|
async def resize_image_async(image):
|
|
return resize_image(image)
|
|
|
|
# Qwen-specific: resize images if they are raw Image objects
|
|
if base_output.images and isinstance(base_output.images[0], Image.Image):
|
|
resize_tasks = [resize_image_async(image) for image in base_output.images]
|
|
base_output.images = await asyncio.gather(*resize_tasks)
|
|
|
|
video_grid_thw = None # TODO
|
|
|
|
combined_mm_item, input_ids = self.process_and_combine_mm_data(base_output)
|
|
|
|
if combined_mm_item is None:
|
|
# Note(Xinyuan): This is the case where image loading fails.
|
|
return None
|
|
|
|
video_grid_thw = None # TODO
|
|
second_per_grid_ts = getattr(combined_mm_item, "second_per_grid_ts", None)
|
|
|
|
mrope_positions, mrope_position_delta = MRotaryEmbedding.get_rope_index(
|
|
spatial_merge_size=self.hf_config.vision_config.spatial_merge_size,
|
|
image_token_id=self.IM_TOKEN_ID,
|
|
video_token_id=self.VIDEO_TOKEN_ID,
|
|
vision_start_token_id=self.vision_start_token_id,
|
|
model_type=self.hf_config.model_type,
|
|
tokens_per_second=getattr(
|
|
self.hf_config.vision_config, "tokens_per_second", None
|
|
),
|
|
input_ids=input_ids.unsqueeze(0),
|
|
image_grid_thw=combined_mm_item.image_grid_thw,
|
|
video_grid_thw=video_grid_thw,
|
|
second_per_grid_ts=second_per_grid_ts,
|
|
)
|
|
mrope_positions = mrope_positions.squeeze(1)
|
|
|
|
return {
|
|
"input_ids": input_ids.tolist(),
|
|
"mm_items": [combined_mm_item],
|
|
"im_start_id": self.IM_START_TOKEN_ID,
|
|
"im_end_id": self.IM_END_TOKEN_ID,
|
|
"im_token_id": self.IM_TOKEN_ID,
|
|
"video_token_id": self.VIDEO_TOKEN_ID,
|
|
"mrope_positions": mrope_positions,
|
|
"mrope_position_delta": mrope_position_delta,
|
|
}
|