195 lines
6.2 KiB
Python
195 lines
6.2 KiB
Python
|
|
"""
|
||
|
|
Reading Steiner Evaluation Script
|
||
|
|
|
||
|
|
Tests the trained model on:
|
||
|
|
1. Query-relevant extraction (QE) - F1/Precision/Recall
|
||
|
|
2. Main content extraction (ME) - F1/Precision/Recall
|
||
|
|
3. Inference speed on CPU
|
||
|
|
"""
|
||
|
|
|
||
|
|
import json
|
||
|
|
import time
|
||
|
|
import os
|
||
|
|
import torch
|
||
|
|
from datasets import load_dataset
|
||
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||
|
|
|
||
|
|
|
||
|
|
def parse_intervals(text):
|
||
|
|
"""Parse interval string like '[[1,3],[5,7]]' into a set of indices."""
|
||
|
|
text = text.strip()
|
||
|
|
if text.upper() == 'NA' or not text:
|
||
|
|
return set()
|
||
|
|
try:
|
||
|
|
intervals = json.loads(text)
|
||
|
|
indices = set()
|
||
|
|
for start, end in intervals:
|
||
|
|
indices.update(range(start, end + 1))
|
||
|
|
return indices
|
||
|
|
except (json.JSONDecodeError, TypeError, ValueError):
|
||
|
|
return set()
|
||
|
|
|
||
|
|
|
||
|
|
def compute_f1(pred_indices, gold_indices):
|
||
|
|
"""Compute F1, precision, recall between two sets of indices."""
|
||
|
|
if not pred_indices and not gold_indices:
|
||
|
|
return 1.0, 1.0, 1.0
|
||
|
|
if not pred_indices or not gold_indices:
|
||
|
|
return 0.0, 0.0, 0.0
|
||
|
|
|
||
|
|
tp = len(pred_indices & gold_indices)
|
||
|
|
precision = tp / len(pred_indices) if pred_indices else 0
|
||
|
|
recall = tp / len(gold_indices) if gold_indices else 0
|
||
|
|
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
|
||
|
|
return f1, precision, recall
|
||
|
|
|
||
|
|
|
||
|
|
def generate_response(model, tokenizer, messages, device, max_new_tokens=128):
|
||
|
|
"""Generate model response for given messages."""
|
||
|
|
text = tokenizer.apply_chat_template(
|
||
|
|
messages[:-1], # Exclude assistant message (ground truth)
|
||
|
|
tokenize=False,
|
||
|
|
add_generation_prompt=True,
|
||
|
|
enable_thinking=False,
|
||
|
|
)
|
||
|
|
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=4096).to(device)
|
||
|
|
|
||
|
|
with torch.no_grad():
|
||
|
|
outputs = model.generate(
|
||
|
|
**inputs,
|
||
|
|
max_new_tokens=max_new_tokens,
|
||
|
|
do_sample=False, # Greedy for deterministic eval
|
||
|
|
temperature=1.0,
|
||
|
|
pad_token_id=tokenizer.pad_token_id,
|
||
|
|
)
|
||
|
|
|
||
|
|
# Decode only the new tokens
|
||
|
|
new_tokens = outputs[0][inputs['input_ids'].shape[1]:]
|
||
|
|
response = tokenizer.decode(new_tokens, skip_special_tokens=True)
|
||
|
|
return response.strip()
|
||
|
|
|
||
|
|
|
||
|
|
def evaluate_model(model_id, device="cpu", num_samples=100):
|
||
|
|
"""Run full evaluation."""
|
||
|
|
print(f"\n{'='*60}")
|
||
|
|
print(f"Evaluating: {model_id}")
|
||
|
|
print(f"Device: {device}")
|
||
|
|
print(f"{'='*60}")
|
||
|
|
|
||
|
|
# Load model
|
||
|
|
print("Loading model...")
|
||
|
|
dtype = torch.float32 if device == "cpu" else torch.bfloat16
|
||
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||
|
|
model = AutoModelForCausalLM.from_pretrained(
|
||
|
|
model_id,
|
||
|
|
torch_dtype=dtype,
|
||
|
|
attn_implementation="sdpa",
|
||
|
|
).to(device)
|
||
|
|
model.eval()
|
||
|
|
|
||
|
|
# Load eval dataset
|
||
|
|
print("Loading eval dataset...")
|
||
|
|
dataset = load_dataset("OmAlve/indexlm-training-data", split="eval")
|
||
|
|
|
||
|
|
# Sample
|
||
|
|
if len(dataset) > num_samples:
|
||
|
|
dataset = dataset.shuffle(seed=42).select(range(num_samples))
|
||
|
|
|
||
|
|
# Categorize examples
|
||
|
|
qe_examples = []
|
||
|
|
me_examples = []
|
||
|
|
|
||
|
|
for row in dataset:
|
||
|
|
msgs = row['messages']
|
||
|
|
system_msg = msgs[0]['content'] if msgs[0]['role'] == 'system' else ''
|
||
|
|
if 'query' in system_msg.lower() and 'relevant' in system_msg.lower():
|
||
|
|
qe_examples.append(msgs)
|
||
|
|
else:
|
||
|
|
me_examples.append(msgs)
|
||
|
|
|
||
|
|
print(f"QE examples: {len(qe_examples)}, ME examples: {len(me_examples)}")
|
||
|
|
|
||
|
|
# Evaluate QE
|
||
|
|
print("\n--- Query-Relevant Extraction (QE) ---")
|
||
|
|
qe_metrics = evaluate_task(model, tokenizer, qe_examples[:50], device)
|
||
|
|
|
||
|
|
# Evaluate ME
|
||
|
|
print("\n--- Main Content Extraction (ME) ---")
|
||
|
|
me_metrics = evaluate_task(model, tokenizer, me_examples[:50], device)
|
||
|
|
|
||
|
|
# Speed test
|
||
|
|
print("\n--- Inference Speed Test ---")
|
||
|
|
speed_test(model, tokenizer, qe_examples[:20], device)
|
||
|
|
|
||
|
|
return qe_metrics, me_metrics
|
||
|
|
|
||
|
|
|
||
|
|
def evaluate_task(model, tokenizer, examples, device):
|
||
|
|
"""Evaluate on a set of examples."""
|
||
|
|
if not examples:
|
||
|
|
print("No examples for this task.")
|
||
|
|
return {}
|
||
|
|
|
||
|
|
f1_scores = []
|
||
|
|
precision_scores = []
|
||
|
|
recall_scores = []
|
||
|
|
exact_matches = 0
|
||
|
|
|
||
|
|
for i, msgs in enumerate(examples):
|
||
|
|
gold = msgs[-1]['content']
|
||
|
|
gold_indices = parse_intervals(gold)
|
||
|
|
|
||
|
|
pred = generate_response(model, tokenizer, msgs, device)
|
||
|
|
pred_indices = parse_intervals(pred)
|
||
|
|
|
||
|
|
f1, prec, rec = compute_f1(pred_indices, gold_indices)
|
||
|
|
f1_scores.append(f1)
|
||
|
|
precision_scores.append(prec)
|
||
|
|
recall_scores.append(rec)
|
||
|
|
|
||
|
|
if pred_indices == gold_indices:
|
||
|
|
exact_matches += 1
|
||
|
|
|
||
|
|
if i < 3:
|
||
|
|
print(f" Example {i+1}:")
|
||
|
|
print(f" Gold: {gold}")
|
||
|
|
print(f" Pred: {pred}")
|
||
|
|
print(f" F1: {f1:.3f}, P: {prec:.3f}, R: {rec:.3f}")
|
||
|
|
|
||
|
|
avg_f1 = sum(f1_scores) / len(f1_scores) * 100
|
||
|
|
avg_prec = sum(precision_scores) / len(precision_scores) * 100
|
||
|
|
avg_rec = sum(recall_scores) / len(recall_scores) * 100
|
||
|
|
em_rate = exact_matches / len(examples) * 100
|
||
|
|
|
||
|
|
print(f"\n Results ({len(examples)} examples):")
|
||
|
|
print(f" F1: {avg_f1:.2f}")
|
||
|
|
print(f" Precision: {avg_prec:.2f}")
|
||
|
|
print(f" Recall: {avg_rec:.2f}")
|
||
|
|
print(f" Exact Match: {em_rate:.2f}%")
|
||
|
|
|
||
|
|
return {"f1": avg_f1, "precision": avg_prec, "recall": avg_rec, "exact_match": em_rate}
|
||
|
|
|
||
|
|
|
||
|
|
def speed_test(model, tokenizer, examples, device):
|
||
|
|
"""Test inference speed."""
|
||
|
|
if not examples:
|
||
|
|
return
|
||
|
|
|
||
|
|
times = []
|
||
|
|
for msgs in examples:
|
||
|
|
start = time.time()
|
||
|
|
_ = generate_response(model, tokenizer, msgs, device)
|
||
|
|
elapsed = time.time() - start
|
||
|
|
times.append(elapsed)
|
||
|
|
|
||
|
|
avg_time = sum(times) / len(times)
|
||
|
|
print(f" Average inference time: {avg_time:.3f}s ({device})")
|
||
|
|
print(f" Min: {min(times):.3f}s, Max: {max(times):.3f}s")
|
||
|
|
print(f" Throughput: {1/avg_time:.1f} pages/sec")
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
model_id = os.environ.get("MODEL_ID", "OmAlve/reading-steiner")
|
||
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||
|
|
evaluate_model(model_id, device=device, num_samples=100)
|