feat: replace Decord with video_reader-rs (#5163)

Signed-off-by: Xinyuan Tong <justinning0323@outlook.com>
Co-authored-by: Xinyuan Tong <justinning0323@outlook.com>
This commit is contained in:
kozo
2025-07-16 09:17:34 +08:00
committed by GitHub
parent f06bd210c0
commit ebff5fcb06
6 changed files with 16 additions and 21 deletions

View File

@@ -84,6 +84,7 @@ from torch.library import Library
from torch.profiler import ProfilerActivity, profile, record_function
from torch.utils._contextlib import _DecoratorContextManager
from triton.runtime.cache import FileCacheManager
from video_reader import PyVideoReader
logger = logging.getLogger(__name__)
@@ -757,16 +758,9 @@ def load_image(
def load_video(video_file: Union[str, bytes], use_gpu: bool = True):
# We import decord here to avoid a strange Segmentation fault (core dumped) issue.
from decord import VideoReader, cpu, gpu
try:
from decord.bridge import decord_bridge
ctx = gpu(0)
_ = decord_bridge.get_ctx_device(ctx)
except Exception:
ctx = cpu(0)
from video_reader import PyVideoReader
device = "cuda" if use_gpu and torch.cuda.is_available() else None
tmp_file = None
vr = None
try:
@@ -774,7 +768,7 @@ def load_video(video_file: Union[str, bytes], use_gpu: bool = True):
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
tmp_file.write(video_file)
tmp_file.close()
vr = VideoReader(tmp_file.name, ctx=ctx)
vr = PyVideoReader(tmp_file.name, device=device, threads=0)
elif isinstance(video_file, str):
if video_file.startswith(("http://", "https://")):
timeout = int(os.getenv("REQUEST_TIMEOUT", "10"))
@@ -784,22 +778,22 @@ def load_video(video_file: Union[str, bytes], use_gpu: bool = True):
for chunk in response.iter_content(chunk_size=8192):
tmp_file.write(chunk)
tmp_file.close()
vr = VideoReader(tmp_file.name, ctx=ctx)
vr = PyVideoReader(tmp_file.name, device=device, threads=0)
elif video_file.startswith("data:"):
_, encoded = video_file.split(",", 1)
video_bytes = base64.b64decode(encoded)
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
tmp_file.write(video_bytes)
tmp_file.close()
vr = VideoReader(tmp_file.name, ctx=ctx)
vr = PyVideoReader(tmp_file.name, device=device, threads=0)
elif os.path.isfile(video_file):
vr = VideoReader(video_file, ctx=ctx)
vr = PyVideoReader(video_file, device=device, threads=0)
else:
video_bytes = base64.b64decode(video_file)
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
tmp_file.write(video_bytes)
tmp_file.close()
vr = VideoReader(tmp_file.name, ctx=ctx)
vr = PyVideoReader(tmp_file.name, device=device, threads=0)
else:
raise ValueError(f"Unsupported video input type: {type(video_file)}")