feat: replace Decord with video_reader-rs (#5163)
Signed-off-by: Xinyuan Tong <justinning0323@outlook.com> Co-authored-by: Xinyuan Tong <justinning0323@outlook.com>
This commit is contained in:
@@ -21,6 +21,7 @@ runtime_common = [
|
|||||||
"build",
|
"build",
|
||||||
"compressed-tensors",
|
"compressed-tensors",
|
||||||
"datasets",
|
"datasets",
|
||||||
|
"video-reader-rs",
|
||||||
"fastapi",
|
"fastapi",
|
||||||
"hf_transfer",
|
"hf_transfer",
|
||||||
"huggingface_hub",
|
"huggingface_hub",
|
||||||
|
|||||||
@@ -47,7 +47,7 @@ PACKAGE_LIST = [
|
|||||||
"tiktoken",
|
"tiktoken",
|
||||||
"anthropic",
|
"anthropic",
|
||||||
"litellm",
|
"litellm",
|
||||||
"decord",
|
"video-reader-rs",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -206,7 +206,7 @@ class BaseMultimodalProcessor(ABC):
|
|||||||
estimate the total frame count from all visual input
|
estimate the total frame count from all visual input
|
||||||
"""
|
"""
|
||||||
# Lazy import because decord is not available on some arm platforms.
|
# Lazy import because decord is not available on some arm platforms.
|
||||||
from decord import VideoReader, cpu
|
from video_reader import PyVideoReader, cpu
|
||||||
|
|
||||||
# Before processing inputs
|
# Before processing inputs
|
||||||
if not image_data or len(image_data) == 0:
|
if not image_data or len(image_data) == 0:
|
||||||
@@ -216,7 +216,7 @@ class BaseMultimodalProcessor(ABC):
|
|||||||
if isinstance(image, str) and image.startswith("video:"):
|
if isinstance(image, str) and image.startswith("video:"):
|
||||||
path = image[len("video:") :]
|
path = image[len("video:") :]
|
||||||
# Estimate frames for the video
|
# Estimate frames for the video
|
||||||
vr = VideoReader(path, ctx=cpu(0))
|
vr = PyVideoReader(path, threads=0)
|
||||||
num_frames = len(vr)
|
num_frames = len(vr)
|
||||||
else:
|
else:
|
||||||
# For images, each contributes one frame
|
# For images, each contributes one frame
|
||||||
|
|||||||
@@ -150,7 +150,7 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
|
|||||||
def load_video(video_path, bound=None, input_size=448, max_num=1, num_segments=32):
|
def load_video(video_path, bound=None, input_size=448, max_num=1, num_segments=32):
|
||||||
vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
|
vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
|
||||||
max_frame = len(vr) - 1
|
max_frame = len(vr) - 1
|
||||||
fps = float(vr.get_avg_fps())
|
fps = float(vr.get_fps())
|
||||||
|
|
||||||
pixel_values_list, num_patches_list = [], []
|
pixel_values_list, num_patches_list = [], []
|
||||||
transform = InternVLImageProcessor.build_transform(input_size=input_size)
|
transform = InternVLImageProcessor.build_transform(input_size=input_size)
|
||||||
@@ -158,7 +158,7 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
|
|||||||
bound, fps, max_frame, first_idx=0, num_segments=num_segments
|
bound, fps, max_frame, first_idx=0, num_segments=num_segments
|
||||||
)
|
)
|
||||||
for frame_index in frame_indices:
|
for frame_index in frame_indices:
|
||||||
img = Image.fromarray(vr[frame_index].asnumpy()).convert("RGB")
|
img = Image.fromarray(vr[frame_index]).convert("RGB")
|
||||||
img = InternVLImageProcessor.dynamic_preprocess(
|
img = InternVLImageProcessor.dynamic_preprocess(
|
||||||
img, image_size=input_size, use_thumbnail=True, max_num=max_num
|
img, image_size=input_size, use_thumbnail=True, max_num=max_num
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -156,10 +156,10 @@ async def preprocess_video(
|
|||||||
# vr: VideoReader, image_factor: int = IMAGE_FACTOR
|
# vr: VideoReader, image_factor: int = IMAGE_FACTOR
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
ele = {}
|
ele = {}
|
||||||
total_frames, video_fps = len(vr), vr.get_avg_fps()
|
total_frames, video_fps = len(vr), vr.get_fps()
|
||||||
nframes = smart_nframes({}, total_frames=total_frames, video_fps=video_fps)
|
nframes = smart_nframes({}, total_frames=total_frames, video_fps=video_fps)
|
||||||
idx = torch.linspace(0, total_frames - 1, nframes).round().long().tolist()
|
idx = torch.linspace(0, total_frames - 1, nframes).round().long().tolist()
|
||||||
video = vr.get_batch(idx).asnumpy()
|
video = vr.get_batch(idx)
|
||||||
video = torch.tensor(video).permute(0, 3, 1, 2) # Convert to TCHW format
|
video = torch.tensor(video).permute(0, 3, 1, 2) # Convert to TCHW format
|
||||||
nframes, _, height, width = video.shape
|
nframes, _, height, width = video.shape
|
||||||
min_pixels = ele.get("min_pixels", VIDEO_MIN_PIXELS)
|
min_pixels = ele.get("min_pixels", VIDEO_MIN_PIXELS)
|
||||||
|
|||||||
@@ -84,6 +84,7 @@ from torch.library import Library
|
|||||||
from torch.profiler import ProfilerActivity, profile, record_function
|
from torch.profiler import ProfilerActivity, profile, record_function
|
||||||
from torch.utils._contextlib import _DecoratorContextManager
|
from torch.utils._contextlib import _DecoratorContextManager
|
||||||
from triton.runtime.cache import FileCacheManager
|
from triton.runtime.cache import FileCacheManager
|
||||||
|
from video_reader import PyVideoReader
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -757,16 +758,9 @@ def load_image(
|
|||||||
|
|
||||||
def load_video(video_file: Union[str, bytes], use_gpu: bool = True):
|
def load_video(video_file: Union[str, bytes], use_gpu: bool = True):
|
||||||
# We import decord here to avoid a strange Segmentation fault (core dumped) issue.
|
# We import decord here to avoid a strange Segmentation fault (core dumped) issue.
|
||||||
from decord import VideoReader, cpu, gpu
|
from video_reader import PyVideoReader
|
||||||
|
|
||||||
try:
|
|
||||||
from decord.bridge import decord_bridge
|
|
||||||
|
|
||||||
ctx = gpu(0)
|
|
||||||
_ = decord_bridge.get_ctx_device(ctx)
|
|
||||||
except Exception:
|
|
||||||
ctx = cpu(0)
|
|
||||||
|
|
||||||
|
device = "cuda" if use_gpu and torch.cuda.is_available() else None
|
||||||
tmp_file = None
|
tmp_file = None
|
||||||
vr = None
|
vr = None
|
||||||
try:
|
try:
|
||||||
@@ -774,7 +768,7 @@ def load_video(video_file: Union[str, bytes], use_gpu: bool = True):
|
|||||||
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
|
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
|
||||||
tmp_file.write(video_file)
|
tmp_file.write(video_file)
|
||||||
tmp_file.close()
|
tmp_file.close()
|
||||||
vr = VideoReader(tmp_file.name, ctx=ctx)
|
vr = PyVideoReader(tmp_file.name, device=device, threads=0)
|
||||||
elif isinstance(video_file, str):
|
elif isinstance(video_file, str):
|
||||||
if video_file.startswith(("http://", "https://")):
|
if video_file.startswith(("http://", "https://")):
|
||||||
timeout = int(os.getenv("REQUEST_TIMEOUT", "10"))
|
timeout = int(os.getenv("REQUEST_TIMEOUT", "10"))
|
||||||
@@ -784,22 +778,22 @@ def load_video(video_file: Union[str, bytes], use_gpu: bool = True):
|
|||||||
for chunk in response.iter_content(chunk_size=8192):
|
for chunk in response.iter_content(chunk_size=8192):
|
||||||
tmp_file.write(chunk)
|
tmp_file.write(chunk)
|
||||||
tmp_file.close()
|
tmp_file.close()
|
||||||
vr = VideoReader(tmp_file.name, ctx=ctx)
|
vr = PyVideoReader(tmp_file.name, device=device, threads=0)
|
||||||
elif video_file.startswith("data:"):
|
elif video_file.startswith("data:"):
|
||||||
_, encoded = video_file.split(",", 1)
|
_, encoded = video_file.split(",", 1)
|
||||||
video_bytes = base64.b64decode(encoded)
|
video_bytes = base64.b64decode(encoded)
|
||||||
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
|
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
|
||||||
tmp_file.write(video_bytes)
|
tmp_file.write(video_bytes)
|
||||||
tmp_file.close()
|
tmp_file.close()
|
||||||
vr = VideoReader(tmp_file.name, ctx=ctx)
|
vr = PyVideoReader(tmp_file.name, device=device, threads=0)
|
||||||
elif os.path.isfile(video_file):
|
elif os.path.isfile(video_file):
|
||||||
vr = VideoReader(video_file, ctx=ctx)
|
vr = PyVideoReader(video_file, device=device, threads=0)
|
||||||
else:
|
else:
|
||||||
video_bytes = base64.b64decode(video_file)
|
video_bytes = base64.b64decode(video_file)
|
||||||
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
|
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
|
||||||
tmp_file.write(video_bytes)
|
tmp_file.write(video_bytes)
|
||||||
tmp_file.close()
|
tmp_file.close()
|
||||||
vr = VideoReader(tmp_file.name, ctx=ctx)
|
vr = PyVideoReader(tmp_file.name, device=device, threads=0)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported video input type: {type(video_file)}")
|
raise ValueError(f"Unsupported video input type: {type(video_file)}")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user