Online serving benchmarks of real datasets for hierarchical KV caching (#3211)

Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
This commit is contained in:
Yueyang Pan
2025-03-06 01:16:43 +01:00
committed by GitHub
parent 62b362b1f1
commit 25482edb5c
6 changed files with 1914 additions and 1 deletions

View File

@@ -22,4 +22,70 @@ python bench_multiturn.py --model-path Qwen/Qwen2.5-14B-Instruct
Note: The performance gain of hierarchical caching depends on the ratio of reusable tokens to GPU memory capacity. The more tokens to be reused, the larger the model, and the more constrained the GPU memory size, the greater the benefit one can expect from hierarchical caching.
## More benchmarks to be added
# Benchmark with more datasets
## Download Dataset
```bash
./download.sh {sharegpt|ultragpt|loogle|nextqa|all}
```
This script will automatically download the required dataset to the current working directory
## Multiturn Benchmark
### Supported Datasets
- sharegpt
- ultrachat
- loogle
### Example Usage:
```bash
python3 bench_serving.py --model mistralai/Mistral-7B-Instruct-v0.3 --backend sglang \
--dataset-path longdep_qa.json --dataset-name loogle --request-rate 10 --num-prompts 10 \
--port 8001 --enable-multiturn --disable-shuffle
```
This uses `mistralai/Mistral-7B-Instruct-v0.3` model with `sglang` as backend. The dataset
is `longdep_qa.json`. We send `10 conversations` with `10 req/s` to port 8001. We enable
multiturn chat without shuffling the order of conversations (i.e. following the original
order in the dataset file).
### Note:
The requests of multiple conversations are sent in a round robin fashion.
For example, if we have 3 conversations A, B, C whose rounds are `[2, 3, 4]` correspondingly,
multiturn chat will send the requests to the backend in the following order: `[A1, B1, C1, A2, B2, C2, B3, C3, C4]`
This has implications on the cache reuse patterns: the cache reuse distance is the largest
under this request pattern (which means a prefix-aware local scheduler in the backend can
yield the most benefit compared to a FIFO scheduler)
## Shared Prefix Benchmark
### Supported Datasets
- loogle
### Example Usage:
```bash
python3 bench_serving.py --model mistralai/Mistral-7B-Instruct-v0.3 --backend sglang \
--dataset-path longdep_qa.json --dataset-name loogle --request-rate 10 --num-prompts 10 \
--port 8001 --enable-shared-prefix --disable-shuffle
```
### Note:
Shared Prefix benchmark sends the questions for the same prompt together. For example,
if we have 3 shared prefix A, B, C, which have [2, 3, 4] questions correspondingly,
the shared prefix benchmark will send the requests to the
backend in the following order: `[A+Q1, A+Q2, B+Q1, B+Q2, B+Q3, C+Q1, C+Q2, C+Q3]`.
## Multi Modality Benchmark (WIP)
### Supported Datasets:
- nextqa
### Example Usage:
```bash
Server:
python3 -m sglang.launch_server --model-path lmms-lab/LLaVA-NeXT-Video-7B --tp 2 --dp 1 --port 8001 \
--host 0.0.0.0 --mem-fraction-static 0.9 --tokenizer-path llava-hf/llava-1.5-7b-hf \
--json-model-override-args "{\"architectures\": [\"LlavaVidForCausalLM\"], \"model_type\":\"llava\", \"mm_spatial_pool_stride\":2}"
Client:
python3 bench_serving.py --model lmms-lab/LLaVA-NeXT-Video-7B --backend sglang --dataset-path \
NExTVideo --dataset-name nextqa --request-rate 10 --num-prompts 1 --disable-shuffle --port 8001 \ --enable-multiturn --max-frames 16 --tokenizer llava-hf/llava-1.5-7b-hf --fixed-output-len 2048
```
Note: for the server args, `tokenizer-path`, overriding architecture are necessary.
## Supported Backend
- sglang (oai)
- vllm (oai)
- lmdeploy (oai)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,589 @@
import json
import os
import pickle
import random
from pathlib import Path
from typing import List, Optional, Tuple, Union
import numpy as np
from nextqa import NExTQALoader
# from nextqa.video import , VideoPrompt
from tqdm.asyncio import tqdm
from transformers import PreTrainedTokenizerBase
SHAREGPT_URL = "https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json"
from sglang.bench_serving import (
download_and_cache_file,
gen_prompt,
get_gen_prefix_cache_path,
)
from sglang.lang.chat_template import get_chat_template, get_chat_template_by_model_path
from sglang.srt.openai_api.protocol import ChatCompletionMessageContentPart
from sglang.utils import encode_video_base64
# type of content fields, can be only prompts or with images/videos
MsgContent = Union[str, List[ChatCompletionMessageContentPart]]
# A list of all the conversations. Each conversation is a list of
# tuples. If multiturn is not enabled, the length of list is 1,
# containing only the first Q&A pair.
# For the shared prefix workload (synthetic, loogle, nextqa), it
# is a list of conversations sharing the same prefix (synthetic,
# doc, video)
SampleOutput = List[List[Tuple[MsgContent, int, int]]]
def common_filter_chat(
num_requests: int,
new_dataset: List,
tokenizer: PreTrainedTokenizerBase,
min_prompt_len: Optional[int],
min_output_len: Optional[int],
max_prompt_len: Optional[int],
max_output_len: Optional[int],
fixed_output_len: Optional[int],
) -> SampleOutput:
# Filter out sequences that are too long or too short
filtered_dataset: SampleOutput = []
l = 0
input_tokens = 0
output_tokens = 0
while l < num_requests:
for i in range(len(new_dataset)):
if l == num_requests:
break
processed = []
for j in new_dataset[i]:
# Tokenize the prompts and completions.
prompt = j[0]
prompt_token_ids = tokenizer.encode(prompt)
prompt_len = len(prompt_token_ids)
completion = j[1]
completion_token_ids = tokenizer.encode(completion)
output_len = (
len(completion_token_ids)
if fixed_output_len is None
else fixed_output_len
)
if (
min_prompt_len is not None
and prompt_len < min_prompt_len
or min_output_len is not None
and output_len < min_output_len
or max_prompt_len is not None
and prompt_len > max_prompt_len
or max_output_len is not None
and output_len > max_output_len
):
# Prune too short sequences.
continue
input_tokens += prompt_len
output_tokens += output_len
processed.append((prompt, prompt_len, output_len))
filtered_dataset.append(processed)
l += 1
print(f"#Input tokens: {input_tokens}")
print(f"#Output tokens: {output_tokens}")
return filtered_dataset
def sample_sharegpt_requests(
dataset_path: str,
num_requests: int,
tokenizer: PreTrainedTokenizerBase,
disable_shuffle: bool = False,
enable_multiturn: bool = True,
fixed_output_len: Optional[int] = None,
) -> SampleOutput:
if fixed_output_len is not None and fixed_output_len < 4:
raise ValueError("output_len too small")
# Download sharegpt if necessary
if not os.path.isfile(dataset_path):
dataset_path = download_and_cache_file(SHAREGPT_URL)
# Load the dataset.
with open(dataset_path) as f:
dataset = json.load(f)
# Filter out the conversations with less than 2 turns.
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
# Keep one conversation in one list
new_dataset = []
for data in dataset:
if len(data["conversations"]) % 2 != 0:
continue
if data["conversations"][0]["from"] != "human":
continue
chat = []
total_len = 2
if enable_multiturn:
total_len = len(data["conversations"])
for i in range(0, total_len, 2):
# One user One Assistant
chat.append(
(
data["conversations"][i]["value"],
data["conversations"][i + 1]["value"],
)
)
new_dataset.append(chat)
if not disable_shuffle:
# Shuffle the dataset.
random.shuffle(new_dataset)
# Filter out sequences that are too long or too short
filtered_dataset: SampleOutput = common_filter_chat(
num_requests, new_dataset, tokenizer, 4, 4, None, None, fixed_output_len
)
return filtered_dataset
def sample_ultrachat_requests(
dataset_path: str,
num_requests: int,
tokenizer: PreTrainedTokenizerBase,
disable_shuffle: bool = False,
enable_multiturn: bool = True,
fixed_output_len: Optional[int] = None,
) -> SampleOutput:
if fixed_output_len is not None and fixed_output_len < 4:
raise ValueError("output_len too small")
# Load the dataset
dataset = []
with open(dataset_path) as f:
while True:
line = f.readline()
if not line:
break
dataset.append(json.loads(line))
# Filter out the conversations with less than 2 turns.
dataset = [data for data in dataset if len(data["data"]) >= 2]
# Keep one conversation in one list
new_dataset = []
for data in dataset:
if len(data["data"]) % 2 != 0:
continue
chat = []
total_len = 2
if enable_multiturn:
total_len = len(data["data"])
for i in range(0, total_len, 2):
# One user One Assistant
chat.append((data["data"][i], data["data"][i + 1]))
new_dataset.append(chat)
# Shuffle the dataset.
if not disable_shuffle:
random.shuffle(new_dataset)
# Filter out sequences that are too long or too short
filtered_dataset: SampleOutput = common_filter_chat(
num_requests, new_dataset, tokenizer, 4, 4, None, None, fixed_output_len
)
return filtered_dataset
def sample_loogle_requests(
dataset_path: str,
num_requests: int,
tokenizer: PreTrainedTokenizerBase,
disable_shuffle: bool = False,
enable_multiturn: bool = True,
enable_shared_prefix: bool = False,
fixed_output_len: Optional[int] = None,
) -> SampleOutput:
if fixed_output_len is not None and fixed_output_len < 4:
raise ValueError("output_len too small")
# Load the dataset
dataset = []
with open(dataset_path) as f:
while True:
line = f.readline()
if not line:
break
dataset.append(json.loads(line))
# Keep one conversation in one list
new_dataset = []
# TODO: Add shared prefix support for loogle
# NOTE: Now we preprocess it only for chat
for data in dataset:
chat = []
if (
"qa_pairs" not in data
or data["qa_pairs"] == "none"
or len(data["qa_pairs"]) == 0
):
# If Q is none (for summarization),
# We add a question for summarization
# And keep the summary up to 1024 words
chat.append(
(
"Input: "
+ data["input"]
+ " Question: "
+ "Please summarize the input",
data["input"][:1024],
)
)
new_dataset.append(chat)
else:
qa_pairs = eval(data["qa_pairs"])
for i, qa in enumerate(qa_pairs):
if i == 0 or enable_shared_prefix:
# Combine input with the first Q
chat.append(
("Input: " + data["input"] + " Question: " + qa["Q"], qa["A"])
)
elif enable_multiturn:
chat.append((qa["Q"], qa["A"]))
new_dataset.append(chat)
# Shuffle the dataset.
if not disable_shuffle:
random.shuffle(new_dataset)
# Filter out sequences that are too long or too short
filtered_dataset: SampleOutput = common_filter_chat(
num_requests, new_dataset, tokenizer, 4, None, None, None, fixed_output_len
)
return filtered_dataset
def sample_nextqa_requests(
dataset_path: str,
num_requests: int,
tokenizer: PreTrainedTokenizerBase,
max_frames: int, # Specific for video
model_path: str,
disable_shuffle: bool = False,
enable_multiturn: bool = True, # No multiturn support for now
backend: str = "sglang-oai",
chat_template_name: Optional[str] = None,
fixed_output_len: Optional[int] = None,
) -> SampleOutput:
"""
Example of messages:
message = {
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": base64_data}},
{"type": "text", "text": video.prompt},
],
}
"""
if fixed_output_len is None:
fixed_output_len = 4096
# TODO: Check for multiturn
dataset = NExTQALoader(video_dir=dataset_path, max_frames=max_frames)
new_dataset = []
for v in dataset:
new_dataset.append(v)
if not disable_shuffle:
random.shuffle(new_dataset)
# TODO: prompt len can get from server side
filtered_dataset = []
l = 0
while l < num_requests:
for i in range(len(new_dataset)):
if l == num_requests:
break
video = new_dataset[i]
# text prompt
prompt = video.prompt
# NOTE: Chat Template is a must for video benchmark because we have to
# add special image token for later expansion
if backend == "sglang" or backend == "sglang-native":
if "chat_template" in tokenizer.init_kwargs:
chat_template = get_chat_template(tokenizer.get_chat_template())
elif chat_template_name is not None:
chat_template = get_chat_template(chat_template_name)
else:
chat_template = get_chat_template_by_model_path(model_path)
prompt = chat_template.image_token + prompt
prompt_token_ids = tokenizer(prompt).input_ids
prompt_len = len(prompt_token_ids)
output_len = fixed_output_len # max output len, not real output len
# video input
base64_data = encode_video_base64(video.path, video.num_frames)
# NOTE: This will be replaced by the expanded length from the server
prompt_len += video.num_frames
# add to content
content = [
{"type": "image_url", "image_url": {"url": base64_data}},
{"type": "text", "text": prompt},
]
filtered_dataset.append([(content, prompt_len, output_len)])
l += 1
return filtered_dataset
def sample_random_requests(
input_len: int,
output_len: int,
num_prompts: int,
range_ratio: float,
tokenizer: PreTrainedTokenizerBase,
dataset_path: str,
disable_shuffle: bool = False,
) -> SampleOutput:
input_lens = np.random.randint(
max(int(input_len * range_ratio), 1),
input_len + 1,
size=num_prompts,
)
output_lens = np.random.randint(
int(output_len * range_ratio),
output_len + 1,
size=num_prompts,
)
if True:
# Sample token ids from ShareGPT and repeat/truncate them to satisfy the input_lens
# Download sharegpt if necessary
if not os.path.isfile(dataset_path):
dataset_path = download_and_cache_file(SHAREGPT_URL)
# Load the dataset.
with open(dataset_path) as f:
dataset = json.load(f)
# Filter out the conversations with less than 2 turns.
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
# Only keep the first two turns of each conversation.
dataset = [
(data["conversations"][0]["value"], data["conversations"][1]["value"])
for data in dataset
]
if not disable_shuffle:
# Shuffle the dataset.
random.shuffle(dataset)
# Filter out sequences that are too long or too short
input_requests: SampleOutput = []
for data in dataset:
i = len(input_requests)
if i == num_prompts:
break
# Tokenize the prompts and completions.
prompt = data[0]
prompt_token_ids = tokenizer.encode(prompt)
prompt_len = len(prompt_token_ids)
# Skip empty prompt
if prompt_len == 0:
continue
if prompt_len > input_lens[i]:
input_ids = prompt_token_ids[: input_lens[i]]
else:
ratio = (input_lens[i] + prompt_len - 1) // prompt_len
input_ids = (prompt_token_ids * ratio)[: input_lens[i]]
prompt = tokenizer.decode(input_ids)
input_requests.append([(prompt, int(input_lens[i]), int(output_lens[i]))])
else:
# Sample token ids from random integers. This can cause some NaN issues.
offsets = np.random.randint(0, tokenizer.vocab_size, size=num_prompts)
input_requests = []
for i in range(num_prompts):
prompt = tokenizer.decode(
[
(offsets[i] + i + j) % tokenizer.vocab_size
for j in range(input_lens[i])
]
)
input_requests.append([(prompt, int(input_lens[i]), int(output_lens[i]))])
print(f"#Input tokens: {np.sum(input_lens)}")
print(f"#Output tokens: {np.sum(output_lens)}")
return input_requests
def gen_prompt(tokenizer, token_num):
"""Generate a random prompt of specified token length using tokenizer vocabulary."""
all_available_tokens = list(tokenizer.get_vocab().values())
selected_tokens = random.choices(all_available_tokens, k=token_num)
return tokenizer.decode(selected_tokens)
def get_gen_prefix_cache_path(args, tokenizer):
"""Create cache directory under ~/.cache/sglang/benchmark"""
cache_dir = Path.home() / ".cache" / "sglang" / "benchmark"
# Create a unique cache filename based on the generation parameters
cache_key = (
f"gen_prefix_{args.gen_num_groups}_{args.gen_prompts_per_group}_"
f"{args.gen_system_prompt_len}_{args.gen_question_len}_{args.gen_output_len}_"
f"{tokenizer.__class__.__name__}.pkl"
)
return cache_dir / cache_key
def sample_generated_shared_prefix_requests(
num_groups: int,
prompts_per_group: int,
system_prompt_len: int,
question_len: int,
output_len: int,
tokenizer: PreTrainedTokenizerBase,
args,
disable_shuffle: bool = False,
) -> SampleOutput:
"""Generate benchmark requests with shared system prompts using random tokens and caching."""
cache_path = get_gen_prefix_cache_path(args, tokenizer)
# Try to load from cache first
if cache_path.exists():
print(f"\nLoading cached generated input data from {cache_path}")
with open(cache_path, "rb") as f:
return pickle.load(f)
print("\nGenerating new input data...")
# Generate system prompts for each group
system_prompts = []
for _ in range(num_groups):
system_prompt = gen_prompt(tokenizer, system_prompt_len)
system_prompts.append(system_prompt)
# Generate questions
questions = []
for _ in range(num_groups * prompts_per_group):
question = gen_prompt(tokenizer, question_len)
questions.append(question)
# Combine system prompts with questions
input_requests = []
total_input_tokens = 0
total_output_tokens = 0
for group_idx in tqdm(range(num_groups), desc="Generating system prompt"):
system_prompt = system_prompts[group_idx]
input_requests.append([])
for prompt_idx in tqdm(
range(prompts_per_group), desc="Generating questions", leave=False
):
question = questions[group_idx * prompts_per_group + prompt_idx]
full_prompt = f"{system_prompt}\n\n{question}"
prompt_len = len(tokenizer.encode(full_prompt))
input_requests[-1].append((full_prompt, prompt_len, output_len))
total_input_tokens += prompt_len
total_output_tokens += output_len
if not disable_shuffle:
# Shuffle questions
random.shuffle(input_requests)
# Print statistics
print(f"\nGenerated shared prefix dataset statistics:")
print(f"Number of groups: {num_groups}")
print(f"Prompts per group: {prompts_per_group}")
print(f"Total prompts: {len(input_requests) * prompts_per_group}")
print(f"Total input tokens: {total_input_tokens}")
print(f"Total output tokens: {total_output_tokens}")
print(
f"Average system prompt length: {sum(len(tokenizer.encode(sp)) for sp in system_prompts) / len(system_prompts):.1f} tokens"
)
print(
f"Average question length: {sum(len(tokenizer.encode(q)) for q in questions) / len(questions):.1f} tokens\n"
)
# Save to cache
cache_path.parent.mkdir(parents=True, exist_ok=True)
print(f"Caching generated input data to {cache_path}")
with open(cache_path, "wb") as f:
pickle.dump(input_requests, f)
return input_requests
def get_dataset(args, tokenizer):
if args.dataset_name == "sharegpt":
input_requests = sample_sharegpt_requests(
dataset_path=args.dataset_path,
num_requests=args.num_prompts,
tokenizer=tokenizer,
disable_shuffle=args.disable_shuffle,
enable_multiturn=args.enable_multiturn,
fixed_output_len=args.fixed_output_len,
)
elif args.dataset_name == "ultrachat":
input_requests = sample_ultrachat_requests(
dataset_path=args.dataset_path,
num_requests=args.num_prompts,
tokenizer=tokenizer,
disable_shuffle=args.disable_shuffle,
enable_multiturn=args.enable_multiturn,
fixed_output_len=args.fixed_output_len,
)
elif args.dataset_name == "loogle":
input_requests = sample_loogle_requests(
dataset_path=args.dataset_path,
num_requests=args.num_prompts,
tokenizer=tokenizer,
disable_shuffle=args.disable_shuffle,
enable_multiturn=args.enable_multiturn,
enable_shared_prefix=args.enable_shared_prefix,
fixed_output_len=args.fixed_output_len,
)
elif args.dataset_name == "nextqa":
input_requests = sample_nextqa_requests(
dataset_path=args.dataset_path,
num_requests=args.num_prompts,
tokenizer=tokenizer,
max_frames=args.max_frames,
model_path=args.model,
disable_shuffle=args.disable_shuffle,
enable_multiturn=args.enable_multiturn,
backend=args.backend,
chat_template_name=args.chat_template,
fixed_output_len=args.fixed_output_len,
)
elif args.dataset_name == "random":
input_requests = sample_random_requests(
input_len=args.random_input_len,
output_len=args.random_output_len,
num_prompts=args.num_prompts,
range_ratio=args.random_range_ratio,
tokenizer=tokenizer,
dataset_path=args.dataset_path,
)
elif args.dataset_name == "generated-shared-prefix":
input_requests = sample_generated_shared_prefix_requests(
num_groups=args.gen_num_groups,
prompts_per_group=args.gen_prompts_per_group,
system_prompt_len=args.gen_system_prompt_len,
question_len=args.gen_question_len,
output_len=args.gen_output_len,
args=args,
tokenizer=tokenizer,
)
else:
raise ValueError(f"Unknown dataset: {args.dataset_name}")
return input_requests

66
benchmark/hicache/download.sh Executable file
View File

@@ -0,0 +1,66 @@
#!/usr/bin/bash
# The usage function
usage() {
echo "Usage: $0 {sharegpt|ultragpt|loogle|nextqa|all}"
exit 1
}
# The download function
download() {
case "$1" in
sharegpt)
echo $1
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
;;
ultragpt)
echo $1
# Questions about the world
wget https://cloud.tsinghua.edu.cn/seafhttp/files/be1d7b87-22ca-449e-a6a7-c61d1ea7e010/ultrachat_release_230407.json
# Writing and Creation
wget https://cloud.tsinghua.edu.cn/seafhttp/files/61742d2a-25e2-4d08-b2b9-15f47ae50ace/ultrachat_material_release_230417.json
wget https://cloud.tsinghua.edu.cn/seafhttp/files/f71f6aa6-d346-4b16-85b7-8502efa3d608/ultrachat_material_release_230412.json
# External materials
wget https://cloud.tsinghua.edu.cn/seafhttp/files/42d22e28-e899-4975-a70f-5eda163e265d/ultrachat_existent_material_release_230420.json.gz
gunzip ultrachat_existent_material_release_230420.json.gz
;;
loogle)
echo $1
git lfs install
git clone git@hf.co:datasets/bigainlco/LooGLE
unzip LooGLE/data.zip
;;
nextqa)
echo $1
git lfs install
git clone https://huggingface.co/datasets/lmms-lab/NExTQA
unzip NExTQA/videos.zip
;;
*)
usage
exit 1
;;
esac
}
# Arg check
if [ "$#" -ne 1 ]; then
usage
fi
# Invoke
case "$1" in
sharegpt|ultragpt|loogle|nextqa)
download "$1"
;;
all)
download sharegpt
download ultragpt
download loogle
download nextqa
;;
*)
usage
;;
esac

