197 lines
5.9 KiB
Python
197 lines
5.9 KiB
Python
|
|
"""
|
||
|
|
Usage:
|
||
|
|
python convert_hf_to_easylm.py \
|
||
|
|
--checkpoint_dir /path/hf_format_dir/ \
|
||
|
|
--output_file /path/easylm_format.stream \
|
||
|
|
--model_size 7b \
|
||
|
|
--streaming
|
||
|
|
"""
|
||
|
|
import time
|
||
|
|
from pathlib import Path
|
||
|
|
import argparse
|
||
|
|
|
||
|
|
import mlxu
|
||
|
|
import torch
|
||
|
|
import flax
|
||
|
|
|
||
|
|
from EasyLM.checkpoint import StreamingCheckpointer
|
||
|
|
|
||
|
|
LLAMA_STANDARD_CONFIGS = {
|
||
|
|
'1b': {
|
||
|
|
'dim': 2048,
|
||
|
|
'intermediate_size': 5504,
|
||
|
|
'n_layers': 22,
|
||
|
|
'n_heads': 16,
|
||
|
|
'norm_eps': 1e-6,
|
||
|
|
},
|
||
|
|
'3b': {
|
||
|
|
'dim': 3200,
|
||
|
|
'intermediate_size': 8640,
|
||
|
|
'n_layers': 26,
|
||
|
|
'n_heads': 32,
|
||
|
|
'norm_eps': 1e-6,
|
||
|
|
},
|
||
|
|
"7b": {
|
||
|
|
"dim": 4096,
|
||
|
|
"intermediate_size": 11008,
|
||
|
|
"n_layers": 32,
|
||
|
|
"n_heads": 32,
|
||
|
|
"norm_eps": 1e-6,
|
||
|
|
},
|
||
|
|
"13b": {
|
||
|
|
"dim": 5120,
|
||
|
|
"intermediate_size": 13824,
|
||
|
|
"n_layers": 40,
|
||
|
|
"n_heads": 40,
|
||
|
|
"norm_eps": 1e-6,
|
||
|
|
},
|
||
|
|
"30b": {
|
||
|
|
"dim": 6656,
|
||
|
|
"intermediate_size": 17920,
|
||
|
|
"n_layers": 60,
|
||
|
|
"n_heads": 52,
|
||
|
|
"norm_eps": 1e-6,
|
||
|
|
},
|
||
|
|
"65b": {
|
||
|
|
"dim": 8192,
|
||
|
|
"intermediate_size": 22016,
|
||
|
|
"n_layers": 80,
|
||
|
|
"n_heads": 64,
|
||
|
|
"norm_eps": 1e-5,
|
||
|
|
},
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
def inverse_permute(params, w):
|
||
|
|
n_layers = params["n_layers"]
|
||
|
|
n_heads = params["n_heads"]
|
||
|
|
dim = params["dim"]
|
||
|
|
reshaped_w = w.reshape(n_heads, 2, dim // n_heads // 2, dim)
|
||
|
|
transposed_w = reshaped_w.transpose(0, 2, 1, 3)
|
||
|
|
inverted_w = transposed_w.reshape(dim, dim)
|
||
|
|
return inverted_w
|
||
|
|
|
||
|
|
|
||
|
|
def main(args):
|
||
|
|
start = time.time()
|
||
|
|
params = LLAMA_STANDARD_CONFIGS[args.model_size]
|
||
|
|
|
||
|
|
ckpt_paths = sorted(Path(args.checkpoint_dir).glob("*.bin"))
|
||
|
|
ckpt = {}
|
||
|
|
for i, ckpt_path in enumerate(ckpt_paths):
|
||
|
|
checkpoint = torch.load(ckpt_path, map_location="cpu")
|
||
|
|
for k, v in checkpoint.items():
|
||
|
|
if k.startswith("model."):
|
||
|
|
k = k[6:]
|
||
|
|
ckpt[k] = v
|
||
|
|
print(f"Start convert weight to easylm format...")
|
||
|
|
jax_weights = {
|
||
|
|
"transformer": {
|
||
|
|
"wte": {"embedding": ckpt["embed_tokens.weight"].numpy()},
|
||
|
|
"ln_f": {"kernel": ckpt["norm.weight"].numpy()},
|
||
|
|
"h": {
|
||
|
|
"%d"
|
||
|
|
% (layer): {
|
||
|
|
"attention": {
|
||
|
|
"wq": {
|
||
|
|
"kernel": inverse_permute(
|
||
|
|
params,
|
||
|
|
ckpt[f"layers.{layer}.self_attn.q_proj.weight"].numpy(),
|
||
|
|
).transpose()
|
||
|
|
},
|
||
|
|
"wk": {
|
||
|
|
"kernel": inverse_permute(
|
||
|
|
params,
|
||
|
|
ckpt[f"layers.{layer}.self_attn.k_proj.weight"].numpy(),
|
||
|
|
).transpose()
|
||
|
|
},
|
||
|
|
"wv": {
|
||
|
|
"kernel": ckpt[f"layers.{layer}.self_attn.v_proj.weight"]
|
||
|
|
.numpy()
|
||
|
|
.transpose()
|
||
|
|
},
|
||
|
|
"wo": {
|
||
|
|
"kernel": ckpt[f"layers.{layer}.self_attn.o_proj.weight"]
|
||
|
|
.numpy()
|
||
|
|
.transpose()
|
||
|
|
},
|
||
|
|
},
|
||
|
|
"feed_forward": {
|
||
|
|
"w1": {
|
||
|
|
"kernel": ckpt[f"layers.{layer}.mlp.gate_proj.weight"]
|
||
|
|
.numpy()
|
||
|
|
.transpose()
|
||
|
|
},
|
||
|
|
"w2": {
|
||
|
|
"kernel": ckpt[f"layers.{layer}.mlp.down_proj.weight"]
|
||
|
|
.numpy()
|
||
|
|
.transpose()
|
||
|
|
},
|
||
|
|
"w3": {
|
||
|
|
"kernel": ckpt[f"layers.{layer}.mlp.up_proj.weight"]
|
||
|
|
.numpy()
|
||
|
|
.transpose()
|
||
|
|
},
|
||
|
|
},
|
||
|
|
"attention_norm": {
|
||
|
|
"kernel": ckpt[f"layers.{layer}.input_layernorm.weight"].numpy()
|
||
|
|
},
|
||
|
|
"ffn_norm": {
|
||
|
|
"kernel": ckpt[
|
||
|
|
f"layers.{layer}.post_attention_layernorm.weight"
|
||
|
|
].numpy()
|
||
|
|
},
|
||
|
|
}
|
||
|
|
for layer in range(params["n_layers"])
|
||
|
|
},
|
||
|
|
},
|
||
|
|
"lm_head": {"kernel": ckpt["lm_head.weight"].numpy().transpose()},
|
||
|
|
}
|
||
|
|
print(f"Convert weight to easylm format finished...")
|
||
|
|
print(f"Start to save...")
|
||
|
|
|
||
|
|
if args.streaming:
|
||
|
|
StreamingCheckpointer.save_train_state_to_file(jax_weights, args.output_file)
|
||
|
|
else:
|
||
|
|
with mlxu.open_file(args.output_file, "wb") as fout:
|
||
|
|
fout.write(flax.serialization.msgpack_serialize(jax_weights, in_place=True))
|
||
|
|
|
||
|
|
print(
|
||
|
|
f"Save finished!!! take time: {time.time() - start} save path: {args.output_file}"
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
parser = argparse.ArgumentParser(description="hf to easylm format script")
|
||
|
|
|
||
|
|
parser.add_argument(
|
||
|
|
"--checkpoint_dir",
|
||
|
|
type=str,
|
||
|
|
help="Need to be converted model weight dir. it is a dir",
|
||
|
|
)
|
||
|
|
parser.add_argument(
|
||
|
|
"--output_file", type=str, help="Save model weight file path, it is a file."
|
||
|
|
)
|
||
|
|
parser.add_argument(
|
||
|
|
"--model_size",
|
||
|
|
type=str,
|
||
|
|
default="7b",
|
||
|
|
choices=["7b", "13b", "30b", "65b"],
|
||
|
|
help="model size",
|
||
|
|
)
|
||
|
|
parser.add_argument(
|
||
|
|
"--streaming",
|
||
|
|
action="store_true",
|
||
|
|
default=True,
|
||
|
|
help="whether is model weight saved stream format",
|
||
|
|
)
|
||
|
|
|
||
|
|
args = parser.parse_args()
|
||
|
|
|
||
|
|
print(f"checkpoint_dir: {args.checkpoint_dir}")
|
||
|
|
print(f"output_file: {args.output_file}")
|
||
|
|
print(f"model_size: {args.model_size}")
|
||
|
|
print(f"streaming: {args.streaming}")
|
||
|
|
|
||
|
|
main(args)
|