初始化项目,由ModelHub XC社区提供模型
Model: HuggingFaceTB/qwen3-1.7b-gsm8k-sft Source: Original Platform
This commit is contained in:
138
scripts/evaluate.py
Normal file
138
scripts/evaluate.py
Normal file
@@ -0,0 +1,138 @@
|
||||
#!/usr/bin/env python3
|
||||
from __future__ import annotations
|
||||
import os
|
||||
|
||||
import argparse
|
||||
import json
|
||||
|
||||
from inspect_ai.log._log import EvalLog, EvalMetric, EvalSample
|
||||
from inspect_ai import eval as inspect_eval # type: ignore # noqa: E402
|
||||
from inspect_ai.util._display import init_display_type # noqa: E402
|
||||
|
||||
import inspect_evals.gsm8k # noqa: F401, E402 (registers task definitions)
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description="Run Inspect AI eval without banners.")
|
||||
parser.add_argument(
|
||||
"--model-path",
|
||||
type=str,
|
||||
default="final_model",
|
||||
help="Path to the Hugging Face model (directory or model identifier).",
|
||||
)
|
||||
# this is a good limit for this task, just keep it like that (or use less in case you want faster tests)
|
||||
parser.add_argument(
|
||||
"--limit",
|
||||
type=int,
|
||||
default=150,
|
||||
help="Optional limit for number of samples to evaluate.",
|
||||
)
|
||||
parser.add_argument(
|
||||
'--json-output-file',
|
||||
type=str,
|
||||
default=None,
|
||||
help="Optional path to output the metrics as a seperate JSON file.",
|
||||
)
|
||||
parser.add_argument(
|
||||
'--templates-dir',
|
||||
type=str,
|
||||
default="templates/",
|
||||
)
|
||||
# You can adjust --max-connections if you want faster tests and don't receive errors (or if you have issues with vllm, try lowering this value)
|
||||
parser.add_argument(
|
||||
"--max-connections",
|
||||
type=int,
|
||||
default=2,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-tokens",
|
||||
type=int,
|
||||
default=4000,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gpu-memory-utilization",
|
||||
type=float,
|
||||
default=0.3,
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
|
||||
init_display_type("plain")
|
||||
|
||||
other_kwargs = {}
|
||||
if (args.limit is not None) and (args.limit != -1):
|
||||
other_kwargs["limit"] = args.limit
|
||||
|
||||
task = "inspect_evals/gsm8k"
|
||||
model_args = {
|
||||
'gpu_memory_utilization': args.gpu_memory_utilization,
|
||||
}
|
||||
model_args.update(template_kwargs(args))
|
||||
|
||||
eval_out = inspect_eval(
|
||||
task,
|
||||
model=f"vllm/{args.model_path}",
|
||||
model_args=model_args,
|
||||
score_display=False,
|
||||
log_realtime=False,
|
||||
log_format='json',
|
||||
timeout=18000000,
|
||||
attempt_timeout=18000000,
|
||||
max_tokens=args.max_tokens,
|
||||
max_connections=args.max_connections,
|
||||
**other_kwargs,
|
||||
)
|
||||
|
||||
if args.json_output_file is not None:
|
||||
assert len(eval_out) == 1, eval_out
|
||||
assert len(eval_out[0].results.scores) == 1, eval_out[0].results.scores
|
||||
metrics = {}
|
||||
for k, v in eval_out[0].results.scores[0].metrics.items():
|
||||
metrics[k] = v.value
|
||||
|
||||
with open(args.json_output_file, 'w') as f:
|
||||
json.dump(metrics, f, indent=2)
|
||||
|
||||
def model_type(args) -> str:
|
||||
if 'qwen' in args.model_path.lower():
|
||||
return 'qwen'
|
||||
if 'llama' in args.model_path.lower():
|
||||
return 'llama'
|
||||
if 'gemma' in args.model_path.lower():
|
||||
return 'gemma'
|
||||
if 'smollm' in args.model_path.lower():
|
||||
return 'smollm'
|
||||
|
||||
with open(os.path.join(args.model_path, "config.json"), 'r') as f:
|
||||
config = json.load(f)
|
||||
architecture = config['architectures'][0].lower()
|
||||
if 'gemma' in architecture:
|
||||
return 'gemma'
|
||||
if 'llama' in architecture:
|
||||
return 'llama'
|
||||
if 'qwen' in architecture:
|
||||
return 'qwen'
|
||||
if 'smollm' in architecture:
|
||||
return 'smollm'
|
||||
raise ValueError(architecture)
|
||||
|
||||
def template_kwargs(args) -> dict:
|
||||
model_type_str = model_type(args)
|
||||
if model_type_str == 'qwen':
|
||||
template = 'qwen3.jinja'
|
||||
elif model_type_str == 'llama':
|
||||
template = 'llama3.jinja'
|
||||
elif model_type_str == 'gemma':
|
||||
template = 'gemma3.jinja'
|
||||
elif model_type_str == 'smollm':
|
||||
template = 'smollm.jinja'
|
||||
else:
|
||||
raise ValueError(model_type_str)
|
||||
return {
|
||||
'chat_template': os.path.join(args.templates_dir, template)
|
||||
}
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
159
scripts/evaluate_math500.py
Normal file
159
scripts/evaluate_math500.py
Normal file
@@ -0,0 +1,159 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Evaluate model on MATH-500 dataset (harder math problems)."""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import re
|
||||
from datasets import load_dataset
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
def extract_answer(response: str) -> str:
|
||||
"""Extract the final answer from model response."""
|
||||
# Look for boxed answer first (common in MATH format)
|
||||
boxed_match = re.search(r'\\boxed\{([^}]+)\}', response)
|
||||
if boxed_match:
|
||||
return boxed_match.group(1).strip()
|
||||
|
||||
# Look for "The answer is X" pattern
|
||||
answer_match = re.search(r'[Tt]he (?:final )?answer is[:\s]*([^\n.]+)', response)
|
||||
if answer_match:
|
||||
return answer_match.group(1).strip()
|
||||
|
||||
# Look for "= X" at the end
|
||||
equals_match = re.search(r'=\s*([^\n=]+?)\s*$', response)
|
||||
if equals_match:
|
||||
return equals_match.group(1).strip()
|
||||
|
||||
# Return last line as fallback
|
||||
lines = [l.strip() for l in response.strip().split('\n') if l.strip()]
|
||||
return lines[-1] if lines else ""
|
||||
|
||||
def normalize_answer(answer: str) -> str:
|
||||
"""Normalize answer for comparison."""
|
||||
# Remove common formatting
|
||||
answer = answer.strip()
|
||||
answer = re.sub(r'\\text\{([^}]*)\}', r'\1', answer)
|
||||
answer = re.sub(r'\\mathrm\{([^}]*)\}', r'\1', answer)
|
||||
answer = re.sub(r'\\left|\\right', '', answer)
|
||||
answer = re.sub(r'\$', '', answer)
|
||||
answer = answer.strip()
|
||||
return answer.lower()
|
||||
|
||||
def answers_match(predicted: str, expected: str) -> bool:
|
||||
"""Check if answers match (with some tolerance)."""
|
||||
pred_norm = normalize_answer(predicted)
|
||||
exp_norm = normalize_answer(expected)
|
||||
|
||||
# Direct match
|
||||
if pred_norm == exp_norm:
|
||||
return True
|
||||
|
||||
# Try numeric comparison
|
||||
try:
|
||||
pred_num = float(re.sub(r'[^\d.-]', '', pred_norm))
|
||||
exp_num = float(re.sub(r'[^\d.-]', '', exp_norm))
|
||||
if abs(pred_num - exp_num) < 1e-6:
|
||||
return True
|
||||
except:
|
||||
pass
|
||||
|
||||
# Check if one contains the other
|
||||
if exp_norm in pred_norm or pred_norm in exp_norm:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model-path", type=str, default="final_model")
|
||||
parser.add_argument("--limit", type=int, default=100)
|
||||
args = parser.parse_args()
|
||||
|
||||
print(f"Loading MATH-500 dataset...")
|
||||
dataset = load_dataset("HuggingFaceH4/MATH-500", split="test")
|
||||
|
||||
if args.limit:
|
||||
dataset = dataset.select(range(min(args.limit, len(dataset))))
|
||||
|
||||
print(f"Evaluating {len(dataset)} problems...")
|
||||
|
||||
# Load model
|
||||
print(f"Loading model from {args.model_path}...")
|
||||
llm = LLM(
|
||||
model=args.model_path,
|
||||
dtype="bfloat16",
|
||||
max_model_len=4096,
|
||||
gpu_memory_utilization=0.9,
|
||||
)
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0,
|
||||
max_tokens=2048,
|
||||
stop=["<|im_end|>", "<|endoftext|>"],
|
||||
)
|
||||
|
||||
# Prepare prompts
|
||||
prompts = []
|
||||
for item in dataset:
|
||||
problem = item["problem"]
|
||||
prompt = f"<|im_start|>user\n{problem}<|im_end|>\n<|im_start|>assistant\n"
|
||||
prompts.append(prompt)
|
||||
|
||||
# Generate
|
||||
print("Generating responses...")
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
# Evaluate
|
||||
correct = 0
|
||||
results_by_level = {}
|
||||
results_by_subject = {}
|
||||
|
||||
for i, (item, output) in enumerate(zip(dataset, outputs)):
|
||||
response = output.outputs[0].text
|
||||
predicted = extract_answer(response)
|
||||
expected = item["answer"]
|
||||
level = item["level"]
|
||||
subject = item["subject"]
|
||||
|
||||
is_correct = answers_match(predicted, expected)
|
||||
if is_correct:
|
||||
correct += 1
|
||||
|
||||
# Track by level
|
||||
if level not in results_by_level:
|
||||
results_by_level[level] = {"correct": 0, "total": 0}
|
||||
results_by_level[level]["total"] += 1
|
||||
if is_correct:
|
||||
results_by_level[level]["correct"] += 1
|
||||
|
||||
# Track by subject
|
||||
if subject not in results_by_subject:
|
||||
results_by_subject[subject] = {"correct": 0, "total": 0}
|
||||
results_by_subject[subject]["total"] += 1
|
||||
if is_correct:
|
||||
results_by_subject[subject]["correct"] += 1
|
||||
|
||||
if (i + 1) % 20 == 0:
|
||||
print(f"Progress: {i+1}/{len(dataset)}, Accuracy so far: {correct/(i+1)*100:.1f}%")
|
||||
|
||||
# Print results
|
||||
accuracy = correct / len(dataset) * 100
|
||||
print(f"\n{'='*60}")
|
||||
print(f"MATH-500 Results ({len(dataset)} problems)")
|
||||
print(f"{'='*60}")
|
||||
print(f"Overall Accuracy: {accuracy:.1f}% ({correct}/{len(dataset)})")
|
||||
|
||||
print(f"\nBy Level:")
|
||||
for level in sorted(results_by_level.keys()):
|
||||
stats = results_by_level[level]
|
||||
acc = stats["correct"] / stats["total"] * 100
|
||||
print(f" {level}: {acc:.1f}% ({stats['correct']}/{stats['total']})")
|
||||
|
||||
print(f"\nBy Subject:")
|
||||
for subject in sorted(results_by_subject.keys()):
|
||||
stats = results_by_subject[subject]
|
||||
acc = stats["correct"] / stats["total"] * 100
|
||||
print(f" {subject}: {acc:.1f}% ({stats['correct']}/{stats['total']})")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
154
scripts/prepare_combined_data.py
Normal file
154
scripts/prepare_combined_data.py
Normal file
@@ -0,0 +1,154 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Combine GSM8K training data with MetaMathQA GSM-related examples.
|
||||
This creates a larger, more diverse training set.
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
from datasets import load_dataset
|
||||
|
||||
def extract_answer_gsm8k(answer_text):
|
||||
"""Extract the final numerical answer from GSM8K answer format."""
|
||||
match = re.search(r'####\s*(-?[\d,]+\.?\d*)', answer_text)
|
||||
if match:
|
||||
return match.group(1).replace(',', '')
|
||||
return None
|
||||
|
||||
def format_reasoning_gsm8k(answer_text):
|
||||
"""Convert GSM8K step-by-step format to thinking format."""
|
||||
reasoning = re.sub(r'####\s*-?[\d,]+\.?\d*\s*$', '', answer_text).strip()
|
||||
reasoning = re.sub(r'<<[^>]+>>', '', reasoning)
|
||||
return reasoning
|
||||
|
||||
def extract_answer_metamath(response):
|
||||
"""Extract answer from MetaMathQA format (usually ends with boxed answer)."""
|
||||
# Try to find boxed answer
|
||||
match = re.search(r'\\boxed\{([^}]+)\}', response)
|
||||
if match:
|
||||
return match.group(1).strip()
|
||||
# Try to find "the answer is X" pattern
|
||||
match = re.search(r'the answer is[:\s]*\$?(-?[\d,]+\.?\d*)', response, re.IGNORECASE)
|
||||
if match:
|
||||
return match.group(1).replace(',', '')
|
||||
# Try to find "= X" at the end
|
||||
match = re.search(r'=\s*\$?(-?[\d,]+\.?\d*)\s*(?:dollars?|\.)?$', response)
|
||||
if match:
|
||||
return match.group(1).replace(',', '')
|
||||
return None
|
||||
|
||||
def create_gsm8k_example(question, answer):
|
||||
"""Create a training example from GSM8K format."""
|
||||
final_answer = extract_answer_gsm8k(answer)
|
||||
reasoning = format_reasoning_gsm8k(answer)
|
||||
|
||||
if final_answer is None:
|
||||
return None
|
||||
|
||||
assistant_content = f"""<think>
|
||||
Let me solve this step by step.
|
||||
|
||||
{reasoning}
|
||||
|
||||
Therefore, the answer is {final_answer}.
|
||||
</think>
|
||||
|
||||
The answer is {final_answer}"""
|
||||
|
||||
return {
|
||||
"messages": [
|
||||
{"role": "user", "content": f"Solve the following math problem step by step. Show your reasoning and then provide the final answer.\n\n{question}"},
|
||||
{"role": "assistant", "content": assistant_content}
|
||||
]
|
||||
}
|
||||
|
||||
def create_metamath_example(query, response):
|
||||
"""Create a training example from MetaMathQA format."""
|
||||
# Clean up the response - remove LaTeX formatting artifacts
|
||||
clean_response = response.replace('\\n', '\n').strip()
|
||||
|
||||
# Extract the answer
|
||||
final_answer = extract_answer_metamath(clean_response)
|
||||
if final_answer is None:
|
||||
return None
|
||||
|
||||
# Remove the boxed answer and everything after for reasoning
|
||||
reasoning = re.sub(r'\\boxed\{[^}]+\}.*$', '', clean_response, flags=re.DOTALL).strip()
|
||||
reasoning = re.sub(r'The answer is.*$', '', reasoning, flags=re.IGNORECASE | re.DOTALL).strip()
|
||||
|
||||
# Skip if reasoning is too short
|
||||
if len(reasoning) < 50:
|
||||
return None
|
||||
|
||||
assistant_content = f"""<think>
|
||||
Let me solve this step by step.
|
||||
|
||||
{reasoning}
|
||||
|
||||
Therefore, the answer is {final_answer}.
|
||||
</think>
|
||||
|
||||
The answer is {final_answer}"""
|
||||
|
||||
return {
|
||||
"messages": [
|
||||
{"role": "user", "content": f"Solve the following math problem step by step. Show your reasoning and then provide the final answer.\n\n{query}"},
|
||||
{"role": "assistant", "content": assistant_content}
|
||||
]
|
||||
}
|
||||
|
||||
def main():
|
||||
training_data = []
|
||||
|
||||
# Load GSM8K training data
|
||||
print("Loading GSM8K dataset...")
|
||||
gsm8k = load_dataset("openai/gsm8k", "main", split="train")
|
||||
print(f"Loaded {len(gsm8k)} GSM8K examples")
|
||||
|
||||
gsm8k_count = 0
|
||||
for example in gsm8k:
|
||||
formatted = create_gsm8k_example(example['question'], example['answer'])
|
||||
if formatted:
|
||||
training_data.append(formatted)
|
||||
gsm8k_count += 1
|
||||
print(f"Added {gsm8k_count} GSM8K examples")
|
||||
|
||||
# Load MetaMathQA - only GSM-related examples
|
||||
print("\nLoading MetaMathQA dataset...")
|
||||
metamath = load_dataset("meta-math/MetaMathQA", split="train")
|
||||
print(f"Loaded {len(metamath)} MetaMathQA examples")
|
||||
|
||||
# Filter for GSM-related examples only
|
||||
metamath_count = 0
|
||||
for example in metamath:
|
||||
if 'GSM' in example['type']: # GSM_Rephrased, GSM_SV, GSM_AnsAug, etc.
|
||||
formatted = create_metamath_example(example['query'], example['response'])
|
||||
if formatted:
|
||||
training_data.append(formatted)
|
||||
metamath_count += 1
|
||||
print(f"Added {metamath_count} MetaMathQA GSM examples")
|
||||
|
||||
print(f"\nTotal training examples: {len(training_data)}")
|
||||
|
||||
# Shuffle the data
|
||||
import random
|
||||
random.seed(42)
|
||||
random.shuffle(training_data)
|
||||
|
||||
# Save to JSONL
|
||||
output_file = "combined_math_train.jsonl"
|
||||
with open(output_file, 'w') as f:
|
||||
for item in training_data:
|
||||
f.write(json.dumps(item) + '\n')
|
||||
|
||||
print(f"Saved to {output_file}")
|
||||
|
||||
# Show samples
|
||||
print("\n=== Sample GSM8K example ===")
|
||||
for item in training_data[:10]:
|
||||
if "Natalia" in item['messages'][0]['content']:
|
||||
print(json.dumps(item, indent=2)[:500])
|
||||
break
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
90
scripts/train_continued.py
Normal file
90
scripts/train_continued.py
Normal file
@@ -0,0 +1,90 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Continue training the already fine-tuned model with lower learning rate for additional refinement.
|
||||
"""
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from trl import SFTTrainer, SFTConfig
|
||||
import os
|
||||
|
||||
def main():
|
||||
# Load from our previously trained model
|
||||
print("Loading previously trained model from final_model/...")
|
||||
model_name = "./final_model"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
|
||||
# Ensure pad token is set
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
attn_implementation="sdpa",
|
||||
device_map="auto",
|
||||
)
|
||||
|
||||
# Load dataset
|
||||
print("Loading dataset...")
|
||||
dataset = load_dataset("json", data_files="combined_math_train.jsonl", split="train")
|
||||
print(f"Dataset size: {len(dataset)}")
|
||||
|
||||
# Training config - lower LR for refinement, 1 more epoch
|
||||
training_args = SFTConfig(
|
||||
output_dir="./sft_output_continued",
|
||||
num_train_epochs=1,
|
||||
per_device_train_batch_size=8,
|
||||
gradient_accumulation_steps=4,
|
||||
learning_rate=5e-6, # Lower LR for continued training
|
||||
lr_scheduler_type="cosine",
|
||||
warmup_ratio=0.01,
|
||||
weight_decay=0.01,
|
||||
logging_steps=100,
|
||||
save_steps=2000,
|
||||
save_total_limit=2,
|
||||
bf16=True,
|
||||
gradient_checkpointing=True,
|
||||
gradient_checkpointing_kwargs={"use_reentrant": False},
|
||||
max_length=1024,
|
||||
packing=True,
|
||||
report_to="none",
|
||||
seed=42,
|
||||
dataloader_num_workers=4,
|
||||
optim="adamw_torch_fused",
|
||||
)
|
||||
|
||||
# Create trainer
|
||||
print("Creating trainer...")
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
processing_class=tokenizer,
|
||||
)
|
||||
|
||||
# Print training info
|
||||
print(f"\n=== Continued Training Configuration ===")
|
||||
print(f"Model: {model_name} (previously fine-tuned)")
|
||||
print(f"Dataset size: {len(dataset)}")
|
||||
print(f"Batch size: {training_args.per_device_train_batch_size}")
|
||||
print(f"Gradient accumulation: {training_args.gradient_accumulation_steps}")
|
||||
print(f"Effective batch size: {training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps}")
|
||||
print(f"Learning rate: {training_args.learning_rate}")
|
||||
print(f"Epochs: {training_args.num_train_epochs}")
|
||||
print("="*40)
|
||||
|
||||
# Train
|
||||
print("\nStarting continued training...")
|
||||
trainer.train()
|
||||
|
||||
# Save final model
|
||||
print("\nSaving model to final_model/...")
|
||||
trainer.save_model("final_model")
|
||||
tokenizer.save_pretrained("final_model")
|
||||
|
||||
print("Continued training complete!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
99
scripts/train_improved.py
Normal file
99
scripts/train_improved.py
Normal file
@@ -0,0 +1,99 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Improved SFT training for GSM8K performance.
|
||||
Key improvements:
|
||||
1. More training data (247K examples from GSM8K + MetaMathQA)
|
||||
2. Multiple epochs with cosine LR schedule
|
||||
3. Proper batch size and gradient accumulation for H100
|
||||
4. Gradient checkpointing for memory efficiency
|
||||
"""
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from trl import SFTTrainer, SFTConfig
|
||||
import os
|
||||
|
||||
def main():
|
||||
# Load model and tokenizer
|
||||
print("Loading model and tokenizer...")
|
||||
model_name = "Qwen/Qwen3-1.7B"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
|
||||
# Ensure pad token is set
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
attn_implementation="sdpa", # Use SDPA instead of flash_attention_2
|
||||
device_map="auto",
|
||||
)
|
||||
|
||||
# Load dataset
|
||||
print("Loading dataset...")
|
||||
dataset = load_dataset("json", data_files="combined_math_train.jsonl", split="train")
|
||||
print(f"Dataset size: {len(dataset)}")
|
||||
|
||||
# Training config - optimized for H100 and GSM8K task
|
||||
# With 247K examples and batch_size 8 * grad_accum 4 = effective batch 32
|
||||
# Steps per epoch: 247467 / 32 ≈ 7733 steps
|
||||
# 2 epochs ≈ 15466 steps
|
||||
training_args = SFTConfig(
|
||||
output_dir="./sft_output_improved",
|
||||
num_train_epochs=2,
|
||||
per_device_train_batch_size=8,
|
||||
gradient_accumulation_steps=4,
|
||||
learning_rate=2e-5,
|
||||
lr_scheduler_type="cosine",
|
||||
warmup_ratio=0.03,
|
||||
weight_decay=0.01,
|
||||
logging_steps=100,
|
||||
save_steps=2000,
|
||||
save_total_limit=3,
|
||||
bf16=True,
|
||||
gradient_checkpointing=True,
|
||||
gradient_checkpointing_kwargs={"use_reentrant": False},
|
||||
max_length=1024, # Math problems don't need very long context
|
||||
packing=True,
|
||||
report_to="none",
|
||||
seed=42,
|
||||
dataloader_num_workers=4,
|
||||
optim="adamw_torch_fused",
|
||||
)
|
||||
|
||||
# Create trainer
|
||||
print("Creating trainer...")
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
processing_class=tokenizer,
|
||||
)
|
||||
|
||||
# Print training info
|
||||
print(f"\n=== Training Configuration ===")
|
||||
print(f"Model: {model_name}")
|
||||
print(f"Dataset size: {len(dataset)}")
|
||||
print(f"Batch size: {training_args.per_device_train_batch_size}")
|
||||
print(f"Gradient accumulation: {training_args.gradient_accumulation_steps}")
|
||||
print(f"Effective batch size: {training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps}")
|
||||
print(f"Learning rate: {training_args.learning_rate}")
|
||||
print(f"Epochs: {training_args.num_train_epochs}")
|
||||
print(f"Max length: {training_args.max_length}")
|
||||
print("="*30)
|
||||
|
||||
# Train
|
||||
print("\nStarting training...")
|
||||
trainer.train()
|
||||
|
||||
# Save final model
|
||||
print("\nSaving model to final_model/...")
|
||||
trainer.save_model("final_model")
|
||||
tokenizer.save_pretrained("final_model")
|
||||
|
||||
print("Training complete!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user