初始化项目,由ModelHub XC社区提供模型

Model: openbmb/BitCPM-CANN-1B-unquantized
Source: Original Platform
This commit is contained in:
ModelHub XC
2026-06-04 14:32:59 +08:00
commit a94919dde4
28 changed files with 181544 additions and 0 deletions

35
.gitattributes vendored Normal file
View File

@@ -0,0 +1,35 @@
*.7z filter=lfs diff=lfs merge=lfs -text
*.arrow filter=lfs diff=lfs merge=lfs -text
*.bin filter=lfs diff=lfs merge=lfs -text
*.bz2 filter=lfs diff=lfs merge=lfs -text
*.ckpt filter=lfs diff=lfs merge=lfs -text
*.ftz filter=lfs diff=lfs merge=lfs -text
*.gz filter=lfs diff=lfs merge=lfs -text
*.h5 filter=lfs diff=lfs merge=lfs -text
*.joblib filter=lfs diff=lfs merge=lfs -text
*.lfs.* filter=lfs diff=lfs merge=lfs -text
*.mlmodel filter=lfs diff=lfs merge=lfs -text
*.model filter=lfs diff=lfs merge=lfs -text
*.msgpack filter=lfs diff=lfs merge=lfs -text
*.npy filter=lfs diff=lfs merge=lfs -text
*.npz filter=lfs diff=lfs merge=lfs -text
*.onnx filter=lfs diff=lfs merge=lfs -text
*.ot filter=lfs diff=lfs merge=lfs -text
*.parquet filter=lfs diff=lfs merge=lfs -text
*.pb filter=lfs diff=lfs merge=lfs -text
*.pickle filter=lfs diff=lfs merge=lfs -text
*.pkl filter=lfs diff=lfs merge=lfs -text
*.pt filter=lfs diff=lfs merge=lfs -text
*.pth filter=lfs diff=lfs merge=lfs -text
*.rar filter=lfs diff=lfs merge=lfs -text
*.safetensors filter=lfs diff=lfs merge=lfs -text
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.tar.* filter=lfs diff=lfs merge=lfs -text
*.tar filter=lfs diff=lfs merge=lfs -text
*.tflite filter=lfs diff=lfs merge=lfs -text
*.tgz filter=lfs diff=lfs merge=lfs -text
*.wasm filter=lfs diff=lfs merge=lfs -text
*.xz filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text

126
README.md Normal file
View File

