commit 92eddcb2d67dd81c582bc8969cba47f2b70692a9 Author: ModelHub XC Date: Tue Jun 16 08:15:17 2026 +0800 初始化项目,由ModelHub XC社区提供模型 Model: OmAlve/reading-steiner Source: Original Platform diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..52373fe --- /dev/null +++ b/.gitattributes @@ -0,0 +1,36 @@ +*.7z filter=lfs diff=lfs merge=lfs -text +*.arrow filter=lfs diff=lfs merge=lfs -text +*.bin filter=lfs diff=lfs merge=lfs -text +*.bz2 filter=lfs diff=lfs merge=lfs -text +*.ckpt filter=lfs diff=lfs merge=lfs -text +*.ftz filter=lfs diff=lfs merge=lfs -text +*.gz filter=lfs diff=lfs merge=lfs -text +*.h5 filter=lfs diff=lfs merge=lfs -text +*.joblib filter=lfs diff=lfs merge=lfs -text +*.lfs.* filter=lfs diff=lfs merge=lfs -text +*.mlmodel filter=lfs diff=lfs merge=lfs -text +*.model filter=lfs diff=lfs merge=lfs -text +*.msgpack filter=lfs diff=lfs merge=lfs -text +*.npy filter=lfs diff=lfs merge=lfs -text +*.npz filter=lfs diff=lfs merge=lfs -text +*.onnx filter=lfs diff=lfs merge=lfs -text +*.ot filter=lfs diff=lfs merge=lfs -text +*.parquet filter=lfs diff=lfs merge=lfs -text +*.pb filter=lfs diff=lfs merge=lfs -text +*.pickle filter=lfs diff=lfs merge=lfs -text +*.pkl filter=lfs diff=lfs merge=lfs -text +*.pt filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text +*.rar filter=lfs diff=lfs merge=lfs -text +*.safetensors filter=lfs diff=lfs merge=lfs -text +saved_model/**/* filter=lfs diff=lfs merge=lfs -text +*.tar.* filter=lfs diff=lfs merge=lfs -text +*.tar filter=lfs diff=lfs merge=lfs -text +*.tflite filter=lfs diff=lfs merge=lfs -text +*.tgz filter=lfs diff=lfs merge=lfs -text +*.wasm filter=lfs diff=lfs merge=lfs -text +*.xz filter=lfs diff=lfs merge=lfs -text +*.zip filter=lfs diff=lfs merge=lfs -text +*.zst filter=lfs diff=lfs merge=lfs -text +*tfevents* filter=lfs diff=lfs merge=lfs -text +tokenizer.json filter=lfs diff=lfs merge=lfs -text diff --git a/HANDOFF.md b/HANDOFF.md new file mode 100644 index 0000000..36dd4d7 --- /dev/null +++ b/HANDOFF.md @@ -0,0 +1,1082 @@ +# IndexLM-0.6B: Index-based Web Content Extraction + +## Project Handoff Document + +**Paper**: [An Index-based Approach for Efficient and Effective Web Content Extraction](https://arxiv.org/abs/2512.06641) +**Goal**: Fine-tune a SOTA web content extraction model that runs fast on CPU +**Status**: Dataset prepared & pushed ✅ | Training script ready ✅ | Training NOT yet run ❌ + +--- + +## 1. What This Is + +The paper introduces **IndexLM** — a model that extracts relevant content from web pages by predicting **index intervals** instead of generating full text. This makes it: +- **10–50× faster** than generative extraction (ReaderLM-v2, Firecrawl, etc.) +- **SOTA on RAG QA** benchmarks (HotpotQA, NQ, TriviaQA, MuSiQue, MultiHopRAG) +- **Tiny**: even the 0.6B version beats all baselines + +The original IndexLM weights are **not publicly released**. This project replicates the approach. + +### How It Works + +1. HTML is cleaned and split into indexed blocks: `[1]

Title

`, `[2]

Content...

