初始化项目,由ModelHub XC社区提供模型

Model: HuggingFaceTB/qwen3-1.7b-gsm8k-sft
Source: Original Platform
This commit is contained in:
ModelHub XC
2026-05-12 18:21:40 +08:00
commit 665d418665
20 changed files with 152756 additions and 0 deletions

38
.gitattributes vendored Normal file
View File

@@ -0,0 +1,38 @@
*.7z filter=lfs diff=lfs merge=lfs -text
*.arrow filter=lfs diff=lfs merge=lfs -text
*.bin filter=lfs diff=lfs merge=lfs -text
*.bz2 filter=lfs diff=lfs merge=lfs -text
*.ckpt filter=lfs diff=lfs merge=lfs -text
*.ftz filter=lfs diff=lfs merge=lfs -text
*.gz filter=lfs diff=lfs merge=lfs -text
*.h5 filter=lfs diff=lfs merge=lfs -text
*.joblib filter=lfs diff=lfs merge=lfs -text
*.lfs.* filter=lfs diff=lfs merge=lfs -text
*.mlmodel filter=lfs diff=lfs merge=lfs -text
*.model filter=lfs diff=lfs merge=lfs -text
*.msgpack filter=lfs diff=lfs merge=lfs -text
*.npy filter=lfs diff=lfs merge=lfs -text
*.npz filter=lfs diff=lfs merge=lfs -text
*.onnx filter=lfs diff=lfs merge=lfs -text
*.ot filter=lfs diff=lfs merge=lfs -text
*.parquet filter=lfs diff=lfs merge=lfs -text
*.pb filter=lfs diff=lfs merge=lfs -text
*.pickle filter=lfs diff=lfs merge=lfs -text
*.pkl filter=lfs diff=lfs merge=lfs -text
*.pt filter=lfs diff=lfs merge=lfs -text
*.pth filter=lfs diff=lfs merge=lfs -text
*.rar filter=lfs diff=lfs merge=lfs -text
*.safetensors filter=lfs diff=lfs merge=lfs -text
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.tar.* filter=lfs diff=lfs merge=lfs -text
*.tar filter=lfs diff=lfs merge=lfs -text
*.tflite filter=lfs diff=lfs merge=lfs -text
*.tgz filter=lfs diff=lfs merge=lfs -text
*.wasm filter=lfs diff=lfs merge=lfs -text
*.xz filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
tokenizer.json filter=lfs diff=lfs merge=lfs -text
qwen3-1.7b-gsm8k-q8_0.gguf filter=lfs diff=lfs merge=lfs -text
qwen3-1.7b-gsm8k-f16.gguf filter=lfs diff=lfs merge=lfs -text

215
README.md Normal file
View File

