Files
reading-steiner/eval_indexlm.py

195 lines
6.2 KiB
Python
Raw Permalink Normal View History

"""
Reading Steiner Evaluation Script
Tests the trained model on:
1. Query-relevant extraction (QE) - F1/Precision/Recall
2. Main content extraction (ME) - F1/Precision/Recall
3. Inference speed on CPU
"""
import json
import time
import os
import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
def parse_intervals(text):
"""Parse interval string like '[[1,3],[5,7]]' into a set of indices."""
text = text.strip()
if text.upper() == 'NA' or not text:
return set()
try:
intervals = json.loads(text)
indices = set()
for start, end in intervals:
indices.update(range(start, end + 1))
return indices
except (json.JSONDecodeError, TypeError, ValueError):
return set()
def compute_f1(pred_indices, gold_indices):
"""Compute F1, precision, recall between two sets of indices."""
if not pred_indices and not gold_indices:
return 1.0, 1.0, 1.0
if not pred_indices or not gold_indices:
return 0.0, 0.0, 0.0
tp = len(pred_indices & gold_indices)
precision = tp / len(pred_indices) if pred_indices else 0
recall = tp / len(gold_indices) if gold_indices else 0
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
return f1, precision, recall
def generate_response(model, tokenizer, messages, device, max_new_tokens=128):
"""Generate model response for given messages."""
text = tokenizer.apply_chat_template(
messages[:-1], # Exclude assistant message (ground truth)
tokenize=False,
add_generation_prompt=True,
enable_thinking=False,
)
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=4096).to(device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=False, # Greedy for deterministic eval
temperature=1.0,
pad_token_id=tokenizer.pad_token_id,
)
# Decode only the new tokens
new_tokens = outputs[0][inputs['input_ids'].shape[1]:]
response = tokenizer.decode(new_tokens, skip_special_tokens=True)
return response.strip()
def evaluate_model(model_id, device="cpu", num_samples=100):
"""Run full evaluation."""
print(f"\n{'='*60}")
print(f"Evaluating: {model_id}")
print(f"Device: {device}")
print(f"{'='*60}")
# Load model
print("Loading model...")
dtype = torch.float32 if device == "cpu" else torch.bfloat16
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=dtype,
attn_implementation="sdpa",
).to(device)
model.eval()
# Load eval dataset
print("Loading eval dataset...")
dataset = load_dataset("OmAlve/indexlm-training-data", split="eval")
# Sample
if len(dataset) > num_samples:
dataset = dataset.shuffle(seed=42).select(range(num_samples))
# Categorize examples
qe_examples = []
me_examples = []
for row in dataset:
msgs = row['messages']
system_msg = msgs[0]['content'] if msgs[0]['role'] == 'system' else ''
if 'query' in system_msg.lower() and 'relevant' in system_msg.lower():
qe_examples.append(msgs)
else:
me_examples.append(msgs)
print(f"QE examples: {len(qe_examples)}, ME examples: {len(me_examples)}")
# Evaluate QE
print("\n--- Query-Relevant Extraction (QE) ---")
qe_metrics = evaluate_task(model, tokenizer, qe_examples[:50], device)
# Evaluate ME
print("\n--- Main Content Extraction (ME) ---")
me_metrics = evaluate_task(model, tokenizer, me_examples[:50], device)
# Speed test
print("\n--- Inference Speed Test ---")
speed_test(model, tokenizer, qe_examples[:20], device)
return qe_metrics, me_metrics
def evaluate_task(model, tokenizer, examples, device):
"""Evaluate on a set of examples."""
if not examples:
print("No examples for this task.")
return {}
f1_scores = []
precision_scores = []
recall_scores = []
exact_matches = 0
for i, msgs in enumerate(examples):
gold = msgs[-1]['content']
gold_indices = parse_intervals(gold)
pred = generate_response(model, tokenizer, msgs, device)
pred_indices = parse_intervals(pred)
f1, prec, rec = compute_f1(pred_indices, gold_indices)
f1_scores.append(f1)
precision_scores.append(prec)
recall_scores.append(rec)
if pred_indices == gold_indices:
exact_matches += 1
if i < 3:
print(f" Example {i+1}:")
print(f" Gold: {gold}")
print(f" Pred: {pred}")
print(f" F1: {f1:.3f}, P: {prec:.3f}, R: {rec:.3f}")
avg_f1 = sum(f1_scores) / len(f1_scores) * 100
avg_prec = sum(precision_scores) / len(precision_scores) * 100
avg_rec = sum(recall_scores) / len(recall_scores) * 100
em_rate = exact_matches / len(examples) * 100
print(f"\n Results ({len(examples)} examples):")
print(f" F1: {avg_f1:.2f}")
print(f" Precision: {avg_prec:.2f}")
print(f" Recall: {avg_rec:.2f}")
print(f" Exact Match: {em_rate:.2f}%")
return {"f1": avg_f1, "precision": avg_prec, "recall": avg_rec, "exact_match": em_rate}
def speed_test(model, tokenizer, examples, device):
"""Test inference speed."""
if not examples:
return
times = []
for msgs in examples:
start = time.time()
_ = generate_response(model, tokenizer, msgs, device)
elapsed = time.time() - start
times.append(elapsed)
avg_time = sum(times) / len(times)
print(f" Average inference time: {avg_time:.3f}s ({device})")
print(f" Min: {min(times):.3f}s, Max: {max(times):.3f}s")
print(f" Throughput: {1/avg_time:.1f} pages/sec")
if __name__ == "__main__":
model_id = os.environ.get("MODEL_ID", "OmAlve/reading-steiner")
device = "cuda" if torch.cuda.is_available() else "cpu"
evaluate_model(model_id, device=device, num_samples=100)