138 lines
3.9 KiB
Python
138 lines
3.9 KiB
Python
"""
|
|
Reading Steiner - Fine-tune Qwen3-0.6B for Index-based Web Content Extraction
|
|
|
|
Based on: "An Index-based Approach for Efficient and Effective Web Content Extraction" (arxiv:2512.06641)
|
|
Base model: Qwen/Qwen3-0.6B (0.6B params, ideal for CPU deployment)
|
|
Training method: SFT with TRL SFTTrainer
|
|
Dataset: OmAlve/indexlm-training-data (21K+ multi-domain examples)
|
|
"""
|
|
|
|
import os
|
|
import torch
|
|
from datasets import load_dataset
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
from trl import SFTTrainer, SFTConfig
|
|
import trackio
|
|
|
|
# ============ Configuration ============
|
|
MODEL_ID = "Qwen/Qwen3-0.6B"
|
|
DATASET_ID = "OmAlve/indexlm-training-data"
|
|
OUTPUT_DIR = "./reading-steiner"
|
|
HUB_MODEL_ID = "OmAlve/reading-steiner"
|
|
|
|
# Training hyperparameters (from paper: standard SFT)
|
|
LEARNING_RATE = 2e-5
|
|
NUM_EPOCHS = 3
|
|
BATCH_SIZE = 4
|
|
GRAD_ACCUM = 4 # Effective batch size = 16
|
|
MAX_SEQ_LENGTH = 4096
|
|
WARMUP_RATIO = 0.05
|
|
|
|
# ============ Setup Trackio ============
|
|
trackio.init(
|
|
name="reading-steiner-training",
|
|
project="reading-steiner"
|
|
)
|
|
|
|
# ============ Load Dataset ============
|
|
print("Loading dataset...")
|
|
dataset = load_dataset(DATASET_ID)
|
|
train_dataset = dataset["train"]
|
|
eval_dataset = dataset["eval"]
|
|
print(f"Train: {len(train_dataset)}, Eval: {len(eval_dataset)}")
|
|
|
|
# ============ Load Model & Tokenizer ============
|
|
print("Loading model and tokenizer...")
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
|
|
|
# Ensure padding token is set
|
|
if tokenizer.pad_token is None:
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
MODEL_ID,
|
|
torch_dtype=torch.bfloat16,
|
|
attn_implementation="flash_attention_2", # Change to "sdpa" if flash-attn unavailable
|
|
)
|
|
|
|
print(f"Model loaded: {MODEL_ID}")
|
|
print(f"Model params: {sum(p.numel() for p in model.parameters()) / 1e6:.1f}M")
|
|
|
|
# ============ Training Config ============
|
|
training_args = SFTConfig(
|
|
output_dir=OUTPUT_DIR,
|
|
num_train_epochs=NUM_EPOCHS,
|
|
per_device_train_batch_size=BATCH_SIZE,
|
|
per_device_eval_batch_size=BATCH_SIZE,
|
|
gradient_accumulation_steps=GRAD_ACCUM,
|
|
learning_rate=LEARNING_RATE,
|
|
lr_scheduler_type="cosine",
|
|
warmup_ratio=WARMUP_RATIO,
|
|
weight_decay=0.01,
|
|
bf16=True,
|
|
gradient_checkpointing=True,
|
|
max_length=MAX_SEQ_LENGTH,
|
|
# Logging
|
|
logging_steps=10,
|
|
logging_first_step=True,
|
|
logging_strategy="steps",
|
|
disable_tqdm=True,
|
|
# Evaluation
|
|
eval_strategy="steps",
|
|
eval_steps=500,
|
|
# Saving
|
|
save_strategy="steps",
|
|
save_steps=500,
|
|
save_total_limit=3,
|
|
load_best_model_at_end=True,
|
|
metric_for_best_model="eval_loss",
|
|
greater_is_better=False,
|
|
# Hub push
|
|
push_to_hub=True,
|
|
hub_model_id=HUB_MODEL_ID,
|
|
hub_strategy="every_save",
|
|
# Performance
|
|
dataloader_num_workers=4,
|
|
dataloader_pin_memory=True,
|
|
# Report
|
|
report_to="none",
|
|
# Seed
|
|
seed=42,
|
|
)
|
|
|
|
# ============ Initialize Trainer ============
|
|
print("Initializing trainer...")
|
|
trainer = SFTTrainer(
|
|
model=model,
|
|
args=training_args,
|
|
train_dataset=train_dataset,
|
|
eval_dataset=eval_dataset,
|
|
processing_class=tokenizer,
|
|
)
|
|
|
|
# ============ Train ============
|
|
print("Starting training...")
|
|
train_result = trainer.train()
|
|
|
|
# ============ Save Final Model ============
|
|
print("Saving final model...")
|
|
trainer.save_model(OUTPUT_DIR)
|
|
tokenizer.save_pretrained(OUTPUT_DIR)
|
|
|
|
# Push to Hub
|
|
print("Pushing to Hub...")
|
|
trainer.push_to_hub(commit_message="Final reading-steiner model")
|
|
|
|
# ============ Log Final Metrics ============
|
|
metrics = train_result.metrics
|
|
print(f"\nTraining complete!")
|
|
print(f" Train loss: {metrics.get('train_loss', 'N/A')}")
|
|
print(f" Train runtime: {metrics.get('train_runtime', 'N/A'):.0f}s")
|
|
print(f" Train samples/sec: {metrics.get('train_samples_per_second', 'N/A'):.1f}")
|
|
|
|
# Final eval
|
|
eval_metrics = trainer.evaluate()
|
|
print(f" Eval loss: {eval_metrics.get('eval_loss', 'N/A')}")
|
|
|
|
print(f"\nModel pushed to: https://huggingface.co/{HUB_MODEL_ID}")
|