@@ -0,0 +1,126 @@
---
license: apache-2.0
language:
- zh
- en
pipeline_tag: text-generation
library_name: transformers
---
<div align="center">
<img src="https://github.com/OpenBMB/MiniCPM/blob/main/assets/minicpm_logo.png?raw=true" width="500em" ></img>
</div>
<p align="center">
<a href="https://github.com/OpenBMB/MiniCPM/" target="_blank">GitHub Repo</a> |
<a href="https://github.com/OpenBMB/MiniCPM/blob/main/docs/BitCPM_CANN.pdf" target="_blank">Technical Report</a>
</p>
<p align="center">
👋 Join us on <a href="https://discord.gg/3cGQn9b3YM" target="_blank">Discord</a> and <a href="https://github.com/OpenBMB/MiniCPM/blob/main/assets/wechat.jpg" target="_blank">WeChat</a>
</p>
## Overview
BitCPM-CANN-1B-unquantized is the **unquantized QAT (Quantization-Aware Training) checkpoint** of BitCPM-CANN-1B, designed for **continued pre-training and fine-tuning**. It preserves full-precision latent weights with ternary fake quantizers (weights → {-1, 0, 1} with group-wise scaling, trained via STE) defined in `modeling.py`, enabling the model to keep learning under quantization constraints. For technical details, see our [Technical Report](https://github.com/OpenBMB/MiniCPM/blob/main/docs/BitCPM_CANN.pdf).
> ⚠️ **This model is NOT for direct inference.** For inference, use the pseudo-quantized version: [openbmb/BitCPM-CANN-1B](https://huggingface.co/openbmb/BitCPM-CANN-1B).
## Continued Pre-training & Fine-tuning
The **only requirement** is that the forward pass must go through the bundled `modeling.py` (which contains the ternary fake quantizer). Load with `trust_remote_code=True` and do NOT replace or bypass the model's forward logic.
### Option 1: DeepSpeed (Recommended)
We provide ready-to-use training scripts in the [example](https://huggingface.co/openbmb/BitCPM-CANN-1B-unquantized/tree/main/example) directory (using the 1B model as an example):
- **Continued pre-training**: `example/run.sh` + `example/train.py`
- **SFT (Supervised Fine-tuning)**: `example/run_sft.sh` + `example/train_sft.py`
Quick start:
```bash
# Continued pre-training
cd example && bash run.sh
# Supervised fine-tuning
cd example && bash run_sft.sh
```
### Option 2: HuggingFace-compatible Frameworks
Any framework that supports HuggingFace model loading with custom code can be used, such as **LLaMA Factory**, **HuggingFace Trainer**, etc. The key is to ensure `trust_remote_code=True`:
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
path = 'openbmb/BitCPM-CANN-1B-unquantized'
tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
path,
torch_dtype=torch.bfloat16,
trust_remote_code=True
)
# Use with your preferred framework (LLaMA Factory, HF Trainer, etc.)
# The ternary fake quantizer in modeling.py is applied automatically during forward pass.
```
## Post-Training Conversion
After training, use `qat-convert.py` to fuse the fake quantizer and produce inference-ready pseudo-quantized weights:
```bash
python qat-convert.py \
--input_bin <path-to-finetuned-pytorch.bin> \
--output <path-to-output-pseudo-quantized-pytorch.bin> \
--quant_type ternary \
--group_size -1
```
The converted model can be loaded for inference in the same way as [openbmb/BitCPM-CANN-1B](https://huggingface.co/openbmb/BitCPM-CANN-1B)—no special quantization libraries required.
## Workflow
```
┌─────────────────────────────────┐
│ BitCPM-CANN-1B-unquantized │ ← This model (QAT checkpoint + fake quantizer in modeling.py)
└───────────────┬─────────────────┘
▼ Train (DeepSpeed / LLaMA Factory / HF Trainer / ...)
┌─────────────────────────────────┐
│ Fine-tuned checkpoint │ ← Still contains un-fused QAT parameters
└───────────────┬─────────────────┘
▼ python qat-convert.py --quant_type ternary --group_size -1
┌─────────────────────────────────┐
│ Pseudo-quantized model │ ← Ready for inference (same format as BitCPM-CANN-1B)
└─────────────────────────────────┘
```
## BitCPM-CANN Model Family
| Model | HuggingFace (Inference) | HuggingFace (Fine-tuning) |
|-------|-------------------------|---------------------------|
| BitCPM-CANN-0.5B | [openbmb/BitCPM-CANN-0.5B](https://huggingface.co/openbmb/BitCPM-CANN-0.5B) | [openbmb/BitCPM-CANN-0.5B-unquantized](https://huggingface.co/openbmb/BitCPM-CANN-0.5B-unquantized) |
| BitCPM-CANN-1B | [openbmb/BitCPM-CANN-1B](https://huggingface.co/openbmb/BitCPM-CANN-1B) | [openbmb/BitCPM-CANN-1B-unquantized](https://huggingface.co/openbmb/BitCPM-CANN-1B-unquantized) |
| BitCPM-CANN-3B | [openbmb/BitCPM-CANN-3B](https://huggingface.co/openbmb/BitCPM-CANN-3B) | [openbmb/BitCPM-CANN-3B-unquantized](https://huggingface.co/openbmb/BitCPM-CANN-3B-unquantized) |
| BitCPM-CANN-8B | [openbmb/BitCPM-CANN-8B](https://huggingface.co/openbmb/BitCPM-CANN-8B) | [openbmb/BitCPM-CANN-8B-unquantized](https://huggingface.co/openbmb/BitCPM-CANN-8B-unquantized) |
## Statement
- As a language model, BitCPM-CANN generates content by learning from a vast amount of text.
- However, it does not possess the ability to comprehend or express personal opinions or value judgments.
- Any content generated by BitCPM-CANN does not represent the viewpoints or positions of the model developers.
- Therefore, when using content generated by BitCPM-CANN, users should take full responsibility for evaluating and verifying it on their own.
## LICENSE
- This repository and BitCPM-CANN models are released under the [Apache-2.0](https://github.com/OpenBMB/MiniCPM/blob/main/LICENSE) License.
## Citation
- Please cite our technical report if you find our work valuable.
```bibtex
@article{bitcpmcann,
title={{BitCPM-CANN}: Native 1.58-Bit Large Language Model Training on Ascend NPU},
author={BitCPM Team},
year={2026}
}
```

169
config.json Normal file
View File

@@ -0,0 +1,169 @@
{
"_name_or_path": "openbmb/CPM-2B",
"architectures": [
"LlamaForCausalLM"
],
"auto_map": {
"AutoConfig": "configuration_llama.LlamaConfig",
"AutoModel": "modeling_llama.LlamaForCausalLM",
"AutoModelForCausalLM": "modeling_llama.LlamaForCausalLM"
},
"bos_token_id": 1,
"eos_token_id": [
2,
73440
],
"pad_token_id": 2,
"hidden_act": "silu",
"hidden_size": 2048,
"initializer_range": 0.1,
"intermediate_size": 6144,
"head_dim": 128,
"max_position_embeddings": 32768,
"model_type": "llama",
"num_attention_heads": 16,
"num_hidden_layers": 28,
"num_key_value_heads": 2,
"rms_norm_eps": 1e-06,
"rope_scaling": {
"factor": 1.0,
"rope_type": "longrope",
"long_factor": [
0.9977997200264581,
1.014658295992452,
1.0349680404997148,
1.059429246056193,
1.0888815016813513,
1.1243301355211495,
1.166977103606075,
1.2182568066927284,
1.2798772354275727,
1.3538666751582975,
1.4426259039919596,
1.5489853358570191,
1.6762658237220625,
1.8283407612492941,
2.0096956085876183,
2.225478927469756,
2.481536379650452,
2.784415934557119,
3.1413289096347365,
3.560047844772632,
4.048719380066383,
4.615569542115128,
5.2684819496549835,
6.014438591970396,
6.858830049237097,
7.804668263503327,
8.851768731513417,
9.99600492938444,
11.228766118181639,
12.536757560834843,
13.902257701387796,
15.303885189125953,
16.717837610115794,
18.119465097853947,
19.484965238406907,
20.792956681060105,
22.02571786985731,
23.16995406772833,
24.217054535738416,
25.16289275000465,
26.007284207271347,
26.753240849586767,
27.40615325712662,
27.973003419175363,
28.461674954469114,
28.880393889607006,
29.237306864684626,
29.540186419591297,
29.79624387177199,
30.01202719065413,
30.193382037992453,
30.34545697551969,
30.47273746338473,
30.579096895249787,
30.66785612408345,
30.741845563814174,
30.80346599254902,
30.85474569563567,
30.897392663720595,
30.932841297560394,
30.962293553185553,
30.986754758742034,
31.007064503249293,
31.02392307921529
],
"short_factor": [
0.9977997200264581,
1.014658295992452,
1.0349680404997148,
1.059429246056193,
1.0888815016813513,
1.1243301355211495,
1.166977103606075,
1.2182568066927284,
1.2798772354275727,
1.3538666751582975,
1.4426259039919596,
1.5489853358570191,
1.6762658237220625,
1.8283407612492941,
2.0096956085876183,
2.225478927469756,
2.481536379650452,
2.784415934557119,
3.1413289096347365,
3.560047844772632,
4.048719380066383,
4.615569542115128,
5.2684819496549835,
6.014438591970396,
6.858830049237097,
7.804668263503327,
8.851768731513417,
9.99600492938444,
11.228766118181639,
12.536757560834843,
13.902257701387796,
15.303885189125953,
16.717837610115794,
18.119465097853947,
19.484965238406907,
20.792956681060105,
22.02571786985731,
23.16995406772833,
24.217054535738416,
25.16289275000465,
26.007284207271347,
26.753240849586767,
27.40615325712662,
27.973003419175363,
28.461674954469114,
28.880393889607006,
29.237306864684626,
29.540186419591297,
29.79624387177199,
30.01202719065413,
30.193382037992453,
30.34545697551969,
30.47273746338473,
30.579096895249787,
30.66785612408345,
30.741845563814174,
30.80346599254902,
30.85474569563567,
30.897392663720595,
30.932841297560394,
30.962293553185553,
30.986754758742034,
31.007064503249293,
31.02392307921529
],
"original_max_position_embeddings": 32768
},
"torch_dtype": "bfloat16",
"transformers_version": "4.46.3",
"use_cache": true,
"vocab_size": 73448
}

206
configuration_llama.py Normal file
View File

@@ -0,0 +1,206 @@
# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""LLaMA model configuration"""
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_rope_utils import rope_config_validation
class LlamaConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the LLaMA-7B.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 32000):
Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`LlamaModel`]
hidden_size (`int`, *optional*, defaults to 4096):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 11008):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 32):
Number of hidden layers in the Transformer decoder.
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the Transformer decoder.
num_key_value_heads (`int`, *optional*):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
`num_attention_heads`.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 2048):
The maximum sequence length that this model might ever be used with. Llama 1 supports up to 2048 tokens,
Llama 2 up to 4096, CodeLlama up to 16384.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
pad_token_id (`int`, *optional*):
Padding token id.
bos_token_id (`int`, *optional*, defaults to 1):
Beginning of stream token id.
eos_token_id (`int`, *optional*, defaults to 2):
End of stream token id.
pretraining_tp (`int`, *optional*, defaults to 1):
Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to
understand more about it. This value is necessary to ensure exact reproducibility of the pretraining
results. Please refer to [this issue](https://github.com/pytorch/pytorch/issues/76232).
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether to tie weight embeddings
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
accordingly.
Expected contents:
`rope_type` (`str`):
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
'llama3'], with 'default' being the original RoPE implementation.
`factor` (`float`, *optional*):
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
original maximum pre-trained length.
`original_max_position_embeddings` (`int`, *optional*):
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
pretraining.
`attention_factor` (`float`, *optional*):
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
computation. If unspecified, it defaults to value recommended by the implementation, using the
`factor` field to infer the suggested value.
`beta_fast` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
ramp function. If unspecified, it defaults to 32.
`beta_slow` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
ramp function. If unspecified, it defaults to 1.
`short_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`long_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`low_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
`high_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
attention_bias (`bool`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
mlp_bias (`bool`, *optional*, defaults to `False`):
Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers.
head_dim (`int`, *optional*):
The attention head dimension. If None, it will default to hidden_size // num_heads
```python
>>> from transformers import LlamaModel, LlamaConfig
>>> # Initializing a LLaMA llama-7b style configuration
>>> configuration = LlamaConfig()
>>> # Initializing a model from the llama-7b style configuration
>>> model = LlamaModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "llama"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=32000,
hidden_size=4096,
intermediate_size=11008,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=None,
hidden_act="silu",
max_position_embeddings=2048,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=None,
bos_token_id=1,
eos_token_id=2,
pretraining_tp=1,
tie_word_embeddings=False,
rope_theta=10000.0,
rope_scaling=None,
attention_bias=False,
attention_dropout=0.0,
mlp_bias=False,
head_dim=None,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.pretraining_tp = pretraining_tp
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.mlp_bias = mlp_bias
self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads
# Validate the correctness of rotary position embeddings parameters
# BC: if there is a 'type' field, copy it it to 'rope_type'.
if self.rope_scaling is not None and "type" in self.rope_scaling:
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
rope_config_validation(self)
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)

103
example/README.md Normal file
View File

@@ -0,0 +1,103 @@
# BitCPM Training Example
This project provides scripts for continue pretraining (CPT) and supervised fine-tuning (SFT) of **BitCPM-CANN-1B-unquantized**.
## File Description
CPT and SFT each have a pair of scripts (training script + launch script) and share DeepSpeed configuration files:
| File | Description |
| --- | --- |
| `train.py` | Continue pretrain script based on HuggingFace Trainer + DeepSpeed |
| `run.sh` | Launch script for CPT with hyperparameter configuration |
| `train_sft.py` | Supervised fine-tuning script based on HuggingFace Trainer + DeepSpeed |
| `run_sft.sh` | Launch script for SFT with hyperparameter configuration |
| `ds_config.json` | DeepSpeed ZeRO-3 configuration (with CPU offload) |
| `ds_config_z2.json` | DeepSpeed ZeRO-2 configuration (used by default) |
| `requirements.txt` | Python dependency list |
## Environment Setup
### Docker Image
Use the following Huawei NPU image:
```
swr.cn-south-1.myhuaweicloud.com/ascendhub/mindspeed-llm:openeuler22.03-mindspeed-llm-2.3.0-a3-arm
```
Other Huawei NPU images may also work but have not been fully tested. For GPU environments, you can skip the Docker image and just install `requirements.txt` directly.
### Install Dependencies
After entering the container, install the Python dependencies:
```bash
pip install -r requirements.txt
```
## Continue Pretrain (CPT)
### Dataset
The test dataset used is [C4-Pro](https://huggingface.co/datasets/gair-prox/c4-pro), stored in parquet format after downloading.
### Usage
Modify the path configuration in `run.sh`:
```bash
MODEL_PATH="/path/to/BitCPM-CANN-1B-unquantized/"
DATA_PATH="/path/to/c4-pro/data/your_file.parquet"
```
Then start training:
```bash
bash run.sh
```
## Supervised Fine-Tuning (SFT)
### Dataset
The test dataset used is [UltraChat 200k](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k), stored in parquet format after downloading.
### Usage
Modify the path configuration in `run_sft.sh`:
```bash
MODEL_PATH="/path/to/BitCPM-CANN-1B-unquantized/"
DATA_PATH="/path/to/ultrachat_200k/data/your_file.parquet"
```
Then start training:
```bash
bash run_sft.sh
```
## Training Results Reference
> **Note:** BitCPM has its own training dataset and data mixture. It is expected that the loss continues to decrease when training on open-source datasets.
Below are the loss curves from smoke tests on GPU and NPU for both CPT and SFT tasks. The results are highly consistent across GPU and NPU, indicating that users can continue pre-training or fine-tuning on various compute devices:
| | GPU | NPU |
| --- | --- | --- |
| **CPT** | ![GPU Pretrain Loss](gpu_pretrain_loss.png) | ![NPU Pretrain Loss](npu_pretrain_loss.png) |
| **SFT** | ![GPU SFT Loss](gpu_sft_loss.png) | ![NPU SFT Loss](npu_sft_loss.png) |
Training log CSV files (corresponding to the loss curves above):
| CSV File | Corresponding Loss Curve |
| --- | --- |
| [gpu_pretrain.csv](gpu_pretrain.csv) | GPU CPT |
| [npu_pretrain.csv](npu_pretrain.csv) | NPU CPT |
| [gpu_sft.csv](gpu_sft.csv) | GPU SFT |
| [npu_sft.csv](npu_sft.csv) | NPU SFT |
---
These scripts provide a convenient, ready-to-use toolkit for QAT-aware continued pre-training and fine-tuning of BitCPM-CANN models, so you can quickly adapt the model to your own data and tasks while preserving ternary quantization constraints.

29
example/ds_config.json Normal file
View File

@@ -0,0 +1,29 @@
{
"bf16": {
"enabled": true
},
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
},
"offload_param": {
"device": "none"
},
"overlap_comm": true,
"contiguous_gradients": true,
"sub_group_size": 1e9,
"reduce_bucket_size": 2e8,
"stage3_prefetch_bucket_size": 2e8,
"stage3_param_persistence_threshold": 1e5,
"stage3_max_live_parameters": 2e9,
"stage3_max_reuse_distance": 2e9,
"stage3_gather_16bit_weights_on_model_save": true
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false
}

22
example/ds_config_z2.json Normal file
View File

@@ -0,0 +1,22 @@
{
"bf16": {
"enabled": true
},
"zero_optimization": {
"stage": 2,
"offload_optimizer": {
"device": "none"
},
"allgather_partitions": true,
"allgather_bucket_size": 2e8,
"overlap_comm": true,
"reduce_scatter": true,
"reduce_bucket_size": 2e8,
"contiguous_gradients": true
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false
}

51
example/gpu_pretrain.csv Normal file
View File

@@ -0,0 +1,51 @@
step,train/loss,train/grad_norm,train/learning_rate,train/epoch,train/train_runtime,train/train_samples_per_second,train/train_steps_per_second,train/total_flos,train/train_loss
2,2.7920000553131104,0.03527498617768288,7.999999979801942e-06,0.010457516647875309,,,,,
4,2.8011999130249023,0.03495891019701958,1.5999999959603883e-05,0.020915033295750618,,,,,
6,2.7964000701904297,0.03271934762597084,2.4000000848900527e-05,0.0313725508749485,,,,,
8,2.763700008392334,0.024968057870864868,3.199999991920777e-05,0.041830066591501236,,,,,
10,3.281599998474121,0.31758183240890503,3.9999998989515007e-05,0.05228758230805397,,,,,
12,2.941200017929077,0.044055406004190445,3.995128281530924e-05,0.062745101749897,,,,,
14,2.851799964904785,0.03649706766009331,3.9805359847377986e-05,0.07320261746644974,,,,,
16,2.7869999408721924,0.022624235600233078,3.9562950405525044e-05,0.08366013318300247,,,,,
18,2.7825000286102295,0.021830420941114426,3.922523319488391e-05,0.0941176488995552,,,,,
20,2.7857000827789307,0.01685911975800991,3.87938525818754e-05,0.10457516461610794,,,,,
22,2.7571001052856445,0.01572061888873577,3.827090768027119e-05,0.11503268033266068,,,,,
24,2.762399911880493,0.016891509294509888,3.7658952351193875e-05,0.125490203499794,,,,,
26,2.7411000728607178,0.015683824196457863,3.6960962461307645e-05,0.13594771921634674,,,,,
28,2.733099937438965,0.012847283855080605,3.6180339520797133e-05,0.14640523493289948,,,,,
30,2.723400115966797,0.015209181234240532,3.532088885549456e-05,0.1568627506494522,,,,,
32,2.7342000007629395,0.01241038367152214,3.4386797779006884e-05,0.16732026636600494,,,,,
34,2.7321999073028564,0.012879018671810627,3.338261376484297e-05,0.17777778208255768,,,,,
36,2.7314000129699707,0.013242729939520359,3.231322989449836e-05,0.1882352977991104,,,,,
38,2.7065999507904053,0.01113435160368681,3.118385939160362e-05,0.19869281351566315,,,,,
40,2.6958999633789062,0.012413726188242435,2.9999999242136255e-05,0.20915032923221588,,,,,
42,2.7516000270843506,0.011661508120596409,2.8767422918463126e-05,0.21960784494876862,,,,,
44,2.713099956512451,0.012248368933796883,2.749213126662653e-05,0.23006536066532135,,,,,
46,2.7102999687194824,0.011450185440480709,2.6180339773418382e-05,0.24052287638187408,,,,,
48,2.7021000385284424,0.011155751533806324,2.483843854861334e-05,0.250980406999588,,,,,
50,2.680500030517578,0.010021247901022434,2.3472963221138343e-05,0.26143792271614075,,,,,
52,2.699199914932251,0.010751751251518726,2.2090569473220967e-05,0.2718954384326935,,,,,
54,2.694200038909912,0.010503941215574741,2.0697989384643734e-05,0.2823529541492462,,,,,
56,2.7091000080108643,0.010059370659291744,1.9302009604871273e-05,0.29281046986579895,,,,,
58,2.699399948120117,0.012161476537585258,1.7909431335283443e-05,0.3032679855823517,,,,,
60,2.7216999530792236,0.010671027936041355,1.6527035768376663e-05,0.3137255012989044,,,,,
62,2.7158000469207764,0.010463157668709755,1.516156225989107e-05,0.32418301701545715,,,,,
64,2.7214999198913574,0.010665320791304111,1.3819660125591327e-05,0.3346405327320099,,,,,
66,2.7116000652313232,0.01046629250049591,1.2507867722888477e-05,0.3450980484485626,,,,,
68,2.6923000812530518,0.010609752498567104,1.1232576980546582e-05,0.35555556416511536,,,,,
70,2.6830999851226807,0.009290814399719238,9.999999747378752e-06,0.3660130798816681,,,,,
72,2.7093000411987305,0.010727670043706894,8.816142326395493e-06,0.3764705955982208,,,,,
74,2.698699951171875,0.0109737953171134,7.686770914006047e-06,0.38692811131477356,,,,,
76,2.712599992752075,0.010320967063307762,6.61738795315614e-06,0.3973856270313263,,,,,
78,2.6993000507354736,0.009841523133218288,5.613203938992228e-06,0.40784314274787903,,,,,
80,2.6861000061035156,0.010179675184190273,4.6791110435151495e-06,0.41830065846443176,,,,,
82,2.6828999519348145,0.009790077805519104,3.819659923465224e-06,0.4287581741809845,,,,,
84,2.699199914932251,0.010508442297577858,3.03903811982309e-06,0.43921568989753723,,,,,
86,2.6988000869750977,0.009589221328496933,2.3410482299368596e-06,0.44967320561408997,,,,,
88,2.688499927520752,0.010065913200378418,1.7290908544964623e-06,0.4601307213306427,,,,,
90,2.6928999423980713,0.010363687761127949,1.206147544507985e-06,0.47058823704719543,,,,,
92,2.714200019836426,0.010142815299332142,7.74766078848188e-07,0.48104575276374817,,,,,
94,2.672300100326538,0.009833029471337795,4.370479871340649e-07,0.4915032684803009,,,,,
96,2.7018001079559326,0.009937037713825703,1.9463863054625108e-07,0.501960813999176,,,,,
98,2.7121999263763428,0.009417451918125153,4.8718995060426096e-08,0.5124183297157288,,,,,
100,2.7028000354766846,0.009256146848201752,0.0,0.5228758454322815,365.8839111328125,139.93499755859375,0.27300000190734863,4.629706395531346e+17,2.7395541667938232
1 step train/loss train/grad_norm train/learning_rate train/epoch train/train_runtime train/train_samples_per_second train/train_steps_per_second train/total_flos train/train_loss
2 2 2.7920000553131104 0.03527498617768288 7.999999979801942e-06 0.010457516647875309
3 4 2.8011999130249023 0.03495891019701958 1.5999999959603883e-05 0.020915033295750618
4 6 2.7964000701904297 0.03271934762597084 2.4000000848900527e-05 0.0313725508749485
5 8 2.763700008392334 0.024968057870864868 3.199999991920777e-05 0.041830066591501236
6 10 3.281599998474121 0.31758183240890503 3.9999998989515007e-05 0.05228758230805397
7 12 2.941200017929077 0.044055406004190445 3.995128281530924e-05 0.062745101749897
8 14 2.851799964904785 0.03649706766009331 3.9805359847377986e-05 0.07320261746644974
9 16 2.7869999408721924 0.022624235600233078 3.9562950405525044e-05 0.08366013318300247
10 18 2.7825000286102295 0.021830420941114426 3.922523319488391e-05 0.0941176488995552
11 20 2.7857000827789307 0.01685911975800991 3.87938525818754e-05 0.10457516461610794
12 22 2.7571001052856445 0.01572061888873577 3.827090768027119e-05 0.11503268033266068
13 24 2.762399911880493 0.016891509294509888 3.7658952351193875e-05 0.125490203499794
14 26 2.7411000728607178 0.015683824196457863 3.6960962461307645e-05 0.13594771921634674
15 28 2.733099937438965 0.012847283855080605 3.6180339520797133e-05 0.14640523493289948
16 30 2.723400115966797 0.015209181234240532 3.532088885549456e-05 0.1568627506494522
17 32 2.7342000007629395 0.01241038367152214 3.4386797779006884e-05 0.16732026636600494
18 34 2.7321999073028564 0.012879018671810627 3.338261376484297e-05 0.17777778208255768
19 36 2.7314000129699707 0.013242729939520359 3.231322989449836e-05 0.1882352977991104
20 38 2.7065999507904053 0.01113435160368681 3.118385939160362e-05 0.19869281351566315
21 40 2.6958999633789062 0.012413726188242435 2.9999999242136255e-05 0.20915032923221588
22 42 2.7516000270843506 0.011661508120596409 2.8767422918463126e-05 0.21960784494876862
23 44 2.713099956512451 0.012248368933796883 2.749213126662653e-05 0.23006536066532135
24 46 2.7102999687194824 0.011450185440480709 2.6180339773418382e-05 0.24052287638187408
25 48 2.7021000385284424 0.011155751533806324 2.483843854861334e-05 0.250980406999588
26 50 2.680500030517578 0.010021247901022434 2.3472963221138343e-05 0.26143792271614075
27 52 2.699199914932251 0.010751751251518726 2.2090569473220967e-05 0.2718954384326935
28 54 2.694200038909912 0.010503941215574741 2.0697989384643734e-05 0.2823529541492462
29 56 2.7091000080108643 0.010059370659291744 1.9302009604871273e-05 0.29281046986579895
30 58 2.699399948120117 0.012161476537585258 1.7909431335283443e-05 0.3032679855823517
31 60 2.7216999530792236 0.010671027936041355 1.6527035768376663e-05 0.3137255012989044
32 62 2.7158000469207764 0.010463157668709755 1.516156225989107e-05 0.32418301701545715
33 64 2.7214999198913574 0.010665320791304111 1.3819660125591327e-05 0.3346405327320099
34 66 2.7116000652313232 0.01046629250049591 1.2507867722888477e-05 0.3450980484485626
35 68 2.6923000812530518 0.010609752498567104 1.1232576980546582e-05 0.35555556416511536
36 70 2.6830999851226807 0.009290814399719238 9.999999747378752e-06 0.3660130798816681
37 72 2.7093000411987305 0.010727670043706894 8.816142326395493e-06 0.3764705955982208
38 74 2.698699951171875 0.0109737953171134 7.686770914006047e-06 0.38692811131477356
39 76 2.712599992752075 0.010320967063307762 6.61738795315614e-06 0.3973856270313263
40 78 2.6993000507354736 0.009841523133218288 5.613203938992228e-06 0.40784314274787903
41 80 2.6861000061035156 0.010179675184190273 4.6791110435151495e-06 0.41830065846443176
42 82 2.6828999519348145 0.009790077805519104 3.819659923465224e-06 0.4287581741809845
43 84 2.699199914932251 0.010508442297577858 3.03903811982309e-06 0.43921568989753723
44 86 2.6988000869750977 0.009589221328496933 2.3410482299368596e-06 0.44967320561408997
45 88 2.688499927520752 0.010065913200378418 1.7290908544964623e-06 0.4601307213306427
46 90 2.6928999423980713 0.010363687761127949 1.206147544507985e-06 0.47058823704719543
47 92 2.714200019836426 0.010142815299332142 7.74766078848188e-07 0.48104575276374817
48 94 2.672300100326538 0.009833029471337795 4.370479871340649e-07 0.4915032684803009
49 96 2.7018001079559326 0.009937037713825703 1.9463863054625108e-07 0.501960813999176
50 98 2.7121999263763428 0.009417451918125153 4.8718995060426096e-08 0.5124183297157288
51 100 2.7028000354766846 0.009256146848201752 0.0 0.5228758454322815 365.8839111328125 139.93499755859375 0.27300000190734863 4.629706395531346e+17 2.7395541667938232

Binary file not shown.

After

Width:  |  Height:  |  Size: 49 KiB

51
example/gpu_sft.csv Normal file
View File

@@ -0,0 +1,51 @@
step,train/loss,train/grad_norm,train/learning_rate,train/epoch,train/train_runtime,train/train_samples_per_second,train/train_steps_per_second,train/total_flos,train/train_loss
2,1.1492999792099,0.6216375231742859,1.9999999949504854e-06,0.0004617871018126607,,,,,
4,1.0979000329971313,0.681877851486206,3.999999989900971e-06,0.0009235742036253214,,,,,
6,1.1269999742507935,0.784303605556488,6.000000212225132e-06,0.001385361305437982,,,,,
8,1.0542000532150269,0.8737029433250427,7.999999979801942e-06,0.0018471484072506428,,,,,
10,1.2440999746322632,0.7068291902542114,9.999999747378752e-06,0.0023089356254786253,,,,,
12,1.2925000190734863,0.6821666955947876,1.2000000424450263e-05,0.002770722610875964,,,,,
14,1.0843000411987305,0.525643527507782,1.4000000192027073e-05,0.0032325098291039467,,,,,
16,1.0961999893188477,0.43757057189941406,1.5999999959603883e-05,0.0036942968145012856,,,,,
18,1.0614999532699585,0.46141618490219116,1.8000000636675395e-05,0.004156084265559912,,,,,
20,1.332900047302246,0.715879499912262,1.9999999494757503e-05,0.004617871250957251,,,,,
22,1.2070000171661377,0.5926885008811951,1.996917308133561e-05,0.0050796582363545895,,,,,
24,1.2043999433517456,0.5833240747451782,1.9876883015967906e-05,0.005541445221751928,,,,,
26,1.0740000009536743,0.44734400510787964,1.9723698642337695e-05,0.0060032326728105545,,,,,
28,1.1162999868392944,0.3701137900352478,1.9510565834934823e-05,0.006465019658207893,,,,,
30,1.0454000234603882,0.43832680583000183,1.9238796085119247e-05,0.006926806643605232,,,,,
32,1.124899983406067,0.4591037631034851,1.8910064682131633e-05,0.007388593629002571,,,,,
34,1.0686999559402466,0.3873400390148163,1.8526401618146338e-05,0.00785038061439991,,,,,
36,1.0291999578475952,0.40313437581062317,1.8090169760398567e-05,0.008312168531119823,,,,,
38,1.1052000522613525,0.3735405504703522,1.7604059394216165e-05,0.008773955516517162,,,,,
40,1.1555999517440796,0.3818407654762268,1.7071068214136176e-05,0.009235742501914501,,,,,
42,1.0235999822616577,0.4255191683769226,1.6494481315021403e-05,0.00969752948731184,,,,,
44,1.0364999771118164,0.4794503152370453,1.5877853002166376e-05,0.010159316472709179,,,,,
46,1.1344000101089478,0.37273937463760376,1.5224985872919206e-05,0.010621103458106518,,,,,
48,1.0866999626159668,0.417492538690567,1.453990535082994e-05,0.011082890443503857,,,,,
50,1.1038000583648682,0.35408055782318115,1.3826834219798911e-05,0.01154467836022377,,,,,
52,1.1478999853134155,0.3930828273296356,1.3090169886709191e-05,0.012006465345621109,,,,,
54,1.1858999729156494,0.3965947926044464,1.2334453458606731e-05,0.012468252331018448,,,,,
56,1.0096999406814575,0.3860221207141876,1.1564344276848715e-05,0.012930039316415787,,,,,
58,1.114799976348877,0.44393691420555115,1.0784590813273098e-05,0.013391826301813126,,,,,
60,1.079300045967102,0.3605058789253235,9.999999747378752e-06,0.013853613287210464,,,,,
62,1.1766999959945679,0.40689122676849365,9.215408681484405e-06,0.014315400272607803,,,,,
64,1.1075999736785889,0.4002344310283661,8.435655217908788e-06,0.014777187258005142,,,,,
66,1.1866999864578247,0.46947163343429565,7.665546036150772e-06,0.015238975174725056,,,,,
68,1.0311000347137451,0.3296957314014435,6.909830062795663e-06,0.01570076122879982,,,,,
70,1.1088999509811401,0.33858785033226013,6.173165729705943e-06,0.01616254821419716,,,,,
72,1.0720000267028809,0.3967427909374237,5.460095053422265e-06,0.016624337062239647,,,,,
74,1.1460000276565552,0.41202062368392944,4.7750145313329995e-06,0.017086124047636986,,,,,
76,1.0425000190734863,0.38334518671035767,4.1221474020858295e-06,0.017547911033034325,,,,,
78,0.9154000282287598,0.40649303793907166,3.505519543978153e-06,0.018009698018431664,,,,,
80,1.1110999584197998,0.35371580719947815,2.9289321901160292e-06,0.018471485003829002,,,,,
82,1.1672999858856201,0.3381657302379608,2.3959403279150138e-06,0.01893327198922634,,,,,
84,1.2374000549316406,0.3815234303474426,1.909829961732612e-06,0.01939505897462368,,,,,
86,1.2151000499725342,0.38446080684661865,1.4735983313585166e-06,0.01985684596002102,,,,,
88,1.163100004196167,0.40419140458106995,1.0899348126258701e-06,0.020318632945418358,,,,,
90,1.1883000135421753,0.4011874198913574,7.612046601934708e-07,0.020780419930815697,,,,,
92,1.1526999473571777,0.3836020231246948,4.894348535344761e-07,0.021242206916213036,,,,,
94,1.15339994430542,0.452364057302475,2.7630079557638965e-07,0.021703993901610374,,,,,
96,1.062000036239624,0.3502688705921173,1.2311659247643547e-07,0.022165780887007713,,,,,
98,1.0271999835968018,0.4022065997123718,3.0826662111849146e-08,0.022627567872405052,,,,,
100,1.0283000469207764,0.38241174817085266,0.0,0.02308935672044754,183.9481964111328,8.697999954223633,0.5440000295639038,1862467846144.0,1.1177252531051636
1 step train/loss train/grad_norm train/learning_rate train/epoch train/train_runtime train/train_samples_per_second train/train_steps_per_second train/total_flos train/train_loss
2 2 1.1492999792099 0.6216375231742859 1.9999999949504854e-06 0.0004617871018126607
3 4 1.0979000329971313 0.681877851486206 3.999999989900971e-06 0.0009235742036253214
4 6 1.1269999742507935 0.784303605556488 6.000000212225132e-06 0.001385361305437982
5 8 1.0542000532150269 0.8737029433250427 7.999999979801942e-06 0.0018471484072506428
6 10 1.2440999746322632 0.7068291902542114 9.999999747378752e-06 0.0023089356254786253
7 12 1.2925000190734863 0.6821666955947876 1.2000000424450263e-05 0.002770722610875964
8 14 1.0843000411987305 0.525643527507782 1.4000000192027073e-05 0.0032325098291039467
9 16 1.0961999893188477 0.43757057189941406 1.5999999959603883e-05 0.0036942968145012856
10 18 1.0614999532699585 0.46141618490219116 1.8000000636675395e-05 0.004156084265559912
11 20 1.332900047302246 0.715879499912262 1.9999999494757503e-05 0.004617871250957251
12 22 1.2070000171661377 0.5926885008811951 1.996917308133561e-05 0.0050796582363545895
13 24 1.2043999433517456 0.5833240747451782 1.9876883015967906e-05 0.005541445221751928
14 26 1.0740000009536743 0.44734400510787964 1.9723698642337695e-05 0.0060032326728105545
15 28 1.1162999868392944 0.3701137900352478 1.9510565834934823e-05 0.006465019658207893
16 30 1.0454000234603882 0.43832680583000183 1.9238796085119247e-05 0.006926806643605232
17 32 1.124899983406067 0.4591037631034851 1.8910064682131633e-05 0.007388593629002571
18 34 1.0686999559402466 0.3873400390148163 1.8526401618146338e-05 0.00785038061439991
19 36 1.0291999578475952 0.40313437581062317 1.8090169760398567e-05 0.008312168531119823
20 38 1.1052000522613525 0.3735405504703522 1.7604059394216165e-05 0.008773955516517162
21 40 1.1555999517440796 0.3818407654762268 1.7071068214136176e-05 0.009235742501914501
22 42 1.0235999822616577 0.4255191683769226 1.6494481315021403e-05 0.00969752948731184
23 44 1.0364999771118164 0.4794503152370453 1.5877853002166376e-05 0.010159316472709179
24 46 1.1344000101089478 0.37273937463760376 1.5224985872919206e-05 0.010621103458106518
25 48 1.0866999626159668 0.417492538690567 1.453990535082994e-05 0.011082890443503857
26 50 1.1038000583648682 0.35408055782318115 1.3826834219798911e-05 0.01154467836022377
27 52 1.1478999853134155 0.3930828273296356 1.3090169886709191e-05 0.012006465345621109
28 54 1.1858999729156494 0.3965947926044464 1.2334453458606731e-05 0.012468252331018448
29 56 1.0096999406814575 0.3860221207141876 1.1564344276848715e-05 0.012930039316415787
30 58 1.114799976348877 0.44393691420555115 1.0784590813273098e-05 0.013391826301813126
31 60 1.079300045967102 0.3605058789253235 9.999999747378752e-06 0.013853613287210464
32 62 1.1766999959945679 0.40689122676849365 9.215408681484405e-06 0.014315400272607803
33 64 1.1075999736785889 0.4002344310283661 8.435655217908788e-06 0.014777187258005142
34 66 1.1866999864578247 0.46947163343429565 7.665546036150772e-06 0.015238975174725056
35 68 1.0311000347137451 0.3296957314014435 6.909830062795663e-06 0.01570076122879982
36 70 1.1088999509811401 0.33858785033226013 6.173165729705943e-06 0.01616254821419716
37 72 1.0720000267028809 0.3967427909374237 5.460095053422265e-06 0.016624337062239647
38 74 1.1460000276565552 0.41202062368392944 4.7750145313329995e-06 0.017086124047636986
39 76 1.0425000190734863 0.38334518671035767 4.1221474020858295e-06 0.017547911033034325
40 78 0.9154000282287598 0.40649303793907166 3.505519543978153e-06 0.018009698018431664
41 80 1.1110999584197998 0.35371580719947815 2.9289321901160292e-06 0.018471485003829002
42 82 1.1672999858856201 0.3381657302379608 2.3959403279150138e-06 0.01893327198922634
43 84 1.2374000549316406 0.3815234303474426 1.909829961732612e-06 0.01939505897462368
44 86 1.2151000499725342 0.38446080684661865 1.4735983313585166e-06 0.01985684596002102
45 88 1.163100004196167 0.40419140458106995 1.0899348126258701e-06 0.020318632945418358
46 90 1.1883000135421753 0.4011874198913574 7.612046601934708e-07 0.020780419930815697
47 92 1.1526999473571777 0.3836020231246948 4.894348535344761e-07 0.021242206916213036
48 94 1.15339994430542 0.452364057302475 2.7630079557638965e-07 0.021703993901610374
49 96 1.062000036239624 0.3502688705921173 1.2311659247643547e-07 0.022165780887007713
50 98 1.0271999835968018 0.4022065997123718 3.0826662111849146e-08 0.022627567872405052
51 100 1.0283000469207764 0.38241174817085266 0.0 0.02308935672044754 183.9481964111328 8.697999954223633 0.5440000295639038 1862467846144.0 1.1177252531051636

BIN
example/gpu_sft_loss.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 68 KiB

51
example/npu_pretrain.csv Normal file
View File

@@ -0,0 +1,51 @@
step,train/loss,train/grad_norm,train/learning_rate,train/epoch,train/train_runtime,train/train_samples_per_second,train/train_steps_per_second,train/total_flos,train/train_loss
2,2.7920000553131104,0.035306449979543686,7.999999979801942e-06,0.010457516647875309,,,,,
4,2.8011999130249023,0.03491510450839996,1.5999999959603883e-05,0.020915033295750618,,,,,
6,2.7964000701904297,0.032717395573854446,2.4000000848900527e-05,0.0313725508749485,,,,,
8,2.763700008392334,0.024953875690698624,3.199999991920777e-05,0.041830066591501236,,,,,
10,3.2811999320983887,0.3170815408229828,3.9999998989515007e-05,0.05228758230805397,,,,,
12,2.9409000873565674,0.04423849284648895,3.995128281530924e-05,0.062745101749897,,,,,
14,2.851900100708008,0.03667925298213959,3.9805359847377986e-05,0.07320261746644974,,,,,
16,2.7869999408721924,0.022814607247710228,3.9562950405525044e-05,0.08366013318300247,,,,,
18,2.782599925994873,0.021528413519263268,3.922523319488391e-05,0.0941176488995552,,,,,
20,2.785599946975708,0.017014438286423683,3.87938525818754e-05,0.10457516461610794,,,,,
22,2.7571001052856445,0.015719758346676826,3.827090768027119e-05,0.11503268033266068,,,,,
24,2.762399911880493,0.016948623582720757,3.7658952351193875e-05,0.125490203499794,,,,,
26,2.7411000728607178,0.015535997226834297,3.6960962461307645e-05,0.13594771921634674,,,,,
28,2.7330000400543213,0.012748735956847668,3.6180339520797133e-05,0.14640523493289948,,,,,
30,2.723299980163574,0.014809778891503811,3.532088885549456e-05,0.1568627506494522,,,,,
32,2.7342000007629395,0.01219236571341753,3.4386797779006884e-05,0.16732026636600494,,,,,
34,2.7321999073028564,0.012785322032868862,3.338261376484297e-05,0.17777778208255768,,,,,
36,2.7314000129699707,0.012986919842660427,3.231322989449836e-05,0.1882352977991104,,,,,
38,2.7065999507904053,0.01096824835985899,3.118385939160362e-05,0.19869281351566315,,,,,
40,2.6958999633789062,0.012387535534799099,2.9999999242136255e-05,0.20915032923221588,,,,,
42,2.751499891281128,0.011586200445890427,2.8767422918463126e-05,0.21960784494876862,,,,,
44,2.713099956512451,0.011821281164884567,2.749213126662653e-05,0.23006536066532135,,,,,
46,2.7102999687194824,0.01147585827857256,2.6180339773418382e-05,0.24052287638187408,,,,,
48,2.7019999027252197,0.011368263512849808,2.483843854861334e-05,0.250980406999588,,,,,
50,2.680500030517578,0.009935515932738781,2.3472963221138343e-05,0.26143792271614075,,,,,
52,2.6993000507354736,0.0109846917912364,2.2090569473220967e-05,0.2718954384326935,,,,,
54,2.6940999031066895,0.010465175844728947,2.0697989384643734e-05,0.2823529541492462,,,,,
56,2.7091000080108643,0.01009758748114109,1.9302009604871273e-05,0.29281046986579895,,,,,
58,2.69950008392334,0.01249368954449892,1.7909431335283443e-05,0.3032679855823517,,,,,
60,2.7216999530792236,0.01051376760005951,1.6527035768376663e-05,0.3137255012989044,,,,,
62,2.7158000469207764,0.01054943073540926,1.516156225989107e-05,0.32418301701545715,,,,,
64,2.7214999198913574,0.01076149195432663,1.3819660125591327e-05,0.3346405327320099,,,,,
66,2.7116000652313232,0.010380392894148827,1.2507867722888477e-05,0.3450980484485626,,,,,
68,2.6923000812530518,0.010425001382827759,1.1232576980546582e-05,0.35555556416511536,,,,,
70,2.683199882507324,0.00925016961991787,9.999999747378752e-06,0.3660130798816681,,,,,
72,2.7093000411987305,0.01072422880679369,8.816142326395493e-06,0.3764705955982208,,,,,
74,2.6988000869750977,0.011063243262469769,7.686770914006047e-06,0.38692811131477356,,,,,
76,2.7125000953674316,0.01013101264834404,6.61738795315614e-06,0.3973856270313263,,,,,
78,2.6993000507354736,0.009940676391124725,5.613203938992228e-06,0.40784314274787903,,,,,
80,2.6861000061035156,0.01050259917974472,4.6791110435151495e-06,0.41830065846443176,,,,,
82,2.6828999519348145,0.009912634268403053,3.819659923465224e-06,0.4287581741809845,,,,,
84,2.699199914932251,0.010668900795280933,3.03903811982309e-06,0.43921568989753723,,,,,
86,2.698899984359741,0.009650414809584618,2.3410482299368596e-06,0.44967320561408997,,,,,
88,2.6884000301361084,0.01006452739238739,1.7290908544964623e-06,0.4601307213306427,,,,,
90,2.6928999423980713,0.010409764014184475,1.206147544507985e-06,0.47058823704719543,,,,,
92,2.714200019836426,0.009937116876244545,7.74766078848188e-07,0.48104575276374817,,,,,
94,2.672300100326538,0.009728306904435158,4.370479871340649e-07,0.4915032684803009,,,,,
96,2.7018001079559326,0.010098566301167011,1.9463863054625108e-07,0.501960813999176,,,,,
98,2.7123000621795654,0.009524320252239704,4.8718995060426096e-08,0.5124183297157288,,,,,
100,2.7028000354766846,0.009290286339819431,0.0,0.5228758454322815,788.0635986328125,64.96900177001953,0.12700000405311584,4.629706395531346e+17,2.739542245864868
1 step train/loss train/grad_norm train/learning_rate train/epoch train/train_runtime train/train_samples_per_second train/train_steps_per_second train/total_flos train/train_loss
2 2 2.7920000553131104 0.035306449979543686 7.999999979801942e-06 0.010457516647875309
3 4 2.8011999130249023 0.03491510450839996 1.5999999959603883e-05 0.020915033295750618
4 6 2.7964000701904297 0.032717395573854446 2.4000000848900527e-05 0.0313725508749485
5 8 2.763700008392334 0.024953875690698624 3.199999991920777e-05 0.041830066591501236
6 10 3.2811999320983887 0.3170815408229828 3.9999998989515007e-05 0.05228758230805397
7 12 2.9409000873565674 0.04423849284648895 3.995128281530924e-05 0.062745101749897
8 14 2.851900100708008 0.03667925298213959 3.9805359847377986e-05 0.07320261746644974
9 16 2.7869999408721924 0.022814607247710228 3.9562950405525044e-05 0.08366013318300247
10 18 2.782599925994873 0.021528413519263268 3.922523319488391e-05 0.0941176488995552
11 20 2.785599946975708 0.017014438286423683 3.87938525818754e-05 0.10457516461610794
12 22 2.7571001052856445 0.015719758346676826 3.827090768027119e-05 0.11503268033266068
13 24 2.762399911880493 0.016948623582720757 3.7658952351193875e-05 0.125490203499794
14 26 2.7411000728607178 0.015535997226834297 3.6960962461307645e-05 0.13594771921634674
15 28 2.7330000400543213 0.012748735956847668 3.6180339520797133e-05 0.14640523493289948
16 30 2.723299980163574 0.014809778891503811 3.532088885549456e-05 0.1568627506494522
17 32 2.7342000007629395 0.01219236571341753 3.4386797779006884e-05 0.16732026636600494
18 34 2.7321999073028564 0.012785322032868862 3.338261376484297e-05 0.17777778208255768
19 36 2.7314000129699707 0.012986919842660427 3.231322989449836e-05 0.1882352977991104
20 38 2.7065999507904053 0.01096824835985899 3.118385939160362e-05 0.19869281351566315
21 40 2.6958999633789062 0.012387535534799099 2.9999999242136255e-05 0.20915032923221588
22 42 2.751499891281128 0.011586200445890427 2.8767422918463126e-05 0.21960784494876862
23 44 2.713099956512451 0.011821281164884567 2.749213126662653e-05 0.23006536066532135
24 46 2.7102999687194824 0.01147585827857256 2.6180339773418382e-05 0.24052287638187408
25 48 2.7019999027252197 0.011368263512849808 2.483843854861334e-05 0.250980406999588
26 50 2.680500030517578 0.009935515932738781 2.3472963221138343e-05 0.26143792271614075
27 52 2.6993000507354736 0.0109846917912364 2.2090569473220967e-05 0.2718954384326935
28 54 2.6940999031066895 0.010465175844728947 2.0697989384643734e-05 0.2823529541492462
29 56 2.7091000080108643 0.01009758748114109 1.9302009604871273e-05 0.29281046986579895
30 58 2.69950008392334 0.01249368954449892 1.7909431335283443e-05 0.3032679855823517
31 60 2.7216999530792236 0.01051376760005951 1.6527035768376663e-05 0.3137255012989044
32 62 2.7158000469207764 0.01054943073540926 1.516156225989107e-05 0.32418301701545715
33 64 2.7214999198913574 0.01076149195432663 1.3819660125591327e-05 0.3346405327320099
34 66 2.7116000652313232 0.010380392894148827 1.2507867722888477e-05 0.3450980484485626
35 68 2.6923000812530518 0.010425001382827759 1.1232576980546582e-05 0.35555556416511536
36 70 2.683199882507324 0.00925016961991787 9.999999747378752e-06 0.3660130798816681
37 72 2.7093000411987305 0.01072422880679369 8.816142326395493e-06 0.3764705955982208
38 74 2.6988000869750977 0.011063243262469769 7.686770914006047e-06 0.38692811131477356
39 76 2.7125000953674316 0.01013101264834404 6.61738795315614e-06 0.3973856270313263
40 78 2.6993000507354736 0.009940676391124725 5.613203938992228e-06 0.40784314274787903
41 80 2.6861000061035156 0.01050259917974472 4.6791110435151495e-06 0.41830065846443176
42 82 2.6828999519348145 0.009912634268403053 3.819659923465224e-06 0.4287581741809845
43 84 2.699199914932251 0.010668900795280933 3.03903811982309e-06 0.43921568989753723
44 86 2.698899984359741 0.009650414809584618 2.3410482299368596e-06 0.44967320561408997
45 88 2.6884000301361084 0.01006452739238739 1.7290908544964623e-06 0.4601307213306427
46 90 2.6928999423980713 0.010409764014184475 1.206147544507985e-06 0.47058823704719543
47 92 2.714200019836426 0.009937116876244545 7.74766078848188e-07 0.48104575276374817
48 94 2.672300100326538 0.009728306904435158 4.370479871340649e-07 0.4915032684803009
49 96 2.7018001079559326 0.010098566301167011 1.9463863054625108e-07 0.501960813999176
50 98 2.7123000621795654 0.009524320252239704 4.8718995060426096e-08 0.5124183297157288
51 100 2.7028000354766846 0.009290286339819431 0.0 0.5228758454322815 788.0635986328125 64.96900177001953 0.12700000405311584 4.629706395531346e+17 2.739542245864868

Binary file not shown.

After

Width:  |  Height:  |  Size: 46 KiB

51
example/npu_sft.csv Normal file
View File

@@ -0,0 +1,51 @@
step,train/loss,train/grad_norm,train/learning_rate,train/epoch,train/train_runtime,train/train_samples_per_second,train/train_steps_per_second,train/total_flos,train/train_loss
2,1.1491999626159668,0.6218180060386658,1.9999999949504854e-06,0.0004617871018126607,,,,,
4,1.0981999635696411,0.6825665235519409,3.999999989900971e-06,0.0009235742036253214,,,,,
6,1.1269999742507935,0.7838642001152039,6.000000212225132e-06,0.001385361305437982,,,,,
8,1.0542000532150269,0.8744276762008667,7.999999979801942e-06,0.0018471484072506428,,,,,
10,1.2441999912261963,0.7064258456230164,9.999999747378752e-06,0.0023089356254786253,,,,,
12,1.2927000522613525,0.6829814910888672,1.2000000424450263e-05,0.002770722610875964,,,,,
14,1.0844999551773071,0.5265647172927856,1.4000000192027073e-05,0.0032325098291039467,,,,,
16,1.0963000059127808,0.4373657703399658,1.5999999959603883e-05,0.0036942968145012856,,,,,
18,1.0615999698638916,0.46220508217811584,1.8000000636675395e-05,0.004156084265559912,,,,,
20,1.3325999975204468,0.7157824039459229,1.9999999494757503e-05,0.004617871250957251,,,,,
22,1.2070000171661377,0.5933427214622498,1.996917308133561e-05,0.0050796582363545895,,,,,
24,1.2044999599456787,0.5816172957420349,1.9876883015967906e-05,0.005541445221751928,,,,,
26,1.0740000009536743,0.4489712119102478,1.9723698642337695e-05,0.0060032326728105545,,,,,
28,1.1164000034332275,0.3696516752243042,1.9510565834934823e-05,0.006465019658207893,,,,,
30,1.045199990272522,0.4376335144042969,1.9238796085119247e-05,0.006926806643605232,,,,,
32,1.1247999668121338,0.4589230716228485,1.8910064682131633e-05,0.007388593629002571,,,,,
34,1.0688999891281128,0.3879022002220154,1.8526401618146338e-05,0.00785038061439991,,,,,
36,1.0292999744415283,0.4027869403362274,1.8090169760398567e-05,0.008312168531119823,,,,,
38,1.1052000522613525,0.37394437193870544,1.7604059394216165e-05,0.008773955516517162,,,,,
40,1.1557999849319458,0.3808683753013611,1.7071068214136176e-05,0.009235742501914501,,,,,
42,1.0232000350952148,0.4252733886241913,1.6494481315021403e-05,0.00969752948731184,,,,,
44,1.0364999771118164,0.48068660497665405,1.5877853002166376e-05,0.010159316472709179,,,,,
46,1.1340999603271484,0.37313926219940186,1.5224985872919206e-05,0.010621103458106518,,,,,
48,1.0866999626159668,0.4175492823123932,1.453990535082994e-05,0.011082890443503857,,,,,
50,1.1039999723434448,0.35443660616874695,1.3826834219798911e-05,0.01154467836022377,,,,,
52,1.1480000019073486,0.39232146739959717,1.3090169886709191e-05,0.012006465345621109,,,,,
54,1.1861000061035156,0.396918922662735,1.2334453458606731e-05,0.012468252331018448,,,,,
56,1.0096999406814575,0.3885609209537506,1.1564344276848715e-05,0.012930039316415787,,,,,
58,1.114799976348877,0.4421806335449219,1.0784590813273098e-05,0.013391826301813126,,,,,
60,1.0795999765396118,0.36081990599632263,9.999999747378752e-06,0.013853613287210464,,,,,
62,1.1764999628067017,0.4062329828739166,9.215408681484405e-06,0.014315400272607803,,,,,
64,1.107200026512146,0.39982733130455017,8.435655217908788e-06,0.014777187258005142,,,,,
66,1.1868000030517578,0.4688170254230499,7.665546036150772e-06,0.015238975174725056,,,,,
68,1.0312999486923218,0.3301626741886139,6.909830062795663e-06,0.01570076122879982,,,,,
70,1.1089999675750732,0.3377252221107483,6.173165729705943e-06,0.01616254821419716,,,,,
72,1.0716999769210815,0.39666977524757385,5.460095053422265e-06,0.016624337062239647,,,,,
74,1.1461999416351318,0.4125552177429199,4.7750145313329995e-06,0.017086124047636986,,,,,
76,1.042199969291687,0.3825180232524872,4.1221474020858295e-06,0.017547911033034325,,,,,
78,0.9157000184059143,0.4063441753387451,3.505519543978153e-06,0.018009698018431664,,,,,
80,1.1110999584197998,0.35289037227630615,2.9289321901160292e-06,0.018471485003829002,,,,,
82,1.167199969291687,0.33720290660858154,2.3959403279150138e-06,0.01893327198922634,,,,,
84,1.2375999689102173,0.38099613785743713,1.909829961732612e-06,0.01939505897462368,,,,,
86,1.2151999473571777,0.3848689794540405,1.4735983313585166e-06,0.01985684596002102,,,,,
88,1.1628999710083008,0.40408074855804443,1.0899348126258701e-06,0.020318632945418358,,,,,
90,1.1884000301361084,0.4015007019042969,7.612046601934708e-07,0.020780419930815697,,,,,
92,1.152500033378601,0.38306349515914917,4.894348535344761e-07,0.021242206916213036,,,,,
94,1.154099941253662,0.45273807644844055,2.7630079557638965e-07,0.021703993901610374,,,,,
96,1.0618000030517578,0.35036078095436096,1.2311659247643547e-07,0.022165780887007713,,,,,
98,1.0270999670028687,0.40208569169044495,3.0826662111849146e-08,0.022627567872405052,,,,,
100,1.0285999774932861,0.38247284293174744,0.0,0.02308935672044754,728.7083129882812,2.196000099182129,0.13699999451637268,1862467846144.0,1.117748498916626
1 step train/loss train/grad_norm train/learning_rate train/epoch train/train_runtime train/train_samples_per_second train/train_steps_per_second train/total_flos train/train_loss
2 2 1.1491999626159668 0.6218180060386658 1.9999999949504854e-06 0.0004617871018126607
3 4 1.0981999635696411 0.6825665235519409 3.999999989900971e-06 0.0009235742036253214
4 6 1.1269999742507935 0.7838642001152039 6.000000212225132e-06 0.001385361305437982
5 8 1.0542000532150269 0.8744276762008667 7.999999979801942e-06 0.0018471484072506428
6 10 1.2441999912261963 0.7064258456230164 9.999999747378752e-06 0.0023089356254786253
7 12 1.2927000522613525 0.6829814910888672 1.2000000424450263e-05 0.002770722610875964
8 14 1.0844999551773071 0.5265647172927856 1.4000000192027073e-05 0.0032325098291039467
9 16 1.0963000059127808 0.4373657703399658 1.5999999959603883e-05 0.0036942968145012856
10 18 1.0615999698638916 0.46220508217811584 1.8000000636675395e-05 0.004156084265559912
11 20 1.3325999975204468 0.7157824039459229 1.9999999494757503e-05 0.004617871250957251
12 22 1.2070000171661377 0.5933427214622498 1.996917308133561e-05 0.0050796582363545895
13 24 1.2044999599456787 0.5816172957420349 1.9876883015967906e-05 0.005541445221751928
14 26 1.0740000009536743 0.4489712119102478 1.9723698642337695e-05 0.0060032326728105545
15 28 1.1164000034332275 0.3696516752243042 1.9510565834934823e-05 0.006465019658207893
16 30 1.045199990272522 0.4376335144042969 1.9238796085119247e-05 0.006926806643605232
17 32 1.1247999668121338 0.4589230716228485 1.8910064682131633e-05 0.007388593629002571
18 34 1.0688999891281128 0.3879022002220154 1.8526401618146338e-05 0.00785038061439991
19 36 1.0292999744415283 0.4027869403362274 1.8090169760398567e-05 0.008312168531119823
20 38 1.1052000522613525 0.37394437193870544 1.7604059394216165e-05 0.008773955516517162
21 40 1.1557999849319458 0.3808683753013611 1.7071068214136176e-05 0.009235742501914501
22 42 1.0232000350952148 0.4252733886241913 1.6494481315021403e-05 0.00969752948731184
23 44 1.0364999771118164 0.48068660497665405 1.5877853002166376e-05 0.010159316472709179
24 46 1.1340999603271484 0.37313926219940186 1.5224985872919206e-05 0.010621103458106518
25 48 1.0866999626159668 0.4175492823123932 1.453990535082994e-05 0.011082890443503857
26 50 1.1039999723434448 0.35443660616874695 1.3826834219798911e-05 0.01154467836022377
27 52 1.1480000019073486 0.39232146739959717 1.3090169886709191e-05 0.012006465345621109
28 54 1.1861000061035156 0.396918922662735 1.2334453458606731e-05 0.012468252331018448
29 56 1.0096999406814575 0.3885609209537506 1.1564344276848715e-05 0.012930039316415787
30 58 1.114799976348877 0.4421806335449219 1.0784590813273098e-05 0.013391826301813126
31 60 1.0795999765396118 0.36081990599632263 9.999999747378752e-06 0.013853613287210464
32 62 1.1764999628067017 0.4062329828739166 9.215408681484405e-06 0.014315400272607803
33 64 1.107200026512146 0.39982733130455017 8.435655217908788e-06 0.014777187258005142
34 66 1.1868000030517578 0.4688170254230499 7.665546036150772e-06 0.015238975174725056
35 68 1.0312999486923218 0.3301626741886139 6.909830062795663e-06 0.01570076122879982
36 70 1.1089999675750732 0.3377252221107483 6.173165729705943e-06 0.01616254821419716
37 72 1.0716999769210815 0.39666977524757385 5.460095053422265e-06 0.016624337062239647
38 74 1.1461999416351318 0.4125552177429199 4.7750145313329995e-06 0.017086124047636986
39 76 1.042199969291687 0.3825180232524872 4.1221474020858295e-06 0.017547911033034325
40 78 0.9157000184059143 0.4063441753387451 3.505519543978153e-06 0.018009698018431664
41 80 1.1110999584197998 0.35289037227630615 2.9289321901160292e-06 0.018471485003829002
42 82 1.167199969291687 0.33720290660858154 2.3959403279150138e-06 0.01893327198922634
43 84 1.2375999689102173 0.38099613785743713 1.909829961732612e-06 0.01939505897462368
44 86 1.2151999473571777 0.3848689794540405 1.4735983313585166e-06 0.01985684596002102
45 88 1.1628999710083008 0.40408074855804443 1.0899348126258701e-06 0.020318632945418358
46 90 1.1884000301361084 0.4015007019042969 7.612046601934708e-07 0.020780419930815697
47 92 1.152500033378601 0.38306349515914917 4.894348535344761e-07 0.021242206916213036
48 94 1.154099941253662 0.45273807644844055 2.7630079557638965e-07 0.021703993901610374
49 96 1.0618000030517578 0.35036078095436096 1.2311659247643547e-07 0.022165780887007713
50 98 1.0270999670028687 0.40208569169044495 3.0826662111849146e-08 0.022627567872405052
51 100 1.0285999774932861 0.38247284293174744 0.0 0.02308935672044754 728.7083129882812 2.196000099182129 0.13699999451637268 1862467846144.0 1.117748498916626

BIN
example/npu_sft_loss.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 67 KiB

8
example/requirements.txt Normal file
View File

@@ -0,0 +1,8 @@
transformers==4.46.3
tokenizers==0.20.3
accelerate==1.1.1
deepspeed==0.16.2
datasets==3.1.0
safetensors==0.4.5
pyarrow==17.0.0
tensorboard==2.18.0

38
example/run.sh Normal file
View File

@@ -0,0 +1,38 @@
#!/bin/bash
MODEL_PATH="/model/BitCPM-CANN-1B-unquantized"
DATA_PATH="/dataset/c4-pro/data/000_1_7.parquet"
OUTPUT_DIR="./output"
DS_CONFIG="./ds_config_z2.json"
NUM_GPUS=8
BATCH_SIZE_PER_GPU=8
GRAD_ACCUM_STEPS=8
MAX_SEQ_LENGTH=1024
export ASCEND_RT_VISIBLE_DEVICES=8,9,10,11,12,13,14,15
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
export DS_SKIP_CUDA_CHECK=1
torchrun --nproc_per_node=$NUM_GPUS train.py \
--model_name_or_path $MODEL_PATH \
--data_path $DATA_PATH \
--max_seq_length $MAX_SEQ_LENGTH \
--output_dir $OUTPUT_DIR \
--per_device_train_batch_size $BATCH_SIZE_PER_GPU \
--gradient_accumulation_steps $GRAD_ACCUM_STEPS \
--max_steps 100 \
--learning_rate 4e-5 \
--lr_scheduler_type cosine \
--warmup_ratio 0.1 \
--weight_decay 1e-2 \
--logging_steps 2 \
--save_steps 500 \
--save_total_limit 3 \
--bf16 \
--deepspeed $DS_CONFIG \
--gradient_checkpointing \
--seed 42 \
--dataloader_num_workers 4 \
--report_to tensorboard \
--logging_dir /data/tensorboard/pretrain \
--gradient_checkpointing_kwargs '{"use_reentrant": false}'

40
example/run_sft.sh Normal file
View File

@@ -0,0 +1,40 @@
#!/bin/bash
MODEL_PATH="/model/BitCPM-CANN-1B-unquantized"
DATA_PATH="/dataset/HuggingFaceH4_ultrachat_200k/data/train_sft-00000-of-00003-a3ecf92756993583.parquet"
OUTPUT_DIR="./output_sft"
DS_CONFIG="./ds_config.json"
NUM_GPUS=8
BATCH_SIZE_PER_GPU=2
GRAD_ACCUM_STEPS=1
MAX_SEQ_LENGTH=8192
export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
export DS_SKIP_CUDA_CHECK=1
torchrun --nproc_per_node=$NUM_GPUS train_sft.py \
--model_name_or_path $MODEL_PATH \
--data_path $DATA_PATH \
--max_seq_length $MAX_SEQ_LENGTH \
--output_dir $OUTPUT_DIR \
--per_device_train_batch_size $BATCH_SIZE_PER_GPU \
--gradient_accumulation_steps $GRAD_ACCUM_STEPS \
--max_steps 100 \
--learning_rate 2e-5 \
--lr_scheduler_type cosine \
--warmup_ratio 0.2 \
--weight_decay 0.0 \
--logging_steps 2 \
--save_steps 500 \
--save_total_limit 3 \
--bf16 \
--deepspeed $DS_CONFIG \
--gradient_checkpointing \
--seed 42 \
--dataloader_num_workers 4 \
--report_to tensorboard \
--logging_dir /data/tensorboard/sft \
--train_on_prompt false \
--gradient_checkpointing_kwargs '{"use_reentrant": false}'

203
example/train.py Normal file
View File

@@ -0,0 +1,203 @@
"""
Continual pretraining script for CPM-2B model using DeepSpeed + HuggingFace Trainer.
"""
import os
import json
import math
import logging
from dataclasses import dataclass, field
from typing import Optional
import contextlib
import torch
from datasets import load_dataset
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
AutoConfig,
Trainer,
TrainingArguments,
HfArgumentParser,
DataCollatorForLanguageModeling,
set_seed,
)
import deepspeed
_orig_no_sync = deepspeed.DeepSpeedEngine.no_sync
@contextlib.contextmanager
def _patched_no_sync(self):
try:
with _orig_no_sync(self):
yield
except AssertionError:
yield
deepspeed.DeepSpeedEngine.no_sync = _patched_no_sync
logger = logging.getLogger(__name__)
@dataclass
class ModelArguments:
model_name_or_path: str = field(
metadata={"help": "Path to pretrained model or model identifier"}
)
torch_dtype: Optional[str] = field(
default="bfloat16",
metadata={"help": "torch dtype for model weights (float16, bfloat16, float32)"},
)
@dataclass
class DataArguments:
data_path: str = field(
metadata={"help": "Path to training data (parquet file or directory)"}
)
max_seq_length: int = field(
default=4096,
metadata={"help": "Maximum sequence length for training"},
)
text_column: str = field(
default="text",
metadata={"help": "Name of the text column in the dataset"},
)
preprocessing_num_workers: int = field(
default=8,
metadata={"help": "Number of workers for data preprocessing"},
)
def tokenize_and_group(dataset, tokenizer, data_args):
"""Tokenize texts and group into chunks of max_seq_length."""
column_names = dataset.column_names
text_column = data_args.text_column
if text_column not in column_names:
candidates = [c for c in column_names if "text" in c.lower()]
if candidates:
text_column = candidates[0]
else:
text_column = column_names[0]
logger.warning(f"Column '{data_args.text_column}' not found, using '{text_column}'")
def tokenize_function(examples):
return tokenizer(examples[text_column], add_special_tokens=False)
tokenized_dataset = dataset.map(
tokenize_function,
batched=True,
num_proc=data_args.preprocessing_num_workers,
remove_columns=column_names,
desc="Tokenizing",
)
block_size = data_args.max_seq_length
def group_texts(examples):
concatenated = {k: sum(examples[k], []) for k in examples.keys()}
total_length = len(concatenated["input_ids"])
total_length = (total_length // block_size) * block_size
result = {
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
for k, t in concatenated.items()
}
result["labels"] = result["input_ids"].copy()
return result
grouped_dataset = tokenized_dataset.map(
group_texts,
batched=True,
num_proc=data_args.preprocessing_num_workers,
desc="Grouping texts",
)
return grouped_dataset
def main():
parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
)
logger.info(f"Training args: {training_args}")
set_seed(training_args.seed)
dtype_map = {
"float16": torch.float16,
"bfloat16": torch.bfloat16,
"float32": torch.float32,
}
torch_dtype = dtype_map.get(model_args.torch_dtype, torch.bfloat16)
logger.info(f"Loading tokenizer from {model_args.model_name_or_path}")
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
trust_remote_code=True,
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
logger.info(f"Loading model from {model_args.model_name_or_path}")
model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
torch_dtype=torch_dtype,
trust_remote_code=True,
attn_implementation="sdpa",
)
model.config.use_cache = False
logger.info(f"Loading dataset from {data_args.data_path}")
if os.path.isfile(data_args.data_path):
raw_dataset = load_dataset("parquet", data_files=data_args.data_path, split="train")
elif os.path.isdir(data_args.data_path):
parquet_files = [
os.path.join(data_args.data_path, f)
for f in os.listdir(data_args.data_path)
if f.endswith(".parquet")
]
raw_dataset = load_dataset("parquet", data_files=parquet_files, split="train")
else:
raise ValueError(f"Data path not found: {data_args.data_path}")
logger.info(f"Dataset loaded: {len(raw_dataset)} samples, columns: {raw_dataset.column_names}")
train_dataset = tokenize_and_group(raw_dataset, tokenizer, data_args)
logger.info(f"Processed dataset: {len(train_dataset)} samples of length {data_args.max_seq_length}")
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=False,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
data_collator=data_collator,
)
logger.info("Starting training...")
train_result = trainer.train(
resume_from_checkpoint=training_args.resume_from_checkpoint
)
trainer.save_model()
trainer.save_state()
metrics = train_result.metrics
metrics["train_samples"] = len(train_dataset)
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
if __name__ == "__main__":
main()

424
example/train_sft.py Normal file
View File

@@ -0,0 +1,424 @@
"""
Supervised fine-tuning script using DeepSpeed + HuggingFace Trainer.
"""
import json
import logging
import os
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple
import contextlib
import torch
from datasets import load_dataset
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
HfArgumentParser,
Trainer,
TrainingArguments,
set_seed,
)
import deepspeed
_orig_no_sync = deepspeed.DeepSpeedEngine.no_sync
@contextlib.contextmanager
def _patched_no_sync(self):
try:
with _orig_no_sync(self):
yield
except AssertionError:
yield
deepspeed.DeepSpeedEngine.no_sync = _patched_no_sync
logger = logging.getLogger(__name__)
IGNORE_INDEX = -100
@dataclass
class ModelArguments:
model_name_or_path: str = field(
metadata={"help": "Path to pretrained model or model identifier"}
)
torch_dtype: Optional[str] = field(
default="bfloat16",
metadata={"help": "torch dtype for model weights (float16, bfloat16, float32)"},
)
@dataclass
class DataArguments:
data_path: str = field(metadata={"help": "Path to SFT data file or directory"})
max_seq_length: int = field(
default=4096,
metadata={"help": "Maximum sequence length for training"},
)
prompt_column: Optional[str] = field(
default=None,
metadata={"help": "Prompt/instruction column name. Auto-detected if omitted."},
)
input_column: Optional[str] = field(
default=None,
metadata={"help": "Optional extra input/context column name"},
)
response_column: Optional[str] = field(
default=None,
metadata={"help": "Response/output column name. Auto-detected if omitted."},
)
messages_column: Optional[str] = field(
default=None,
metadata={"help": "Chat messages column name. Auto-detected if omitted."},
)
system_column: Optional[str] = field(
default=None,
metadata={"help": "Optional system prompt column name"},
)
train_on_prompt: bool = field(
default=False,
metadata={"help": "Whether to compute loss on prompt/user tokens"},
)
add_eos_token: bool = field(
default=True,
metadata={"help": "Append eos_token to plain prompt/response examples"},
)
preprocessing_num_workers: int = field(
default=8,
metadata={"help": "Number of workers for data preprocessing"},
)
class SFTDataCollator:
def __init__(self, tokenizer, pad_to_multiple_of: Optional[int] = 8):
self.tokenizer = tokenizer
self.pad_to_multiple_of = pad_to_multiple_of
def __call__(self, features: List[Dict[str, List[int]]]) -> Dict[str, torch.Tensor]:
max_length = max(len(feature["input_ids"]) for feature in features)
if self.pad_to_multiple_of:
multiple = self.pad_to_multiple_of
max_length = ((max_length + multiple - 1) // multiple) * multiple
input_ids = []
attention_mask = []
labels = []
pad_token_id = self.tokenizer.pad_token_id
for feature in features:
length = len(feature["input_ids"])
pad_length = max_length - length
input_ids.append(feature["input_ids"] + [pad_token_id] * pad_length)
attention_mask.append([1] * length + [0] * pad_length)
labels.append(feature["labels"] + [IGNORE_INDEX] * pad_length)
return {
"input_ids": torch.tensor(input_ids, dtype=torch.long),
"attention_mask": torch.tensor(attention_mask, dtype=torch.long),
"labels": torch.tensor(labels, dtype=torch.long),
}
def load_sft_dataset(data_path: str):
if os.path.isfile(data_path):
extension = os.path.splitext(data_path)[1].lstrip(".").lower()
if extension == "jsonl":
extension = "json"
if extension not in {"parquet", "json", "csv", "txt"}:
raise ValueError(f"Unsupported data file extension: {extension}")
return load_dataset(extension, data_files=data_path, split="train")
if os.path.isdir(data_path):
data_files = []
extension = None
for name in os.listdir(data_path):
current_extension = os.path.splitext(name)[1].lstrip(".").lower()
if current_extension == "jsonl":
current_extension = "json"
if current_extension in {"parquet", "json", "csv", "txt"}:
extension = extension or current_extension
if current_extension == extension:
data_files.append(os.path.join(data_path, name))
if not data_files or extension is None:
raise ValueError(f"No supported data files found in: {data_path}")
return load_dataset(extension, data_files=sorted(data_files), split="train")
raise ValueError(f"Data path not found: {data_path}")
def choose_column(
column_names: List[str], explicit: Optional[str], candidates: List[str]
) -> Optional[str]:
if explicit:
if explicit not in column_names:
raise ValueError(f"Column '{explicit}' not found. Available columns: {column_names}")
return explicit
for name in candidates:
if name in column_names:
return name
return None
def parse_messages(value: Any) -> List[Dict[str, str]]:
if isinstance(value, str):
value = json.loads(value)
if not isinstance(value, list):
raise ValueError("messages/conversations column must be a list or JSON string")
messages = []
for item in value:
if not isinstance(item, dict):
raise ValueError("Each message must be a dict")
role = item.get("role", item.get("from"))
content = item.get("content", item.get("value"))
if role == "human":
role = "user"
elif role == "gpt":
role = "assistant"
if role is None or content is None:
raise ValueError("Each message must contain role/from and content/value")
messages.append({"role": str(role), "content": str(content)})
return messages
def tokenize_text(tokenizer, text: str) -> List[int]:
return tokenizer(text, add_special_tokens=False)["input_ids"]
def apply_chat_template(tokenizer, messages: List[Dict[str, str]], add_generation_prompt: bool) -> str:
if tokenizer.chat_template is None:
raise ValueError(
"The tokenizer has no chat_template. Use prompt/response columns or set a chat_template."
)
return tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=add_generation_prompt,
)
def encode_prompt_response(
example: Dict[str, Any],
tokenizer,
data_args: DataArguments,
prompt_column: str,
input_column: Optional[str],
response_column: str,
) -> Tuple[List[int], List[int]]:
prompt = str(example[prompt_column])
if input_column and example.get(input_column):
prompt = prompt + "\n" + str(example[input_column])
response = str(example[response_column])
messages = []
if data_args.system_column and example.get(data_args.system_column):
messages.append({"role": "system", "content": str(example[data_args.system_column])})
messages.append({"role": "user", "content": prompt})
messages.append({"role": "assistant", "content": response})
if tokenizer.chat_template is not None:
full_text = apply_chat_template(tokenizer, messages, add_generation_prompt=False)
prompt_text = apply_chat_template(tokenizer, messages[:-1], add_generation_prompt=True)
input_ids = tokenize_text(tokenizer, full_text)
prompt_length = len(tokenize_text(tokenizer, prompt_text))
else:
response_text = response
if data_args.add_eos_token and tokenizer.eos_token:
response_text += tokenizer.eos_token
full_text = prompt + "\n" + response_text
input_ids = tokenize_text(tokenizer, full_text)
prompt_length = len(tokenize_text(tokenizer, prompt + "\n"))
labels = input_ids.copy()
if not data_args.train_on_prompt:
labels[:prompt_length] = [IGNORE_INDEX] * min(prompt_length, len(labels))
return input_ids, labels
def encode_messages(
example: Dict[str, Any],
tokenizer,
data_args: DataArguments,
messages_column: str,
) -> Tuple[List[int], List[int]]:
messages = parse_messages(example[messages_column])
if tokenizer.chat_template is not None:
full_text = apply_chat_template(tokenizer, messages, add_generation_prompt=False)
input_ids = tokenize_text(tokenizer, full_text)
labels = [IGNORE_INDEX] * len(input_ids)
if data_args.train_on_prompt:
labels = input_ids.copy()
else:
for index, message in enumerate(messages):
if message["role"] != "assistant":
continue
before_text = apply_chat_template(
tokenizer, messages[:index], add_generation_prompt=True
)
after_text = apply_chat_template(
tokenizer, messages[: index + 1], add_generation_prompt=False
)
start = len(tokenize_text(tokenizer, before_text))
end = len(tokenize_text(tokenizer, after_text))
labels[start:end] = input_ids[start:end]
else:
labels = []
input_ids = []
for message in messages:
part = f"{message['role']}: {message['content']}\n"
if data_args.add_eos_token and message["role"] == "assistant" and tokenizer.eos_token:
part += tokenizer.eos_token
part_ids = tokenize_text(tokenizer, part)
input_ids.extend(part_ids)
if data_args.train_on_prompt or message["role"] == "assistant":
labels.extend(part_ids)
else:
labels.extend([IGNORE_INDEX] * len(part_ids))
return input_ids, labels
def preprocess_sft_dataset(raw_dataset, tokenizer, data_args: DataArguments):
column_names = raw_dataset.column_names
messages_column = choose_column(
column_names, data_args.messages_column, ["messages", "conversations"]
)
prompt_column = choose_column(
column_names,
data_args.prompt_column,
["prompt", "instruction", "question"],
)
input_column = choose_column(
column_names,
data_args.input_column,
["input", "context"],
)
response_column = choose_column(
column_names,
data_args.response_column,
["response", "output", "answer", "chosen"],
)
if messages_column:
logger.info(f"Using chat messages column: {messages_column}")
elif prompt_column and response_column:
logger.info(f"Using prompt column '{prompt_column}' and response column '{response_column}'")
else:
raise ValueError(
"Cannot infer SFT data format. Provide either messages/conversations or "
"prompt/instruction plus response/output columns."
)
def encode_batch(examples):
batch_input_ids = []
batch_labels = []
batch_attention_mask = []
batch_size = len(next(iter(examples.values())))
for i in range(batch_size):
example = {name: values[i] for name, values in examples.items()}
if messages_column:
input_ids, labels = encode_messages(example, tokenizer, data_args, messages_column)
else:
input_ids, labels = encode_prompt_response(
example, tokenizer, data_args, prompt_column, input_column, response_column
)
input_ids = input_ids[: data_args.max_seq_length]
labels = labels[: data_args.max_seq_length]
if not input_ids or all(label == IGNORE_INDEX for label in labels):
continue
batch_input_ids.append(input_ids)
batch_labels.append(labels)
batch_attention_mask.append([1] * len(input_ids))
return {
"input_ids": batch_input_ids,
"attention_mask": batch_attention_mask,
"labels": batch_labels,
}
return raw_dataset.map(
encode_batch,
batched=True,
num_proc=data_args.preprocessing_num_workers,
remove_columns=column_names,
desc="Tokenizing SFT data",
)
def main():
parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
)
logger.info(f"Training args: {training_args}")
set_seed(training_args.seed)
dtype_map = {
"float16": torch.float16,
"bfloat16": torch.bfloat16,
"float32": torch.float32,
}
torch_dtype = dtype_map.get(model_args.torch_dtype, torch.bfloat16)
logger.info(f"Loading tokenizer from {model_args.model_name_or_path}")
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
trust_remote_code=True,
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
logger.info(f"Loading model from {model_args.model_name_or_path}")
model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
torch_dtype=torch_dtype,
trust_remote_code=True,
attn_implementation="sdpa",
)
model.config.use_cache = False
logger.info(f"Loading SFT dataset from {data_args.data_path}")
raw_dataset = load_sft_dataset(data_args.data_path)
logger.info(f"Dataset loaded: {len(raw_dataset)} samples, columns: {raw_dataset.column_names}")
train_dataset = preprocess_sft_dataset(raw_dataset, tokenizer, data_args)
logger.info(f"Processed dataset: {len(train_dataset)} samples")
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
data_collator=SFTDataCollator(tokenizer),
)
logger.info("Starting SFT training...")
train_result = trainer.train(
resume_from_checkpoint=training_args.resume_from_checkpoint
)
trainer.save_model()
trainer.save_state()
metrics = train_result.metrics
metrics["train_samples"] = len(train_dataset)
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
if __name__ == "__main__":
main()

