Online serving benchmarks of real datasets for hierarchical KV caching (#3211)
Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
This commit is contained in:
159
benchmark/hicache/nextqa.py
Normal file
159
benchmark/hicache/nextqa.py
Normal file
@@ -0,0 +1,159 @@
|
||||
import os
|
||||
import sys
|
||||
from typing import List
|
||||
|
||||
import av
|
||||
from datasets import load_dataset
|
||||
|
||||
|
||||
def find_video_files(video_dir) -> List[str]:
|
||||
if os.path.isfile(video_dir):
|
||||
return [video_dir]
|
||||
|
||||
video_files = []
|
||||
for root, dirs, files in os.walk(video_dir):
|
||||
for file in files:
|
||||
if file.endswith((".mp4", ".avi", ".mov")):
|
||||
video_files.append(os.path.join(root, file))
|
||||
# if file is dir
|
||||
elif os.path.isdir(file):
|
||||
video_files.extend(find_video_files(file))
|
||||
return video_files
|
||||
|
||||
|
||||
def video_frames(video_path, max_frames) -> int:
|
||||
container = av.open(video_path)
|
||||
total_frames = container.streams.video[0].frames
|
||||
return min(total_frames, max_frames)
|
||||
|
||||
|
||||
class Video:
|
||||
def __init__(self, video_path, num_frames):
|
||||
self.path = video_path
|
||||
self.num_frames = num_frames
|
||||
|
||||
def __str__(self):
|
||||
return f"Video({self.path}, {self.num_frames})"
|
||||
|
||||
def __iter__(self):
|
||||
return iter((self.path, self.num_frames))
|
||||
|
||||
|
||||
class VideoPrompt(Video):
|
||||
def __init__(self, video_path, num_frames, prompt):
|
||||
super().__init__(video_path, num_frames)
|
||||
self.prompt = prompt
|
||||
|
||||
def __str__(self):
|
||||
return f"VideoPrompt({self.path}, {self.num_frames}, {self.prompt})"
|
||||
|
||||
def __iter__(self):
|
||||
return iter((self.path, self.num_frames, self.prompt))
|
||||
|
||||
|
||||
class VideoLoader:
|
||||
pass
|
||||
|
||||
|
||||
class VideoFileLoader(VideoLoader):
|
||||
"""
|
||||
Load all the videos in a directory
|
||||
"""
|
||||
|
||||
def __init__(self, video_dir, batch_size=1, max_frames=sys.maxsize):
|
||||
super().__init__()
|
||||
self.video_dir = video_dir
|
||||
self.video_files = find_video_files(video_dir)
|
||||
self.batch_size = batch_size
|
||||
self.max_frames = max_frames
|
||||
print(f"batch_size: {batch_size}, max_frames: {max_frames}")
|
||||
|
||||
def __iter__(self): # (file, number of frames)
|
||||
if self.batch_size == 1:
|
||||
for video_file in self.video_files:
|
||||
yield Video(video_file, video_frames(video_file, self.max_frames))
|
||||
else:
|
||||
batch = []
|
||||
for video_file in self.video_files:
|
||||
video = Video(video_file, video_frames(video_file, self.max_frames))
|
||||
batch.append(video)
|
||||
if len(batch) == self.batch_size:
|
||||
yield batch
|
||||
batch = []
|
||||
|
||||
|
||||
class NExTQALoader(VideoLoader):
|
||||
"""
|
||||
Load vdideos and prompts from NExT dataset
|
||||
set: train, test or validation
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, video_dir, batch_size=1, max_frames=sys.maxsize, dset="test", task="OE"
|
||||
):
|
||||
"""
|
||||
task: 'MV' or 'OE'
|
||||
"""
|
||||
super().__init__()
|
||||
self.task = task
|
||||
print(f"Loading the {dset} data of {task} from lmms-lab/NExTQA")
|
||||
self.ds = load_dataset("lmms-lab/NExTQA", task)
|
||||
self.ds = self.ds[dset]
|
||||
|
||||
# self.n = ds.num_rows
|
||||
self.video_dir = video_dir
|
||||
self.video_files = find_video_files(video_dir)
|
||||
self.video_to_path = dict()
|
||||
for video_file in self.video_files:
|
||||
video_id = video_file.split("/")[-1].split(".")[0]
|
||||
self.video_to_path[video_id] = video_file
|
||||
|
||||
self.batch_size = batch_size
|
||||
self.max_frames = max_frames
|
||||
|
||||
def get_video_prompt(self, entry, max_frames) -> VideoPrompt:
|
||||
# Get video
|
||||
video_id = entry["video"]
|
||||
video_path = self.video_to_path[video_id]
|
||||
assert os.path.exists(video_path), f"Video not found: {video_path}"
|
||||
num_frames = min(entry["frame_count"], max_frames)
|
||||
video = Video(video_path, num_frames)
|
||||
prompt = entry["question"] + "?"
|
||||
if self.task == "MC": # add choices
|
||||
prompt += f' a0: {entry["a0"]}, a1: {entry["a1"]}, a2: {entry["a2"]}, a3: {entry["a3"]}'
|
||||
return VideoPrompt(video_path, num_frames, prompt)
|
||||
|
||||
def __iter__(self):
|
||||
if self.batch_size == 1:
|
||||
for entry in self.ds:
|
||||
yield self.get_video_prompt(entry, self.max_frames)
|
||||
else:
|
||||
batch = []
|
||||
for entry in self.ds:
|
||||
video = self.get_video_prompt(entry, self.max_frames)
|
||||
batch.append(video)
|
||||
if len(batch) == self.batch_size:
|
||||
yield batch
|
||||
batch = []
|
||||
|
||||
|
||||
# main
|
||||
if __name__ == "__main__":
|
||||
video_dir = "./videos"
|
||||
# video_loader = VideoFileLoader(video_dir, batch_size=16)
|
||||
# for batch in video_loader:
|
||||
# print(f"Number of videos in batch: {len(batch)}")
|
||||
# for video_file, num_frames in batch:
|
||||
# print(f"Video: {video_file} number of frames: {num_frames}")
|
||||
|
||||
video_loader = NExTQALoader(video_dir, batch_size=16, dset="test", task="OE")
|
||||
for batch in video_loader:
|
||||
print(f"Number of videos in batch: {len(batch)}")
|
||||
for video_file, num_frames, prompt in batch:
|
||||
print(
|
||||
f"Video: {video_file} number of frames: {num_frames}, prompt: {prompt}"
|
||||
)
|
||||
# break
|
||||
# for video_file, prompt in batch:
|
||||
# print(f"Video: {video_file} prompt: {prompt}")
|
||||
# break
|
||||
Reference in New Issue
Block a user