feat: replace Decord with video_reader-rs (#5163)

Signed-off-by: Xinyuan Tong <justinning0323@outlook.com>
Co-authored-by: Xinyuan Tong <justinning0323@outlook.com>
This commit is contained in:
kozo
2025-07-16 09:17:34 +08:00
committed by GitHub
parent f06bd210c0
commit ebff5fcb06
6 changed files with 16 additions and 21 deletions

View File

@@ -206,7 +206,7 @@ class BaseMultimodalProcessor(ABC):
estimate the total frame count from all visual input
"""
# Lazy import because decord is not available on some arm platforms.
from decord import VideoReader, cpu
from video_reader import PyVideoReader, cpu
# Before processing inputs
if not image_data or len(image_data) == 0:
@@ -216,7 +216,7 @@ class BaseMultimodalProcessor(ABC):
if isinstance(image, str) and image.startswith("video:"):
path = image[len("video:") :]
# Estimate frames for the video
vr = VideoReader(path, ctx=cpu(0))
vr = PyVideoReader(path, threads=0)
num_frames = len(vr)
else:
# For images, each contributes one frame

View File

@@ -150,7 +150,7 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
def load_video(video_path, bound=None, input_size=448, max_num=1, num_segments=32):
vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
max_frame = len(vr) - 1
fps = float(vr.get_avg_fps())
fps = float(vr.get_fps())
pixel_values_list, num_patches_list = [], []
transform = InternVLImageProcessor.build_transform(input_size=input_size)
@@ -158,7 +158,7 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
bound, fps, max_frame, first_idx=0, num_segments=num_segments
)
for frame_index in frame_indices:
img = Image.fromarray(vr[frame_index].asnumpy()).convert("RGB")
img = Image.fromarray(vr[frame_index]).convert("RGB")
img = InternVLImageProcessor.dynamic_preprocess(
img, image_size=input_size, use_thumbnail=True, max_num=max_num
)

View File

@@ -156,10 +156,10 @@ async def preprocess_video(
# vr: VideoReader, image_factor: int = IMAGE_FACTOR
) -> torch.Tensor:
ele = {}
total_frames, video_fps = len(vr), vr.get_avg_fps()
total_frames, video_fps = len(vr), vr.get_fps()
nframes = smart_nframes({}, total_frames=total_frames, video_fps=video_fps)
idx = torch.linspace(0, total_frames - 1, nframes).round().long().tolist()
video = vr.get_batch(idx).asnumpy()
video = vr.get_batch(idx)
video = torch.tensor(video).permute(0, 3, 1, 2) # Convert to TCHW format
nframes, _, height, width = video.shape
min_pixels = ele.get("min_pixels", VIDEO_MIN_PIXELS)