#!/usr/bin/env python3 """ SFT training script for SmolLM3-3B-GSM8K-SFT Fine-tunes SmolLM3-3B-Base on MetaMathQA for math reasoning. Requirements: pip install transformers trl datasets torch Usage: python train_sft.py """ import torch from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoTokenizer from trl import SFTConfig, SFTTrainer # Configuration BASE_MODEL = "HuggingFaceTB/SmolLM3-3B-Base" OUTPUT_DIR = "./sft_output" FINAL_MODEL_DIR = "./final_model" NUM_SAMPLES = 100000 # Use 100k samples from MetaMathQA # ChatML template for base model (it has no chat template by default) CHAT_TEMPLATE = """{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful AI assistant.<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}""" def format_example(example): """Format the MetaMathQA example into conversation format.""" question = example['query'] answer = example['response'] text = f"""<|im_start|>system You are a helpful AI assistant that solves math problems step by step. Show your work clearly and end with the final numerical answer after ####.<|im_end|> <|im_start|>user {question}<|im_end|> <|im_start|>assistant {answer}<|im_end|>""" return {"text": text} def main(): print("=" * 50) print("Starting SFT Training for Math Reasoning") print("=" * 50) # Load tokenizer and add chat template print("\nLoading tokenizer...") tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # Add ChatML special tokens special_tokens = { "additional_special_tokens": ["<|im_start|>", "<|im_end|>"] } tokenizer.add_special_tokens(special_tokens) tokenizer.chat_template = CHAT_TEMPLATE # Load model print("\nLoading model...") model = AutoModelForCausalLM.from_pretrained( BASE_MODEL, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2", device_map="auto", ) # Resize embeddings for new tokens model.resize_token_embeddings(len(tokenizer)) # Load and prepare dataset print("\nLoading MetaMathQA dataset...") dataset = load_dataset("meta-math/MetaMathQA", split="train") dataset = dataset.shuffle(seed=42).select(range(min(NUM_SAMPLES, len(dataset)))) print(f"Using {len(dataset)} samples") # Format dataset print("\nFormatting dataset...") formatted_dataset = dataset.map( format_example, remove_columns=dataset.column_names ) # Split for train/eval split_dataset = formatted_dataset.train_test_split(test_size=0.01, seed=42) train_dataset = split_dataset['train'] eval_dataset = split_dataset['test'] print(f"Train samples: {len(train_dataset)}") print(f"Eval samples: {len(eval_dataset)}") # Training configuration training_args = SFTConfig( output_dir=OUTPUT_DIR, num_train_epochs=1, per_device_train_batch_size=2, gradient_accumulation_steps=8, # Effective batch size = 16 learning_rate=1e-5, lr_scheduler_type="cosine", warmup_ratio=0.05, weight_decay=0.01, logging_steps=50, save_steps=500, save_total_limit=2, eval_strategy="steps", eval_steps=500, bf16=True, gradient_checkpointing=True, report_to="none", max_length=2048, seed=42, packing=False, ) # Create trainer print("\nInitializing SFT Trainer...") trainer = SFTTrainer( model=model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset, processing_class=tokenizer, ) # Train print("\nStarting training...") trainer.train() # Save final model print(f"\nSaving final model to {FINAL_MODEL_DIR}...") trainer.save_model(FINAL_MODEL_DIR) tokenizer.save_pretrained(FINAL_MODEL_DIR) print("\n" + "=" * 50) print("Training completed!") print("=" * 50) if __name__ == "__main__": main()