@@ -0,0 +1,215 @@
---
license: apache-2.0
base_model: Qwen/Qwen3-1.7B
tags:
- math
- gsm8k
- fine-tuned
- chain-of-thought
- reasoning
- gguf
datasets:
- openai/gsm8k
- meta-math/MetaMathQA
language:
- en
pipeline_tag: text-generation
model-index:
- name: qwen3-1.7b-gsm8k-sft
results:
- task:
type: text-generation
name: Math Reasoning
dataset:
name: GSM8K
type: openai/gsm8k
split: test
metrics:
- type: accuracy
value: 77.2
name: Accuracy
- task:
type: text-generation
name: Math Reasoning
dataset:
name: MATH-500
type: HuggingFaceH4/MATH-500
split: test
metrics:
- type: accuracy
value: 55.2
name: Accuracy
---
# Qwen3-1.7B Fine-tuned for GSM8K Math Reasoning
This model is a fine-tuned version of [Qwen/Qwen3-1.7B](https://huggingface.co/Qwen/Qwen3-1.7B) optimized for mathematical reasoning on the GSM8K benchmark.
## Performance
| Benchmark | Accuracy | Notes |
|-----------|----------|-------|
| **GSM8K** | **77.2%** | Grade school math (1,319 test problems) |
| **MATH-500** | **55.2%** | Competition math (500 test problems) |
| Baseline GSM8K | 20% | Original Qwen3-1.7B |
### MATH-500 Breakdown by Difficulty Level
| Level | Accuracy |
|-------|----------|
| Level 1 (Easiest) | 86.0% |
| Level 2 | 68.9% |
| Level 3 | 64.8% |
| Level 4 | 54.7% |
| Level 5 (Hardest) | 29.1% |
### MATH-500 Breakdown by Subject
| Subject | Accuracy |
|---------|----------|
| Algebra | 71.8% |
| Prealgebra | 68.3% |
| Number Theory | 61.3% |
| Counting & Probability | 55.3% |
| Geometry | 43.9% |
| Precalculus | 41.1% |
| Intermediate Algebra | 32.0% |
### Baseline Comparison
| Model | GSM8K | MATH-500 | Notes |
|-------|-------|----------|-------|
| **This model (SFT)** | **77.2%** | 55.2% | Optimized for GSM8K |
| Qwen3-1.7B (base) | ~20% | 62.0% | Pre-training only |
Note: The fine-tuned model shows significant improvement on GSM8K (+57pp) but slightly lower performance on MATH-500 compared to the base model. This is expected as the training focused on GSM8K-style problems.
## GGUF Quantized Versions
For deployment with llama.cpp, Ollama, or other GGUF-compatible runtimes:
| File | Size | Description |
|------|------|-------------|
| [qwen3-1.7b-gsm8k-q8_0.gguf](https://huggingface.co/HuggingFaceTB/qwen3-1.7b-gsm8k-sft/blob/main/qwen3-1.7b-gsm8k-q8_0.gguf) | 1.8 GB | 8-bit quantized (recommended) |
| [qwen3-1.7b-gsm8k-f16.gguf](https://huggingface.co/HuggingFaceTB/qwen3-1.7b-gsm8k-sft/blob/main/qwen3-1.7b-gsm8k-f16.gguf) | 3.3 GB | Full FP16 precision |
### Usage with Ollama
```bash
# Download and run
ollama run hf.co/HuggingFaceTB/qwen3-1.7b-gsm8k-sft:q8_0
```
### Usage with llama.cpp
```bash
# Download the GGUF file
huggingface-cli download HuggingFaceTB/qwen3-1.7b-gsm8k-sft qwen3-1.7b-gsm8k-q8_0.gguf
# Run inference
./llama-cli -m qwen3-1.7b-gsm8k-q8_0.gguf -p "Solve: If a train travels 120 miles in 2 hours, what is its average speed?"
```
## Training Details
### Dataset
- **Size**: 247,467 examples
- **Sources**:
- [GSM8K](https://huggingface.co/datasets/openai/gsm8k) train set (7,473 examples)
- [MetaMathQA](https://huggingface.co/datasets/meta-math/MetaMathQA) GSM-related examples (239,994 examples)
- **Format**: Conversational messages with `<think>...</think>` chain-of-thought reasoning
### Training Configuration
- **Stage 1** (2 epochs): lr=2e-5, loss 0.30 → 0.17
- **Stage 2** (1 epoch): lr=5e-6, loss 0.17 → 0.167
- **Batch size**: 8 per device, gradient accumulation 4
- **Hardware**: NVIDIA H100 80GB GPU
- **Total training time**: ~7 hours
### Hyperparameters
```python
SFTConfig(
num_train_epochs=2, # Stage 1
per_device_train_batch_size=8,
gradient_accumulation_steps=4,
learning_rate=2e-5, # 5e-6 for Stage 2
lr_scheduler_type="cosine",
warmup_ratio=0.03,
weight_decay=0.01,
max_length=1024,
packing=True,
bf16=True,
gradient_checkpointing=True,
)
```
## Usage
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
model = AutoModelForCausalLM.from_pretrained(
"HuggingFaceTB/qwen3-1.7b-gsm8k-sft",
torch_dtype=torch.bfloat16,
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/qwen3-1.7b-gsm8k-sft")
# For math problems, the model uses chain-of-thought reasoning
messages = [
{"role": "user", "content": "Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?"}
]
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(text, return_tensors="pt").to(model.device)
outputs = model.generate(**inputs, max_new_tokens=1024, do_sample=False)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
```
## Evaluation
### GSM8K
- **Accuracy**: 77.2% ± 1.2% (standard error)
- Test set: 1,319 grade school math word problems
### MATH-500
- **Accuracy**: 55.2%
- Test set: 500 competition-level math problems
- Best performance on Algebra (71.8%) and Prealgebra (68.3%)
- Model uses chain-of-thought reasoning enclosed in `<think>...</think>` tags
## Key Learnings
1. **Chain-of-thought format is crucial** - The `<think>...</think>` reasoning format significantly improves math performance
2. **Large diverse dataset works better** - MetaMathQA (240K examples) outperforms small task-specific data
3. **Two-stage training** - Starting with higher LR (2e-5) then refining with lower LR (5e-6) works well
4. **Diminishing returns after ~3 epochs** - Additional fine-tuning showed minimal improvement
5. **Transfer to harder problems** - GSM8K training also improves MATH-500 performance, especially on algebra
## Training Scripts
Training scripts are available in the `scripts/` directory:
- `train_improved.py` - Main training script (Stage 1)
- `train_continued.py` - Continued training script (Stage 2)
- `evaluate.py` - GSM8K evaluation script
- `evaluate_math500.py` - MATH-500 evaluation script
- `prepare_combined_data.py` - Data preparation script
## Citation
If you use this model, please cite:
```bibtex
@misc{qwen3-gsm8k-sft,
title={Qwen3-1.7B Fine-tuned for GSM8K},
author={HuggingFaceTB},
year={2026},
publisher={Hugging Face},
url={https://huggingface.co/HuggingFaceTB/qwen3-1.7b-gsm8k-sft}
}
```
## License
This model inherits the license from the base model [Qwen/Qwen3-1.7B](https://huggingface.co/Qwen/Qwen3-1.7B) (Apache 2.0).

28
added_tokens.json Normal file
View File

@@ -0,0 +1,28 @@
{
"</think>": 151668,
"</tool_call>": 151658,
"</tool_response>": 151666,
"<think>": 151667,
"<tool_call>": 151657,
"<tool_response>": 151665,
"<|box_end|>": 151649,
"<|box_start|>": 151648,
"<|endoftext|>": 151643,
"<|file_sep|>": 151664,
"<|fim_middle|>": 151660,
"<|fim_pad|>": 151662,
"<|fim_prefix|>": 151659,
"<|fim_suffix|>": 151661,
"<|im_end|>": 151645,
"<|im_start|>": 151644,
"<|image_pad|>": 151655,
"<|object_ref_end|>": 151647,
"<|object_ref_start|>": 151646,
"<|quad_end|>": 151651,
"<|quad_start|>": 151650,
"<|repo_name|>": 151663,
"<|video_pad|>": 151656,
"<|vision_end|>": 151653,
"<|vision_pad|>": 151654,
"<|vision_start|>": 151652
}

89
chat_template.jinja Normal file
View File

@@ -0,0 +1,89 @@
{%- if tools %}
{{- '<|im_start|>system\n' }}
{%- if messages[0].role == 'system' %}
{{- messages[0].content + '\n\n' }}
{%- endif %}
{{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
{%- for tool in tools %}
{{- "\n" }}
{{- tool | tojson }}
{%- endfor %}
{{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
{%- else %}
{%- if messages[0].role == 'system' %}
{{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }}
{%- endif %}
{%- endif %}
{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
{%- for message in messages[::-1] %}
{%- set index = (messages|length - 1) - loop.index0 %}
{%- if ns.multi_step_tool and message.role == "user" and message.content is string and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}
{%- set ns.multi_step_tool = false %}
{%- set ns.last_query_index = index %}
{%- endif %}
{%- endfor %}
{%- for message in messages %}
{%- if message.content is string %}
{%- set content = message.content %}
{%- else %}
{%- set content = '' %}
{%- endif %}
{%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
{{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }}
{%- elif message.role == "assistant" %}
{%- set reasoning_content = '' %}
{%- if message.reasoning_content is string %}
{%- set reasoning_content = message.reasoning_content %}
{%- else %}
{%- if '</think>' in content %}
{%- set reasoning_content = content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
{%- set content = content.split('</think>')[-1].lstrip('\n') %}
{%- endif %}
{%- endif %}
{%- if loop.index0 > ns.last_query_index %}
{%- if loop.last or (not loop.last and reasoning_content) %}
{{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content.strip('\n') + '\n</think>\n\n' + content.lstrip('\n') }}
{%- else %}
{{- '<|im_start|>' + message.role + '\n' + content }}
{%- endif %}
{%- else %}
{{- '<|im_start|>' + message.role + '\n' + content }}
{%- endif %}
{%- if message.tool_calls %}
{%- for tool_call in message.tool_calls %}
{%- if (loop.first and content) or (not loop.first) %}
{{- '\n' }}
{%- endif %}
{%- if tool_call.function %}
{%- set tool_call = tool_call.function %}
{%- endif %}
{{- '<tool_call>\n{"name": "' }}
{{- tool_call.name }}
{{- '", "arguments": ' }}
{%- if tool_call.arguments is string %}
{{- tool_call.arguments }}
{%- else %}
{{- tool_call.arguments | tojson }}
{%- endif %}
{{- '}\n</tool_call>' }}
{%- endfor %}
{%- endif %}
{{- '<|im_end|>\n' }}
{%- elif message.role == "tool" %}
{%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
{{- '<|im_start|>user' }}
{%- endif %}
{{- '\n<tool_response>\n' }}
{{- content }}
{{- '\n</tool_response>' }}
{%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
{{- '<|im_end|>\n' }}
{%- endif %}
{%- endif %}
{%- endfor %}
{%- if add_generation_prompt %}
{{- '<|im_start|>assistant\n' }}
{%- if enable_thinking is defined and enable_thinking is false %}
{{- '<think>\n\n</think>\n\n' }}
{%- endif %}
{%- endif %}

60
config.json Normal file
View File

@@ -0,0 +1,60 @@
{
"architectures": [
"Qwen3ForCausalLM"
],
"attention_bias": false,
"attention_dropout": 0.0,
"dtype": "bfloat16",
"eos_token_id": 151645,
"head_dim": 128,
"hidden_act": "silu",
"hidden_size": 2048,
"initializer_range": 0.02,
"intermediate_size": 6144,
"layer_types": [
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention"
],
"max_position_embeddings": 40960,
"max_window_layers": 28,
"model_type": "qwen3",
"num_attention_heads": 16,
"num_hidden_layers": 28,
"num_key_value_heads": 8,
"pad_token_id": 151643,
"rms_norm_eps": 1e-06,
"rope_scaling": null,
"rope_theta": 1000000,
"sliding_window": null,
"tie_word_embeddings": true,
"transformers_version": "4.57.6",
"use_cache": true,
"use_sliding_window": false,
"vocab_size": 151936
}

12
generation_config.json Normal file
View File

@@ -0,0 +1,12 @@
{
"do_sample": true,
"eos_token_id": [
151645,
151643
],
"pad_token_id": 151643,
"temperature": 0.6,
"top_k": 20,
"top_p": 0.95,
"transformers_version": "4.57.6"
}

151388
merges.txt Normal file

File diff suppressed because it is too large Load Diff

3
model.safetensors Normal file
View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:1febbe96ad1e0ad1ac19919d0eb5362bb013b0493063de241d1ecdd7b67bb291
size 3441185608

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:cd6b433c2c52329f688810ded7e66d38c5313f8ca93ca4ec0e45fc9a63b9a79b
size 3447348928

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:9508fd895d10cb39fd4799d709c3599efeb00a61d4afbc05ecbfa34b02998e1a
size 1834426048

138
scripts/evaluate.py Normal file
View File

@@ -0,0 +1,138 @@
#!/usr/bin/env python3
from __future__ import annotations
import os
import argparse
import json
from inspect_ai.log._log import EvalLog, EvalMetric, EvalSample
from inspect_ai import eval as inspect_eval # type: ignore # noqa: E402
from inspect_ai.util._display import init_display_type # noqa: E402
import inspect_evals.gsm8k # noqa: F401, E402 (registers task definitions)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Run Inspect AI eval without banners.")
parser.add_argument(
"--model-path",
type=str,
default="final_model",
help="Path to the Hugging Face model (directory or model identifier).",
)
# this is a good limit for this task, just keep it like that (or use less in case you want faster tests)
parser.add_argument(
"--limit",
type=int,
default=150,
help="Optional limit for number of samples to evaluate.",
)
parser.add_argument(
'--json-output-file',
type=str,
default=None,
help="Optional path to output the metrics as a seperate JSON file.",
)
parser.add_argument(
'--templates-dir',
type=str,
default="templates/",
)
# You can adjust --max-connections if you want faster tests and don't receive errors (or if you have issues with vllm, try lowering this value)
parser.add_argument(
"--max-connections",
type=int,
default=2,
)
parser.add_argument(
"--max-tokens",
type=int,
default=4000,
)
parser.add_argument(
"--gpu-memory-utilization",
type=float,
default=0.3,
)
return parser.parse_args()
def main() -> None:
args = parse_args()
init_display_type("plain")
other_kwargs = {}
if (args.limit is not None) and (args.limit != -1):
other_kwargs["limit"] = args.limit
task = "inspect_evals/gsm8k"
model_args = {
'gpu_memory_utilization': args.gpu_memory_utilization,
}
model_args.update(template_kwargs(args))
eval_out = inspect_eval(
task,
model=f"vllm/{args.model_path}",
model_args=model_args,
score_display=False,
log_realtime=False,
log_format='json',
timeout=18000000,
attempt_timeout=18000000,
max_tokens=args.max_tokens,
max_connections=args.max_connections,
**other_kwargs,
)
if args.json_output_file is not None:
assert len(eval_out) == 1, eval_out
assert len(eval_out[0].results.scores) == 1, eval_out[0].results.scores
metrics = {}
for k, v in eval_out[0].results.scores[0].metrics.items():
metrics[k] = v.value
with open(args.json_output_file, 'w') as f:
json.dump(metrics, f, indent=2)
def model_type(args) -> str:
if 'qwen' in args.model_path.lower():
return 'qwen'
if 'llama' in args.model_path.lower():
return 'llama'
if 'gemma' in args.model_path.lower():
return 'gemma'
if 'smollm' in args.model_path.lower():
return 'smollm'
with open(os.path.join(args.model_path, "config.json"), 'r') as f:
config = json.load(f)
architecture = config['architectures'][0].lower()
if 'gemma' in architecture:
return 'gemma'
if 'llama' in architecture:
return 'llama'
if 'qwen' in architecture:
return 'qwen'
if 'smollm' in architecture:
return 'smollm'
raise ValueError(architecture)
def template_kwargs(args) -> dict:
model_type_str = model_type(args)
if model_type_str == 'qwen':
template = 'qwen3.jinja'
elif model_type_str == 'llama':
template = 'llama3.jinja'
elif model_type_str == 'gemma':
template = 'gemma3.jinja'
elif model_type_str == 'smollm':
template = 'smollm.jinja'
else:
raise ValueError(model_type_str)
return {
'chat_template': os.path.join(args.templates_dir, template)
}
if __name__ == "__main__":
main()

159
scripts/evaluate_math500.py Normal file
View File

@@ -0,0 +1,159 @@
#!/usr/bin/env python3
"""Evaluate model on MATH-500 dataset (harder math problems)."""
import argparse
import json
import re
from datasets import load_dataset
from vllm import LLM, SamplingParams
def extract_answer(response: str) -> str:
"""Extract the final answer from model response."""
# Look for boxed answer first (common in MATH format)
boxed_match = re.search(r'\\boxed\{([^}]+)\}', response)
if boxed_match:
return boxed_match.group(1).strip()
# Look for "The answer is X" pattern
answer_match = re.search(r'[Tt]he (?:final )?answer is[:\s]*([^\n.]+)', response)
if answer_match:
return answer_match.group(1).strip()
# Look for "= X" at the end
equals_match = re.search(r'=\s*([^\n=]+?)\s*$', response)
if equals_match:
return equals_match.group(1).strip()
# Return last line as fallback
lines = [l.strip() for l in response.strip().split('\n') if l.strip()]
return lines[-1] if lines else ""
def normalize_answer(answer: str) -> str:
"""Normalize answer for comparison."""
# Remove common formatting
answer = answer.strip()
answer = re.sub(r'\\text\{([^}]*)\}', r'\1', answer)
answer = re.sub(r'\\mathrm\{([^}]*)\}', r'\1', answer)
answer = re.sub(r'\\left|\\right', '', answer)
answer = re.sub(r'\$', '', answer)
answer = answer.strip()
return answer.lower()
def answers_match(predicted: str, expected: str) -> bool:
"""Check if answers match (with some tolerance)."""
pred_norm = normalize_answer(predicted)
exp_norm = normalize_answer(expected)
# Direct match
if pred_norm == exp_norm:
return True
# Try numeric comparison
try:
pred_num = float(re.sub(r'[^\d.-]', '', pred_norm))
exp_num = float(re.sub(r'[^\d.-]', '', exp_norm))
if abs(pred_num - exp_num) < 1e-6:
return True
except:
pass
# Check if one contains the other
if exp_norm in pred_norm or pred_norm in exp_norm:
return True
return False
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model-path", type=str, default="final_model")
parser.add_argument("--limit", type=int, default=100)
args = parser.parse_args()
print(f"Loading MATH-500 dataset...")
dataset = load_dataset("HuggingFaceH4/MATH-500", split="test")
if args.limit:
dataset = dataset.select(range(min(args.limit, len(dataset))))
print(f"Evaluating {len(dataset)} problems...")
# Load model
print(f"Loading model from {args.model_path}...")
llm = LLM(
model=args.model_path,
dtype="bfloat16",
max_model_len=4096,
gpu_memory_utilization=0.9,
)
sampling_params = SamplingParams(
temperature=0,
max_tokens=2048,
stop=["<|im_end|>", "<|endoftext|>"],
)
# Prepare prompts
prompts = []
for item in dataset:
problem = item["problem"]
prompt = f"<|im_start|>user\n{problem}<|im_end|>\n<|im_start|>assistant\n"
prompts.append(prompt)
# Generate
print("Generating responses...")
outputs = llm.generate(prompts, sampling_params)
# Evaluate
correct = 0
results_by_level = {}
results_by_subject = {}
for i, (item, output) in enumerate(zip(dataset, outputs)):
response = output.outputs[0].text
predicted = extract_answer(response)
expected = item["answer"]
level = item["level"]
subject = item["subject"]
is_correct = answers_match(predicted, expected)
if is_correct:
correct += 1
# Track by level
if level not in results_by_level:
results_by_level[level] = {"correct": 0, "total": 0}
results_by_level[level]["total"] += 1
if is_correct:
results_by_level[level]["correct"] += 1
# Track by subject
if subject not in results_by_subject:
results_by_subject[subject] = {"correct": 0, "total": 0}
results_by_subject[subject]["total"] += 1
if is_correct:
results_by_subject[subject]["correct"] += 1
if (i + 1) % 20 == 0:
print(f"Progress: {i+1}/{len(dataset)}, Accuracy so far: {correct/(i+1)*100:.1f}%")
# Print results
accuracy = correct / len(dataset) * 100
print(f"\n{'='*60}")
print(f"MATH-500 Results ({len(dataset)} problems)")
print(f"{'='*60}")
print(f"Overall Accuracy: {accuracy:.1f}% ({correct}/{len(dataset)})")
print(f"\nBy Level:")
for level in sorted(results_by_level.keys()):
stats = results_by_level[level]
acc = stats["correct"] / stats["total"] * 100
print(f" {level}: {acc:.1f}% ({stats['correct']}/{stats['total']})")
print(f"\nBy Subject:")
for subject in sorted(results_by_subject.keys()):
stats = results_by_subject[subject]
acc = stats["correct"] / stats["total"] * 100
print(f" {subject}: {acc:.1f}% ({stats['correct']}/{stats['total']})")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,154 @@
#!/usr/bin/env python3
"""
Combine GSM8K training data with MetaMathQA GSM-related examples.
This creates a larger, more diverse training set.
"""
import json
import re
from datasets import load_dataset
def extract_answer_gsm8k(answer_text):
"""Extract the final numerical answer from GSM8K answer format."""
match = re.search(r'####\s*(-?[\d,]+\.?\d*)', answer_text)
if match:
return match.group(1).replace(',', '')
return None
def format_reasoning_gsm8k(answer_text):
"""Convert GSM8K step-by-step format to thinking format."""
reasoning = re.sub(r'####\s*-?[\d,]+\.?\d*\s*$', '', answer_text).strip()
reasoning = re.sub(r'<<[^>]+>>', '', reasoning)
return reasoning
def extract_answer_metamath(response):
"""Extract answer from MetaMathQA format (usually ends with boxed answer)."""
# Try to find boxed answer
match = re.search(r'\\boxed\{([^}]+)\}', response)
if match:
return match.group(1).strip()
# Try to find "the answer is X" pattern
match = re.search(r'the answer is[:\s]*\$?(-?[\d,]+\.?\d*)', response, re.IGNORECASE)
if match:
return match.group(1).replace(',', '')
# Try to find "= X" at the end
match = re.search(r'=\s*\$?(-?[\d,]+\.?\d*)\s*(?:dollars?|\.)?$', response)
if match:
return match.group(1).replace(',', '')
return None
def create_gsm8k_example(question, answer):
"""Create a training example from GSM8K format."""
final_answer = extract_answer_gsm8k(answer)
reasoning = format_reasoning_gsm8k(answer)
if final_answer is None:
return None
assistant_content = f"""<think>
Let me solve this step by step.
{reasoning}
Therefore, the answer is {final_answer}.
</think>
The answer is {final_answer}"""
return {
"messages": [
{"role": "user", "content": f"Solve the following math problem step by step. Show your reasoning and then provide the final answer.\n\n{question}"},
{"role": "assistant", "content": assistant_content}
]
}
def create_metamath_example(query, response):
"""Create a training example from MetaMathQA format."""
# Clean up the response - remove LaTeX formatting artifacts
clean_response = response.replace('\\n', '\n').strip()
# Extract the answer
final_answer = extract_answer_metamath(clean_response)
if final_answer is None:
return None
# Remove the boxed answer and everything after for reasoning
reasoning = re.sub(r'\\boxed\{[^}]+\}.*$', '', clean_response, flags=re.DOTALL).strip()
reasoning = re.sub(r'The answer is.*$', '', reasoning, flags=re.IGNORECASE | re.DOTALL).strip()
# Skip if reasoning is too short
if len(reasoning) < 50:
return None
assistant_content = f"""<think>
Let me solve this step by step.
{reasoning}
Therefore, the answer is {final_answer}.
</think>
The answer is {final_answer}"""
return {
"messages": [
{"role": "user", "content": f"Solve the following math problem step by step. Show your reasoning and then provide the final answer.\n\n{query}"},
{"role": "assistant", "content": assistant_content}
]
}
def main():
training_data = []
# Load GSM8K training data
print("Loading GSM8K dataset...")
gsm8k = load_dataset("openai/gsm8k", "main", split="train")
print(f"Loaded {len(gsm8k)} GSM8K examples")
gsm8k_count = 0
for example in gsm8k:
formatted = create_gsm8k_example(example['question'], example['answer'])
if formatted:
training_data.append(formatted)
gsm8k_count += 1
print(f"Added {gsm8k_count} GSM8K examples")
# Load MetaMathQA - only GSM-related examples
print("\nLoading MetaMathQA dataset...")
metamath = load_dataset("meta-math/MetaMathQA", split="train")
print(f"Loaded {len(metamath)} MetaMathQA examples")
# Filter for GSM-related examples only
metamath_count = 0
for example in metamath:
if 'GSM' in example['type']: # GSM_Rephrased, GSM_SV, GSM_AnsAug, etc.
formatted = create_metamath_example(example['query'], example['response'])
if formatted:
training_data.append(formatted)
metamath_count += 1
print(f"Added {metamath_count} MetaMathQA GSM examples")
print(f"\nTotal training examples: {len(training_data)}")
# Shuffle the data
import random
random.seed(42)
random.shuffle(training_data)
# Save to JSONL
output_file = "combined_math_train.jsonl"
with open(output_file, 'w') as f:
for item in training_data:
f.write(json.dumps(item) + '\n')
print(f"Saved to {output_file}")
# Show samples
print("\n=== Sample GSM8K example ===")
for item in training_data[:10]:
if "Natalia" in item['messages'][0]['content']:
print(json.dumps(item, indent=2)[:500])
break
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,90 @@
#!/usr/bin/env python3
"""
Continue training the already fine-tuned model with lower learning rate for additional refinement.
"""
import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import SFTTrainer, SFTConfig
import os
def main():
# Load from our previously trained model
print("Loading previously trained model from final_model/...")
model_name = "./final_model"
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Ensure pad token is set
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
attn_implementation="sdpa",
device_map="auto",
)
# Load dataset
print("Loading dataset...")
dataset = load_dataset("json", data_files="combined_math_train.jsonl", split="train")
print(f"Dataset size: {len(dataset)}")
# Training config - lower LR for refinement, 1 more epoch
training_args = SFTConfig(
output_dir="./sft_output_continued",
num_train_epochs=1,
per_device_train_batch_size=8,
gradient_accumulation_steps=4,
learning_rate=5e-6, # Lower LR for continued training
lr_scheduler_type="cosine",
warmup_ratio=0.01,
weight_decay=0.01,
logging_steps=100,
save_steps=2000,
save_total_limit=2,
bf16=True,
gradient_checkpointing=True,
gradient_checkpointing_kwargs={"use_reentrant": False},
max_length=1024,
packing=True,
report_to="none",
seed=42,
dataloader_num_workers=4,
optim="adamw_torch_fused",
)
# Create trainer
print("Creating trainer...")
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=dataset,
processing_class=tokenizer,
)
# Print training info
print(f"\n=== Continued Training Configuration ===")
print(f"Model: {model_name} (previously fine-tuned)")
print(f"Dataset size: {len(dataset)}")
print(f"Batch size: {training_args.per_device_train_batch_size}")
print(f"Gradient accumulation: {training_args.gradient_accumulation_steps}")
print(f"Effective batch size: {training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps}")
print(f"Learning rate: {training_args.learning_rate}")
print(f"Epochs: {training_args.num_train_epochs}")
print("="*40)
# Train
print("\nStarting continued training...")
trainer.train()
# Save final model
print("\nSaving model to final_model/...")
trainer.save_model("final_model")
tokenizer.save_pretrained("final_model")
print("Continued training complete!")
if __name__ == "__main__":
main()

99
scripts/train_improved.py Normal file
View File

@@ -0,0 +1,99 @@
#!/usr/bin/env python3
"""
Improved SFT training for GSM8K performance.
Key improvements:
1. More training data (247K examples from GSM8K + MetaMathQA)
2. Multiple epochs with cosine LR schedule
3. Proper batch size and gradient accumulation for H100
4. Gradient checkpointing for memory efficiency
"""
import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import SFTTrainer, SFTConfig
import os
def main():
# Load model and tokenizer
print("Loading model and tokenizer...")
model_name = "Qwen/Qwen3-1.7B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Ensure pad token is set
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
attn_implementation="sdpa", # Use SDPA instead of flash_attention_2
device_map="auto",
)
# Load dataset
print("Loading dataset...")
dataset = load_dataset("json", data_files="combined_math_train.jsonl", split="train")
print(f"Dataset size: {len(dataset)}")
# Training config - optimized for H100 and GSM8K task
# With 247K examples and batch_size 8 * grad_accum 4 = effective batch 32
# Steps per epoch: 247467 / 32 ≈ 7733 steps
# 2 epochs ≈ 15466 steps
training_args = SFTConfig(
output_dir="./sft_output_improved",
num_train_epochs=2,
per_device_train_batch_size=8,
gradient_accumulation_steps=4,
learning_rate=2e-5,
lr_scheduler_type="cosine",
warmup_ratio=0.03,
weight_decay=0.01,
logging_steps=100,
save_steps=2000,
save_total_limit=3,
bf16=True,
gradient_checkpointing=True,
gradient_checkpointing_kwargs={"use_reentrant": False},
max_length=1024, # Math problems don't need very long context
packing=True,
report_to="none",
seed=42,
dataloader_num_workers=4,
optim="adamw_torch_fused",
)
# Create trainer
print("Creating trainer...")
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=dataset,
processing_class=tokenizer,
)
# Print training info
print(f"\n=== Training Configuration ===")
print(f"Model: {model_name}")
print(f"Dataset size: {len(dataset)}")
print(f"Batch size: {training_args.per_device_train_batch_size}")
print(f"Gradient accumulation: {training_args.gradient_accumulation_steps}")
print(f"Effective batch size: {training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps}")
print(f"Learning rate: {training_args.learning_rate}")
print(f"Epochs: {training_args.num_train_epochs}")
print(f"Max length: {training_args.max_length}")
print("="*30)
# Train
print("\nStarting training...")
trainer.train()
# Save final model
print("\nSaving model to final_model/...")
trainer.save_model("final_model")
tokenizer.save_pretrained("final_model")
print("Training complete!")
if __name__ == "__main__":
main()

31
special_tokens_map.json Normal file
View File

@@ -0,0 +1,31 @@
{
"additional_special_tokens": [
"<|im_start|>",
"<|im_end|>",
"<|object_ref_start|>",
"<|object_ref_end|>",
"<|box_start|>",
"<|box_end|>",
"<|quad_start|>",
"<|quad_end|>",
"<|vision_start|>",
"<|vision_end|>",
"<|vision_pad|>",
"<|image_pad|>",
"<|video_pad|>"
],
"eos_token": {
"content": "<|im_end|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
},
"pad_token": {
"content": "<|endoftext|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
}
}

BIN
tokenizer.json (Stored with Git LFS) Normal file

Binary file not shown.

239
tokenizer_config.json Normal file
View File

@@ -0,0 +1,239 @@
{
"add_bos_token": false,
"add_prefix_space": false,
"added_tokens_decoder": {
"151643": {
"content": "<|endoftext|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151644": {
"content": "<|im_start|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151645": {
"content": "<|im_end|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151646": {
"content": "<|object_ref_start|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151647": {
"content": "<|object_ref_end|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151648": {
"content": "<|box_start|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151649": {
"content": "<|box_end|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151650": {
"content": "<|quad_start|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151651": {
"content": "<|quad_end|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151652": {
"content": "<|vision_start|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151653": {
"content": "<|vision_end|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151654": {
"content": "<|vision_pad|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151655": {
"content": "<|image_pad|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151656": {
"content": "<|video_pad|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151657": {
"content": "<tool_call>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": false
},
"151658": {
"content": "</tool_call>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": false
},
"151659": {
"content": "<|fim_prefix|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": false
},
"151660": {
"content": "<|fim_middle|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": false
},
"151661": {
"content": "<|fim_suffix|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": false
},
"151662": {
"content": "<|fim_pad|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": false
},
"151663": {
"content": "<|repo_name|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": false
},
"151664": {
"content": "<|file_sep|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": false
},
"151665": {
"content": "<tool_response>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": false
},
"151666": {
"content": "</tool_response>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": false
},
"151667": {
"content": "<think>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": false
},
"151668": {
"content": "</think>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": false
}
},
"additional_special_tokens": [
"<|im_start|>",
"<|im_end|>",
"<|object_ref_start|>",
"<|object_ref_end|>",
"<|box_start|>",
"<|box_end|>",
"<|quad_start|>",
"<|quad_end|>",
"<|vision_start|>",
"<|vision_end|>",
"<|vision_pad|>",
"<|image_pad|>",
"<|video_pad|>"
],
"bos_token": null,
"clean_up_tokenization_spaces": false,
"eos_token": "<|im_end|>",
"errors": "replace",
"extra_special_tokens": {},
"model_max_length": 131072,
"pad_token": "<|endoftext|>",
"split_special_tokens": false,
"tokenizer_class": "Qwen2Tokenizer",
"unk_token": null
}

3
training_args.bin Normal file
View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:4cc4183a399e46771d11513377493de112e1c85dfa95d53dc6b122e5befe0449
size 6289

1
vocab.json Normal file

File diff suppressed because one or more lines are too long