[Fix] Address remaining issues of supporting MiniCPMV (#2977)
This commit is contained in:
@@ -240,6 +240,7 @@ class MllamaImageProcessor(BaseImageProcessor):
|
||||
class MiniCPMVImageProcessor(BaseImageProcessor):
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
self.IMAGE_TOKEN = "(<image>./</image>)"
|
||||
|
||||
@staticmethod
|
||||
def _process_images_task(images, input_text):
|
||||
@@ -271,7 +272,7 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
|
||||
async def process_images_async(
|
||||
self,
|
||||
image_data: List[Union[str, bytes]],
|
||||
input_text,
|
||||
input_ids,
|
||||
request_obj,
|
||||
max_req_input_len,
|
||||
):
|
||||
@@ -282,28 +283,49 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
|
||||
image_data = [image_data]
|
||||
|
||||
image_hashes, image_sizes = [], []
|
||||
raw_images = []
|
||||
IMAGE_TOKEN = "(<image>./</image>)"
|
||||
all_frames = []
|
||||
|
||||
# roughly calculate the max number of frames
|
||||
# TODO: the process should be applied to all the visual inputs
|
||||
# roughly calculate the max number of frames under the max_req_input_len limit
|
||||
def calculate_max_num_frames() -> int:
|
||||
# Model-specific
|
||||
NUM_TOKEN_PER_FRAME = 330
|
||||
|
||||
ret = (max_req_input_len - len(input_text)) // NUM_TOKEN_PER_FRAME
|
||||
ret = (max_req_input_len - len(input_ids)) // NUM_TOKEN_PER_FRAME
|
||||
return min(ret, 100)
|
||||
|
||||
# if cuda OOM set a smaller number
|
||||
MAX_NUM_FRAMES = calculate_max_num_frames()
|
||||
print(f"MAX_NUM_FRAMES: {MAX_NUM_FRAMES}")
|
||||
|
||||
def encode_video(video_path):
|
||||
# print(f"MAX_NUM_FRAMES: {MAX_NUM_FRAMES}")
|
||||
|
||||
def get_estimated_frames_list():
|
||||
"""
|
||||
estimate the total frame count from all visual input
|
||||
"""
|
||||
# Before processing inputs
|
||||
estimated_frames_list = []
|
||||
for image in image_data:
|
||||
if isinstance(image, str) and image.startswith("video:"):
|
||||
path = image[len("video:") :]
|
||||
# Estimate frames for the video
|
||||
vr = VideoReader(path, ctx=cpu(0))
|
||||
num_frames = len(vr)
|
||||
else:
|
||||
# For images, each contributes one frame
|
||||
num_frames = 1
|
||||
estimated_frames_list.append(num_frames)
|
||||
|
||||
return estimated_frames_list
|
||||
|
||||
estimated_frames_list = get_estimated_frames_list()
|
||||
total_frame_count = sum(estimated_frames_list)
|
||||
scaling_factor = min(1.0, MAX_NUM_FRAMES / total_frame_count)
|
||||
|
||||
def encode_video(video_path, frame_count_limit=None):
|
||||
if not os.path.exists(video_path):
|
||||
logger.error(f"Video {video_path} does not exist")
|
||||
return []
|
||||
|
||||
if MAX_NUM_FRAMES == 0:
|
||||
if frame_count_limit == 0:
|
||||
return []
|
||||
|
||||
def uniform_sample(l, n):
|
||||
@@ -314,45 +336,63 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
|
||||
vr = VideoReader(video_path, ctx=cpu(0))
|
||||
sample_fps = round(vr.get_avg_fps() / 1) # FPS
|
||||
frame_idx = [i for i in range(0, len(vr), sample_fps)]
|
||||
if len(frame_idx) > MAX_NUM_FRAMES:
|
||||
frame_idx = uniform_sample(frame_idx, MAX_NUM_FRAMES)
|
||||
if frame_count_limit is not None and len(frame_idx) > frame_count_limit:
|
||||
frame_idx = uniform_sample(frame_idx, frame_count_limit)
|
||||
frames = vr.get_batch(frame_idx).asnumpy()
|
||||
frames = [Image.fromarray(v.astype("uint8")) for v in frames]
|
||||
return frames
|
||||
|
||||
if isinstance(input_text, list):
|
||||
assert len(input_text) and isinstance(input_text[0], int)
|
||||
input_text = self._processor.tokenizer.decode(input_text)
|
||||
|
||||
if isinstance(input_ids, list):
|
||||
assert len(input_ids) and isinstance(input_ids[0], int)
|
||||
input_text = self._processor.tokenizer.decode(input_ids)
|
||||
else:
|
||||
input_text = input_ids
|
||||
# MiniCPMV requires each frame of video as a single image token
|
||||
text_parts = input_text.split(IMAGE_TOKEN)
|
||||
text_parts = input_text.split(self.IMAGE_TOKEN)
|
||||
new_text_parts = []
|
||||
|
||||
for image_index, image in enumerate(image_data):
|
||||
try:
|
||||
if isinstance(image, str) and image.startswith("video:"):
|
||||
path = image[len("video:") :]
|
||||
frames = encode_video(path)
|
||||
else:
|
||||
raw_image, size = load_image(image)
|
||||
frames = [raw_image]
|
||||
if len(frames) == 0:
|
||||
continue
|
||||
except FileNotFoundError as e:
|
||||
print(e)
|
||||
return None
|
||||
# Process each input with allocated frames
|
||||
for image_index, (image, estimated_frames) in enumerate(
|
||||
zip(image_data, estimated_frames_list)
|
||||
):
|
||||
if len(all_frames) >= MAX_NUM_FRAMES:
|
||||
frames_to_process = 0
|
||||
else:
|
||||
frames_to_process = max(1, int(estimated_frames * scaling_factor))
|
||||
|
||||
if frames_to_process == 0:
|
||||
frames = []
|
||||
else:
|
||||
try:
|
||||
if isinstance(image, str) and image.startswith("video:"):
|
||||
path = image[len("video:") :]
|
||||
frames = encode_video(path, frame_count_limit=frames_to_process)
|
||||
else:
|
||||
raw_image, _size = load_image(image)
|
||||
frames = [raw_image]
|
||||
if len(frames) == 0:
|
||||
continue
|
||||
except FileNotFoundError as e:
|
||||
print(e)
|
||||
return None
|
||||
image_sizes += frames[0].size * len(frames)
|
||||
image_hashes += [hash(image)] * len(frames)
|
||||
all_frames += frames
|
||||
|
||||
assert frames_to_process == len(frames)
|
||||
|
||||
image_sizes += frames[0].size * len(frames)
|
||||
image_hashes += [hash(image)] * len(frames)
|
||||
raw_images += frames
|
||||
new_text_parts.append(text_parts[image_index])
|
||||
new_text_parts.append(IMAGE_TOKEN * len(frames))
|
||||
|
||||
if frames_to_process != 0:
|
||||
new_text_parts.append(self.IMAGE_TOKEN * len(frames))
|
||||
|
||||
new_text_parts.append(text_parts[-1])
|
||||
|
||||
input_text = "".join(new_text_parts)
|
||||
if len(raw_images) == 0:
|
||||
|
||||
if len(all_frames) == 0:
|
||||
return None
|
||||
res = await self._process_images(images=raw_images, input_text=input_text)
|
||||
res = await self._process_images(images=all_frames, input_text=input_text)
|
||||
pixel_values = res["pixel_values"]
|
||||
tgt_sizes = res["tgt_sizes"]
|
||||
input_ids = res["input_ids"]
|
||||
@@ -364,7 +404,6 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
|
||||
if tokenizer.slice_start_id:
|
||||
slice_start_id = [tokenizer.slice_start_id]
|
||||
slice_end_id = [tokenizer.slice_end_id]
|
||||
|
||||
return {
|
||||
"input_ids": input_ids.flatten().tolist(),
|
||||
"pixel_values": pixel_values,
|
||||
|
||||
Reference in New Issue
Block a user