From 98be3bd306e018770e43c0f8ab0fd476212ae6a6 Mon Sep 17 00:00:00 2001 From: Mick Date: Tue, 18 Mar 2025 09:11:47 +0800 Subject: [PATCH] refactor: rewrite bench-mmmu-sglang (#4458) --- benchmark/mmmu/README.md | 6 ++- benchmark/mmmu/bench_hf.py | 77 +++++++++++----------------------- benchmark/mmmu/bench_sglang.py | 72 ++++++++++--------------------- benchmark/mmmu/eval_utils.py | 19 +++++++-- 4 files changed, 66 insertions(+), 108 deletions(-) diff --git a/benchmark/mmmu/README.md b/benchmark/mmmu/README.md index be3f1e043..fa23751f3 100644 --- a/benchmark/mmmu/README.md +++ b/benchmark/mmmu/README.md @@ -3,7 +3,11 @@ ### Evaluate sglang ``` -python benchmark/mmmu/bench_sglang.py --model-path Qwen/Qwen2-VL-7B-Instruct --chat-template qwen2-vl +python -m sglang.launch_server --model-path Qwen/Qwen2-VL-7B-Instruct --port 30000 +``` + +``` +python benchmark/mmmu/bench_sglang.py --model-path Qwen/Qwen2-VL-7B-Instruct --chat-template qwen2-vl --port 30000 ``` It's recommended to reduce the memory usage by appending something ike `--mem-fraction-static 0.6` to the command above. diff --git a/benchmark/mmmu/bench_hf.py b/benchmark/mmmu/bench_hf.py index 0a237b07b..60bc15bc2 100644 --- a/benchmark/mmmu/bench_hf.py +++ b/benchmark/mmmu/bench_hf.py @@ -1,14 +1,4 @@ -""" - Bench the huggingface vLM with benchmark MMMU - - Usage: - python benchmark/mmmu/bench_hf.py --model-path Qwen/Qwen2-VL-7B-Instruct - - The eval output will be logged -""" - import argparse -import random import torch from data_utils import save_json @@ -53,48 +43,31 @@ def eval_mmmu(args): image = sample["image"] prefix = prompt.split("<")[0] suffix = prompt.split(">")[1] - if image is not None: - messages = [ - { - "role": "user", - "content": [ - {"type": "text", "text": prefix}, - { - "type": "image", - "image": image, - }, - {"type": "text", "text": suffix}, - ], - } - ] - text = processor.apply_chat_template( - messages, tokenize=False, add_generation_prompt=True - ) - inputs = processor( - text=[text], - images=[image], - padding=True, - return_tensors="pt", - ).to(model.device) - - generated_ids = model.generate( - **inputs, generation_config=generation_config - ) - - response = processor.decode( - generated_ids[0], - skip_special_tokens=True, - clean_up_tokenization_spaces=False, - )[len(text) :] - print(f"response: {response}") - else: # multiple images actually - if sample["question_type"] == "multiple-choice": - all_choices = sample["all_choices"] - response = random.choice(all_choices) - - else: - response = "INVALID GENERATION FOR MULTIPLE IMAGE INPUTS" - + assert image is not None + contents = [] + if prefix: + contents += [{"type": "text", "text": prefix}] + contents += [ + { + "type": "image", + "image": sample["image_path"], + } + ] + if suffix: + contents += [{"type": "text", "text": suffix}] + messages = [{"role": "user", "content": contents}] + model_inputs = processor.apply_chat_template( + messages, + tokenize=True, + return_dict=True, + add_generation_prompt=True, + return_tensors="pt", + ).to(model.device) + input_len = model_inputs["input_ids"].shape[-1] + generation = model.generate(**model_inputs, generation_config=generation_config) + generation = generation[0][input_len:] + response = processor.decode(generation, skip_special_tokens=True) + print(f"response: {response}") process_result(response, sample, answer_dict, out_samples) args.output_path = f"{args.model_path}_val_hf.json" diff --git a/benchmark/mmmu/bench_sglang.py b/benchmark/mmmu/bench_sglang.py index ba03dced3..643a4b6d0 100644 --- a/benchmark/mmmu/bench_sglang.py +++ b/benchmark/mmmu/bench_sglang.py @@ -8,11 +8,8 @@ """ import argparse -import base64 -import dataclasses -import random -from io import BytesIO +import openai from data_utils import save_json from eval_utils import ( EvalArgs, @@ -23,21 +20,12 @@ from eval_utils import ( ) from tqdm import tqdm -from sglang import Engine -from sglang.srt.conversation import generate_chat_conv -from sglang.srt.openai_api.protocol import ChatCompletionRequest -from sglang.srt.server_args import ServerArgs +from sglang.test.test_utils import add_common_sglang_args_and_parse def eval_mmmu(args): - server_args = ServerArgs.from_cli_args(args) eval_args = EvalArgs.from_cli_args(args) - if server_args.chat_template is None: - raise ValueError("Chat template must be provided for this benchmark") - - backend = Engine(**dataclasses.asdict(server_args)) - out_samples = dict() sampling_params = get_sampling_params(eval_args) @@ -46,17 +34,20 @@ def eval_mmmu(args): answer_dict = {} - for sample in tqdm(samples): + # had to use an openai server, since SglImage doesn't support image data + client = openai.Client(api_key="sk", base_url=f"http://127.0.0.1:{args.port}/v1") + + for i, sample in enumerate(tqdm(samples)): prompt = sample["final_input_prompt"] - image = sample["image"] - buff = BytesIO() - image.save(buff, format="PNG") - base64_str = base64.b64encode(buff.getvalue()).decode("utf-8") prefix = prompt.split("<")[0] suffix = prompt.split(">")[1] - request_dict = { - "model": "", - "messages": [ + image = sample["image"] + assert image is not None + image_path = sample["image_path"] + # TODO: batch + response = client.chat.completions.create( + model="default", + messages=[ { "role": "user", "content": [ @@ -66,9 +57,7 @@ def eval_mmmu(args): }, { "type": "image_url", - "image_url": { - "url": f"data:image/jpeg;base64,{base64_str}" - }, + "image_url": {"url": image_path}, }, { "type": "text", @@ -77,40 +66,21 @@ def eval_mmmu(args): ], } ], - } - - conv = generate_chat_conv( - ChatCompletionRequest(**request_dict), - template_name=server_args.chat_template, + temperature=0, + max_completion_tokens=sampling_params["max_new_tokens"], + max_tokens=sampling_params["max_new_tokens"], ) - prompt = conv.get_prompt() - if image is not None: - gen_out = backend.generate( - prompt=prompt, - image_data=conv.image_data, - sampling_params=sampling_params, - )["text"] - - response = gen_out - - else: # multiple images actually - if sample["question_type"] == "multiple-choice": - all_choices = sample["all_choices"] - response = random.choice(all_choices) - else: - response = "INVALID GENERATION FOR MULTIPLE IMAGE INPUTS" - + response = response.choices[0].message.content process_result(response, sample, answer_dict, out_samples) - args.output_path = f"{args.model_path}_val_sglang.json" + + args.output_path = f"./val_sglang.json" save_json(args.output_path, out_samples) eval_result(model_answer_path=args.output_path, answer_dict=answer_dict) - backend.shutdown() - if __name__ == "__main__": parser = argparse.ArgumentParser() - ServerArgs.add_cli_args(parser) + args = add_common_sglang_args_and_parse(parser) EvalArgs.add_cli_args(parser) args = parser.parse_args() diff --git a/benchmark/mmmu/eval_utils.py b/benchmark/mmmu/eval_utils.py index 6daf5db6f..046bed6b2 100644 --- a/benchmark/mmmu/eval_utils.py +++ b/benchmark/mmmu/eval_utils.py @@ -19,11 +19,11 @@ from data_utils import ( process_single_sample, ) from datasets import concatenate_datasets, load_dataset +from tqdm import tqdm @dataclasses.dataclass class EvalArgs: - backend: str = "engine" seed: int = 42 split: str = "validation" # Default setting to make the benchmark available on A100 for most 7B models @@ -35,7 +35,6 @@ class EvalArgs: @staticmethod def add_cli_args(parser: argparse.ArgumentParser): - parser.add_argument("--backend", type=str, default=EvalArgs.backend) parser.add_argument( "--result-filename", type=str, default=EvalArgs.result_filename ) @@ -108,7 +107,7 @@ def prepare_samples(eval_args: EvalArgs): # run for each subject sub_dataset_list = [] - for subject in CAT_SHORT2LONG.values(): + for subject in tqdm(CAT_SHORT2LONG.values()): sub_dataset = load_dataset( eval_args.dataset_path, subject, split=eval_args.split ) @@ -121,19 +120,31 @@ def prepare_samples(eval_args: EvalArgs): ## prepare images samples = [] skip_count = 0 - for i, sample in enumerate(dataset): + + # use image file as input to ensure the consistency between sglang and hf + images_path = os.path.expanduser("~/.cache/mmmu/images") + os.makedirs(images_path, exist_ok=True) + print(f"Saving images to: {images_path}") + + for i, sample in enumerate(tqdm(dataset)): sample = process_single_sample(sample) sample = construct_prompt(sample, eval_args.config) image = sample["image"] + width, height = image.size if width * height >= eval_args.image_pixels_limit: skip_count += 1 continue + image_path = f"{images_path}/image_{i}.png" + if not os.path.exists(image_path): + image.save(image_path) + sample["image_path"] = image_path samples.append(sample) print( f"skipping {skip_count} samples with large images, {round((float(skip_count) / len(dataset)) * 100, 2)}% of dataset" ) + print("samples have been prepared") return samples