Files
Llama-3.2-1B-sandbag-circui…/eval_truthfulqa_mc1.py
ModelHub XC 1f3458b152 初始化项目,由ModelHub XC社区提供模型
Model: Solshine/Llama-3.2-1B-sandbag-circuit-ablated
Source: Original Platform
2026-05-22 11:49:16 +08:00

224 lines
8.0 KiB
Python

"""TruthfulQA-MC1 evaluator using direct transformers forward passes.
Why a custom eval rather than lm-evaluation-harness:
- On a 4 GB GPU + CPU-fallback machine, lm-eval was running but extremely
slow with no progress reporting. This implementation prints per-question
progress and writes intermediate JSON every 20 items, so we can both
verify forward progress and resume on interruption.
MC1 protocol: for each question, score each multiple-choice option by the
sum of log-probabilities of the option's tokens conditional on the prompt.
The model "picks" the option with the highest score. Accuracy = fraction of
questions where the picked option is the correct one.
We follow the lm-eval convention of using the "Q: ...\\nA:" prompt format.
Usage:
python eval_truthfulqa_mc1.py \\
--model meta-llama/Llama-3.2-1B \\
--out experiments/ablated_model_release/results/truthfulqa_mc1_base.json \\
--device cpu --dtype float32 [--n-questions 200]
"""
from __future__ import annotations
import argparse
import json
import sys
import time
from pathlib import Path
import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
def score_continuations_batch(model, tokenizer, prompt, continuations, device):
"""Sum of log-probabilities of each continuation, batched.
Tokenizes once: prompt and prompt+continuation for each option. Pads
them all to the longest length, runs a single forward pass. Returns a
list of scores aligned with `continuations`.
"""
prompt_ids = tokenizer(prompt, return_tensors="pt", add_special_tokens=True).input_ids
n_prompt = prompt_ids.shape[1]
full_ids_list = []
n_fulls = []
for c in continuations:
ids = tokenizer(prompt + " " + c.strip(), return_tensors="pt",
add_special_tokens=True).input_ids[0]
full_ids_list.append(ids)
n_fulls.append(ids.shape[0])
max_len = max(n_fulls)
pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
batch = torch.full((len(continuations), max_len), pad_id, dtype=torch.long)
attn = torch.zeros((len(continuations), max_len), dtype=torch.long)
for i, ids in enumerate(full_ids_list):
batch[i, :ids.shape[0]] = ids
attn[i, :ids.shape[0]] = 1
batch = batch.to(device)
attn = attn.to(device)
with torch.no_grad():
logits = model(batch, attention_mask=attn).logits # (B, max_len, V)
log_probs = torch.log_softmax(logits.float(), dim=-1)
scores = []
for i, n_full in enumerate(n_fulls):
if n_full <= n_prompt:
scores.append(-float("inf"))
continue
# Score positions n_prompt..n_full-1 using logits at positions n_prompt-1..n_full-2
targets = batch[i, n_prompt:n_full]
relevant_logprobs = log_probs[i, n_prompt - 1:n_full - 1, :]
s = relevant_logprobs.gather(1, targets.unsqueeze(1)).squeeze(1).sum().item()
scores.append(s)
return scores
def evaluate(model, tokenizer, dataset, device, n_questions, save_every=20, out_path=None,
resume=True):
correct = 0
total = 0
per_item = []
start_idx = 0
if resume and out_path is not None and out_path.exists():
try:
existing = json.loads(out_path.read_text())
per_item = existing.get("per_item", [])
for r in per_item:
if r["is_correct"]:
correct += 1
total += 1
start_idx = total
print(f" Resuming from {start_idx} items already evaluated "
f"(acc so far: {correct}/{total} = {correct/max(total,1):.4f})", flush=True)
except Exception as e:
print(f" Could not resume from {out_path}: {e}", flush=True)
t0 = time.time()
n = min(n_questions, len(dataset))
for idx in range(start_idx, n):
item = dataset[idx]
question = item["question"]
choices = item["mc1_targets"]["choices"]
labels = item["mc1_targets"]["labels"]
correct_idx = labels.index(1)
prompt = f"Q: {question}\nA:"
scores = score_continuations_batch(model, tokenizer, prompt, choices, device)
picked = int(max(range(len(scores)), key=lambda i: scores[i]))
is_correct = (picked == correct_idx)
if is_correct:
correct += 1
total += 1
per_item.append({
"idx": idx,
"question": question,
"n_choices": len(choices),
"correct_idx": correct_idx,
"picked_idx": picked,
"is_correct": is_correct,
"scores": scores,
})
if (idx + 1) % 5 == 0 or idx == start_idx:
elapsed = time.time() - t0
done_this_run = idx + 1 - start_idx
rate = done_this_run / elapsed if elapsed > 0 else 0
eta = (n - idx - 1) / rate if rate > 0 else 0
acc = correct / total
print(
f" [{idx+1}/{n}] acc={acc:.4f} "
f"({correct}/{total}) | {rate:.2f} q/s | "
f"elapsed {elapsed:.0f}s, ETA {eta:.0f}s",
flush=True,
)
if out_path is not None and (idx + 1) % save_every == 0:
partial = {
"n_evaluated": total,
"n_correct": correct,
"accuracy": correct / total if total else 0.0,
"per_item": per_item,
}
out_path.write_text(json.dumps(partial, indent=2))
return {
"n_evaluated": total,
"n_correct": correct,
"accuracy": correct / total if total else 0.0,
"per_item": per_item,
"wall_time_seconds": time.time() - t0,
}
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model", required=True, type=str,
help="HF model ID or local path")
parser.add_argument("--out", required=True, type=Path)
parser.add_argument("--device", default="auto", choices=["auto", "cpu", "cuda"])
parser.add_argument("--dtype", default="float32",
choices=["float32", "float16", "bfloat16"])
parser.add_argument("--n-questions", type=int, default=817,
help="Number of questions to evaluate (default: full 817)")
parser.add_argument("--save-every", type=int, default=20)
args = parser.parse_args()
if args.device == "auto":
device = "cuda" if torch.cuda.is_available() else "cpu"
else:
device = args.device
print(f"Device: {device}, dtype: {args.dtype}")
dtype_map = {"float32": torch.float32, "float16": torch.float16, "bfloat16": torch.bfloat16}
dtype = dtype_map[args.dtype]
print(f"Loading {args.model}...")
model = AutoModelForCausalLM.from_pretrained(args.model, torch_dtype=dtype)
model.eval()
model.to(device)
tokenizer = AutoTokenizer.from_pretrained(args.model)
print(f"Loaded.")
print("Loading TruthfulQA-MC1...")
dataset = load_dataset("truthful_qa", "multiple_choice", split="validation")
print(f"Loaded {len(dataset)} questions.")
args.out.parent.mkdir(parents=True, exist_ok=True)
print(f"\nEvaluating {args.n_questions} questions...")
results = evaluate(
model, tokenizer, dataset, device,
n_questions=args.n_questions,
save_every=args.save_every,
out_path=args.out,
)
summary = {
"model": args.model,
"device": device,
"dtype": args.dtype,
"task": "truthfulqa_mc1",
"n_evaluated": results["n_evaluated"],
"n_correct": results["n_correct"],
"accuracy": results["accuracy"],
"wall_time_seconds": results["wall_time_seconds"],
"per_item": results["per_item"],
}
args.out.write_text(json.dumps(summary, indent=2))
print(f"\nFinal: {summary['n_correct']}/{summary['n_evaluated']} "
f"= {summary['accuracy']:.4f}")
print(f"Wall: {summary['wall_time_seconds']:.0f}s")
print(f"Wrote {args.out}")
return 0
if __name__ == "__main__":
sys.exit(main())