Model: Support Qwen 2.5 vl (#3258)

This commit is contained in:
Mick
2025-02-16 16:58:53 +08:00
committed by GitHub
parent 39416e394a
commit bcc213df61
11 changed files with 2000 additions and 262 deletions

View File

@@ -1,6 +1,7 @@
# TODO: also move pad_input_ids into this module
import asyncio
import concurrent.futures
import dataclasses
import logging
import multiprocessing as mp
import os
@@ -8,6 +9,7 @@ from abc import ABC, abstractmethod
from typing import List, Optional, Union
import numpy as np
import PIL
import transformers
from decord import VideoReader, cpu
from PIL import Image
@@ -34,11 +36,22 @@ def init_global_processor(server_args: ServerArgs):
)
@dataclasses.dataclass
class BaseImageProcessorOutput:
image_hashes: list[int]
image_sizes: list[int]
all_frames: [PIL.Image]
# input_text, with each frame of video/image represented with a image_token
input_text: str
class BaseImageProcessor(ABC):
def __init__(self, hf_config, server_args, _processor):
self.hf_config = hf_config
self._processor = _processor
self.server_args = server_args
# FIXME: not accurate, model and image specific
self.NUM_TOKEN_PER_FRAME = 330
self.executor = concurrent.futures.ProcessPoolExecutor(
initializer=init_global_processor,
@@ -48,9 +61,128 @@ class BaseImageProcessor(ABC):
)
@abstractmethod
async def process_images_async(self, image_data, input_text, **kwargs):
async def process_images_async(
self, image_data, input_text, max_req_input_len, **kwargs
):
pass
def get_estimated_frames_list(self, image_data):
"""
estimate the total frame count from all visual input
"""
# Before processing inputs
estimated_frames_list = []
for image in image_data:
if isinstance(image, str) and image.startswith("video:"):
path = image[len("video:") :]
# Estimate frames for the video
vr = VideoReader(path, ctx=cpu(0))
num_frames = len(vr)
else:
# For images, each contributes one frame
num_frames = 1
estimated_frames_list.append(num_frames)
return estimated_frames_list
def encode_video(self, video_path, frame_count_limit=None):
if not os.path.exists(video_path):
logger.error(f"Video {video_path} does not exist")
return []
if frame_count_limit == 0:
return []
def uniform_sample(l, n):
gap = len(l) / n
idxs = [int(i * gap + gap / 2) for i in range(n)]
return [l[i] for i in idxs]
vr = VideoReader(video_path, ctx=cpu(0))
sample_fps = round(vr.get_avg_fps() / 1) # FPS
frame_idx = [i for i in range(0, len(vr), sample_fps)]
if frame_count_limit is not None and len(frame_idx) > frame_count_limit:
frame_idx = uniform_sample(frame_idx, frame_count_limit)
frames = vr.get_batch(frame_idx).asnumpy()
frames = [Image.fromarray(v.astype("uint8")) for v in frames]
return frames
def load_images(
self,
max_req_input_len: int,
input_ids: list,
image_data,
image_token: str,
) -> BaseImageProcessorOutput:
"""
Each frame of video/image will be replaced by a single image token
"""
image_hashes, image_sizes = [], []
all_frames = []
new_text_parts = []
if isinstance(input_ids, list):
assert len(input_ids) and isinstance(input_ids[0], int)
input_text = self._processor.tokenizer.decode(input_ids)
else:
input_text = input_ids
text_parts = input_text.split(image_token)
# roughly calculate the max number of frames under the max_req_input_len limit
def calculate_max_num_frames() -> int:
ret = (max_req_input_len - len(input_ids)) // self.NUM_TOKEN_PER_FRAME
return min(ret, 100)
MAX_NUM_FRAMES = calculate_max_num_frames()
estimated_frames_list = self.get_estimated_frames_list(image_data=image_data)
total_frame_count = sum(estimated_frames_list)
# a heuristic value, suggesting the maximum fraction of frames to embed from all visual inputs.
# e.g., 0.1 suggests that 1 frame out of 10 input frames should be used
scaling_factor = min(1.0, MAX_NUM_FRAMES / total_frame_count)
# Process each input with allocated frames
for image_index, (image, estimated_frames) in enumerate(
zip(image_data, estimated_frames_list)
):
if len(all_frames) >= MAX_NUM_FRAMES:
frames_to_process = 0
else:
frames_to_process = max(1, int(estimated_frames * scaling_factor))
if frames_to_process == 0:
frames = []
else:
try:
if isinstance(image, str) and image.startswith("video:"):
path = image[len("video:") :]
frames = self.encode_video(
path, frame_count_limit=frames_to_process
)
else:
raw_image, _size = load_image(image)
frames = [raw_image]
if len(frames) == 0:
continue
except FileNotFoundError as e:
print(e)
return None
image_sizes += frames[0].size * len(frames)
image_hashes += [hash(image)] * len(frames)
all_frames += frames
new_text_parts.append(text_parts[image_index])
if frames_to_process != 0:
new_text_parts.append(image_token * len(frames))
assert frames_to_process == len(frames)
new_text_parts.append(text_parts[-1])
input_text = "".join(new_text_parts)
return BaseImageProcessorOutput(
image_hashes, image_sizes, all_frames, input_text
)
class DummyImageProcessor(BaseImageProcessor):
def __init__(self):
@@ -248,9 +380,9 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
text=input_text, images=images, return_tensors="pt"
)
return {
"input_ids": result["input_ids"],
"pixel_values": result["pixel_values"],
"tgt_sizes": result["tgt_sizes"],
"input_ids": result.input_ids,
"pixel_values": result.pixel_values,
"tgt_sizes": result.tgt_sizes,
}
async def _process_images(self, images, input_text):
@@ -278,124 +410,20 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
):
if not image_data:
return None
if not isinstance(image_data, list):
image_data = [image_data]
image_hashes, image_sizes = [], []
all_frames = []
# roughly calculate the max number of frames under the max_req_input_len limit
def calculate_max_num_frames() -> int:
# Model-specific
NUM_TOKEN_PER_FRAME = 330
ret = (max_req_input_len - len(input_ids)) // NUM_TOKEN_PER_FRAME
return min(ret, 100)
MAX_NUM_FRAMES = calculate_max_num_frames()
# print(f"MAX_NUM_FRAMES: {MAX_NUM_FRAMES}")
def get_estimated_frames_list():
"""
estimate the total frame count from all visual input
"""
# Before processing inputs
estimated_frames_list = []
for image in image_data:
if isinstance(image, str) and image.startswith("video:"):
path = image[len("video:") :]
# Estimate frames for the video
vr = VideoReader(path, ctx=cpu(0))
num_frames = len(vr)
else:
# For images, each contributes one frame
num_frames = 1
estimated_frames_list.append(num_frames)
return estimated_frames_list
estimated_frames_list = get_estimated_frames_list()
total_frame_count = sum(estimated_frames_list)
scaling_factor = min(1.0, MAX_NUM_FRAMES / total_frame_count)
def encode_video(video_path, frame_count_limit=None):
if not os.path.exists(video_path):
logger.error(f"Video {video_path} does not exist")
return []
if frame_count_limit == 0:
return []
def uniform_sample(l, n):
gap = len(l) / n
idxs = [int(i * gap + gap / 2) for i in range(n)]
return [l[i] for i in idxs]
vr = VideoReader(video_path, ctx=cpu(0))
sample_fps = round(vr.get_avg_fps() / 1) # FPS
frame_idx = [i for i in range(0, len(vr), sample_fps)]
if frame_count_limit is not None and len(frame_idx) > frame_count_limit:
frame_idx = uniform_sample(frame_idx, frame_count_limit)
frames = vr.get_batch(frame_idx).asnumpy()
frames = [Image.fromarray(v.astype("uint8")) for v in frames]
return frames
if isinstance(input_ids, list):
assert len(input_ids) and isinstance(input_ids[0], int)
input_text = self._processor.tokenizer.decode(input_ids)
else:
input_text = input_ids
# MiniCPMV requires each frame of video as a single image token
text_parts = input_text.split(self.IMAGE_TOKEN)
new_text_parts = []
# Process each input with allocated frames
for image_index, (image, estimated_frames) in enumerate(
zip(image_data, estimated_frames_list)
):
if len(all_frames) >= MAX_NUM_FRAMES:
frames_to_process = 0
else:
frames_to_process = max(1, int(estimated_frames * scaling_factor))
if frames_to_process == 0:
frames = []
else:
try:
if isinstance(image, str) and image.startswith("video:"):
path = image[len("video:") :]
frames = encode_video(path, frame_count_limit=frames_to_process)
else:
raw_image, _size = load_image(image)
frames = [raw_image]
if len(frames) == 0:
continue
except FileNotFoundError as e:
print(e)
return None
image_sizes += frames[0].size * len(frames)
image_hashes += [hash(image)] * len(frames)
all_frames += frames
assert frames_to_process == len(frames)
new_text_parts.append(text_parts[image_index])
if frames_to_process != 0:
new_text_parts.append(self.IMAGE_TOKEN * len(frames))
new_text_parts.append(text_parts[-1])
input_text = "".join(new_text_parts)
if len(all_frames) == 0:
base_output = self.load_images(
max_req_input_len, input_ids, image_data, self.IMAGE_TOKEN
)
if base_output is None:
return None
res = await self._process_images(images=all_frames, input_text=input_text)
pixel_values = res["pixel_values"]
tgt_sizes = res["tgt_sizes"]
input_ids = res["input_ids"]
if len(base_output.all_frames) == 0:
return None
res = await self._process_images(
images=base_output.all_frames, input_text=base_output.input_text
)
# Collect special token ids
tokenizer = self._processor.tokenizer
@@ -405,10 +433,10 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
slice_start_id = [tokenizer.slice_start_id]
slice_end_id = [tokenizer.slice_end_id]
return {
"input_ids": input_ids.flatten().tolist(),
"pixel_values": pixel_values,
"tgt_sizes": tgt_sizes,
"image_hashes": image_hashes,
"input_ids": res["input_ids"].flatten().tolist(),
"pixel_values": res["pixel_values"],
"tgt_sizes": res["tgt_sizes"],
"image_hashes": base_output.image_hashes,
"modalities": request_obj.modalities or ["image"],
"im_start_id": im_start_id,
"im_end_id": im_end_id,
@@ -536,13 +564,80 @@ class Qwen2VLImageProcessor(BaseImageProcessor):
}
class Qwen2_5VLImageProcessor(BaseImageProcessor):
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
self.IMAGE_TOKEN = "<|vision_start|><|image_pad|><|vision_end|>"
self.IM_START_TOKEN_ID = hf_config.vision_start_token_id
self.IM_END_TOKEN_ID = hf_config.vision_end_token_id
self.NUM_TOKEN_PER_FRAME = 770
@staticmethod
def _process_images_task(images, input_text):
result = global_processor.__call__(
text=input_text, images=images, return_tensors="pt"
)
return {
"input_ids": result.input_ids,
"pixel_values": result.pixel_values,
"image_grid_thws": result.image_grid_thw,
}
async def _process_images(self, images, input_text) -> dict:
if self.executor is not None:
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
self.executor,
Qwen2_5VLImageProcessor._process_images_task,
images,
input_text,
)
else:
return self._process_images_task(images, input_text)
async def process_images_async(
self,
image_data: List[Union[str, bytes]],
input_ids,
request_obj,
max_req_input_len,
*args,
**kwargs,
):
if not image_data:
return None
if isinstance(image_data, str):
image_data = [image_data]
image_token = self.IMAGE_TOKEN
base_output = self.load_images(
max_req_input_len, input_ids, image_data, image_token
)
ret = await self._process_images(base_output.all_frames, base_output.input_text)
return {
"input_ids": ret["input_ids"].flatten().tolist(),
"pixel_values": ret["pixel_values"],
"image_hashes": base_output.image_hashes,
"modalities": request_obj.modalities or ["image"],
"image_grid_thws": ret["image_grid_thws"],
"im_start_id": self.IM_START_TOKEN_ID,
"im_end_id": self.IM_END_TOKEN_ID,
}
def get_image_processor(
hf_config, server_args: ServerArgs, processor
) -> BaseImageProcessor:
if "MllamaForConditionalGeneration" in hf_config.architectures:
return MllamaImageProcessor(hf_config, server_args, processor)
elif "Qwen2VLForConditionalGeneration" in hf_config.architectures:
return Qwen2VLImageProcessor(hf_config, server_args, processor.image_processor)
return Qwen2VLImageProcessor(hf_config, server_args, processor)
elif "Qwen2_5_VLForConditionalGeneration" in hf_config.architectures:
return Qwen2_5VLImageProcessor(hf_config, server_args, processor)
elif "MiniCPMV" in hf_config.architectures:
return MiniCPMVImageProcessor(hf_config, server_args, processor)
else: