初始化项目,由ModelHub XC社区提供模型
Model: openbmb/BitCPM-CANN-0.5B-unquantized Source: Original Platform
This commit is contained in:
35
.gitattributes
vendored
Normal file
35
.gitattributes
vendored
Normal file
@@ -0,0 +1,35 @@
|
||||
*.7z filter=lfs diff=lfs merge=lfs -text
|
||||
*.arrow filter=lfs diff=lfs merge=lfs -text
|
||||
*.bin filter=lfs diff=lfs merge=lfs -text
|
||||
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
||||
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
||||
*.ftz filter=lfs diff=lfs merge=lfs -text
|
||||
*.gz filter=lfs diff=lfs merge=lfs -text
|
||||
*.h5 filter=lfs diff=lfs merge=lfs -text
|
||||
*.joblib filter=lfs diff=lfs merge=lfs -text
|
||||
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
||||
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
||||
*.model filter=lfs diff=lfs merge=lfs -text
|
||||
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
||||
*.npy filter=lfs diff=lfs merge=lfs -text
|
||||
*.npz filter=lfs diff=lfs merge=lfs -text
|
||||
*.onnx filter=lfs diff=lfs merge=lfs -text
|
||||
*.ot filter=lfs diff=lfs merge=lfs -text
|
||||
*.parquet filter=lfs diff=lfs merge=lfs -text
|
||||
*.pb filter=lfs diff=lfs merge=lfs -text
|
||||
*.pickle filter=lfs diff=lfs merge=lfs -text
|
||||
*.pkl filter=lfs diff=lfs merge=lfs -text
|
||||
*.pt filter=lfs diff=lfs merge=lfs -text
|
||||
*.pth filter=lfs diff=lfs merge=lfs -text
|
||||
*.rar filter=lfs diff=lfs merge=lfs -text
|
||||
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
||||
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
||||
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
||||
*.tar filter=lfs diff=lfs merge=lfs -text
|
||||
*.tflite filter=lfs diff=lfs merge=lfs -text
|
||||
*.tgz filter=lfs diff=lfs merge=lfs -text
|
||||
*.wasm filter=lfs diff=lfs merge=lfs -text
|
||||
*.xz filter=lfs diff=lfs merge=lfs -text
|
||||
*.zip filter=lfs diff=lfs merge=lfs -text
|
||||
*.zst filter=lfs diff=lfs merge=lfs -text
|
||||
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
||||
126
README.md
Normal file
126
README.md
Normal file
@@ -0,0 +1,126 @@
|
||||
---
|
||||
license: apache-2.0
|
||||
language:
|
||||
- zh
|
||||
- en
|
||||
pipeline_tag: text-generation
|
||||
library_name: transformers
|
||||
---
|
||||
<div align="center">
|
||||
<img src="https://github.com/OpenBMB/MiniCPM/blob/main/assets/minicpm_logo.png?raw=true" width="500em" ></img>
|
||||
</div>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://github.com/OpenBMB/MiniCPM/" target="_blank">GitHub Repo</a> |
|
||||
<a href="https://github.com/OpenBMB/MiniCPM/blob/main/docs/BitCPM_CANN.pdf" target="_blank">Technical Report</a>
|
||||
</p>
|
||||
<p align="center">
|
||||
👋 Join us on <a href="https://discord.gg/3cGQn9b3YM" target="_blank">Discord</a> and <a href="https://github.com/OpenBMB/MiniCPM/blob/main/assets/wechat.jpg" target="_blank">WeChat</a>
|
||||
</p>
|
||||
|
||||
## Overview
|
||||
|
||||
BitCPM-CANN-0.5B-unquantized is the **unquantized QAT (Quantization-Aware Training) checkpoint** of BitCPM-CANN-0.5B, designed for **continued pre-training and fine-tuning**. It preserves full-precision latent weights with ternary fake quantizers (weights → {-1, 0, 1} with group-wise scaling, trained via STE) defined in `modeling.py`, enabling the model to keep learning under quantization constraints. For technical details, see our [Technical Report](https://github.com/OpenBMB/MiniCPM/blob/main/docs/BitCPM_CANN.pdf).
|
||||
|
||||
> ⚠️ **This model is NOT for direct inference.** For inference, use the pseudo-quantized version: [openbmb/BitCPM-CANN-0.5B](https://huggingface.co/openbmb/BitCPM-CANN-0.5B).
|
||||
|
||||
## Continued Pre-training & Fine-tuning
|
||||
|
||||
The **only requirement** is that the forward pass must go through the bundled `modeling.py` (which contains the ternary fake quantizer). Load with `trust_remote_code=True` and do NOT replace or bypass the model's forward logic.
|
||||
|
||||
### Option 1: DeepSpeed (Recommended)
|
||||
|
||||
We provide ready-to-use training scripts in the [example](https://huggingface.co/openbmb/BitCPM-CANN-0.5B-unquantized/tree/main/example) directory (using the 1B model as an example):
|
||||
|
||||
- **Continued pre-training**: `example/run.sh` + `example/train.py`
|
||||
- **SFT (Supervised Fine-tuning)**: `example/run_sft.sh` + `example/train_sft.py`
|
||||
|
||||
Quick start:
|
||||
|
||||
```bash
|
||||
# Continued pre-training
|
||||
cd example && bash run.sh
|
||||
|
||||
# Supervised fine-tuning
|
||||
cd example && bash run_sft.sh
|
||||
```
|
||||
|
||||
### Option 2: HuggingFace-compatible Frameworks
|
||||
|
||||
Any framework that supports HuggingFace model loading with custom code can be used, such as **LLaMA Factory**, **HuggingFace Trainer**, etc. The key is to ensure `trust_remote_code=True`:
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
path = 'openbmb/BitCPM-CANN-0.5B-unquantized'
|
||||
tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
path,
|
||||
torch_dtype=torch.bfloat16,
|
||||
trust_remote_code=True
|
||||
)
|
||||
|
||||
# Use with your preferred framework (LLaMA Factory, HF Trainer, etc.)
|
||||
# The ternary fake quantizer in modeling.py is applied automatically during forward pass.
|
||||
```
|
||||
|
||||
## Post-Training Conversion
|
||||
|
||||
After training, use `qat-convert.py` to fuse the fake quantizer and produce inference-ready pseudo-quantized weights:
|
||||
|
||||
```bash
|
||||
python qat-convert.py \
|
||||
--input_bin <path-to-finetuned-pytorch.bin> \
|
||||
--output <path-to-output-pseudo-quantized-pytorch.bin> \
|
||||
--quant_type ternary \
|
||||
--group_size -1
|
||||
```
|
||||
|
||||
The converted model can be loaded for inference in the same way as [openbmb/BitCPM-CANN-0.5B](https://huggingface.co/openbmb/BitCPM-CANN-0.5B)—no special quantization libraries required.
|
||||
|
||||
## Workflow
|
||||
|
||||
```
|
||||
┌─────────────────────────────────┐
|
||||
│ BitCPM-CANN-0.5B-unquantized │ ← This model (QAT checkpoint + fake quantizer in modeling.py)
|
||||
└───────────────┬─────────────────┘
|
||||
│
|
||||
▼ Train (DeepSpeed / LLaMA Factory / HF Trainer / ...)
|
||||
┌─────────────────────────────────┐
|
||||
│ Fine-tuned checkpoint │ ← Still contains un-fused QAT parameters
|
||||
└───────────────┬─────────────────┘
|
||||
│
|
||||
▼ python qat-convert.py --quant_type ternary --group_size -1
|
||||
┌─────────────────────────────────┐
|
||||
│ Pseudo-quantized model │ ← Ready for inference (same format as BitCPM-CANN-0.5B)
|
||||
└─────────────────────────────────┘
|
||||
```
|
||||
|
||||
## BitCPM-CANN Model Family
|
||||
|
||||
| Model | HuggingFace (Inference) | HuggingFace (Fine-tuning) |
|
||||
|-------|-------------------------|---------------------------|
|
||||
| BitCPM-CANN-0.5B | [openbmb/BitCPM-CANN-0.5B](https://huggingface.co/openbmb/BitCPM-CANN-0.5B) | [openbmb/BitCPM-CANN-0.5B-unquantized](https://huggingface.co/openbmb/BitCPM-CANN-0.5B-unquantized) |
|
||||
| BitCPM-CANN-1B | [openbmb/BitCPM-CANN-1B](https://huggingface.co/openbmb/BitCPM-CANN-1B) | [openbmb/BitCPM-CANN-1B-unquantized](https://huggingface.co/openbmb/BitCPM-CANN-1B-unquantized) |
|
||||
| BitCPM-CANN-3B | [openbmb/BitCPM-CANN-3B](https://huggingface.co/openbmb/BitCPM-CANN-3B) | [openbmb/BitCPM-CANN-3B-unquantized](https://huggingface.co/openbmb/BitCPM-CANN-3B-unquantized) |
|
||||
| BitCPM-CANN-8B | [openbmb/BitCPM-CANN-8B](https://huggingface.co/openbmb/BitCPM-CANN-8B) | [openbmb/BitCPM-CANN-8B-unquantized](https://huggingface.co/openbmb/BitCPM-CANN-8B-unquantized) |
|
||||
|
||||
## Statement
|
||||
- As a language model, BitCPM-CANN generates content by learning from a vast amount of text.
|
||||
- However, it does not possess the ability to comprehend or express personal opinions or value judgments.
|
||||
- Any content generated by BitCPM-CANN does not represent the viewpoints or positions of the model developers.
|
||||
- Therefore, when using content generated by BitCPM-CANN, users should take full responsibility for evaluating and verifying it on their own.
|
||||
|
||||
## LICENSE
|
||||
- This repository and BitCPM-CANN models are released under the [Apache-2.0](https://github.com/OpenBMB/MiniCPM/blob/main/LICENSE) License.
|
||||
|
||||
## Citation
|
||||
- Please cite our technical report if you find our work valuable.
|
||||
|
||||
```bibtex
|
||||
@article{bitcpmcann,
|
||||
title={{BitCPM-CANN}: Native 1.58-Bit Large Language Model Training on Ascend NPU},
|
||||
author={BitCPM Team},
|
||||
year={2026}
|
||||
}
|
||||
```
|
||||
10
added_tokens.json
Normal file
10
added_tokens.json
Normal file
@@ -0,0 +1,10 @@
|
||||
{
|
||||
"<|execute_end|>": 73444,
|
||||
"<|execute_start|>": 73443,
|
||||
"<|fim_middle|>": 73446,
|
||||
"<|fim_prefix|>": 73445,
|
||||
"<|fim_suffix|>": 73447,
|
||||
"<|im_end|>": 73440,
|
||||
"<|im_start|>": 73441,
|
||||
"<|tool_call|>": 73442
|
||||
}
|
||||
37
config.json
Normal file
37
config.json
Normal file
@@ -0,0 +1,37 @@
|
||||
{
|
||||
"_name_or_path": "openbmb/MiniCPM4-0.5B",
|
||||
"architectures": [
|
||||
"MiniCPMForCausalLM"
|
||||
],
|
||||
"auto_map": {
|
||||
"AutoConfig": "configuration_minicpm.MiniCPMConfig",
|
||||
"AutoModel": "modeling_minicpm.MiniCPMModel",
|
||||
"AutoModelForCausalLM": "modeling_minicpm.MiniCPMForCausalLM",
|
||||
"AutoModelForSeq2SeqLM": "modeling_minicpm.MiniCPMForCausalLM",
|
||||
"AutoModelForSequenceClassification": "modeling_minicpm.MiniCPMForSequenceClassification"
|
||||
},
|
||||
"bos_token_id": 1,
|
||||
"eos_token_id": [2, 73440],
|
||||
"hidden_act": "silu",
|
||||
"hidden_size": 1024,
|
||||
"initializer_range": 0.1,
|
||||
"intermediate_size": 4096,
|
||||
"max_position_embeddings": 32768,
|
||||
"num_attention_heads": 16,
|
||||
"num_hidden_layers": 24,
|
||||
"num_key_value_heads": 2,
|
||||
"rms_norm_eps": 1e-05,
|
||||
"rope_scaling": {
|
||||
"rope_type": "longrope",
|
||||
"long_factor": [1.0004360675811768, 1.0668443441390991, 1.1631425619125366, 1.3025742769241333, 1.5040205717086792, 1.7941505908966064, 2.2101221084594727, 2.802666664123535, 3.6389970779418945, 4.804192543029785, 6.39855432510376, 8.527148246765137, 11.277542114257812, 14.684998512268066, 18.69317054748535, 23.13019371032715, 27.72362518310547, 32.1606559753418, 36.168827056884766, 39.57627868652344, 42.32667541503906, 44.45526885986328, 46.04962921142578, 47.21482849121094, 48.05115509033203, 48.64370346069336, 49.05967712402344, 49.34980392456055, 49.551246643066406, 49.69068145751953, 49.78697967529297, 49.85338592529297],
|
||||
"short_factor": [1.0004360675811768, 1.0668443441390991, 1.1631425619125366, 1.3025742769241333, 1.5040205717086792, 1.7941505908966064, 2.2101221084594727, 2.802666664123535, 3.6389970779418945, 4.804192543029785, 6.39855432510376, 8.527148246765137, 11.277542114257812, 14.684998512268066, 18.69317054748535, 23.13019371032715, 27.72362518310547, 32.1606559753418, 36.168827056884766, 39.57627868652344, 42.32667541503906, 44.45526885986328, 46.04962921142578, 47.21482849121094, 48.05115509033203, 48.64370346069336, 49.05967712402344, 49.34980392456055, 49.551246643066406, 49.69068145751953, 49.78697967529297, 49.85338592529297],
|
||||
"original_max_position_embeddings": 32768
|
||||
},
|
||||
"torch_dtype": "bfloat16",
|
||||
"transformers_version": "4.46.3",
|
||||
"use_cache": true,
|
||||
"vocab_size": 73448,
|
||||
"scale_emb": 12,
|
||||
"dim_model_base": 256,
|
||||
"scale_depth": 1.4
|
||||
}
|
||||
203
configuration_minicpm.py
Normal file
203
configuration_minicpm.py
Normal file
@@ -0,0 +1,203 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The OpenBMB Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" MiniCPM model configuration"""
|
||||
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.utils import logging
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
MINICPM_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
|
||||
|
||||
|
||||
class MiniCPMConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`MiniCPMModel`]. It is used to instantiate an MiniCPM
|
||||
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
||||
defaults will yield a similar configuration to that of the MiniCPM-7B.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 32000):
|
||||
Vocabulary size of the MiniCPM model. Defines the number of different tokens that can be represented by the
|
||||
`inputs_ids` passed when calling [`MiniCPMModel`]
|
||||
hidden_size (`int`, *optional*, defaults to 4096):
|
||||
Dimension of the hidden representations.
|
||||
intermediate_size (`int`, *optional*, defaults to 11008):
|
||||
Dimension of the MLP representations.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 32):
|
||||
Number of hidden layers in the Transformer decoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 32):
|
||||
Number of attention heads for each attention layer in the Transformer decoder.
|
||||
num_key_value_heads (`int`, *optional*):
|
||||
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
||||
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
||||
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
||||
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
||||
by meanpooling all the original heads within that group. For more details checkout [this
|
||||
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
|
||||
`num_attention_heads`.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
||||
The non-linear activation function (function or string) in the decoder.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 2048):
|
||||
The maximum sequence length that this model might ever be used with. MiniCPM 1 supports up to 2048 tokens,
|
||||
MiniCPM 2 up to 4096, CodeMiniCPM up to 16384.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
||||
The epsilon used by the rms normalization layers.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||
relevant if `config.is_decoder=True`.
|
||||
pad_token_id (`int`, *optional*):
|
||||
Padding token id.
|
||||
bos_token_id (`int`, *optional*, defaults to 1):
|
||||
Beginning of stream token id.
|
||||
eos_token_id (`int`, *optional*, defaults to 2):
|
||||
End of stream token id.
|
||||
pretraining_tp (`int`, *optional*, defaults to 1):
|
||||
Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
|
||||
document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
|
||||
necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
|
||||
issue](https://github.com/pytorch/pytorch/issues/76232).
|
||||
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
||||
Whether to tie weight embeddings
|
||||
rope_theta (`float`, *optional*, defaults to 10000.0):
|
||||
The base period of the RoPE embeddings.
|
||||
rope_scaling (`Dict`, *optional*):
|
||||
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
|
||||
strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
|
||||
`{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
|
||||
`max_position_embeddings` to the expected new maximum. See the following thread for more information on how
|
||||
these scaling strategies behave:
|
||||
https://www.reddit.com/r/LocalMiniCPM/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
|
||||
experimental feature, subject to breaking API changes in future versions.
|
||||
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
|
||||
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
|
||||
```python
|
||||
>>> from transformers import MiniCPMModel, MiniCPMConfig
|
||||
|
||||
>>> # Initializing a MiniCPM minicpm-7b style configuration
|
||||
>>> configuration = MiniCPMConfig()
|
||||
|
||||
>>> # Initializing a model from the minicpm-7b style configuration
|
||||
>>> model = MiniCPMModel(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = 'minicpm'
|
||||
keys_to_ignore_at_inference = ['past_key_values']
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=32000,
|
||||
hidden_size=4096,
|
||||
intermediate_size=11008,
|
||||
num_hidden_layers=32,
|
||||
num_attention_heads=32,
|
||||
num_key_value_heads=None,
|
||||
hidden_act='silu',
|
||||
max_position_embeddings=2048,
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=1e-6,
|
||||
use_cache=True,
|
||||
pad_token_id=None,
|
||||
bos_token_id=1,
|
||||
eos_token_id=2,
|
||||
pretraining_tp=1,
|
||||
tie_word_embeddings=True,
|
||||
rope_theta=10000.0,
|
||||
rope_scaling=None,
|
||||
attention_bias=False,
|
||||
attention_dropout=0.0,
|
||||
scale_emb=1,
|
||||
dim_model_base=1,
|
||||
scale_depth=1,
|
||||
mup_denominator=None,
|
||||
sparse_config=None,
|
||||
**kwargs):
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
|
||||
# for backward compatibility
|
||||
if num_key_value_heads is None:
|
||||
num_key_value_heads = num_attention_heads
|
||||
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.pretraining_tp = pretraining_tp
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = rope_scaling
|
||||
# self._rope_scaling_validation()
|
||||
self.attention_bias = attention_bias
|
||||
self.attention_dropout = attention_dropout
|
||||
self.scale_emb = scale_emb
|
||||
self.dim_model_base = dim_model_base
|
||||
self.scale_depth = scale_depth
|
||||
# only used for Eagle Head
|
||||
self.mup_denominator = mup_denominator
|
||||
|
||||
# sparse config
|
||||
self.sparse_config = sparse_config
|
||||
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
try:
|
||||
import flash_attn
|
||||
self._attn_implementation = 'flash_attention_2'
|
||||
except:
|
||||
pass
|
||||
|
||||
def _rope_scaling_validation(self):
|
||||
"""
|
||||
Validate the `rope_scaling` configuration.
|
||||
"""
|
||||
if self.rope_scaling is None:
|
||||
return
|
||||
|
||||
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
|
||||
raise ValueError(
|
||||
'`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, '
|
||||
f'got {self.rope_scaling}'
|
||||
)
|
||||
rope_scaling_type = self.rope_scaling.get('type', None)
|
||||
rope_scaling_factor = self.rope_scaling.get('factor', None)
|
||||
if rope_scaling_type is None or rope_scaling_type not in ['linear', 'dynamic']:
|
||||
raise ValueError(
|
||||
f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
|
||||
)
|
||||
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
|
||||
raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
|
||||
105
example/README.md
Normal file
105
example/README.md
Normal file
@@ -0,0 +1,105 @@
|
||||
# BitCPM Training Example
|
||||
|
||||
This project provides scripts for continue pretraining (CPT) and supervised fine-tuning (SFT) of **BitCPM-CANN-1B-unquantized**.
|
||||
|
||||
## File Description
|
||||
|
||||
CPT and SFT each have a pair of scripts (training script + launch script) and share DeepSpeed configuration files:
|
||||
|
||||
| File | Description |
|
||||
| --- | --- |
|
||||
| `run.sh` | Launch script for CPT with hyperparameter configuration |
|
||||
| `run_sft.sh` | Launch script for SFT with hyperparameter configuration |
|
||||
| `train.py` | Continue pretrain script based on HuggingFace Trainer + DeepSpeed |
|
||||
| `train_sft.py` | Supervised fine-tuning script based on HuggingFace Trainer + DeepSpeed |
|
||||
| `ds_config.json` | DeepSpeed ZeRO-3 configuration (with CPU offload) |
|
||||
| `ds_config_z2.json` | DeepSpeed ZeRO-2 configuration (used by default) |
|
||||
| `requirements.txt` | Python dependency list |
|
||||
|
||||
## Environment Setup
|
||||
|
||||
### Docker Image
|
||||
|
||||
Use the following Huawei NPU image on 910C:
|
||||
|
||||
```
|
||||
swr.cn-south-1.myhuaweicloud.com/ascendhub/mindspeed-llm:openeuler22.03-mindspeed-llm-2.3.0-a3-arm
|
||||
```
|
||||
|
||||
Other Huawei NPU images may also work but have not been fully tested.
|
||||
|
||||
For GPU environments, there are no special image requirements — just install `requirements.txt` directly.
|
||||
|
||||
### Install Dependencies
|
||||
|
||||
After entering the container, install the Python dependencies:
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## Continue Pretrain (CPT)
|
||||
|
||||
### Dataset
|
||||
|
||||
The test dataset used is [C4-Pro](https://huggingface.co/datasets/gair-prox/c4-pro), stored in parquet format after downloading.
|
||||
|
||||
### Usage
|
||||
|
||||
Modify the path configuration in `run.sh`:
|
||||
|
||||
```bash
|
||||
MODEL_PATH="/path/to/BitCPM-CANN-1B-unquantized/"
|
||||
DATA_PATH="/path/to/c4-pro/data/your_file.parquet"
|
||||
```
|
||||
|
||||
Then start training:
|
||||
|
||||
```bash
|
||||
bash run.sh
|
||||
```
|
||||
|
||||
## Supervised Fine-Tuning (SFT)
|
||||
|
||||
### Dataset
|
||||
|
||||
The test dataset used is [UltraChat 200k](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k), stored in parquet format after downloading.
|
||||
|
||||
### Usage
|
||||
|
||||
Modify the path configuration in `run_sft.sh`:
|
||||
|
||||
```bash
|
||||
MODEL_PATH="/path/to/BitCPM-CANN-1B-unquantized/"
|
||||
DATA_PATH="/path/to/ultrachat_200k/data/your_file.parquet"
|
||||
```
|
||||
|
||||
Then start training:
|
||||
|
||||
```bash
|
||||
bash run_sft.sh
|
||||
```
|
||||
|
||||
## Training Results Reference
|
||||
|
||||
> **Note:** BitCPM has its own training dataset and data mixture. It is expected that the loss continues to decrease when training on open-source datasets.
|
||||
|
||||
Below are the loss curves from smoke tests on GPU and NPU for both CPT and SFT tasks. The results are highly consistent across GPU and NPU, indicating that users can continue pre-training or fine-tuning on various compute devices:
|
||||
|
||||
| | GPU | NPU |
|
||||
| --- | --- | --- |
|
||||
| **CPT** |  |  |
|
||||
| **SFT** |  |  |
|
||||
|
||||
Training log CSV files (corresponding to the loss curves above):
|
||||
|
||||
| CSV File | Corresponding Loss Curve |
|
||||
| --- | --- |
|
||||
| [gpu_pretrain.csv](gpu_pretrain.csv) | GPU CPT |
|
||||
| [npu_pretrain.csv](npu_pretrain.csv) | NPU CPT |
|
||||
| [gpu_sft.csv](gpu_sft.csv) | GPU SFT |
|
||||
| [npu_sft.csv](npu_sft.csv) | NPU SFT |
|
||||
|
||||
---
|
||||
|
||||
These scripts provide a convenient, ready-to-use toolkit for QAT-aware continued pre-training and fine-tuning of BitCPM-CANN models, so you can quickly adapt the model to your own data and tasks while preserving ternary quantization constraints.
|
||||
29
example/ds_config.json
Normal file
29
example/ds_config.json
Normal file
@@ -0,0 +1,29 @@
|
||||
{
|
||||
"bf16": {
|
||||
"enabled": true
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": 3,
|
||||
"offload_optimizer": {
|
||||
"device": "cpu",
|
||||
"pin_memory": true
|
||||
},
|
||||
"offload_param": {
|
||||
"device": "none"
|
||||
},
|
||||
"overlap_comm": true,
|
||||
"contiguous_gradients": true,
|
||||
"sub_group_size": 1e9,
|
||||
"reduce_bucket_size": 2e8,
|
||||
"stage3_prefetch_bucket_size": 2e8,
|
||||
"stage3_param_persistence_threshold": 1e5,
|
||||
"stage3_max_live_parameters": 2e9,
|
||||
"stage3_max_reuse_distance": 2e9,
|
||||
"stage3_gather_16bit_weights_on_model_save": true
|
||||
},
|
||||
"gradient_accumulation_steps": "auto",
|
||||
"gradient_clipping": "auto",
|
||||
"train_batch_size": "auto",
|
||||
"train_micro_batch_size_per_gpu": "auto",
|
||||
"wall_clock_breakdown": false
|
||||
}
|
||||
22
example/ds_config_z2.json
Normal file
22
example/ds_config_z2.json
Normal file
@@ -0,0 +1,22 @@
|
||||
{
|
||||
"bf16": {
|
||||
"enabled": true
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": 2,
|
||||
"offload_optimizer": {
|
||||
"device": "none"
|
||||
},
|
||||
"allgather_partitions": true,
|
||||
"allgather_bucket_size": 2e8,
|
||||
"overlap_comm": true,
|
||||
"reduce_scatter": true,
|
||||
"reduce_bucket_size": 2e8,
|
||||
"contiguous_gradients": true
|
||||
},
|
||||
"gradient_accumulation_steps": "auto",
|
||||
"gradient_clipping": "auto",
|
||||
"train_batch_size": "auto",
|
||||
"train_micro_batch_size_per_gpu": "auto",
|
||||
"wall_clock_breakdown": false
|
||||
}
|
||||
51
example/gpu_pretrain.csv
Normal file
51
example/gpu_pretrain.csv
Normal file
@@ -0,0 +1,51 @@
|
||||
step,train/loss,train/grad_norm,train/learning_rate,train/epoch,train/train_runtime,train/train_samples_per_second,train/train_steps_per_second,train/total_flos,train/train_loss
|
||||
2,2.7920000553131104,0.03527498617768288,7.999999979801942e-06,0.010457516647875309,,,,,
|
||||
4,2.8011999130249023,0.03495891019701958,1.5999999959603883e-05,0.020915033295750618,,,,,
|
||||
6,2.7964000701904297,0.03271934762597084,2.4000000848900527e-05,0.0313725508749485,,,,,
|
||||
8,2.763700008392334,0.024968057870864868,3.199999991920777e-05,0.041830066591501236,,,,,
|
||||
10,3.281599998474121,0.31758183240890503,3.9999998989515007e-05,0.05228758230805397,,,,,
|
||||
12,2.941200017929077,0.044055406004190445,3.995128281530924e-05,0.062745101749897,,,,,
|
||||
14,2.851799964904785,0.03649706766009331,3.9805359847377986e-05,0.07320261746644974,,,,,
|
||||
16,2.7869999408721924,0.022624235600233078,3.9562950405525044e-05,0.08366013318300247,,,,,
|
||||
18,2.7825000286102295,0.021830420941114426,3.922523319488391e-05,0.0941176488995552,,,,,
|
||||
20,2.7857000827789307,0.01685911975800991,3.87938525818754e-05,0.10457516461610794,,,,,
|
||||
22,2.7571001052856445,0.01572061888873577,3.827090768027119e-05,0.11503268033266068,,,,,
|
||||
24,2.762399911880493,0.016891509294509888,3.7658952351193875e-05,0.125490203499794,,,,,
|
||||
26,2.7411000728607178,0.015683824196457863,3.6960962461307645e-05,0.13594771921634674,,,,,
|
||||
28,2.733099937438965,0.012847283855080605,3.6180339520797133e-05,0.14640523493289948,,,,,
|
||||
30,2.723400115966797,0.015209181234240532,3.532088885549456e-05,0.1568627506494522,,,,,
|
||||
32,2.7342000007629395,0.01241038367152214,3.4386797779006884e-05,0.16732026636600494,,,,,
|
||||
34,2.7321999073028564,0.012879018671810627,3.338261376484297e-05,0.17777778208255768,,,,,
|
||||
36,2.7314000129699707,0.013242729939520359,3.231322989449836e-05,0.1882352977991104,,,,,
|
||||
38,2.7065999507904053,0.01113435160368681,3.118385939160362e-05,0.19869281351566315,,,,,
|
||||
40,2.6958999633789062,0.012413726188242435,2.9999999242136255e-05,0.20915032923221588,,,,,
|
||||
42,2.7516000270843506,0.011661508120596409,2.8767422918463126e-05,0.21960784494876862,,,,,
|
||||
44,2.713099956512451,0.012248368933796883,2.749213126662653e-05,0.23006536066532135,,,,,
|
||||
46,2.7102999687194824,0.011450185440480709,2.6180339773418382e-05,0.24052287638187408,,,,,
|
||||
48,2.7021000385284424,0.011155751533806324,2.483843854861334e-05,0.250980406999588,,,,,
|
||||
50,2.680500030517578,0.010021247901022434,2.3472963221138343e-05,0.26143792271614075,,,,,
|
||||
52,2.699199914932251,0.010751751251518726,2.2090569473220967e-05,0.2718954384326935,,,,,
|
||||
54,2.694200038909912,0.010503941215574741,2.0697989384643734e-05,0.2823529541492462,,,,,
|
||||
56,2.7091000080108643,0.010059370659291744,1.9302009604871273e-05,0.29281046986579895,,,,,
|
||||
58,2.699399948120117,0.012161476537585258,1.7909431335283443e-05,0.3032679855823517,,,,,
|
||||
60,2.7216999530792236,0.010671027936041355,1.6527035768376663e-05,0.3137255012989044,,,,,
|
||||
62,2.7158000469207764,0.010463157668709755,1.516156225989107e-05,0.32418301701545715,,,,,
|
||||
64,2.7214999198913574,0.010665320791304111,1.3819660125591327e-05,0.3346405327320099,,,,,
|
||||
66,2.7116000652313232,0.01046629250049591,1.2507867722888477e-05,0.3450980484485626,,,,,
|
||||
68,2.6923000812530518,0.010609752498567104,1.1232576980546582e-05,0.35555556416511536,,,,,
|
||||
70,2.6830999851226807,0.009290814399719238,9.999999747378752e-06,0.3660130798816681,,,,,
|
||||
72,2.7093000411987305,0.010727670043706894,8.816142326395493e-06,0.3764705955982208,,,,,
|
||||
74,2.698699951171875,0.0109737953171134,7.686770914006047e-06,0.38692811131477356,,,,,
|
||||
76,2.712599992752075,0.010320967063307762,6.61738795315614e-06,0.3973856270313263,,,,,
|
||||
78,2.6993000507354736,0.009841523133218288,5.613203938992228e-06,0.40784314274787903,,,,,
|
||||
80,2.6861000061035156,0.010179675184190273,4.6791110435151495e-06,0.41830065846443176,,,,,
|
||||
82,2.6828999519348145,0.009790077805519104,3.819659923465224e-06,0.4287581741809845,,,,,
|
||||
84,2.699199914932251,0.010508442297577858,3.03903811982309e-06,0.43921568989753723,,,,,
|
||||
86,2.6988000869750977,0.009589221328496933,2.3410482299368596e-06,0.44967320561408997,,,,,
|
||||
88,2.688499927520752,0.010065913200378418,1.7290908544964623e-06,0.4601307213306427,,,,,
|
||||
90,2.6928999423980713,0.010363687761127949,1.206147544507985e-06,0.47058823704719543,,,,,
|
||||
92,2.714200019836426,0.010142815299332142,7.74766078848188e-07,0.48104575276374817,,,,,
|
||||
94,2.672300100326538,0.009833029471337795,4.370479871340649e-07,0.4915032684803009,,,,,
|
||||
96,2.7018001079559326,0.009937037713825703,1.9463863054625108e-07,0.501960813999176,,,,,
|
||||
98,2.7121999263763428,0.009417451918125153,4.8718995060426096e-08,0.5124183297157288,,,,,
|
||||
100,2.7028000354766846,0.009256146848201752,0.0,0.5228758454322815,365.8839111328125,139.93499755859375,0.27300000190734863,4.629706395531346e+17,2.7395541667938232
|
||||
|
BIN
example/gpu_pretrain_loss.png
Normal file
BIN
example/gpu_pretrain_loss.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 49 KiB |
51
example/gpu_sft.csv
Normal file
51
example/gpu_sft.csv
Normal file
@@ -0,0 +1,51 @@
|
||||
step,train/loss,train/grad_norm,train/learning_rate,train/epoch,train/train_runtime,train/train_samples_per_second,train/train_steps_per_second,train/total_flos,train/train_loss
|
||||
2,1.1492999792099,0.6216375231742859,1.9999999949504854e-06,0.0004617871018126607,,,,,
|
||||
4,1.0979000329971313,0.681877851486206,3.999999989900971e-06,0.0009235742036253214,,,,,
|
||||
6,1.1269999742507935,0.784303605556488,6.000000212225132e-06,0.001385361305437982,,,,,
|
||||
8,1.0542000532150269,0.8737029433250427,7.999999979801942e-06,0.0018471484072506428,,,,,
|
||||
10,1.2440999746322632,0.7068291902542114,9.999999747378752e-06,0.0023089356254786253,,,,,
|
||||
12,1.2925000190734863,0.6821666955947876,1.2000000424450263e-05,0.002770722610875964,,,,,
|
||||
14,1.0843000411987305,0.525643527507782,1.4000000192027073e-05,0.0032325098291039467,,,,,
|
||||
16,1.0961999893188477,0.43757057189941406,1.5999999959603883e-05,0.0036942968145012856,,,,,
|
||||
18,1.0614999532699585,0.46141618490219116,1.8000000636675395e-05,0.004156084265559912,,,,,
|
||||
20,1.332900047302246,0.715879499912262,1.9999999494757503e-05,0.004617871250957251,,,,,
|
||||
22,1.2070000171661377,0.5926885008811951,1.996917308133561e-05,0.0050796582363545895,,,,,
|
||||
24,1.2043999433517456,0.5833240747451782,1.9876883015967906e-05,0.005541445221751928,,,,,
|
||||
26,1.0740000009536743,0.44734400510787964,1.9723698642337695e-05,0.0060032326728105545,,,,,
|
||||
28,1.1162999868392944,0.3701137900352478,1.9510565834934823e-05,0.006465019658207893,,,,,
|
||||
30,1.0454000234603882,0.43832680583000183,1.9238796085119247e-05,0.006926806643605232,,,,,
|
||||
32,1.124899983406067,0.4591037631034851,1.8910064682131633e-05,0.007388593629002571,,,,,
|
||||
34,1.0686999559402466,0.3873400390148163,1.8526401618146338e-05,0.00785038061439991,,,,,
|
||||
36,1.0291999578475952,0.40313437581062317,1.8090169760398567e-05,0.008312168531119823,,,,,
|
||||
38,1.1052000522613525,0.3735405504703522,1.7604059394216165e-05,0.008773955516517162,,,,,
|
||||
40,1.1555999517440796,0.3818407654762268,1.7071068214136176e-05,0.009235742501914501,,,,,
|
||||
42,1.0235999822616577,0.4255191683769226,1.6494481315021403e-05,0.00969752948731184,,,,,
|
||||
44,1.0364999771118164,0.4794503152370453,1.5877853002166376e-05,0.010159316472709179,,,,,
|
||||
46,1.1344000101089478,0.37273937463760376,1.5224985872919206e-05,0.010621103458106518,,,,,
|
||||
48,1.0866999626159668,0.417492538690567,1.453990535082994e-05,0.011082890443503857,,,,,
|
||||
50,1.1038000583648682,0.35408055782318115,1.3826834219798911e-05,0.01154467836022377,,,,,
|
||||
52,1.1478999853134155,0.3930828273296356,1.3090169886709191e-05,0.012006465345621109,,,,,
|
||||
54,1.1858999729156494,0.3965947926044464,1.2334453458606731e-05,0.012468252331018448,,,,,
|
||||
56,1.0096999406814575,0.3860221207141876,1.1564344276848715e-05,0.012930039316415787,,,,,
|
||||
58,1.114799976348877,0.44393691420555115,1.0784590813273098e-05,0.013391826301813126,,,,,
|
||||
60,1.079300045967102,0.3605058789253235,9.999999747378752e-06,0.013853613287210464,,,,,
|
||||
62,1.1766999959945679,0.40689122676849365,9.215408681484405e-06,0.014315400272607803,,,,,
|
||||
64,1.1075999736785889,0.4002344310283661,8.435655217908788e-06,0.014777187258005142,,,,,
|
||||
66,1.1866999864578247,0.46947163343429565,7.665546036150772e-06,0.015238975174725056,,,,,
|
||||
68,1.0311000347137451,0.3296957314014435,6.909830062795663e-06,0.01570076122879982,,,,,
|
||||
70,1.1088999509811401,0.33858785033226013,6.173165729705943e-06,0.01616254821419716,,,,,
|
||||
72,1.0720000267028809,0.3967427909374237,5.460095053422265e-06,0.016624337062239647,,,,,
|
||||
74,1.1460000276565552,0.41202062368392944,4.7750145313329995e-06,0.017086124047636986,,,,,
|
||||
76,1.0425000190734863,0.38334518671035767,4.1221474020858295e-06,0.017547911033034325,,,,,
|
||||
78,0.9154000282287598,0.40649303793907166,3.505519543978153e-06,0.018009698018431664,,,,,
|
||||
80,1.1110999584197998,0.35371580719947815,2.9289321901160292e-06,0.018471485003829002,,,,,
|
||||
82,1.1672999858856201,0.3381657302379608,2.3959403279150138e-06,0.01893327198922634,,,,,
|
||||
84,1.2374000549316406,0.3815234303474426,1.909829961732612e-06,0.01939505897462368,,,,,
|
||||
86,1.2151000499725342,0.38446080684661865,1.4735983313585166e-06,0.01985684596002102,,,,,
|
||||
88,1.163100004196167,0.40419140458106995,1.0899348126258701e-06,0.020318632945418358,,,,,
|
||||
90,1.1883000135421753,0.4011874198913574,7.612046601934708e-07,0.020780419930815697,,,,,
|
||||
92,1.1526999473571777,0.3836020231246948,4.894348535344761e-07,0.021242206916213036,,,,,
|
||||
94,1.15339994430542,0.452364057302475,2.7630079557638965e-07,0.021703993901610374,,,,,
|
||||
96,1.062000036239624,0.3502688705921173,1.2311659247643547e-07,0.022165780887007713,,,,,
|
||||
98,1.0271999835968018,0.4022065997123718,3.0826662111849146e-08,0.022627567872405052,,,,,
|
||||
100,1.0283000469207764,0.38241174817085266,0.0,0.02308935672044754,183.9481964111328,8.697999954223633,0.5440000295639038,1862467846144.0,1.1177252531051636
|
||||
|
BIN
example/gpu_sft_loss.png
Normal file
BIN
example/gpu_sft_loss.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 68 KiB |
51
example/npu_pretrain.csv
Normal file
51
example/npu_pretrain.csv
Normal file
@@ -0,0 +1,51 @@
|
||||
step,train/loss,train/grad_norm,train/learning_rate,train/epoch,train/train_runtime,train/train_samples_per_second,train/train_steps_per_second,train/total_flos,train/train_loss
|
||||
2,2.7920000553131104,0.035306449979543686,7.999999979801942e-06,0.010457516647875309,,,,,
|
||||
4,2.8011999130249023,0.03491510450839996,1.5999999959603883e-05,0.020915033295750618,,,,,
|
||||
6,2.7964000701904297,0.032717395573854446,2.4000000848900527e-05,0.0313725508749485,,,,,
|
||||
8,2.763700008392334,0.024953875690698624,3.199999991920777e-05,0.041830066591501236,,,,,
|
||||
10,3.2811999320983887,0.3170815408229828,3.9999998989515007e-05,0.05228758230805397,,,,,
|
||||
12,2.9409000873565674,0.04423849284648895,3.995128281530924e-05,0.062745101749897,,,,,
|
||||
14,2.851900100708008,0.03667925298213959,3.9805359847377986e-05,0.07320261746644974,,,,,
|
||||
16,2.7869999408721924,0.022814607247710228,3.9562950405525044e-05,0.08366013318300247,,,,,
|
||||
18,2.782599925994873,0.021528413519263268,3.922523319488391e-05,0.0941176488995552,,,,,
|
||||
20,2.785599946975708,0.017014438286423683,3.87938525818754e-05,0.10457516461610794,,,,,
|
||||
22,2.7571001052856445,0.015719758346676826,3.827090768027119e-05,0.11503268033266068,,,,,
|
||||
24,2.762399911880493,0.016948623582720757,3.7658952351193875e-05,0.125490203499794,,,,,
|
||||
26,2.7411000728607178,0.015535997226834297,3.6960962461307645e-05,0.13594771921634674,,,,,
|
||||
28,2.7330000400543213,0.012748735956847668,3.6180339520797133e-05,0.14640523493289948,,,,,
|
||||
30,2.723299980163574,0.014809778891503811,3.532088885549456e-05,0.1568627506494522,,,,,
|
||||
32,2.7342000007629395,0.01219236571341753,3.4386797779006884e-05,0.16732026636600494,,,,,
|
||||
34,2.7321999073028564,0.012785322032868862,3.338261376484297e-05,0.17777778208255768,,,,,
|
||||
36,2.7314000129699707,0.012986919842660427,3.231322989449836e-05,0.1882352977991104,,,,,
|
||||
38,2.7065999507904053,0.01096824835985899,3.118385939160362e-05,0.19869281351566315,,,,,
|
||||
40,2.6958999633789062,0.012387535534799099,2.9999999242136255e-05,0.20915032923221588,,,,,
|
||||
42,2.751499891281128,0.011586200445890427,2.8767422918463126e-05,0.21960784494876862,,,,,
|
||||
44,2.713099956512451,0.011821281164884567,2.749213126662653e-05,0.23006536066532135,,,,,
|
||||
46,2.7102999687194824,0.01147585827857256,2.6180339773418382e-05,0.24052287638187408,,,,,
|
||||
48,2.7019999027252197,0.011368263512849808,2.483843854861334e-05,0.250980406999588,,,,,
|
||||
50,2.680500030517578,0.009935515932738781,2.3472963221138343e-05,0.26143792271614075,,,,,
|
||||
52,2.6993000507354736,0.0109846917912364,2.2090569473220967e-05,0.2718954384326935,,,,,
|
||||
54,2.6940999031066895,0.010465175844728947,2.0697989384643734e-05,0.2823529541492462,,,,,
|
||||
56,2.7091000080108643,0.01009758748114109,1.9302009604871273e-05,0.29281046986579895,,,,,
|
||||
58,2.69950008392334,0.01249368954449892,1.7909431335283443e-05,0.3032679855823517,,,,,
|
||||
60,2.7216999530792236,0.01051376760005951,1.6527035768376663e-05,0.3137255012989044,,,,,
|
||||
62,2.7158000469207764,0.01054943073540926,1.516156225989107e-05,0.32418301701545715,,,,,
|
||||
64,2.7214999198913574,0.01076149195432663,1.3819660125591327e-05,0.3346405327320099,,,,,
|
||||
66,2.7116000652313232,0.010380392894148827,1.2507867722888477e-05,0.3450980484485626,,,,,
|
||||
68,2.6923000812530518,0.010425001382827759,1.1232576980546582e-05,0.35555556416511536,,,,,
|
||||
70,2.683199882507324,0.00925016961991787,9.999999747378752e-06,0.3660130798816681,,,,,
|
||||
72,2.7093000411987305,0.01072422880679369,8.816142326395493e-06,0.3764705955982208,,,,,
|
||||
74,2.6988000869750977,0.011063243262469769,7.686770914006047e-06,0.38692811131477356,,,,,
|
||||
76,2.7125000953674316,0.01013101264834404,6.61738795315614e-06,0.3973856270313263,,,,,
|
||||
78,2.6993000507354736,0.009940676391124725,5.613203938992228e-06,0.40784314274787903,,,,,
|
||||
80,2.6861000061035156,0.01050259917974472,4.6791110435151495e-06,0.41830065846443176,,,,,
|
||||
82,2.6828999519348145,0.009912634268403053,3.819659923465224e-06,0.4287581741809845,,,,,
|
||||
84,2.699199914932251,0.010668900795280933,3.03903811982309e-06,0.43921568989753723,,,,,
|
||||
86,2.698899984359741,0.009650414809584618,2.3410482299368596e-06,0.44967320561408997,,,,,
|
||||
88,2.6884000301361084,0.01006452739238739,1.7290908544964623e-06,0.4601307213306427,,,,,
|
||||
90,2.6928999423980713,0.010409764014184475,1.206147544507985e-06,0.47058823704719543,,,,,
|
||||
92,2.714200019836426,0.009937116876244545,7.74766078848188e-07,0.48104575276374817,,,,,
|
||||
94,2.672300100326538,0.009728306904435158,4.370479871340649e-07,0.4915032684803009,,,,,
|
||||
96,2.7018001079559326,0.010098566301167011,1.9463863054625108e-07,0.501960813999176,,,,,
|
||||
98,2.7123000621795654,0.009524320252239704,4.8718995060426096e-08,0.5124183297157288,,,,,
|
||||
100,2.7028000354766846,0.009290286339819431,0.0,0.5228758454322815,788.0635986328125,64.96900177001953,0.12700000405311584,4.629706395531346e+17,2.739542245864868
|
||||
|
BIN
example/npu_pretrain_loss.png
Normal file
BIN
example/npu_pretrain_loss.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 46 KiB |
51
example/npu_sft.csv
Normal file
51
example/npu_sft.csv
Normal file
@@ -0,0 +1,51 @@
|
||||
step,train/loss,train/grad_norm,train/learning_rate,train/epoch,train/train_runtime,train/train_samples_per_second,train/train_steps_per_second,train/total_flos,train/train_loss
|
||||
2,1.1491999626159668,0.6218180060386658,1.9999999949504854e-06,0.0004617871018126607,,,,,
|
||||
4,1.0981999635696411,0.6825665235519409,3.999999989900971e-06,0.0009235742036253214,,,,,
|
||||
6,1.1269999742507935,0.7838642001152039,6.000000212225132e-06,0.001385361305437982,,,,,
|
||||
8,1.0542000532150269,0.8744276762008667,7.999999979801942e-06,0.0018471484072506428,,,,,
|
||||
10,1.2441999912261963,0.7064258456230164,9.999999747378752e-06,0.0023089356254786253,,,,,
|
||||
12,1.2927000522613525,0.6829814910888672,1.2000000424450263e-05,0.002770722610875964,,,,,
|
||||
14,1.0844999551773071,0.5265647172927856,1.4000000192027073e-05,0.0032325098291039467,,,,,
|
||||
16,1.0963000059127808,0.4373657703399658,1.5999999959603883e-05,0.0036942968145012856,,,,,
|
||||
18,1.0615999698638916,0.46220508217811584,1.8000000636675395e-05,0.004156084265559912,,,,,
|
||||
20,1.3325999975204468,0.7157824039459229,1.9999999494757503e-05,0.004617871250957251,,,,,
|
||||
22,1.2070000171661377,0.5933427214622498,1.996917308133561e-05,0.0050796582363545895,,,,,
|
||||
24,1.2044999599456787,0.5816172957420349,1.9876883015967906e-05,0.005541445221751928,,,,,
|
||||
26,1.0740000009536743,0.4489712119102478,1.9723698642337695e-05,0.0060032326728105545,,,,,
|
||||
28,1.1164000034332275,0.3696516752243042,1.9510565834934823e-05,0.006465019658207893,,,,,
|
||||
30,1.045199990272522,0.4376335144042969,1.9238796085119247e-05,0.006926806643605232,,,,,
|
||||
32,1.1247999668121338,0.4589230716228485,1.8910064682131633e-05,0.007388593629002571,,,,,
|
||||
34,1.0688999891281128,0.3879022002220154,1.8526401618146338e-05,0.00785038061439991,,,,,
|
||||
36,1.0292999744415283,0.4027869403362274,1.8090169760398567e-05,0.008312168531119823,,,,,
|
||||
38,1.1052000522613525,0.37394437193870544,1.7604059394216165e-05,0.008773955516517162,,,,,
|
||||
40,1.1557999849319458,0.3808683753013611,1.7071068214136176e-05,0.009235742501914501,,,,,
|
||||
42,1.0232000350952148,0.4252733886241913,1.6494481315021403e-05,0.00969752948731184,,,,,
|
||||
44,1.0364999771118164,0.48068660497665405,1.5877853002166376e-05,0.010159316472709179,,,,,
|
||||
46,1.1340999603271484,0.37313926219940186,1.5224985872919206e-05,0.010621103458106518,,,,,
|
||||
48,1.0866999626159668,0.4175492823123932,1.453990535082994e-05,0.011082890443503857,,,,,
|
||||
50,1.1039999723434448,0.35443660616874695,1.3826834219798911e-05,0.01154467836022377,,,,,
|
||||
52,1.1480000019073486,0.39232146739959717,1.3090169886709191e-05,0.012006465345621109,,,,,
|
||||
54,1.1861000061035156,0.396918922662735,1.2334453458606731e-05,0.012468252331018448,,,,,
|
||||
56,1.0096999406814575,0.3885609209537506,1.1564344276848715e-05,0.012930039316415787,,,,,
|
||||
58,1.114799976348877,0.4421806335449219,1.0784590813273098e-05,0.013391826301813126,,,,,
|
||||
60,1.0795999765396118,0.36081990599632263,9.999999747378752e-06,0.013853613287210464,,,,,
|
||||
62,1.1764999628067017,0.4062329828739166,9.215408681484405e-06,0.014315400272607803,,,,,
|
||||
64,1.107200026512146,0.39982733130455017,8.435655217908788e-06,0.014777187258005142,,,,,
|
||||
66,1.1868000030517578,0.4688170254230499,7.665546036150772e-06,0.015238975174725056,,,,,
|
||||
68,1.0312999486923218,0.3301626741886139,6.909830062795663e-06,0.01570076122879982,,,,,
|
||||
70,1.1089999675750732,0.3377252221107483,6.173165729705943e-06,0.01616254821419716,,,,,
|
||||
72,1.0716999769210815,0.39666977524757385,5.460095053422265e-06,0.016624337062239647,,,,,
|
||||
74,1.1461999416351318,0.4125552177429199,4.7750145313329995e-06,0.017086124047636986,,,,,
|
||||
76,1.042199969291687,0.3825180232524872,4.1221474020858295e-06,0.017547911033034325,,,,,
|
||||
78,0.9157000184059143,0.4063441753387451,3.505519543978153e-06,0.018009698018431664,,,,,
|
||||
80,1.1110999584197998,0.35289037227630615,2.9289321901160292e-06,0.018471485003829002,,,,,
|
||||
82,1.167199969291687,0.33720290660858154,2.3959403279150138e-06,0.01893327198922634,,,,,
|
||||
84,1.2375999689102173,0.38099613785743713,1.909829961732612e-06,0.01939505897462368,,,,,
|
||||
86,1.2151999473571777,0.3848689794540405,1.4735983313585166e-06,0.01985684596002102,,,,,
|
||||
88,1.1628999710083008,0.40408074855804443,1.0899348126258701e-06,0.020318632945418358,,,,,
|
||||
90,1.1884000301361084,0.4015007019042969,7.612046601934708e-07,0.020780419930815697,,,,,
|
||||
92,1.152500033378601,0.38306349515914917,4.894348535344761e-07,0.021242206916213036,,,,,
|
||||
94,1.154099941253662,0.45273807644844055,2.7630079557638965e-07,0.021703993901610374,,,,,
|
||||
96,1.0618000030517578,0.35036078095436096,1.2311659247643547e-07,0.022165780887007713,,,,,
|
||||
98,1.0270999670028687,0.40208569169044495,3.0826662111849146e-08,0.022627567872405052,,,,,
|
||||
100,1.0285999774932861,0.38247284293174744,0.0,0.02308935672044754,728.7083129882812,2.196000099182129,0.13699999451637268,1862467846144.0,1.117748498916626
|
||||
|
BIN
example/npu_sft_loss.png
Normal file
BIN
example/npu_sft_loss.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 67 KiB |
8
example/requirements.txt
Normal file
8
example/requirements.txt
Normal file
@@ -0,0 +1,8 @@
|
||||
transformers==4.46.3
|
||||
tokenizers==0.20.3
|
||||
accelerate==1.1.1
|
||||
deepspeed==0.16.2
|
||||
datasets==3.1.0
|
||||
safetensors==0.4.5
|
||||
pyarrow==17.0.0
|
||||
tensorboard==2.18.0
|
||||
38
example/run.sh
Normal file
38
example/run.sh
Normal file
@@ -0,0 +1,38 @@
|
||||
#!/bin/bash
|
||||
|
||||
MODEL_PATH="/model/BitCPM-CANN-1B-unquantized"
|
||||
DATA_PATH="/dataset/c4-pro/data/000_1_7.parquet"
|
||||
OUTPUT_DIR="./output"
|
||||
DS_CONFIG="./ds_config_z2.json"
|
||||
|
||||
NUM_GPUS=8
|
||||
BATCH_SIZE_PER_GPU=8
|
||||
GRAD_ACCUM_STEPS=8
|
||||
MAX_SEQ_LENGTH=1024
|
||||
|
||||
export ASCEND_RT_VISIBLE_DEVICES=8,9,10,11,12,13,14,15
|
||||
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
|
||||
export DS_SKIP_CUDA_CHECK=1
|
||||
torchrun --nproc_per_node=$NUM_GPUS train.py \
|
||||
--model_name_or_path $MODEL_PATH \
|
||||
--data_path $DATA_PATH \
|
||||
--max_seq_length $MAX_SEQ_LENGTH \
|
||||
--output_dir $OUTPUT_DIR \
|
||||
--per_device_train_batch_size $BATCH_SIZE_PER_GPU \
|
||||
--gradient_accumulation_steps $GRAD_ACCUM_STEPS \
|
||||
--max_steps 100 \
|
||||
--learning_rate 4e-5 \
|
||||
--lr_scheduler_type cosine \
|
||||
--warmup_ratio 0.1 \
|
||||
--weight_decay 1e-2 \
|
||||
--logging_steps 2 \
|
||||
--save_steps 500 \
|
||||
--save_total_limit 3 \
|
||||
--bf16 \
|
||||
--deepspeed $DS_CONFIG \
|
||||
--gradient_checkpointing \
|
||||
--seed 42 \
|
||||
--dataloader_num_workers 4 \
|
||||
--report_to tensorboard \
|
||||
--logging_dir /data/tensorboard/pretrain \
|
||||
--gradient_checkpointing_kwargs '{"use_reentrant": false}'
|
||||
40
example/run_sft.sh
Normal file
40
example/run_sft.sh
Normal file
@@ -0,0 +1,40 @@
|
||||
#!/bin/bash
|
||||
|
||||
MODEL_PATH="/model/BitCPM-CANN-1B-unquantized"
|
||||
DATA_PATH="/dataset/HuggingFaceH4_ultrachat_200k/data/train_sft-00000-of-00003-a3ecf92756993583.parquet"
|
||||
OUTPUT_DIR="./output_sft"
|
||||
DS_CONFIG="./ds_config.json"
|
||||
|
||||
NUM_GPUS=8
|
||||
BATCH_SIZE_PER_GPU=2
|
||||
GRAD_ACCUM_STEPS=1
|
||||
MAX_SEQ_LENGTH=8192
|
||||
|
||||
export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
|
||||
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
|
||||
export DS_SKIP_CUDA_CHECK=1
|
||||
|
||||
torchrun --nproc_per_node=$NUM_GPUS train_sft.py \
|
||||
--model_name_or_path $MODEL_PATH \
|
||||
--data_path $DATA_PATH \
|
||||
--max_seq_length $MAX_SEQ_LENGTH \
|
||||
--output_dir $OUTPUT_DIR \
|
||||
--per_device_train_batch_size $BATCH_SIZE_PER_GPU \
|
||||
--gradient_accumulation_steps $GRAD_ACCUM_STEPS \
|
||||
--max_steps 100 \
|
||||
--learning_rate 2e-5 \
|
||||
--lr_scheduler_type cosine \
|
||||
--warmup_ratio 0.2 \
|
||||
--weight_decay 0.0 \
|
||||
--logging_steps 2 \
|
||||
--save_steps 500 \
|
||||
--save_total_limit 3 \
|
||||
--bf16 \
|
||||
--deepspeed $DS_CONFIG \
|
||||
--gradient_checkpointing \
|
||||
--seed 42 \
|
||||
--dataloader_num_workers 4 \
|
||||
--report_to tensorboard \
|
||||
--logging_dir /data/tensorboard/sft \
|
||||
--train_on_prompt false \
|
||||
--gradient_checkpointing_kwargs '{"use_reentrant": false}'
|
||||
203
example/train.py
Normal file
203
example/train.py
Normal file
@@ -0,0 +1,203 @@
|
||||
"""
|
||||
Continual pretraining script for CPM-2B model using DeepSpeed + HuggingFace Trainer.
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import math
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import contextlib
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
AutoConfig,
|
||||
Trainer,
|
||||
TrainingArguments,
|
||||
HfArgumentParser,
|
||||
DataCollatorForLanguageModeling,
|
||||
set_seed,
|
||||
)
|
||||
|
||||
import deepspeed
|
||||
_orig_no_sync = deepspeed.DeepSpeedEngine.no_sync
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _patched_no_sync(self):
|
||||
try:
|
||||
with _orig_no_sync(self):
|
||||
yield
|
||||
except AssertionError:
|
||||
yield
|
||||
|
||||
deepspeed.DeepSpeedEngine.no_sync = _patched_no_sync
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments:
|
||||
model_name_or_path: str = field(
|
||||
metadata={"help": "Path to pretrained model or model identifier"}
|
||||
)
|
||||
torch_dtype: Optional[str] = field(
|
||||
default="bfloat16",
|
||||
metadata={"help": "torch dtype for model weights (float16, bfloat16, float32)"},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataArguments:
|
||||
data_path: str = field(
|
||||
metadata={"help": "Path to training data (parquet file or directory)"}
|
||||
)
|
||||
max_seq_length: int = field(
|
||||
default=4096,
|
||||
metadata={"help": "Maximum sequence length for training"},
|
||||
)
|
||||
text_column: str = field(
|
||||
default="text",
|
||||
metadata={"help": "Name of the text column in the dataset"},
|
||||
)
|
||||
preprocessing_num_workers: int = field(
|
||||
default=8,
|
||||
metadata={"help": "Number of workers for data preprocessing"},
|
||||
)
|
||||
|
||||
|
||||
def tokenize_and_group(dataset, tokenizer, data_args):
|
||||
"""Tokenize texts and group into chunks of max_seq_length."""
|
||||
|
||||
column_names = dataset.column_names
|
||||
text_column = data_args.text_column
|
||||
if text_column not in column_names:
|
||||
candidates = [c for c in column_names if "text" in c.lower()]
|
||||
if candidates:
|
||||
text_column = candidates[0]
|
||||
else:
|
||||
text_column = column_names[0]
|
||||
logger.warning(f"Column '{data_args.text_column}' not found, using '{text_column}'")
|
||||
|
||||
def tokenize_function(examples):
|
||||
return tokenizer(examples[text_column], add_special_tokens=False)
|
||||
|
||||
tokenized_dataset = dataset.map(
|
||||
tokenize_function,
|
||||
batched=True,
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
remove_columns=column_names,
|
||||
desc="Tokenizing",
|
||||
)
|
||||
|
||||
block_size = data_args.max_seq_length
|
||||
|
||||
def group_texts(examples):
|
||||
concatenated = {k: sum(examples[k], []) for k in examples.keys()}
|
||||
total_length = len(concatenated["input_ids"])
|
||||
total_length = (total_length // block_size) * block_size
|
||||
|
||||
result = {
|
||||
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
|
||||
for k, t in concatenated.items()
|
||||
}
|
||||
result["labels"] = result["input_ids"].copy()
|
||||
return result
|
||||
|
||||
grouped_dataset = tokenized_dataset.map(
|
||||
group_texts,
|
||||
batched=True,
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
desc="Grouping texts",
|
||||
)
|
||||
|
||||
return grouped_dataset
|
||||
|
||||
|
||||
def main():
|
||||
parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
|
||||
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
|
||||
)
|
||||
logger.info(f"Training args: {training_args}")
|
||||
|
||||
set_seed(training_args.seed)
|
||||
|
||||
dtype_map = {
|
||||
"float16": torch.float16,
|
||||
"bfloat16": torch.bfloat16,
|
||||
"float32": torch.float32,
|
||||
}
|
||||
torch_dtype = dtype_map.get(model_args.torch_dtype, torch.bfloat16)
|
||||
|
||||
logger.info(f"Loading tokenizer from {model_args.model_name_or_path}")
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
logger.info(f"Loading model from {model_args.model_name_or_path}")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
torch_dtype=torch_dtype,
|
||||
trust_remote_code=True,
|
||||
attn_implementation="sdpa",
|
||||
)
|
||||
model.config.use_cache = False
|
||||
|
||||
logger.info(f"Loading dataset from {data_args.data_path}")
|
||||
if os.path.isfile(data_args.data_path):
|
||||
raw_dataset = load_dataset("parquet", data_files=data_args.data_path, split="train")
|
||||
elif os.path.isdir(data_args.data_path):
|
||||
parquet_files = [
|
||||
os.path.join(data_args.data_path, f)
|
||||
for f in os.listdir(data_args.data_path)
|
||||
if f.endswith(".parquet")
|
||||
]
|
||||
raw_dataset = load_dataset("parquet", data_files=parquet_files, split="train")
|
||||
else:
|
||||
raise ValueError(f"Data path not found: {data_args.data_path}")
|
||||
|
||||
logger.info(f"Dataset loaded: {len(raw_dataset)} samples, columns: {raw_dataset.column_names}")
|
||||
|
||||
train_dataset = tokenize_and_group(raw_dataset, tokenizer, data_args)
|
||||
logger.info(f"Processed dataset: {len(train_dataset)} samples of length {data_args.max_seq_length}")
|
||||
|
||||
data_collator = DataCollatorForLanguageModeling(
|
||||
tokenizer=tokenizer,
|
||||
mlm=False,
|
||||
)
|
||||
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
data_collator=data_collator,
|
||||
)
|
||||
|
||||
logger.info("Starting training...")
|
||||
train_result = trainer.train(
|
||||
resume_from_checkpoint=training_args.resume_from_checkpoint
|
||||
)
|
||||
|
||||
trainer.save_model()
|
||||
trainer.save_state()
|
||||
|
||||
metrics = train_result.metrics
|
||||
metrics["train_samples"] = len(train_dataset)
|
||||
trainer.log_metrics("train", metrics)
|
||||
trainer.save_metrics("train", metrics)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
424
example/train_sft.py
Normal file
424
example/train_sft.py
Normal file
@@ -0,0 +1,424 @@
|
||||
"""
|
||||
Supervised fine-tuning script using DeepSpeed + HuggingFace Trainer.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import contextlib
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
HfArgumentParser,
|
||||
Trainer,
|
||||
TrainingArguments,
|
||||
set_seed,
|
||||
)
|
||||
|
||||
import deepspeed
|
||||
_orig_no_sync = deepspeed.DeepSpeedEngine.no_sync
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _patched_no_sync(self):
|
||||
try:
|
||||
with _orig_no_sync(self):
|
||||
yield
|
||||
except AssertionError:
|
||||
yield
|
||||
|
||||
deepspeed.DeepSpeedEngine.no_sync = _patched_no_sync
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
IGNORE_INDEX = -100
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments:
|
||||
model_name_or_path: str = field(
|
||||
metadata={"help": "Path to pretrained model or model identifier"}
|
||||
)
|
||||
torch_dtype: Optional[str] = field(
|
||||
default="bfloat16",
|
||||
metadata={"help": "torch dtype for model weights (float16, bfloat16, float32)"},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataArguments:
|
||||
data_path: str = field(metadata={"help": "Path to SFT data file or directory"})
|
||||
max_seq_length: int = field(
|
||||
default=4096,
|
||||
metadata={"help": "Maximum sequence length for training"},
|
||||
)
|
||||
prompt_column: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Prompt/instruction column name. Auto-detected if omitted."},
|
||||
)
|
||||
input_column: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Optional extra input/context column name"},
|
||||
)
|
||||
response_column: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Response/output column name. Auto-detected if omitted."},
|
||||
)
|
||||
messages_column: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Chat messages column name. Auto-detected if omitted."},
|
||||
)
|
||||
system_column: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Optional system prompt column name"},
|
||||
)
|
||||
train_on_prompt: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to compute loss on prompt/user tokens"},
|
||||
)
|
||||
add_eos_token: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Append eos_token to plain prompt/response examples"},
|
||||
)
|
||||
preprocessing_num_workers: int = field(
|
||||
default=8,
|
||||
metadata={"help": "Number of workers for data preprocessing"},
|
||||
)
|
||||
|
||||
|
||||
class SFTDataCollator:
|
||||
def __init__(self, tokenizer, pad_to_multiple_of: Optional[int] = 8):
|
||||
self.tokenizer = tokenizer
|
||||
self.pad_to_multiple_of = pad_to_multiple_of
|
||||
|
||||
def __call__(self, features: List[Dict[str, List[int]]]) -> Dict[str, torch.Tensor]:
|
||||
max_length = max(len(feature["input_ids"]) for feature in features)
|
||||
if self.pad_to_multiple_of:
|
||||
multiple = self.pad_to_multiple_of
|
||||
max_length = ((max_length + multiple - 1) // multiple) * multiple
|
||||
|
||||
input_ids = []
|
||||
attention_mask = []
|
||||
labels = []
|
||||
pad_token_id = self.tokenizer.pad_token_id
|
||||
|
||||
for feature in features:
|
||||
length = len(feature["input_ids"])
|
||||
pad_length = max_length - length
|
||||
input_ids.append(feature["input_ids"] + [pad_token_id] * pad_length)
|
||||
attention_mask.append([1] * length + [0] * pad_length)
|
||||
labels.append(feature["labels"] + [IGNORE_INDEX] * pad_length)
|
||||
|
||||
return {
|
||||
"input_ids": torch.tensor(input_ids, dtype=torch.long),
|
||||
"attention_mask": torch.tensor(attention_mask, dtype=torch.long),
|
||||
"labels": torch.tensor(labels, dtype=torch.long),
|
||||
}
|
||||
|
||||
|
||||
def load_sft_dataset(data_path: str):
|
||||
if os.path.isfile(data_path):
|
||||
extension = os.path.splitext(data_path)[1].lstrip(".").lower()
|
||||
if extension == "jsonl":
|
||||
extension = "json"
|
||||
if extension not in {"parquet", "json", "csv", "txt"}:
|
||||
raise ValueError(f"Unsupported data file extension: {extension}")
|
||||
return load_dataset(extension, data_files=data_path, split="train")
|
||||
|
||||
if os.path.isdir(data_path):
|
||||
data_files = []
|
||||
extension = None
|
||||
for name in os.listdir(data_path):
|
||||
current_extension = os.path.splitext(name)[1].lstrip(".").lower()
|
||||
if current_extension == "jsonl":
|
||||
current_extension = "json"
|
||||
if current_extension in {"parquet", "json", "csv", "txt"}:
|
||||
extension = extension or current_extension
|
||||
if current_extension == extension:
|
||||
data_files.append(os.path.join(data_path, name))
|
||||
if not data_files or extension is None:
|
||||
raise ValueError(f"No supported data files found in: {data_path}")
|
||||
return load_dataset(extension, data_files=sorted(data_files), split="train")
|
||||
|
||||
raise ValueError(f"Data path not found: {data_path}")
|
||||
|
||||
|
||||
def choose_column(
|
||||
column_names: List[str], explicit: Optional[str], candidates: List[str]
|
||||
) -> Optional[str]:
|
||||
if explicit:
|
||||
if explicit not in column_names:
|
||||
raise ValueError(f"Column '{explicit}' not found. Available columns: {column_names}")
|
||||
return explicit
|
||||
for name in candidates:
|
||||
if name in column_names:
|
||||
return name
|
||||
return None
|
||||
|
||||
|
||||
def parse_messages(value: Any) -> List[Dict[str, str]]:
|
||||
if isinstance(value, str):
|
||||
value = json.loads(value)
|
||||
if not isinstance(value, list):
|
||||
raise ValueError("messages/conversations column must be a list or JSON string")
|
||||
|
||||
messages = []
|
||||
for item in value:
|
||||
if not isinstance(item, dict):
|
||||
raise ValueError("Each message must be a dict")
|
||||
|
||||
role = item.get("role", item.get("from"))
|
||||
content = item.get("content", item.get("value"))
|
||||
if role == "human":
|
||||
role = "user"
|
||||
elif role == "gpt":
|
||||
role = "assistant"
|
||||
|
||||
if role is None or content is None:
|
||||
raise ValueError("Each message must contain role/from and content/value")
|
||||
messages.append({"role": str(role), "content": str(content)})
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
def tokenize_text(tokenizer, text: str) -> List[int]:
|
||||
return tokenizer(text, add_special_tokens=False)["input_ids"]
|
||||
|
||||
|
||||
def apply_chat_template(tokenizer, messages: List[Dict[str, str]], add_generation_prompt: bool) -> str:
|
||||
if tokenizer.chat_template is None:
|
||||
raise ValueError(
|
||||
"The tokenizer has no chat_template. Use prompt/response columns or set a chat_template."
|
||||
)
|
||||
return tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=add_generation_prompt,
|
||||
)
|
||||
|
||||
|
||||
def encode_prompt_response(
|
||||
example: Dict[str, Any],
|
||||
tokenizer,
|
||||
data_args: DataArguments,
|
||||
prompt_column: str,
|
||||
input_column: Optional[str],
|
||||
response_column: str,
|
||||
) -> Tuple[List[int], List[int]]:
|
||||
prompt = str(example[prompt_column])
|
||||
if input_column and example.get(input_column):
|
||||
prompt = prompt + "\n" + str(example[input_column])
|
||||
response = str(example[response_column])
|
||||
|
||||
messages = []
|
||||
if data_args.system_column and example.get(data_args.system_column):
|
||||
messages.append({"role": "system", "content": str(example[data_args.system_column])})
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
messages.append({"role": "assistant", "content": response})
|
||||
|
||||
if tokenizer.chat_template is not None:
|
||||
full_text = apply_chat_template(tokenizer, messages, add_generation_prompt=False)
|
||||
prompt_text = apply_chat_template(tokenizer, messages[:-1], add_generation_prompt=True)
|
||||
input_ids = tokenize_text(tokenizer, full_text)
|
||||
prompt_length = len(tokenize_text(tokenizer, prompt_text))
|
||||
else:
|
||||
response_text = response
|
||||
if data_args.add_eos_token and tokenizer.eos_token:
|
||||
response_text += tokenizer.eos_token
|
||||
full_text = prompt + "\n" + response_text
|
||||
input_ids = tokenize_text(tokenizer, full_text)
|
||||
prompt_length = len(tokenize_text(tokenizer, prompt + "\n"))
|
||||
|
||||
labels = input_ids.copy()
|
||||
if not data_args.train_on_prompt:
|
||||
labels[:prompt_length] = [IGNORE_INDEX] * min(prompt_length, len(labels))
|
||||
return input_ids, labels
|
||||
|
||||
|
||||
def encode_messages(
|
||||
example: Dict[str, Any],
|
||||
tokenizer,
|
||||
data_args: DataArguments,
|
||||
messages_column: str,
|
||||
) -> Tuple[List[int], List[int]]:
|
||||
messages = parse_messages(example[messages_column])
|
||||
|
||||
if tokenizer.chat_template is not None:
|
||||
full_text = apply_chat_template(tokenizer, messages, add_generation_prompt=False)
|
||||
input_ids = tokenize_text(tokenizer, full_text)
|
||||
labels = [IGNORE_INDEX] * len(input_ids)
|
||||
|
||||
if data_args.train_on_prompt:
|
||||
labels = input_ids.copy()
|
||||
else:
|
||||
for index, message in enumerate(messages):
|
||||
if message["role"] != "assistant":
|
||||
continue
|
||||
before_text = apply_chat_template(
|
||||
tokenizer, messages[:index], add_generation_prompt=True
|
||||
)
|
||||
after_text = apply_chat_template(
|
||||
tokenizer, messages[: index + 1], add_generation_prompt=False
|
||||
)
|
||||
start = len(tokenize_text(tokenizer, before_text))
|
||||
end = len(tokenize_text(tokenizer, after_text))
|
||||
labels[start:end] = input_ids[start:end]
|
||||
else:
|
||||
labels = []
|
||||
input_ids = []
|
||||
for message in messages:
|
||||
part = f"{message['role']}: {message['content']}\n"
|
||||
if data_args.add_eos_token and message["role"] == "assistant" and tokenizer.eos_token:
|
||||
part += tokenizer.eos_token
|
||||
part_ids = tokenize_text(tokenizer, part)
|
||||
input_ids.extend(part_ids)
|
||||
if data_args.train_on_prompt or message["role"] == "assistant":
|
||||
labels.extend(part_ids)
|
||||
else:
|
||||
labels.extend([IGNORE_INDEX] * len(part_ids))
|
||||
|
||||
return input_ids, labels
|
||||
|
||||
|
||||
def preprocess_sft_dataset(raw_dataset, tokenizer, data_args: DataArguments):
|
||||
column_names = raw_dataset.column_names
|
||||
messages_column = choose_column(
|
||||
column_names, data_args.messages_column, ["messages", "conversations"]
|
||||
)
|
||||
prompt_column = choose_column(
|
||||
column_names,
|
||||
data_args.prompt_column,
|
||||
["prompt", "instruction", "question"],
|
||||
)
|
||||
input_column = choose_column(
|
||||
column_names,
|
||||
data_args.input_column,
|
||||
["input", "context"],
|
||||
)
|
||||
response_column = choose_column(
|
||||
column_names,
|
||||
data_args.response_column,
|
||||
["response", "output", "answer", "chosen"],
|
||||
)
|
||||
|
||||
if messages_column:
|
||||
logger.info(f"Using chat messages column: {messages_column}")
|
||||
elif prompt_column and response_column:
|
||||
logger.info(f"Using prompt column '{prompt_column}' and response column '{response_column}'")
|
||||
else:
|
||||
raise ValueError(
|
||||
"Cannot infer SFT data format. Provide either messages/conversations or "
|
||||
"prompt/instruction plus response/output columns."
|
||||
)
|
||||
|
||||
def encode_batch(examples):
|
||||
batch_input_ids = []
|
||||
batch_labels = []
|
||||
batch_attention_mask = []
|
||||
|
||||
batch_size = len(next(iter(examples.values())))
|
||||
for i in range(batch_size):
|
||||
example = {name: values[i] for name, values in examples.items()}
|
||||
if messages_column:
|
||||
input_ids, labels = encode_messages(example, tokenizer, data_args, messages_column)
|
||||
else:
|
||||
input_ids, labels = encode_prompt_response(
|
||||
example, tokenizer, data_args, prompt_column, input_column, response_column
|
||||
)
|
||||
|
||||
input_ids = input_ids[: data_args.max_seq_length]
|
||||
labels = labels[: data_args.max_seq_length]
|
||||
if not input_ids or all(label == IGNORE_INDEX for label in labels):
|
||||
continue
|
||||
|
||||
batch_input_ids.append(input_ids)
|
||||
batch_labels.append(labels)
|
||||
batch_attention_mask.append([1] * len(input_ids))
|
||||
|
||||
return {
|
||||
"input_ids": batch_input_ids,
|
||||
"attention_mask": batch_attention_mask,
|
||||
"labels": batch_labels,
|
||||
}
|
||||
|
||||
return raw_dataset.map(
|
||||
encode_batch,
|
||||
batched=True,
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
remove_columns=column_names,
|
||||
desc="Tokenizing SFT data",
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
|
||||
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
|
||||
)
|
||||
logger.info(f"Training args: {training_args}")
|
||||
|
||||
set_seed(training_args.seed)
|
||||
|
||||
dtype_map = {
|
||||
"float16": torch.float16,
|
||||
"bfloat16": torch.bfloat16,
|
||||
"float32": torch.float32,
|
||||
}
|
||||
torch_dtype = dtype_map.get(model_args.torch_dtype, torch.bfloat16)
|
||||
|
||||
logger.info(f"Loading tokenizer from {model_args.model_name_or_path}")
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
logger.info(f"Loading model from {model_args.model_name_or_path}")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
torch_dtype=torch_dtype,
|
||||
trust_remote_code=True,
|
||||
attn_implementation="sdpa",
|
||||
)
|
||||
model.config.use_cache = False
|
||||
|
||||
logger.info(f"Loading SFT dataset from {data_args.data_path}")
|
||||
raw_dataset = load_sft_dataset(data_args.data_path)
|
||||
logger.info(f"Dataset loaded: {len(raw_dataset)} samples, columns: {raw_dataset.column_names}")
|
||||
|
||||
train_dataset = preprocess_sft_dataset(raw_dataset, tokenizer, data_args)
|
||||
logger.info(f"Processed dataset: {len(train_dataset)} samples")
|
||||
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
data_collator=SFTDataCollator(tokenizer),
|
||||
)
|
||||
|
||||
logger.info("Starting SFT training...")
|
||||
train_result = trainer.train(
|
||||
resume_from_checkpoint=training_args.resume_from_checkpoint
|
||||
)
|
||||
|
||||
trainer.save_model()
|
||||
trainer.save_state()
|
||||
|
||||
metrics = train_result.metrics
|
||||
metrics["train_samples"] = len(train_dataset)
|
||||
trainer.log_metrics("train", metrics)
|
||||
trainer.save_metrics("train", metrics)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
12
generation_config.json
Normal file
12
generation_config.json
Normal file
@@ -0,0 +1,12 @@
|
||||
{
|
||||
"bos_token_id": 1,
|
||||
"do_sample": true,
|
||||
"eos_token_id": [
|
||||
2,
|
||||
73440
|
||||
],
|
||||
"pad_token_id": 2,
|
||||
"temperature": 0.8,
|
||||
"top_p": 0.8,
|
||||
"transformers_version": "4.46.1"
|
||||
}
|
||||
1615
modeling_minicpm.py
Normal file
1615
modeling_minicpm.py
Normal file
File diff suppressed because it is too large
Load Diff
3
pytorch_model.bin
Normal file
3
pytorch_model.bin
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:fe236eb8d3fd7e6bea58f8e44529318687d6be0921df0c1e9cfd8050d01e6808
|
||||
size 867818482
|
||||
176
qat-convert.py
Normal file
176
qat-convert.py
Normal file
@@ -0,0 +1,176 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from tqdm import tqdm
|
||||
import os
|
||||
import safetensors
|
||||
|
||||
class SteTernaryQuantizer(nn.Module):
|
||||
def __init__(self, group_size):
|
||||
super().__init__()
|
||||
self.group_size = group_size
|
||||
|
||||
def forward(self, x):
|
||||
org_w_shape = x.shape
|
||||
if self.group_size > 0:
|
||||
assert x.shape[-1] % self.group_size == 0
|
||||
x = x.reshape(-1, self.group_size)
|
||||
elif self.group_size == -1:
|
||||
x = x.reshape(-1, x.shape[-1])
|
||||
assert x.dim() == 2
|
||||
scales = 1.0 / (x.abs().mean(dim=1, keepdim=True).clamp_(min=1e-5))
|
||||
x_q = (torch.clamp(torch.round(x * scales),-1,1) / scales)
|
||||
assert torch.isnan(x_q).sum() == 0
|
||||
x = x.reshape(org_w_shape)
|
||||
x_q = x_q.reshape(org_w_shape)
|
||||
return x_q
|
||||
|
||||
class SteIntQuantizer(nn.Module):
|
||||
def __init__(self, bit, group_size):
|
||||
super().__init__()
|
||||
self.bit = bit
|
||||
self.group_size = group_size
|
||||
|
||||
def forward(self, x):
|
||||
org_w_shape = x.shape
|
||||
if self.group_size > 0:
|
||||
assert org_w_shape[-1] % self.group_size == 0
|
||||
x = x.reshape(-1, self.group_size)
|
||||
elif self.group_size == -1:
|
||||
x = x.reshape(-1, x.shape[-1])
|
||||
|
||||
assert x.dim() == 2
|
||||
|
||||
abs_max_val = x.abs().amax(dim=1, keepdim=True)
|
||||
max_int = 2 ** (self.bit - 1) - 1
|
||||
min_int = - (2 ** (self.bit - 1))
|
||||
scales = abs_max_val.clamp(min=1e-5) / max_int
|
||||
|
||||
assert torch.isnan(scales).sum() == 0
|
||||
|
||||
x_q = (torch.clamp(torch.round(x / scales), min_int, max_int)) * scales
|
||||
|
||||
assert torch.isnan(x_q).sum() == 0
|
||||
|
||||
x = x.reshape(org_w_shape)
|
||||
x_q = x_q.reshape(org_w_shape)
|
||||
|
||||
return x_q
|
||||
|
||||
class SteInt2Quantizer(nn.Module):
|
||||
def __init__(self, group_size):
|
||||
super().__init__()
|
||||
self.group_size = group_size
|
||||
|
||||
def forward(self, x):
|
||||
org_w_shape = x.shape
|
||||
if self.group_size > 0:
|
||||
assert x.shape[-1] % self.group_size == 0
|
||||
x = x.reshape(-1, self.group_size)
|
||||
elif self.group_size == -1:
|
||||
x = x.reshape(-1, x.shape[-1])
|
||||
|
||||
assert x.dim() == 2
|
||||
|
||||
scales = 1.0 / (x.abs().mean(dim=1, keepdim=True).clamp_(min=1e-5) * 1)
|
||||
x_q = (torch.clamp(torch.round(x * scales),-2,1) / scales)
|
||||
|
||||
assert torch.isnan(x_q).sum() == 0
|
||||
|
||||
x = x.reshape(org_w_shape)
|
||||
x_q = x_q.reshape(org_w_shape)
|
||||
|
||||
return x_q
|
||||
|
||||
def quantize_model_bin(input_bin_path, output_bin_path, quant_type="ternary", bit=2, group_size=128, device="cuda" if torch.cuda.is_available() else "cpu"):
|
||||
"""
|
||||
直接对PyTorch模型bin文件进行量化。
|
||||
|
||||
Args:
|
||||
input_bin_path: 输入模型bin文件路径
|
||||
output_bin_path: 输出量化后的模型bin文件路径
|
||||
quant_type: 量化类型 ("ternary" 或 "int")
|
||||
bit: 整数量化的位数 (仅在 quant_type="int" 时使用)
|
||||
group_size: 量化分组大小
|
||||
device: 运行设备
|
||||
"""
|
||||
print(f"加载模型文件: {input_bin_path}...")
|
||||
if input_bin_path.endswith(".bin"):
|
||||
state_dict = torch.load(input_bin_path, map_location=device)
|
||||
elif input_bin_path.endswith(".safetensors"):
|
||||
state_dict = safetensors.load_file(input_bin_path)
|
||||
elif os.path.isdir(input_bin_path) and os.path.exists(os.path.join(input_bin_path, "pytorch_model.bin")):
|
||||
state_dict = torch.load(os.path.join(input_bin_path, "pytorch_model.bin"), map_location=device)
|
||||
elif os.path.isdir(input_bin_path) and os.path.exists(os.path.join(input_bin_path, "model.safetensors")):
|
||||
state_dict = safetensors.load_file(os.path.join(input_bin_path, "model.safetensors"))
|
||||
else:
|
||||
raise ValueError(f"不支持的模型文件类型: {input_bin_path}")
|
||||
|
||||
print(f"应用 {quant_type} 量化...")
|
||||
if quant_type == "ternary":
|
||||
quantizer = SteTernaryQuantizer(group_size=group_size)
|
||||
elif quant_type == "int":
|
||||
quantizer = SteIntQuantizer(bit=bit, group_size=group_size)
|
||||
elif quant_type == "int2":
|
||||
quantizer = SteInt2Quantizer(group_size=group_size)
|
||||
else:
|
||||
raise ValueError(f"不支持的量化类型: {quant_type}")
|
||||
|
||||
# 统计需要量化的参数数量
|
||||
total_params = sum(1 for k, v in state_dict.items() if ("weight" in k and "layer" in k) or ("fc" in k))
|
||||
|
||||
# 应用量化
|
||||
with torch.no_grad():
|
||||
for name, param in tqdm(state_dict.items(), total=total_params, desc="量化中"):
|
||||
if (("weight" in name and "layer" in name and param.dim() == 2) or ("fc" in name and param.dim() == 2)):
|
||||
# 对权重进行量化
|
||||
original_weight = param.data.clone()
|
||||
quantized_weight = quantizer(original_weight)
|
||||
state_dict[name] = quantized_weight
|
||||
|
||||
# 打印前几个层的统计信息
|
||||
if total_params > 0:
|
||||
total_params -= 1
|
||||
if total_params > total_params - 5:
|
||||
print(f"层: {name}")
|
||||
print(f" 原始范围: {original_weight.min():.4f} 到 {original_weight.max():.4f}")
|
||||
print(f" 量化后范围: {quantized_weight.min():.4f} 到 {quantized_weight.max():.4f}")
|
||||
print(f" 均方误差: {((original_weight - quantized_weight)**2).mean():.8f}")
|
||||
|
||||
# 保存量化后的模型
|
||||
print(f"保存量化后的模型到: {output_bin_path}...")
|
||||
if output_bin_path.endswith(".bin"):
|
||||
torch.save(state_dict, output_bin_path)
|
||||
elif output_bin_path.endswith(".safetensors"):
|
||||
safetensors.save_file(state_dict, output_bin_path)
|
||||
else:
|
||||
os.makedirs(os.path.dirname(output_bin_path), exist_ok=True)
|
||||
output_bin_path = os.path.join(output_bin_path, "pytorch_model.bin")
|
||||
torch.save(state_dict, output_bin_path)
|
||||
print("完成!")
|
||||
|
||||
def main():
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser(description="量化PyTorch模型bin文件")
|
||||
parser.add_argument("--input_bin", type=str, required=True, help="输入模型bin文件路径")
|
||||
parser.add_argument("--output", type=str, required=True, help="输出量化后的模型bin文件路径")
|
||||
parser.add_argument("--quant_type", type=str, default="ternary", choices=["ternary", "int", "int2"], help="量化类型")
|
||||
parser.add_argument("--bit", type=int, default=2, help="整数量化的位数")
|
||||
parser.add_argument("--group_size", type=int, default=-1, help="量化分组大小")
|
||||
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="运行设备")
|
||||
parser.add_argument("--config", type=str, default="", help="model config file")
|
||||
|
||||
args = parser.parse_args()
|
||||
os.makedirs(args.output, exist_ok=True)
|
||||
quantize_model_bin(
|
||||
input_bin_path=args.input_bin,
|
||||
output_bin_path=os.path.join(args.output, "pytorch_model.bin"),
|
||||
quant_type=args.quant_type,
|
||||
bit=args.bit,
|
||||
group_size=args.group_size,
|
||||
device=args.device
|
||||
)
|
||||
if args.config:
|
||||
os.system(f"cp {args.config}/* {args.output}")
|
||||
print(f"复制{args.config}文件到{args.output}")
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
33
special_tokens_map.json
Normal file
33
special_tokens_map.json
Normal file
@@ -0,0 +1,33 @@
|
||||
{
|
||||
"additional_special_tokens": [
|
||||
"<|im_end|>",
|
||||
"<|im_start|>",
|
||||
"<|tool_call|>",
|
||||
"<|execute_start|>",
|
||||
"<|execute_end|>",
|
||||
"<|fim_prefix|>",
|
||||
"<|fim_middle|>",
|
||||
"<|fim_suffix|>"
|
||||
],
|
||||
"bos_token": {
|
||||
"content": "<s>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"eos_token": {
|
||||
"content": "<|im_end|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"unk_token": {
|
||||
"content": "<unk>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
}
|
||||
}
|
||||
490843
tokenizer.json
Normal file
490843
tokenizer.json
Normal file
File diff suppressed because it is too large
Load Diff
3
tokenizer.model
Normal file
3
tokenizer.model
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:bb74d51116831c3bf65db812c553f94ab0c88dcf97a5bbb37e3504f6d359c530
|
||||
size 1181204
|
||||
117
tokenizer_config.json
Normal file
117
tokenizer_config.json
Normal file
@@ -0,0 +1,117 @@
|
||||
{
|
||||
"add_bos_token": true,
|
||||
"add_eos_token": false,
|
||||
"add_prefix_space": null,
|
||||
"added_tokens_decoder": {
|
||||
"0": {
|
||||
"content": "<unk>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"1": {
|
||||
"content": "<s>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"2": {
|
||||
"content": "</s>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"73440": {
|
||||
"content": "<|im_end|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"73441": {
|
||||
"content": "<|im_start|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"73442": {
|
||||
"content": "<|tool_call|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"73443": {
|
||||
"content": "<|execute_start|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"73444": {
|
||||
"content": "<|execute_end|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"73445": {
|
||||
"content": "<|fim_prefix|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"73446": {
|
||||
"content": "<|fim_middle|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"73447": {
|
||||
"content": "<|fim_suffix|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
}
|
||||
},
|
||||
"additional_special_tokens": [
|
||||
"<|im_end|>",
|
||||
"<|im_start|>",
|
||||
"<|tool_call|>",
|
||||
"<|execute_start|>",
|
||||
"<|execute_end|>",
|
||||
"<|fim_prefix|>",
|
||||
"<|fim_middle|>",
|
||||
"<|fim_suffix|>"
|
||||
],
|
||||
"bos_token": "<s>",
|
||||
"chat_template": "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
|
||||
"clean_up_tokenization_spaces": false,
|
||||
"eos_token": "<|im_end|>",
|
||||
"legacy": true,
|
||||
"model_max_length": 1000000000000000019884624838656,
|
||||
"pad_token": null,
|
||||
"sp_model_kwargs": {},
|
||||
"spaces_between_special_tokens": false,
|
||||
"tokenizer_class": "LlamaTokenizer",
|
||||
"unk_token": "<unk>",
|
||||
"use_default_system_prompt": false
|
||||
}
|
||||
Reference in New Issue
Block a user