support llava video (#426)

This commit is contained in:
Yuanhan Zhang
2024-05-14 07:57:00 +08:00
committed by GitHub
parent 5dc55a5f02
commit 0992d85f92
37 changed files with 1139 additions and 222 deletions

View File

@@ -2,13 +2,16 @@
import base64
import json
import os
import sys
import threading
import traceback
import urllib.request
from concurrent.futures import ThreadPoolExecutor
from io import BytesIO
from json import dumps
import numpy as np
import requests
@@ -110,6 +113,74 @@ def encode_image_base64(image_path):
return base64.b64encode(buffered.getvalue()).decode("utf-8")
def encode_frame(frame):
import cv2 # pip install opencv-python-headless
from PIL import Image
# Convert the frame to RGB (OpenCV uses BGR by default)
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# Convert the frame to PIL Image to easily convert to bytes
im_pil = Image.fromarray(frame)
# Convert to bytes
buffered = BytesIO()
# frame_format = str(os.getenv('FRAME_FORMAT', "JPEG"))
im_pil.save(buffered, format="PNG")
frame_bytes = buffered.getvalue()
# Return the bytes of the frame
return frame_bytes
def encode_video_base64(video_path, num_frames=16):
import cv2
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
raise IOError(f"Could not open video file:{video_path}")
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
print(f"target_frames: {num_frames}")
frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
frames = []
for i in range(total_frames):
ret, frame = cap.read()
if ret:
frames.append(frame)
else:
# Handle the case where the frame could not be read
# print(f"Warning: Could not read frame at index {i}.")
pass
cap.release()
# Safely select frames based on frame_indices, avoiding IndexError
frames = [frames[i] for i in frame_indices if i < len(frames)]
# If there are not enough frames, duplicate the last frame until we reach the target
while len(frames) < num_frames:
frames.append(frames[-1])
# Use ThreadPoolExecutor to process and encode frames in parallel
with ThreadPoolExecutor() as executor:
encoded_frames = list(executor.map(encode_frame, frames))
# encoded_frames = list(map(encode_frame, frames))
# Concatenate all frames bytes
video_bytes = b"".join(encoded_frames)
# Encode the concatenated bytes to base64
video_base64 = "video:" + base64.b64encode(video_bytes).decode("utf-8")
return video_base64
def _is_chinese_char(cp):
"""Checks whether CP is the codepoint of a CJK character."""
# This defines a "chinese character" as anything in the CJK Unicode block:
@@ -170,4 +241,4 @@ def run_with_timeout(func, args=(), kwargs=None, timeout=None):
if not ret_value:
raise RuntimeError()
return ret_value[0]
return ret_value[0]