初始化项目,由ModelHub XC社区提供模型
Model: Shekswess/trlm-135m Source: Original Platform
This commit is contained in:
51
.gitattributes
vendored
Normal file
51
.gitattributes
vendored
Normal file
@@ -0,0 +1,51 @@
|
||||
*.7z filter=lfs diff=lfs merge=lfs -text
|
||||
*.arrow filter=lfs diff=lfs merge=lfs -text
|
||||
|
||||
|
||||
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
||||
*.ftz filter=lfs diff=lfs merge=lfs -text
|
||||
*.gz filter=lfs diff=lfs merge=lfs -text
|
||||
*.h5 filter=lfs diff=lfs merge=lfs -text
|
||||
*.joblib filter=lfs diff=lfs merge=lfs -text
|
||||
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
||||
*.model filter=lfs diff=lfs merge=lfs -text
|
||||
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
||||
*.onnx filter=lfs diff=lfs merge=lfs -text
|
||||
*.ot filter=lfs diff=lfs merge=lfs -text
|
||||
*.parquet filter=lfs diff=lfs merge=lfs -text
|
||||
*.pb filter=lfs diff=lfs merge=lfs -text
|
||||
*.pt filter=lfs diff=lfs merge=lfs -text
|
||||
*.pth filter=lfs diff=lfs merge=lfs -text
|
||||
*.rar filter=lfs diff=lfs merge=lfs -text
|
||||
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
||||
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
||||
*.tflite filter=lfs diff=lfs merge=lfs -text
|
||||
*.tgz filter=lfs diff=lfs merge=lfs -text
|
||||
*.xz filter=lfs diff=lfs merge=lfs -text
|
||||
*.zip filter=lfs diff=lfs merge=lfs -text
|
||||
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
||||
*.tfevents* filter=lfs diff=lfs merge=lfs -text
|
||||
*.db* filter=lfs diff=lfs merge=lfs -text
|
||||
*.ark* filter=lfs diff=lfs merge=lfs -text
|
||||
**/*ckpt*data* filter=lfs diff=lfs merge=lfs -text
|
||||
**/*ckpt*.meta filter=lfs diff=lfs merge=lfs -text
|
||||
**/*ckpt*.index filter=lfs diff=lfs merge=lfs -text
|
||||
|
||||
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
||||
*.gguf* filter=lfs diff=lfs merge=lfs -text
|
||||
*.ggml filter=lfs diff=lfs merge=lfs -text
|
||||
*.llamafile* filter=lfs diff=lfs merge=lfs -text
|
||||
*.pt2 filter=lfs diff=lfs merge=lfs -text
|
||||
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
||||
*.npy filter=lfs diff=lfs merge=lfs -text
|
||||
*.npz filter=lfs diff=lfs merge=lfs -text
|
||||
*.pickle filter=lfs diff=lfs merge=lfs -text
|
||||
*.pkl filter=lfs diff=lfs merge=lfs -text
|
||||
*.tar filter=lfs diff=lfs merge=lfs -text
|
||||
*.wasm filter=lfs diff=lfs merge=lfs -text
|
||||
*.zst filter=lfs diff=lfs merge=lfs -text
|
||||
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
||||
|
||||
tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
||||
training_args.bin filter=lfs diff=lfs merge=lfs -text
|
||||
model.safetensors filter=lfs diff=lfs merge=lfs -text
|
||||
158
README.md
Normal file
158
README.md
Normal file
@@ -0,0 +1,158 @@
|
||||
---
|
||||
library_name: transformers
|
||||
license: apache-2.0
|
||||
base_model: Shekswess/trlm-stage-2-sft-final-2
|
||||
tags:
|
||||
- trl
|
||||
- dpo
|
||||
- preference-alignment
|
||||
- reasoning
|
||||
- generated_from_trainer
|
||||
model-index:
|
||||
- name: trlm-stage-3-dpo-final-2
|
||||
results: []
|
||||
---
|
||||
|
||||
# Tiny Reasoning Language Model (trlm-135)
|
||||
|
||||

|
||||
|
||||
## Table of Contents
|
||||
|
||||
1. [Model Summary](#model-summary)
|
||||
2. [Post-Training Pipeline](#post-training-pipeline)
|
||||
3. [How to use](#how-to-use)
|
||||
4. [Training](#training)
|
||||
5. [Evaluation](#evaluation)
|
||||
6. [Limitations](#limitations)
|
||||
7. [Acknowledgements](#acknowledgements)
|
||||
8. [License](#license)
|
||||
---
|
||||
|
||||
## Model Summary
|
||||
|
||||
The **Tiny Reasoning Language Model (trlm-135)** is a **135M parameter** research prototype designed to study how small models can learn step-by-step reasoning.
|
||||
It was built on top of [SmolLM2-135M-Instruct](https://huggingface.co/HuggingFaceTB/SmolLM2-135M-Instruct) and fine-tuned through a **3-stage pipeline**:
|
||||
|
||||
* **[Stage 1 SFT](https://huggingface.co/Shekswess/trlm-stage-1-sft-final-2)**: general instruction tuning (non-reasoning).
|
||||
* **[Stage 2 SFT](https://huggingface.co/Shekswess/trlm-stage-2-sft-final-2)**: reasoning traces with `<think>` tags.
|
||||
* **[Stage 3 DPO](https://huggingface.co/Shekswess/trlm-stage-3-dpo-final-2)**: preference alignment for reasoning style.
|
||||
|
||||
The **code** for everything can be found **[here](https://github.com/Shekswess/tiny-reasoning-language-model/blob/main/README.md)**
|
||||
|
||||
---
|
||||
|
||||
## Post-Training Pipeline
|
||||
<img width="1014" height="563" alt="image" src="https://github.com/user-attachments/assets/195ef389-6aa9-4527-b4f0-bea68c0841ae" />
|
||||
|
||||
---
|
||||
|
||||
## How to use
|
||||
|
||||
```bash
|
||||
pip install -U transformers accelerate
|
||||
```
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
model_name = "Shekswess/trlm-135m"
|
||||
device = "cuda" # or "cpu"
|
||||
|
||||
# Load tokenizer & model
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
).to(device)
|
||||
|
||||
# Example prompt
|
||||
prompt = "Give me a brief explanation of gravity in simple terms."
|
||||
messages = [
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
|
||||
# Apply chat template
|
||||
text = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
|
||||
inputs = tokenizer([text], return_tensors="pt").to(model.device)
|
||||
|
||||
# Generate
|
||||
outputs = model.generate(**inputs, max_new_tokens=256)
|
||||
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
> For reasoning-heavy tasks, set `temperature=0.6` and `top_p=0.95`.
|
||||
|
||||
---
|
||||
|
||||
## Training
|
||||
|
||||
### Model
|
||||
|
||||
* **Architecture**: Decoder-only transformer (SmolLM2 backbone which infact is Llama 3 based model).
|
||||
* **Parameters**: ~135M.
|
||||
* **Precision**: mix-precision (bfloat16) during training.
|
||||
|
||||
### Software & Hardware
|
||||
|
||||
* **Training Frameworks**: PyTorch (ROCm), Hugging Face Transformers & TRL.
|
||||
* **Hardware**: AMD MI300X (192GB VRAM, 224GB RAM).
|
||||
|
||||
**Special thanks to [@HotAisle](https://x.com/HotAisle)**
|
||||
|
||||
### Training Stages
|
||||
|
||||
1. **Stage 1 – SFT (non-reasoning)**
|
||||
* ~58k samples, everyday conversations & instruction following.
|
||||
2. **Stage 2 – SFT (reasoning)**
|
||||
* ~78k samples with `<think>` segments.
|
||||
3. **Stage 3 – DPO (alignment)**
|
||||
* ~50k preference pairs (chosen vs. rejected reasoning traces).
|
||||
---
|
||||
|
||||
## Evaluation
|
||||
|
||||
Evaluation was done with `lm-eval-harness`:
|
||||
|
||||
| **Benchmark** | **Tiny Reasoning Language Model (trlm-135M)** | **SmolLM2-135M-Instruct** | **Improvements** |
|
||||
| -------------------- | ---------------------------- | ------------------------- | ---------------------------- |
|
||||
| **ARC Challenge** | **40.61** (avg) | 37.3 (avg) | **+3.31** |
|
||||
| **BBH** | **36.80** (3-shot) | 28.2 (3-shot) | **+8.6** |
|
||||
| **BoolQ** | **62.17** | – | N/A |
|
||||
| **GSM8K** | **2.59** (5-shot) | 1.4 (5-shot) | **+1.19** |
|
||||
| **IFEval** | **35.49** (avg) | 29.9 (avg) | **+5.59** |
|
||||
| **MMLU** | **34.95** | 29.3 | **+5.65** |
|
||||
| **PIQA** | **64.91** | 66.3 | **–1.39** |
|
||||
| **HellaSwag** | – | 40.9 | N/A |
|
||||
| **MT-Bench** | – | 19.8 | N/A |
|
||||
|
||||
---
|
||||
|
||||
## Limitations
|
||||
|
||||
* **Not production-ready**: hallucinations and logical errors are frequent.
|
||||
* **Small size**: limited general knowledge and reasoning depth.
|
||||
* **English-only**: multilingual capabilities not explored.
|
||||
|
||||
---
|
||||
|
||||
## Acknowledgements
|
||||
|
||||
- [@HotAisle](https://x.com/HotAisle) for providing the compute resources to train all three stages on a awesome AMD MI300x setup.
|
||||
- [@mkurman88](https://x.com/mkurman88) for ideas, feedback and code samples.
|
||||
- [HuggingFaceTB team](https://huggingface.co/HuggingFaceTB) for SmolLM2-135M-Instruct model and the Smoltalk2 dataset collection.
|
||||
- [@scottgeng00](https://huggingface.co/scottgeng00) for the OLmO-3-Preference-Mix-Deltas dataset.
|
||||
- [@eliebakouchi](https://x.com/eliebakouch) for help with the tokenization.
|
||||
|
||||
---
|
||||
|
||||
## License
|
||||
|
||||
[Apache 2.0](https://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
---
|
||||
4
added_tokens.json
Normal file
4
added_tokens.json
Normal file
@@ -0,0 +1,4 @@
|
||||
{
|
||||
"</think>": 49153,
|
||||
"<think>": 49152
|
||||
}
|
||||
9
chat_template.jinja
Normal file
9
chat_template.jinja
Normal file
@@ -0,0 +1,9 @@
|
||||
{% for message in messages %}
|
||||
{% if loop.first and messages[0]['role'] != 'system' %}
|
||||
{{ '<|im_start|>system\nYou are a helpful AI assistant named Tiny Reasoning Language Model, trained by Shekswess. You are an assistant, with the ability to do reasoning. When performing reasoning always perform your full chain of thought inside <think>...</think> before giving a final answer. You are always reasoning so always use <think> </think> tags.<|im_end|>\n' }}
|
||||
{% endif %}
|
||||
{{ '<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>\n' }}
|
||||
{% endfor %}
|
||||
{% if add_generation_prompt %}
|
||||
{{ '<|im_start|>assistant\n' }}
|
||||
{% endif %}
|
||||
38
config.json
Normal file
38
config.json
Normal file
@@ -0,0 +1,38 @@
|
||||
{
|
||||
"architectures": [
|
||||
"LlamaForCausalLM"
|
||||
],
|
||||
"attention_bias": false,
|
||||
"attention_dropout": 0.0,
|
||||
"bos_token_id": 1,
|
||||
"dtype": "bfloat16",
|
||||
"eos_token_id": 2,
|
||||
"head_dim": 64,
|
||||
"hidden_act": "silu",
|
||||
"hidden_size": 576,
|
||||
"initializer_range": 0.041666666666666664,
|
||||
"intermediate_size": 1536,
|
||||
"is_llama_config": true,
|
||||
"max_position_embeddings": 8192,
|
||||
"mlp_bias": false,
|
||||
"model_type": "llama",
|
||||
"num_attention_heads": 9,
|
||||
"num_hidden_layers": 30,
|
||||
"num_key_value_heads": 3,
|
||||
"pad_token_id": 2,
|
||||
"pretraining_tp": 1,
|
||||
"rms_norm_eps": 1e-05,
|
||||
"rope_interleaved": false,
|
||||
"rope_scaling": null,
|
||||
"rope_theta": 100000,
|
||||
"tie_word_embeddings": true,
|
||||
"transformers.js_config": {
|
||||
"kv_cache_dtype": {
|
||||
"fp16": "float16",
|
||||
"q4f16": "float16"
|
||||
}
|
||||
},
|
||||
"transformers_version": "4.56.2",
|
||||
"use_cache": true,
|
||||
"vocab_size": 49154
|
||||
}
|
||||
1
configuration.json
Normal file
1
configuration.json
Normal file
@@ -0,0 +1 @@
|
||||
{"framework": "pytorch", "task": "text-generation", "allow_remote": true}
|
||||
9
generation_config.json
Normal file
9
generation_config.json
Normal file
@@ -0,0 +1,9 @@
|
||||
{
|
||||
"_from_model_config": true,
|
||||
"bos_token_id": 1,
|
||||
"eos_token_id": [
|
||||
2
|
||||
],
|
||||
"pad_token_id": 2,
|
||||
"transformers_version": "4.56.2"
|
||||
}
|
||||
48901
merges.txt
Normal file
48901
merges.txt
Normal file
File diff suppressed because it is too large
Load Diff
3
model.safetensors
Normal file
3
model.safetensors
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:f72be34d84f9d95d2f1ba9e6aa5d352ec841ce1e7aed81998135ce5bf96e9a09
|
||||
size 269062856
|
||||
34
special_tokens_map.json
Normal file
34
special_tokens_map.json
Normal file
@@ -0,0 +1,34 @@
|
||||
{
|
||||
"additional_special_tokens": [
|
||||
"<think>",
|
||||
"</think>"
|
||||
],
|
||||
"bos_token": {
|
||||
"content": "<|im_start|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"eos_token": {
|
||||
"content": "<|im_end|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"pad_token": {
|
||||
"content": "<|im_end|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"unk_token": {
|
||||
"content": "<|endoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
}
|
||||
}
|
||||
3
tokenizer.json
Normal file
3
tokenizer.json
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:67afd7bf266e40695a8d209874c82a88ff5ade32c7a96967ffa712e519757986
|
||||
size 3523023
|
||||
170
tokenizer_config.json
Normal file
170
tokenizer_config.json
Normal file
@@ -0,0 +1,170 @@
|
||||
{
|
||||
"add_prefix_space": false,
|
||||
"added_tokens_decoder": {
|
||||
"0": {
|
||||
"content": "<|endoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"1": {
|
||||
"content": "<|im_start|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"2": {
|
||||
"content": "<|im_end|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"3": {
|
||||
"content": "<repo_name>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"4": {
|
||||
"content": "<reponame>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"5": {
|
||||
"content": "<file_sep>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"6": {
|
||||
"content": "<filename>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"7": {
|
||||
"content": "<gh_stars>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"8": {
|
||||
"content": "<issue_start>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"9": {
|
||||
"content": "<issue_comment>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"10": {
|
||||
"content": "<issue_closed>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"11": {
|
||||
"content": "<jupyter_start>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"12": {
|
||||
"content": "<jupyter_text>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"13": {
|
||||
"content": "<jupyter_code>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"14": {
|
||||
"content": "<jupyter_output>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"15": {
|
||||
"content": "<jupyter_script>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"16": {
|
||||
"content": "<empty_output>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"49152": {
|
||||
"content": "<think>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"49153": {
|
||||
"content": "</think>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
}
|
||||
},
|
||||
"additional_special_tokens": [
|
||||
"<think>",
|
||||
"</think>"
|
||||
],
|
||||
"bos_token": "<|im_start|>",
|
||||
"clean_up_tokenization_spaces": false,
|
||||
"eos_token": "<|im_end|>",
|
||||
"extra_special_tokens": {},
|
||||
"model_max_length": 8192,
|
||||
"pad_token": "<|im_end|>",
|
||||
"tokenizer_class": "GPT2Tokenizer",
|
||||
"unk_token": "<|endoftext|>",
|
||||
"vocab_size": 49152
|
||||
}
|
||||
3
training_args.bin
Normal file
3
training_args.bin
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:247b792c868f8245ceddd15c2b2486a99317401202045e562ed34e945a36ed82
|
||||
size 6865
|
||||
1
vocab.json
Normal file
1
vocab.json
Normal file
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user