Files
ModelHub XC 1f3458b152 初始化项目,由ModelHub XC社区提供模型
Model: Solshine/Llama-3.2-1B-sandbag-circuit-ablated
Source: Original Platform
2026-05-22 11:49:16 +08:00

247 lines
8.7 KiB
Python

"""MMLU evaluator (0-shot, single-letter likelihood) using direct transformers.
MMLU has exactly 4 multiple-choice options per question (labels A/B/C/D).
Standard 0-shot protocol: present the question + the 4 lettered options +
"Answer:", then compare log-probabilities of the 4 letter tokens at the
next position. The model's "pick" is whichever letter has the highest
log-prob.
This is much faster than TruthfulQA-MC1 because each question requires
only one forward pass (instead of one per choice continuation).
Stratified subsampling: with 57 subjects, sampling N questions evenly
across subjects gives ~N/57 per subject. With N=228 → 4 per subject.
Usage:
python eval_mmlu.py --model meta-llama/Llama-3.2-1B \\
--out results/mmlu_base.json \\
--device cpu --dtype float32 --n-questions 228
"""
from __future__ import annotations
import argparse
import json
import random
import sys
import time
from pathlib import Path
import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
SUBJECT_NUM_TO_LETTER = ["A", "B", "C", "D"]
def build_prompt(question, choices, subject):
"""Standard MMLU 0-shot prompt format."""
subject_clean = subject.replace("_", " ")
s = f"The following is a multiple choice question about {subject_clean}.\n\n"
s += f"{question}\n"
for letter, choice in zip(SUBJECT_NUM_TO_LETTER, choices):
s += f"{letter}. {choice}\n"
s += "Answer:"
return s
def stratified_subsample(dataset, n_questions, seed=17):
"""Sample n_questions questions evenly across the 57 MMLU subjects.
If n_questions doesn't divide evenly, distributes remainder via the seed.
"""
rng = random.Random(seed)
by_subject = {}
for i, item in enumerate(dataset):
by_subject.setdefault(item["subject"], []).append(i)
subjects = sorted(by_subject.keys())
per_subject = n_questions // len(subjects)
remainder = n_questions - per_subject * len(subjects)
picked_subjects_extra = set(rng.sample(subjects, remainder)) if remainder else set()
indices = []
for subj in subjects:
k = per_subject + (1 if subj in picked_subjects_extra else 0)
if k > 0:
indices.extend(rng.sample(by_subject[subj], min(k, len(by_subject[subj]))))
rng.shuffle(indices)
return indices
def evaluate(model, tokenizer, dataset, indices, device, save_every=20, out_path=None,
resume=True):
"""Run 0-shot MMLU. For each question, compare log-probs of " A", " B",
" C", " D" at the position immediately after "Answer:"."""
# Pre-compute the 4 single-token IDs for " A", " B", " C", " D"
# MMLU convention: there's a space before the letter.
letter_token_ids = []
for letter in SUBJECT_NUM_TO_LETTER:
ids = tokenizer.encode(" " + letter, add_special_tokens=False)
if len(ids) != 1:
print(f" Warning: ' {letter}' tokenizes to {ids} (len {len(ids)}); using last token.")
letter_token_ids.append(ids[-1])
print(f" Letter token IDs: {dict(zip(SUBJECT_NUM_TO_LETTER, letter_token_ids))}")
correct = 0
total = 0
per_item = []
start_idx = 0
if resume and out_path is not None and out_path.exists():
try:
existing = json.loads(out_path.read_text())
per_item = existing.get("per_item", [])
for r in per_item:
if r["is_correct"]:
correct += 1
total += 1
start_idx = total
print(f" Resuming from {start_idx} items already evaluated", flush=True)
except Exception as e:
print(f" Could not resume: {e}", flush=True)
n = len(indices)
t0 = time.time()
by_subject_acc = {}
for run_idx in range(start_idx, n):
ds_idx = indices[run_idx]
item = dataset[ds_idx]
question = item["question"]
choices = item["choices"]
correct_idx = item["answer"]
subject = item["subject"]
prompt = build_prompt(question, choices, subject)
input_ids = tokenizer(prompt, return_tensors="pt", add_special_tokens=True).input_ids.to(device)
with torch.no_grad():
logits = model(input_ids).logits[0, -1, :]
log_probs = torch.log_softmax(logits.float(), dim=-1)
scores = [log_probs[t].item() for t in letter_token_ids]
picked = int(max(range(4), key=lambda i: scores[i]))
is_correct = (picked == correct_idx)
if is_correct:
correct += 1
total += 1
by_subject_acc.setdefault(subject, [0, 0])
by_subject_acc[subject][1] += 1
if is_correct:
by_subject_acc[subject][0] += 1
per_item.append({
"ds_idx": ds_idx,
"subject": subject,
"question": question,
"choices": choices,
"correct_idx": correct_idx,
"picked_idx": picked,
"is_correct": is_correct,
"letter_log_probs": dict(zip(SUBJECT_NUM_TO_LETTER, scores)),
})
if (run_idx + 1) % 10 == 0 or run_idx == start_idx:
elapsed = time.time() - t0
done_this_run = run_idx + 1 - start_idx
rate = done_this_run / elapsed if elapsed > 0 else 0
eta = (n - run_idx - 1) / rate if rate > 0 else 0
print(
f" [{run_idx+1}/{n}] acc={correct/total:.4f} "
f"({correct}/{total}) | {rate:.2f} q/s | "
f"elapsed {elapsed:.0f}s, ETA {eta:.0f}s",
flush=True,
)
if out_path is not None and (run_idx + 1) % save_every == 0:
partial = {
"n_evaluated": total,
"n_correct": correct,
"accuracy": correct / total if total else 0.0,
"per_item": per_item,
"by_subject": {k: {"correct": v[0], "total": v[1],
"accuracy": v[0]/v[1] if v[1] else 0.0}
for k, v in by_subject_acc.items()},
}
out_path.write_text(json.dumps(partial, indent=2))
return {
"n_evaluated": total,
"n_correct": correct,
"accuracy": correct / total if total else 0.0,
"per_item": per_item,
"by_subject": {k: {"correct": v[0], "total": v[1],
"accuracy": v[0]/v[1] if v[1] else 0.0}
for k, v in by_subject_acc.items()},
"wall_time_seconds": time.time() - t0,
}
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model", required=True)
parser.add_argument("--out", required=True, type=Path)
parser.add_argument("--device", default="auto", choices=["auto", "cpu", "cuda"])
parser.add_argument("--dtype", default="float32",
choices=["float32", "float16", "bfloat16"])
parser.add_argument("--n-questions", type=int, default=228,
help="Subset size, stratified across the 57 subjects (default: 228 = 4/subject)")
parser.add_argument("--seed", type=int, default=17)
parser.add_argument("--save-every", type=int, default=20)
args = parser.parse_args()
if args.device == "auto":
device = "cuda" if torch.cuda.is_available() else "cpu"
else:
device = args.device
print(f"Device: {device}, dtype: {args.dtype}")
dtype_map = {"float32": torch.float32, "float16": torch.float16, "bfloat16": torch.bfloat16}
dtype = dtype_map[args.dtype]
print(f"Loading {args.model}...")
model = AutoModelForCausalLM.from_pretrained(args.model, torch_dtype=dtype)
model.eval()
model.to(device)
tokenizer = AutoTokenizer.from_pretrained(args.model)
print("Loaded.")
print("Loading MMLU (all)...")
dataset = load_dataset("cais/mmlu", "all", split="test")
print(f"Loaded {len(dataset)} questions across "
f"{len(set(dataset['subject']))} subjects.")
indices = stratified_subsample(dataset, args.n_questions, args.seed)
print(f"Stratified subsample: {len(indices)} questions.")
args.out.parent.mkdir(parents=True, exist_ok=True)
results = evaluate(
model, tokenizer, dataset, indices, device,
save_every=args.save_every, out_path=args.out,
)
summary = {
"model": args.model,
"device": device,
"dtype": args.dtype,
"task": "mmlu_0shot_letter_likelihood",
"n_questions": len(indices),
"seed": args.seed,
**results,
}
args.out.write_text(json.dumps(summary, indent=2))
print(f"\nFinal: {summary['n_correct']}/{summary['n_evaluated']} "
f"= {summary['accuracy']:.4f}")
print(f"Wall: {summary['wall_time_seconds']:.0f}s")
print(f"Wrote {args.out}")
return 0
if __name__ == "__main__":
sys.exit(main())