#!/usr/bin/env python3 """ Improved SFT training for GSM8K performance. Key improvements: 1. More training data (247K examples from GSM8K + MetaMathQA) 2. Multiple epochs with cosine LR schedule 3. Proper batch size and gradient accumulation for H100 4. Gradient checkpointing for memory efficiency """ import torch from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoTokenizer from trl import SFTTrainer, SFTConfig import os def main(): # Load model and tokenizer print("Loading model and tokenizer...") model_name = "Qwen/Qwen3-1.7B" tokenizer = AutoTokenizer.from_pretrained(model_name) # Ensure pad token is set if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.bfloat16, attn_implementation="sdpa", # Use SDPA instead of flash_attention_2 device_map="auto", ) # Load dataset print("Loading dataset...") dataset = load_dataset("json", data_files="combined_math_train.jsonl", split="train") print(f"Dataset size: {len(dataset)}") # Training config - optimized for H100 and GSM8K task # With 247K examples and batch_size 8 * grad_accum 4 = effective batch 32 # Steps per epoch: 247467 / 32 ≈ 7733 steps # 2 epochs ≈ 15466 steps training_args = SFTConfig( output_dir="./sft_output_improved", num_train_epochs=2, per_device_train_batch_size=8, gradient_accumulation_steps=4, learning_rate=2e-5, lr_scheduler_type="cosine", warmup_ratio=0.03, weight_decay=0.01, logging_steps=100, save_steps=2000, save_total_limit=3, bf16=True, gradient_checkpointing=True, gradient_checkpointing_kwargs={"use_reentrant": False}, max_length=1024, # Math problems don't need very long context packing=True, report_to="none", seed=42, dataloader_num_workers=4, optim="adamw_torch_fused", ) # Create trainer print("Creating trainer...") trainer = SFTTrainer( model=model, args=training_args, train_dataset=dataset, processing_class=tokenizer, ) # Print training info print(f"\n=== Training Configuration ===") print(f"Model: {model_name}") print(f"Dataset size: {len(dataset)}") print(f"Batch size: {training_args.per_device_train_batch_size}") print(f"Gradient accumulation: {training_args.gradient_accumulation_steps}") print(f"Effective batch size: {training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps}") print(f"Learning rate: {training_args.learning_rate}") print(f"Epochs: {training_args.num_train_epochs}") print(f"Max length: {training_args.max_length}") print("="*30) # Train print("\nStarting training...") trainer.train() # Save final model print("\nSaving model to final_model/...") trainer.save_model("final_model") tokenizer.save_pretrained("final_model") print("Training complete!") if __name__ == "__main__": main()