#!/usr/bin/env python3 """ Continue training the already fine-tuned model with lower learning rate for additional refinement. """ import torch from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoTokenizer from trl import SFTTrainer, SFTConfig import os def main(): # Load from our previously trained model print("Loading previously trained model from final_model/...") model_name = "./final_model" tokenizer = AutoTokenizer.from_pretrained(model_name) # Ensure pad token is set if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.bfloat16, attn_implementation="sdpa", device_map="auto", ) # Load dataset print("Loading dataset...") dataset = load_dataset("json", data_files="combined_math_train.jsonl", split="train") print(f"Dataset size: {len(dataset)}") # Training config - lower LR for refinement, 1 more epoch training_args = SFTConfig( output_dir="./sft_output_continued", num_train_epochs=1, per_device_train_batch_size=8, gradient_accumulation_steps=4, learning_rate=5e-6, # Lower LR for continued training lr_scheduler_type="cosine", warmup_ratio=0.01, weight_decay=0.01, logging_steps=100, save_steps=2000, save_total_limit=2, bf16=True, gradient_checkpointing=True, gradient_checkpointing_kwargs={"use_reentrant": False}, max_length=1024, packing=True, report_to="none", seed=42, dataloader_num_workers=4, optim="adamw_torch_fused", ) # Create trainer print("Creating trainer...") trainer = SFTTrainer( model=model, args=training_args, train_dataset=dataset, processing_class=tokenizer, ) # Print training info print(f"\n=== Continued Training Configuration ===") print(f"Model: {model_name} (previously fine-tuned)") print(f"Dataset size: {len(dataset)}") print(f"Batch size: {training_args.per_device_train_batch_size}") print(f"Gradient accumulation: {training_args.gradient_accumulation_steps}") print(f"Effective batch size: {training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps}") print(f"Learning rate: {training_args.learning_rate}") print(f"Epochs: {training_args.num_train_epochs}") print("="*40) # Train print("\nStarting continued training...") trainer.train() # Save final model print("\nSaving model to final_model/...") trainer.save_model("final_model") tokenizer.save_pretrained("final_model") print("Continued training complete!") if __name__ == "__main__": main()