Files
reading-steiner/train_indexlm.py
ModelHub XC 92eddcb2d6 初始化项目,由ModelHub XC社区提供模型
Model: OmAlve/reading-steiner
Source: Original Platform
2026-06-16 08:15:17 +08:00

138 lines
3.9 KiB
Python

"""
Reading Steiner - Fine-tune Qwen3-0.6B for Index-based Web Content Extraction
Based on: "An Index-based Approach for Efficient and Effective Web Content Extraction" (arxiv:2512.06641)
Base model: Qwen/Qwen3-0.6B (0.6B params, ideal for CPU deployment)
Training method: SFT with TRL SFTTrainer
Dataset: OmAlve/indexlm-training-data (21K+ multi-domain examples)
"""
import os
import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import SFTTrainer, SFTConfig
import trackio
# ============ Configuration ============
MODEL_ID = "Qwen/Qwen3-0.6B"
DATASET_ID = "OmAlve/indexlm-training-data"
OUTPUT_DIR = "./reading-steiner"
HUB_MODEL_ID = "OmAlve/reading-steiner"
# Training hyperparameters (from paper: standard SFT)
LEARNING_RATE = 2e-5
NUM_EPOCHS = 3
BATCH_SIZE = 4
GRAD_ACCUM = 4 # Effective batch size = 16
MAX_SEQ_LENGTH = 4096
WARMUP_RATIO = 0.05
# ============ Setup Trackio ============
trackio.init(
name="reading-steiner-training",
project="reading-steiner"
)
# ============ Load Dataset ============
print("Loading dataset...")
dataset = load_dataset(DATASET_ID)
train_dataset = dataset["train"]
eval_dataset = dataset["eval"]
print(f"Train: {len(train_dataset)}, Eval: {len(eval_dataset)}")
# ============ Load Model & Tokenizer ============
print("Loading model and tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
# Ensure padding token is set
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2", # Change to "sdpa" if flash-attn unavailable
)
print(f"Model loaded: {MODEL_ID}")
print(f"Model params: {sum(p.numel() for p in model.parameters()) / 1e6:.1f}M")
# ============ Training Config ============
training_args = SFTConfig(
output_dir=OUTPUT_DIR,
num_train_epochs=NUM_EPOCHS,
per_device_train_batch_size=BATCH_SIZE,
per_device_eval_batch_size=BATCH_SIZE,
gradient_accumulation_steps=GRAD_ACCUM,
learning_rate=LEARNING_RATE,
lr_scheduler_type="cosine",
warmup_ratio=WARMUP_RATIO,
weight_decay=0.01,
bf16=True,
gradient_checkpointing=True,
max_length=MAX_SEQ_LENGTH,
# Logging
logging_steps=10,
logging_first_step=True,
logging_strategy="steps",
disable_tqdm=True,
# Evaluation
eval_strategy="steps",
eval_steps=500,
# Saving
save_strategy="steps",
save_steps=500,
save_total_limit=3,
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
greater_is_better=False,
# Hub push
push_to_hub=True,
hub_model_id=HUB_MODEL_ID,
hub_strategy="every_save",
# Performance
dataloader_num_workers=4,
dataloader_pin_memory=True,
# Report
report_to="none",
# Seed
seed=42,
)
# ============ Initialize Trainer ============
print("Initializing trainer...")
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
processing_class=tokenizer,
)
# ============ Train ============
print("Starting training...")
train_result = trainer.train()
# ============ Save Final Model ============
print("Saving final model...")
trainer.save_model(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)
# Push to Hub
print("Pushing to Hub...")
trainer.push_to_hub(commit_message="Final reading-steiner model")
# ============ Log Final Metrics ============
metrics = train_result.metrics
print(f"\nTraining complete!")
print(f" Train loss: {metrics.get('train_loss', 'N/A')}")
print(f" Train runtime: {metrics.get('train_runtime', 'N/A'):.0f}s")
print(f" Train samples/sec: {metrics.get('train_samples_per_second', 'N/A'):.1f}")
# Final eval
eval_metrics = trainer.evaluate()
print(f" Eval loss: {eval_metrics.get('eval_loss', 'N/A')}")
print(f"\nModel pushed to: https://huggingface.co/{HUB_MODEL_ID}")