258 lines
9.8 KiB
Python
258 lines
9.8 KiB
Python
|
|
"""
|
||
|
|
German Structured Output SFT Training
|
||
|
|
======================================
|
||
|
|
EU AI Act Article 53 compliant fine-tuning of OLMo 2 1B Instruct
|
||
|
|
on philipp-zettl/german-structured-output dataset.
|
||
|
|
|
||
|
|
Base model: allenai/OLMo-2-0425-1B-Instruct (Apache 2.0)
|
||
|
|
- Pretraining data: allenai/olmo-mix-1124 (ODC-BY, fully documented)
|
||
|
|
- Every data source has explicit license: DCLM (CC-BY-4.0), ArXiv (ODC-BY),
|
||
|
|
PeS2o (ODC-BY), StarCoder (ODC-BY), Wikipedia (ODC-BY), etc.
|
||
|
|
- Full training logs, code, and intermediate checkpoints published by AI2
|
||
|
|
|
||
|
|
Dataset: philipp-zettl/german-structured-output (CC BY-SA 4.0)
|
||
|
|
- All sources documented with licenses per example
|
||
|
|
- Zero PII, GDPR compliant
|
||
|
|
|
||
|
|
Method: Full SFT with gradient checkpointing
|
||
|
|
"""
|
||
|
|
|
||
|
|
import os
|
||
|
|
import sys
|
||
|
|
import json
|
||
|
|
import torch
|
||
|
|
import statistics
|
||
|
|
|
||
|
|
# Force unbuffered output so logs stream in real-time
|
||
|
|
os.environ["PYTHONUNBUFFERED"] = "1"
|
||
|
|
|
||
|
|
print("=" * 60, flush=True)
|
||
|
|
print("Starting German Structured Output SFT Training", flush=True)
|
||
|
|
print("=" * 60, flush=True)
|
||
|
|
|
||
|
|
# Verify HF_TOKEN is available (needed for private dataset + push_to_hub)
|
||
|
|
token = os.environ.get("HF_TOKEN")
|
||
|
|
if not token:
|
||
|
|
print("ERROR: HF_TOKEN not set! Cannot access private dataset or push model.", flush=True)
|
||
|
|
print("Set via: export HF_TOKEN=hf_... or pass --env HF_TOKEN=hf_...", flush=True)
|
||
|
|
sys.exit(1)
|
||
|
|
else:
|
||
|
|
print(f"HF_TOKEN: present ({token[:8]}...)", flush=True)
|
||
|
|
|
||
|
|
# Check GPU
|
||
|
|
print(f"CUDA available: {torch.cuda.is_available()}", flush=True)
|
||
|
|
if torch.cuda.is_available():
|
||
|
|
print(f"GPU: {torch.cuda.get_device_name()}", flush=True)
|
||
|
|
print(f"VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB", flush=True)
|
||
|
|
else:
|
||
|
|
print("WARNING: No GPU detected! Training will be very slow.", flush=True)
|
||
|
|
|
||
|
|
from datasets import load_dataset
|
||
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||
|
|
from trl import SFTConfig, SFTTrainer
|
||
|
|
|
||
|
|
print("All imports OK", flush=True)
|
||
|
|
|
||
|
|
# ============================================================================
|
||
|
|
# Configuration
|
||
|
|
# ============================================================================
|
||
|
|
|
||
|
|
MODEL_ID = "allenai/OLMo-2-0425-1B-Instruct"
|
||
|
|
DATASET_ID = "philipp-zettl/german-structured-output"
|
||
|
|
OUTPUT_MODEL_ID = "philipp-zettl/german-structured-output-olmo2-1b"
|
||
|
|
OUTPUT_DIR = "./output"
|
||
|
|
|
||
|
|
# ============================================================================
|
||
|
|
# Trackio monitoring — writes to HF Bucket for persistent storage
|
||
|
|
# ============================================================================
|
||
|
|
print("Initializing Trackio...", flush=True)
|
||
|
|
import trackio
|
||
|
|
trackio.init(
|
||
|
|
project="german-structured-output-sft",
|
||
|
|
name="olmo2-1b-full-sft",
|
||
|
|
space_id="philipp-zettl/german-structured-output-training",
|
||
|
|
bucket_id="philipp-zettl/german-structured-output-training-bucket",
|
||
|
|
)
|
||
|
|
print("Trackio initialized with bucket storage", flush=True)
|
||
|
|
|
||
|
|
# ============================================================================
|
||
|
|
# Load dataset
|
||
|
|
# ============================================================================
|
||
|
|
print("Loading dataset...", flush=True)
|
||
|
|
dataset = load_dataset(DATASET_ID)
|
||
|
|
print(f"Train: {len(dataset['train'])}, Val: {len(dataset['validation'])}, Test: {len(dataset['test'])}", flush=True)
|
||
|
|
|
||
|
|
# Quick data audit
|
||
|
|
print("\nSample messages structure:", flush=True)
|
||
|
|
sample = dataset["train"][0]
|
||
|
|
for msg in sample["messages"]:
|
||
|
|
print(f" [{msg['role']}]: {msg['content'][:100]}...", flush=True)
|
||
|
|
print(f" task_type: {sample['task_type']}", flush=True)
|
||
|
|
print(f" quality_score: {sample['quality_score']}", flush=True)
|
||
|
|
|
||
|
|
# ============================================================================
|
||
|
|
# Load model and tokenizer
|
||
|
|
# ============================================================================
|
||
|
|
print(f"\nLoading model: {MODEL_ID}", flush=True)
|
||
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
||
|
|
model = AutoModelForCausalLM.from_pretrained(
|
||
|
|
MODEL_ID,
|
||
|
|
torch_dtype=torch.bfloat16,
|
||
|
|
attn_implementation="sdpa",
|
||
|
|
)
|
||
|
|
|
||
|
|
# Verify chat template exists
|
||
|
|
print(f"Chat template available: {tokenizer.chat_template is not None}", flush=True)
|
||
|
|
print(f"Model parameters: {sum(p.numel() for p in model.parameters()) / 1e6:.0f}M", flush=True)
|
||
|
|
|
||
|
|
# Check sequence lengths in dataset to set max_length appropriately
|
||
|
|
print("\nAnalyzing sequence lengths...", flush=True)
|
||
|
|
lengths = []
|
||
|
|
for i, example in enumerate(dataset["train"]):
|
||
|
|
text = tokenizer.apply_chat_template(example["messages"], tokenize=False)
|
||
|
|
tokens = tokenizer(text, return_length=True)
|
||
|
|
lengths.append(tokens["length"][0])
|
||
|
|
if i >= 200: # Sample first 200
|
||
|
|
break
|
||
|
|
|
||
|
|
print(f" Sampled {len(lengths)} examples", flush=True)
|
||
|
|
print(f" Min: {min(lengths)}, Max: {max(lengths)}, Mean: {statistics.mean(lengths):.0f}, Median: {statistics.median(lengths):.0f}", flush=True)
|
||
|
|
print(f" P95: {sorted(lengths)[int(len(lengths)*0.95)]}", flush=True)
|
||
|
|
print(f" P99: {sorted(lengths)[int(len(lengths)*0.99)]}", flush=True)
|
||
|
|
|
||
|
|
# Set max_length to cover P99 + margin, cap at 2048
|
||
|
|
MAX_LENGTH = min(sorted(lengths)[int(len(lengths)*0.99)] + 128, 2048)
|
||
|
|
print(f" Using max_length: {MAX_LENGTH}", flush=True)
|
||
|
|
|
||
|
|
# ============================================================================
|
||
|
|
# Training configuration
|
||
|
|
# ============================================================================
|
||
|
|
training_args = SFTConfig(
|
||
|
|
output_dir=OUTPUT_DIR,
|
||
|
|
|
||
|
|
# Precision
|
||
|
|
bf16=True,
|
||
|
|
|
||
|
|
# Sequence
|
||
|
|
max_length=MAX_LENGTH,
|
||
|
|
packing=False, # No packing — variable length structured outputs
|
||
|
|
|
||
|
|
# Batch size: effective = 4 * 4 = 16
|
||
|
|
per_device_train_batch_size=4,
|
||
|
|
gradient_accumulation_steps=4,
|
||
|
|
|
||
|
|
# Optimizer
|
||
|
|
learning_rate=2e-5,
|
||
|
|
lr_scheduler_type="cosine",
|
||
|
|
warmup_ratio=0.05,
|
||
|
|
weight_decay=0.01,
|
||
|
|
max_grad_norm=1.0,
|
||
|
|
|
||
|
|
# Epochs — small dataset (4K), more epochs to converge
|
||
|
|
num_train_epochs=5,
|
||
|
|
|
||
|
|
# Evaluation
|
||
|
|
eval_strategy="epoch",
|
||
|
|
eval_on_start=True,
|
||
|
|
|
||
|
|
# Saving
|
||
|
|
save_strategy="epoch",
|
||
|
|
save_total_limit=2,
|
||
|
|
load_best_model_at_end=True,
|
||
|
|
metric_for_best_model="eval_loss",
|
||
|
|
greater_is_better=False,
|
||
|
|
|
||
|
|
# Logging — plain text, no tqdm
|
||
|
|
logging_strategy="steps",
|
||
|
|
logging_steps=10,
|
||
|
|
logging_first_step=True,
|
||
|
|
disable_tqdm=True,
|
||
|
|
report_to="trackio",
|
||
|
|
run_name="olmo2-1b-german-structured-output",
|
||
|
|
|
||
|
|
# Memory
|
||
|
|
gradient_checkpointing=True,
|
||
|
|
|
||
|
|
# Hub — CRITICAL: push_to_hub to save model (ephemeral job storage)
|
||
|
|
push_to_hub=True,
|
||
|
|
hub_model_id=OUTPUT_MODEL_ID,
|
||
|
|
hub_strategy="end",
|
||
|
|
|
||
|
|
# NEFTune for small dataset regularization (paper: arxiv 2310.05914)
|
||
|
|
neftune_noise_alpha=5.0,
|
||
|
|
)
|
||
|
|
|
||
|
|
# ============================================================================
|
||
|
|
# Trainer
|
||
|
|
# ============================================================================
|
||
|
|
print("\nInitializing SFTTrainer...", flush=True)
|
||
|
|
trainer = SFTTrainer(
|
||
|
|
model=model,
|
||
|
|
processing_class=tokenizer,
|
||
|
|
args=training_args,
|
||
|
|
train_dataset=dataset["train"],
|
||
|
|
eval_dataset=dataset["validation"],
|
||
|
|
)
|
||
|
|
|
||
|
|
# Print training summary
|
||
|
|
total_steps = len(dataset["train"]) // (training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps) * int(training_args.num_train_epochs)
|
||
|
|
print(f"\n{'='*60}", flush=True)
|
||
|
|
print(f"Training Summary", flush=True)
|
||
|
|
print(f"{'='*60}", flush=True)
|
||
|
|
print(f" Base model: {MODEL_ID}", flush=True)
|
||
|
|
print(f" Base model license: Apache 2.0", flush=True)
|
||
|
|
print(f" Base model training data: allenai/olmo-mix-1124 (ODC-BY)", flush=True)
|
||
|
|
print(f" Dataset: {DATASET_ID}", flush=True)
|
||
|
|
print(f" Dataset license: CC BY-SA 4.0", flush=True)
|
||
|
|
print(f" Output model: {OUTPUT_MODEL_ID}", flush=True)
|
||
|
|
print(f" Train examples: {len(dataset['train'])}", flush=True)
|
||
|
|
print(f" Val examples: {len(dataset['validation'])}", flush=True)
|
||
|
|
print(f" Max length: {MAX_LENGTH}", flush=True)
|
||
|
|
print(f" Effective batch size: {training_args.per_device_train_batch_size} x {training_args.gradient_accumulation_steps} = {training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps}", flush=True)
|
||
|
|
print(f" Learning rate: {training_args.learning_rate}", flush=True)
|
||
|
|
print(f" Epochs: {training_args.num_train_epochs}", flush=True)
|
||
|
|
print(f" Estimated steps: ~{total_steps}", flush=True)
|
||
|
|
print(f" NEFTune alpha: {training_args.neftune_noise_alpha}", flush=True)
|
||
|
|
print(f" EU AI Act: Article 53 compliant (full data provenance chain)", flush=True)
|
||
|
|
print(f"{'='*60}", flush=True)
|
||
|
|
|
||
|
|
# ============================================================================
|
||
|
|
# Train!
|
||
|
|
# ============================================================================
|
||
|
|
print("\nStarting training...", flush=True)
|
||
|
|
train_result = trainer.train()
|
||
|
|
|
||
|
|
# ============================================================================
|
||
|
|
# Save and push
|
||
|
|
# ============================================================================
|
||
|
|
print("\nSaving model...", flush=True)
|
||
|
|
trainer.save_model()
|
||
|
|
|
||
|
|
# Log final metrics
|
||
|
|
metrics = train_result.metrics
|
||
|
|
print(f"\nTraining complete!", flush=True)
|
||
|
|
print(f" Train loss: {metrics.get('train_loss', 'N/A')}", flush=True)
|
||
|
|
print(f" Train runtime: {metrics.get('train_runtime', 'N/A'):.0f}s", flush=True)
|
||
|
|
print(f" Train samples/sec: {metrics.get('train_samples_per_second', 'N/A'):.1f}", flush=True)
|
||
|
|
|
||
|
|
# Evaluate
|
||
|
|
print("\nRunning final evaluation...", flush=True)
|
||
|
|
eval_metrics = trainer.evaluate()
|
||
|
|
print(f" Eval loss: {eval_metrics.get('eval_loss', 'N/A')}", flush=True)
|
||
|
|
|
||
|
|
# Push to hub
|
||
|
|
print(f"\nPushing to Hub: {OUTPUT_MODEL_ID}", flush=True)
|
||
|
|
trainer.push_to_hub(
|
||
|
|
commit_message="EU AI Act compliant SFT: OLMo 2 1B on german-structured-output",
|
||
|
|
tags=[
|
||
|
|
"german", "structured-output", "json", "function-calling",
|
||
|
|
"ner", "relation-extraction", "gdpr-anonymization",
|
||
|
|
"eu-ai-act-compliant", "gdpr-compliant", "sft",
|
||
|
|
"olmo2", "article-53"
|
||
|
|
],
|
||
|
|
)
|
||
|
|
|
||
|
|
print("\n✅ Training complete and model pushed to Hub!", flush=True)
|
||
|
|
print(f" Model: https://huggingface.co/{OUTPUT_MODEL_ID}", flush=True)
|
||
|
|
print(f" Trackio: https://huggingface.co/spaces/philipp-zettl/german-structured-output-training", flush=True)
|