vlm: support video as an input modality (#5888)
This commit is contained in:
@@ -728,33 +728,6 @@ def load_audio(audio_file: str, sr: int = 16000, mono: bool = True) -> np.ndarra
|
||||
return audio
|
||||
|
||||
|
||||
def encode_video(video_path, frame_count_limit=None):
|
||||
# Lazy import because decord is not available on some arm platforms.
|
||||
from decord import VideoReader, cpu
|
||||
|
||||
if not os.path.exists(video_path):
|
||||
logger.error(f"Video {video_path} does not exist")
|
||||
return []
|
||||
|
||||
if frame_count_limit == 0:
|
||||
return []
|
||||
|
||||
def uniform_sample(l, n):
|
||||
gap = len(l) / n
|
||||
idxs = [int(i * gap + gap / 2) for i in range(n)]
|
||||
return [l[i] for i in idxs]
|
||||
|
||||
vr = VideoReader(video_path, ctx=cpu(0))
|
||||
sample_fps = round(vr.get_avg_fps() / 1) # FPS
|
||||
frame_indices = [i for i in range(0, len(vr), sample_fps)]
|
||||
if frame_count_limit is not None and len(frame_indices) > frame_count_limit:
|
||||
frame_indices = uniform_sample(frame_indices, frame_count_limit)
|
||||
|
||||
frames = vr.get_batch(frame_indices).asnumpy()
|
||||
frames = [Image.fromarray(v.astype("uint8")) for v in frames]
|
||||
return frames
|
||||
|
||||
|
||||
def load_image(
|
||||
image_file: Union[Image.Image, str, bytes],
|
||||
) -> tuple[Image.Image, tuple[int, int]]:
|
||||
@@ -774,9 +747,6 @@ def load_image(
|
||||
elif image_file.startswith("data:"):
|
||||
image_file = image_file.split(",")[1]
|
||||
image = Image.open(BytesIO(pybase64.b64decode(image_file, validate=True)))
|
||||
elif image_file.startswith("video:"):
|
||||
image_file = image_file.replace("video:", "")
|
||||
image, image_size = decode_video_base64(image_file)
|
||||
elif isinstance(image_file, str):
|
||||
image = Image.open(BytesIO(pybase64.b64decode(image_file, validate=True)))
|
||||
else:
|
||||
@@ -785,6 +755,61 @@ def load_image(
|
||||
return image, image_size
|
||||
|
||||
|
||||
def load_video(video_file: Union[str, bytes], use_gpu: bool = True):
|
||||
# We import decord here to avoid a strange Segmentation fault (core dumped) issue.
|
||||
from decord import VideoReader, cpu, gpu
|
||||
|
||||
try:
|
||||
from decord.bridge import decord_bridge
|
||||
|
||||
ctx = gpu(0)
|
||||
_ = decord_bridge.get_ctx_device(ctx)
|
||||
except Exception:
|
||||
ctx = cpu(0)
|
||||
|
||||
tmp_file = None
|
||||
vr = None
|
||||
try:
|
||||
if isinstance(video_file, bytes):
|
||||
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
|
||||
tmp_file.write(video_file)
|
||||
tmp_file.close()
|
||||
vr = VideoReader(tmp_file.name, ctx=ctx)
|
||||
elif isinstance(video_file, str):
|
||||
if video_file.startswith(("http://", "https://")):
|
||||
timeout = int(os.getenv("REQUEST_TIMEOUT", "10"))
|
||||
response = requests.get(video_file, stream=True, timeout=timeout)
|
||||
response.raise_for_status()
|
||||
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
|
||||
for chunk in response.iter_content(chunk_size=8192):
|
||||
tmp_file.write(chunk)
|
||||
tmp_file.close()
|
||||
vr = VideoReader(tmp_file.name, ctx=ctx)
|
||||
elif video_file.startswith("data:"):
|
||||
_, encoded = video_file.split(",", 1)
|
||||
video_bytes = base64.b64decode(encoded)
|
||||
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
|
||||
tmp_file.write(video_bytes)
|
||||
tmp_file.close()
|
||||
vr = VideoReader(tmp_file.name, ctx=ctx)
|
||||
elif os.path.isfile(video_file):
|
||||
vr = VideoReader(video_file, ctx=ctx)
|
||||
else:
|
||||
video_bytes = base64.b64decode(video_file)
|
||||
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
|
||||
tmp_file.write(video_bytes)
|
||||
tmp_file.close()
|
||||
vr = VideoReader(tmp_file.name, ctx=ctx)
|
||||
else:
|
||||
raise ValueError(f"Unsupported video input type: {type(video_file)}")
|
||||
|
||||
return vr
|
||||
|
||||
finally:
|
||||
if tmp_file and os.path.exists(tmp_file.name):
|
||||
os.unlink(tmp_file.name)
|
||||
|
||||
|
||||
def suppress_other_loggers():
|
||||
warnings.filterwarnings(
|
||||
"ignore", category=UserWarning, message="The given NumPy array is not writable"
|
||||
|
||||
Reference in New Issue
Block a user