初始化项目,由ModelHub XC社区提供模型
Model: OmAlve/reading-steiner Source: Original Platform
This commit is contained in:
194
eval_indexlm.py
Normal file
194
eval_indexlm.py
Normal file
@@ -0,0 +1,194 @@
|
||||
"""
|
||||
Reading Steiner Evaluation Script
|
||||
|
||||
Tests the trained model on:
|
||||
1. Query-relevant extraction (QE) - F1/Precision/Recall
|
||||
2. Main content extraction (ME) - F1/Precision/Recall
|
||||
3. Inference speed on CPU
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
import os
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
|
||||
def parse_intervals(text):
|
||||
"""Parse interval string like '[[1,3],[5,7]]' into a set of indices."""
|
||||
text = text.strip()
|
||||
if text.upper() == 'NA' or not text:
|
||||
return set()
|
||||
try:
|
||||
intervals = json.loads(text)
|
||||
indices = set()
|
||||
for start, end in intervals:
|
||||
indices.update(range(start, end + 1))
|
||||
return indices
|
||||
except (json.JSONDecodeError, TypeError, ValueError):
|
||||
return set()
|
||||
|
||||
|
||||
def compute_f1(pred_indices, gold_indices):
|
||||
"""Compute F1, precision, recall between two sets of indices."""
|
||||
if not pred_indices and not gold_indices:
|
||||
return 1.0, 1.0, 1.0
|
||||
if not pred_indices or not gold_indices:
|
||||
return 0.0, 0.0, 0.0
|
||||
|
||||
tp = len(pred_indices & gold_indices)
|
||||
precision = tp / len(pred_indices) if pred_indices else 0
|
||||
recall = tp / len(gold_indices) if gold_indices else 0
|
||||
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
|
||||
return f1, precision, recall
|
||||
|
||||
|
||||
def generate_response(model, tokenizer, messages, device, max_new_tokens=128):
|
||||
"""Generate model response for given messages."""
|
||||
text = tokenizer.apply_chat_template(
|
||||
messages[:-1], # Exclude assistant message (ground truth)
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=False,
|
||||
)
|
||||
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=4096).to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=max_new_tokens,
|
||||
do_sample=False, # Greedy for deterministic eval
|
||||
temperature=1.0,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
)
|
||||
|
||||
# Decode only the new tokens
|
||||
new_tokens = outputs[0][inputs['input_ids'].shape[1]:]
|
||||
response = tokenizer.decode(new_tokens, skip_special_tokens=True)
|
||||
return response.strip()
|
||||
|
||||
|
||||
def evaluate_model(model_id, device="cpu", num_samples=100):
|
||||
"""Run full evaluation."""
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Evaluating: {model_id}")
|
||||
print(f"Device: {device}")
|
||||
print(f"{'='*60}")
|
||||
|
||||
# Load model
|
||||
print("Loading model...")
|
||||
dtype = torch.float32 if device == "cpu" else torch.bfloat16
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id,
|
||||
torch_dtype=dtype,
|
||||
attn_implementation="sdpa",
|
||||
).to(device)
|
||||
model.eval()
|
||||
|
||||
# Load eval dataset
|
||||
print("Loading eval dataset...")
|
||||
dataset = load_dataset("OmAlve/indexlm-training-data", split="eval")
|
||||
|
||||
# Sample
|
||||
if len(dataset) > num_samples:
|
||||
dataset = dataset.shuffle(seed=42).select(range(num_samples))
|
||||
|
||||
# Categorize examples
|
||||
qe_examples = []
|
||||
me_examples = []
|
||||
|
||||
for row in dataset:
|
||||
msgs = row['messages']
|
||||
system_msg = msgs[0]['content'] if msgs[0]['role'] == 'system' else ''
|
||||
if 'query' in system_msg.lower() and 'relevant' in system_msg.lower():
|
||||
qe_examples.append(msgs)
|
||||
else:
|
||||
me_examples.append(msgs)
|
||||
|
||||
print(f"QE examples: {len(qe_examples)}, ME examples: {len(me_examples)}")
|
||||
|
||||
# Evaluate QE
|
||||
print("\n--- Query-Relevant Extraction (QE) ---")
|
||||
qe_metrics = evaluate_task(model, tokenizer, qe_examples[:50], device)
|
||||
|
||||
# Evaluate ME
|
||||
print("\n--- Main Content Extraction (ME) ---")
|
||||
me_metrics = evaluate_task(model, tokenizer, me_examples[:50], device)
|
||||
|
||||
# Speed test
|
||||
print("\n--- Inference Speed Test ---")
|
||||
speed_test(model, tokenizer, qe_examples[:20], device)
|
||||
|
||||
return qe_metrics, me_metrics
|
||||
|
||||
|
||||
def evaluate_task(model, tokenizer, examples, device):
|
||||
"""Evaluate on a set of examples."""
|
||||
if not examples:
|
||||
print("No examples for this task.")
|
||||
return {}
|
||||
|
||||
f1_scores = []
|
||||
precision_scores = []
|
||||
recall_scores = []
|
||||
exact_matches = 0
|
||||
|
||||
for i, msgs in enumerate(examples):
|
||||
gold = msgs[-1]['content']
|
||||
gold_indices = parse_intervals(gold)
|
||||
|
||||
pred = generate_response(model, tokenizer, msgs, device)
|
||||
pred_indices = parse_intervals(pred)
|
||||
|
||||
f1, prec, rec = compute_f1(pred_indices, gold_indices)
|
||||
f1_scores.append(f1)
|
||||
precision_scores.append(prec)
|
||||
recall_scores.append(rec)
|
||||
|
||||
if pred_indices == gold_indices:
|
||||
exact_matches += 1
|
||||
|
||||
if i < 3:
|
||||
print(f" Example {i+1}:")
|
||||
print(f" Gold: {gold}")
|
||||
print(f" Pred: {pred}")
|
||||
print(f" F1: {f1:.3f}, P: {prec:.3f}, R: {rec:.3f}")
|
||||
|
||||
avg_f1 = sum(f1_scores) / len(f1_scores) * 100
|
||||
avg_prec = sum(precision_scores) / len(precision_scores) * 100
|
||||
avg_rec = sum(recall_scores) / len(recall_scores) * 100
|
||||
em_rate = exact_matches / len(examples) * 100
|
||||
|
||||
print(f"\n Results ({len(examples)} examples):")
|
||||
print(f" F1: {avg_f1:.2f}")
|
||||
print(f" Precision: {avg_prec:.2f}")
|
||||
print(f" Recall: {avg_rec:.2f}")
|
||||
print(f" Exact Match: {em_rate:.2f}%")
|
||||
|
||||
return {"f1": avg_f1, "precision": avg_prec, "recall": avg_rec, "exact_match": em_rate}
|
||||
|
||||
|
||||
def speed_test(model, tokenizer, examples, device):
|
||||
"""Test inference speed."""
|
||||
if not examples:
|
||||
return
|
||||
|
||||
times = []
|
||||
for msgs in examples:
|
||||
start = time.time()
|
||||
_ = generate_response(model, tokenizer, msgs, device)
|
||||
elapsed = time.time() - start
|
||||
times.append(elapsed)
|
||||
|
||||
avg_time = sum(times) / len(times)
|
||||
print(f" Average inference time: {avg_time:.3f}s ({device})")
|
||||
print(f" Min: {min(times):.3f}s, Max: {max(times):.3f}s")
|
||||
print(f" Throughput: {1/avg_time:.1f} pages/sec")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
model_id = os.environ.get("MODEL_ID", "OmAlve/reading-steiner")
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
evaluate_model(model_id, device=device, num_samples=100)
|
||||
Reference in New Issue
Block a user