初始化项目,由ModelHub XC社区提供模型

Model: Finnish-NLP/Ahma-7B
Source: Original Platform
This commit is contained in:
ModelHub XC
2026-06-01 02:08:18 +08:00
commit be39ad8722
45 changed files with 297486 additions and 0 deletions

View File

@@ -0,0 +1,338 @@
# Copyright 2022 EleutherAI and The HuggingFace Inc. team. All rights reserved.
# Copyright 2023 Xinyang Geng
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This script converts LLaMA model checkpoint trained by EsayLM to the
# HuggingFace transformers LLaMA PyTorch format, which can then be loaded
# by HuggingFace transformers.
import gc
import json
import math
import os
import shutil
import numpy as np
import mlxu
import jax
import jax.numpy as jnp
import flax
from flax.traverse_util import flatten_dict
import torch
from transformers import LlamaConfig, LlamaForCausalLM
from EasyLM.checkpoint import StreamingCheckpointer
from EasyLM.jax_utils import float_tensor_to_dtype
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
load_checkpoint='',
tokenizer_path='',
model_size='13b',
output_dir='',
)
LLAMA_STANDARD_CONFIGS = {
'small': {
'vocab_size': 64256,
'dim': 768,
'intermediate_size': 3072,
'n_layers': 12,
'n_heads': 12,
'norm_eps': 1e-6,
},
'medium': {
'vocab_size': 64256,
'dim': 1024,
'intermediate_size': 4096,
'n_layers': 24,
'n_heads': 16,
'norm_eps': 1e-6,
},
'large': {
'vocab_size': 64256,
'dim': 1536,
'intermediate_size': 6144,
'n_layers': 24,
'n_heads': 16,
'norm_eps': 1e-6,
},
'xlarge': {
'vocab_size': 64256,
'dim': 2048,
'intermediate_size': 8192,
'n_layers': 24,
'n_heads': 32,
'norm_eps': 1e-6,
},
'1b': {
'vocab_size': 64256,
'dim': 2048,
'intermediate_size': 5504,
'n_layers': 22,
'n_heads': 16,
'norm_eps': 1e-6,
},
'3b': {
'vocab_size': 64256,
'dim': 3200,
'intermediate_size': 8640,
'n_layers': 26,
'n_heads': 32,
'norm_eps': 1e-6,
},
'7b': {
'vocab_size': 64256,
'dim': 4096,
'intermediate_size': 11008,
'n_layers': 32,
'n_heads': 32,
'norm_eps': 1e-6,
},
'13b': {
'vocab_size': 64256,
'dim': 5120,
'intermediate_size': 13824,
'n_layers': 40,
'n_heads': 40,
'norm_eps': 1e-6,
},
'30b': {
'vocab_size': 64256,
'dim': 6656,
'intermediate_size': 17920,
'n_layers': 60,
'n_heads': 52,
'norm_eps': 1e-6,
},
'65b': {
'vocab_size': 64256,
'dim': 8192,
'intermediate_size': 22016,
'n_layers': 80,
'n_heads': 64,
'norm_eps': 1e-5,
},
}
def match_keywords(string, positives, negatives):
for positive in positives:
if positive not in string:
return False
for negative in negatives:
if negative in string:
return False
return True
def load_and_convert_checkpoint(path):
_, flax_params = StreamingCheckpointer.load_trainstate_checkpoint(path)
flax_params = flatten_dict(flax_params['params'], sep='.')
torch_params = {}
for key, tensor in flax_params.items():
if match_keywords(key, ["kernel"], ["norm", 'ln_f']):
tensor = tensor.T
torch_params[key] = torch.tensor(
float_tensor_to_dtype(tensor, 'fp32'), dtype=torch.float16
)
return torch_params
def read_json(path):
with open(path, "r") as f:
return json.load(f)
def write_json(text, path):
with open(path, "w") as f:
json.dump(text, f)
def write_model(loaded, model_path, model_size):
os.makedirs(model_path, exist_ok=True)
tmp_model_path = os.path.join(model_path, "tmp")
os.makedirs(tmp_model_path, exist_ok=True)
params = LLAMA_STANDARD_CONFIGS[model_size]
n_layers = params["n_layers"]
n_heads = params["n_heads"]
dim = params["dim"]
dims_per_head = dim // n_heads
base = 10000.0
inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head))
# permute for sliced rotary
def permute(w):
return w.view(n_heads, dim // n_heads // 2, 2, dim).transpose(1, 2).reshape(dim, dim)
param_count = 0
index_dict = {"weight_map": {}}
for layer_i in range(n_layers):
filename = f"pytorch_model-{layer_i + 1}-of-{n_layers + 1}.bin"
state_dict = {
f"model.layers.{layer_i}.self_attn.q_proj.weight": permute(
loaded[f"transformer.h.{layer_i}.attention.wq.kernel"]
),
f"model.layers.{layer_i}.self_attn.k_proj.weight": permute(
loaded[f"transformer.h.{layer_i}.attention.wk.kernel"]
),
f"model.layers.{layer_i}.self_attn.v_proj.weight": loaded[f"transformer.h.{layer_i}.attention.wv.kernel"],
f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded[f"transformer.h.{layer_i}.attention.wo.kernel"],
f"model.layers.{layer_i}.mlp.gate_proj.weight": loaded[f"transformer.h.{layer_i}.feed_forward.w1.kernel"],
f"model.layers.{layer_i}.mlp.down_proj.weight": loaded[f"transformer.h.{layer_i}.feed_forward.w2.kernel"],
f"model.layers.{layer_i}.mlp.up_proj.weight": loaded[f"transformer.h.{layer_i}.feed_forward.w3.kernel"],
f"model.layers.{layer_i}.input_layernorm.weight": loaded[f"transformer.h.{layer_i}.attention_norm.kernel"],
f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[f"transformer.h.{layer_i}.ffn_norm.kernel"],
}
state_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq
for k, v in state_dict.items():
index_dict["weight_map"][k] = filename
param_count += v.numel()
torch.save(state_dict, os.path.join(tmp_model_path, filename))
filename = f"pytorch_model-{n_layers + 1}-of-{n_layers + 1}.bin"
# Unsharded
state_dict = {
"model.embed_tokens.weight": loaded["transformer.wte.embedding"],
"model.norm.weight": loaded["transformer.ln_f.kernel"],
"lm_head.weight": loaded["lm_head.kernel"],
}
for k, v in state_dict.items():
index_dict["weight_map"][k] = filename
param_count += v.numel()
torch.save(state_dict, os.path.join(tmp_model_path, filename))
# Write configs
index_dict["metadata"] = {"total_size": param_count * 2}
write_json(index_dict, os.path.join(tmp_model_path, "pytorch_model.bin.index.json"))
config = LlamaConfig(
vocab_size=params["vocab_size"],
hidden_size=dim,
intermediate_size=params["intermediate_size"],
num_attention_heads=params["n_heads"],
num_hidden_layers=params["n_layers"],
rms_norm_eps=params["norm_eps"],
)
config.save_pretrained(tmp_model_path)
# Make space so we can load the model properly now.
del state_dict
del loaded
gc.collect()
print("Loading the checkpoint in a Llama model.")
model = LlamaForCausalLM.from_pretrained(tmp_model_path, torch_dtype=torch.float16)
# Avoid saving this as part of the config.
print("Model parameter count", model.num_parameters())
del model.config._name_or_path
print("Saving in the Transformers format.")
model.save_pretrained(model_path, safe_serialization=True)
shutil.rmtree(tmp_model_path)
def write_tokenizer(tokenizer_path, input_tokenizer_path):
print(f"Fetching the tokenizer from {input_tokenizer_path}.")
os.makedirs(tokenizer_path, exist_ok=True)
write_json(
{
"bos_token": {
"content": "<s>",
"lstrip": False,
"normalized": True,
"rstrip": False,
"single_word": False
},
"eos_token": {
"content": "</s>",
"lstrip": False,
"normalized": True,
"rstrip": False,
"single_word": False
},
"unk_token": {
"content": "<unk>",
"lstrip": False,
"normalized": True,
"rstrip": False,
"single_word": False
},
},
os.path.join(tokenizer_path, "special_tokens_map.json")
)
write_json(
{
"add_bos_token": True,
"add_eos_token": False,
"model_max_length": 2048,
"pad_token": None,
"sp_model_kwargs": {},
"tokenizer_class": "LlamaTokenizer",
"clean_up_tokenization_spaces": False,
"bos_token": {
"__type": "AddedToken",
"content": "<s>",
"lstrip": False,
"normalized": True,
"rstrip": False,
"single_word": False
},
"eos_token": {
"__type": "AddedToken",
"content": "</s>",
"lstrip": False,
"normalized": True,
"rstrip": False,
"single_word": False
},
"unk_token": {
"__type": "AddedToken",
"content": "<unk>",
"lstrip": False,
"normalized": True,
"rstrip": False,
"single_word": False
},
},
os.path.join(tokenizer_path, "tokenizer_config.json"),
)
shutil.copyfile(input_tokenizer_path, os.path.join(tokenizer_path, "tokenizer.model"))
def main(argv):
assert FLAGS.load_checkpoint != "" and FLAGS.output_dir != ""# and FLAGS.tokenizer_path != ""
assert FLAGS.model_size in LLAMA_STANDARD_CONFIGS
# write_tokenizer(
# tokenizer_path=FLAGS.output_dir,
# input_tokenizer_path=FLAGS.tokenizer_path,
# )
write_model(
load_and_convert_checkpoint(FLAGS.load_checkpoint),
model_path=FLAGS.output_dir,
model_size=FLAGS.model_size,
)
if __name__ == "__main__":
mlxu.run(main)

View File

@@ -0,0 +1,196 @@
"""
Usage:
python convert_hf_to_easylm.py \
--checkpoint_dir /path/hf_format_dir/ \
--output_file /path/easylm_format.stream \
--model_size 7b \
--streaming
"""
import time
from pathlib import Path
import argparse
import mlxu
import torch
import flax
from EasyLM.checkpoint import StreamingCheckpointer
LLAMA_STANDARD_CONFIGS = {
'1b': {
'dim': 2048,
'intermediate_size': 5504,
'n_layers': 22,
'n_heads': 16,
'norm_eps': 1e-6,
},
'3b': {
'dim': 3200,
'intermediate_size': 8640,
'n_layers': 26,
'n_heads': 32,
'norm_eps': 1e-6,
},
"7b": {
"dim": 4096,
"intermediate_size": 11008,
"n_layers": 32,
"n_heads": 32,
"norm_eps": 1e-6,
},
"13b": {
"dim": 5120,
"intermediate_size": 13824,
"n_layers": 40,
"n_heads": 40,
"norm_eps": 1e-6,
},
"30b": {
"dim": 6656,
"intermediate_size": 17920,
"n_layers": 60,
"n_heads": 52,
"norm_eps": 1e-6,
},
"65b": {
"dim": 8192,
"intermediate_size": 22016,
"n_layers": 80,
"n_heads": 64,
"norm_eps": 1e-5,
},
}
def inverse_permute(params, w):
n_layers = params["n_layers"]
n_heads = params["n_heads"]
dim = params["dim"]
reshaped_w = w.reshape(n_heads, 2, dim // n_heads // 2, dim)
transposed_w = reshaped_w.transpose(0, 2, 1, 3)
inverted_w = transposed_w.reshape(dim, dim)
return inverted_w
def main(args):
start = time.time()
params = LLAMA_STANDARD_CONFIGS[args.model_size]
ckpt_paths = sorted(Path(args.checkpoint_dir).glob("*.bin"))
ckpt = {}
for i, ckpt_path in enumerate(ckpt_paths):
checkpoint = torch.load(ckpt_path, map_location="cpu")
for k, v in checkpoint.items():
if k.startswith("model."):
k = k[6:]
ckpt[k] = v
print(f"Start convert weight to easylm format...")
jax_weights = {
"transformer": {
"wte": {"embedding": ckpt["embed_tokens.weight"].numpy()},
"ln_f": {"kernel": ckpt["norm.weight"].numpy()},
"h": {
"%d"
% (layer): {
"attention": {
"wq": {
"kernel": inverse_permute(
params,
ckpt[f"layers.{layer}.self_attn.q_proj.weight"].numpy(),
).transpose()
},
"wk": {
"kernel": inverse_permute(
params,
ckpt[f"layers.{layer}.self_attn.k_proj.weight"].numpy(),
).transpose()
},
"wv": {
"kernel": ckpt[f"layers.{layer}.self_attn.v_proj.weight"]
.numpy()
.transpose()
},
"wo": {
"kernel": ckpt[f"layers.{layer}.self_attn.o_proj.weight"]
.numpy()
.transpose()
},
},
"feed_forward": {
"w1": {
"kernel": ckpt[f"layers.{layer}.mlp.gate_proj.weight"]
.numpy()
.transpose()
},
"w2": {
"kernel": ckpt[f"layers.{layer}.mlp.down_proj.weight"]
.numpy()
.transpose()
},
"w3": {
"kernel": ckpt[f"layers.{layer}.mlp.up_proj.weight"]
.numpy()
.transpose()
},
},
"attention_norm": {
"kernel": ckpt[f"layers.{layer}.input_layernorm.weight"].numpy()
},
"ffn_norm": {
"kernel": ckpt[
f"layers.{layer}.post_attention_layernorm.weight"
].numpy()
},
}
for layer in range(params["n_layers"])
},
},
"lm_head": {"kernel": ckpt["lm_head.weight"].numpy().transpose()},
}
print(f"Convert weight to easylm format finished...")
print(f"Start to save...")
if args.streaming:
StreamingCheckpointer.save_train_state_to_file(jax_weights, args.output_file)
else:
with mlxu.open_file(args.output_file, "wb") as fout:
fout.write(flax.serialization.msgpack_serialize(jax_weights, in_place=True))
print(
f"Save finished!!! take time: {time.time() - start} save path: {args.output_file}"
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="hf to easylm format script")
parser.add_argument(
"--checkpoint_dir",
type=str,
help="Need to be converted model weight dir. it is a dir",
)
parser.add_argument(
"--output_file", type=str, help="Save model weight file path, it is a file."
)
parser.add_argument(
"--model_size",
type=str,
default="7b",
choices=["7b", "13b", "30b", "65b"],
help="model size",
)
parser.add_argument(
"--streaming",
action="store_true",
default=True,
help="whether is model weight saved stream format",
)
args = parser.parse_args()
print(f"checkpoint_dir: {args.checkpoint_dir}")
print(f"output_file: {args.output_file}")
print(f"model_size: {args.model_size}")
print(f"streaming: {args.streaming}")
main(args)

View File

@@ -0,0 +1,68 @@
# This script converts the standrd LLaMA PyTorch checkpoint released by Meta
# to the EasyLM checkpoint format. The converted checkpoint can then be loaded
# by EasyLM for fine-tuning or inference.
# This script is largely borrow from https://github.com/Sea-Snell/JAX_llama
from pathlib import Path
import json
import numpy as np
import torch
import flax
import mlxu
from EasyLM.checkpoint import StreamingCheckpointer
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
checkpoint_dir='',
output_file='',
streaming=True,
)
def main(argv):
ckpt_paths = sorted(Path(FLAGS.checkpoint_dir).glob("*.pth"))
ckpts = {}
for i, ckpt_path in enumerate(ckpt_paths):
checkpoint = torch.load(ckpt_path, map_location="cpu")
ckpts[int(ckpt_path.name.split('.', maxsplit=2)[1])] = checkpoint
ckpts = [ckpts[i] for i in sorted(list(ckpts.keys()))]
with open(Path(FLAGS.checkpoint_dir) / "params.json", "r") as f:
params = json.loads(f.read())
jax_weights = {
'transformer': {
'wte': {'embedding': np.concatenate([ckpt['tok_embeddings.weight'].numpy() for ckpt in ckpts], axis=1)},
'ln_f': {'kernel': ckpts[0]['norm.weight'].numpy()},
'h': {
'%d' % (layer): {
'attention': {
'wq': {'kernel': np.concatenate([ckpt['layers.%d.attention.wq.weight' % (layer)].numpy() for ckpt in ckpts], axis=0).transpose()},
'wk': {'kernel': np.concatenate([ckpt['layers.%d.attention.wk.weight' % (layer)].numpy() for ckpt in ckpts], axis=0).transpose()},
'wv': {'kernel': np.concatenate([ckpt['layers.%d.attention.wv.weight' % (layer)].numpy() for ckpt in ckpts], axis=0).transpose()},
'wo': {'kernel': np.concatenate([ckpt['layers.%d.attention.wo.weight' % (layer)].numpy() for ckpt in ckpts], axis=1).transpose()},
},
'feed_forward': {
'w1': {'kernel': np.concatenate([ckpt['layers.%d.feed_forward.w1.weight' % (layer)].numpy() for ckpt in ckpts], axis=0).transpose()},
'w2': {'kernel': np.concatenate([ckpt['layers.%d.feed_forward.w2.weight' % (layer)].numpy() for ckpt in ckpts], axis=1).transpose()},
'w3': {'kernel': np.concatenate([ckpt['layers.%d.feed_forward.w3.weight' % (layer)].numpy() for ckpt in ckpts], axis=0).transpose()},
},
'attention_norm': {'kernel': ckpts[0]['layers.%d.attention_norm.weight' % (layer)].numpy()},
'ffn_norm': {'kernel': ckpts[0]['layers.%d.ffn_norm.weight' % (layer)].numpy()},
}
for layer in range(params['n_layers'])},
},
'lm_head': {'kernel': np.concatenate([ckpt['output.weight'].numpy() for ckpt in ckpts], axis=0).transpose()},
}
if FLAGS.streaming:
StreamingCheckpointer.save_train_state_to_file(
jax_weights, FLAGS.output_file
)
else:
with mlxu.open_file(FLAGS.output_file, 'wb') as fout:
fout.write(flax.serialization.msgpack_serialize(jax_weights, in_place=True))
if __name__ == '__main__':
mlxu.run(main)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,386 @@
import pprint
from functools import partial
import numpy as np
import mlxu
import jax
import jax.numpy as jnp
from jax.experimental.pjit import pjit
from jax.sharding import PartitionSpec as PS
import optax
from transformers import GenerationConfig, FlaxLogitsProcessorList
from EasyLM.checkpoint import StreamingCheckpointer
from EasyLM.serving import LMServer
from EasyLM.jax_utils import (
JaxRNG, JaxDistributedConfig, next_rng, match_partition_rules, tree_apply,
set_random_seed, get_float_dtype_by_name, make_shard_and_gather_fns,
with_sharding_constraint, FlaxTemperatureLogitsWarper
)
from EasyLM.models.llama.llama_model import LLaMAConfig, FlaxLLaMAForCausalLM
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
seed=42,
initialize_jax_distributed=False,
mesh_dim='1,-1,1',
dtype='bf16',
input_length=1024,
seq_length=2048,
top_k=50,
top_p=1.0,
do_sample=True,
num_beams=1,
add_bos_token=True,
load_llama_config='',
load_checkpoint='',
tokenizer=LLaMAConfig.get_tokenizer_config(),
lm_server=LMServer.get_default_config(),
jax_distributed=JaxDistributedConfig.get_default_config(),
)
def main(argv):
JaxDistributedConfig.initialize(FLAGS.jax_distributed)
set_random_seed(FLAGS.seed)
prefix_tokenizer = LLaMAConfig.get_tokenizer(
FLAGS.tokenizer, truncation_side='left', padding_side='left'
)
tokenizer = LLaMAConfig.get_tokenizer(
FLAGS.tokenizer, truncation_side='right', padding_side='right'
)
with jax.default_device(jax.devices("cpu")[0]):
llama_config = LLaMAConfig.load_config(FLAGS.load_llama_config)
_, params = StreamingCheckpointer.load_trainstate_checkpoint(
FLAGS.load_checkpoint, disallow_trainstate=True
)
hf_model = FlaxLLaMAForCausalLM(
llama_config,
input_shape=(1, FLAGS.seq_length),
seed=FLAGS.seed,
_do_init=False
)
model_ps = match_partition_rules(
LLaMAConfig.get_partition_rules(), params
)
shard_fns, _ = make_shard_and_gather_fns(
model_ps, get_float_dtype_by_name(FLAGS.dtype)
)
@partial(
pjit,
in_shardings=(model_ps, PS(), PS()),
out_shardings=(PS(), PS(), PS())
)
def forward_loglikelihood(params, rng, batch):
batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
rng_generator = JaxRNG(rng)
input_tokens = batch['input_tokens']
output_tokens = batch['output_tokens']
input_mask = batch['input_mask']
output_mask = batch['output_mask']
logits = hf_model.module.apply(
params, input_tokens, attention_mask=input_mask,
deterministic=True, rngs=rng_generator(llama_config.rng_keys()),
).logits
# if llama_config.n_real_tokens is not None:
# logits = logits.at[:, :, llama_config.n_real_tokens:].set(-1e8)
loglikelihood = -optax.softmax_cross_entropy_with_integer_labels(
logits, output_tokens
)
loglikelihood = jnp.sum(loglikelihood * output_mask, axis=-1)
match_count = jnp.sum(
(jnp.argmax(logits, axis=-1) == output_tokens) * output_mask,
axis=-1
)
total = jnp.sum(output_mask, axis=-1)
is_greedy = match_count == total
return loglikelihood, is_greedy, rng_generator()
@partial(
pjit,
in_shardings=(model_ps, PS(), PS(), PS()),
out_shardings=(PS(), PS())
)
def forward_generate(params, rng, batch, temperature):
batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
rng_generator = JaxRNG(rng)
output = hf_model.generate(
batch['input_tokens'],
attention_mask=batch['attention_mask'],
params=params['params'],
prng_key=rng_generator(),
logits_processor=FlaxLogitsProcessorList(
[FlaxTemperatureLogitsWarper(temperature)]
),
generation_config=GenerationConfig(
max_new_tokens=FLAGS.seq_length - FLAGS.input_length,
pad_token_id=tokenizer.eos_token_id,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
do_sample=FLAGS.do_sample,
num_beams=FLAGS.num_beams,
top_k=FLAGS.top_k,
top_p=FLAGS.top_p,
)
).sequences[:, batch['input_tokens'].shape[1]:]
return output, rng_generator()
@partial(
pjit,
in_shardings=(model_ps, PS(), PS()),
out_shardings=(PS(), PS())
)
def forward_greedy_generate(params, rng, batch):
batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
rng_generator = JaxRNG(rng)
output = hf_model.generate(
batch['input_tokens'],
attention_mask=batch['attention_mask'],
params=params['params'],
prng_key=rng_generator(),
generation_config=GenerationConfig(
max_new_tokens=FLAGS.seq_length - FLAGS.input_length,
pad_token_id=tokenizer.eos_token_id,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
do_sample=False,
num_beams=1,
)
).sequences[:, batch['input_tokens'].shape[1]:]
return output, rng_generator()
mesh = LLaMAConfig.get_jax_mesh(FLAGS.mesh_dim)
with mesh:
params = tree_apply(shard_fns, params)
sharded_rng = next_rng()
class ModelServer(LMServer):
@staticmethod
def loglikelihood(prefix_text, text):
nonlocal sharded_rng
prefix = prefix_tokenizer(
prefix_text,
padding='max_length',
truncation=True,
max_length=FLAGS.input_length,
return_tensors='np',
)
inputs = tokenizer(
text,
padding='max_length',
truncation=True,
max_length=FLAGS.seq_length - FLAGS.input_length,
return_tensors='np',
)
output_tokens = np.concatenate([prefix.input_ids, inputs.input_ids], axis=1)
bos_tokens = np.full(
(output_tokens.shape[0], 1), tokenizer.bos_token_id, dtype=np.int32
)
input_tokens = np.concatenate([bos_tokens, output_tokens[:, :-1]], axis=-1)
input_mask = np.concatenate(
[prefix.attention_mask, inputs.attention_mask], axis=1
)
if FLAGS.add_bos_token:
bos_mask = np.ones_like(input_mask[:, :1])
else:
bos_mask = np.zeros_like(input_mask[:, :1])
input_mask = np.concatenate([bos_mask, input_mask[:, :-1]], axis=1)
output_mask = np.concatenate(
[np.zeros_like(prefix.attention_mask), inputs.attention_mask], axis=1
)
batch = dict(
input_tokens=input_tokens,
output_tokens=output_tokens,
input_mask=input_mask,
output_mask=output_mask,
)
with mesh:
loglikelihood, is_greedy, sharded_rng = forward_loglikelihood(
params, sharded_rng, batch
)
loglikelihood, is_greedy = jax.device_get((loglikelihood, is_greedy))
return loglikelihood, is_greedy
@staticmethod
def loglikelihood_rolling(text):
nonlocal sharded_rng
inputs = tokenizer(
text,
padding='longest',
truncation=False,
max_length=np.iinfo(np.int32).max,
return_tensors='np',
)
batch_size = inputs.input_ids.shape[0]
output_tokens = inputs.input_ids
attention_mask = inputs.attention_mask
if output_tokens.shape[1] < FLAGS.seq_length:
padding_length = FLAGS.seq_length - output_tokens.shape[1]
pad_tokens = np.full(
(batch_size, padding_length), tokenizer.pad_token_id, dtype=np.int32
)
output_tokens = np.concatenate([output_tokens, pad_tokens], axis=-1)
pad_mask = np.zeros(
(batch_size, padding_length), dtype=inputs.attention_mask.dtype
)
attention_mask = np.concatenate([attention_mask, pad_mask], axis=-1)
bos_tokens = np.full(
(batch_size, 1), tokenizer.bos_token_id, dtype=np.int32
)
input_tokens = np.concatenate([bos_tokens, output_tokens[:, :-1]], axis=-1)
bos_mask = np.ones((batch_size, 1), dtype=inputs.attention_mask.dtype)
total_seq_length = output_tokens.shape[1]
total_loglikelihood = 0.0
total_is_greedy = True
# Sliding window
for i in range(0, total_seq_length, FLAGS.seq_length):
# Last window
if i + FLAGS.seq_length > total_seq_length:
last_output_mask = np.copy(attention_mask[:, -FLAGS.seq_length:])
last_output_mask[:, :i - total_seq_length] = 0.0
batch = dict(
input_tokens=input_tokens[:, -FLAGS.seq_length:],
output_tokens=output_tokens[:, -FLAGS.seq_length:],
input_mask=attention_mask[:, -FLAGS.seq_length:],
output_mask=last_output_mask,
)
# Normal window
else:
batch = dict(
input_tokens=input_tokens[:, i:i + FLAGS.seq_length],
output_tokens=output_tokens[:, i:i + FLAGS.seq_length],
input_mask=attention_mask[:, i:i + FLAGS.seq_length],
output_mask=attention_mask[:, i:i + FLAGS.seq_length],
)
with mesh:
loglikelihood, is_greedy, sharded_rng = forward_loglikelihood(
params, sharded_rng, batch
)
loglikelihood, is_greedy = jax.device_get((loglikelihood, is_greedy))
total_loglikelihood += loglikelihood
total_is_greedy = np.logical_and(is_greedy, total_is_greedy)
return total_loglikelihood, total_is_greedy
@staticmethod
def generate(text, temperature):
nonlocal sharded_rng
inputs = prefix_tokenizer(
text,
padding='max_length',
truncation=True,
max_length=FLAGS.input_length,
return_tensors='np',
)
input_tokens = inputs.input_ids
input_mask = inputs.attention_mask
if FLAGS.add_bos_token:
input_tokens[:, 0] = tokenizer.bos_token_id
input_mask[:, 0] = 1
batch = dict(
input_tokens=input_tokens,
attention_mask=input_mask,
)
with mesh:
output, sharded_rng = forward_generate(
params, sharded_rng, batch, temperature
)
output = jax.device_get(output)
output_text = []
for text in list(tokenizer.batch_decode(output)):
if tokenizer.eos_token in text:
text = text.split(tokenizer.eos_token, maxsplit=1)[0]
output_text.append(text)
return output_text
@staticmethod
def greedy_until(prefix_text, until, max_length):
nonlocal sharded_rng
all_outputs = []
for pf, ut in zip(prefix_text, until):
if isinstance(ut, str):
ut = [ut]
total_length = 0
total_generated = ''
while total_length < max_length:
pf_tokens = tokenizer(
pf,
padding=False,
truncation=False,
max_length=np.iinfo(np.int32).max,
return_tensors='np',
)
input_tokens = pf_tokens.input_ids
attention_mask = pf_tokens.attention_mask
if input_tokens.shape[1] < FLAGS.input_length:
extra = FLAGS.input_length - input_tokens.shape[1]
pad_tokens = np.full(
(1, extra), tokenizer.pad_token_id, dtype=np.int32
)
input_tokens = np.concatenate(
[pad_tokens, input_tokens], axis=1
)
pad_attention = np.zeros((1, extra), dtype=attention_mask.dtype)
attention_mask = np.concatenate(
[pad_attention, attention_mask], axis=1
)
elif input_tokens.shape[1] > FLAGS.input_length:
input_tokens = input_tokens[:, -FLAGS.input_length:]
attention_mask = attention_mask[:, -FLAGS.input_length:]
if FLAGS.add_bos_token:
input_tokens[:, 0] = tokenizer.bos_token_id
attention_mask[:, 0] = 1
batch = dict(input_tokens=input_tokens, attention_mask=attention_mask)
with mesh:
output, sharded_rng = forward_greedy_generate(
params, sharded_rng, batch
)
output = jax.device_get(output)
total_length += output.shape[1]
output_text = tokenizer.batch_decode(output)[0]
total_generated = total_generated + output_text
pf = pf + output_text
done = False
for s in ut:
if s in total_generated:
total_generated = total_generated.split(s, maxsplit=1)[0]
done = True
if done:
break
all_outputs.append(total_generated)
return all_outputs
server = ModelServer(FLAGS.lm_server)
server.run()
if __name__ == "__main__":
mlxu.run(main)

