140 lines
4.3 KiB
Python
140 lines
4.3 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
SFT training script for SmolLM3-3B-GSM8K-SFT
|
|
Fine-tunes SmolLM3-3B-Base on MetaMathQA for math reasoning.
|
|
|
|
Requirements:
|
|
pip install transformers trl datasets torch
|
|
|
|
Usage:
|
|
python train_sft.py
|
|
"""
|
|
|
|
import torch
|
|
from datasets import load_dataset
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
from trl import SFTConfig, SFTTrainer
|
|
|
|
# Configuration
|
|
BASE_MODEL = "HuggingFaceTB/SmolLM3-3B-Base"
|
|
OUTPUT_DIR = "./sft_output"
|
|
FINAL_MODEL_DIR = "./final_model"
|
|
NUM_SAMPLES = 100000 # Use 100k samples from MetaMathQA
|
|
|
|
# ChatML template for base model (it has no chat template by default)
|
|
CHAT_TEMPLATE = """{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful AI assistant.<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"""
|
|
|
|
def format_example(example):
|
|
"""Format the MetaMathQA example into conversation format."""
|
|
question = example['query']
|
|
answer = example['response']
|
|
|
|
text = f"""<|im_start|>system
|
|
You are a helpful AI assistant that solves math problems step by step. Show your work clearly and end with the final numerical answer after ####.<|im_end|>
|
|
<|im_start|>user
|
|
{question}<|im_end|>
|
|
<|im_start|>assistant
|
|
{answer}<|im_end|>"""
|
|
|
|
return {"text": text}
|
|
|
|
def main():
|
|
print("=" * 50)
|
|
print("Starting SFT Training for Math Reasoning")
|
|
print("=" * 50)
|
|
|
|
# Load tokenizer and add chat template
|
|
print("\nLoading tokenizer...")
|
|
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
|
|
if tokenizer.pad_token is None:
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
# Add ChatML special tokens
|
|
special_tokens = {
|
|
"additional_special_tokens": ["<|im_start|>", "<|im_end|>"]
|
|
}
|
|
tokenizer.add_special_tokens(special_tokens)
|
|
tokenizer.chat_template = CHAT_TEMPLATE
|
|
|
|
# Load model
|
|
print("\nLoading model...")
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
BASE_MODEL,
|
|
torch_dtype=torch.bfloat16,
|
|
attn_implementation="flash_attention_2",
|
|
device_map="auto",
|
|
)
|
|
|
|
# Resize embeddings for new tokens
|
|
model.resize_token_embeddings(len(tokenizer))
|
|
|
|
# Load and prepare dataset
|
|
print("\nLoading MetaMathQA dataset...")
|
|
dataset = load_dataset("meta-math/MetaMathQA", split="train")
|
|
dataset = dataset.shuffle(seed=42).select(range(min(NUM_SAMPLES, len(dataset))))
|
|
print(f"Using {len(dataset)} samples")
|
|
|
|
# Format dataset
|
|
print("\nFormatting dataset...")
|
|
formatted_dataset = dataset.map(
|
|
format_example,
|
|
remove_columns=dataset.column_names
|
|
)
|
|
|
|
# Split for train/eval
|
|
split_dataset = formatted_dataset.train_test_split(test_size=0.01, seed=42)
|
|
train_dataset = split_dataset['train']
|
|
eval_dataset = split_dataset['test']
|
|
|
|
print(f"Train samples: {len(train_dataset)}")
|
|
print(f"Eval samples: {len(eval_dataset)}")
|
|
|
|
# Training configuration
|
|
training_args = SFTConfig(
|
|
output_dir=OUTPUT_DIR,
|
|
num_train_epochs=1,
|
|
per_device_train_batch_size=2,
|
|
gradient_accumulation_steps=8, # Effective batch size = 16
|
|
learning_rate=1e-5,
|
|
lr_scheduler_type="cosine",
|
|
warmup_ratio=0.05,
|
|
weight_decay=0.01,
|
|
logging_steps=50,
|
|
save_steps=500,
|
|
save_total_limit=2,
|
|
eval_strategy="steps",
|
|
eval_steps=500,
|
|
bf16=True,
|
|
gradient_checkpointing=True,
|
|
report_to="none",
|
|
max_length=2048,
|
|
seed=42,
|
|
packing=False,
|
|
)
|
|
|
|
# Create trainer
|
|
print("\nInitializing SFT Trainer...")
|
|
trainer = SFTTrainer(
|
|
model=model,
|
|
args=training_args,
|
|
train_dataset=train_dataset,
|
|
eval_dataset=eval_dataset,
|
|
processing_class=tokenizer,
|
|
)
|
|
|
|
# Train
|
|
print("\nStarting training...")
|
|
trainer.train()
|
|
|
|
# Save final model
|
|
print(f"\nSaving final model to {FINAL_MODEL_DIR}...")
|
|
trainer.save_model(FINAL_MODEL_DIR)
|
|
tokenizer.save_pretrained(FINAL_MODEL_DIR)
|
|
|
|
print("\n" + "=" * 50)
|
|
print("Training completed!")
|
|
print("=" * 50)
|
|
|
|
if __name__ == "__main__":
|
|
main()
|