refactor: rewrite bench-mmmu-sglang (#4458)

This commit is contained in:
Mick
2025-03-18 09:11:47 +08:00
committed by GitHub
parent a98290aea3
commit 98be3bd306
4 changed files with 66 additions and 108 deletions

View File

@@ -1,14 +1,4 @@
"""
Bench the huggingface vLM with benchmark MMMU
Usage:
python benchmark/mmmu/bench_hf.py --model-path Qwen/Qwen2-VL-7B-Instruct
The eval output will be logged
"""
import argparse
import random
import torch
from data_utils import save_json
@@ -53,48 +43,31 @@ def eval_mmmu(args):
image = sample["image"]
prefix = prompt.split("<")[0]
suffix = prompt.split(">")[1]
if image is not None:
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": prefix},
{
"type": "image",
"image": image,
},
{"type": "text", "text": suffix},
],
}
]
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
inputs = processor(
text=[text],
images=[image],
padding=True,
return_tensors="pt",
).to(model.device)
generated_ids = model.generate(
**inputs, generation_config=generation_config
)
response = processor.decode(
generated_ids[0],
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)[len(text) :]
print(f"response: {response}")
else: # multiple images actually
if sample["question_type"] == "multiple-choice":
all_choices = sample["all_choices"]
response = random.choice(all_choices)
else:
response = "INVALID GENERATION FOR MULTIPLE IMAGE INPUTS"
assert image is not None
contents = []
if prefix:
contents += [{"type": "text", "text": prefix}]
contents += [
{
"type": "image",
"image": sample["image_path"],
}
]
if suffix:
contents += [{"type": "text", "text": suffix}]
messages = [{"role": "user", "content": contents}]
model_inputs = processor.apply_chat_template(
messages,
tokenize=True,
return_dict=True,
add_generation_prompt=True,
return_tensors="pt",
).to(model.device)
input_len = model_inputs["input_ids"].shape[-1]
generation = model.generate(**model_inputs, generation_config=generation_config)
generation = generation[0][input_len:]
response = processor.decode(generation, skip_special_tokens=True)
print(f"response: {response}")
process_result(response, sample, answer_dict, out_samples)
args.output_path = f"{args.model_path}_val_hf.json"