1083 lines
37 KiB
Markdown
1083 lines
37 KiB
Markdown
|
|
# IndexLM-0.6B: Index-based Web Content Extraction
|
|||
|
|
|
|||
|
|
## Project Handoff Document
|
|||
|
|
|
|||
|
|
**Paper**: [An Index-based Approach for Efficient and Effective Web Content Extraction](https://arxiv.org/abs/2512.06641)
|
|||
|
|
**Goal**: Fine-tune a SOTA web content extraction model that runs fast on CPU
|
|||
|
|
**Status**: Dataset prepared & pushed ✅ | Training script ready ✅ | Training NOT yet run ❌
|
|||
|
|
|
|||
|
|
---
|
|||
|
|
|
|||
|
|
## 1. What This Is
|
|||
|
|
|
|||
|
|
The paper introduces **IndexLM** — a model that extracts relevant content from web pages by predicting **index intervals** instead of generating full text. This makes it:
|
|||
|
|
- **10–50× faster** than generative extraction (ReaderLM-v2, Firecrawl, etc.)
|
|||
|
|
- **SOTA on RAG QA** benchmarks (HotpotQA, NQ, TriviaQA, MuSiQue, MultiHopRAG)
|
|||
|
|
- **Tiny**: even the 0.6B version beats all baselines
|
|||
|
|
|
|||
|
|
The original IndexLM weights are **not publicly released**. This project replicates the approach.
|
|||
|
|
|
|||
|
|
### How It Works
|
|||
|
|
|
|||
|
|
1. HTML is cleaned and split into indexed blocks: `[1] <h1>Title</h1>`, `[2] <p>Content...</p>`, etc.
|
|||
|
|
2. The model receives these blocks + a query
|
|||
|
|
3. It outputs index intervals like `[[2,4],[7,7],[10,12]]` — identifying which blocks are relevant
|
|||
|
|
4. The blocks are reassembled into clean HTML/Markdown
|
|||
|
|
|
|||
|
|
Two tasks:
|
|||
|
|
- **Query-relevant extraction (QE)**: Extract blocks relevant to a specific query
|
|||
|
|
- **Main content extraction (ME)**: Extract main content, filtering out nav/ads/sidebars
|
|||
|
|
|
|||
|
|
### Paper Results (Table 2 & 3)
|
|||
|
|
|
|||
|
|
| Model | Params | Avg RAG QA F1 | ME F1 | QE F1 | Latency (ME) |
|
|||
|
|
|-------|--------|---------------|-------|-------|-------------|
|
|||
|
|
| **IndexLM-0.6B** | 0.6B | 54.70 | 83.38 | 28.64 | **0.35s** |
|
|||
|
|
| **IndexLM-4B** | 4B | 55.41 | 87.40 | 31.69 | 0.81s |
|
|||
|
|
| ReaderLM-v2 | 1.5B | 46.84 | 68.89 | 13.31 | 11.76s |
|
|||
|
|
| HtmlRAG | - | 47.00 | 48.65 | 8.83 | 7.12s |
|
|||
|
|
| Firecrawl Extract | API | 52.72 | - | 29.48 | 11.33s |
|
|||
|
|
|
|||
|
|
---
|
|||
|
|
|
|||
|
|
## 2. What's Been Done
|
|||
|
|
|
|||
|
|
### ✅ Dataset Created & Pushed (v2 — Multi-domain)
|
|||
|
|
|
|||
|
|
**Hub**: [`OmAlve/indexlm-training-data`](https://huggingface.co/datasets/OmAlve/indexlm-training-data)
|
|||
|
|
|
|||
|
|
| Split | Rows |
|
|||
|
|
|-------|------|
|
|||
|
|
| train | 21,098 |
|
|||
|
|
| eval | 500 |
|
|||
|
|
|
|||
|
|
**Domain Composition (avoids Wikipedia-only bias):**
|
|||
|
|
| Source | Count | % | Domain |
|
|||
|
|
|--------|-------|---|--------|
|
|||
|
|
| MultiHopRAG | 7,165 | 33.2% | News (Mashable, CNBC, AP, etc.) |
|
|||
|
|
| HotpotQA | 6,479 | 30.0% | Wikipedia |
|
|||
|
|
| HtmlRAG-train | 2,692 | 12.5% | **Real Bing-scraped web HTML** (diverse) |
|
|||
|
|
| MS MARCO | 4,844 | 22.4% | Diverse web (Bing search results) |
|
|||
|
|
| NA (mismatched) | 418 | 1.9% | Cross-domain |
|
|||
|
|
|
|||
|
|
**Task Type Composition:**
|
|||
|
|
- `query_relevant`: ~78% — query-specific extraction
|
|||
|
|
- `main_content`: ~20% — main content vs. noise (nav/ads/cookies)
|
|||
|
|
- `query_relevant_na`: ~2% — no relevant content exists
|
|||
|
|
|
|||
|
|
**Key improvement over v1**: Real web HTML from Bing search results (via HtmlRAG-train) + news articles + MS MARCO diverse web QA, not just Wikipedia.
|
|||
|
|
|
|||
|
|
**Format**: Conversational `messages` column (SFTTrainer-native):
|
|||
|
|
```json
|
|||
|
|
{
|
|||
|
|
"messages": [
|
|||
|
|
{"role": "system", "content": "You are IndexLM, a web content extraction model..."},
|
|||
|
|
{"role": "user", "content": "URL: ...\nQuery: ...\n\nBlocks:\n[1] <h2>Title</h2>\n[2] <p>Content</p>\n...\n\nOutput the index intervals of blocks relevant to the query."},
|
|||
|
|
{"role": "assistant", "content": "[[2, 4], [7, 7]]"}
|
|||
|
|
]
|
|||
|
|
}
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
**Token length stats** (Qwen3-0.6B tokenizer):
|
|||
|
|
- Min: 316, Max: 4,105, Mean: 1,944, Median: 2,019
|
|||
|
|
- 43 examples filtered (>4096 tokens)
|
|||
|
|
|
|||
|
|
**Data pipeline** (from `prepare_data_v2.py`):
|
|||
|
|
1. **HtmlRAG-train** (5,880 raw examples): Real Bing-scraped HTML from 5 QA datasets (NQ, ASQA, TriviaQA, MuSiQue, HotpotQA). Segments HTML by block-level tags, matches relevant blocks to ground-truth answers using trigram/substring matching.
|
|||
|
|
2. **MultiHopRAG** (8,521 examples): News articles from Mashable, CNBC, AP, etc. Converts article body + evidence annotations to indexed blocks. Injects realistic noise blocks.
|
|||
|
|
3. **HotpotQA** (6,486 examples, minority): Wikipedia context with supporting facts → index intervals. Noise injected.
|
|||
|
|
4. **MS MARCO** (4,844 examples): Diverse web QA from Bing search. Passages from real web pages across numeric, entity, description, location, person query types.
|
|||
|
|
5. **NA examples** (500): Mismatched query-page pairs from different sources.
|
|||
|
|
6. Filters to ≤4096 tokens, shuffles, splits train/eval.
|
|||
|
|
|
|||
|
|
### ✅ Training Script Ready
|
|||
|
|
|
|||
|
|
**File**: `train_indexlm.py` (see Section 5 below)
|
|||
|
|
|
|||
|
|
Key settings:
|
|||
|
|
- **Base model**: `Qwen/Qwen3-0.6B` (751M params, bf16, GQA, 32K context)
|
|||
|
|
- **Method**: SFT via TRL `SFTTrainer` + `SFTConfig`
|
|||
|
|
- **Output**: `OmAlve/IndexLM-0.6B` on Hub
|
|||
|
|
- **Hyperparameters**: lr=2e-5, epochs=3, batch=4, grad_accum=4 (effective BS=16), max_length=4096, cosine LR schedule, warmup=5%
|
|||
|
|
- `push_to_hub=True`, `hub_model_id="OmAlve/IndexLM-0.6B"`
|
|||
|
|
- Trackio monitoring included
|
|||
|
|
- Flash Attention 2 for training speed
|
|||
|
|
|
|||
|
|
### ✅ Evaluation Script Ready
|
|||
|
|
|
|||
|
|
**File**: `eval_indexlm.py` (see Section 5 below)
|
|||
|
|
|
|||
|
|
Evaluates:
|
|||
|
|
- QE F1/Precision/Recall on eval split
|
|||
|
|
- ME F1/Precision/Recall on eval split
|
|||
|
|
- CPU inference speed benchmark
|
|||
|
|
|
|||
|
|
### ❌ Training Not Yet Run
|
|||
|
|
|
|||
|
|
Ran into credits issue on HF Jobs (402 Payment Required). You need to run `train_indexlm.py` on a GPU.
|
|||
|
|
|
|||
|
|
---
|
|||
|
|
|
|||
|
|
## 3. How to Train
|
|||
|
|
|
|||
|
|
### Option A: HF Jobs (if you have credits)
|
|||
|
|
|
|||
|
|
```bash
|
|||
|
|
# Dependencies
|
|||
|
|
pip install "transformers>=4.51.0" "trl>=1.2.0" torch datasets accelerate trackio "flash-attn --no-build-isolation"
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
Recommended hardware: **a10g-large** ($2/hr) or **t4-small** ($0.60/hr) — model is only 0.6B params.
|
|||
|
|
Estimated time: **2-4 hours** on a10g, **4-6 hours** on T4.
|
|||
|
|
Set timeout to **6h** minimum.
|
|||
|
|
|
|||
|
|
### Option B: Any GPU machine
|
|||
|
|
|
|||
|
|
```bash
|
|||
|
|
pip install "transformers>=4.51.0" "trl>=1.2.0" torch datasets accelerate trackio
|
|||
|
|
pip install flash-attn --no-build-isolation # optional, speeds up training
|
|||
|
|
|
|||
|
|
python train_indexlm.py
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
**VRAM**: ~8-10 GB with gradient checkpointing + bf16 at batch_size=4. Fits on T4 (16GB), any A-series, etc.
|
|||
|
|
|
|||
|
|
### Option C: Without Flash Attention
|
|||
|
|
|
|||
|
|
If `flash-attn` fails to install, change this line in `train_indexlm.py`:
|
|||
|
|
```python
|
|||
|
|
# FROM:
|
|||
|
|
attn_implementation="flash_attention_2",
|
|||
|
|
# TO:
|
|||
|
|
attn_implementation="sdpa",
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
---
|
|||
|
|
|
|||
|
|
## 4. How to Deploy on CPU
|
|||
|
|
|
|||
|
|
After training, the model at `OmAlve/IndexLM-0.6B` can be loaded for CPU inference:
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|||
|
|
import torch
|
|||
|
|
|
|||
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|||
|
|
"OmAlve/IndexLM-0.6B",
|
|||
|
|
torch_dtype=torch.float32,
|
|||
|
|
attn_implementation="sdpa",
|
|||
|
|
)
|
|||
|
|
tokenizer = AutoTokenizer.from_pretrained("OmAlve/IndexLM-0.6B")
|
|||
|
|
model.eval()
|
|||
|
|
|
|||
|
|
# Example: extract relevant content from a web page
|
|||
|
|
messages = [
|
|||
|
|
{"role": "system", "content": "You are IndexLM, a web content extraction model..."},
|
|||
|
|
{"role": "user", "content": "URL: ...\nQuery: What is Python?\n\nBlocks:\n[1] <nav>Home</nav>\n[2] <h1>Python Programming</h1>\n[3] <p>Python is a programming language...</p>\n[4] <footer>Copyright 2024</footer>\n\nOutput the index intervals of blocks relevant to the query."}
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, enable_thinking=False)
|
|||
|
|
inputs = tokenizer(text, return_tensors="pt")
|
|||
|
|
|
|||
|
|
with torch.no_grad():
|
|||
|
|
out = model.generate(**inputs, max_new_tokens=128, do_sample=False)
|
|||
|
|
|
|||
|
|
response = tokenizer.decode(out[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
|
|||
|
|
print(response) # → [[2, 3]]
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
**For even faster CPU**: quantize to INT4/INT8 with `bitsandbytes` or export to ONNX.
|
|||
|
|
|
|||
|
|
---
|
|||
|
|
|
|||
|
|
## 5. All Scripts
|
|||
|
|
|
|||
|
|
### 5.1 Data Preparation (`prepare_data.py`)
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
"""
|
|||
|
|
Prepare IndexLM training data from HotpotQA and MSMARCO.
|
|||
|
|
|
|||
|
|
Pipeline:
|
|||
|
|
1. Load HotpotQA (has context = list of (title, sentences) + supporting_facts)
|
|||
|
|
2. Convert context into indexed HTML-like blocks: [i] <tag>content</tag>
|
|||
|
|
3. The target is index intervals of blocks containing supporting facts
|
|||
|
|
4. Also create main-content extraction examples (all content blocks are "main content",
|
|||
|
|
but we inject noise blocks like nav/ads to train the model to filter them)
|
|||
|
|
5. Format as conversational messages for SFT
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import json
|
|||
|
|
import random
|
|||
|
|
import re
|
|||
|
|
from datasets import load_dataset, Dataset
|
|||
|
|
from collections import defaultdict
|
|||
|
|
|
|||
|
|
random.seed(42)
|
|||
|
|
|
|||
|
|
# Noise blocks to inject (simulating real web page clutter)
|
|||
|
|
NOISE_BLOCKS = [
|
|||
|
|
'<nav>Home | About | Contact | Privacy Policy</nav>',
|
|||
|
|
'<div class="ad">Advertisement - Continue Reading Below</div>',
|
|||
|
|
'<div class="sidebar">Related Articles: Top 10 Facts You Didn\'t Know</div>',
|
|||
|
|
'<footer>© 2024 All Rights Reserved | Terms of Service</footer>',
|
|||
|
|
'<div class="cookie-banner">This site uses cookies. Accept | Decline</div>',
|
|||
|
|
'<div class="social">Share on: Twitter | Facebook | LinkedIn</div>',
|
|||
|
|
'<nav class="breadcrumb">Home > Category > Subcategory > Article</nav>',
|
|||
|
|
'<div class="newsletter">Subscribe to our newsletter for updates</div>',
|
|||
|
|
'<div class="popup">Sign up for free access to premium content</div>',
|
|||
|
|
'<aside>Trending: Latest news and popular stories</aside>',
|
|||
|
|
'<div class="comments">Comments (0) - Be the first to comment</div>',
|
|||
|
|
'<div class="author">Written by Staff Reporter | Updated: Jan 2024</div>',
|
|||
|
|
'<div class="pagination">Previous | 1 | 2 | 3 | Next</div>',
|
|||
|
|
'<div class="search">Search this site...</div>',
|
|||
|
|
'<div class="menu">Categories: Science, Tech, Health, Sports</div>',
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
SYSTEM_PROMPT_QE = """You are IndexLM, a web content extraction model. Given a webpage split into indexed blocks and a user query, identify which blocks contain content relevant to the query.
|
|||
|
|
|
|||
|
|
Each block is formatted as: [i] <tag>content</tag>
|
|||
|
|
Output the indices of relevant blocks as a Python list of [start, end] intervals (inclusive).
|
|||
|
|
If no relevant content exists, output 'NA'.
|
|||
|
|
|
|||
|
|
Example output: [[2,4],[7,7],[10,12]]"""
|
|||
|
|
|
|||
|
|
SYSTEM_PROMPT_ME = """You are IndexLM, a web content extraction model. Given a webpage split into indexed blocks, identify which blocks contain the main content of the page (filtering out navigation, advertisements, sidebars, and other non-content elements).
|
|||
|
|
|
|||
|
|
Each block is formatted as: [i] <tag>content</tag>
|
|||
|
|
Output the indices of main content blocks as a Python list of [start, end] intervals (inclusive).
|
|||
|
|
If no main content exists, output 'NA'.
|
|||
|
|
|
|||
|
|
Example output: [[1,3],[5,8],[11,15]]"""
|
|||
|
|
|
|||
|
|
|
|||
|
|
def indices_to_intervals(indices):
|
|||
|
|
"""Convert a sorted list of indices to intervals [[start,end], ...]"""
|
|||
|
|
if not indices:
|
|||
|
|
return "NA"
|
|||
|
|
indices = sorted(set(indices))
|
|||
|
|
intervals = []
|
|||
|
|
start = indices[0]
|
|||
|
|
end = indices[0]
|
|||
|
|
for i in indices[1:]:
|
|||
|
|
if i == end + 1:
|
|||
|
|
end = i
|
|||
|
|
else:
|
|||
|
|
intervals.append([start, end])
|
|||
|
|
start = i
|
|||
|
|
end = i
|
|||
|
|
intervals.append([start, end])
|
|||
|
|
return json.dumps(intervals)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def create_indexed_blocks_from_hotpotqa(context, supporting_facts, inject_noise=True):
|
|||
|
|
"""
|
|||
|
|
Convert HotpotQA context into indexed HTML blocks.
|
|||
|
|
|
|||
|
|
context: {'title': [...], 'sentences': [[...], ...]}
|
|||
|
|
supporting_facts: {'title': [...], 'sent_id': [...]}
|
|||
|
|
|
|||
|
|
Returns: (block_text, relevant_indices, all_content_indices)
|
|||
|
|
"""
|
|||
|
|
titles = context['title']
|
|||
|
|
sentences_list = context['sentences']
|
|||
|
|
|
|||
|
|
# Build supporting facts lookup
|
|||
|
|
sf_lookup = defaultdict(set)
|
|||
|
|
for title, sent_id in zip(supporting_facts['title'], supporting_facts['sent_id']):
|
|||
|
|
sf_lookup[title].add(sent_id)
|
|||
|
|
|
|||
|
|
blocks = []
|
|||
|
|
relevant_indices = []
|
|||
|
|
content_indices = [] # All real content (non-noise)
|
|||
|
|
|
|||
|
|
idx = 1
|
|||
|
|
|
|||
|
|
for doc_idx, (title, sentences) in enumerate(zip(titles, sentences_list)):
|
|||
|
|
# Title block
|
|||
|
|
blocks.append(f"[{idx}] <h2>{title}</h2>")
|
|||
|
|
content_indices.append(idx)
|
|||
|
|
if title in sf_lookup:
|
|||
|
|
# Title of a supporting document is relevant
|
|||
|
|
relevant_indices.append(idx)
|
|||
|
|
idx += 1
|
|||
|
|
|
|||
|
|
# Sentence blocks
|
|||
|
|
for sent_idx, sentence in enumerate(sentences):
|
|||
|
|
sentence = sentence.strip()
|
|||
|
|
if not sentence:
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
# Use <p> for regular text
|
|||
|
|
blocks.append(f"[{idx}] <p>{sentence}</p>")
|
|||
|
|
content_indices.append(idx)
|
|||
|
|
|
|||
|
|
if title in sf_lookup and sent_idx in sf_lookup[title]:
|
|||
|
|
relevant_indices.append(idx)
|
|||
|
|
idx += 1
|
|||
|
|
|
|||
|
|
# Inject noise between documents sometimes
|
|||
|
|
if inject_noise and random.random() < 0.4 and doc_idx < len(titles) - 1:
|
|||
|
|
noise = random.choice(NOISE_BLOCKS)
|
|||
|
|
blocks.append(f"[{idx}] {noise}")
|
|||
|
|
idx += 1
|
|||
|
|
|
|||
|
|
# Sometimes add noise at start and end
|
|||
|
|
if inject_noise:
|
|||
|
|
prefix_noise = []
|
|||
|
|
if random.random() < 0.5:
|
|||
|
|
for _ in range(random.randint(1, 3)):
|
|||
|
|
noise = random.choice(NOISE_BLOCKS)
|
|||
|
|
prefix_noise.append(noise)
|
|||
|
|
|
|||
|
|
suffix_noise = []
|
|||
|
|
if random.random() < 0.5:
|
|||
|
|
for _ in range(random.randint(1, 3)):
|
|||
|
|
noise = random.choice(NOISE_BLOCKS)
|
|||
|
|
suffix_noise.append(noise)
|
|||
|
|
|
|||
|
|
if prefix_noise or suffix_noise:
|
|||
|
|
# Reindex everything
|
|||
|
|
new_blocks = []
|
|||
|
|
new_relevant = []
|
|||
|
|
new_content = []
|
|||
|
|
new_idx = 1
|
|||
|
|
|
|||
|
|
# Prefix noise
|
|||
|
|
for noise in prefix_noise:
|
|||
|
|
new_blocks.append(f"[{new_idx}] {noise}")
|
|||
|
|
new_idx += 1
|
|||
|
|
|
|||
|
|
# Remap original blocks
|
|||
|
|
offset = len(prefix_noise)
|
|||
|
|
for b in blocks:
|
|||
|
|
old_idx = int(b.split(']')[0].replace('[', ''))
|
|||
|
|
new_b = f"[{old_idx + offset}] " + '] '.join(b.split('] ')[1:])
|
|||
|
|
new_blocks.append(new_b)
|
|||
|
|
|
|||
|
|
new_relevant = [r + offset for r in relevant_indices]
|
|||
|
|
new_content = [c + offset for c in content_indices]
|
|||
|
|
|
|||
|
|
# Suffix noise
|
|||
|
|
next_idx = len(new_blocks) + 1
|
|||
|
|
for noise in suffix_noise:
|
|||
|
|
new_blocks.append(f"[{next_idx}] {noise}")
|
|||
|
|
next_idx += 1
|
|||
|
|
|
|||
|
|
blocks = new_blocks
|
|||
|
|
relevant_indices = new_relevant
|
|||
|
|
content_indices = new_content
|
|||
|
|
|
|||
|
|
block_text = "\n".join(blocks)
|
|||
|
|
return block_text, relevant_indices, content_indices
|
|||
|
|
|
|||
|
|
|
|||
|
|
def build_query_relevant_example(question, block_text, relevant_indices, url="https://en.wikipedia.org"):
|
|||
|
|
"""Build a query-relevant extraction (QE) example."""
|
|||
|
|
intervals = indices_to_intervals(relevant_indices)
|
|||
|
|
|
|||
|
|
user_content = f"URL: {url}\nQuery: {question}\n\nBlocks:\n{block_text}\n\nOutput the index intervals of blocks relevant to the query."
|
|||
|
|
|
|||
|
|
messages = [
|
|||
|
|
{"role": "system", "content": SYSTEM_PROMPT_QE},
|
|||
|
|
{"role": "user", "content": user_content},
|
|||
|
|
{"role": "assistant", "content": intervals}
|
|||
|
|
]
|
|||
|
|
return messages
|
|||
|
|
|
|||
|
|
|
|||
|
|
def build_main_content_example(block_text, content_indices, title="Wikipedia Article", url="https://en.wikipedia.org"):
|
|||
|
|
"""Build a main content extraction (ME) example."""
|
|||
|
|
intervals = indices_to_intervals(content_indices)
|
|||
|
|
|
|||
|
|
user_content = f"URL: {url}\nTitle: {title}\n\nBlocks:\n{block_text}\n\nOutput the index intervals of main content blocks."
|
|||
|
|
|
|||
|
|
messages = [
|
|||
|
|
{"role": "system", "content": SYSTEM_PROMPT_ME},
|
|||
|
|
{"role": "user", "content": user_content},
|
|||
|
|
{"role": "assistant", "content": intervals}
|
|||
|
|
]
|
|||
|
|
return messages
|
|||
|
|
|
|||
|
|
|
|||
|
|
def process_hotpotqa():
|
|||
|
|
"""Process HotpotQA into IndexLM training data."""
|
|||
|
|
print("Loading HotpotQA...")
|
|||
|
|
ds = load_dataset("hotpotqa/hotpot_qa", "distractor", split="train")
|
|||
|
|
|
|||
|
|
# Sample a manageable amount
|
|||
|
|
num_samples = min(15000, len(ds))
|
|||
|
|
ds = ds.shuffle(seed=42).select(range(num_samples))
|
|||
|
|
|
|||
|
|
all_examples = []
|
|||
|
|
skipped = 0
|
|||
|
|
|
|||
|
|
for i, row in enumerate(ds):
|
|||
|
|
if i % 1000 == 0:
|
|||
|
|
print(f"Processing {i}/{num_samples}...")
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
block_text, relevant_indices, content_indices = create_indexed_blocks_from_hotpotqa(
|
|||
|
|
row['context'], row['supporting_facts'], inject_noise=True
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# Skip if too few relevant indices
|
|||
|
|
if len(relevant_indices) < 1:
|
|||
|
|
skipped += 1
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
# Query-relevant extraction example
|
|||
|
|
qe_messages = build_query_relevant_example(
|
|||
|
|
row['question'], block_text, relevant_indices
|
|||
|
|
)
|
|||
|
|
all_examples.append({
|
|||
|
|
"messages": qe_messages,
|
|||
|
|
"task_type": "query_relevant",
|
|||
|
|
"source": "hotpotqa"
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
# Main content extraction example (50% of the time)
|
|||
|
|
if random.random() < 0.5:
|
|||
|
|
me_messages = build_main_content_example(
|
|||
|
|
block_text, content_indices,
|
|||
|
|
title=row['context']['title'][0] if row['context']['title'] else "Article"
|
|||
|
|
)
|
|||
|
|
all_examples.append({
|
|||
|
|
"messages": me_messages,
|
|||
|
|
"task_type": "main_content",
|
|||
|
|
"source": "hotpotqa"
|
|||
|
|
})
|
|||
|
|
except Exception as e:
|
|||
|
|
skipped += 1
|
|||
|
|
if skipped < 5:
|
|||
|
|
print(f"Error on row {i}: {e}")
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
print(f"Created {len(all_examples)} examples from HotpotQA ({skipped} skipped)")
|
|||
|
|
return all_examples
|
|||
|
|
|
|||
|
|
|
|||
|
|
def create_synthetic_web_pages():
|
|||
|
|
"""Create synthetic web page examples for main content extraction training."""
|
|||
|
|
print("Creating synthetic web page examples...")
|
|||
|
|
|
|||
|
|
# Load a text dataset to get content
|
|||
|
|
ds = load_dataset("hotpotqa/hotpot_qa", "distractor", split="validation")
|
|||
|
|
ds = ds.shuffle(seed=123).select(range(3000))
|
|||
|
|
|
|||
|
|
examples = []
|
|||
|
|
|
|||
|
|
for i, row in enumerate(ds):
|
|||
|
|
if i % 500 == 0:
|
|||
|
|
print(f"Synthetic page {i}/3000...")
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
# Build a more realistic web page structure
|
|||
|
|
titles = row['context']['title']
|
|||
|
|
sentences_list = row['context']['sentences']
|
|||
|
|
|
|||
|
|
if not titles or not sentences_list:
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
blocks = []
|
|||
|
|
content_indices = []
|
|||
|
|
idx = 1
|
|||
|
|
|
|||
|
|
# Header noise (nav, etc.)
|
|||
|
|
num_header_noise = random.randint(1, 4)
|
|||
|
|
for _ in range(num_header_noise):
|
|||
|
|
blocks.append(f"[{idx}] {random.choice(NOISE_BLOCKS)}")
|
|||
|
|
idx += 1
|
|||
|
|
|
|||
|
|
# Page title
|
|||
|
|
main_title = titles[0]
|
|||
|
|
blocks.append(f"[{idx}] <h1>{main_title}</h1>")
|
|||
|
|
content_indices.append(idx)
|
|||
|
|
idx += 1
|
|||
|
|
|
|||
|
|
# Main content (just first 1-3 documents)
|
|||
|
|
num_docs = min(random.randint(1, 3), len(titles))
|
|||
|
|
for doc_idx in range(num_docs):
|
|||
|
|
title = titles[doc_idx]
|
|||
|
|
sents = sentences_list[doc_idx]
|
|||
|
|
|
|||
|
|
if doc_idx > 0:
|
|||
|
|
blocks.append(f"[{idx}] <h2>{title}</h2>")
|
|||
|
|
content_indices.append(idx)
|
|||
|
|
idx += 1
|
|||
|
|
|
|||
|
|
for sent in sents:
|
|||
|
|
sent = sent.strip()
|
|||
|
|
if not sent:
|
|||
|
|
continue
|
|||
|
|
blocks.append(f"[{idx}] <p>{sent}</p>")
|
|||
|
|
content_indices.append(idx)
|
|||
|
|
idx += 1
|
|||
|
|
|
|||
|
|
# Occasional inline noise
|
|||
|
|
if random.random() < 0.3:
|
|||
|
|
blocks.append(f"[{idx}] {random.choice(NOISE_BLOCKS)}")
|
|||
|
|
idx += 1
|
|||
|
|
|
|||
|
|
# Footer noise
|
|||
|
|
num_footer_noise = random.randint(1, 4)
|
|||
|
|
for _ in range(num_footer_noise):
|
|||
|
|
blocks.append(f"[{idx}] {random.choice(NOISE_BLOCKS)}")
|
|||
|
|
idx += 1
|
|||
|
|
|
|||
|
|
block_text = "\n".join(blocks)
|
|||
|
|
me_messages = build_main_content_example(
|
|||
|
|
block_text, content_indices,
|
|||
|
|
title=main_title,
|
|||
|
|
url=f"https://en.wikipedia.org/wiki/{main_title.replace(' ', '_')}"
|
|||
|
|
)
|
|||
|
|
examples.append({
|
|||
|
|
"messages": me_messages,
|
|||
|
|
"task_type": "main_content",
|
|||
|
|
"source": "synthetic"
|
|||
|
|
})
|
|||
|
|
except Exception as e:
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
print(f"Created {len(examples)} synthetic web page examples")
|
|||
|
|
return examples
|
|||
|
|
|
|||
|
|
|
|||
|
|
def create_na_examples():
|
|||
|
|
"""Create examples where no relevant content exists (model should output 'NA')."""
|
|||
|
|
print("Creating NA examples...")
|
|||
|
|
ds = load_dataset("hotpotqa/hotpot_qa", "distractor", split="validation")
|
|||
|
|
ds = ds.shuffle(seed=456).select(range(1000))
|
|||
|
|
|
|||
|
|
examples = []
|
|||
|
|
|
|||
|
|
for i, row in enumerate(ds):
|
|||
|
|
try:
|
|||
|
|
# Use context from one question but query from another (mismatched)
|
|||
|
|
other_idx = (i + 500) % len(ds)
|
|||
|
|
other_question = ds[other_idx]['question']
|
|||
|
|
|
|||
|
|
# Build blocks from current context but keep only non-supporting content
|
|||
|
|
block_text, _, content_indices = create_indexed_blocks_from_hotpotqa(
|
|||
|
|
row['context'], {'title': [], 'sent_id': []}, inject_noise=True
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
user_content = f"URL: https://en.wikipedia.org\nQuery: {other_question}\n\nBlocks:\n{block_text}\n\nOutput the index intervals of blocks relevant to the query."
|
|||
|
|
|
|||
|
|
messages = [
|
|||
|
|
{"role": "system", "content": SYSTEM_PROMPT_QE},
|
|||
|
|
{"role": "user", "content": user_content},
|
|||
|
|
{"role": "assistant", "content": "NA"}
|
|||
|
|
]
|
|||
|
|
examples.append({
|
|||
|
|
"messages": messages,
|
|||
|
|
"task_type": "query_relevant_na",
|
|||
|
|
"source": "hotpotqa_mismatched"
|
|||
|
|
})
|
|||
|
|
except:
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
# Keep only a fraction (the paper mentions partial filtering of NA)
|
|||
|
|
random.shuffle(examples)
|
|||
|
|
examples = examples[:300]
|
|||
|
|
print(f"Created {len(examples)} NA examples")
|
|||
|
|
return examples
|
|||
|
|
|
|||
|
|
|
|||
|
|
def main():
|
|||
|
|
# Build all training examples
|
|||
|
|
qe_examples = process_hotpotqa()
|
|||
|
|
me_examples = create_synthetic_web_pages()
|
|||
|
|
na_examples = create_na_examples()
|
|||
|
|
|
|||
|
|
all_examples = qe_examples + me_examples + na_examples
|
|||
|
|
random.shuffle(all_examples)
|
|||
|
|
|
|||
|
|
print(f"\nTotal examples: {len(all_examples)}")
|
|||
|
|
|
|||
|
|
# Count by type
|
|||
|
|
type_counts = defaultdict(int)
|
|||
|
|
for ex in all_examples:
|
|||
|
|
type_counts[ex['task_type']] += 1
|
|||
|
|
for t, c in type_counts.items():
|
|||
|
|
print(f" {t}: {c}")
|
|||
|
|
|
|||
|
|
# Check lengths
|
|||
|
|
from transformers import AutoTokenizer
|
|||
|
|
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
|
|||
|
|
|
|||
|
|
lengths = []
|
|||
|
|
for ex in all_examples[:500]:
|
|||
|
|
text = tokenizer.apply_chat_template(ex['messages'], tokenize=False)
|
|||
|
|
tokens = tokenizer.encode(text)
|
|||
|
|
lengths.append(len(tokens))
|
|||
|
|
|
|||
|
|
print(f"\nToken length stats (sample of 500):")
|
|||
|
|
print(f" Min: {min(lengths)}")
|
|||
|
|
print(f" Max: {max(lengths)}")
|
|||
|
|
print(f" Mean: {sum(lengths)/len(lengths):.0f}")
|
|||
|
|
print(f" Median: {sorted(lengths)[len(lengths)//2]}")
|
|||
|
|
|
|||
|
|
# Filter out examples that are too long (>4096 tokens for efficiency)
|
|||
|
|
MAX_LEN = 4096
|
|||
|
|
filtered = []
|
|||
|
|
too_long = 0
|
|||
|
|
for ex in all_examples:
|
|||
|
|
text = tokenizer.apply_chat_template(ex['messages'], tokenize=False)
|
|||
|
|
tokens = tokenizer.encode(text)
|
|||
|
|
if len(tokens) <= MAX_LEN:
|
|||
|
|
filtered.append(ex)
|
|||
|
|
else:
|
|||
|
|
too_long += 1
|
|||
|
|
|
|||
|
|
print(f"\nFiltered: {too_long} examples too long (>{MAX_LEN} tokens)")
|
|||
|
|
print(f"Final dataset: {len(filtered)} examples")
|
|||
|
|
|
|||
|
|
# Split into train/eval
|
|||
|
|
random.shuffle(filtered)
|
|||
|
|
eval_size = min(500, len(filtered) // 10)
|
|||
|
|
train_data = filtered[:-eval_size]
|
|||
|
|
eval_data = filtered[-eval_size:]
|
|||
|
|
|
|||
|
|
print(f"Train: {len(train_data)}, Eval: {len(eval_data)}")
|
|||
|
|
|
|||
|
|
# Create HF dataset with just messages column (for SFTTrainer)
|
|||
|
|
train_ds = Dataset.from_list([{"messages": ex["messages"]} for ex in train_data])
|
|||
|
|
eval_ds = Dataset.from_list([{"messages": ex["messages"]} for ex in eval_data])
|
|||
|
|
|
|||
|
|
# Save locally
|
|||
|
|
train_ds.save_to_disk("/app/indexlm_train")
|
|||
|
|
eval_ds.save_to_disk("/app/indexlm_eval")
|
|||
|
|
|
|||
|
|
# Also push to HF Hub
|
|||
|
|
from datasets import DatasetDict
|
|||
|
|
import os
|
|||
|
|
ds_dict = DatasetDict({"train": train_ds, "eval": eval_ds})
|
|||
|
|
ds_dict.push_to_hub("OmAlve/indexlm-training-data", token=os.environ.get("HF_TOKEN"))
|
|||
|
|
|
|||
|
|
print("\nDone! Dataset pushed to OmAlve/indexlm-training-data")
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
main()
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
### 5.2 Training Script (`train_indexlm.py`)
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
"""
|
|||
|
|
IndexLM Training Script - Fine-tune Qwen3-0.6B for Index-based Web Content Extraction
|
|||
|
|
|
|||
|
|
Based on: "An Index-based Approach for Efficient and Effective Web Content Extraction" (arxiv:2512.06641)
|
|||
|
|
Base model: Qwen/Qwen3-0.6B (0.6B params, ideal for CPU deployment)
|
|||
|
|
Training method: SFT with TRL SFTTrainer
|
|||
|
|
Dataset: OmAlve/indexlm-training-data (25K+ examples)
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import os
|
|||
|
|
import torch
|
|||
|
|
from datasets import load_dataset
|
|||
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|||
|
|
from trl import SFTTrainer, SFTConfig
|
|||
|
|
import trackio
|
|||
|
|
|
|||
|
|
# ============ Configuration ============
|
|||
|
|
MODEL_ID = "Qwen/Qwen3-0.6B"
|
|||
|
|
DATASET_ID = "OmAlve/indexlm-training-data"
|
|||
|
|
OUTPUT_DIR = "./indexlm-0.6b"
|
|||
|
|
HUB_MODEL_ID = "OmAlve/IndexLM-0.6B"
|
|||
|
|
|
|||
|
|
# Training hyperparameters (from paper: standard SFT)
|
|||
|
|
LEARNING_RATE = 2e-5
|
|||
|
|
NUM_EPOCHS = 3
|
|||
|
|
BATCH_SIZE = 4
|
|||
|
|
GRAD_ACCUM = 4 # Effective batch size = 16
|
|||
|
|
MAX_SEQ_LENGTH = 4096
|
|||
|
|
WARMUP_RATIO = 0.05
|
|||
|
|
|
|||
|
|
# ============ Setup Trackio ============
|
|||
|
|
trackio.init(
|
|||
|
|
name="indexlm-0.6b-training",
|
|||
|
|
project="indexlm"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# ============ Load Dataset ============
|
|||
|
|
print("Loading dataset...")
|
|||
|
|
dataset = load_dataset(DATASET_ID)
|
|||
|
|
train_dataset = dataset["train"]
|
|||
|
|
eval_dataset = dataset["eval"]
|
|||
|
|
print(f"Train: {len(train_dataset)}, Eval: {len(eval_dataset)}")
|
|||
|
|
|
|||
|
|
# ============ Load Model & Tokenizer ============
|
|||
|
|
print("Loading model and tokenizer...")
|
|||
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
|||
|
|
|
|||
|
|
# Ensure padding token is set
|
|||
|
|
if tokenizer.pad_token is None:
|
|||
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|||
|
|
|
|||
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|||
|
|
MODEL_ID,
|
|||
|
|
torch_dtype=torch.bfloat16,
|
|||
|
|
attn_implementation="flash_attention_2", # Change to "sdpa" if flash-attn unavailable
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
print(f"Model loaded: {MODEL_ID}")
|
|||
|
|
print(f"Model params: {sum(p.numel() for p in model.parameters()) / 1e6:.1f}M")
|
|||
|
|
|
|||
|
|
# ============ Training Config ============
|
|||
|
|
training_args = SFTConfig(
|
|||
|
|
output_dir=OUTPUT_DIR,
|
|||
|
|
num_train_epochs=NUM_EPOCHS,
|
|||
|
|
per_device_train_batch_size=BATCH_SIZE,
|
|||
|
|
per_device_eval_batch_size=BATCH_SIZE,
|
|||
|
|
gradient_accumulation_steps=GRAD_ACCUM,
|
|||
|
|
learning_rate=LEARNING_RATE,
|
|||
|
|
lr_scheduler_type="cosine",
|
|||
|
|
warmup_ratio=WARMUP_RATIO,
|
|||
|
|
weight_decay=0.01,
|
|||
|
|
bf16=True,
|
|||
|
|
gradient_checkpointing=True,
|
|||
|
|
max_length=MAX_SEQ_LENGTH,
|
|||
|
|
# Logging
|
|||
|
|
logging_steps=10,
|
|||
|
|
logging_first_step=True,
|
|||
|
|
logging_strategy="steps",
|
|||
|
|
disable_tqdm=True,
|
|||
|
|
# Evaluation
|
|||
|
|
eval_strategy="steps",
|
|||
|
|
eval_steps=500,
|
|||
|
|
# Saving
|
|||
|
|
save_strategy="steps",
|
|||
|
|
save_steps=500,
|
|||
|
|
save_total_limit=3,
|
|||
|
|
load_best_model_at_end=True,
|
|||
|
|
metric_for_best_model="eval_loss",
|
|||
|
|
greater_is_better=False,
|
|||
|
|
# Hub push
|
|||
|
|
push_to_hub=True,
|
|||
|
|
hub_model_id=HUB_MODEL_ID,
|
|||
|
|
hub_strategy="every_save",
|
|||
|
|
# Performance
|
|||
|
|
dataloader_num_workers=4,
|
|||
|
|
dataloader_pin_memory=True,
|
|||
|
|
# Report
|
|||
|
|
report_to="none",
|
|||
|
|
# Seed
|
|||
|
|
seed=42,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# ============ Initialize Trainer ============
|
|||
|
|
print("Initializing trainer...")
|
|||
|
|
trainer = SFTTrainer(
|
|||
|
|
model=model,
|
|||
|
|
args=training_args,
|
|||
|
|
train_dataset=train_dataset,
|
|||
|
|
eval_dataset=eval_dataset,
|
|||
|
|
processing_class=tokenizer,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# ============ Train ============
|
|||
|
|
print("Starting training...")
|
|||
|
|
train_result = trainer.train()
|
|||
|
|
|
|||
|
|
# ============ Save Final Model ============
|
|||
|
|
print("Saving final model...")
|
|||
|
|
trainer.save_model(OUTPUT_DIR)
|
|||
|
|
tokenizer.save_pretrained(OUTPUT_DIR)
|
|||
|
|
|
|||
|
|
# Push to Hub
|
|||
|
|
print("Pushing to Hub...")
|
|||
|
|
trainer.push_to_hub(commit_message="Final IndexLM-0.6B model")
|
|||
|
|
|
|||
|
|
# ============ Log Final Metrics ============
|
|||
|
|
metrics = train_result.metrics
|
|||
|
|
print(f"\nTraining complete!")
|
|||
|
|
print(f" Train loss: {metrics.get('train_loss', 'N/A')}")
|
|||
|
|
print(f" Train runtime: {metrics.get('train_runtime', 'N/A'):.0f}s")
|
|||
|
|
print(f" Train samples/sec: {metrics.get('train_samples_per_second', 'N/A'):.1f}")
|
|||
|
|
|
|||
|
|
# Final eval
|
|||
|
|
eval_metrics = trainer.evaluate()
|
|||
|
|
print(f" Eval loss: {eval_metrics.get('eval_loss', 'N/A')}")
|
|||
|
|
|
|||
|
|
print(f"\nModel pushed to: https://huggingface.co/{HUB_MODEL_ID}")
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
### 5.3 Evaluation Script (`eval_indexlm.py`)
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
"""
|
|||
|
|
IndexLM Evaluation Script
|
|||
|
|
|
|||
|
|
Tests the trained model on:
|
|||
|
|
1. Query-relevant extraction (QE) - F1/Precision/Recall
|
|||
|
|
2. Main content extraction (ME) - F1/Precision/Recall
|
|||
|
|
3. Inference speed on CPU
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import json
|
|||
|
|
import time
|
|||
|
|
import os
|
|||
|
|
import torch
|
|||
|
|
from datasets import load_dataset
|
|||
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|||
|
|
|
|||
|
|
|
|||
|
|
def parse_intervals(text):
|
|||
|
|
"""Parse interval string like '[[1,3],[5,7]]' into a set of indices."""
|
|||
|
|
text = text.strip()
|
|||
|
|
if text.upper() == 'NA' or not text:
|
|||
|
|
return set()
|
|||
|
|
try:
|
|||
|
|
intervals = json.loads(text)
|
|||
|
|
indices = set()
|
|||
|
|
for start, end in intervals:
|
|||
|
|
indices.update(range(start, end + 1))
|
|||
|
|
return indices
|
|||
|
|
except (json.JSONDecodeError, TypeError, ValueError):
|
|||
|
|
return set()
|
|||
|
|
|
|||
|
|
|
|||
|
|
def compute_f1(pred_indices, gold_indices):
|
|||
|
|
"""Compute F1, precision, recall between two sets of indices."""
|
|||
|
|
if not pred_indices and not gold_indices:
|
|||
|
|
return 1.0, 1.0, 1.0
|
|||
|
|
if not pred_indices or not gold_indices:
|
|||
|
|
return 0.0, 0.0, 0.0
|
|||
|
|
|
|||
|
|
tp = len(pred_indices & gold_indices)
|
|||
|
|
precision = tp / len(pred_indices) if pred_indices else 0
|
|||
|
|
recall = tp / len(gold_indices) if gold_indices else 0
|
|||
|
|
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
|
|||
|
|
return f1, precision, recall
|
|||
|
|
|
|||
|
|
|
|||
|
|
def generate_response(model, tokenizer, messages, device, max_new_tokens=128):
|
|||
|
|
"""Generate model response for given messages."""
|
|||
|
|
text = tokenizer.apply_chat_template(
|
|||
|
|
messages[:-1], # Exclude assistant message (ground truth)
|
|||
|
|
tokenize=False,
|
|||
|
|
add_generation_prompt=True,
|
|||
|
|
enable_thinking=False,
|
|||
|
|
)
|
|||
|
|
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=4096).to(device)
|
|||
|
|
|
|||
|
|
with torch.no_grad():
|
|||
|
|
outputs = model.generate(
|
|||
|
|
**inputs,
|
|||
|
|
max_new_tokens=max_new_tokens,
|
|||
|
|
do_sample=False, # Greedy for deterministic eval
|
|||
|
|
temperature=1.0,
|
|||
|
|
pad_token_id=tokenizer.pad_token_id,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# Decode only the new tokens
|
|||
|
|
new_tokens = outputs[0][inputs['input_ids'].shape[1]:]
|
|||
|
|
response = tokenizer.decode(new_tokens, skip_special_tokens=True)
|
|||
|
|
return response.strip()
|
|||
|
|
|
|||
|
|
|
|||
|
|
def evaluate_model(model_id, device="cpu", num_samples=100):
|
|||
|
|
"""Run full evaluation."""
|
|||
|
|
print(f"\n{'='*60}")
|
|||
|
|
print(f"Evaluating: {model_id}")
|
|||
|
|
print(f"Device: {device}")
|
|||
|
|
print(f"{'='*60}")
|
|||
|
|
|
|||
|
|
# Load model
|
|||
|
|
print("Loading model...")
|
|||
|
|
dtype = torch.float32 if device == "cpu" else torch.bfloat16
|
|||
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
|||
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|||
|
|
model_id,
|
|||
|
|
torch_dtype=dtype,
|
|||
|
|
attn_implementation="sdpa",
|
|||
|
|
).to(device)
|
|||
|
|
model.eval()
|
|||
|
|
|
|||
|
|
# Load eval dataset
|
|||
|
|
print("Loading eval dataset...")
|
|||
|
|
dataset = load_dataset("OmAlve/indexlm-training-data", split="eval")
|
|||
|
|
|
|||
|
|
# Sample
|
|||
|
|
if len(dataset) > num_samples:
|
|||
|
|
dataset = dataset.shuffle(seed=42).select(range(num_samples))
|
|||
|
|
|
|||
|
|
# Categorize examples
|
|||
|
|
qe_examples = []
|
|||
|
|
me_examples = []
|
|||
|
|
|
|||
|
|
for row in dataset:
|
|||
|
|
msgs = row['messages']
|
|||
|
|
system_msg = msgs[0]['content'] if msgs[0]['role'] == 'system' else ''
|
|||
|
|
if 'query' in system_msg.lower() and 'relevant' in system_msg.lower():
|
|||
|
|
qe_examples.append(msgs)
|
|||
|
|
else:
|
|||
|
|
me_examples.append(msgs)
|
|||
|
|
|
|||
|
|
print(f"QE examples: {len(qe_examples)}, ME examples: {len(me_examples)}")
|
|||
|
|
|
|||
|
|
# Evaluate QE
|
|||
|
|
print("\n--- Query-Relevant Extraction (QE) ---")
|
|||
|
|
qe_metrics = evaluate_task(model, tokenizer, qe_examples[:50], device)
|
|||
|
|
|
|||
|
|
# Evaluate ME
|
|||
|
|
print("\n--- Main Content Extraction (ME) ---")
|
|||
|
|
me_metrics = evaluate_task(model, tokenizer, me_examples[:50], device)
|
|||
|
|
|
|||
|
|
# Speed test
|
|||
|
|
print("\n--- Inference Speed Test ---")
|
|||
|
|
speed_test(model, tokenizer, qe_examples[:20], device)
|
|||
|
|
|
|||
|
|
return qe_metrics, me_metrics
|
|||
|
|
|
|||
|
|
|
|||
|
|
def evaluate_task(model, tokenizer, examples, device):
|
|||
|
|
"""Evaluate on a set of examples."""
|
|||
|
|
if not examples:
|
|||
|
|
print("No examples for this task.")
|
|||
|
|
return {}
|
|||
|
|
|
|||
|
|
f1_scores = []
|
|||
|
|
precision_scores = []
|
|||
|
|
recall_scores = []
|
|||
|
|
exact_matches = 0
|
|||
|
|
|
|||
|
|
for i, msgs in enumerate(examples):
|
|||
|
|
gold = msgs[-1]['content']
|
|||
|
|
gold_indices = parse_intervals(gold)
|
|||
|
|
|
|||
|
|
pred = generate_response(model, tokenizer, msgs, device)
|
|||
|
|
pred_indices = parse_intervals(pred)
|
|||
|
|
|
|||
|
|
f1, prec, rec = compute_f1(pred_indices, gold_indices)
|
|||
|
|
f1_scores.append(f1)
|
|||
|
|
precision_scores.append(prec)
|
|||
|
|
recall_scores.append(rec)
|
|||
|
|
|
|||
|
|
if pred_indices == gold_indices:
|
|||
|
|
exact_matches += 1
|
|||
|
|
|
|||
|
|
if i < 3:
|
|||
|
|
print(f" Example {i+1}:")
|
|||
|
|
print(f" Gold: {gold}")
|
|||
|
|
print(f" Pred: {pred}")
|
|||
|
|
print(f" F1: {f1:.3f}, P: {prec:.3f}, R: {rec:.3f}")
|
|||
|
|
|
|||
|
|
avg_f1 = sum(f1_scores) / len(f1_scores) * 100
|
|||
|
|
avg_prec = sum(precision_scores) / len(precision_scores) * 100
|
|||
|
|
avg_rec = sum(recall_scores) / len(recall_scores) * 100
|
|||
|
|
em_rate = exact_matches / len(examples) * 100
|
|||
|
|
|
|||
|
|
print(f"\n Results ({len(examples)} examples):")
|
|||
|
|
print(f" F1: {avg_f1:.2f}")
|
|||
|
|
print(f" Precision: {avg_prec:.2f}")
|
|||
|
|
print(f" Recall: {avg_rec:.2f}")
|
|||
|
|
print(f" Exact Match: {em_rate:.2f}%")
|
|||
|
|
|
|||
|
|
return {"f1": avg_f1, "precision": avg_prec, "recall": avg_rec, "exact_match": em_rate}
|
|||
|
|
|
|||
|
|
|
|||
|
|
def speed_test(model, tokenizer, examples, device):
|
|||
|
|
"""Test inference speed."""
|
|||
|
|
if not examples:
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
times = []
|
|||
|
|
for msgs in examples:
|
|||
|
|
start = time.time()
|
|||
|
|
_ = generate_response(model, tokenizer, msgs, device)
|
|||
|
|
elapsed = time.time() - start
|
|||
|
|
times.append(elapsed)
|
|||
|
|
|
|||
|
|
avg_time = sum(times) / len(times)
|
|||
|
|
print(f" Average inference time: {avg_time:.3f}s ({device})")
|
|||
|
|
print(f" Min: {min(times):.3f}s, Max: {max(times):.3f}s")
|
|||
|
|
print(f" Throughput: {1/avg_time:.1f} pages/sec")
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
model_id = os.environ.get("MODEL_ID", "OmAlve/IndexLM-0.6B")
|
|||
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|||
|
|
evaluate_model(model_id, device=device, num_samples=100)
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
---
|
|||
|
|
|
|||
|
|
## 6. Key Design Decisions & Rationale
|
|||
|
|
|
|||
|
|
### Why Qwen3-0.6B?
|
|||
|
|
- The paper uses Qwen3-0.6B/1.7B/4B. The 0.6B achieves **near-identical performance** to 4B on RAG QA (54.70 vs 55.41 avg F1)
|
|||
|
|
- 0.6B is **1.4GB in bf16, ~700MB in INT4** — runs fast on CPU
|
|||
|
|
- TRL's own SFT documentation uses Qwen3-0.6B as its default example model — maximum compatibility
|
|||
|
|
- Qwen3 has GQA (grouped-query attention) which is faster for inference than MHA
|
|||
|
|
|
|||
|
|
### Why not ReaderLM-v2?
|
|||
|
|
- ReaderLM-v2 does generative HTML→Markdown extraction (different task)
|
|||
|
|
- It's **33-70× slower** than IndexLM on the paper's benchmarks
|
|||
|
|
- Fine-tuning it for index prediction would fight against its pretrained generation behavior
|
|||
|
|
|
|||
|
|
### Dataset construction vs. the paper
|
|||
|
|
The paper uses:
|
|||
|
|
1. Google Search API crawls → real HTML from the web
|
|||
|
|
2. DeepSeek V3 annotation with 5-run majority voting
|
|||
|
|
3. Common Crawl WARC files
|
|||
|
|
|
|||
|
|
We approximate this with:
|
|||
|
|
1. HotpotQA's structured context (title + sentences) converted to indexed HTML blocks
|
|||
|
|
2. Programmatic labeling from HotpotQA's `supporting_facts` ground truth (higher quality than LLM annotation)
|
|||
|
|
3. Synthetic noise injection (nav, ads, cookies, etc.) to simulate real web clutter
|
|||
|
|
4. Mismatched query-page pairs for NA examples
|
|||
|
|
|
|||
|
|
**Trade-off**: Our HTML blocks are simpler than real web HTML (no nested tables, complex CSS-in-JS, etc.). For production use, augmenting with real crawled HTML would improve robustness. The paper's full pipeline would require API costs (Google Search, DeepSeek V3).
|
|||
|
|
|
|||
|
|
### Hyperparameters
|
|||
|
|
Directly from the paper Section 3.3.2: "The training process is a typical SFT process" on Qwen3. We use:
|
|||
|
|
- lr=2e-5 (TRL SFT default, standard for Qwen3)
|
|||
|
|
- 3 epochs (standard SFT)
|
|||
|
|
- Effective batch size 16 (4 × 4 grad accum)
|
|||
|
|
- Cosine LR schedule with 5% warmup
|
|||
|
|
- max_length=4096 (covers 99.8% of our data, well within Qwen3's 32K context)
|
|||
|
|
|
|||
|
|
---
|
|||
|
|
|
|||
|
|
## 7. What's Left To Do
|
|||
|
|
|
|||
|
|
| Task | Status | Notes |
|
|||
|
|
|------|--------|-------|
|
|||
|
|
| Run `train_indexlm.py` | ❌ | Needs GPU — a10g-large recommended (~$8 total) |
|
|||
|
|
| Run `eval_indexlm.py` | ❌ | After training completes |
|
|||
|
|
| ONNX export for CPU | ❌ | Optional: `optimum-cli export onnx --model OmAlve/IndexLM-0.6B indexlm-onnx/` |
|
|||
|
|
| INT4 quantization | ❌ | Optional: use `bitsandbytes` or `llama.cpp` for faster CPU |
|
|||
|
|
| Real HTML augmentation | ❌ | Optional: crawl real web pages to augment training data |
|
|||
|
|
|
|||
|
|
---
|
|||
|
|
|
|||
|
|
## 8. Resources
|
|||
|
|
|
|||
|
|
| Resource | URL |
|
|||
|
|
|----------|-----|
|
|||
|
|
| Paper | https://arxiv.org/abs/2512.06641 |
|
|||
|
|
| Training dataset | https://huggingface.co/datasets/OmAlve/indexlm-training-data |
|
|||
|
|
| Base model | https://huggingface.co/Qwen/Qwen3-0.6B |
|
|||
|
|
| Output model (after training) | https://huggingface.co/OmAlve/IndexLM-0.6B |
|
|||
|
|
| TRL SFT docs | https://huggingface.co/docs/trl/sft_trainer |
|
|||
|
|
| HotpotQA source | https://huggingface.co/datasets/hotpotqa/hotpot_qa |
|
|||
|
|
|
|||
|
|
---
|
|||
|
|
|
|||
|
|
## 9. Dependencies
|
|||
|
|
|
|||
|
|
```
|
|||
|
|
transformers>=4.51.0
|
|||
|
|
trl>=1.2.0
|
|||
|
|
torch
|
|||
|
|
datasets
|
|||
|
|
accelerate
|
|||
|
|
trackio
|
|||
|
|
flash-attn # optional, GPU training only
|
|||
|
|
beautifulsoup4 # only for prepare_data.py
|
|||
|
|
lxml # only for prepare_data.py
|
|||
|
|
```
|