bench: Add MMMU benchmark for vLM (#3562)
This commit is contained in:
101
benchmark/mmmu/bench_sglang.py
Normal file
101
benchmark/mmmu/bench_sglang.py
Normal file
@@ -0,0 +1,101 @@
|
||||
"""
|
||||
Bench the sglang-hosted vLM with benchmark MMMU
|
||||
|
||||
Usage:
|
||||
python benchmark/mmmu/bench_sglang.py --model-path Qwen/Qwen2-VL-7B-Instruct --chat-template qwen2-vl
|
||||
|
||||
The eval output will be logged
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import dataclasses
|
||||
import random
|
||||
import re
|
||||
from io import BytesIO
|
||||
|
||||
from data_utils import save_json
|
||||
from eval_utils import (
|
||||
EvalArgs,
|
||||
eval_result,
|
||||
get_sampling_params,
|
||||
parse_multi_choice_response,
|
||||
prepare_samples,
|
||||
)
|
||||
from tqdm import tqdm
|
||||
|
||||
from sglang import Engine
|
||||
from sglang.srt.conversation import chat_templates
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
|
||||
|
||||
def eval_mmmu(args):
|
||||
server_args = ServerArgs.from_cli_args(args)
|
||||
eval_args = EvalArgs.from_cli_args(args)
|
||||
|
||||
if server_args.chat_template is None:
|
||||
raise ValueError("Chat template must be provided for this benchmark")
|
||||
|
||||
samples = prepare_samples(eval_args)
|
||||
|
||||
backend = Engine(**dataclasses.asdict(server_args))
|
||||
|
||||
out_samples = dict()
|
||||
|
||||
sampling_params = get_sampling_params(eval_args)
|
||||
|
||||
conv = chat_templates[server_args.chat_template].copy()
|
||||
image_token = conv.image_token
|
||||
answer_dict = {}
|
||||
for sample in tqdm(samples):
|
||||
prompt = sample["final_input_prompt"]
|
||||
image = sample["image"]
|
||||
bytes_io = BytesIO()
|
||||
image.save(bytes_io, format="PNG")
|
||||
png_bytes = bytes_io.getvalue()
|
||||
|
||||
prompt = re.sub(r"<[^>]*>", image_token, prompt)
|
||||
|
||||
if image is not None:
|
||||
gen_out = backend.generate(
|
||||
prompt=prompt, image_data=[png_bytes], sampling_params=sampling_params
|
||||
)["text"]
|
||||
|
||||
response = gen_out
|
||||
else: # multiple images actually
|
||||
if sample["question_type"] == "multiple-choice":
|
||||
all_choices = sample["all_choices"]
|
||||
response = random.choice(all_choices)
|
||||
|
||||
else:
|
||||
response = "INVALID GENERATION FOR MULTIPLE IMAGE INPUTS"
|
||||
|
||||
if sample["question_type"] == "multiple-choice":
|
||||
pred_ans = parse_multi_choice_response(
|
||||
response, sample["all_choices"], sample["index2ans"]
|
||||
)
|
||||
else: # open question
|
||||
pred_ans = response
|
||||
out_samples[sample["id"]] = pred_ans
|
||||
|
||||
# set ground truth answer
|
||||
answer_dict[sample["id"]] = {
|
||||
"question_type": sample["question_type"],
|
||||
"ground_truth": (
|
||||
sample["correct_choice"]
|
||||
if "correct_choice" in samples
|
||||
else sample["answer"]
|
||||
),
|
||||
}
|
||||
|
||||
args.output_path = f"{args.model_path}_val_sglang.json"
|
||||
save_json(args.output_path, out_samples)
|
||||
eval_result(output_path=args.output_path, answer_dict=answer_dict)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
ServerArgs.add_cli_args(parser)
|
||||
EvalArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
eval_mmmu(args)
|
||||
Reference in New Issue
Block a user