"""MMLU evaluator (0-shot, single-letter likelihood) using direct transformers. MMLU has exactly 4 multiple-choice options per question (labels A/B/C/D). Standard 0-shot protocol: present the question + the 4 lettered options + "Answer:", then compare log-probabilities of the 4 letter tokens at the next position. The model's "pick" is whichever letter has the highest log-prob. This is much faster than TruthfulQA-MC1 because each question requires only one forward pass (instead of one per choice continuation). Stratified subsampling: with 57 subjects, sampling N questions evenly across subjects gives ~N/57 per subject. With N=228 → 4 per subject. Usage: python eval_mmlu.py --model meta-llama/Llama-3.2-1B \\ --out results/mmlu_base.json \\ --device cpu --dtype float32 --n-questions 228 """ from __future__ import annotations import argparse import json import random import sys import time from pathlib import Path import torch from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoTokenizer SUBJECT_NUM_TO_LETTER = ["A", "B", "C", "D"] def build_prompt(question, choices, subject): """Standard MMLU 0-shot prompt format.""" subject_clean = subject.replace("_", " ") s = f"The following is a multiple choice question about {subject_clean}.\n\n" s += f"{question}\n" for letter, choice in zip(SUBJECT_NUM_TO_LETTER, choices): s += f"{letter}. {choice}\n" s += "Answer:" return s def stratified_subsample(dataset, n_questions, seed=17): """Sample n_questions questions evenly across the 57 MMLU subjects. If n_questions doesn't divide evenly, distributes remainder via the seed. """ rng = random.Random(seed) by_subject = {} for i, item in enumerate(dataset): by_subject.setdefault(item["subject"], []).append(i) subjects = sorted(by_subject.keys()) per_subject = n_questions // len(subjects) remainder = n_questions - per_subject * len(subjects) picked_subjects_extra = set(rng.sample(subjects, remainder)) if remainder else set() indices = [] for subj in subjects: k = per_subject + (1 if subj in picked_subjects_extra else 0) if k > 0: indices.extend(rng.sample(by_subject[subj], min(k, len(by_subject[subj])))) rng.shuffle(indices) return indices def evaluate(model, tokenizer, dataset, indices, device, save_every=20, out_path=None, resume=True): """Run 0-shot MMLU. For each question, compare log-probs of " A", " B", " C", " D" at the position immediately after "Answer:".""" # Pre-compute the 4 single-token IDs for " A", " B", " C", " D" # MMLU convention: there's a space before the letter. letter_token_ids = [] for letter in SUBJECT_NUM_TO_LETTER: ids = tokenizer.encode(" " + letter, add_special_tokens=False) if len(ids) != 1: print(f" Warning: ' {letter}' tokenizes to {ids} (len {len(ids)}); using last token.") letter_token_ids.append(ids[-1]) print(f" Letter token IDs: {dict(zip(SUBJECT_NUM_TO_LETTER, letter_token_ids))}") correct = 0 total = 0 per_item = [] start_idx = 0 if resume and out_path is not None and out_path.exists(): try: existing = json.loads(out_path.read_text()) per_item = existing.get("per_item", []) for r in per_item: if r["is_correct"]: correct += 1 total += 1 start_idx = total print(f" Resuming from {start_idx} items already evaluated", flush=True) except Exception as e: print(f" Could not resume: {e}", flush=True) n = len(indices) t0 = time.time() by_subject_acc = {} for run_idx in range(start_idx, n): ds_idx = indices[run_idx] item = dataset[ds_idx] question = item["question"] choices = item["choices"] correct_idx = item["answer"] subject = item["subject"] prompt = build_prompt(question, choices, subject) input_ids = tokenizer(prompt, return_tensors="pt", add_special_tokens=True).input_ids.to(device) with torch.no_grad(): logits = model(input_ids).logits[0, -1, :] log_probs = torch.log_softmax(logits.float(), dim=-1) scores = [log_probs[t].item() for t in letter_token_ids] picked = int(max(range(4), key=lambda i: scores[i])) is_correct = (picked == correct_idx) if is_correct: correct += 1 total += 1 by_subject_acc.setdefault(subject, [0, 0]) by_subject_acc[subject][1] += 1 if is_correct: by_subject_acc[subject][0] += 1 per_item.append({ "ds_idx": ds_idx, "subject": subject, "question": question, "choices": choices, "correct_idx": correct_idx, "picked_idx": picked, "is_correct": is_correct, "letter_log_probs": dict(zip(SUBJECT_NUM_TO_LETTER, scores)), }) if (run_idx + 1) % 10 == 0 or run_idx == start_idx: elapsed = time.time() - t0 done_this_run = run_idx + 1 - start_idx rate = done_this_run / elapsed if elapsed > 0 else 0 eta = (n - run_idx - 1) / rate if rate > 0 else 0 print( f" [{run_idx+1}/{n}] acc={correct/total:.4f} " f"({correct}/{total}) | {rate:.2f} q/s | " f"elapsed {elapsed:.0f}s, ETA {eta:.0f}s", flush=True, ) if out_path is not None and (run_idx + 1) % save_every == 0: partial = { "n_evaluated": total, "n_correct": correct, "accuracy": correct / total if total else 0.0, "per_item": per_item, "by_subject": {k: {"correct": v[0], "total": v[1], "accuracy": v[0]/v[1] if v[1] else 0.0} for k, v in by_subject_acc.items()}, } out_path.write_text(json.dumps(partial, indent=2)) return { "n_evaluated": total, "n_correct": correct, "accuracy": correct / total if total else 0.0, "per_item": per_item, "by_subject": {k: {"correct": v[0], "total": v[1], "accuracy": v[0]/v[1] if v[1] else 0.0} for k, v in by_subject_acc.items()}, "wall_time_seconds": time.time() - t0, } def main(): parser = argparse.ArgumentParser() parser.add_argument("--model", required=True) parser.add_argument("--out", required=True, type=Path) parser.add_argument("--device", default="auto", choices=["auto", "cpu", "cuda"]) parser.add_argument("--dtype", default="float32", choices=["float32", "float16", "bfloat16"]) parser.add_argument("--n-questions", type=int, default=228, help="Subset size, stratified across the 57 subjects (default: 228 = 4/subject)") parser.add_argument("--seed", type=int, default=17) parser.add_argument("--save-every", type=int, default=20) args = parser.parse_args() if args.device == "auto": device = "cuda" if torch.cuda.is_available() else "cpu" else: device = args.device print(f"Device: {device}, dtype: {args.dtype}") dtype_map = {"float32": torch.float32, "float16": torch.float16, "bfloat16": torch.bfloat16} dtype = dtype_map[args.dtype] print(f"Loading {args.model}...") model = AutoModelForCausalLM.from_pretrained(args.model, torch_dtype=dtype) model.eval() model.to(device) tokenizer = AutoTokenizer.from_pretrained(args.model) print("Loaded.") print("Loading MMLU (all)...") dataset = load_dataset("cais/mmlu", "all", split="test") print(f"Loaded {len(dataset)} questions across " f"{len(set(dataset['subject']))} subjects.") indices = stratified_subsample(dataset, args.n_questions, args.seed) print(f"Stratified subsample: {len(indices)} questions.") args.out.parent.mkdir(parents=True, exist_ok=True) results = evaluate( model, tokenizer, dataset, indices, device, save_every=args.save_every, out_path=args.out, ) summary = { "model": args.model, "device": device, "dtype": args.dtype, "task": "mmlu_0shot_letter_likelihood", "n_questions": len(indices), "seed": args.seed, **results, } args.out.write_text(json.dumps(summary, indent=2)) print(f"\nFinal: {summary['n_correct']}/{summary['n_evaluated']} " f"= {summary['accuracy']:.4f}") print(f"Wall: {summary['wall_time_seconds']:.0f}s") print(f"Wrote {args.out}") return 0 if __name__ == "__main__": sys.exit(main())