初始化项目,由ModelHub XC社区提供模型
Model: philipp-zettl/german-structured-output-olmo2-1b Source: Original Platform
This commit is contained in:
257
train.py
Normal file
257
train.py
Normal file
@@ -0,0 +1,257 @@
|
||||
"""
|
||||
German Structured Output SFT Training
|
||||
======================================
|
||||
EU AI Act Article 53 compliant fine-tuning of OLMo 2 1B Instruct
|
||||
on philipp-zettl/german-structured-output dataset.
|
||||
|
||||
Base model: allenai/OLMo-2-0425-1B-Instruct (Apache 2.0)
|
||||
- Pretraining data: allenai/olmo-mix-1124 (ODC-BY, fully documented)
|
||||
- Every data source has explicit license: DCLM (CC-BY-4.0), ArXiv (ODC-BY),
|
||||
PeS2o (ODC-BY), StarCoder (ODC-BY), Wikipedia (ODC-BY), etc.
|
||||
- Full training logs, code, and intermediate checkpoints published by AI2
|
||||
|
||||
Dataset: philipp-zettl/german-structured-output (CC BY-SA 4.0)
|
||||
- All sources documented with licenses per example
|
||||
- Zero PII, GDPR compliant
|
||||
|
||||
Method: Full SFT with gradient checkpointing
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import torch
|
||||
import statistics
|
||||
|
||||
# Force unbuffered output so logs stream in real-time
|
||||
os.environ["PYTHONUNBUFFERED"] = "1"
|
||||
|
||||
print("=" * 60, flush=True)
|
||||
print("Starting German Structured Output SFT Training", flush=True)
|
||||
print("=" * 60, flush=True)
|
||||
|
||||
# Verify HF_TOKEN is available (needed for private dataset + push_to_hub)
|
||||
token = os.environ.get("HF_TOKEN")
|
||||
if not token:
|
||||
print("ERROR: HF_TOKEN not set! Cannot access private dataset or push model.", flush=True)
|
||||
print("Set via: export HF_TOKEN=hf_... or pass --env HF_TOKEN=hf_...", flush=True)
|
||||
sys.exit(1)
|
||||
else:
|
||||
print(f"HF_TOKEN: present ({token[:8]}...)", flush=True)
|
||||
|
||||
# Check GPU
|
||||
print(f"CUDA available: {torch.cuda.is_available()}", flush=True)
|
||||
if torch.cuda.is_available():
|
||||
print(f"GPU: {torch.cuda.get_device_name()}", flush=True)
|
||||
print(f"VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB", flush=True)
|
||||
else:
|
||||
print("WARNING: No GPU detected! Training will be very slow.", flush=True)
|
||||
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
from trl import SFTConfig, SFTTrainer
|
||||
|
||||
print("All imports OK", flush=True)
|
||||
|
||||
# ============================================================================
|
||||
# Configuration
|
||||
# ============================================================================
|
||||
|
||||
MODEL_ID = "allenai/OLMo-2-0425-1B-Instruct"
|
||||
DATASET_ID = "philipp-zettl/german-structured-output"
|
||||
OUTPUT_MODEL_ID = "philipp-zettl/german-structured-output-olmo2-1b"
|
||||
OUTPUT_DIR = "./output"
|
||||
|
||||
# ============================================================================
|
||||
# Trackio monitoring — writes to HF Bucket for persistent storage
|
||||
# ============================================================================
|
||||
print("Initializing Trackio...", flush=True)
|
||||
import trackio
|
||||
trackio.init(
|
||||
project="german-structured-output-sft",
|
||||
name="olmo2-1b-full-sft",
|
||||
space_id="philipp-zettl/german-structured-output-training",
|
||||
bucket_id="philipp-zettl/german-structured-output-training-bucket",
|
||||
)
|
||||
print("Trackio initialized with bucket storage", flush=True)
|
||||
|
||||
# ============================================================================
|
||||
# Load dataset
|
||||
# ============================================================================
|
||||
print("Loading dataset...", flush=True)
|
||||
dataset = load_dataset(DATASET_ID)
|
||||
print(f"Train: {len(dataset['train'])}, Val: {len(dataset['validation'])}, Test: {len(dataset['test'])}", flush=True)
|
||||
|
||||
# Quick data audit
|
||||
print("\nSample messages structure:", flush=True)
|
||||
sample = dataset["train"][0]
|
||||
for msg in sample["messages"]:
|
||||
print(f" [{msg['role']}]: {msg['content'][:100]}...", flush=True)
|
||||
print(f" task_type: {sample['task_type']}", flush=True)
|
||||
print(f" quality_score: {sample['quality_score']}", flush=True)
|
||||
|
||||
# ============================================================================
|
||||
# Load model and tokenizer
|
||||
# ============================================================================
|
||||
print(f"\nLoading model: {MODEL_ID}", flush=True)
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
MODEL_ID,
|
||||
torch_dtype=torch.bfloat16,
|
||||
attn_implementation="sdpa",
|
||||
)
|
||||
|
||||
# Verify chat template exists
|
||||
print(f"Chat template available: {tokenizer.chat_template is not None}", flush=True)
|
||||
print(f"Model parameters: {sum(p.numel() for p in model.parameters()) / 1e6:.0f}M", flush=True)
|
||||
|
||||
# Check sequence lengths in dataset to set max_length appropriately
|
||||
print("\nAnalyzing sequence lengths...", flush=True)
|
||||
lengths = []
|
||||
for i, example in enumerate(dataset["train"]):
|
||||
text = tokenizer.apply_chat_template(example["messages"], tokenize=False)
|
||||
tokens = tokenizer(text, return_length=True)
|
||||
lengths.append(tokens["length"][0])
|
||||
if i >= 200: # Sample first 200
|
||||
break
|
||||
|
||||
print(f" Sampled {len(lengths)} examples", flush=True)
|
||||
print(f" Min: {min(lengths)}, Max: {max(lengths)}, Mean: {statistics.mean(lengths):.0f}, Median: {statistics.median(lengths):.0f}", flush=True)
|
||||
print(f" P95: {sorted(lengths)[int(len(lengths)*0.95)]}", flush=True)
|
||||
print(f" P99: {sorted(lengths)[int(len(lengths)*0.99)]}", flush=True)
|
||||
|
||||
# Set max_length to cover P99 + margin, cap at 2048
|
||||
MAX_LENGTH = min(sorted(lengths)[int(len(lengths)*0.99)] + 128, 2048)
|
||||
print(f" Using max_length: {MAX_LENGTH}", flush=True)
|
||||
|
||||
# ============================================================================
|
||||
# Training configuration
|
||||
# ============================================================================
|
||||
training_args = SFTConfig(
|
||||
output_dir=OUTPUT_DIR,
|
||||
|
||||
# Precision
|
||||
bf16=True,
|
||||
|
||||
# Sequence
|
||||
max_length=MAX_LENGTH,
|
||||
packing=False, # No packing — variable length structured outputs
|
||||
|
||||
# Batch size: effective = 4 * 4 = 16
|
||||
per_device_train_batch_size=4,
|
||||
gradient_accumulation_steps=4,
|
||||
|
||||
# Optimizer
|
||||
learning_rate=2e-5,
|
||||
lr_scheduler_type="cosine",
|
||||
warmup_ratio=0.05,
|
||||
weight_decay=0.01,
|
||||
max_grad_norm=1.0,
|
||||
|
||||
# Epochs — small dataset (4K), more epochs to converge
|
||||
num_train_epochs=5,
|
||||
|
||||
# Evaluation
|
||||
eval_strategy="epoch",
|
||||
eval_on_start=True,
|
||||
|
||||
# Saving
|
||||
save_strategy="epoch",
|
||||
save_total_limit=2,
|
||||
load_best_model_at_end=True,
|
||||
metric_for_best_model="eval_loss",
|
||||
greater_is_better=False,
|
||||
|
||||
# Logging — plain text, no tqdm
|
||||
logging_strategy="steps",
|
||||
logging_steps=10,
|
||||
logging_first_step=True,
|
||||
disable_tqdm=True,
|
||||
report_to="trackio",
|
||||
run_name="olmo2-1b-german-structured-output",
|
||||
|
||||
# Memory
|
||||
gradient_checkpointing=True,
|
||||
|
||||
# Hub — CRITICAL: push_to_hub to save model (ephemeral job storage)
|
||||
push_to_hub=True,
|
||||
hub_model_id=OUTPUT_MODEL_ID,
|
||||
hub_strategy="end",
|
||||
|
||||
# NEFTune for small dataset regularization (paper: arxiv 2310.05914)
|
||||
neftune_noise_alpha=5.0,
|
||||
)
|
||||
|
||||
# ============================================================================
|
||||
# Trainer
|
||||
# ============================================================================
|
||||
print("\nInitializing SFTTrainer...", flush=True)
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
processing_class=tokenizer,
|
||||
args=training_args,
|
||||
train_dataset=dataset["train"],
|
||||
eval_dataset=dataset["validation"],
|
||||
)
|
||||
|
||||
# Print training summary
|
||||
total_steps = len(dataset["train"]) // (training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps) * int(training_args.num_train_epochs)
|
||||
print(f"\n{'='*60}", flush=True)
|
||||
print(f"Training Summary", flush=True)
|
||||
print(f"{'='*60}", flush=True)
|
||||
print(f" Base model: {MODEL_ID}", flush=True)
|
||||
print(f" Base model license: Apache 2.0", flush=True)
|
||||
print(f" Base model training data: allenai/olmo-mix-1124 (ODC-BY)", flush=True)
|
||||
print(f" Dataset: {DATASET_ID}", flush=True)
|
||||
print(f" Dataset license: CC BY-SA 4.0", flush=True)
|
||||
print(f" Output model: {OUTPUT_MODEL_ID}", flush=True)
|
||||
print(f" Train examples: {len(dataset['train'])}", flush=True)
|
||||
print(f" Val examples: {len(dataset['validation'])}", flush=True)
|
||||
print(f" Max length: {MAX_LENGTH}", flush=True)
|
||||
print(f" Effective batch size: {training_args.per_device_train_batch_size} x {training_args.gradient_accumulation_steps} = {training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps}", flush=True)
|
||||
print(f" Learning rate: {training_args.learning_rate}", flush=True)
|
||||
print(f" Epochs: {training_args.num_train_epochs}", flush=True)
|
||||
print(f" Estimated steps: ~{total_steps}", flush=True)
|
||||
print(f" NEFTune alpha: {training_args.neftune_noise_alpha}", flush=True)
|
||||
print(f" EU AI Act: Article 53 compliant (full data provenance chain)", flush=True)
|
||||
print(f"{'='*60}", flush=True)
|
||||
|
||||
# ============================================================================
|
||||
# Train!
|
||||
# ============================================================================
|
||||
print("\nStarting training...", flush=True)
|
||||
train_result = trainer.train()
|
||||
|
||||
# ============================================================================
|
||||
# Save and push
|
||||
# ============================================================================
|
||||
print("\nSaving model...", flush=True)
|
||||
trainer.save_model()
|
||||
|
||||
# Log final metrics
|
||||
metrics = train_result.metrics
|
||||
print(f"\nTraining complete!", flush=True)
|
||||
print(f" Train loss: {metrics.get('train_loss', 'N/A')}", flush=True)
|
||||
print(f" Train runtime: {metrics.get('train_runtime', 'N/A'):.0f}s", flush=True)
|
||||
print(f" Train samples/sec: {metrics.get('train_samples_per_second', 'N/A'):.1f}", flush=True)
|
||||
|
||||
# Evaluate
|
||||
print("\nRunning final evaluation...", flush=True)
|
||||
eval_metrics = trainer.evaluate()
|
||||
print(f" Eval loss: {eval_metrics.get('eval_loss', 'N/A')}", flush=True)
|
||||
|
||||
# Push to hub
|
||||
print(f"\nPushing to Hub: {OUTPUT_MODEL_ID}", flush=True)
|
||||
trainer.push_to_hub(
|
||||
commit_message="EU AI Act compliant SFT: OLMo 2 1B on german-structured-output",
|
||||
tags=[
|
||||
"german", "structured-output", "json", "function-calling",
|
||||
"ner", "relation-extraction", "gdpr-anonymization",
|
||||
"eu-ai-act-compliant", "gdpr-compliant", "sft",
|
||||
"olmo2", "article-53"
|
||||
],
|
||||
)
|
||||
|
||||
print("\n✅ Training complete and model pushed to Hub!", flush=True)
|
||||
print(f" Model: https://huggingface.co/{OUTPUT_MODEL_ID}", flush=True)
|
||||
print(f" Trackio: https://huggingface.co/spaces/philipp-zettl/german-structured-output-training", flush=True)
|
||||
Reference in New Issue
Block a user