初始化项目,由ModelHub XC社区提供模型
Model: HuggingFaceTB/qwen3-1.7b-gsm8k-sft Source: Original Platform
This commit is contained in:
38
.gitattributes
vendored
Normal file
38
.gitattributes
vendored
Normal file
@@ -0,0 +1,38 @@
|
||||
*.7z filter=lfs diff=lfs merge=lfs -text
|
||||
*.arrow filter=lfs diff=lfs merge=lfs -text
|
||||
*.bin filter=lfs diff=lfs merge=lfs -text
|
||||
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
||||
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
||||
*.ftz filter=lfs diff=lfs merge=lfs -text
|
||||
*.gz filter=lfs diff=lfs merge=lfs -text
|
||||
*.h5 filter=lfs diff=lfs merge=lfs -text
|
||||
*.joblib filter=lfs diff=lfs merge=lfs -text
|
||||
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
||||
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
||||
*.model filter=lfs diff=lfs merge=lfs -text
|
||||
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
||||
*.npy filter=lfs diff=lfs merge=lfs -text
|
||||
*.npz filter=lfs diff=lfs merge=lfs -text
|
||||
*.onnx filter=lfs diff=lfs merge=lfs -text
|
||||
*.ot filter=lfs diff=lfs merge=lfs -text
|
||||
*.parquet filter=lfs diff=lfs merge=lfs -text
|
||||
*.pb filter=lfs diff=lfs merge=lfs -text
|
||||
*.pickle filter=lfs diff=lfs merge=lfs -text
|
||||
*.pkl filter=lfs diff=lfs merge=lfs -text
|
||||
*.pt filter=lfs diff=lfs merge=lfs -text
|
||||
*.pth filter=lfs diff=lfs merge=lfs -text
|
||||
*.rar filter=lfs diff=lfs merge=lfs -text
|
||||
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
||||
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
||||
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
||||
*.tar filter=lfs diff=lfs merge=lfs -text
|
||||
*.tflite filter=lfs diff=lfs merge=lfs -text
|
||||
*.tgz filter=lfs diff=lfs merge=lfs -text
|
||||
*.wasm filter=lfs diff=lfs merge=lfs -text
|
||||
*.xz filter=lfs diff=lfs merge=lfs -text
|
||||
*.zip filter=lfs diff=lfs merge=lfs -text
|
||||
*.zst filter=lfs diff=lfs merge=lfs -text
|
||||
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
||||
tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
||||
qwen3-1.7b-gsm8k-q8_0.gguf filter=lfs diff=lfs merge=lfs -text
|
||||
qwen3-1.7b-gsm8k-f16.gguf filter=lfs diff=lfs merge=lfs -text
|
||||
215
README.md
Normal file
215
README.md
Normal file
@@ -0,0 +1,215 @@
|
||||
---
|
||||
license: apache-2.0
|
||||
base_model: Qwen/Qwen3-1.7B
|
||||
tags:
|
||||
- math
|
||||
- gsm8k
|
||||
- fine-tuned
|
||||
- chain-of-thought
|
||||
- reasoning
|
||||
- gguf
|
||||
datasets:
|
||||
- openai/gsm8k
|
||||
- meta-math/MetaMathQA
|
||||
language:
|
||||
- en
|
||||
pipeline_tag: text-generation
|
||||
model-index:
|
||||
- name: qwen3-1.7b-gsm8k-sft
|
||||
results:
|
||||
- task:
|
||||
type: text-generation
|
||||
name: Math Reasoning
|
||||
dataset:
|
||||
name: GSM8K
|
||||
type: openai/gsm8k
|
||||
split: test
|
||||
metrics:
|
||||
- type: accuracy
|
||||
value: 77.2
|
||||
name: Accuracy
|
||||
- task:
|
||||
type: text-generation
|
||||
name: Math Reasoning
|
||||
dataset:
|
||||
name: MATH-500
|
||||
type: HuggingFaceH4/MATH-500
|
||||
split: test
|
||||
metrics:
|
||||
- type: accuracy
|
||||
value: 55.2
|
||||
name: Accuracy
|
||||
---
|
||||
|
||||
# Qwen3-1.7B Fine-tuned for GSM8K Math Reasoning
|
||||
|
||||
This model is a fine-tuned version of [Qwen/Qwen3-1.7B](https://huggingface.co/Qwen/Qwen3-1.7B) optimized for mathematical reasoning on the GSM8K benchmark.
|
||||
|
||||
## Performance
|
||||
|
||||
| Benchmark | Accuracy | Notes |
|
||||
|-----------|----------|-------|
|
||||
| **GSM8K** | **77.2%** | Grade school math (1,319 test problems) |
|
||||
| **MATH-500** | **55.2%** | Competition math (500 test problems) |
|
||||
| Baseline GSM8K | 20% | Original Qwen3-1.7B |
|
||||
|
||||
### MATH-500 Breakdown by Difficulty Level
|
||||
|
||||
| Level | Accuracy |
|
||||
|-------|----------|
|
||||
| Level 1 (Easiest) | 86.0% |
|
||||
| Level 2 | 68.9% |
|
||||
| Level 3 | 64.8% |
|
||||
| Level 4 | 54.7% |
|
||||
| Level 5 (Hardest) | 29.1% |
|
||||
|
||||
### MATH-500 Breakdown by Subject
|
||||
|
||||
| Subject | Accuracy |
|
||||
|---------|----------|
|
||||
| Algebra | 71.8% |
|
||||
| Prealgebra | 68.3% |
|
||||
| Number Theory | 61.3% |
|
||||
| Counting & Probability | 55.3% |
|
||||
| Geometry | 43.9% |
|
||||
| Precalculus | 41.1% |
|
||||
| Intermediate Algebra | 32.0% |
|
||||
|
||||
### Baseline Comparison
|
||||
|
||||
| Model | GSM8K | MATH-500 | Notes |
|
||||
|-------|-------|----------|-------|
|
||||
| **This model (SFT)** | **77.2%** | 55.2% | Optimized for GSM8K |
|
||||
| Qwen3-1.7B (base) | ~20% | 62.0% | Pre-training only |
|
||||
|
||||
Note: The fine-tuned model shows significant improvement on GSM8K (+57pp) but slightly lower performance on MATH-500 compared to the base model. This is expected as the training focused on GSM8K-style problems.
|
||||
|
||||
## GGUF Quantized Versions
|
||||
|
||||
For deployment with llama.cpp, Ollama, or other GGUF-compatible runtimes:
|
||||
|
||||
| File | Size | Description |
|
||||
|------|------|-------------|
|
||||
| [qwen3-1.7b-gsm8k-q8_0.gguf](https://huggingface.co/HuggingFaceTB/qwen3-1.7b-gsm8k-sft/blob/main/qwen3-1.7b-gsm8k-q8_0.gguf) | 1.8 GB | 8-bit quantized (recommended) |
|
||||
| [qwen3-1.7b-gsm8k-f16.gguf](https://huggingface.co/HuggingFaceTB/qwen3-1.7b-gsm8k-sft/blob/main/qwen3-1.7b-gsm8k-f16.gguf) | 3.3 GB | Full FP16 precision |
|
||||
|
||||
### Usage with Ollama
|
||||
|
||||
```bash
|
||||
# Download and run
|
||||
ollama run hf.co/HuggingFaceTB/qwen3-1.7b-gsm8k-sft:q8_0
|
||||
```
|
||||
|
||||
### Usage with llama.cpp
|
||||
|
||||
```bash
|
||||
# Download the GGUF file
|
||||
huggingface-cli download HuggingFaceTB/qwen3-1.7b-gsm8k-sft qwen3-1.7b-gsm8k-q8_0.gguf
|
||||
|
||||
# Run inference
|
||||
./llama-cli -m qwen3-1.7b-gsm8k-q8_0.gguf -p "Solve: If a train travels 120 miles in 2 hours, what is its average speed?"
|
||||
```
|
||||
|
||||
## Training Details
|
||||
|
||||
### Dataset
|
||||
- **Size**: 247,467 examples
|
||||
- **Sources**:
|
||||
- [GSM8K](https://huggingface.co/datasets/openai/gsm8k) train set (7,473 examples)
|
||||
- [MetaMathQA](https://huggingface.co/datasets/meta-math/MetaMathQA) GSM-related examples (239,994 examples)
|
||||
- **Format**: Conversational messages with `<think>...</think>` chain-of-thought reasoning
|
||||
|
||||
### Training Configuration
|
||||
- **Stage 1** (2 epochs): lr=2e-5, loss 0.30 → 0.17
|
||||
- **Stage 2** (1 epoch): lr=5e-6, loss 0.17 → 0.167
|
||||
- **Batch size**: 8 per device, gradient accumulation 4
|
||||
- **Hardware**: NVIDIA H100 80GB GPU
|
||||
- **Total training time**: ~7 hours
|
||||
|
||||
### Hyperparameters
|
||||
```python
|
||||
SFTConfig(
|
||||
num_train_epochs=2, # Stage 1
|
||||
per_device_train_batch_size=8,
|
||||
gradient_accumulation_steps=4,
|
||||
learning_rate=2e-5, # 5e-6 for Stage 2
|
||||
lr_scheduler_type="cosine",
|
||||
warmup_ratio=0.03,
|
||||
weight_decay=0.01,
|
||||
max_length=1024,
|
||||
packing=True,
|
||||
bf16=True,
|
||||
gradient_checkpointing=True,
|
||||
)
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
import torch
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"HuggingFaceTB/qwen3-1.7b-gsm8k-sft",
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto"
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/qwen3-1.7b-gsm8k-sft")
|
||||
|
||||
# For math problems, the model uses chain-of-thought reasoning
|
||||
messages = [
|
||||
{"role": "user", "content": "Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?"}
|
||||
]
|
||||
|
||||
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||
inputs = tokenizer(text, return_tensors="pt").to(model.device)
|
||||
outputs = model.generate(**inputs, max_new_tokens=1024, do_sample=False)
|
||||
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
||||
```
|
||||
|
||||
## Evaluation
|
||||
|
||||
### GSM8K
|
||||
- **Accuracy**: 77.2% ± 1.2% (standard error)
|
||||
- Test set: 1,319 grade school math word problems
|
||||
|
||||
### MATH-500
|
||||
- **Accuracy**: 55.2%
|
||||
- Test set: 500 competition-level math problems
|
||||
- Best performance on Algebra (71.8%) and Prealgebra (68.3%)
|
||||
- Model uses chain-of-thought reasoning enclosed in `<think>...</think>` tags
|
||||
|
||||
## Key Learnings
|
||||
|
||||
1. **Chain-of-thought format is crucial** - The `<think>...</think>` reasoning format significantly improves math performance
|
||||
2. **Large diverse dataset works better** - MetaMathQA (240K examples) outperforms small task-specific data
|
||||
3. **Two-stage training** - Starting with higher LR (2e-5) then refining with lower LR (5e-6) works well
|
||||
4. **Diminishing returns after ~3 epochs** - Additional fine-tuning showed minimal improvement
|
||||
5. **Transfer to harder problems** - GSM8K training also improves MATH-500 performance, especially on algebra
|
||||
|
||||
## Training Scripts
|
||||
|
||||
Training scripts are available in the `scripts/` directory:
|
||||
- `train_improved.py` - Main training script (Stage 1)
|
||||
- `train_continued.py` - Continued training script (Stage 2)
|
||||
- `evaluate.py` - GSM8K evaluation script
|
||||
- `evaluate_math500.py` - MATH-500 evaluation script
|
||||
- `prepare_combined_data.py` - Data preparation script
|
||||
|
||||
## Citation
|
||||
|
||||
If you use this model, please cite:
|
||||
|
||||
```bibtex
|
||||
@misc{qwen3-gsm8k-sft,
|
||||
title={Qwen3-1.7B Fine-tuned for GSM8K},
|
||||
author={HuggingFaceTB},
|
||||
year={2026},
|
||||
publisher={Hugging Face},
|
||||
url={https://huggingface.co/HuggingFaceTB/qwen3-1.7b-gsm8k-sft}
|
||||
}
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
This model inherits the license from the base model [Qwen/Qwen3-1.7B](https://huggingface.co/Qwen/Qwen3-1.7B) (Apache 2.0).
|
||||
28
added_tokens.json
Normal file
28
added_tokens.json
Normal file
@@ -0,0 +1,28 @@
|
||||
{
|
||||
"</think>": 151668,
|
||||
"</tool_call>": 151658,
|
||||
"</tool_response>": 151666,
|
||||
"<think>": 151667,
|
||||
"<tool_call>": 151657,
|
||||
"<tool_response>": 151665,
|
||||
"<|box_end|>": 151649,
|
||||
"<|box_start|>": 151648,
|
||||
"<|endoftext|>": 151643,
|
||||
"<|file_sep|>": 151664,
|
||||
"<|fim_middle|>": 151660,
|
||||
"<|fim_pad|>": 151662,
|
||||
"<|fim_prefix|>": 151659,
|
||||
"<|fim_suffix|>": 151661,
|
||||
"<|im_end|>": 151645,
|
||||
"<|im_start|>": 151644,
|
||||
"<|image_pad|>": 151655,
|
||||
"<|object_ref_end|>": 151647,
|
||||
"<|object_ref_start|>": 151646,
|
||||
"<|quad_end|>": 151651,
|
||||
"<|quad_start|>": 151650,
|
||||
"<|repo_name|>": 151663,
|
||||
"<|video_pad|>": 151656,
|
||||
"<|vision_end|>": 151653,
|
||||
"<|vision_pad|>": 151654,
|
||||
"<|vision_start|>": 151652
|
||||
}
|
||||
89
chat_template.jinja
Normal file
89
chat_template.jinja
Normal file
@@ -0,0 +1,89 @@
|
||||
{%- if tools %}
|
||||
{{- '<|im_start|>system\n' }}
|
||||
{%- if messages[0].role == 'system' %}
|
||||
{{- messages[0].content + '\n\n' }}
|
||||
{%- endif %}
|
||||
{{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
|
||||
{%- for tool in tools %}
|
||||
{{- "\n" }}
|
||||
{{- tool | tojson }}
|
||||
{%- endfor %}
|
||||
{{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
|
||||
{%- else %}
|
||||
{%- if messages[0].role == 'system' %}
|
||||
{{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
|
||||
{%- for message in messages[::-1] %}
|
||||
{%- set index = (messages|length - 1) - loop.index0 %}
|
||||
{%- if ns.multi_step_tool and message.role == "user" and message.content is string and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}
|
||||
{%- set ns.multi_step_tool = false %}
|
||||
{%- set ns.last_query_index = index %}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{%- for message in messages %}
|
||||
{%- if message.content is string %}
|
||||
{%- set content = message.content %}
|
||||
{%- else %}
|
||||
{%- set content = '' %}
|
||||
{%- endif %}
|
||||
{%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
|
||||
{{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }}
|
||||
{%- elif message.role == "assistant" %}
|
||||
{%- set reasoning_content = '' %}
|
||||
{%- if message.reasoning_content is string %}
|
||||
{%- set reasoning_content = message.reasoning_content %}
|
||||
{%- else %}
|
||||
{%- if '</think>' in content %}
|
||||
{%- set reasoning_content = content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
|
||||
{%- set content = content.split('</think>')[-1].lstrip('\n') %}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
{%- if loop.index0 > ns.last_query_index %}
|
||||
{%- if loop.last or (not loop.last and reasoning_content) %}
|
||||
{{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content.strip('\n') + '\n</think>\n\n' + content.lstrip('\n') }}
|
||||
{%- else %}
|
||||
{{- '<|im_start|>' + message.role + '\n' + content }}
|
||||
{%- endif %}
|
||||
{%- else %}
|
||||
{{- '<|im_start|>' + message.role + '\n' + content }}
|
||||
{%- endif %}
|
||||
{%- if message.tool_calls %}
|
||||
{%- for tool_call in message.tool_calls %}
|
||||
{%- if (loop.first and content) or (not loop.first) %}
|
||||
{{- '\n' }}
|
||||
{%- endif %}
|
||||
{%- if tool_call.function %}
|
||||
{%- set tool_call = tool_call.function %}
|
||||
{%- endif %}
|
||||
{{- '<tool_call>\n{"name": "' }}
|
||||
{{- tool_call.name }}
|
||||
{{- '", "arguments": ' }}
|
||||
{%- if tool_call.arguments is string %}
|
||||
{{- tool_call.arguments }}
|
||||
{%- else %}
|
||||
{{- tool_call.arguments | tojson }}
|
||||
{%- endif %}
|
||||
{{- '}\n</tool_call>' }}
|
||||
{%- endfor %}
|
||||
{%- endif %}
|
||||
{{- '<|im_end|>\n' }}
|
||||
{%- elif message.role == "tool" %}
|
||||
{%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
|
||||
{{- '<|im_start|>user' }}
|
||||
{%- endif %}
|
||||
{{- '\n<tool_response>\n' }}
|
||||
{{- content }}
|
||||
{{- '\n</tool_response>' }}
|
||||
{%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
|
||||
{{- '<|im_end|>\n' }}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{%- if add_generation_prompt %}
|
||||
{{- '<|im_start|>assistant\n' }}
|
||||
{%- if enable_thinking is defined and enable_thinking is false %}
|
||||
{{- '<think>\n\n</think>\n\n' }}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
60
config.json
Normal file
60
config.json
Normal file
@@ -0,0 +1,60 @@
|
||||
{
|
||||
"architectures": [
|
||||
"Qwen3ForCausalLM"
|
||||
],
|
||||
"attention_bias": false,
|
||||
"attention_dropout": 0.0,
|
||||
"dtype": "bfloat16",
|
||||
"eos_token_id": 151645,
|
||||
"head_dim": 128,
|
||||
"hidden_act": "silu",
|
||||
"hidden_size": 2048,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 6144,
|
||||
"layer_types": [
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention"
|
||||
],
|
||||
"max_position_embeddings": 40960,
|
||||
"max_window_layers": 28,
|
||||
"model_type": "qwen3",
|
||||
"num_attention_heads": 16,
|
||||
"num_hidden_layers": 28,
|
||||
"num_key_value_heads": 8,
|
||||
"pad_token_id": 151643,
|
||||
"rms_norm_eps": 1e-06,
|
||||
"rope_scaling": null,
|
||||
"rope_theta": 1000000,
|
||||
"sliding_window": null,
|
||||
"tie_word_embeddings": true,
|
||||
"transformers_version": "4.57.6",
|
||||
"use_cache": true,
|
||||
"use_sliding_window": false,
|
||||
"vocab_size": 151936
|
||||
}
|
||||
12
generation_config.json
Normal file
12
generation_config.json
Normal file
@@ -0,0 +1,12 @@
|
||||
{
|
||||
"do_sample": true,
|
||||
"eos_token_id": [
|
||||
151645,
|
||||
151643
|
||||
],
|
||||
"pad_token_id": 151643,
|
||||
"temperature": 0.6,
|
||||
"top_k": 20,
|
||||
"top_p": 0.95,
|
||||
"transformers_version": "4.57.6"
|
||||
}
|
||||
151388
merges.txt
Normal file
151388
merges.txt
Normal file
File diff suppressed because it is too large
Load Diff
3
model.safetensors
Normal file
3
model.safetensors
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:1febbe96ad1e0ad1ac19919d0eb5362bb013b0493063de241d1ecdd7b67bb291
|
||||
size 3441185608
|
||||
3
qwen3-1.7b-gsm8k-f16.gguf
Normal file
3
qwen3-1.7b-gsm8k-f16.gguf
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:cd6b433c2c52329f688810ded7e66d38c5313f8ca93ca4ec0e45fc9a63b9a79b
|
||||
size 3447348928
|
||||
3
qwen3-1.7b-gsm8k-q8_0.gguf
Normal file
3
qwen3-1.7b-gsm8k-q8_0.gguf
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:9508fd895d10cb39fd4799d709c3599efeb00a61d4afbc05ecbfa34b02998e1a
|
||||
size 1834426048
|
||||
138
scripts/evaluate.py
Normal file
138
scripts/evaluate.py
Normal file
@@ -0,0 +1,138 @@
|
||||
#!/usr/bin/env python3
|
||||
from __future__ import annotations
|
||||
import os
|
||||
|
||||
import argparse
|
||||
import json
|
||||
|
||||
from inspect_ai.log._log import EvalLog, EvalMetric, EvalSample
|
||||
from inspect_ai import eval as inspect_eval # type: ignore # noqa: E402
|
||||
from inspect_ai.util._display import init_display_type # noqa: E402
|
||||
|
||||
import inspect_evals.gsm8k # noqa: F401, E402 (registers task definitions)
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description="Run Inspect AI eval without banners.")
|
||||
parser.add_argument(
|
||||
"--model-path",
|
||||
type=str,
|
||||
default="final_model",
|
||||
help="Path to the Hugging Face model (directory or model identifier).",
|
||||
)
|
||||
# this is a good limit for this task, just keep it like that (or use less in case you want faster tests)
|
||||
parser.add_argument(
|
||||
"--limit",
|
||||
type=int,
|
||||
default=150,
|
||||
help="Optional limit for number of samples to evaluate.",
|
||||
)
|
||||
parser.add_argument(
|
||||
'--json-output-file',
|
||||
type=str,
|
||||
default=None,
|
||||
help="Optional path to output the metrics as a seperate JSON file.",
|
||||
)
|
||||
parser.add_argument(
|
||||
'--templates-dir',
|
||||
type=str,
|
||||
default="templates/",
|
||||
)
|
||||
# You can adjust --max-connections if you want faster tests and don't receive errors (or if you have issues with vllm, try lowering this value)
|
||||
parser.add_argument(
|
||||
"--max-connections",
|
||||
type=int,
|
||||
default=2,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-tokens",
|
||||
type=int,
|
||||
default=4000,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gpu-memory-utilization",
|
||||
type=float,
|
||||
default=0.3,
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
|
||||
init_display_type("plain")
|
||||
|
||||
other_kwargs = {}
|
||||
if (args.limit is not None) and (args.limit != -1):
|
||||
other_kwargs["limit"] = args.limit
|
||||
|
||||
task = "inspect_evals/gsm8k"
|
||||
model_args = {
|
||||
'gpu_memory_utilization': args.gpu_memory_utilization,
|
||||
}
|
||||
model_args.update(template_kwargs(args))
|
||||
|
||||
eval_out = inspect_eval(
|
||||
task,
|
||||
model=f"vllm/{args.model_path}",
|
||||
model_args=model_args,
|
||||
score_display=False,
|
||||
log_realtime=False,
|
||||
log_format='json',
|
||||
timeout=18000000,
|
||||
attempt_timeout=18000000,
|
||||
max_tokens=args.max_tokens,
|
||||
max_connections=args.max_connections,
|
||||
**other_kwargs,
|
||||
)
|
||||
|
||||
if args.json_output_file is not None:
|
||||
assert len(eval_out) == 1, eval_out
|
||||
assert len(eval_out[0].results.scores) == 1, eval_out[0].results.scores
|
||||
metrics = {}
|
||||
for k, v in eval_out[0].results.scores[0].metrics.items():
|
||||
metrics[k] = v.value
|
||||
|
||||
with open(args.json_output_file, 'w') as f:
|
||||
json.dump(metrics, f, indent=2)
|
||||
|
||||
def model_type(args) -> str:
|
||||
if 'qwen' in args.model_path.lower():
|
||||
return 'qwen'
|
||||
if 'llama' in args.model_path.lower():
|
||||
return 'llama'
|
||||
if 'gemma' in args.model_path.lower():
|
||||
return 'gemma'
|
||||
if 'smollm' in args.model_path.lower():
|
||||
return 'smollm'
|
||||
|
||||
with open(os.path.join(args.model_path, "config.json"), 'r') as f:
|
||||
config = json.load(f)
|
||||
architecture = config['architectures'][0].lower()
|
||||
if 'gemma' in architecture:
|
||||
return 'gemma'
|
||||
if 'llama' in architecture:
|
||||
return 'llama'
|
||||
if 'qwen' in architecture:
|
||||
return 'qwen'
|
||||
if 'smollm' in architecture:
|
||||
return 'smollm'
|
||||
raise ValueError(architecture)
|
||||
|
||||
def template_kwargs(args) -> dict:
|
||||
model_type_str = model_type(args)
|
||||
if model_type_str == 'qwen':
|
||||
template = 'qwen3.jinja'
|
||||
elif model_type_str == 'llama':
|
||||
template = 'llama3.jinja'
|
||||
elif model_type_str == 'gemma':
|
||||
template = 'gemma3.jinja'
|
||||
elif model_type_str == 'smollm':
|
||||
template = 'smollm.jinja'
|
||||
else:
|
||||
raise ValueError(model_type_str)
|
||||
return {
|
||||
'chat_template': os.path.join(args.templates_dir, template)
|
||||
}
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
159
scripts/evaluate_math500.py
Normal file
159
scripts/evaluate_math500.py
Normal file
@@ -0,0 +1,159 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Evaluate model on MATH-500 dataset (harder math problems)."""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import re
|
||||
from datasets import load_dataset
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
def extract_answer(response: str) -> str:
|
||||
"""Extract the final answer from model response."""
|
||||
# Look for boxed answer first (common in MATH format)
|
||||
boxed_match = re.search(r'\\boxed\{([^}]+)\}', response)
|
||||
if boxed_match:
|
||||
return boxed_match.group(1).strip()
|
||||
|
||||
# Look for "The answer is X" pattern
|
||||
answer_match = re.search(r'[Tt]he (?:final )?answer is[:\s]*([^\n.]+)', response)
|
||||
if answer_match:
|
||||
return answer_match.group(1).strip()
|
||||
|
||||
# Look for "= X" at the end
|
||||
equals_match = re.search(r'=\s*([^\n=]+?)\s*$', response)
|
||||
if equals_match:
|
||||
return equals_match.group(1).strip()
|
||||
|
||||
# Return last line as fallback
|
||||
lines = [l.strip() for l in response.strip().split('\n') if l.strip()]
|
||||
return lines[-1] if lines else ""
|
||||
|
||||
def normalize_answer(answer: str) -> str:
|
||||
"""Normalize answer for comparison."""
|
||||
# Remove common formatting
|
||||
answer = answer.strip()
|
||||
answer = re.sub(r'\\text\{([^}]*)\}', r'\1', answer)
|
||||
answer = re.sub(r'\\mathrm\{([^}]*)\}', r'\1', answer)
|
||||
answer = re.sub(r'\\left|\\right', '', answer)
|
||||
answer = re.sub(r'\$', '', answer)
|
||||
answer = answer.strip()
|
||||
return answer.lower()
|
||||
|
||||
def answers_match(predicted: str, expected: str) -> bool:
|
||||
"""Check if answers match (with some tolerance)."""
|
||||
pred_norm = normalize_answer(predicted)
|
||||
exp_norm = normalize_answer(expected)
|
||||
|
||||
# Direct match
|
||||
if pred_norm == exp_norm:
|
||||
return True
|
||||
|
||||
# Try numeric comparison
|
||||
try:
|
||||
pred_num = float(re.sub(r'[^\d.-]', '', pred_norm))
|
||||
exp_num = float(re.sub(r'[^\d.-]', '', exp_norm))
|
||||
if abs(pred_num - exp_num) < 1e-6:
|
||||
return True
|
||||
except:
|
||||
pass
|
||||
|
||||
# Check if one contains the other
|
||||
if exp_norm in pred_norm or pred_norm in exp_norm:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model-path", type=str, default="final_model")
|
||||
parser.add_argument("--limit", type=int, default=100)
|
||||
args = parser.parse_args()
|
||||
|
||||
print(f"Loading MATH-500 dataset...")
|
||||
dataset = load_dataset("HuggingFaceH4/MATH-500", split="test")
|
||||
|
||||
if args.limit:
|
||||
dataset = dataset.select(range(min(args.limit, len(dataset))))
|
||||
|
||||
print(f"Evaluating {len(dataset)} problems...")
|
||||
|
||||
# Load model
|
||||
print(f"Loading model from {args.model_path}...")
|
||||
llm = LLM(
|
||||
model=args.model_path,
|
||||
dtype="bfloat16",
|
||||
max_model_len=4096,
|
||||
gpu_memory_utilization=0.9,
|
||||
)
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0,
|
||||
max_tokens=2048,
|
||||
stop=["<|im_end|>", "<|endoftext|>"],
|
||||
)
|
||||
|
||||
# Prepare prompts
|
||||
prompts = []
|
||||
for item in dataset:
|
||||
problem = item["problem"]
|
||||
prompt = f"<|im_start|>user\n{problem}<|im_end|>\n<|im_start|>assistant\n"
|
||||
prompts.append(prompt)
|
||||
|
||||
# Generate
|
||||
print("Generating responses...")
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
# Evaluate
|
||||
correct = 0
|
||||
results_by_level = {}
|
||||
results_by_subject = {}
|
||||
|
||||
for i, (item, output) in enumerate(zip(dataset, outputs)):
|
||||
response = output.outputs[0].text
|
||||
predicted = extract_answer(response)
|
||||
expected = item["answer"]
|
||||
level = item["level"]
|
||||
subject = item["subject"]
|
||||
|
||||
is_correct = answers_match(predicted, expected)
|
||||
if is_correct:
|
||||
correct += 1
|
||||
|
||||
# Track by level
|
||||
if level not in results_by_level:
|
||||
results_by_level[level] = {"correct": 0, "total": 0}
|
||||
results_by_level[level]["total"] += 1
|
||||
if is_correct:
|
||||
results_by_level[level]["correct"] += 1
|
||||
|
||||
# Track by subject
|
||||
if subject not in results_by_subject:
|
||||
results_by_subject[subject] = {"correct": 0, "total": 0}
|
||||
results_by_subject[subject]["total"] += 1
|
||||
if is_correct:
|
||||
results_by_subject[subject]["correct"] += 1
|
||||
|
||||
if (i + 1) % 20 == 0:
|
||||
print(f"Progress: {i+1}/{len(dataset)}, Accuracy so far: {correct/(i+1)*100:.1f}%")
|
||||
|
||||
# Print results
|
||||
accuracy = correct / len(dataset) * 100
|
||||
print(f"\n{'='*60}")
|
||||
print(f"MATH-500 Results ({len(dataset)} problems)")
|
||||
print(f"{'='*60}")
|
||||
print(f"Overall Accuracy: {accuracy:.1f}% ({correct}/{len(dataset)})")
|
||||
|
||||
print(f"\nBy Level:")
|
||||
for level in sorted(results_by_level.keys()):
|
||||
stats = results_by_level[level]
|
||||
acc = stats["correct"] / stats["total"] * 100
|
||||
print(f" {level}: {acc:.1f}% ({stats['correct']}/{stats['total']})")
|
||||
|
||||
print(f"\nBy Subject:")
|
||||
for subject in sorted(results_by_subject.keys()):
|
||||
stats = results_by_subject[subject]
|
||||
acc = stats["correct"] / stats["total"] * 100
|
||||
print(f" {subject}: {acc:.1f}% ({stats['correct']}/{stats['total']})")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
154
scripts/prepare_combined_data.py
Normal file
154
scripts/prepare_combined_data.py
Normal file
@@ -0,0 +1,154 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Combine GSM8K training data with MetaMathQA GSM-related examples.
|
||||
This creates a larger, more diverse training set.
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
from datasets import load_dataset
|
||||
|
||||
def extract_answer_gsm8k(answer_text):
|
||||
"""Extract the final numerical answer from GSM8K answer format."""
|
||||
match = re.search(r'####\s*(-?[\d,]+\.?\d*)', answer_text)
|
||||
if match:
|
||||
return match.group(1).replace(',', '')
|
||||
return None
|
||||
|
||||
def format_reasoning_gsm8k(answer_text):
|
||||
"""Convert GSM8K step-by-step format to thinking format."""
|
||||
reasoning = re.sub(r'####\s*-?[\d,]+\.?\d*\s*$', '', answer_text).strip()
|
||||
reasoning = re.sub(r'<<[^>]+>>', '', reasoning)
|
||||
return reasoning
|
||||
|
||||
def extract_answer_metamath(response):
|
||||
"""Extract answer from MetaMathQA format (usually ends with boxed answer)."""
|
||||
# Try to find boxed answer
|
||||
match = re.search(r'\\boxed\{([^}]+)\}', response)
|
||||
if match:
|
||||
return match.group(1).strip()
|
||||
# Try to find "the answer is X" pattern
|
||||
match = re.search(r'the answer is[:\s]*\$?(-?[\d,]+\.?\d*)', response, re.IGNORECASE)
|
||||
if match:
|
||||
return match.group(1).replace(',', '')
|
||||
# Try to find "= X" at the end
|
||||
match = re.search(r'=\s*\$?(-?[\d,]+\.?\d*)\s*(?:dollars?|\.)?$', response)
|
||||
if match:
|
||||
return match.group(1).replace(',', '')
|
||||
return None
|
||||
|
||||
def create_gsm8k_example(question, answer):
|
||||
"""Create a training example from GSM8K format."""
|
||||
final_answer = extract_answer_gsm8k(answer)
|
||||
reasoning = format_reasoning_gsm8k(answer)
|
||||
|
||||
if final_answer is None:
|
||||
return None
|
||||
|
||||
assistant_content = f"""<think>
|
||||
Let me solve this step by step.
|
||||
|
||||
{reasoning}
|
||||
|
||||
Therefore, the answer is {final_answer}.
|
||||
</think>
|
||||
|
||||
The answer is {final_answer}"""
|
||||
|
||||
return {
|
||||
"messages": [
|
||||
{"role": "user", "content": f"Solve the following math problem step by step. Show your reasoning and then provide the final answer.\n\n{question}"},
|
||||
{"role": "assistant", "content": assistant_content}
|
||||
]
|
||||
}
|
||||
|
||||
def create_metamath_example(query, response):
|
||||
"""Create a training example from MetaMathQA format."""
|
||||
# Clean up the response - remove LaTeX formatting artifacts
|
||||
clean_response = response.replace('\\n', '\n').strip()
|
||||
|
||||
# Extract the answer
|
||||
final_answer = extract_answer_metamath(clean_response)
|
||||
if final_answer is None:
|
||||
return None
|
||||
|
||||
# Remove the boxed answer and everything after for reasoning
|
||||
reasoning = re.sub(r'\\boxed\{[^}]+\}.*$', '', clean_response, flags=re.DOTALL).strip()
|
||||
reasoning = re.sub(r'The answer is.*$', '', reasoning, flags=re.IGNORECASE | re.DOTALL).strip()
|
||||
|
||||
# Skip if reasoning is too short
|
||||
if len(reasoning) < 50:
|
||||
return None
|
||||
|
||||
assistant_content = f"""<think>
|
||||
Let me solve this step by step.
|
||||
|
||||
{reasoning}
|
||||
|
||||
Therefore, the answer is {final_answer}.
|
||||
</think>
|
||||
|
||||
The answer is {final_answer}"""
|
||||
|
||||
return {
|
||||
"messages": [
|
||||
{"role": "user", "content": f"Solve the following math problem step by step. Show your reasoning and then provide the final answer.\n\n{query}"},
|
||||
{"role": "assistant", "content": assistant_content}
|
||||
]
|
||||
}
|
||||
|
||||
def main():
|
||||
training_data = []
|
||||
|
||||
# Load GSM8K training data
|
||||
print("Loading GSM8K dataset...")
|
||||
gsm8k = load_dataset("openai/gsm8k", "main", split="train")
|
||||
print(f"Loaded {len(gsm8k)} GSM8K examples")
|
||||
|
||||
gsm8k_count = 0
|
||||
for example in gsm8k:
|
||||
formatted = create_gsm8k_example(example['question'], example['answer'])
|
||||
if formatted:
|
||||
training_data.append(formatted)
|
||||
gsm8k_count += 1
|
||||
print(f"Added {gsm8k_count} GSM8K examples")
|
||||
|
||||
# Load MetaMathQA - only GSM-related examples
|
||||
print("\nLoading MetaMathQA dataset...")
|
||||
metamath = load_dataset("meta-math/MetaMathQA", split="train")
|
||||
print(f"Loaded {len(metamath)} MetaMathQA examples")
|
||||
|
||||
# Filter for GSM-related examples only
|
||||
metamath_count = 0
|
||||
for example in metamath:
|
||||
if 'GSM' in example['type']: # GSM_Rephrased, GSM_SV, GSM_AnsAug, etc.
|
||||
formatted = create_metamath_example(example['query'], example['response'])
|
||||
if formatted:
|
||||
training_data.append(formatted)
|
||||
metamath_count += 1
|
||||
print(f"Added {metamath_count} MetaMathQA GSM examples")
|
||||
|
||||
print(f"\nTotal training examples: {len(training_data)}")
|
||||
|
||||
# Shuffle the data
|
||||
import random
|
||||
random.seed(42)
|
||||
random.shuffle(training_data)
|
||||
|
||||
# Save to JSONL
|
||||
output_file = "combined_math_train.jsonl"
|
||||
with open(output_file, 'w') as f:
|
||||
for item in training_data:
|
||||
f.write(json.dumps(item) + '\n')
|
||||
|
||||
print(f"Saved to {output_file}")
|
||||
|
||||
# Show samples
|
||||
print("\n=== Sample GSM8K example ===")
|
||||
for item in training_data[:10]:
|
||||
if "Natalia" in item['messages'][0]['content']:
|
||||
print(json.dumps(item, indent=2)[:500])
|
||||
break
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
90
scripts/train_continued.py
Normal file
90
scripts/train_continued.py
Normal file
@@ -0,0 +1,90 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Continue training the already fine-tuned model with lower learning rate for additional refinement.
|
||||
"""
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from trl import SFTTrainer, SFTConfig
|
||||
import os
|
||||
|
||||
def main():
|
||||
# Load from our previously trained model
|
||||
print("Loading previously trained model from final_model/...")
|
||||
model_name = "./final_model"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
|
||||
# Ensure pad token is set
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
attn_implementation="sdpa",
|
||||
device_map="auto",
|
||||
)
|
||||
|
||||
# Load dataset
|
||||
print("Loading dataset...")
|
||||
dataset = load_dataset("json", data_files="combined_math_train.jsonl", split="train")
|
||||
print(f"Dataset size: {len(dataset)}")
|
||||
|
||||
# Training config - lower LR for refinement, 1 more epoch
|
||||
training_args = SFTConfig(
|
||||
output_dir="./sft_output_continued",
|
||||
num_train_epochs=1,
|
||||
per_device_train_batch_size=8,
|
||||
gradient_accumulation_steps=4,
|
||||
learning_rate=5e-6, # Lower LR for continued training
|
||||
lr_scheduler_type="cosine",
|
||||
warmup_ratio=0.01,
|
||||
weight_decay=0.01,
|
||||
logging_steps=100,
|
||||
save_steps=2000,
|
||||
save_total_limit=2,
|
||||
bf16=True,
|
||||
gradient_checkpointing=True,
|
||||
gradient_checkpointing_kwargs={"use_reentrant": False},
|
||||
max_length=1024,
|
||||
packing=True,
|
||||
report_to="none",
|
||||
seed=42,
|
||||
dataloader_num_workers=4,
|
||||
optim="adamw_torch_fused",
|
||||
)
|
||||
|
||||
# Create trainer
|
||||
print("Creating trainer...")
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
processing_class=tokenizer,
|
||||
)
|
||||
|
||||
# Print training info
|
||||
print(f"\n=== Continued Training Configuration ===")
|
||||
print(f"Model: {model_name} (previously fine-tuned)")
|
||||
print(f"Dataset size: {len(dataset)}")
|
||||
print(f"Batch size: {training_args.per_device_train_batch_size}")
|
||||
print(f"Gradient accumulation: {training_args.gradient_accumulation_steps}")
|
||||
print(f"Effective batch size: {training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps}")
|
||||
print(f"Learning rate: {training_args.learning_rate}")
|
||||
print(f"Epochs: {training_args.num_train_epochs}")
|
||||
print("="*40)
|
||||
|
||||
# Train
|
||||
print("\nStarting continued training...")
|
||||
trainer.train()
|
||||
|
||||
# Save final model
|
||||
print("\nSaving model to final_model/...")
|
||||
trainer.save_model("final_model")
|
||||
tokenizer.save_pretrained("final_model")
|
||||
|
||||
print("Continued training complete!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
99
scripts/train_improved.py
Normal file
99
scripts/train_improved.py
Normal file
@@ -0,0 +1,99 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Improved SFT training for GSM8K performance.
|
||||
Key improvements:
|
||||
1. More training data (247K examples from GSM8K + MetaMathQA)
|
||||
2. Multiple epochs with cosine LR schedule
|
||||
3. Proper batch size and gradient accumulation for H100
|
||||
4. Gradient checkpointing for memory efficiency
|
||||
"""
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from trl import SFTTrainer, SFTConfig
|
||||
import os
|
||||
|
||||
def main():
|
||||
# Load model and tokenizer
|
||||
print("Loading model and tokenizer...")
|
||||
model_name = "Qwen/Qwen3-1.7B"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
|
||||
# Ensure pad token is set
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
attn_implementation="sdpa", # Use SDPA instead of flash_attention_2
|
||||
device_map="auto",
|
||||
)
|
||||
|
||||
# Load dataset
|
||||
print("Loading dataset...")
|
||||
dataset = load_dataset("json", data_files="combined_math_train.jsonl", split="train")
|
||||
print(f"Dataset size: {len(dataset)}")
|
||||
|
||||
# Training config - optimized for H100 and GSM8K task
|
||||
# With 247K examples and batch_size 8 * grad_accum 4 = effective batch 32
|
||||
# Steps per epoch: 247467 / 32 ≈ 7733 steps
|
||||
# 2 epochs ≈ 15466 steps
|
||||
training_args = SFTConfig(
|
||||
output_dir="./sft_output_improved",
|
||||
num_train_epochs=2,
|
||||
per_device_train_batch_size=8,
|
||||
gradient_accumulation_steps=4,
|
||||
learning_rate=2e-5,
|
||||
lr_scheduler_type="cosine",
|
||||
warmup_ratio=0.03,
|
||||
weight_decay=0.01,
|
||||
logging_steps=100,
|
||||
save_steps=2000,
|
||||
save_total_limit=3,
|
||||
bf16=True,
|
||||
gradient_checkpointing=True,
|
||||
gradient_checkpointing_kwargs={"use_reentrant": False},
|
||||
max_length=1024, # Math problems don't need very long context
|
||||
packing=True,
|
||||
report_to="none",
|
||||
seed=42,
|
||||
dataloader_num_workers=4,
|
||||
optim="adamw_torch_fused",
|
||||
)
|
||||
|
||||
# Create trainer
|
||||
print("Creating trainer...")
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
processing_class=tokenizer,
|
||||
)
|
||||
|
||||
# Print training info
|
||||
print(f"\n=== Training Configuration ===")
|
||||
print(f"Model: {model_name}")
|
||||
print(f"Dataset size: {len(dataset)}")
|
||||
print(f"Batch size: {training_args.per_device_train_batch_size}")
|
||||
print(f"Gradient accumulation: {training_args.gradient_accumulation_steps}")
|
||||
print(f"Effective batch size: {training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps}")
|
||||
print(f"Learning rate: {training_args.learning_rate}")
|
||||
print(f"Epochs: {training_args.num_train_epochs}")
|
||||
print(f"Max length: {training_args.max_length}")
|
||||
print("="*30)
|
||||
|
||||
# Train
|
||||
print("\nStarting training...")
|
||||
trainer.train()
|
||||
|
||||
# Save final model
|
||||
print("\nSaving model to final_model/...")
|
||||
trainer.save_model("final_model")
|
||||
tokenizer.save_pretrained("final_model")
|
||||
|
||||
print("Training complete!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
31
special_tokens_map.json
Normal file
31
special_tokens_map.json
Normal file
@@ -0,0 +1,31 @@
|
||||
{
|
||||
"additional_special_tokens": [
|
||||
"<|im_start|>",
|
||||
"<|im_end|>",
|
||||
"<|object_ref_start|>",
|
||||
"<|object_ref_end|>",
|
||||
"<|box_start|>",
|
||||
"<|box_end|>",
|
||||
"<|quad_start|>",
|
||||
"<|quad_end|>",
|
||||
"<|vision_start|>",
|
||||
"<|vision_end|>",
|
||||
"<|vision_pad|>",
|
||||
"<|image_pad|>",
|
||||
"<|video_pad|>"
|
||||
],
|
||||
"eos_token": {
|
||||
"content": "<|im_end|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"pad_token": {
|
||||
"content": "<|endoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
}
|
||||
}
|
||||
BIN
tokenizer.json
(Stored with Git LFS)
Normal file
BIN
tokenizer.json
(Stored with Git LFS)
Normal file
Binary file not shown.
239
tokenizer_config.json
Normal file
239
tokenizer_config.json
Normal file
@@ -0,0 +1,239 @@
|
||||
{
|
||||
"add_bos_token": false,
|
||||
"add_prefix_space": false,
|
||||
"added_tokens_decoder": {
|
||||
"151643": {
|
||||
"content": "<|endoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"151644": {
|
||||
"content": "<|im_start|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"151645": {
|
||||
"content": "<|im_end|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"151646": {
|
||||
"content": "<|object_ref_start|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"151647": {
|
||||
"content": "<|object_ref_end|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"151648": {
|
||||
"content": "<|box_start|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"151649": {
|
||||
"content": "<|box_end|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"151650": {
|
||||
"content": "<|quad_start|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"151651": {
|
||||
"content": "<|quad_end|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"151652": {
|
||||
"content": "<|vision_start|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"151653": {
|
||||
"content": "<|vision_end|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"151654": {
|
||||
"content": "<|vision_pad|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"151655": {
|
||||
"content": "<|image_pad|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"151656": {
|
||||
"content": "<|video_pad|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"151657": {
|
||||
"content": "<tool_call>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"151658": {
|
||||
"content": "</tool_call>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"151659": {
|
||||
"content": "<|fim_prefix|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"151660": {
|
||||
"content": "<|fim_middle|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"151661": {
|
||||
"content": "<|fim_suffix|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"151662": {
|
||||
"content": "<|fim_pad|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"151663": {
|
||||
"content": "<|repo_name|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"151664": {
|
||||
"content": "<|file_sep|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"151665": {
|
||||
"content": "<tool_response>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"151666": {
|
||||
"content": "</tool_response>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"151667": {
|
||||
"content": "<think>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"151668": {
|
||||
"content": "</think>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
}
|
||||
},
|
||||
"additional_special_tokens": [
|
||||
"<|im_start|>",
|
||||
"<|im_end|>",
|
||||
"<|object_ref_start|>",
|
||||
"<|object_ref_end|>",
|
||||
"<|box_start|>",
|
||||
"<|box_end|>",
|
||||
"<|quad_start|>",
|
||||
"<|quad_end|>",
|
||||
"<|vision_start|>",
|
||||
"<|vision_end|>",
|
||||
"<|vision_pad|>",
|
||||
"<|image_pad|>",
|
||||
"<|video_pad|>"
|
||||
],
|
||||
"bos_token": null,
|
||||
"clean_up_tokenization_spaces": false,
|
||||
"eos_token": "<|im_end|>",
|
||||
"errors": "replace",
|
||||
"extra_special_tokens": {},
|
||||
"model_max_length": 131072,
|
||||
"pad_token": "<|endoftext|>",
|
||||
"split_special_tokens": false,
|
||||
"tokenizer_class": "Qwen2Tokenizer",
|
||||
"unk_token": null
|
||||
}
|
||||
3
training_args.bin
Normal file
3
training_args.bin
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:4cc4183a399e46771d11513377493de112e1c85dfa95d53dc6b122e5befe0449
|
||||
size 6289
|
||||
1
vocab.json
Normal file
1
vocab.json
Normal file
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user