Files
sglang/python/sglang/test/simple_eval_mmmu_vlm.py
2025-09-26 15:24:30 -07:00

442 lines
15 KiB
Python

"""
MMMU evaluation for VLMs using the run_eval simple-evals interface.
"""
from __future__ import annotations
import base64
import io
from typing import List, Optional, Tuple
from datasets import concatenate_datasets, load_dataset
from PIL import Image
from sglang.test import simple_eval_common as common
from sglang.test.simple_eval_common import (
HTML_JINJA,
Eval,
EvalResult,
SamplerBase,
SingleEvalResult,
map_with_progress,
)
class MMMUVLMEval(Eval):
DOMAIN_CAT2SUB_CAT = {
"Art and Design": ["Art", "Art_Theory", "Design", "Music"],
"Business": ["Accounting", "Economics", "Finance", "Manage", "Marketing"],
"Science": ["Biology", "Chemistry", "Geography", "Math", "Physics"],
"Health and Medicine": [
"Basic_Medical_Science",
"Clinical_Medicine",
"Diagnostics_and_Laboratory_Medicine",
"Pharmacy",
"Public_Health",
],
"Humanities and Social Science": [
"History",
"Literature",
"Sociology",
"Psychology",
],
"Tech and Engineering": [
"Agriculture",
"Architecture_and_Engineering",
"Computer_Science",
"Electronics",
"Energy_and_Power",
"Materials",
"Mechanical_Engineering",
],
}
def __init__(
self, num_examples: Optional[int] = 100, num_threads: int = 32, seed: int = 42
):
"""Create MMMU VLM eval (Math subset, 100 fixed samples by default)."""
self.num_examples = num_examples
self.num_threads = num_threads
self.seed = seed
# Prepare samples deterministically across all MMMU subjects (validation split)
self.samples = self._prepare_mmmu_samples(self.num_examples)
@staticmethod
def _to_data_uri(image: Image.Image) -> str:
if image.mode == "RGBA":
image = image.convert("RGB")
buf = io.BytesIO()
image.save(buf, format="PNG")
b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
return f"data:image/png;base64,{b64}"
@staticmethod
def _build_mc_mapping(options: List[str]) -> Tuple[dict, List[str]]:
index2ans = {}
all_choices = []
ch = ord("A")
for opt in options:
letter = chr(ch)
index2ans[letter] = opt
all_choices.append(letter)
ch += 1
return index2ans, all_choices
def _prepare_mmmu_samples(self, k: int) -> List[dict]:
# Subjects and domains copied from MMMU data_utils to categorize results
subjects: List[str] = []
for subs in self.DOMAIN_CAT2SUB_CAT.values():
subjects.extend(subs)
# Load validation split of each subject
datasets = []
for subj in subjects:
try:
d = load_dataset("MMMU/MMMU", subj, split="validation")
# attach subject info via transform
d = d.add_column("__subject__", [subj] * len(d))
datasets.append(d)
except Exception:
continue
if not datasets:
raise RuntimeError("Failed to load MMMU datasets")
merged = concatenate_datasets(datasets)
# Deterministic selection: sort by id (fallback to subject+index)
def _key(idx):
ex = merged[idx]
return str(ex.get("id", f"{ex['__subject__']}:{idx}"))
order = sorted(range(len(merged)), key=_key)
picked_indices = order[:k]
samples: List[dict] = []
for idx in picked_indices:
ex = merged[idx]
subject = ex["__subject__"]
image = ex.get("image_1")
if image is None or not hasattr(image, "convert"):
continue
data_uri = self._to_data_uri(image)
question = ex.get("question", "")
answer = ex.get("answer")
raw_options = ex.get("options")
question_type = "open"
index2ans = None
all_choices = None
options = None
if raw_options:
try:
options = (
raw_options
if isinstance(raw_options, list)
else list(eval(raw_options))
)
if isinstance(options, list) and len(options) > 0:
index2ans, all_choices = self._build_mc_mapping(options)
question_type = "multiple-choice"
except Exception:
options = None
# Build final textual prompt; include choices if MC
prompt_text = f"Question: {question}\n\n"
if options:
letters = [chr(ord("A") + i) for i in range(len(options))]
for letter, opt in zip(letters, options):
prompt_text += f"{letter}) {opt}\n"
prompt_text += "\nAnswer: "
samples.append(
{
"id": ex.get("id", f"{subject}:{idx}"),
"final_input_prompt": prompt_text,
"image_data": data_uri,
"answer": answer,
"question_type": question_type,
"index2ans": index2ans,
"all_choices": all_choices,
"category": subject,
}
)
return samples
@staticmethod
def _split_prompt_for_image(prompt: str) -> tuple[str, str]:
"""Split a prompt containing an inline image tag into prefix and suffix.
If no tag is present, treat the whole prompt as prefix and empty suffix.
"""
if "<" in prompt and ">" in prompt:
prefix = prompt.split("<")[0]
suffix = prompt.split(">", 1)[1]
return prefix, suffix
return prompt, ""
@staticmethod
def build_chat_messages_from_prompt(prompt: str, image_data) -> List:
"""Split a prompt containing an inline image tag into prefix and suffix.
If no tag is present, treat the whole prompt as prefix and empty suffix.
"""
# Build a vision+text message for OpenAI-compatible API
prefix, suffix = MMMUVLMEval._split_prompt_for_image(prompt)
content: List[dict] = []
if prefix:
content.append({"type": "text", "text": prefix})
content.append({"type": "image_url", "image_url": {"url": image_data}})
if suffix:
content.append({"type": "text", "text": suffix})
prompt_messages = [{"role": "user", "content": content}]
return prompt_messages
def __call__(self, sampler: SamplerBase) -> EvalResult:
def fn(sample: dict):
prompt = sample["final_input_prompt"]
image_data = sample["image_data"]
prompt_messages = MMMUVLMEval.build_chat_messages_from_prompt(
prompt, image_data
)
# Sample
response_text = sampler(prompt_messages)
# Parse and score
gold = sample["answer"]
if (
sample["question_type"] == "multiple-choice"
and sample["all_choices"]
and sample["index2ans"]
):
pred = _parse_multi_choice_response(
response_text, sample["all_choices"], sample["index2ans"]
)
score = 1.0 if (gold is not None and pred == gold) else 0.0
extracted_answer = pred
else:
parsed_list = _parse_open_response(response_text)
score = (
1.0 if (gold is not None and _eval_open(gold, parsed_list)) else 0.0
)
extracted_answer = ", ".join(map(str, parsed_list))
html_rendered = common.jinja_env.from_string(HTML_JINJA).render(
prompt_messages=prompt_messages,
next_message=dict(content=response_text, role="assistant"),
score=score,
correct_answer=gold,
extracted_answer=extracted_answer,
)
convo = prompt_messages + [dict(content=response_text, role="assistant")]
return SingleEvalResult(
html=html_rendered,
score=score,
metrics={"__category__": sample["category"]},
convo=convo,
)
results = map_with_progress(fn, self.samples, self.num_threads)
# Build category table and overall accuracy
# Gather per-sample correctness and category
per_cat_total: dict[str, int] = {}
per_cat_correct: dict[str, int] = {}
htmls = []
convos = []
scores: List[float] = []
for r in results:
# __category__ stored under metrics
cat = r.metrics.get("__category__") if r.metrics else None
if cat is None:
cat = "Unknown"
per_cat_total[cat] = per_cat_total.get(cat, 0) + 1
if r.score:
per_cat_correct[cat] = per_cat_correct.get(cat, 0) + 1
htmls.append(r.html)
convos.append(r.convo)
if r.score is not None:
scores.append(r.score)
evaluation_result = {}
for cat, tot in per_cat_total.items():
corr = per_cat_correct.get(cat, 0)
acc = (corr / tot) if tot > 0 else 0.0
evaluation_result[cat] = {"acc": round(acc, 3), "num_example": tot}
printable_results = {}
# Domains first
for domain, cats in self.DOMAIN_CAT2SUB_CAT.items():
acc_sum = 0.0
num_sum = 0
for cat in cats:
if cat in evaluation_result:
acc_sum += (
evaluation_result[cat]["acc"]
* evaluation_result[cat]["num_example"]
)
num_sum += evaluation_result[cat]["num_example"]
if num_sum > 0:
printable_results[f"Overall-{domain}"] = {
"num": num_sum,
"acc": round(acc_sum / num_sum, 3),
}
# add each sub-category row if present
for cat in cats:
if cat in evaluation_result:
printable_results[cat] = {
"num": evaluation_result[cat]["num_example"],
"acc": evaluation_result[cat]["acc"],
}
# Overall
total_num = sum(v["num_example"] for v in evaluation_result.values())
overall_acc = (
sum(v["acc"] * v["num_example"] for v in evaluation_result.values())
/ total_num
if total_num > 0
else 0.0
)
printable_results["Overall"] = {"num": total_num, "acc": round(overall_acc, 3)}
# Build EvalResult
return EvalResult(
score=overall_acc, metrics=printable_results, htmls=htmls, convos=convos
)
def _parse_multi_choice_response(
response: str, all_choices: List[str], index2ans: dict
) -> str:
# loosely adapted from benchmark mmmu eval
for char in [",", ".", "!", "?", ";", ":", "'"]:
response = response.strip(char)
response = " " + response + " "
# Prefer explicit letter with bracket e.g. (A)
candidates: List[str] = []
for choice in all_choices:
if f"({choice})" in response:
candidates.append(choice)
if not candidates:
for choice in all_choices:
if f" {choice} " in response:
candidates.append(choice)
if not candidates and len(response.split()) > 5:
# try match by option text
for idx, ans in index2ans.items():
if ans and ans.lower() in response.lower():
candidates.append(idx)
if not candidates:
# fallback to first choice
return all_choices[0]
if len(candidates) == 1:
return candidates[0]
# choose the last occurrence
starts = []
for can in candidates:
pos = response.rfind(f"({can})")
if pos == -1:
pos = response.rfind(f" {can} ")
if pos == -1 and index2ans.get(can):
pos = response.lower().rfind(index2ans[can].lower())
starts.append(pos)
return candidates[int(max(range(len(starts)), key=lambda i: starts[i]))]
def _check_is_number(s: str) -> bool:
try:
float(s.replace(",", ""))
return True
except Exception:
return False
def _normalize_str(s: str):
s = s.strip()
if _check_is_number(s):
s = s.replace(",", "")
try:
v = round(float(s), 2)
return [v]
except Exception:
return [s.lower()]
return [s.lower()] if len(s) > 1 else [" " + s, s + " "]
def _extract_numbers(s: str) -> List[str]:
import re as _re
pattern_commas = r"-?\b\d{1,3}(?:,\d{3})+\b"
pattern_scientific = r"-?\d+(?:\.\d+)?[eE][+-]?\d+"
pattern_simple = r"-?(?:\d+\.\d+|\.\d+|\d+\b)(?![eE][+-]?\d+)(?![,\d])"
return (
_re.findall(pattern_commas, s)
+ _re.findall(pattern_scientific, s)
+ _re.findall(pattern_simple, s)
)
def _parse_open_response(response: str) -> List[str]:
import re as _re
def get_key_subresponses(resp: str) -> List[str]:
resp = resp.strip().strip(".").lower()
subs = _re.split(r"\.\s(?=[A-Z])|\n", resp)
indicators = [
"could be ",
"so ",
"is ",
"thus ",
"therefore ",
"final ",
"answer ",
"result ",
]
keys = []
for i, s in enumerate(subs):
cands = [*indicators]
if i == len(subs) - 1:
cands.append("=")
shortest = None
for ind in cands:
if ind in s:
part = s.split(ind)[-1].strip()
if not shortest or len(part) < len(shortest):
shortest = part
if shortest and shortest not in [":", ",", ".", "!", "?", ";", ":", "'"]:
keys.append(shortest)
return keys or [resp]
key_resps = get_key_subresponses(response)
pred_list = key_resps.copy()
for r in key_resps:
pred_list.extend(_extract_numbers(r))
out = []
for x in pred_list:
out.extend(_normalize_str(x))
# dedup
return list(dict.fromkeys(out))
def _eval_open(gold, preds: List[str]) -> bool:
if isinstance(gold, list):
norm_answers = []
for ans in gold:
norm_answers.extend(_normalize_str(ans))
else:
norm_answers = _normalize_str(gold)
for p in preds:
if isinstance(p, str):
for na in norm_answers:
if isinstance(na, str) and na in p:
return True
else:
if p in norm_answers:
return True
return False