262 lines
6.8 KiB
Markdown
262 lines
6.8 KiB
Markdown
|
|
---
|
||
|
|
license: llama3.2
|
||
|
|
language:
|
||
|
|
- en
|
||
|
|
library_name: transformers
|
||
|
|
tags:
|
||
|
|
- llama
|
||
|
|
- llama-3.2
|
||
|
|
- gsm8k
|
||
|
|
- math
|
||
|
|
- reasoning
|
||
|
|
- full-parameter
|
||
|
|
- fine-tuning
|
||
|
|
base_model: meta-llama/Llama-3.2-3B-Instruct
|
||
|
|
---
|
||
|
|
|
||
|
|
# GSM8K Full Parameter Fine-tuned Llama 3.2 3B Instruct
|
||
|
|
|
||
|
|
Llama 3.2 3B Instruct model fine-tuned on GSM8K dataset using **Full Parameter Fine-tuning** for improved mathematical reasoning capabilities.
|
||
|
|
|
||
|
|
## Model Details
|
||
|
|
|
||
|
|
- **Base Model**: meta-llama/Llama-3.2-3B-Instruct
|
||
|
|
- **Training Method**: Full Parameter Fine-tuning (All weights updated)
|
||
|
|
- **Training Dataset**: [GSM8K](https://huggingface.co/datasets/openai/gsm8k)
|
||
|
|
- **Training Date**: 2026-02-23
|
||
|
|
- **Model Type**: Causal Language Model
|
||
|
|
- **Framework**: Transformers + TRL (SFTTrainer)
|
||
|
|
|
||
|
|
## Training Configuration
|
||
|
|
|
||
|
|
### Full Parameter Training
|
||
|
|
- **Method**: All model parameters updated (not LoRA)
|
||
|
|
- **Total Parameters**: ~3B (all trainable)
|
||
|
|
- **Training Samples**: 7,473
|
||
|
|
- **Epochs**: 3
|
||
|
|
- **Batch Size**: 2
|
||
|
|
- **Gradient Accumulation Steps**: 4
|
||
|
|
- **Effective Batch Size**: 8
|
||
|
|
- **Learning Rate**: 2e-5
|
||
|
|
- **Optimizer**: AdamW 8-bit
|
||
|
|
- **Scheduler**: Cosine
|
||
|
|
- **Warmup Ratio**: 0.0
|
||
|
|
- **Max Length**: 512
|
||
|
|
- **Dtype**: bfloat16
|
||
|
|
- **Gradient Checkpointing**: Enabled
|
||
|
|
|
||
|
|
## Performance
|
||
|
|
|
||
|
|
- **GSM8K Test Accuracy**: 40.00% (20/50 samples)
|
||
|
|
- **Training Time**: ~44 minutes
|
||
|
|
- **Hardware**: NVIDIA GPU (CUDA)
|
||
|
|
|
||
|
|
## Usage
|
||
|
|
|
||
|
|
### Basic Inference
|
||
|
|
|
||
|
|
```python
|
||
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||
|
|
import torch
|
||
|
|
|
||
|
|
# Load model and tokenizer
|
||
|
|
model = AutoModelForCausalLM.from_pretrained(
|
||
|
|
"kmseong/Llama3.2-3B-gsm8k-fullft-like-sn",
|
||
|
|
torch_dtype=torch.bfloat16,
|
||
|
|
device_map="auto",
|
||
|
|
trust_remote_code=True
|
||
|
|
)
|
||
|
|
tokenizer = AutoTokenizer.from_pretrained("kmseong/Llama3.2-3B-gsm8k-fullft-like-sn")
|
||
|
|
|
||
|
|
# Prepare prompt
|
||
|
|
question = "Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?"
|
||
|
|
|
||
|
|
prompt = f"""Solve this math problem step by step:
|
||
|
|
|
||
|
|
{question}
|
||
|
|
|
||
|
|
Provide your final answer in the format:
|
||
|
|
[reasoning steps]
|
||
|
|
####
|
||
|
|
[final answer (just the number)]"""
|
||
|
|
|
||
|
|
# Generate response
|
||
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
||
|
|
outputs = model.generate(
|
||
|
|
**inputs,
|
||
|
|
max_new_tokens=256,
|
||
|
|
temperature=0.7,
|
||
|
|
top_p=0.9,
|
||
|
|
do_sample=False
|
||
|
|
)
|
||
|
|
|
||
|
|
response = tokenizer.decode(outputs[0][inputs['input_ids'].shape[-1]:], skip_special_tokens=True)
|
||
|
|
print(response)
|
||
|
|
```
|
||
|
|
|
||
|
|
### Extract Answer
|
||
|
|
|
||
|
|
```python
|
||
|
|
import re
|
||
|
|
|
||
|
|
def extract_answer(text: str) -> str:
|
||
|
|
"""Extract numerical answer from model response."""
|
||
|
|
if '####' in text:
|
||
|
|
parts = text.split('####')
|
||
|
|
answer_part = parts[-1].strip()
|
||
|
|
numbers = re.findall(r'-?\d+\.?\d*', answer_part)
|
||
|
|
if numbers:
|
||
|
|
return numbers[0]
|
||
|
|
|
||
|
|
numbers = re.findall(r'-?\d+\.?\d*', text)
|
||
|
|
if numbers:
|
||
|
|
return numbers[-1]
|
||
|
|
return None
|
||
|
|
|
||
|
|
# Use after generation
|
||
|
|
answer = extract_answer(response)
|
||
|
|
print(f"Final Answer: {answer}")
|
||
|
|
```
|
||
|
|
|
||
|
|
### Batch Inference
|
||
|
|
|
||
|
|
```python
|
||
|
|
from datasets import load_dataset
|
||
|
|
from tqdm import tqdm
|
||
|
|
|
||
|
|
# Load GSM8K test set
|
||
|
|
test_dataset = load_dataset('openai/gsm8k', 'main', split='test')
|
||
|
|
|
||
|
|
correct = 0
|
||
|
|
total = 0
|
||
|
|
|
||
|
|
for sample in tqdm(test_dataset.select(range(100))):
|
||
|
|
question = sample['question']
|
||
|
|
expected = extract_answer(sample['answer'])
|
||
|
|
|
||
|
|
# Generate
|
||
|
|
prompt = create_prompt(question)
|
||
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
||
|
|
outputs = model.generate(**inputs, max_new_tokens=256)
|
||
|
|
response = tokenizer.decode(outputs[0][inputs['input_ids'].shape[-1]:], skip_special_tokens=True)
|
||
|
|
|
||
|
|
predicted = extract_answer(response)
|
||
|
|
|
||
|
|
if predicted and float(predicted) == float(expected):
|
||
|
|
correct += 1
|
||
|
|
total += 1
|
||
|
|
|
||
|
|
accuracy = (correct / total) * 100
|
||
|
|
print(f"Accuracy: {accuracy:.2f}%")
|
||
|
|
```
|
||
|
|
|
||
|
|
## Training Details
|
||
|
|
|
||
|
|
### Dataset Format
|
||
|
|
|
||
|
|
The model was trained on GSM8K with the following format:
|
||
|
|
|
||
|
|
```
|
||
|
|
Question: [math problem]
|
||
|
|
Answer: [step-by-step solution]
|
||
|
|
####
|
||
|
|
[final numerical answer]
|
||
|
|
```
|
||
|
|
|
||
|
|
### Training Script
|
||
|
|
|
||
|
|
```bash
|
||
|
|
python finetune_gsm8k_full_params.py \
|
||
|
|
--num_train_samples 7473 \
|
||
|
|
--num_eval_samples 0 \
|
||
|
|
--batch_size 2 \
|
||
|
|
--epochs 3 \
|
||
|
|
--learning_rate 2e-5 \
|
||
|
|
--max_length 512 \
|
||
|
|
--output_dir ./gsm8k_llama3_full_finetune \
|
||
|
|
--cache_dir ./cache \
|
||
|
|
--model_path meta-llama/Llama-3.2-3B-Instruct \
|
||
|
|
--dtype bfloat16
|
||
|
|
```
|
||
|
|
|
||
|
|
## Model Architecture
|
||
|
|
|
||
|
|
This is a **full parameter fine-tuned** model, meaning:
|
||
|
|
- ✅ All 3B parameters were updated during training
|
||
|
|
- ✅ No adapter/LoRA - this is the complete model
|
||
|
|
- ✅ Can be used directly without PEFT library
|
||
|
|
- ✅ Better performance than LoRA for sufficient training data
|
||
|
|
- ❌ Larger file size (~6GB)
|
||
|
|
- ❌ Longer training time
|
||
|
|
|
||
|
|
## Differences from LoRA
|
||
|
|
|
||
|
|
| Aspect | Full Parameter | LoRA |
|
||
|
|
|--------|---------------|------|
|
||
|
|
| **Trainable Params** | 100% (3B) | ~0.1% (~3M) |
|
||
|
|
| **Training Speed** | Slower | Faster |
|
||
|
|
| **Memory Usage** | Higher | Lower |
|
||
|
|
| **Model Size** | ~6GB | Base + ~10MB |
|
||
|
|
| **Performance** | Better with enough data | Good with limited data |
|
||
|
|
| **Use Case** | Production, large datasets | Research, quick experiments |
|
||
|
|
|
||
|
|
## Limitations
|
||
|
|
|
||
|
|
- Trained only on GSM8K (grade school math problems)
|
||
|
|
- May not generalize well to other mathematical domains
|
||
|
|
- Performance degrades on non-math tasks
|
||
|
|
- Requires GPU for inference (recommended: 16GB+ VRAM)
|
||
|
|
|
||
|
|
## Evaluation Results
|
||
|
|
|
||
|
|
### GSM8K Test Set (50 samples)
|
||
|
|
- ✅ Correct: 20
|
||
|
|
- ❌ Incorrect: 30
|
||
|
|
- 📊 Accuracy: 40.00%
|
||
|
|
|
||
|
|
### Example Predictions
|
||
|
|
|
||
|
|
**Correct Example:**
|
||
|
|
```
|
||
|
|
Question: Janet's ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?
|
||
|
|
Expected: 18
|
||
|
|
Predicted: 18 ✅
|
||
|
|
```
|
||
|
|
|
||
|
|
**Incorrect Example:**
|
||
|
|
```
|
||
|
|
Question: A robe takes 2 bolts of blue fiber and half that much white fiber. How many bolts in total does it take?
|
||
|
|
Expected: 20
|
||
|
|
Predicted: 267 ❌
|
||
|
|
```
|
||
|
|
|
||
|
|
## Citation
|
||
|
|
|
||
|
|
```bibtex
|
||
|
|
@misc{gsm8k-fullparam-llama32-3b,
|
||
|
|
title={GSM8K Full Parameter Fine-tuned Llama 3.2 3B Instruct},
|
||
|
|
author={Kim, Min-Seong},
|
||
|
|
year={2026},
|
||
|
|
publisher={HuggingFace},
|
||
|
|
howpublished={\url{https://huggingface.co/kmseong/Llama3.2-3B-gsm8k-fullft-like-sn}}
|
||
|
|
}
|
||
|
|
```
|
||
|
|
|
||
|
|
## License
|
||
|
|
|
||
|
|
This model is built on Llama 3.2 3B Instruct and follows the [Llama 3.2 Community License](https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/LICENSE).
|
||
|
|
|
||
|
|
## Acknowledgments
|
||
|
|
|
||
|
|
- **Base Model**: Meta AI's Llama 3.2 3B Instruct
|
||
|
|
- **Dataset**: OpenAI's GSM8K
|
||
|
|
- **Framework**: HuggingFace Transformers & TRL
|
||
|
|
|
||
|
|
## Contact
|
||
|
|
|
||
|
|
For questions or issues, please open an issue on the model repository.
|
||
|
|
|
||
|
|
---
|
||
|
|
|
||
|
|
**Note**: This is a full parameter fine-tuned model. Unlike LoRA models, all weights have been updated and the model can be used directly without any adapter libraries.
|