model: Minicpmo (#3023)

This commit is contained in:
Mick
2025-03-25 11:08:40 +08:00
committed by GitHub
parent 64129fa632
commit 1e86457c90
40 changed files with 2906 additions and 493 deletions

View File

@@ -0,0 +1,275 @@
import concurrent
import concurrent.futures
import dataclasses
import multiprocessing as mp
import os
from abc import ABC, abstractmethod
from typing import Optional
import numpy as np
import PIL
import transformers
from decord import VideoReader, cpu
from openai import BadRequestError
from PIL import Image
from sglang.srt.utils import load_audio, load_image, logger
global global_processor
def get_global_processor():
global global_processor
return global_processor
@dataclasses.dataclass
class BaseMultiModalProcessorOutput:
# input_text, with each frame of video/image represented with a image_token
input_text: str
mm_data_hashes: Optional[list[int]]
# images
image_sizes: Optional[list[int]]
# frames loaded from image and video, in given order
images: Optional[list[PIL.Image]] = None
# audios
audios: Optional[list[np.ndarray]] = None
def normalize(self):
for field_name in ["data_hashes", "image_sizes", "images", "audios"]:
field = getattr(self, field_name, None)
if field is not None and isinstance(field, list) and len(field) == 0:
setattr(self, field_name, None)
@dataclasses.dataclass
class MultimodalSpecialTokens:
image_token: Optional[str] = None
video_token: Optional[str] = None
audio_token: Optional[str] = None
def collect(self) -> list[str]:
return [
token
for token in [self.image_token, self.video_token, self.audio_token]
if token
]
class BaseMultimodalProcessor(ABC):
models = []
def __init__(self, hf_config, server_args, _processor):
self.hf_config = hf_config
self._processor = _processor
self.server_args = server_args
# FIXME: not accurate, model and image specific
self.NUM_TOKEN_PER_FRAME = 330
# Initialize global processor first
init_global_processor(self, server_args)
self.executor = concurrent.futures.ProcessPoolExecutor(
initializer=init_global_processor,
mp_context=mp.get_context("fork"),
initargs=(
self,
server_args,
),
max_workers=int(os.environ.get("SGLANG_CPU_COUNT", os.cpu_count())),
)
def _build_processor(self, server_args):
"""Init the global processor for multi modal models."""
from sglang.srt.hf_transformers_utils import get_processor
return get_processor(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
)
@abstractmethod
async def process_mm_data_async(
self, image_data, input_text, max_req_input_len, **kwargs
):
pass
def get_estimated_frames_list(self, image_data):
"""
estimate the total frame count from all visual input
"""
# Before processing inputs
estimated_frames_list = []
for image in image_data:
if isinstance(image, str) and image.startswith("video:"):
path = image[len("video:") :]
# Estimate frames for the video
vr = VideoReader(path, ctx=cpu(0))
num_frames = len(vr)
else:
# For images, each contributes one frame
num_frames = 1
estimated_frames_list.append(num_frames)
return estimated_frames_list
@staticmethod
def encode_video(video_path, frame_count_limit=None):
if not os.path.exists(video_path):
logger.error(f"Video {video_path} does not exist")
return []
if frame_count_limit == 0:
return []
def uniform_sample(l, n):
gap = len(l) / n
idxs = [int(i * gap + gap / 2) for i in range(n)]
return [l[i] for i in idxs]
vr = VideoReader(video_path, ctx=cpu(0))
sample_fps = round(vr.get_avg_fps() / 1) # FPS
frame_indices = [i for i in range(0, len(vr), sample_fps)]
if frame_count_limit is not None and len(frame_indices) > frame_count_limit:
frame_indices = uniform_sample(frame_indices, frame_count_limit)
frames = vr.get_batch(frame_indices).asnumpy()
frames = [Image.fromarray(v.astype("uint8")) for v in frames]
return frames
def load_mm_data(
self,
input_ids: list[int],
multimodal_tokens: MultimodalSpecialTokens,
max_req_input_len: int,
image_data: Optional[list] = None,
audio_data: Optional[list] = None,
return_text: Optional[bool] = True,
discard_alpha_channel: bool = True,
) -> BaseMultiModalProcessorOutput:
"""
Each frame of video/image will be replaced by a single image token
Args:
multimodal_tokens (list[str]): list of special token which denoting a single multimodal data
e.g. image token or audio token
discard_alpha_channel: if True, discards the alpha channel in the returned images
"""
if isinstance(multimodal_tokens.image_token, int):
multimodal_tokens.image_token = (
self._processor.tokenizer.convert_ids_to_tokens(
multimodal_tokens.image_token
)
)
else:
multimodal_tokens.image_token = multimodal_tokens.image_token
if isinstance(input_ids, list) and return_text:
assert len(input_ids) and isinstance(input_ids[0], int)
input_text = self._processor.tokenizer.decode(input_ids)
else:
input_text = input_ids
if return_text:
import re
pattern = (
"("
+ "|".join(re.escape(sep) for sep in multimodal_tokens.collect())
+ ")"
)
# split text into list of normal text and special tokens
text_parts = re.split(pattern, input_text)
# TODO(mick): load from server_args, env, or sampling_params
MAX_NUM_FRAMES = 30
estimated_frames_list = self.get_estimated_frames_list(image_data=image_data)
total_frame_count = sum(estimated_frames_list)
# a heuristic value, suggesting the maximum fraction of frames to embed from all visual inputs.
# e.g., 0.1 suggests that 1 frame out of 10 input frames should be used
scaling_factor = min(1.0, MAX_NUM_FRAMES / max(1, total_frame_count))
assert len(image_data) == len(estimated_frames_list)
image_index, audio_index = 0, 0
hashes, image_sizes, images, audios = [], [], [], []
new_text = ""
for index, text_part in enumerate(text_parts):
try:
if text_part == multimodal_tokens.image_token:
# load as image
if len(images) >= MAX_NUM_FRAMES:
frames_to_process = 0
else:
estimated_frames = estimated_frames_list[image_index]
frames_to_process = max(
1, int(estimated_frames * scaling_factor)
)
if frames_to_process == 0:
frames = []
else:
image_file = image_data[image_index]
if isinstance(image_file, str) and image_file.startswith(
"video:"
):
# video
path = image_file[len("video:") :]
frames = BaseMultimodalProcessor.encode_video(
path, frame_count_limit=frames_to_process
)
else:
# image
raw_image, _size = load_image(image_file)
if discard_alpha_channel:
raw_image = raw_image.convert("RGB")
frames = [raw_image]
if len(frames) == 0:
continue
image_sizes += frames[0].size * len(frames)
hashes += [hash(image_file)] * len(frames)
images += frames
image_index += 1
if frames_to_process != 0:
new_text += multimodal_tokens.image_token * len(frames)
assert frames_to_process == len(frames)
elif text_part == multimodal_tokens.audio_token:
# load as audio
audio_file = audio_data[audio_index]
audio = load_audio(audio_file)
hashes += [hash(audio_file)]
audios += [audio]
audio_index += 1
new_text += multimodal_tokens.audio_token
else:
# TODO(mick): handle video
# normal text
new_text += text_part
except Exception as e:
logger.error(f"An exception occurred while loading images: {e}")
raise BadRequestError(
f"An exception occurred while loading images: {e}"
)
out = BaseMultiModalProcessorOutput(
mm_data_hashes=hashes,
image_sizes=image_sizes,
images=images,
audios=audios,
input_text=new_text,
)
out.normalize()
return out
def init_global_processor(sglang_processor: BaseMultimodalProcessor, server_args):
"""
Init the global processor for multimodal models."""
global global_processor
transformers.logging.set_verbosity_error()
global_processor = sglang_processor._build_processor(server_args=server_args)

View File

@@ -0,0 +1,119 @@
# Copyright (c) 2023-2024 DeepSeek.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
# the Software, and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
import asyncio
import torch
from sglang.srt.managers.multimodal_processors.base_processor import (
BaseMultimodalProcessor,
MultimodalSpecialTokens,
get_global_processor,
)
from sglang.srt.models.deepseek_vl2 import DeepseekVL2ForCausalLM
class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
models = [DeepseekVL2ForCausalLM]
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
self.IMAGE_TOKEN = "<image>"
@staticmethod
def _process_images_task(image, input_text, max_req_input_len):
processor = get_global_processor()
res = processor.__call__(
conversations=input_text, images=image, max_req_input_len=max_req_input_len
)
image_token_id = processor.image_token_id
res["im_token_id"] = image_token_id
return res
async def _process_images(self, image_data, input_text, max_req_input_len):
if self.executor is not None:
loop = asyncio.get_event_loop()
image_inputs = await loop.run_in_executor(
self.executor,
DeepseekVL2ImageProcessor._process_images_task,
image_data,
input_text,
max_req_input_len,
)
else:
image_inputs = self._process_images_task(
image_data, input_text, max_req_input_len
)
return image_inputs
async def _process_images(self, image_data, input_text, max_req_input_len):
if self.executor is not None:
loop = asyncio.get_event_loop()
image_inputs = await loop.run_in_executor(
self.executor,
DeepseekVL2ImageProcessor._process_images_task,
image_data,
input_text,
max_req_input_len,
)
else:
image_inputs = self._process_images_task(
image_data, input_text, max_req_input_len
)
return image_inputs
async def process_mm_data_async(
self, image_data, input_ids, request_obj, max_req_input_len, *args, **kwargs
):
if not image_data:
return None
if not isinstance(image_data, list):
image_data = [image_data]
images, image_sizes = [], []
image_token = self.IMAGE_TOKEN
base_output = self.load_mm_data(
input_ids,
image_data=image_data,
multimodal_tokens=MultimodalSpecialTokens(image_token=image_token),
max_req_input_len=max_req_input_len,
)
res = await self._process_images(
base_output.images, base_output.input_text, max_req_input_len
)
images_seq_mask = res["images_seq_mask"]
images_spatial_crop = res["images_spatial_crop"]
batched_images_spatial_crop = []
batched_images_spatial_crop.append(images_spatial_crop)
batched_images_spatial_crop = torch.stack(batched_images_spatial_crop, dim=0)
return {
"input_ids": res["input_ids"].tolist(),
"pixel_values": res["images"],
"im_token_id": res["im_token_id"],
"data_hashes": base_output.mm_data_hashes,
"image_sizes": image_sizes,
"images_emb_mask": images_seq_mask,
"image_spatial_crop": batched_images_spatial_crop,
"modalities": request_obj.modalities or ["image"],
}

View File

@@ -0,0 +1,83 @@
from typing import List, Union
from transformers.utils import logging
from sglang.srt.managers.multimodal_processor import (
BaseMultimodalProcessor as SGLangBaseProcessor,
)
from sglang.srt.managers.multimodal_processors.base_processor import (
MultimodalSpecialTokens,
get_global_processor,
)
from sglang.srt.models.gemma3_mm import Gemma3ForConditionalGeneration
# Copied from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma3/image_processing_gemma3_fast.py
# will be removed in the future
logger = logging.get_logger(__name__)
class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
models = [Gemma3ForConditionalGeneration]
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
self.IMAGE_TOKEN = "<start_of_image>"
self.IM_START_TOKEN_ID = hf_config.boi_token_index
self.IM_END_TOKEN_ID = hf_config.eoi_token_index
async def _process_single_image(self, images, input_text) -> dict:
if isinstance(images, list) and len(images) == 0:
images = None
processor = get_global_processor()
result = processor.__call__(
text=[input_text],
images=images,
padding=True,
return_tensors="pt",
# if RGBA, this needs to be set
# images_kwargs={
# "input_data_format": ChannelDimension.FIRST
# }
)
pixel_values = getattr(result, "pixel_values", None)
return {
"input_ids": result.input_ids,
"pixel_values": pixel_values,
}
async def process_mm_data_async(
self,
image_data: List[Union[str, bytes]],
input_ids,
request_obj,
max_req_input_len,
*args,
**kwargs,
):
if not image_data:
return None
if isinstance(image_data, str):
image_data = [image_data]
image_token = self.IMAGE_TOKEN
base_output = self.load_mm_data(
input_ids=input_ids,
image_data=image_data,
multimodal_tokens=MultimodalSpecialTokens(image_token=image_token),
max_req_input_len=max_req_input_len,
discard_alpha_channel=True,
)
ret = await self._process_single_image(
input_text=base_output.input_text, images=base_output.images
)
return {
"input_ids": ret["input_ids"].flatten().tolist(),
"pixel_values": ret["pixel_values"],
"data_hashes": base_output.mm_data_hashes,
"im_start_id": self.IM_START_TOKEN_ID,
"im_end_id": self.IM_END_TOKEN_ID,
}

View File

@@ -0,0 +1,84 @@
import asyncio
from typing import List, Union
from sglang.srt.managers.multimodal_processors.base_processor import (
BaseMultimodalProcessor,
MultimodalSpecialTokens,
get_global_processor,
)
from sglang.srt.models.deepseek_janus_pro import MultiModalityCausalLM
class JanusProImageProcessor(BaseMultimodalProcessor):
models = [MultiModalityCausalLM]
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
@staticmethod
def _process_images_task(images, input_text):
processor = get_global_processor()
result = processor.__call__(
prompt=input_text, images=images, return_tensors="pt"
)
return {
"input_ids": result["input_ids"],
"pixel_values": result["pixel_values"],
"images_emb_mask": result["images_emb_mask"],
"im_start_id": processor.image_start_id,
"im_end_id": processor.image_end_id,
"im_token_id": processor.image_id,
}
async def _process_images(self, images, input_text):
if self.executor is not None:
loop = asyncio.get_event_loop()
image_inputs = await loop.run_in_executor(
self.executor,
JanusProImageProcessor._process_images_task,
images,
input_text,
)
else:
image_inputs = self._processor(
images=images, text=input_text, return_tensors="pt"
)
return image_inputs
async def process_mm_data_async(
self,
image_data: List[Union[str, bytes]],
input_ids,
request_obj,
max_req_input_len,
**kwargs,
):
if not image_data:
return None
if not isinstance(image_data, list):
image_data = [image_data]
base_out = self.load_mm_data(
input_ids=input_ids,
image_data=image_data,
multimodal_tokens=MultimodalSpecialTokens(
image_token="<image_placeholder>"
),
max_req_input_len=max_req_input_len,
)
images = base_out.images
res = await self._process_images(images=images, input_text=base_out.input_text)
# print(res)
# print(base_out)
# print("", res["images_emb_mask"].shape)
return {
"input_ids": res["input_ids"].flatten().tolist(),
"pixel_values": res["pixel_values"],
"images_emb_mask": res["images_emb_mask"],
"data_hashes": base_out.mm_data_hashes,
"im_start_id": res["im_start_id"],
"im_end_id": res["im_end_id"],
"im_token_id": res["im_token_id"],
}

View File

@@ -0,0 +1,147 @@
import asyncio
from typing import List, Optional, Union
import numpy as np
from sglang.srt.managers.multimodal_processors.base_processor import (
BaseMultimodalProcessor,
get_global_processor,
)
from sglang.srt.mm_utils import expand2square, process_anyres_image
from sglang.srt.models.llava import LlavaMistralForCausalLM, LlavaQwenForCausalLM
from sglang.srt.models.llavavid import LlavaVidForCausalLM
from sglang.srt.utils import load_image, logger
from sglang.utils import get_exception_traceback
class LlavaImageProcessor(BaseMultimodalProcessor):
models = [LlavaVidForCausalLM, LlavaQwenForCausalLM, LlavaMistralForCausalLM]
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
@staticmethod
def _process_single_image_task(
image_data: Union[str, bytes],
image_aspect_ratio: Optional[str] = None,
image_grid_pinpoints: Optional[str] = None,
image_processor=None,
):
processor = get_global_processor()
image_processor = image_processor or processor.image_processor
try:
image, image_size = load_image(image_data)
if image_size is not None:
# It is a video with multiple images
image_hash = hash(image_data)
pixel_values = image_processor(image)["pixel_values"]
for _ in range(len(pixel_values)):
pixel_values[_] = pixel_values[_].astype(np.float16)
pixel_values = np.stack(pixel_values, axis=0)
return pixel_values, image_hash, image_size
else:
# It is an image
image_hash = hash(image_data)
if image_aspect_ratio == "pad":
image = expand2square(
image,
tuple(int(x * 255) for x in image_processor.image_mean),
)
pixel_values = image_processor(image.convert("RGB"))[
"pixel_values"
][0]
elif image_aspect_ratio == "anyres" or (
image_aspect_ratio is not None
and "anyres_max" in image_aspect_ratio
):
pixel_values = process_anyres_image(
image, image_processor, image_grid_pinpoints
)
else:
pixel_values = image_processor(image)["pixel_values"][0]
if isinstance(pixel_values, np.ndarray):
pixel_values = pixel_values.astype(np.float16)
return pixel_values, image_hash, image.size
except Exception:
logger.error("Exception in TokenizerManager:\n" + get_exception_traceback())
async def _process_single_image(
self, image_data: Union[bytes, str], aspect_ratio: str, grid_pinpoints: str
):
if self.executor is not None:
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
self.executor,
LlavaImageProcessor._process_single_image_task,
image_data,
aspect_ratio,
grid_pinpoints,
)
else:
return self._process_single_image_task(
image_data, aspect_ratio, grid_pinpoints
)
async def process_mm_data_async(
self,
image_data: List[Union[str, bytes]],
input_text,
request_obj,
*args,
**kwargs,
):
if not image_data:
return None
modalities = request_obj.modalities or ["image"]
aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None)
grid_pinpoints = (
self.hf_config.image_grid_pinpoints
if hasattr(self.hf_config, "image_grid_pinpoints")
and "anyres" in aspect_ratio
else None
)
if isinstance(image_data, str):
image_data = [image_data]
if isinstance(image_data, list) and len(image_data) > 0:
if "multi-images" in modalities or "video" in modalities:
# Multiple images
aspect_ratio = "pad" # LLaVA OneVision Handling: more than one image --> interleaved image mode or video mode. We do not use anyres
pixel_values, data_hashes, image_sizes = [], [], []
res = []
for img_data in image_data:
res.append(
self._process_single_image(
img_data, aspect_ratio, grid_pinpoints
)
)
res = await asyncio.gather(*res)
for pixel_v, image_h, image_s in res:
pixel_values.append(pixel_v)
data_hashes.append(image_h)
image_sizes.append(image_s)
if isinstance(pixel_values[0], np.ndarray):
pixel_values = np.stack(pixel_values, axis=0)
else:
# A single image
pixel_values, image_hash, image_size = await self._process_single_image(
image_data[0], aspect_ratio, grid_pinpoints
)
data_hashes = [image_hash]
image_sizes = [image_size]
else:
raise ValueError(f"Invalid image data: {image_data}")
return {
"pixel_values": pixel_values,
"data_hashes": data_hashes,
"image_sizes": image_sizes,
"modalities": request_obj.modalities or ["image"],
}