8
generation_config.json Normal file
View File

@@ -0,0 +1,8 @@
{
"do_sample": true,
"top_p": 0.8,
"temperature": 0.8,
"bos_token_id": 1,
"eos_token_id": [2,73440],
"pad_token_id": 2
}

1598
modeling_llama.py Normal file

File diff suppressed because it is too large Load Diff

3
pytorch_model.bin Normal file
View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:5ad186deeb7a0b84bcd58a5b9ad90baa300c719cf363b215dc999f9d11ef031f
size 3244413702

176
qat-convert.py Normal file
View File

@@ -0,0 +1,176 @@
import torch
import torch.nn as nn
from tqdm import tqdm
import os
import safetensors
class SteTernaryQuantizer(nn.Module):
def __init__(self, group_size):
super().__init__()
self.group_size = group_size
def forward(self, x):
org_w_shape = x.shape
if self.group_size > 0:
assert x.shape[-1] % self.group_size == 0
x = x.reshape(-1, self.group_size)
elif self.group_size == -1:
x = x.reshape(-1, x.shape[-1])
assert x.dim() == 2
scales = 1.0 / (x.abs().mean(dim=1, keepdim=True).clamp_(min=1e-5))
x_q = (torch.clamp(torch.round(x * scales),-1,1) / scales)
assert torch.isnan(x_q).sum() == 0
x = x.reshape(org_w_shape)
x_q = x_q.reshape(org_w_shape)
return x_q
class SteIntQuantizer(nn.Module):
def __init__(self, bit, group_size):
super().__init__()
self.bit = bit
self.group_size = group_size
def forward(self, x):
org_w_shape = x.shape
if self.group_size > 0:
assert org_w_shape[-1] % self.group_size == 0
x = x.reshape(-1, self.group_size)
elif self.group_size == -1:
x = x.reshape(-1, x.shape[-1])
assert x.dim() == 2
abs_max_val = x.abs().amax(dim=1, keepdim=True)
max_int = 2 ** (self.bit - 1) - 1
min_int = - (2 ** (self.bit - 1))
scales = abs_max_val.clamp(min=1e-5) / max_int
assert torch.isnan(scales).sum() == 0
x_q = (torch.clamp(torch.round(x / scales), min_int, max_int)) * scales
assert torch.isnan(x_q).sum() == 0
x = x.reshape(org_w_shape)
x_q = x_q.reshape(org_w_shape)
return x_q
class SteInt2Quantizer(nn.Module):
def __init__(self, group_size):
super().__init__()
self.group_size = group_size
def forward(self, x):
org_w_shape = x.shape
if self.group_size > 0:
assert x.shape[-1] % self.group_size == 0
x = x.reshape(-1, self.group_size)
elif self.group_size == -1:
x = x.reshape(-1, x.shape[-1])
assert x.dim() == 2
scales = 1.0 / (x.abs().mean(dim=1, keepdim=True).clamp_(min=1e-5) * 1)
x_q = (torch.clamp(torch.round(x * scales),-2,1) / scales)
assert torch.isnan(x_q).sum() == 0
x = x.reshape(org_w_shape)
x_q = x_q.reshape(org_w_shape)
return x_q
def quantize_model_bin(input_bin_path, output_bin_path, quant_type="ternary", bit=2, group_size=128, device="cuda" if torch.cuda.is_available() else "cpu"):
"""
直接对PyTorch模型bin文件进行量化。
Args:
input_bin_path: 输入模型bin文件路径
output_bin_path: 输出量化后的模型bin文件路径
quant_type: 量化类型 ("ternary""int")
bit: 整数量化的位数 (仅在 quant_type="int" 时使用)
group_size: 量化分组大小
device: 运行设备
"""
print(f"加载模型文件: {input_bin_path}...")
if input_bin_path.endswith(".bin"):
state_dict = torch.load(input_bin_path, map_location=device)
elif input_bin_path.endswith(".safetensors"):
state_dict = safetensors.load_file(input_bin_path)
elif os.path.isdir(input_bin_path) and os.path.exists(os.path.join(input_bin_path, "pytorch_model.bin")):
state_dict = torch.load(os.path.join(input_bin_path, "pytorch_model.bin"), map_location=device)
elif os.path.isdir(input_bin_path) and os.path.exists(os.path.join(input_bin_path, "model.safetensors")):
state_dict = safetensors.load_file(os.path.join(input_bin_path, "model.safetensors"))
else:
raise ValueError(f"不支持的模型文件类型: {input_bin_path}")
print(f"应用 {quant_type} 量化...")
if quant_type == "ternary":
quantizer = SteTernaryQuantizer(group_size=group_size)
elif quant_type == "int":
quantizer = SteIntQuantizer(bit=bit, group_size=group_size)
elif quant_type == "int2":
quantizer = SteInt2Quantizer(group_size=group_size)
else:
raise ValueError(f"不支持的量化类型: {quant_type}")
# 统计需要量化的参数数量
total_params = sum(1 for k, v in state_dict.items() if ("weight" in k and "layer" in k) or ("fc" in k))
# 应用量化
with torch.no_grad():
for name, param in tqdm(state_dict.items(), total=total_params, desc="量化中"):
if (("weight" in name and "layer" in name and param.dim() == 2) or ("fc" in name and param.dim() == 2)):
# 对权重进行量化
original_weight = param.data.clone()
quantized_weight = quantizer(original_weight)
state_dict[name] = quantized_weight
# 打印前几个层的统计信息
if total_params > 0:
total_params -= 1
if total_params > total_params - 5:
print(f"层: {name}")
print(f" 原始范围: {original_weight.min():.4f}{original_weight.max():.4f}")
print(f" 量化后范围: {quantized_weight.min():.4f}{quantized_weight.max():.4f}")
print(f" 均方误差: {((original_weight - quantized_weight)**2).mean():.8f}")
# 保存量化后的模型
print(f"保存量化后的模型到: {output_bin_path}...")
if output_bin_path.endswith(".bin"):
torch.save(state_dict, output_bin_path)
elif output_bin_path.endswith(".safetensors"):
safetensors.save_file(state_dict, output_bin_path)
else:
os.makedirs(os.path.dirname(output_bin_path), exist_ok=True)
output_bin_path = os.path.join(output_bin_path, "pytorch_model.bin")
torch.save(state_dict, output_bin_path)
print("完成!")
def main():
import argparse
parser = argparse.ArgumentParser(description="量化PyTorch模型bin文件")
parser.add_argument("--input_bin", type=str, required=True, help="输入模型bin文件路径")
parser.add_argument("--output", type=str, required=True, help="输出量化后的模型bin文件路径")
parser.add_argument("--quant_type", type=str, default="ternary", choices=["ternary", "int", "int2"], help="量化类型")
parser.add_argument("--bit", type=int, default=2, help="整数量化的位数")
parser.add_argument("--group_size", type=int, default=-1, help="量化分组大小")
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="运行设备")
parser.add_argument("--config", type=str, default="", help="model config file")
args = parser.parse_args()
os.makedirs(args.output, exist_ok=True)
quantize_model_bin(
input_bin_path=args.input_bin,
output_bin_path=os.path.join(args.output, "pytorch_model.bin"),
quant_type=args.quant_type,
bit=args.bit,
group_size=args.group_size,
device=args.device
)
if args.config:
os.system(f"cp {args.config}/* {args.output}")
print(f"复制{args.config}文件到{args.output}")
if __name__ == "__main__":
main()

