refactor: move image processors to separate files (#4229)
This commit is contained in:
@@ -11,11 +11,16 @@ import argparse
|
||||
import random
|
||||
|
||||
import torch
|
||||
from bench_sglang import EvalArgs, prepare_samples
|
||||
from data_utils import save_json
|
||||
from eval_utils import eval_result, get_sampling_params, parse_multi_choice_response
|
||||
from eval_utils import (
|
||||
EvalArgs,
|
||||
eval_result,
|
||||
get_sampling_params,
|
||||
prepare_samples,
|
||||
process_result,
|
||||
)
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoModelForImageTextToText, AutoProcessor
|
||||
from transformers import AutoModelForImageTextToText, AutoProcessor, GenerationConfig
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -28,7 +33,6 @@ def eval_mmmu(args):
|
||||
trust_remote_code=True,
|
||||
)
|
||||
model = model.eval().cuda()
|
||||
model = torch.compile(model)
|
||||
|
||||
processor = AutoProcessor.from_pretrained(
|
||||
args.model_path, torch_dtype="auto", device_map="auto"
|
||||
@@ -38,6 +42,10 @@ def eval_mmmu(args):
|
||||
out_samples = dict()
|
||||
|
||||
sampling_params = get_sampling_params(eval_args)
|
||||
generation_config = GenerationConfig(
|
||||
max_new_tokens=sampling_params["max_new_tokens"],
|
||||
do_sample=False,
|
||||
)
|
||||
|
||||
answer_dict = {}
|
||||
for sample in tqdm(samples):
|
||||
@@ -62,7 +70,6 @@ def eval_mmmu(args):
|
||||
text = processor.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
|
||||
inputs = processor(
|
||||
text=[text],
|
||||
images=[image],
|
||||
@@ -70,13 +77,16 @@ def eval_mmmu(args):
|
||||
return_tensors="pt",
|
||||
).to(model.device)
|
||||
|
||||
generated_ids = model.generate(**inputs, **sampling_params)
|
||||
generated_ids = model.generate(
|
||||
**inputs, generation_config=generation_config
|
||||
)
|
||||
|
||||
response = processor.decode(
|
||||
generated_ids[0],
|
||||
skip_special_tokens=True,
|
||||
clean_up_tokenization_spaces=False,
|
||||
)[len(text) :]
|
||||
print(f"response: {response}")
|
||||
else: # multiple images actually
|
||||
if sample["question_type"] == "multiple-choice":
|
||||
all_choices = sample["all_choices"]
|
||||
@@ -85,24 +95,11 @@ def eval_mmmu(args):
|
||||
else:
|
||||
response = "INVALID GENERATION FOR MULTIPLE IMAGE INPUTS"
|
||||
|
||||
if sample["question_type"] == "multiple-choice":
|
||||
pred_ans = parse_multi_choice_response(
|
||||
response, sample["all_choices"], sample["index2ans"]
|
||||
)
|
||||
else: # open question
|
||||
pred_ans = response
|
||||
out_samples[sample["id"]] = pred_ans
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
# set ground truth answer
|
||||
answer_dict[sample["id"]] = {
|
||||
"question_type": sample["question_type"],
|
||||
"ground_truth": sample["answer"],
|
||||
}
|
||||
process_result(response, sample, answer_dict, out_samples)
|
||||
|
||||
args.output_path = f"{args.model_path}_val_hf.json"
|
||||
save_json(args.output_path, out_samples)
|
||||
eval_result(output_path=args.output_path, answer_dict=answer_dict)
|
||||
eval_result(model_answer_path=args.output_path, answer_dict=answer_dict)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -8,9 +8,9 @@
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import base64
|
||||
import dataclasses
|
||||
import random
|
||||
import re
|
||||
from io import BytesIO
|
||||
|
||||
from data_utils import save_json
|
||||
@@ -18,13 +18,14 @@ from eval_utils import (
|
||||
EvalArgs,
|
||||
eval_result,
|
||||
get_sampling_params,
|
||||
parse_multi_choice_response,
|
||||
prepare_samples,
|
||||
process_result,
|
||||
)
|
||||
from tqdm import tqdm
|
||||
|
||||
from sglang import Engine
|
||||
from sglang.srt.conversation import chat_templates
|
||||
from sglang.srt.conversation import generate_chat_conv
|
||||
from sglang.srt.openai_api.protocol import ChatCompletionRequest
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
|
||||
|
||||
@@ -35,61 +36,76 @@ def eval_mmmu(args):
|
||||
if server_args.chat_template is None:
|
||||
raise ValueError("Chat template must be provided for this benchmark")
|
||||
|
||||
samples = prepare_samples(eval_args)
|
||||
|
||||
backend = Engine(**dataclasses.asdict(server_args))
|
||||
|
||||
out_samples = dict()
|
||||
|
||||
sampling_params = get_sampling_params(eval_args)
|
||||
|
||||
conv = chat_templates[server_args.chat_template].copy()
|
||||
image_token = conv.image_token
|
||||
samples = prepare_samples(eval_args)
|
||||
|
||||
answer_dict = {}
|
||||
|
||||
for sample in tqdm(samples):
|
||||
prompt = sample["final_input_prompt"]
|
||||
image = sample["image"]
|
||||
bytes_io = BytesIO()
|
||||
image.save(bytes_io, format="PNG")
|
||||
png_bytes = bytes_io.getvalue()
|
||||
|
||||
prompt = re.sub(r"<[^>]*>", image_token, prompt)
|
||||
buff = BytesIO()
|
||||
image.save(buff, format="PNG")
|
||||
base64_str = base64.b64encode(buff.getvalue()).decode("utf-8")
|
||||
prefix = prompt.split("<")[0]
|
||||
suffix = prompt.split(">")[1]
|
||||
request_dict = {
|
||||
"model": "",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": prefix,
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{base64_str}"
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": suffix,
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
conv = generate_chat_conv(
|
||||
ChatCompletionRequest(**request_dict),
|
||||
template_name=server_args.chat_template,
|
||||
)
|
||||
prompt = conv.get_prompt()
|
||||
if image is not None:
|
||||
gen_out = backend.generate(
|
||||
prompt=prompt, image_data=[png_bytes], sampling_params=sampling_params
|
||||
prompt=prompt,
|
||||
image_data=conv.image_data,
|
||||
sampling_params=sampling_params,
|
||||
)["text"]
|
||||
|
||||
response = gen_out
|
||||
|
||||
else: # multiple images actually
|
||||
if sample["question_type"] == "multiple-choice":
|
||||
all_choices = sample["all_choices"]
|
||||
response = random.choice(all_choices)
|
||||
|
||||
else:
|
||||
response = "INVALID GENERATION FOR MULTIPLE IMAGE INPUTS"
|
||||
|
||||
if sample["question_type"] == "multiple-choice":
|
||||
pred_ans = parse_multi_choice_response(
|
||||
response, sample["all_choices"], sample["index2ans"]
|
||||
)
|
||||
else: # open question
|
||||
pred_ans = response
|
||||
out_samples[sample["id"]] = pred_ans
|
||||
|
||||
# set ground truth answer
|
||||
answer_dict[sample["id"]] = {
|
||||
"question_type": sample["question_type"],
|
||||
"ground_truth": (
|
||||
sample["correct_choice"]
|
||||
if "correct_choice" in samples
|
||||
else sample["answer"]
|
||||
),
|
||||
}
|
||||
|
||||
process_result(response, sample, answer_dict, out_samples)
|
||||
args.output_path = f"{args.model_path}_val_sglang.json"
|
||||
save_json(args.output_path, out_samples)
|
||||
eval_result(output_path=args.output_path, answer_dict=answer_dict)
|
||||
eval_result(model_answer_path=args.output_path, answer_dict=answer_dict)
|
||||
|
||||
backend.shutdown()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -143,6 +143,7 @@ def process_single_sample(data):
|
||||
|
||||
# DATA SAVING
|
||||
def save_json(filename, ds):
|
||||
print(f"answers saved to: {filename}")
|
||||
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
||||
with open(filename, "w") as f:
|
||||
json.dump(ds, f, indent=4)
|
||||
|
||||
@@ -87,6 +87,7 @@ def set_seed(seed_value):
|
||||
|
||||
|
||||
def prepare_samples(eval_args: EvalArgs):
|
||||
print("preparing samples...")
|
||||
# Build prompts
|
||||
set_seed(eval_args.seed)
|
||||
|
||||
@@ -110,6 +111,7 @@ def prepare_samples(eval_args: EvalArgs):
|
||||
eval_args.dataset_path, subject, split=eval_args.split
|
||||
)
|
||||
sub_dataset_list.append(sub_dataset)
|
||||
# break
|
||||
|
||||
# merge all dataset
|
||||
dataset = concatenate_datasets(sub_dataset_list)
|
||||
@@ -426,9 +428,26 @@ def calculate_ins_level_acc(results: Dict):
|
||||
return acc / ins_num
|
||||
|
||||
|
||||
def eval_result(output_path, answer_dict):
|
||||
def process_result(response, sample, answer_dict, out_samples):
|
||||
if sample["question_type"] == "multiple-choice":
|
||||
pred_ans = parse_multi_choice_response(
|
||||
response, sample["all_choices"], sample["index2ans"]
|
||||
)
|
||||
else: # open question
|
||||
pred_ans = response
|
||||
|
||||
out_samples[sample["id"]] = pred_ans
|
||||
|
||||
# set ground truth answer
|
||||
answer_dict[sample["id"]] = {
|
||||
"question_type": sample["question_type"],
|
||||
"ground_truth": sample["answer"],
|
||||
}
|
||||
|
||||
|
||||
def eval_result(model_answer_path, answer_dict):
|
||||
print("Evaluating...")
|
||||
output_dict = json.load(open(output_path))
|
||||
output_dict = json.load(open(model_answer_path))
|
||||
# answer_dict = json.load(open(answer_path))
|
||||
|
||||
# group by category
|
||||
@@ -521,7 +540,7 @@ def eval_result(output_path, answer_dict):
|
||||
"acc": overall_acc,
|
||||
}
|
||||
pprint.pprint(printable_results)
|
||||
out = output_path
|
||||
out = model_answer_path
|
||||
with open(out, "w", encoding="utf-8") as outfile:
|
||||
json.dump(printable_results, outfile)
|
||||
print(f"eval out saved to {out}")
|
||||
|
||||
Reference in New Issue
Block a user