初始化项目,由ModelHub XC社区提供模型
Model: HuggingFaceTB/qwen3-1.7b-gsm8k-sft Source: Original Platform
This commit is contained in:
99
scripts/train_improved.py
Normal file
99
scripts/train_improved.py
Normal file
@@ -0,0 +1,99 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Improved SFT training for GSM8K performance.
|
||||
Key improvements:
|
||||
1. More training data (247K examples from GSM8K + MetaMathQA)
|
||||
2. Multiple epochs with cosine LR schedule
|
||||
3. Proper batch size and gradient accumulation for H100
|
||||
4. Gradient checkpointing for memory efficiency
|
||||
"""
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from trl import SFTTrainer, SFTConfig
|
||||
import os
|
||||
|
||||
def main():
|
||||
# Load model and tokenizer
|
||||
print("Loading model and tokenizer...")
|
||||
model_name = "Qwen/Qwen3-1.7B"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
|
||||
# Ensure pad token is set
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
attn_implementation="sdpa", # Use SDPA instead of flash_attention_2
|
||||
device_map="auto",
|
||||
)
|
||||
|
||||
# Load dataset
|
||||
print("Loading dataset...")
|
||||
dataset = load_dataset("json", data_files="combined_math_train.jsonl", split="train")
|
||||
print(f"Dataset size: {len(dataset)}")
|
||||
|
||||
# Training config - optimized for H100 and GSM8K task
|
||||
# With 247K examples and batch_size 8 * grad_accum 4 = effective batch 32
|
||||
# Steps per epoch: 247467 / 32 ≈ 7733 steps
|
||||
# 2 epochs ≈ 15466 steps
|
||||
training_args = SFTConfig(
|
||||
output_dir="./sft_output_improved",
|
||||
num_train_epochs=2,
|
||||
per_device_train_batch_size=8,
|
||||
gradient_accumulation_steps=4,
|
||||
learning_rate=2e-5,
|
||||
lr_scheduler_type="cosine",
|
||||
warmup_ratio=0.03,
|
||||
weight_decay=0.01,
|
||||
logging_steps=100,
|
||||
save_steps=2000,
|
||||
save_total_limit=3,
|
||||
bf16=True,
|
||||
gradient_checkpointing=True,
|
||||
gradient_checkpointing_kwargs={"use_reentrant": False},
|
||||
max_length=1024, # Math problems don't need very long context
|
||||
packing=True,
|
||||
report_to="none",
|
||||
seed=42,
|
||||
dataloader_num_workers=4,
|
||||
optim="adamw_torch_fused",
|
||||
)
|
||||
|
||||
# Create trainer
|
||||
print("Creating trainer...")
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
processing_class=tokenizer,
|
||||
)
|
||||
|
||||
# Print training info
|
||||
print(f"\n=== Training Configuration ===")
|
||||
print(f"Model: {model_name}")
|
||||
print(f"Dataset size: {len(dataset)}")
|
||||
print(f"Batch size: {training_args.per_device_train_batch_size}")
|
||||
print(f"Gradient accumulation: {training_args.gradient_accumulation_steps}")
|
||||
print(f"Effective batch size: {training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps}")
|
||||
print(f"Learning rate: {training_args.learning_rate}")
|
||||
print(f"Epochs: {training_args.num_train_epochs}")
|
||||
print(f"Max length: {training_args.max_length}")
|
||||
print("="*30)
|
||||
|
||||
# Train
|
||||
print("\nStarting training...")
|
||||
trainer.train()
|
||||
|
||||
# Save final model
|
||||
print("\nSaving model to final_model/...")
|
||||
trainer.save_model("final_model")
|
||||
tokenizer.save_pretrained("final_model")
|
||||
|
||||
print("Training complete!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user