refactor: rewrite bench-mmmu-sglang (#4458)
This commit is contained in:
@@ -3,7 +3,11 @@
|
|||||||
### Evaluate sglang
|
### Evaluate sglang
|
||||||
|
|
||||||
```
|
```
|
||||||
python benchmark/mmmu/bench_sglang.py --model-path Qwen/Qwen2-VL-7B-Instruct --chat-template qwen2-vl
|
python -m sglang.launch_server --model-path Qwen/Qwen2-VL-7B-Instruct --port 30000
|
||||||
|
```
|
||||||
|
|
||||||
|
```
|
||||||
|
python benchmark/mmmu/bench_sglang.py --model-path Qwen/Qwen2-VL-7B-Instruct --chat-template qwen2-vl --port 30000
|
||||||
```
|
```
|
||||||
|
|
||||||
It's recommended to reduce the memory usage by appending something ike `--mem-fraction-static 0.6` to the command above.
|
It's recommended to reduce the memory usage by appending something ike `--mem-fraction-static 0.6` to the command above.
|
||||||
|
|||||||
@@ -1,14 +1,4 @@
|
|||||||
"""
|
|
||||||
Bench the huggingface vLM with benchmark MMMU
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
python benchmark/mmmu/bench_hf.py --model-path Qwen/Qwen2-VL-7B-Instruct
|
|
||||||
|
|
||||||
The eval output will be logged
|
|
||||||
"""
|
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import random
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from data_utils import save_json
|
from data_utils import save_json
|
||||||
@@ -53,48 +43,31 @@ def eval_mmmu(args):
|
|||||||
image = sample["image"]
|
image = sample["image"]
|
||||||
prefix = prompt.split("<")[0]
|
prefix = prompt.split("<")[0]
|
||||||
suffix = prompt.split(">")[1]
|
suffix = prompt.split(">")[1]
|
||||||
if image is not None:
|
assert image is not None
|
||||||
messages = [
|
contents = []
|
||||||
{
|
if prefix:
|
||||||
"role": "user",
|
contents += [{"type": "text", "text": prefix}]
|
||||||
"content": [
|
contents += [
|
||||||
{"type": "text", "text": prefix},
|
{
|
||||||
{
|
"type": "image",
|
||||||
"type": "image",
|
"image": sample["image_path"],
|
||||||
"image": image,
|
}
|
||||||
},
|
]
|
||||||
{"type": "text", "text": suffix},
|
if suffix:
|
||||||
],
|
contents += [{"type": "text", "text": suffix}]
|
||||||
}
|
messages = [{"role": "user", "content": contents}]
|
||||||
]
|
model_inputs = processor.apply_chat_template(
|
||||||
text = processor.apply_chat_template(
|
messages,
|
||||||
messages, tokenize=False, add_generation_prompt=True
|
tokenize=True,
|
||||||
)
|
return_dict=True,
|
||||||
inputs = processor(
|
add_generation_prompt=True,
|
||||||
text=[text],
|
return_tensors="pt",
|
||||||
images=[image],
|
).to(model.device)
|
||||||
padding=True,
|
input_len = model_inputs["input_ids"].shape[-1]
|
||||||
return_tensors="pt",
|
generation = model.generate(**model_inputs, generation_config=generation_config)
|
||||||
).to(model.device)
|
generation = generation[0][input_len:]
|
||||||
|
response = processor.decode(generation, skip_special_tokens=True)
|
||||||
generated_ids = model.generate(
|
print(f"response: {response}")
|
||||||
**inputs, generation_config=generation_config
|
|
||||||
)
|
|
||||||
|
|
||||||
response = processor.decode(
|
|
||||||
generated_ids[0],
|
|
||||||
skip_special_tokens=True,
|
|
||||||
clean_up_tokenization_spaces=False,
|
|
||||||
)[len(text) :]
|
|
||||||
print(f"response: {response}")
|
|
||||||
else: # multiple images actually
|
|
||||||
if sample["question_type"] == "multiple-choice":
|
|
||||||
all_choices = sample["all_choices"]
|
|
||||||
response = random.choice(all_choices)
|
|
||||||
|
|
||||||
else:
|
|
||||||
response = "INVALID GENERATION FOR MULTIPLE IMAGE INPUTS"
|
|
||||||
|
|
||||||
process_result(response, sample, answer_dict, out_samples)
|
process_result(response, sample, answer_dict, out_samples)
|
||||||
|
|
||||||
args.output_path = f"{args.model_path}_val_hf.json"
|
args.output_path = f"{args.model_path}_val_hf.json"
|
||||||
|
|||||||
@@ -8,11 +8,8 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import base64
|
|
||||||
import dataclasses
|
|
||||||
import random
|
|
||||||
from io import BytesIO
|
|
||||||
|
|
||||||
|
import openai
|
||||||
from data_utils import save_json
|
from data_utils import save_json
|
||||||
from eval_utils import (
|
from eval_utils import (
|
||||||
EvalArgs,
|
EvalArgs,
|
||||||
@@ -23,21 +20,12 @@ from eval_utils import (
|
|||||||
)
|
)
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from sglang import Engine
|
from sglang.test.test_utils import add_common_sglang_args_and_parse
|
||||||
from sglang.srt.conversation import generate_chat_conv
|
|
||||||
from sglang.srt.openai_api.protocol import ChatCompletionRequest
|
|
||||||
from sglang.srt.server_args import ServerArgs
|
|
||||||
|
|
||||||
|
|
||||||
def eval_mmmu(args):
|
def eval_mmmu(args):
|
||||||
server_args = ServerArgs.from_cli_args(args)
|
|
||||||
eval_args = EvalArgs.from_cli_args(args)
|
eval_args = EvalArgs.from_cli_args(args)
|
||||||
|
|
||||||
if server_args.chat_template is None:
|
|
||||||
raise ValueError("Chat template must be provided for this benchmark")
|
|
||||||
|
|
||||||
backend = Engine(**dataclasses.asdict(server_args))
|
|
||||||
|
|
||||||
out_samples = dict()
|
out_samples = dict()
|
||||||
|
|
||||||
sampling_params = get_sampling_params(eval_args)
|
sampling_params = get_sampling_params(eval_args)
|
||||||
@@ -46,17 +34,20 @@ def eval_mmmu(args):
|
|||||||
|
|
||||||
answer_dict = {}
|
answer_dict = {}
|
||||||
|
|
||||||
for sample in tqdm(samples):
|
# had to use an openai server, since SglImage doesn't support image data
|
||||||
|
client = openai.Client(api_key="sk", base_url=f"http://127.0.0.1:{args.port}/v1")
|
||||||
|
|
||||||
|
for i, sample in enumerate(tqdm(samples)):
|
||||||
prompt = sample["final_input_prompt"]
|
prompt = sample["final_input_prompt"]
|
||||||
image = sample["image"]
|
|
||||||
buff = BytesIO()
|
|
||||||
image.save(buff, format="PNG")
|
|
||||||
base64_str = base64.b64encode(buff.getvalue()).decode("utf-8")
|
|
||||||
prefix = prompt.split("<")[0]
|
prefix = prompt.split("<")[0]
|
||||||
suffix = prompt.split(">")[1]
|
suffix = prompt.split(">")[1]
|
||||||
request_dict = {
|
image = sample["image"]
|
||||||
"model": "",
|
assert image is not None
|
||||||
"messages": [
|
image_path = sample["image_path"]
|
||||||
|
# TODO: batch
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model="default",
|
||||||
|
messages=[
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": [
|
"content": [
|
||||||
@@ -66,9 +57,7 @@ def eval_mmmu(args):
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"type": "image_url",
|
"type": "image_url",
|
||||||
"image_url": {
|
"image_url": {"url": image_path},
|
||||||
"url": f"data:image/jpeg;base64,{base64_str}"
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"type": "text",
|
"type": "text",
|
||||||
@@ -77,40 +66,21 @@ def eval_mmmu(args):
|
|||||||
],
|
],
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
}
|
temperature=0,
|
||||||
|
max_completion_tokens=sampling_params["max_new_tokens"],
|
||||||
conv = generate_chat_conv(
|
max_tokens=sampling_params["max_new_tokens"],
|
||||||
ChatCompletionRequest(**request_dict),
|
|
||||||
template_name=server_args.chat_template,
|
|
||||||
)
|
)
|
||||||
prompt = conv.get_prompt()
|
response = response.choices[0].message.content
|
||||||
if image is not None:
|
|
||||||
gen_out = backend.generate(
|
|
||||||
prompt=prompt,
|
|
||||||
image_data=conv.image_data,
|
|
||||||
sampling_params=sampling_params,
|
|
||||||
)["text"]
|
|
||||||
|
|
||||||
response = gen_out
|
|
||||||
|
|
||||||
else: # multiple images actually
|
|
||||||
if sample["question_type"] == "multiple-choice":
|
|
||||||
all_choices = sample["all_choices"]
|
|
||||||
response = random.choice(all_choices)
|
|
||||||
else:
|
|
||||||
response = "INVALID GENERATION FOR MULTIPLE IMAGE INPUTS"
|
|
||||||
|
|
||||||
process_result(response, sample, answer_dict, out_samples)
|
process_result(response, sample, answer_dict, out_samples)
|
||||||
args.output_path = f"{args.model_path}_val_sglang.json"
|
|
||||||
|
args.output_path = f"./val_sglang.json"
|
||||||
save_json(args.output_path, out_samples)
|
save_json(args.output_path, out_samples)
|
||||||
eval_result(model_answer_path=args.output_path, answer_dict=answer_dict)
|
eval_result(model_answer_path=args.output_path, answer_dict=answer_dict)
|
||||||
|
|
||||||
backend.shutdown()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
ServerArgs.add_cli_args(parser)
|
args = add_common_sglang_args_and_parse(parser)
|
||||||
EvalArgs.add_cli_args(parser)
|
EvalArgs.add_cli_args(parser)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|||||||
@@ -19,11 +19,11 @@ from data_utils import (
|
|||||||
process_single_sample,
|
process_single_sample,
|
||||||
)
|
)
|
||||||
from datasets import concatenate_datasets, load_dataset
|
from datasets import concatenate_datasets, load_dataset
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class EvalArgs:
|
class EvalArgs:
|
||||||
backend: str = "engine"
|
|
||||||
seed: int = 42
|
seed: int = 42
|
||||||
split: str = "validation"
|
split: str = "validation"
|
||||||
# Default setting to make the benchmark available on A100 for most 7B models
|
# Default setting to make the benchmark available on A100 for most 7B models
|
||||||
@@ -35,7 +35,6 @@ class EvalArgs:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_cli_args(parser: argparse.ArgumentParser):
|
def add_cli_args(parser: argparse.ArgumentParser):
|
||||||
parser.add_argument("--backend", type=str, default=EvalArgs.backend)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--result-filename", type=str, default=EvalArgs.result_filename
|
"--result-filename", type=str, default=EvalArgs.result_filename
|
||||||
)
|
)
|
||||||
@@ -108,7 +107,7 @@ def prepare_samples(eval_args: EvalArgs):
|
|||||||
# run for each subject
|
# run for each subject
|
||||||
sub_dataset_list = []
|
sub_dataset_list = []
|
||||||
|
|
||||||
for subject in CAT_SHORT2LONG.values():
|
for subject in tqdm(CAT_SHORT2LONG.values()):
|
||||||
sub_dataset = load_dataset(
|
sub_dataset = load_dataset(
|
||||||
eval_args.dataset_path, subject, split=eval_args.split
|
eval_args.dataset_path, subject, split=eval_args.split
|
||||||
)
|
)
|
||||||
@@ -121,19 +120,31 @@ def prepare_samples(eval_args: EvalArgs):
|
|||||||
## prepare images
|
## prepare images
|
||||||
samples = []
|
samples = []
|
||||||
skip_count = 0
|
skip_count = 0
|
||||||
for i, sample in enumerate(dataset):
|
|
||||||
|
# use image file as input to ensure the consistency between sglang and hf
|
||||||
|
images_path = os.path.expanduser("~/.cache/mmmu/images")
|
||||||
|
os.makedirs(images_path, exist_ok=True)
|
||||||
|
print(f"Saving images to: {images_path}")
|
||||||
|
|
||||||
|
for i, sample in enumerate(tqdm(dataset)):
|
||||||
sample = process_single_sample(sample)
|
sample = process_single_sample(sample)
|
||||||
sample = construct_prompt(sample, eval_args.config)
|
sample = construct_prompt(sample, eval_args.config)
|
||||||
image = sample["image"]
|
image = sample["image"]
|
||||||
|
|
||||||
width, height = image.size
|
width, height = image.size
|
||||||
if width * height >= eval_args.image_pixels_limit:
|
if width * height >= eval_args.image_pixels_limit:
|
||||||
skip_count += 1
|
skip_count += 1
|
||||||
continue
|
continue
|
||||||
|
image_path = f"{images_path}/image_{i}.png"
|
||||||
|
if not os.path.exists(image_path):
|
||||||
|
image.save(image_path)
|
||||||
|
sample["image_path"] = image_path
|
||||||
samples.append(sample)
|
samples.append(sample)
|
||||||
|
|
||||||
print(
|
print(
|
||||||
f"skipping {skip_count} samples with large images, {round((float(skip_count) / len(dataset)) * 100, 2)}% of dataset"
|
f"skipping {skip_count} samples with large images, {round((float(skip_count) / len(dataset)) * 100, 2)}% of dataset"
|
||||||
)
|
)
|
||||||
|
print("samples have been prepared")
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user