View File

@@ -0,0 +1,167 @@
import asyncio
from typing import List, Union
import torch
from sglang.srt.managers.multimodal_processors.base_processor import (
BaseMultimodalProcessor,
MultimodalSpecialTokens,
get_global_processor,
)
from sglang.srt.models.minicpmo import MiniCPMO
from sglang.srt.models.minicpmv import MiniCPMV
# Compatible with both 'O' and 'V'
class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
models = [MiniCPMV, MiniCPMO]
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
self.image_token = "(<image>./</image>)"
self.audio_token = "(<audio>./</audio>)"
@staticmethod
def _process_data_task(input_text, images=None, audios=None):
if isinstance(images, list) and len(images) == 0:
images = None
if isinstance(audios, list) and len(audios) == 0:
audios = None
result = get_global_processor().__call__(
text=input_text,
images=images,
audios=audios,
return_tensors="pt",
chunk_input=True,
)
return {
"input_ids": result.input_ids,
"pixel_values": getattr(result, "pixel_values", None),
"tgt_sizes": getattr(result, "tgt_sizes", None),
"audio_features": getattr(result, "audio_features", None),
"audio_feature_lens": getattr(result, "audio_feature_lens", None),
"audio_bounds": getattr(result, "audio_bounds", None),
}
async def _process_data(self, images, input_text, audios=None):
if self.executor is not None:
loop = asyncio.get_event_loop()
multimodal_data_inputs = await loop.run_in_executor(
self.executor,
MiniCPMMultimodalProcessor._process_data_task,
input_text,
images,
audios,
)
else:
multimodal_data_inputs = self._processor(
images=images, text=input_text, audios=audios, return_tensors="pt"
)
return multimodal_data_inputs
async def process_mm_data_async(
self,
image_data: List[Union[str, bytes]],
input_ids,
request_obj,
max_req_input_len,
):
audio_data = request_obj.audio_data
if not image_data and not audio_data:
return None
if not isinstance(image_data, list):
image_data = [image_data]
if not isinstance(audio_data, list):
audio_data = [audio_data]
base_output = self.load_mm_data(
input_ids=input_ids,
max_req_input_len=max_req_input_len,
audio_data=audio_data,
image_data=image_data,
multimodal_tokens=MultimodalSpecialTokens(
image_token=self.image_token, audio_token=self.audio_token
),
)
if base_output is None:
return None
res = await self._process_data(
images=base_output.images,
input_text=base_output.input_text,
audios=base_output.audios,
)
# Collect special token ids
tokenizer = self._processor.tokenizer
slice_start_id, slice_end_id, audio_start_id, audio_end_id = (
None,
None,
None,
None,
)
if tokenizer.slice_start_id:
slice_start_id = tokenizer.slice_start_id
slice_end_id = tokenizer.slice_end_id
if hasattr(tokenizer, "audio_start_id"):
audio_start_id = tokenizer.audio_start_id
audio_end_id = tokenizer.audio_end_id
im_token_id = tokenizer.unk_token_id
pixel_values = res["pixel_values"]
tgt_sizes = res["tgt_sizes"]
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError(
"Incorrect type of pixel values. " f"Got type: {type(pixel_values)}"
)
if not isinstance(tgt_sizes, (torch.Tensor, list)):
raise ValueError(
"Incorrect type of target sizes. " f"Got type: {type(tgt_sizes)}"
)
if len(pixel_values) != len(tgt_sizes):
raise ValueError(
"Inconsistent batch lengths, found: "
f"{len(pixel_values)} vs. {len(tgt_sizes)}"
)
pixel_values_flat: List[torch.Tensor] = []
tgt_sizes_flat: List[torch.Tensor] = []
for pixel_b, tgt_b in zip(pixel_values, tgt_sizes):
# per image
if len(pixel_b) != len(tgt_b):
raise ValueError(
"Inconsistent N lengths, found: " f"{len(pixel_b)} vs {len(tgt_b)}"
)
for pixel_n, tgt_n in zip(pixel_b, tgt_b):
pixel_values_flat += [pixel_n]
tgt_sizes_flat += [tgt_n]
pixel_values = pixel_values_flat
if len(tgt_sizes_flat) == 0:
tgt_sizes = None
else:
tgt_sizes = torch.stack(tgt_sizes_flat)
if not isinstance(res["audio_features"], list):
res["audio_features"] = [res["audio_features"]]
return {
"input_ids": res["input_ids"].flatten().tolist(),
"pixel_values": pixel_values,
"tgt_sizes": tgt_sizes,
"data_hashes": base_output.mm_data_hashes,
"modalities": request_obj.modalities or ["image"],
"audio_start_id": audio_start_id,
"audio_end_id": audio_end_id,
"audio_features": res["audio_features"],
"audio_bounds": res["audio_bounds"],
"audio_feature_lens": res["audio_feature_lens"],
"im_token_id": im_token_id,
"im_start_id": tokenizer.im_start_id,
"im_end_id": tokenizer.im_end_id,
"slice_start_id": slice_start_id,
"slice_end_id": slice_end_id,
}

