"""TruthfulQA-MC1 evaluator using direct transformers forward passes. Why a custom eval rather than lm-evaluation-harness: - On a 4 GB GPU + CPU-fallback machine, lm-eval was running but extremely slow with no progress reporting. This implementation prints per-question progress and writes intermediate JSON every 20 items, so we can both verify forward progress and resume on interruption. MC1 protocol: for each question, score each multiple-choice option by the sum of log-probabilities of the option's tokens conditional on the prompt. The model "picks" the option with the highest score. Accuracy = fraction of questions where the picked option is the correct one. We follow the lm-eval convention of using the "Q: ...\\nA:" prompt format. Usage: python eval_truthfulqa_mc1.py \\ --model meta-llama/Llama-3.2-1B \\ --out experiments/ablated_model_release/results/truthfulqa_mc1_base.json \\ --device cpu --dtype float32 [--n-questions 200] """ from __future__ import annotations import argparse import json import sys import time from pathlib import Path import torch from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoTokenizer def score_continuations_batch(model, tokenizer, prompt, continuations, device): """Sum of log-probabilities of each continuation, batched. Tokenizes once: prompt and prompt+continuation for each option. Pads them all to the longest length, runs a single forward pass. Returns a list of scores aligned with `continuations`. """ prompt_ids = tokenizer(prompt, return_tensors="pt", add_special_tokens=True).input_ids n_prompt = prompt_ids.shape[1] full_ids_list = [] n_fulls = [] for c in continuations: ids = tokenizer(prompt + " " + c.strip(), return_tensors="pt", add_special_tokens=True).input_ids[0] full_ids_list.append(ids) n_fulls.append(ids.shape[0]) max_len = max(n_fulls) pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id batch = torch.full((len(continuations), max_len), pad_id, dtype=torch.long) attn = torch.zeros((len(continuations), max_len), dtype=torch.long) for i, ids in enumerate(full_ids_list): batch[i, :ids.shape[0]] = ids attn[i, :ids.shape[0]] = 1 batch = batch.to(device) attn = attn.to(device) with torch.no_grad(): logits = model(batch, attention_mask=attn).logits # (B, max_len, V) log_probs = torch.log_softmax(logits.float(), dim=-1) scores = [] for i, n_full in enumerate(n_fulls): if n_full <= n_prompt: scores.append(-float("inf")) continue # Score positions n_prompt..n_full-1 using logits at positions n_prompt-1..n_full-2 targets = batch[i, n_prompt:n_full] relevant_logprobs = log_probs[i, n_prompt - 1:n_full - 1, :] s = relevant_logprobs.gather(1, targets.unsqueeze(1)).squeeze(1).sum().item() scores.append(s) return scores def evaluate(model, tokenizer, dataset, device, n_questions, save_every=20, out_path=None, resume=True): correct = 0 total = 0 per_item = [] start_idx = 0 if resume and out_path is not None and out_path.exists(): try: existing = json.loads(out_path.read_text()) per_item = existing.get("per_item", []) for r in per_item: if r["is_correct"]: correct += 1 total += 1 start_idx = total print(f" Resuming from {start_idx} items already evaluated " f"(acc so far: {correct}/{total} = {correct/max(total,1):.4f})", flush=True) except Exception as e: print(f" Could not resume from {out_path}: {e}", flush=True) t0 = time.time() n = min(n_questions, len(dataset)) for idx in range(start_idx, n): item = dataset[idx] question = item["question"] choices = item["mc1_targets"]["choices"] labels = item["mc1_targets"]["labels"] correct_idx = labels.index(1) prompt = f"Q: {question}\nA:" scores = score_continuations_batch(model, tokenizer, prompt, choices, device) picked = int(max(range(len(scores)), key=lambda i: scores[i])) is_correct = (picked == correct_idx) if is_correct: correct += 1 total += 1 per_item.append({ "idx": idx, "question": question, "n_choices": len(choices), "correct_idx": correct_idx, "picked_idx": picked, "is_correct": is_correct, "scores": scores, }) if (idx + 1) % 5 == 0 or idx == start_idx: elapsed = time.time() - t0 done_this_run = idx + 1 - start_idx rate = done_this_run / elapsed if elapsed > 0 else 0 eta = (n - idx - 1) / rate if rate > 0 else 0 acc = correct / total print( f" [{idx+1}/{n}] acc={acc:.4f} " f"({correct}/{total}) | {rate:.2f} q/s | " f"elapsed {elapsed:.0f}s, ETA {eta:.0f}s", flush=True, ) if out_path is not None and (idx + 1) % save_every == 0: partial = { "n_evaluated": total, "n_correct": correct, "accuracy": correct / total if total else 0.0, "per_item": per_item, } out_path.write_text(json.dumps(partial, indent=2)) return { "n_evaluated": total, "n_correct": correct, "accuracy": correct / total if total else 0.0, "per_item": per_item, "wall_time_seconds": time.time() - t0, } def main(): parser = argparse.ArgumentParser() parser.add_argument("--model", required=True, type=str, help="HF model ID or local path") parser.add_argument("--out", required=True, type=Path) parser.add_argument("--device", default="auto", choices=["auto", "cpu", "cuda"]) parser.add_argument("--dtype", default="float32", choices=["float32", "float16", "bfloat16"]) parser.add_argument("--n-questions", type=int, default=817, help="Number of questions to evaluate (default: full 817)") parser.add_argument("--save-every", type=int, default=20) args = parser.parse_args() if args.device == "auto": device = "cuda" if torch.cuda.is_available() else "cpu" else: device = args.device print(f"Device: {device}, dtype: {args.dtype}") dtype_map = {"float32": torch.float32, "float16": torch.float16, "bfloat16": torch.bfloat16} dtype = dtype_map[args.dtype] print(f"Loading {args.model}...") model = AutoModelForCausalLM.from_pretrained(args.model, torch_dtype=dtype) model.eval() model.to(device) tokenizer = AutoTokenizer.from_pretrained(args.model) print(f"Loaded.") print("Loading TruthfulQA-MC1...") dataset = load_dataset("truthful_qa", "multiple_choice", split="validation") print(f"Loaded {len(dataset)} questions.") args.out.parent.mkdir(parents=True, exist_ok=True) print(f"\nEvaluating {args.n_questions} questions...") results = evaluate( model, tokenizer, dataset, device, n_questions=args.n_questions, save_every=args.save_every, out_path=args.out, ) summary = { "model": args.model, "device": device, "dtype": args.dtype, "task": "truthfulqa_mc1", "n_evaluated": results["n_evaluated"], "n_correct": results["n_correct"], "accuracy": results["accuracy"], "wall_time_seconds": results["wall_time_seconds"], "per_item": results["per_item"], } args.out.write_text(json.dumps(summary, indent=2)) print(f"\nFinal: {summary['n_correct']}/{summary['n_evaluated']} " f"= {summary['accuracy']:.4f}") print(f"Wall: {summary['wall_time_seconds']:.0f}s") print(f"Wrote {args.out}") return 0 if __name__ == "__main__": sys.exit(main())