初始化项目,由ModelHub XC社区提供模型
Model: OmAlve/reading-steiner Source: Original Platform
This commit is contained in:
137
train_indexlm.py
Normal file
137
train_indexlm.py
Normal file
@@ -0,0 +1,137 @@
|
||||
"""
|
||||
Reading Steiner - Fine-tune Qwen3-0.6B for Index-based Web Content Extraction
|
||||
|
||||
Based on: "An Index-based Approach for Efficient and Effective Web Content Extraction" (arxiv:2512.06641)
|
||||
Base model: Qwen/Qwen3-0.6B (0.6B params, ideal for CPU deployment)
|
||||
Training method: SFT with TRL SFTTrainer
|
||||
Dataset: OmAlve/indexlm-training-data (21K+ multi-domain examples)
|
||||
"""
|
||||
|
||||
import os
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from trl import SFTTrainer, SFTConfig
|
||||
import trackio
|
||||
|
||||
# ============ Configuration ============
|
||||
MODEL_ID = "Qwen/Qwen3-0.6B"
|
||||
DATASET_ID = "OmAlve/indexlm-training-data"
|
||||
OUTPUT_DIR = "./reading-steiner"
|
||||
HUB_MODEL_ID = "OmAlve/reading-steiner"
|
||||
|
||||
# Training hyperparameters (from paper: standard SFT)
|
||||
LEARNING_RATE = 2e-5
|
||||
NUM_EPOCHS = 3
|
||||
BATCH_SIZE = 4
|
||||
GRAD_ACCUM = 4 # Effective batch size = 16
|
||||
MAX_SEQ_LENGTH = 4096
|
||||
WARMUP_RATIO = 0.05
|
||||
|
||||
# ============ Setup Trackio ============
|
||||
trackio.init(
|
||||
name="reading-steiner-training",
|
||||
project="reading-steiner"
|
||||
)
|
||||
|
||||
# ============ Load Dataset ============
|
||||
print("Loading dataset...")
|
||||
dataset = load_dataset(DATASET_ID)
|
||||
train_dataset = dataset["train"]
|
||||
eval_dataset = dataset["eval"]
|
||||
print(f"Train: {len(train_dataset)}, Eval: {len(eval_dataset)}")
|
||||
|
||||
# ============ Load Model & Tokenizer ============
|
||||
print("Loading model and tokenizer...")
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
||||
|
||||
# Ensure padding token is set
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
MODEL_ID,
|
||||
torch_dtype=torch.bfloat16,
|
||||
attn_implementation="flash_attention_2", # Change to "sdpa" if flash-attn unavailable
|
||||
)
|
||||
|
||||
print(f"Model loaded: {MODEL_ID}")
|
||||
print(f"Model params: {sum(p.numel() for p in model.parameters()) / 1e6:.1f}M")
|
||||
|
||||
# ============ Training Config ============
|
||||
training_args = SFTConfig(
|
||||
output_dir=OUTPUT_DIR,
|
||||
num_train_epochs=NUM_EPOCHS,
|
||||
per_device_train_batch_size=BATCH_SIZE,
|
||||
per_device_eval_batch_size=BATCH_SIZE,
|
||||
gradient_accumulation_steps=GRAD_ACCUM,
|
||||
learning_rate=LEARNING_RATE,
|
||||
lr_scheduler_type="cosine",
|
||||
warmup_ratio=WARMUP_RATIO,
|
||||
weight_decay=0.01,
|
||||
bf16=True,
|
||||
gradient_checkpointing=True,
|
||||
max_length=MAX_SEQ_LENGTH,
|
||||
# Logging
|
||||
logging_steps=10,
|
||||
logging_first_step=True,
|
||||
logging_strategy="steps",
|
||||
disable_tqdm=True,
|
||||
# Evaluation
|
||||
eval_strategy="steps",
|
||||
eval_steps=500,
|
||||
# Saving
|
||||
save_strategy="steps",
|
||||
save_steps=500,
|
||||
save_total_limit=3,
|
||||
load_best_model_at_end=True,
|
||||
metric_for_best_model="eval_loss",
|
||||
greater_is_better=False,
|
||||
# Hub push
|
||||
push_to_hub=True,
|
||||
hub_model_id=HUB_MODEL_ID,
|
||||
hub_strategy="every_save",
|
||||
# Performance
|
||||
dataloader_num_workers=4,
|
||||
dataloader_pin_memory=True,
|
||||
# Report
|
||||
report_to="none",
|
||||
# Seed
|
||||
seed=42,
|
||||
)
|
||||
|
||||
# ============ Initialize Trainer ============
|
||||
print("Initializing trainer...")
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
processing_class=tokenizer,
|
||||
)
|
||||
|
||||
# ============ Train ============
|
||||
print("Starting training...")
|
||||
train_result = trainer.train()
|
||||
|
||||
# ============ Save Final Model ============
|
||||
print("Saving final model...")
|
||||
trainer.save_model(OUTPUT_DIR)
|
||||
tokenizer.save_pretrained(OUTPUT_DIR)
|
||||
|
||||
# Push to Hub
|
||||
print("Pushing to Hub...")
|
||||
trainer.push_to_hub(commit_message="Final reading-steiner model")
|
||||
|
||||
# ============ Log Final Metrics ============
|
||||
metrics = train_result.metrics
|
||||
print(f"\nTraining complete!")
|
||||
print(f" Train loss: {metrics.get('train_loss', 'N/A')}")
|
||||
print(f" Train runtime: {metrics.get('train_runtime', 'N/A'):.0f}s")
|
||||
print(f" Train samples/sec: {metrics.get('train_samples_per_second', 'N/A'):.1f}")
|
||||
|
||||
# Final eval
|
||||
eval_metrics = trainer.evaluate()
|
||||
print(f" Eval loss: {eval_metrics.get('eval_loss', 'N/A')}")
|
||||
|
||||
print(f"\nModel pushed to: https://huggingface.co/{HUB_MODEL_ID}")
|
||||
Reference in New Issue
Block a user