refactor: rewrite bench-mmmu-sglang (#4458)
This commit is contained in:
@@ -8,11 +8,8 @@
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import base64
|
||||
import dataclasses
|
||||
import random
|
||||
from io import BytesIO
|
||||
|
||||
import openai
|
||||
from data_utils import save_json
|
||||
from eval_utils import (
|
||||
EvalArgs,
|
||||
@@ -23,21 +20,12 @@ from eval_utils import (
|
||||
)
|
||||
from tqdm import tqdm
|
||||
|
||||
from sglang import Engine
|
||||
from sglang.srt.conversation import generate_chat_conv
|
||||
from sglang.srt.openai_api.protocol import ChatCompletionRequest
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.test.test_utils import add_common_sglang_args_and_parse
|
||||
|
||||
|
||||
def eval_mmmu(args):
|
||||
server_args = ServerArgs.from_cli_args(args)
|
||||
eval_args = EvalArgs.from_cli_args(args)
|
||||
|
||||
if server_args.chat_template is None:
|
||||
raise ValueError("Chat template must be provided for this benchmark")
|
||||
|
||||
backend = Engine(**dataclasses.asdict(server_args))
|
||||
|
||||
out_samples = dict()
|
||||
|
||||
sampling_params = get_sampling_params(eval_args)
|
||||
@@ -46,17 +34,20 @@ def eval_mmmu(args):
|
||||
|
||||
answer_dict = {}
|
||||
|
||||
for sample in tqdm(samples):
|
||||
# had to use an openai server, since SglImage doesn't support image data
|
||||
client = openai.Client(api_key="sk", base_url=f"http://127.0.0.1:{args.port}/v1")
|
||||
|
||||
for i, sample in enumerate(tqdm(samples)):
|
||||
prompt = sample["final_input_prompt"]
|
||||
image = sample["image"]
|
||||
buff = BytesIO()
|
||||
image.save(buff, format="PNG")
|
||||
base64_str = base64.b64encode(buff.getvalue()).decode("utf-8")
|
||||
prefix = prompt.split("<")[0]
|
||||
suffix = prompt.split(">")[1]
|
||||
request_dict = {
|
||||
"model": "",
|
||||
"messages": [
|
||||
image = sample["image"]
|
||||
assert image is not None
|
||||
image_path = sample["image_path"]
|
||||
# TODO: batch
|
||||
response = client.chat.completions.create(
|
||||
model="default",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
@@ -66,9 +57,7 @@ def eval_mmmu(args):
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{base64_str}"
|
||||
},
|
||||
"image_url": {"url": image_path},
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
@@ -77,40 +66,21 @@ def eval_mmmu(args):
|
||||
],
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
conv = generate_chat_conv(
|
||||
ChatCompletionRequest(**request_dict),
|
||||
template_name=server_args.chat_template,
|
||||
temperature=0,
|
||||
max_completion_tokens=sampling_params["max_new_tokens"],
|
||||
max_tokens=sampling_params["max_new_tokens"],
|
||||
)
|
||||
prompt = conv.get_prompt()
|
||||
if image is not None:
|
||||
gen_out = backend.generate(
|
||||
prompt=prompt,
|
||||
image_data=conv.image_data,
|
||||
sampling_params=sampling_params,
|
||||
)["text"]
|
||||
|
||||
response = gen_out
|
||||
|
||||
else: # multiple images actually
|
||||
if sample["question_type"] == "multiple-choice":
|
||||
all_choices = sample["all_choices"]
|
||||
response = random.choice(all_choices)
|
||||
else:
|
||||
response = "INVALID GENERATION FOR MULTIPLE IMAGE INPUTS"
|
||||
|
||||
response = response.choices[0].message.content
|
||||
process_result(response, sample, answer_dict, out_samples)
|
||||
args.output_path = f"{args.model_path}_val_sglang.json"
|
||||
|
||||
args.output_path = f"./val_sglang.json"
|
||||
save_json(args.output_path, out_samples)
|
||||
eval_result(model_answer_path=args.output_path, answer_dict=answer_dict)
|
||||
|
||||
backend.shutdown()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
ServerArgs.add_cli_args(parser)
|
||||
args = add_common_sglang_args_and_parse(parser)
|
||||
EvalArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user