初始化项目,由ModelHub XC社区提供模型
Model: OmAlve/reading-steiner Source: Original Platform
This commit is contained in:
486
prepare_data.py
Normal file
486
prepare_data.py
Normal file
@@ -0,0 +1,486 @@
|
||||
"""
|
||||
Prepare IndexLM training data from HotpotQA and MSMARCO.
|
||||
|
||||
Pipeline:
|
||||
1. Load HotpotQA (has context = list of (title, sentences) + supporting_facts)
|
||||
2. Convert context into indexed HTML-like blocks: [i] <tag>content</tag>
|
||||
3. The target is index intervals of blocks containing supporting facts
|
||||
4. Also create main-content extraction examples (all content blocks are "main content",
|
||||
but we inject noise blocks like nav/ads to train the model to filter them)
|
||||
5. Format as conversational messages for SFT
|
||||
"""
|
||||
|
||||
import json
|
||||
import random
|
||||
import re
|
||||
from datasets import load_dataset, Dataset
|
||||
from collections import defaultdict
|
||||
|
||||
random.seed(42)
|
||||
|
||||
# Noise blocks to inject (simulating real web page clutter)
|
||||
NOISE_BLOCKS = [
|
||||
'<nav>Home | About | Contact | Privacy Policy</nav>',
|
||||
'<div class="ad">Advertisement - Continue Reading Below</div>',
|
||||
'<div class="sidebar">Related Articles: Top 10 Facts You Didn\'t Know</div>',
|
||||
'<footer>© 2024 All Rights Reserved | Terms of Service</footer>',
|
||||
'<div class="cookie-banner">This site uses cookies. Accept | Decline</div>',
|
||||
'<div class="social">Share on: Twitter | Facebook | LinkedIn</div>',
|
||||
'<nav class="breadcrumb">Home > Category > Subcategory > Article</nav>',
|
||||
'<div class="newsletter">Subscribe to our newsletter for updates</div>',
|
||||
'<div class="popup">Sign up for free access to premium content</div>',
|
||||
'<aside>Trending: Latest news and popular stories</aside>',
|
||||
'<div class="comments">Comments (0) - Be the first to comment</div>',
|
||||
'<div class="author">Written by Staff Reporter | Updated: Jan 2024</div>',
|
||||
'<div class="pagination">Previous | 1 | 2 | 3 | Next</div>',
|
||||
'<div class="search">Search this site...</div>',
|
||||
'<div class="menu">Categories: Science, Tech, Health, Sports</div>',
|
||||
]
|
||||
|
||||
SYSTEM_PROMPT_QE = """You are IndexLM, a web content extraction model. Given a webpage split into indexed blocks and a user query, identify which blocks contain content relevant to the query.
|
||||
|
||||
Each block is formatted as: [i] <tag>content</tag>
|
||||
Output the indices of relevant blocks as a Python list of [start, end] intervals (inclusive).
|
||||
If no relevant content exists, output 'NA'.
|
||||
|
||||
Example output: [[2,4],[7,7],[10,12]]"""
|
||||
|
||||
SYSTEM_PROMPT_ME = """You are IndexLM, a web content extraction model. Given a webpage split into indexed blocks, identify which blocks contain the main content of the page (filtering out navigation, advertisements, sidebars, and other non-content elements).
|
||||
|
||||
Each block is formatted as: [i] <tag>content</tag>
|
||||
Output the indices of main content blocks as a Python list of [start, end] intervals (inclusive).
|
||||
If no main content exists, output 'NA'.
|
||||
|
||||
Example output: [[1,3],[5,8],[11,15]]"""
|
||||
|
||||
|
||||
def indices_to_intervals(indices):
|
||||
"""Convert a sorted list of indices to intervals [[start,end], ...]"""
|
||||
if not indices:
|
||||
return "NA"
|
||||
indices = sorted(set(indices))
|
||||
intervals = []
|
||||
start = indices[0]
|
||||
end = indices[0]
|
||||
for i in indices[1:]:
|
||||
if i == end + 1:
|
||||
end = i
|
||||
else:
|
||||
intervals.append([start, end])
|
||||
start = i
|
||||
end = i
|
||||
intervals.append([start, end])
|
||||
return json.dumps(intervals)
|
||||
|
||||
|
||||
def create_indexed_blocks_from_hotpotqa(context, supporting_facts, inject_noise=True):
|
||||
"""
|
||||
Convert HotpotQA context into indexed HTML blocks.
|
||||
|
||||
context: {'title': [...], 'sentences': [[...], ...]}
|
||||
supporting_facts: {'title': [...], 'sent_id': [...]}
|
||||
|
||||
Returns: (block_text, relevant_indices, all_content_indices)
|
||||
"""
|
||||
titles = context['title']
|
||||
sentences_list = context['sentences']
|
||||
|
||||
# Build supporting facts lookup
|
||||
sf_lookup = defaultdict(set)
|
||||
for title, sent_id in zip(supporting_facts['title'], supporting_facts['sent_id']):
|
||||
sf_lookup[title].add(sent_id)
|
||||
|
||||
blocks = []
|
||||
relevant_indices = []
|
||||
content_indices = [] # All real content (non-noise)
|
||||
|
||||
idx = 1
|
||||
|
||||
for doc_idx, (title, sentences) in enumerate(zip(titles, sentences_list)):
|
||||
# Title block
|
||||
blocks.append(f"[{idx}] <h2>{title}</h2>")
|
||||
content_indices.append(idx)
|
||||
if title in sf_lookup:
|
||||
# Title of a supporting document is relevant
|
||||
relevant_indices.append(idx)
|
||||
idx += 1
|
||||
|
||||
# Sentence blocks
|
||||
for sent_idx, sentence in enumerate(sentences):
|
||||
sentence = sentence.strip()
|
||||
if not sentence:
|
||||
continue
|
||||
|
||||
# Use <p> for regular text
|
||||
blocks.append(f"[{idx}] <p>{sentence}</p>")
|
||||
content_indices.append(idx)
|
||||
|
||||
if title in sf_lookup and sent_idx in sf_lookup[title]:
|
||||
relevant_indices.append(idx)
|
||||
idx += 1
|
||||
|
||||
# Inject noise between documents sometimes
|
||||
if inject_noise and random.random() < 0.4 and doc_idx < len(titles) - 1:
|
||||
noise = random.choice(NOISE_BLOCKS)
|
||||
blocks.append(f"[{idx}] {noise}")
|
||||
idx += 1
|
||||
|
||||
# Sometimes add noise at start and end
|
||||
if inject_noise:
|
||||
prefix_noise = []
|
||||
if random.random() < 0.5:
|
||||
for _ in range(random.randint(1, 3)):
|
||||
noise = random.choice(NOISE_BLOCKS)
|
||||
prefix_noise.append(noise)
|
||||
|
||||
suffix_noise = []
|
||||
if random.random() < 0.5:
|
||||
for _ in range(random.randint(1, 3)):
|
||||
noise = random.choice(NOISE_BLOCKS)
|
||||
suffix_noise.append(noise)
|
||||
|
||||
if prefix_noise or suffix_noise:
|
||||
# Reindex everything
|
||||
new_blocks = []
|
||||
new_relevant = []
|
||||
new_content = []
|
||||
new_idx = 1
|
||||
|
||||
# Prefix noise
|
||||
for noise in prefix_noise:
|
||||
new_blocks.append(f"[{new_idx}] {noise}")
|
||||
new_idx += 1
|
||||
|
||||
# Remap original blocks
|
||||
offset = len(prefix_noise)
|
||||
for b in blocks:
|
||||
old_idx = int(b.split(']')[0].replace('[', ''))
|
||||
new_b = f"[{old_idx + offset}] " + '] '.join(b.split('] ')[1:])
|
||||
new_blocks.append(new_b)
|
||||
|
||||
new_relevant = [r + offset for r in relevant_indices]
|
||||
new_content = [c + offset for c in content_indices]
|
||||
|
||||
# Suffix noise
|
||||
next_idx = len(new_blocks) + 1
|
||||
for noise in suffix_noise:
|
||||
new_blocks.append(f"[{next_idx}] {noise}")
|
||||
next_idx += 1
|
||||
|
||||
blocks = new_blocks
|
||||
relevant_indices = new_relevant
|
||||
content_indices = new_content
|
||||
|
||||
block_text = "\n".join(blocks)
|
||||
return block_text, relevant_indices, content_indices
|
||||
|
||||
|
||||
def build_query_relevant_example(question, block_text, relevant_indices, url="https://en.wikipedia.org"):
|
||||
"""Build a query-relevant extraction (QE) example."""
|
||||
intervals = indices_to_intervals(relevant_indices)
|
||||
|
||||
user_content = f"URL: {url}\nQuery: {question}\n\nBlocks:\n{block_text}\n\nOutput the index intervals of blocks relevant to the query."
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": SYSTEM_PROMPT_QE},
|
||||
{"role": "user", "content": user_content},
|
||||
{"role": "assistant", "content": intervals}
|
||||
]
|
||||
return messages
|
||||
|
||||
|
||||
def build_main_content_example(block_text, content_indices, title="Wikipedia Article", url="https://en.wikipedia.org"):
|
||||
"""Build a main content extraction (ME) example."""
|
||||
intervals = indices_to_intervals(content_indices)
|
||||
|
||||
user_content = f"URL: {url}\nTitle: {title}\n\nBlocks:\n{block_text}\n\nOutput the index intervals of main content blocks."
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": SYSTEM_PROMPT_ME},
|
||||
{"role": "user", "content": user_content},
|
||||
{"role": "assistant", "content": intervals}
|
||||
]
|
||||
return messages
|
||||
|
||||
|
||||
def process_hotpotqa():
|
||||
"""Process HotpotQA into IndexLM training data."""
|
||||
print("Loading HotpotQA...")
|
||||
ds = load_dataset("hotpotqa/hotpot_qa", "distractor", split="train")
|
||||
|
||||
# Sample a manageable amount
|
||||
num_samples = min(15000, len(ds))
|
||||
ds = ds.shuffle(seed=42).select(range(num_samples))
|
||||
|
||||
all_examples = []
|
||||
skipped = 0
|
||||
|
||||
for i, row in enumerate(ds):
|
||||
if i % 1000 == 0:
|
||||
print(f"Processing {i}/{num_samples}...")
|
||||
|
||||
try:
|
||||
block_text, relevant_indices, content_indices = create_indexed_blocks_from_hotpotqa(
|
||||
row['context'], row['supporting_facts'], inject_noise=True
|
||||
)
|
||||
|
||||
# Skip if too few relevant indices
|
||||
if len(relevant_indices) < 1:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
# Query-relevant extraction example
|
||||
qe_messages = build_query_relevant_example(
|
||||
row['question'], block_text, relevant_indices
|
||||
)
|
||||
all_examples.append({
|
||||
"messages": qe_messages,
|
||||
"task_type": "query_relevant",
|
||||
"source": "hotpotqa"
|
||||
})
|
||||
|
||||
# Main content extraction example (50% of the time)
|
||||
if random.random() < 0.5:
|
||||
me_messages = build_main_content_example(
|
||||
block_text, content_indices,
|
||||
title=row['context']['title'][0] if row['context']['title'] else "Article"
|
||||
)
|
||||
all_examples.append({
|
||||
"messages": me_messages,
|
||||
"task_type": "main_content",
|
||||
"source": "hotpotqa"
|
||||
})
|
||||
except Exception as e:
|
||||
skipped += 1
|
||||
if skipped < 5:
|
||||
print(f"Error on row {i}: {e}")
|
||||
continue
|
||||
|
||||
print(f"Created {len(all_examples)} examples from HotpotQA ({skipped} skipped)")
|
||||
return all_examples
|
||||
|
||||
|
||||
def create_synthetic_web_pages():
|
||||
"""Create synthetic web page examples for main content extraction training."""
|
||||
print("Creating synthetic web page examples...")
|
||||
|
||||
# Load a text dataset to get content
|
||||
ds = load_dataset("hotpotqa/hotpot_qa", "distractor", split="validation")
|
||||
ds = ds.shuffle(seed=123).select(range(3000))
|
||||
|
||||
examples = []
|
||||
|
||||
for i, row in enumerate(ds):
|
||||
if i % 500 == 0:
|
||||
print(f"Synthetic page {i}/3000...")
|
||||
|
||||
try:
|
||||
# Build a more realistic web page structure
|
||||
titles = row['context']['title']
|
||||
sentences_list = row['context']['sentences']
|
||||
|
||||
if not titles or not sentences_list:
|
||||
continue
|
||||
|
||||
blocks = []
|
||||
content_indices = []
|
||||
idx = 1
|
||||
|
||||
# Header noise (nav, etc.)
|
||||
num_header_noise = random.randint(1, 4)
|
||||
for _ in range(num_header_noise):
|
||||
blocks.append(f"[{idx}] {random.choice(NOISE_BLOCKS)}")
|
||||
idx += 1
|
||||
|
||||
# Page title
|
||||
main_title = titles[0]
|
||||
blocks.append(f"[{idx}] <h1>{main_title}</h1>")
|
||||
content_indices.append(idx)
|
||||
idx += 1
|
||||
|
||||
# Main content (just first 1-3 documents)
|
||||
num_docs = min(random.randint(1, 3), len(titles))
|
||||
for doc_idx in range(num_docs):
|
||||
title = titles[doc_idx]
|
||||
sents = sentences_list[doc_idx]
|
||||
|
||||
if doc_idx > 0:
|
||||
blocks.append(f"[{idx}] <h2>{title}</h2>")
|
||||
content_indices.append(idx)
|
||||
idx += 1
|
||||
|
||||
for sent in sents:
|
||||
sent = sent.strip()
|
||||
if not sent:
|
||||
continue
|
||||
blocks.append(f"[{idx}] <p>{sent}</p>")
|
||||
content_indices.append(idx)
|
||||
idx += 1
|
||||
|
||||
# Occasional inline noise
|
||||
if random.random() < 0.3:
|
||||
blocks.append(f"[{idx}] {random.choice(NOISE_BLOCKS)}")
|
||||
idx += 1
|
||||
|
||||
# Footer noise
|
||||
num_footer_noise = random.randint(1, 4)
|
||||
for _ in range(num_footer_noise):
|
||||
blocks.append(f"[{idx}] {random.choice(NOISE_BLOCKS)}")
|
||||
idx += 1
|
||||
|
||||
block_text = "\n".join(blocks)
|
||||
me_messages = build_main_content_example(
|
||||
block_text, content_indices,
|
||||
title=main_title,
|
||||
url=f"https://en.wikipedia.org/wiki/{main_title.replace(' ', '_')}"
|
||||
)
|
||||
examples.append({
|
||||
"messages": me_messages,
|
||||
"task_type": "main_content",
|
||||
"source": "synthetic"
|
||||
})
|
||||
except Exception as e:
|
||||
continue
|
||||
|
||||
print(f"Created {len(examples)} synthetic web page examples")
|
||||
return examples
|
||||
|
||||
|
||||
def create_na_examples():
|
||||
"""Create examples where no relevant content exists (model should output 'NA')."""
|
||||
print("Creating NA examples...")
|
||||
ds = load_dataset("hotpotqa/hotpot_qa", "distractor", split="validation")
|
||||
ds = ds.shuffle(seed=456).select(range(1000))
|
||||
|
||||
examples = []
|
||||
|
||||
for i, row in enumerate(ds):
|
||||
try:
|
||||
# Use context from one question but query from another (mismatched)
|
||||
other_idx = (i + 500) % len(ds)
|
||||
other_question = ds[other_idx]['question']
|
||||
|
||||
# Build blocks from current context but keep only non-supporting content
|
||||
block_text, _, content_indices = create_indexed_blocks_from_hotpotqa(
|
||||
row['context'], {'title': [], 'sent_id': []}, inject_noise=True
|
||||
)
|
||||
|
||||
# The query doesn't match this content → expected output: NA
|
||||
# But actually some content might still be tangentially relevant,
|
||||
# so we'll be conservative and only do this for clearly mismatched pairs
|
||||
user_content = f"URL: https://en.wikipedia.org\nQuery: {other_question}\n\nBlocks:\n{block_text}\n\nOutput the index intervals of blocks relevant to the query."
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": SYSTEM_PROMPT_QE},
|
||||
{"role": "user", "content": user_content},
|
||||
{"role": "assistant", "content": "NA"}
|
||||
]
|
||||
examples.append({
|
||||
"messages": messages,
|
||||
"task_type": "query_relevant_na",
|
||||
"source": "hotpotqa_mismatched"
|
||||
})
|
||||
except:
|
||||
continue
|
||||
|
||||
# Keep only a fraction (the paper mentions partial filtering of NA)
|
||||
random.shuffle(examples)
|
||||
examples = examples[:300]
|
||||
print(f"Created {len(examples)} NA examples")
|
||||
return examples
|
||||
|
||||
|
||||
def main():
|
||||
# Build all training examples
|
||||
qe_examples = process_hotpotqa()
|
||||
me_examples = create_synthetic_web_pages()
|
||||
na_examples = create_na_examples()
|
||||
|
||||
all_examples = qe_examples + me_examples + na_examples
|
||||
random.shuffle(all_examples)
|
||||
|
||||
print(f"\nTotal examples: {len(all_examples)}")
|
||||
|
||||
# Count by type
|
||||
type_counts = defaultdict(int)
|
||||
for ex in all_examples:
|
||||
type_counts[ex['task_type']] += 1
|
||||
for t, c in type_counts.items():
|
||||
print(f" {t}: {c}")
|
||||
|
||||
# Check lengths
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
|
||||
|
||||
lengths = []
|
||||
for ex in all_examples[:500]:
|
||||
text = tokenizer.apply_chat_template(ex['messages'], tokenize=False)
|
||||
tokens = tokenizer.encode(text)
|
||||
lengths.append(len(tokens))
|
||||
|
||||
print(f"\nToken length stats (sample of 500):")
|
||||
print(f" Min: {min(lengths)}")
|
||||
print(f" Max: {max(lengths)}")
|
||||
print(f" Mean: {sum(lengths)/len(lengths):.0f}")
|
||||
print(f" Median: {sorted(lengths)[len(lengths)//2]}")
|
||||
|
||||
# Filter out examples that are too long (>4096 tokens for efficiency)
|
||||
MAX_LEN = 4096
|
||||
filtered = []
|
||||
too_long = 0
|
||||
for ex in all_examples:
|
||||
text = tokenizer.apply_chat_template(ex['messages'], tokenize=False)
|
||||
tokens = tokenizer.encode(text)
|
||||
if len(tokens) <= MAX_LEN:
|
||||
filtered.append(ex)
|
||||
else:
|
||||
too_long += 1
|
||||
|
||||
print(f"\nFiltered: {too_long} examples too long (>{MAX_LEN} tokens)")
|
||||
print(f"Final dataset: {len(filtered)} examples")
|
||||
|
||||
# Split into train/eval
|
||||
random.shuffle(filtered)
|
||||
eval_size = min(500, len(filtered) // 10)
|
||||
train_data = filtered[:-eval_size]
|
||||
eval_data = filtered[-eval_size:]
|
||||
|
||||
print(f"Train: {len(train_data)}, Eval: {len(eval_data)}")
|
||||
|
||||
# Create HF dataset with just messages column (for SFTTrainer)
|
||||
train_ds = Dataset.from_list([{"messages": ex["messages"]} for ex in train_data])
|
||||
eval_ds = Dataset.from_list([{"messages": ex["messages"]} for ex in eval_data])
|
||||
|
||||
# Save locally
|
||||
train_ds.save_to_disk("/app/indexlm_train")
|
||||
eval_ds.save_to_disk("/app/indexlm_eval")
|
||||
|
||||
# Also push to HF Hub
|
||||
from huggingface_hub import login
|
||||
import os
|
||||
login(token=os.environ.get("HF_TOKEN"))
|
||||
|
||||
from datasets import DatasetDict
|
||||
ds_dict = DatasetDict({"train": train_ds, "eval": eval_ds})
|
||||
ds_dict.push_to_hub("OmAlve/indexlm-training-data")
|
||||
|
||||
print("\nDone! Dataset pushed to OmAlve/indexlm-training-data")
|
||||
|
||||
# Print sample
|
||||
print("\n=== Sample QE example ===")
|
||||
for ex in train_data[:3]:
|
||||
if ex.get("task_type", "") == "query_relevant":
|
||||
for m in ex["messages"]:
|
||||
print(f"\n[{m['role']}]: {m['content'][:200]}...")
|
||||
break
|
||||
|
||||
print("\n=== Sample ME example ===")
|
||||
for ex in train_data[:10]:
|
||||
if ex.get("task_type", "") == "main_content":
|
||||
for m in ex["messages"]:
|
||||
print(f"\n[{m['role']}]: {m['content'][:200]}...")
|
||||
break
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user