import asyncio import math import re from typing import Dict, List, Union import torch from PIL import Image from sglang.srt.layers.rotary_embedding import MRotaryEmbedding from sglang.srt.managers.multimodal_processors.base_processor import ( BaseMultimodalProcessor as SGLangBaseProcessor, ) from sglang.srt.managers.multimodal_processors.base_processor import ( MultimodalSpecialTokens, ) from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem from sglang.srt.models.qwen2_5_vl import Qwen2_5_VLForConditionalGeneration from sglang.srt.models.qwen2_vl import Qwen2VLForConditionalGeneration # Compatible with Qwen2VL and Qwen2_5VL class Qwen2_5VLImageProcessor(SGLangBaseProcessor): models = [Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration] def __init__(self, hf_config, server_args, _processor): super().__init__(hf_config, server_args, _processor) # The single, pre-expanded image token. self.IMAGE_TOKEN = "<|vision_start|><|image_pad|><|vision_end|>" # The regex that matches expanded image tokens. self.IMAGE_TOKEN_REGEX = re.compile( r"<\|vision_start\|>(?:<\|image_pad\|>)+<\|vision_end\|>" ) self.IM_START_TOKEN_ID = hf_config.vision_start_token_id self.IM_END_TOKEN_ID = hf_config.vision_end_token_id self.image_token_id = hf_config.image_token_id self.video_token_id = hf_config.video_token_id self.vision_start_token_id = hf_config.vision_start_token_id self.vision_end_token_id = hf_config.vision_end_token_id self.NUM_TOKEN_PER_FRAME = 770 self.IMAGE_FACTOR = 28 self.MIN_PIXELS = 4 * 28 * 28 self.MAX_PIXELS = 16384 * 28 * 28 self.MAX_RATIO = 200 async def process_mm_data_async( self, image_data: List[Union[str, bytes, Dict]], input_text, request_obj, max_req_input_len, *args, **kwargs, ): if isinstance(image_data, str): image_data = [image_data] base_output = self.load_mm_data( prompt=input_text, image_data=image_data, multimodal_tokens=MultimodalSpecialTokens( image_token=self.IMAGE_TOKEN, image_token_regex=self.IMAGE_TOKEN_REGEX, ), max_req_input_len=max_req_input_len, ) def smart_resize( height: int, width: int, factor: int = self.IMAGE_FACTOR, min_pixels: int = self.MIN_PIXELS, max_pixels: int = self.MAX_PIXELS, ) -> tuple[int, int]: """ Rescales the image so that the following conditions are met: 1. Both dimensions (height and width) are divisible by 'factor'. 2. The total number of pixels is within the range ['min_pixels', 'max_pixels']. 3. The aspect ratio of the image is maintained as closely as possible. """ if max(height, width) / min(height, width) > self.MAX_RATIO: raise ValueError( f"absolute aspect ratio must be smaller than {self.MAX_RATIO}, got {max(height, width) / min(height, width)}" ) h_bar = max(factor, round_by_factor(height, factor)) w_bar = max(factor, round_by_factor(width, factor)) if h_bar * w_bar > max_pixels: beta = math.sqrt((height * width) / max_pixels) h_bar = floor_by_factor(height / beta, factor) w_bar = floor_by_factor(width / beta, factor) elif h_bar * w_bar < min_pixels: beta = math.sqrt(min_pixels / (height * width)) h_bar = ceil_by_factor(height * beta, factor) w_bar = ceil_by_factor(width * beta, factor) return h_bar, w_bar def resize_image(image, size_factor: int = self.IMAGE_FACTOR) -> Image.Image: width, height = image.size min_pixels = self.MIN_PIXELS max_pixels = self.MAX_PIXELS resized_height, resized_width = smart_resize( height, width, factor=size_factor, min_pixels=min_pixels, max_pixels=max_pixels, ) image = image.resize((resized_width, resized_height)) return image def round_by_factor(number: int, factor: int) -> int: """Returns the closest integer to 'number' that is divisible by 'factor'.""" return round(number / factor) * factor def ceil_by_factor(number: int, factor: int) -> int: """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'.""" return math.ceil(number / factor) * factor def floor_by_factor(number: int, factor: int) -> int: """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'.""" return math.floor(number / factor) * factor async def resize_image_async(image): return resize_image(image) images_are_preprocessed = self.mm_inputs_are_preprocessed(base_output.images) if base_output.images and not images_are_preprocessed: resize_tasks = [resize_image_async(image) for image in base_output.images] base_output.images = await asyncio.gather(*resize_tasks) ret = self.process_mm_data( input_text=base_output.input_text, images=None if images_are_preprocessed else base_output.images, ) input_ids = ret["input_ids"].flatten().tolist() image_offsets = self.get_mm_items_offset( input_ids=ret["input_ids"].flatten(), mm_token_id=self.image_token_id ) image_grid_thw = None video_grid_thw = None # TODO items = [] if base_output.images: if images_are_preprocessed: image_grid_thw = torch.concat( [ torch.as_tensor(item.image_grid_thws) for item in base_output.images ] ) all_pixel_values = [ item.pixel_values for item in base_output.images if item.pixel_values is not None ] all_precomputed_features = [ item.precomputed_features for item in base_output.images if item.precomputed_features is not None ] pixel_values = ( torch.concat(all_pixel_values) if all_pixel_values else None ) precomputed_features = ( torch.concat(all_precomputed_features) if all_precomputed_features else None ) else: image_grid_thw = ret["image_grid_thw"] pixel_values = ret["pixel_values"] precomputed_features = None items += [ MultimodalDataItem( pixel_values=pixel_values, image_grid_thws=image_grid_thw, video_grid_thws=video_grid_thw, precomputed_features=precomputed_features, image_offsets=image_offsets, modality=Modality.IMAGE, ) ] mrope_positions, mrope_position_delta = MRotaryEmbedding.get_rope_index( spatial_merge_size=self.hf_config.vision_config.spatial_merge_size, image_token_id=self.image_token_id, video_token_id=self.video_token_id, vision_start_token_id=self.vision_start_token_id, model_type=self.hf_config.model_type, tokens_per_second=getattr( self.hf_config.vision_config, "tokens_per_second", None ), input_ids=torch.tensor(input_ids).unsqueeze(0), image_grid_thw=image_grid_thw, video_grid_thw=video_grid_thw, second_per_grid_ts=ret.get("second_per_grid_ts", None), ) mrope_positions = mrope_positions.squeeze(1) return { "input_ids": input_ids, "mm_items": items, "im_start_id": self.IM_START_TOKEN_ID, "im_end_id": self.IM_END_TOKEN_ID, "im_token_id": self.image_token_id, "video_token_id": self.video_token_id, "mrope_positions": mrope_positions, "mrope_position_delta": mrope_position_delta, }