初始化项目,由ModelHub XC社区提供模型
Model: openbmb/BitCPM-CANN-1B-unquantized Source: Original Platform
This commit is contained in:
103
example/README.md
Normal file
103
example/README.md
Normal file
@@ -0,0 +1,103 @@
|
||||
# BitCPM Training Example
|
||||
|
||||
This project provides scripts for continue pretraining (CPT) and supervised fine-tuning (SFT) of **BitCPM-CANN-1B-unquantized**.
|
||||
|
||||
## File Description
|
||||
|
||||
CPT and SFT each have a pair of scripts (training script + launch script) and share DeepSpeed configuration files:
|
||||
|
||||
| File | Description |
|
||||
| --- | --- |
|
||||
| `train.py` | Continue pretrain script based on HuggingFace Trainer + DeepSpeed |
|
||||
| `run.sh` | Launch script for CPT with hyperparameter configuration |
|
||||
| `train_sft.py` | Supervised fine-tuning script based on HuggingFace Trainer + DeepSpeed |
|
||||
| `run_sft.sh` | Launch script for SFT with hyperparameter configuration |
|
||||
| `ds_config.json` | DeepSpeed ZeRO-3 configuration (with CPU offload) |
|
||||
| `ds_config_z2.json` | DeepSpeed ZeRO-2 configuration (used by default) |
|
||||
| `requirements.txt` | Python dependency list |
|
||||
|
||||
## Environment Setup
|
||||
|
||||
### Docker Image
|
||||
|
||||
Use the following Huawei NPU image:
|
||||
|
||||
```
|
||||
swr.cn-south-1.myhuaweicloud.com/ascendhub/mindspeed-llm:openeuler22.03-mindspeed-llm-2.3.0-a3-arm
|
||||
```
|
||||
|
||||
Other Huawei NPU images may also work but have not been fully tested. For GPU environments, you can skip the Docker image and just install `requirements.txt` directly.
|
||||
|
||||
### Install Dependencies
|
||||
|
||||
After entering the container, install the Python dependencies:
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## Continue Pretrain (CPT)
|
||||
|
||||
### Dataset
|
||||
|
||||
The test dataset used is [C4-Pro](https://huggingface.co/datasets/gair-prox/c4-pro), stored in parquet format after downloading.
|
||||
|
||||
### Usage
|
||||
|
||||
Modify the path configuration in `run.sh`:
|
||||
|
||||
```bash
|
||||
MODEL_PATH="/path/to/BitCPM-CANN-1B-unquantized/"
|
||||
DATA_PATH="/path/to/c4-pro/data/your_file.parquet"
|
||||
```
|
||||
|
||||
Then start training:
|
||||
|
||||
```bash
|
||||
bash run.sh
|
||||
```
|
||||
|
||||
## Supervised Fine-Tuning (SFT)
|
||||
|
||||
### Dataset
|
||||
|
||||
The test dataset used is [UltraChat 200k](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k), stored in parquet format after downloading.
|
||||
|
||||
### Usage
|
||||
|
||||
Modify the path configuration in `run_sft.sh`:
|
||||
|
||||
```bash
|
||||
MODEL_PATH="/path/to/BitCPM-CANN-1B-unquantized/"
|
||||
DATA_PATH="/path/to/ultrachat_200k/data/your_file.parquet"
|
||||
```
|
||||
|
||||
Then start training:
|
||||
|
||||
```bash
|
||||
bash run_sft.sh
|
||||
```
|
||||
|
||||
## Training Results Reference
|
||||
|
||||
> **Note:** BitCPM has its own training dataset and data mixture. It is expected that the loss continues to decrease when training on open-source datasets.
|
||||
|
||||
Below are the loss curves from smoke tests on GPU and NPU for both CPT and SFT tasks. The results are highly consistent across GPU and NPU, indicating that users can continue pre-training or fine-tuning on various compute devices:
|
||||
|
||||
| | GPU | NPU |
|
||||
| --- | --- | --- |
|
||||
| **CPT** |  |  |
|
||||
| **SFT** |  |  |
|
||||
|
||||
Training log CSV files (corresponding to the loss curves above):
|
||||
|
||||
| CSV File | Corresponding Loss Curve |
|
||||
| --- | --- |
|
||||
| [gpu_pretrain.csv](gpu_pretrain.csv) | GPU CPT |
|
||||
| [npu_pretrain.csv](npu_pretrain.csv) | NPU CPT |
|
||||
| [gpu_sft.csv](gpu_sft.csv) | GPU SFT |
|
||||
| [npu_sft.csv](npu_sft.csv) | NPU SFT |
|
||||
|
||||
---
|
||||
|
||||
These scripts provide a convenient, ready-to-use toolkit for QAT-aware continued pre-training and fine-tuning of BitCPM-CANN models, so you can quickly adapt the model to your own data and tasks while preserving ternary quantization constraints.
|
||||
29
example/ds_config.json
Normal file
29
example/ds_config.json
Normal file
@@ -0,0 +1,29 @@
|
||||
{
|
||||
"bf16": {
|
||||
"enabled": true
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": 3,
|
||||
"offload_optimizer": {
|
||||
"device": "cpu",
|
||||
"pin_memory": true
|
||||
},
|
||||
"offload_param": {
|
||||
"device": "none"
|
||||
},
|
||||
"overlap_comm": true,
|
||||
"contiguous_gradients": true,
|
||||
"sub_group_size": 1e9,
|
||||
"reduce_bucket_size": 2e8,
|
||||
"stage3_prefetch_bucket_size": 2e8,
|
||||
"stage3_param_persistence_threshold": 1e5,
|
||||
"stage3_max_live_parameters": 2e9,
|
||||
"stage3_max_reuse_distance": 2e9,
|
||||
"stage3_gather_16bit_weights_on_model_save": true
|
||||
},
|
||||
"gradient_accumulation_steps": "auto",
|
||||
"gradient_clipping": "auto",
|
||||
"train_batch_size": "auto",
|
||||
"train_micro_batch_size_per_gpu": "auto",
|
||||
"wall_clock_breakdown": false
|
||||
}
|
||||
22
example/ds_config_z2.json
Normal file
22
example/ds_config_z2.json
Normal file
@@ -0,0 +1,22 @@
|
||||
{
|
||||
"bf16": {
|
||||
"enabled": true
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": 2,
|
||||
"offload_optimizer": {
|
||||
"device": "none"
|
||||
},
|
||||
"allgather_partitions": true,
|
||||
"allgather_bucket_size": 2e8,
|
||||
"overlap_comm": true,
|
||||
"reduce_scatter": true,
|
||||
"reduce_bucket_size": 2e8,
|
||||
"contiguous_gradients": true
|
||||
},
|
||||
"gradient_accumulation_steps": "auto",
|
||||
"gradient_clipping": "auto",
|
||||
"train_batch_size": "auto",
|
||||
"train_micro_batch_size_per_gpu": "auto",
|
||||
"wall_clock_breakdown": false
|
||||
}
|
||||
51
example/gpu_pretrain.csv
Normal file
51
example/gpu_pretrain.csv
Normal file
@@ -0,0 +1,51 @@
|
||||
step,train/loss,train/grad_norm,train/learning_rate,train/epoch,train/train_runtime,train/train_samples_per_second,train/train_steps_per_second,train/total_flos,train/train_loss
|
||||
2,2.7920000553131104,0.03527498617768288,7.999999979801942e-06,0.010457516647875309,,,,,
|
||||
4,2.8011999130249023,0.03495891019701958,1.5999999959603883e-05,0.020915033295750618,,,,,
|
||||
6,2.7964000701904297,0.03271934762597084,2.4000000848900527e-05,0.0313725508749485,,,,,
|
||||
8,2.763700008392334,0.024968057870864868,3.199999991920777e-05,0.041830066591501236,,,,,
|
||||
10,3.281599998474121,0.31758183240890503,3.9999998989515007e-05,0.05228758230805397,,,,,
|
||||
12,2.941200017929077,0.044055406004190445,3.995128281530924e-05,0.062745101749897,,,,,
|
||||
14,2.851799964904785,0.03649706766009331,3.9805359847377986e-05,0.07320261746644974,,,,,
|
||||
16,2.7869999408721924,0.022624235600233078,3.9562950405525044e-05,0.08366013318300247,,,,,
|
||||
18,2.7825000286102295,0.021830420941114426,3.922523319488391e-05,0.0941176488995552,,,,,
|
||||
20,2.7857000827789307,0.01685911975800991,3.87938525818754e-05,0.10457516461610794,,,,,
|
||||
22,2.7571001052856445,0.01572061888873577,3.827090768027119e-05,0.11503268033266068,,,,,
|
||||
24,2.762399911880493,0.016891509294509888,3.7658952351193875e-05,0.125490203499794,,,,,
|
||||
26,2.7411000728607178,0.015683824196457863,3.6960962461307645e-05,0.13594771921634674,,,,,
|
||||
28,2.733099937438965,0.012847283855080605,3.6180339520797133e-05,0.14640523493289948,,,,,
|
||||
30,2.723400115966797,0.015209181234240532,3.532088885549456e-05,0.1568627506494522,,,,,
|
||||
32,2.7342000007629395,0.01241038367152214,3.4386797779006884e-05,0.16732026636600494,,,,,
|
||||
34,2.7321999073028564,0.012879018671810627,3.338261376484297e-05,0.17777778208255768,,,,,
|
||||
36,2.7314000129699707,0.013242729939520359,3.231322989449836e-05,0.1882352977991104,,,,,
|
||||
38,2.7065999507904053,0.01113435160368681,3.118385939160362e-05,0.19869281351566315,,,,,
|
||||
40,2.6958999633789062,0.012413726188242435,2.9999999242136255e-05,0.20915032923221588,,,,,
|
||||
42,2.7516000270843506,0.011661508120596409,2.8767422918463126e-05,0.21960784494876862,,,,,
|
||||
44,2.713099956512451,0.012248368933796883,2.749213126662653e-05,0.23006536066532135,,,,,
|
||||
46,2.7102999687194824,0.011450185440480709,2.6180339773418382e-05,0.24052287638187408,,,,,
|
||||
48,2.7021000385284424,0.011155751533806324,2.483843854861334e-05,0.250980406999588,,,,,
|
||||
50,2.680500030517578,0.010021247901022434,2.3472963221138343e-05,0.26143792271614075,,,,,
|
||||
52,2.699199914932251,0.010751751251518726,2.2090569473220967e-05,0.2718954384326935,,,,,
|
||||
54,2.694200038909912,0.010503941215574741,2.0697989384643734e-05,0.2823529541492462,,,,,
|
||||
56,2.7091000080108643,0.010059370659291744,1.9302009604871273e-05,0.29281046986579895,,,,,
|
||||
58,2.699399948120117,0.012161476537585258,1.7909431335283443e-05,0.3032679855823517,,,,,
|
||||
60,2.7216999530792236,0.010671027936041355,1.6527035768376663e-05,0.3137255012989044,,,,,
|
||||
62,2.7158000469207764,0.010463157668709755,1.516156225989107e-05,0.32418301701545715,,,,,
|
||||
64,2.7214999198913574,0.010665320791304111,1.3819660125591327e-05,0.3346405327320099,,,,,
|
||||
66,2.7116000652313232,0.01046629250049591,1.2507867722888477e-05,0.3450980484485626,,,,,
|
||||
68,2.6923000812530518,0.010609752498567104,1.1232576980546582e-05,0.35555556416511536,,,,,
|
||||
70,2.6830999851226807,0.009290814399719238,9.999999747378752e-06,0.3660130798816681,,,,,
|
||||
72,2.7093000411987305,0.010727670043706894,8.816142326395493e-06,0.3764705955982208,,,,,
|
||||
74,2.698699951171875,0.0109737953171134,7.686770914006047e-06,0.38692811131477356,,,,,
|
||||
76,2.712599992752075,0.010320967063307762,6.61738795315614e-06,0.3973856270313263,,,,,
|
||||
78,2.6993000507354736,0.009841523133218288,5.613203938992228e-06,0.40784314274787903,,,,,
|
||||
80,2.6861000061035156,0.010179675184190273,4.6791110435151495e-06,0.41830065846443176,,,,,
|
||||
82,2.6828999519348145,0.009790077805519104,3.819659923465224e-06,0.4287581741809845,,,,,
|
||||
84,2.699199914932251,0.010508442297577858,3.03903811982309e-06,0.43921568989753723,,,,,
|
||||
86,2.6988000869750977,0.009589221328496933,2.3410482299368596e-06,0.44967320561408997,,,,,
|
||||
88,2.688499927520752,0.010065913200378418,1.7290908544964623e-06,0.4601307213306427,,,,,
|
||||
90,2.6928999423980713,0.010363687761127949,1.206147544507985e-06,0.47058823704719543,,,,,
|
||||
92,2.714200019836426,0.010142815299332142,7.74766078848188e-07,0.48104575276374817,,,,,
|
||||
94,2.672300100326538,0.009833029471337795,4.370479871340649e-07,0.4915032684803009,,,,,
|
||||
96,2.7018001079559326,0.009937037713825703,1.9463863054625108e-07,0.501960813999176,,,,,
|
||||
98,2.7121999263763428,0.009417451918125153,4.8718995060426096e-08,0.5124183297157288,,,,,
|
||||
100,2.7028000354766846,0.009256146848201752,0.0,0.5228758454322815,365.8839111328125,139.93499755859375,0.27300000190734863,4.629706395531346e+17,2.7395541667938232
|
||||
|
BIN
example/gpu_pretrain_loss.png
Normal file
BIN
example/gpu_pretrain_loss.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 49 KiB |
51
example/gpu_sft.csv
Normal file
51
example/gpu_sft.csv
Normal file
@@ -0,0 +1,51 @@
|
||||
step,train/loss,train/grad_norm,train/learning_rate,train/epoch,train/train_runtime,train/train_samples_per_second,train/train_steps_per_second,train/total_flos,train/train_loss
|
||||
2,1.1492999792099,0.6216375231742859,1.9999999949504854e-06,0.0004617871018126607,,,,,
|
||||
4,1.0979000329971313,0.681877851486206,3.999999989900971e-06,0.0009235742036253214,,,,,
|
||||
6,1.1269999742507935,0.784303605556488,6.000000212225132e-06,0.001385361305437982,,,,,
|
||||
8,1.0542000532150269,0.8737029433250427,7.999999979801942e-06,0.0018471484072506428,,,,,
|
||||
10,1.2440999746322632,0.7068291902542114,9.999999747378752e-06,0.0023089356254786253,,,,,
|
||||
12,1.2925000190734863,0.6821666955947876,1.2000000424450263e-05,0.002770722610875964,,,,,
|
||||
14,1.0843000411987305,0.525643527507782,1.4000000192027073e-05,0.0032325098291039467,,,,,
|
||||
16,1.0961999893188477,0.43757057189941406,1.5999999959603883e-05,0.0036942968145012856,,,,,
|
||||
18,1.0614999532699585,0.46141618490219116,1.8000000636675395e-05,0.004156084265559912,,,,,
|
||||
20,1.332900047302246,0.715879499912262,1.9999999494757503e-05,0.004617871250957251,,,,,
|
||||
22,1.2070000171661377,0.5926885008811951,1.996917308133561e-05,0.0050796582363545895,,,,,
|
||||
24,1.2043999433517456,0.5833240747451782,1.9876883015967906e-05,0.005541445221751928,,,,,
|
||||
26,1.0740000009536743,0.44734400510787964,1.9723698642337695e-05,0.0060032326728105545,,,,,
|
||||
28,1.1162999868392944,0.3701137900352478,1.9510565834934823e-05,0.006465019658207893,,,,,
|
||||
30,1.0454000234603882,0.43832680583000183,1.9238796085119247e-05,0.006926806643605232,,,,,
|
||||
32,1.124899983406067,0.4591037631034851,1.8910064682131633e-05,0.007388593629002571,,,,,
|
||||
34,1.0686999559402466,0.3873400390148163,1.8526401618146338e-05,0.00785038061439991,,,,,
|
||||
36,1.0291999578475952,0.40313437581062317,1.8090169760398567e-05,0.008312168531119823,,,,,
|
||||
38,1.1052000522613525,0.3735405504703522,1.7604059394216165e-05,0.008773955516517162,,,,,
|
||||
40,1.1555999517440796,0.3818407654762268,1.7071068214136176e-05,0.009235742501914501,,,,,
|
||||
42,1.0235999822616577,0.4255191683769226,1.6494481315021403e-05,0.00969752948731184,,,,,
|
||||
44,1.0364999771118164,0.4794503152370453,1.5877853002166376e-05,0.010159316472709179,,,,,
|
||||
46,1.1344000101089478,0.37273937463760376,1.5224985872919206e-05,0.010621103458106518,,,,,
|
||||
48,1.0866999626159668,0.417492538690567,1.453990535082994e-05,0.011082890443503857,,,,,
|
||||
50,1.1038000583648682,0.35408055782318115,1.3826834219798911e-05,0.01154467836022377,,,,,
|
||||
52,1.1478999853134155,0.3930828273296356,1.3090169886709191e-05,0.012006465345621109,,,,,
|
||||
54,1.1858999729156494,0.3965947926044464,1.2334453458606731e-05,0.012468252331018448,,,,,
|
||||
56,1.0096999406814575,0.3860221207141876,1.1564344276848715e-05,0.012930039316415787,,,,,
|
||||
58,1.114799976348877,0.44393691420555115,1.0784590813273098e-05,0.013391826301813126,,,,,
|
||||
60,1.079300045967102,0.3605058789253235,9.999999747378752e-06,0.013853613287210464,,,,,
|
||||
62,1.1766999959945679,0.40689122676849365,9.215408681484405e-06,0.014315400272607803,,,,,
|
||||
64,1.1075999736785889,0.4002344310283661,8.435655217908788e-06,0.014777187258005142,,,,,
|
||||
66,1.1866999864578247,0.46947163343429565,7.665546036150772e-06,0.015238975174725056,,,,,
|
||||
68,1.0311000347137451,0.3296957314014435,6.909830062795663e-06,0.01570076122879982,,,,,
|
||||
70,1.1088999509811401,0.33858785033226013,6.173165729705943e-06,0.01616254821419716,,,,,
|
||||
72,1.0720000267028809,0.3967427909374237,5.460095053422265e-06,0.016624337062239647,,,,,
|
||||
74,1.1460000276565552,0.41202062368392944,4.7750145313329995e-06,0.017086124047636986,,,,,
|
||||
76,1.0425000190734863,0.38334518671035767,4.1221474020858295e-06,0.017547911033034325,,,,,
|
||||
78,0.9154000282287598,0.40649303793907166,3.505519543978153e-06,0.018009698018431664,,,,,
|
||||
80,1.1110999584197998,0.35371580719947815,2.9289321901160292e-06,0.018471485003829002,,,,,
|
||||
82,1.1672999858856201,0.3381657302379608,2.3959403279150138e-06,0.01893327198922634,,,,,
|
||||
84,1.2374000549316406,0.3815234303474426,1.909829961732612e-06,0.01939505897462368,,,,,
|
||||
86,1.2151000499725342,0.38446080684661865,1.4735983313585166e-06,0.01985684596002102,,,,,
|
||||
88,1.163100004196167,0.40419140458106995,1.0899348126258701e-06,0.020318632945418358,,,,,
|
||||
90,1.1883000135421753,0.4011874198913574,7.612046601934708e-07,0.020780419930815697,,,,,
|
||||
92,1.1526999473571777,0.3836020231246948,4.894348535344761e-07,0.021242206916213036,,,,,
|
||||
94,1.15339994430542,0.452364057302475,2.7630079557638965e-07,0.021703993901610374,,,,,
|
||||
96,1.062000036239624,0.3502688705921173,1.2311659247643547e-07,0.022165780887007713,,,,,
|
||||
98,1.0271999835968018,0.4022065997123718,3.0826662111849146e-08,0.022627567872405052,,,,,
|
||||
100,1.0283000469207764,0.38241174817085266,0.0,0.02308935672044754,183.9481964111328,8.697999954223633,0.5440000295639038,1862467846144.0,1.1177252531051636
|
||||
|
BIN
example/gpu_sft_loss.png
Normal file
BIN
example/gpu_sft_loss.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 68 KiB |
51
example/npu_pretrain.csv
Normal file
51
example/npu_pretrain.csv
Normal file
@@ -0,0 +1,51 @@
|
||||
step,train/loss,train/grad_norm,train/learning_rate,train/epoch,train/train_runtime,train/train_samples_per_second,train/train_steps_per_second,train/total_flos,train/train_loss
|
||||
2,2.7920000553131104,0.035306449979543686,7.999999979801942e-06,0.010457516647875309,,,,,
|
||||
4,2.8011999130249023,0.03491510450839996,1.5999999959603883e-05,0.020915033295750618,,,,,
|
||||
6,2.7964000701904297,0.032717395573854446,2.4000000848900527e-05,0.0313725508749485,,,,,
|
||||
8,2.763700008392334,0.024953875690698624,3.199999991920777e-05,0.041830066591501236,,,,,
|
||||
10,3.2811999320983887,0.3170815408229828,3.9999998989515007e-05,0.05228758230805397,,,,,
|
||||
12,2.9409000873565674,0.04423849284648895,3.995128281530924e-05,0.062745101749897,,,,,
|
||||
14,2.851900100708008,0.03667925298213959,3.9805359847377986e-05,0.07320261746644974,,,,,
|
||||
16,2.7869999408721924,0.022814607247710228,3.9562950405525044e-05,0.08366013318300247,,,,,
|
||||
18,2.782599925994873,0.021528413519263268,3.922523319488391e-05,0.0941176488995552,,,,,
|
||||
20,2.785599946975708,0.017014438286423683,3.87938525818754e-05,0.10457516461610794,,,,,
|
||||
22,2.7571001052856445,0.015719758346676826,3.827090768027119e-05,0.11503268033266068,,,,,
|
||||
24,2.762399911880493,0.016948623582720757,3.7658952351193875e-05,0.125490203499794,,,,,
|
||||
26,2.7411000728607178,0.015535997226834297,3.6960962461307645e-05,0.13594771921634674,,,,,
|
||||
28,2.7330000400543213,0.012748735956847668,3.6180339520797133e-05,0.14640523493289948,,,,,
|
||||
30,2.723299980163574,0.014809778891503811,3.532088885549456e-05,0.1568627506494522,,,,,
|
||||
32,2.7342000007629395,0.01219236571341753,3.4386797779006884e-05,0.16732026636600494,,,,,
|
||||
34,2.7321999073028564,0.012785322032868862,3.338261376484297e-05,0.17777778208255768,,,,,
|
||||
36,2.7314000129699707,0.012986919842660427,3.231322989449836e-05,0.1882352977991104,,,,,
|
||||
38,2.7065999507904053,0.01096824835985899,3.118385939160362e-05,0.19869281351566315,,,,,
|
||||
40,2.6958999633789062,0.012387535534799099,2.9999999242136255e-05,0.20915032923221588,,,,,
|
||||
42,2.751499891281128,0.011586200445890427,2.8767422918463126e-05,0.21960784494876862,,,,,
|
||||
44,2.713099956512451,0.011821281164884567,2.749213126662653e-05,0.23006536066532135,,,,,
|
||||
46,2.7102999687194824,0.01147585827857256,2.6180339773418382e-05,0.24052287638187408,,,,,
|
||||
48,2.7019999027252197,0.011368263512849808,2.483843854861334e-05,0.250980406999588,,,,,
|
||||
50,2.680500030517578,0.009935515932738781,2.3472963221138343e-05,0.26143792271614075,,,,,
|
||||
52,2.6993000507354736,0.0109846917912364,2.2090569473220967e-05,0.2718954384326935,,,,,
|
||||
54,2.6940999031066895,0.010465175844728947,2.0697989384643734e-05,0.2823529541492462,,,,,
|
||||
56,2.7091000080108643,0.01009758748114109,1.9302009604871273e-05,0.29281046986579895,,,,,
|
||||
58,2.69950008392334,0.01249368954449892,1.7909431335283443e-05,0.3032679855823517,,,,,
|
||||
60,2.7216999530792236,0.01051376760005951,1.6527035768376663e-05,0.3137255012989044,,,,,
|
||||
62,2.7158000469207764,0.01054943073540926,1.516156225989107e-05,0.32418301701545715,,,,,
|
||||
64,2.7214999198913574,0.01076149195432663,1.3819660125591327e-05,0.3346405327320099,,,,,
|
||||
66,2.7116000652313232,0.010380392894148827,1.2507867722888477e-05,0.3450980484485626,,,,,
|
||||
68,2.6923000812530518,0.010425001382827759,1.1232576980546582e-05,0.35555556416511536,,,,,
|
||||
70,2.683199882507324,0.00925016961991787,9.999999747378752e-06,0.3660130798816681,,,,,
|
||||
72,2.7093000411987305,0.01072422880679369,8.816142326395493e-06,0.3764705955982208,,,,,
|
||||
74,2.6988000869750977,0.011063243262469769,7.686770914006047e-06,0.38692811131477356,,,,,
|
||||
76,2.7125000953674316,0.01013101264834404,6.61738795315614e-06,0.3973856270313263,,,,,
|
||||
78,2.6993000507354736,0.009940676391124725,5.613203938992228e-06,0.40784314274787903,,,,,
|
||||
80,2.6861000061035156,0.01050259917974472,4.6791110435151495e-06,0.41830065846443176,,,,,
|
||||
82,2.6828999519348145,0.009912634268403053,3.819659923465224e-06,0.4287581741809845,,,,,
|
||||
84,2.699199914932251,0.010668900795280933,3.03903811982309e-06,0.43921568989753723,,,,,
|
||||
86,2.698899984359741,0.009650414809584618,2.3410482299368596e-06,0.44967320561408997,,,,,
|
||||
88,2.6884000301361084,0.01006452739238739,1.7290908544964623e-06,0.4601307213306427,,,,,
|
||||
90,2.6928999423980713,0.010409764014184475,1.206147544507985e-06,0.47058823704719543,,,,,
|
||||
92,2.714200019836426,0.009937116876244545,7.74766078848188e-07,0.48104575276374817,,,,,
|
||||
94,2.672300100326538,0.009728306904435158,4.370479871340649e-07,0.4915032684803009,,,,,
|
||||
96,2.7018001079559326,0.010098566301167011,1.9463863054625108e-07,0.501960813999176,,,,,
|
||||
98,2.7123000621795654,0.009524320252239704,4.8718995060426096e-08,0.5124183297157288,,,,,
|
||||
100,2.7028000354766846,0.009290286339819431,0.0,0.5228758454322815,788.0635986328125,64.96900177001953,0.12700000405311584,4.629706395531346e+17,2.739542245864868
|
||||
|
BIN
example/npu_pretrain_loss.png
Normal file
BIN
example/npu_pretrain_loss.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 46 KiB |
51
example/npu_sft.csv
Normal file
51
example/npu_sft.csv
Normal file
@@ -0,0 +1,51 @@
|
||||
step,train/loss,train/grad_norm,train/learning_rate,train/epoch,train/train_runtime,train/train_samples_per_second,train/train_steps_per_second,train/total_flos,train/train_loss
|
||||
2,1.1491999626159668,0.6218180060386658,1.9999999949504854e-06,0.0004617871018126607,,,,,
|
||||
4,1.0981999635696411,0.6825665235519409,3.999999989900971e-06,0.0009235742036253214,,,,,
|
||||
6,1.1269999742507935,0.7838642001152039,6.000000212225132e-06,0.001385361305437982,,,,,
|
||||
8,1.0542000532150269,0.8744276762008667,7.999999979801942e-06,0.0018471484072506428,,,,,
|
||||
10,1.2441999912261963,0.7064258456230164,9.999999747378752e-06,0.0023089356254786253,,,,,
|
||||
12,1.2927000522613525,0.6829814910888672,1.2000000424450263e-05,0.002770722610875964,,,,,
|
||||
14,1.0844999551773071,0.5265647172927856,1.4000000192027073e-05,0.0032325098291039467,,,,,
|
||||
16,1.0963000059127808,0.4373657703399658,1.5999999959603883e-05,0.0036942968145012856,,,,,
|
||||
18,1.0615999698638916,0.46220508217811584,1.8000000636675395e-05,0.004156084265559912,,,,,
|
||||
20,1.3325999975204468,0.7157824039459229,1.9999999494757503e-05,0.004617871250957251,,,,,
|
||||
22,1.2070000171661377,0.5933427214622498,1.996917308133561e-05,0.0050796582363545895,,,,,
|
||||
24,1.2044999599456787,0.5816172957420349,1.9876883015967906e-05,0.005541445221751928,,,,,
|
||||
26,1.0740000009536743,0.4489712119102478,1.9723698642337695e-05,0.0060032326728105545,,,,,
|
||||
28,1.1164000034332275,0.3696516752243042,1.9510565834934823e-05,0.006465019658207893,,,,,
|
||||
30,1.045199990272522,0.4376335144042969,1.9238796085119247e-05,0.006926806643605232,,,,,
|
||||
32,1.1247999668121338,0.4589230716228485,1.8910064682131633e-05,0.007388593629002571,,,,,
|
||||
34,1.0688999891281128,0.3879022002220154,1.8526401618146338e-05,0.00785038061439991,,,,,
|
||||
36,1.0292999744415283,0.4027869403362274,1.8090169760398567e-05,0.008312168531119823,,,,,
|
||||
38,1.1052000522613525,0.37394437193870544,1.7604059394216165e-05,0.008773955516517162,,,,,
|
||||
40,1.1557999849319458,0.3808683753013611,1.7071068214136176e-05,0.009235742501914501,,,,,
|
||||
42,1.0232000350952148,0.4252733886241913,1.6494481315021403e-05,0.00969752948731184,,,,,
|
||||
44,1.0364999771118164,0.48068660497665405,1.5877853002166376e-05,0.010159316472709179,,,,,
|
||||
46,1.1340999603271484,0.37313926219940186,1.5224985872919206e-05,0.010621103458106518,,,,,
|
||||
48,1.0866999626159668,0.4175492823123932,1.453990535082994e-05,0.011082890443503857,,,,,
|
||||
50,1.1039999723434448,0.35443660616874695,1.3826834219798911e-05,0.01154467836022377,,,,,
|
||||
52,1.1480000019073486,0.39232146739959717,1.3090169886709191e-05,0.012006465345621109,,,,,
|
||||
54,1.1861000061035156,0.396918922662735,1.2334453458606731e-05,0.012468252331018448,,,,,
|
||||
56,1.0096999406814575,0.3885609209537506,1.1564344276848715e-05,0.012930039316415787,,,,,
|
||||
58,1.114799976348877,0.4421806335449219,1.0784590813273098e-05,0.013391826301813126,,,,,
|
||||
60,1.0795999765396118,0.36081990599632263,9.999999747378752e-06,0.013853613287210464,,,,,
|
||||
62,1.1764999628067017,0.4062329828739166,9.215408681484405e-06,0.014315400272607803,,,,,
|
||||
64,1.107200026512146,0.39982733130455017,8.435655217908788e-06,0.014777187258005142,,,,,
|
||||
66,1.1868000030517578,0.4688170254230499,7.665546036150772e-06,0.015238975174725056,,,,,
|
||||
68,1.0312999486923218,0.3301626741886139,6.909830062795663e-06,0.01570076122879982,,,,,
|
||||
70,1.1089999675750732,0.3377252221107483,6.173165729705943e-06,0.01616254821419716,,,,,
|
||||
72,1.0716999769210815,0.39666977524757385,5.460095053422265e-06,0.016624337062239647,,,,,
|
||||
74,1.1461999416351318,0.4125552177429199,4.7750145313329995e-06,0.017086124047636986,,,,,
|
||||
76,1.042199969291687,0.3825180232524872,4.1221474020858295e-06,0.017547911033034325,,,,,
|
||||
78,0.9157000184059143,0.4063441753387451,3.505519543978153e-06,0.018009698018431664,,,,,
|
||||
80,1.1110999584197998,0.35289037227630615,2.9289321901160292e-06,0.018471485003829002,,,,,
|
||||
82,1.167199969291687,0.33720290660858154,2.3959403279150138e-06,0.01893327198922634,,,,,
|
||||
84,1.2375999689102173,0.38099613785743713,1.909829961732612e-06,0.01939505897462368,,,,,
|
||||
86,1.2151999473571777,0.3848689794540405,1.4735983313585166e-06,0.01985684596002102,,,,,
|
||||
88,1.1628999710083008,0.40408074855804443,1.0899348126258701e-06,0.020318632945418358,,,,,
|
||||
90,1.1884000301361084,0.4015007019042969,7.612046601934708e-07,0.020780419930815697,,,,,
|
||||
92,1.152500033378601,0.38306349515914917,4.894348535344761e-07,0.021242206916213036,,,,,
|
||||
94,1.154099941253662,0.45273807644844055,2.7630079557638965e-07,0.021703993901610374,,,,,
|
||||
96,1.0618000030517578,0.35036078095436096,1.2311659247643547e-07,0.022165780887007713,,,,,
|
||||
98,1.0270999670028687,0.40208569169044495,3.0826662111849146e-08,0.022627567872405052,,,,,
|
||||
100,1.0285999774932861,0.38247284293174744,0.0,0.02308935672044754,728.7083129882812,2.196000099182129,0.13699999451637268,1862467846144.0,1.117748498916626
|
||||
|
BIN
example/npu_sft_loss.png
Normal file
BIN
example/npu_sft_loss.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 67 KiB |
8
example/requirements.txt
Normal file
8
example/requirements.txt
Normal file
@@ -0,0 +1,8 @@
|
||||
transformers==4.46.3
|
||||
tokenizers==0.20.3
|
||||
accelerate==1.1.1
|
||||
deepspeed==0.16.2
|
||||
datasets==3.1.0
|
||||
safetensors==0.4.5
|
||||
pyarrow==17.0.0
|
||||
tensorboard==2.18.0
|
||||
38
example/run.sh
Normal file
38
example/run.sh
Normal file
@@ -0,0 +1,38 @@
|
||||
#!/bin/bash
|
||||
|
||||
MODEL_PATH="/model/BitCPM-CANN-1B-unquantized"
|
||||
DATA_PATH="/dataset/c4-pro/data/000_1_7.parquet"
|
||||
OUTPUT_DIR="./output"
|
||||
DS_CONFIG="./ds_config_z2.json"
|
||||
|
||||
NUM_GPUS=8
|
||||
BATCH_SIZE_PER_GPU=8
|
||||
GRAD_ACCUM_STEPS=8
|
||||
MAX_SEQ_LENGTH=1024
|
||||
|
||||
export ASCEND_RT_VISIBLE_DEVICES=8,9,10,11,12,13,14,15
|
||||
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
|
||||
export DS_SKIP_CUDA_CHECK=1
|
||||
torchrun --nproc_per_node=$NUM_GPUS train.py \
|
||||
--model_name_or_path $MODEL_PATH \
|
||||
--data_path $DATA_PATH \
|
||||
--max_seq_length $MAX_SEQ_LENGTH \
|
||||
--output_dir $OUTPUT_DIR \
|
||||
--per_device_train_batch_size $BATCH_SIZE_PER_GPU \
|
||||
--gradient_accumulation_steps $GRAD_ACCUM_STEPS \
|
||||
--max_steps 100 \
|
||||
--learning_rate 4e-5 \
|
||||
--lr_scheduler_type cosine \
|
||||
--warmup_ratio 0.1 \
|
||||
--weight_decay 1e-2 \
|
||||
--logging_steps 2 \
|
||||
--save_steps 500 \
|
||||
--save_total_limit 3 \
|
||||
--bf16 \
|
||||
--deepspeed $DS_CONFIG \
|
||||
--gradient_checkpointing \
|
||||
--seed 42 \
|
||||
--dataloader_num_workers 4 \
|
||||
--report_to tensorboard \
|
||||
--logging_dir /data/tensorboard/pretrain \
|
||||
--gradient_checkpointing_kwargs '{"use_reentrant": false}'
|
||||
40
example/run_sft.sh
Normal file
40
example/run_sft.sh
Normal file
@@ -0,0 +1,40 @@
|
||||
#!/bin/bash
|
||||
|
||||
MODEL_PATH="/model/BitCPM-CANN-1B-unquantized"
|
||||
DATA_PATH="/dataset/HuggingFaceH4_ultrachat_200k/data/train_sft-00000-of-00003-a3ecf92756993583.parquet"
|
||||
OUTPUT_DIR="./output_sft"
|
||||
DS_CONFIG="./ds_config.json"
|
||||
|
||||
NUM_GPUS=8
|
||||
BATCH_SIZE_PER_GPU=2
|
||||
GRAD_ACCUM_STEPS=1
|
||||
MAX_SEQ_LENGTH=8192
|
||||
|
||||
export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
|
||||
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
|
||||
export DS_SKIP_CUDA_CHECK=1
|
||||
|
||||
torchrun --nproc_per_node=$NUM_GPUS train_sft.py \
|
||||
--model_name_or_path $MODEL_PATH \
|
||||
--data_path $DATA_PATH \
|
||||
--max_seq_length $MAX_SEQ_LENGTH \
|
||||
--output_dir $OUTPUT_DIR \
|
||||
--per_device_train_batch_size $BATCH_SIZE_PER_GPU \
|
||||
--gradient_accumulation_steps $GRAD_ACCUM_STEPS \
|
||||
--max_steps 100 \
|
||||
--learning_rate 2e-5 \
|
||||
--lr_scheduler_type cosine \
|
||||
--warmup_ratio 0.2 \
|
||||
--weight_decay 0.0 \
|
||||
--logging_steps 2 \
|
||||
--save_steps 500 \
|
||||
--save_total_limit 3 \
|
||||
--bf16 \
|
||||
--deepspeed $DS_CONFIG \
|
||||
--gradient_checkpointing \
|
||||
--seed 42 \
|
||||
--dataloader_num_workers 4 \
|
||||
--report_to tensorboard \
|
||||
--logging_dir /data/tensorboard/sft \
|
||||
--train_on_prompt false \
|
||||
--gradient_checkpointing_kwargs '{"use_reentrant": false}'
|
||||
203
example/train.py
Normal file
203
example/train.py
Normal file
@@ -0,0 +1,203 @@
|
||||
"""
|
||||
Continual pretraining script for CPM-2B model using DeepSpeed + HuggingFace Trainer.
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import math
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import contextlib
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
AutoConfig,
|
||||
Trainer,
|
||||
TrainingArguments,
|
||||
HfArgumentParser,
|
||||
DataCollatorForLanguageModeling,
|
||||
set_seed,
|
||||
)
|
||||
|
||||
import deepspeed
|
||||
_orig_no_sync = deepspeed.DeepSpeedEngine.no_sync
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _patched_no_sync(self):
|
||||
try:
|
||||
with _orig_no_sync(self):
|
||||
yield
|
||||
except AssertionError:
|
||||
yield
|
||||
|
||||
deepspeed.DeepSpeedEngine.no_sync = _patched_no_sync
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments:
|
||||
model_name_or_path: str = field(
|
||||
metadata={"help": "Path to pretrained model or model identifier"}
|
||||
)
|
||||
torch_dtype: Optional[str] = field(
|
||||
default="bfloat16",
|
||||
metadata={"help": "torch dtype for model weights (float16, bfloat16, float32)"},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataArguments:
|
||||
data_path: str = field(
|
||||
metadata={"help": "Path to training data (parquet file or directory)"}
|
||||
)
|
||||
max_seq_length: int = field(
|
||||
default=4096,
|
||||
metadata={"help": "Maximum sequence length for training"},
|
||||
)
|
||||
text_column: str = field(
|
||||
default="text",
|
||||
metadata={"help": "Name of the text column in the dataset"},
|
||||
)
|
||||
preprocessing_num_workers: int = field(
|
||||
default=8,
|
||||
metadata={"help": "Number of workers for data preprocessing"},
|
||||
)
|
||||
|
||||
|
||||
def tokenize_and_group(dataset, tokenizer, data_args):
|
||||
"""Tokenize texts and group into chunks of max_seq_length."""
|
||||
|
||||
column_names = dataset.column_names
|
||||
text_column = data_args.text_column
|
||||
if text_column not in column_names:
|
||||
candidates = [c for c in column_names if "text" in c.lower()]
|
||||
if candidates:
|
||||
text_column = candidates[0]
|
||||
else:
|
||||
text_column = column_names[0]
|
||||
logger.warning(f"Column '{data_args.text_column}' not found, using '{text_column}'")
|
||||
|
||||
def tokenize_function(examples):
|
||||
return tokenizer(examples[text_column], add_special_tokens=False)
|
||||
|
||||
tokenized_dataset = dataset.map(
|
||||
tokenize_function,
|
||||
batched=True,
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
remove_columns=column_names,
|
||||
desc="Tokenizing",
|
||||
)
|
||||
|
||||
block_size = data_args.max_seq_length
|
||||
|
||||
def group_texts(examples):
|
||||
concatenated = {k: sum(examples[k], []) for k in examples.keys()}
|
||||
total_length = len(concatenated["input_ids"])
|
||||
total_length = (total_length // block_size) * block_size
|
||||
|
||||
result = {
|
||||
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
|
||||
for k, t in concatenated.items()
|
||||
}
|
||||
result["labels"] = result["input_ids"].copy()
|
||||
return result
|
||||
|
||||
grouped_dataset = tokenized_dataset.map(
|
||||
group_texts,
|
||||
batched=True,
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
desc="Grouping texts",
|
||||
)
|
||||
|
||||
return grouped_dataset
|
||||
|
||||
|
||||
def main():
|
||||
parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
|
||||
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
|
||||
)
|
||||
logger.info(f"Training args: {training_args}")
|
||||
|
||||
set_seed(training_args.seed)
|
||||
|
||||
dtype_map = {
|
||||
"float16": torch.float16,
|
||||
"bfloat16": torch.bfloat16,
|
||||
"float32": torch.float32,
|
||||
}
|
||||
torch_dtype = dtype_map.get(model_args.torch_dtype, torch.bfloat16)
|
||||
|
||||
logger.info(f"Loading tokenizer from {model_args.model_name_or_path}")
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
logger.info(f"Loading model from {model_args.model_name_or_path}")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
torch_dtype=torch_dtype,
|
||||
trust_remote_code=True,
|
||||
attn_implementation="sdpa",
|
||||
)
|
||||
model.config.use_cache = False
|
||||
|
||||
logger.info(f"Loading dataset from {data_args.data_path}")
|
||||
if os.path.isfile(data_args.data_path):
|
||||
raw_dataset = load_dataset("parquet", data_files=data_args.data_path, split="train")
|
||||
elif os.path.isdir(data_args.data_path):
|
||||
parquet_files = [
|
||||
os.path.join(data_args.data_path, f)
|
||||
for f in os.listdir(data_args.data_path)
|
||||
if f.endswith(".parquet")
|
||||
]
|
||||
raw_dataset = load_dataset("parquet", data_files=parquet_files, split="train")
|
||||
else:
|
||||
raise ValueError(f"Data path not found: {data_args.data_path}")
|
||||
|
||||
logger.info(f"Dataset loaded: {len(raw_dataset)} samples, columns: {raw_dataset.column_names}")
|
||||
|
||||
train_dataset = tokenize_and_group(raw_dataset, tokenizer, data_args)
|
||||
logger.info(f"Processed dataset: {len(train_dataset)} samples of length {data_args.max_seq_length}")
|
||||
|
||||
data_collator = DataCollatorForLanguageModeling(
|
||||
tokenizer=tokenizer,
|
||||
mlm=False,
|
||||
)
|
||||
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
data_collator=data_collator,
|
||||
)
|
||||
|
||||
logger.info("Starting training...")
|
||||
train_result = trainer.train(
|
||||
resume_from_checkpoint=training_args.resume_from_checkpoint
|
||||
)
|
||||
|
||||
trainer.save_model()
|
||||
trainer.save_state()
|
||||
|
||||
metrics = train_result.metrics
|
||||
metrics["train_samples"] = len(train_dataset)
|
||||
trainer.log_metrics("train", metrics)
|
||||
trainer.save_metrics("train", metrics)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
424
example/train_sft.py
Normal file
424
example/train_sft.py
Normal file
@@ -0,0 +1,424 @@
|
||||
"""
|
||||
Supervised fine-tuning script using DeepSpeed + HuggingFace Trainer.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import contextlib
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
HfArgumentParser,
|
||||
Trainer,
|
||||
TrainingArguments,
|
||||
set_seed,
|
||||
)
|
||||
|
||||
import deepspeed
|
||||
_orig_no_sync = deepspeed.DeepSpeedEngine.no_sync
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _patched_no_sync(self):
|
||||
try:
|
||||
with _orig_no_sync(self):
|
||||
yield
|
||||
except AssertionError:
|
||||
yield
|
||||
|
||||
deepspeed.DeepSpeedEngine.no_sync = _patched_no_sync
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
IGNORE_INDEX = -100
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments:
|
||||
model_name_or_path: str = field(
|
||||
metadata={"help": "Path to pretrained model or model identifier"}
|
||||
)
|
||||
torch_dtype: Optional[str] = field(
|
||||
default="bfloat16",
|
||||
metadata={"help": "torch dtype for model weights (float16, bfloat16, float32)"},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataArguments:
|
||||
data_path: str = field(metadata={"help": "Path to SFT data file or directory"})
|
||||
max_seq_length: int = field(
|
||||
default=4096,
|
||||
metadata={"help": "Maximum sequence length for training"},
|
||||
)
|
||||
prompt_column: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Prompt/instruction column name. Auto-detected if omitted."},
|
||||
)
|
||||
input_column: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Optional extra input/context column name"},
|
||||
)
|
||||
response_column: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Response/output column name. Auto-detected if omitted."},
|
||||
)
|
||||
messages_column: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Chat messages column name. Auto-detected if omitted."},
|
||||
)
|
||||
system_column: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Optional system prompt column name"},
|
||||
)
|
||||
train_on_prompt: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to compute loss on prompt/user tokens"},
|
||||
)
|
||||
add_eos_token: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Append eos_token to plain prompt/response examples"},
|
||||
)
|
||||
preprocessing_num_workers: int = field(
|
||||
default=8,
|
||||
metadata={"help": "Number of workers for data preprocessing"},
|
||||
)
|
||||
|
||||
|
||||
class SFTDataCollator:
|
||||
def __init__(self, tokenizer, pad_to_multiple_of: Optional[int] = 8):
|
||||
self.tokenizer = tokenizer
|
||||
self.pad_to_multiple_of = pad_to_multiple_of
|
||||
|
||||
def __call__(self, features: List[Dict[str, List[int]]]) -> Dict[str, torch.Tensor]:
|
||||
max_length = max(len(feature["input_ids"]) for feature in features)
|
||||
if self.pad_to_multiple_of:
|
||||
multiple = self.pad_to_multiple_of
|
||||
max_length = ((max_length + multiple - 1) // multiple) * multiple
|
||||
|
||||
input_ids = []
|
||||
attention_mask = []
|
||||
labels = []
|
||||
pad_token_id = self.tokenizer.pad_token_id
|
||||
|
||||
for feature in features:
|
||||
length = len(feature["input_ids"])
|
||||
pad_length = max_length - length
|
||||
input_ids.append(feature["input_ids"] + [pad_token_id] * pad_length)
|
||||
attention_mask.append([1] * length + [0] * pad_length)
|
||||
labels.append(feature["labels"] + [IGNORE_INDEX] * pad_length)
|
||||
|
||||
return {
|
||||
"input_ids": torch.tensor(input_ids, dtype=torch.long),
|
||||
"attention_mask": torch.tensor(attention_mask, dtype=torch.long),
|
||||
"labels": torch.tensor(labels, dtype=torch.long),
|
||||
}
|
||||
|
||||
|
||||
def load_sft_dataset(data_path: str):
|
||||
if os.path.isfile(data_path):
|
||||
extension = os.path.splitext(data_path)[1].lstrip(".").lower()
|
||||
if extension == "jsonl":
|
||||
extension = "json"
|
||||
if extension not in {"parquet", "json", "csv", "txt"}:
|
||||
raise ValueError(f"Unsupported data file extension: {extension}")
|
||||
return load_dataset(extension, data_files=data_path, split="train")
|
||||
|
||||
if os.path.isdir(data_path):
|
||||
data_files = []
|
||||
extension = None
|
||||
for name in os.listdir(data_path):
|
||||
current_extension = os.path.splitext(name)[1].lstrip(".").lower()
|
||||
if current_extension == "jsonl":
|
||||
current_extension = "json"
|
||||
if current_extension in {"parquet", "json", "csv", "txt"}:
|
||||
extension = extension or current_extension
|
||||
if current_extension == extension:
|
||||
data_files.append(os.path.join(data_path, name))
|
||||
if not data_files or extension is None:
|
||||
raise ValueError(f"No supported data files found in: {data_path}")
|
||||
return load_dataset(extension, data_files=sorted(data_files), split="train")
|
||||
|
||||
raise ValueError(f"Data path not found: {data_path}")
|
||||
|
||||
|
||||
def choose_column(
|
||||
column_names: List[str], explicit: Optional[str], candidates: List[str]
|
||||
) -> Optional[str]:
|
||||
if explicit:
|
||||
if explicit not in column_names:
|
||||
raise ValueError(f"Column '{explicit}' not found. Available columns: {column_names}")
|
||||
return explicit
|
||||
for name in candidates:
|
||||
if name in column_names:
|
||||
return name
|
||||
return None
|
||||
|
||||
|
||||
def parse_messages(value: Any) -> List[Dict[str, str]]:
|
||||
if isinstance(value, str):
|
||||
value = json.loads(value)
|
||||
if not isinstance(value, list):
|
||||
raise ValueError("messages/conversations column must be a list or JSON string")
|
||||
|
||||
messages = []
|
||||
for item in value:
|
||||
if not isinstance(item, dict):
|
||||
raise ValueError("Each message must be a dict")
|
||||
|
||||
role = item.get("role", item.get("from"))
|
||||
content = item.get("content", item.get("value"))
|
||||
if role == "human":
|
||||
role = "user"
|
||||
elif role == "gpt":
|
||||
role = "assistant"
|
||||
|
||||
if role is None or content is None:
|
||||
raise ValueError("Each message must contain role/from and content/value")
|
||||
messages.append({"role": str(role), "content": str(content)})
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
def tokenize_text(tokenizer, text: str) -> List[int]:
|
||||
return tokenizer(text, add_special_tokens=False)["input_ids"]
|
||||
|
||||
|
||||
def apply_chat_template(tokenizer, messages: List[Dict[str, str]], add_generation_prompt: bool) -> str:
|
||||
if tokenizer.chat_template is None:
|
||||
raise ValueError(
|
||||
"The tokenizer has no chat_template. Use prompt/response columns or set a chat_template."
|
||||
)
|
||||
return tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=add_generation_prompt,
|
||||
)
|
||||
|
||||
|
||||
def encode_prompt_response(
|
||||
example: Dict[str, Any],
|
||||
tokenizer,
|
||||
data_args: DataArguments,
|
||||
prompt_column: str,
|
||||
input_column: Optional[str],
|
||||
response_column: str,
|
||||
) -> Tuple[List[int], List[int]]:
|
||||
prompt = str(example[prompt_column])
|
||||
if input_column and example.get(input_column):
|
||||
prompt = prompt + "\n" + str(example[input_column])
|
||||
response = str(example[response_column])
|
||||
|
||||
messages = []
|
||||
if data_args.system_column and example.get(data_args.system_column):
|
||||
messages.append({"role": "system", "content": str(example[data_args.system_column])})
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
messages.append({"role": "assistant", "content": response})
|
||||
|
||||
if tokenizer.chat_template is not None:
|
||||
full_text = apply_chat_template(tokenizer, messages, add_generation_prompt=False)
|
||||
prompt_text = apply_chat_template(tokenizer, messages[:-1], add_generation_prompt=True)
|
||||
input_ids = tokenize_text(tokenizer, full_text)
|
||||
prompt_length = len(tokenize_text(tokenizer, prompt_text))
|
||||
else:
|
||||
response_text = response
|
||||
if data_args.add_eos_token and tokenizer.eos_token:
|
||||
response_text += tokenizer.eos_token
|
||||
full_text = prompt + "\n" + response_text
|
||||
input_ids = tokenize_text(tokenizer, full_text)
|
||||
prompt_length = len(tokenize_text(tokenizer, prompt + "\n"))
|
||||
|
||||
labels = input_ids.copy()
|
||||
if not data_args.train_on_prompt:
|
||||
labels[:prompt_length] = [IGNORE_INDEX] * min(prompt_length, len(labels))
|
||||
return input_ids, labels
|
||||
|
||||
|
||||
def encode_messages(
|
||||
example: Dict[str, Any],
|
||||
tokenizer,
|
||||
data_args: DataArguments,
|
||||
messages_column: str,
|
||||
) -> Tuple[List[int], List[int]]:
|
||||
messages = parse_messages(example[messages_column])
|
||||
|
||||
if tokenizer.chat_template is not None:
|
||||
full_text = apply_chat_template(tokenizer, messages, add_generation_prompt=False)
|
||||
input_ids = tokenize_text(tokenizer, full_text)
|
||||
labels = [IGNORE_INDEX] * len(input_ids)
|
||||
|
||||
if data_args.train_on_prompt:
|
||||
labels = input_ids.copy()
|
||||
else:
|
||||
for index, message in enumerate(messages):
|
||||
if message["role"] != "assistant":
|
||||
continue
|
||||
before_text = apply_chat_template(
|
||||
tokenizer, messages[:index], add_generation_prompt=True
|
||||
)
|
||||
after_text = apply_chat_template(
|
||||
tokenizer, messages[: index + 1], add_generation_prompt=False
|
||||
)
|
||||
start = len(tokenize_text(tokenizer, before_text))
|
||||
end = len(tokenize_text(tokenizer, after_text))
|
||||
labels[start:end] = input_ids[start:end]
|
||||
else:
|
||||
labels = []
|
||||
input_ids = []
|
||||
for message in messages:
|
||||
part = f"{message['role']}: {message['content']}\n"
|
||||
if data_args.add_eos_token and message["role"] == "assistant" and tokenizer.eos_token:
|
||||
part += tokenizer.eos_token
|
||||
part_ids = tokenize_text(tokenizer, part)
|
||||
input_ids.extend(part_ids)
|
||||
if data_args.train_on_prompt or message["role"] == "assistant":
|
||||
labels.extend(part_ids)
|
||||
else:
|
||||
labels.extend([IGNORE_INDEX] * len(part_ids))
|
||||
|
||||
return input_ids, labels
|
||||
|
||||
|
||||
def preprocess_sft_dataset(raw_dataset, tokenizer, data_args: DataArguments):
|
||||
column_names = raw_dataset.column_names
|
||||
messages_column = choose_column(
|
||||
column_names, data_args.messages_column, ["messages", "conversations"]
|
||||
)
|
||||
prompt_column = choose_column(
|
||||
column_names,
|
||||
data_args.prompt_column,
|
||||
["prompt", "instruction", "question"],
|
||||
)
|
||||
input_column = choose_column(
|
||||
column_names,
|
||||
data_args.input_column,
|
||||
["input", "context"],
|
||||
)
|
||||
response_column = choose_column(
|
||||
column_names,
|
||||
data_args.response_column,
|
||||
["response", "output", "answer", "chosen"],
|
||||
)
|
||||
|
||||
if messages_column:
|
||||
logger.info(f"Using chat messages column: {messages_column}")
|
||||
elif prompt_column and response_column:
|
||||
logger.info(f"Using prompt column '{prompt_column}' and response column '{response_column}'")
|
||||
else:
|
||||
raise ValueError(
|
||||
"Cannot infer SFT data format. Provide either messages/conversations or "
|
||||
"prompt/instruction plus response/output columns."
|
||||
)
|
||||
|
||||
def encode_batch(examples):
|
||||
batch_input_ids = []
|
||||
batch_labels = []
|
||||
batch_attention_mask = []
|
||||
|
||||
batch_size = len(next(iter(examples.values())))
|
||||
for i in range(batch_size):
|
||||
example = {name: values[i] for name, values in examples.items()}
|
||||
if messages_column:
|
||||
input_ids, labels = encode_messages(example, tokenizer, data_args, messages_column)
|
||||
else:
|
||||
input_ids, labels = encode_prompt_response(
|
||||
example, tokenizer, data_args, prompt_column, input_column, response_column
|
||||
)
|
||||
|
||||
input_ids = input_ids[: data_args.max_seq_length]
|
||||
labels = labels[: data_args.max_seq_length]
|
||||
if not input_ids or all(label == IGNORE_INDEX for label in labels):
|
||||
continue
|
||||
|
||||
batch_input_ids.append(input_ids)
|
||||
batch_labels.append(labels)
|
||||
batch_attention_mask.append([1] * len(input_ids))
|
||||
|
||||
return {
|
||||
"input_ids": batch_input_ids,
|
||||
"attention_mask": batch_attention_mask,
|
||||
"labels": batch_labels,
|
||||
}
|
||||
|
||||
return raw_dataset.map(
|
||||
encode_batch,
|
||||
batched=True,
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
remove_columns=column_names,
|
||||
desc="Tokenizing SFT data",
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
|
||||
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
|
||||
)
|
||||
logger.info(f"Training args: {training_args}")
|
||||
|
||||
set_seed(training_args.seed)
|
||||
|
||||
dtype_map = {
|
||||
"float16": torch.float16,
|
||||
"bfloat16": torch.bfloat16,
|
||||
"float32": torch.float32,
|
||||
}
|
||||
torch_dtype = dtype_map.get(model_args.torch_dtype, torch.bfloat16)
|
||||
|
||||
logger.info(f"Loading tokenizer from {model_args.model_name_or_path}")
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
logger.info(f"Loading model from {model_args.model_name_or_path}")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
torch_dtype=torch_dtype,
|
||||
trust_remote_code=True,
|
||||
attn_implementation="sdpa",
|
||||
)
|
||||
model.config.use_cache = False
|
||||
|
||||
logger.info(f"Loading SFT dataset from {data_args.data_path}")
|
||||
raw_dataset = load_sft_dataset(data_args.data_path)
|
||||
logger.info(f"Dataset loaded: {len(raw_dataset)} samples, columns: {raw_dataset.column_names}")
|
||||
|
||||
train_dataset = preprocess_sft_dataset(raw_dataset, tokenizer, data_args)
|
||||
logger.info(f"Processed dataset: {len(train_dataset)} samples")
|
||||
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
data_collator=SFTDataCollator(tokenizer),
|
||||
)
|
||||
|
||||
logger.info("Starting SFT training...")
|
||||
train_result = trainer.train(
|
||||
resume_from_checkpoint=training_args.resume_from_checkpoint
|
||||
)
|
||||
|
||||
trainer.save_model()
|
||||
trainer.save_state()
|
||||
|
||||
metrics = train_result.metrics
|
||||
metrics["train_samples"] = len(train_dataset)
|
||||
trainer.log_metrics("train", metrics)
|
||||
trainer.save_metrics("train", metrics)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user