81
special_tokens_map.json Normal file
View File

@@ -0,0 +1,81 @@
{
"additional_special_tokens": [
{
"content": "<|im_end|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
},
{
"content": "<|im_start|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
},
{
"content": "<|tool_call|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
},
{
"content": "<|execute_start|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
},
{
"content": "<|execute_end|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
},
{
"content": "<|fim_prefix|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
},
{
"content": "<|fim_middle|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
},
{
"content": "<|fim_suffix|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
}
],
"bos_token": {
"content": "<s>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
},
"eos_token": {
"content": "</s>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
},
"unk_token": {
"content": "<unk>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
}
}

177952
tokenizer.json Normal file

File diff suppressed because it is too large Load Diff

3
tokenizer.model Normal file
View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:bb74d51116831c3bf65db812c553f94ab0c88dcf97a5bbb37e3504f6d359c530
size 1181204

116
tokenizer_config.json Normal file
View File

@@ -0,0 +1,116 @@
{
"add_bos_token": true,
"add_eos_token": false,
"added_tokens_decoder": {
"0": {
"content": "<unk>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"1": {
"content": "<s>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"2": {
"content": "</s>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"73440": {
"content": "<|im_end|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"73441": {
"content": "<|im_start|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"73442": {
"content": "<|tool_call|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"73443": {
"content": "<|execute_start|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"73444": {
"content": "<|execute_end|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"73445": {
"content": "<|fim_prefix|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"73446": {
"content": "<|fim_middle|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"73447": {
"content": "<|fim_suffix|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
}
},
"additional_special_tokens": [
"<|im_end|>",
"<|im_start|>",
"<|tool_call|>",
"<|execute_start|>",
"<|execute_end|>",
"<|fim_prefix|>",
"<|fim_middle|>",
"<|fim_suffix|>"
],
"bos_token": "<s>",
"clean_up_tokenization_spaces": false,
"eos_token": "<|im_end|>",
"legacy": true,
"model_max_length": 1000000000000000019884624838656,
"pad_token": null,
"sp_model_kwargs": {},
"spaces_between_special_tokens": false,
"tokenizer_class": "LlamaTokenizer",
"unk_token": "<unk>",
"use_default_system_prompt": false,
"chat_template": "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
}