refactor: rewrite bench-mmmu-sglang (#4458)

This commit is contained in:
Mick
2025-03-18 09:11:47 +08:00
committed by GitHub
parent a98290aea3
commit 98be3bd306
4 changed files with 66 additions and 108 deletions

View File

@@ -19,11 +19,11 @@ from data_utils import (
process_single_sample,
)
from datasets import concatenate_datasets, load_dataset
from tqdm import tqdm
@dataclasses.dataclass
class EvalArgs:
backend: str = "engine"
seed: int = 42
split: str = "validation"
# Default setting to make the benchmark available on A100 for most 7B models
@@ -35,7 +35,6 @@ class EvalArgs:
@staticmethod
def add_cli_args(parser: argparse.ArgumentParser):
parser.add_argument("--backend", type=str, default=EvalArgs.backend)
parser.add_argument(
"--result-filename", type=str, default=EvalArgs.result_filename
)
@@ -108,7 +107,7 @@ def prepare_samples(eval_args: EvalArgs):
# run for each subject
sub_dataset_list = []
for subject in CAT_SHORT2LONG.values():
for subject in tqdm(CAT_SHORT2LONG.values()):
sub_dataset = load_dataset(
eval_args.dataset_path, subject, split=eval_args.split
)
@@ -121,19 +120,31 @@ def prepare_samples(eval_args: EvalArgs):
## prepare images
samples = []
skip_count = 0
for i, sample in enumerate(dataset):
# use image file as input to ensure the consistency between sglang and hf
images_path = os.path.expanduser("~/.cache/mmmu/images")
os.makedirs(images_path, exist_ok=True)
print(f"Saving images to: {images_path}")
for i, sample in enumerate(tqdm(dataset)):
sample = process_single_sample(sample)
sample = construct_prompt(sample, eval_args.config)
image = sample["image"]
width, height = image.size
if width * height >= eval_args.image_pixels_limit:
skip_count += 1
continue
image_path = f"{images_path}/image_{i}.png"
if not os.path.exists(image_path):
image.save(image_path)
sample["image_path"] = image_path
samples.append(sample)
print(
f"skipping {skip_count} samples with large images, {round((float(skip_count) / len(dataset)) * 100, 2)}% of dataset"
)
print("samples have been prepared")
return samples