296 lines
9.6 KiB
Markdown
296 lines
9.6 KiB
Markdown
---
|
|
base_model: Qwen/Qwen3-8B
|
|
library_name: transformers
|
|
license: apache-2.0
|
|
language:
|
|
- en
|
|
tags:
|
|
- math
|
|
- reasoning
|
|
- gsm8k
|
|
- synthetic-data
|
|
- qwen3
|
|
- qlora
|
|
- unsloth
|
|
- fine-tuned
|
|
- grade-school-math
|
|
- chain-of-thought
|
|
datasets:
|
|
- clarkkitchen22/SynthGSM8K-50K
|
|
pipeline_tag: text-generation
|
|
model-index:
|
|
- name: Qwen3-8B-GSM8K-Synth-50K
|
|
results:
|
|
- task:
|
|
type: text-generation
|
|
name: Math Reasoning
|
|
dataset:
|
|
name: GSM8K
|
|
type: openai/gsm8k
|
|
metrics:
|
|
- name: GSM8K Accuracy
|
|
type: accuracy
|
|
value: 86.2
|
|
verified: true
|
|
- name: Training Loss (final)
|
|
type: loss
|
|
value: 0.266
|
|
---
|
|
|
|
# Qwen3-8B-GSM8K-Synth-50K
|
|
|
|
A **Qwen3-8B** model fine-tuned on **50,418 synthetic grade-school math problems** using QLoRA, designed for step-by-step mathematical reasoning with chain-of-thought.
|
|
|
|
## What This Model Does
|
|
|
|
Given a math word problem, the model produces a structured reasoning chain inside `<think>` tags, then outputs the final numerical answer.
|
|
|
|
### Example
|
|
|
|
**Input:**
|
|
> If 3x + 7 = 22, what is x?
|
|
|
|
**Output:**
|
|
```
|
|
<think>
|
|
Step 1: Subtract 7 from both sides: 3x = 22 - 7
|
|
Step 2: Calculate: 3x = 15
|
|
Step 3: Divide both sides by 3: x = 15 / 3
|
|
Step 4: Calculate: x = 5
|
|
</think>
|
|
|
|
The answer is 5.0.
|
|
```
|
|
|
|
## Evaluation Results
|
|
|
|
Evaluated on the **full GSM8K test set** (1,319 questions) with greedy decoding and 4-bit quantization.
|
|
|
|
| Model | GSM8K Accuracy | Correct / Total | Time |
|
|
|---|---|---|---|
|
|
| Base Qwen3-8B | 79.4% | 1,047 / 1,319 | 121.8m |
|
|
| **Qwen3-8B-GSM8K-Synth-50K** | **86.2%** | **1,137 / 1,319** | 45.1m |
|
|
|
|
**Fine-tuning improvement: +6.8 percentage points** over the base model.
|
|
|
|
The fine-tuned model also runs ~2.7x faster at inference due to shorter, more structured outputs (the base model produces verbose markdown formatting while the fine-tuned model outputs concise step-by-step solutions).
|
|
|
|
### Cross-Model Comparison
|
|
|
|
| Model | Params | Training Data | GSM8K Accuracy | vs Base |
|
|
|---|---|---|---|---|
|
|
| Base Qwen3-4B | 4B | — | 74.7% | — |
|
|
| Qwen3-4B-GSM8K-Synth-35K | 4B | 35K synthetic | 85.0% | +10.3% |
|
|
| Base Qwen3-8B | 8B | — | 79.4% | — |
|
|
| **Qwen3-8B-GSM8K-Synth-50K** | **8B** | **50K synthetic** | **86.2%** | **+6.8%** |
|
|
|
|
Key takeaways:
|
|
- Synthetic data fine-tuning provides a substantial accuracy boost at both model scales (+10.3% for 4B, +6.8% for 8B)
|
|
- The 8B fine-tuned model achieves the highest absolute accuracy (86.2%)
|
|
- Scaling from 4B to 8B improves base performance by +4.7% and fine-tuned performance by +1.2%
|
|
|
|
## Training Details
|
|
|
|
### Base Model & Method
|
|
|
|
| Parameter | Value |
|
|
|---|---|
|
|
| **Base model** | [Qwen/Qwen3-8B](https://huggingface.co/Qwen/Qwen3-8B) |
|
|
| **Method** | QLoRA (4-bit NF4 quantization) |
|
|
| **Framework** | [Unsloth](https://github.com/unslothai/unsloth) + HuggingFace TRL |
|
|
| **Merge** | Fully merged to 16-bit (no adapter needed at inference) |
|
|
|
|
### QLoRA Configuration
|
|
|
|
| Parameter | Value |
|
|
|---|---|
|
|
| **LoRA rank** | 16 |
|
|
| **LoRA alpha** | 16 |
|
|
| **Target modules** | q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj |
|
|
| **Dropout** | 0 |
|
|
| **Trainable parameters** | 43.6M / 8.23B (0.53%) |
|
|
|
|
### Training Hyperparameters
|
|
|
|
| Parameter | Value |
|
|
|---|---|
|
|
| **Epochs** | 3 |
|
|
| **Batch size** | 1 (per device) |
|
|
| **Gradient accumulation** | 64 (effective batch = 64) |
|
|
| **Learning rate** | 2e-4 (cosine schedule) |
|
|
| **Warmup steps** | 10 |
|
|
| **Optimizer** | AdamW 8-bit |
|
|
| **Precision** | bf16 |
|
|
| **Max sequence length** | 1024 |
|
|
| **Max grad norm** | 1.0 |
|
|
| **Seed** | 42 |
|
|
| **Total steps** | 2,364 |
|
|
|
|
### Memory Optimizations (fitting 8B in 12GB VRAM)
|
|
|
|
Training Qwen3-8B in 4-bit still uses ~7.2GB for weights alone, leaving only ~4.4GB on a 12GB GPU. The following optimizations made training possible:
|
|
|
|
- **Embedding offloading** (`offload_embedding=True`) — input embeddings kept on CPU
|
|
- **Chunked fused CE loss** (`UNSLOTH_CE_LOSS_N_CHUNKS=8`) — splits the large 151,936-vocab logits computation into smaller chunks
|
|
- **Unsloth gradient checkpointing** — auto-offloads activations for long sequences
|
|
- **Reduced sequence length** (1024 vs 4096) — data is short (median ~100 tokens, max ~260)
|
|
- **PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True** — reduces CUDA memory fragmentation
|
|
- Peak VRAM usage: **11.8GB / 12.3GB** (96%)
|
|
|
|
### Training Loss Curve
|
|
|
|
```
|
|
Epoch 1: 0.735 → 0.304 (rapid descent)
|
|
Epoch 2: 0.292 → 0.277 (steady refinement)
|
|
Epoch 3: 0.271 → 0.266 (final polish)
|
|
|
|
Final training loss: 0.302 (avg over 3 epochs)
|
|
```
|
|
|
|
| Milestone | Loss | Epoch |
|
|
|---|---|---|
|
|
| Step 50 | 0.735 | 0.06 |
|
|
| Step 250 | 0.333 | 0.32 |
|
|
| Step 500 | 0.316 | 0.63 |
|
|
| Step 788 (Epoch 1) | 0.304 | 1.02 |
|
|
| Step 1000 | 0.292 | 1.27 |
|
|
| Step 1250 | 0.291 | 1.59 |
|
|
| Step 1576 (Epoch 2) | 0.277 | 2.03 |
|
|
| Step 1750 | 0.270 | 2.22 |
|
|
| Step 2000 | 0.266 | 2.54 |
|
|
| Step 2364 (Epoch 3) | 0.266 | 2.98 |
|
|
|
|
### Comparison with 4B Model
|
|
|
|
| Metric | Qwen3-4B (35K) | Qwen3-8B (50K) |
|
|
|---|---|---|
|
|
| **Training data** | 34,818 examples | 50,418 examples |
|
|
| **Final loss** | 0.291 | 0.266 |
|
|
| **LoRA rank** | 32 | 16 |
|
|
| **Training time** | 3h 18m | 9h 25m |
|
|
| **Peak VRAM** | ~8.1 GB | ~11.8 GB |
|
|
|
|
### Hardware & Time
|
|
|
|
| Metric | Value |
|
|
|---|---|
|
|
| **GPU** | NVIDIA RTX 4070 SUPER (12GB VRAM) |
|
|
| **Training time** | 9h 25m (33,890 seconds) |
|
|
| **Throughput** | 4.46 samples/sec, 0.07 steps/sec |
|
|
| **Peak VRAM** | ~11.8 GB |
|
|
|
|
## Training Data
|
|
|
|
Trained on the **full 50,418 examples** from [clarkkitchen22/SynthGSM8K-50K](https://huggingface.co/datasets/clarkkitchen22/SynthGSM8K-50K) — a synthetic grade-school math dataset generated by Claude Haiku 4.5 via Anthropic's Batch API, then filtered through an 8-stage quality pipeline.
|
|
|
|
### Data Format
|
|
|
|
Each training example follows the Qwen3 ChatML format with thinking tags:
|
|
|
|
```
|
|
<|im_start|>user
|
|
{math word problem}<|im_end|>
|
|
<|im_start|>assistant
|
|
<think>
|
|
{step-by-step solution}
|
|
</think>
|
|
|
|
The answer is {number}.<|im_end|>
|
|
```
|
|
|
|
GSM8K-style calculation annotations (e.g., `<<24*3=72>>`) are stripped from solutions during preprocessing.
|
|
|
|
### Dataset Highlights
|
|
|
|
- **50,418 problems** — full dataset used for this training run
|
|
- Generated via few-shot prompting from 200 real GSM8K seed problems
|
|
- 8-stage filter pipeline: structure, answer range, solution quality, AI detection, math verification, exact dedup, fuzzy dedup (TF-IDF @ 0.85), seed overlap
|
|
- Average 3.0 math operations per solution
|
|
- 92.6% integer answers, range 0-225,000
|
|
- ~$55 generation cost (Haiku 4.5 Batch API at 50% discount)
|
|
|
|
## Usage
|
|
|
|
### With Transformers
|
|
|
|
```python
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
|
model_name = "clarkkitchen22/Qwen3-8B-GSM8K-Synth-50K"
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto")
|
|
|
|
messages = [{"role": "user", "content": "A store sells apples for $2 each and oranges for $3 each. If Sarah buys 5 apples and 4 oranges, how much does she spend?"}]
|
|
|
|
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
|
inputs = tokenizer(text, return_tensors="pt").to(model.device)
|
|
|
|
outputs = model.generate(**inputs, max_new_tokens=512, temperature=0.6, top_p=0.95)
|
|
response = tokenizer.decode(outputs[0][inputs.input_ids.shape[-1]:], skip_special_tokens=True)
|
|
print(response)
|
|
```
|
|
|
|
### Answer Extraction
|
|
|
|
```python
|
|
import re
|
|
|
|
def extract_answer(text):
|
|
"""Extract numerical answer from model output."""
|
|
match = re.search(r"answer\s*(?:is|:)\s*([-\d,]+\.?\d*)", text, re.IGNORECASE)
|
|
if match:
|
|
return float(match.group(1).replace(",", ""))
|
|
matches = re.findall(r"([-\d,]+\.?\d+)", text)
|
|
return float(matches[-1].replace(",", "")) if matches else None
|
|
```
|
|
|
|
## Intended Use
|
|
|
|
- **Math tutoring**: Step-by-step solutions to grade-school math problems
|
|
- **Research**: Studying the effect of model scale and synthetic data on math reasoning
|
|
- **Distillation baseline**: Comparing synthetic-data-trained small models against larger models
|
|
- **Further fine-tuning**: Starting point for domain-specific math reasoning tasks
|
|
|
|
## Limitations
|
|
|
|
- Trained on synthetic data generated by Haiku 4.5 — bounded by that model's math ability
|
|
- Optimized for GSM8K-style word problems (arithmetic, basic algebra) — not calculus, geometry, or advanced math
|
|
- All training answers are non-negative; may struggle with problems requiring negative answers
|
|
- Solutions use a specific `<think>` tag format — other prompting styles may give worse results
|
|
- Evaluated on GSM8K only — performance on other math benchmarks (MATH, MMLU-Math) not yet tested
|
|
|
|
## How It Was Built
|
|
|
|
### End-to-End Pipeline
|
|
|
|
```
|
|
200 GSM8K seeds → Claude Haiku 4.5 (Batch API) → 83K raw problems
|
|
→ 8-stage filter → 50K clean dataset → QLoRA fine-tune Qwen3-8B
|
|
→ Merge to 16-bit → Push to HuggingFace
|
|
```
|
|
|
|
### Pipeline Code
|
|
|
|
The full data generation pipeline and training code is available at:
|
|
[github.com/goldbar123467/SynthDataGSM8K](https://github.com/goldbar123467/SynthDataGSM8K)
|
|
|
|
## Citation
|
|
|
|
```bibtex
|
|
@model{qwen3_8b_gsm8k_synth_50k,
|
|
title={Qwen3-8B-GSM8K-Synth-50K},
|
|
author={clarkkitchen22},
|
|
year={2026},
|
|
base_model={Qwen/Qwen3-8B},
|
|
training_data={clarkkitchen22/SynthGSM8K-50K},
|
|
url={https://huggingface.co/clarkkitchen22/Qwen3-8B-GSM8K-Synth-50K}
|
|
}
|
|
```
|
|
|
|
## Acknowledgements
|
|
|
|
- **Base model**: [Qwen/Qwen3-8B](https://huggingface.co/Qwen/Qwen3-8B) by Alibaba
|
|
- **Training data**: [SynthGSM8K-50K](https://huggingface.co/datasets/clarkkitchen22/SynthGSM8K-50K) — synthetic math problems from Claude Haiku 4.5
|
|
- **Training framework**: [Unsloth](https://github.com/unslothai/unsloth) (2x faster QLoRA)
|
|
- **Seed data**: [OpenAI GSM8K](https://github.com/openai/grade-school-math)
|