224 lines
8.0 KiB
Python
224 lines
8.0 KiB
Python
"""TruthfulQA-MC1 evaluator using direct transformers forward passes.
|
|
|
|
Why a custom eval rather than lm-evaluation-harness:
|
|
- On a 4 GB GPU + CPU-fallback machine, lm-eval was running but extremely
|
|
slow with no progress reporting. This implementation prints per-question
|
|
progress and writes intermediate JSON every 20 items, so we can both
|
|
verify forward progress and resume on interruption.
|
|
|
|
MC1 protocol: for each question, score each multiple-choice option by the
|
|
sum of log-probabilities of the option's tokens conditional on the prompt.
|
|
The model "picks" the option with the highest score. Accuracy = fraction of
|
|
questions where the picked option is the correct one.
|
|
|
|
We follow the lm-eval convention of using the "Q: ...\\nA:" prompt format.
|
|
|
|
Usage:
|
|
python eval_truthfulqa_mc1.py \\
|
|
--model meta-llama/Llama-3.2-1B \\
|
|
--out experiments/ablated_model_release/results/truthfulqa_mc1_base.json \\
|
|
--device cpu --dtype float32 [--n-questions 200]
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
import sys
|
|
import time
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
from datasets import load_dataset
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
|
|
|
def score_continuations_batch(model, tokenizer, prompt, continuations, device):
|
|
"""Sum of log-probabilities of each continuation, batched.
|
|
|
|
Tokenizes once: prompt and prompt+continuation for each option. Pads
|
|
them all to the longest length, runs a single forward pass. Returns a
|
|
list of scores aligned with `continuations`.
|
|
"""
|
|
prompt_ids = tokenizer(prompt, return_tensors="pt", add_special_tokens=True).input_ids
|
|
n_prompt = prompt_ids.shape[1]
|
|
|
|
full_ids_list = []
|
|
n_fulls = []
|
|
for c in continuations:
|
|
ids = tokenizer(prompt + " " + c.strip(), return_tensors="pt",
|
|
add_special_tokens=True).input_ids[0]
|
|
full_ids_list.append(ids)
|
|
n_fulls.append(ids.shape[0])
|
|
|
|
max_len = max(n_fulls)
|
|
pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
|
|
batch = torch.full((len(continuations), max_len), pad_id, dtype=torch.long)
|
|
attn = torch.zeros((len(continuations), max_len), dtype=torch.long)
|
|
for i, ids in enumerate(full_ids_list):
|
|
batch[i, :ids.shape[0]] = ids
|
|
attn[i, :ids.shape[0]] = 1
|
|
batch = batch.to(device)
|
|
attn = attn.to(device)
|
|
|
|
with torch.no_grad():
|
|
logits = model(batch, attention_mask=attn).logits # (B, max_len, V)
|
|
log_probs = torch.log_softmax(logits.float(), dim=-1)
|
|
|
|
scores = []
|
|
for i, n_full in enumerate(n_fulls):
|
|
if n_full <= n_prompt:
|
|
scores.append(-float("inf"))
|
|
continue
|
|
# Score positions n_prompt..n_full-1 using logits at positions n_prompt-1..n_full-2
|
|
targets = batch[i, n_prompt:n_full]
|
|
relevant_logprobs = log_probs[i, n_prompt - 1:n_full - 1, :]
|
|
s = relevant_logprobs.gather(1, targets.unsqueeze(1)).squeeze(1).sum().item()
|
|
scores.append(s)
|
|
return scores
|
|
|
|
|
|
def evaluate(model, tokenizer, dataset, device, n_questions, save_every=20, out_path=None,
|
|
resume=True):
|
|
correct = 0
|
|
total = 0
|
|
per_item = []
|
|
start_idx = 0
|
|
if resume and out_path is not None and out_path.exists():
|
|
try:
|
|
existing = json.loads(out_path.read_text())
|
|
per_item = existing.get("per_item", [])
|
|
for r in per_item:
|
|
if r["is_correct"]:
|
|
correct += 1
|
|
total += 1
|
|
start_idx = total
|
|
print(f" Resuming from {start_idx} items already evaluated "
|
|
f"(acc so far: {correct}/{total} = {correct/max(total,1):.4f})", flush=True)
|
|
except Exception as e:
|
|
print(f" Could not resume from {out_path}: {e}", flush=True)
|
|
|
|
t0 = time.time()
|
|
n = min(n_questions, len(dataset))
|
|
|
|
for idx in range(start_idx, n):
|
|
item = dataset[idx]
|
|
question = item["question"]
|
|
choices = item["mc1_targets"]["choices"]
|
|
labels = item["mc1_targets"]["labels"]
|
|
correct_idx = labels.index(1)
|
|
|
|
prompt = f"Q: {question}\nA:"
|
|
scores = score_continuations_batch(model, tokenizer, prompt, choices, device)
|
|
|
|
picked = int(max(range(len(scores)), key=lambda i: scores[i]))
|
|
is_correct = (picked == correct_idx)
|
|
if is_correct:
|
|
correct += 1
|
|
total += 1
|
|
|
|
per_item.append({
|
|
"idx": idx,
|
|
"question": question,
|
|
"n_choices": len(choices),
|
|
"correct_idx": correct_idx,
|
|
"picked_idx": picked,
|
|
"is_correct": is_correct,
|
|
"scores": scores,
|
|
})
|
|
|
|
if (idx + 1) % 5 == 0 or idx == start_idx:
|
|
elapsed = time.time() - t0
|
|
done_this_run = idx + 1 - start_idx
|
|
rate = done_this_run / elapsed if elapsed > 0 else 0
|
|
eta = (n - idx - 1) / rate if rate > 0 else 0
|
|
acc = correct / total
|
|
print(
|
|
f" [{idx+1}/{n}] acc={acc:.4f} "
|
|
f"({correct}/{total}) | {rate:.2f} q/s | "
|
|
f"elapsed {elapsed:.0f}s, ETA {eta:.0f}s",
|
|
flush=True,
|
|
)
|
|
|
|
if out_path is not None and (idx + 1) % save_every == 0:
|
|
partial = {
|
|
"n_evaluated": total,
|
|
"n_correct": correct,
|
|
"accuracy": correct / total if total else 0.0,
|
|
"per_item": per_item,
|
|
}
|
|
out_path.write_text(json.dumps(partial, indent=2))
|
|
|
|
return {
|
|
"n_evaluated": total,
|
|
"n_correct": correct,
|
|
"accuracy": correct / total if total else 0.0,
|
|
"per_item": per_item,
|
|
"wall_time_seconds": time.time() - t0,
|
|
}
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--model", required=True, type=str,
|
|
help="HF model ID or local path")
|
|
parser.add_argument("--out", required=True, type=Path)
|
|
parser.add_argument("--device", default="auto", choices=["auto", "cpu", "cuda"])
|
|
parser.add_argument("--dtype", default="float32",
|
|
choices=["float32", "float16", "bfloat16"])
|
|
parser.add_argument("--n-questions", type=int, default=817,
|
|
help="Number of questions to evaluate (default: full 817)")
|
|
parser.add_argument("--save-every", type=int, default=20)
|
|
args = parser.parse_args()
|
|
|
|
if args.device == "auto":
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
else:
|
|
device = args.device
|
|
print(f"Device: {device}, dtype: {args.dtype}")
|
|
|
|
dtype_map = {"float32": torch.float32, "float16": torch.float16, "bfloat16": torch.bfloat16}
|
|
dtype = dtype_map[args.dtype]
|
|
|
|
print(f"Loading {args.model}...")
|
|
model = AutoModelForCausalLM.from_pretrained(args.model, torch_dtype=dtype)
|
|
model.eval()
|
|
model.to(device)
|
|
tokenizer = AutoTokenizer.from_pretrained(args.model)
|
|
print(f"Loaded.")
|
|
|
|
print("Loading TruthfulQA-MC1...")
|
|
dataset = load_dataset("truthful_qa", "multiple_choice", split="validation")
|
|
print(f"Loaded {len(dataset)} questions.")
|
|
|
|
args.out.parent.mkdir(parents=True, exist_ok=True)
|
|
print(f"\nEvaluating {args.n_questions} questions...")
|
|
results = evaluate(
|
|
model, tokenizer, dataset, device,
|
|
n_questions=args.n_questions,
|
|
save_every=args.save_every,
|
|
out_path=args.out,
|
|
)
|
|
|
|
summary = {
|
|
"model": args.model,
|
|
"device": device,
|
|
"dtype": args.dtype,
|
|
"task": "truthfulqa_mc1",
|
|
"n_evaluated": results["n_evaluated"],
|
|
"n_correct": results["n_correct"],
|
|
"accuracy": results["accuracy"],
|
|
"wall_time_seconds": results["wall_time_seconds"],
|
|
"per_item": results["per_item"],
|
|
}
|
|
args.out.write_text(json.dumps(summary, indent=2))
|
|
print(f"\nFinal: {summary['n_correct']}/{summary['n_evaluated']} "
|
|
f"= {summary['accuracy']:.4f}")
|
|
print(f"Wall: {summary['wall_time_seconds']:.0f}s")
|
|
print(f"Wrote {args.out}")
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
sys.exit(main())
|