初始化项目,由ModelHub XC社区提供模型
Model: TheBloke/WizardLM-13B-V1.2-AWQ Source: Original Platform
This commit is contained in:
36
.gitattributes
vendored
Normal file
36
.gitattributes
vendored
Normal file
@@ -0,0 +1,36 @@
|
||||
*.7z filter=lfs diff=lfs merge=lfs -text
|
||||
*.arrow filter=lfs diff=lfs merge=lfs -text
|
||||
*.bin filter=lfs diff=lfs merge=lfs -text
|
||||
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
||||
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
||||
*.ftz filter=lfs diff=lfs merge=lfs -text
|
||||
*.gz filter=lfs diff=lfs merge=lfs -text
|
||||
*.h5 filter=lfs diff=lfs merge=lfs -text
|
||||
*.joblib filter=lfs diff=lfs merge=lfs -text
|
||||
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
||||
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
||||
*.model filter=lfs diff=lfs merge=lfs -text
|
||||
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
||||
*.npy filter=lfs diff=lfs merge=lfs -text
|
||||
*.npz filter=lfs diff=lfs merge=lfs -text
|
||||
*.onnx filter=lfs diff=lfs merge=lfs -text
|
||||
*.ot filter=lfs diff=lfs merge=lfs -text
|
||||
*.parquet filter=lfs diff=lfs merge=lfs -text
|
||||
*.pb filter=lfs diff=lfs merge=lfs -text
|
||||
*.pickle filter=lfs diff=lfs merge=lfs -text
|
||||
*.pkl filter=lfs diff=lfs merge=lfs -text
|
||||
*.pt filter=lfs diff=lfs merge=lfs -text
|
||||
*.pth filter=lfs diff=lfs merge=lfs -text
|
||||
*.rar filter=lfs diff=lfs merge=lfs -text
|
||||
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
||||
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
||||
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
||||
*.tar filter=lfs diff=lfs merge=lfs -text
|
||||
*.tflite filter=lfs diff=lfs merge=lfs -text
|
||||
*.tgz filter=lfs diff=lfs merge=lfs -text
|
||||
*.wasm filter=lfs diff=lfs merge=lfs -text
|
||||
*.xz filter=lfs diff=lfs merge=lfs -text
|
||||
*.zip filter=lfs diff=lfs merge=lfs -text
|
||||
*.zst filter=lfs diff=lfs merge=lfs -text
|
||||
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
||||
model.safetensors filter=lfs diff=lfs merge=lfs -text
|
||||
126
LICENSE.txt
Normal file
126
LICENSE.txt
Normal file
@@ -0,0 +1,126 @@
|
||||
LLAMA 2 COMMUNITY LICENSE AGREEMENT
|
||||
Llama 2 Version Release Date: July 18, 2023
|
||||
|
||||
"Agreement" means the terms and conditions for use, reproduction, distribution and
|
||||
modification of the Llama Materials set forth herein.
|
||||
|
||||
"Documentation" means the specifications, manuals and documentation
|
||||
accompanying Llama 2 distributed by Meta at ai.meta.com/resources/models-and-
|
||||
libraries/llama-downloads/.
|
||||
|
||||
"Licensee" or "you" means you, or your employer or any other person or entity (if
|
||||
you are entering into this Agreement on such person or entity's behalf), of the age
|
||||
required under applicable laws, rules or regulations to provide legal consent and that
|
||||
has legal authority to bind your employer or such other person or entity if you are
|
||||
entering in this Agreement on their behalf.
|
||||
|
||||
"Llama 2" means the foundational large language models and software and
|
||||
algorithms, including machine-learning model code, trained model weights,
|
||||
inference-enabling code, training-enabling code, fine-tuning enabling code and other
|
||||
elements of the foregoing distributed by Meta at ai.meta.com/resources/models-and-
|
||||
libraries/llama-downloads/.
|
||||
|
||||
"Llama Materials" means, collectively, Meta's proprietary Llama 2 and
|
||||
Documentation (and any portion thereof) made available under this Agreement.
|
||||
|
||||
"Meta" or "we" means Meta Platforms Ireland Limited (if you are located in or, if you
|
||||
are an entity, your principal place of business is in the EEA or Switzerland) and Meta
|
||||
Platforms, Inc. (if you are located outside of the EEA or Switzerland).
|
||||
|
||||
By clicking "I Accept" below or by using or distributing any portion or element of the
|
||||
Llama Materials, you agree to be bound by this Agreement.
|
||||
|
||||
1. License Rights and Redistribution.
|
||||
|
||||
a. Grant of Rights. You are granted a non-exclusive, worldwide, non-
|
||||
transferable and royalty-free limited license under Meta's intellectual property or
|
||||
other rights owned by Meta embodied in the Llama Materials to use, reproduce,
|
||||
distribute, copy, create derivative works of, and make modifications to the Llama
|
||||
Materials.
|
||||
|
||||
b. Redistribution and Use.
|
||||
|
||||
i. If you distribute or make the Llama Materials, or any derivative works
|
||||
thereof, available to a third party, you shall provide a copy of this Agreement to such
|
||||
third party.
|
||||
ii. If you receive Llama Materials, or any derivative works thereof, from
|
||||
a Licensee as part of an integrated end user product, then Section 2 of this
|
||||
Agreement will not apply to you.
|
||||
|
||||
iii. You must retain in all copies of the Llama Materials that you
|
||||
distribute the following attribution notice within a "Notice" text file distributed as a
|
||||
part of such copies: "Llama 2 is licensed under the LLAMA 2 Community License,
|
||||
Copyright (c) Meta Platforms, Inc. All Rights Reserved."
|
||||
|
||||
iv. Your use of the Llama Materials must comply with applicable laws
|
||||
and regulations (including trade compliance laws and regulations) and adhere to the
|
||||
Acceptable Use Policy for the Llama Materials (available at
|
||||
https://ai.meta.com/llama/use-policy), which is hereby incorporated by reference into
|
||||
this Agreement.
|
||||
|
||||
v. You will not use the Llama Materials or any output or results of the
|
||||
Llama Materials to improve any other large language model (excluding Llama 2 or
|
||||
derivative works thereof).
|
||||
|
||||
2. Additional Commercial Terms. If, on the Llama 2 version release date, the
|
||||
monthly active users of the products or services made available by or for Licensee,
|
||||
or Licensee's affiliates, is greater than 700 million monthly active users in the
|
||||
preceding calendar month, you must request a license from Meta, which Meta may
|
||||
grant to you in its sole discretion, and you are not authorized to exercise any of the
|
||||
rights under this Agreement unless or until Meta otherwise expressly grants you
|
||||
such rights.
|
||||
|
||||
3. Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE
|
||||
LLAMA MATERIALS AND ANY OUTPUT AND RESULTS THEREFROM ARE
|
||||
PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
|
||||
EITHER EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY
|
||||
WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR
|
||||
FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE
|
||||
FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING
|
||||
THE LLAMA MATERIALS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR
|
||||
USE OF THE LLAMA MATERIALS AND ANY OUTPUT AND RESULTS.
|
||||
|
||||
4. Limitation of Liability. IN NO EVENT WILL META OR ITS AFFILIATES BE
|
||||
LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT,
|
||||
NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS
|
||||
AGREEMENT, FOR ANY LOST PROFITS OR ANY INDIRECT, SPECIAL,
|
||||
CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN
|
||||
IF META OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF
|
||||
ANY OF THE FOREGOING.
|
||||
|
||||
5. Intellectual Property.
|
||||
|
||||
a. No trademark licenses are granted under this Agreement, and in
|
||||
connection with the Llama Materials, neither Meta nor Licensee may use any name
|
||||
or mark owned by or associated with the other or any of its affiliates, except as
|
||||
required for reasonable and customary use in describing and redistributing the
|
||||
Llama Materials.
|
||||
|
||||
b. Subject to Meta's ownership of Llama Materials and derivatives made by or
|
||||
for Meta, with respect to any derivative works and modifications of the Llama
|
||||
Materials that are made by you, as between you and Meta, you are and will be the
|
||||
owner of such derivative works and modifications.
|
||||
|
||||
c. If you institute litigation or other proceedings against Meta or any entity
|
||||
(including a cross-claim or counterclaim in a lawsuit) alleging that the Llama
|
||||
Materials or Llama 2 outputs or results, or any portion of any of the foregoing,
|
||||
constitutes infringement of intellectual property or other rights owned or licensable
|
||||
by you, then any licenses granted to you under this Agreement shall terminate as of
|
||||
the date such litigation or claim is filed or instituted. You will indemnify and hold
|
||||
harmless Meta from and against any claim by any third party arising out of or related
|
||||
to your use or distribution of the Llama Materials.
|
||||
|
||||
6. Term and Termination. The term of this Agreement will commence upon your
|
||||
acceptance of this Agreement or access to the Llama Materials and will continue in
|
||||
full force and effect until terminated in accordance with the terms and conditions
|
||||
herein. Meta may terminate this Agreement if you are in breach of any term or
|
||||
condition of this Agreement. Upon termination of this Agreement, you shall delete
|
||||
and cease use of the Llama Materials. Sections 3, 4 and 7 shall survive the
|
||||
termination of this Agreement.
|
||||
|
||||
7. Governing Law and Jurisdiction. This Agreement will be governed and
|
||||
construed under the laws of the State of California without regard to choice of law
|
||||
principles, and the UN Convention on Contracts for the International Sale of Goods
|
||||
does not apply to this Agreement. The courts of California shall have exclusive
|
||||
jurisdiction of any dispute arising out of this Agreement.
|
||||
|
||||
1
Notice
Normal file
1
Notice
Normal file
@@ -0,0 +1 @@
|
||||
Llama 2 is licensed under the LLAMA 2 Community License, Copyright © Meta Platforms, Inc. All Rights Reserved.
|
||||
332
README.md
Normal file
332
README.md
Normal file
@@ -0,0 +1,332 @@
|
||||
---
|
||||
license: llama2
|
||||
model_name: WizardLM 13B V1.2
|
||||
base_model: WizardLM/WizardLM-13B-V1.2
|
||||
inference: false
|
||||
model_creator: WizardLM
|
||||
model_type: llama
|
||||
prompt_template: 'A chat between a curious user and an artificial intelligence assistant.
|
||||
The assistant gives helpful, detailed, and polite answers to the user''s questions.
|
||||
USER: {prompt} ASSISTANT:
|
||||
|
||||
'
|
||||
quantized_by: TheBloke
|
||||
---
|
||||
|
||||
<!-- header start -->
|
||||
<!-- 200823 -->
|
||||
<div style="width: auto; margin-left: auto; margin-right: auto">
|
||||
<img src="https://i.imgur.com/EBdldam.jpg" alt="TheBlokeAI" style="width: 100%; min-width: 400px; display: block; margin: auto;">
|
||||
</div>
|
||||
<div style="display: flex; justify-content: space-between; width: 100%;">
|
||||
<div style="display: flex; flex-direction: column; align-items: flex-start;">
|
||||
<p style="margin-top: 0.5em; margin-bottom: 0em;"><a href="https://discord.gg/theblokeai">Chat & support: TheBloke's Discord server</a></p>
|
||||
</div>
|
||||
<div style="display: flex; flex-direction: column; align-items: flex-end;">
|
||||
<p style="margin-top: 0.5em; margin-bottom: 0em;"><a href="https://www.patreon.com/TheBlokeAI">Want to contribute? TheBloke's Patreon page</a></p>
|
||||
</div>
|
||||
</div>
|
||||
<div style="text-align:center; margin-top: 0em; margin-bottom: 0em"><p style="margin-top: 0.25em; margin-bottom: 0em;">TheBloke's LLM work is generously supported by a grant from <a href="https://a16z.com">andreessen horowitz (a16z)</a></p></div>
|
||||
<hr style="margin-top: 1.0em; margin-bottom: 1.0em;">
|
||||
<!-- header end -->
|
||||
|
||||
# WizardLM 13B V1.2 - AWQ
|
||||
- Model creator: [WizardLM](https://huggingface.co/WizardLM)
|
||||
- Original model: [WizardLM 13B V1.2](https://huggingface.co/WizardLM/WizardLM-13B-V1.2)
|
||||
|
||||
<!-- description start -->
|
||||
## Description
|
||||
|
||||
This repo contains AWQ model files for [WizardLM's WizardLM 13B V1.2](https://huggingface.co/WizardLM/WizardLM-13B-V1.2).
|
||||
|
||||
|
||||
### About AWQ
|
||||
|
||||
AWQ is an efficient, accurate and blazing-fast low-bit weight quantization method, currently supporting 4-bit quantization. Compared to GPTQ, it offers faster Transformers-based inference.
|
||||
|
||||
It is also now supported by continuous batching server [vLLM](https://github.com/vllm-project/vllm), allowing use of AWQ models for high-throughput concurrent inference in multi-user server scenarios. Note that, at the time of writing, overall throughput is still lower than running vLLM with unquantised models, however using AWQ enables using much smaller GPUs which can lead to easier deployment and overall cost savings. For example, a 70B model can be run on 1 x 48GB GPU instead of 2 x 80GB.
|
||||
<!-- description end -->
|
||||
<!-- repositories-available start -->
|
||||
## Repositories available
|
||||
|
||||
* [AWQ model(s) for GPU inference.](https://huggingface.co/TheBloke/WizardLM-13B-V1.2-AWQ)
|
||||
* [GPTQ models for GPU inference, with multiple quantisation parameter options.](https://huggingface.co/TheBloke/WizardLM-13B-V1.2-GPTQ)
|
||||
* [2, 3, 4, 5, 6 and 8-bit GGUF models for CPU+GPU inference](https://huggingface.co/TheBloke/WizardLM-13B-V1.2-GGUF)
|
||||
* [WizardLM's original unquantised fp16 model in pytorch format, for GPU inference and for further conversions](https://huggingface.co/WizardLM/WizardLM-13B-V1.2)
|
||||
<!-- repositories-available end -->
|
||||
|
||||
<!-- prompt-template start -->
|
||||
## Prompt template: Vicuna
|
||||
|
||||
```
|
||||
A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: {prompt} ASSISTANT:
|
||||
|
||||
```
|
||||
|
||||
<!-- prompt-template end -->
|
||||
|
||||
|
||||
<!-- README_AWQ.md-provided-files start -->
|
||||
## Provided files and AWQ parameters
|
||||
|
||||
For my first release of AWQ models, I am releasing 128g models only. I will consider adding 32g as well if there is interest, and once I have done perplexity and evaluation comparisons, but at this time 32g models are still not fully tested with AutoAWQ and vLLM.
|
||||
|
||||
Models are released as sharded safetensors files.
|
||||
|
||||
| Branch | Bits | GS | AWQ Dataset | Seq Len | Size |
|
||||
| ------ | ---- | -- | ----------- | ------- | ---- |
|
||||
| [main](https://huggingface.co/TheBloke/WizardLM-13B-V1.2-AWQ/tree/main) | 4 | 128 | [wikitext](https://huggingface.co/datasets/wikitext/viewer/wikitext-2-v1/test) | 4096 | 7.25 GB
|
||||
|
||||
<!-- README_AWQ.md-provided-files end -->
|
||||
|
||||
<!-- README_AWQ.md-use-from-vllm start -->
|
||||
## Serving this model from vLLM
|
||||
|
||||
Documentation on installing and using vLLM [can be found here](https://vllm.readthedocs.io/en/latest/).
|
||||
|
||||
- When using vLLM as a server, pass the `--quantization awq` parameter, for example:
|
||||
|
||||
```shell
|
||||
python3 python -m vllm.entrypoints.api_server --model TheBloke/WizardLM-13B-V1.2-AWQ --quantization awq
|
||||
```
|
||||
|
||||
When using vLLM from Python code, pass the `quantization=awq` parameter, for example:
|
||||
|
||||
```python
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
|
||||
llm = LLM(model="TheBloke/WizardLM-13B-V1.2-AWQ", quantization="awq")
|
||||
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
# Print the outputs.
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
```
|
||||
<!-- README_AWQ.md-use-from-vllm start -->
|
||||
|
||||
<!-- README_AWQ.md-use-from-python start -->
|
||||
## How to use this AWQ model from Python code
|
||||
|
||||
### Install the necessary packages
|
||||
|
||||
Requires: [AutoAWQ](https://github.com/casper-hansen/AutoAWQ) 0.0.2 or later
|
||||
|
||||
```shell
|
||||
pip3 install autoawq
|
||||
```
|
||||
|
||||
If you have problems installing [AutoAWQ](https://github.com/casper-hansen/AutoAWQ) using the pre-built wheels, install it from source instead:
|
||||
|
||||
```shell
|
||||
pip3 uninstall -y autoawq
|
||||
git clone https://github.com/casper-hansen/AutoAWQ
|
||||
cd AutoAWQ
|
||||
pip3 install .
|
||||
```
|
||||
|
||||
### You can then try the following example code
|
||||
|
||||
```python
|
||||
from awq import AutoAWQForCausalLM
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
model_name_or_path = "TheBloke/WizardLM-13B-V1.2-AWQ"
|
||||
|
||||
# Load model
|
||||
model = AutoAWQForCausalLM.from_quantized(model_name_or_path, fuse_layers=True,
|
||||
trust_remote_code=False, safetensors=True)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=False)
|
||||
|
||||
prompt = "Tell me about AI"
|
||||
prompt_template=f'''A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: {prompt} ASSISTANT:
|
||||
|
||||
'''
|
||||
|
||||
print("\n\n*** Generate:")
|
||||
|
||||
tokens = tokenizer(
|
||||
prompt_template,
|
||||
return_tensors='pt'
|
||||
).input_ids.cuda()
|
||||
|
||||
# Generate output
|
||||
generation_output = model.generate(
|
||||
tokens,
|
||||
do_sample=True,
|
||||
temperature=0.7,
|
||||
top_p=0.95,
|
||||
top_k=40,
|
||||
max_new_tokens=512
|
||||
)
|
||||
|
||||
print("Output: ", tokenizer.decode(generation_output[0]))
|
||||
|
||||
# Inference can also be done using transformers' pipeline
|
||||
from transformers import pipeline
|
||||
|
||||
print("*** Pipeline:")
|
||||
pipe = pipeline(
|
||||
"text-generation",
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
max_new_tokens=512,
|
||||
do_sample=True,
|
||||
temperature=0.7,
|
||||
top_p=0.95,
|
||||
top_k=40,
|
||||
repetition_penalty=1.1
|
||||
)
|
||||
|
||||
print(pipe(prompt_template)[0]['generated_text'])
|
||||
```
|
||||
<!-- README_AWQ.md-use-from-python end -->
|
||||
|
||||
<!-- README_AWQ.md-compatibility start -->
|
||||
## Compatibility
|
||||
|
||||
The files provided are tested to work with [AutoAWQ](https://github.com/casper-hansen/AutoAWQ), and [vLLM](https://github.com/vllm-project/vllm).
|
||||
|
||||
[Huggingface Text Generation Inference (TGI)](https://github.com/huggingface/text-generation-inference) is not yet compatible with AWQ, but a PR is open which should bring support soon: [TGI PR #781](https://github.com/huggingface/text-generation-inference/issues/781).
|
||||
<!-- README_AWQ.md-compatibility end -->
|
||||
|
||||
<!-- footer start -->
|
||||
<!-- 200823 -->
|
||||
## Discord
|
||||
|
||||
For further support, and discussions on these models and AI in general, join us at:
|
||||
|
||||
[TheBloke AI's Discord server](https://discord.gg/theblokeai)
|
||||
|
||||
## Thanks, and how to contribute
|
||||
|
||||
Thanks to the [chirper.ai](https://chirper.ai) team!
|
||||
|
||||
Thanks to Clay from [gpus.llm-utils.org](llm-utils)!
|
||||
|
||||
I've had a lot of people ask if they can contribute. I enjoy providing models and helping people, and would love to be able to spend even more time doing it, as well as expanding into new projects like fine tuning/training.
|
||||
|
||||
If you're able and willing to contribute it will be most gratefully received and will help me to keep providing more models, and to start work on new AI projects.
|
||||
|
||||
Donaters will get priority support on any and all AI/LLM/model questions and requests, access to a private Discord room, plus other benefits.
|
||||
|
||||
* Patreon: https://patreon.com/TheBlokeAI
|
||||
* Ko-Fi: https://ko-fi.com/TheBlokeAI
|
||||
|
||||
**Special thanks to**: Aemon Algiz.
|
||||
|
||||
**Patreon special mentions**: Alicia Loh, Stephen Murray, K, Ajan Kanaga, RoA, Magnesian, Deo Leter, Olakabola, Eugene Pentland, zynix, Deep Realms, Raymond Fosdick, Elijah Stavena, Iucharbius, Erik Bjäreholt, Luis Javier Navarrete Lozano, Nicholas, theTransient, John Detwiler, alfie_i, knownsqashed, Mano Prime, Willem Michiel, Enrico Ros, LangChain4j, OG, Michael Dempsey, Pierre Kircher, Pedro Madruga, James Bentley, Thomas Belote, Luke @flexchar, Leonard Tan, Johann-Peter Hartmann, Illia Dulskyi, Fen Risland, Chadd, S_X, Jeff Scroggin, Ken Nordquist, Sean Connelly, Artur Olbinski, Swaroop Kallakuri, Jack West, Ai Maven, David Ziegler, Russ Johnson, transmissions 11, John Villwock, Alps Aficionado, Clay Pascal, Viktor Bowallius, Subspace Studios, Rainer Wilmers, Trenton Dambrowitz, vamX, Michael Levine, 준교 김, Brandon Frisco, Kalila, Trailburnt, Randy H, Talal Aujan, Nathan Dryer, Vadim, 阿明, ReadyPlayerEmma, Tiffany J. Kim, George Stoitzev, Spencer Kim, Jerry Meng, Gabriel Tamborski, Cory Kujawski, Jeffrey Morgan, Spiking Neurons AB, Edmond Seymore, Alexandros Triantafyllidis, Lone Striker, Cap'n Zoog, Nikolai Manek, danny, ya boyyy, Derek Yates, usrbinkat, Mandus, TL, Nathan LeClaire, subjectnull, Imad Khwaja, webtim, Raven Klaugh, Asp the Wyvern, Gabriel Puliatti, Caitlyn Gatomon, Joseph William Delisle, Jonathan Leane, Luke Pendergrass, SuperWojo, Sebastain Graf, Will Dee, Fred von Graf, Andrey, Dan Guido, Daniel P. Andersen, Nitin Borwankar, Elle, Vitor Caleffi, biorpg, jjj, NimbleBox.ai, Pieter, Matthew Berman, terasurfer, Michael Davis, Alex, Stanislav Ovsiannikov
|
||||
|
||||
|
||||
Thank you to all my generous patrons and donaters!
|
||||
|
||||
And thank you again to a16z for their generous grant.
|
||||
|
||||
<!-- footer end -->
|
||||
|
||||
# Original model card: WizardLM's WizardLM 13B V1.2
|
||||
|
||||
This is the **Full-Weight** of WizardLM-13B V1.2 model, this model is trained from **Llama-2 13b**.
|
||||
|
||||
## WizardLM: Empowering Large Pre-Trained Language Models to Follow Complex Instructions
|
||||
|
||||
|
||||
|
||||
<p align="center">
|
||||
🤗 <a href="https://huggingface.co/WizardLM" target="_blank">HF Repo</a> •🐱 <a href="https://github.com/nlpxucan/WizardLM" target="_blank">Github Repo</a> • 🐦 <a href="https://twitter.com/WizardLM_AI" target="_blank">Twitter</a> • 📃 <a href="https://arxiv.org/abs/2304.12244" target="_blank">[WizardLM]</a> • 📃 <a href="https://arxiv.org/abs/2306.08568" target="_blank">[WizardCoder]</a> • 📃 <a href="https://arxiv.org/abs/2308.09583" target="_blank">[WizardMath]</a> <br>
|
||||
</p>
|
||||
<p align="center">
|
||||
👋 Join our <a href="https://discord.gg/VZjjHtWrKs" target="_blank">Discord</a>
|
||||
</p>
|
||||
|
||||
## News
|
||||
|
||||
- 🔥🔥🔥[2023/08/26] We released **WizardCoder-Python-34B-V1.0** , which achieves the **73.2 pass@1** and surpasses **GPT4 (2023/03/15)**, **ChatGPT-3.5**, and **Claude2** on the [HumanEval Benchmarks](https://github.com/openai/human-eval). For more details, please refer to [WizardCoder](https://github.com/nlpxucan/WizardLM/tree/main/WizardCoder).
|
||||
- [2023/06/16] We released **WizardCoder-15B-V1.0** , which surpasses **Claude-Plus (+6.8)**, **Bard (+15.3)** and **InstructCodeT5+ (+22.3)** on the [HumanEval Benchmarks](https://github.com/openai/human-eval). For more details, please refer to [WizardCoder](https://github.com/nlpxucan/WizardLM/tree/main/WizardCoder).
|
||||
|
||||
| Model | Checkpoint | Paper | HumanEval | MBPP | Demo | License |
|
||||
| ----- |------| ---- |------|-------| ----- | ----- |
|
||||
| WizardCoder-Python-34B-V1.0 | 🤗 <a href="https://huggingface.co/WizardLM/WizardCoder-Python-34B-V1.0" target="_blank">HF Link</a> | 📃 <a href="https://arxiv.org/abs/2306.08568" target="_blank">[WizardCoder]</a> | 73.2 | 61.2 | [Demo](http://47.103.63.15:50085/) | <a href="https://ai.meta.com/resources/models-and-libraries/llama-downloads/" target="_blank">Llama2</a> |
|
||||
| WizardCoder-15B-V1.0 | 🤗 <a href="https://huggingface.co/WizardLM/WizardCoder-15B-V1.0" target="_blank">HF Link</a> | 📃 <a href="https://arxiv.org/abs/2306.08568" target="_blank">[WizardCoder]</a> | 59.8 |50.6 | -- | <a href="https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement" target="_blank">OpenRAIL-M</a> |
|
||||
| WizardCoder-Python-13B-V1.0 | 🤗 <a href="https://huggingface.co/WizardLM/WizardCoder-Python-13B-V1.0" target="_blank">HF Link</a> | 📃 <a href="https://arxiv.org/abs/2306.08568" target="_blank">[WizardCoder]</a> | 64.0 | 55.6 | -- | <a href="https://ai.meta.com/resources/models-and-libraries/llama-downloads/" target="_blank">Llama2</a> |
|
||||
| WizardCoder-Python-7B-V1.0 | 🤗 <a href="https://huggingface.co/WizardLM/WizardCoder-Python-7B-V1.0" target="_blank">HF Link</a> | 📃 <a href="https://arxiv.org/abs/2306.08568" target="_blank">[WizardCoder]</a> | 55.5 | 51.6 | [Demo](http://47.103.63.15:50088/) | <a href="https://ai.meta.com/resources/models-and-libraries/llama-downloads/" target="_blank">Llama2</a> |
|
||||
| WizardCoder-3B-V1.0 | 🤗 <a href="https://huggingface.co/WizardLM/WizardCoder-3B-V1.0" target="_blank">HF Link</a> | 📃 <a href="https://arxiv.org/abs/2306.08568" target="_blank">[WizardCoder]</a> | 34.8 |37.4 | -- | <a href="https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement" target="_blank">OpenRAIL-M</a> |
|
||||
| WizardCoder-1B-V1.0 | 🤗 <a href="https://huggingface.co/WizardLM/WizardCoder-1B-V1.0" target="_blank">HF Link</a> | 📃 <a href="https://arxiv.org/abs/2306.08568" target="_blank">[WizardCoder]</a> | 23.8 |28.6 | -- | <a href="https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement" target="_blank">OpenRAIL-M</a> |
|
||||
|
||||
|
||||
- 🔥 [08/11/2023] We release **WizardMath** Models.
|
||||
- 🔥 Our **WizardMath-70B-V1.0** model slightly outperforms some closed-source LLMs on the GSM8K, including **ChatGPT 3.5**, **Claude Instant 1** and **PaLM 2 540B**.
|
||||
- 🔥 Our **WizardMath-70B-V1.0** model achieves **81.6 pass@1** on the [GSM8k Benchmarks](https://github.com/openai/grade-school-math), which is **24.8** points higher than the SOTA open-source LLM.
|
||||
- 🔥 Our **WizardMath-70B-V1.0** model achieves **22.7 pass@1** on the [MATH Benchmarks](https://github.com/hendrycks/math), which is **9.2** points higher than the SOTA open-source LLM.
|
||||
|
||||
|
||||
| Model | Checkpoint | Paper | GSM8k | MATH |Online Demo| License|
|
||||
| ----- |------| ---- |------|-------| ----- | ----- |
|
||||
| WizardMath-70B-V1.0 | 🤗 <a href="https://huggingface.co/WizardLM/WizardMath-70B-V1.0" target="_blank">HF Link</a> | 📃 <a href="https://arxiv.org/abs/2308.09583" target="_blank">[WizardMath]</a>| **81.6** | **22.7** |[Demo](http://47.103.63.15:50083/)| <a href="https://ai.meta.com/resources/models-and-libraries/llama-downloads/" target="_blank">Llama 2 </a> |
|
||||
| WizardMath-13B-V1.0 | 🤗 <a href="https://huggingface.co/WizardLM/WizardMath-13B-V1.0" target="_blank">HF Link</a> | 📃 <a href="https://arxiv.org/abs/2308.09583" target="_blank">[WizardMath]</a>| **63.9** | **14.0** |[Demo](http://47.103.63.15:50082/)| <a href="https://ai.meta.com/resources/models-and-libraries/llama-downloads/" target="_blank">Llama 2 </a> |
|
||||
| WizardMath-7B-V1.0 | 🤗 <a href="https://huggingface.co/WizardLM/WizardMath-7B-V1.0" target="_blank">HF Link</a> | 📃 <a href="https://arxiv.org/abs/2308.09583" target="_blank">[WizardMath]</a>| **54.9** | **10.7** | [Demo](http://47.103.63.15:50080/)| <a href="https://ai.meta.com/resources/models-and-libraries/llama-downloads/" target="_blank">Llama 2 </a>|
|
||||
|
||||
|
||||
<font size=4>
|
||||
|
||||
| <sup>Model</sup> | <sup>Checkpoint</sup> | <sup>Paper</sup> |<sup>MT-Bench</sup> | <sup>AlpacaEval</sup> | <sup>WizardEval</sup> | <sup>HumanEval</sup> | <sup>License</sup>|
|
||||
| ----- |------| ---- |------|-------| ----- | ----- | ----- |
|
||||
| <sup>WizardLM-13B-V1.2</sup> | <sup>🤗 <a href="https://huggingface.co/WizardLM/WizardLM-13B-V1.2" target="_blank">HF Link</a> </sup>| | <sup>7.06</sup> | <sup>89.17%</sup> | <sup>101.4% </sup>|<sup>36.6 pass@1</sup>|<sup> <a href="https://ai.meta.com/resources/models-and-libraries/llama-downloads/" target="_blank">Llama 2 License </a></sup> |
|
||||
| <sup>WizardLM-13B-V1.1</sup> |<sup> 🤗 <a href="https://huggingface.co/WizardLM/WizardLM-13B-V1.1" target="_blank">HF Link</a> </sup> | | <sup>6.76</sup> |<sup>86.32%</sup> | <sup>99.3% </sup> |<sup>25.0 pass@1</sup>| <sup>Non-commercial</sup>|
|
||||
| <sup>WizardLM-30B-V1.0</sup> | <sup>🤗 <a href="https://huggingface.co/WizardLM/WizardLM-30B-V1.0" target="_blank">HF Link</a></sup> | | <sup>7.01</sup> | | <sup>97.8% </sup> | <sup>37.8 pass@1</sup>| <sup>Non-commercial</sup> |
|
||||
| <sup>WizardLM-13B-V1.0</sup> | <sup>🤗 <a href="https://huggingface.co/WizardLM/WizardLM-13B-V1.0" target="_blank">HF Link</a> </sup> | | <sup>6.35</sup> | <sup>75.31%</sup> | <sup>89.1% </sup> |<sup> 24.0 pass@1 </sup> | <sup>Non-commercial</sup>|
|
||||
| <sup>WizardLM-7B-V1.0 </sup>| <sup>🤗 <a href="https://huggingface.co/WizardLM/WizardLM-7B-V1.0" target="_blank">HF Link</a> </sup> |<sup> 📃 <a href="https://arxiv.org/abs/2304.12244" target="_blank">[WizardLM]</a> </sup>| | | <sup>78.0% </sup> |<sup>19.1 pass@1 </sup>|<sup> Non-commercial</sup>|
|
||||
</font>
|
||||
|
||||
**Repository**: https://github.com/nlpxucan/WizardLM
|
||||
|
||||
**Twitter**:
|
||||
|
||||
|
||||
- 🔥🔥🔥 [7/25/2023] We released **WizardLM V1.2** models. The **WizardLM-13B-V1.2** is here ([Demo_13B-V1.2](https://b7a19878988c8c73.gradio.app), [Demo_13B-V1.2_bak-1](https://d0a37a76e0ac4b52.gradio.app/), [Full Model Weight](https://huggingface.co/WizardLM/WizardLM-13B-V1.2)). Please checkout the [paper](https://arxiv.org/abs/2304.12244).
|
||||
- 🔥🔥🔥 [7/25/2023] The **WizardLM-13B-V1.2** achieves **7.06** on [MT-Bench Leaderboard](https://chat.lmsys.org/?leaderboard), **89.17%** on [AlpacaEval Leaderboard](https://tatsu-lab.github.io/alpaca_eval/), and **101.4%** on [WizardLM Eval](https://github.com/nlpxucan/WizardLM/blob/main/WizardLM/data/WizardLM_testset.jsonl). (Note: MT-Bench and AlpacaEval are all self-test, will push update and request review. All tests are completed under their official settings.)
|
||||
|
||||
❗<b>Note for model system prompts usage:</b>
|
||||
|
||||
|
||||
<b>WizardLM</b> adopts the prompt format from <b>Vicuna</b> and supports **multi-turn** conversation. The prompt should be as following:
|
||||
|
||||
```
|
||||
A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: Hi ASSISTANT: Hello.</s>USER: Who are you? ASSISTANT: I am WizardLM.</s>......
|
||||
```
|
||||
|
||||
## Inference WizardLM Demo Script
|
||||
|
||||
We provide the inference WizardLM demo code [here](https://github.com/nlpxucan/WizardLM/tree/main/demo).
|
||||
|
||||
Please cite the paper if you use the data or code from WizardLM.
|
||||
|
||||
```
|
||||
@article{xu2023wizardlm,
|
||||
title={Wizardlm: Empowering large language models to follow complex instructions},
|
||||
author={Xu, Can and Sun, Qingfeng and Zheng, Kai and Geng, Xiubo and Zhao, Pu and Feng, Jiazhan and Tao, Chongyang and Jiang, Daxin},
|
||||
journal={arXiv preprint arXiv:2304.12244},
|
||||
year={2023}
|
||||
}
|
||||
```
|
||||
|
||||
❗<b>To commen concern about dataset:</b>
|
||||
|
||||
Recently, there have been clear changes in the open-source policy and regulations of our overall organization's code, data, and models.
|
||||
|
||||
|
||||
Despite this, we have still worked hard to obtain opening the weights of the model first, but the data involves stricter auditing and is in review with our legal team .
|
||||
|
||||
Our researchers have no authority to publicly release them without authorization.
|
||||
|
||||
Thank you for your understanding.
|
||||
50
USE_POLICY.md
Normal file
50
USE_POLICY.md
Normal file
@@ -0,0 +1,50 @@
|
||||
# Llama 2 Acceptable Use Policy
|
||||
|
||||
Meta is committed to promoting safe and fair use of its tools and features, including Llama 2. If you access or use Llama 2, you agree to this Acceptable Use Policy (“Policy”). The most recent copy of this policy can be found at [ai.meta.com/llama/use-policy](http://ai.meta.com/llama/use-policy).
|
||||
|
||||
## Prohibited Uses
|
||||
We want everyone to use Llama 2 safely and responsibly. You agree you will not use, or allow others to use, Llama 2 to:
|
||||
|
||||
1. Violate the law or others’ rights, including to:
|
||||
1. Engage in, promote, generate, contribute to, encourage, plan, incite, or further illegal or unlawful activity or content, such as:
|
||||
1. Violence or terrorism
|
||||
2. Exploitation or harm to children, including the solicitation, creation, acquisition, or dissemination of child exploitative content or failure to report Child Sexual Abuse Material
|
||||
3. Human trafficking, exploitation, and sexual violence
|
||||
4. The illegal distribution of information or materials to minors, including obscene materials, or failure to employ legally required age-gating in connection with such information or materials.
|
||||
5. Sexual solicitation
|
||||
6. Any other criminal activity
|
||||
2. Engage in, promote, incite, or facilitate the harassment, abuse, threatening, or bullying of individuals or groups of individuals
|
||||
3. Engage in, promote, incite, or facilitate discrimination or other unlawful or harmful conduct in the provision of employment, employment benefits, credit, housing, other economic benefits, or other essential goods and services
|
||||
4. Engage in the unauthorized or unlicensed practice of any profession including, but not limited to, financial, legal, medical/health, or related professional practices
|
||||
5. Collect, process, disclose, generate, or infer health, demographic, or other sensitive personal or private information about individuals without rights and consents required by applicable laws
|
||||
6. Engage in or facilitate any action or generate any content that infringes, misappropriates, or otherwise violates any third-party rights, including the outputs or results of any products or services using the Llama 2 Materials
|
||||
7. Create, generate, or facilitate the creation of malicious code, malware, computer viruses or do anything else that could disable, overburden, interfere with or impair the proper working, integrity, operation or appearance of a website or computer system
|
||||
|
||||
|
||||
|
||||
2. Engage in, promote, incite, facilitate, or assist in the planning or development of activities that present a risk of death or bodily harm to individuals, including use of Llama 2 related to the following:
|
||||
1. Military, warfare, nuclear industries or applications, espionage, use for materials or activities that are subject to the International Traffic Arms Regulations (ITAR) maintained by the United States Department of State
|
||||
2. Guns and illegal weapons (including weapon development)
|
||||
3. Illegal drugs and regulated/controlled substances
|
||||
4. Operation of critical infrastructure, transportation technologies, or heavy machinery
|
||||
5. Self-harm or harm to others, including suicide, cutting, and eating disorders
|
||||
6. Any content intended to incite or promote violence, abuse, or any infliction of bodily harm to an individual
|
||||
|
||||
|
||||
|
||||
3. Intentionally deceive or mislead others, including use of Llama 2 related to the following:
|
||||
1. Generating, promoting, or furthering fraud or the creation or promotion of disinformation
|
||||
2. Generating, promoting, or furthering defamatory content, including the creation of defamatory statements, images, or other content
|
||||
3. Generating, promoting, or further distributing spam
|
||||
4. Impersonating another individual without consent, authorization, or legal right
|
||||
5. Representing that the use of Llama 2 or outputs are human-generated
|
||||
6. Generating or facilitating false online engagement, including fake reviews and other means of fake online engagement
|
||||
4. Fail to appropriately disclose to end users any known dangers of your AI system
|
||||
|
||||
Please report any violation of this Policy, software “bug,” or other problems that could lead to a violation of this Policy through one of the following means:
|
||||
|
||||
* Reporting issues with the model: [github.com/facebookresearch/llama](http://github.com/facebookresearch/llama)
|
||||
* Reporting risky content generated by the model: [developers.facebook.com/llama_output_feedback](http://developers.facebook.com/llama_output_feedback)
|
||||
* Reporting bugs and security concerns: [facebook.com/whitehat/info](http://facebook.com/whitehat/info)
|
||||
* Reporting violations of the Acceptable Use Policy or unlicensed uses of Llama: [LlamaUseReport@meta.com](mailto:LlamaUseReport@meta.com)
|
||||
|
||||
3
added_tokens.json
Normal file
3
added_tokens.json
Normal file
@@ -0,0 +1,3 @@
|
||||
{
|
||||
"<pad>": 32000
|
||||
}
|
||||
33
config.json
Normal file
33
config.json
Normal file
@@ -0,0 +1,33 @@
|
||||
{
|
||||
"_name_or_path": "//workspaceblobstore/caxu/llama_new/Llama-2-13b-chat-hf",
|
||||
"architectures": [
|
||||
"LlamaForCausalLM"
|
||||
],
|
||||
"bos_token_id": 1,
|
||||
"eos_token_id": 2,
|
||||
"hidden_act": "silu",
|
||||
"hidden_size": 5120,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 13824,
|
||||
"max_position_embeddings": 4096,
|
||||
"model_type": "llama",
|
||||
"num_attention_heads": 40,
|
||||
"num_hidden_layers": 40,
|
||||
"num_key_value_heads": 40,
|
||||
"pad_token_id": 0,
|
||||
"pretraining_tp": 1,
|
||||
"rms_norm_eps": 1e-05,
|
||||
"rope_scaling": null,
|
||||
"tie_word_embeddings": false,
|
||||
"torch_dtype": "float16",
|
||||
"transformers_version": "4.29.2",
|
||||
"use_cache": true,
|
||||
"vocab_size": 32000,
|
||||
"quantization_config": {
|
||||
"quant_method": "awq",
|
||||
"zero_point": true,
|
||||
"group_size": 128,
|
||||
"bits": 4,
|
||||
"version": "gemm"
|
||||
}
|
||||
}
|
||||
1
configuration.json
Normal file
1
configuration.json
Normal file
@@ -0,0 +1 @@
|
||||
{"framework": "pytorch", "task": "text-generation", "allow_remote": true}
|
||||
9
generation_config.json
Normal file
9
generation_config.json
Normal file
@@ -0,0 +1,9 @@
|
||||
{
|
||||
"_from_model_config": true,
|
||||
"bos_token_id": 1,
|
||||
"eos_token_id": 2,
|
||||
"pad_token_id": 0,
|
||||
"temperature": 0.9,
|
||||
"top_p": 0.6,
|
||||
"transformers_version": "4.29.2"
|
||||
}
|
||||
3
model.safetensors
Normal file
3
model.safetensors
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:3ee0b85bb496b5d3bceb039e271a2acc63f92528f05dca420b8945afae122413
|
||||
size 7247987312
|
||||
6
quant_config.json
Normal file
6
quant_config.json
Normal file
@@ -0,0 +1,6 @@
|
||||
{
|
||||
"zero_point": true,
|
||||
"q_group_size": 128,
|
||||
"w_bit": 4,
|
||||
"version": "GEMM"
|
||||
}
|
||||
6
special_tokens_map.json
Normal file
6
special_tokens_map.json
Normal file
@@ -0,0 +1,6 @@
|
||||
{
|
||||
"bos_token": "</s>",
|
||||
"eos_token": "</s>",
|
||||
"pad_token": "<unk>",
|
||||
"unk_token": "</s>"
|
||||
}
|
||||
93400
tokenizer.json
Normal file
93400
tokenizer.json
Normal file
File diff suppressed because it is too large
Load Diff
3
tokenizer.model
Normal file
3
tokenizer.model
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:9e556afd44213b6bd1be2b850ebbbd98f5481437a8021afaf58ee7fb1818d347
|
||||
size 499723
|
||||
37
tokenizer_config.json
Normal file
37
tokenizer_config.json
Normal file
@@ -0,0 +1,37 @@
|
||||
{
|
||||
"add_bos_token": true,
|
||||
"add_eos_token": false,
|
||||
"bos_token": {
|
||||
"__type": "AddedToken",
|
||||
"content": "<s>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"clean_up_tokenization_spaces": false,
|
||||
"eos_token": {
|
||||
"__type": "AddedToken",
|
||||
"content": "</s>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"legacy": false,
|
||||
"model_max_length": 2048,
|
||||
"pad_token": null,
|
||||
"padding_side": "right",
|
||||
"sp_model_kwargs": {},
|
||||
"spaces_between_special_tokens": false,
|
||||
"tokenizer_class": "LlamaTokenizer",
|
||||
"unk_token": {
|
||||
"__type": "AddedToken",
|
||||
"content": "<unk>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"use_default_system_prompt": true
|
||||
}
|
||||
578
zero_to_fp32.py
Normal file
578
zero_to_fp32.py
Normal file
@@ -0,0 +1,578 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
# This script extracts fp32 consolidated weights from a zero 2 and 3 DeepSpeed checkpoints. It gets
|
||||
# copied into the top level checkpoint dir, so the user can easily do the conversion at any point in
|
||||
# the future. Once extracted, the weights don't require DeepSpeed and can be used in any
|
||||
# application.
|
||||
#
|
||||
# example: python zero_to_fp32.py . pytorch_model.bin
|
||||
|
||||
import argparse
|
||||
import torch
|
||||
import glob
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass
|
||||
|
||||
# while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with
|
||||
# DeepSpeed data structures it has to be available in the current python environment.
|
||||
from deepspeed.utils import logger
|
||||
from deepspeed.checkpoint.constants import (DS_VERSION, OPTIMIZER_STATE_DICT, SINGLE_PARTITION_OF_FP32_GROUPS,
|
||||
FP32_FLAT_GROUPS, ZERO_STAGE, PARTITION_COUNT, PARAM_SHAPES, BUFFER_NAMES,
|
||||
FROZEN_PARAM_SHAPES, FROZEN_PARAM_FRAGMENTS)
|
||||
|
||||
|
||||
@dataclass
|
||||
class zero_model_state:
|
||||
buffers: dict()
|
||||
param_shapes: dict()
|
||||
shared_params: list
|
||||
ds_version: int
|
||||
frozen_param_shapes: dict()
|
||||
frozen_param_fragments: dict()
|
||||
|
||||
|
||||
debug = 0
|
||||
|
||||
# load to cpu
|
||||
device = torch.device('cpu')
|
||||
|
||||
|
||||
def atoi(text):
|
||||
return int(text) if text.isdigit() else text
|
||||
|
||||
|
||||
def natural_keys(text):
|
||||
'''
|
||||
alist.sort(key=natural_keys) sorts in human order
|
||||
http://nedbatchelder.com/blog/200712/human_sorting.html
|
||||
(See Toothy's implementation in the comments)
|
||||
'''
|
||||
return [atoi(c) for c in re.split(r'(\d+)', text)]
|
||||
|
||||
|
||||
def get_model_state_file(checkpoint_dir, zero_stage):
|
||||
if not os.path.isdir(checkpoint_dir):
|
||||
raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist")
|
||||
|
||||
# there should be only one file
|
||||
if zero_stage == 2:
|
||||
file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt")
|
||||
elif zero_stage == 3:
|
||||
file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt")
|
||||
|
||||
if not os.path.exists(file):
|
||||
raise FileNotFoundError(f"can't find model states file at '{file}'")
|
||||
|
||||
return file
|
||||
|
||||
|
||||
def get_checkpoint_files(checkpoint_dir, glob_pattern):
|
||||
# XXX: need to test that this simple glob rule works for multi-node setup too
|
||||
ckpt_files = sorted(glob.glob(os.path.join(checkpoint_dir, glob_pattern)), key=natural_keys)
|
||||
|
||||
if len(ckpt_files) == 0:
|
||||
raise FileNotFoundError(f"can't find {glob_pattern} files in directory '{checkpoint_dir}'")
|
||||
|
||||
return ckpt_files
|
||||
|
||||
|
||||
def get_optim_files(checkpoint_dir):
|
||||
return get_checkpoint_files(checkpoint_dir, "*_optim_states.pt")
|
||||
|
||||
|
||||
def get_model_state_files(checkpoint_dir):
|
||||
return get_checkpoint_files(checkpoint_dir, "*_model_states.pt")
|
||||
|
||||
|
||||
def parse_model_states(files):
|
||||
zero_model_states = []
|
||||
for file in files:
|
||||
state_dict = torch.load(file, map_location=device)
|
||||
|
||||
if BUFFER_NAMES not in state_dict:
|
||||
raise ValueError(f"{file} is not a model state checkpoint")
|
||||
buffer_names = state_dict[BUFFER_NAMES]
|
||||
if debug:
|
||||
print("Found buffers:", buffer_names)
|
||||
|
||||
# recover just the buffers while restoring them to fp32 if they were saved in fp16
|
||||
buffers = {k: v.float() for k, v in state_dict["module"].items() if k in buffer_names}
|
||||
param_shapes = state_dict[PARAM_SHAPES]
|
||||
|
||||
# collect parameters that are included in param_shapes
|
||||
param_names = []
|
||||
for s in param_shapes:
|
||||
for name in s.keys():
|
||||
param_names.append(name)
|
||||
|
||||
# update with frozen parameters
|
||||
frozen_param_shapes = state_dict.get(FROZEN_PARAM_SHAPES, None)
|
||||
if frozen_param_shapes is not None:
|
||||
if debug:
|
||||
print(f"Found frozen_param_shapes: {frozen_param_shapes}")
|
||||
param_names += list(frozen_param_shapes.keys())
|
||||
|
||||
# handle shared params
|
||||
shared_params = [[k, v] for k, v in state_dict["shared_params"].items()]
|
||||
|
||||
ds_version = state_dict.get(DS_VERSION, None)
|
||||
|
||||
frozen_param_fragments = state_dict.get(FROZEN_PARAM_FRAGMENTS, None)
|
||||
|
||||
z_model_state = zero_model_state(buffers=buffers,
|
||||
param_shapes=param_shapes,
|
||||
shared_params=shared_params,
|
||||
ds_version=ds_version,
|
||||
frozen_param_shapes=frozen_param_shapes,
|
||||
frozen_param_fragments=frozen_param_fragments)
|
||||
zero_model_states.append(z_model_state)
|
||||
|
||||
return zero_model_states
|
||||
|
||||
|
||||
def parse_optim_states(files, ds_checkpoint_dir):
|
||||
|
||||
total_files = len(files)
|
||||
state_dicts = []
|
||||
for f in files:
|
||||
state_dicts.append(torch.load(f, map_location=device))
|
||||
|
||||
if not ZERO_STAGE in state_dicts[0][OPTIMIZER_STATE_DICT]:
|
||||
raise ValueError(f"{files[0]} is not a zero checkpoint")
|
||||
zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE]
|
||||
world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT]
|
||||
|
||||
# For ZeRO-2 each param group can have different partition_count as data parallelism for expert
|
||||
# parameters can be different from data parallelism for non-expert parameters. So we can just
|
||||
# use the max of the partition_count to get the dp world_size.
|
||||
|
||||
if type(world_size) is list:
|
||||
world_size = max(world_size)
|
||||
|
||||
if world_size != total_files:
|
||||
raise ValueError(
|
||||
f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. "
|
||||
"Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes."
|
||||
)
|
||||
|
||||
# the groups are named differently in each stage
|
||||
if zero_stage == 2:
|
||||
fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS
|
||||
elif zero_stage == 3:
|
||||
fp32_groups_key = FP32_FLAT_GROUPS
|
||||
else:
|
||||
raise ValueError(f"unknown zero stage {zero_stage}")
|
||||
|
||||
if zero_stage == 2:
|
||||
fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))]
|
||||
elif zero_stage == 3:
|
||||
# if there is more than one param group, there will be multiple flattened tensors - one
|
||||
# flattened tensor per group - for simplicity merge them into a single tensor
|
||||
#
|
||||
# XXX: could make the script more memory efficient for when there are multiple groups - it
|
||||
# will require matching the sub-lists of param_shapes for each param group flattened tensor
|
||||
|
||||
fp32_flat_groups = [
|
||||
torch.cat(state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key], 0) for i in range(len(state_dicts))
|
||||
]
|
||||
|
||||
return zero_stage, world_size, fp32_flat_groups
|
||||
|
||||
|
||||
def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir):
|
||||
"""
|
||||
Returns fp32 state_dict reconstructed from ds checkpoint
|
||||
|
||||
Args:
|
||||
- ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are)
|
||||
|
||||
"""
|
||||
print(f"Processing zero checkpoint '{ds_checkpoint_dir}'")
|
||||
|
||||
optim_files = get_optim_files(ds_checkpoint_dir)
|
||||
zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir)
|
||||
print(f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}")
|
||||
|
||||
model_files = get_model_state_files(ds_checkpoint_dir)
|
||||
|
||||
zero_model_states = parse_model_states(model_files)
|
||||
print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}')
|
||||
|
||||
if zero_stage == 2:
|
||||
return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states)
|
||||
elif zero_stage == 3:
|
||||
return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states)
|
||||
|
||||
|
||||
def _zero2_merge_frozen_params(state_dict, zero_model_states):
|
||||
if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
|
||||
return
|
||||
|
||||
frozen_param_shapes = zero_model_states[0].frozen_param_shapes
|
||||
frozen_param_fragments = zero_model_states[0].frozen_param_fragments
|
||||
|
||||
if debug:
|
||||
num_elem = sum(s.numel() for s in frozen_param_shapes.values())
|
||||
print(f'rank 0: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
|
||||
|
||||
wanted_params = len(frozen_param_shapes)
|
||||
wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
|
||||
avail_numel = sum([p.numel() for p in frozen_param_fragments.values()])
|
||||
print(f'Frozen params: Have {avail_numel} numels to process.')
|
||||
print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
|
||||
|
||||
total_params = 0
|
||||
total_numel = 0
|
||||
for name, shape in frozen_param_shapes.items():
|
||||
total_params += 1
|
||||
unpartitioned_numel = shape.numel()
|
||||
total_numel += unpartitioned_numel
|
||||
|
||||
state_dict[name] = frozen_param_fragments[name]
|
||||
|
||||
if debug:
|
||||
print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
|
||||
|
||||
print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
|
||||
|
||||
|
||||
def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
|
||||
param_shapes = zero_model_states[0].param_shapes
|
||||
|
||||
# Reconstruction protocol:
|
||||
#
|
||||
# XXX: document this
|
||||
|
||||
if debug:
|
||||
for i in range(world_size):
|
||||
for j in range(len(fp32_flat_groups[0])):
|
||||
print(f"{FP32_FLAT_GROUPS}[{i}][{j}].shape={fp32_flat_groups[i][j].shape}")
|
||||
|
||||
# XXX: memory usage doubles here (zero2)
|
||||
num_param_groups = len(fp32_flat_groups[0])
|
||||
merged_single_partition_of_fp32_groups = []
|
||||
for i in range(num_param_groups):
|
||||
merged_partitions = [sd[i] for sd in fp32_flat_groups]
|
||||
full_single_fp32_vector = torch.cat(merged_partitions, 0)
|
||||
merged_single_partition_of_fp32_groups.append(full_single_fp32_vector)
|
||||
avail_numel = sum(
|
||||
[full_single_fp32_vector.numel() for full_single_fp32_vector in merged_single_partition_of_fp32_groups])
|
||||
|
||||
if debug:
|
||||
wanted_params = sum([len(shapes) for shapes in param_shapes])
|
||||
wanted_numel = sum([sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes])
|
||||
# not asserting if there is a mismatch due to possible padding
|
||||
print(f"Have {avail_numel} numels to process.")
|
||||
print(f"Need {wanted_numel} numels in {wanted_params} params.")
|
||||
|
||||
# params
|
||||
# XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
|
||||
# out-of-core computing solution
|
||||
total_numel = 0
|
||||
total_params = 0
|
||||
for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups):
|
||||
offset = 0
|
||||
avail_numel = full_single_fp32_vector.numel()
|
||||
for name, shape in shapes.items():
|
||||
|
||||
unpartitioned_numel = shape.numel()
|
||||
total_numel += unpartitioned_numel
|
||||
total_params += 1
|
||||
|
||||
if debug:
|
||||
print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
|
||||
state_dict[name] = full_single_fp32_vector.narrow(0, offset, unpartitioned_numel).view(shape)
|
||||
offset += unpartitioned_numel
|
||||
|
||||
# Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and
|
||||
# avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex
|
||||
# paddings performed in the code it's almost impossible to predict the exact numbers w/o the
|
||||
# live optimizer object, so we are checking that the numbers are within the right range
|
||||
align_to = 2 * world_size
|
||||
|
||||
def zero2_align(x):
|
||||
return align_to * math.ceil(x / align_to)
|
||||
|
||||
if debug:
|
||||
print(f"original offset={offset}, avail_numel={avail_numel}")
|
||||
|
||||
offset = zero2_align(offset)
|
||||
avail_numel = zero2_align(avail_numel)
|
||||
|
||||
if debug:
|
||||
print(f"aligned offset={offset}, avail_numel={avail_numel}")
|
||||
|
||||
# Sanity check
|
||||
if offset != avail_numel:
|
||||
raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
|
||||
|
||||
print(f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements")
|
||||
|
||||
|
||||
def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states):
|
||||
state_dict = OrderedDict()
|
||||
|
||||
# buffers
|
||||
buffers = zero_model_states[0].buffers
|
||||
state_dict.update(buffers)
|
||||
if debug:
|
||||
print(f"added {len(buffers)} buffers")
|
||||
|
||||
_zero2_merge_frozen_params(state_dict, zero_model_states)
|
||||
|
||||
_zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
|
||||
|
||||
# recover shared parameters
|
||||
for pair in zero_model_states[0].shared_params:
|
||||
if pair[1] in state_dict:
|
||||
state_dict[pair[0]] = state_dict[pair[1]]
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def zero3_partitioned_param_info(unpartitioned_numel, world_size):
|
||||
remainder = unpartitioned_numel % world_size
|
||||
padding_numel = (world_size - remainder) if remainder else 0
|
||||
partitioned_numel = math.ceil(unpartitioned_numel / world_size)
|
||||
return partitioned_numel, padding_numel
|
||||
|
||||
|
||||
def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states):
|
||||
if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
|
||||
return
|
||||
|
||||
if debug:
|
||||
for i in range(world_size):
|
||||
num_elem = sum(s.numel() for s in zero_model_states[i].frozen_param_fragments.values())
|
||||
print(f'rank {i}: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
|
||||
|
||||
frozen_param_shapes = zero_model_states[0].frozen_param_shapes
|
||||
wanted_params = len(frozen_param_shapes)
|
||||
wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
|
||||
avail_numel = sum([p.numel() for p in zero_model_states[0].frozen_param_fragments.values()]) * world_size
|
||||
print(f'Frozen params: Have {avail_numel} numels to process.')
|
||||
print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
|
||||
|
||||
total_params = 0
|
||||
total_numel = 0
|
||||
for name, shape in zero_model_states[0].frozen_param_shapes.items():
|
||||
total_params += 1
|
||||
unpartitioned_numel = shape.numel()
|
||||
total_numel += unpartitioned_numel
|
||||
|
||||
param_frags = tuple(model_state.frozen_param_fragments[name] for model_state in zero_model_states)
|
||||
state_dict[name] = torch.cat(param_frags, 0).narrow(0, 0, unpartitioned_numel).view(shape)
|
||||
|
||||
partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
|
||||
|
||||
if debug:
|
||||
print(
|
||||
f"Frozen params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
|
||||
)
|
||||
|
||||
print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
|
||||
|
||||
|
||||
def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
|
||||
param_shapes = zero_model_states[0].param_shapes
|
||||
avail_numel = fp32_flat_groups[0].numel() * world_size
|
||||
# Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each
|
||||
# param, re-consolidating each param, while dealing with padding if any
|
||||
|
||||
# merge list of dicts, preserving order
|
||||
param_shapes = {k: v for d in param_shapes for k, v in d.items()}
|
||||
|
||||
if debug:
|
||||
for i in range(world_size):
|
||||
print(f"{FP32_FLAT_GROUPS}[{i}].shape={fp32_flat_groups[i].shape}")
|
||||
|
||||
wanted_params = len(param_shapes)
|
||||
wanted_numel = sum(shape.numel() for shape in param_shapes.values())
|
||||
# not asserting if there is a mismatch due to possible padding
|
||||
avail_numel = fp32_flat_groups[0].numel() * world_size
|
||||
print(f"Trainable params: Have {avail_numel} numels to process.")
|
||||
print(f"Trainable params: Need {wanted_numel} numels in {wanted_params} params.")
|
||||
|
||||
# params
|
||||
# XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
|
||||
# out-of-core computing solution
|
||||
offset = 0
|
||||
total_numel = 0
|
||||
total_params = 0
|
||||
for name, shape in param_shapes.items():
|
||||
|
||||
unpartitioned_numel = shape.numel()
|
||||
total_numel += unpartitioned_numel
|
||||
total_params += 1
|
||||
|
||||
partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
|
||||
|
||||
if debug:
|
||||
print(
|
||||
f"Trainable params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
|
||||
)
|
||||
|
||||
# XXX: memory usage doubles here
|
||||
state_dict[name] = torch.cat(
|
||||
tuple(fp32_flat_groups[i].narrow(0, offset, partitioned_numel) for i in range(world_size)),
|
||||
0).narrow(0, 0, unpartitioned_numel).view(shape)
|
||||
offset += partitioned_numel
|
||||
|
||||
offset *= world_size
|
||||
|
||||
# Sanity check
|
||||
if offset != avail_numel:
|
||||
raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
|
||||
|
||||
print(f"Reconstructed Trainable fp32 state dict with {total_params} params {total_numel} elements")
|
||||
|
||||
|
||||
def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states):
|
||||
state_dict = OrderedDict()
|
||||
|
||||
# buffers
|
||||
buffers = zero_model_states[0].buffers
|
||||
state_dict.update(buffers)
|
||||
if debug:
|
||||
print(f"added {len(buffers)} buffers")
|
||||
|
||||
_zero3_merge_frozen_params(state_dict, world_size, zero_model_states)
|
||||
|
||||
_zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
|
||||
|
||||
# recover shared parameters
|
||||
for pair in zero_model_states[0].shared_params:
|
||||
if pair[1] in state_dict:
|
||||
state_dict[pair[0]] = state_dict[pair[1]]
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None):
|
||||
"""
|
||||
Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with
|
||||
``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example
|
||||
via a model hub.
|
||||
|
||||
Args:
|
||||
- ``checkpoint_dir``: path to the desired checkpoint folder
|
||||
- ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14``
|
||||
|
||||
Returns:
|
||||
- pytorch ``state_dict``
|
||||
|
||||
Note: this approach may not work if your application doesn't have sufficient free CPU memory and
|
||||
you may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with
|
||||
the checkpoint.
|
||||
|
||||
A typical usage might be ::
|
||||
|
||||
from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
|
||||
# do the training and checkpoint saving
|
||||
state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu
|
||||
model = model.cpu() # move to cpu
|
||||
model.load_state_dict(state_dict)
|
||||
# submit to model hub or save the model to share with others
|
||||
|
||||
In this example the ``model`` will no longer be usable in the deepspeed context of the same
|
||||
application. i.e. you will need to re-initialize the deepspeed engine, since
|
||||
``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
|
||||
|
||||
If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead.
|
||||
|
||||
"""
|
||||
if tag is None:
|
||||
latest_path = os.path.join(checkpoint_dir, 'latest')
|
||||
if os.path.isfile(latest_path):
|
||||
with open(latest_path, 'r') as fd:
|
||||
tag = fd.read().strip()
|
||||
else:
|
||||
raise ValueError(f"Unable to find 'latest' file at {latest_path}")
|
||||
|
||||
ds_checkpoint_dir = os.path.join(checkpoint_dir, tag)
|
||||
|
||||
if not os.path.isdir(ds_checkpoint_dir):
|
||||
raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist")
|
||||
|
||||
return _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir)
|
||||
|
||||
|
||||
def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag=None):
|
||||
"""
|
||||
Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be
|
||||
loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed.
|
||||
|
||||
Args:
|
||||
- ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
|
||||
- ``output_file``: path to the pytorch fp32 state_dict output file (e.g. path/pytorch_model.bin)
|
||||
- ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
|
||||
"""
|
||||
|
||||
state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag)
|
||||
print(f"Saving fp32 state dict to {output_file}")
|
||||
torch.save(state_dict, output_file)
|
||||
|
||||
|
||||
def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None):
|
||||
"""
|
||||
1. Put the provided model to cpu
|
||||
2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict``
|
||||
3. Load it into the provided model
|
||||
|
||||
Args:
|
||||
- ``model``: the model object to update
|
||||
- ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
|
||||
- ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
|
||||
|
||||
Returns:
|
||||
- ``model`: modified model
|
||||
|
||||
Make sure you have plenty of CPU memory available before you call this function. If you don't
|
||||
have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it
|
||||
conveniently placed for you in the checkpoint folder.
|
||||
|
||||
A typical usage might be ::
|
||||
|
||||
from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint
|
||||
model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir)
|
||||
# submit to model hub or save the model to share with others
|
||||
|
||||
Note, that once this was run, the ``model`` will no longer be usable in the deepspeed context
|
||||
of the same application. i.e. you will need to re-initialize the deepspeed engine, since
|
||||
``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
|
||||
|
||||
"""
|
||||
logger.info(f"Extracting fp32 weights")
|
||||
state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag)
|
||||
|
||||
logger.info(f"Overwriting model with fp32 weights")
|
||||
model = model.cpu()
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("checkpoint_dir",
|
||||
type=str,
|
||||
help="path to the desired checkpoint folder, e.g., path/checkpoint-12")
|
||||
parser.add_argument(
|
||||
"output_file",
|
||||
type=str,
|
||||
help="path to the pytorch fp32 state_dict output file (e.g. path/checkpoint-12/pytorch_model.bin)")
|
||||
parser.add_argument("-d", "--debug", action='store_true', help="enable debug")
|
||||
args = parser.parse_args()
|
||||
|
||||
debug = args.debug
|
||||
|
||||
convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir, args.output_file)
|
||||
Reference in New Issue
Block a user