初始化项目,由ModelHub XC社区提供模型
Model: Solshine/Llama-3.2-1B-sandbag-circuit-ablated Source: Original Platform
This commit is contained in:
223
eval_truthfulqa_mc1.py
Normal file
223
eval_truthfulqa_mc1.py
Normal file
@@ -0,0 +1,223 @@
|
||||
"""TruthfulQA-MC1 evaluator using direct transformers forward passes.
|
||||
|
||||
Why a custom eval rather than lm-evaluation-harness:
|
||||
- On a 4 GB GPU + CPU-fallback machine, lm-eval was running but extremely
|
||||
slow with no progress reporting. This implementation prints per-question
|
||||
progress and writes intermediate JSON every 20 items, so we can both
|
||||
verify forward progress and resume on interruption.
|
||||
|
||||
MC1 protocol: for each question, score each multiple-choice option by the
|
||||
sum of log-probabilities of the option's tokens conditional on the prompt.
|
||||
The model "picks" the option with the highest score. Accuracy = fraction of
|
||||
questions where the picked option is the correct one.
|
||||
|
||||
We follow the lm-eval convention of using the "Q: ...\\nA:" prompt format.
|
||||
|
||||
Usage:
|
||||
python eval_truthfulqa_mc1.py \\
|
||||
--model meta-llama/Llama-3.2-1B \\
|
||||
--out experiments/ablated_model_release/results/truthfulqa_mc1_base.json \\
|
||||
--device cpu --dtype float32 [--n-questions 200]
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
|
||||
def score_continuations_batch(model, tokenizer, prompt, continuations, device):
|
||||
"""Sum of log-probabilities of each continuation, batched.
|
||||
|
||||
Tokenizes once: prompt and prompt+continuation for each option. Pads
|
||||
them all to the longest length, runs a single forward pass. Returns a
|
||||
list of scores aligned with `continuations`.
|
||||
"""
|
||||
prompt_ids = tokenizer(prompt, return_tensors="pt", add_special_tokens=True).input_ids
|
||||
n_prompt = prompt_ids.shape[1]
|
||||
|
||||
full_ids_list = []
|
||||
n_fulls = []
|
||||
for c in continuations:
|
||||
ids = tokenizer(prompt + " " + c.strip(), return_tensors="pt",
|
||||
add_special_tokens=True).input_ids[0]
|
||||
full_ids_list.append(ids)
|
||||
n_fulls.append(ids.shape[0])
|
||||
|
||||
max_len = max(n_fulls)
|
||||
pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
|
||||
batch = torch.full((len(continuations), max_len), pad_id, dtype=torch.long)
|
||||
attn = torch.zeros((len(continuations), max_len), dtype=torch.long)
|
||||
for i, ids in enumerate(full_ids_list):
|
||||
batch[i, :ids.shape[0]] = ids
|
||||
attn[i, :ids.shape[0]] = 1
|
||||
batch = batch.to(device)
|
||||
attn = attn.to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
logits = model(batch, attention_mask=attn).logits # (B, max_len, V)
|
||||
log_probs = torch.log_softmax(logits.float(), dim=-1)
|
||||
|
||||
scores = []
|
||||
for i, n_full in enumerate(n_fulls):
|
||||
if n_full <= n_prompt:
|
||||
scores.append(-float("inf"))
|
||||
continue
|
||||
# Score positions n_prompt..n_full-1 using logits at positions n_prompt-1..n_full-2
|
||||
targets = batch[i, n_prompt:n_full]
|
||||
relevant_logprobs = log_probs[i, n_prompt - 1:n_full - 1, :]
|
||||
s = relevant_logprobs.gather(1, targets.unsqueeze(1)).squeeze(1).sum().item()
|
||||
scores.append(s)
|
||||
return scores
|
||||
|
||||
|
||||
def evaluate(model, tokenizer, dataset, device, n_questions, save_every=20, out_path=None,
|
||||
resume=True):
|
||||
correct = 0
|
||||
total = 0
|
||||
per_item = []
|
||||
start_idx = 0
|
||||
if resume and out_path is not None and out_path.exists():
|
||||
try:
|
||||
existing = json.loads(out_path.read_text())
|
||||
per_item = existing.get("per_item", [])
|
||||
for r in per_item:
|
||||
if r["is_correct"]:
|
||||
correct += 1
|
||||
total += 1
|
||||
start_idx = total
|
||||
print(f" Resuming from {start_idx} items already evaluated "
|
||||
f"(acc so far: {correct}/{total} = {correct/max(total,1):.4f})", flush=True)
|
||||
except Exception as e:
|
||||
print(f" Could not resume from {out_path}: {e}", flush=True)
|
||||
|
||||
t0 = time.time()
|
||||
n = min(n_questions, len(dataset))
|
||||
|
||||
for idx in range(start_idx, n):
|
||||
item = dataset[idx]
|
||||
question = item["question"]
|
||||
choices = item["mc1_targets"]["choices"]
|
||||
labels = item["mc1_targets"]["labels"]
|
||||
correct_idx = labels.index(1)
|
||||
|
||||
prompt = f"Q: {question}\nA:"
|
||||
scores = score_continuations_batch(model, tokenizer, prompt, choices, device)
|
||||
|
||||
picked = int(max(range(len(scores)), key=lambda i: scores[i]))
|
||||
is_correct = (picked == correct_idx)
|
||||
if is_correct:
|
||||
correct += 1
|
||||
total += 1
|
||||
|
||||
per_item.append({
|
||||
"idx": idx,
|
||||
"question": question,
|
||||
"n_choices": len(choices),
|
||||
"correct_idx": correct_idx,
|
||||
"picked_idx": picked,
|
||||
"is_correct": is_correct,
|
||||
"scores": scores,
|
||||
})
|
||||
|
||||
if (idx + 1) % 5 == 0 or idx == start_idx:
|
||||
elapsed = time.time() - t0
|
||||
done_this_run = idx + 1 - start_idx
|
||||
rate = done_this_run / elapsed if elapsed > 0 else 0
|
||||
eta = (n - idx - 1) / rate if rate > 0 else 0
|
||||
acc = correct / total
|
||||
print(
|
||||
f" [{idx+1}/{n}] acc={acc:.4f} "
|
||||
f"({correct}/{total}) | {rate:.2f} q/s | "
|
||||
f"elapsed {elapsed:.0f}s, ETA {eta:.0f}s",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
if out_path is not None and (idx + 1) % save_every == 0:
|
||||
partial = {
|
||||
"n_evaluated": total,
|
||||
"n_correct": correct,
|
||||
"accuracy": correct / total if total else 0.0,
|
||||
"per_item": per_item,
|
||||
}
|
||||
out_path.write_text(json.dumps(partial, indent=2))
|
||||
|
||||
return {
|
||||
"n_evaluated": total,
|
||||
"n_correct": correct,
|
||||
"accuracy": correct / total if total else 0.0,
|
||||
"per_item": per_item,
|
||||
"wall_time_seconds": time.time() - t0,
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model", required=True, type=str,
|
||||
help="HF model ID or local path")
|
||||
parser.add_argument("--out", required=True, type=Path)
|
||||
parser.add_argument("--device", default="auto", choices=["auto", "cpu", "cuda"])
|
||||
parser.add_argument("--dtype", default="float32",
|
||||
choices=["float32", "float16", "bfloat16"])
|
||||
parser.add_argument("--n-questions", type=int, default=817,
|
||||
help="Number of questions to evaluate (default: full 817)")
|
||||
parser.add_argument("--save-every", type=int, default=20)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.device == "auto":
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
else:
|
||||
device = args.device
|
||||
print(f"Device: {device}, dtype: {args.dtype}")
|
||||
|
||||
dtype_map = {"float32": torch.float32, "float16": torch.float16, "bfloat16": torch.bfloat16}
|
||||
dtype = dtype_map[args.dtype]
|
||||
|
||||
print(f"Loading {args.model}...")
|
||||
model = AutoModelForCausalLM.from_pretrained(args.model, torch_dtype=dtype)
|
||||
model.eval()
|
||||
model.to(device)
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model)
|
||||
print(f"Loaded.")
|
||||
|
||||
print("Loading TruthfulQA-MC1...")
|
||||
dataset = load_dataset("truthful_qa", "multiple_choice", split="validation")
|
||||
print(f"Loaded {len(dataset)} questions.")
|
||||
|
||||
args.out.parent.mkdir(parents=True, exist_ok=True)
|
||||
print(f"\nEvaluating {args.n_questions} questions...")
|
||||
results = evaluate(
|
||||
model, tokenizer, dataset, device,
|
||||
n_questions=args.n_questions,
|
||||
save_every=args.save_every,
|
||||
out_path=args.out,
|
||||
)
|
||||
|
||||
summary = {
|
||||
"model": args.model,
|
||||
"device": device,
|
||||
"dtype": args.dtype,
|
||||
"task": "truthfulqa_mc1",
|
||||
"n_evaluated": results["n_evaluated"],
|
||||
"n_correct": results["n_correct"],
|
||||
"accuracy": results["accuracy"],
|
||||
"wall_time_seconds": results["wall_time_seconds"],
|
||||
"per_item": results["per_item"],
|
||||
}
|
||||
args.out.write_text(json.dumps(summary, indent=2))
|
||||
print(f"\nFinal: {summary['n_correct']}/{summary['n_evaluated']} "
|
||||
f"= {summary['accuracy']:.4f}")
|
||||
print(f"Wall: {summary['wall_time_seconds']:.0f}s")
|
||||
print(f"Wrote {args.out}")
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
Reference in New Issue
Block a user