Support MMMU benchmark for InternVL (#5968)

This commit is contained in:
XinyuanTong
2025-05-02 00:17:21 -07:00
committed by GitHub
parent 3409aaab32
commit 6ea1e6ac6e
2 changed files with 139 additions and 12 deletions

View File

@@ -17,6 +17,13 @@ from transformers import AutoModel, AutoProcessor, GenerationConfig
@torch.no_grad()
def eval_mmmu(args):
eval_args = EvalArgs.from_cli_args(args)
sampling_params = get_sampling_params(eval_args)
generation_config = GenerationConfig(
max_new_tokens=sampling_params["max_new_tokens"],
do_sample=False,
)
try:
from transformers import AutoModelForImageTextToText
@@ -27,12 +34,28 @@ def eval_mmmu(args):
)
except Exception as first_exception:
try:
model = AutoModel.from_pretrained(
args.model_path,
torch_dtype="auto",
trust_remote_code=True,
init_tts=False,
)
# check if the model is belongs to internvl
if "InternVL" in args.model_path:
from internvl_utils import load_image
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(args.model_path)
model = AutoModel.from_pretrained(
args.model_path,
torch_dtype="auto",
trust_remote_code=True,
)
generation_config_internvl = dict(
max_new_tokens=sampling_params["max_new_tokens"], do_sample=False
)
else:
model = AutoModel.from_pretrained(
args.model_path,
torch_dtype="auto",
trust_remote_code=True,
init_tts=False,
)
except Exception as second_exception:
raise RuntimeError(
f"Failed to load model: First attempt failed with {first_exception}, "
@@ -48,12 +71,6 @@ def eval_mmmu(args):
samples = prepare_samples(eval_args)
out_samples = dict()
sampling_params = get_sampling_params(eval_args)
generation_config = GenerationConfig(
max_new_tokens=sampling_params["max_new_tokens"],
do_sample=False,
)
answer_dict = {}
for sample in tqdm(samples):
prompt = sample["final_input_prompt"]
@@ -61,6 +78,22 @@ def eval_mmmu(args):
prefix = prompt.split("<")[0]
suffix = prompt.split(">")[1]
assert image is not None
if "InternVL" in args.model_path:
pixel_values = load_image(sample["image_path"]).to(torch.bfloat16).cuda()
contents = ""
if prefix:
contents += prefix
contents += "<image>\n"
if suffix:
contents += suffix
response = model.chat(
tokenizer, pixel_values, contents, generation_config_internvl
)
print(f"response: {response}")
process_result(response, sample, answer_dict, out_samples)
continue
contents = []
if prefix:
contents += [{"type": "text", "text": prefix}]