91 lines
2.8 KiB
Python
91 lines
2.8 KiB
Python
|
|
#!/usr/bin/env python3
|
||
|
|
"""
|
||
|
|
Continue training the already fine-tuned model with lower learning rate for additional refinement.
|
||
|
|
"""
|
||
|
|
|
||
|
|
import torch
|
||
|
|
from datasets import load_dataset
|
||
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||
|
|
from trl import SFTTrainer, SFTConfig
|
||
|
|
import os
|
||
|
|
|
||
|
|
def main():
|
||
|
|
# Load from our previously trained model
|
||
|
|
print("Loading previously trained model from final_model/...")
|
||
|
|
model_name = "./final_model"
|
||
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||
|
|
|
||
|
|
# Ensure pad token is set
|
||
|
|
if tokenizer.pad_token is None:
|
||
|
|
tokenizer.pad_token = tokenizer.eos_token
|
||
|
|
|
||
|
|
model = AutoModelForCausalLM.from_pretrained(
|
||
|
|
model_name,
|
||
|
|
torch_dtype=torch.bfloat16,
|
||
|
|
attn_implementation="sdpa",
|
||
|
|
device_map="auto",
|
||
|
|
)
|
||
|
|
|
||
|
|
# Load dataset
|
||
|
|
print("Loading dataset...")
|
||
|
|
dataset = load_dataset("json", data_files="combined_math_train.jsonl", split="train")
|
||
|
|
print(f"Dataset size: {len(dataset)}")
|
||
|
|
|
||
|
|
# Training config - lower LR for refinement, 1 more epoch
|
||
|
|
training_args = SFTConfig(
|
||
|
|
output_dir="./sft_output_continued",
|
||
|
|
num_train_epochs=1,
|
||
|
|
per_device_train_batch_size=8,
|
||
|
|
gradient_accumulation_steps=4,
|
||
|
|
learning_rate=5e-6, # Lower LR for continued training
|
||
|
|
lr_scheduler_type="cosine",
|
||
|
|
warmup_ratio=0.01,
|
||
|
|
weight_decay=0.01,
|
||
|
|
logging_steps=100,
|
||
|
|
save_steps=2000,
|
||
|
|
save_total_limit=2,
|
||
|
|
bf16=True,
|
||
|
|
gradient_checkpointing=True,
|
||
|
|
gradient_checkpointing_kwargs={"use_reentrant": False},
|
||
|
|
max_length=1024,
|
||
|
|
packing=True,
|
||
|
|
report_to="none",
|
||
|
|
seed=42,
|
||
|
|
dataloader_num_workers=4,
|
||
|
|
optim="adamw_torch_fused",
|
||
|
|
)
|
||
|
|
|
||
|
|
# Create trainer
|
||
|
|
print("Creating trainer...")
|
||
|
|
trainer = SFTTrainer(
|
||
|
|
model=model,
|
||
|
|
args=training_args,
|
||
|
|
train_dataset=dataset,
|
||
|
|
processing_class=tokenizer,
|
||
|
|
)
|
||
|
|
|
||
|
|
# Print training info
|
||
|
|
print(f"\n=== Continued Training Configuration ===")
|
||
|
|
print(f"Model: {model_name} (previously fine-tuned)")
|
||
|
|
print(f"Dataset size: {len(dataset)}")
|
||
|
|
print(f"Batch size: {training_args.per_device_train_batch_size}")
|
||
|
|
print(f"Gradient accumulation: {training_args.gradient_accumulation_steps}")
|
||
|
|
print(f"Effective batch size: {training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps}")
|
||
|
|
print(f"Learning rate: {training_args.learning_rate}")
|
||
|
|
print(f"Epochs: {training_args.num_train_epochs}")
|
||
|
|
print("="*40)
|
||
|
|
|
||
|
|
# Train
|
||
|
|
print("\nStarting continued training...")
|
||
|
|
trainer.train()
|
||
|
|
|
||
|
|
# Save final model
|
||
|
|
print("\nSaving model to final_model/...")
|
||
|
|
trainer.save_model("final_model")
|
||
|
|
tokenizer.save_pretrained("final_model")
|
||
|
|
|
||
|
|
print("Continued training complete!")
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
main()
|