`, etc. +2. The model receives these blocks + a query +3. It outputs index intervals like `[[2,4],[7,7],[10,12]]` — identifying which blocks are relevant +4. The blocks are reassembled into clean HTML/Markdown + +Two tasks: +- **Query-relevant extraction (QE)**: Extract blocks relevant to a specific query +- **Main content extraction (ME)**: Extract main content, filtering out nav/ads/sidebars + +### Paper Results (Table 2 & 3) + +| Model | Params | Avg RAG QA F1 | ME F1 | QE F1 | Latency (ME) | +|-------|--------|---------------|-------|-------|-------------| +| **IndexLM-0.6B** | 0.6B | 54.70 | 83.38 | 28.64 | **0.35s** | +| **IndexLM-4B** | 4B | 55.41 | 87.40 | 31.69 | 0.81s | +| ReaderLM-v2 | 1.5B | 46.84 | 68.89 | 13.31 | 11.76s | +| HtmlRAG | - | 47.00 | 48.65 | 8.83 | 7.12s | +| Firecrawl Extract | API | 52.72 | - | 29.48 | 11.33s | + +--- + +## 2. What's Been Done + +### ✅ Dataset Created & Pushed (v2 — Multi-domain) + +**Hub**: [`OmAlve/indexlm-training-data`](https://huggingface.co/datasets/OmAlve/indexlm-training-data) + +| Split | Rows | +|-------|------| +| train | 21,098 | +| eval | 500 | + +**Domain Composition (avoids Wikipedia-only bias):** +| Source | Count | % | Domain | +|--------|-------|---|--------| +| MultiHopRAG | 7,165 | 33.2% | News (Mashable, CNBC, AP, etc.) | +| HotpotQA | 6,479 | 30.0% | Wikipedia | +| HtmlRAG-train | 2,692 | 12.5% | **Real Bing-scraped web HTML** (diverse) | +| MS MARCO | 4,844 | 22.4% | Diverse web (Bing search results) | +| NA (mismatched) | 418 | 1.9% | Cross-domain | + +**Task Type Composition:** +- `query_relevant`: ~78% — query-specific extraction +- `main_content`: ~20% — main content vs. noise (nav/ads/cookies) +- `query_relevant_na`: ~2% — no relevant content exists + +**Key improvement over v1**: Real web HTML from Bing search results (via HtmlRAG-train) + news articles + MS MARCO diverse web QA, not just Wikipedia. + +**Format**: Conversational `messages` column (SFTTrainer-native): +```json +{ + "messages": [ + {"role": "system", "content": "You are IndexLM, a web content extraction model..."}, + {"role": "user", "content": "URL: ...\nQuery: ...\n\nBlocks:\n[1]

Title

\n[2]

Content

\n...\n\nOutput the index intervals of blocks relevant to the query."}, + {"role": "assistant", "content": "[[2, 4], [7, 7]]"} + ] +} +``` + +**Token length stats** (Qwen3-0.6B tokenizer): +- Min: 316, Max: 4,105, Mean: 1,944, Median: 2,019 +- 43 examples filtered (>4096 tokens) + +**Data pipeline** (from `prepare_data_v2.py`): +1. **HtmlRAG-train** (5,880 raw examples): Real Bing-scraped HTML from 5 QA datasets (NQ, ASQA, TriviaQA, MuSiQue, HotpotQA). Segments HTML by block-level tags, matches relevant blocks to ground-truth answers using trigram/substring matching. +2. **MultiHopRAG** (8,521 examples): News articles from Mashable, CNBC, AP, etc. Converts article body + evidence annotations to indexed blocks. Injects realistic noise blocks. +3. **HotpotQA** (6,486 examples, minority): Wikipedia context with supporting facts → index intervals. Noise injected. +4. **MS MARCO** (4,844 examples): Diverse web QA from Bing search. Passages from real web pages across numeric, entity, description, location, person query types. +5. **NA examples** (500): Mismatched query-page pairs from different sources. +6. Filters to ≤4096 tokens, shuffles, splits train/eval. + +### ✅ Training Script Ready + +**File**: `train_indexlm.py` (see Section 5 below) + +Key settings: +- **Base model**: `Qwen/Qwen3-0.6B` (751M params, bf16, GQA, 32K context) +- **Method**: SFT via TRL `SFTTrainer` + `SFTConfig` +- **Output**: `OmAlve/IndexLM-0.6B` on Hub +- **Hyperparameters**: lr=2e-5, epochs=3, batch=4, grad_accum=4 (effective BS=16), max_length=4096, cosine LR schedule, warmup=5% +- `push_to_hub=True`, `hub_model_id="OmAlve/IndexLM-0.6B"` +- Trackio monitoring included +- Flash Attention 2 for training speed + +### ✅ Evaluation Script Ready + +**File**: `eval_indexlm.py` (see Section 5 below) + +Evaluates: +- QE F1/Precision/Recall on eval split +- ME F1/Precision/Recall on eval split +- CPU inference speed benchmark + +### ❌ Training Not Yet Run + +Ran into credits issue on HF Jobs (402 Payment Required). You need to run `train_indexlm.py` on a GPU. + +--- + +## 3. How to Train + +### Option A: HF Jobs (if you have credits) + +```bash +# Dependencies +pip install "transformers>=4.51.0" "trl>=1.2.0" torch datasets accelerate trackio "flash-attn --no-build-isolation" +``` + +Recommended hardware: **a10g-large** ($2/hr) or **t4-small** ($0.60/hr) — model is only 0.6B params. +Estimated time: **2-4 hours** on a10g, **4-6 hours** on T4. +Set timeout to **6h** minimum. + +### Option B: Any GPU machine + +```bash +pip install "transformers>=4.51.0" "trl>=1.2.0" torch datasets accelerate trackio +pip install flash-attn --no-build-isolation # optional, speeds up training + +python train_indexlm.py +``` + +**VRAM**: ~8-10 GB with gradient checkpointing + bf16 at batch_size=4. Fits on T4 (16GB), any A-series, etc. + +### Option C: Without Flash Attention + +If `flash-attn` fails to install, change this line in `train_indexlm.py`: +```python +# FROM: +attn_implementation="flash_attention_2", +# TO: +attn_implementation="sdpa", +``` + +--- + +## 4. How to Deploy on CPU + +After training, the model at `OmAlve/IndexLM-0.6B` can be loaded for CPU inference: + +```python +from transformers import AutoModelForCausalLM, AutoTokenizer +import torch + +model = AutoModelForCausalLM.from_pretrained( + "OmAlve/IndexLM-0.6B", + torch_dtype=torch.float32, + attn_implementation="sdpa", +) +tokenizer = AutoTokenizer.from_pretrained("OmAlve/IndexLM-0.6B") +model.eval() + +# Example: extract relevant content from a web page +messages = [ + {"role": "system", "content": "You are IndexLM, a web content extraction model..."}, + {"role": "user", "content": "URL: ...\nQuery: What is Python?\n\nBlocks:\n[1] \n[2]

Python Programming

\n[3]

Python is a programming language...

\n[4] \n\nOutput the index intervals of blocks relevant to the query."} +] + +text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, enable_thinking=False) +inputs = tokenizer(text, return_tensors="pt") + +with torch.no_grad(): + out = model.generate(**inputs, max_new_tokens=128, do_sample=False) + +response = tokenizer.decode(out[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True) +print(response) # → [[2, 3]] +``` + +**For even faster CPU**: quantize to INT4/INT8 with `bitsandbytes` or export to ONNX. + +--- + +## 5. All Scripts + +### 5.1 Data Preparation (`prepare_data.py`) + +```python +""" +Prepare IndexLM training data from HotpotQA and MSMARCO. + +Pipeline: +1. Load HotpotQA (has context = list of (title, sentences) + supporting_facts) +2. Convert context into indexed HTML-like blocks: [i] content +3. The target is index intervals of blocks containing supporting facts +4. Also create main-content extraction examples (all content blocks are "main content", + but we inject noise blocks like nav/ads to train the model to filter them) +5. Format as conversational messages for SFT +""" + +import json +import random +import re +from datasets import load_dataset, Dataset +from collections import defaultdict + +random.seed(42) + +# Noise blocks to inject (simulating real web page clutter) +NOISE_BLOCKS = [ + '', + '
Advertisement - Continue Reading Below
', + '', + '', + '', + '
Share on: Twitter | Facebook | LinkedIn
', + '', + '
Subscribe to our newsletter for updates
', + '', + '', + '
Comments (0) - Be the first to comment
', + '
Written by Staff Reporter | Updated: Jan 2024
', + '', + '', + '', +] + +SYSTEM_PROMPT_QE = """You are IndexLM, a web content extraction model. Given a webpage split into indexed blocks and a user query, identify which blocks contain content relevant to the query. + +Each block is formatted as: [i] content +Output the indices of relevant blocks as a Python list of [start, end] intervals (inclusive). +If no relevant content exists, output 'NA'. + +Example output: [[2,4],[7,7],[10,12]]""" + +SYSTEM_PROMPT_ME = """You are IndexLM, a web content extraction model. Given a webpage split into indexed blocks, identify which blocks contain the main content of the page (filtering out navigation, advertisements, sidebars, and other non-content elements). + +Each block is formatted as: [i] content +Output the indices of main content blocks as a Python list of [start, end] intervals (inclusive). +If no main content exists, output 'NA'. + +Example output: [[1,3],[5,8],[11,15]]""" + + +def indices_to_intervals(indices): + """Convert a sorted list of indices to intervals [[start,end], ...]""" + if not indices: + return "NA" + indices = sorted(set(indices)) + intervals = [] + start = indices[0] + end = indices[0] + for i in indices[1:]: + if i == end + 1: + end = i + else: + intervals.append([start, end]) + start = i + end = i + intervals.append([start, end]) + return json.dumps(intervals) + + +def create_indexed_blocks_from_hotpotqa(context, supporting_facts, inject_noise=True): + """ + Convert HotpotQA context into indexed HTML blocks. + + context: {'title': [...], 'sentences': [[...], ...]} + supporting_facts: {'title': [...], 'sent_id': [...]} + + Returns: (block_text, relevant_indices, all_content_indices) + """ + titles = context['title'] + sentences_list = context['sentences'] + + # Build supporting facts lookup + sf_lookup = defaultdict(set) + for title, sent_id in zip(supporting_facts['title'], supporting_facts['sent_id']): + sf_lookup[title].add(sent_id) + + blocks = [] + relevant_indices = [] + content_indices = [] # All real content (non-noise) + + idx = 1 + + for doc_idx, (title, sentences) in enumerate(zip(titles, sentences_list)): + # Title block + blocks.append(f"[{idx}]

{title}

") + content_indices.append(idx) + if title in sf_lookup: + # Title of a supporting document is relevant + relevant_indices.append(idx) + idx += 1 + + # Sentence blocks + for sent_idx, sentence in enumerate(sentences): + sentence = sentence.strip() + if not sentence: + continue + + # Use

for regular text + blocks.append(f"[{idx}]

{sentence}

") + content_indices.append(idx) + + if title in sf_lookup and sent_idx in sf_lookup[title]: + relevant_indices.append(idx) + idx += 1 + + # Inject noise between documents sometimes + if inject_noise and random.random() < 0.4 and doc_idx < len(titles) - 1: + noise = random.choice(NOISE_BLOCKS) + blocks.append(f"[{idx}] {noise}") + idx += 1 + + # Sometimes add noise at start and end + if inject_noise: + prefix_noise = [] + if random.random() < 0.5: + for _ in range(random.randint(1, 3)): + noise = random.choice(NOISE_BLOCKS) + prefix_noise.append(noise) + + suffix_noise = [] + if random.random() < 0.5: + for _ in range(random.randint(1, 3)): + noise = random.choice(NOISE_BLOCKS) + suffix_noise.append(noise) + + if prefix_noise or suffix_noise: + # Reindex everything + new_blocks = [] + new_relevant = [] + new_content = [] + new_idx = 1 + + # Prefix noise + for noise in prefix_noise: + new_blocks.append(f"[{new_idx}] {noise}") + new_idx += 1 + + # Remap original blocks + offset = len(prefix_noise) + for b in blocks: + old_idx = int(b.split(']')[0].replace('[', '')) + new_b = f"[{old_idx + offset}] " + '] '.join(b.split('] ')[1:]) + new_blocks.append(new_b) + + new_relevant = [r + offset for r in relevant_indices] + new_content = [c + offset for c in content_indices] + + # Suffix noise + next_idx = len(new_blocks) + 1 + for noise in suffix_noise: + new_blocks.append(f"[{next_idx}] {noise}") + next_idx += 1 + + blocks = new_blocks + relevant_indices = new_relevant + content_indices = new_content + + block_text = "\n".join(blocks) + return block_text, relevant_indices, content_indices + + +def build_query_relevant_example(question, block_text, relevant_indices, url="https://en.wikipedia.org"): + """Build a query-relevant extraction (QE) example.""" + intervals = indices_to_intervals(relevant_indices) + + user_content = f"URL: {url}\nQuery: {question}\n\nBlocks:\n{block_text}\n\nOutput the index intervals of blocks relevant to the query." + + messages = [ + {"role": "system", "content": SYSTEM_PROMPT_QE}, + {"role": "user", "content": user_content}, + {"role": "assistant", "content": intervals} + ] + return messages + + +def build_main_content_example(block_text, content_indices, title="Wikipedia Article", url="https://en.wikipedia.org"): + """Build a main content extraction (ME) example.""" + intervals = indices_to_intervals(content_indices) + + user_content = f"URL: {url}\nTitle: {title}\n\nBlocks:\n{block_text}\n\nOutput the index intervals of main content blocks." + + messages = [ + {"role": "system", "content": SYSTEM_PROMPT_ME}, + {"role": "user", "content": user_content}, + {"role": "assistant", "content": intervals} + ] + return messages + + +def process_hotpotqa(): + """Process HotpotQA into IndexLM training data.""" + print("Loading HotpotQA...") + ds = load_dataset("hotpotqa/hotpot_qa", "distractor", split="train") + + # Sample a manageable amount + num_samples = min(15000, len(ds)) + ds = ds.shuffle(seed=42).select(range(num_samples)) + + all_examples = [] + skipped = 0 + + for i, row in enumerate(ds): + if i % 1000 == 0: + print(f"Processing {i}/{num_samples}...") + + try: + block_text, relevant_indices, content_indices = create_indexed_blocks_from_hotpotqa( + row['context'], row['supporting_facts'], inject_noise=True + ) + + # Skip if too few relevant indices + if len(relevant_indices) < 1: + skipped += 1 + continue + + # Query-relevant extraction example + qe_messages = build_query_relevant_example( + row['question'], block_text, relevant_indices + ) + all_examples.append({ + "messages": qe_messages, + "task_type": "query_relevant", + "source": "hotpotqa" + }) + + # Main content extraction example (50% of the time) + if random.random() < 0.5: + me_messages = build_main_content_example( + block_text, content_indices, + title=row['context']['title'][0] if row['context']['title'] else "Article" + ) + all_examples.append({ + "messages": me_messages, + "task_type": "main_content", + "source": "hotpotqa" + }) + except Exception as e: + skipped += 1 + if skipped < 5: + print(f"Error on row {i}: {e}") + continue + + print(f"Created {len(all_examples)} examples from HotpotQA ({skipped} skipped)") + return all_examples + + +def create_synthetic_web_pages(): + """Create synthetic web page examples for main content extraction training.""" + print("Creating synthetic web page examples...") + + # Load a text dataset to get content + ds = load_dataset("hotpotqa/hotpot_qa", "distractor", split="validation") + ds = ds.shuffle(seed=123).select(range(3000)) + + examples = [] + + for i, row in enumerate(ds): + if i % 500 == 0: + print(f"Synthetic page {i}/3000...") + + try: + # Build a more realistic web page structure + titles = row['context']['title'] + sentences_list = row['context']['sentences'] + + if not titles or not sentences_list: + continue + + blocks = [] + content_indices = [] + idx = 1 + + # Header noise (nav, etc.) + num_header_noise = random.randint(1, 4) + for _ in range(num_header_noise): + blocks.append(f"[{idx}] {random.choice(NOISE_BLOCKS)}") + idx += 1 + + # Page title + main_title = titles[0] + blocks.append(f"[{idx}]

{main_title}

") + content_indices.append(idx) + idx += 1 + + # Main content (just first 1-3 documents) + num_docs = min(random.randint(1, 3), len(titles)) + for doc_idx in range(num_docs): + title = titles[doc_idx] + sents = sentences_list[doc_idx] + + if doc_idx > 0: + blocks.append(f"[{idx}]

{title}

") + content_indices.append(idx) + idx += 1 + + for sent in sents: + sent = sent.strip() + if not sent: + continue + blocks.append(f"[{idx}]

{sent}

") + content_indices.append(idx) + idx += 1 + + # Occasional inline noise + if random.random() < 0.3: + blocks.append(f"[{idx}] {random.choice(NOISE_BLOCKS)}") + idx += 1 + + # Footer noise + num_footer_noise = random.randint(1, 4) + for _ in range(num_footer_noise): + blocks.append(f"[{idx}] {random.choice(NOISE_BLOCKS)}") + idx += 1 + + block_text = "\n".join(blocks) + me_messages = build_main_content_example( + block_text, content_indices, + title=main_title, + url=f"https://en.wikipedia.org/wiki/{main_title.replace(' ', '_')}" + ) + examples.append({ + "messages": me_messages, + "task_type": "main_content", + "source": "synthetic" + }) + except Exception as e: + continue + + print(f"Created {len(examples)} synthetic web page examples") + return examples + + +def create_na_examples(): + """Create examples where no relevant content exists (model should output 'NA').""" + print("Creating NA examples...") + ds = load_dataset("hotpotqa/hotpot_qa", "distractor", split="validation") + ds = ds.shuffle(seed=456).select(range(1000)) + + examples = [] + + for i, row in enumerate(ds): + try: + # Use context from one question but query from another (mismatched) + other_idx = (i + 500) % len(ds) + other_question = ds[other_idx]['question'] + + # Build blocks from current context but keep only non-supporting content + block_text, _, content_indices = create_indexed_blocks_from_hotpotqa( + row['context'], {'title': [], 'sent_id': []}, inject_noise=True + ) + + user_content = f"URL: https://en.wikipedia.org\nQuery: {other_question}\n\nBlocks:\n{block_text}\n\nOutput the index intervals of blocks relevant to the query." + + messages = [ + {"role": "system", "content": SYSTEM_PROMPT_QE}, + {"role": "user", "content": user_content}, + {"role": "assistant", "content": "NA"} + ] + examples.append({ + "messages": messages, + "task_type": "query_relevant_na", + "source": "hotpotqa_mismatched" + }) + except: + continue + + # Keep only a fraction (the paper mentions partial filtering of NA) + random.shuffle(examples) + examples = examples[:300] + print(f"Created {len(examples)} NA examples") + return examples + + +def main(): + # Build all training examples + qe_examples = process_hotpotqa() + me_examples = create_synthetic_web_pages() + na_examples = create_na_examples() + + all_examples = qe_examples + me_examples + na_examples + random.shuffle(all_examples) + + print(f"\nTotal examples: {len(all_examples)}") + + # Count by type + type_counts = defaultdict(int) + for ex in all_examples: + type_counts[ex['task_type']] += 1 + for t, c in type_counts.items(): + print(f" {t}: {c}") + + # Check lengths + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B") + + lengths = [] + for ex in all_examples[:500]: + text = tokenizer.apply_chat_template(ex['messages'], tokenize=False) + tokens = tokenizer.encode(text) + lengths.append(len(tokens)) + + print(f"\nToken length stats (sample of 500):") + print(f" Min: {min(lengths)}") + print(f" Max: {max(lengths)}") + print(f" Mean: {sum(lengths)/len(lengths):.0f}") + print(f" Median: {sorted(lengths)[len(lengths)//2]}") + + # Filter out examples that are too long (>4096 tokens for efficiency) + MAX_LEN = 4096 + filtered = [] + too_long = 0 + for ex in all_examples: + text = tokenizer.apply_chat_template(ex['messages'], tokenize=False) + tokens = tokenizer.encode(text) + if len(tokens) <= MAX_LEN: + filtered.append(ex) + else: + too_long += 1 + + print(f"\nFiltered: {too_long} examples too long (>{MAX_LEN} tokens)") + print(f"Final dataset: {len(filtered)} examples") + + # Split into train/eval + random.shuffle(filtered) + eval_size = min(500, len(filtered) // 10) + train_data = filtered[:-eval_size] + eval_data = filtered[-eval_size:] + + print(f"Train: {len(train_data)}, Eval: {len(eval_data)}") + + # Create HF dataset with just messages column (for SFTTrainer) + train_ds = Dataset.from_list([{"messages": ex["messages"]} for ex in train_data]) + eval_ds = Dataset.from_list([{"messages": ex["messages"]} for ex in eval_data]) + + # Save locally + train_ds.save_to_disk("/app/indexlm_train") + eval_ds.save_to_disk("/app/indexlm_eval") + + # Also push to HF Hub + from datasets import DatasetDict + import os + ds_dict = DatasetDict({"train": train_ds, "eval": eval_ds}) + ds_dict.push_to_hub("OmAlve/indexlm-training-data", token=os.environ.get("HF_TOKEN")) + + print("\nDone! Dataset pushed to OmAlve/indexlm-training-data") + + +if __name__ == "__main__": + main() +``` + +### 5.2 Training Script (`train_indexlm.py`) + +```python +""" +IndexLM Training Script - Fine-tune Qwen3-0.6B for Index-based Web Content Extraction + +Based on: "An Index-based Approach for Efficient and Effective Web Content Extraction" (arxiv:2512.06641) +Base model: Qwen/Qwen3-0.6B (0.6B params, ideal for CPU deployment) +Training method: SFT with TRL SFTTrainer +Dataset: OmAlve/indexlm-training-data (25K+ examples) +""" + +import os +import torch +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer +from trl import SFTTrainer, SFTConfig +import trackio + +# ============ Configuration ============ +MODEL_ID = "Qwen/Qwen3-0.6B" +DATASET_ID = "OmAlve/indexlm-training-data" +OUTPUT_DIR = "./indexlm-0.6b" +HUB_MODEL_ID = "OmAlve/IndexLM-0.6B" + +# Training hyperparameters (from paper: standard SFT) +LEARNING_RATE = 2e-5 +NUM_EPOCHS = 3 +BATCH_SIZE = 4 +GRAD_ACCUM = 4 # Effective batch size = 16 +MAX_SEQ_LENGTH = 4096 +WARMUP_RATIO = 0.05 + +# ============ Setup Trackio ============ +trackio.init( + name="indexlm-0.6b-training", + project="indexlm" +) + +# ============ Load Dataset ============ +print("Loading dataset...") +dataset = load_dataset(DATASET_ID) +train_dataset = dataset["train"] +eval_dataset = dataset["eval"] +print(f"Train: {len(train_dataset)}, Eval: {len(eval_dataset)}") + +# ============ Load Model & Tokenizer ============ +print("Loading model and tokenizer...") +tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) + +# Ensure padding token is set +if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + +model = AutoModelForCausalLM.from_pretrained( + MODEL_ID, + torch_dtype=torch.bfloat16, + attn_implementation="flash_attention_2", # Change to "sdpa" if flash-attn unavailable +) + +print(f"Model loaded: {MODEL_ID}") +print(f"Model params: {sum(p.numel() for p in model.parameters()) / 1e6:.1f}M") + +# ============ Training Config ============ +training_args = SFTConfig( + output_dir=OUTPUT_DIR, + num_train_epochs=NUM_EPOCHS, + per_device_train_batch_size=BATCH_SIZE, + per_device_eval_batch_size=BATCH_SIZE, + gradient_accumulation_steps=GRAD_ACCUM, + learning_rate=LEARNING_RATE, + lr_scheduler_type="cosine", + warmup_ratio=WARMUP_RATIO, + weight_decay=0.01, + bf16=True, + gradient_checkpointing=True, + max_length=MAX_SEQ_LENGTH, + # Logging + logging_steps=10, + logging_first_step=True, + logging_strategy="steps", + disable_tqdm=True, + # Evaluation + eval_strategy="steps", + eval_steps=500, + # Saving + save_strategy="steps", + save_steps=500, + save_total_limit=3, + load_best_model_at_end=True, + metric_for_best_model="eval_loss", + greater_is_better=False, + # Hub push + push_to_hub=True, + hub_model_id=HUB_MODEL_ID, + hub_strategy="every_save", + # Performance + dataloader_num_workers=4, + dataloader_pin_memory=True, + # Report + report_to="none", + # Seed + seed=42, +) + +# ============ Initialize Trainer ============ +print("Initializing trainer...") +trainer = SFTTrainer( + model=model, + args=training_args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + processing_class=tokenizer, +) + +# ============ Train ============ +print("Starting training...") +train_result = trainer.train() + +# ============ Save Final Model ============ +print("Saving final model...") +trainer.save_model(OUTPUT_DIR) +tokenizer.save_pretrained(OUTPUT_DIR) + +# Push to Hub +print("Pushing to Hub...") +trainer.push_to_hub(commit_message="Final IndexLM-0.6B model") + +# ============ Log Final Metrics ============ +metrics = train_result.metrics +print(f"\nTraining complete!") +print(f" Train loss: {metrics.get('train_loss', 'N/A')}") +print(f" Train runtime: {metrics.get('train_runtime', 'N/A'):.0f}s") +print(f" Train samples/sec: {metrics.get('train_samples_per_second', 'N/A'):.1f}") + +# Final eval +eval_metrics = trainer.evaluate() +print(f" Eval loss: {eval_metrics.get('eval_loss', 'N/A')}") + +print(f"\nModel pushed to: https://huggingface.co/{HUB_MODEL_ID}") +``` + +### 5.3 Evaluation Script (`eval_indexlm.py`) + +```python +""" +IndexLM Evaluation Script + +Tests the trained model on: +1. Query-relevant extraction (QE) - F1/Precision/Recall +2. Main content extraction (ME) - F1/Precision/Recall +3. Inference speed on CPU +""" + +import json +import time +import os +import torch +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer + + +def parse_intervals(text): + """Parse interval string like '[[1,3],[5,7]]' into a set of indices.""" + text = text.strip() + if text.upper() == 'NA' or not text: + return set() + try: + intervals = json.loads(text) + indices = set() + for start, end in intervals: + indices.update(range(start, end + 1)) + return indices + except (json.JSONDecodeError, TypeError, ValueError): + return set() + + +def compute_f1(pred_indices, gold_indices): + """Compute F1, precision, recall between two sets of indices.""" + if not pred_indices and not gold_indices: + return 1.0, 1.0, 1.0 + if not pred_indices or not gold_indices: + return 0.0, 0.0, 0.0 + + tp = len(pred_indices & gold_indices) + precision = tp / len(pred_indices) if pred_indices else 0 + recall = tp / len(gold_indices) if gold_indices else 0 + f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0 + return f1, precision, recall + + +def generate_response(model, tokenizer, messages, device, max_new_tokens=128): + """Generate model response for given messages.""" + text = tokenizer.apply_chat_template( + messages[:-1], # Exclude assistant message (ground truth) + tokenize=False, + add_generation_prompt=True, + enable_thinking=False, + ) + inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=4096).to(device) + + with torch.no_grad(): + outputs = model.generate( + **inputs, + max_new_tokens=max_new_tokens, + do_sample=False, # Greedy for deterministic eval + temperature=1.0, + pad_token_id=tokenizer.pad_token_id, + ) + + # Decode only the new tokens + new_tokens = outputs[0][inputs['input_ids'].shape[1]:] + response = tokenizer.decode(new_tokens, skip_special_tokens=True) + return response.strip() + + +def evaluate_model(model_id, device="cpu", num_samples=100): + """Run full evaluation.""" + print(f"\n{'='*60}") + print(f"Evaluating: {model_id}") + print(f"Device: {device}") + print(f"{'='*60}") + + # Load model + print("Loading model...") + dtype = torch.float32 if device == "cpu" else torch.bfloat16 + tokenizer = AutoTokenizer.from_pretrained(model_id) + model = AutoModelForCausalLM.from_pretrained( + model_id, + torch_dtype=dtype, + attn_implementation="sdpa", + ).to(device) + model.eval() + + # Load eval dataset + print("Loading eval dataset...") + dataset = load_dataset("OmAlve/indexlm-training-data", split="eval") + + # Sample + if len(dataset) > num_samples: + dataset = dataset.shuffle(seed=42).select(range(num_samples)) + + # Categorize examples + qe_examples = [] + me_examples = [] + + for row in dataset: + msgs = row['messages'] + system_msg = msgs[0]['content'] if msgs[0]['role'] == 'system' else '' + if 'query' in system_msg.lower() and 'relevant' in system_msg.lower(): + qe_examples.append(msgs) + else: + me_examples.append(msgs) + + print(f"QE examples: {len(qe_examples)}, ME examples: {len(me_examples)}") + + # Evaluate QE + print("\n--- Query-Relevant Extraction (QE) ---") + qe_metrics = evaluate_task(model, tokenizer, qe_examples[:50], device) + + # Evaluate ME + print("\n--- Main Content Extraction (ME) ---") + me_metrics = evaluate_task(model, tokenizer, me_examples[:50], device) + + # Speed test + print("\n--- Inference Speed Test ---") + speed_test(model, tokenizer, qe_examples[:20], device) + + return qe_metrics, me_metrics + + +def evaluate_task(model, tokenizer, examples, device): + """Evaluate on a set of examples.""" + if not examples: + print("No examples for this task.") + return {} + + f1_scores = [] + precision_scores = [] + recall_scores = [] + exact_matches = 0 + + for i, msgs in enumerate(examples): + gold = msgs[-1]['content'] + gold_indices = parse_intervals(gold) + + pred = generate_response(model, tokenizer, msgs, device) + pred_indices = parse_intervals(pred) + + f1, prec, rec = compute_f1(pred_indices, gold_indices) + f1_scores.append(f1) + precision_scores.append(prec) + recall_scores.append(rec) + + if pred_indices == gold_indices: + exact_matches += 1 + + if i < 3: + print(f" Example {i+1}:") + print(f" Gold: {gold}") + print(f" Pred: {pred}") + print(f" F1: {f1:.3f}, P: {prec:.3f}, R: {rec:.3f}") + + avg_f1 = sum(f1_scores) / len(f1_scores) * 100 + avg_prec = sum(precision_scores) / len(precision_scores) * 100 + avg_rec = sum(recall_scores) / len(recall_scores) * 100 + em_rate = exact_matches / len(examples) * 100 + + print(f"\n Results ({len(examples)} examples):") + print(f" F1: {avg_f1:.2f}") + print(f" Precision: {avg_prec:.2f}") + print(f" Recall: {avg_rec:.2f}") + print(f" Exact Match: {em_rate:.2f}%") + + return {"f1": avg_f1, "precision": avg_prec, "recall": avg_rec, "exact_match": em_rate} + + +def speed_test(model, tokenizer, examples, device): + """Test inference speed.""" + if not examples: + return + + times = [] + for msgs in examples: + start = time.time() + _ = generate_response(model, tokenizer, msgs, device) + elapsed = time.time() - start + times.append(elapsed) + + avg_time = sum(times) / len(times) + print(f" Average inference time: {avg_time:.3f}s ({device})") + print(f" Min: {min(times):.3f}s, Max: {max(times):.3f}s") + print(f" Throughput: {1/avg_time:.1f} pages/sec") + + +if __name__ == "__main__": + model_id = os.environ.get("MODEL_ID", "OmAlve/IndexLM-0.6B") + device = "cuda" if torch.cuda.is_available() else "cpu" + evaluate_model(model_id, device=device, num_samples=100) +``` + +--- + +## 6. Key Design Decisions & Rationale + +### Why Qwen3-0.6B? +- The paper uses Qwen3-0.6B/1.7B/4B. The 0.6B achieves **near-identical performance** to 4B on RAG QA (54.70 vs 55.41 avg F1) +- 0.6B is **1.4GB in bf16, ~700MB in INT4** — runs fast on CPU +- TRL's own SFT documentation uses Qwen3-0.6B as its default example model — maximum compatibility +- Qwen3 has GQA (grouped-query attention) which is faster for inference than MHA + +### Why not ReaderLM-v2? +- ReaderLM-v2 does generative HTML→Markdown extraction (different task) +- It's **33-70× slower** than IndexLM on the paper's benchmarks +- Fine-tuning it for index prediction would fight against its pretrained generation behavior + +### Dataset construction vs. the paper +The paper uses: +1. Google Search API crawls → real HTML from the web +2. DeepSeek V3 annotation with 5-run majority voting +3. Common Crawl WARC files + +We approximate this with: +1. HotpotQA's structured context (title + sentences) converted to indexed HTML blocks +2. Programmatic labeling from HotpotQA's `supporting_facts` ground truth (higher quality than LLM annotation) +3. Synthetic noise injection (nav, ads, cookies, etc.) to simulate real web clutter +4. Mismatched query-page pairs for NA examples + +**Trade-off**: Our HTML blocks are simpler than real web HTML (no nested tables, complex CSS-in-JS, etc.). For production use, augmenting with real crawled HTML would improve robustness. The paper's full pipeline would require API costs (Google Search, DeepSeek V3). + +### Hyperparameters +Directly from the paper Section 3.3.2: "The training process is a typical SFT process" on Qwen3. We use: +- lr=2e-5 (TRL SFT default, standard for Qwen3) +- 3 epochs (standard SFT) +- Effective batch size 16 (4 × 4 grad accum) +- Cosine LR schedule with 5% warmup +- max_length=4096 (covers 99.8% of our data, well within Qwen3's 32K context) + +--- + +## 7. What's Left To Do + +| Task | Status | Notes | +|------|--------|-------| +| Run `train_indexlm.py` | ❌ | Needs GPU — a10g-large recommended (~$8 total) | +| Run `eval_indexlm.py` | ❌ | After training completes | +| ONNX export for CPU | ❌ | Optional: `optimum-cli export onnx --model OmAlve/IndexLM-0.6B indexlm-onnx/` | +| INT4 quantization | ❌ | Optional: use `bitsandbytes` or `llama.cpp` for faster CPU | +| Real HTML augmentation | ❌ | Optional: crawl real web pages to augment training data | + +--- + +## 8. Resources + +| Resource | URL | +|----------|-----| +| Paper | https://arxiv.org/abs/2512.06641 | +| Training dataset | https://huggingface.co/datasets/OmAlve/indexlm-training-data | +| Base model | https://huggingface.co/Qwen/Qwen3-0.6B | +| Output model (after training) | https://huggingface.co/OmAlve/IndexLM-0.6B | +| TRL SFT docs | https://huggingface.co/docs/trl/sft_trainer | +| HotpotQA source | https://huggingface.co/datasets/hotpotqa/hotpot_qa | + +--- + +## 9. Dependencies + +``` +transformers>=4.51.0 +trl>=1.2.0 +torch +datasets +accelerate +trackio +flash-attn # optional, GPU training only +beautifulsoup4 # only for prepare_data.py +lxml # only for prepare_data.py +``` diff --git a/README.md b/README.md new file mode 100644 index 0000000..5c3485c --- /dev/null +++ b/README.md @@ -0,0 +1,60 @@ +--- +base_model: Qwen/Qwen3-0.6B +library_name: transformers +model_name: reading-steiner +tags: +- generated_from_trainer +- trackio:https://huggingface.co/spaces/OmAlve/trackio +- trl +- sft +- trackio +licence: license +--- + +# Model Card for reading-steiner + +This model is a fine-tuned version of [Qwen/Qwen3-0.6B](https://huggingface.co/Qwen/Qwen3-0.6B). +It has been trained using [TRL](https://github.com/huggingface/trl). + +## Quick start + +```python +from transformers import pipeline + +question = "If you had a time machine, but could only go to the past or the future once and never return, which would you choose and why?" +generator = pipeline("text-generation", model="OmAlve/reading-steiner", device="cuda") +output = generator([{"role": "user", "content": question}], max_new_tokens=128, return_full_text=False)[0] +print(output["generated_text"]) +``` + +## Training procedure + + + + +This model was trained with SFT. + +### Framework versions + +- TRL: 0.24.0 +- Transformers: 5.5.0 +- Pytorch: 2.5.1+cu124 +- Datasets: 4.3.0 +- Tokenizers: 0.22.2 + +## Citations + + + +Cite TRL as: + +```bibtex +@misc{vonwerra2022trl, + title = {{TRL: Transformer Reinforcement Learning}}, + author = {Leandro von Werra and Younes Belkada and Lewis Tunstall and Edward Beeching and Tristan Thrush and Nathan Lambert and Shengyi Huang and Kashif Rasul and Quentin Gallou{\'e}dec}, + year = 2020, + journal = {GitHub repository}, + publisher = {GitHub}, + howpublished = {\url{https://github.com/huggingface/trl}} +} +``` \ No newline at end of file diff --git a/chat_template.jinja b/chat_template.jinja new file mode 100644 index 0000000..01be9b3 --- /dev/null +++ b/chat_template.jinja @@ -0,0 +1,89 @@ +{%- if tools %} + {{- '<|im_start|>system\n' }} + {%- if messages[0].role == 'system' %} + {{- messages[0].content + '\n\n' }} + {%- endif %} + {{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within XML tags:\n" }} + {%- for tool in tools %} + {{- "\n" }} + {{- tool | tojson }} + {%- endfor %} + {{- "\n\n\nFor each function call, return a json object with function name and arguments within XML tags:\n\n{\"name\": , \"arguments\": }\n<|im_end|>\n" }} +{%- else %} + {%- if messages[0].role == 'system' %} + {{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }} + {%- endif %} +{%- endif %} +{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %} +{%- for message in messages[::-1] %} + {%- set index = (messages|length - 1) - loop.index0 %} + {%- if ns.multi_step_tool and message.role == "user" and message.content is string and not(message.content.startswith('') and message.content.endswith('')) %} + {%- set ns.multi_step_tool = false %} + {%- set ns.last_query_index = index %} + {%- endif %} +{%- endfor %} +{%- for message in messages %} + {%- if message.content is string %} + {%- set content = message.content %} + {%- else %} + {%- set content = '' %} + {%- endif %} + {%- if (message.role == "user") or (message.role == "system" and not loop.first) %} + {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }} + {%- elif message.role == "assistant" %} + {%- set reasoning_content = '' %} + {%- if message.reasoning_content is string %} + {%- set reasoning_content = message.reasoning_content %} + {%- else %} + {%- if '' in content %} + {%- set reasoning_content = content.split('')[0].rstrip('\n').split('')[-1].lstrip('\n') %} + {%- set content = content.split('')[-1].lstrip('\n') %} + {%- endif %} + {%- endif %} + {%- if loop.index0 > ns.last_query_index %} + {%- if loop.last or (not loop.last and reasoning_content) %} + {{- '<|im_start|>' + message.role + '\n\n' + reasoning_content.strip('\n') + '\n\n\n' + content.lstrip('\n') }} + {%- else %} + {{- '<|im_start|>' + message.role + '\n' + content }} + {%- endif %} + {%- else %} + {{- '<|im_start|>' + message.role + '\n' + content }} + {%- endif %} + {%- if message.tool_calls %} + {%- for tool_call in message.tool_calls %} + {%- if (loop.first and content) or (not loop.first) %} + {{- '\n' }} + {%- endif %} + {%- if tool_call.function %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '\n{"name": "' }} + {{- tool_call.name }} + {{- '", "arguments": ' }} + {%- if tool_call.arguments is string %} + {{- tool_call.arguments }} + {%- else %} + {{- tool_call.arguments | tojson }} + {%- endif %} + {{- '}\n' }} + {%- endfor %} + {%- endif %} + {{- '<|im_end|>\n' }} + {%- elif message.role == "tool" %} + {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %} + {{- '<|im_start|>user' }} + {%- endif %} + {{- '\n\n' }} + {{- content }} + {{- '\n' }} + {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %} + {{- '<|im_end|>\n' }} + {%- endif %} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|im_start|>assistant\n' }} + {%- if enable_thinking is defined and enable_thinking is false %} + {{- '\n\n\n\n' }} + {%- endif %} +{%- endif %} \ No newline at end of file diff --git a/config.json b/config.json new file mode 100644 index 0000000..7e1858f --- /dev/null +++ b/config.json @@ -0,0 +1,63 @@ +{ + "architectures": [ + "Qwen3ForCausalLM" + ], + "attention_bias": false, + "attention_dropout": 0.0, + "bos_token_id": null, + "dtype": "bfloat16", + "eos_token_id": 151645, + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 1024, + "initializer_range": 0.02, + "intermediate_size": 3072, + "layer_types": [ + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention" + ], + "max_position_embeddings": 40960, + "max_window_layers": 28, + "model_type": "qwen3", + "num_attention_heads": 16, + "num_hidden_layers": 28, + "num_key_value_heads": 8, + "pad_token_id": 151643, + "rms_norm_eps": 1e-06, + "rope_parameters": { + "rope_theta": 1000000, + "rope_type": "default" + }, + "sliding_window": null, + "tie_word_embeddings": true, + "transformers_version": "5.5.0", + "use_cache": false, + "use_sliding_window": false, + "vocab_size": 151936 +} diff --git a/eval_indexlm.py b/eval_indexlm.py new file mode 100644 index 0000000..d34385b --- /dev/null +++ b/eval_indexlm.py @@ -0,0 +1,194 @@ +""" +Reading Steiner Evaluation Script + +Tests the trained model on: +1. Query-relevant extraction (QE) - F1/Precision/Recall +2. Main content extraction (ME) - F1/Precision/Recall +3. Inference speed on CPU +""" + +import json +import time +import os +import torch +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer + + +def parse_intervals(text): + """Parse interval string like '[[1,3],[5,7]]' into a set of indices.""" + text = text.strip() + if text.upper() == 'NA' or not text: + return set() + try: + intervals = json.loads(text) + indices = set() + for start, end in intervals: + indices.update(range(start, end + 1)) + return indices + except (json.JSONDecodeError, TypeError, ValueError): + return set() + + +def compute_f1(pred_indices, gold_indices): + """Compute F1, precision, recall between two sets of indices.""" + if not pred_indices and not gold_indices: + return 1.0, 1.0, 1.0 + if not pred_indices or not gold_indices: + return 0.0, 0.0, 0.0 + + tp = len(pred_indices & gold_indices) + precision = tp / len(pred_indices) if pred_indices else 0 + recall = tp / len(gold_indices) if gold_indices else 0 + f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0 + return f1, precision, recall + + +def generate_response(model, tokenizer, messages, device, max_new_tokens=128): + """Generate model response for given messages.""" + text = tokenizer.apply_chat_template( + messages[:-1], # Exclude assistant message (ground truth) + tokenize=False, + add_generation_prompt=True, + enable_thinking=False, + ) + inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=4096).to(device) + + with torch.no_grad(): + outputs = model.generate( + **inputs, + max_new_tokens=max_new_tokens, + do_sample=False, # Greedy for deterministic eval + temperature=1.0, + pad_token_id=tokenizer.pad_token_id, + ) + + # Decode only the new tokens + new_tokens = outputs[0][inputs['input_ids'].shape[1]:] + response = tokenizer.decode(new_tokens, skip_special_tokens=True) + return response.strip() + + +def evaluate_model(model_id, device="cpu", num_samples=100): + """Run full evaluation.""" + print(f"\n{'='*60}") + print(f"Evaluating: {model_id}") + print(f"Device: {device}") + print(f"{'='*60}") + + # Load model + print("Loading model...") + dtype = torch.float32 if device == "cpu" else torch.bfloat16 + tokenizer = AutoTokenizer.from_pretrained(model_id) + model = AutoModelForCausalLM.from_pretrained( + model_id, + torch_dtype=dtype, + attn_implementation="sdpa", + ).to(device) + model.eval() + + # Load eval dataset + print("Loading eval dataset...") + dataset = load_dataset("OmAlve/indexlm-training-data", split="eval") + + # Sample + if len(dataset) > num_samples: + dataset = dataset.shuffle(seed=42).select(range(num_samples)) + + # Categorize examples + qe_examples = [] + me_examples = [] + + for row in dataset: + msgs = row['messages'] + system_msg = msgs[0]['content'] if msgs[0]['role'] == 'system' else '' + if 'query' in system_msg.lower() and 'relevant' in system_msg.lower(): + qe_examples.append(msgs) + else: + me_examples.append(msgs) + + print(f"QE examples: {len(qe_examples)}, ME examples: {len(me_examples)}") + + # Evaluate QE + print("\n--- Query-Relevant Extraction (QE) ---") + qe_metrics = evaluate_task(model, tokenizer, qe_examples[:50], device) + + # Evaluate ME + print("\n--- Main Content Extraction (ME) ---") + me_metrics = evaluate_task(model, tokenizer, me_examples[:50], device) + + # Speed test + print("\n--- Inference Speed Test ---") + speed_test(model, tokenizer, qe_examples[:20], device) + + return qe_metrics, me_metrics + + +def evaluate_task(model, tokenizer, examples, device): + """Evaluate on a set of examples.""" + if not examples: + print("No examples for this task.") + return {} + + f1_scores = [] + precision_scores = [] + recall_scores = [] + exact_matches = 0 + + for i, msgs in enumerate(examples): + gold = msgs[-1]['content'] + gold_indices = parse_intervals(gold) + + pred = generate_response(model, tokenizer, msgs, device) + pred_indices = parse_intervals(pred) + + f1, prec, rec = compute_f1(pred_indices, gold_indices) + f1_scores.append(f1) + precision_scores.append(prec) + recall_scores.append(rec) + + if pred_indices == gold_indices: + exact_matches += 1 + + if i < 3: + print(f" Example {i+1}:") + print(f" Gold: {gold}") + print(f" Pred: {pred}") + print(f" F1: {f1:.3f}, P: {prec:.3f}, R: {rec:.3f}") + + avg_f1 = sum(f1_scores) / len(f1_scores) * 100 + avg_prec = sum(precision_scores) / len(precision_scores) * 100 + avg_rec = sum(recall_scores) / len(recall_scores) * 100 + em_rate = exact_matches / len(examples) * 100 + + print(f"\n Results ({len(examples)} examples):") + print(f" F1: {avg_f1:.2f}") + print(f" Precision: {avg_prec:.2f}") + print(f" Recall: {avg_rec:.2f}") + print(f" Exact Match: {em_rate:.2f}%") + + return {"f1": avg_f1, "precision": avg_prec, "recall": avg_rec, "exact_match": em_rate} + + +def speed_test(model, tokenizer, examples, device): + """Test inference speed.""" + if not examples: + return + + times = [] + for msgs in examples: + start = time.time() + _ = generate_response(model, tokenizer, msgs, device) + elapsed = time.time() - start + times.append(elapsed) + + avg_time = sum(times) / len(times) + print(f" Average inference time: {avg_time:.3f}s ({device})") + print(f" Min: {min(times):.3f}s, Max: {max(times):.3f}s") + print(f" Throughput: {1/avg_time:.1f} pages/sec") + + +if __name__ == "__main__": + model_id = os.environ.get("MODEL_ID", "OmAlve/reading-steiner") + device = "cuda" if torch.cuda.is_available() else "cpu" + evaluate_model(model_id, device=device, num_samples=100) diff --git a/generation_config.json b/generation_config.json new file mode 100644 index 0000000..d40253d --- /dev/null +++ b/generation_config.json @@ -0,0 +1,12 @@ +{ + "do_sample": true, + "eos_token_id": [ + 151645, + 151643 + ], + "pad_token_id": 151643, + "temperature": 0.6, + "top_k": 20, + "top_p": 0.95, + "transformers_version": "5.5.0" +} diff --git a/model.safetensors b/model.safetensors new file mode 100644 index 0000000..751c905 --- /dev/null +++ b/model.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f0b468190c0c4336d45ccab3bb36097639a63a67f58ac443efd9bd67d9d42f95 +size 1192135096 diff --git a/prepare_data.py b/prepare_data.py new file mode 100644 index 0000000..7aee287 --- /dev/null +++ b/prepare_data.py @@ -0,0 +1,486 @@ +""" +Prepare IndexLM training data from HotpotQA and MSMARCO. + +Pipeline: +1. Load HotpotQA (has context = list of (title, sentences) + supporting_facts) +2. Convert context into indexed HTML-like blocks: [i] content +3. The target is index intervals of blocks containing supporting facts +4. Also create main-content extraction examples (all content blocks are "main content", + but we inject noise blocks like nav/ads to train the model to filter them) +5. Format as conversational messages for SFT +""" + +import json +import random +import re +from datasets import load_dataset, Dataset +from collections import defaultdict + +random.seed(42) + +# Noise blocks to inject (simulating real web page clutter) +NOISE_BLOCKS = [ + '', + '
Advertisement - Continue Reading Below
', + '', + '', + '', + '
Share on: Twitter | Facebook | LinkedIn
', + '', + '
Subscribe to our newsletter for updates
', + '', + '', + '
Comments (0) - Be the first to comment
', + '
Written by Staff Reporter | Updated: Jan 2024
', + '', + '', + '', +] + +SYSTEM_PROMPT_QE = """You are IndexLM, a web content extraction model. Given a webpage split into indexed blocks and a user query, identify which blocks contain content relevant to the query. + +Each block is formatted as: [i] content +Output the indices of relevant blocks as a Python list of [start, end] intervals (inclusive). +If no relevant content exists, output 'NA'. + +Example output: [[2,4],[7,7],[10,12]]""" + +SYSTEM_PROMPT_ME = """You are IndexLM, a web content extraction model. Given a webpage split into indexed blocks, identify which blocks contain the main content of the page (filtering out navigation, advertisements, sidebars, and other non-content elements). + +Each block is formatted as: [i] content +Output the indices of main content blocks as a Python list of [start, end] intervals (inclusive). +If no main content exists, output 'NA'. + +Example output: [[1,3],[5,8],[11,15]]""" + + +def indices_to_intervals(indices): + """Convert a sorted list of indices to intervals [[start,end], ...]""" + if not indices: + return "NA" + indices = sorted(set(indices)) + intervals = [] + start = indices[0] + end = indices[0] + for i in indices[1:]: + if i == end + 1: + end = i + else: + intervals.append([start, end]) + start = i + end = i + intervals.append([start, end]) + return json.dumps(intervals) + + +def create_indexed_blocks_from_hotpotqa(context, supporting_facts, inject_noise=True): + """ + Convert HotpotQA context into indexed HTML blocks. + + context: {'title': [...], 'sentences': [[...], ...]} + supporting_facts: {'title': [...], 'sent_id': [...]} + + Returns: (block_text, relevant_indices, all_content_indices) + """ + titles = context['title'] + sentences_list = context['sentences'] + + # Build supporting facts lookup + sf_lookup = defaultdict(set) + for title, sent_id in zip(supporting_facts['title'], supporting_facts['sent_id']): + sf_lookup[title].add(sent_id) + + blocks = [] + relevant_indices = [] + content_indices = [] # All real content (non-noise) + + idx = 1 + + for doc_idx, (title, sentences) in enumerate(zip(titles, sentences_list)): + # Title block + blocks.append(f"[{idx}]

{title}

") + content_indices.append(idx) + if title in sf_lookup: + # Title of a supporting document is relevant + relevant_indices.append(idx) + idx += 1 + + # Sentence blocks + for sent_idx, sentence in enumerate(sentences): + sentence = sentence.strip() + if not sentence: + continue + + # Use

for regular text + blocks.append(f"[{idx}]

{sentence}

") + content_indices.append(idx) + + if title in sf_lookup and sent_idx in sf_lookup[title]: + relevant_indices.append(idx) + idx += 1 + + # Inject noise between documents sometimes + if inject_noise and random.random() < 0.4 and doc_idx < len(titles) - 1: + noise = random.choice(NOISE_BLOCKS) + blocks.append(f"[{idx}] {noise}") + idx += 1 + + # Sometimes add noise at start and end + if inject_noise: + prefix_noise = [] + if random.random() < 0.5: + for _ in range(random.randint(1, 3)): + noise = random.choice(NOISE_BLOCKS) + prefix_noise.append(noise) + + suffix_noise = [] + if random.random() < 0.5: + for _ in range(random.randint(1, 3)): + noise = random.choice(NOISE_BLOCKS) + suffix_noise.append(noise) + + if prefix_noise or suffix_noise: + # Reindex everything + new_blocks = [] + new_relevant = [] + new_content = [] + new_idx = 1 + + # Prefix noise + for noise in prefix_noise: + new_blocks.append(f"[{new_idx}] {noise}") + new_idx += 1 + + # Remap original blocks + offset = len(prefix_noise) + for b in blocks: + old_idx = int(b.split(']')[0].replace('[', '')) + new_b = f"[{old_idx + offset}] " + '] '.join(b.split('] ')[1:]) + new_blocks.append(new_b) + + new_relevant = [r + offset for r in relevant_indices] + new_content = [c + offset for c in content_indices] + + # Suffix noise + next_idx = len(new_blocks) + 1 + for noise in suffix_noise: + new_blocks.append(f"[{next_idx}] {noise}") + next_idx += 1 + + blocks = new_blocks + relevant_indices = new_relevant + content_indices = new_content + + block_text = "\n".join(blocks) + return block_text, relevant_indices, content_indices + + +def build_query_relevant_example(question, block_text, relevant_indices, url="https://en.wikipedia.org"): + """Build a query-relevant extraction (QE) example.""" + intervals = indices_to_intervals(relevant_indices) + + user_content = f"URL: {url}\nQuery: {question}\n\nBlocks:\n{block_text}\n\nOutput the index intervals of blocks relevant to the query." + + messages = [ + {"role": "system", "content": SYSTEM_PROMPT_QE}, + {"role": "user", "content": user_content}, + {"role": "assistant", "content": intervals} + ] + return messages + + +def build_main_content_example(block_text, content_indices, title="Wikipedia Article", url="https://en.wikipedia.org"): + """Build a main content extraction (ME) example.""" + intervals = indices_to_intervals(content_indices) + + user_content = f"URL: {url}\nTitle: {title}\n\nBlocks:\n{block_text}\n\nOutput the index intervals of main content blocks." + + messages = [ + {"role": "system", "content": SYSTEM_PROMPT_ME}, + {"role": "user", "content": user_content}, + {"role": "assistant", "content": intervals} + ] + return messages + + +def process_hotpotqa(): + """Process HotpotQA into IndexLM training data.""" + print("Loading HotpotQA...") + ds = load_dataset("hotpotqa/hotpot_qa", "distractor", split="train") + + # Sample a manageable amount + num_samples = min(15000, len(ds)) + ds = ds.shuffle(seed=42).select(range(num_samples)) + + all_examples = [] + skipped = 0 + + for i, row in enumerate(ds): + if i % 1000 == 0: + print(f"Processing {i}/{num_samples}...") + + try: + block_text, relevant_indices, content_indices = create_indexed_blocks_from_hotpotqa( + row['context'], row['supporting_facts'], inject_noise=True + ) + + # Skip if too few relevant indices + if len(relevant_indices) < 1: + skipped += 1 + continue + + # Query-relevant extraction example + qe_messages = build_query_relevant_example( + row['question'], block_text, relevant_indices + ) + all_examples.append({ + "messages": qe_messages, + "task_type": "query_relevant", + "source": "hotpotqa" + }) + + # Main content extraction example (50% of the time) + if random.random() < 0.5: + me_messages = build_main_content_example( + block_text, content_indices, + title=row['context']['title'][0] if row['context']['title'] else "Article" + ) + all_examples.append({ + "messages": me_messages, + "task_type": "main_content", + "source": "hotpotqa" + }) + except Exception as e: + skipped += 1 + if skipped < 5: + print(f"Error on row {i}: {e}") + continue + + print(f"Created {len(all_examples)} examples from HotpotQA ({skipped} skipped)") + return all_examples + + +def create_synthetic_web_pages(): + """Create synthetic web page examples for main content extraction training.""" + print("Creating synthetic web page examples...") + + # Load a text dataset to get content + ds = load_dataset("hotpotqa/hotpot_qa", "distractor", split="validation") + ds = ds.shuffle(seed=123).select(range(3000)) + + examples = [] + + for i, row in enumerate(ds): + if i % 500 == 0: + print(f"Synthetic page {i}/3000...") + + try: + # Build a more realistic web page structure + titles = row['context']['title'] + sentences_list = row['context']['sentences'] + + if not titles or not sentences_list: + continue + + blocks = [] + content_indices = [] + idx = 1 + + # Header noise (nav, etc.) + num_header_noise = random.randint(1, 4) + for _ in range(num_header_noise): + blocks.append(f"[{idx}] {random.choice(NOISE_BLOCKS)}") + idx += 1 + + # Page title + main_title = titles[0] + blocks.append(f"[{idx}]

{main_title}

") + content_indices.append(idx) + idx += 1 + + # Main content (just first 1-3 documents) + num_docs = min(random.randint(1, 3), len(titles)) + for doc_idx in range(num_docs): + title = titles[doc_idx] + sents = sentences_list[doc_idx] + + if doc_idx > 0: + blocks.append(f"[{idx}]

{title}

") + content_indices.append(idx) + idx += 1 + + for sent in sents: + sent = sent.strip() + if not sent: + continue + blocks.append(f"[{idx}]

{sent}

") + content_indices.append(idx) + idx += 1 + + # Occasional inline noise + if random.random() < 0.3: + blocks.append(f"[{idx}] {random.choice(NOISE_BLOCKS)}") + idx += 1 + + # Footer noise + num_footer_noise = random.randint(1, 4) + for _ in range(num_footer_noise): + blocks.append(f"[{idx}] {random.choice(NOISE_BLOCKS)}") + idx += 1 + + block_text = "\n".join(blocks) + me_messages = build_main_content_example( + block_text, content_indices, + title=main_title, + url=f"https://en.wikipedia.org/wiki/{main_title.replace(' ', '_')}" + ) + examples.append({ + "messages": me_messages, + "task_type": "main_content", + "source": "synthetic" + }) + except Exception as e: + continue + + print(f"Created {len(examples)} synthetic web page examples") + return examples + + +def create_na_examples(): + """Create examples where no relevant content exists (model should output 'NA').""" + print("Creating NA examples...") + ds = load_dataset("hotpotqa/hotpot_qa", "distractor", split="validation") + ds = ds.shuffle(seed=456).select(range(1000)) + + examples = [] + + for i, row in enumerate(ds): + try: + # Use context from one question but query from another (mismatched) + other_idx = (i + 500) % len(ds) + other_question = ds[other_idx]['question'] + + # Build blocks from current context but keep only non-supporting content + block_text, _, content_indices = create_indexed_blocks_from_hotpotqa( + row['context'], {'title': [], 'sent_id': []}, inject_noise=True + ) + + # The query doesn't match this content → expected output: NA + # But actually some content might still be tangentially relevant, + # so we'll be conservative and only do this for clearly mismatched pairs + user_content = f"URL: https://en.wikipedia.org\nQuery: {other_question}\n\nBlocks:\n{block_text}\n\nOutput the index intervals of blocks relevant to the query." + + messages = [ + {"role": "system", "content": SYSTEM_PROMPT_QE}, + {"role": "user", "content": user_content}, + {"role": "assistant", "content": "NA"} + ] + examples.append({ + "messages": messages, + "task_type": "query_relevant_na", + "source": "hotpotqa_mismatched" + }) + except: + continue + + # Keep only a fraction (the paper mentions partial filtering of NA) + random.shuffle(examples) + examples = examples[:300] + print(f"Created {len(examples)} NA examples") + return examples + + +def main(): + # Build all training examples + qe_examples = process_hotpotqa() + me_examples = create_synthetic_web_pages() + na_examples = create_na_examples() + + all_examples = qe_examples + me_examples + na_examples + random.shuffle(all_examples) + + print(f"\nTotal examples: {len(all_examples)}") + + # Count by type + type_counts = defaultdict(int) + for ex in all_examples: + type_counts[ex['task_type']] += 1 + for t, c in type_counts.items(): + print(f" {t}: {c}") + + # Check lengths + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B") + + lengths = [] + for ex in all_examples[:500]: + text = tokenizer.apply_chat_template(ex['messages'], tokenize=False) + tokens = tokenizer.encode(text) + lengths.append(len(tokens)) + + print(f"\nToken length stats (sample of 500):") + print(f" Min: {min(lengths)}") + print(f" Max: {max(lengths)}") + print(f" Mean: {sum(lengths)/len(lengths):.0f}") + print(f" Median: {sorted(lengths)[len(lengths)//2]}") + + # Filter out examples that are too long (>4096 tokens for efficiency) + MAX_LEN = 4096 + filtered = [] + too_long = 0 + for ex in all_examples: + text = tokenizer.apply_chat_template(ex['messages'], tokenize=False) + tokens = tokenizer.encode(text) + if len(tokens) <= MAX_LEN: + filtered.append(ex) + else: + too_long += 1 + + print(f"\nFiltered: {too_long} examples too long (>{MAX_LEN} tokens)") + print(f"Final dataset: {len(filtered)} examples") + + # Split into train/eval + random.shuffle(filtered) + eval_size = min(500, len(filtered) // 10) + train_data = filtered[:-eval_size] + eval_data = filtered[-eval_size:] + + print(f"Train: {len(train_data)}, Eval: {len(eval_data)}") + + # Create HF dataset with just messages column (for SFTTrainer) + train_ds = Dataset.from_list([{"messages": ex["messages"]} for ex in train_data]) + eval_ds = Dataset.from_list([{"messages": ex["messages"]} for ex in eval_data]) + + # Save locally + train_ds.save_to_disk("/app/indexlm_train") + eval_ds.save_to_disk("/app/indexlm_eval") + + # Also push to HF Hub + from huggingface_hub import login + import os + login(token=os.environ.get("HF_TOKEN")) + + from datasets import DatasetDict + ds_dict = DatasetDict({"train": train_ds, "eval": eval_ds}) + ds_dict.push_to_hub("OmAlve/indexlm-training-data") + + print("\nDone! Dataset pushed to OmAlve/indexlm-training-data") + + # Print sample + print("\n=== Sample QE example ===") + for ex in train_data[:3]: + if ex.get("task_type", "") == "query_relevant": + for m in ex["messages"]: + print(f"\n[{m['role']}]: {m['content'][:200]}...") + break + + print("\n=== Sample ME example ===") + for ex in train_data[:10]: + if ex.get("task_type", "") == "main_content": + for m in ex["messages"]: + print(f"\n[{m['role']}]: {m['content'][:200]}...") + break + + +if __name__ == "__main__": + main() diff --git a/prepare_data_v2.py b/prepare_data_v2.py new file mode 100644 index 0000000..73b12ab --- /dev/null +++ b/prepare_data_v2.py @@ -0,0 +1,897 @@ +""" +Prepare DIVERSE IndexLM training data from multiple sources: + +1. HtmlRAG-train (real Bing-scraped web HTML) — diverse domains +2. MultiHopRAG (news domain) — technology, business, sports, entertainment +3. HotpotQA (Wikipedia) — structured QA with supporting facts + +This avoids the Wikipedia-only bias of the original dataset. + +Output: Conversational messages for SFT with TRL SFTTrainer +Format: system + user (indexed HTML blocks + query) → assistant (index intervals) +""" + +import json +import random +import re +import os +from datasets import load_dataset, Dataset, DatasetDict +from collections import defaultdict +from bs4 import BeautifulSoup +import html as html_lib + +random.seed(42) + +# ============ System Prompts ============ + +SYSTEM_PROMPT_QE = """You are IndexLM, a web content extraction model. Given a webpage split into indexed blocks and a user query, identify which blocks contain content relevant to the query. + +Each block is formatted as: [i] content +Output the indices of relevant blocks as a Python list of [start, end] intervals (inclusive). +If no relevant content exists, output 'NA'. + +Example output: [[2,4],[7,7],[10,12]]""" + +SYSTEM_PROMPT_ME = """You are IndexLM, a web content extraction model. Given a webpage split into indexed blocks, identify which blocks contain the main content of the page (filtering out navigation, advertisements, sidebars, and other non-content elements). + +Each block is formatted as: [i] content +Output the indices of main content blocks as a Python list of [start, end] intervals (inclusive). +If no main content exists, output 'NA'. + +Example output: [[1,3],[5,8],[11,15]]""" + +# ============ Noise blocks for injection ============ + +NOISE_BLOCKS_REALISTIC = [ + '', + '
Advertisement - Continue Reading Below
', + '', + '', + '', + '
Share: Twitter | Facebook | LinkedIn | Reddit | Email
', + '', + '
Subscribe to our newsletter for the latest updates delivered to your inbox weekly.
', + '', + '', + '
Comments (0) — Be the first to comment! Please read our community guidelines before posting.
', + '
Written by Staff Reporter | Updated: January 15, 2024 | 5 min read
', + '', + '', + '
Categories: Science | Technology | Health | Business | Sports | Entertainment | Politics
', + '
Already a subscriber? Log in for full access. Not a member? Subscribe now starting at $4.99/month.
', + '', + '
Watch: Video player requires JavaScript to be enabled. [Video placeholder]
', + '
BREAKING: Markets rally on latest economic data | Sports: Championship results | Weather: Storm warning issued
', + '
Skip to main content | Skip to navigation | Accessibility statement
', + '
We value your privacy. We and our partners use tracking technologies to improve your browsing experience, serve personalized content, and analyze traffic.
', + '
Download our app for a better reading experience! Available on iOS and Android.
', + '', + '', + '', + '
Was this article helpful? Yes | No | Send Feedback
', + '
Language: English | Español | Français | Deutsch | 日本語 | 中文
', + '', +] + + +def indices_to_intervals(indices): + """Convert a sorted list of indices to intervals [[start,end], ...]""" + if not indices: + return "NA" + indices = sorted(set(indices)) + intervals = [] + start = indices[0] + end = indices[0] + for i in indices[1:]: + if i == end + 1: + end = i + else: + intervals.append([start, end]) + start = i + end = i + intervals.append([start, end]) + return json.dumps(intervals) + + +# ============================================================ +# SOURCE 1: HtmlRAG-train (Real Bing-scraped web HTML) +# ============================================================ + +def extract_text_content(html_str): + """Extract visible text from an HTML string.""" + try: + soup = BeautifulSoup(html_str, 'html.parser') + return soup.get_text(separator=' ', strip=True) + except: + # Fallback: strip tags with regex + clean = re.sub(r'<[^>]+>', ' ', html_str) + return re.sub(r'\s+', ' ', clean).strip() + + +def segment_html_to_blocks(html_content): + """ + Segment real HTML content into indexed blocks. + Splits by block-level HTML tags and line boundaries. + """ + blocks = [] + + # Strategy: split by block-level closing/opening tags + # HtmlRAG uses tags like ,

, ,

  • , etc. + # Split at positions where block-level tags start + block_tag_pattern = r'(<(?:div|p|h[1-6]|li|ul|ol|table|tr|td|th|article|section|header|footer|nav|aside|main|blockquote|pre|form|figure|figcaption|details|summary|option|title|button|label|select|textarea|hgroup|dl|dd|dt|caption|thead|tbody|tfoot)\b[^>]*>)' + + # Also handle HtmlRAG numbered tags like , , etc. + block_tag_pattern_numbered = r'(<(?:div|p|h|li|ul|ol|table|tr|td|th|article|section|header|footer|nav|aside|main|blockquote|pre|form|figure|option|title|button|hgroup)\d*[^>]*>)' + + # Split content by block-level tags + parts = re.split(block_tag_pattern_numbered, html_content) + + current_block = '' + for part in parts: + part = part.strip() + if not part: + continue + + # Check if this part is a block-level opening tag + if re.match(block_tag_pattern_numbered, part): + # Save previous block if it has content + if current_block.strip(): + blocks.append(current_block.strip()) + current_block = part + else: + current_block += ' ' + part + + # Don't forget the last block + if current_block.strip(): + blocks.append(current_block.strip()) + + # If tag-based splitting yields too few blocks, fall back to line-based + if len(blocks) < 5: + blocks = [] + lines = html_content.split('\n') + for line in lines: + line = line.strip() + if line and len(line) > 5: + blocks.append(line) + + # If still too few, split by multiple tags on same line + if len(blocks) < 5: + new_blocks = [] + for block in blocks: + # Try splitting long blocks by inner tags + if len(block) > 200: + inner_parts = re.split(r'()', block) + current = '' + for ip in inner_parts: + current += ip + if re.match(r'', ip): + if current.strip(): + new_blocks.append(current.strip()) + current = '' + if current.strip(): + new_blocks.append(current.strip()) + else: + new_blocks.append(block) + if len(new_blocks) > len(blocks): + blocks = new_blocks + + # Filter: extract text and remove blocks with no meaningful content + def extract_text_simple(s): + clean = re.sub(r'<[^>]+>', ' ', s) + return re.sub(r'\s+', ' ', clean).strip() + + blocks = [b for b in blocks if len(extract_text_simple(b)) > 5] + + return blocks + + +def classify_block_as_noise(block_text): + """Heuristic: classify if a block is likely noise (nav, ad, etc.).""" + text_lower = block_text.lower() + noise_indicators = [ + 'cookie', 'privacy policy', 'terms of service', 'advertisement', + 'subscribe', 'newsletter', 'sign up', 'log in', 'login', + 'copyright ©', 'all rights reserved', 'skip to', 'accessibility', + 'share on twitter', 'share on facebook', 'social media', + 'related articles', 'you may also like', 'trending now', + 'app download', 'sponsored content', 'affiliate', + ] + nav_patterns = [' div, h20 -> h2) + tag = re.sub(r'\d+$', '', tag) + if not tag: + tag = 'div' + else: + tag = 'p' + + text = extract_text_content(block) + if not text or len(text) < 3: + continue + + indexed_blocks.append(f"[{idx}] <{tag}>{text}") + + # Check if this block is noise + is_noise = classify_block_as_noise(block) + if not is_noise: + content_indices.append(idx) + + # Check relevance by substring matching with assistant output + # Use the full relevant text as a search target + text_lower = text.lower() + relevant_lower = relevant_text.lower() + + # Method 1: Check if significant portions of relevant text appear in block + # Split relevant text into 3-word ngrams and check for matches + rel_words_list = relevant_lower.split() + matched = False + + # Check 3-gram overlap + for i in range(len(rel_words_list) - 2): + trigram = ' '.join(rel_words_list[i:i+3]) + if trigram in text_lower: + matched = True + break + + # Also check: does the block text appear as a substring in the relevant text? + if not matched and len(text) > 15: + # Check if meaningful portion of block appears in relevant output + block_sentences = [s.strip() for s in text.split('.') if len(s.strip()) > 10] + for sent in block_sentences: + if sent.lower() in relevant_lower: + matched = True + break + + # Also check word overlap with a more lenient threshold + if not matched: + block_words = set(text_lower.split()) + if relevant_words and block_words: + overlap_count = len(block_words & relevant_words) + # At least 3 content words overlap (excluding stopwords) + stopwords = {'the','a','an','is','are','was','were','in','on','at','to','for','of','and','or','but','with','by','from','as','it','this','that','be','has','have','had','do','does','did','not','no'} + content_overlap = len((block_words - stopwords) & (relevant_words - stopwords)) + if content_overlap >= 2: + matched = True + + if matched: + relevant_indices.append(idx) + + if not indexed_blocks or len(indexed_blocks) < 3: + return None + + block_text = "\n".join(indexed_blocks) + + results = [] + + # Query-relevant extraction example + if relevant_indices: + intervals = indices_to_intervals(relevant_indices) + user_msg = f"URL: https://example.com\nQuery: {question}\n\nBlocks:\n{block_text}\n\nOutput the index intervals of blocks relevant to the query." + results.append({ + "messages": [ + {"role": "system", "content": SYSTEM_PROMPT_QE}, + {"role": "user", "content": user_msg}, + {"role": "assistant", "content": intervals} + ], + "task_type": "query_relevant", + "source": "htmlrag" + }) + + # Main content extraction example (30% of the time to balance) + if content_indices and random.random() < 0.3: + intervals = indices_to_intervals(content_indices) + user_msg = f"URL: https://example.com\nTitle: Web Page\n\nBlocks:\n{block_text}\n\nOutput the index intervals of main content blocks." + results.append({ + "messages": [ + {"role": "system", "content": SYSTEM_PROMPT_ME}, + {"role": "user", "content": user_msg}, + {"role": "assistant", "content": intervals} + ], + "task_type": "main_content", + "source": "htmlrag" + }) + + return results + + +def load_htmlrag_data(): + """Load and convert HtmlRAG-train data.""" + print("Loading HtmlRAG-train (real web HTML)...") + + # Use 4k and 8k token variants - good balance of context + files = [ + 'nq-4k.jsonl', 'nq-8k.jsonl', + 'asqa-4k.jsonl', 'asqa-8k.jsonl', + 'trivia-qa-4k.jsonl', 'trivia-qa-8k.jsonl', + 'musique-4k.jsonl', 'musique-8k.jsonl', + 'hotpot-qa-4k.jsonl', 'hotpot-qa-8k.jsonl', + ] + + all_examples = [] + + for file in files: + print(f" Processing {file}...") + try: + ds = load_dataset('zstanjj/HtmlRAG-train', data_files=file, split='train') + count = 0 + for row in ds: + results = process_htmlrag_example(row) + if results: + all_examples.extend(results) + count += len(results) + print(f" Got {count} examples from {file}") + except Exception as e: + print(f" Error loading {file}: {e}") + + print(f" Total HtmlRAG examples: {len(all_examples)}") + return all_examples + + +# ============================================================ +# SOURCE 2: MultiHopRAG (News domain) +# ============================================================ + +def process_multihoprag(): + """Convert MultiHopRAG news articles into IndexLM format.""" + print("Loading MultiHopRAG (news domain)...") + + corpus = load_dataset("yixuantt/MultiHopRAG", name="corpus", split="train") + queries = load_dataset("yixuantt/MultiHopRAG", name="MultiHopRAG", split="train") + + # Build URL->article lookup + url_to_article = {} + for article in corpus: + url_to_article[article['url']] = article + + all_examples = [] + + for q_row in queries: + query = q_row['query'] + evidence_list = q_row['evidence_list'] + + for evidence in evidence_list: + url = evidence.get('url', '') + fact = evidence.get('fact', '') + + if url not in url_to_article or not fact: + continue + + article = url_to_article[url] + title = article.get('title', 'News Article') + body = article.get('body', '') + source = article.get('source', 'Unknown') + category = article.get('category', 'general') + + if not body or len(body) < 100: + continue + + # Split article body into paragraphs + paragraphs = [p.strip() for p in body.split('\n') if p.strip() and len(p.strip()) > 20] + if not paragraphs: + continue + + # Build indexed blocks with realistic web structure + blocks = [] + content_indices = [] + relevant_indices = [] + idx = 1 + + # Add realistic header noise + num_header = random.randint(1, 3) + for _ in range(num_header): + blocks.append(f"[{idx}] {random.choice(NOISE_BLOCKS_REALISTIC)}") + idx += 1 + + # Article title + blocks.append(f"[{idx}]

    {title}

    ") + content_indices.append(idx) + idx += 1 + + # Author/date line + author = article.get('author', 'Staff Writer') + published = article.get('published_at', '2024-01-01') + blocks.append(f"[{idx}]
    By {author} | {source} | {published} | Category: {category}
    ") + content_indices.append(idx) + idx += 1 + + # Article paragraphs + fact_words = set(fact.lower().split()) + + for para in paragraphs: + # Determine tag + if len(para) < 60 and not para.endswith('.'): + tag = 'h2' + else: + tag = 'p' + + blocks.append(f"[{idx}] <{tag}>{para}") + content_indices.append(idx) + + # Check if paragraph contains the evidence fact + para_words = set(para.lower().split()) + overlap = len(para_words & fact_words) + if overlap > 5 or (fact_words and overlap / len(fact_words) > 0.3): + relevant_indices.append(idx) + + idx += 1 + + # Occasional mid-article noise + if random.random() < 0.15: + blocks.append(f"[{idx}] {random.choice(NOISE_BLOCKS_REALISTIC)}") + idx += 1 + + # Footer noise + num_footer = random.randint(1, 4) + for _ in range(num_footer): + blocks.append(f"[{idx}] {random.choice(NOISE_BLOCKS_REALISTIC)}") + idx += 1 + + block_text = "\n".join(blocks) + + # Query-relevant extraction + if relevant_indices: + intervals = indices_to_intervals(relevant_indices) + user_msg = f"URL: {url}\nQuery: {query}\n\nBlocks:\n{block_text}\n\nOutput the index intervals of blocks relevant to the query." + all_examples.append({ + "messages": [ + {"role": "system", "content": SYSTEM_PROMPT_QE}, + {"role": "user", "content": user_msg}, + {"role": "assistant", "content": intervals} + ], + "task_type": "query_relevant", + "source": "multihoprag_news" + }) + + # Main content extraction + if content_indices and random.random() < 0.4: + intervals = indices_to_intervals(content_indices) + user_msg = f"URL: {url}\nTitle: {title}\n\nBlocks:\n{block_text}\n\nOutput the index intervals of main content blocks." + all_examples.append({ + "messages": [ + {"role": "system", "content": SYSTEM_PROMPT_ME}, + {"role": "user", "content": user_msg}, + {"role": "assistant", "content": intervals} + ], + "task_type": "main_content", + "source": "multihoprag_news" + }) + + print(f" Total MultiHopRAG examples: {len(all_examples)}") + return all_examples + + +# ============================================================ +# SOURCE 3: HotpotQA (Wikipedia - but balanced as minority) +# ============================================================ + +def process_hotpotqa(): + """Process HotpotQA — kept but as a smaller proportion.""" + print("Loading HotpotQA (Wikipedia domain)...") + ds = load_dataset("hotpotqa/hotpot_qa", "distractor", split="train") + + # Reduced from 15K to 5K — wiki should be minority source + num_samples = min(5000, len(ds)) + ds = ds.shuffle(seed=42).select(range(num_samples)) + + all_examples = [] + skipped = 0 + + for i, row in enumerate(ds): + if i % 1000 == 0: + print(f" Processing {i}/{num_samples}...") + + try: + titles = row['context']['title'] + sentences_list = row['context']['sentences'] + sf = row['supporting_facts'] + + sf_lookup = defaultdict(set) + for title, sent_id in zip(sf['title'], sf['sent_id']): + sf_lookup[title].add(sent_id) + + blocks = [] + relevant_indices = [] + content_indices = [] + idx = 1 + + # Header noise + if random.random() < 0.6: + for _ in range(random.randint(1, 3)): + blocks.append(f"[{idx}] {random.choice(NOISE_BLOCKS_REALISTIC)}") + idx += 1 + + for doc_idx, (title, sentences) in enumerate(zip(titles, sentences_list)): + blocks.append(f"[{idx}]

    {title}

    ") + content_indices.append(idx) + if title in sf_lookup: + relevant_indices.append(idx) + idx += 1 + + for sent_idx, sentence in enumerate(sentences): + sentence = sentence.strip() + if not sentence: + continue + blocks.append(f"[{idx}]

    {sentence}

    ") + content_indices.append(idx) + if title in sf_lookup and sent_idx in sf_lookup[title]: + relevant_indices.append(idx) + idx += 1 + + if random.random() < 0.3 and doc_idx < len(titles) - 1: + blocks.append(f"[{idx}] {random.choice(NOISE_BLOCKS_REALISTIC)}") + idx += 1 + + # Footer noise + if random.random() < 0.6: + for _ in range(random.randint(1, 3)): + blocks.append(f"[{idx}] {random.choice(NOISE_BLOCKS_REALISTIC)}") + idx += 1 + + if len(relevant_indices) < 1: + skipped += 1 + continue + + block_text = "\n".join(blocks) + + # QE example + intervals = indices_to_intervals(relevant_indices) + user_msg = f"URL: https://en.wikipedia.org\nQuery: {row['question']}\n\nBlocks:\n{block_text}\n\nOutput the index intervals of blocks relevant to the query." + all_examples.append({ + "messages": [ + {"role": "system", "content": SYSTEM_PROMPT_QE}, + {"role": "user", "content": user_msg}, + {"role": "assistant", "content": intervals} + ], + "task_type": "query_relevant", + "source": "hotpotqa_wiki" + }) + + # ME example (less frequent - wiki is minority) + if random.random() < 0.3: + intervals = indices_to_intervals(content_indices) + user_msg = f"URL: https://en.wikipedia.org\nTitle: {titles[0]}\n\nBlocks:\n{block_text}\n\nOutput the index intervals of main content blocks." + all_examples.append({ + "messages": [ + {"role": "system", "content": SYSTEM_PROMPT_ME}, + {"role": "user", "content": user_msg}, + {"role": "assistant", "content": intervals} + ], + "task_type": "main_content", + "source": "hotpotqa_wiki" + }) + + except Exception as e: + skipped += 1 + continue + + print(f" Total HotpotQA examples: {len(all_examples)} ({skipped} skipped)") + return all_examples + + +# ============================================================ +# SOURCE 4: MS MARCO (Diverse web QA) +# ============================================================ + +def process_msmarco(): + """Process MS MARCO for diverse web domain QA examples.""" + print("Loading MS MARCO (diverse web QA)...") + + try: + ds = load_dataset("microsoft/ms_marco", "v1.1", split="train") + # Sample a manageable subset + num_samples = min(5000, len(ds)) + ds = ds.shuffle(seed=99).select(range(num_samples)) + except Exception as e: + print(f" Could not load MS MARCO: {e}") + return [] + + all_examples = [] + + for i, row in enumerate(ds): + if i % 1000 == 0: + print(f" Processing {i}/{num_samples}...") + + try: + query = row['query'] + passages = row['passages'] + + if not passages or not passages.get('passage_text'): + continue + + passage_texts = passages['passage_text'] + is_selected = passages.get('is_selected', [0] * len(passage_texts)) + + if not any(is_selected): + continue + + # Build blocks from passages (these are real web snippets from Bing) + blocks = [] + relevant_indices = [] + content_indices = [] + idx = 1 + + # Header noise + if random.random() < 0.5: + for _ in range(random.randint(1, 2)): + blocks.append(f"[{idx}] {random.choice(NOISE_BLOCKS_REALISTIC)}") + idx += 1 + + for p_idx, (text, selected) in enumerate(zip(passage_texts, is_selected)): + text = text.strip() + if not text: + continue + + # Simulate different content types + if p_idx == 0 and random.random() < 0.3: + tag = 'h1' + elif len(text) < 80: + tag = random.choice(['h2', 'h3', 'strong']) + else: + tag = 'p' + + blocks.append(f"[{idx}] <{tag}>{text}") + content_indices.append(idx) + + if selected: + relevant_indices.append(idx) + idx += 1 + + # Between-passage noise + if random.random() < 0.2: + blocks.append(f"[{idx}] {random.choice(NOISE_BLOCKS_REALISTIC)}") + idx += 1 + + # Footer noise + if random.random() < 0.5: + for _ in range(random.randint(1, 2)): + blocks.append(f"[{idx}] {random.choice(NOISE_BLOCKS_REALISTIC)}") + idx += 1 + + if not relevant_indices or len(blocks) < 3: + continue + + block_text = "\n".join(blocks) + + # QE example + intervals = indices_to_intervals(relevant_indices) + query_type = row.get('query_type', 'general') + user_msg = f"URL: https://www.bing.com/search\nQuery: {query}\n\nBlocks:\n{block_text}\n\nOutput the index intervals of blocks relevant to the query." + all_examples.append({ + "messages": [ + {"role": "system", "content": SYSTEM_PROMPT_QE}, + {"role": "user", "content": user_msg}, + {"role": "assistant", "content": intervals} + ], + "task_type": "query_relevant", + "source": f"msmarco_{query_type}" + }) + + except Exception as e: + continue + + print(f" Total MS MARCO examples: {len(all_examples)}") + return all_examples + + +# ============================================================ +# NA Examples (no relevant content) +# ============================================================ + +def create_na_examples(all_examples): + """Create NA examples by mismatching queries with pages.""" + print("Creating NA examples (mismatched query-page pairs)...") + + # Get QE examples + qe_examples = [e for e in all_examples if e['task_type'] == 'query_relevant'] + + if len(qe_examples) < 100: + print(" Too few QE examples for NA generation") + return [] + + na_examples = [] + + for i in range(min(500, len(qe_examples) // 5)): + # Pick two random QE examples + idx_a = random.randint(0, len(qe_examples) - 1) + idx_b = (idx_a + random.randint(100, len(qe_examples) - 1)) % len(qe_examples) + + # Use query from A, blocks from B + msgs_a = qe_examples[idx_a]['messages'] + msgs_b = qe_examples[idx_b]['messages'] + + # Extract query from A + user_a = msgs_a[1]['content'] + query_match = re.search(r'Query: (.+?)(\n|$)', user_a) + if not query_match: + continue + query = query_match.group(1).strip() + + # Extract blocks from B + user_b = msgs_b[1]['content'] + blocks_match = re.search(r'Blocks:\n(.+?)(\n\nOutput)', user_b, re.DOTALL) + if not blocks_match: + continue + blocks = blocks_match.group(1) + + user_msg = f"URL: https://example.com\nQuery: {query}\n\nBlocks:\n{blocks}\n\nOutput the index intervals of blocks relevant to the query." + na_examples.append({ + "messages": [ + {"role": "system", "content": SYSTEM_PROMPT_QE}, + {"role": "user", "content": user_msg}, + {"role": "assistant", "content": "NA"} + ], + "task_type": "query_relevant_na", + "source": "mismatched" + }) + + print(f" Created {len(na_examples)} NA examples") + return na_examples + + +# ============================================================ +# Main Pipeline +# ============================================================ + +def main(): + print("=" * 60) + print("Building DIVERSE IndexLM Training Data") + print("=" * 60) + + # Collect from all sources + htmlrag_examples = load_htmlrag_data() # Real web HTML (primary) + multihoprag_examples = process_multihoprag() # News domain + hotpotqa_examples = process_hotpotqa() # Wikipedia (minority) + msmarco_examples = process_msmarco() # Diverse web QA + + # Combine + all_examples = htmlrag_examples + multihoprag_examples + hotpotqa_examples + msmarco_examples + + # Add NA examples + na_examples = create_na_examples(all_examples) + all_examples.extend(na_examples) + + random.shuffle(all_examples) + + # Print composition + print(f"\n{'='*60}") + print(f"Total examples: {len(all_examples)}") + + source_counts = defaultdict(int) + type_counts = defaultdict(int) + for ex in all_examples: + source_counts[ex.get('source', 'unknown')] += 1 + type_counts[ex['task_type']] += 1 + + print("\nBy source:") + for s, c in sorted(source_counts.items(), key=lambda x: -x[1]): + pct = c / len(all_examples) * 100 + print(f" {s}: {c} ({pct:.1f}%)") + + print("\nBy task type:") + for t, c in sorted(type_counts.items(), key=lambda x: -x[1]): + pct = c / len(all_examples) * 100 + print(f" {t}: {c} ({pct:.1f}%)") + + # Check token lengths + print("\nChecking token lengths...") + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B") + + lengths = [] + for ex in random.sample(all_examples, min(500, len(all_examples))): + text = tokenizer.apply_chat_template(ex['messages'], tokenize=False) + tokens = tokenizer.encode(text) + lengths.append(len(tokens)) + + print(f"Token length stats (sample of {len(lengths)}):") + print(f" Min: {min(lengths)}, Max: {max(lengths)}") + print(f" Mean: {sum(lengths)/len(lengths):.0f}, Median: {sorted(lengths)[len(lengths)//2]}") + + # Filter by length + MAX_LEN = 4096 + filtered = [] + too_long = 0 + for ex in all_examples: + text = tokenizer.apply_chat_template(ex['messages'], tokenize=False) + tokens = tokenizer.encode(text) + if len(tokens) <= MAX_LEN: + filtered.append(ex) + else: + too_long += 1 + + print(f"\nFiltered: {too_long} examples too long (>{MAX_LEN} tokens)") + print(f"Final dataset size: {len(filtered)}") + + # Final composition + final_source_counts = defaultdict(int) + for ex in filtered: + final_source_counts[ex.get('source', 'unknown')] += 1 + print("\nFinal composition by source:") + for s, c in sorted(final_source_counts.items(), key=lambda x: -x[1]): + pct = c / len(filtered) * 100 + print(f" {s}: {c} ({pct:.1f}%)") + + # Split + random.shuffle(filtered) + eval_size = min(500, len(filtered) // 10) + train_data = filtered[:-eval_size] + eval_data = filtered[-eval_size:] + + print(f"\nTrain: {len(train_data)}, Eval: {len(eval_data)}") + + # Create HF datasets + train_ds = Dataset.from_list([{"messages": ex["messages"]} for ex in train_data]) + eval_ds = Dataset.from_list([{"messages": ex["messages"]} for ex in eval_data]) + + # Save locally + train_ds.save_to_disk("/app/indexlm_train_v2") + eval_ds.save_to_disk("/app/indexlm_eval_v2") + + # Push to Hub + ds_dict = DatasetDict({"train": train_ds, "eval": eval_ds}) + ds_dict.push_to_hub("OmAlve/indexlm-training-data", token=os.environ.get("HF_TOKEN")) + + print(f"\n{'='*60}") + print("Done! Dataset pushed to OmAlve/indexlm-training-data") + print(f"{'='*60}") + + +if __name__ == "__main__": + main() diff --git a/runs/Apr24_07-58-54_1eb67182ed08/events.out.tfevents.1777017534.1eb67182ed08.50747.0 b/runs/Apr24_07-58-54_1eb67182ed08/events.out.tfevents.1777017534.1eb67182ed08.50747.0 new file mode 100644 index 0000000..a20a90c --- /dev/null +++ b/runs/Apr24_07-58-54_1eb67182ed08/events.out.tfevents.1777017534.1eb67182ed08.50747.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:21a64bfa465482310059002822d2fb6e5ce46b4d17aaa650d28d737c73ab3c56 +size 5623 diff --git a/runs/Apr24_08-03-09_1eb67182ed08/events.out.tfevents.1777017789.1eb67182ed08.51740.0 b/runs/Apr24_08-03-09_1eb67182ed08/events.out.tfevents.1777017789.1eb67182ed08.51740.0 new file mode 100644 index 0000000..9d559cd --- /dev/null +++ b/runs/Apr24_08-03-09_1eb67182ed08/events.out.tfevents.1777017789.1eb67182ed08.51740.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f6a1e276745fd765b0a9bfafd89a29a6fb29e69b1dd62f21c69573f85769ffeb +size 6005 diff --git a/runs/Apr24_08-03-46_1eb67182ed08/events.out.tfevents.1777017826.1eb67182ed08.52486.0 b/runs/Apr24_08-03-46_1eb67182ed08/events.out.tfevents.1777017826.1eb67182ed08.52486.0 new file mode 100644 index 0000000..e9050c6 --- /dev/null +++ b/runs/Apr24_08-03-46_1eb67182ed08/events.out.tfevents.1777017826.1eb67182ed08.52486.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:70c1d56045aa1b9bc45b1e326393dbb1f37b837654cdc29088ead54fa8106694 +size 6008 diff --git a/runs/Apr24_08-06-45_1eb67182ed08/events.out.tfevents.1777018005.1eb67182ed08.53075.0 b/runs/Apr24_08-06-45_1eb67182ed08/events.out.tfevents.1777018005.1eb67182ed08.53075.0 new file mode 100644 index 0000000..21cd1c0 --- /dev/null +++ b/runs/Apr24_08-06-45_1eb67182ed08/events.out.tfevents.1777018005.1eb67182ed08.53075.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5d9d339d3a50391043b225ee1dcbaec38d846f085ba065605cb9fef69a650d14 +size 200909 diff --git a/tokenizer.json b/tokenizer.json new file mode 100644 index 0000000..c7afbed --- /dev/null +++ b/tokenizer.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:be75606093db2094d7cd20f3c2f385c212750648bd6ea4fb2bf507a6a4c55506 +size 11422650 diff --git a/tokenizer_config.json b/tokenizer_config.json new file mode 100644 index 0000000..7d75d3b --- /dev/null +++ b/tokenizer_config.json @@ -0,0 +1,29 @@ +{ + "add_prefix_space": false, + "backend": "tokenizers", + "bos_token": null, + "clean_up_tokenization_spaces": false, + "eos_token": "<|im_end|>", + "errors": "replace", + "extra_special_tokens": [ + "<|im_start|>", + "<|im_end|>", + "<|object_ref_start|>", + "<|object_ref_end|>", + "<|box_start|>", + "<|box_end|>", + "<|quad_start|>", + "<|quad_end|>", + "<|vision_start|>", + "<|vision_end|>", + "<|vision_pad|>", + "<|image_pad|>", + "<|video_pad|>" + ], + "is_local": false, + "model_max_length": 131072, + "pad_token": "<|endoftext|>", + "split_special_tokens": false, + "tokenizer_class": "Qwen2Tokenizer", + "unk_token": null +} diff --git a/train_indexlm.py b/train_indexlm.py new file mode 100644 index 0000000..9509121 --- /dev/null +++ b/train_indexlm.py @@ -0,0 +1,137 @@ +""" +Reading Steiner - Fine-tune Qwen3-0.6B for Index-based Web Content Extraction + +Based on: "An Index-based Approach for Efficient and Effective Web Content Extraction" (arxiv:2512.06641) +Base model: Qwen/Qwen3-0.6B (0.6B params, ideal for CPU deployment) +Training method: SFT with TRL SFTTrainer +Dataset: OmAlve/indexlm-training-data (21K+ multi-domain examples) +""" + +import os +import torch +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer +from trl import SFTTrainer, SFTConfig +import trackio + +# ============ Configuration ============ +MODEL_ID = "Qwen/Qwen3-0.6B" +DATASET_ID = "OmAlve/indexlm-training-data" +OUTPUT_DIR = "./reading-steiner" +HUB_MODEL_ID = "OmAlve/reading-steiner" + +# Training hyperparameters (from paper: standard SFT) +LEARNING_RATE = 2e-5 +NUM_EPOCHS = 3 +BATCH_SIZE = 4 +GRAD_ACCUM = 4 # Effective batch size = 16 +MAX_SEQ_LENGTH = 4096 +WARMUP_RATIO = 0.05 + +# ============ Setup Trackio ============ +trackio.init( + name="reading-steiner-training", + project="reading-steiner" +) + +# ============ Load Dataset ============ +print("Loading dataset...") +dataset = load_dataset(DATASET_ID) +train_dataset = dataset["train"] +eval_dataset = dataset["eval"] +print(f"Train: {len(train_dataset)}, Eval: {len(eval_dataset)}") + +# ============ Load Model & Tokenizer ============ +print("Loading model and tokenizer...") +tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) + +# Ensure padding token is set +if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + +model = AutoModelForCausalLM.from_pretrained( + MODEL_ID, + torch_dtype=torch.bfloat16, + attn_implementation="flash_attention_2", # Change to "sdpa" if flash-attn unavailable +) + +print(f"Model loaded: {MODEL_ID}") +print(f"Model params: {sum(p.numel() for p in model.parameters()) / 1e6:.1f}M") + +# ============ Training Config ============ +training_args = SFTConfig( + output_dir=OUTPUT_DIR, + num_train_epochs=NUM_EPOCHS, + per_device_train_batch_size=BATCH_SIZE, + per_device_eval_batch_size=BATCH_SIZE, + gradient_accumulation_steps=GRAD_ACCUM, + learning_rate=LEARNING_RATE, + lr_scheduler_type="cosine", + warmup_ratio=WARMUP_RATIO, + weight_decay=0.01, + bf16=True, + gradient_checkpointing=True, + max_length=MAX_SEQ_LENGTH, + # Logging + logging_steps=10, + logging_first_step=True, + logging_strategy="steps", + disable_tqdm=True, + # Evaluation + eval_strategy="steps", + eval_steps=500, + # Saving + save_strategy="steps", + save_steps=500, + save_total_limit=3, + load_best_model_at_end=True, + metric_for_best_model="eval_loss", + greater_is_better=False, + # Hub push + push_to_hub=True, + hub_model_id=HUB_MODEL_ID, + hub_strategy="every_save", + # Performance + dataloader_num_workers=4, + dataloader_pin_memory=True, + # Report + report_to="none", + # Seed + seed=42, +) + +# ============ Initialize Trainer ============ +print("Initializing trainer...") +trainer = SFTTrainer( + model=model, + args=training_args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + processing_class=tokenizer, +) + +# ============ Train ============ +print("Starting training...") +train_result = trainer.train() + +# ============ Save Final Model ============ +print("Saving final model...") +trainer.save_model(OUTPUT_DIR) +tokenizer.save_pretrained(OUTPUT_DIR) + +# Push to Hub +print("Pushing to Hub...") +trainer.push_to_hub(commit_message="Final reading-steiner model") + +# ============ Log Final Metrics ============ +metrics = train_result.metrics +print(f"\nTraining complete!") +print(f" Train loss: {metrics.get('train_loss', 'N/A')}") +print(f" Train runtime: {metrics.get('train_runtime', 'N/A'):.0f}s") +print(f" Train samples/sec: {metrics.get('train_samples_per_second', 'N/A'):.1f}") + +# Final eval +eval_metrics = trainer.evaluate() +print(f" Eval loss: {eval_metrics.get('eval_loss', 'N/A')}") + +print(f"\nModel pushed to: https://huggingface.co/{HUB_MODEL_ID}") diff --git a/training_args.bin b/training_args.bin new file mode 100644 index 0000000..6e940ad --- /dev/null +++ b/training_args.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:79a86123895b91f3924b594437b0ba747ae3d58fb7f8ab73d7c7d6f4be44b100 +size 5304