From 0ea2008ea23cc07e16cef4ef8326998ad1e0f055 Mon Sep 17 00:00:00 2001 From: ModelHub XC Date: Sun, 3 May 2026 01:27:00 +0800 Subject: [PATCH] =?UTF-8?q?=E5=88=9D=E5=A7=8B=E5=8C=96=E9=A1=B9=E7=9B=AE?= =?UTF-8?q?=EF=BC=8C=E7=94=B1ModelHub=20XC=E7=A4=BE=E5=8C=BA=E6=8F=90?= =?UTF-8?q?=E4=BE=9B=E6=A8=A1=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Model: xu1998hz/instructscore_caption Source: Original Platform --- .gitattributes | 35 + added_tokens.json | 3 + config.json | 24 + generation_config.json | 7 + id_rsa.pub | 1 + latest | 1 + length_check.py | 7 + pytorch_model.bin | 3 + rng_state_0.pth | 3 + rng_state_1.pth | 3 + rng_state_2.pth | 3 + rng_state_3.pth | 3 + rng_state_4.pth | 3 + rng_state_5.pth | 3 + rng_state_6.pth | 3 + rng_state_7.pth | 3 + special_tokens_map.json | 6 + tion_ssh_key.pub | 1 + tokenizer.model | 3 + tokenizer_config.json | 34 + trainer_state.json | 2788 +++++++++++++++++++++++++++++++++++++++ training_args.bin | 3 + zero_to_fp32.py | 578 ++++++++ 23 files changed, 3518 insertions(+) create mode 100644 .gitattributes create mode 100644 added_tokens.json create mode 100644 config.json create mode 100644 generation_config.json create mode 100644 id_rsa.pub create mode 100644 latest create mode 100644 length_check.py create mode 100644 pytorch_model.bin create mode 100644 rng_state_0.pth create mode 100644 rng_state_1.pth create mode 100644 rng_state_2.pth create mode 100644 rng_state_3.pth create mode 100644 rng_state_4.pth create mode 100644 rng_state_5.pth create mode 100644 rng_state_6.pth create mode 100644 rng_state_7.pth create mode 100644 special_tokens_map.json create mode 100644 tion_ssh_key.pub create mode 100644 tokenizer.model create mode 100644 tokenizer_config.json create mode 100644 trainer_state.json create mode 100644 training_args.bin create mode 100755 zero_to_fp32.py diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..a6344aa --- /dev/null +++ b/.gitattributes @@ -0,0 +1,35 @@ +*.7z filter=lfs diff=lfs merge=lfs -text +*.arrow filter=lfs diff=lfs merge=lfs -text +*.bin filter=lfs diff=lfs merge=lfs -text +*.bz2 filter=lfs diff=lfs merge=lfs -text +*.ckpt filter=lfs diff=lfs merge=lfs -text +*.ftz filter=lfs diff=lfs merge=lfs -text +*.gz filter=lfs diff=lfs merge=lfs -text +*.h5 filter=lfs diff=lfs merge=lfs -text +*.joblib filter=lfs diff=lfs merge=lfs -text +*.lfs.* filter=lfs diff=lfs merge=lfs -text +*.mlmodel filter=lfs diff=lfs merge=lfs -text +*.model filter=lfs diff=lfs merge=lfs -text +*.msgpack filter=lfs diff=lfs merge=lfs -text +*.npy filter=lfs diff=lfs merge=lfs -text +*.npz filter=lfs diff=lfs merge=lfs -text +*.onnx filter=lfs diff=lfs merge=lfs -text +*.ot filter=lfs diff=lfs merge=lfs -text +*.parquet filter=lfs diff=lfs merge=lfs -text +*.pb filter=lfs diff=lfs merge=lfs -text +*.pickle filter=lfs diff=lfs merge=lfs -text +*.pkl filter=lfs diff=lfs merge=lfs -text +*.pt filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text +*.rar filter=lfs diff=lfs merge=lfs -text +*.safetensors filter=lfs diff=lfs merge=lfs -text +saved_model/**/* filter=lfs diff=lfs merge=lfs -text +*.tar.* filter=lfs diff=lfs merge=lfs -text +*.tar filter=lfs diff=lfs merge=lfs -text +*.tflite filter=lfs diff=lfs merge=lfs -text +*.tgz filter=lfs diff=lfs merge=lfs -text +*.wasm filter=lfs diff=lfs merge=lfs -text +*.xz filter=lfs diff=lfs merge=lfs -text +*.zip filter=lfs diff=lfs merge=lfs -text +*.zst filter=lfs diff=lfs merge=lfs -text +*tfevents* filter=lfs diff=lfs merge=lfs -text diff --git a/added_tokens.json b/added_tokens.json new file mode 100644 index 0000000..e41416d --- /dev/null +++ b/added_tokens.json @@ -0,0 +1,3 @@ +{ + "[PAD]": 32000 +} diff --git a/config.json b/config.json new file mode 100644 index 0000000..bb719f4 --- /dev/null +++ b/config.json @@ -0,0 +1,24 @@ +{ + "_name_or_path": "decapoda-research/llama-7b-hf", + "architectures": [ + "LlamaForCausalLM" + ], + "bos_token_id": 0, + "eos_token_id": 1, + "hidden_act": "silu", + "hidden_size": 4096, + "initializer_range": 0.02, + "intermediate_size": 11008, + "max_position_embeddings": 2048, + "max_sequence_length": 2048, + "model_type": "llama", + "num_attention_heads": 32, + "num_hidden_layers": 32, + "pad_token_id": -1, + "rms_norm_eps": 1e-06, + "tie_word_embeddings": false, + "torch_dtype": "float16", + "transformers_version": "4.29.2", + "use_cache": true, + "vocab_size": 32001 +} diff --git a/generation_config.json b/generation_config.json new file mode 100644 index 0000000..517f415 --- /dev/null +++ b/generation_config.json @@ -0,0 +1,7 @@ +{ + "_from_model_config": true, + "bos_token_id": 0, + "eos_token_id": 1, + "pad_token_id": 0, + "transformers_version": "4.29.2" +} diff --git a/id_rsa.pub b/id_rsa.pub new file mode 100644 index 0000000..28577d5 --- /dev/null +++ b/id_rsa.pub @@ -0,0 +1 @@ +ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQDVKQ2ytnBTZdaIptWwmQDQIEXNyMPQBYNGC6k0OclGETZwVXy7SsQ5DSGudPqjqadbyjcP7Sxc8EykTD9qafi76CVBKljClOlueJkL0aqAFfGamn1kyJ2qAjw7+WUiehYYCRMdKW9q1by89wqMqY5CE/b900ietdr6JYPMBRPoS/yt63HrwgXVdGDk7BFQ6V2C+pRfdj6OLsM5/8ARmFNBSvQmX96ToGvK3xSILIqu3lzV92HGsdealUnBdMwXtX2vG9Ituw9rcYwPyhGqa/RlrsYVEp4B7Q+5858ZD0Pwe3H0sWxQ2uBGZWa6KeE3VXs2NxwLzdjB89VdZ9dPXJFT+VCjuvui5cy6r1+sgKJZaGnYUDOA5D1ipjjh7g/oWeGj+8rqVscrALdDOggwLA1ejtS5hy1ZvRMrAI1M/ueWoPm0HRc64LAZoUjjWhEoBsjfLXnva/G/ai6DmR/2qkePvoJipulrWjgBv7q0lEAb4WnjOdX9XNc6ZQiY1hzfcmM= wenda@wenda-Inspiron-3593 diff --git a/latest b/latest new file mode 100644 index 0000000..5035238 --- /dev/null +++ b/latest @@ -0,0 +1 @@ +global_step462 \ No newline at end of file diff --git a/length_check.py b/length_check.py new file mode 100644 index 0000000..fbf8ab7 --- /dev/null +++ b/length_check.py @@ -0,0 +1,7 @@ +from datasets import load_dataset +from transformers import AutoTokenizer + +data_cnn = load_dataset("cnn_dailymail", '3.0.0') +data_xsum = load_dataset("xsum") + +print(data_xsum) diff --git a/pytorch_model.bin b/pytorch_model.bin new file mode 100644 index 0000000..9af1eb0 --- /dev/null +++ b/pytorch_model.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:446726817ff6e57a8ecd26703077946f1f09ebc6a56993df42a08fcb1492657c +size 13476958625 diff --git a/rng_state_0.pth b/rng_state_0.pth new file mode 100644 index 0000000..8235363 --- /dev/null +++ b/rng_state_0.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4c1d4aab692d41c18c919a673e506a5264ad9de9ff132b0de5dab52510157c75 +size 21687 diff --git a/rng_state_1.pth b/rng_state_1.pth new file mode 100644 index 0000000..bffd984 --- /dev/null +++ b/rng_state_1.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:72575f2b4522908157bba855533d15354ced6598ecdc09ca8ce5417fa99c9b92 +size 21687 diff --git a/rng_state_2.pth b/rng_state_2.pth new file mode 100644 index 0000000..cff1daa --- /dev/null +++ b/rng_state_2.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:eabb43b1fac99cd8e34e58482a942b31a1855cd53120560e1428e8a2aae185e9 +size 21687 diff --git a/rng_state_3.pth b/rng_state_3.pth new file mode 100644 index 0000000..cfbd0b7 --- /dev/null +++ b/rng_state_3.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cfd9568c2d39f428f86eb5a99beb6f19ac57f3f024ec4494b8cadd107d148298 +size 21687 diff --git a/rng_state_4.pth b/rng_state_4.pth new file mode 100644 index 0000000..1a9b886 --- /dev/null +++ b/rng_state_4.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:21ab6cb42d703108752ec781af0a5227b721321c411bb7f7bca00b6ec4e15324 +size 21687 diff --git a/rng_state_5.pth b/rng_state_5.pth new file mode 100644 index 0000000..58ecc9d --- /dev/null +++ b/rng_state_5.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c9fe828193c13a4a1adccaf24735d85428cb925bfa6204603f3a0a4a34c8cec9 +size 21687 diff --git a/rng_state_6.pth b/rng_state_6.pth new file mode 100644 index 0000000..84a982c --- /dev/null +++ b/rng_state_6.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:36e15c4dcb706b97a74d0f8259d6466dd2e46e60fd71dcb2c6d99e1e9ff2f204 +size 21687 diff --git a/rng_state_7.pth b/rng_state_7.pth new file mode 100644 index 0000000..2f84f3f --- /dev/null +++ b/rng_state_7.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7a467ddda52bc75625e56adc924121f89adb1e65a5733eec05f6284586a9a926 +size 21687 diff --git a/special_tokens_map.json b/special_tokens_map.json new file mode 100644 index 0000000..318f913 --- /dev/null +++ b/special_tokens_map.json @@ -0,0 +1,6 @@ +{ + "bos_token": "", + "eos_token": "", + "pad_token": "[PAD]", + "unk_token": "" +} diff --git a/tion_ssh_key.pub b/tion_ssh_key.pub new file mode 100644 index 0000000..05aea10 --- /dev/null +++ b/tion_ssh_key.pub @@ -0,0 +1 @@ +ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQDVgp/wj0CXuAA2dP/swXwcu5e0snmQETPF/HIXKNteUZ46+wOSpvwpuk/D2hrl86bj3GOdiGLtZC51nUWnOxrHv7a3MUbFlbKYpKKv2LtEQuSoY5nhSVWIJZnhio12tStvuqr6eat4RZqHm5jZ+vHNfb8GMiPxi5jM59k7nbZm4Fc6XaqJXeOxRsMCTyDI0T7RR2FA0T+CJTyQsDeGXAxclq2P9cXUy9vKfrC68bI87k5ZzXbNDT9HcWKLoTWYHxTHvRwIQh7GKaZiXQ3Q6cph9Jcn7Xr3sbWA8adTzCCVxcfO7X50tBBO1SXNBZvTHEfKYuXSvqCIw/034+0QQ1WEc9rse8y/7KKoOyenf1rcQ/RadmsxWHu0oGgOQIs6gyXmzS38cE31qnHzNducK+yw7XBVFUrUWTqAdaXsFaTL39nbXGSYgaCfQKWOeJXl2cM0F6PG3eWAEiZvpgMS9+Yl+DZzqksCR/uQGLOX/A5YtPpSn5W1HDl7FLly9b5xWrc= wenda@wenda-Inspiron-3593 diff --git a/tokenizer.model b/tokenizer.model new file mode 100644 index 0000000..6c00c74 --- /dev/null +++ b/tokenizer.model @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9e556afd44213b6bd1be2b850ebbbd98f5481437a8021afaf58ee7fb1818d347 +size 499723 diff --git a/tokenizer_config.json b/tokenizer_config.json new file mode 100644 index 0000000..747187e --- /dev/null +++ b/tokenizer_config.json @@ -0,0 +1,34 @@ +{ + "add_bos_token": true, + "add_eos_token": false, + "bos_token": { + "__type": "AddedToken", + "content": "", + "lstrip": false, + "normalized": true, + "rstrip": false, + "single_word": false + }, + "clean_up_tokenization_spaces": false, + "eos_token": { + "__type": "AddedToken", + "content": "", + "lstrip": false, + "normalized": true, + "rstrip": false, + "single_word": false + }, + "model_max_length": 720, + "pad_token": null, + "padding_side": "right", + "sp_model_kwargs": {}, + "tokenizer_class": "LlamaTokenizer", + "unk_token": { + "__type": "AddedToken", + "content": "", + "lstrip": false, + "normalized": true, + "rstrip": false, + "single_word": false + } +} diff --git a/trainer_state.json b/trainer_state.json new file mode 100644 index 0000000..ba229e9 --- /dev/null +++ b/trainer_state.json @@ -0,0 +1,2788 @@ +{ + "best_metric": null, + "best_model_checkpoint": null, + "epoch": 2.9890820865345735, + "global_step": 462, + "is_hyper_param_search": false, + "is_local_process_zero": true, + "is_world_process_zero": true, + "log_history": [ + { + "epoch": 0.01, + "learning_rate": 0, + "loss": 10.8463, + "step": 1 + }, + { + "epoch": 0.01, + "learning_rate": 0, + "loss": 11.3118, + "step": 2 + }, + { + "epoch": 0.02, + "learning_rate": 0, + "loss": 11.0038, + "step": 3 + }, + { + "epoch": 0.03, + "learning_rate": 0, + "loss": 11.1465, + "step": 4 + }, + { + "epoch": 0.03, + "learning_rate": 0, + "loss": 11.3677, + "step": 5 + }, + { + "epoch": 0.04, + "learning_rate": 0, + "loss": 11.3668, + "step": 6 + }, + { + "epoch": 0.05, + "learning_rate": 0.0, + "loss": 10.8335, + "step": 7 + }, + { + "epoch": 0.05, + "learning_rate": 0.0, + "loss": 11.1042, + "step": 8 + }, + { + "epoch": 0.06, + "learning_rate": 2e-05, + "loss": 11.0963, + "step": 9 + }, + { + "epoch": 0.06, + "learning_rate": 2e-05, + "loss": 11.3511, + "step": 10 + }, + { + "epoch": 0.07, + "learning_rate": 2e-05, + "loss": 9.441, + "step": 11 + }, + { + "epoch": 0.08, + "learning_rate": 2e-05, + "loss": 9.1196, + "step": 12 + }, + { + "epoch": 0.08, + "learning_rate": 2e-05, + "loss": 8.403, + "step": 13 + }, + { + "epoch": 0.09, + "learning_rate": 2e-05, + "loss": 8.6426, + "step": 14 + }, + { + "epoch": 0.1, + "learning_rate": 2e-05, + "loss": 8.7776, + "step": 15 + }, + { + "epoch": 0.1, + "learning_rate": 2e-05, + "loss": 8.9274, + "step": 16 + }, + { + "epoch": 0.11, + "learning_rate": 2e-05, + "loss": 8.2598, + "step": 17 + }, + { + "epoch": 0.12, + "learning_rate": 2e-05, + "loss": 8.2732, + "step": 18 + }, + { + "epoch": 0.12, + "learning_rate": 2e-05, + "loss": 8.2438, + "step": 19 + }, + { + "epoch": 0.13, + "learning_rate": 2e-05, + "loss": 8.0727, + "step": 20 + }, + { + "epoch": 0.14, + "learning_rate": 2e-05, + "loss": 8.1909, + "step": 21 + }, + { + "epoch": 0.14, + "learning_rate": 2e-05, + "loss": 7.7707, + "step": 22 + }, + { + "epoch": 0.15, + "learning_rate": 2e-05, + "loss": 7.4861, + "step": 23 + }, + { + "epoch": 0.16, + "learning_rate": 2e-05, + "loss": 7.4279, + "step": 24 + }, + { + "epoch": 0.16, + "learning_rate": 2e-05, + "loss": 7.3696, + "step": 25 + }, + { + "epoch": 0.17, + "learning_rate": 2e-05, + "loss": 7.243, + "step": 26 + }, + { + "epoch": 0.17, + "learning_rate": 2e-05, + "loss": 6.9497, + "step": 27 + }, + { + "epoch": 0.18, + "learning_rate": 2e-05, + "loss": 6.9924, + "step": 28 + }, + { + "epoch": 0.19, + "learning_rate": 2e-05, + "loss": 6.5594, + "step": 29 + }, + { + "epoch": 0.19, + "learning_rate": 2e-05, + "loss": 6.7199, + "step": 30 + }, + { + "epoch": 0.2, + "learning_rate": 2e-05, + "loss": 6.2285, + "step": 31 + }, + { + "epoch": 0.21, + "learning_rate": 2e-05, + "loss": 6.0328, + "step": 32 + }, + { + "epoch": 0.21, + "learning_rate": 2e-05, + "loss": 5.821, + "step": 33 + }, + { + "epoch": 0.22, + "learning_rate": 2e-05, + "loss": 5.7585, + "step": 34 + }, + { + "epoch": 0.23, + "learning_rate": 2e-05, + "loss": 5.6209, + "step": 35 + }, + { + "epoch": 0.23, + "learning_rate": 2e-05, + "loss": 5.3515, + "step": 36 + }, + { + "epoch": 0.24, + "learning_rate": 2e-05, + "loss": 5.2036, + "step": 37 + }, + { + "epoch": 0.25, + "learning_rate": 2e-05, + "loss": 5.0071, + "step": 38 + }, + { + "epoch": 0.25, + "learning_rate": 2e-05, + "loss": 4.8688, + "step": 39 + }, + { + "epoch": 0.26, + "learning_rate": 2e-05, + "loss": 4.6462, + "step": 40 + }, + { + "epoch": 0.27, + "learning_rate": 2e-05, + "loss": 4.536, + "step": 41 + }, + { + "epoch": 0.27, + "learning_rate": 2e-05, + "loss": 4.0728, + "step": 42 + }, + { + "epoch": 0.28, + "learning_rate": 2e-05, + "loss": 3.8805, + "step": 43 + }, + { + "epoch": 0.28, + "learning_rate": 2e-05, + "loss": 3.9002, + "step": 44 + }, + { + "epoch": 0.29, + "learning_rate": 2e-05, + "loss": 3.3943, + "step": 45 + }, + { + "epoch": 0.3, + "learning_rate": 2e-05, + "loss": 3.3925, + "step": 46 + }, + { + "epoch": 0.3, + "learning_rate": 2e-05, + "loss": 3.0984, + "step": 47 + }, + { + "epoch": 0.31, + "learning_rate": 2e-05, + "loss": 2.8044, + "step": 48 + }, + { + "epoch": 0.32, + "learning_rate": 2e-05, + "loss": 2.6679, + "step": 49 + }, + { + "epoch": 0.32, + "learning_rate": 2e-05, + "loss": 2.4378, + "step": 50 + }, + { + "epoch": 0.33, + "learning_rate": 2e-05, + "loss": 2.1403, + "step": 51 + }, + { + "epoch": 0.34, + "learning_rate": 2e-05, + "loss": 2.0139, + "step": 52 + }, + { + "epoch": 0.34, + "learning_rate": 2e-05, + "loss": 1.8027, + "step": 53 + }, + { + "epoch": 0.35, + "learning_rate": 2e-05, + "loss": 1.5626, + "step": 54 + }, + { + "epoch": 0.36, + "learning_rate": 2e-05, + "loss": 1.3481, + "step": 55 + }, + { + "epoch": 0.36, + "learning_rate": 2e-05, + "loss": 1.17, + "step": 56 + }, + { + "epoch": 0.37, + "learning_rate": 2e-05, + "loss": 1.0311, + "step": 57 + }, + { + "epoch": 0.38, + "learning_rate": 2e-05, + "loss": 0.8989, + "step": 58 + }, + { + "epoch": 0.38, + "learning_rate": 2e-05, + "loss": 0.7375, + "step": 59 + }, + { + "epoch": 0.39, + "learning_rate": 2e-05, + "loss": 0.6133, + "step": 60 + }, + { + "epoch": 0.39, + "learning_rate": 2e-05, + "loss": 0.4984, + "step": 61 + }, + { + "epoch": 0.4, + "learning_rate": 2e-05, + "loss": 0.4197, + "step": 62 + }, + { + "epoch": 0.41, + "learning_rate": 2e-05, + "loss": 0.349, + "step": 63 + }, + { + "epoch": 0.41, + "learning_rate": 2e-05, + "loss": 0.2899, + "step": 64 + }, + { + "epoch": 0.42, + "learning_rate": 2e-05, + "loss": 0.2383, + "step": 65 + }, + { + "epoch": 0.43, + "learning_rate": 2e-05, + "loss": 0.2003, + "step": 66 + }, + { + "epoch": 0.43, + "learning_rate": 2e-05, + "loss": 0.1653, + "step": 67 + }, + { + "epoch": 0.44, + "learning_rate": 2e-05, + "loss": 0.1448, + "step": 68 + }, + { + "epoch": 0.45, + "learning_rate": 2e-05, + "loss": 0.118, + "step": 69 + }, + { + "epoch": 0.45, + "learning_rate": 2e-05, + "loss": 0.1192, + "step": 70 + }, + { + "epoch": 0.46, + "learning_rate": 2e-05, + "loss": 0.1045, + "step": 71 + }, + { + "epoch": 0.47, + "learning_rate": 2e-05, + "loss": 0.1026, + "step": 72 + }, + { + "epoch": 0.47, + "learning_rate": 2e-05, + "loss": 0.0894, + "step": 73 + }, + { + "epoch": 0.48, + "learning_rate": 2e-05, + "loss": 0.0847, + "step": 74 + }, + { + "epoch": 0.49, + "learning_rate": 2e-05, + "loss": 0.0838, + "step": 75 + }, + { + "epoch": 0.49, + "learning_rate": 2e-05, + "loss": 0.0812, + "step": 76 + }, + { + "epoch": 0.5, + "learning_rate": 2e-05, + "loss": 0.0852, + "step": 77 + }, + { + "epoch": 0.5, + "learning_rate": 2e-05, + "loss": 0.0834, + "step": 78 + }, + { + "epoch": 0.51, + "learning_rate": 2e-05, + "loss": 0.0805, + "step": 79 + }, + { + "epoch": 0.52, + "learning_rate": 2e-05, + "loss": 0.0743, + "step": 80 + }, + { + "epoch": 0.52, + "learning_rate": 2e-05, + "loss": 0.0747, + "step": 81 + }, + { + "epoch": 0.53, + "learning_rate": 2e-05, + "loss": 0.0738, + "step": 82 + }, + { + "epoch": 0.54, + "learning_rate": 2e-05, + "loss": 0.0845, + "step": 83 + }, + { + "epoch": 0.54, + "learning_rate": 2e-05, + "loss": 0.0687, + "step": 84 + }, + { + "epoch": 0.55, + "learning_rate": 2e-05, + "loss": 0.0515, + "step": 85 + }, + { + "epoch": 0.56, + "learning_rate": 2e-05, + "loss": 0.0695, + "step": 86 + }, + { + "epoch": 0.56, + "learning_rate": 2e-05, + "loss": 0.0615, + "step": 87 + }, + { + "epoch": 0.57, + "learning_rate": 2e-05, + "loss": 0.0677, + "step": 88 + }, + { + "epoch": 0.58, + "learning_rate": 2e-05, + "loss": 0.0692, + "step": 89 + }, + { + "epoch": 0.58, + "learning_rate": 2e-05, + "loss": 0.0564, + "step": 90 + }, + { + "epoch": 0.59, + "learning_rate": 2e-05, + "loss": 0.058, + "step": 91 + }, + { + "epoch": 0.6, + "learning_rate": 2e-05, + "loss": 0.0637, + "step": 92 + }, + { + "epoch": 0.6, + "learning_rate": 2e-05, + "loss": 0.0711, + "step": 93 + }, + { + "epoch": 0.61, + "learning_rate": 2e-05, + "loss": 0.0525, + "step": 94 + }, + { + "epoch": 0.61, + "learning_rate": 2e-05, + "loss": 0.0518, + "step": 95 + }, + { + "epoch": 0.62, + "learning_rate": 2e-05, + "loss": 0.0603, + "step": 96 + }, + { + "epoch": 0.63, + "learning_rate": 2e-05, + "loss": 0.0509, + "step": 97 + }, + { + "epoch": 0.63, + "learning_rate": 2e-05, + "loss": 0.0599, + "step": 98 + }, + { + "epoch": 0.64, + "learning_rate": 2e-05, + "loss": 0.0639, + "step": 99 + }, + { + "epoch": 0.65, + "learning_rate": 2e-05, + "loss": 0.0611, + "step": 100 + }, + { + "epoch": 0.65, + "learning_rate": 2e-05, + "loss": 0.0462, + "step": 101 + }, + { + "epoch": 0.66, + "learning_rate": 2e-05, + "loss": 0.051, + "step": 102 + }, + { + "epoch": 0.67, + "learning_rate": 2e-05, + "loss": 0.0621, + "step": 103 + }, + { + "epoch": 0.67, + "learning_rate": 2e-05, + "loss": 0.0597, + "step": 104 + }, + { + "epoch": 0.68, + "learning_rate": 2e-05, + "loss": 0.0574, + "step": 105 + }, + { + "epoch": 0.69, + "learning_rate": 2e-05, + "loss": 0.0573, + "step": 106 + }, + { + "epoch": 0.69, + "learning_rate": 2e-05, + "loss": 0.0575, + "step": 107 + }, + { + "epoch": 0.7, + "learning_rate": 2e-05, + "loss": 0.0563, + "step": 108 + }, + { + "epoch": 0.71, + "learning_rate": 2e-05, + "loss": 0.0454, + "step": 109 + }, + { + "epoch": 0.71, + "learning_rate": 2e-05, + "loss": 0.0442, + "step": 110 + }, + { + "epoch": 0.72, + "learning_rate": 2e-05, + "loss": 0.0602, + "step": 111 + }, + { + "epoch": 0.72, + "learning_rate": 2e-05, + "loss": 0.0578, + "step": 112 + }, + { + "epoch": 0.73, + "learning_rate": 2e-05, + "loss": 0.0524, + "step": 113 + }, + { + "epoch": 0.74, + "learning_rate": 2e-05, + "loss": 0.0567, + "step": 114 + }, + { + "epoch": 0.74, + "learning_rate": 2e-05, + "loss": 0.0476, + "step": 115 + }, + { + "epoch": 0.75, + "learning_rate": 2e-05, + "loss": 0.0576, + "step": 116 + }, + { + "epoch": 0.76, + "learning_rate": 2e-05, + "loss": 0.0481, + "step": 117 + }, + { + "epoch": 0.76, + "learning_rate": 2e-05, + "loss": 0.0496, + "step": 118 + }, + { + "epoch": 0.77, + "learning_rate": 2e-05, + "loss": 0.0485, + "step": 119 + }, + { + "epoch": 0.78, + "learning_rate": 2e-05, + "loss": 0.0538, + "step": 120 + }, + { + "epoch": 0.78, + "learning_rate": 2e-05, + "loss": 0.0508, + "step": 121 + }, + { + "epoch": 0.79, + "learning_rate": 2e-05, + "loss": 0.0444, + "step": 122 + }, + { + "epoch": 0.8, + "learning_rate": 2e-05, + "loss": 0.0494, + "step": 123 + }, + { + "epoch": 0.8, + "learning_rate": 2e-05, + "loss": 0.0505, + "step": 124 + }, + { + "epoch": 0.81, + "learning_rate": 2e-05, + "loss": 0.045, + "step": 125 + }, + { + "epoch": 0.82, + "learning_rate": 2e-05, + "loss": 0.0492, + "step": 126 + }, + { + "epoch": 0.82, + "learning_rate": 2e-05, + "loss": 0.0608, + "step": 127 + }, + { + "epoch": 0.83, + "learning_rate": 2e-05, + "loss": 0.0549, + "step": 128 + }, + { + "epoch": 0.83, + "learning_rate": 2e-05, + "loss": 0.0526, + "step": 129 + }, + { + "epoch": 0.84, + "learning_rate": 2e-05, + "loss": 0.0525, + "step": 130 + }, + { + "epoch": 0.85, + "learning_rate": 2e-05, + "loss": 0.0455, + "step": 131 + }, + { + "epoch": 0.85, + "learning_rate": 2e-05, + "loss": 0.0528, + "step": 132 + }, + { + "epoch": 0.86, + "learning_rate": 2e-05, + "loss": 0.0613, + "step": 133 + }, + { + "epoch": 0.87, + "learning_rate": 2e-05, + "loss": 0.0612, + "step": 134 + }, + { + "epoch": 0.87, + "learning_rate": 2e-05, + "loss": 0.0449, + "step": 135 + }, + { + "epoch": 0.88, + "learning_rate": 2e-05, + "loss": 0.0572, + "step": 136 + }, + { + "epoch": 0.89, + "learning_rate": 2e-05, + "loss": 0.0513, + "step": 137 + }, + { + "epoch": 0.89, + "learning_rate": 2e-05, + "loss": 0.0637, + "step": 138 + }, + { + "epoch": 0.9, + "learning_rate": 2e-05, + "loss": 0.0539, + "step": 139 + }, + { + "epoch": 0.91, + "learning_rate": 2e-05, + "loss": 0.0575, + "step": 140 + }, + { + "epoch": 0.91, + "learning_rate": 2e-05, + "loss": 0.056, + "step": 141 + }, + { + "epoch": 0.92, + "learning_rate": 2e-05, + "loss": 0.0574, + "step": 142 + }, + { + "epoch": 0.93, + "learning_rate": 2e-05, + "loss": 0.0588, + "step": 143 + }, + { + "epoch": 0.93, + "learning_rate": 2e-05, + "loss": 0.0495, + "step": 144 + }, + { + "epoch": 0.94, + "learning_rate": 2e-05, + "loss": 0.0494, + "step": 145 + }, + { + "epoch": 0.94, + "learning_rate": 2e-05, + "loss": 0.0497, + "step": 146 + }, + { + "epoch": 0.95, + "learning_rate": 2e-05, + "loss": 0.0583, + "step": 147 + }, + { + "epoch": 0.96, + "learning_rate": 2e-05, + "loss": 0.0469, + "step": 148 + }, + { + "epoch": 0.96, + "learning_rate": 2e-05, + "loss": 0.0572, + "step": 149 + }, + { + "epoch": 0.97, + "learning_rate": 2e-05, + "loss": 0.0605, + "step": 150 + }, + { + "epoch": 0.98, + "learning_rate": 2e-05, + "loss": 0.0509, + "step": 151 + }, + { + "epoch": 0.98, + "learning_rate": 2e-05, + "loss": 0.0498, + "step": 152 + }, + { + "epoch": 0.99, + "learning_rate": 2e-05, + "loss": 0.0519, + "step": 153 + }, + { + "epoch": 1.0, + "learning_rate": 2e-05, + "loss": 0.0564, + "step": 154 + }, + { + "epoch": 1.0, + "learning_rate": 2e-05, + "loss": 0.0535, + "step": 155 + }, + { + "epoch": 1.01, + "learning_rate": 2e-05, + "loss": 0.0367, + "step": 156 + }, + { + "epoch": 1.02, + "learning_rate": 2e-05, + "loss": 0.0502, + "step": 157 + }, + { + "epoch": 1.02, + "learning_rate": 2e-05, + "loss": 0.0503, + "step": 158 + }, + { + "epoch": 1.03, + "learning_rate": 2e-05, + "loss": 0.0424, + "step": 159 + }, + { + "epoch": 1.04, + "learning_rate": 2e-05, + "loss": 0.049, + "step": 160 + }, + { + "epoch": 1.04, + "learning_rate": 2e-05, + "loss": 0.0421, + "step": 161 + }, + { + "epoch": 1.05, + "learning_rate": 2e-05, + "loss": 0.0429, + "step": 162 + }, + { + "epoch": 1.05, + "learning_rate": 2e-05, + "loss": 0.0517, + "step": 163 + }, + { + "epoch": 1.06, + "learning_rate": 2e-05, + "loss": 0.0521, + "step": 164 + }, + { + "epoch": 1.07, + "learning_rate": 2e-05, + "loss": 0.0446, + "step": 165 + }, + { + "epoch": 1.07, + "learning_rate": 2e-05, + "loss": 0.0411, + "step": 166 + }, + { + "epoch": 1.08, + "learning_rate": 2e-05, + "loss": 0.0334, + "step": 167 + }, + { + "epoch": 1.09, + "learning_rate": 2e-05, + "loss": 0.0381, + "step": 168 + }, + { + "epoch": 1.09, + "learning_rate": 2e-05, + "loss": 0.0511, + "step": 169 + }, + { + "epoch": 1.1, + "learning_rate": 2e-05, + "loss": 0.0424, + "step": 170 + }, + { + "epoch": 1.11, + "learning_rate": 2e-05, + "loss": 0.0434, + "step": 171 + }, + { + "epoch": 1.11, + "learning_rate": 2e-05, + "loss": 0.0452, + "step": 172 + }, + { + "epoch": 1.12, + "learning_rate": 2e-05, + "loss": 0.0433, + "step": 173 + }, + { + "epoch": 1.13, + "learning_rate": 2e-05, + "loss": 0.0424, + "step": 174 + }, + { + "epoch": 1.13, + "learning_rate": 2e-05, + "loss": 0.0476, + "step": 175 + }, + { + "epoch": 1.14, + "learning_rate": 2e-05, + "loss": 0.0532, + "step": 176 + }, + { + "epoch": 1.15, + "learning_rate": 2e-05, + "loss": 0.0447, + "step": 177 + }, + { + "epoch": 1.15, + "learning_rate": 2e-05, + "loss": 0.0449, + "step": 178 + }, + { + "epoch": 1.16, + "learning_rate": 2e-05, + "loss": 0.0421, + "step": 179 + }, + { + "epoch": 1.16, + "learning_rate": 2e-05, + "loss": 0.0439, + "step": 180 + }, + { + "epoch": 1.17, + "learning_rate": 2e-05, + "loss": 0.0416, + "step": 181 + }, + { + "epoch": 1.18, + "learning_rate": 2e-05, + "loss": 0.0422, + "step": 182 + }, + { + "epoch": 1.18, + "learning_rate": 2e-05, + "loss": 0.0448, + "step": 183 + }, + { + "epoch": 1.19, + "learning_rate": 2e-05, + "loss": 0.0469, + "step": 184 + }, + { + "epoch": 1.2, + "learning_rate": 2e-05, + "loss": 0.0471, + "step": 185 + }, + { + "epoch": 1.2, + "learning_rate": 2e-05, + "loss": 0.0436, + "step": 186 + }, + { + "epoch": 1.21, + "learning_rate": 2e-05, + "loss": 0.0447, + "step": 187 + }, + { + "epoch": 1.22, + "learning_rate": 2e-05, + "loss": 0.0512, + "step": 188 + }, + { + "epoch": 1.22, + "learning_rate": 2e-05, + "loss": 0.0463, + "step": 189 + }, + { + "epoch": 1.23, + "learning_rate": 2e-05, + "loss": 0.0434, + "step": 190 + }, + { + "epoch": 1.24, + "learning_rate": 2e-05, + "loss": 0.0486, + "step": 191 + }, + { + "epoch": 1.24, + "learning_rate": 2e-05, + "loss": 0.0387, + "step": 192 + }, + { + "epoch": 1.25, + "learning_rate": 2e-05, + "loss": 0.0379, + "step": 193 + }, + { + "epoch": 1.26, + "learning_rate": 2e-05, + "loss": 0.0389, + "step": 194 + }, + { + "epoch": 1.26, + "learning_rate": 2e-05, + "loss": 0.0408, + "step": 195 + }, + { + "epoch": 1.27, + "learning_rate": 2e-05, + "loss": 0.0449, + "step": 196 + }, + { + "epoch": 1.27, + "learning_rate": 2e-05, + "loss": 0.0433, + "step": 197 + }, + { + "epoch": 1.28, + "learning_rate": 2e-05, + "loss": 0.0487, + "step": 198 + }, + { + "epoch": 1.29, + "learning_rate": 2e-05, + "loss": 0.0314, + "step": 199 + }, + { + "epoch": 1.29, + "learning_rate": 2e-05, + "loss": 0.0352, + "step": 200 + }, + { + "epoch": 1.3, + "learning_rate": 2e-05, + "loss": 0.0428, + "step": 201 + }, + { + "epoch": 1.31, + "learning_rate": 2e-05, + "loss": 0.0397, + "step": 202 + }, + { + "epoch": 1.31, + "learning_rate": 2e-05, + "loss": 0.0341, + "step": 203 + }, + { + "epoch": 1.32, + "learning_rate": 2e-05, + "loss": 0.0373, + "step": 204 + }, + { + "epoch": 1.33, + "learning_rate": 2e-05, + "loss": 0.0402, + "step": 205 + }, + { + "epoch": 1.33, + "learning_rate": 2e-05, + "loss": 0.0408, + "step": 206 + }, + { + "epoch": 1.34, + "learning_rate": 2e-05, + "loss": 0.043, + "step": 207 + }, + { + "epoch": 1.35, + "learning_rate": 2e-05, + "loss": 0.0405, + "step": 208 + }, + { + "epoch": 1.35, + "learning_rate": 2e-05, + "loss": 0.0444, + "step": 209 + }, + { + "epoch": 1.36, + "learning_rate": 2e-05, + "loss": 0.0412, + "step": 210 + }, + { + "epoch": 1.37, + "learning_rate": 2e-05, + "loss": 0.0368, + "step": 211 + }, + { + "epoch": 1.37, + "learning_rate": 2e-05, + "loss": 0.0524, + "step": 212 + }, + { + "epoch": 1.38, + "learning_rate": 2e-05, + "loss": 0.0439, + "step": 213 + }, + { + "epoch": 1.38, + "learning_rate": 2e-05, + "loss": 0.0444, + "step": 214 + }, + { + "epoch": 1.39, + "learning_rate": 2e-05, + "loss": 0.0415, + "step": 215 + }, + { + "epoch": 1.4, + "learning_rate": 2e-05, + "loss": 0.041, + "step": 216 + }, + { + "epoch": 1.4, + "learning_rate": 2e-05, + "loss": 0.0416, + "step": 217 + }, + { + "epoch": 1.41, + "learning_rate": 2e-05, + "loss": 0.0458, + "step": 218 + }, + { + "epoch": 1.42, + "learning_rate": 2e-05, + "loss": 0.0399, + "step": 219 + }, + { + "epoch": 1.42, + "learning_rate": 2e-05, + "loss": 0.0468, + "step": 220 + }, + { + "epoch": 1.43, + "learning_rate": 2e-05, + "loss": 0.0516, + "step": 221 + }, + { + "epoch": 1.44, + "learning_rate": 2e-05, + "loss": 0.0434, + "step": 222 + }, + { + "epoch": 1.44, + "learning_rate": 2e-05, + "loss": 0.0492, + "step": 223 + }, + { + "epoch": 1.45, + "learning_rate": 2e-05, + "loss": 0.0432, + "step": 224 + }, + { + "epoch": 1.46, + "learning_rate": 2e-05, + "loss": 0.0451, + "step": 225 + }, + { + "epoch": 1.46, + "learning_rate": 2e-05, + "loss": 0.0486, + "step": 226 + }, + { + "epoch": 1.47, + "learning_rate": 2e-05, + "loss": 0.0359, + "step": 227 + }, + { + "epoch": 1.48, + "learning_rate": 2e-05, + "loss": 0.0423, + "step": 228 + }, + { + "epoch": 1.48, + "learning_rate": 2e-05, + "loss": 0.0395, + "step": 229 + }, + { + "epoch": 1.49, + "learning_rate": 2e-05, + "loss": 0.0513, + "step": 230 + }, + { + "epoch": 1.49, + "learning_rate": 2e-05, + "loss": 0.0381, + "step": 231 + }, + { + "epoch": 1.5, + "learning_rate": 2e-05, + "loss": 0.0384, + "step": 232 + }, + { + "epoch": 1.51, + "learning_rate": 2e-05, + "loss": 0.0471, + "step": 233 + }, + { + "epoch": 1.51, + "learning_rate": 2e-05, + "loss": 0.041, + "step": 234 + }, + { + "epoch": 1.52, + "learning_rate": 2e-05, + "loss": 0.0479, + "step": 235 + }, + { + "epoch": 1.53, + "learning_rate": 2e-05, + "loss": 0.047, + "step": 236 + }, + { + "epoch": 1.53, + "learning_rate": 2e-05, + "loss": 0.0441, + "step": 237 + }, + { + "epoch": 1.54, + "learning_rate": 2e-05, + "loss": 0.0505, + "step": 238 + }, + { + "epoch": 1.55, + "learning_rate": 2e-05, + "loss": 0.0419, + "step": 239 + }, + { + "epoch": 1.55, + "learning_rate": 2e-05, + "loss": 0.0409, + "step": 240 + }, + { + "epoch": 1.56, + "learning_rate": 2e-05, + "loss": 0.0444, + "step": 241 + }, + { + "epoch": 1.57, + "learning_rate": 2e-05, + "loss": 0.0392, + "step": 242 + }, + { + "epoch": 1.57, + "learning_rate": 2e-05, + "loss": 0.037, + "step": 243 + }, + { + "epoch": 1.58, + "learning_rate": 2e-05, + "loss": 0.0444, + "step": 244 + }, + { + "epoch": 1.59, + "learning_rate": 2e-05, + "loss": 0.0361, + "step": 245 + }, + { + "epoch": 1.59, + "learning_rate": 2e-05, + "loss": 0.0468, + "step": 246 + }, + { + "epoch": 1.6, + "learning_rate": 2e-05, + "loss": 0.052, + "step": 247 + }, + { + "epoch": 1.6, + "learning_rate": 2e-05, + "loss": 0.0453, + "step": 248 + }, + { + "epoch": 1.61, + "learning_rate": 2e-05, + "loss": 0.0423, + "step": 249 + }, + { + "epoch": 1.62, + "learning_rate": 2e-05, + "loss": 0.0385, + "step": 250 + }, + { + "epoch": 1.62, + "learning_rate": 2e-05, + "loss": 0.0462, + "step": 251 + }, + { + "epoch": 1.63, + "learning_rate": 2e-05, + "loss": 0.0362, + "step": 252 + }, + { + "epoch": 1.64, + "learning_rate": 2e-05, + "loss": 0.0388, + "step": 253 + }, + { + "epoch": 1.64, + "learning_rate": 2e-05, + "loss": 0.0433, + "step": 254 + }, + { + "epoch": 1.65, + "learning_rate": 2e-05, + "loss": 0.051, + "step": 255 + }, + { + "epoch": 1.66, + "learning_rate": 2e-05, + "loss": 0.0362, + "step": 256 + }, + { + "epoch": 1.66, + "learning_rate": 2e-05, + "loss": 0.0444, + "step": 257 + }, + { + "epoch": 1.67, + "learning_rate": 2e-05, + "loss": 0.0413, + "step": 258 + }, + { + "epoch": 1.68, + "learning_rate": 2e-05, + "loss": 0.0331, + "step": 259 + }, + { + "epoch": 1.68, + "learning_rate": 2e-05, + "loss": 0.0425, + "step": 260 + }, + { + "epoch": 1.69, + "learning_rate": 2e-05, + "loss": 0.0508, + "step": 261 + }, + { + "epoch": 1.7, + "learning_rate": 2e-05, + "loss": 0.0445, + "step": 262 + }, + { + "epoch": 1.7, + "learning_rate": 2e-05, + "loss": 0.0366, + "step": 263 + }, + { + "epoch": 1.71, + "learning_rate": 2e-05, + "loss": 0.0401, + "step": 264 + }, + { + "epoch": 1.71, + "learning_rate": 2e-05, + "loss": 0.037, + "step": 265 + }, + { + "epoch": 1.72, + "learning_rate": 2e-05, + "loss": 0.0464, + "step": 266 + }, + { + "epoch": 1.73, + "learning_rate": 2e-05, + "loss": 0.0405, + "step": 267 + }, + { + "epoch": 1.73, + "learning_rate": 2e-05, + "loss": 0.0455, + "step": 268 + }, + { + "epoch": 1.74, + "learning_rate": 2e-05, + "loss": 0.0385, + "step": 269 + }, + { + "epoch": 1.75, + "learning_rate": 2e-05, + "loss": 0.0459, + "step": 270 + }, + { + "epoch": 1.75, + "learning_rate": 2e-05, + "loss": 0.0464, + "step": 271 + }, + { + "epoch": 1.76, + "learning_rate": 2e-05, + "loss": 0.0381, + "step": 272 + }, + { + "epoch": 1.77, + "learning_rate": 2e-05, + "loss": 0.0356, + "step": 273 + }, + { + "epoch": 1.77, + "learning_rate": 2e-05, + "loss": 0.0454, + "step": 274 + }, + { + "epoch": 1.78, + "learning_rate": 2e-05, + "loss": 0.0398, + "step": 275 + }, + { + "epoch": 1.79, + "learning_rate": 2e-05, + "loss": 0.0393, + "step": 276 + }, + { + "epoch": 1.79, + "learning_rate": 2e-05, + "loss": 0.0399, + "step": 277 + }, + { + "epoch": 1.8, + "learning_rate": 2e-05, + "loss": 0.045, + "step": 278 + }, + { + "epoch": 1.81, + "learning_rate": 2e-05, + "loss": 0.0448, + "step": 279 + }, + { + "epoch": 1.81, + "learning_rate": 2e-05, + "loss": 0.0441, + "step": 280 + }, + { + "epoch": 1.82, + "learning_rate": 2e-05, + "loss": 0.042, + "step": 281 + }, + { + "epoch": 1.82, + "learning_rate": 2e-05, + "loss": 0.0387, + "step": 282 + }, + { + "epoch": 1.83, + "learning_rate": 2e-05, + "loss": 0.0383, + "step": 283 + }, + { + "epoch": 1.84, + "learning_rate": 2e-05, + "loss": 0.0377, + "step": 284 + }, + { + "epoch": 1.84, + "learning_rate": 2e-05, + "loss": 0.0433, + "step": 285 + }, + { + "epoch": 1.85, + "learning_rate": 2e-05, + "loss": 0.0424, + "step": 286 + }, + { + "epoch": 1.86, + "learning_rate": 2e-05, + "loss": 0.0459, + "step": 287 + }, + { + "epoch": 1.86, + "learning_rate": 2e-05, + "loss": 0.0402, + "step": 288 + }, + { + "epoch": 1.87, + "learning_rate": 2e-05, + "loss": 0.0389, + "step": 289 + }, + { + "epoch": 1.88, + "learning_rate": 2e-05, + "loss": 0.0394, + "step": 290 + }, + { + "epoch": 1.88, + "learning_rate": 2e-05, + "loss": 0.0412, + "step": 291 + }, + { + "epoch": 1.89, + "learning_rate": 2e-05, + "loss": 0.0493, + "step": 292 + }, + { + "epoch": 1.9, + "learning_rate": 2e-05, + "loss": 0.0468, + "step": 293 + }, + { + "epoch": 1.9, + "learning_rate": 2e-05, + "loss": 0.0407, + "step": 294 + }, + { + "epoch": 1.91, + "learning_rate": 2e-05, + "loss": 0.0376, + "step": 295 + }, + { + "epoch": 1.92, + "learning_rate": 2e-05, + "loss": 0.042, + "step": 296 + }, + { + "epoch": 1.92, + "learning_rate": 2e-05, + "loss": 0.0367, + "step": 297 + }, + { + "epoch": 1.93, + "learning_rate": 2e-05, + "loss": 0.0365, + "step": 298 + }, + { + "epoch": 1.93, + "learning_rate": 2e-05, + "loss": 0.0449, + "step": 299 + }, + { + "epoch": 1.94, + "learning_rate": 2e-05, + "loss": 0.0452, + "step": 300 + }, + { + "epoch": 1.95, + "learning_rate": 2e-05, + "loss": 0.0427, + "step": 301 + }, + { + "epoch": 1.95, + "learning_rate": 2e-05, + "loss": 0.0453, + "step": 302 + }, + { + "epoch": 1.96, + "learning_rate": 2e-05, + "loss": 0.0351, + "step": 303 + }, + { + "epoch": 1.97, + "learning_rate": 2e-05, + "loss": 0.036, + "step": 304 + }, + { + "epoch": 1.97, + "learning_rate": 2e-05, + "loss": 0.0443, + "step": 305 + }, + { + "epoch": 1.98, + "learning_rate": 2e-05, + "loss": 0.0369, + "step": 306 + }, + { + "epoch": 1.99, + "learning_rate": 2e-05, + "loss": 0.0376, + "step": 307 + }, + { + "epoch": 1.99, + "learning_rate": 2e-05, + "loss": 0.048, + "step": 308 + }, + { + "epoch": 2.0, + "learning_rate": 2e-05, + "loss": 0.0366, + "step": 309 + }, + { + "epoch": 2.01, + "learning_rate": 2e-05, + "loss": 0.0337, + "step": 310 + }, + { + "epoch": 2.01, + "learning_rate": 2e-05, + "loss": 0.0278, + "step": 311 + }, + { + "epoch": 2.02, + "learning_rate": 2e-05, + "loss": 0.0279, + "step": 312 + }, + { + "epoch": 2.03, + "learning_rate": 2e-05, + "loss": 0.0307, + "step": 313 + }, + { + "epoch": 2.03, + "learning_rate": 2e-05, + "loss": 0.0303, + "step": 314 + }, + { + "epoch": 2.04, + "learning_rate": 2e-05, + "loss": 0.0286, + "step": 315 + }, + { + "epoch": 2.04, + "learning_rate": 2e-05, + "loss": 0.0233, + "step": 316 + }, + { + "epoch": 2.05, + "learning_rate": 2e-05, + "loss": 0.0285, + "step": 317 + }, + { + "epoch": 2.06, + "learning_rate": 2e-05, + "loss": 0.0258, + "step": 318 + }, + { + "epoch": 2.06, + "learning_rate": 2e-05, + "loss": 0.0262, + "step": 319 + }, + { + "epoch": 2.07, + "learning_rate": 2e-05, + "loss": 0.0268, + "step": 320 + }, + { + "epoch": 2.08, + "learning_rate": 2e-05, + "loss": 0.0273, + "step": 321 + }, + { + "epoch": 2.08, + "learning_rate": 2e-05, + "loss": 0.0258, + "step": 322 + }, + { + "epoch": 2.09, + "learning_rate": 2e-05, + "loss": 0.0261, + "step": 323 + }, + { + "epoch": 2.1, + "learning_rate": 2e-05, + "loss": 0.0269, + "step": 324 + }, + { + "epoch": 2.1, + "learning_rate": 2e-05, + "loss": 0.0268, + "step": 325 + }, + { + "epoch": 2.11, + "learning_rate": 2e-05, + "loss": 0.0244, + "step": 326 + }, + { + "epoch": 2.12, + "learning_rate": 2e-05, + "loss": 0.026, + "step": 327 + }, + { + "epoch": 2.12, + "learning_rate": 2e-05, + "loss": 0.023, + "step": 328 + }, + { + "epoch": 2.13, + "learning_rate": 2e-05, + "loss": 0.0259, + "step": 329 + }, + { + "epoch": 2.14, + "learning_rate": 2e-05, + "loss": 0.0288, + "step": 330 + }, + { + "epoch": 2.14, + "learning_rate": 2e-05, + "loss": 0.0264, + "step": 331 + }, + { + "epoch": 2.15, + "learning_rate": 2e-05, + "loss": 0.0243, + "step": 332 + }, + { + "epoch": 2.15, + "learning_rate": 2e-05, + "loss": 0.0319, + "step": 333 + }, + { + "epoch": 2.16, + "learning_rate": 2e-05, + "loss": 0.0265, + "step": 334 + }, + { + "epoch": 2.17, + "learning_rate": 2e-05, + "loss": 0.0241, + "step": 335 + }, + { + "epoch": 2.17, + "learning_rate": 2e-05, + "loss": 0.028, + "step": 336 + }, + { + "epoch": 2.18, + "learning_rate": 2e-05, + "loss": 0.0318, + "step": 337 + }, + { + "epoch": 2.19, + "learning_rate": 2e-05, + "loss": 0.0307, + "step": 338 + }, + { + "epoch": 2.19, + "learning_rate": 2e-05, + "loss": 0.026, + "step": 339 + }, + { + "epoch": 2.2, + "learning_rate": 2e-05, + "loss": 0.0276, + "step": 340 + }, + { + "epoch": 2.21, + "learning_rate": 2e-05, + "loss": 0.023, + "step": 341 + }, + { + "epoch": 2.21, + "learning_rate": 2e-05, + "loss": 0.028, + "step": 342 + }, + { + "epoch": 2.22, + "learning_rate": 2e-05, + "loss": 0.0299, + "step": 343 + }, + { + "epoch": 2.23, + "learning_rate": 2e-05, + "loss": 0.0239, + "step": 344 + }, + { + "epoch": 2.23, + "learning_rate": 2e-05, + "loss": 0.0269, + "step": 345 + }, + { + "epoch": 2.24, + "learning_rate": 2e-05, + "loss": 0.0284, + "step": 346 + }, + { + "epoch": 2.25, + "learning_rate": 2e-05, + "loss": 0.0243, + "step": 347 + }, + { + "epoch": 2.25, + "learning_rate": 2e-05, + "loss": 0.0237, + "step": 348 + }, + { + "epoch": 2.26, + "learning_rate": 2e-05, + "loss": 0.0185, + "step": 349 + }, + { + "epoch": 2.26, + "learning_rate": 2e-05, + "loss": 0.0235, + "step": 350 + }, + { + "epoch": 2.27, + "learning_rate": 2e-05, + "loss": 0.0199, + "step": 351 + }, + { + "epoch": 2.28, + "learning_rate": 2e-05, + "loss": 0.0263, + "step": 352 + }, + { + "epoch": 2.28, + "learning_rate": 2e-05, + "loss": 0.0271, + "step": 353 + }, + { + "epoch": 2.29, + "learning_rate": 2e-05, + "loss": 0.0282, + "step": 354 + }, + { + "epoch": 2.3, + "learning_rate": 2e-05, + "loss": 0.0248, + "step": 355 + }, + { + "epoch": 2.3, + "learning_rate": 2e-05, + "loss": 0.0281, + "step": 356 + }, + { + "epoch": 2.31, + "learning_rate": 2e-05, + "loss": 0.0247, + "step": 357 + }, + { + "epoch": 2.32, + "learning_rate": 2e-05, + "loss": 0.024, + "step": 358 + }, + { + "epoch": 2.32, + "learning_rate": 2e-05, + "loss": 0.0311, + "step": 359 + }, + { + "epoch": 2.33, + "learning_rate": 2e-05, + "loss": 0.0338, + "step": 360 + }, + { + "epoch": 2.34, + "learning_rate": 2e-05, + "loss": 0.026, + "step": 361 + }, + { + "epoch": 2.34, + "learning_rate": 2e-05, + "loss": 0.0255, + "step": 362 + }, + { + "epoch": 2.35, + "learning_rate": 2e-05, + "loss": 0.0273, + "step": 363 + }, + { + "epoch": 2.36, + "learning_rate": 2e-05, + "loss": 0.0302, + "step": 364 + }, + { + "epoch": 2.36, + "learning_rate": 2e-05, + "loss": 0.0248, + "step": 365 + }, + { + "epoch": 2.37, + "learning_rate": 2e-05, + "loss": 0.0266, + "step": 366 + }, + { + "epoch": 2.37, + "learning_rate": 2e-05, + "loss": 0.0288, + "step": 367 + }, + { + "epoch": 2.38, + "learning_rate": 2e-05, + "loss": 0.0344, + "step": 368 + }, + { + "epoch": 2.39, + "learning_rate": 2e-05, + "loss": 0.0244, + "step": 369 + }, + { + "epoch": 2.39, + "learning_rate": 2e-05, + "loss": 0.0238, + "step": 370 + }, + { + "epoch": 2.4, + "learning_rate": 2e-05, + "loss": 0.0279, + "step": 371 + }, + { + "epoch": 2.41, + "learning_rate": 2e-05, + "loss": 0.0299, + "step": 372 + }, + { + "epoch": 2.41, + "learning_rate": 2e-05, + "loss": 0.0211, + "step": 373 + }, + { + "epoch": 2.42, + "learning_rate": 2e-05, + "loss": 0.0257, + "step": 374 + }, + { + "epoch": 2.43, + "learning_rate": 2e-05, + "loss": 0.0342, + "step": 375 + }, + { + "epoch": 2.43, + "learning_rate": 2e-05, + "loss": 0.0236, + "step": 376 + }, + { + "epoch": 2.44, + "learning_rate": 2e-05, + "loss": 0.035, + "step": 377 + }, + { + "epoch": 2.45, + "learning_rate": 2e-05, + "loss": 0.0316, + "step": 378 + }, + { + "epoch": 2.45, + "learning_rate": 2e-05, + "loss": 0.0254, + "step": 379 + }, + { + "epoch": 2.46, + "learning_rate": 2e-05, + "loss": 0.0281, + "step": 380 + }, + { + "epoch": 2.47, + "learning_rate": 2e-05, + "loss": 0.0246, + "step": 381 + }, + { + "epoch": 2.47, + "learning_rate": 2e-05, + "loss": 0.0299, + "step": 382 + }, + { + "epoch": 2.48, + "learning_rate": 2e-05, + "loss": 0.0269, + "step": 383 + }, + { + "epoch": 2.48, + "learning_rate": 2e-05, + "loss": 0.0285, + "step": 384 + }, + { + "epoch": 2.49, + "learning_rate": 2e-05, + "loss": 0.0346, + "step": 385 + }, + { + "epoch": 2.5, + "learning_rate": 2e-05, + "loss": 0.024, + "step": 386 + }, + { + "epoch": 2.5, + "learning_rate": 2e-05, + "loss": 0.0353, + "step": 387 + }, + { + "epoch": 2.51, + "learning_rate": 2e-05, + "loss": 0.0303, + "step": 388 + }, + { + "epoch": 2.52, + "learning_rate": 2e-05, + "loss": 0.03, + "step": 389 + }, + { + "epoch": 2.52, + "learning_rate": 2e-05, + "loss": 0.0277, + "step": 390 + }, + { + "epoch": 2.53, + "learning_rate": 2e-05, + "loss": 0.0232, + "step": 391 + }, + { + "epoch": 2.54, + "learning_rate": 2e-05, + "loss": 0.0314, + "step": 392 + }, + { + "epoch": 2.54, + "learning_rate": 2e-05, + "loss": 0.0259, + "step": 393 + }, + { + "epoch": 2.55, + "learning_rate": 2e-05, + "loss": 0.0304, + "step": 394 + }, + { + "epoch": 2.56, + "learning_rate": 2e-05, + "loss": 0.0274, + "step": 395 + }, + { + "epoch": 2.56, + "learning_rate": 2e-05, + "loss": 0.0303, + "step": 396 + }, + { + "epoch": 2.57, + "learning_rate": 2e-05, + "loss": 0.0279, + "step": 397 + }, + { + "epoch": 2.58, + "learning_rate": 2e-05, + "loss": 0.0281, + "step": 398 + }, + { + "epoch": 2.58, + "learning_rate": 2e-05, + "loss": 0.0285, + "step": 399 + }, + { + "epoch": 2.59, + "learning_rate": 2e-05, + "loss": 0.0355, + "step": 400 + }, + { + "epoch": 2.59, + "learning_rate": 2e-05, + "loss": 0.0279, + "step": 401 + }, + { + "epoch": 2.6, + "learning_rate": 2e-05, + "loss": 0.0312, + "step": 402 + }, + { + "epoch": 2.61, + "learning_rate": 2e-05, + "loss": 0.0258, + "step": 403 + }, + { + "epoch": 2.61, + "learning_rate": 2e-05, + "loss": 0.0304, + "step": 404 + }, + { + "epoch": 2.62, + "learning_rate": 2e-05, + "loss": 0.0271, + "step": 405 + }, + { + "epoch": 2.63, + "learning_rate": 2e-05, + "loss": 0.0271, + "step": 406 + }, + { + "epoch": 2.63, + "learning_rate": 2e-05, + "loss": 0.0248, + "step": 407 + }, + { + "epoch": 2.64, + "learning_rate": 2e-05, + "loss": 0.0261, + "step": 408 + }, + { + "epoch": 2.65, + "learning_rate": 2e-05, + "loss": 0.0284, + "step": 409 + }, + { + "epoch": 2.65, + "learning_rate": 2e-05, + "loss": 0.0288, + "step": 410 + }, + { + "epoch": 2.66, + "learning_rate": 2e-05, + "loss": 0.0296, + "step": 411 + }, + { + "epoch": 2.67, + "learning_rate": 2e-05, + "loss": 0.028, + "step": 412 + }, + { + "epoch": 2.67, + "learning_rate": 2e-05, + "loss": 0.0232, + "step": 413 + }, + { + "epoch": 2.68, + "learning_rate": 2e-05, + "loss": 0.0281, + "step": 414 + }, + { + "epoch": 2.68, + "learning_rate": 2e-05, + "loss": 0.0337, + "step": 415 + }, + { + "epoch": 2.69, + "learning_rate": 2e-05, + "loss": 0.0228, + "step": 416 + }, + { + "epoch": 2.7, + "learning_rate": 2e-05, + "loss": 0.0319, + "step": 417 + }, + { + "epoch": 2.7, + "learning_rate": 2e-05, + "loss": 0.0321, + "step": 418 + }, + { + "epoch": 2.71, + "learning_rate": 2e-05, + "loss": 0.0312, + "step": 419 + }, + { + "epoch": 2.72, + "learning_rate": 2e-05, + "loss": 0.0276, + "step": 420 + }, + { + "epoch": 2.72, + "learning_rate": 2e-05, + "loss": 0.0241, + "step": 421 + }, + { + "epoch": 2.73, + "learning_rate": 2e-05, + "loss": 0.0302, + "step": 422 + }, + { + "epoch": 2.74, + "learning_rate": 2e-05, + "loss": 0.0271, + "step": 423 + }, + { + "epoch": 2.74, + "learning_rate": 2e-05, + "loss": 0.0319, + "step": 424 + }, + { + "epoch": 2.75, + "learning_rate": 2e-05, + "loss": 0.0269, + "step": 425 + }, + { + "epoch": 2.76, + "learning_rate": 2e-05, + "loss": 0.0247, + "step": 426 + }, + { + "epoch": 2.76, + "learning_rate": 2e-05, + "loss": 0.0241, + "step": 427 + }, + { + "epoch": 2.77, + "learning_rate": 2e-05, + "loss": 0.0245, + "step": 428 + }, + { + "epoch": 2.78, + "learning_rate": 2e-05, + "loss": 0.0262, + "step": 429 + }, + { + "epoch": 2.78, + "learning_rate": 2e-05, + "loss": 0.0307, + "step": 430 + }, + { + "epoch": 2.79, + "learning_rate": 2e-05, + "loss": 0.0293, + "step": 431 + }, + { + "epoch": 2.79, + "learning_rate": 2e-05, + "loss": 0.0235, + "step": 432 + }, + { + "epoch": 2.8, + "learning_rate": 2e-05, + "loss": 0.0257, + "step": 433 + }, + { + "epoch": 2.81, + "learning_rate": 2e-05, + "loss": 0.0301, + "step": 434 + }, + { + "epoch": 2.81, + "learning_rate": 2e-05, + "loss": 0.0266, + "step": 435 + }, + { + "epoch": 2.82, + "learning_rate": 2e-05, + "loss": 0.024, + "step": 436 + }, + { + "epoch": 2.83, + "learning_rate": 2e-05, + "loss": 0.0279, + "step": 437 + }, + { + "epoch": 2.83, + "learning_rate": 2e-05, + "loss": 0.0285, + "step": 438 + }, + { + "epoch": 2.84, + "learning_rate": 2e-05, + "loss": 0.0264, + "step": 439 + }, + { + "epoch": 2.85, + "learning_rate": 2e-05, + "loss": 0.0307, + "step": 440 + }, + { + "epoch": 2.85, + "learning_rate": 2e-05, + "loss": 0.0273, + "step": 441 + }, + { + "epoch": 2.86, + "learning_rate": 2e-05, + "loss": 0.0306, + "step": 442 + }, + { + "epoch": 2.87, + "learning_rate": 2e-05, + "loss": 0.0273, + "step": 443 + }, + { + "epoch": 2.87, + "learning_rate": 2e-05, + "loss": 0.0241, + "step": 444 + }, + { + "epoch": 2.88, + "learning_rate": 2e-05, + "loss": 0.0333, + "step": 445 + }, + { + "epoch": 2.89, + "learning_rate": 2e-05, + "loss": 0.0251, + "step": 446 + }, + { + "epoch": 2.89, + "learning_rate": 2e-05, + "loss": 0.026, + "step": 447 + }, + { + "epoch": 2.9, + "learning_rate": 2e-05, + "loss": 0.0309, + "step": 448 + }, + { + "epoch": 2.9, + "learning_rate": 2e-05, + "loss": 0.0317, + "step": 449 + }, + { + "epoch": 2.91, + "learning_rate": 2e-05, + "loss": 0.0298, + "step": 450 + }, + { + "epoch": 2.92, + "learning_rate": 2e-05, + "loss": 0.0303, + "step": 451 + }, + { + "epoch": 2.92, + "learning_rate": 2e-05, + "loss": 0.0317, + "step": 452 + }, + { + "epoch": 2.93, + "learning_rate": 2e-05, + "loss": 0.0271, + "step": 453 + }, + { + "epoch": 2.94, + "learning_rate": 2e-05, + "loss": 0.0301, + "step": 454 + }, + { + "epoch": 2.94, + "learning_rate": 2e-05, + "loss": 0.0272, + "step": 455 + }, + { + "epoch": 2.95, + "learning_rate": 2e-05, + "loss": 0.0302, + "step": 456 + }, + { + "epoch": 2.96, + "learning_rate": 2e-05, + "loss": 0.031, + "step": 457 + }, + { + "epoch": 2.96, + "learning_rate": 2e-05, + "loss": 0.0293, + "step": 458 + }, + { + "epoch": 2.97, + "learning_rate": 2e-05, + "loss": 0.023, + "step": 459 + }, + { + "epoch": 2.98, + "learning_rate": 2e-05, + "loss": 0.0263, + "step": 460 + }, + { + "epoch": 2.98, + "learning_rate": 2e-05, + "loss": 0.0279, + "step": 461 + }, + { + "epoch": 2.99, + "learning_rate": 2e-05, + "loss": 0.0272, + "step": 462 + } + ], + "max_steps": 462, + "num_train_epochs": 3, + "total_flos": 68006471270400.0, + "trial_name": null, + "trial_params": null +} diff --git a/training_args.bin b/training_args.bin new file mode 100644 index 0000000..fc1aa98 --- /dev/null +++ b/training_args.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ee7b5cbd6e2c82fff914ea83dcdc84ac059d8e2a06be53e1109b69f320111393 +size 5371 diff --git a/zero_to_fp32.py b/zero_to_fp32.py new file mode 100755 index 0000000..c5246ff --- /dev/null +++ b/zero_to_fp32.py @@ -0,0 +1,578 @@ +#!/usr/bin/env python + +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +# This script extracts fp32 consolidated weights from a zero 2 and 3 DeepSpeed checkpoints. It gets +# copied into the top level checkpoint dir, so the user can easily do the conversion at any point in +# the future. Once extracted, the weights don't require DeepSpeed and can be used in any +# application. +# +# example: python zero_to_fp32.py . pytorch_model.bin + +import argparse +import torch +import glob +import math +import os +import re +from collections import OrderedDict +from dataclasses import dataclass + +# while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with +# DeepSpeed data structures it has to be available in the current python environment. +from deepspeed.utils import logger +from deepspeed.checkpoint.constants import (DS_VERSION, OPTIMIZER_STATE_DICT, SINGLE_PARTITION_OF_FP32_GROUPS, + FP32_FLAT_GROUPS, ZERO_STAGE, PARTITION_COUNT, PARAM_SHAPES, BUFFER_NAMES, + FROZEN_PARAM_SHAPES, FROZEN_PARAM_FRAGMENTS) + + +@dataclass +class zero_model_state: + buffers: dict() + param_shapes: dict() + shared_params: list + ds_version: int + frozen_param_shapes: dict() + frozen_param_fragments: dict() + + +debug = 0 + +# load to cpu +device = torch.device('cpu') + + +def atoi(text): + return int(text) if text.isdigit() else text + + +def natural_keys(text): + ''' + alist.sort(key=natural_keys) sorts in human order + http://nedbatchelder.com/blog/200712/human_sorting.html + (See Toothy's implementation in the comments) + ''' + return [atoi(c) for c in re.split(r'(\d+)', text)] + + +def get_model_state_file(checkpoint_dir, zero_stage): + if not os.path.isdir(checkpoint_dir): + raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist") + + # there should be only one file + if zero_stage == 2: + file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt") + elif zero_stage == 3: + file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt") + + if not os.path.exists(file): + raise FileNotFoundError(f"can't find model states file at '{file}'") + + return file + + +def get_checkpoint_files(checkpoint_dir, glob_pattern): + # XXX: need to test that this simple glob rule works for multi-node setup too + ckpt_files = sorted(glob.glob(os.path.join(checkpoint_dir, glob_pattern)), key=natural_keys) + + if len(ckpt_files) == 0: + raise FileNotFoundError(f"can't find {glob_pattern} files in directory '{checkpoint_dir}'") + + return ckpt_files + + +def get_optim_files(checkpoint_dir): + return get_checkpoint_files(checkpoint_dir, "*_optim_states.pt") + + +def get_model_state_files(checkpoint_dir): + return get_checkpoint_files(checkpoint_dir, "*_model_states.pt") + + +def parse_model_states(files): + zero_model_states = [] + for file in files: + state_dict = torch.load(file, map_location=device) + + if BUFFER_NAMES not in state_dict: + raise ValueError(f"{file} is not a model state checkpoint") + buffer_names = state_dict[BUFFER_NAMES] + if debug: + print("Found buffers:", buffer_names) + + # recover just the buffers while restoring them to fp32 if they were saved in fp16 + buffers = {k: v.float() for k, v in state_dict["module"].items() if k in buffer_names} + param_shapes = state_dict[PARAM_SHAPES] + + # collect parameters that are included in param_shapes + param_names = [] + for s in param_shapes: + for name in s.keys(): + param_names.append(name) + + # update with frozen parameters + frozen_param_shapes = state_dict.get(FROZEN_PARAM_SHAPES, None) + if frozen_param_shapes is not None: + if debug: + print(f"Found frozen_param_shapes: {frozen_param_shapes}") + param_names += list(frozen_param_shapes.keys()) + + # handle shared params + shared_params = [[k, v] for k, v in state_dict["shared_params"].items()] + + ds_version = state_dict.get(DS_VERSION, None) + + frozen_param_fragments = state_dict.get(FROZEN_PARAM_FRAGMENTS, None) + + z_model_state = zero_model_state(buffers=buffers, + param_shapes=param_shapes, + shared_params=shared_params, + ds_version=ds_version, + frozen_param_shapes=frozen_param_shapes, + frozen_param_fragments=frozen_param_fragments) + zero_model_states.append(z_model_state) + + return zero_model_states + + +def parse_optim_states(files, ds_checkpoint_dir): + + total_files = len(files) + state_dicts = [] + for f in files: + state_dicts.append(torch.load(f, map_location=device)) + + if not ZERO_STAGE in state_dicts[0][OPTIMIZER_STATE_DICT]: + raise ValueError(f"{files[0]} is not a zero checkpoint") + zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE] + world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT] + + # For ZeRO-2 each param group can have different partition_count as data parallelism for expert + # parameters can be different from data parallelism for non-expert parameters. So we can just + # use the max of the partition_count to get the dp world_size. + + if type(world_size) is list: + world_size = max(world_size) + + if world_size != total_files: + raise ValueError( + f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. " + "Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes." + ) + + # the groups are named differently in each stage + if zero_stage == 2: + fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS + elif zero_stage == 3: + fp32_groups_key = FP32_FLAT_GROUPS + else: + raise ValueError(f"unknown zero stage {zero_stage}") + + if zero_stage == 2: + fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))] + elif zero_stage == 3: + # if there is more than one param group, there will be multiple flattened tensors - one + # flattened tensor per group - for simplicity merge them into a single tensor + # + # XXX: could make the script more memory efficient for when there are multiple groups - it + # will require matching the sub-lists of param_shapes for each param group flattened tensor + + fp32_flat_groups = [ + torch.cat(state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key], 0) for i in range(len(state_dicts)) + ] + + return zero_stage, world_size, fp32_flat_groups + + +def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir): + """ + Returns fp32 state_dict reconstructed from ds checkpoint + + Args: + - ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are) + + """ + print(f"Processing zero checkpoint '{ds_checkpoint_dir}'") + + optim_files = get_optim_files(ds_checkpoint_dir) + zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir) + print(f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}") + + model_files = get_model_state_files(ds_checkpoint_dir) + + zero_model_states = parse_model_states(model_files) + print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}') + + if zero_stage == 2: + return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states) + elif zero_stage == 3: + return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states) + + +def _zero2_merge_frozen_params(state_dict, zero_model_states): + if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0: + return + + frozen_param_shapes = zero_model_states[0].frozen_param_shapes + frozen_param_fragments = zero_model_states[0].frozen_param_fragments + + if debug: + num_elem = sum(s.numel() for s in frozen_param_shapes.values()) + print(f'rank 0: {FROZEN_PARAM_SHAPES}.numel = {num_elem}') + + wanted_params = len(frozen_param_shapes) + wanted_numel = sum(s.numel() for s in frozen_param_shapes.values()) + avail_numel = sum([p.numel() for p in frozen_param_fragments.values()]) + print(f'Frozen params: Have {avail_numel} numels to process.') + print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params') + + total_params = 0 + total_numel = 0 + for name, shape in frozen_param_shapes.items(): + total_params += 1 + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + + state_dict[name] = frozen_param_fragments[name] + + if debug: + print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ") + + print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements") + + +def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states): + param_shapes = zero_model_states[0].param_shapes + + # Reconstruction protocol: + # + # XXX: document this + + if debug: + for i in range(world_size): + for j in range(len(fp32_flat_groups[0])): + print(f"{FP32_FLAT_GROUPS}[{i}][{j}].shape={fp32_flat_groups[i][j].shape}") + + # XXX: memory usage doubles here (zero2) + num_param_groups = len(fp32_flat_groups[0]) + merged_single_partition_of_fp32_groups = [] + for i in range(num_param_groups): + merged_partitions = [sd[i] for sd in fp32_flat_groups] + full_single_fp32_vector = torch.cat(merged_partitions, 0) + merged_single_partition_of_fp32_groups.append(full_single_fp32_vector) + avail_numel = sum( + [full_single_fp32_vector.numel() for full_single_fp32_vector in merged_single_partition_of_fp32_groups]) + + if debug: + wanted_params = sum([len(shapes) for shapes in param_shapes]) + wanted_numel = sum([sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes]) + # not asserting if there is a mismatch due to possible padding + print(f"Have {avail_numel} numels to process.") + print(f"Need {wanted_numel} numels in {wanted_params} params.") + + # params + # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support + # out-of-core computing solution + total_numel = 0 + total_params = 0 + for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups): + offset = 0 + avail_numel = full_single_fp32_vector.numel() + for name, shape in shapes.items(): + + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + total_params += 1 + + if debug: + print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ") + state_dict[name] = full_single_fp32_vector.narrow(0, offset, unpartitioned_numel).view(shape) + offset += unpartitioned_numel + + # Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and + # avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex + # paddings performed in the code it's almost impossible to predict the exact numbers w/o the + # live optimizer object, so we are checking that the numbers are within the right range + align_to = 2 * world_size + + def zero2_align(x): + return align_to * math.ceil(x / align_to) + + if debug: + print(f"original offset={offset}, avail_numel={avail_numel}") + + offset = zero2_align(offset) + avail_numel = zero2_align(avail_numel) + + if debug: + print(f"aligned offset={offset}, avail_numel={avail_numel}") + + # Sanity check + if offset != avail_numel: + raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong") + + print(f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements") + + +def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states): + state_dict = OrderedDict() + + # buffers + buffers = zero_model_states[0].buffers + state_dict.update(buffers) + if debug: + print(f"added {len(buffers)} buffers") + + _zero2_merge_frozen_params(state_dict, zero_model_states) + + _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states) + + # recover shared parameters + for pair in zero_model_states[0].shared_params: + if pair[1] in state_dict: + state_dict[pair[0]] = state_dict[pair[1]] + + return state_dict + + +def zero3_partitioned_param_info(unpartitioned_numel, world_size): + remainder = unpartitioned_numel % world_size + padding_numel = (world_size - remainder) if remainder else 0 + partitioned_numel = math.ceil(unpartitioned_numel / world_size) + return partitioned_numel, padding_numel + + +def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states): + if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0: + return + + if debug: + for i in range(world_size): + num_elem = sum(s.numel() for s in zero_model_states[i].frozen_param_fragments.values()) + print(f'rank {i}: {FROZEN_PARAM_SHAPES}.numel = {num_elem}') + + frozen_param_shapes = zero_model_states[0].frozen_param_shapes + wanted_params = len(frozen_param_shapes) + wanted_numel = sum(s.numel() for s in frozen_param_shapes.values()) + avail_numel = sum([p.numel() for p in zero_model_states[0].frozen_param_fragments.values()]) * world_size + print(f'Frozen params: Have {avail_numel} numels to process.') + print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params') + + total_params = 0 + total_numel = 0 + for name, shape in zero_model_states[0].frozen_param_shapes.items(): + total_params += 1 + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + + param_frags = tuple(model_state.frozen_param_fragments[name] for model_state in zero_model_states) + state_dict[name] = torch.cat(param_frags, 0).narrow(0, 0, unpartitioned_numel).view(shape) + + partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size) + + if debug: + print( + f"Frozen params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}" + ) + + print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements") + + +def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states): + param_shapes = zero_model_states[0].param_shapes + avail_numel = fp32_flat_groups[0].numel() * world_size + # Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each + # param, re-consolidating each param, while dealing with padding if any + + # merge list of dicts, preserving order + param_shapes = {k: v for d in param_shapes for k, v in d.items()} + + if debug: + for i in range(world_size): + print(f"{FP32_FLAT_GROUPS}[{i}].shape={fp32_flat_groups[i].shape}") + + wanted_params = len(param_shapes) + wanted_numel = sum(shape.numel() for shape in param_shapes.values()) + # not asserting if there is a mismatch due to possible padding + avail_numel = fp32_flat_groups[0].numel() * world_size + print(f"Trainable params: Have {avail_numel} numels to process.") + print(f"Trainable params: Need {wanted_numel} numels in {wanted_params} params.") + + # params + # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support + # out-of-core computing solution + offset = 0 + total_numel = 0 + total_params = 0 + for name, shape in param_shapes.items(): + + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + total_params += 1 + + partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size) + + if debug: + print( + f"Trainable params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}" + ) + + # XXX: memory usage doubles here + state_dict[name] = torch.cat( + tuple(fp32_flat_groups[i].narrow(0, offset, partitioned_numel) for i in range(world_size)), + 0).narrow(0, 0, unpartitioned_numel).view(shape) + offset += partitioned_numel + + offset *= world_size + + # Sanity check + if offset != avail_numel: + raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong") + + print(f"Reconstructed Trainable fp32 state dict with {total_params} params {total_numel} elements") + + +def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states): + state_dict = OrderedDict() + + # buffers + buffers = zero_model_states[0].buffers + state_dict.update(buffers) + if debug: + print(f"added {len(buffers)} buffers") + + _zero3_merge_frozen_params(state_dict, world_size, zero_model_states) + + _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states) + + # recover shared parameters + for pair in zero_model_states[0].shared_params: + if pair[1] in state_dict: + state_dict[pair[0]] = state_dict[pair[1]] + + return state_dict + + +def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None): + """ + Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with + ``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example + via a model hub. + + Args: + - ``checkpoint_dir``: path to the desired checkpoint folder + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14`` + + Returns: + - pytorch ``state_dict`` + + Note: this approach may not work if your application doesn't have sufficient free CPU memory and + you may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with + the checkpoint. + + A typical usage might be :: + + from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint + # do the training and checkpoint saving + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu + model = model.cpu() # move to cpu + model.load_state_dict(state_dict) + # submit to model hub or save the model to share with others + + In this example the ``model`` will no longer be usable in the deepspeed context of the same + application. i.e. you will need to re-initialize the deepspeed engine, since + ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it. + + If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead. + + """ + if tag is None: + latest_path = os.path.join(checkpoint_dir, 'latest') + if os.path.isfile(latest_path): + with open(latest_path, 'r') as fd: + tag = fd.read().strip() + else: + raise ValueError(f"Unable to find 'latest' file at {latest_path}") + + ds_checkpoint_dir = os.path.join(checkpoint_dir, tag) + + if not os.path.isdir(ds_checkpoint_dir): + raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist") + + return _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir) + + +def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag=None): + """ + Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be + loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed. + + Args: + - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``) + - ``output_file``: path to the pytorch fp32 state_dict output file (e.g. path/pytorch_model.bin) + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14`` + """ + + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag) + print(f"Saving fp32 state dict to {output_file}") + torch.save(state_dict, output_file) + + +def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None): + """ + 1. Put the provided model to cpu + 2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` + 3. Load it into the provided model + + Args: + - ``model``: the model object to update + - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``) + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14`` + + Returns: + - ``model`: modified model + + Make sure you have plenty of CPU memory available before you call this function. If you don't + have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it + conveniently placed for you in the checkpoint folder. + + A typical usage might be :: + + from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint + model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir) + # submit to model hub or save the model to share with others + + Note, that once this was run, the ``model`` will no longer be usable in the deepspeed context + of the same application. i.e. you will need to re-initialize the deepspeed engine, since + ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it. + + """ + logger.info(f"Extracting fp32 weights") + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag) + + logger.info(f"Overwriting model with fp32 weights") + model = model.cpu() + model.load_state_dict(state_dict, strict=False) + + return model + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument("checkpoint_dir", + type=str, + help="path to the desired checkpoint folder, e.g., path/checkpoint-12") + parser.add_argument( + "output_file", + type=str, + help="path to the pytorch fp32 state_dict output file (e.g. path/checkpoint-12/pytorch_model.bin)") + parser.add_argument("-d", "--debug", action='store_true', help="enable debug") + args = parser.parse_args() + + debug = args.debug + + convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir, args.output_file)