View File

@@ -0,0 +1,59 @@
import asyncio
from typing import List, Union
from sglang.srt.managers.multimodal_processors.base_processor import (
BaseMultimodalProcessor,
get_global_processor,
)
from sglang.srt.models.mllama import MllamaForConditionalGeneration
from sglang.srt.utils import load_image
class MllamaImageProcessor(BaseMultimodalProcessor):
models = [MllamaForConditionalGeneration]
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
@staticmethod
def _process_single_image_task(images, input_text):
# input_ids', 'attention_mask', 'pixel_values', 'aspect_ratio_ids', 'aspect_ratio_mask', 'cross_attention_mask'
return get_global_processor()(images, input_text, return_tensors="pt")
async def _process_single_image(self, images, input_text):
if self.executor is not None:
loop = asyncio.get_event_loop()
image_inputs = await loop.run_in_executor(
self.executor,
MllamaImageProcessor._process_single_image_task,
images,
input_text,
)
else:
image_inputs = self._processor(images, input_text, return_tensors="pt")
return image_inputs
async def process_mm_data_async(
self, image_data: List[Union[str, bytes]], input_text, *args, **kwargs
):
if not image_data:
return None
if isinstance(input_text, list):
assert len(input_text) and isinstance(input_text[0], int)
input_text = self._processor.tokenizer.decode(input_text)
if not isinstance(image_data, list):
image_data = [image_data]
if len(image_data) > 0:
images = [load_image(image)[0] for image in image_data]
else:
images = load_image(image_data[0])[0]
image_inputs = await self._process_single_image(images, input_text)
image_inputs["data_hashes"] = [hash(str(image_data))]
image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0]
return image_inputs

