refactor: rewrite bench-mmmu-sglang (#4458)
This commit is contained in:
@@ -19,11 +19,11 @@ from data_utils import (
|
||||
process_single_sample,
|
||||
)
|
||||
from datasets import concatenate_datasets, load_dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class EvalArgs:
|
||||
backend: str = "engine"
|
||||
seed: int = 42
|
||||
split: str = "validation"
|
||||
# Default setting to make the benchmark available on A100 for most 7B models
|
||||
@@ -35,7 +35,6 @@ class EvalArgs:
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(parser: argparse.ArgumentParser):
|
||||
parser.add_argument("--backend", type=str, default=EvalArgs.backend)
|
||||
parser.add_argument(
|
||||
"--result-filename", type=str, default=EvalArgs.result_filename
|
||||
)
|
||||
@@ -108,7 +107,7 @@ def prepare_samples(eval_args: EvalArgs):
|
||||
# run for each subject
|
||||
sub_dataset_list = []
|
||||
|
||||
for subject in CAT_SHORT2LONG.values():
|
||||
for subject in tqdm(CAT_SHORT2LONG.values()):
|
||||
sub_dataset = load_dataset(
|
||||
eval_args.dataset_path, subject, split=eval_args.split
|
||||
)
|
||||
@@ -121,19 +120,31 @@ def prepare_samples(eval_args: EvalArgs):
|
||||
## prepare images
|
||||
samples = []
|
||||
skip_count = 0
|
||||
for i, sample in enumerate(dataset):
|
||||
|
||||
# use image file as input to ensure the consistency between sglang and hf
|
||||
images_path = os.path.expanduser("~/.cache/mmmu/images")
|
||||
os.makedirs(images_path, exist_ok=True)
|
||||
print(f"Saving images to: {images_path}")
|
||||
|
||||
for i, sample in enumerate(tqdm(dataset)):
|
||||
sample = process_single_sample(sample)
|
||||
sample = construct_prompt(sample, eval_args.config)
|
||||
image = sample["image"]
|
||||
|
||||
width, height = image.size
|
||||
if width * height >= eval_args.image_pixels_limit:
|
||||
skip_count += 1
|
||||
continue
|
||||
image_path = f"{images_path}/image_{i}.png"
|
||||
if not os.path.exists(image_path):
|
||||
image.save(image_path)
|
||||
sample["image_path"] = image_path
|
||||
samples.append(sample)
|
||||
|
||||
print(
|
||||
f"skipping {skip_count} samples with large images, {round((float(skip_count) / len(dataset)) * 100, 2)}% of dataset"
|
||||
)
|
||||
print("samples have been prepared")
|
||||
return samples
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user