初始化项目,由ModelHub XC社区提供模型
Model: Finnish-NLP/Ahma-7B Source: Original Platform
This commit is contained in:
196
EasyLM/models/llama/convert_hf_to_easylm.py
Normal file
196
EasyLM/models/llama/convert_hf_to_easylm.py
Normal file
@@ -0,0 +1,196 @@
|
||||
"""
|
||||
Usage:
|
||||
python convert_hf_to_easylm.py \
|
||||
--checkpoint_dir /path/hf_format_dir/ \
|
||||
--output_file /path/easylm_format.stream \
|
||||
--model_size 7b \
|
||||
--streaming
|
||||
"""
|
||||
import time
|
||||
from pathlib import Path
|
||||
import argparse
|
||||
|
||||
import mlxu
|
||||
import torch
|
||||
import flax
|
||||
|
||||
from EasyLM.checkpoint import StreamingCheckpointer
|
||||
|
||||
LLAMA_STANDARD_CONFIGS = {
|
||||
'1b': {
|
||||
'dim': 2048,
|
||||
'intermediate_size': 5504,
|
||||
'n_layers': 22,
|
||||
'n_heads': 16,
|
||||
'norm_eps': 1e-6,
|
||||
},
|
||||
'3b': {
|
||||
'dim': 3200,
|
||||
'intermediate_size': 8640,
|
||||
'n_layers': 26,
|
||||
'n_heads': 32,
|
||||
'norm_eps': 1e-6,
|
||||
},
|
||||
"7b": {
|
||||
"dim": 4096,
|
||||
"intermediate_size": 11008,
|
||||
"n_layers": 32,
|
||||
"n_heads": 32,
|
||||
"norm_eps": 1e-6,
|
||||
},
|
||||
"13b": {
|
||||
"dim": 5120,
|
||||
"intermediate_size": 13824,
|
||||
"n_layers": 40,
|
||||
"n_heads": 40,
|
||||
"norm_eps": 1e-6,
|
||||
},
|
||||
"30b": {
|
||||
"dim": 6656,
|
||||
"intermediate_size": 17920,
|
||||
"n_layers": 60,
|
||||
"n_heads": 52,
|
||||
"norm_eps": 1e-6,
|
||||
},
|
||||
"65b": {
|
||||
"dim": 8192,
|
||||
"intermediate_size": 22016,
|
||||
"n_layers": 80,
|
||||
"n_heads": 64,
|
||||
"norm_eps": 1e-5,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def inverse_permute(params, w):
|
||||
n_layers = params["n_layers"]
|
||||
n_heads = params["n_heads"]
|
||||
dim = params["dim"]
|
||||
reshaped_w = w.reshape(n_heads, 2, dim // n_heads // 2, dim)
|
||||
transposed_w = reshaped_w.transpose(0, 2, 1, 3)
|
||||
inverted_w = transposed_w.reshape(dim, dim)
|
||||
return inverted_w
|
||||
|
||||
|
||||
def main(args):
|
||||
start = time.time()
|
||||
params = LLAMA_STANDARD_CONFIGS[args.model_size]
|
||||
|
||||
ckpt_paths = sorted(Path(args.checkpoint_dir).glob("*.bin"))
|
||||
ckpt = {}
|
||||
for i, ckpt_path in enumerate(ckpt_paths):
|
||||
checkpoint = torch.load(ckpt_path, map_location="cpu")
|
||||
for k, v in checkpoint.items():
|
||||
if k.startswith("model."):
|
||||
k = k[6:]
|
||||
ckpt[k] = v
|
||||
print(f"Start convert weight to easylm format...")
|
||||
jax_weights = {
|
||||
"transformer": {
|
||||
"wte": {"embedding": ckpt["embed_tokens.weight"].numpy()},
|
||||
"ln_f": {"kernel": ckpt["norm.weight"].numpy()},
|
||||
"h": {
|
||||
"%d"
|
||||
% (layer): {
|
||||
"attention": {
|
||||
"wq": {
|
||||
"kernel": inverse_permute(
|
||||
params,
|
||||
ckpt[f"layers.{layer}.self_attn.q_proj.weight"].numpy(),
|
||||
).transpose()
|
||||
},
|
||||
"wk": {
|
||||
"kernel": inverse_permute(
|
||||
params,
|
||||
ckpt[f"layers.{layer}.self_attn.k_proj.weight"].numpy(),
|
||||
).transpose()
|
||||
},
|
||||
"wv": {
|
||||
"kernel": ckpt[f"layers.{layer}.self_attn.v_proj.weight"]
|
||||
.numpy()
|
||||
.transpose()
|
||||
},
|
||||
"wo": {
|
||||
"kernel": ckpt[f"layers.{layer}.self_attn.o_proj.weight"]
|
||||
.numpy()
|
||||
.transpose()
|
||||
},
|
||||
},
|
||||
"feed_forward": {
|
||||
"w1": {
|
||||
"kernel": ckpt[f"layers.{layer}.mlp.gate_proj.weight"]
|
||||
.numpy()
|
||||
.transpose()
|
||||
},
|
||||
"w2": {
|
||||
"kernel": ckpt[f"layers.{layer}.mlp.down_proj.weight"]
|
||||
.numpy()
|
||||
.transpose()
|
||||
},
|
||||
"w3": {
|
||||
"kernel": ckpt[f"layers.{layer}.mlp.up_proj.weight"]
|
||||
.numpy()
|
||||
.transpose()
|
||||
},
|
||||
},
|
||||
"attention_norm": {
|
||||
"kernel": ckpt[f"layers.{layer}.input_layernorm.weight"].numpy()
|
||||
},
|
||||
"ffn_norm": {
|
||||
"kernel": ckpt[
|
||||
f"layers.{layer}.post_attention_layernorm.weight"
|
||||
].numpy()
|
||||
},
|
||||
}
|
||||
for layer in range(params["n_layers"])
|
||||
},
|
||||
},
|
||||
"lm_head": {"kernel": ckpt["lm_head.weight"].numpy().transpose()},
|
||||
}
|
||||
print(f"Convert weight to easylm format finished...")
|
||||
print(f"Start to save...")
|
||||
|
||||
if args.streaming:
|
||||
StreamingCheckpointer.save_train_state_to_file(jax_weights, args.output_file)
|
||||
else:
|
||||
with mlxu.open_file(args.output_file, "wb") as fout:
|
||||
fout.write(flax.serialization.msgpack_serialize(jax_weights, in_place=True))
|
||||
|
||||
print(
|
||||
f"Save finished!!! take time: {time.time() - start} save path: {args.output_file}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="hf to easylm format script")
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint_dir",
|
||||
type=str,
|
||||
help="Need to be converted model weight dir. it is a dir",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_file", type=str, help="Save model weight file path, it is a file."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_size",
|
||||
type=str,
|
||||
default="7b",
|
||||
choices=["7b", "13b", "30b", "65b"],
|
||||
help="model size",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--streaming",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="whether is model weight saved stream format",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
print(f"checkpoint_dir: {args.checkpoint_dir}")
|
||||
print(f"output_file: {args.output_file}")
|
||||
print(f"model_size: {args.model_size}")
|
||||
print(f"streaming: {args.streaming}")
|
||||
|
||||
main(args)
|
||||
Reference in New Issue
Block a user