初始化项目,由ModelHub XC社区提供模型
Model: Finnish-NLP/Ahma-7B Source: Original Platform
This commit is contained in:
338
EasyLM/models/llama/convert_easylm_to_hf.py
Normal file
338
EasyLM/models/llama/convert_easylm_to_hf.py
Normal file
@@ -0,0 +1,338 @@
|
||||
# Copyright 2022 EleutherAI and The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2023 Xinyang Geng
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# This script converts LLaMA model checkpoint trained by EsayLM to the
|
||||
# HuggingFace transformers LLaMA PyTorch format, which can then be loaded
|
||||
# by HuggingFace transformers.
|
||||
|
||||
import gc
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import shutil
|
||||
|
||||
import numpy as np
|
||||
import mlxu
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import flax
|
||||
from flax.traverse_util import flatten_dict
|
||||
import torch
|
||||
from transformers import LlamaConfig, LlamaForCausalLM
|
||||
|
||||
from EasyLM.checkpoint import StreamingCheckpointer
|
||||
from EasyLM.jax_utils import float_tensor_to_dtype
|
||||
|
||||
|
||||
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
|
||||
load_checkpoint='',
|
||||
tokenizer_path='',
|
||||
model_size='13b',
|
||||
output_dir='',
|
||||
)
|
||||
|
||||
|
||||
LLAMA_STANDARD_CONFIGS = {
|
||||
'small': {
|
||||
'vocab_size': 64256,
|
||||
'dim': 768,
|
||||
'intermediate_size': 3072,
|
||||
'n_layers': 12,
|
||||
'n_heads': 12,
|
||||
'norm_eps': 1e-6,
|
||||
},
|
||||
'medium': {
|
||||
'vocab_size': 64256,
|
||||
'dim': 1024,
|
||||
'intermediate_size': 4096,
|
||||
'n_layers': 24,
|
||||
'n_heads': 16,
|
||||
'norm_eps': 1e-6,
|
||||
},
|
||||
'large': {
|
||||
'vocab_size': 64256,
|
||||
'dim': 1536,
|
||||
'intermediate_size': 6144,
|
||||
'n_layers': 24,
|
||||
'n_heads': 16,
|
||||
'norm_eps': 1e-6,
|
||||
},
|
||||
'xlarge': {
|
||||
'vocab_size': 64256,
|
||||
'dim': 2048,
|
||||
'intermediate_size': 8192,
|
||||
'n_layers': 24,
|
||||
'n_heads': 32,
|
||||
'norm_eps': 1e-6,
|
||||
},
|
||||
'1b': {
|
||||
'vocab_size': 64256,
|
||||
'dim': 2048,
|
||||
'intermediate_size': 5504,
|
||||
'n_layers': 22,
|
||||
'n_heads': 16,
|
||||
'norm_eps': 1e-6,
|
||||
},
|
||||
'3b': {
|
||||
'vocab_size': 64256,
|
||||
'dim': 3200,
|
||||
'intermediate_size': 8640,
|
||||
'n_layers': 26,
|
||||
'n_heads': 32,
|
||||
'norm_eps': 1e-6,
|
||||
},
|
||||
'7b': {
|
||||
'vocab_size': 64256,
|
||||
'dim': 4096,
|
||||
'intermediate_size': 11008,
|
||||
'n_layers': 32,
|
||||
'n_heads': 32,
|
||||
'norm_eps': 1e-6,
|
||||
},
|
||||
'13b': {
|
||||
'vocab_size': 64256,
|
||||
'dim': 5120,
|
||||
'intermediate_size': 13824,
|
||||
'n_layers': 40,
|
||||
'n_heads': 40,
|
||||
'norm_eps': 1e-6,
|
||||
},
|
||||
'30b': {
|
||||
'vocab_size': 64256,
|
||||
'dim': 6656,
|
||||
'intermediate_size': 17920,
|
||||
'n_layers': 60,
|
||||
'n_heads': 52,
|
||||
'norm_eps': 1e-6,
|
||||
},
|
||||
'65b': {
|
||||
'vocab_size': 64256,
|
||||
'dim': 8192,
|
||||
'intermediate_size': 22016,
|
||||
'n_layers': 80,
|
||||
'n_heads': 64,
|
||||
'norm_eps': 1e-5,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def match_keywords(string, positives, negatives):
|
||||
for positive in positives:
|
||||
if positive not in string:
|
||||
return False
|
||||
for negative in negatives:
|
||||
if negative in string:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def load_and_convert_checkpoint(path):
|
||||
_, flax_params = StreamingCheckpointer.load_trainstate_checkpoint(path)
|
||||
flax_params = flatten_dict(flax_params['params'], sep='.')
|
||||
torch_params = {}
|
||||
for key, tensor in flax_params.items():
|
||||
if match_keywords(key, ["kernel"], ["norm", 'ln_f']):
|
||||
tensor = tensor.T
|
||||
torch_params[key] = torch.tensor(
|
||||
float_tensor_to_dtype(tensor, 'fp32'), dtype=torch.float16
|
||||
)
|
||||
return torch_params
|
||||
|
||||
|
||||
def read_json(path):
|
||||
with open(path, "r") as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def write_json(text, path):
|
||||
with open(path, "w") as f:
|
||||
json.dump(text, f)
|
||||
|
||||
|
||||
def write_model(loaded, model_path, model_size):
|
||||
os.makedirs(model_path, exist_ok=True)
|
||||
tmp_model_path = os.path.join(model_path, "tmp")
|
||||
os.makedirs(tmp_model_path, exist_ok=True)
|
||||
|
||||
params = LLAMA_STANDARD_CONFIGS[model_size]
|
||||
|
||||
n_layers = params["n_layers"]
|
||||
n_heads = params["n_heads"]
|
||||
dim = params["dim"]
|
||||
dims_per_head = dim // n_heads
|
||||
base = 10000.0
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head))
|
||||
|
||||
# permute for sliced rotary
|
||||
def permute(w):
|
||||
return w.view(n_heads, dim // n_heads // 2, 2, dim).transpose(1, 2).reshape(dim, dim)
|
||||
|
||||
|
||||
param_count = 0
|
||||
index_dict = {"weight_map": {}}
|
||||
for layer_i in range(n_layers):
|
||||
filename = f"pytorch_model-{layer_i + 1}-of-{n_layers + 1}.bin"
|
||||
state_dict = {
|
||||
f"model.layers.{layer_i}.self_attn.q_proj.weight": permute(
|
||||
loaded[f"transformer.h.{layer_i}.attention.wq.kernel"]
|
||||
),
|
||||
f"model.layers.{layer_i}.self_attn.k_proj.weight": permute(
|
||||
loaded[f"transformer.h.{layer_i}.attention.wk.kernel"]
|
||||
),
|
||||
f"model.layers.{layer_i}.self_attn.v_proj.weight": loaded[f"transformer.h.{layer_i}.attention.wv.kernel"],
|
||||
f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded[f"transformer.h.{layer_i}.attention.wo.kernel"],
|
||||
|
||||
f"model.layers.{layer_i}.mlp.gate_proj.weight": loaded[f"transformer.h.{layer_i}.feed_forward.w1.kernel"],
|
||||
f"model.layers.{layer_i}.mlp.down_proj.weight": loaded[f"transformer.h.{layer_i}.feed_forward.w2.kernel"],
|
||||
f"model.layers.{layer_i}.mlp.up_proj.weight": loaded[f"transformer.h.{layer_i}.feed_forward.w3.kernel"],
|
||||
|
||||
f"model.layers.{layer_i}.input_layernorm.weight": loaded[f"transformer.h.{layer_i}.attention_norm.kernel"],
|
||||
f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[f"transformer.h.{layer_i}.ffn_norm.kernel"],
|
||||
|
||||
}
|
||||
|
||||
state_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq
|
||||
for k, v in state_dict.items():
|
||||
index_dict["weight_map"][k] = filename
|
||||
param_count += v.numel()
|
||||
torch.save(state_dict, os.path.join(tmp_model_path, filename))
|
||||
|
||||
filename = f"pytorch_model-{n_layers + 1}-of-{n_layers + 1}.bin"
|
||||
# Unsharded
|
||||
state_dict = {
|
||||
"model.embed_tokens.weight": loaded["transformer.wte.embedding"],
|
||||
"model.norm.weight": loaded["transformer.ln_f.kernel"],
|
||||
"lm_head.weight": loaded["lm_head.kernel"],
|
||||
}
|
||||
|
||||
for k, v in state_dict.items():
|
||||
index_dict["weight_map"][k] = filename
|
||||
param_count += v.numel()
|
||||
torch.save(state_dict, os.path.join(tmp_model_path, filename))
|
||||
|
||||
# Write configs
|
||||
index_dict["metadata"] = {"total_size": param_count * 2}
|
||||
write_json(index_dict, os.path.join(tmp_model_path, "pytorch_model.bin.index.json"))
|
||||
|
||||
config = LlamaConfig(
|
||||
vocab_size=params["vocab_size"],
|
||||
hidden_size=dim,
|
||||
intermediate_size=params["intermediate_size"],
|
||||
num_attention_heads=params["n_heads"],
|
||||
num_hidden_layers=params["n_layers"],
|
||||
rms_norm_eps=params["norm_eps"],
|
||||
)
|
||||
config.save_pretrained(tmp_model_path)
|
||||
|
||||
# Make space so we can load the model properly now.
|
||||
del state_dict
|
||||
del loaded
|
||||
gc.collect()
|
||||
|
||||
print("Loading the checkpoint in a Llama model.")
|
||||
model = LlamaForCausalLM.from_pretrained(tmp_model_path, torch_dtype=torch.float16)
|
||||
# Avoid saving this as part of the config.
|
||||
print("Model parameter count", model.num_parameters())
|
||||
del model.config._name_or_path
|
||||
|
||||
print("Saving in the Transformers format.")
|
||||
model.save_pretrained(model_path, safe_serialization=True)
|
||||
shutil.rmtree(tmp_model_path)
|
||||
|
||||
|
||||
def write_tokenizer(tokenizer_path, input_tokenizer_path):
|
||||
print(f"Fetching the tokenizer from {input_tokenizer_path}.")
|
||||
os.makedirs(tokenizer_path, exist_ok=True)
|
||||
write_json(
|
||||
{
|
||||
"bos_token": {
|
||||
"content": "<s>",
|
||||
"lstrip": False,
|
||||
"normalized": True,
|
||||
"rstrip": False,
|
||||
"single_word": False
|
||||
},
|
||||
"eos_token": {
|
||||
"content": "</s>",
|
||||
"lstrip": False,
|
||||
"normalized": True,
|
||||
"rstrip": False,
|
||||
"single_word": False
|
||||
},
|
||||
"unk_token": {
|
||||
"content": "<unk>",
|
||||
"lstrip": False,
|
||||
"normalized": True,
|
||||
"rstrip": False,
|
||||
"single_word": False
|
||||
},
|
||||
},
|
||||
os.path.join(tokenizer_path, "special_tokens_map.json")
|
||||
)
|
||||
write_json(
|
||||
{
|
||||
"add_bos_token": True,
|
||||
"add_eos_token": False,
|
||||
"model_max_length": 2048,
|
||||
"pad_token": None,
|
||||
"sp_model_kwargs": {},
|
||||
"tokenizer_class": "LlamaTokenizer",
|
||||
"clean_up_tokenization_spaces": False,
|
||||
"bos_token": {
|
||||
"__type": "AddedToken",
|
||||
"content": "<s>",
|
||||
"lstrip": False,
|
||||
"normalized": True,
|
||||
"rstrip": False,
|
||||
"single_word": False
|
||||
},
|
||||
"eos_token": {
|
||||
"__type": "AddedToken",
|
||||
"content": "</s>",
|
||||
"lstrip": False,
|
||||
"normalized": True,
|
||||
"rstrip": False,
|
||||
"single_word": False
|
||||
},
|
||||
"unk_token": {
|
||||
"__type": "AddedToken",
|
||||
"content": "<unk>",
|
||||
"lstrip": False,
|
||||
"normalized": True,
|
||||
"rstrip": False,
|
||||
"single_word": False
|
||||
},
|
||||
},
|
||||
os.path.join(tokenizer_path, "tokenizer_config.json"),
|
||||
)
|
||||
shutil.copyfile(input_tokenizer_path, os.path.join(tokenizer_path, "tokenizer.model"))
|
||||
|
||||
|
||||
def main(argv):
|
||||
assert FLAGS.load_checkpoint != "" and FLAGS.output_dir != ""# and FLAGS.tokenizer_path != ""
|
||||
assert FLAGS.model_size in LLAMA_STANDARD_CONFIGS
|
||||
# write_tokenizer(
|
||||
# tokenizer_path=FLAGS.output_dir,
|
||||
# input_tokenizer_path=FLAGS.tokenizer_path,
|
||||
# )
|
||||
write_model(
|
||||
load_and_convert_checkpoint(FLAGS.load_checkpoint),
|
||||
model_path=FLAGS.output_dir,
|
||||
model_size=FLAGS.model_size,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mlxu.run(main)
|
||||
196
EasyLM/models/llama/convert_hf_to_easylm.py
Normal file
196
EasyLM/models/llama/convert_hf_to_easylm.py
Normal file
@@ -0,0 +1,196 @@
|
||||
"""
|
||||
Usage:
|
||||
python convert_hf_to_easylm.py \
|
||||
--checkpoint_dir /path/hf_format_dir/ \
|
||||
--output_file /path/easylm_format.stream \
|
||||
--model_size 7b \
|
||||
--streaming
|
||||
"""
|
||||
import time
|
||||
from pathlib import Path
|
||||
import argparse
|
||||
|
||||
import mlxu
|
||||
import torch
|
||||
import flax
|
||||
|
||||
from EasyLM.checkpoint import StreamingCheckpointer
|
||||
|
||||
LLAMA_STANDARD_CONFIGS = {
|
||||
'1b': {
|
||||
'dim': 2048,
|
||||
'intermediate_size': 5504,
|
||||
'n_layers': 22,
|
||||
'n_heads': 16,
|
||||
'norm_eps': 1e-6,
|
||||
},
|
||||
'3b': {
|
||||
'dim': 3200,
|
||||
'intermediate_size': 8640,
|
||||
'n_layers': 26,
|
||||
'n_heads': 32,
|
||||
'norm_eps': 1e-6,
|
||||
},
|
||||
"7b": {
|
||||
"dim": 4096,
|
||||
"intermediate_size": 11008,
|
||||
"n_layers": 32,
|
||||
"n_heads": 32,
|
||||
"norm_eps": 1e-6,
|
||||
},
|
||||
"13b": {
|
||||
"dim": 5120,
|
||||
"intermediate_size": 13824,
|
||||
"n_layers": 40,
|
||||
"n_heads": 40,
|
||||
"norm_eps": 1e-6,
|
||||
},
|
||||
"30b": {
|
||||
"dim": 6656,
|
||||
"intermediate_size": 17920,
|
||||
"n_layers": 60,
|
||||
"n_heads": 52,
|
||||
"norm_eps": 1e-6,
|
||||
},
|
||||
"65b": {
|
||||
"dim": 8192,
|
||||
"intermediate_size": 22016,
|
||||
"n_layers": 80,
|
||||
"n_heads": 64,
|
||||
"norm_eps": 1e-5,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def inverse_permute(params, w):
|
||||
n_layers = params["n_layers"]
|
||||
n_heads = params["n_heads"]
|
||||
dim = params["dim"]
|
||||
reshaped_w = w.reshape(n_heads, 2, dim // n_heads // 2, dim)
|
||||
transposed_w = reshaped_w.transpose(0, 2, 1, 3)
|
||||
inverted_w = transposed_w.reshape(dim, dim)
|
||||
return inverted_w
|
||||
|
||||
|
||||
def main(args):
|
||||
start = time.time()
|
||||
params = LLAMA_STANDARD_CONFIGS[args.model_size]
|
||||
|
||||
ckpt_paths = sorted(Path(args.checkpoint_dir).glob("*.bin"))
|
||||
ckpt = {}
|
||||
for i, ckpt_path in enumerate(ckpt_paths):
|
||||
checkpoint = torch.load(ckpt_path, map_location="cpu")
|
||||
for k, v in checkpoint.items():
|
||||
if k.startswith("model."):
|
||||
k = k[6:]
|
||||
ckpt[k] = v
|
||||
print(f"Start convert weight to easylm format...")
|
||||
jax_weights = {
|
||||
"transformer": {
|
||||
"wte": {"embedding": ckpt["embed_tokens.weight"].numpy()},
|
||||
"ln_f": {"kernel": ckpt["norm.weight"].numpy()},
|
||||
"h": {
|
||||
"%d"
|
||||
% (layer): {
|
||||
"attention": {
|
||||
"wq": {
|
||||
"kernel": inverse_permute(
|
||||
params,
|
||||
ckpt[f"layers.{layer}.self_attn.q_proj.weight"].numpy(),
|
||||
).transpose()
|
||||
},
|
||||
"wk": {
|
||||
"kernel": inverse_permute(
|
||||
params,
|
||||
ckpt[f"layers.{layer}.self_attn.k_proj.weight"].numpy(),
|
||||
).transpose()
|
||||
},
|
||||
"wv": {
|
||||
"kernel": ckpt[f"layers.{layer}.self_attn.v_proj.weight"]
|
||||
.numpy()
|
||||
.transpose()
|
||||
},
|
||||
"wo": {
|
||||
"kernel": ckpt[f"layers.{layer}.self_attn.o_proj.weight"]
|
||||
.numpy()
|
||||
.transpose()
|
||||
},
|
||||
},
|
||||
"feed_forward": {
|
||||
"w1": {
|
||||
"kernel": ckpt[f"layers.{layer}.mlp.gate_proj.weight"]
|
||||
.numpy()
|
||||
.transpose()
|
||||
},
|
||||
"w2": {
|
||||
"kernel": ckpt[f"layers.{layer}.mlp.down_proj.weight"]
|
||||
.numpy()
|
||||
.transpose()
|
||||
},
|
||||
"w3": {
|
||||
"kernel": ckpt[f"layers.{layer}.mlp.up_proj.weight"]
|
||||
.numpy()
|
||||
.transpose()
|
||||
},
|
||||
},
|
||||
"attention_norm": {
|
||||
"kernel": ckpt[f"layers.{layer}.input_layernorm.weight"].numpy()
|
||||
},
|
||||
"ffn_norm": {
|
||||
"kernel": ckpt[
|
||||
f"layers.{layer}.post_attention_layernorm.weight"
|
||||
].numpy()
|
||||
},
|
||||
}
|
||||
for layer in range(params["n_layers"])
|
||||
},
|
||||
},
|
||||
"lm_head": {"kernel": ckpt["lm_head.weight"].numpy().transpose()},
|
||||
}
|
||||
print(f"Convert weight to easylm format finished...")
|
||||
print(f"Start to save...")
|
||||
|
||||
if args.streaming:
|
||||
StreamingCheckpointer.save_train_state_to_file(jax_weights, args.output_file)
|
||||
else:
|
||||
with mlxu.open_file(args.output_file, "wb") as fout:
|
||||
fout.write(flax.serialization.msgpack_serialize(jax_weights, in_place=True))
|
||||
|
||||
print(
|
||||
f"Save finished!!! take time: {time.time() - start} save path: {args.output_file}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="hf to easylm format script")
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint_dir",
|
||||
type=str,
|
||||
help="Need to be converted model weight dir. it is a dir",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_file", type=str, help="Save model weight file path, it is a file."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_size",
|
||||
type=str,
|
||||
default="7b",
|
||||
choices=["7b", "13b", "30b", "65b"],
|
||||
help="model size",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--streaming",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="whether is model weight saved stream format",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
print(f"checkpoint_dir: {args.checkpoint_dir}")
|
||||
print(f"output_file: {args.output_file}")
|
||||
print(f"model_size: {args.model_size}")
|
||||
print(f"streaming: {args.streaming}")
|
||||
|
||||
main(args)
|
||||
68
EasyLM/models/llama/convert_torch_to_easylm.py
Normal file
68
EasyLM/models/llama/convert_torch_to_easylm.py
Normal file
@@ -0,0 +1,68 @@
|
||||
# This script converts the standrd LLaMA PyTorch checkpoint released by Meta
|
||||
# to the EasyLM checkpoint format. The converted checkpoint can then be loaded
|
||||
# by EasyLM for fine-tuning or inference.
|
||||
|
||||
# This script is largely borrow from https://github.com/Sea-Snell/JAX_llama
|
||||
|
||||
from pathlib import Path
|
||||
import json
|
||||
import numpy as np
|
||||
import torch
|
||||
import flax
|
||||
import mlxu
|
||||
|
||||
from EasyLM.checkpoint import StreamingCheckpointer
|
||||
|
||||
|
||||
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
|
||||
checkpoint_dir='',
|
||||
output_file='',
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
|
||||
def main(argv):
|
||||
ckpt_paths = sorted(Path(FLAGS.checkpoint_dir).glob("*.pth"))
|
||||
ckpts = {}
|
||||
for i, ckpt_path in enumerate(ckpt_paths):
|
||||
checkpoint = torch.load(ckpt_path, map_location="cpu")
|
||||
ckpts[int(ckpt_path.name.split('.', maxsplit=2)[1])] = checkpoint
|
||||
ckpts = [ckpts[i] for i in sorted(list(ckpts.keys()))]
|
||||
with open(Path(FLAGS.checkpoint_dir) / "params.json", "r") as f:
|
||||
params = json.loads(f.read())
|
||||
|
||||
jax_weights = {
|
||||
'transformer': {
|
||||
'wte': {'embedding': np.concatenate([ckpt['tok_embeddings.weight'].numpy() for ckpt in ckpts], axis=1)},
|
||||
'ln_f': {'kernel': ckpts[0]['norm.weight'].numpy()},
|
||||
'h': {
|
||||
'%d' % (layer): {
|
||||
'attention': {
|
||||
'wq': {'kernel': np.concatenate([ckpt['layers.%d.attention.wq.weight' % (layer)].numpy() for ckpt in ckpts], axis=0).transpose()},
|
||||
'wk': {'kernel': np.concatenate([ckpt['layers.%d.attention.wk.weight' % (layer)].numpy() for ckpt in ckpts], axis=0).transpose()},
|
||||
'wv': {'kernel': np.concatenate([ckpt['layers.%d.attention.wv.weight' % (layer)].numpy() for ckpt in ckpts], axis=0).transpose()},
|
||||
'wo': {'kernel': np.concatenate([ckpt['layers.%d.attention.wo.weight' % (layer)].numpy() for ckpt in ckpts], axis=1).transpose()},
|
||||
},
|
||||
'feed_forward': {
|
||||
'w1': {'kernel': np.concatenate([ckpt['layers.%d.feed_forward.w1.weight' % (layer)].numpy() for ckpt in ckpts], axis=0).transpose()},
|
||||
'w2': {'kernel': np.concatenate([ckpt['layers.%d.feed_forward.w2.weight' % (layer)].numpy() for ckpt in ckpts], axis=1).transpose()},
|
||||
'w3': {'kernel': np.concatenate([ckpt['layers.%d.feed_forward.w3.weight' % (layer)].numpy() for ckpt in ckpts], axis=0).transpose()},
|
||||
},
|
||||
'attention_norm': {'kernel': ckpts[0]['layers.%d.attention_norm.weight' % (layer)].numpy()},
|
||||
'ffn_norm': {'kernel': ckpts[0]['layers.%d.ffn_norm.weight' % (layer)].numpy()},
|
||||
}
|
||||
for layer in range(params['n_layers'])},
|
||||
},
|
||||
'lm_head': {'kernel': np.concatenate([ckpt['output.weight'].numpy() for ckpt in ckpts], axis=0).transpose()},
|
||||
}
|
||||
if FLAGS.streaming:
|
||||
StreamingCheckpointer.save_train_state_to_file(
|
||||
jax_weights, FLAGS.output_file
|
||||
)
|
||||
else:
|
||||
with mlxu.open_file(FLAGS.output_file, 'wb') as fout:
|
||||
fout.write(flax.serialization.msgpack_serialize(jax_weights, in_place=True))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
mlxu.run(main)
|
||||
1530
EasyLM/models/llama/llama_model.py
Normal file
1530
EasyLM/models/llama/llama_model.py
Normal file
File diff suppressed because it is too large
Load Diff
386
EasyLM/models/llama/llama_serve.py
Normal file
386
EasyLM/models/llama/llama_serve.py
Normal file
@@ -0,0 +1,386 @@
|
||||
import pprint
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import mlxu
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax.experimental.pjit import pjit
|
||||
from jax.sharding import PartitionSpec as PS
|
||||
import optax
|
||||
from transformers import GenerationConfig, FlaxLogitsProcessorList
|
||||
|
||||
from EasyLM.checkpoint import StreamingCheckpointer
|
||||
from EasyLM.serving import LMServer
|
||||
from EasyLM.jax_utils import (
|
||||
JaxRNG, JaxDistributedConfig, next_rng, match_partition_rules, tree_apply,
|
||||
set_random_seed, get_float_dtype_by_name, make_shard_and_gather_fns,
|
||||
with_sharding_constraint, FlaxTemperatureLogitsWarper
|
||||
)
|
||||
from EasyLM.models.llama.llama_model import LLaMAConfig, FlaxLLaMAForCausalLM
|
||||
|
||||
|
||||
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
|
||||
seed=42,
|
||||
initialize_jax_distributed=False,
|
||||
mesh_dim='1,-1,1',
|
||||
dtype='bf16',
|
||||
input_length=1024,
|
||||
seq_length=2048,
|
||||
top_k=50,
|
||||
top_p=1.0,
|
||||
do_sample=True,
|
||||
num_beams=1,
|
||||
add_bos_token=True,
|
||||
load_llama_config='',
|
||||
load_checkpoint='',
|
||||
tokenizer=LLaMAConfig.get_tokenizer_config(),
|
||||
lm_server=LMServer.get_default_config(),
|
||||
jax_distributed=JaxDistributedConfig.get_default_config(),
|
||||
)
|
||||
|
||||
|
||||
def main(argv):
|
||||
JaxDistributedConfig.initialize(FLAGS.jax_distributed)
|
||||
set_random_seed(FLAGS.seed)
|
||||
|
||||
prefix_tokenizer = LLaMAConfig.get_tokenizer(
|
||||
FLAGS.tokenizer, truncation_side='left', padding_side='left'
|
||||
)
|
||||
tokenizer = LLaMAConfig.get_tokenizer(
|
||||
FLAGS.tokenizer, truncation_side='right', padding_side='right'
|
||||
)
|
||||
|
||||
with jax.default_device(jax.devices("cpu")[0]):
|
||||
llama_config = LLaMAConfig.load_config(FLAGS.load_llama_config)
|
||||
_, params = StreamingCheckpointer.load_trainstate_checkpoint(
|
||||
FLAGS.load_checkpoint, disallow_trainstate=True
|
||||
)
|
||||
|
||||
hf_model = FlaxLLaMAForCausalLM(
|
||||
llama_config,
|
||||
input_shape=(1, FLAGS.seq_length),
|
||||
seed=FLAGS.seed,
|
||||
_do_init=False
|
||||
)
|
||||
|
||||
model_ps = match_partition_rules(
|
||||
LLaMAConfig.get_partition_rules(), params
|
||||
)
|
||||
shard_fns, _ = make_shard_and_gather_fns(
|
||||
model_ps, get_float_dtype_by_name(FLAGS.dtype)
|
||||
)
|
||||
|
||||
@partial(
|
||||
pjit,
|
||||
in_shardings=(model_ps, PS(), PS()),
|
||||
out_shardings=(PS(), PS(), PS())
|
||||
)
|
||||
def forward_loglikelihood(params, rng, batch):
|
||||
batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
|
||||
rng_generator = JaxRNG(rng)
|
||||
input_tokens = batch['input_tokens']
|
||||
output_tokens = batch['output_tokens']
|
||||
input_mask = batch['input_mask']
|
||||
output_mask = batch['output_mask']
|
||||
|
||||
logits = hf_model.module.apply(
|
||||
params, input_tokens, attention_mask=input_mask,
|
||||
deterministic=True, rngs=rng_generator(llama_config.rng_keys()),
|
||||
).logits
|
||||
# if llama_config.n_real_tokens is not None:
|
||||
# logits = logits.at[:, :, llama_config.n_real_tokens:].set(-1e8)
|
||||
loglikelihood = -optax.softmax_cross_entropy_with_integer_labels(
|
||||
logits, output_tokens
|
||||
)
|
||||
loglikelihood = jnp.sum(loglikelihood * output_mask, axis=-1)
|
||||
match_count = jnp.sum(
|
||||
(jnp.argmax(logits, axis=-1) == output_tokens) * output_mask,
|
||||
axis=-1
|
||||
)
|
||||
total = jnp.sum(output_mask, axis=-1)
|
||||
is_greedy = match_count == total
|
||||
return loglikelihood, is_greedy, rng_generator()
|
||||
|
||||
|
||||
@partial(
|
||||
pjit,
|
||||
in_shardings=(model_ps, PS(), PS(), PS()),
|
||||
out_shardings=(PS(), PS())
|
||||
)
|
||||
def forward_generate(params, rng, batch, temperature):
|
||||
batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
|
||||
rng_generator = JaxRNG(rng)
|
||||
output = hf_model.generate(
|
||||
batch['input_tokens'],
|
||||
attention_mask=batch['attention_mask'],
|
||||
params=params['params'],
|
||||
prng_key=rng_generator(),
|
||||
logits_processor=FlaxLogitsProcessorList(
|
||||
[FlaxTemperatureLogitsWarper(temperature)]
|
||||
),
|
||||
generation_config=GenerationConfig(
|
||||
max_new_tokens=FLAGS.seq_length - FLAGS.input_length,
|
||||
pad_token_id=tokenizer.eos_token_id,
|
||||
bos_token_id=tokenizer.bos_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
do_sample=FLAGS.do_sample,
|
||||
num_beams=FLAGS.num_beams,
|
||||
top_k=FLAGS.top_k,
|
||||
top_p=FLAGS.top_p,
|
||||
)
|
||||
).sequences[:, batch['input_tokens'].shape[1]:]
|
||||
return output, rng_generator()
|
||||
|
||||
@partial(
|
||||
pjit,
|
||||
in_shardings=(model_ps, PS(), PS()),
|
||||
out_shardings=(PS(), PS())
|
||||
)
|
||||
def forward_greedy_generate(params, rng, batch):
|
||||
batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
|
||||
rng_generator = JaxRNG(rng)
|
||||
output = hf_model.generate(
|
||||
batch['input_tokens'],
|
||||
attention_mask=batch['attention_mask'],
|
||||
params=params['params'],
|
||||
prng_key=rng_generator(),
|
||||
generation_config=GenerationConfig(
|
||||
max_new_tokens=FLAGS.seq_length - FLAGS.input_length,
|
||||
pad_token_id=tokenizer.eos_token_id,
|
||||
bos_token_id=tokenizer.bos_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
do_sample=False,
|
||||
num_beams=1,
|
||||
)
|
||||
).sequences[:, batch['input_tokens'].shape[1]:]
|
||||
return output, rng_generator()
|
||||
|
||||
mesh = LLaMAConfig.get_jax_mesh(FLAGS.mesh_dim)
|
||||
with mesh:
|
||||
params = tree_apply(shard_fns, params)
|
||||
sharded_rng = next_rng()
|
||||
|
||||
class ModelServer(LMServer):
|
||||
|
||||
@staticmethod
|
||||
def loglikelihood(prefix_text, text):
|
||||
nonlocal sharded_rng
|
||||
prefix = prefix_tokenizer(
|
||||
prefix_text,
|
||||
padding='max_length',
|
||||
truncation=True,
|
||||
max_length=FLAGS.input_length,
|
||||
return_tensors='np',
|
||||
)
|
||||
inputs = tokenizer(
|
||||
text,
|
||||
padding='max_length',
|
||||
truncation=True,
|
||||
max_length=FLAGS.seq_length - FLAGS.input_length,
|
||||
return_tensors='np',
|
||||
)
|
||||
output_tokens = np.concatenate([prefix.input_ids, inputs.input_ids], axis=1)
|
||||
bos_tokens = np.full(
|
||||
(output_tokens.shape[0], 1), tokenizer.bos_token_id, dtype=np.int32
|
||||
)
|
||||
input_tokens = np.concatenate([bos_tokens, output_tokens[:, :-1]], axis=-1)
|
||||
input_mask = np.concatenate(
|
||||
[prefix.attention_mask, inputs.attention_mask], axis=1
|
||||
)
|
||||
if FLAGS.add_bos_token:
|
||||
bos_mask = np.ones_like(input_mask[:, :1])
|
||||
else:
|
||||
bos_mask = np.zeros_like(input_mask[:, :1])
|
||||
|
||||
input_mask = np.concatenate([bos_mask, input_mask[:, :-1]], axis=1)
|
||||
output_mask = np.concatenate(
|
||||
[np.zeros_like(prefix.attention_mask), inputs.attention_mask], axis=1
|
||||
)
|
||||
batch = dict(
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
input_mask=input_mask,
|
||||
output_mask=output_mask,
|
||||
)
|
||||
with mesh:
|
||||
loglikelihood, is_greedy, sharded_rng = forward_loglikelihood(
|
||||
params, sharded_rng, batch
|
||||
)
|
||||
loglikelihood, is_greedy = jax.device_get((loglikelihood, is_greedy))
|
||||
return loglikelihood, is_greedy
|
||||
|
||||
@staticmethod
|
||||
def loglikelihood_rolling(text):
|
||||
nonlocal sharded_rng
|
||||
inputs = tokenizer(
|
||||
text,
|
||||
padding='longest',
|
||||
truncation=False,
|
||||
max_length=np.iinfo(np.int32).max,
|
||||
return_tensors='np',
|
||||
)
|
||||
batch_size = inputs.input_ids.shape[0]
|
||||
output_tokens = inputs.input_ids
|
||||
attention_mask = inputs.attention_mask
|
||||
|
||||
if output_tokens.shape[1] < FLAGS.seq_length:
|
||||
padding_length = FLAGS.seq_length - output_tokens.shape[1]
|
||||
pad_tokens = np.full(
|
||||
(batch_size, padding_length), tokenizer.pad_token_id, dtype=np.int32
|
||||
)
|
||||
output_tokens = np.concatenate([output_tokens, pad_tokens], axis=-1)
|
||||
pad_mask = np.zeros(
|
||||
(batch_size, padding_length), dtype=inputs.attention_mask.dtype
|
||||
)
|
||||
attention_mask = np.concatenate([attention_mask, pad_mask], axis=-1)
|
||||
|
||||
bos_tokens = np.full(
|
||||
(batch_size, 1), tokenizer.bos_token_id, dtype=np.int32
|
||||
)
|
||||
input_tokens = np.concatenate([bos_tokens, output_tokens[:, :-1]], axis=-1)
|
||||
bos_mask = np.ones((batch_size, 1), dtype=inputs.attention_mask.dtype)
|
||||
total_seq_length = output_tokens.shape[1]
|
||||
|
||||
total_loglikelihood = 0.0
|
||||
total_is_greedy = True
|
||||
# Sliding window
|
||||
for i in range(0, total_seq_length, FLAGS.seq_length):
|
||||
# Last window
|
||||
if i + FLAGS.seq_length > total_seq_length:
|
||||
last_output_mask = np.copy(attention_mask[:, -FLAGS.seq_length:])
|
||||
last_output_mask[:, :i - total_seq_length] = 0.0
|
||||
|
||||
batch = dict(
|
||||
input_tokens=input_tokens[:, -FLAGS.seq_length:],
|
||||
output_tokens=output_tokens[:, -FLAGS.seq_length:],
|
||||
input_mask=attention_mask[:, -FLAGS.seq_length:],
|
||||
output_mask=last_output_mask,
|
||||
)
|
||||
|
||||
# Normal window
|
||||
else:
|
||||
batch = dict(
|
||||
input_tokens=input_tokens[:, i:i + FLAGS.seq_length],
|
||||
output_tokens=output_tokens[:, i:i + FLAGS.seq_length],
|
||||
input_mask=attention_mask[:, i:i + FLAGS.seq_length],
|
||||
output_mask=attention_mask[:, i:i + FLAGS.seq_length],
|
||||
)
|
||||
|
||||
with mesh:
|
||||
loglikelihood, is_greedy, sharded_rng = forward_loglikelihood(
|
||||
params, sharded_rng, batch
|
||||
)
|
||||
loglikelihood, is_greedy = jax.device_get((loglikelihood, is_greedy))
|
||||
|
||||
total_loglikelihood += loglikelihood
|
||||
total_is_greedy = np.logical_and(is_greedy, total_is_greedy)
|
||||
|
||||
return total_loglikelihood, total_is_greedy
|
||||
|
||||
@staticmethod
|
||||
def generate(text, temperature):
|
||||
nonlocal sharded_rng
|
||||
inputs = prefix_tokenizer(
|
||||
text,
|
||||
padding='max_length',
|
||||
truncation=True,
|
||||
max_length=FLAGS.input_length,
|
||||
return_tensors='np',
|
||||
)
|
||||
input_tokens = inputs.input_ids
|
||||
input_mask = inputs.attention_mask
|
||||
if FLAGS.add_bos_token:
|
||||
input_tokens[:, 0] = tokenizer.bos_token_id
|
||||
input_mask[:, 0] = 1
|
||||
batch = dict(
|
||||
input_tokens=input_tokens,
|
||||
attention_mask=input_mask,
|
||||
)
|
||||
with mesh:
|
||||
output, sharded_rng = forward_generate(
|
||||
params, sharded_rng, batch, temperature
|
||||
)
|
||||
output = jax.device_get(output)
|
||||
output_text = []
|
||||
for text in list(tokenizer.batch_decode(output)):
|
||||
if tokenizer.eos_token in text:
|
||||
text = text.split(tokenizer.eos_token, maxsplit=1)[0]
|
||||
output_text.append(text)
|
||||
|
||||
return output_text
|
||||
|
||||
@staticmethod
|
||||
def greedy_until(prefix_text, until, max_length):
|
||||
nonlocal sharded_rng
|
||||
all_outputs = []
|
||||
for pf, ut in zip(prefix_text, until):
|
||||
if isinstance(ut, str):
|
||||
ut = [ut]
|
||||
total_length = 0
|
||||
total_generated = ''
|
||||
|
||||
while total_length < max_length:
|
||||
pf_tokens = tokenizer(
|
||||
pf,
|
||||
padding=False,
|
||||
truncation=False,
|
||||
max_length=np.iinfo(np.int32).max,
|
||||
return_tensors='np',
|
||||
)
|
||||
input_tokens = pf_tokens.input_ids
|
||||
attention_mask = pf_tokens.attention_mask
|
||||
|
||||
if input_tokens.shape[1] < FLAGS.input_length:
|
||||
extra = FLAGS.input_length - input_tokens.shape[1]
|
||||
pad_tokens = np.full(
|
||||
(1, extra), tokenizer.pad_token_id, dtype=np.int32
|
||||
)
|
||||
input_tokens = np.concatenate(
|
||||
[pad_tokens, input_tokens], axis=1
|
||||
)
|
||||
pad_attention = np.zeros((1, extra), dtype=attention_mask.dtype)
|
||||
attention_mask = np.concatenate(
|
||||
[pad_attention, attention_mask], axis=1
|
||||
)
|
||||
elif input_tokens.shape[1] > FLAGS.input_length:
|
||||
input_tokens = input_tokens[:, -FLAGS.input_length:]
|
||||
attention_mask = attention_mask[:, -FLAGS.input_length:]
|
||||
|
||||
if FLAGS.add_bos_token:
|
||||
input_tokens[:, 0] = tokenizer.bos_token_id
|
||||
attention_mask[:, 0] = 1
|
||||
|
||||
batch = dict(input_tokens=input_tokens, attention_mask=attention_mask)
|
||||
|
||||
with mesh:
|
||||
output, sharded_rng = forward_greedy_generate(
|
||||
params, sharded_rng, batch
|
||||
)
|
||||
output = jax.device_get(output)
|
||||
|
||||
total_length += output.shape[1]
|
||||
output_text = tokenizer.batch_decode(output)[0]
|
||||
total_generated = total_generated + output_text
|
||||
pf = pf + output_text
|
||||
|
||||
done = False
|
||||
for s in ut:
|
||||
if s in total_generated:
|
||||
total_generated = total_generated.split(s, maxsplit=1)[0]
|
||||
done = True
|
||||
if done:
|
||||
break
|
||||
|
||||
all_outputs.append(total_generated)
|
||||
|
||||
return all_outputs
|
||||
|
||||
|
||||
server = ModelServer(FLAGS.lm_server)
|
||||
server.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mlxu.run(main)
|
||||
268
EasyLM/models/llama/llama_train.py
Normal file
268
EasyLM/models/llama/llama_train.py
Normal file
@@ -0,0 +1,268 @@
|
||||
import pprint
|
||||
from functools import partial
|
||||
|
||||
from tqdm import tqdm, trange
|
||||
import numpy as np
|
||||
import mlxu
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax.experimental.pjit import pjit
|
||||
from jax.sharding import PartitionSpec as PS
|
||||
from flax.training.train_state import TrainState
|
||||
|
||||
from EasyLM.data import DatasetFactory
|
||||
from EasyLM.checkpoint import StreamingCheckpointer
|
||||
from EasyLM.optimizers import OptimizerFactory
|
||||
from EasyLM.jax_utils import (
|
||||
JaxRNG, JaxDistributedConfig, next_rng, match_partition_rules,
|
||||
cross_entropy_loss_and_accuracy, global_norm, get_float_dtype_by_name,
|
||||
set_random_seed, average_metrics, get_weight_decay_mask,
|
||||
make_shard_and_gather_fns, with_sharding_constraint,
|
||||
)
|
||||
from EasyLM.models.llama.llama_model import (
|
||||
LLaMAConfig, FlaxLLaMAForCausalLMModule
|
||||
)
|
||||
|
||||
|
||||
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
|
||||
seed=42,
|
||||
mesh_dim='1,-1,1',
|
||||
dtype='fp32',
|
||||
param_dtype='fp32',
|
||||
total_steps=10000,
|
||||
load_llama_config='',
|
||||
update_llama_config='',
|
||||
load_checkpoint='',
|
||||
load_dataset_state='',
|
||||
log_freq=50,
|
||||
save_model_freq=0,
|
||||
save_milestone_freq=0,
|
||||
eval_freq=0,
|
||||
tokenizer=LLaMAConfig.get_tokenizer_config(),
|
||||
train_dataset=DatasetFactory.get_default_config(),
|
||||
eval_dataset=DatasetFactory.get_default_config(),
|
||||
optimizer=OptimizerFactory.get_default_config(),
|
||||
checkpointer=StreamingCheckpointer.get_default_config(),
|
||||
llama=LLaMAConfig.get_default_config(),
|
||||
logger=mlxu.WandBLogger.get_default_config(),
|
||||
log_all_worker=False,
|
||||
jax_distributed=JaxDistributedConfig.get_default_config(),
|
||||
)
|
||||
|
||||
|
||||
def main(argv):
|
||||
JaxDistributedConfig.initialize(FLAGS.jax_distributed)
|
||||
variant = mlxu.get_user_flags(FLAGS, FLAGS_DEF)
|
||||
flags_config_dict = mlxu.user_flags_to_config_dict(FLAGS, FLAGS_DEF)
|
||||
logger = mlxu.WandBLogger(
|
||||
config=FLAGS.logger,
|
||||
variant=variant,
|
||||
enable=FLAGS.log_all_worker or (jax.process_index() == 0),
|
||||
)
|
||||
set_random_seed(FLAGS.seed)
|
||||
|
||||
tokenizer = LLaMAConfig.get_tokenizer(FLAGS.tokenizer)
|
||||
dataset = DatasetFactory.load_dataset(FLAGS.train_dataset, tokenizer)
|
||||
if FLAGS.load_dataset_state != '':
|
||||
dataset.load_state_dict(mlxu.load_pickle(FLAGS.load_dataset_state))
|
||||
|
||||
if FLAGS.eval_freq > 0:
|
||||
eval_dataset = DatasetFactory.load_dataset(
|
||||
FLAGS.eval_dataset, dataset.tokenizer, eval_dataset=True
|
||||
)
|
||||
|
||||
seq_length = dataset.seq_length
|
||||
|
||||
if FLAGS.load_llama_config != '':
|
||||
llama_config = LLaMAConfig.load_config(FLAGS.load_llama_config)
|
||||
else:
|
||||
llama_config = LLaMAConfig(**FLAGS.llama)
|
||||
|
||||
if FLAGS.update_llama_config != '':
|
||||
llama_config.update(dict(eval(FLAGS.update_llama_config)))
|
||||
|
||||
llama_config.update(dict(
|
||||
bos_token_id=dataset.tokenizer.bos_token_id,
|
||||
eos_token_id=dataset.tokenizer.eos_token_id,
|
||||
))
|
||||
if llama_config.vocab_size < dataset.vocab_size:
|
||||
print("Updating model config vocab size from", llama_config.vocab_size, "to", dataset.vocab_size)
|
||||
llama_config.update(dict(vocab_size=dataset.vocab_size))
|
||||
|
||||
model = FlaxLLaMAForCausalLMModule(
|
||||
llama_config, dtype=get_float_dtype_by_name(FLAGS.dtype), param_dtype=get_float_dtype_by_name(FLAGS.param_dtype)
|
||||
)
|
||||
|
||||
optimizer, optimizer_info = OptimizerFactory.get_optimizer(
|
||||
FLAGS.optimizer,
|
||||
get_weight_decay_mask(LLaMAConfig.get_weight_decay_exclusions())
|
||||
)
|
||||
|
||||
def create_trainstate_from_params(params):
|
||||
return TrainState.create(params=params, tx=optimizer, apply_fn=None)
|
||||
|
||||
def init_fn(rng):
|
||||
rng_generator = JaxRNG(rng)
|
||||
params = model.init(
|
||||
input_ids=jnp.zeros((4, seq_length), dtype=jnp.int32),
|
||||
position_ids=jnp.zeros((4, seq_length), dtype=jnp.int32),
|
||||
attention_mask=jnp.ones((4, seq_length), dtype=jnp.int32),
|
||||
rngs=rng_generator(llama_config.rng_keys()),
|
||||
)
|
||||
return TrainState.create(params=params, tx=optimizer, apply_fn=None)
|
||||
|
||||
def train_step(train_state, rng, batch):
|
||||
rng_generator = JaxRNG(rng)
|
||||
batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
|
||||
def loss_and_accuracy(params):
|
||||
logits = model.apply(
|
||||
params, batch['input_tokens'], deterministic=False,
|
||||
rngs=rng_generator(llama_config.rng_keys()),
|
||||
).logits
|
||||
return cross_entropy_loss_and_accuracy(
|
||||
logits, batch['target_tokens'], batch['loss_masks']
|
||||
)
|
||||
grad_fn = jax.value_and_grad(loss_and_accuracy, has_aux=True)
|
||||
(loss, accuracy), grads = grad_fn(train_state.params)
|
||||
train_state = train_state.apply_gradients(grads=grads)
|
||||
metrics = dict(
|
||||
loss=loss,
|
||||
accuracy=accuracy,
|
||||
learning_rate=optimizer_info['learning_rate_schedule'](train_state.step),
|
||||
gradient_norm=global_norm(grads),
|
||||
param_norm=global_norm(train_state.params),
|
||||
)
|
||||
return train_state, rng_generator(), metrics
|
||||
|
||||
def eval_step(train_state, rng, batch):
|
||||
rng_generator = JaxRNG(rng)
|
||||
batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
|
||||
logits = model.apply(
|
||||
train_state.params, batch['input_tokens'], deterministic=True,
|
||||
rngs=rng_generator(llama_config.rng_keys()),
|
||||
).logits
|
||||
loss, accuracy = cross_entropy_loss_and_accuracy(
|
||||
logits, batch['target_tokens'], batch['loss_masks']
|
||||
)
|
||||
metrics = dict(
|
||||
eval_loss=loss,
|
||||
eval_accuracy=accuracy,
|
||||
)
|
||||
return rng_generator(), metrics
|
||||
|
||||
train_state_shapes = jax.eval_shape(init_fn, next_rng())
|
||||
train_state_partition = match_partition_rules(
|
||||
LLaMAConfig.get_partition_rules(), train_state_shapes
|
||||
)
|
||||
|
||||
shard_fns, gather_fns = make_shard_and_gather_fns(
|
||||
train_state_partition, train_state_shapes
|
||||
)
|
||||
checkpointer = StreamingCheckpointer(
|
||||
FLAGS.checkpointer, logger.output_dir,
|
||||
enable=jax.process_index() == 0,
|
||||
)
|
||||
|
||||
sharded_init_fn = pjit(
|
||||
init_fn,
|
||||
in_shardings=PS(),
|
||||
out_shardings=train_state_partition
|
||||
)
|
||||
|
||||
sharded_create_trainstate_from_params = pjit(
|
||||
create_trainstate_from_params,
|
||||
in_shardings=(train_state_partition.params, ),
|
||||
out_shardings=train_state_partition,
|
||||
donate_argnums=(0, ),
|
||||
)
|
||||
|
||||
sharded_train_step = pjit(
|
||||
train_step,
|
||||
in_shardings=(train_state_partition, PS(), PS()),
|
||||
out_shardings=(train_state_partition, PS(), PS()),
|
||||
donate_argnums=(0, 1),
|
||||
)
|
||||
|
||||
sharded_eval_step = pjit(
|
||||
eval_step,
|
||||
in_shardings=(train_state_partition, PS(), PS()),
|
||||
out_shardings=(PS(), PS()),
|
||||
donate_argnums=(1,),
|
||||
)
|
||||
|
||||
def save_checkpoint(train_state, milestone=False):
|
||||
step = int(jax.device_get(train_state.step))
|
||||
metadata = dict(
|
||||
step=step,
|
||||
variant=variant,
|
||||
flags=flags_config_dict,
|
||||
llama_config=llama_config.to_dict(),
|
||||
)
|
||||
checkpointer.save_all(
|
||||
train_state=train_state,
|
||||
gather_fns=gather_fns,
|
||||
metadata=metadata,
|
||||
dataset=dataset.get_state_dict(),
|
||||
milestone=milestone,
|
||||
)
|
||||
|
||||
mesh = LLaMAConfig.get_jax_mesh(FLAGS.mesh_dim)
|
||||
with mesh:
|
||||
train_state, restored_params = None, None
|
||||
if FLAGS.load_checkpoint != '':
|
||||
train_state, restored_params = checkpointer.load_trainstate_checkpoint(
|
||||
FLAGS.load_checkpoint, train_state_shapes, shard_fns
|
||||
)
|
||||
|
||||
if train_state is None and restored_params is None:
|
||||
# Initialize from scratch
|
||||
train_state = sharded_init_fn(next_rng())
|
||||
elif train_state is None and restored_params is not None:
|
||||
# Restore from params but initialize train_state
|
||||
train_state = sharded_create_trainstate_from_params(restored_params)
|
||||
del restored_params
|
||||
|
||||
start_step = int(jax.device_get(train_state.step))
|
||||
|
||||
if FLAGS.save_model_freq > 0:
|
||||
save_checkpoint(train_state)
|
||||
|
||||
sharded_rng = next_rng()
|
||||
|
||||
step_counter = trange(start_step, FLAGS.total_steps, ncols=0)
|
||||
|
||||
for step, (batch, dataset_metrics) in zip(step_counter, dataset):
|
||||
train_state, sharded_rng, metrics = sharded_train_step(
|
||||
train_state, sharded_rng, batch
|
||||
)
|
||||
|
||||
if FLAGS.eval_freq > 0 and (step + 1) % FLAGS.eval_freq == 0:
|
||||
eval_metric_list = []
|
||||
eval_iterator = iter(eval_dataset)
|
||||
for eval_batch, _ in eval_iterator:
|
||||
sharded_rng, eval_metrics = sharded_eval_step(
|
||||
train_state, sharded_rng, eval_batch
|
||||
)
|
||||
eval_metric_list.append(eval_metrics)
|
||||
metrics.update(average_metrics(eval_metric_list))
|
||||
|
||||
if FLAGS.log_freq > 0 and (step + 1) % FLAGS.log_freq == 0:
|
||||
log_metrics = {"step": step + 1}
|
||||
log_metrics.update(metrics)
|
||||
log_metrics.update(dataset_metrics)
|
||||
log_metrics = jax.device_get(log_metrics)
|
||||
logger.log(log_metrics)
|
||||
tqdm.write("\n" + pprint.pformat(log_metrics) + "\n")
|
||||
|
||||
if FLAGS.save_milestone_freq > 0 and (step + 1) % FLAGS.save_milestone_freq == 0:
|
||||
save_checkpoint(train_state, milestone=True)
|
||||
elif FLAGS.save_model_freq > 0 and (step + 1) % FLAGS.save_model_freq == 0:
|
||||
save_checkpoint(train_state)
|
||||
|
||||
if FLAGS.save_model_freq > 0:
|
||||
save_checkpoint(train_state)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mlxu.run(main)
|
||||
Reference in New Issue
Block a user