159
benchmark/hicache/nextqa.py Normal file
View File

@@ -0,0 +1,159 @@
import os
import sys
from typing import List
import av
from datasets import load_dataset
def find_video_files(video_dir) -> List[str]:
if os.path.isfile(video_dir):
return [video_dir]
video_files = []
for root, dirs, files in os.walk(video_dir):
for file in files:
if file.endswith((".mp4", ".avi", ".mov")):
video_files.append(os.path.join(root, file))
# if file is dir
elif os.path.isdir(file):
video_files.extend(find_video_files(file))
return video_files
def video_frames(video_path, max_frames) -> int:
container = av.open(video_path)
total_frames = container.streams.video[0].frames
return min(total_frames, max_frames)
class Video:
def __init__(self, video_path, num_frames):
self.path = video_path
self.num_frames = num_frames
def __str__(self):
return f"Video({self.path}, {self.num_frames})"
def __iter__(self):
return iter((self.path, self.num_frames))
class VideoPrompt(Video):
def __init__(self, video_path, num_frames, prompt):
super().__init__(video_path, num_frames)
self.prompt = prompt
def __str__(self):
return f"VideoPrompt({self.path}, {self.num_frames}, {self.prompt})"
def __iter__(self):
return iter((self.path, self.num_frames, self.prompt))
class VideoLoader:
pass
class VideoFileLoader(VideoLoader):
"""
Load all the videos in a directory
"""
def __init__(self, video_dir, batch_size=1, max_frames=sys.maxsize):
super().__init__()
self.video_dir = video_dir
self.video_files = find_video_files(video_dir)
self.batch_size = batch_size
self.max_frames = max_frames
print(f"batch_size: {batch_size}, max_frames: {max_frames}")
def __iter__(self): # (file, number of frames)
if self.batch_size == 1:
for video_file in self.video_files:
yield Video(video_file, video_frames(video_file, self.max_frames))
else:
batch = []
for video_file in self.video_files:
video = Video(video_file, video_frames(video_file, self.max_frames))
batch.append(video)
if len(batch) == self.batch_size:
yield batch
batch = []
class NExTQALoader(VideoLoader):
"""
Load vdideos and prompts from NExT dataset
set: train, test or validation
"""
def __init__(
self, video_dir, batch_size=1, max_frames=sys.maxsize, dset="test", task="OE"
):
"""
task: 'MV' or 'OE'
"""
super().__init__()
self.task = task
print(f"Loading the {dset} data of {task} from lmms-lab/NExTQA")
self.ds = load_dataset("lmms-lab/NExTQA", task)
self.ds = self.ds[dset]
# self.n = ds.num_rows
self.video_dir = video_dir
self.video_files = find_video_files(video_dir)
self.video_to_path = dict()
for video_file in self.video_files:
video_id = video_file.split("/")[-1].split(".")[0]
self.video_to_path[video_id] = video_file
self.batch_size = batch_size
self.max_frames = max_frames
def get_video_prompt(self, entry, max_frames) -> VideoPrompt:
# Get video
video_id = entry["video"]
video_path = self.video_to_path[video_id]
assert os.path.exists(video_path), f"Video not found: {video_path}"
num_frames = min(entry["frame_count"], max_frames)
video = Video(video_path, num_frames)
prompt = entry["question"] + "?"
if self.task == "MC": # add choices
prompt += f' a0: {entry["a0"]}, a1: {entry["a1"]}, a2: {entry["a2"]}, a3: {entry["a3"]}'
return VideoPrompt(video_path, num_frames, prompt)
def __iter__(self):
if self.batch_size == 1:
for entry in self.ds:
yield self.get_video_prompt(entry, self.max_frames)
else:
batch = []
for entry in self.ds:
video = self.get_video_prompt(entry, self.max_frames)
batch.append(video)
if len(batch) == self.batch_size:
yield batch
batch = []
# main
if __name__ == "__main__":
video_dir = "./videos"
# video_loader = VideoFileLoader(video_dir, batch_size=16)
# for batch in video_loader:
# print(f"Number of videos in batch: {len(batch)}")
# for video_file, num_frames in batch:
# print(f"Video: {video_file} number of frames: {num_frames}")
video_loader = NExTQALoader(video_dir, batch_size=16, dset="test", task="OE")
for batch in video_loader:
print(f"Number of videos in batch: {len(batch)}")
for video_file, num_frames, prompt in batch:
print(
f"Video: {video_file} number of frames: {num_frames}, prompt: {prompt}"
)
# break
# for video_file, prompt in batch:
# print(f"Video: {video_file} prompt: {prompt}")
# break

View File

@@ -24,10 +24,14 @@ import requests
from IPython.display import HTML, display
from tqdm import tqdm
from sglang.srt.openai_api.protocol import ChatCompletionMessageContentPart
from sglang.srt.utils import kill_process_tree
logger = logging.getLogger(__name__)
# type of content fields, can be only prompts or with images/videos
MsgContent = Union[str, List[ChatCompletionMessageContentPart]]
def get_exception_traceback():
etype, value, tb = sys.exc_info()