View File

@@ -0,0 +1,167 @@
import asyncio
import math
import time
from typing import List, Union
import torch
from PIL import Image
from sglang.srt.managers.multimodal_processor import (
BaseMultimodalProcessor as SGLangBaseProcessor,
)
from sglang.srt.managers.multimodal_processors.base_processor import (
MultimodalSpecialTokens,
get_global_processor,
)
from sglang.srt.models.qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
from sglang.srt.models.qwen2_vl import Qwen2VLForConditionalGeneration
# Compatible with Qwen2VL and Qwen2_5VL
class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
models = [Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration]
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
self.IMAGE_TOKEN = "<|vision_start|><|image_pad|><|vision_end|>"
self.IM_START_TOKEN_ID = hf_config.vision_start_token_id
self.IM_END_TOKEN_ID = hf_config.vision_end_token_id
self.image_token_id = hf_config.image_token_id
self.video_token_id = hf_config.video_token_id
self.NUM_TOKEN_PER_FRAME = 770
self.IMAGE_FACTOR = 28
self.MIN_PIXELS = 4 * 28 * 28
self.MAX_PIXELS = 16384 * 28 * 28
self.MAX_RATIO = 200
@staticmethod
def _process_images_task(images, input_text, _hf_config):
if isinstance(images, list) and len(images) == 0:
images = None
result = get_global_processor().__call__(
text=[input_text], images=images, padding=True, return_tensors="pt"
)
return {
"input_ids": result.input_ids,
"pixel_values": getattr(result, "pixel_values", None),
"image_grid_thw": getattr(result, "image_grid_thw", None),
"second_per_grid_ts": getattr(result, "second_per_grid_ts", None),
"video_grid_thws": getattr(result, "video_grid_thws", None),
}
async def _process_single_image(self, images, input_text) -> dict:
if self.executor is not None:
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
self.executor,
Qwen2_5VLImageProcessor._process_images_task,
images,
input_text,
self.hf_config,
)
else:
return self._process_images_task(images, input_text, self.hf_config)
async def process_mm_data_async(
self,
image_data: List[Union[str, bytes]],
input_ids,
request_obj,
max_req_input_len,
*args,
**kwargs,
):
start = time.time()
if not image_data:
return None
if isinstance(image_data, str):
image_data = [image_data]
image_token = self.IMAGE_TOKEN
base_output = self.load_mm_data(
input_ids=input_ids,
image_data=image_data,
multimodal_tokens=MultimodalSpecialTokens(image_token=image_token),
max_req_input_len=max_req_input_len,
)
def smart_resize(
height: int,
width: int,
factor: int = self.IMAGE_FACTOR,
min_pixels: int = self.MIN_PIXELS,
max_pixels: int = self.MAX_PIXELS,
) -> tuple[int, int]:
"""
Rescales the image so that the following conditions are met:
1. Both dimensions (height and width) are divisible by 'factor'.
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
3. The aspect ratio of the image is maintained as closely as possible.
"""
if max(height, width) / min(height, width) > self.MAX_RATIO:
raise ValueError(
f"absolute aspect ratio must be smaller than {self.MAX_RATIO}, got {max(height, width) / min(height, width)}"
)
h_bar = max(factor, round_by_factor(height, factor))
w_bar = max(factor, round_by_factor(width, factor))
if h_bar * w_bar > max_pixels:
beta = math.sqrt((height * width) / max_pixels)
h_bar = floor_by_factor(height / beta, factor)
w_bar = floor_by_factor(width / beta, factor)
elif h_bar * w_bar < min_pixels:
beta = math.sqrt(min_pixels / (height * width))
h_bar = ceil_by_factor(height * beta, factor)
w_bar = ceil_by_factor(width * beta, factor)
return h_bar, w_bar
def resize_image(image, size_factor: int = self.IMAGE_FACTOR) -> Image.Image:
width, height = image.size
min_pixels = self.MIN_PIXELS
max_pixels = self.MAX_PIXELS
resized_height, resized_width = smart_resize(
height,
width,
factor=size_factor,
min_pixels=min_pixels,
max_pixels=max_pixels,
)
image = image.resize((resized_width, resized_height))
return image
def round_by_factor(number: int, factor: int) -> int:
"""Returns the closest integer to 'number' that is divisible by 'factor'."""
return round(number / factor) * factor
def ceil_by_factor(number: int, factor: int) -> int:
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
return math.ceil(number / factor) * factor
def floor_by_factor(number: int, factor: int) -> int:
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
return math.floor(number / factor) * factor
images = [resize_image(image) for image in base_output.images]
ret = await self._process_single_image(
images=images, input_text=base_output.input_text
)
image_grid_thws = torch.concat([ret["image_grid_thw"]])
video_grid_thws = None
return {
"input_ids": ret["input_ids"].flatten().tolist(),
"pixel_values": ret["pixel_values"],
"data_hashes": base_output.mm_data_hashes,
"modalities": request_obj.modalities or ["image"],
"image_grid_thws": image_grid_thws,
"video_grid_thws": video_grid_thws,
"im_start_id": self.IM_START_TOKEN_ID,
"im_end_id": self.IM_END_TOKEN_ID,
"im_token_id": self.image_token_id,
"video_token_id": self.video_token_id,
"second_per_grid_ts": ret["second_per_grid_ts"],
}