247 lines
8.7 KiB
Python
247 lines
8.7 KiB
Python
|
|
"""MMLU evaluator (0-shot, single-letter likelihood) using direct transformers.
|
||
|
|
|
||
|
|
MMLU has exactly 4 multiple-choice options per question (labels A/B/C/D).
|
||
|
|
Standard 0-shot protocol: present the question + the 4 lettered options +
|
||
|
|
"Answer:", then compare log-probabilities of the 4 letter tokens at the
|
||
|
|
next position. The model's "pick" is whichever letter has the highest
|
||
|
|
log-prob.
|
||
|
|
|
||
|
|
This is much faster than TruthfulQA-MC1 because each question requires
|
||
|
|
only one forward pass (instead of one per choice continuation).
|
||
|
|
|
||
|
|
Stratified subsampling: with 57 subjects, sampling N questions evenly
|
||
|
|
across subjects gives ~N/57 per subject. With N=228 → 4 per subject.
|
||
|
|
|
||
|
|
Usage:
|
||
|
|
python eval_mmlu.py --model meta-llama/Llama-3.2-1B \\
|
||
|
|
--out results/mmlu_base.json \\
|
||
|
|
--device cpu --dtype float32 --n-questions 228
|
||
|
|
"""
|
||
|
|
|
||
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
import argparse
|
||
|
|
import json
|
||
|
|
import random
|
||
|
|
import sys
|
||
|
|
import time
|
||
|
|
from pathlib import Path
|
||
|
|
|
||
|
|
import torch
|
||
|
|
from datasets import load_dataset
|
||
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||
|
|
|
||
|
|
|
||
|
|
SUBJECT_NUM_TO_LETTER = ["A", "B", "C", "D"]
|
||
|
|
|
||
|
|
|
||
|
|
def build_prompt(question, choices, subject):
|
||
|
|
"""Standard MMLU 0-shot prompt format."""
|
||
|
|
subject_clean = subject.replace("_", " ")
|
||
|
|
s = f"The following is a multiple choice question about {subject_clean}.\n\n"
|
||
|
|
s += f"{question}\n"
|
||
|
|
for letter, choice in zip(SUBJECT_NUM_TO_LETTER, choices):
|
||
|
|
s += f"{letter}. {choice}\n"
|
||
|
|
s += "Answer:"
|
||
|
|
return s
|
||
|
|
|
||
|
|
|
||
|
|
def stratified_subsample(dataset, n_questions, seed=17):
|
||
|
|
"""Sample n_questions questions evenly across the 57 MMLU subjects.
|
||
|
|
|
||
|
|
If n_questions doesn't divide evenly, distributes remainder via the seed.
|
||
|
|
"""
|
||
|
|
rng = random.Random(seed)
|
||
|
|
by_subject = {}
|
||
|
|
for i, item in enumerate(dataset):
|
||
|
|
by_subject.setdefault(item["subject"], []).append(i)
|
||
|
|
subjects = sorted(by_subject.keys())
|
||
|
|
per_subject = n_questions // len(subjects)
|
||
|
|
remainder = n_questions - per_subject * len(subjects)
|
||
|
|
|
||
|
|
picked_subjects_extra = set(rng.sample(subjects, remainder)) if remainder else set()
|
||
|
|
|
||
|
|
indices = []
|
||
|
|
for subj in subjects:
|
||
|
|
k = per_subject + (1 if subj in picked_subjects_extra else 0)
|
||
|
|
if k > 0:
|
||
|
|
indices.extend(rng.sample(by_subject[subj], min(k, len(by_subject[subj]))))
|
||
|
|
rng.shuffle(indices)
|
||
|
|
return indices
|
||
|
|
|
||
|
|
|
||
|
|
def evaluate(model, tokenizer, dataset, indices, device, save_every=20, out_path=None,
|
||
|
|
resume=True):
|
||
|
|
"""Run 0-shot MMLU. For each question, compare log-probs of " A", " B",
|
||
|
|
" C", " D" at the position immediately after "Answer:"."""
|
||
|
|
|
||
|
|
# Pre-compute the 4 single-token IDs for " A", " B", " C", " D"
|
||
|
|
# MMLU convention: there's a space before the letter.
|
||
|
|
letter_token_ids = []
|
||
|
|
for letter in SUBJECT_NUM_TO_LETTER:
|
||
|
|
ids = tokenizer.encode(" " + letter, add_special_tokens=False)
|
||
|
|
if len(ids) != 1:
|
||
|
|
print(f" Warning: ' {letter}' tokenizes to {ids} (len {len(ids)}); using last token.")
|
||
|
|
letter_token_ids.append(ids[-1])
|
||
|
|
print(f" Letter token IDs: {dict(zip(SUBJECT_NUM_TO_LETTER, letter_token_ids))}")
|
||
|
|
|
||
|
|
correct = 0
|
||
|
|
total = 0
|
||
|
|
per_item = []
|
||
|
|
start_idx = 0
|
||
|
|
|
||
|
|
if resume and out_path is not None and out_path.exists():
|
||
|
|
try:
|
||
|
|
existing = json.loads(out_path.read_text())
|
||
|
|
per_item = existing.get("per_item", [])
|
||
|
|
for r in per_item:
|
||
|
|
if r["is_correct"]:
|
||
|
|
correct += 1
|
||
|
|
total += 1
|
||
|
|
start_idx = total
|
||
|
|
print(f" Resuming from {start_idx} items already evaluated", flush=True)
|
||
|
|
except Exception as e:
|
||
|
|
print(f" Could not resume: {e}", flush=True)
|
||
|
|
|
||
|
|
n = len(indices)
|
||
|
|
t0 = time.time()
|
||
|
|
by_subject_acc = {}
|
||
|
|
|
||
|
|
for run_idx in range(start_idx, n):
|
||
|
|
ds_idx = indices[run_idx]
|
||
|
|
item = dataset[ds_idx]
|
||
|
|
question = item["question"]
|
||
|
|
choices = item["choices"]
|
||
|
|
correct_idx = item["answer"]
|
||
|
|
subject = item["subject"]
|
||
|
|
|
||
|
|
prompt = build_prompt(question, choices, subject)
|
||
|
|
input_ids = tokenizer(prompt, return_tensors="pt", add_special_tokens=True).input_ids.to(device)
|
||
|
|
|
||
|
|
with torch.no_grad():
|
||
|
|
logits = model(input_ids).logits[0, -1, :]
|
||
|
|
log_probs = torch.log_softmax(logits.float(), dim=-1)
|
||
|
|
scores = [log_probs[t].item() for t in letter_token_ids]
|
||
|
|
|
||
|
|
picked = int(max(range(4), key=lambda i: scores[i]))
|
||
|
|
is_correct = (picked == correct_idx)
|
||
|
|
if is_correct:
|
||
|
|
correct += 1
|
||
|
|
total += 1
|
||
|
|
|
||
|
|
by_subject_acc.setdefault(subject, [0, 0])
|
||
|
|
by_subject_acc[subject][1] += 1
|
||
|
|
if is_correct:
|
||
|
|
by_subject_acc[subject][0] += 1
|
||
|
|
|
||
|
|
per_item.append({
|
||
|
|
"ds_idx": ds_idx,
|
||
|
|
"subject": subject,
|
||
|
|
"question": question,
|
||
|
|
"choices": choices,
|
||
|
|
"correct_idx": correct_idx,
|
||
|
|
"picked_idx": picked,
|
||
|
|
"is_correct": is_correct,
|
||
|
|
"letter_log_probs": dict(zip(SUBJECT_NUM_TO_LETTER, scores)),
|
||
|
|
})
|
||
|
|
|
||
|
|
if (run_idx + 1) % 10 == 0 or run_idx == start_idx:
|
||
|
|
elapsed = time.time() - t0
|
||
|
|
done_this_run = run_idx + 1 - start_idx
|
||
|
|
rate = done_this_run / elapsed if elapsed > 0 else 0
|
||
|
|
eta = (n - run_idx - 1) / rate if rate > 0 else 0
|
||
|
|
print(
|
||
|
|
f" [{run_idx+1}/{n}] acc={correct/total:.4f} "
|
||
|
|
f"({correct}/{total}) | {rate:.2f} q/s | "
|
||
|
|
f"elapsed {elapsed:.0f}s, ETA {eta:.0f}s",
|
||
|
|
flush=True,
|
||
|
|
)
|
||
|
|
|
||
|
|
if out_path is not None and (run_idx + 1) % save_every == 0:
|
||
|
|
partial = {
|
||
|
|
"n_evaluated": total,
|
||
|
|
"n_correct": correct,
|
||
|
|
"accuracy": correct / total if total else 0.0,
|
||
|
|
"per_item": per_item,
|
||
|
|
"by_subject": {k: {"correct": v[0], "total": v[1],
|
||
|
|
"accuracy": v[0]/v[1] if v[1] else 0.0}
|
||
|
|
for k, v in by_subject_acc.items()},
|
||
|
|
}
|
||
|
|
out_path.write_text(json.dumps(partial, indent=2))
|
||
|
|
|
||
|
|
return {
|
||
|
|
"n_evaluated": total,
|
||
|
|
"n_correct": correct,
|
||
|
|
"accuracy": correct / total if total else 0.0,
|
||
|
|
"per_item": per_item,
|
||
|
|
"by_subject": {k: {"correct": v[0], "total": v[1],
|
||
|
|
"accuracy": v[0]/v[1] if v[1] else 0.0}
|
||
|
|
for k, v in by_subject_acc.items()},
|
||
|
|
"wall_time_seconds": time.time() - t0,
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
def main():
|
||
|
|
parser = argparse.ArgumentParser()
|
||
|
|
parser.add_argument("--model", required=True)
|
||
|
|
parser.add_argument("--out", required=True, type=Path)
|
||
|
|
parser.add_argument("--device", default="auto", choices=["auto", "cpu", "cuda"])
|
||
|
|
parser.add_argument("--dtype", default="float32",
|
||
|
|
choices=["float32", "float16", "bfloat16"])
|
||
|
|
parser.add_argument("--n-questions", type=int, default=228,
|
||
|
|
help="Subset size, stratified across the 57 subjects (default: 228 = 4/subject)")
|
||
|
|
parser.add_argument("--seed", type=int, default=17)
|
||
|
|
parser.add_argument("--save-every", type=int, default=20)
|
||
|
|
args = parser.parse_args()
|
||
|
|
|
||
|
|
if args.device == "auto":
|
||
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||
|
|
else:
|
||
|
|
device = args.device
|
||
|
|
print(f"Device: {device}, dtype: {args.dtype}")
|
||
|
|
|
||
|
|
dtype_map = {"float32": torch.float32, "float16": torch.float16, "bfloat16": torch.bfloat16}
|
||
|
|
dtype = dtype_map[args.dtype]
|
||
|
|
|
||
|
|
print(f"Loading {args.model}...")
|
||
|
|
model = AutoModelForCausalLM.from_pretrained(args.model, torch_dtype=dtype)
|
||
|
|
model.eval()
|
||
|
|
model.to(device)
|
||
|
|
tokenizer = AutoTokenizer.from_pretrained(args.model)
|
||
|
|
print("Loaded.")
|
||
|
|
|
||
|
|
print("Loading MMLU (all)...")
|
||
|
|
dataset = load_dataset("cais/mmlu", "all", split="test")
|
||
|
|
print(f"Loaded {len(dataset)} questions across "
|
||
|
|
f"{len(set(dataset['subject']))} subjects.")
|
||
|
|
|
||
|
|
indices = stratified_subsample(dataset, args.n_questions, args.seed)
|
||
|
|
print(f"Stratified subsample: {len(indices)} questions.")
|
||
|
|
|
||
|
|
args.out.parent.mkdir(parents=True, exist_ok=True)
|
||
|
|
|
||
|
|
results = evaluate(
|
||
|
|
model, tokenizer, dataset, indices, device,
|
||
|
|
save_every=args.save_every, out_path=args.out,
|
||
|
|
)
|
||
|
|
|
||
|
|
summary = {
|
||
|
|
"model": args.model,
|
||
|
|
"device": device,
|
||
|
|
"dtype": args.dtype,
|
||
|
|
"task": "mmlu_0shot_letter_likelihood",
|
||
|
|
"n_questions": len(indices),
|
||
|
|
"seed": args.seed,
|
||
|
|
**results,
|
||
|
|
}
|
||
|
|
args.out.write_text(json.dumps(summary, indent=2))
|
||
|
|
print(f"\nFinal: {summary['n_correct']}/{summary['n_evaluated']} "
|
||
|
|
f"= {summary['accuracy']:.4f}")
|
||
|
|
print(f"Wall: {summary['wall_time_seconds']:.0f}s")
|
||
|
|
print(f"Wrote {args.out}")
|
||
|
|
return 0
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
sys.exit(main())
|