487 lines
18 KiB
Python
487 lines
18 KiB
Python
"""
|
|
Prepare IndexLM training data from HotpotQA and MSMARCO.
|
|
|
|
Pipeline:
|
|
1. Load HotpotQA (has context = list of (title, sentences) + supporting_facts)
|
|
2. Convert context into indexed HTML-like blocks: [i] <tag>content</tag>
|
|
3. The target is index intervals of blocks containing supporting facts
|
|
4. Also create main-content extraction examples (all content blocks are "main content",
|
|
but we inject noise blocks like nav/ads to train the model to filter them)
|
|
5. Format as conversational messages for SFT
|
|
"""
|
|
|
|
import json
|
|
import random
|
|
import re
|
|
from datasets import load_dataset, Dataset
|
|
from collections import defaultdict
|
|
|
|
random.seed(42)
|
|
|
|
# Noise blocks to inject (simulating real web page clutter)
|
|
NOISE_BLOCKS = [
|
|
'<nav>Home | About | Contact | Privacy Policy</nav>',
|
|
'<div class="ad">Advertisement - Continue Reading Below</div>',
|
|
'<div class="sidebar">Related Articles: Top 10 Facts You Didn\'t Know</div>',
|
|
'<footer>© 2024 All Rights Reserved | Terms of Service</footer>',
|
|
'<div class="cookie-banner">This site uses cookies. Accept | Decline</div>',
|
|
'<div class="social">Share on: Twitter | Facebook | LinkedIn</div>',
|
|
'<nav class="breadcrumb">Home > Category > Subcategory > Article</nav>',
|
|
'<div class="newsletter">Subscribe to our newsletter for updates</div>',
|
|
'<div class="popup">Sign up for free access to premium content</div>',
|
|
'<aside>Trending: Latest news and popular stories</aside>',
|
|
'<div class="comments">Comments (0) - Be the first to comment</div>',
|
|
'<div class="author">Written by Staff Reporter | Updated: Jan 2024</div>',
|
|
'<div class="pagination">Previous | 1 | 2 | 3 | Next</div>',
|
|
'<div class="search">Search this site...</div>',
|
|
'<div class="menu">Categories: Science, Tech, Health, Sports</div>',
|
|
]
|
|
|
|
SYSTEM_PROMPT_QE = """You are IndexLM, a web content extraction model. Given a webpage split into indexed blocks and a user query, identify which blocks contain content relevant to the query.
|
|
|
|
Each block is formatted as: [i] <tag>content</tag>
|
|
Output the indices of relevant blocks as a Python list of [start, end] intervals (inclusive).
|
|
If no relevant content exists, output 'NA'.
|
|
|
|
Example output: [[2,4],[7,7],[10,12]]"""
|
|
|
|
SYSTEM_PROMPT_ME = """You are IndexLM, a web content extraction model. Given a webpage split into indexed blocks, identify which blocks contain the main content of the page (filtering out navigation, advertisements, sidebars, and other non-content elements).
|
|
|
|
Each block is formatted as: [i] <tag>content</tag>
|
|
Output the indices of main content blocks as a Python list of [start, end] intervals (inclusive).
|
|
If no main content exists, output 'NA'.
|
|
|
|
Example output: [[1,3],[5,8],[11,15]]"""
|
|
|
|
|
|
def indices_to_intervals(indices):
|
|
"""Convert a sorted list of indices to intervals [[start,end], ...]"""
|
|
if not indices:
|
|
return "NA"
|
|
indices = sorted(set(indices))
|
|
intervals = []
|
|
start = indices[0]
|
|
end = indices[0]
|
|
for i in indices[1:]:
|
|
if i == end + 1:
|
|
end = i
|
|
else:
|
|
intervals.append([start, end])
|
|
start = i
|
|
end = i
|
|
intervals.append([start, end])
|
|
return json.dumps(intervals)
|
|
|
|
|
|
def create_indexed_blocks_from_hotpotqa(context, supporting_facts, inject_noise=True):
|
|
"""
|
|
Convert HotpotQA context into indexed HTML blocks.
|
|
|
|
context: {'title': [...], 'sentences': [[...], ...]}
|
|
supporting_facts: {'title': [...], 'sent_id': [...]}
|
|
|
|
Returns: (block_text, relevant_indices, all_content_indices)
|
|
"""
|
|
titles = context['title']
|
|
sentences_list = context['sentences']
|
|
|
|
# Build supporting facts lookup
|
|
sf_lookup = defaultdict(set)
|
|
for title, sent_id in zip(supporting_facts['title'], supporting_facts['sent_id']):
|
|
sf_lookup[title].add(sent_id)
|
|
|
|
blocks = []
|
|
relevant_indices = []
|
|
content_indices = [] # All real content (non-noise)
|
|
|
|
idx = 1
|
|
|
|
for doc_idx, (title, sentences) in enumerate(zip(titles, sentences_list)):
|
|
# Title block
|
|
blocks.append(f"[{idx}] <h2>{title}</h2>")
|
|
content_indices.append(idx)
|
|
if title in sf_lookup:
|
|
# Title of a supporting document is relevant
|
|
relevant_indices.append(idx)
|
|
idx += 1
|
|
|
|
# Sentence blocks
|
|
for sent_idx, sentence in enumerate(sentences):
|
|
sentence = sentence.strip()
|
|
if not sentence:
|
|
continue
|
|
|
|
# Use <p> for regular text
|
|
blocks.append(f"[{idx}] <p>{sentence}</p>")
|
|
content_indices.append(idx)
|
|
|
|
if title in sf_lookup and sent_idx in sf_lookup[title]:
|
|
relevant_indices.append(idx)
|
|
idx += 1
|
|
|
|
# Inject noise between documents sometimes
|
|
if inject_noise and random.random() < 0.4 and doc_idx < len(titles) - 1:
|
|
noise = random.choice(NOISE_BLOCKS)
|
|
blocks.append(f"[{idx}] {noise}")
|
|
idx += 1
|
|
|
|
# Sometimes add noise at start and end
|
|
if inject_noise:
|
|
prefix_noise = []
|
|
if random.random() < 0.5:
|
|
for _ in range(random.randint(1, 3)):
|
|
noise = random.choice(NOISE_BLOCKS)
|
|
prefix_noise.append(noise)
|
|
|
|
suffix_noise = []
|
|
if random.random() < 0.5:
|
|
for _ in range(random.randint(1, 3)):
|
|
noise = random.choice(NOISE_BLOCKS)
|
|
suffix_noise.append(noise)
|
|
|
|
if prefix_noise or suffix_noise:
|
|
# Reindex everything
|
|
new_blocks = []
|
|
new_relevant = []
|
|
new_content = []
|
|
new_idx = 1
|
|
|
|
# Prefix noise
|
|
for noise in prefix_noise:
|
|
new_blocks.append(f"[{new_idx}] {noise}")
|
|
new_idx += 1
|
|
|
|
# Remap original blocks
|
|
offset = len(prefix_noise)
|
|
for b in blocks:
|
|
old_idx = int(b.split(']')[0].replace('[', ''))
|
|
new_b = f"[{old_idx + offset}] " + '] '.join(b.split('] ')[1:])
|
|
new_blocks.append(new_b)
|
|
|
|
new_relevant = [r + offset for r in relevant_indices]
|
|
new_content = [c + offset for c in content_indices]
|
|
|
|
# Suffix noise
|
|
next_idx = len(new_blocks) + 1
|
|
for noise in suffix_noise:
|
|
new_blocks.append(f"[{next_idx}] {noise}")
|
|
next_idx += 1
|
|
|
|
blocks = new_blocks
|
|
relevant_indices = new_relevant
|
|
content_indices = new_content
|
|
|
|
block_text = "\n".join(blocks)
|
|
return block_text, relevant_indices, content_indices
|
|
|
|
|
|
def build_query_relevant_example(question, block_text, relevant_indices, url="https://en.wikipedia.org"):
|
|
"""Build a query-relevant extraction (QE) example."""
|
|
intervals = indices_to_intervals(relevant_indices)
|
|
|
|
user_content = f"URL: {url}\nQuery: {question}\n\nBlocks:\n{block_text}\n\nOutput the index intervals of blocks relevant to the query."
|
|
|
|
messages = [
|
|
{"role": "system", "content": SYSTEM_PROMPT_QE},
|
|
{"role": "user", "content": user_content},
|
|
{"role": "assistant", "content": intervals}
|
|
]
|
|
return messages
|
|
|
|
|
|
def build_main_content_example(block_text, content_indices, title="Wikipedia Article", url="https://en.wikipedia.org"):
|
|
"""Build a main content extraction (ME) example."""
|
|
intervals = indices_to_intervals(content_indices)
|
|
|
|
user_content = f"URL: {url}\nTitle: {title}\n\nBlocks:\n{block_text}\n\nOutput the index intervals of main content blocks."
|
|
|
|
messages = [
|
|
{"role": "system", "content": SYSTEM_PROMPT_ME},
|
|
{"role": "user", "content": user_content},
|
|
{"role": "assistant", "content": intervals}
|
|
]
|
|
return messages
|
|
|
|
|
|
def process_hotpotqa():
|
|
"""Process HotpotQA into IndexLM training data."""
|
|
print("Loading HotpotQA...")
|
|
ds = load_dataset("hotpotqa/hotpot_qa", "distractor", split="train")
|
|
|
|
# Sample a manageable amount
|
|
num_samples = min(15000, len(ds))
|
|
ds = ds.shuffle(seed=42).select(range(num_samples))
|
|
|
|
all_examples = []
|
|
skipped = 0
|
|
|
|
for i, row in enumerate(ds):
|
|
if i % 1000 == 0:
|
|
print(f"Processing {i}/{num_samples}...")
|
|
|
|
try:
|
|
block_text, relevant_indices, content_indices = create_indexed_blocks_from_hotpotqa(
|
|
row['context'], row['supporting_facts'], inject_noise=True
|
|
)
|
|
|
|
# Skip if too few relevant indices
|
|
if len(relevant_indices) < 1:
|
|
skipped += 1
|
|
continue
|
|
|
|
# Query-relevant extraction example
|
|
qe_messages = build_query_relevant_example(
|
|
row['question'], block_text, relevant_indices
|
|
)
|
|
all_examples.append({
|
|
"messages": qe_messages,
|
|
"task_type": "query_relevant",
|
|
"source": "hotpotqa"
|
|
})
|
|
|
|
# Main content extraction example (50% of the time)
|
|
if random.random() < 0.5:
|
|
me_messages = build_main_content_example(
|
|
block_text, content_indices,
|
|
title=row['context']['title'][0] if row['context']['title'] else "Article"
|
|
)
|
|
all_examples.append({
|
|
"messages": me_messages,
|
|
"task_type": "main_content",
|
|
"source": "hotpotqa"
|
|
})
|
|
except Exception as e:
|
|
skipped += 1
|
|
if skipped < 5:
|
|
print(f"Error on row {i}: {e}")
|
|
continue
|
|
|
|
print(f"Created {len(all_examples)} examples from HotpotQA ({skipped} skipped)")
|
|
return all_examples
|
|
|
|
|
|
def create_synthetic_web_pages():
|
|
"""Create synthetic web page examples for main content extraction training."""
|
|
print("Creating synthetic web page examples...")
|
|
|
|
# Load a text dataset to get content
|
|
ds = load_dataset("hotpotqa/hotpot_qa", "distractor", split="validation")
|
|
ds = ds.shuffle(seed=123).select(range(3000))
|
|
|
|
examples = []
|
|
|
|
for i, row in enumerate(ds):
|
|
if i % 500 == 0:
|
|
print(f"Synthetic page {i}/3000...")
|
|
|
|
try:
|
|
# Build a more realistic web page structure
|
|
titles = row['context']['title']
|
|
sentences_list = row['context']['sentences']
|
|
|
|
if not titles or not sentences_list:
|
|
continue
|
|
|
|
blocks = []
|
|
content_indices = []
|
|
idx = 1
|
|
|
|
# Header noise (nav, etc.)
|
|
num_header_noise = random.randint(1, 4)
|
|
for _ in range(num_header_noise):
|
|
blocks.append(f"[{idx}] {random.choice(NOISE_BLOCKS)}")
|
|
idx += 1
|
|
|
|
# Page title
|
|
main_title = titles[0]
|
|
blocks.append(f"[{idx}] <h1>{main_title}</h1>")
|
|
content_indices.append(idx)
|
|
idx += 1
|
|
|
|
# Main content (just first 1-3 documents)
|
|
num_docs = min(random.randint(1, 3), len(titles))
|
|
for doc_idx in range(num_docs):
|
|
title = titles[doc_idx]
|
|
sents = sentences_list[doc_idx]
|
|
|
|
if doc_idx > 0:
|
|
blocks.append(f"[{idx}] <h2>{title}</h2>")
|
|
content_indices.append(idx)
|
|
idx += 1
|
|
|
|
for sent in sents:
|
|
sent = sent.strip()
|
|
if not sent:
|
|
continue
|
|
blocks.append(f"[{idx}] <p>{sent}</p>")
|
|
content_indices.append(idx)
|
|
idx += 1
|
|
|
|
# Occasional inline noise
|
|
if random.random() < 0.3:
|
|
blocks.append(f"[{idx}] {random.choice(NOISE_BLOCKS)}")
|
|
idx += 1
|
|
|
|
# Footer noise
|
|
num_footer_noise = random.randint(1, 4)
|
|
for _ in range(num_footer_noise):
|
|
blocks.append(f"[{idx}] {random.choice(NOISE_BLOCKS)}")
|
|
idx += 1
|
|
|
|
block_text = "\n".join(blocks)
|
|
me_messages = build_main_content_example(
|
|
block_text, content_indices,
|
|
title=main_title,
|
|
url=f"https://en.wikipedia.org/wiki/{main_title.replace(' ', '_')}"
|
|
)
|
|
examples.append({
|
|
"messages": me_messages,
|
|
"task_type": "main_content",
|
|
"source": "synthetic"
|
|
})
|
|
except Exception as e:
|
|
continue
|
|
|
|
print(f"Created {len(examples)} synthetic web page examples")
|
|
return examples
|
|
|
|
|
|
def create_na_examples():
|
|
"""Create examples where no relevant content exists (model should output 'NA')."""
|
|
print("Creating NA examples...")
|
|
ds = load_dataset("hotpotqa/hotpot_qa", "distractor", split="validation")
|
|
ds = ds.shuffle(seed=456).select(range(1000))
|
|
|
|
examples = []
|
|
|
|
for i, row in enumerate(ds):
|
|
try:
|
|
# Use context from one question but query from another (mismatched)
|
|
other_idx = (i + 500) % len(ds)
|
|
other_question = ds[other_idx]['question']
|
|
|
|
# Build blocks from current context but keep only non-supporting content
|
|
block_text, _, content_indices = create_indexed_blocks_from_hotpotqa(
|
|
row['context'], {'title': [], 'sent_id': []}, inject_noise=True
|
|
)
|
|
|
|
# The query doesn't match this content → expected output: NA
|
|
# But actually some content might still be tangentially relevant,
|
|
# so we'll be conservative and only do this for clearly mismatched pairs
|
|
user_content = f"URL: https://en.wikipedia.org\nQuery: {other_question}\n\nBlocks:\n{block_text}\n\nOutput the index intervals of blocks relevant to the query."
|
|
|
|
messages = [
|
|
{"role": "system", "content": SYSTEM_PROMPT_QE},
|
|
{"role": "user", "content": user_content},
|
|
{"role": "assistant", "content": "NA"}
|
|
]
|
|
examples.append({
|
|
"messages": messages,
|
|
"task_type": "query_relevant_na",
|
|
"source": "hotpotqa_mismatched"
|
|
})
|
|
except:
|
|
continue
|
|
|
|
# Keep only a fraction (the paper mentions partial filtering of NA)
|
|
random.shuffle(examples)
|
|
examples = examples[:300]
|
|
print(f"Created {len(examples)} NA examples")
|
|
return examples
|
|
|
|
|
|
def main():
|
|
# Build all training examples
|
|
qe_examples = process_hotpotqa()
|
|
me_examples = create_synthetic_web_pages()
|
|
na_examples = create_na_examples()
|
|
|
|
all_examples = qe_examples + me_examples + na_examples
|
|
random.shuffle(all_examples)
|
|
|
|
print(f"\nTotal examples: {len(all_examples)}")
|
|
|
|
# Count by type
|
|
type_counts = defaultdict(int)
|
|
for ex in all_examples:
|
|
type_counts[ex['task_type']] += 1
|
|
for t, c in type_counts.items():
|
|
print(f" {t}: {c}")
|
|
|
|
# Check lengths
|
|
from transformers import AutoTokenizer
|
|
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
|
|
|
|
lengths = []
|
|
for ex in all_examples[:500]:
|
|
text = tokenizer.apply_chat_template(ex['messages'], tokenize=False)
|
|
tokens = tokenizer.encode(text)
|
|
lengths.append(len(tokens))
|
|
|
|
print(f"\nToken length stats (sample of 500):")
|
|
print(f" Min: {min(lengths)}")
|
|
print(f" Max: {max(lengths)}")
|
|
print(f" Mean: {sum(lengths)/len(lengths):.0f}")
|
|
print(f" Median: {sorted(lengths)[len(lengths)//2]}")
|
|
|
|
# Filter out examples that are too long (>4096 tokens for efficiency)
|
|
MAX_LEN = 4096
|
|
filtered = []
|
|
too_long = 0
|
|
for ex in all_examples:
|
|
text = tokenizer.apply_chat_template(ex['messages'], tokenize=False)
|
|
tokens = tokenizer.encode(text)
|
|
if len(tokens) <= MAX_LEN:
|
|
filtered.append(ex)
|
|
else:
|
|
too_long += 1
|
|
|
|
print(f"\nFiltered: {too_long} examples too long (>{MAX_LEN} tokens)")
|
|
print(f"Final dataset: {len(filtered)} examples")
|
|
|
|
# Split into train/eval
|
|
random.shuffle(filtered)
|
|
eval_size = min(500, len(filtered) // 10)
|
|
train_data = filtered[:-eval_size]
|
|
eval_data = filtered[-eval_size:]
|
|
|
|
print(f"Train: {len(train_data)}, Eval: {len(eval_data)}")
|
|
|
|
# Create HF dataset with just messages column (for SFTTrainer)
|
|
train_ds = Dataset.from_list([{"messages": ex["messages"]} for ex in train_data])
|
|
eval_ds = Dataset.from_list([{"messages": ex["messages"]} for ex in eval_data])
|
|
|
|
# Save locally
|
|
train_ds.save_to_disk("/app/indexlm_train")
|
|
eval_ds.save_to_disk("/app/indexlm_eval")
|
|
|
|
# Also push to HF Hub
|
|
from huggingface_hub import login
|
|
import os
|
|
login(token=os.environ.get("HF_TOKEN"))
|
|
|
|
from datasets import DatasetDict
|
|
ds_dict = DatasetDict({"train": train_ds, "eval": eval_ds})
|
|
ds_dict.push_to_hub("OmAlve/indexlm-training-data")
|
|
|
|
print("\nDone! Dataset pushed to OmAlve/indexlm-training-data")
|
|
|
|
# Print sample
|
|
print("\n=== Sample QE example ===")
|
|
for ex in train_data[:3]:
|
|
if ex.get("task_type", "") == "query_relevant":
|
|
for m in ex["messages"]:
|
|
print(f"\n[{m['role']}]: {m['content'][:200]}...")
|
|
break
|
|
|
|
print("\n=== Sample ME example ===")
|
|
for ex in train_data[:10]:
|
|
if ex.get("task_type", "") == "main_content":
|
|
for m in ex["messages"]:
|
|
print(f"\n[{m['role']}]: {m['content'][:200]}...")
|
|
break
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|