View File

@@ -0,0 +1,268 @@
import pprint
from functools import partial
from tqdm import tqdm, trange
import numpy as np
import mlxu
import jax
import jax.numpy as jnp
from jax.experimental.pjit import pjit
from jax.sharding import PartitionSpec as PS
from flax.training.train_state import TrainState
from EasyLM.data import DatasetFactory
from EasyLM.checkpoint import StreamingCheckpointer
from EasyLM.optimizers import OptimizerFactory
from EasyLM.jax_utils import (
JaxRNG, JaxDistributedConfig, next_rng, match_partition_rules,
cross_entropy_loss_and_accuracy, global_norm, get_float_dtype_by_name,
set_random_seed, average_metrics, get_weight_decay_mask,
make_shard_and_gather_fns, with_sharding_constraint,
)
from EasyLM.models.llama.llama_model import (
LLaMAConfig, FlaxLLaMAForCausalLMModule
)
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
seed=42,
mesh_dim='1,-1,1',
dtype='fp32',
param_dtype='fp32',
total_steps=10000,
load_llama_config='',
update_llama_config='',
load_checkpoint='',
load_dataset_state='',
log_freq=50,
save_model_freq=0,
save_milestone_freq=0,
eval_freq=0,
tokenizer=LLaMAConfig.get_tokenizer_config(),
train_dataset=DatasetFactory.get_default_config(),
eval_dataset=DatasetFactory.get_default_config(),
optimizer=OptimizerFactory.get_default_config(),
checkpointer=StreamingCheckpointer.get_default_config(),
llama=LLaMAConfig.get_default_config(),
logger=mlxu.WandBLogger.get_default_config(),
log_all_worker=False,
jax_distributed=JaxDistributedConfig.get_default_config(),
)
def main(argv):
JaxDistributedConfig.initialize(FLAGS.jax_distributed)
variant = mlxu.get_user_flags(FLAGS, FLAGS_DEF)
flags_config_dict = mlxu.user_flags_to_config_dict(FLAGS, FLAGS_DEF)
logger = mlxu.WandBLogger(
config=FLAGS.logger,
variant=variant,
enable=FLAGS.log_all_worker or (jax.process_index() == 0),
)
set_random_seed(FLAGS.seed)
tokenizer = LLaMAConfig.get_tokenizer(FLAGS.tokenizer)
dataset = DatasetFactory.load_dataset(FLAGS.train_dataset, tokenizer)
if FLAGS.load_dataset_state != '':
dataset.load_state_dict(mlxu.load_pickle(FLAGS.load_dataset_state))
if FLAGS.eval_freq > 0:
eval_dataset = DatasetFactory.load_dataset(
FLAGS.eval_dataset, dataset.tokenizer, eval_dataset=True
)
seq_length = dataset.seq_length
if FLAGS.load_llama_config != '':
llama_config = LLaMAConfig.load_config(FLAGS.load_llama_config)
else:
llama_config = LLaMAConfig(**FLAGS.llama)
if FLAGS.update_llama_config != '':
llama_config.update(dict(eval(FLAGS.update_llama_config)))
llama_config.update(dict(
bos_token_id=dataset.tokenizer.bos_token_id,
eos_token_id=dataset.tokenizer.eos_token_id,
))
if llama_config.vocab_size < dataset.vocab_size:
print("Updating model config vocab size from", llama_config.vocab_size, "to", dataset.vocab_size)
llama_config.update(dict(vocab_size=dataset.vocab_size))
model = FlaxLLaMAForCausalLMModule(
llama_config, dtype=get_float_dtype_by_name(FLAGS.dtype), param_dtype=get_float_dtype_by_name(FLAGS.param_dtype)
)
optimizer, optimizer_info = OptimizerFactory.get_optimizer(
FLAGS.optimizer,
get_weight_decay_mask(LLaMAConfig.get_weight_decay_exclusions())
)
def create_trainstate_from_params(params):
return TrainState.create(params=params, tx=optimizer, apply_fn=None)
def init_fn(rng):
rng_generator = JaxRNG(rng)
params = model.init(
input_ids=jnp.zeros((4, seq_length), dtype=jnp.int32),
position_ids=jnp.zeros((4, seq_length), dtype=jnp.int32),
attention_mask=jnp.ones((4, seq_length), dtype=jnp.int32),
rngs=rng_generator(llama_config.rng_keys()),
)
return TrainState.create(params=params, tx=optimizer, apply_fn=None)
def train_step(train_state, rng, batch):
rng_generator = JaxRNG(rng)
batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
def loss_and_accuracy(params):
logits = model.apply(
params, batch['input_tokens'], deterministic=False,
rngs=rng_generator(llama_config.rng_keys()),
).logits
return cross_entropy_loss_and_accuracy(
logits, batch['target_tokens'], batch['loss_masks']
)
grad_fn = jax.value_and_grad(loss_and_accuracy, has_aux=True)
(loss, accuracy), grads = grad_fn(train_state.params)
train_state = train_state.apply_gradients(grads=grads)
metrics = dict(
loss=loss,
accuracy=accuracy,
learning_rate=optimizer_info['learning_rate_schedule'](train_state.step),
gradient_norm=global_norm(grads),
param_norm=global_norm(train_state.params),
)
return train_state, rng_generator(), metrics
def eval_step(train_state, rng, batch):
rng_generator = JaxRNG(rng)
batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
logits = model.apply(
train_state.params, batch['input_tokens'], deterministic=True,
rngs=rng_generator(llama_config.rng_keys()),
).logits
loss, accuracy = cross_entropy_loss_and_accuracy(
logits, batch['target_tokens'], batch['loss_masks']
)
metrics = dict(
eval_loss=loss,
eval_accuracy=accuracy,
)
return rng_generator(), metrics
train_state_shapes = jax.eval_shape(init_fn, next_rng())
train_state_partition = match_partition_rules(
LLaMAConfig.get_partition_rules(), train_state_shapes
)
shard_fns, gather_fns = make_shard_and_gather_fns(
train_state_partition, train_state_shapes
)
checkpointer = StreamingCheckpointer(
FLAGS.checkpointer, logger.output_dir,
enable=jax.process_index() == 0,
)
sharded_init_fn = pjit(
init_fn,
in_shardings=PS(),
out_shardings=train_state_partition
)
sharded_create_trainstate_from_params = pjit(
create_trainstate_from_params,
in_shardings=(train_state_partition.params, ),
out_shardings=train_state_partition,
donate_argnums=(0, ),
)
sharded_train_step = pjit(
train_step,
in_shardings=(train_state_partition, PS(), PS()),
out_shardings=(train_state_partition, PS(), PS()),
donate_argnums=(0, 1),
)
sharded_eval_step = pjit(
eval_step,
in_shardings=(train_state_partition, PS(), PS()),
out_shardings=(PS(), PS()),
donate_argnums=(1,),
)
def save_checkpoint(train_state, milestone=False):
step = int(jax.device_get(train_state.step))
metadata = dict(
step=step,
variant=variant,
flags=flags_config_dict,
llama_config=llama_config.to_dict(),
)
checkpointer.save_all(
train_state=train_state,
gather_fns=gather_fns,
metadata=metadata,
dataset=dataset.get_state_dict(),
milestone=milestone,
)
mesh = LLaMAConfig.get_jax_mesh(FLAGS.mesh_dim)
with mesh:
train_state, restored_params = None, None
if FLAGS.load_checkpoint != '':
train_state, restored_params = checkpointer.load_trainstate_checkpoint(
FLAGS.load_checkpoint, train_state_shapes, shard_fns
)
if train_state is None and restored_params is None:
# Initialize from scratch
train_state = sharded_init_fn(next_rng())
elif train_state is None and restored_params is not None:
# Restore from params but initialize train_state
train_state = sharded_create_trainstate_from_params(restored_params)
del restored_params
start_step = int(jax.device_get(train_state.step))
if FLAGS.save_model_freq > 0:
save_checkpoint(train_state)
sharded_rng = next_rng()
step_counter = trange(start_step, FLAGS.total_steps, ncols=0)
for step, (batch, dataset_metrics) in zip(step_counter, dataset):
train_state, sharded_rng, metrics = sharded_train_step(
train_state, sharded_rng, batch
)
if FLAGS.eval_freq > 0 and (step + 1) % FLAGS.eval_freq == 0:
eval_metric_list = []
eval_iterator = iter(eval_dataset)
for eval_batch, _ in eval_iterator:
sharded_rng, eval_metrics = sharded_eval_step(
train_state, sharded_rng, eval_batch
)
eval_metric_list.append(eval_metrics)
metrics.update(average_metrics(eval_metric_list))
if FLAGS.log_freq > 0 and (step + 1) % FLAGS.log_freq == 0:
log_metrics = {"step": step + 1}
log_metrics.update(metrics)
log_metrics.update(dataset_metrics)
log_metrics = jax.device_get(log_metrics)
logger.log(log_metrics)
tqdm.write("\n" + pprint.pformat(log_metrics) + "\n")
if FLAGS.save_milestone_freq > 0 and (step + 1) % FLAGS.save_milestone_freq == 0:
save_checkpoint(train_state, milestone=True)
elif FLAGS.save_model_freq > 0 and (step + 1) % FLAGS.save_model_freq == 0:
save_checkpoint(train_state)
if FLAGS.save_model_freq > 0:
save_checkpoint(train_state)
if __name__ == "__main__":
mlxu.run(main)