Files
sglang/python/sglang/srt/managers/image_processor.py
2024-10-26 00:09:44 -07:00

362 lines
14 KiB
Python

# TODO: also move pad_input_ids into this module
import asyncio
import concurrent.futures
import logging
import multiprocessing as mp
import os
from abc import ABC, abstractmethod
from typing import List, Optional, Union
import numpy as np
import transformers
from sglang.srt.hf_transformers_utils import get_processor
from sglang.srt.mm_utils import expand2square, process_anyres_image
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import load_image
from sglang.utils import get_exception_traceback
logger = logging.getLogger(__name__)
global global_processor
def init_global_processor(server_args: ServerArgs):
"""Init the global processor for multi modal models."""
global global_processor
transformers.logging.set_verbosity_error()
global_processor = get_processor(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
)
class BaseImageProcessor(ABC):
def __init__(self, hf_config, server_args, _processor):
self.hf_config = hf_config
self._processor = _processor
self.executor = concurrent.futures.ProcessPoolExecutor(
initializer=init_global_processor,
mp_context=mp.get_context("fork"),
initargs=(server_args,),
max_workers=int(os.environ.get("SGLANG_CPU_COUNT", os.cpu_count())),
)
@abstractmethod
async def process_images_async(self, image_data, input_text, **kwargs):
pass
class DummyImageProcessor(BaseImageProcessor):
def __init__(self):
pass
async def process_images_async(self, *args, **kwargs):
return None
class LlavaImageProcessor(BaseImageProcessor):
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
@staticmethod
def _process_single_image_task(
image_data: Union[str, bytes],
image_aspect_ratio: Optional[str] = None,
image_grid_pinpoints: Optional[str] = None,
image_processor=None,
):
image_processor = image_processor or global_processor.image_processor
try:
image, image_size = load_image(image_data)
if image_size is not None:
# It is a video with multiple images
image_hash = hash(image_data)
pixel_values = image_processor(image)["pixel_values"]
for _ in range(len(pixel_values)):
pixel_values[_] = pixel_values[_].astype(np.float16)
pixel_values = np.stack(pixel_values, axis=0)
return pixel_values, image_hash, image_size
else:
# It is an image
image_hash = hash(image_data)
if image_aspect_ratio == "pad":
image = expand2square(
image,
tuple(int(x * 255) for x in image_processor.image_mean),
)
pixel_values = image_processor(image.convert("RGB"))[
"pixel_values"
][0]
elif image_aspect_ratio == "anyres" or (
image_aspect_ratio is not None
and "anyres_max" in image_aspect_ratio
):
pixel_values = process_anyres_image(
image, image_processor, image_grid_pinpoints
)
else:
pixel_values = image_processor(image)["pixel_values"][0]
if isinstance(pixel_values, np.ndarray):
pixel_values = pixel_values.astype(np.float16)
return pixel_values, image_hash, image.size
except Exception:
logger.error("Exception in TokenizerManager:\n" + get_exception_traceback())
async def _process_single_image(
self, image_data: Union[bytes, str], aspect_ratio: str, grid_pinpoints: str
):
if self.executor is not None:
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
self.executor,
LlavaImageProcessor._process_single_image_task,
image_data,
aspect_ratio,
grid_pinpoints,
)
else:
return self._process_single_image_task(
image_data, aspect_ratio, grid_pinpoints
)
async def process_images_async(
self, image_data: List[Union[str, bytes]], input_text, request_obj
):
if not image_data:
return None
aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None)
grid_pinpoints = (
self.hf_config.image_grid_pinpoints
if hasattr(self.hf_config, "image_grid_pinpoints")
and "anyres" in aspect_ratio
else None
)
if isinstance(image_data, list) and len(image_data) > 0:
# Multiple images
if len(image_data) > 1:
aspect_ratio = "pad" # LLaVA OneVision Handling: more than one image --> interleaved image mode or video mode. We do not use anyres
pixel_values, image_hashes, image_sizes = [], [], []
res = []
for img_data in image_data:
res.append(
self._process_single_image(
img_data, aspect_ratio, grid_pinpoints
)
)
res = await asyncio.gather(*res)
for pixel_v, image_h, image_s in res:
pixel_values.append(pixel_v)
image_hashes.append(image_h)
image_sizes.append(image_s)
if isinstance(pixel_values[0], np.ndarray):
pixel_values = np.stack(pixel_values, axis=0)
else:
# A single image
pixel_values, image_hash, image_size = await self._process_single_image(
image_data[0], aspect_ratio, grid_pinpoints
)
image_hashes = [image_hash]
image_sizes = [image_size]
elif isinstance(image_data, str):
# A single image
pixel_values, image_hash, image_size = await self._process_single_image(
image_data, aspect_ratio, grid_pinpoints
)
image_hashes = [image_hash]
image_sizes = [image_size]
else:
raise ValueError(f"Invalid image data: {image_data}")
return {
"pixel_values": pixel_values,
"image_hashes": image_hashes,
"image_sizes": image_sizes,
"modalities": request_obj.modalities,
}
class MllamaImageProcessor(BaseImageProcessor):
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
@staticmethod
def _process_single_image_task(images, input_text):
# input_ids', 'attention_mask', 'pixel_values', 'aspect_ratio_ids', 'aspect_ratio_mask', 'cross_attention_mask'
return global_processor(images, input_text, return_tensors="pt")
async def _process_single_image(self, images, input_text):
if self.executor is not None:
loop = asyncio.get_event_loop()
image_inputs = await loop.run_in_executor(
self.executor,
MllamaImageProcessor._process_single_image_task,
images,
input_text,
)
else:
image_inputs = self._processor(images, input_text, return_tensors="pt")
return image_inputs
async def process_images_async(
self, image_data: List[Union[str, bytes]], input_text, *args, **kwargs
):
if not image_data:
return None
if isinstance(input_text, list):
assert len(input_text) and isinstance(input_text[0], int)
input_text = self._processor.tokenizer.decode(input_text)
if not isinstance(image_data, list):
image_data = [image_data]
if len(image_data) > 0:
images = [load_image(image)[0] for image in image_data]
else:
images = load_image(image_data[0])[0]
image_inputs = await self._process_single_image(images, input_text)
image_inputs["image_hashes"] = [hash(str(image_data))]
image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0]
return image_inputs
class Qwen2VLImageProcessor(BaseImageProcessor):
def __init__(self, hf_config, server_args, _image_processor):
self.hf_config = hf_config
self._image_processor = _image_processor
self.executor = concurrent.futures.ProcessPoolExecutor(
initializer=init_global_processor,
mp_context=mp.get_context("fork"),
initargs=(server_args,),
max_workers=int(os.environ.get("SGLANG_CPU_COUNT", os.cpu_count())),
)
@staticmethod
def _process_single_image_task(
image_data: Union[str, bytes],
image_processor=None,
):
image_processor = image_processor or global_processor.image_processor
try:
image, image_size = load_image(image_data)
if image_size is not None:
# It is a video with multiple images
image_hash = hash(image_data)
process_result = image_processor(image)
pixel_values, image_grid_thws = (
process_result["pixel_values"],
process_result["image_grid_thw"][0],
)
for _ in range(len(pixel_values)):
pixel_values[_] = pixel_values[_].astype(np.float16)
pixel_values = np.stack(pixel_values, axis=0)
image_grid_thws = np.stack(image_grid_thws, axis=0)
return pixel_values, image_hash, image_size, image_grid_thws
else:
# It is an image
image_hash = hash(image_data)
process_result = image_processor(image)
pixel_values, image_grid_thws = (
process_result["pixel_values"],
process_result["image_grid_thw"][0],
)
if isinstance(pixel_values, np.ndarray):
pixel_values = pixel_values.astype(np.float16)
return pixel_values, image_hash, image.size, image_grid_thws
except Exception:
logger.error("Exception in TokenizerManager:\n" + get_exception_traceback())
async def _process_single_image(self, image_data: Union[bytes, str]):
if self.executor is not None:
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
self.executor,
Qwen2VLImageProcessor._process_single_image_task,
image_data,
)
else:
return self._process_single_image_task(image_data)
async def process_images_async(
self, image_data: List[Union[str, bytes]], input_text, request_obj
):
if not image_data:
return None
if isinstance(image_data, list) and len(image_data) > 0:
# Multiple images
if len(image_data) > 1:
pixel_values, image_hashes, image_sizes, image_grid_thws = (
[],
[],
[],
[],
)
res = []
for img_data in image_data:
res.append(self._process_single_image(img_data))
res = await asyncio.gather(*res)
for pixel_v, image_h, image_s, image_thw in res:
pixel_values.append(pixel_v)
image_hashes.append(image_h)
image_sizes.append(image_s)
image_grid_thws.append(image_thw)
if isinstance(pixel_values[0], np.ndarray):
pixel_values = np.concatenate(pixel_values, axis=0)
else:
# A single image
pixel_values, image_hash, image_size, image_grid_thw = (
await self._process_single_image(image_data[0])
)
image_hashes = [image_hash]
image_sizes = [image_size]
image_grid_thws = [image_grid_thw]
elif isinstance(image_data, str):
# A single image
pixel_values, image_hash, image_size, image_grid_thw = (
await self._process_single_image(image_data)
)
image_hashes = [image_hash]
image_sizes = [image_size]
image_grid_thws = [image_grid_thw]
else:
raise ValueError(f"Invalid image data: {image_data}")
return {
"pixel_values": pixel_values,
"image_hashes": image_hashes,
"image_sizes": image_sizes,
"modalities": request_obj.modalities,
"image_grid_thws": image_grid_thws,
}
def get_image_processor(
hf_config, server_args: ServerArgs, processor
) -> BaseImageProcessor:
if "MllamaForConditionalGeneration" in hf_config.architectures:
return MllamaImageProcessor(hf_config, server_args, processor)
elif "Qwen2VLForConditionalGeneration" in hf_config.architectures:
return Qwen2VLImageProcessor(hf_config, server_args, processor.image_processor)
else:
return LlavaImageProcessor(hf_config, server_args, processor.image_processor)
def get_dummy_image_processor():
return DummyImageProcessor()