初始化项目,由ModelHub XC社区提供模型
Model: HuggingFaceTB/SmolLM3-3B-GSM8K-SFT Source: Original Platform
This commit is contained in:
93
training/evaluate_gsm8k.py
Normal file
93
training/evaluate_gsm8k.py
Normal file
@@ -0,0 +1,93 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Simple GSM8K evaluation script for SmolLM3-3B-GSM8K-SFT
|
||||
|
||||
Requirements:
|
||||
pip install transformers datasets vllm
|
||||
|
||||
Usage:
|
||||
python evaluate_gsm8k.py --model HuggingFaceTB/SmolLM3-3B-GSM8K-SFT --samples 100
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import re
|
||||
from datasets import load_dataset
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
def extract_answer(text: str) -> str:
|
||||
"""Extract numerical answer from model output."""
|
||||
# Look for #### followed by number
|
||||
match = re.search(r'####\s*(-?[\d,]+(?:\.\d+)?)', text)
|
||||
if match:
|
||||
return match.group(1).replace(',', '')
|
||||
|
||||
# Look for "answer is X" pattern
|
||||
match = re.search(r'answer is[:\s]*(-?[\d,]+(?:\.\d+)?)', text.lower())
|
||||
if match:
|
||||
return match.group(1).replace(',', '')
|
||||
|
||||
# Extract last number
|
||||
numbers = re.findall(r'-?[\d,]+(?:\.\d+)?', text)
|
||||
if numbers:
|
||||
return numbers[-1].replace(',', '')
|
||||
|
||||
return ""
|
||||
|
||||
def extract_gold_answer(text: str) -> str:
|
||||
"""Extract gold answer from GSM8K format."""
|
||||
match = re.search(r'####\s*(-?[\d,]+(?:\.\d+)?)', text)
|
||||
if match:
|
||||
return match.group(1).replace(',', '')
|
||||
return ""
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model", type=str, default="HuggingFaceTB/SmolLM3-3B-GSM8K-SFT")
|
||||
parser.add_argument("--samples", type=int, default=100)
|
||||
args = parser.parse_args()
|
||||
|
||||
print(f"Loading model: {args.model}")
|
||||
llm = LLM(
|
||||
model=args.model,
|
||||
trust_remote_code=True,
|
||||
gpu_memory_utilization=0.9,
|
||||
)
|
||||
|
||||
tokenizer = llm.get_tokenizer()
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=512,
|
||||
temperature=0.0,
|
||||
stop=["<|im_end|>"],
|
||||
)
|
||||
|
||||
print("Loading GSM8K test set...")
|
||||
dataset = load_dataset("openai/gsm8k", "main", split="test")
|
||||
dataset = dataset.select(range(min(args.samples, len(dataset))))
|
||||
|
||||
# Prepare prompts
|
||||
prompts = []
|
||||
for example in dataset:
|
||||
messages = [{"role": "user", "content": example["question"]}]
|
||||
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||
prompts.append(prompt)
|
||||
|
||||
# Generate
|
||||
print(f"Evaluating {len(prompts)} samples...")
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
# Score
|
||||
correct = 0
|
||||
for i, (example, output) in enumerate(zip(dataset, outputs)):
|
||||
pred = extract_answer(output.outputs[0].text)
|
||||
gold = extract_gold_answer(example["answer"])
|
||||
|
||||
if pred == gold:
|
||||
correct += 1
|
||||
|
||||
accuracy = correct / len(dataset) * 100
|
||||
print(f"\nResults:")
|
||||
print(f" Correct: {correct}/{len(dataset)}")
|
||||
print(f" Accuracy: {accuracy:.1f}%")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
8
training/requirements.txt
Normal file
8
training/requirements.txt
Normal file
@@ -0,0 +1,8 @@
|
||||
torch>=2.0.0
|
||||
transformers>=4.40.0
|
||||
trl>=1.0.0
|
||||
datasets>=2.18.0
|
||||
accelerate>=0.27.0
|
||||
vllm>=0.4.0
|
||||
flash-attn>=2.5.0
|
||||
bitsandbytes>=0.42.0
|
||||
139
training/train_sft.py
Normal file
139
training/train_sft.py
Normal file
@@ -0,0 +1,139 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
SFT training script for SmolLM3-3B-GSM8K-SFT
|
||||
Fine-tunes SmolLM3-3B-Base on MetaMathQA for math reasoning.
|
||||
|
||||
Requirements:
|
||||
pip install transformers trl datasets torch
|
||||
|
||||
Usage:
|
||||
python train_sft.py
|
||||
"""
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from trl import SFTConfig, SFTTrainer
|
||||
|
||||
# Configuration
|
||||
BASE_MODEL = "HuggingFaceTB/SmolLM3-3B-Base"
|
||||
OUTPUT_DIR = "./sft_output"
|
||||
FINAL_MODEL_DIR = "./final_model"
|
||||
NUM_SAMPLES = 100000 # Use 100k samples from MetaMathQA
|
||||
|
||||
# ChatML template for base model (it has no chat template by default)
|
||||
CHAT_TEMPLATE = """{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful AI assistant.<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"""
|
||||
|
||||
def format_example(example):
|
||||
"""Format the MetaMathQA example into conversation format."""
|
||||
question = example['query']
|
||||
answer = example['response']
|
||||
|
||||
text = f"""<|im_start|>system
|
||||
You are a helpful AI assistant that solves math problems step by step. Show your work clearly and end with the final numerical answer after ####.<|im_end|>
|
||||
<|im_start|>user
|
||||
{question}<|im_end|>
|
||||
<|im_start|>assistant
|
||||
{answer}<|im_end|>"""
|
||||
|
||||
return {"text": text}
|
||||
|
||||
def main():
|
||||
print("=" * 50)
|
||||
print("Starting SFT Training for Math Reasoning")
|
||||
print("=" * 50)
|
||||
|
||||
# Load tokenizer and add chat template
|
||||
print("\nLoading tokenizer...")
|
||||
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
# Add ChatML special tokens
|
||||
special_tokens = {
|
||||
"additional_special_tokens": ["<|im_start|>", "<|im_end|>"]
|
||||
}
|
||||
tokenizer.add_special_tokens(special_tokens)
|
||||
tokenizer.chat_template = CHAT_TEMPLATE
|
||||
|
||||
# Load model
|
||||
print("\nLoading model...")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
BASE_MODEL,
|
||||
torch_dtype=torch.bfloat16,
|
||||
attn_implementation="flash_attention_2",
|
||||
device_map="auto",
|
||||
)
|
||||
|
||||
# Resize embeddings for new tokens
|
||||
model.resize_token_embeddings(len(tokenizer))
|
||||
|
||||
# Load and prepare dataset
|
||||
print("\nLoading MetaMathQA dataset...")
|
||||
dataset = load_dataset("meta-math/MetaMathQA", split="train")
|
||||
dataset = dataset.shuffle(seed=42).select(range(min(NUM_SAMPLES, len(dataset))))
|
||||
print(f"Using {len(dataset)} samples")
|
||||
|
||||
# Format dataset
|
||||
print("\nFormatting dataset...")
|
||||
formatted_dataset = dataset.map(
|
||||
format_example,
|
||||
remove_columns=dataset.column_names
|
||||
)
|
||||
|
||||
# Split for train/eval
|
||||
split_dataset = formatted_dataset.train_test_split(test_size=0.01, seed=42)
|
||||
train_dataset = split_dataset['train']
|
||||
eval_dataset = split_dataset['test']
|
||||
|
||||
print(f"Train samples: {len(train_dataset)}")
|
||||
print(f"Eval samples: {len(eval_dataset)}")
|
||||
|
||||
# Training configuration
|
||||
training_args = SFTConfig(
|
||||
output_dir=OUTPUT_DIR,
|
||||
num_train_epochs=1,
|
||||
per_device_train_batch_size=2,
|
||||
gradient_accumulation_steps=8, # Effective batch size = 16
|
||||
learning_rate=1e-5,
|
||||
lr_scheduler_type="cosine",
|
||||
warmup_ratio=0.05,
|
||||
weight_decay=0.01,
|
||||
logging_steps=50,
|
||||
save_steps=500,
|
||||
save_total_limit=2,
|
||||
eval_strategy="steps",
|
||||
eval_steps=500,
|
||||
bf16=True,
|
||||
gradient_checkpointing=True,
|
||||
report_to="none",
|
||||
max_length=2048,
|
||||
seed=42,
|
||||
packing=False,
|
||||
)
|
||||
|
||||
# Create trainer
|
||||
print("\nInitializing SFT Trainer...")
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
processing_class=tokenizer,
|
||||
)
|
||||
|
||||
# Train
|
||||
print("\nStarting training...")
|
||||
trainer.train()
|
||||
|
||||
# Save final model
|
||||
print(f"\nSaving final model to {FINAL_MODEL_DIR}...")
|
||||
trainer.save_model(FINAL_MODEL_DIR)
|
||||
tokenizer.save_pretrained(FINAL_MODEL_DIR)
|
||||
|
||||
print("\n" + "=" * 50)
|
||||
print("Training completed!")
|
||||
print("=" * 50)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user