初始化项目,由ModelHub XC社区提供模型

Model: Finnish-NLP/Ahma-7B
Source: Original Platform
This commit is contained in:
ModelHub XC
2026-06-01 02:08:18 +08:00
commit be39ad8722
45 changed files with 297486 additions and 0 deletions

35
.gitattributes vendored Normal file
View File

@@ -0,0 +1,35 @@
*.7z filter=lfs diff=lfs merge=lfs -text
*.arrow filter=lfs diff=lfs merge=lfs -text
*.bin filter=lfs diff=lfs merge=lfs -text
*.bz2 filter=lfs diff=lfs merge=lfs -text
*.ckpt filter=lfs diff=lfs merge=lfs -text
*.ftz filter=lfs diff=lfs merge=lfs -text
*.gz filter=lfs diff=lfs merge=lfs -text
*.h5 filter=lfs diff=lfs merge=lfs -text
*.joblib filter=lfs diff=lfs merge=lfs -text
*.lfs.* filter=lfs diff=lfs merge=lfs -text
*.mlmodel filter=lfs diff=lfs merge=lfs -text
*.model filter=lfs diff=lfs merge=lfs -text
*.msgpack filter=lfs diff=lfs merge=lfs -text
*.npy filter=lfs diff=lfs merge=lfs -text
*.npz filter=lfs diff=lfs merge=lfs -text
*.onnx filter=lfs diff=lfs merge=lfs -text
*.ot filter=lfs diff=lfs merge=lfs -text
*.parquet filter=lfs diff=lfs merge=lfs -text
*.pb filter=lfs diff=lfs merge=lfs -text
*.pickle filter=lfs diff=lfs merge=lfs -text
*.pkl filter=lfs diff=lfs merge=lfs -text
*.pt filter=lfs diff=lfs merge=lfs -text
*.pth filter=lfs diff=lfs merge=lfs -text
*.rar filter=lfs diff=lfs merge=lfs -text
*.safetensors filter=lfs diff=lfs merge=lfs -text
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.tar.* filter=lfs diff=lfs merge=lfs -text
*.tar filter=lfs diff=lfs merge=lfs -text
*.tflite filter=lfs diff=lfs merge=lfs -text
*.tgz filter=lfs diff=lfs merge=lfs -text
*.wasm filter=lfs diff=lfs merge=lfs -text
*.xz filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text

1
.gitignore vendored Normal file
View File

@@ -0,0 +1 @@
__pycache__

0
EasyLM/__init__.py Normal file
View File

228
EasyLM/bpt.py Normal file
View File

@@ -0,0 +1,228 @@
"""
An implementation of Blockwise parallel transformer https://arxiv.org/abs/2305.19370
Also include a reference implementation of memory-efficient transformer https://arxiv.org/abs/2112.05682
"""
import functools
from typing import NamedTuple
import flax.linen as nn
import jax
import jax.lax as lax
import jax.numpy as jnp
from einops import rearrange
"""
Computing ffn blockwise without materializing the large hidden tensor, training
4x longer sequences than the memory-efficient transformer.
Blockwise parallel transformer https://arxiv.org/abs/2305.19370 Liu et al. 2023
"""
def blockwise_ffn(remat_ffn, inputs, chunk_size=2048, deterministic=True):
# remat_ffn: a rematerialized ffn with policy jax.checkpoint_policies.nothing_saveable()
# inputs: (batch, seq_len, dim)
# chunk_size: the chunk size to split the sequence
inputs = rearrange(inputs, 'b (c n) d -> b c n d', c=chunk_size)
def scan_ffn(remat_ffn, carry, hidden_states):
outputs = remat_ffn(hidden_states, deterministic=deterministic)
return carry, outputs
scan_axis = inputs.ndim - 2
_, res = nn.scan(
scan_ffn,
variable_broadcast="params",
split_rngs={"params": False, "dropout": True},
in_axes=scan_axis,
out_axes=scan_axis,
)(remat_ffn, None, inputs)
res = rearrange(res, 'b c n d -> b (c n) d')
return res
"""
Compute attention blockwise without materializing the full attention matrix,
initially proposed in memory-efficient transformer https://arxiv.org/abs/2112.05682 Rabe et al. 2021;
flash attention https://arxiv.org/abs/2205.14135 Dao et al. 2022 proposes a CUDA
efficient implementation; blockwise parallel transformer https://arxiv.org/abs/2305.19370
Liu et al. 2023 proposes blockwise computing both attention and FFN, enabling 4x
longer sequences than memory-efficient/flash-attention and fusion of attention and FFN.
"""
def blockwise_attn(
query, key, value,
bias=None,
deterministic=True,
dropout_rng=None,
attn_pdrop=0.0,
causal=True,
query_chunk_size=2048,
key_chunk_size=2048,
dtype=jnp.float32,
policy=jax.checkpoint_policies.nothing_saveable(),
precision=None,
float32_logits=True,
prevent_cse=True,
):
# query, key, value: (batch, seq_len, num_heads, dim_per_head)
# bias: (batch, seq_len) can be used to mask out attention (e.g. padding)
# causal: whether to use causal mask
# policy: one of jax.checkpoint_policies
query = query / jnp.sqrt(query.shape[-1]).astype(dtype)
if float32_logits:
query = query.astype(jnp.float32)
key = key.astype(jnp.float32)
batch, q_len, num_heads, dim_per_head = query.shape
batch, kv_len, num_heads, dim_per_head = key.shape
batch, kv_len, num_heads, dim_per_head = value.shape
num_q = q_len // query_chunk_size
num_kv = kv_len // key_chunk_size
query = query.reshape((batch, num_q, query_chunk_size, num_heads, dim_per_head))
key = key.reshape((batch, num_kv, key_chunk_size, num_heads, dim_per_head))
value = value.reshape((batch, num_kv, key_chunk_size, num_heads, dim_per_head))
query = jnp.moveaxis(query, 1, 0)
key = jnp.moveaxis(key, 1, 0)
value = jnp.moveaxis(value, 1, 0)
if bias is not None:
for bias_dim, broadcast_dim in zip(bias.shape, (batch, num_heads, q_len, kv_len)):
assert bias_dim == 1 or bias_dim == broadcast_dim
if not deterministic and attn_pdrop > 0.0:
attn_dropout_rng, dropout_rng = jax.random.split(dropout_rng)
attn_dropout = jax.random.bernoulli(attn_dropout_rng, attn_pdrop, (batch, num_heads, q_len, kv_len))
else:
attn_dropout = None
_chunk_bias_fn = functools.partial(
_chunk_attention_bias,
query_chunk_size, key_chunk_size, bias, deterministic,
attn_dropout, attn_pdrop, causal, dtype)
def scan_attention(args):
query_chunk, query_chunk_idx = args
@functools.partial(jax.checkpoint, prevent_cse=prevent_cse, policy=policy)
def scan_kv_block(carry, args):
key_chunk, value_chunk, key_chunk_idx = args
(numerator, denominator, prev_max_score) = carry
attn_weights = jnp.einsum('bqhd,bkhd->bqhk', query_chunk, key_chunk, precision=precision)
bias_chunk = _chunk_bias_fn(query_chunk_idx, key_chunk_idx)
bias_chunk = jnp.moveaxis(bias_chunk, 1, 2)
attn_weights = attn_weights + bias_chunk
max_score = jnp.max(attn_weights, axis=-1, keepdims=True)
max_score = jnp.maximum(prev_max_score, max_score)
max_score = jax.lax.stop_gradient(max_score)
exp_weights = jnp.exp(attn_weights - max_score)
exp_values = jnp.einsum(
'bqhv,bvhd->bqhd', exp_weights, value_chunk, precision=precision
)
correction = jnp.exp(prev_max_score - max_score)
numerator = numerator * correction + exp_values
denominator = denominator * correction + exp_weights.sum(axis=-1, keepdims=True)
return Carry(numerator, denominator, max_score), None
def skip_upper_half(carry, args):
key_chunk, value_chunk, key_chunk_idx = args
skip_block = jnp.array(False)
if causal:
skip_block = query_chunk_idx < key_chunk_idx
return jax.lax.cond(
skip_block,
lambda carry, args: (carry, None),
scan_kv_block,
carry,
args,
)
init_carry = Carry(
jnp.zeros((batch, query_chunk_size, num_heads, dim_per_head), dtype=query.dtype),
jnp.zeros((batch, query_chunk_size, num_heads, dim_per_head), dtype=query.dtype),
(-jnp.inf) * jnp.ones((batch, query_chunk_size, num_heads, 1), dtype=query.dtype),
)
(numerator, denominator, max_score), _ = lax.scan(
skip_upper_half, init_carry, xs=(key, value, jnp.arange(0, num_kv))
)
outputs = (numerator / denominator).astype(dtype)
return outputs
_, res = lax.scan(
lambda _, x: ((), scan_attention(x)),
(), xs=(query, jnp.arange(0, num_q))
)
res = rearrange(res, 'n b c h d -> b (n c) h d')
return res
class Carry(NamedTuple):
numerator: jax.Array
denominator: jax.Array
max_so_far: jax.Array
def _chunk_attention_bias(query_chunk_size, key_chunk_size,
bias, deterministic, attn_dropout, attn_pdrop, causal,
dtype, query_chunk_idx, key_chunk_idx):
query_offset = query_chunk_idx * query_chunk_size
key_offset = key_chunk_idx * key_chunk_size
chunk_bias = jnp.zeros((1, 1, 1, 1), dtype=dtype)
if bias is not None:
chunk_bias = lax.dynamic_slice(
bias,
start_indices=(0, 0, query_offset, key_offset),
slice_sizes=(*bias.shape[:2], min(bias.shape[-2], query_chunk_size), min(bias.shape[-1], key_chunk_size)),
)
if causal:
query_idx = lax.broadcasted_iota(dtype=jnp.int32, shape=(query_chunk_size, 1), dimension=0)
key_idx = lax.broadcasted_iota(dtype=jnp.int32, shape=(1, key_chunk_size), dimension=1)
offset = query_offset - key_offset
query_idx += offset
causal_mask_value = (query_idx < key_idx) * jnp.finfo(dtype).min
chunk_bias += causal_mask_value.reshape(1, 1, *causal_mask_value.shape)
if not deterministic and attn_pdrop > 0.0:
attn_dropout_slice = lax.dynamic_slice(
attn_dropout,
start_indices=(0, 0, query_offset, key_offset),
slice_sizes=(
*attn_dropout.shape[:2],
min(attn_dropout.shape[-2], query_chunk_size),
min(attn_dropout.shape[-1], key_chunk_size),
),
)
chunk_bias += attn_dropout_slice * jnp.finfo(dtype).min
return chunk_bias.astype(dtype)
if __name__ == '__main__':
# test
def reference_attn(query, key, value, causal, dtype):
query = query / jnp.sqrt(query.shape[-1]).astype(dtype)
logits = jnp.einsum("bqhc,bkhc->bhqk", query, key)
if causal:
mask_value = jnp.finfo(logits.dtype).min
_, q_seq_len, _, _ = query.shape
_, kv_seq_len, _, _ = key.shape
mask_shape = (q_seq_len, kv_seq_len)
row_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 0)
col_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 1)
causal_mask = (row_ids < col_ids)[None, None, :, :]
logits = logits + jnp.where(causal_mask, mask_value, 0.0)
weights = jax.nn.softmax(logits, axis=-1)
out = jnp.einsum("bhqk,bkhc->bqhc", weights, value)
return out
# random inputs
shape = (1, 32, 8, 64)
query = jax.random.normal(jax.random.PRNGKey(0), shape)
key = jax.random.normal(jax.random.PRNGKey(1), shape)
value = jax.random.normal(jax.random.PRNGKey(2), shape)
causal = True
chunk_size = 4
policy = jax.checkpoint_policies.nothing_saveable()
blockwise = blockwise_attn(query, key, value, None, False, None, 0.0, causal, chunk_size, chunk_size, jnp.float32, policy, 'float32', True, False)
reference = reference_attn(query, key, value, causal, 'float32')
assert jnp.allclose(reference, blockwise, atol=1e-6)

212
EasyLM/checkpoint.py Normal file
View File

@@ -0,0 +1,212 @@
import os
import numpy as np
from ml_collections import ConfigDict
import mlxu
import jax
import jax.numpy as jnp
import flax
from flax.serialization import (
from_bytes, to_bytes, to_state_dict, from_state_dict
)
from flax.traverse_util import flatten_dict, unflatten_dict, empty_node
import msgpack
from EasyLM.jax_utils import tree_apply, float_tensor_to_dtype
class StreamingCheckpointer(object):
""" Custom msgpack checkpointer that saves large train states by serializing
and saving tensors one by one in a streaming fashion. Avoids running
out of memory or local TPU disk with default flax checkpointer.
"""
@staticmethod
def get_default_config(updates=None):
config = ConfigDict()
config.float_dtype = 'bf16'
config.save_optimizer_state = False
if updates is not None:
config.update(ConfigDict(updates).copy_and_resolve_references())
return config
def __init__(self, config, checkpoint_dir, enable=True):
self.config = self.get_default_config(config)
self.checkpoint_dir = checkpoint_dir
self.enable = enable
def save_checkpoint(self, train_state, filename, gather_fns=None):
if self.enable:
path = os.path.join(self.checkpoint_dir, filename)
else:
path = '/dev/null'
self.save_train_state_to_file(
train_state, path, gather_fns, self.config.float_dtype
)
@staticmethod
def save_train_state_to_file(train_state, path, gather_fns=None, float_dtype=None):
train_state = to_state_dict(train_state)
packer = msgpack.Packer()
flattend_train_state = flatten_dict(train_state)
if gather_fns is not None:
gather_fns = flatten_dict(to_state_dict(gather_fns))
with mlxu.open_file(path, "wb") as fout:
for key, value in flattend_train_state.items():
if gather_fns is not None:
value = gather_fns[key](value)
value = float_tensor_to_dtype(value, float_dtype)
fout.write(packer.pack((key, to_bytes(value))))
def save_pickle(self, obj, filename):
if self.enable:
path = os.path.join(self.checkpoint_dir, filename)
else:
path = '/dev/null'
mlxu.save_pickle(obj, path)
def save_all(self, train_state, gather_fns, metadata=None, dataset=None, milestone=False):
step = int(jax.device_get(train_state.step))
if self.config.save_optimizer_state:
checkpoint_state = train_state
checkpoint_name = 'streaming_train_state'
checkpoint_gather_fns = gather_fns
else:
checkpoint_state = train_state.params['params']
checkpoint_name = 'streaming_params'
checkpoint_gather_fns = gather_fns.params['params']
if milestone:
# Save a milestone checkpoint that will not be overwritten
self.save_pickle(metadata, f'metadata_{step}.pkl')
self.save_pickle(dataset, f'dataset_{step}.pkl')
self.save_checkpoint(
checkpoint_state, f'{checkpoint_name}_{step}', checkpoint_gather_fns
)
else:
# Save a normal checkpoint that can be overwritten
self.save_pickle(metadata, 'metadata.pkl')
self.save_pickle(dataset, 'dataset.pkl')
self.save_checkpoint(
checkpoint_state, f'{checkpoint_name}', checkpoint_gather_fns
)
@staticmethod
def load_checkpoint(path, target=None, shard_fns=None, remove_dict_prefix=None):
if shard_fns is not None:
shard_fns = flatten_dict(
to_state_dict(shard_fns)
)
if remove_dict_prefix is not None:
remove_dict_prefix = tuple(remove_dict_prefix)
flattend_train_state = {}
with mlxu.open_file(path) as fin:
# 83886080 bytes = 80 MB, which is 16 blocks on GCS
unpacker = msgpack.Unpacker(fin, read_size=83886080, max_buffer_size=0)
for key, value in unpacker:
key = tuple(key)
if remove_dict_prefix is not None:
if key[:len(remove_dict_prefix)] == remove_dict_prefix:
key = key[len(remove_dict_prefix):]
else:
continue
tensor = from_bytes(None, value)
if shard_fns is not None:
tensor = shard_fns[key](tensor)
flattend_train_state[key] = tensor
if target is not None:
flattened_target = flatten_dict(
to_state_dict(target), keep_empty_nodes=True
)
for key, value in flattened_target.items():
if key not in flattend_train_state and value == empty_node:
flattend_train_state[key] = value
train_state = unflatten_dict(flattend_train_state)
if target is None:
return train_state
return from_state_dict(target, train_state)
@staticmethod
def load_flax_checkpoint(path, target=None, shard_fns=None):
""" Load a standard flax checkpoint that's not saved with the
msgpack streaming format.
"""
with mlxu.open_file(path, "rb") as fin:
encoded_bytes = fin.read()
state_dict = flax.serialization.msgpack_restore(encoded_bytes)
if shard_fns is not None:
shard_fns = to_state_dict(shard_fns)
state_dict = tree_apply(shard_fns, state_dict)
if target is None:
return state_dict
return from_state_dict(target, state_dict)
@classmethod
def load_trainstate_checkpoint(cls, load_from, trainstate_target=None,
trainstate_shard_fns=None,
disallow_trainstate=False):
if trainstate_target is not None:
params_target = trainstate_target.params['params']
else:
params_target = None
if trainstate_shard_fns is not None:
params_shard_fns = trainstate_shard_fns.params['params']
else:
params_shard_fns = None
load_type, load_path = load_from.split('::', 1)
if disallow_trainstate:
assert load_type != 'trainstate', 'Loading full trainstate is not allowed!'
train_state = None
restored_params = None
if load_type == 'trainstate':
# Load the entire train state in the streaming format
train_state = cls.load_checkpoint(
path=load_path,
target=trainstate_target,
shard_fns=trainstate_shard_fns,
)
elif load_type == 'trainstate_params':
# Load the params part of the train state in the streaming format
restored_params = cls.load_checkpoint(
path=load_path,
target=params_target,
shard_fns=params_shard_fns,
remove_dict_prefix=('params', 'params'),
)
restored_params = flax.core.frozen_dict.freeze(
{'params': restored_params}
)
elif load_type == 'params':
# Load the params in the streaming format
restored_params = cls.load_checkpoint(
path=load_path,
target=params_target,
shard_fns=params_shard_fns,
)
restored_params = flax.core.frozen_dict.freeze(
{'params': restored_params}
)
elif load_type == 'flax_params':
# Load the params in the standard flax format (non-streaming)
# This requires the entire params to fit in memory
restored_params = cls.load_flax_checkpoint(
path=load_path,
target=params_target,
shard_fns=params_shard_fns
)
restored_params = flax.core.frozen_dict.freeze(
{'params': restored_params}
)
else:
raise ValueError(f'Invalid load_from type: {load_type}')
return train_state, restored_params

436
EasyLM/data.py Normal file
View File

@@ -0,0 +1,436 @@
import dataclasses
import pprint
import time
from functools import partial
import json
import base64
from multiprocessing import Pool
import h5py
import mlxu
from ml_collections.config_dict import config_dict
from ml_collections import ConfigDict
from tqdm import tqdm, trange
import numpy as np
from datasets import load_dataset, load_from_disk
class DatasetFactory(object):
""" Datset builder class. """
@staticmethod
def get_default_config(updates=None):
config = ConfigDict()
config.type = 'huggingface'
config.text_processor = TextProcessor.get_default_config()
config.huggingface_dataset = HuggingfaceDataset.get_default_config()
config.json_dataset = JsonDataset.get_default_config()
if updates is not None:
config.update(ConfigDict(updates).copy_and_resolve_references())
return config
@classmethod
def load_dataset(cls, config, tokenizer, **kwargs):
config = cls.get_default_config(config)
text_processor = TextProcessor(config.text_processor, tokenizer)
if config.type == 'huggingface':
return HuggingfaceDataset(
config.huggingface_dataset, tokenizer, text_processor, **kwargs
)
elif config.type == 'json':
return JsonDataset(config.json_dataset, tokenizer, text_processor, **kwargs)
else:
raise ValueError(f'Unknown dataset type: {config.type}')
def __init__(self):
raise ValueError('DatasetFactory is a static class and should not be instantiated.')
class TextProcessor(object):
""" Example processor that converts a dictionary of texts into tokens. """
@staticmethod
def get_default_config(updates=None):
config = ConfigDict()
config.fields_from_example = ''
config.fields = ''
config.subfield_separator = ' '
config.add_bos_token = True
config.add_eos_token = True
config.prepend_text = ''
config.base64_token_dtype = 'i4'
if updates is not None:
config.update(ConfigDict(updates).copy_and_resolve_references())
return config
def __init__(self, config, tokenizer):
self.config = self.get_default_config(config)
assert self.config.fields != '' or self.config.fields_from_example != '', (
'Either fields or fields_from_example must be specified.'
)
self.tokenizer = tokenizer
def __call__(self, example, has_aux=False):
if has_aux:
example, *aux = example
else:
aux = tuple()
token_buffer = []
loss_mask_buffer = []
if self.config.add_bos_token:
token_buffer.append(self.tokenizer.bos_token_id)
loss_mask_buffer.append(0.0)
if self.config.fields_from_example != '':
fields = example[self.config.fields_from_example].split(',')
else:
fields = self.config.fields.split(',')
for i, field in enumerate(fields):
if field.startswith('[') and field.endswith(']'):
# No loss for this field.
field = field[1:-1]
mask = 0.0
else:
mask = 1.0
if field.startswith('<|') and field.endswith('|>'):
# Special tokens.
field = field[2:-2]
if field == 'bos':
token_buffer.append(self.tokenizer.bos_token_id)
elif field == 'eos':
token_buffer.append(self.tokenizer.eos_token_id)
else:
# Token ID specified directly.
token_buffer.append(int(field))
loss_mask_buffer.append(mask)
elif field.startswith('{') and field.endswith('}'):
field = field[1:-1]
# Base64 encoded raw tokens.
tokens = np.frombuffer(
base64.b64decode(example[field]),
dtype=self.config.base64_token_dtype
).tolist()
token_buffer.extend(tokens)
loss_mask_buffer.extend([mask for _ in range(len(tokens))])
else:
subfields = field.split('+')
text = self.config.subfield_separator.join(
[example[subfield] for subfield in subfields]
)
if i == 0:
text = self.config.prepend_text + text
tokens = self.tokenizer.encode(text)
token_buffer.extend(tokens)
loss_mask_buffer.extend([mask for _ in range(len(tokens))])
if self.config.add_eos_token:
token_buffer.append(self.tokenizer.eos_token_id)
loss_mask_buffer.append(1.0)
return token_buffer, loss_mask_buffer, *aux
class HuggingfaceDataset(object):
""" Huggingface dataset, where the dataset is loaded using the huggingface
datasets.load_dataset() function.
"""
@staticmethod
def get_default_config(updates=None):
config = ConfigDict()
config.path = 'c4'
config.name = 'en'
config.split = 'train'
config.streaming = False
config.seq_length = 1024
config.batch_size = 8
config.always_start_with_bos = False
config.start_seek_loc = 0
config.tokens_count_at_start = 0
config.batch_token_dtype = 'i4'
config.reset_dataset_loc = False
if updates is not None:
config.update(ConfigDict(updates).copy_and_resolve_references())
return config
def __init__(self, config, tokenizer, text_processor, eval_dataset=False):
self.config = self.get_default_config(config)
name = self.config.name if self.config.name != '' else None
split = self.config.split if self.config.split != '' else None
self._tokenizer = tokenizer
self._text_processor = text_processor
self._dataset = load_from_disk(
self.config.path
)[split]
self._dataset = self._dataset.to_iterable_dataset(num_shards=128 if len(self._dataset) > 128 else len(self._dataset))
self._eval_dataset = eval_dataset
self._train_epochs = 0
self._dataset_loc = self.config.start_seek_loc
self._total_tokens = self.config.tokens_count_at_start
self._index = 0
self.reset_dataset_loc = self.config.reset_dataset_loc
def __iter__(self):
if not self._eval_dataset and self._train_epochs > 0:
self._dataset = self._dataset.shuffle(seed=42, buffer_size=10000)
chunk_size = self.config.batch_size * self.config.seq_length
while True:
token_buffer = []
loss_mask_buffer = []
if not self._eval_dataset and self._train_epochs > 0:
self._dataset.set_epoch(self._train_epochs)
for index, example in enumerate(self._dataset):
self._index = index
if not self._eval_dataset and self._dataset_loc > index:
continue
tokens, loss_masks = self.text_processor(example)
token_buffer.extend(tokens)
loss_mask_buffer.extend(loss_masks)
while len(token_buffer) > chunk_size + 1:
self._total_tokens += chunk_size
metrics = {
'dataset_example_index': index,
'dataset_total_tokens': self._total_tokens,
'epoch': self._train_epochs,
}
batch = {
'input_tokens': np.array(token_buffer[:chunk_size], dtype=self.config.batch_token_dtype).reshape(
self.config.batch_size, -1
),
'target_tokens': np.array(token_buffer[1:chunk_size + 1], dtype=self.config.batch_token_dtype).reshape(
self.config.batch_size, -1
),
'loss_masks': np.array(loss_mask_buffer[1:chunk_size + 1], dtype=np.float32).reshape(
self.config.batch_size, -1
),
}
if self.config.always_start_with_bos:
batch['input_tokens'][:, 0] = self.tokenizer.bos_token_id
yield batch, metrics
token_buffer = token_buffer[chunk_size:]
loss_mask_buffer = loss_mask_buffer[chunk_size:]
if self._eval_dataset:
break
else:
if self._train_epochs == 0:
self._dataset = self._dataset.shuffle(seed=42, buffer_size=10000)
self._dataset_loc = 0
self._train_epochs += 1
def get_state_dict(self):
return dict(
config=self.config,
dataset_loc=self._index,
total_tokens=self._total_tokens,
epochs=self._train_epochs,
)
def load_state_dict(self, state_dict):
if 'config' in state_dict:
self.config.update(ConfigDict(state_dict['config']))
self._dataset_loc = state_dict.get('dataset_loc', self.config.start_seek_loc)
self._total_tokens = state_dict.get('total_tokens', self.config.tokens_count_at_start)
self._train_epochs = state_dict.get('epochs', 0)
if self.reset_dataset_loc:
self._dataset_loc = 0
self._train_epochs = 0
@property
def seq_length(self):
return self.config.seq_length
@property
def tokenizer(self):
return self._tokenizer
@property
def text_processor(self):
return self._text_processor
@property
def dataset(self):
return self._dataset
@property
def vocab_size(self):
return len(self._tokenizer)
class JsonDataset(object):
""" JSON dataset, where each line of the data file contains a JSON
dictionary with text fields.
"""
@staticmethod
def get_default_config(updates=None):
config = ConfigDict()
config.path = ''
config.seq_length = 1024
config.batch_size = 8
config.always_start_with_bos = False
config.start_seek_loc = 0
config.example_index_at_start = 0
config.tokens_count_at_start = 0
config.tokenizer_processes = 1
config.tokenizer_parallel_chunk_size = 32
config.tokenizer_parallel_batch_size = 1024
config.throughput_average_window_size = 200
if updates is not None:
config.update(ConfigDict(updates).copy_and_resolve_references())
return config
def __init__(self, config, tokenizer, text_processor):
self.config = self.get_default_config(config)
assert self.config.path != ''
self._tokenizer = tokenizer
self._text_processor = text_processor
self._index = self.config.example_index_at_start
self._file_loc = self.config.start_seek_loc
self._total_tokens = self.config.tokens_count_at_start
def parse_json(self, line):
if not line or line == '\n':
return None
try:
data = json.loads(line)
except json.decoder.JSONDecodeError:
print(f'Error parsing json line:\n{line}')
return None
return data
def json_iterator(self):
with mlxu.open_file(self.config.path, 'r') as fin:
fin.seek(self._file_loc)
while True:
line = fin.readline()
self._file_loc = fin.tell()
if not line: # Reached EOF
self._index = 0
fin.seek(0)
continue
data = self.parse_json(line)
if data is not None:
# JSON parsing succeeded
yield data, self._file_loc, self._index
self._index += 1
def batched(self, iterator, batch_size):
batch = []
for example in iterator:
batch.append(example)
if len(batch) == batch_size:
yield batch
batch = []
if len(batch) > 0:
yield batch
def parallel_example_iterator(self):
if self.config.tokenizer_processes == 1:
for example, loc, index in self.json_iterator():
yield self.text_processor((example, loc, index), has_aux=True)
else:
process_pool = Pool(self.config.tokenizer_processes)
batched_iterator = self.batched(
self.json_iterator(), self.config.tokenizer_parallel_batch_size
)
with process_pool as pool:
map_fn = partial(self.text_processor, has_aux=True)
next_batch = pool.map_async(
map_fn, next(batched_iterator),
chunksize=self.config.tokenizer_parallel_chunk_size
)
while True:
current_batch = next_batch
next_batch = pool.map_async(
map_fn, next(batched_iterator),
chunksize=self.config.tokenizer_parallel_chunk_size
)
for example in current_batch.get():
yield example
def __iter__(self):
chunk_size = self.config.batch_size * self.config.seq_length
token_buffer = []
loss_mask_buffer = []
last_time = 0.0
step_times = []
start_time = time.time()
start_tokens = self._total_tokens
for tokens, loss_masks, loc, index in self.parallel_example_iterator():
token_buffer.extend(tokens)
loss_mask_buffer.extend(loss_masks)
while len(token_buffer) > chunk_size + 1:
self._total_tokens += chunk_size
step_times.append(time.time() - last_time)
last_time = time.time()
if len(step_times) > self.config.throughput_average_window_size:
step_times = step_times[-self.config.throughput_average_window_size:]
average_throughput = chunk_size / np.mean(step_times)
accumulated_throughput = (
(self._total_tokens - start_tokens) / (time.time() - start_time)
)
metrics = {
'dataset_file_loc': loc,
'dataset_example_index': index,
'dataset_total_tokens': self._total_tokens,
'dataset_accumulated_tps': accumulated_throughput,
'dataset_average_tps': average_throughput,
}
batch = {
'input_tokens': np.array(token_buffer[:chunk_size], dtype=np.int32).reshape(
self.config.batch_size, -1
),
'target_tokens': np.array(token_buffer[1:chunk_size + 1], dtype=np.int32).reshape(
self.config.batch_size, -1
),
'loss_masks': np.array(loss_mask_buffer[1:chunk_size + 1], dtype=np.float32).reshape(
self.config.batch_size, -1
),
}
if self.config.always_start_with_bos:
batch['input_tokens'][:, 0] = self.tokenizer.bos_token_id
yield batch, metrics
token_buffer = token_buffer[chunk_size:]
loss_mask_buffer = loss_mask_buffer[chunk_size:]
def get_state_dict(self):
return dict(
config=self.config,
index=self._index,
file_loc=self._file_loc,
total_tokens=self._total_tokens,
)
def load_state_dict(self, state_dict):
if 'config' in state_dict:
self.config.update(ConfigDict(state_dict['config']))
self._index = state_dict.get('index', self.config.example_index_at_start)
self._file_loc = state_dict.get('file_loc', self.config.start_seek_loc)
self._total_tokens = state_dict.get('total_tokens', self.config.tokens_count_at_start)
@property
def seq_length(self):
return self.config.seq_length
@property
def tokenizer(self):
return self._tokenizer
@property
def text_processor(self):
return self._text_processor
@property
def vocab_size(self):
return len(self.tokenizer)

403
EasyLM/jax_utils.py Normal file
View File

@@ -0,0 +1,403 @@
import os
import math
from typing import Any, Mapping, Text, Tuple, Union, NamedTuple
from functools import partial
import re
import dataclasses
import random
from ml_collections import ConfigDict
from ml_collections.config_dict.config_dict import placeholder
import flax
import jax
import jax.numpy as jnp
from jax.sharding import PartitionSpec as PS
from jax.sharding import Mesh
from jax.experimental import mesh_utils
from jax.experimental.pjit import with_sharding_constraint as _with_sharding_constraint
from jax.experimental.pjit import pjit
from jax.interpreters import pxla
import numpy as np
from transformers import FlaxLogitsWarper
class JaxRNG(object):
""" A convenient stateful Jax RNG wrapper. Can be used to wrap RNG inside
pure function.
"""
@classmethod
def from_seed(cls, seed):
return cls(jax.random.PRNGKey(seed))
def __init__(self, rng):
self.rng = rng
def __call__(self, keys=None):
if keys is None:
self.rng, split_rng = jax.random.split(self.rng)
return split_rng
elif isinstance(keys, int):
split_rngs = jax.random.split(self.rng, num=keys + 1)
self.rng = split_rngs[0]
return tuple(split_rngs[1:])
else:
split_rngs = jax.random.split(self.rng, num=len(keys) + 1)
self.rng = split_rngs[0]
return {key: val for key, val in zip(keys, split_rngs[1:])}
class JaxDistributedConfig(object):
""" Utility class for initializing JAX distributed. """
@staticmethod
def get_default_config(updates=None):
config = ConfigDict()
config.initialize_jax_distributed = False
config.coordinator_address = placeholder(str)
config.num_processes = placeholder(int)
config.process_id = placeholder(int)
config.local_device_ids = placeholder(str)
if updates is not None:
config.update(ConfigDict(updates).copy_and_resolve_references())
return config
@classmethod
def initialize(cls, config):
config = cls.get_default_config(config)
if config.initialize_jax_distributed:
if config.local_device_ids is not None:
local_device_ids = [int(x) for x in config.local_device_ids.split(',')]
else:
local_device_ids = None
jax.distributed.initialize(
coordinator_address=config.coordinator_address,
num_processes=config.num_processes,
process_id=config.process_id,
local_device_ids=local_device_ids,
)
class FlaxTemperatureLogitsWarper(FlaxLogitsWarper):
""" JIT traceable version of FlaxLogitsWarper that performs temperature scaling."""
def __init__(self, temperature):
self.temperature = temperature
def __call__(self, input_ids, scores, cur_len):
return scores / jnp.clip(self.temperature, a_min=1e-8)
def make_shard_and_gather_fns(partition_specs, dtype_specs=None):
""" Create pytree of sharding and gathering functions from pytree of
partition specs.
"""
float_dtypes = (jnp.bfloat16, jnp.float16, jnp.float32, jnp.float64)
def make_to_dtype_fn(dtype_spec):
def to_dtype(tensor):
if dtype_specs in float_dtypes and getattr(tensor, 'dtype', None) in float_dtypes:
# Convert all float tensors to the same dtype
return tensor.astype(dtype_specs)
elif hasattr(dtype_spec, 'dtype') and hasattr(tensor, 'dtype'):
return tensor.astype(dtype_spec.dtype)
return tensor
return to_dtype
def make_shard_fn(partition_spec, dtype_spec=None):
jax_shard_function = pjit(
make_to_dtype_fn(dtype_spec),
in_shardings=None,
out_shardings=partition_spec
)
def shard_fn(tensor):
return jax_shard_function(tensor).block_until_ready()
return shard_fn
def make_gather_fn(partition_spec, dtype_spec=None):
jax_gather_fn = pjit(
make_to_dtype_fn(dtype_spec),
in_shardings=partition_spec,
out_shardings=None
)
def gather_fn(tensor):
return jax.device_get(jax_gather_fn(tensor))
return gather_fn
if dtype_specs is None or dtype_specs in float_dtypes:
shard_fns = jax.tree_util.tree_map(make_shard_fn, partition_specs)
gather_fns = jax.tree_util.tree_map(make_gather_fn, partition_specs)
else:
shard_fns = jax.tree_util.tree_map(
make_shard_fn, partition_specs, dtype_specs
)
gather_fns = jax.tree_util.tree_map(
make_gather_fn, partition_specs, dtype_specs
)
return shard_fns, gather_fns
def set_random_seed(seed):
np.random.seed(seed)
random.seed(seed)
init_rng(seed)
def get_jax_mesh(axis_dims, names):
if axis_dims.startswith('!'):
# Allow splitting a physical mesh axis if needed
mesh_axis_splitting = True
axis_dims = axis_dims[1:]
else:
mesh_axis_splitting = False
if ':' in axis_dims:
dims = []
dim_names = []
for axis in axis_dims.split(','):
name, dim = axis.split(':')
assert name in names
dims.append(int(dim))
dim_names.append(name)
assert(set(dim_names) == set(names))
else:
dims = [int(x) for x in axis_dims.split(',')]
dim_names = names
assert len(dims) == len(names)
mesh_shape = np.arange(jax.device_count()).reshape(dims).shape
if mesh_axis_splitting:
physical_mesh = np.array(jax.devices()).reshape(mesh_shape)
else:
physical_mesh = mesh_utils.create_device_mesh(mesh_shape)
return Mesh(physical_mesh, dim_names)
def names_in_current_mesh(*names):
""" Check if current mesh axes contain these names. """
mesh_axis_names = pxla.thread_resources.env.physical_mesh.axis_names
return set(names) <= set(mesh_axis_names)
def get_names_from_parition_spec(partition_specs):
""" Return axis names from partition specs. """
names = set()
if isinstance(partition_specs, dict):
partition_specs = partition_specs.values()
for item in partition_specs:
if item is None:
continue
elif isinstance(item, str):
names.add(item)
else:
names.update(get_names_from_parition_spec(item))
return list(names)
def with_sharding_constraint(x, partition_specs):
""" A smarter version of with_sharding_constraint that only applies the
constraint if the current mesh contains the axes in the partition specs.
"""
axis_names = get_names_from_parition_spec(partition_specs)
if names_in_current_mesh(*axis_names):
x = _with_sharding_constraint(x, partition_specs)
return x
def wrap_function_with_rng(rng):
""" To be used as decorator, automatically bookkeep a RNG for the wrapped function. """
def wrap_function(function):
def wrapped(*args, **kwargs):
nonlocal rng
rng, split_rng = jax.random.split(rng)
return function(split_rng, *args, **kwargs)
return wrapped
return wrap_function
def init_rng(seed):
global jax_utils_rng
jax_utils_rng = JaxRNG.from_seed(seed)
def next_rng(*args, **kwargs):
global jax_utils_rng
return jax_utils_rng(*args, **kwargs)
def get_metrics(metrics, unreplicate=False, stack=False):
if unreplicate:
metrics = flax.jax_utils.unreplicate(metrics)
metrics = jax.device_get(metrics)
if stack:
return jax.tree_map(lambda *args: np.stack(args), *metrics)
else:
return {key: float(val) for key, val in metrics.items()}
def mse_loss(val, target, valid=None):
if valid is None:
valid = jnp.ones((*target.shape[:2], 1))
valid = valid.astype(jnp.float32)
loss = jnp.mean(
jnp.where(
valid > 0.0,
jnp.square(val - target),
0.0
)
)
return loss
def cross_entropy_loss_and_accuracy(logits, tokens, valid=None):
if valid is None:
valid = jnp.ones(tokens.shape[:2])
valid = valid.astype(jnp.float32)
valid_text_length = jnp.maximum(jnp.sum(valid, axis=-1), 1e-10)
logits = logits.astype(jnp.float32) # for numerical stability
token_log_prob = jnp.squeeze(
jnp.take_along_axis(
jax.nn.log_softmax(logits, axis=-1),
jnp.expand_dims(tokens, -1),
axis=-1,
),
-1,
)
token_log_prob = jnp.where(valid > 0.0, token_log_prob, jnp.array(0.0))
loss = -jnp.mean(jnp.sum(token_log_prob, axis=-1) / valid_text_length)
correct = jnp.where(
valid > 0.0,
jnp.argmax(logits, axis=-1) == tokens,
jnp.array(False)
)
accuracy = jnp.mean(jnp.sum(correct, axis=-1) / valid_text_length)
return loss, accuracy
def global_norm(tree):
""" Return the global L2 norm of a pytree. """
squared = jax.tree_util.tree_map(lambda x: jnp.sum(jnp.square(x)), tree)
flattened, _ = jax.flatten_util.ravel_pytree(squared)
return jnp.sqrt(jnp.sum(flattened))
def average_metrics(metrics):
with jax.spmd_mode("allow_all"):
return jax.tree_map(
lambda *args: jnp.mean(jnp.stack(args)),
*metrics
)
def get_float_dtype_by_name(dtype):
return {
'bf16': jnp.bfloat16,
'bfloat16': jnp.bfloat16,
'fp16': jnp.float16,
'float16': jnp.float16,
'fp32': jnp.float32,
'float32': jnp.float32,
'fp64': jnp.float64,
'float64': jnp.float64,
}[dtype]
def float_tensor_to_dtype(tensor, dtype):
if dtype is None or dtype == '':
return tensor
if isinstance(dtype, str):
dtype = get_float_dtype_by_name(dtype)
float_dtypes = (jnp.bfloat16, jnp.float16, jnp.float32, jnp.float64)
if getattr(tensor, 'dtype', None) in float_dtypes:
tensor = tensor.astype(dtype)
return tensor
def float_to_dtype(tree, dtype):
return jax.tree_util.tree_map(
partial(float_tensor_to_dtype, dtype=dtype), tree
)
def get_gradient_checkpoint_policy(name):
return {
'everything_saveable': jax.checkpoint_policies.everything_saveable,
'nothing_saveable': jax.checkpoint_policies.nothing_saveable,
'checkpoint_dots': jax.checkpoint_policies.checkpoint_dots,
'checkpoint_dots_with_no_batch_dims': jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims,
}[name]
def tree_path_to_string(path, sep=None):
keys = []
for key in path:
if isinstance(key, jax.tree_util.SequenceKey):
keys.append(str(key.idx))
elif isinstance(key, jax.tree_util.DictKey):
keys.append(str(key.key))
elif isinstance(key, jax.tree_util.GetAttrKey):
keys.append(str(key.name))
elif isinstance(key, jax.tree_util.FlattenedIndexKey):
keys.append(str(key.key))
else:
keys.append(str(key))
if sep is None:
return tuple(keys)
return sep.join(keys)
def flatten_tree(xs, is_leaf=None, sep=None):
flattened, _ = jax.tree_util.tree_flatten_with_path(xs, is_leaf=is_leaf)
output = {}
for key, val in flattened:
output[tree_path_to_string(key, sep=sep)] = val
return output
def named_tree_map(f, tree, *rest, is_leaf=None, sep=None):
""" An extended version of jax.tree_util.tree_map, where the mapped function
f takes both the name (path) and the tree leaf as input.
"""
return jax.tree_util.tree_map_with_path(
lambda path, x, *r: f(tree_path_to_string(path, sep=sep), x, *r),
tree, *rest,
is_leaf=is_leaf
)
def match_partition_rules(rules, params):
""" Returns a pytree of PartitionSpec according to rules. Supports handling
Flax TrainState and Optax optimizer state.
"""
def get_partition_spec(name, leaf):
if len(leaf.shape) == 0 or np.prod(leaf.shape) == 1:
""" Don't partition scalar values. """
return PS()
for rule, ps in rules:
if re.search(rule, name) is not None:
return ps
raise ValueError(f'Partition rule not found for param: {name}')
return named_tree_map(get_partition_spec, params, sep='/')
def get_weight_decay_mask(exclusions):
""" Return a weight decay mask function that computes the pytree masks
according to the given exclusion rules.
"""
def decay(name, _):
for rule in exclusions:
if re.search(rule, name) is not None:
return False
return True
def weight_decay_mask(params):
return named_tree_map(decay, params, sep='/')
return weight_decay_mask
def tree_apply(fns, tree):
""" Apply a pytree of functions to the pytree. """
return jax.tree_util.tree_map(lambda fn, x: fn(x), fns, tree)

View File

View File

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,396 @@
import pprint
from functools import partial
import numpy as np
import mlxu
import jax
import jax.numpy as jnp
from jax.experimental.pjit import pjit
from jax.sharding import PartitionSpec as PS
import flax
from flax import linen as nn
from flax.jax_utils import prefetch_to_device
from flax.training.train_state import TrainState
import optax
from transformers import GenerationConfig, FlaxLogitsProcessorList
from EasyLM.checkpoint import StreamingCheckpointer
from EasyLM.serving import LMServer
from EasyLM.jax_utils import (
JaxRNG, JaxDistributedConfig, next_rng, match_partition_rules, tree_apply,
set_random_seed, get_float_dtype_by_name, make_shard_and_gather_fns,
with_sharding_constraint, FlaxTemperatureLogitsWarper
)
from EasyLM.models.gptj.gptj_model import (
GPTJConfig, FlaxGPTJForCausalLMModule, FlaxGPTJForCausalLM
)
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
seed=42,
initialize_jax_distributed=False,
mesh_dim='1,-1,1',
dtype='bf16',
input_length=1024,
seq_length=2048,
top_k=50,
top_p=1.0,
do_sample=True,
num_beams=1,
add_bos_token=False,
load_gptj_config='',
load_checkpoint='',
tokenizer=GPTJConfig.get_tokenizer_config(),
lm_server=LMServer.get_default_config(),
jax_distributed=JaxDistributedConfig.get_default_config(),
)
def main(argv):
JaxDistributedConfig.initialize(FLAGS.jax_distributed)
set_random_seed(FLAGS.seed)
prefix_tokenizer = GPTJConfig.get_tokenizer(
FLAGS.tokenizer, truncation_side='left', padding_side='left'
)
tokenizer = GPTJConfig.get_tokenizer(
FLAGS.tokenizer, truncation_side='right', padding_side='right'
)
with jax.default_device(jax.devices("cpu")[0]):
gptj_config = GPTJConfig.load_config(FLAGS.load_gptj_config)
load_type, load_path = FLAGS.load_checkpoint.split('::', 1)
if load_type == 'huggingface':
params = gptj_config.load_pretrained(load_path)
else:
_, params = StreamingCheckpointer.load_trainstate_checkpoint(
FLAGS.load_checkpoint, disallow_trainstate=True
)
hf_model = FlaxGPTJForCausalLM(
gptj_config,
input_shape=(1, FLAGS.seq_length),
seed=FLAGS.seed,
_do_init=False
)
model_ps = match_partition_rules(
GPTJConfig.get_partition_rules(), params
)
shard_fns, _ = make_shard_and_gather_fns(
model_ps, get_float_dtype_by_name(FLAGS.dtype)
)
@partial(
pjit,
in_shardings=(model_ps, PS(), PS()),
out_shardings=(PS(), PS(), PS())
)
def forward_loglikelihood(params, rng, batch):
batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
rng_generator = JaxRNG(rng)
input_tokens = batch['input_tokens']
output_tokens = batch['output_tokens']
input_mask = batch['input_mask']
output_mask = batch['output_mask']
logits = hf_model.module.apply(
params, input_tokens, attention_mask=input_mask,
deterministic=True, rngs=rng_generator(gptj_config.rng_keys()),
).logits
if gptj_config.n_real_tokens is not None:
logits = logits.at[:, :, gptj_config.n_real_tokens:].set(-1e8)
loglikelihood = -optax.softmax_cross_entropy_with_integer_labels(
logits, output_tokens
)
loglikelihood = jnp.sum(loglikelihood * output_mask, axis=-1)
match_count = jnp.sum(
(jnp.argmax(logits, axis=-1) == output_tokens) * output_mask,
axis=-1
)
total = jnp.sum(output_mask, axis=-1)
is_greedy = match_count == total
return loglikelihood, is_greedy, rng_generator()
@partial(
pjit,
in_shardings=(model_ps, PS(), PS(), PS()),
out_shardings=(PS(), PS())
)
def forward_generate(params, rng, batch, temperature):
batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
rng_generator = JaxRNG(rng)
output = hf_model.generate(
batch['input_tokens'],
attention_mask=batch['attention_mask'],
params=params['params'],
prng_key=rng_generator(),
logits_processor=FlaxLogitsProcessorList(
[FlaxTemperatureLogitsWarper(temperature)]
),
generation_config=GenerationConfig(
max_new_tokens=FLAGS.seq_length - FLAGS.input_length,
pad_token_id=tokenizer.eos_token_id,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
do_sample=FLAGS.do_sample,
num_beams=FLAGS.num_beams,
top_k=FLAGS.top_k,
top_p=FLAGS.top_p,
)
).sequences[:, batch['input_tokens'].shape[1]:]
return output, rng_generator()
@partial(
pjit,
in_shardings=(model_ps, PS(), PS()),
out_shardings=(PS(), PS())
)
def forward_greedy_generate(params, rng, batch):
batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
rng_generator = JaxRNG(rng)
output = hf_model.generate(
batch['input_tokens'],
attention_mask=batch['attention_mask'],
params=params['params'],
prng_key=rng_generator(),
generation_config=GenerationConfig(
max_new_tokens=FLAGS.seq_length - FLAGS.input_length,
pad_token_id=tokenizer.eos_token_id,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
do_sample=False,
num_beams=1,
)
).sequences[:, batch['input_tokens'].shape[1]:]
return output, rng_generator()
mesh = GPTJConfig.get_jax_mesh(FLAGS.mesh_dim)
with mesh:
params = tree_apply(shard_fns, params)
sharded_rng = next_rng()
class ModelServer(LMServer):
@staticmethod
def loglikelihood(prefix_text, text):
nonlocal sharded_rng
prefix = prefix_tokenizer(
prefix_text,
padding='max_length',
truncation=True,
max_length=FLAGS.input_length,
return_tensors='np',
)
inputs = tokenizer(
text,
padding='max_length',
truncation=True,
max_length=FLAGS.seq_length - FLAGS.input_length,
return_tensors='np',
)
output_tokens = np.concatenate([prefix.input_ids, inputs.input_ids], axis=1)
bos_tokens = np.full(
(output_tokens.shape[0], 1), tokenizer.bos_token_id, dtype=np.int32
)
input_tokens = np.concatenate([bos_tokens, output_tokens[:, :-1]], axis=-1)
input_mask = np.concatenate(
[prefix.attention_mask, inputs.attention_mask], axis=1
)
if FLAGS.add_bos_token:
bos_mask = np.ones_like(input_mask[:, :1])
else:
bos_mask = np.zeros_like(input_mask[:, :1])
input_mask = np.concatenate([bos_mask, input_mask[:, :-1]], axis=1)
output_mask = np.concatenate(
[np.zeros_like(prefix.attention_mask), inputs.attention_mask], axis=1
)
batch = dict(
input_tokens=input_tokens,
output_tokens=output_tokens,
input_mask=input_mask,
output_mask=output_mask,
)
with mesh:
loglikelihood, is_greedy, sharded_rng = forward_loglikelihood(
params, sharded_rng, batch
)
loglikelihood, is_greedy = jax.device_get((loglikelihood, is_greedy))
return loglikelihood, is_greedy
@staticmethod
def loglikelihood_rolling(text):
nonlocal sharded_rng
inputs = tokenizer(
text,
padding='longest',
truncation=False,
max_length=np.iinfo(np.int32).max,
return_tensors='np',
)
batch_size = inputs.input_ids.shape[0]
output_tokens = inputs.input_ids
attention_mask = inputs.attention_mask
if output_tokens.shape[1] < FLAGS.seq_length:
padding_length = FLAGS.seq_length - output_tokens.shape[1]
pad_tokens = np.full(
(batch_size, padding_length), tokenizer.pad_token_id, dtype=np.int32
)
output_tokens = np.concatenate([output_tokens, pad_tokens], axis=-1)
pad_mask = np.zeros(
(batch_size, padding_length), dtype=inputs.attention_mask.dtype
)
attention_mask = np.concatenate([attention_mask, pad_mask], axis=-1)
bos_tokens = np.full(
(batch_size, 1), tokenizer.bos_token_id, dtype=np.int32
)
input_tokens = np.concatenate([bos_tokens, output_tokens[:, :-1]], axis=-1)
bos_mask = np.ones((batch_size, 1), dtype=inputs.attention_mask.dtype)
total_seq_length = output_tokens.shape[1]
total_loglikelihood = 0.0
total_is_greedy = True
# Sliding window
for i in range(0, total_seq_length, FLAGS.seq_length):
# Last window
if i + FLAGS.seq_length > total_seq_length:
last_output_mask = np.copy(attention_mask[:, -FLAGS.seq_length:])
last_output_mask[:, :i - total_seq_length] = 0.0
batch = dict(
input_tokens=input_tokens[:, -FLAGS.seq_length:],
output_tokens=output_tokens[:, -FLAGS.seq_length:],
input_mask=attention_mask[:, -FLAGS.seq_length:],
output_mask=last_output_mask,
)
# Normal window
else:
batch = dict(
input_tokens=input_tokens[:, i:i + FLAGS.seq_length],
output_tokens=output_tokens[:, i:i + FLAGS.seq_length],
input_mask=attention_mask[:, i:i + FLAGS.seq_length],
output_mask=attention_mask[:, i:i + FLAGS.seq_length],
)
with mesh:
loglikelihood, is_greedy, sharded_rng = forward_loglikelihood(
params, sharded_rng, batch
)
loglikelihood, is_greedy = jax.device_get((loglikelihood, is_greedy))
total_loglikelihood += loglikelihood
total_is_greedy = np.logical_and(is_greedy, total_is_greedy)
return total_loglikelihood, total_is_greedy
@staticmethod
def generate(text, temperature):
nonlocal sharded_rng
inputs = prefix_tokenizer(
text,
padding='max_length',
truncation=True,
max_length=FLAGS.input_length,
return_tensors='np',
)
input_tokens = inputs.input_ids
input_mask = inputs.attention_mask
if FLAGS.add_bos_token:
input_tokens[:, 0] = tokenizer.bos_token_id
input_mask[:, 0] = 1
batch = dict(
input_tokens=input_tokens,
attention_mask=input_mask,
)
with mesh:
output, sharded_rng = forward_generate(
params, sharded_rng, batch, temperature
)
output = jax.device_get(output)
output_text = []
for text in list(tokenizer.batch_decode(output)):
if tokenizer.eos_token in text:
text = text.split(tokenizer.eos_token, maxsplit=1)[0]
output_text.append(text)
return output_text
@staticmethod
def greedy_until(prefix_text, until, max_length):
nonlocal sharded_rng
all_outputs = []
for pf, ut in zip(prefix_text, until):
if isinstance(ut, str):
ut = [ut]
total_length = 0
total_generated = ''
while total_length < max_length:
pf_tokens = tokenizer(
pf,
padding=False,
truncation=False,
max_length=np.iinfo(np.int32).max,
return_tensors='np',
)
input_tokens = pf_tokens.input_ids
attention_mask = pf_tokens.attention_mask
if input_tokens.shape[1] < FLAGS.input_length:
extra = FLAGS.input_length - input_tokens.shape[1]
pad_tokens = np.full(
(1, extra), tokenizer.pad_token_id, dtype=np.int32
)
input_tokens = np.concatenate(
[pad_tokens, input_tokens], axis=1
)
pad_attention = np.zeros((1, extra), dtype=attention_mask.dtype)
attention_mask = np.concatenate(
[pad_attention, attention_mask], axis=1
)
elif input_tokens.shape[1] > FLAGS.input_length:
input_tokens = input_tokens[:, -FLAGS.input_length:]
attention_mask = attention_mask[:, -FLAGS.input_length:]
if FLAGS.add_bos_token:
input_tokens[:, 0] = tokenizer.bos_token_id
attention_mask[:, 0] = 1
batch = dict(input_tokens=input_tokens, attention_mask=attention_mask)
with mesh:
output, sharded_rng = forward_greedy_generate(
params, sharded_rng, batch
)
output = jax.device_get(output)
total_length += output.shape[1]
output_text = tokenizer.batch_decode(output)[0]
total_generated = total_generated + output_text
pf = pf + output_text
done = False
for s in ut:
if s in total_generated:
total_generated = total_generated.split(s, maxsplit=1)[0]
done = True
if done:
break
all_outputs.append(total_generated)
return all_outputs
server = ModelServer(FLAGS.lm_server)
server.run()
if __name__ == "__main__":
mlxu.run(main)

View File

@@ -0,0 +1,272 @@
import pprint
from functools import partial
from tqdm import tqdm, trange
import numpy as np
import mlxu
import jax
import jax.numpy as jnp
from jax.experimental.pjit import pjit, with_sharding_constraint
from jax.sharding import PartitionSpec as PS
from flax.training.train_state import TrainState
from EasyLM.data import DatasetFactory
from EasyLM.checkpoint import StreamingCheckpointer
from EasyLM.optimizers import OptimizerFactory
from EasyLM.jax_utils import (
JaxRNG, JaxDistributedConfig, next_rng, match_partition_rules,
cross_entropy_loss_and_accuracy, global_norm, get_float_dtype_by_name,
set_random_seed, average_metrics, get_weight_decay_mask,
make_shard_and_gather_fns, tree_apply
)
from EasyLM.models.gptj.gptj_model import GPTJConfig, FlaxGPTJForCausalLMModule
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
seed=42,
mesh_dim='1,-1,1',
dtype='fp32',
total_steps=10000,
load_gptj_config='',
update_gptj_config='',
load_checkpoint='',
load_dataset_state='',
log_freq=50,
save_model_freq=0,
save_milestone_freq=0,
eval_steps=0,
tokenizer=GPTJConfig.get_tokenizer_config(),
train_dataset=DatasetFactory.get_default_config(),
eval_dataset=DatasetFactory.get_default_config(),
optimizer=OptimizerFactory.get_default_config(),
checkpointer=StreamingCheckpointer.get_default_config(),
gptj=GPTJConfig.get_default_config(),
logger=mlxu.WandBLogger.get_default_config(),
log_all_worker=False,
jax_distributed=JaxDistributedConfig.get_default_config(),
)
def main(argv):
JaxDistributedConfig.initialize(FLAGS.jax_distributed)
variant = mlxu.get_user_flags(FLAGS, FLAGS_DEF)
flags_config_dict = mlxu.user_flags_to_config_dict(FLAGS, FLAGS_DEF)
logger = mlxu.WandBLogger(
config=FLAGS.logger,
variant=variant,
enable=FLAGS.log_all_worker or (jax.process_index() == 0),
)
set_random_seed(FLAGS.seed)
tokenizer = GPTJConfig.get_tokenizer(FLAGS.tokenizer)
dataset = DatasetFactory.load_dataset(FLAGS.train_dataset, tokenizer)
if FLAGS.load_dataset_state != '':
dataset.load_state_dict(mlxu.load_pickle(FLAGS.load_dataset_state))
if FLAGS.eval_steps > 0:
eval_dataset = DatasetFactory.load_dataset(
FLAGS.eval_dataset, dataset.tokenizer
)
eval_iterator = iter(eval_dataset)
seq_length = dataset.seq_length
if FLAGS.load_gptj_config != '':
gptj_config = GPTJConfig.load_config(FLAGS.load_gptj_config)
else:
gptj_config = GPTJConfig(**FLAGS.gptj)
if FLAGS.update_gptj_config != '':
gptj_config.update(dict(eval(FLAGS.update_gptj_config)))
gptj_config.update(dict(
bos_token_id=dataset.tokenizer.bos_token_id,
eos_token_id=dataset.tokenizer.eos_token_id,
))
if gptj_config.vocab_size < dataset.vocab_size:
gptj_config.update(dict(vocab_size=dataset.vocab_size))
model = FlaxGPTJForCausalLMModule(
gptj_config, dtype=get_float_dtype_by_name(FLAGS.dtype)
)
optimizer, optimizer_info = OptimizerFactory.get_optimizer(
FLAGS.optimizer,
get_weight_decay_mask(GPTJConfig.get_weight_decay_exclusions()),
)
def create_trainstate_from_params(params):
return TrainState.create(params=params, tx=optimizer, apply_fn=None)
def init_fn(rng):
rng_generator = JaxRNG(rng)
params = model.init(
input_ids=jnp.zeros((4, seq_length), dtype=jnp.int32),
position_ids=jnp.zeros((4, seq_length), dtype=jnp.int32),
attention_mask=jnp.ones((4, seq_length), dtype=jnp.int32),
rngs=rng_generator(gptj_config.rng_keys()),
)
return TrainState.create(params=params, tx=optimizer, apply_fn=None)
def train_step(train_state, rng, batch):
rng_generator = JaxRNG(rng)
batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
def loss_and_accuracy(params):
logits = model.apply(
params, batch['input_tokens'], deterministic=False,
rngs=rng_generator(gptj_config.rng_keys()),
).logits
return cross_entropy_loss_and_accuracy(
logits, batch['target_tokens'], batch['loss_masks']
)
grad_fn = jax.value_and_grad(loss_and_accuracy, has_aux=True)
(loss, accuracy), grads = grad_fn(train_state.params)
train_state = train_state.apply_gradients(grads=grads)
metrics = dict(
loss=loss,
accuracy=accuracy,
learning_rate=optimizer_info['learning_rate_schedule'](train_state.step),
gradient_norm=global_norm(grads),
param_norm=global_norm(train_state.params),
)
return train_state, rng_generator(), metrics
def eval_step(train_state, rng, batch):
rng_generator = JaxRNG(rng)
batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
logits = model.apply(
train_state.params, batch['input_tokens'], deterministic=True,
rngs=rng_generator(gptj_config.rng_keys()),
).logits
loss, accuracy = cross_entropy_loss_and_accuracy(
logits, batch['target_tokens'], batch['loss_masks']
)
metrics = dict(
eval_loss=loss,
eval_accuracy=accuracy,
)
return rng_generator(), metrics
train_state_shapes = jax.eval_shape(init_fn, next_rng())
train_state_partition = match_partition_rules(
GPTJConfig.get_partition_rules(), train_state_shapes
)
shard_fns, gather_fns = make_shard_and_gather_fns(
train_state_partition, train_state_shapes
)
checkpointer = StreamingCheckpointer(
FLAGS.checkpointer, logger.output_dir,
enable=jax.process_index() == 0,
)
sharded_init_fn = pjit(
init_fn,
in_shardings=PS(),
out_shardings=train_state_partition
)
sharded_create_trainstate_from_params = pjit(
create_trainstate_from_params,
in_shardings=(train_state_partition.params, ),
out_shardings=train_state_partition,
donate_argnums=(0, ),
)
sharded_train_step = pjit(
train_step,
in_shardings=(train_state_partition, PS(), PS()),
out_shardings=(train_state_partition, PS(), PS()),
donate_argnums=(0, 1),
)
sharded_eval_step = pjit(
eval_step,
in_shardings=(train_state_partition, PS(), PS()),
out_shardings=(PS(), PS()),
donate_argnums=(1,),
)
def save_checkpoint(train_state, milestone=False):
step = int(jax.device_get(train_state.step))
metadata = dict(
step=step,
variant=variant,
flags=flags_config_dict,
gptj_config=gptj_config.to_dict(),
)
checkpointer.save_all(
train_state=train_state,
gather_fns=gather_fns,
metadata=metadata,
dataset=dataset.get_state_dict(),
milestone=milestone,
)
mesh = GPTJConfig.get_jax_mesh(FLAGS.mesh_dim)
with mesh:
train_state, restored_params = None, None
if FLAGS.load_checkpoint != '':
load_type, load_path = FLAGS.load_checkpoint.split('::', 1)
if load_type == 'huggingface':
restored_params = tree_apply(
shard_fns.params, gptj_config.load_pretrained(load_path)
)
train_state = None
else:
train_state, restored_params = checkpointer.load_trainstate_checkpoint(
FLAGS.load_checkpoint, train_state_shapes, shard_fns
)
if train_state is None and restored_params is None:
# Initialize from scratch
train_state = sharded_init_fn(next_rng())
elif train_state is None and restored_params is not None:
# Restore from params but initialize train_state
train_state = sharded_create_trainstate_from_params(restored_params)
del restored_params
start_step = int(jax.device_get(train_state.step))
if FLAGS.save_model_freq > 0:
save_checkpoint(train_state)
sharded_rng = next_rng()
step_counter = trange(start_step, FLAGS.total_steps, ncols=0)
for step, (batch, dataset_metrics) in zip(step_counter, dataset):
train_state, sharded_rng, metrics = sharded_train_step(
train_state, sharded_rng, batch
)
if step % FLAGS.log_freq == 0:
if FLAGS.eval_steps > 0:
eval_metric_list = []
for _ in range(FLAGS.eval_steps):
eval_batch, _ = next(eval_iterator)
sharded_rng, eval_metrics = sharded_eval_step(
train_state, sharded_rng, eval_batch
)
eval_metric_list.append(eval_metrics)
metrics.update(average_metrics(eval_metric_list))
log_metrics = {"step": step}
log_metrics.update(metrics)
log_metrics.update(dataset_metrics)
log_metrics = jax.device_get(log_metrics)
logger.log(log_metrics)
tqdm.write("\n" + pprint.pformat(log_metrics) + "\n")
if FLAGS.save_milestone_freq > 0 and (step + 1) % FLAGS.save_milestone_freq == 0:
save_checkpoint(train_state, milestone=True)
elif FLAGS.save_model_freq > 0 and (step + 1) % FLAGS.save_model_freq == 0:
save_checkpoint(train_state)
if FLAGS.save_model_freq > 0:
save_checkpoint(train_state)
if __name__ == "__main__":
mlxu.run(main)

View File

@@ -0,0 +1,338 @@
# Copyright 2022 EleutherAI and The HuggingFace Inc. team. All rights reserved.
# Copyright 2023 Xinyang Geng
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This script converts LLaMA model checkpoint trained by EsayLM to the
# HuggingFace transformers LLaMA PyTorch format, which can then be loaded
# by HuggingFace transformers.
import gc
import json
import math
import os
import shutil
import numpy as np
import mlxu
import jax
import jax.numpy as jnp
import flax
from flax.traverse_util import flatten_dict
import torch
from transformers import LlamaConfig, LlamaForCausalLM
from EasyLM.checkpoint import StreamingCheckpointer
from EasyLM.jax_utils import float_tensor_to_dtype
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
load_checkpoint='',
tokenizer_path='',
model_size='13b',
output_dir='',
)
LLAMA_STANDARD_CONFIGS = {
'small': {
'vocab_size': 64256,
'dim': 768,
'intermediate_size': 3072,
'n_layers': 12,
'n_heads': 12,
'norm_eps': 1e-6,
},
'medium': {
'vocab_size': 64256,
'dim': 1024,
'intermediate_size': 4096,
'n_layers': 24,
'n_heads': 16,
'norm_eps': 1e-6,
},
'large': {
'vocab_size': 64256,
'dim': 1536,
'intermediate_size': 6144,
'n_layers': 24,
'n_heads': 16,
'norm_eps': 1e-6,
},
'xlarge': {
'vocab_size': 64256,
'dim': 2048,
'intermediate_size': 8192,
'n_layers': 24,
'n_heads': 32,
'norm_eps': 1e-6,
},
'1b': {
'vocab_size': 64256,
'dim': 2048,
'intermediate_size': 5504,
'n_layers': 22,
'n_heads': 16,
'norm_eps': 1e-6,
},
'3b': {
'vocab_size': 64256,
'dim': 3200,
'intermediate_size': 8640,
'n_layers': 26,
'n_heads': 32,
'norm_eps': 1e-6,
},
'7b': {
'vocab_size': 64256,
'dim': 4096,
'intermediate_size': 11008,
'n_layers': 32,
'n_heads': 32,
'norm_eps': 1e-6,
},
'13b': {
'vocab_size': 64256,
'dim': 5120,
'intermediate_size': 13824,
'n_layers': 40,
'n_heads': 40,
'norm_eps': 1e-6,
},
'30b': {
'vocab_size': 64256,
'dim': 6656,
'intermediate_size': 17920,
'n_layers': 60,
'n_heads': 52,
'norm_eps': 1e-6,
},
'65b': {
'vocab_size': 64256,
'dim': 8192,
'intermediate_size': 22016,
'n_layers': 80,
'n_heads': 64,
'norm_eps': 1e-5,
},
}
def match_keywords(string, positives, negatives):
for positive in positives:
if positive not in string:
return False
for negative in negatives:
if negative in string:
return False
return True
def load_and_convert_checkpoint(path):
_, flax_params = StreamingCheckpointer.load_trainstate_checkpoint(path)
flax_params = flatten_dict(flax_params['params'], sep='.')
torch_params = {}
for key, tensor in flax_params.items():
if match_keywords(key, ["kernel"], ["norm", 'ln_f']):
tensor = tensor.T
torch_params[key] = torch.tensor(
float_tensor_to_dtype(tensor, 'fp32'), dtype=torch.float16
)
return torch_params
def read_json(path):
with open(path, "r") as f:
return json.load(f)
def write_json(text, path):
with open(path, "w") as f:
json.dump(text, f)
def write_model(loaded, model_path, model_size):
os.makedirs(model_path, exist_ok=True)
tmp_model_path = os.path.join(model_path, "tmp")
os.makedirs(tmp_model_path, exist_ok=True)
params = LLAMA_STANDARD_CONFIGS[model_size]
n_layers = params["n_layers"]
n_heads = params["n_heads"]
dim = params["dim"]
dims_per_head = dim // n_heads
base = 10000.0
inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head))
# permute for sliced rotary
def permute(w):
return w.view(n_heads, dim // n_heads // 2, 2, dim).transpose(1, 2).reshape(dim, dim)
param_count = 0
index_dict = {"weight_map": {}}
for layer_i in range(n_layers):
filename = f"pytorch_model-{layer_i + 1}-of-{n_layers + 1}.bin"
state_dict = {
f"model.layers.{layer_i}.self_attn.q_proj.weight": permute(
loaded[f"transformer.h.{layer_i}.attention.wq.kernel"]
),
f"model.layers.{layer_i}.self_attn.k_proj.weight": permute(
loaded[f"transformer.h.{layer_i}.attention.wk.kernel"]
),
f"model.layers.{layer_i}.self_attn.v_proj.weight": loaded[f"transformer.h.{layer_i}.attention.wv.kernel"],
f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded[f"transformer.h.{layer_i}.attention.wo.kernel"],
f"model.layers.{layer_i}.mlp.gate_proj.weight": loaded[f"transformer.h.{layer_i}.feed_forward.w1.kernel"],
f"model.layers.{layer_i}.mlp.down_proj.weight": loaded[f"transformer.h.{layer_i}.feed_forward.w2.kernel"],
f"model.layers.{layer_i}.mlp.up_proj.weight": loaded[f"transformer.h.{layer_i}.feed_forward.w3.kernel"],
f"model.layers.{layer_i}.input_layernorm.weight": loaded[f"transformer.h.{layer_i}.attention_norm.kernel"],
f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[f"transformer.h.{layer_i}.ffn_norm.kernel"],
}
state_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq
for k, v in state_dict.items():
index_dict["weight_map"][k] = filename
param_count += v.numel()
torch.save(state_dict, os.path.join(tmp_model_path, filename))
filename = f"pytorch_model-{n_layers + 1}-of-{n_layers + 1}.bin"
# Unsharded
state_dict = {
"model.embed_tokens.weight": loaded["transformer.wte.embedding"],
"model.norm.weight": loaded["transformer.ln_f.kernel"],
"lm_head.weight": loaded["lm_head.kernel"],
}
for k, v in state_dict.items():
index_dict["weight_map"][k] = filename
param_count += v.numel()
torch.save(state_dict, os.path.join(tmp_model_path, filename))
# Write configs
index_dict["metadata"] = {"total_size": param_count * 2}
write_json(index_dict, os.path.join(tmp_model_path, "pytorch_model.bin.index.json"))
config = LlamaConfig(
vocab_size=params["vocab_size"],
hidden_size=dim,
intermediate_size=params["intermediate_size"],
num_attention_heads=params["n_heads"],
num_hidden_layers=params["n_layers"],
rms_norm_eps=params["norm_eps"],
)
config.save_pretrained(tmp_model_path)
# Make space so we can load the model properly now.
del state_dict
del loaded
gc.collect()
print("Loading the checkpoint in a Llama model.")
model = LlamaForCausalLM.from_pretrained(tmp_model_path, torch_dtype=torch.float16)
# Avoid saving this as part of the config.
print("Model parameter count", model.num_parameters())
del model.config._name_or_path
print("Saving in the Transformers format.")
model.save_pretrained(model_path, safe_serialization=True)
shutil.rmtree(tmp_model_path)
def write_tokenizer(tokenizer_path, input_tokenizer_path):
print(f"Fetching the tokenizer from {input_tokenizer_path}.")
os.makedirs(tokenizer_path, exist_ok=True)
write_json(
{
"bos_token": {
"content": "<s>",
"lstrip": False,
"normalized": True,
"rstrip": False,
"single_word": False
},
"eos_token": {
"content": "</s>",
"lstrip": False,
"normalized": True,
"rstrip": False,
"single_word": False
},
"unk_token": {
"content": "<unk>",
"lstrip": False,
"normalized": True,
"rstrip": False,
"single_word": False
},
},
os.path.join(tokenizer_path, "special_tokens_map.json")
)
write_json(
{
"add_bos_token": True,
"add_eos_token": False,
"model_max_length": 2048,
"pad_token": None,
"sp_model_kwargs": {},
"tokenizer_class": "LlamaTokenizer",
"clean_up_tokenization_spaces": False,
"bos_token": {
"__type": "AddedToken",
"content": "<s>",
"lstrip": False,
"normalized": True,
"rstrip": False,
"single_word": False
},
"eos_token": {
"__type": "AddedToken",
"content": "</s>",
"lstrip": False,
"normalized": True,
"rstrip": False,
"single_word": False
},
"unk_token": {
"__type": "AddedToken",
"content": "<unk>",
"lstrip": False,
"normalized": True,
"rstrip": False,
"single_word": False
},
},
os.path.join(tokenizer_path, "tokenizer_config.json"),
)
shutil.copyfile(input_tokenizer_path, os.path.join(tokenizer_path, "tokenizer.model"))
def main(argv):
assert FLAGS.load_checkpoint != "" and FLAGS.output_dir != ""# and FLAGS.tokenizer_path != ""
assert FLAGS.model_size in LLAMA_STANDARD_CONFIGS
# write_tokenizer(
# tokenizer_path=FLAGS.output_dir,
# input_tokenizer_path=FLAGS.tokenizer_path,
# )
write_model(
load_and_convert_checkpoint(FLAGS.load_checkpoint),
model_path=FLAGS.output_dir,
model_size=FLAGS.model_size,
)
if __name__ == "__main__":
mlxu.run(main)

View File

@@ -0,0 +1,196 @@
"""
Usage:
python convert_hf_to_easylm.py \
--checkpoint_dir /path/hf_format_dir/ \
--output_file /path/easylm_format.stream \
--model_size 7b \
--streaming
"""
import time
from pathlib import Path
import argparse
import mlxu
import torch
import flax
from EasyLM.checkpoint import StreamingCheckpointer
LLAMA_STANDARD_CONFIGS = {
'1b': {
'dim': 2048,
'intermediate_size': 5504,
'n_layers': 22,
'n_heads': 16,
'norm_eps': 1e-6,
},
'3b': {
'dim': 3200,
'intermediate_size': 8640,
'n_layers': 26,
'n_heads': 32,
'norm_eps': 1e-6,
},
"7b": {
"dim": 4096,
"intermediate_size": 11008,
"n_layers": 32,
"n_heads": 32,
"norm_eps": 1e-6,
},
"13b": {
"dim": 5120,
"intermediate_size": 13824,
"n_layers": 40,
"n_heads": 40,
"norm_eps": 1e-6,
},
"30b": {
"dim": 6656,
"intermediate_size": 17920,
"n_layers": 60,
"n_heads": 52,
"norm_eps": 1e-6,
},
"65b": {
"dim": 8192,
"intermediate_size": 22016,
"n_layers": 80,
"n_heads": 64,
"norm_eps": 1e-5,
},
}
def inverse_permute(params, w):
n_layers = params["n_layers"]
n_heads = params["n_heads"]
dim = params["dim"]
reshaped_w = w.reshape(n_heads, 2, dim // n_heads // 2, dim)
transposed_w = reshaped_w.transpose(0, 2, 1, 3)
inverted_w = transposed_w.reshape(dim, dim)
return inverted_w
def main(args):
start = time.time()
params = LLAMA_STANDARD_CONFIGS[args.model_size]
ckpt_paths = sorted(Path(args.checkpoint_dir).glob("*.bin"))
ckpt = {}
for i, ckpt_path in enumerate(ckpt_paths):
checkpoint = torch.load(ckpt_path, map_location="cpu")
for k, v in checkpoint.items():
if k.startswith("model."):
k = k[6:]
ckpt[k] = v
print(f"Start convert weight to easylm format...")
jax_weights = {
"transformer": {
"wte": {"embedding": ckpt["embed_tokens.weight"].numpy()},
"ln_f": {"kernel": ckpt["norm.weight"].numpy()},
"h": {
"%d"
% (layer): {
"attention": {
"wq": {
"kernel": inverse_permute(
params,
ckpt[f"layers.{layer}.self_attn.q_proj.weight"].numpy(),
).transpose()
},
"wk": {
"kernel": inverse_permute(
params,
ckpt[f"layers.{layer}.self_attn.k_proj.weight"].numpy(),
).transpose()
},
"wv": {
"kernel": ckpt[f"layers.{layer}.self_attn.v_proj.weight"]
.numpy()
.transpose()
},
"wo": {
"kernel": ckpt[f"layers.{layer}.self_attn.o_proj.weight"]
.numpy()
.transpose()
},
},
"feed_forward": {
"w1": {
"kernel": ckpt[f"layers.{layer}.mlp.gate_proj.weight"]
.numpy()
.transpose()
},
"w2": {
"kernel": ckpt[f"layers.{layer}.mlp.down_proj.weight"]
.numpy()
.transpose()
},
"w3": {
"kernel": ckpt[f"layers.{layer}.mlp.up_proj.weight"]
.numpy()
.transpose()
},
},
"attention_norm": {
"kernel": ckpt[f"layers.{layer}.input_layernorm.weight"].numpy()
},
"ffn_norm": {
"kernel": ckpt[
f"layers.{layer}.post_attention_layernorm.weight"
].numpy()
},
}
for layer in range(params["n_layers"])
},
},
"lm_head": {"kernel": ckpt["lm_head.weight"].numpy().transpose()},
}
print(f"Convert weight to easylm format finished...")
print(f"Start to save...")
if args.streaming:
StreamingCheckpointer.save_train_state_to_file(jax_weights, args.output_file)
else:
with mlxu.open_file(args.output_file, "wb") as fout:
fout.write(flax.serialization.msgpack_serialize(jax_weights, in_place=True))
print(
f"Save finished!!! take time: {time.time() - start} save path: {args.output_file}"
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="hf to easylm format script")
parser.add_argument(
"--checkpoint_dir",
type=str,
help="Need to be converted model weight dir. it is a dir",
)
parser.add_argument(
"--output_file", type=str, help="Save model weight file path, it is a file."
)
parser.add_argument(
"--model_size",
type=str,
default="7b",
choices=["7b", "13b", "30b", "65b"],
help="model size",
)
parser.add_argument(
"--streaming",
action="store_true",
default=True,
help="whether is model weight saved stream format",
)
args = parser.parse_args()
print(f"checkpoint_dir: {args.checkpoint_dir}")
print(f"output_file: {args.output_file}")
print(f"model_size: {args.model_size}")
print(f"streaming: {args.streaming}")
main(args)

View File

@@ -0,0 +1,68 @@
# This script converts the standrd LLaMA PyTorch checkpoint released by Meta
# to the EasyLM checkpoint format. The converted checkpoint can then be loaded
# by EasyLM for fine-tuning or inference.
# This script is largely borrow from https://github.com/Sea-Snell/JAX_llama
from pathlib import Path
import json
import numpy as np
import torch
import flax
import mlxu
from EasyLM.checkpoint import StreamingCheckpointer
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
checkpoint_dir='',
output_file='',
streaming=True,
)
def main(argv):
ckpt_paths = sorted(Path(FLAGS.checkpoint_dir).glob("*.pth"))
ckpts = {}
for i, ckpt_path in enumerate(ckpt_paths):
checkpoint = torch.load(ckpt_path, map_location="cpu")
ckpts[int(ckpt_path.name.split('.', maxsplit=2)[1])] = checkpoint
ckpts = [ckpts[i] for i in sorted(list(ckpts.keys()))]
with open(Path(FLAGS.checkpoint_dir) / "params.json", "r") as f:
params = json.loads(f.read())
jax_weights = {
'transformer': {
'wte': {'embedding': np.concatenate([ckpt['tok_embeddings.weight'].numpy() for ckpt in ckpts], axis=1)},
'ln_f': {'kernel': ckpts[0]['norm.weight'].numpy()},
'h': {
'%d' % (layer): {
'attention': {
'wq': {'kernel': np.concatenate([ckpt['layers.%d.attention.wq.weight' % (layer)].numpy() for ckpt in ckpts], axis=0).transpose()},
'wk': {'kernel': np.concatenate([ckpt['layers.%d.attention.wk.weight' % (layer)].numpy() for ckpt in ckpts], axis=0).transpose()},
'wv': {'kernel': np.concatenate([ckpt['layers.%d.attention.wv.weight' % (layer)].numpy() for ckpt in ckpts], axis=0).transpose()},
'wo': {'kernel': np.concatenate([ckpt['layers.%d.attention.wo.weight' % (layer)].numpy() for ckpt in ckpts], axis=1).transpose()},
},
'feed_forward': {
'w1': {'kernel': np.concatenate([ckpt['layers.%d.feed_forward.w1.weight' % (layer)].numpy() for ckpt in ckpts], axis=0).transpose()},
'w2': {'kernel': np.concatenate([ckpt['layers.%d.feed_forward.w2.weight' % (layer)].numpy() for ckpt in ckpts], axis=1).transpose()},
'w3': {'kernel': np.concatenate([ckpt['layers.%d.feed_forward.w3.weight' % (layer)].numpy() for ckpt in ckpts], axis=0).transpose()},
},
'attention_norm': {'kernel': ckpts[0]['layers.%d.attention_norm.weight' % (layer)].numpy()},
'ffn_norm': {'kernel': ckpts[0]['layers.%d.ffn_norm.weight' % (layer)].numpy()},
}
for layer in range(params['n_layers'])},
},
'lm_head': {'kernel': np.concatenate([ckpt['output.weight'].numpy() for ckpt in ckpts], axis=0).transpose()},
}
if FLAGS.streaming:
StreamingCheckpointer.save_train_state_to_file(
jax_weights, FLAGS.output_file
)
else:
with mlxu.open_file(FLAGS.output_file, 'wb') as fout:
fout.write(flax.serialization.msgpack_serialize(jax_weights, in_place=True))
if __name__ == '__main__':
mlxu.run(main)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,386 @@
import pprint
from functools import partial
import numpy as np
import mlxu
import jax
import jax.numpy as jnp
from jax.experimental.pjit import pjit
from jax.sharding import PartitionSpec as PS
import optax
from transformers import GenerationConfig, FlaxLogitsProcessorList
from EasyLM.checkpoint import StreamingCheckpointer
from EasyLM.serving import LMServer
from EasyLM.jax_utils import (
JaxRNG, JaxDistributedConfig, next_rng, match_partition_rules, tree_apply,
set_random_seed, get_float_dtype_by_name, make_shard_and_gather_fns,
with_sharding_constraint, FlaxTemperatureLogitsWarper
)
from EasyLM.models.llama.llama_model import LLaMAConfig, FlaxLLaMAForCausalLM
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
seed=42,
initialize_jax_distributed=False,
mesh_dim='1,-1,1',
dtype='bf16',
input_length=1024,
seq_length=2048,
top_k=50,
top_p=1.0,
do_sample=True,
num_beams=1,
add_bos_token=True,
load_llama_config='',
load_checkpoint='',
tokenizer=LLaMAConfig.get_tokenizer_config(),
lm_server=LMServer.get_default_config(),
jax_distributed=JaxDistributedConfig.get_default_config(),
)
def main(argv):
JaxDistributedConfig.initialize(FLAGS.jax_distributed)
set_random_seed(FLAGS.seed)
prefix_tokenizer = LLaMAConfig.get_tokenizer(
FLAGS.tokenizer, truncation_side='left', padding_side='left'
)
tokenizer = LLaMAConfig.get_tokenizer(
FLAGS.tokenizer, truncation_side='right', padding_side='right'
)
with jax.default_device(jax.devices("cpu")[0]):
llama_config = LLaMAConfig.load_config(FLAGS.load_llama_config)
_, params = StreamingCheckpointer.load_trainstate_checkpoint(
FLAGS.load_checkpoint, disallow_trainstate=True
)
hf_model = FlaxLLaMAForCausalLM(
llama_config,
input_shape=(1, FLAGS.seq_length),
seed=FLAGS.seed,
_do_init=False
)
model_ps = match_partition_rules(
LLaMAConfig.get_partition_rules(), params
)
shard_fns, _ = make_shard_and_gather_fns(
model_ps, get_float_dtype_by_name(FLAGS.dtype)
)
@partial(
pjit,
in_shardings=(model_ps, PS(), PS()),
out_shardings=(PS(), PS(), PS())
)
def forward_loglikelihood(params, rng, batch):
batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
rng_generator = JaxRNG(rng)
input_tokens = batch['input_tokens']
output_tokens = batch['output_tokens']
input_mask = batch['input_mask']
output_mask = batch['output_mask']
logits = hf_model.module.apply(
params, input_tokens, attention_mask=input_mask,
deterministic=True, rngs=rng_generator(llama_config.rng_keys()),
).logits
# if llama_config.n_real_tokens is not None:
# logits = logits.at[:, :, llama_config.n_real_tokens:].set(-1e8)
loglikelihood = -optax.softmax_cross_entropy_with_integer_labels(
logits, output_tokens
)
loglikelihood = jnp.sum(loglikelihood * output_mask, axis=-1)
match_count = jnp.sum(
(jnp.argmax(logits, axis=-1) == output_tokens) * output_mask,
axis=-1
)
total = jnp.sum(output_mask, axis=-1)
is_greedy = match_count == total
return loglikelihood, is_greedy, rng_generator()
@partial(
pjit,
in_shardings=(model_ps, PS(), PS(), PS()),
out_shardings=(PS(), PS())
)
def forward_generate(params, rng, batch, temperature):
batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
rng_generator = JaxRNG(rng)
output = hf_model.generate(
batch['input_tokens'],
attention_mask=batch['attention_mask'],
params=params['params'],
prng_key=rng_generator(),
logits_processor=FlaxLogitsProcessorList(
[FlaxTemperatureLogitsWarper(temperature)]
),
generation_config=GenerationConfig(
max_new_tokens=FLAGS.seq_length - FLAGS.input_length,
pad_token_id=tokenizer.eos_token_id,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
do_sample=FLAGS.do_sample,
num_beams=FLAGS.num_beams,
top_k=FLAGS.top_k,
top_p=FLAGS.top_p,
)
).sequences[:, batch['input_tokens'].shape[1]:]
return output, rng_generator()
@partial(
pjit,
in_shardings=(model_ps, PS(), PS()),
out_shardings=(PS(), PS())
)
def forward_greedy_generate(params, rng, batch):
batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
rng_generator = JaxRNG(rng)
output = hf_model.generate(
batch['input_tokens'],
attention_mask=batch['attention_mask'],
params=params['params'],
prng_key=rng_generator(),
generation_config=GenerationConfig(
max_new_tokens=FLAGS.seq_length - FLAGS.input_length,
pad_token_id=tokenizer.eos_token_id,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
do_sample=False,
num_beams=1,
)
).sequences[:, batch['input_tokens'].shape[1]:]
return output, rng_generator()
mesh = LLaMAConfig.get_jax_mesh(FLAGS.mesh_dim)
with mesh:
params = tree_apply(shard_fns, params)
sharded_rng = next_rng()
class ModelServer(LMServer):
@staticmethod
def loglikelihood(prefix_text, text):
nonlocal sharded_rng
prefix = prefix_tokenizer(
prefix_text,
padding='max_length',
truncation=True,
max_length=FLAGS.input_length,
return_tensors='np',
)
inputs = tokenizer(
text,
padding='max_length',
truncation=True,
max_length=FLAGS.seq_length - FLAGS.input_length,
return_tensors='np',
)
output_tokens = np.concatenate([prefix.input_ids, inputs.input_ids], axis=1)
bos_tokens = np.full(
(output_tokens.shape[0], 1), tokenizer.bos_token_id, dtype=np.int32
)
input_tokens = np.concatenate([bos_tokens, output_tokens[:, :-1]], axis=-1)
input_mask = np.concatenate(
[prefix.attention_mask, inputs.attention_mask], axis=1
)
if FLAGS.add_bos_token:
bos_mask = np.ones_like(input_mask[:, :1])
else:
bos_mask = np.zeros_like(input_mask[:, :1])
input_mask = np.concatenate([bos_mask, input_mask[:, :-1]], axis=1)
output_mask = np.concatenate(
[np.zeros_like(prefix.attention_mask), inputs.attention_mask], axis=1
)
batch = dict(
input_tokens=input_tokens,
output_tokens=output_tokens,
input_mask=input_mask,
output_mask=output_mask,
)
with mesh:
loglikelihood, is_greedy, sharded_rng = forward_loglikelihood(
params, sharded_rng, batch
)
loglikelihood, is_greedy = jax.device_get((loglikelihood, is_greedy))
return loglikelihood, is_greedy
@staticmethod
def loglikelihood_rolling(text):
nonlocal sharded_rng
inputs = tokenizer(
text,
padding='longest',
truncation=False,
max_length=np.iinfo(np.int32).max,
return_tensors='np',
)
batch_size = inputs.input_ids.shape[0]
output_tokens = inputs.input_ids
attention_mask = inputs.attention_mask
if output_tokens.shape[1] < FLAGS.seq_length:
padding_length = FLAGS.seq_length - output_tokens.shape[1]
pad_tokens = np.full(
(batch_size, padding_length), tokenizer.pad_token_id, dtype=np.int32
)
output_tokens = np.concatenate([output_tokens, pad_tokens], axis=-1)
pad_mask = np.zeros(
(batch_size, padding_length), dtype=inputs.attention_mask.dtype
)
attention_mask = np.concatenate([attention_mask, pad_mask], axis=-1)
bos_tokens = np.full(
(batch_size, 1), tokenizer.bos_token_id, dtype=np.int32
)
input_tokens = np.concatenate([bos_tokens, output_tokens[:, :-1]], axis=-1)
bos_mask = np.ones((batch_size, 1), dtype=inputs.attention_mask.dtype)
total_seq_length = output_tokens.shape[1]
total_loglikelihood = 0.0
total_is_greedy = True
# Sliding window
for i in range(0, total_seq_length, FLAGS.seq_length):
# Last window
if i + FLAGS.seq_length > total_seq_length:
last_output_mask = np.copy(attention_mask[:, -FLAGS.seq_length:])
last_output_mask[:, :i - total_seq_length] = 0.0
batch = dict(
input_tokens=input_tokens[:, -FLAGS.seq_length:],
output_tokens=output_tokens[:, -FLAGS.seq_length:],
input_mask=attention_mask[:, -FLAGS.seq_length:],
output_mask=last_output_mask,
)
# Normal window
else:
batch = dict(
input_tokens=input_tokens[:, i:i + FLAGS.seq_length],
output_tokens=output_tokens[:, i:i + FLAGS.seq_length],
input_mask=attention_mask[:, i:i + FLAGS.seq_length],
output_mask=attention_mask[:, i:i + FLAGS.seq_length],
)
with mesh:
loglikelihood, is_greedy, sharded_rng = forward_loglikelihood(
params, sharded_rng, batch
)
loglikelihood, is_greedy = jax.device_get((loglikelihood, is_greedy))
total_loglikelihood += loglikelihood
total_is_greedy = np.logical_and(is_greedy, total_is_greedy)
return total_loglikelihood, total_is_greedy
@staticmethod
def generate(text, temperature):
nonlocal sharded_rng
inputs = prefix_tokenizer(
text,
padding='max_length',
truncation=True,
max_length=FLAGS.input_length,
return_tensors='np',
)
input_tokens = inputs.input_ids
input_mask = inputs.attention_mask
if FLAGS.add_bos_token:
input_tokens[:, 0] = tokenizer.bos_token_id
input_mask[:, 0] = 1
batch = dict(
input_tokens=input_tokens,
attention_mask=input_mask,
)
with mesh:
output, sharded_rng = forward_generate(
params, sharded_rng, batch, temperature
)
output = jax.device_get(output)
output_text = []
for text in list(tokenizer.batch_decode(output)):
if tokenizer.eos_token in text:
text = text.split(tokenizer.eos_token, maxsplit=1)[0]
output_text.append(text)
return output_text
@staticmethod
def greedy_until(prefix_text, until, max_length):
nonlocal sharded_rng
all_outputs = []
for pf, ut in zip(prefix_text, until):
if isinstance(ut, str):
ut = [ut]
total_length = 0
total_generated = ''
while total_length < max_length:
pf_tokens = tokenizer(
pf,
padding=False,
truncation=False,
max_length=np.iinfo(np.int32).max,
return_tensors='np',
)
input_tokens = pf_tokens.input_ids
attention_mask = pf_tokens.attention_mask
if input_tokens.shape[1] < FLAGS.input_length:
extra = FLAGS.input_length - input_tokens.shape[1]
pad_tokens = np.full(
(1, extra), tokenizer.pad_token_id, dtype=np.int32
)
input_tokens = np.concatenate(
[pad_tokens, input_tokens], axis=1
)
pad_attention = np.zeros((1, extra), dtype=attention_mask.dtype)
attention_mask = np.concatenate(
[pad_attention, attention_mask], axis=1
)
elif input_tokens.shape[1] > FLAGS.input_length:
input_tokens = input_tokens[:, -FLAGS.input_length:]
attention_mask = attention_mask[:, -FLAGS.input_length:]
if FLAGS.add_bos_token:
input_tokens[:, 0] = tokenizer.bos_token_id
attention_mask[:, 0] = 1
batch = dict(input_tokens=input_tokens, attention_mask=attention_mask)
with mesh:
output, sharded_rng = forward_greedy_generate(
params, sharded_rng, batch
)
output = jax.device_get(output)
total_length += output.shape[1]
output_text = tokenizer.batch_decode(output)[0]
total_generated = total_generated + output_text
pf = pf + output_text
done = False
for s in ut:
if s in total_generated:
total_generated = total_generated.split(s, maxsplit=1)[0]
done = True
if done:
break
all_outputs.append(total_generated)
return all_outputs
server = ModelServer(FLAGS.lm_server)
server.run()
if __name__ == "__main__":
mlxu.run(main)

View File

@@ -0,0 +1,268 @@
import pprint
from functools import partial
from tqdm import tqdm, trange
import numpy as np
import mlxu
import jax
import jax.numpy as jnp
from jax.experimental.pjit import pjit
from jax.sharding import PartitionSpec as PS
from flax.training.train_state import TrainState
from EasyLM.data import DatasetFactory
from EasyLM.checkpoint import StreamingCheckpointer
from EasyLM.optimizers import OptimizerFactory
from EasyLM.jax_utils import (
JaxRNG, JaxDistributedConfig, next_rng, match_partition_rules,
cross_entropy_loss_and_accuracy, global_norm, get_float_dtype_by_name,
set_random_seed, average_metrics, get_weight_decay_mask,
make_shard_and_gather_fns, with_sharding_constraint,
)
from EasyLM.models.llama.llama_model import (
LLaMAConfig, FlaxLLaMAForCausalLMModule
)
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
seed=42,
mesh_dim='1,-1,1',
dtype='fp32',
param_dtype='fp32',
total_steps=10000,
load_llama_config='',
update_llama_config='',
load_checkpoint='',
load_dataset_state='',
log_freq=50,
save_model_freq=0,
save_milestone_freq=0,
eval_freq=0,
tokenizer=LLaMAConfig.get_tokenizer_config(),
train_dataset=DatasetFactory.get_default_config(),
eval_dataset=DatasetFactory.get_default_config(),
optimizer=OptimizerFactory.get_default_config(),
checkpointer=StreamingCheckpointer.get_default_config(),
llama=LLaMAConfig.get_default_config(),
logger=mlxu.WandBLogger.get_default_config(),
log_all_worker=False,
jax_distributed=JaxDistributedConfig.get_default_config(),
)
def main(argv):
JaxDistributedConfig.initialize(FLAGS.jax_distributed)
variant = mlxu.get_user_flags(FLAGS, FLAGS_DEF)
flags_config_dict = mlxu.user_flags_to_config_dict(FLAGS, FLAGS_DEF)
logger = mlxu.WandBLogger(
config=FLAGS.logger,
variant=variant,
enable=FLAGS.log_all_worker or (jax.process_index() == 0),
)
set_random_seed(FLAGS.seed)
tokenizer = LLaMAConfig.get_tokenizer(FLAGS.tokenizer)
dataset = DatasetFactory.load_dataset(FLAGS.train_dataset, tokenizer)
if FLAGS.load_dataset_state != '':
dataset.load_state_dict(mlxu.load_pickle(FLAGS.load_dataset_state))
if FLAGS.eval_freq > 0:
eval_dataset = DatasetFactory.load_dataset(
FLAGS.eval_dataset, dataset.tokenizer, eval_dataset=True
)
seq_length = dataset.seq_length
if FLAGS.load_llama_config != '':
llama_config = LLaMAConfig.load_config(FLAGS.load_llama_config)
else:
llama_config = LLaMAConfig(**FLAGS.llama)
if FLAGS.update_llama_config != '':
llama_config.update(dict(eval(FLAGS.update_llama_config)))
llama_config.update(dict(
bos_token_id=dataset.tokenizer.bos_token_id,
eos_token_id=dataset.tokenizer.eos_token_id,
))
if llama_config.vocab_size < dataset.vocab_size:
print("Updating model config vocab size from", llama_config.vocab_size, "to", dataset.vocab_size)
llama_config.update(dict(vocab_size=dataset.vocab_size))
model = FlaxLLaMAForCausalLMModule(
llama_config, dtype=get_float_dtype_by_name(FLAGS.dtype), param_dtype=get_float_dtype_by_name(FLAGS.param_dtype)
)
optimizer, optimizer_info = OptimizerFactory.get_optimizer(
FLAGS.optimizer,
get_weight_decay_mask(LLaMAConfig.get_weight_decay_exclusions())
)
def create_trainstate_from_params(params):
return TrainState.create(params=params, tx=optimizer, apply_fn=None)
def init_fn(rng):
rng_generator = JaxRNG(rng)
params = model.init(
input_ids=jnp.zeros((4, seq_length), dtype=jnp.int32),
position_ids=jnp.zeros((4, seq_length), dtype=jnp.int32),
attention_mask=jnp.ones((4, seq_length), dtype=jnp.int32),
rngs=rng_generator(llama_config.rng_keys()),
)
return TrainState.create(params=params, tx=optimizer, apply_fn=None)
def train_step(train_state, rng, batch):
rng_generator = JaxRNG(rng)
batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
def loss_and_accuracy(params):
logits = model.apply(
params, batch['input_tokens'], deterministic=False,
rngs=rng_generator(llama_config.rng_keys()),
).logits
return cross_entropy_loss_and_accuracy(
logits, batch['target_tokens'], batch['loss_masks']
)
grad_fn = jax.value_and_grad(loss_and_accuracy, has_aux=True)
(loss, accuracy), grads = grad_fn(train_state.params)
train_state = train_state.apply_gradients(grads=grads)
metrics = dict(
loss=loss,
accuracy=accuracy,
learning_rate=optimizer_info['learning_rate_schedule'](train_state.step),
gradient_norm=global_norm(grads),
param_norm=global_norm(train_state.params),
)
return train_state, rng_generator(), metrics
def eval_step(train_state, rng, batch):
rng_generator = JaxRNG(rng)
batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
logits = model.apply(
train_state.params, batch['input_tokens'], deterministic=True,
rngs=rng_generator(llama_config.rng_keys()),
).logits
loss, accuracy = cross_entropy_loss_and_accuracy(
logits, batch['target_tokens'], batch['loss_masks']
)
metrics = dict(
eval_loss=loss,
eval_accuracy=accuracy,
)
return rng_generator(), metrics
train_state_shapes = jax.eval_shape(init_fn, next_rng())
train_state_partition = match_partition_rules(
LLaMAConfig.get_partition_rules(), train_state_shapes
)
shard_fns, gather_fns = make_shard_and_gather_fns(
train_state_partition, train_state_shapes
)
checkpointer = StreamingCheckpointer(
FLAGS.checkpointer, logger.output_dir,
enable=jax.process_index() == 0,
)
sharded_init_fn = pjit(
init_fn,
in_shardings=PS(),
out_shardings=train_state_partition
)
sharded_create_trainstate_from_params = pjit(
create_trainstate_from_params,
in_shardings=(train_state_partition.params, ),
out_shardings=train_state_partition,
donate_argnums=(0, ),
)
sharded_train_step = pjit(
train_step,
in_shardings=(train_state_partition, PS(), PS()),
out_shardings=(train_state_partition, PS(), PS()),
donate_argnums=(0, 1),
)
sharded_eval_step = pjit(
eval_step,
in_shardings=(train_state_partition, PS(), PS()),
out_shardings=(PS(), PS()),
donate_argnums=(1,),
)
def save_checkpoint(train_state, milestone=False):
step = int(jax.device_get(train_state.step))
metadata = dict(
step=step,
variant=variant,
flags=flags_config_dict,
llama_config=llama_config.to_dict(),
)
checkpointer.save_all(
train_state=train_state,
gather_fns=gather_fns,
metadata=metadata,
dataset=dataset.get_state_dict(),
milestone=milestone,
)
mesh = LLaMAConfig.get_jax_mesh(FLAGS.mesh_dim)
with mesh:
train_state, restored_params = None, None
if FLAGS.load_checkpoint != '':
train_state, restored_params = checkpointer.load_trainstate_checkpoint(
FLAGS.load_checkpoint, train_state_shapes, shard_fns
)
if train_state is None and restored_params is None:
# Initialize from scratch
train_state = sharded_init_fn(next_rng())
elif train_state is None and restored_params is not None:
# Restore from params but initialize train_state
train_state = sharded_create_trainstate_from_params(restored_params)
del restored_params
start_step = int(jax.device_get(train_state.step))
if FLAGS.save_model_freq > 0:
save_checkpoint(train_state)
sharded_rng = next_rng()
step_counter = trange(start_step, FLAGS.total_steps, ncols=0)
for step, (batch, dataset_metrics) in zip(step_counter, dataset):
train_state, sharded_rng, metrics = sharded_train_step(
train_state, sharded_rng, batch
)
if FLAGS.eval_freq > 0 and (step + 1) % FLAGS.eval_freq == 0:
eval_metric_list = []
eval_iterator = iter(eval_dataset)
for eval_batch, _ in eval_iterator:
sharded_rng, eval_metrics = sharded_eval_step(
train_state, sharded_rng, eval_batch
)
eval_metric_list.append(eval_metrics)
metrics.update(average_metrics(eval_metric_list))
if FLAGS.log_freq > 0 and (step + 1) % FLAGS.log_freq == 0:
log_metrics = {"step": step + 1}
log_metrics.update(metrics)
log_metrics.update(dataset_metrics)
log_metrics = jax.device_get(log_metrics)
logger.log(log_metrics)
tqdm.write("\n" + pprint.pformat(log_metrics) + "\n")
if FLAGS.save_milestone_freq > 0 and (step + 1) % FLAGS.save_milestone_freq == 0:
save_checkpoint(train_state, milestone=True)
elif FLAGS.save_model_freq > 0 and (step + 1) % FLAGS.save_model_freq == 0:
save_checkpoint(train_state)
if FLAGS.save_model_freq > 0:
save_checkpoint(train_state)
if __name__ == "__main__":
mlxu.run(main)

View File

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,307 @@
import dataclasses
import pprint
from functools import partial
import re
from tqdm import tqdm, trange
import numpy as np
import mlxu
import jax
import jax.numpy as jnp
from jax.experimental.pjit import pjit, with_sharding_constraint
from jax.sharding import PartitionSpec as PS
from flax.training.train_state import TrainState
from EasyLM.data import DatasetFactory
from EasyLM.checkpoint import StreamingCheckpointer
from EasyLM.optimizers import OptimizerFactory
from EasyLM.jax_utils import (
JaxRNG, JaxDistributedConfig, next_rng, match_partition_rules, get_float_dtype_by_name,
cross_entropy_loss_and_accuracy, named_tree_map, global_norm,
set_random_seed, average_metrics, get_weight_decay_mask,
make_shard_and_gather_fns, tree_apply
)
from EasyLM.models.roberta.roberta_model import (
RobertaConfig, FlaxRobertaForMaskedLMModule
)
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
seed=42,
mesh_dim='-1,1,1',
dtype='fp32',
mask_token_probability=0.15,
total_steps=10000,
load_roberta_config='',
update_roberta_config='',
load_checkpoint='',
load_dataset_state='',
log_freq=50,
save_model_freq=0,
save_milestone_freq=0,
eval_steps=0,
tokenizer=RobertaConfig.get_tokenizer_config(),
train_dataset=DatasetFactory.get_default_config(),
eval_dataset=DatasetFactory.get_default_config(),
optimizer=OptimizerFactory.get_default_config(),
checkpointer=StreamingCheckpointer.get_default_config(),
roberta=RobertaConfig.get_default_config(),
logger=mlxu.WandBLogger.get_default_config(),
log_all_worker=False,
jax_distributed=JaxDistributedConfig.get_default_config(),
)
def main(argv):
JaxDistributedConfig.initialize(FLAGS.jax_distributed)
variant = mlxu.get_user_flags(FLAGS, FLAGS_DEF)
flags_config_dict = mlxu.user_flags_to_config_dict(FLAGS, FLAGS_DEF)
logger = mlxu.WandBLogger(
config=FLAGS.logger,
variant=variant,
enable=FLAGS.log_all_worker or (jax.process_index() == 0),
)
set_random_seed(FLAGS.seed)
tokenizer = RobertaConfig.get_tokenizer(FLAGS.tokenizer)
dataset = DatasetFactory.load_dataset(FLAGS.train_dataset, tokenizer)
if FLAGS.load_dataset_state != '':
dataset.load_state_dict(mlxu.load_pickle(FLAGS.load_dataset_state))
if FLAGS.eval_steps > 0:
eval_dataset = DatasetFactory.load_dataset(
FLAGS.eval_dataset, dataset.tokenizer
)
eval_iterator = iter(eval_dataset)
seq_length = dataset.seq_length
if FLAGS.load_roberta_config != '':
roberta_config = RobertaConfig.load_config(FLAGS.load_roberta_config)
else:
roberta_config = RobertaConfig(**FLAGS.roberta)
if FLAGS.update_roberta_config != '':
roberta_config.update(dict(eval(FLAGS.update_roberta_config)))
roberta_config.update(dict(
bos_token_id=dataset.tokenizer.bos_token_id,
eos_token_id=dataset.tokenizer.eos_token_id,
pad_token_id=dataset.tokenizer.pad_token_id,
vocab_size=dataset.vocab_size,
))
model = FlaxRobertaForMaskedLMModule(
roberta_config, dtype=get_float_dtype_by_name(FLAGS.dtype)
)
optimizer, optimizer_info = OptimizerFactory.get_optimizer(
FLAGS.optimizer,
get_weight_decay_mask(RobertaConfig.get_weight_decay_exclusions()),
)
def create_trainstate_from_params(params):
return TrainState.create(params=params, tx=optimizer, apply_fn=None)
def init_fn(rng):
rng_generator = JaxRNG(rng)
params = model.init(
input_ids=jnp.zeros((4, seq_length), dtype=jnp.int32),
position_ids=jnp.zeros((4, seq_length), dtype=jnp.int32),
attention_mask=jnp.ones((4, seq_length), dtype=jnp.int32),
token_type_ids=None,
head_mask=None,
rngs=rng_generator(roberta_config.rng_keys()),
)
return TrainState.create(params=params, tx=optimizer, apply_fn=None)
def train_step(train_state, rng, batch):
rng_generator = JaxRNG(rng)
tokens = with_sharding_constraint(batch['target_tokens'], PS(('dp', 'fsdp')))
def loss_and_accuracy(params):
altered_tokens = jax.random.uniform(
rng_generator(), shape=tokens.shape
) < FLAGS.mask_token_probability
random_uniform = jax.random.uniform(rng_generator(), shape=tokens.shape)
altered_by_mask = altered_tokens & (random_uniform < 0.8)
altered_by_random = altered_tokens & (random_uniform >= 0.8) & (random_uniform < 0.9)
inputs = jnp.where(altered_by_mask, dataset.tokenizer.mask_token_id, tokens)
random_tokens = jax.random.randint(
rng_generator(), shape=tokens.shape, minval=0, maxval=dataset.vocab_size
)
inputs = jnp.where(altered_by_random, random_tokens, inputs)
logits = model.apply(
params, inputs,
attention_mask=jnp.ones_like(inputs),
token_type_ids=None,
position_ids=None,
head_mask=None,
deterministic=False,
rngs=rng_generator(roberta_config.rng_keys()),
).logits
return cross_entropy_loss_and_accuracy(logits, tokens, valid=altered_tokens)
grad_fn = jax.value_and_grad(loss_and_accuracy, has_aux=True)
(loss, accuracy), grads = grad_fn(train_state.params)
train_state = train_state.apply_gradients(grads=grads)
metrics = dict(
loss=loss,
accuracy=accuracy,
learning_rate=optimizer_info['learning_rate_schedule'](train_state.step),
gradient_norm=global_norm(grads),
param_norm=global_norm(train_state.params),
)
return train_state, rng_generator(), metrics
def eval_step(train_state, rng, batch):
rng_generator = JaxRNG(rng)
tokens = with_sharding_constraint(batch['target_tokens'], PS(('dp', 'fsdp')))
altered_tokens = jax.random.uniform(
rng_generator(), shape=tokens.shape
) < FLAGS.mask_token_probability
random_uniform = jax.random.uniform(rng_generator(), shape=tokens.shape)
altered_by_mask = altered_tokens & (random_uniform < 0.8)
altered_by_random = altered_tokens & (random_uniform >= 0.8) & (random_uniform < 0.9)
inputs = jnp.where(altered_by_mask, dataset.tokenizer.mask_token_id, tokens)
random_tokens = jax.random.randint(
rng_generator(), shape=tokens.shape, minval=0, maxval=dataset.vocab_size
)
inputs = jnp.where(altered_by_random, random_tokens, inputs)
logits = model.apply(
train_state.params, inputs,
attention_mask=jnp.ones_like(inputs),
token_type_ids=None,
position_ids=None,
head_mask=None,
deterministic=False,
rngs=rng_generator(roberta_config.rng_keys()),
).logits
loss, accuracy = cross_entropy_loss_and_accuracy(logits, tokens, valid=altered_tokens)
metrics = dict(
eval_loss=loss,
eval_accuracy=accuracy,
)
return rng_generator(), metrics
train_state_shapes = jax.eval_shape(init_fn, next_rng())
train_state_partition = match_partition_rules(
RobertaConfig.get_partition_rules(), train_state_shapes
)
shard_fns, gather_fns = make_shard_and_gather_fns(
train_state_partition, train_state_shapes
)
checkpointer = StreamingCheckpointer(
FLAGS.checkpointer, logger.output_dir,
enable=jax.process_index() == 0
)
sharded_init_fn = pjit(
init_fn,
in_shardings=PS(),
out_shardings=train_state_partition
)
sharded_create_trainstate_from_params = pjit(
create_trainstate_from_params,
in_shardings=(train_state_partition.params, ),
out_shardings=train_state_partition,
donate_argnums=(0, ),
)
sharded_train_step = pjit(
train_step,
in_shardings=(train_state_partition, PS(), PS()),
out_shardings=(train_state_partition, PS(), PS()),
donate_argnums=(0, 1),
)
sharded_eval_step = pjit(
eval_step,
in_shardings=(train_state_partition, PS(), PS()),
out_shardings=(PS(), PS()),
donate_argnums=(1,),
)
def save_checkpoint(train_state, milestone=False):
step = int(jax.device_get(train_state.step))
metadata = dict(
step=step,
variant=variant,
flags=flags_config_dict,
roberta_config=roberta_config.to_dict(),
)
checkpointer.save_all(
train_state=train_state,
gather_fns=gather_fns,
metadata=metadata,
dataset=dataset.get_state_dict(),
milestone=milestone,
)
mesh = RobertaConfig.get_jax_mesh(FLAGS.mesh_dim)
with mesh:
train_state, restored_params = None, None
if FLAGS.load_checkpoint != '':
load_type, load_path = FLAGS.load_checkpoint.split('::', 1)
if load_type == 'huggingface':
restored_params = tree_apply(
shard_fns.params, roberta_config.load_pretrained(load_path)
)
train_state = None
else:
train_state, restored_params = checkpointer.load_trainstate_checkpoint(
FLAGS.load_checkpoint, train_state_shapes, shard_fns
)
if train_state is None and restored_params is None:
# Initialize from scratch
train_state = sharded_init_fn(next_rng())
elif train_state is None and restored_params is not None:
# Restore from params but initialize train_state
train_state = sharded_create_trainstate_from_params(restored_params)
del restored_params
start_step = int(jax.device_get(train_state.step))
if FLAGS.save_model_freq > 0:
save_checkpoint(train_state)
sharded_rng = next_rng()
step_counter = trange(start_step, FLAGS.total_steps, ncols=0)
for step, (batch, dataset_metrics) in zip(step_counter, dataset):
train_state, sharded_rng, metrics = sharded_train_step(
train_state, sharded_rng, batch
)
if step % FLAGS.log_freq == 0:
if FLAGS.eval_steps > 0:
eval_metric_list = []
for _ in range(FLAGS.eval_steps):
eval_batch, _ = next(eval_iterator)
sharded_rng, eval_metrics = sharded_eval_step(
train_state, sharded_rng, eval_batch
)
eval_metric_list.append(eval_metrics)
metrics.update(average_metrics(eval_metric_list))
log_metrics = {"step": step}
log_metrics.update(metrics)
log_metrics.update(dataset_metrics)
log_metrics = jax.device_get(log_metrics)
logger.log(log_metrics)
tqdm.write("\n" + pprint.pformat(log_metrics) + "\n")
if FLAGS.save_milestone_freq > 0 and (step + 1) % FLAGS.save_milestone_freq == 0:
save_checkpoint(train_state, milestone=True)
elif FLAGS.save_model_freq > 0 and (step + 1) % FLAGS.save_model_freq == 0:
save_checkpoint(train_state)
if FLAGS.save_model_freq > 0:
save_checkpoint(train_state)
if __name__ == "__main__":
mlxu.run(main)

346
EasyLM/optimizers.py Normal file
View File

@@ -0,0 +1,346 @@
import os
import time
from typing import Any, Mapping, Text, Tuple, Union, NamedTuple
from functools import partial
import re
import dataclasses
import random
from ml_collections.config_dict import config_dict
from ml_collections import ConfigDict
import jax
import jax.numpy as jnp
import numpy as np
from absl import logging
import optax
from EasyLM.jax_utils import float_to_dtype
class OptimizerFactory(object):
""" Configurable optax optimizer factory. """
def __init__(self):
raise NotImplementedError
@staticmethod
def get_default_config(updates=None):
config = ConfigDict()
config.accumulate_gradient_steps = 1
config.type = 'adamw'
config.palm_optimizer = PalmOptimizerFactory.get_default_config()
config.adamw_optimizer = AdamWOptimizerFactory.get_default_config()
config.lion_optimizer = LionOptimizerFactory.get_default_config()
if updates is not None:
config.update(ConfigDict(updates).copy_and_resolve_references())
return config
@classmethod
def get_optimizer(cls, config, weight_decay_mask=None):
config = cls.get_default_config(config)
if config.type == 'palm':
optimizer, optimizer_info = PalmOptimizerFactory.get_optimizer(
config.palm_optimizer, weight_decay_mask
)
elif config.type == 'adamw':
optimizer, optimizer_info = AdamWOptimizerFactory.get_optimizer(
config.adamw_optimizer, weight_decay_mask
)
elif config.type == 'lion':
optimizer, optimizer_info = LionOptimizerFactory.get_optimizer(
config.lion_optimizer, weight_decay_mask
)
else:
raise ValueError(f'Unknown optimizer type: {config.type}')
if config.accumulate_gradient_steps > 1:
optimizer = optax.MultiSteps(
optimizer, config.accumulate_gradient_steps
)
return optimizer, optimizer_info
class PalmOptimizerFactory(object):
""" PaLM optimizer factory. This optimizer implements the optimizer
described in the PaLM paper: https://arxiv.org/abs/2204.02311
"""
def __init__(self):
raise NotImplementedError
@staticmethod
def get_default_config(updates=None):
config = ConfigDict()
config.lr = 0.01
config.lr_warmup_steps = 10000
config.b1 = 0.9
config.b2 = 0.99
config.clip_gradient = 1.0
config.weight_decay = 1e-4
config.bf16_momentum = False
if updates is not None:
config.update(ConfigDict(updates).copy_and_resolve_references())
return config
@classmethod
def get_optimizer(cls, config, weight_decay_mask=None):
config = cls.get_default_config(config)
def learning_rate_schedule(step):
multiplier = config.lr / 0.01
return multiplier / jnp.sqrt(jnp.maximum(step, config.lr_warmup_steps))
def weight_decay_schedule(step):
multiplier = config.weight_decay / 1e-4
return -multiplier * jnp.square(learning_rate_schedule(step))
optimizer_info = dict(
learning_rate_schedule=learning_rate_schedule,
weight_decay_schedule=weight_decay_schedule,
)
optimizer = optax.chain(
optax.clip_by_global_norm(config.clip_gradient),
optax.adafactor(
learning_rate=learning_rate_schedule,
multiply_by_parameter_scale=True,
momentum=config.b1,
decay_rate=config.b2,
factored=False,
clipping_threshold=None,
dtype_momentum=jnp.bfloat16 if config.bf16_momentum else jnp.float32,
),
optax_add_scheduled_weight_decay(
weight_decay_schedule, weight_decay_mask
)
)
return optimizer, optimizer_info
class AdamWOptimizerFactory(object):
""" AdamW optimizer with cosine schedule. """
def __init__(self):
raise NotImplementedError
@staticmethod
def get_default_config(updates=None):
config = ConfigDict()
config.init_lr = 0.0
config.end_lr = 0.001
config.lr = 0.01
config.lr_warmup_steps = 2000
config.lr_decay_steps = 500000
config.b1 = 0.9
config.b2 = 0.95
config.clip_gradient = 1.0
config.weight_decay = 1e-4
config.bf16_momentum = False
config.multiply_by_parameter_scale = False
if updates is not None:
config.update(ConfigDict(updates).copy_and_resolve_references())
return config
@classmethod
def get_optimizer(cls, config, weight_decay_mask=None):
config = cls.get_default_config(config)
learning_rate_schedule = optax.warmup_cosine_decay_schedule(
init_value=config.init_lr,
peak_value=config.lr,
warmup_steps=config.lr_warmup_steps,
decay_steps=config.lr_decay_steps,
end_value=config.end_lr,
)
optimizer_info = dict(
learning_rate_schedule=learning_rate_schedule,
)
if config.multiply_by_parameter_scale:
optimizer = optax.chain(
optax.clip_by_global_norm(config.clip_gradient),
optax.adafactor(
learning_rate=learning_rate_schedule,
multiply_by_parameter_scale=True,
momentum=config.b1,
decay_rate=config.b2,
factored=False,
clipping_threshold=None,
dtype_momentum=jnp.bfloat16 if config.bf16_momentum else jnp.float32,
),
optax_add_scheduled_weight_decay(
lambda step: -learning_rate_schedule(step) * config.weight_decay,
weight_decay_mask
)
)
else:
optimizer = optax.chain(
optax.clip_by_global_norm(config.clip_gradient),
optax.adamw(
learning_rate=learning_rate_schedule,
weight_decay=config.weight_decay,
b1=config.b1,
b2=config.b2,
mask=weight_decay_mask,
mu_dtype=jnp.bfloat16 if config.bf16_momentum else jnp.float32,
),
)
return optimizer, optimizer_info
class LionOptimizerFactory(object):
""" Lion optimizer with cosine schedule. """
def __init__(self):
raise NotImplementedError
@staticmethod
def get_default_config(updates=None):
config = ConfigDict()
config.init_lr = 0.0
config.end_lr = 0.0001
config.lr = 0.001
config.lr_warmup_steps = 60000
config.lr_constant_steps = 840000
config.lr_decay_steps = 100000
config.b1 = 0.9
config.b2 = 0.98
config.clip_gradient = 1.0
config.weight_decay = 1e-3
config.bf16_momentum = False
config.lr_schedule_type = "warmup_cosine_decay_schedule"
config.lr_decay_rate = 0.98
if updates is not None:
config.update(ConfigDict(updates).copy_and_resolve_references())
return config
@classmethod
def get_optimizer(cls, config, weight_decay_mask=None):
config = cls.get_default_config(config)
if config.lr_schedule_type == "warmup_cosine_decay_schedule":
learning_rate_schedule = optax.warmup_cosine_decay_schedule(
init_value=config.init_lr,
peak_value=config.lr,
warmup_steps=config.lr_warmup_steps,
decay_steps=config.lr_decay_steps,
end_value=config.end_lr,
)
elif config.lr_schedule_type == "warmup_constant":
learning_rate_schedule = optax.join_schedules(
[
optax.linear_schedule(
init_value=config.init_lr,
end_value=config.lr,
transition_steps=config.lr_warmup_steps,
),
optax.constant_schedule(config.lr),
],
[config.lr_warmup_steps],
)
elif config.lr_schedule_type == "warmup_constant_linear_decay":
learning_rate_schedule = optax.join_schedules(
[
optax.linear_schedule(
init_value=config.init_lr,
end_value=config.lr,
transition_steps=config.lr_warmup_steps,
),
optax.constant_schedule(config.lr),
optax.linear_schedule(
init_value=config.lr,
end_value=config.end_lr,
transition_steps=config.lr_decay_steps,
)
],
[config.lr_warmup_steps, config.lr_constant_steps],
)
elif config.lr_schedule_type == "warmup_constant_exponential_decay":
learning_rate_schedule = optax.join_schedules(
[
optax.linear_schedule(
init_value=config.init_lr,
end_value=config.lr,
transition_steps=config.lr_warmup_steps,
),
optax.constant_schedule(config.lr),
optax.exponential_decay(
init_value=config.lr,
transition_steps=config.lr_decay_steps,
decay_rate=config.lr_decay_rate,
transition_begin=0,
staircase=False,
end_value=config.end_lr,
)
],
[config.lr_warmup_steps, config.lr_constant_steps],
)
elif config.lr_schedule_type == "exponential_decay":
learning_rate_schedule = optax.exponential_decay(
init_value=config.lr,
transition_steps=config.lr_decay_steps,
decay_rate=config.lr_decay_rate,
transition_begin=0,
staircase=False,
end_value=config.end_lr,
)
elif config.lr_schedule_type == "linear_decay":
learning_rate_schedule = optax.linear_schedule(
init_value=config.lr,
end_value=config.end_lr,
transition_steps=config.lr_decay_steps,
)
else:
raise ValueError('config.lr_schedule_type must be "warmup_cosine_decay_schedule", "warmup_constant", "warmup_constant_linear_decay", "warmup_constant_exponential_decay", "exponential_decay" or "linear_decay"')
optimizer_info = dict(
learning_rate_schedule=learning_rate_schedule,
)
optimizer = optax.chain(
optax.clip_by_global_norm(config.clip_gradient),
optax.lion(
learning_rate=learning_rate_schedule,
weight_decay=config.weight_decay,
b1=config.b1,
b2=config.b2,
mask=weight_decay_mask,
mu_dtype=jnp.bfloat16 if config.bf16_momentum else jnp.float32,
),
)
return optimizer, optimizer_info
class OptaxScheduledWeightDecayState(NamedTuple):
count: jax.Array
def optax_add_scheduled_weight_decay(schedule_fn, mask=None):
""" Apply weight decay with schedule. """
def init_fn(params):
del params
return OptaxScheduledWeightDecayState(count=jnp.zeros([], jnp.int32))
def update_fn(updates, state, params):
if params is None:
raise ValueError('Params cannot be None for weight decay!')
weight_decay = schedule_fn(state.count)
updates = jax.tree_util.tree_map(
lambda g, p: g + weight_decay * p, updates, params
)
return updates, OptaxScheduledWeightDecayState(
count=optax.safe_int32_increment(state.count)
)
if mask is not None:
return optax.masked(optax.GradientTransformation(init_fn, update_fn), mask)
return optax.GradientTransformation(init_fn, update_fn)

View File

View File

@@ -0,0 +1,150 @@
from functools import partial
from time import time
import os
import numpy as np
import jax
import jax.flatten_util
import jax.numpy as jnp
import mlxu
from EasyLM.bpt import blockwise_attn
from EasyLM.jax_utils import (
get_float_dtype_by_name, set_random_seed, next_rng, JaxRNG
)
FLAGS, _ = mlxu.define_flags_with_default(
seed=42,
dtype='fp32',
embed_dim=2048,
n_heads=16,
ref_attn_seq_len=2048,
eff_attn_seq_len=16384,
batch_size=1,
query_chunk_size=2048,
key_chunk_size=2048,
warmup_steps=40,
steps=200,
)
def main(argv):
def random_kqv(rng_key, seq_len):
rng_generator = JaxRNG(rng_key)
kqv = []
for i in range(3):
kqv.append(
jax.random.normal(
rng_generator(),
(FLAGS.batch_size, seq_len, FLAGS.n_heads, FLAGS.embed_dim // FLAGS.n_heads),
dtype=get_float_dtype_by_name(FLAGS.dtype)
)
)
return tuple(kqv)
def reference_attn(query, key, value):
dtype = get_float_dtype_by_name(FLAGS.dtype)
query = query / jnp.sqrt(query.shape[-1]).astype(dtype)
logits = jnp.einsum("bqhc,bkhc->bhqk", query, key)
mask_value = jnp.finfo(logits.dtype).min
_, q_seq_len, _, _ = query.shape
_, kv_seq_len, _, _ = key.shape
mask_shape = (q_seq_len, kv_seq_len)
row_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 0)
col_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 1)
causal_mask = (row_ids < col_ids)[None, None, :, :]
logits = logits + jnp.where(causal_mask, mask_value, 0.0)
weights = jax.nn.softmax(logits, axis=-1)
out = jnp.einsum("bhqk,bkhc->bqhc", weights, value)
return out
def efficient_attention(query, key, value):
dtype = get_float_dtype_by_name(FLAGS.dtype)
return blockwise_attn(
query, key, value,
bias=None,
deterministic=True,
dropout_rng=None,
attn_pdrop=0.0,
causal=True,
query_chunk_size=FLAGS.query_chunk_size,
key_chunk_size=FLAGS.key_chunk_size,
dtype=get_float_dtype_by_name(FLAGS.dtype),
policy=jax.checkpoint_policies.nothing_saveable(),
precision=None,
float32_logits=True,
prevent_cse=True,
)
@partial(jax.jit, static_argnums=(1,))
def reference_attn_forward_backward(rng_key, seq_len):
@partial(jax.grad, argnums=(0, 1, 2))
@partial(jax.checkpoint, policy=jax.checkpoint_policies.nothing_saveable())
def grad_fn(query, key, value):
out = reference_attn(query, key, value)
return jnp.mean(out)
query, key, value = random_kqv(rng_key, seq_len)
return jax.flatten_util.ravel_pytree(
grad_fn(query, key, value)[1]
)[0].mean()
@partial(jax.jit, static_argnums=(1,))
def efficient_attn_forward_backward(rng_key, seq_len):
@partial(jax.grad, argnums=(0, 1, 2))
def grad_fn(query, key, value):
out = efficient_attention(query, key, value)
return jnp.mean(out)
query, key, value = random_kqv(rng_key, seq_len)
return jax.flatten_util.ravel_pytree(
grad_fn(query, key, value)[1]
)[0].mean()
set_random_seed(FLAGS.seed)
jax.block_until_ready(reference_attn_forward_backward(next_rng(), FLAGS.ref_attn_seq_len))
jax.block_until_ready(efficient_attn_forward_backward(next_rng(), FLAGS.eff_attn_seq_len))
all_results = []
for i in range(FLAGS.warmup_steps):
all_results.append(reference_attn_forward_backward(next_rng(), FLAGS.ref_attn_seq_len))
jax.block_until_ready(all_results)
start_time = time()
all_results = []
for i in range(FLAGS.steps):
all_results.append(reference_attn_forward_backward(next_rng(), FLAGS.ref_attn_seq_len))
jax.block_until_ready(all_results)
elapsed_time_ref_attn = time() - start_time
print(f'Reference attention: {elapsed_time_ref_attn:.3f} seconds')
all_results = []
for i in range(FLAGS.warmup_steps):
all_results.append(efficient_attn_forward_backward(next_rng(), FLAGS.eff_attn_seq_len))
jax.block_until_ready(all_results)
start_time = time()
all_results = []
for i in range(FLAGS.steps):
all_results.append(efficient_attn_forward_backward(next_rng(), FLAGS.eff_attn_seq_len))
jax.block_until_ready(all_results)
elapsed_time_efficient_attn = time() - start_time
print(f'Efficient attention: {elapsed_time_efficient_attn:.3f} seconds')
flops_ratio = (FLAGS.eff_attn_seq_len / FLAGS.ref_attn_seq_len) ** 2
efficiency = elapsed_time_ref_attn / elapsed_time_efficient_attn * flops_ratio
print(f'Efficiency: {efficiency:.3f}')
if __name__ == '__main__':
mlxu.run(main)

View File

@@ -0,0 +1,42 @@
# This script converts model checkpoint trained by EsayLM to a standard
# mspack checkpoint that can be loaded by huggingface transformers or
# flax.serialization.msgpack_restore. Such conversion allows models to be
# used by other frameworks that integrate with huggingface transformers.
import pprint
from functools import partial
import os
import numpy as np
import mlxu
import jax.numpy as jnp
import flax.serialization
from EasyLM.checkpoint import StreamingCheckpointer
from EasyLM.jax_utils import float_to_dtype
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
load_checkpoint='',
output_file='',
streaming=False,
float_dtype='bf16',
)
def main(argv):
assert FLAGS.load_checkpoint != '' and FLAGS.output_file != '', 'input and output must be specified'
params = StreamingCheckpointer.load_trainstate_checkpoint(
FLAGS.load_checkpoint, disallow_trainstate=True
)[1]['params']
if FLAGS.streaming:
StreamingCheckpointer.save_train_state_to_file(
params, FLAGS.output_file, float_dtype=FLAGS.float_dtype
)
else:
params = float_to_dtype(params, FLAGS.float_dtype)
with mlxu.open_file(FLAGS.output, 'wb') as fout:
fout.write(flax.serialization.msgpack_serialize(params, in_place=True))
if __name__ == "__main__":
mlxu.run(main)

View File

@@ -0,0 +1,59 @@
# This script converts model checkpoint trained by EsayLM to a standard
# mspack checkpoint that can be loaded by huggingface transformers or
# flax.serialization.msgpack_restore. Such conversion allows models to be
# used by other frameworks that integrate with huggingface transformers.
import pprint
from functools import partial
import os
import numpy as np
import jax
import jax.numpy as jnp
import flax.serialization
import mlxu
from EasyLM.checkpoint import StreamingCheckpointer
from EasyLM.jax_utils import float_to_dtype
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
recover_diff=False,
load_base_checkpoint='',
load_target_checkpoint='',
output_file='',
streaming=True,
float_dtype='bf16',
)
def main(argv):
assert FLAGS.load_base_checkpoint != '' and FLAGS.load_target_checkpoint != ''
assert FLAGS.output_file != ''
base_params = StreamingCheckpointer.load_trainstate_checkpoint(
FLAGS.load_base_checkpoint, disallow_trainstate=True
)[1]['params']
target_params = StreamingCheckpointer.load_trainstate_checkpoint(
FLAGS.load_target_checkpoint, disallow_trainstate=True
)[1]['params']
if FLAGS.recover_diff:
params = jax.tree_util.tree_map(
lambda b, t: b + t, base_params, target_params
)
else:
params = jax.tree_util.tree_map(
lambda b, t: t - b, base_params, target_params
)
if FLAGS.streaming:
StreamingCheckpointer.save_train_state_to_file(
params, FLAGS.output_file, float_dtype=FLAGS.float_dtype
)
else:
params = float_to_dtype(params, FLAGS.float_dtype)
with mlxu.open_file(FLAGS.output, 'wb') as fout:
fout.write(flax.serialization.msgpack_serialize(params, in_place=True))
if __name__ == "__main__":
mlxu.run(main)

View File

@@ -0,0 +1,65 @@
# This script runs lm_eval_harness evaluations against a served language model.
# Typically, you need to run a language model server first, e.g.:
# python -m EasyLM.models.gptj.gptj_serve ...
import dataclasses
import pprint
from functools import partial
import os
from tqdm import tqdm, trange
import numpy as np
import mlxu
from flax.traverse_util import flatten_dict
from lm_eval import evaluator, tasks
from lm_eval.base import LM
from EasyLM.serving import LMClient
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
tasks='wsc,piqa,winogrande,openbookqa,logiqa',
shots=0,
limit=0,
write_out=False,
lm_client=LMClient.get_default_config(),
logger=mlxu.WandBLogger.get_default_config(),
)
class LMEvalHarnessInterface(LM):
def __init__(self, lm_client):
self.lm_client = lm_client
def greedy_until(self, inputs):
prefix, until = zip(*inputs)
return self.lm_client.greedy_until(prefix, until)
def loglikelihood_rolling(self, inputs):
loglikelihood, is_greedy = self.lm_client.loglikelihood_rolling(inputs)
return list(zip(loglikelihood, is_greedy))
def loglikelihood(self, inputs):
prefix, text = zip(*inputs)
loglikelihood, is_greedy = self.lm_client.loglikelihood(prefix, text)
return list(zip(loglikelihood, is_greedy))
def main(argv):
logger = mlxu.WandBLogger(
config=FLAGS.logger, variant=mlxu.get_user_flags(FLAGS, FLAGS_DEF)
)
model = LMEvalHarnessInterface(LMClient(FLAGS.lm_client))
task_list = FLAGS.tasks.split(',')
results = evaluator.evaluate(
model, tasks.get_task_dict(task_list), False, FLAGS.shots,
limit=None if FLAGS.limit <= 0 else FLAGS.limit,
write_out=FLAGS.write_out,
)
logger.log(flatten_dict(results['results'], sep='/'))
pprint.pprint(results)
if __name__ == "__main__":
mlxu.run(main)

View File

@@ -0,0 +1,52 @@
import json
import mlxu
from EasyLM.serving import LMClient
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
input_file='',
output_file='',
prefix_field='prefix',
text_field='text',
until_field='until',
eval_type='loglikelihood',
lm_client=LMClient.get_default_config(),
)
def main(argv):
lm_client = LMClient(FLAGS.lm_client)
with mlxu.open_file(FLAGS.input_file, 'r') as fin:
input_data = json.load(fin)
if FLAGS.eval_type == 'loglikelihood':
prefix = input_data[FLAGS.prefix_field]
text = input_data[FLAGS.text_field]
loglikelihoods, is_greedys = lm_client.loglikelihood(prefix, text)
output_data = {
'loglikelihood': loglikelihoods,
'is_greedy': is_greedys,
}
elif FLAGS.eval_type == 'loglikelihood_rolling':
text = input_data[FLAGS.text_field]
loglikelihoods, is_greedys = lm_client.loglikelihood_rolling(text)
output_data = {
'loglikelihood': loglikelihoods,
'is_greedy': is_greedys,
}
elif FLAGS.eval_type == 'greedy_until':
prefix = input_data[FLAGS.prefix_field]
until = input_data[FLAGS.until_field]
output_data = {'output_text': lm_client.greedy_until(prefix, until)}
elif FLAGS.eval_type == 'generate':
prefix = input_data[FLAGS.prefix_field]
output_data = {'output_text': lm_client.generate(prefix)}
else:
raise ValueError(f'Unknown eval_type: {FLAGS.eval_type}')
with mlxu.open_file(FLAGS.output_file, 'w') as fout:
json.dump(output_data, fout)
if __name__ == "__main__":
mlxu.run(main)

566
EasyLM/serving.py Normal file
View File

@@ -0,0 +1,566 @@
import dataclasses
import pprint
from functools import partial
import re
import os
from threading import Lock
import urllib
import time
from typing import List, Optional, Union
from pydantic import BaseModel
import absl.logging
from tqdm import tqdm, trange
import numpy as np
import mlxu
from ml_collections import ConfigDict
import uvicorn
from fastapi import FastAPI
import gradio as gr
import requests
from requests.exceptions import Timeout, ConnectionError
class InferenceRequest(BaseModel):
prefix_text: Optional[List[str]] = None
text: Optional[List[str]] = None
until: Optional[Union[List[str], List[List[str]]]] = None
temperature: Optional[float] = None
class ChatRequest(BaseModel):
prompt: str
context: str = ''
temperature: Optional[float] = None
class LMServer(object):
""" HTTP server for serving langauge models. """
@staticmethod
def get_default_config(updates=None):
config = ConfigDict()
config.host = '0.0.0.0'
config.port = 5007
config.batch_size = 1
config.logging = False
config.pre_compile = 'loglikelihood'
config.default_temperature = 1.0
config.greedy_until_max_length = 5000
config.prepend_to_prefix = ''
config.append_to_prefix = ''
config.prepend_to_text = ''
config.append_to_text = ''
config.chat_prepend_text = ''
config.chat_user_prefix = ''
config.chat_user_suffix = ''
config.chat_lm_prefix = ''
config.chat_lm_suffix = ''
config.notes = ''
if updates is not None:
config.update(ConfigDict(updates).copy_and_resolve_references())
return config
def __init__(self, config):
self.config = self.get_default_config(config)
self.lock = Lock()
self.app = FastAPI()
self.app.post('/loglikelihood')(self.serve_loglikelihood)
self.app.post('/loglikelihood-rolling')(self.serve_loglikelihood_rolling)
self.app.post('/generate')(self.serve_generate)
self.app.post('/greedy-until')(self.serve_greedy_until)
self.app.post('/chat')(self.serve_chat)
self.app.get('/ready')(self.serve_ready)
self.app = gr.mount_gradio_app(self.app, self.create_chat_app(), '/')
@staticmethod
def loglikelihood(prefix_text, text):
raise NotImplementedError()
@staticmethod
def loglikelihood_rolling(text):
raise NotImplementedError()
@staticmethod
def generate(text, temperature):
raise NotImplementedError()
@staticmethod
def greedy_until(prefix_text, until, max_length):
raise NotImplementedError()
@staticmethod
def to_list(x):
if isinstance(x, np.ndarray):
return x.tolist()
return x
def serve_ready(self):
return 'Ready!\n'
def serve_loglikelihood(self, data: InferenceRequest):
with self.lock:
if self.config.logging:
absl.logging.info(
'\n========= Serving Log Likelihood Request ========= \n'
+ pprint.pformat(data) + '\n'
)
if data.prefix_text is None:
data.prefix_text = ['' for _ in data.text]
prefix_text = [
self.config.prepend_to_prefix + p + self.config.append_to_prefix
for p in data.prefix_text
]
text = [
self.config.prepend_to_text + t + self.config.append_to_text
for t in data.text
]
log_likelihood = []
is_greedy = []
for i in trange(0, len(text), self.config.batch_size, ncols=0):
batch_prefix_text = prefix_text[i:i + self.config.batch_size]
batch_text = text[i:i + self.config.batch_size]
batch_size = len(batch_text)
if batch_size < self.config.batch_size:
extra = self.config.batch_size - batch_size
batch_prefix_text.extend(['a' for _ in range(extra)])
batch_text.extend(['a' for _ in range(extra)])
batch_log_likelihood, batch_is_greedy = self.loglikelihood(
batch_prefix_text, batch_text
)
batch_log_likelihood = self.to_list(batch_log_likelihood)
batch_is_greedy = self.to_list(batch_is_greedy)
log_likelihood.extend(batch_log_likelihood[:batch_size])
is_greedy.extend(batch_is_greedy[:batch_size])
output = {
'prefix_text': data.prefix_text,
'text': data.text,
'log_likelihood': log_likelihood,
'is_greedy': is_greedy,
}
if self.config.logging:
absl.logging.info(
'\n========= Output ========= \n'
+ pprint.pformat(output) + '\n'
)
return output
def serve_loglikelihood_rolling(self, data: InferenceRequest):
with self.lock:
if self.config.logging:
absl.logging.info(
'\n========= Serving Log Likelihood Request ========= \n'
+ pprint.pformat(data) + '\n'
)
text = [
self.config.prepend_to_text + t + self.config.append_to_text
for t in data.text
]
log_likelihood = []
is_greedy = []
for i in trange(0, len(text), self.config.batch_size, ncols=0):
batch_text = text[i:i + self.config.batch_size]
batch_size = len(batch_text)
if batch_size < self.config.batch_size:
extra = self.config.batch_size - batch_size
batch_text.extend(['a' for _ in range(extra)])
batch_log_likelihood, batch_is_greedy = self.loglikelihood_rolling(
batch_text
)
batch_log_likelihood = self.to_list(batch_log_likelihood)
batch_is_greedy = self.to_list(batch_is_greedy)
log_likelihood.extend(batch_log_likelihood[:batch_size])
is_greedy.extend(batch_is_greedy[:batch_size])
output = {
'text': data.text,
'log_likelihood': log_likelihood,
'is_greedy': is_greedy,
}
if self.config.logging:
absl.logging.info(
'\n========= Output ========= \n'
+ pprint.pformat(output) + '\n'
)
return output
def serve_generate(self, data: InferenceRequest):
with self.lock:
if self.config.logging:
absl.logging.info(
'\n========= Serving Generate Request ========= \n'
+ pprint.pformat(data) + '\n'
)
prefix_text = [
self.config.prepend_to_prefix + p + self.config.append_to_prefix
for p in data.prefix_text
]
if data.temperature is None:
data.temperature = self.config.default_temperature
output_text = []
for i in trange(0, len(prefix_text), self.config.batch_size, ncols=0):
batch_prefix_text = prefix_text[i:i + self.config.batch_size]
batch_size = len(batch_prefix_text)
if batch_size < self.config.batch_size:
extra = self.config.batch_size - batch_size
batch_prefix_text.extend(['a' for _ in range(extra)])
batch_output_text = self.generate(
batch_prefix_text,
temperature=data.temperature,
)
output_text.extend(self.to_list(batch_output_text)[:batch_size])
output = {
'prefix_text': data.prefix_text,
'output_text': output_text,
'temperature': data.temperature,
}
if self.config.logging:
absl.logging.info(
'\n========= Output ========= \n'
+ pprint.pformat(output) + '\n'
)
return output
def serve_greedy_until(self, data: InferenceRequest):
with self.lock:
if self.config.logging:
absl.logging.info(
'\n========= Serving Greedy Until Request ========= \n'
+ pprint.pformat(data) + '\n'
)
prefix_text = [
self.config.prepend_to_prefix + p + self.config.append_to_prefix
for p in data.prefix_text
]
until = data.until
max_length = self.config.greedy_until_max_length
output_text = []
for i in range(0, len(prefix_text), self.config.batch_size):
batch_prefix_text = prefix_text[i:i + self.config.batch_size]
batch_until = until[i:i + self.config.batch_size]
batch_size = len(batch_prefix_text)
batch_output_text = self.greedy_until(batch_prefix_text, batch_until, max_length)
output_text.extend(self.to_list(batch_output_text)[:batch_size])
output = {
'prefix_text': data.prefix_text,
'until': data.until,
'max_length': max_length,
'output_text': output_text,
}
if self.config.logging:
absl.logging.info(
'\n========= Output ========= \n'
+ pprint.pformat(output) + '\n'
)
return output
def process_chat(self, prompt, context, temperature):
context = (
context + self.config.chat_user_prefix
+ prompt + self.config.chat_user_suffix
+ self.config.chat_lm_prefix
)
response = self.generate(
[self.config.chat_prepend_text + context],
temperature=float(temperature),
)[0]
context = context + response + self.config.chat_lm_suffix
return response, context
def serve_chat(self, data: ChatRequest):
if data.temperature is None:
data.temperature = self.config.default_temperature
response, context = self.process_chat(
data.prompt, data.context,
temperature=data.temperature,
)
return {
'response': response,
'context': context,
'temperature': data.temperature,
}
def create_chat_app(self):
with gr.Blocks(analytics_enabled=False, title='EasyLM Chat') as gradio_chatbot:
gr.Markdown('# Chatbot Powered by [EasyLM](https://github.com/young-geng/EasyLM)')
gr.Markdown(self.config.notes)
chatbot = gr.Chatbot(label='Chat history')
msg = gr.Textbox(
placeholder='Type your message here...',
show_label=False
)
with gr.Row():
send = gr.Button('Send')
regenerate = gr.Button('Regenerate', interactive=False)
clear = gr.Button('Reset')
temp_slider = gr.Slider(
label='Temperature', minimum=0, maximum=2.0,
value=self.config.default_temperature
)
context_state = gr.State(['', ''])
def user_fn(user_message, history, context):
return {
msg: gr.update(value='', interactive=False),
clear: gr.update(interactive=False),
send: gr.update(interactive=False),
regenerate: gr.update(interactive=False),
chatbot: history + [[user_message, None]],
context_state: [context[1], context[1]],
}
def model_fn(history, context, temperature):
history[-1][1], new_context = self.process_chat(
history[-1][0], context[0], temperature
)
return {
msg: gr.update(value='', interactive=True),
clear: gr.update(interactive=True),
send: gr.update(interactive=True),
chatbot: history,
context_state: [context[0], new_context],
regenerate: gr.update(interactive=True),
}
def regenerate_fn():
return {
msg: gr.update(value='', interactive=False),
clear: gr.update(interactive=False),
send: gr.update(interactive=False),
regenerate: gr.update(interactive=False),
}
def clear_fn():
return {
chatbot: None,
msg: '',
context_state: ['', ''],
regenerate: gr.update(interactive=False),
}
msg.submit(
user_fn,
inputs=[msg, chatbot, context_state],
outputs=[msg, clear, send, chatbot, context_state, regenerate],
queue=False
).then(
model_fn,
inputs=[chatbot, context_state, temp_slider],
outputs=[msg, clear, send, chatbot, context_state, regenerate],
queue=True
)
send.click(
user_fn,
inputs=[msg, chatbot, context_state],
outputs=[msg, clear, send, chatbot, context_state, regenerate],
queue=False
).then(
model_fn,
inputs=[chatbot, context_state, temp_slider],
outputs=[msg, clear, send, chatbot, context_state, regenerate],
queue=True
)
regenerate.click(
regenerate_fn,
inputs=None,
outputs=[msg, clear, send, regenerate],
queue=False
).then(
model_fn,
inputs=[chatbot, context_state, temp_slider],
outputs=[msg, clear, send, chatbot, context_state, regenerate],
queue=True
)
clear.click(
clear_fn,
inputs=None,
outputs=[chatbot, msg, context_state, regenerate],
queue=False
)
gradio_chatbot.queue(concurrency_count=1)
return gradio_chatbot
def run(self):
if self.config.pre_compile != '':
if self.config.pre_compile == 'all':
pre_compile = ['loglikelihood', 'generate', 'greedy_until', 'chat']
else:
pre_compile = self.config.pre_compile.split(',')
pre_compile_data = ['a' for _ in range(self.config.batch_size)]
for task in pre_compile:
if task == 'loglikelihood':
self.loglikelihood(pre_compile_data, pre_compile_data)
self.loglikelihood_rolling(pre_compile_data)
elif task == 'generate':
self.generate(pre_compile_data, 1.0)
elif task == 'greedy_until':
self.greedy_until(
pre_compile_data, pre_compile_data,
self.config.greedy_until_max_length
)
elif task == 'chat':
self.process_chat('a', 'a', 1.0)
else:
raise ValueError(f'Invalid precompile task: {task}!')
uvicorn.run(self.app, host=self.config.host, port=self.config.port)
class LMClient(object):
""" A simple client for the LM server. """
@staticmethod
def get_default_config(updates=None):
config = ConfigDict()
config.url = 'http://localhost:5007'
config.batch_size = 1
config.wait_for_ready = True
config.dummy = False
if updates is not None:
config.update(ConfigDict(updates).copy_and_resolve_references())
return config
def __init__(self, config=None):
self.config = self.get_default_config(config)
if self.config.wait_for_ready:
self.wait_for_ready()
def wait_for_ready(self):
if self.config.dummy:
return
while True:
try:
requests.get(urllib.parse.urljoin(self.config.url, 'ready'))
return
except (Timeout, ConnectionError) as e:
time.sleep(10)
@staticmethod
def batched(iterator, batch_size):
batch = []
for example in iterator:
batch.append(example)
if len(batch) == batch_size:
yield batch
batch = []
if len(batch) > 0:
yield batch
def loglikelihood(self, prefix, text):
prefix, text = list(prefix), list(text)
if self.config.dummy:
return [-1.0 for _ in text], [False for _ in text]
log_likelihood = []
is_greedy = []
batched_iterator = list(zip(
self.batched(prefix, self.config.batch_size),
self.batched(text, self.config.batch_size)
))
for batch_prefix, batch_text in tqdm(batched_iterator, ncols=0):
response = requests.post(
urllib.parse.urljoin(self.config.url, 'loglikelihood'),
json={'prefix_text': batch_prefix, 'text': batch_text}
).json()
log_likelihood.extend(response['log_likelihood'])
is_greedy.extend(response['is_greedy'])
return log_likelihood, is_greedy
def loglikelihood_rolling(self, text):
text = list(text)
if self.config.dummy:
return [-1.0 for _ in text], [False for _ in text]
log_likelihood = []
is_greedy = []
batched_iterator = list(self.batched(text, self.config.batch_size))
for batch_text in tqdm(batched_iterator, ncols=0):
response = requests.post(
urllib.parse.urljoin(self.config.url, 'loglikelihood-rolling'),
json={'text': batch_text}
).json()
log_likelihood.extend(response['log_likelihood'])
is_greedy.extend(response['is_greedy'])
return log_likelihood, is_greedy
def greedy_until(self, prefix, until):
prefix, until = list(prefix), list(until)
if self.config.dummy:
results = []
for u in until:
if isinstance(u, str):
results.append('dummy text ' + u)
else:
results.append('dummy text ' + u[0])
return results
batched_iterator = list(zip(
self.batched(prefix, self.config.batch_size),
self.batched(until, self.config.batch_size),
))
output_text = []
for batch_prefix, batch_until in tqdm(batched_iterator, ncols=0):
response = requests.post(
urllib.parse.urljoin(self.config.url, 'greedy-until'),
json={'prefix_text': batch_prefix, 'until': batch_until}
).json()
output_text.extend(response['output_text'])
return output_text
def generate(self, prefix, temperature=None):
prefix = list(prefix)
if self.config.dummy:
return ['' for _ in prefix]
output_text = []
batched_iterator = list(self.batched(prefix, self.config.batch_size))
for batch_prefix in tqdm(batched_iterator, ncols=0):
response = requests.post(
urllib.parse.urljoin(self.config.url, 'generate'),
json={
'prefix_text': batch_prefix,
'temperature': temperature,
}
).json()
output_text.extend(response['output_text'])
return output_text
def chat(self, prompt, context, temperature=None):
if self.config.dummy:
return ''
response = requests.post(
urllib.parse.urljoin(self.config.url, 'chat'),
json={
'prompt': prompt,
'context': context,
'temperature': temperature,
}
).json()
return response['response'], response['context']

302
README.md Normal file
View File

@@ -0,0 +1,302 @@
---
datasets:
- Finnish-NLP/CulturaX_fi_cleaned
- Finnish-NLP/HPLT_1.2_fi_cleaned
- Finnish-NLP/wikipedia_20231101_fi_cleaned
- Finnish-NLP/Reddit_fi_2006_2022
- intfloat/multilingual_cc_news
language:
- fi
license: apache-2.0
pipeline_tag: text-generation
tags:
- finnish
- llama
library_name: transformers
---
# Ahma-7B for Finnish
Ahma-7B is a 7B parameter decoder-only transformer model based on Meta's Llama (v1) architecture, pretrained from scratch on the Finnish language. Its development was informed by the research presented in the paper [Scaling Data-Constrained Language Models](https://huggingface.co/papers/2305.16264). The original Llama model architecture was introduced in
[this paper](https://arxiv.org/abs/2302.13971)
and first released at [this page](https://github.com/facebookresearch/llama).
What does Ahma mean? Ahma is the Finnish word for wolverine! In the Finnish Lapland, wolverines are the biggest cause of reindeer damage.
There are two different sized base Ahma models both pretrained from scratch, Ahma-3B for 139B tokens and Ahma-7B for 149B tokens:
| Model | Context length | Layers | Dim | Heads | Params |
|:--------------------------------------------------------------------------------|:---------------|:-------|:-----|:------|:-------|
| [Ahma-3B](https://huggingface.co/Finnish-NLP/Ahma-3B) | 2048 | 26 | 3200 | 32 | 3.6B |
| [Ahma-7B](https://huggingface.co/Finnish-NLP/Ahma-7B) | 2048 | 32 | 4096 | 32 | 7.0B |
And two instruct-tuned versions:
| Model | Context length | Layers | Dim | Heads | Params |
|:--------------------------------------------------------------------------------|:---------------|:-------|:-----|:------|:-------|
| [Ahma-3B-Instruct](https://huggingface.co/Finnish-NLP/Ahma-3B-Instruct) | 2048 | 26 | 3200 | 32 | 3.6B |
| [Ahma-7B-Instruct](https://huggingface.co/Finnish-NLP/Ahma-7B-Instruct) | 2048 | 32 | 4096 | 32 | 7.0B |
## Paper Abstract
The current trend of scaling language models involves increasing both parameter count and training dataset size. Extrapolating this trend suggests that training dataset size may soon be limited by the amount of text data available on the internet. Motivated by this limit, we investigate scaling language models in data-constrained regimes. Specifically, we run a large set of experiments varying the extent of data repetition and compute budget, ranging up to 900 billion training tokens and 9 billion parameter models. We find that with constrained data for a fixed compute budget, training with up to 4 epochs of repeated data yields negligible changes to loss compared to having unique data. However, with more repetition, the value of adding compute eventually decays to zero. We propose and empirically validate a scaling law for compute optimality that accounts for the decreasing value of repeated tokens and excess parameters. Finally, we experiment with approaches mitigating data scarcity, including augmenting the training dataset with code data or removing commonly used filters. Models and datasets from our 400 training runs are freely available at this https URL .
## Intended uses & limitations
This model was pretrained only in a self-supervised way, without any supervised training. You can use this model for text generation or fine-tune it for a downstream task. This model followed a 2-stage pretraining approach where single-turn instruction-following examples were mixed in with the other training data in the second stage (explained more later in this readme). Thanks to this approach, this pretrained model is already capable of instruction following, but you might get even better results if you specifically fine-tune it for instruction following or other use cases. For instruction-following fine-tuning, you should use the same prompt format showcased below.
### How to use
#### Fine-tuning
We have now added finetuning example notebook along with video! \
Notebook: https://huggingface.co/Finnish-NLP/Ahma-3B/blob/main/Finetune_Ahma_3B_example.ipynb \
Video: https://www.youtube.com/watch?v=6mbgn9XzpS4
#### Inference
If you want to use this model for instruction-following, you need to use the same prompt format we used in the second stage of the pretraining (basically the same format what Meta used in their Llama2 models). **Note: do not use "LlamaTokenizer" from transformers library but always use the AutoTokenizer instead, or use the plain sentencepiece tokenizer.** Here is an example using the instruction-following prompt format, with some generation arguments you can modify for your use:
```python
from transformers import AutoTokenizer, AutoModelForCausalLM
system_prompt = "Olet tekoälyavustaja. Vastaat aina mahdollisimman avuliaasti. Vastauksesi eivät saa sisältää mitään haitallista, epäeettistä, rasistista, seksististä, vaarallista tai laitonta sisältöä. Jos kysymyksessä ei ole mitään järkeä tai se ei ole asiasisällöltään johdonmukainen, selitä miksi sen sijaan, että vastaisit jotain väärin. Jos et tiedä vastausta kysymykseen, älä kerro väärää tietoa."
def format_prompt(prompt: str) -> str:
prompt = f" [INST] <<SYS>>
{system_prompt.strip()}
<</SYS>>
{prompt.strip()} [/INST] "
return prompt
tokenizer = AutoTokenizer.from_pretrained("Finnish-NLP/Ahma-7B")
model = AutoModelForCausalLM.from_pretrained("Finnish-NLP/Ahma-7B")
model = model.to("cuda")
# use the custom prompt format function or the chat template feature in the tokenizer to format your inputs
# prompt = format_prompt("Listaa kolme hyötyä, joita pienet avoimen lähdekoodin kielimallit tuovat?")
# inputs = tokenizer(prompt, return_tensors="pt")
messages = [
{
"role": "system",
"content": system_prompt,
},
{"role": "user", "content": "Listaa kolme hyötyä, joita pienet avoimen lähdekoodin kielimallit tuovat?"},
]
inputs = tokenizer.apply_chat_template(
messages, tokenize=True, add_generation_prompt=True, return_tensors="pt"
)
inputs = inputs.to("cuda")
generated_ids = model.generate(
inputs,
temperature=0.6,
penalty_alpha=0.6,
top_k=4,
do_sample=True,
repetition_penalty=1.2,
min_length=5,
max_length=2048,
)
generated_text = tokenizer.batch_decode(
generated_ids, skip_special_tokens=False
)[0]
"""
1. Parempi luettavuus ja ymmärtäminen: Pienten avoimen lähdekoodin kielimallien avulla voidaan luoda ymmärrettävämpää ja luettavampaa tekstiä, mikä helpottaa ihmisten ymmärtämistä ja tiedon hankkimista.
2. Parempi mukautuvuus ja monipuolisuus: Avoimen lähdekoodin mallit antavat kehittäjille mahdollisuuden luoda räätälöityjä ratkaisuja omiin tarpeisiinsa, jolloin he voivat hyödyntää olemassa olevaa tietämystä ja asiantuntemusta.
3. Lisääntynyt yhteistyö ja avoimuus: Avoimen lähdekoodin mallien ansiosta kehittäjät voivat tehdä yhteistyötä muiden kanssa, jakaa ideoita ja parantaa koodin laatua jakamalla oivalluksia ja parhaita käytäntöjä. Tämä edistää yhteistyöhön perustuvaa ympäristöä ja kannustaa jatkuvaan parantamiseen.
"""
```
You may experiment with different system prompt instructions too if you like.
### Limitations and bias
This model was trained only with Finnish texts excluding code so it should not be used for multilingual and code generation use cases.
The training data used for this model contains a lot of content from the internet, which is far from neutral. Therefore, the model can have biased predictions. This bias will also affect all fine-tuned versions of this model.
To reduce toxic content, training data was filtered with a toxicity classifier but it cannot truly eliminate all toxic text.
## Training data
This model was pretrained on the combination of 14 datasets:
- [CulturaX_fi_cleaned](https://huggingface.co/datasets/Finnish-NLP/CulturaX_fi_cleaned), we cleaned Finnish split from the original [CulturaX](https://huggingface.co/datasets/uonlp/CulturaX) dataset
- [HPLT_1.2_fi_cleaned](https://huggingface.co/datasets/Finnish-NLP/HPLT_1.2_fi_cleaned), we cleaned Finnish split from the original [HPLT v1.2](https://hplt-project.org/datasets/v1.2) dataset
- [wikipedia_20231101_fi_cleaned](https://huggingface.co/datasets/Finnish-NLP/wikipedia_20231101_fi_cleaned), we used the Finnish subset of the wikipedia (November 2023) dataset
- [Reddit_fi_2006_2022](https://huggingface.co/datasets/Finnish-NLP/Reddit_fi_2006_2022), filtered and post-processed dataset of Finnish Reddit
- [Yle Finnish News Archive 2011-2018](http://urn.fi/urn:nbn:fi:lb-2017070501)
- [Yle Finnish News Archive 2019-2020](http://urn.fi/urn:nbn:fi:lb-2021050401)
- [Finnish News Agency Archive (STT)](http://urn.fi/urn:nbn:fi:lb-2018121001)
- [The Suomi24 Sentences Corpus](http://urn.fi/urn:nbn:fi:lb-2020021803)
- [Project Lönnrot](http://www.lonnrot.net/)
- [Finnish parliament speeches](https://avoindata.eduskunta.fi)
- [multilingual_cc_news](https://huggingface.co/datasets/intfloat/multilingual_cc_news), we used the Finnish subset of the multilingual CC-News dataset
- [fi-news-corpus](https://github.com/nkrusch/fi-news-corpus)
- Finnish higher education public theses
- Finnish single-turn instruction-following datasets, combination of multiple originally openly licensed English datasets translated to Finnish. For example, [Ultrachat, Aya, Capybara, etc](https://huggingface.co/collections/Finnish-NLP/sft-dpo-dataset-65f55dde1139c3cd683ff035)
Raw datasets were automatically cleaned to filter out bad quality and non-Finnish examples. Also, a [perplexity](https://huggingface.co/course/chapter7/3#perplexity-for-language-models) score was calculated for all texts with a KenLM model which was trained with very clean Finnish texts only. This perplexity score can then be used to determine how "clean" Finnish language the text contains. To reduce toxic text, we used Finnish toxicity classifier [TurkuNLP/bert-large-finnish-cased-toxicity](https://huggingface.co/TurkuNLP/bert-large-finnish-cased-toxicity) released by TurkuNLP to classify all text examples. Classified toxicity label scores can then be used to determine how toxic the text is.
All datasets were concatenated and the whole dataset near deduplicated using MinHashLSH from [text-dedup](https://github.com/ChenghaoMou/text-dedup). Top 95% perplexity score was used as a filtering threshold to filter out the worst quality 5% of texts. To reduce amount of toxic content, the dataset was filtered to include text examples having lower than 80% score for the toxicity labels "label_identity_attack", "label_insult", "label_threat" and "label_severe_toxicity".
Finally, 20,000 text examples from each of the CulturaX, Wikipedia, Yle, STT, Suomi24, and Reddit datasets were randomly selected for evaluation dataset.
The final training dataset had 23 billion words (calculated with regex "\w+") and the evaluation dataset had 23 million words. After tokenization, the training dataset had 41 billion tokens and the evaluation dataset had 40 million tokens. For the 2-stage pretraining, training datasets are divided as follows:
The first stage:
|Dataset | Words | Ratio |
|:-----------------------------|:------------|:-------------|
|CulturaX | 12.820B | 59.88% |
|HPLT v1.2 | 5.034B | 23.51% |
|Suomi24 | 3.018B | 14.09% |
|Reddit | 0.141B | 0.66% |
|CC-News | 0.311B | 1.45% |
|FI news corpus | 0.004B | 0.02% |
|Project Lönnrot | 0.083B | 0.39% |
|**TOTAL** | **21.410B** | **100.0%** |
The second stage:
|Dataset | Words | Ratio |
|:--------------------------------------------------------------|:------------|:------------|
|CulturaX (cleaner sample using KenLM perplexity score) | 2.252B | 55.48% |
|Wikipedia | 0.095B | 2.34% |
|STT | 0.253B | 6.23% |
|Yle | 0.212B | 5.22% |
|Finnish parliament speeches | 0.021B | 0.52% |
|Finnish higher education public theses | 0.855B | 21.07% |
|Finnish instruction-following datasets (note: 2X upsampled) | 0.371B | 9.14% |
|**TOTAL** | **4.059B** | **100.0%** |
## Training procedure
### Preprocessing
Texts are tokenized using Byte Pair Encoding (BPE) using the implementation from SentencePiece splitting all numbers into individual digits and using bytes to decompose unknown UTF-8 characters. The total
vocabulary size is 64k tokens. Inputs are sequences of 2048 consecutive tokens. Texts are not lower cased so this model is case-sensitive: it makes a difference between finnish and Finnish. Both BOS and EOS tokens were used in the pretraining.
### 2-stage pretraining
The model was trained on TPUv4-32 VM, sponsored by the [Google TPU Research Cloud](https://sites.research.google/trc/about/). Training was conducted with a slightly modified Jax/Flax based [EasyLM](https://github.com/young-geng/EasyLM) framework, and inspired by the [OpenLLaMA](https://github.com/openlm-research/open_llama) project. The optimizer used was a [Lion](https://arxiv.org/abs/2302.06675).
The 2-stage pretraining approach was inspired by [MiniCPM](https://shengdinghu.notion.site/MiniCPM-Unveiling-the-Potential-of-End-side-Large-Language-Models-d4d3a8c426424654a4e80e42a711cb20) findings. For the first stage (79% of the entire training), we used noisier web-scraped datasets. For the second stage (21% of the entire training), we primarily used cleaner datasets and instruction-following datasets shuffled together, like in MiniCPM. The learning rate schedule for the 2-stage pretraining was Warmup-Stable-Decay (WSD). During the first stage, the learning rate schedule had a linear warmup for about 8 billion tokens to a peak learning rate of 1e-4 (note: with the Lion optimizer, the learning rate had to be about 10 times smaller than with the commonly used AdamW), followed by a stable phase where the rate of 1e-4 was kept constant. During the second stage, the learning rate schedule had a linear decay from 1e-4 to 6e-6 for the first 7 billion tokens, followed by a stable phase for the remaining tokens.
In the first stage, the model was trained for 118 billion tokens, which is about three epochs of the first-stage training data, inspired by the findings of [Scaling Data-Constrained Language Models](https://huggingface.co/papers/2305.16264). In the second stage, the model was trained for 31 billion tokens, which is close to five epochs of the second-stage training data.
Thanks to the WSD learning rate schedule, you can more easily experiment with different first-stage model checkpoints. For example, you could apply the second-stage training on an earlier checkpoint or continue pretraining further before the second stage. Model checkpoints were pushed to this repository every 100,000 training steps (approximately 13 billion tokens).
- [900K](https://huggingface.co/Finnish-NLP/Ahma-7B/tree/5f6eb9498b17fece810d766f81c711c38a2b2de2)
- [800K](https://huggingface.co/Finnish-NLP/Ahma-7B/tree/bc2d607ce302c1b0ff75c229496645cf232c6d98)
- [700K](https://huggingface.co/Finnish-NLP/Ahma-7B/tree/69352a497d5953c5290296a1f429a450978c7f7f)
- [600K](https://huggingface.co/Finnish-NLP/Ahma-7B/tree/760ab5f865b08d9a512c1df523a5c4deb6874322)
- [500K](https://huggingface.co/Finnish-NLP/Ahma-7B/tree/32ea3d35931da8039180e80d67f6c323719ae50a)
- [400K](https://huggingface.co/Finnish-NLP/Ahma-7B/tree/d1256a6815983053d0f9934f21f163d764fc5ecd)
- [300K](https://huggingface.co/Finnish-NLP/Ahma-7B/tree/1e3094c66e788fe81d2aadad5bf8f0431358bd38)
- [200K](https://huggingface.co/Finnish-NLP/Ahma-7B/tree/a4afd130fa0effea047deaaf8bf63b3eba1b323b)
- [100K](https://huggingface.co/Finnish-NLP/Ahma-7B/tree/245fad2f5838af1465cb40ad42caef092e875cd9)
## Evaluation results
### FIN-bench
This Ahma 7B base model was primarily evaluated using [FIN-bench by TurkuNLP](https://github.com/TurkuNLP/FIN-bench), and the same evaluation was carried out for other relevant Finnish models for comparison: [FinGPT 8B by TurkuNLP](https://huggingface.co/TurkuNLP/gpt3-finnish-8B), [Viking 7B by TurkuNLP, SiloGen and HPLT](https://huggingface.co/LumiOpen/Viking-7B), and [Poro 34B by SiloGen, TurkuNLP and HPLT](https://huggingface.co/LumiOpen/Poro-34B). Below are the results with 0-shot and 3-shot settings in FIN-bench.
0-shot results:
| Benchmark | Ahma 3B base (instruct prompt format) | Ahma 3B Instruct (instruct prompt format) | Ahma 7B base (instruct prompt format) | Ahma 7B Instruct (instruct prompt format) | FinGPT 8B | Viking 7B | Poro 34B (8bit quant) |
|:---------------------------|:--------------------------------------|:------------------------------------------|:--------------------------------------|:------------------------------------------|:----------|:----------|:----------------------|
| Analogies | 50.77 | 48.46 | 56.92 | 41.54 | 49.23 | 40.00 | 54.62 |
| Arithmetic | 27.64 | 22.14 | 11.50 | 14.70 | 33.15 | 30.16 | 30.34 |
| Cause and Effect | 59.48 | 58.82 | 59.48 | 53.60 | 66.01 | 58.82 | 62.74 |
| Emotions | 36.25 | 28.12 | 36.25 | 27.50 | 22.50 | 26.25 | 35.63 |
| Empirical Judgements | 33.33 | 35.35 | 33.33 | 33.33 | 27.27 | 33.33 | 49.49 |
| General Knowledge | 44.29 | 48.57 | 51.43 | 37.14 | 40.00 | 24.29 | 51.43 |
| HHH Alignment | 42.09 | 41.66 | 44.23 | 43.22 | 41.81 | 42.51 | 42.92 |
| Intent Recognition | 24.42 | 26.16 | 43.64 | 56.94 | 17.49 | 22.40 | 68.35 |
| Misconceptions | 46.27 | 47.01 | 46.27 | 47.01 | 53.73 | 53.73 | 52.24 |
| Paraphrase | 59.50 | 73.00 | 67.00 | 70.50 | 51.00 | 50.00 | 51.00 |
| Sentence Ambiguity | 53.33 | 65.00 | 60.00 | 63.33 | 51.67 | 48.33 | 50.00 |
| Similarities Abstraction | 65.79 | 68.42 | 71.05 | 61.84 | 60.53 | 65.79 | 60.53 |
| **Non-Arithmetic Average** | **47.55** | **48.95** | **51.33** | **48.30** | **46.17** | **44.42** | **52.08** |
| **Overall Average** | **36.49** | **34.06** | **29.20** | **29.64** | **38.93** | **36.50** | **40.00** |
3-shot results:
| Benchmark | Ahma 3B base (instruct prompt format) | Ahma 3B Instruct (instruct prompt format) | Ahma 7B base (instruct prompt format) | Ahma 7B Instruct (instruct prompt format) | FinGPT 8B | Viking 7B | Poro 34B (8bit quant) |
|:---------------------------|:--------------------------------------|:------------------------------------------|:--------------------------------------|:------------------------------------------|:----------|:----------|:----------------------|
| Analogies | 50.77 | 49.23 | 49.23 | 43.08 | 40.77 | 54.62 | 76.92 |
| Arithmetic | 38.38 | 43.89 | 20.88 | 26.81 | 43.63 | 45.78 | 53.68 |
| Cause and Effect | 60.78 | 64.71 | 66.01 | 62.74 | 64.05 | 58.17 | 67.32 |
| Emotions | 30.00 | 41.25 | 30.00 | 53.75 | 44.37 | 48.13 | 56.87 |
| Empirical Judgements | 46.46 | 44.44 | 39.39 | 39.39 | 32.32 | 43.43 | 63.64 |
| General Knowledge | 47.14 | 40.00 | 27.14 | 44.29 | 54.29 | 28.57 | 74.29 |
| HHH Alignment | 43.53 | 44.80 | 43.80 | 45.09 | 45.39 | 44.80 | 46.07 |
| Intent Recognition | 20.52 | 44.22 | 36.42 | 39.02 | 51.45 | 58.82 | 83.67 |
| Misconceptions | 50.75 | 52.24 | 46.27 | 51.49 | 52.99 | 46.27 | 52.99 |
| Paraphrase | 50.50 | 58.50 | 57.50 | 65.00 | 53.00 | 54.50 | 55.00 |
| Sentence Ambiguity | 53.33 | 48.33 | 53.33 | 51.67 | 51.67 | 53.33 | 66.67 |
| Similarities Abstraction | 69.74 | 72.37 | 72.37 | 69.74 | 64.47 | 73.68 | 75.00 |
| **Non-Arithmetic Average** | **48.48** | **51.49** | **49.05** | **51.63** | **51.19** | **50.94** | **61.96** |
| **Overall Average** | **42.87** | **47.27** | **33.41** | **37.84** | **46.99** | **48.07** | **57.36** |
As we can see, Ahma 7B base model has bad arithmetic performance but in non-arithmetic tasks it clearly outperforms same sized models like the FinGPT 8B and Viking 7B, especially in 0-shot usage. Ahma 7B base model is even on-par with the 5X larger Poro 34B model, in non-arithmetic tasks in 0-shot usage. This result might be attributed to Ahma's 2-stage pretraining and the inclusion of instruct-following examples during the pretraining phase.
In a 3-shot setting, the results are more mixed. The poorer performance of Ahma 7B base model in 3-shot settings might be due to the use of the instruct prompt format and having only single-turn instruction-following training examples.
### MTBench Finnish
This Ahma 7B base model was also evaluated using [MTBench Finnish by LumiOpen](https://github.com/LumiOpen/FastChat/tree/main/fastchat/llm_judge) even though this Ahma model is not fine-tuned for chat. Since the MTBench evaluates also multi-turn chats while Ahma base models were only pretrained with single-turn instruction following examples, we have reported MTBench Finnish results separately for their single-turn and multi-turn evaluation examples. [Poro 34B Chat by SiloGen, TurkuNLP and HPLT](https://huggingface.co/LumiOpen/Poro-34B-chat) model's presumably multi-turn results are copied from their model card for the comparison.
Single-turn results:
| Benchmark | Ahma 3B base (instruct prompt format) | Ahma 3B Instruct (instruct prompt format) | Ahma 7B base (instruct prompt format) | Ahma 7B Instruct (instruct prompt format) |
|:--------------------|:--------------------------------------|:------------------------------------------|:--------------------------------------|:------------------------------------------|
| Coding | 1.00 | 1.00 | 1.70 | 1.10 |
| Extraction | 2.00 | 1.30 | 3.10 | 3.00 |
| Humanities | 4.05 | 6.20 | 6.60 | 8.00 |
| Math | 3.00 | 3.20 | 3.90 | 2.90 |
| Reasoning | 2.90 | 4.60 | 3.70 | 5.70 |
| Roleplay | 4.80 | 6.50 | 6.60 | 7.20 |
| STEM | 5.10 | 5.95 | 6.75 | 7.30 |
| Writing | 6.60 | 9.00 | 7.10 | 8.80 |
| **Overall Average** | **3.68** | **4.72** | **4.93** | **5.50** |
Multi-turn results:
| Benchmark | Ahma 3B base (instruct prompt format) | Ahma 3B Instruct (instruct prompt format) | Ahma 7B base (instruct prompt format) | Ahma 7B Instruct (instruct prompt format) | Poro 34B Chat |
|:--------------------|:--------------------------------------|:------------------------------------------|:--------------------------------------|:------------------------------------------|:--------------|
| Coding | 1.00 | 1.00 | 1.40 | 1.05 | 3.70 |
| Extraction | 1.55 | 1.15 | 2.05 | 2.65 | 6.37 |
| Humanities | 3.25 | 6.20 | 4.95 | 7.85 | 9.25 |
| Math | 2.20 | 2.70 | 2.50 | 2.40 | 1.20 |
| Reasoning | 2.45 | 3.50 | 2.55 | 4.50 | 4.35 |
| Roleplay | 4.90 | 6.40 | 6.35 | 6.60 | 7.35 |
| STEM | 4.20 | 4.78 | 4.28 | 5.40 | 7.80 |
| Writing | 3.80 | 6.65 | 4.10 | 6.25 | 8.50 |
| **Overall Average** | **2.92** | **4.05** | **3.52** | **4.59** | **6.06** |
As we can see, Ahma 7B base model struggles with multi-turn examples, as expected, since it has only been pretrained with single-turn instruction following examples. In addition, coding performance was expectedly poor because the Ahma 7B model is not trained with code data. In single-turn setting, Ahma 7B beats both the Ahma 3B base and Instruct-tuned versions, demonstrating greater base capability to be further improved with Instruct-tuning.
## Acknowledgements
This project would not have been possible without compute generously provided by Google through the
[TPU Research Cloud](https://sites.research.google/trc/).
## Team Members
- Aapo Tanskanen, [Hugging Face profile](https://huggingface.co/aapot), [LinkedIn profile](https://www.linkedin.com/in/aapotanskanen/)
- Rasmus Toivanen, [Hugging Face profile](https://huggingface.co/RASMUS), [LinkedIn profile](https://www.linkedin.com/in/rasmustoivanen/)
Feel free to contact us for more details 🤗
![Ahma](ahma.jpg)

BIN
ahma.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 59 KiB

27
config.json Normal file
View File

@@ -0,0 +1,27 @@
{
"architectures": [
"LlamaForCausalLM"
],
"attention_bias": false,
"attention_dropout": 0.0,
"bos_token_id": 1,
"eos_token_id": 2,
"hidden_act": "silu",
"hidden_size": 4096,
"initializer_range": 0.02,
"intermediate_size": 11008,
"max_position_embeddings": 2048,
"model_type": "llama",
"num_attention_heads": 32,
"num_hidden_layers": 32,
"num_key_value_heads": 32,
"pretraining_tp": 1,
"rms_norm_eps": 1e-06,
"rope_scaling": null,
"rope_theta": 10000.0,
"tie_word_embeddings": false,
"torch_dtype": "float16",
"transformers_version": "4.38.0.dev0",
"use_cache": true,
"vocab_size": 64256
}

4
convert_to_hf_model.sh Normal file
View File

@@ -0,0 +1,4 @@
JAX_PLATFORM_NAME=cpu python3 -m EasyLM.models.llama.convert_easylm_to_hf \
--load_checkpoint='' \
--model_size='7b' \
--output_dir='./'

6
generation_config.json Normal file
View File

@@ -0,0 +1,6 @@
{
"_from_model_config": true,
"bos_token_id": 1,
"eos_token_id": 2,
"transformers_version": "4.38.0.dev0"
}

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:625f48801b93273fa419f51704ccc45bba97010337ed52ed8db290767a152c71
size 4978830560

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:2d8f06f2f59ecbffe4fce68ab67d5ad530d951fbce594f7c9850d9bce0b739a3
size 4991431320

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:c34dac66ee076f27cda65ce37ae74894f30a4e6cc49f898cf1bf67cbf3c1f10e
size 4035085208

View File

@@ -0,0 +1,298 @@
{
"metadata": {
"total_size": 14005313536
},
"weight_map": {
"lm_head.weight": "model-00003-of-00003.safetensors",
"model.embed_tokens.weight": "model-00001-of-00003.safetensors",
"model.layers.0.input_layernorm.weight": "model-00001-of-00003.safetensors",
"model.layers.0.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
"model.layers.0.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
"model.layers.0.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
"model.layers.0.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
"model.layers.0.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
"model.layers.0.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
"model.layers.0.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
"model.layers.0.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
"model.layers.1.input_layernorm.weight": "model-00001-of-00003.safetensors",
"model.layers.1.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
"model.layers.1.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
"model.layers.1.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
"model.layers.1.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
"model.layers.1.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
"model.layers.1.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
"model.layers.1.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
"model.layers.1.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
"model.layers.10.input_layernorm.weight": "model-00001-of-00003.safetensors",
"model.layers.10.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
"model.layers.10.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
"model.layers.10.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
"model.layers.10.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
"model.layers.10.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
"model.layers.10.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
"model.layers.10.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
"model.layers.10.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
"model.layers.11.input_layernorm.weight": "model-00002-of-00003.safetensors",
"model.layers.11.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
"model.layers.11.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
"model.layers.11.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
"model.layers.11.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
"model.layers.11.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
"model.layers.11.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
"model.layers.11.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
"model.layers.11.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
"model.layers.12.input_layernorm.weight": "model-00002-of-00003.safetensors",
"model.layers.12.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
"model.layers.12.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
"model.layers.12.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
"model.layers.12.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
"model.layers.12.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
"model.layers.12.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
"model.layers.12.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
"model.layers.12.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
"model.layers.13.input_layernorm.weight": "model-00002-of-00003.safetensors",
"model.layers.13.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
"model.layers.13.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
"model.layers.13.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
"model.layers.13.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
"model.layers.13.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
"model.layers.13.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
"model.layers.13.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
"model.layers.13.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
"model.layers.14.input_layernorm.weight": "model-00002-of-00003.safetensors",
"model.layers.14.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
"model.layers.14.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
"model.layers.14.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
"model.layers.14.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
"model.layers.14.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
"model.layers.14.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
"model.layers.14.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
"model.layers.14.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
"model.layers.15.input_layernorm.weight": "model-00002-of-00003.safetensors",
"model.layers.15.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
"model.layers.15.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
"model.layers.15.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
"model.layers.15.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
"model.layers.15.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
"model.layers.15.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
"model.layers.15.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
"model.layers.15.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
"model.layers.16.input_layernorm.weight": "model-00002-of-00003.safetensors",
"model.layers.16.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
"model.layers.16.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
"model.layers.16.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
"model.layers.16.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
"model.layers.16.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
"model.layers.16.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
"model.layers.16.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
"model.layers.16.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
"model.layers.17.input_layernorm.weight": "model-00002-of-00003.safetensors",
"model.layers.17.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
"model.layers.17.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
"model.layers.17.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
"model.layers.17.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
"model.layers.17.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
"model.layers.17.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
"model.layers.17.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
"model.layers.17.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
"model.layers.18.input_layernorm.weight": "model-00002-of-00003.safetensors",
"model.layers.18.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
"model.layers.18.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
"model.layers.18.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
"model.layers.18.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
"model.layers.18.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
"model.layers.18.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
"model.layers.18.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
"model.layers.18.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
"model.layers.19.input_layernorm.weight": "model-00002-of-00003.safetensors",
"model.layers.19.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
"model.layers.19.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
"model.layers.19.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
"model.layers.19.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
"model.layers.19.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
"model.layers.19.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
"model.layers.19.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
"model.layers.19.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
"model.layers.2.input_layernorm.weight": "model-00001-of-00003.safetensors",
"model.layers.2.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
"model.layers.2.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
"model.layers.2.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
"model.layers.2.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
"model.layers.2.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
"model.layers.2.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
"model.layers.2.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
"model.layers.2.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
"model.layers.20.input_layernorm.weight": "model-00002-of-00003.safetensors",
"model.layers.20.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
"model.layers.20.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
"model.layers.20.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
"model.layers.20.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
"model.layers.20.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
"model.layers.20.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
"model.layers.20.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
"model.layers.20.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
"model.layers.21.input_layernorm.weight": "model-00002-of-00003.safetensors",
"model.layers.21.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
"model.layers.21.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
"model.layers.21.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
"model.layers.21.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
"model.layers.21.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
"model.layers.21.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
"model.layers.21.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
"model.layers.21.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
"model.layers.22.input_layernorm.weight": "model-00002-of-00003.safetensors",
"model.layers.22.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
"model.layers.22.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
"model.layers.22.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
"model.layers.22.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
"model.layers.22.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
"model.layers.22.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
"model.layers.22.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
"model.layers.22.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
"model.layers.23.input_layernorm.weight": "model-00003-of-00003.safetensors",
"model.layers.23.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
"model.layers.23.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
"model.layers.23.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
"model.layers.23.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
"model.layers.23.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
"model.layers.23.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
"model.layers.23.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
"model.layers.23.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
"model.layers.24.input_layernorm.weight": "model-00003-of-00003.safetensors",
"model.layers.24.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
"model.layers.24.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
"model.layers.24.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
"model.layers.24.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
"model.layers.24.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
"model.layers.24.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
"model.layers.24.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
"model.layers.24.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
"model.layers.25.input_layernorm.weight": "model-00003-of-00003.safetensors",
"model.layers.25.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
"model.layers.25.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
"model.layers.25.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
"model.layers.25.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
"model.layers.25.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
"model.layers.25.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
"model.layers.25.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
"model.layers.25.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
"model.layers.26.input_layernorm.weight": "model-00003-of-00003.safetensors",
"model.layers.26.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
"model.layers.26.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
"model.layers.26.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
"model.layers.26.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
"model.layers.26.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
"model.layers.26.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
"model.layers.26.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
"model.layers.26.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
"model.layers.27.input_layernorm.weight": "model-00003-of-00003.safetensors",
"model.layers.27.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
"model.layers.27.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
"model.layers.27.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
"model.layers.27.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
"model.layers.27.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
"model.layers.27.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
"model.layers.27.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
"model.layers.27.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
"model.layers.28.input_layernorm.weight": "model-00003-of-00003.safetensors",
"model.layers.28.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
"model.layers.28.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
"model.layers.28.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
"model.layers.28.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
"model.layers.28.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
"model.layers.28.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
"model.layers.28.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
"model.layers.28.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
"model.layers.29.input_layernorm.weight": "model-00003-of-00003.safetensors",
"model.layers.29.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
"model.layers.29.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
"model.layers.29.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
"model.layers.29.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
"model.layers.29.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
"model.layers.29.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
"model.layers.29.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
"model.layers.29.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
"model.layers.3.input_layernorm.weight": "model-00001-of-00003.safetensors",
"model.layers.3.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
"model.layers.3.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
"model.layers.3.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
"model.layers.3.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
"model.layers.3.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
"model.layers.3.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
"model.layers.3.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
"model.layers.3.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
"model.layers.30.input_layernorm.weight": "model-00003-of-00003.safetensors",
"model.layers.30.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
"model.layers.30.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
"model.layers.30.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
"model.layers.30.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
"model.layers.30.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
"model.layers.30.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
"model.layers.30.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
"model.layers.30.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
"model.layers.31.input_layernorm.weight": "model-00003-of-00003.safetensors",
"model.layers.31.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
"model.layers.31.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
"model.layers.31.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
"model.layers.31.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
"model.layers.31.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
"model.layers.31.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
"model.layers.31.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
"model.layers.31.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
"model.layers.4.input_layernorm.weight": "model-00001-of-00003.safetensors",
"model.layers.4.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
"model.layers.4.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
"model.layers.4.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
"model.layers.4.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
"model.layers.4.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
"model.layers.4.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
"model.layers.4.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
"model.layers.4.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
"model.layers.5.input_layernorm.weight": "model-00001-of-00003.safetensors",
"model.layers.5.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
"model.layers.5.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
"model.layers.5.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
"model.layers.5.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
"model.layers.5.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
"model.layers.5.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
"model.layers.5.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
"model.layers.5.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
"model.layers.6.input_layernorm.weight": "model-00001-of-00003.safetensors",
"model.layers.6.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
"model.layers.6.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
"model.layers.6.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
"model.layers.6.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
"model.layers.6.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
"model.layers.6.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
"model.layers.6.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
"model.layers.6.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
"model.layers.7.input_layernorm.weight": "model-00001-of-00003.safetensors",
"model.layers.7.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
"model.layers.7.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
"model.layers.7.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
"model.layers.7.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
"model.layers.7.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
"model.layers.7.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
"model.layers.7.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
"model.layers.7.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
"model.layers.8.input_layernorm.weight": "model-00001-of-00003.safetensors",
"model.layers.8.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
"model.layers.8.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
"model.layers.8.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
"model.layers.8.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
"model.layers.8.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
"model.layers.8.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
"model.layers.8.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
"model.layers.8.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
"model.layers.9.input_layernorm.weight": "model-00001-of-00003.safetensors",
"model.layers.9.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
"model.layers.9.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
"model.layers.9.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
"model.layers.9.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
"model.layers.9.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
"model.layers.9.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
"model.layers.9.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
"model.layers.9.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
"model.norm.weight": "model-00003-of-00003.safetensors"
}
}

55
pretrain_llama_7b.sh Executable file
View File

@@ -0,0 +1,55 @@
#! /bin/bash
# Put your WANDB API key here to enable logging to wandb.
export WANDB_API_KEY=''
# TPU specific flags to improve training throughput
export LIBTPU_INIT_ARGS='--xla_jf_spmd_threshold_for_windowed_einsum_mib=0 --xla_tpu_spmd_threshold_for_allgather_cse=10000 --xla_tpu_spmd_rewrite_einsum_with_reshape=true --xla_enable_async_all_gather=true --jax_enable_async_collective_offload=true --xla_tpu_enable_latency_hiding_scheduler=true TPU_MEGACORE=MEGACORE_DENSE'
python3 -m EasyLM.models.llama.llama_train \
--jax_distributed.initialize_jax_distributed=True \
--mesh_dim='1,-1,4' \
--dtype='bf16' \
--total_steps=900000 \
--eval_freq=50000 \
--log_freq=1000 \
--save_model_freq=2000 \
--save_milestone_freq=50000 \
--load_llama_config='7b' \
--update_llama_config='' \
--load_dataset_state='' \
--load_checkpoint='' \
--tokenizer.vocab_file='tokenizer.model' \
--optimizer.type='lion' \
--optimizer.lion_optimizer.weight_decay=1.0 \
--optimizer.lion_optimizer.lr_schedule_type='warmup_constant_linear_decay' \
--optimizer.lion_optimizer.lr=1e-4 \
--optimizer.lion_optimizer.end_lr=1e-5 \
--optimizer.lion_optimizer.lr_warmup_steps=60000 \
--optimizer.lion_optimizer.lr_constant_steps=900000 \
--optimizer.lion_optimizer.lr_decay_steps=100000 \
--optimizer.lion_optimizer.bf16_momentum=True \
--train_dataset.type='huggingface' \
--train_dataset.text_processor.fields='text' \
--train_dataset.text_processor.add_eos_token=True \
--train_dataset.text_processor.add_bos_token=True \
--train_dataset.huggingface_dataset.path='/researchdisk/lm_training_dataset_first_stage' \
--train_dataset.huggingface_dataset.split='train' \
--train_dataset.huggingface_dataset.seq_length=2048 \
--train_dataset.huggingface_dataset.batch_size=64 \
--eval_dataset.type='huggingface' \
--eval_dataset.text_processor.fields='text' \
--eval_dataset.text_processor.add_eos_token=True \
--eval_dataset.text_processor.add_bos_token=True \
--eval_dataset.huggingface_dataset.path='/researchdisk/lm_training_dataset_first_stage' \
--eval_dataset.huggingface_dataset.split='validation' \
--eval_dataset.huggingface_dataset.seq_length=2048 \
--eval_dataset.huggingface_dataset.batch_size=64 \
--checkpointer.save_optimizer_state=True \
--logger.online=True \
--logger.prefix='EasyLM' \
--logger.project="llama-7b-v2" \
--logger.output_dir="gs://finnish-nlp-research-us/llama-7b-v2-checkpoint" \
--logger.wandb_dir="./"

23
special_tokens_map.json Normal file
View File

@@ -0,0 +1,23 @@
{
"bos_token": {
"content": "<s>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
},
"eos_token": {
"content": "</s>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
},
"unk_token": {
"content": "<unk>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
}
}

223314
tokenizer.json Normal file

File diff suppressed because it is too large Load Diff

3
tokenizer.model Normal file
View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:1980c00aa3cb5455177a39efa3e60e7b8887ee89c3f7b8950719592a08ad9456
size 1400411

64256
tokenizer.vocab Normal file

File diff suppressed because it is too large Load Diff

75
tokenizer_config.json Normal file
View File

@@ -0,0 +1,75 @@
{
"add_bos_token": true,
"add_eos_token": false,
"add_prefix_space": true,
"added_tokens_decoder": {
"0": {
"content": "<unk>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"1": {
"content": "<s>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"2": {
"content": "</s>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"3": {
"content": "[INST]",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"4": {
"content": "[/INST]",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"5": {
"content": "<<SYS>>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"6": {
"content": "<</SYS>>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
}
},
"bos_token": "<s>",
"chat_template": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = 'Olet tekoälyavustaja. Vastaat aina mahdollisimman avuliaasti. Vastauksesi eivät saa sisältää mitään haitallista, epäeettistä, rasistista, seksististä, vaarallista tai laitonta sisältöä. Jos kysymyksessä ei ole mitään järkeä tai se ei ole asiasisällöltään johdonmukainen, selitä miksi sen sijaan, että vastaisit jotain väärin. Jos et tiedä vastausta kysymykseen, älä kerro väärää tietoa.' %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<<SYS>>\\n' + system_message + '\\n<</SYS>>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + ' [INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + eos_token }}{% endif %}{% endfor %}",
"clean_up_tokenization_spaces": false,
"eos_token": "</s>",
"legacy": false,
"model_max_length": 1000000000000000019884624838656,
"pad_token": null,
"sp_model_kwargs": {},
"spaces_between_special_tokens": false,
"tokenizer_class": "PreTrainedTokenizerFast",
"unk_token": "<unk>",
"use_default_system_prompt": false
}

10
train_sentencepiece.py Normal file
View File

@@ -0,0 +1,10 @@
import sentencepiece as spm
spm.SentencePieceTrainer.train(input="/researchdisk/training_dataset_sentences/train.txt", model_prefix="tokenizer",
model_type="bpe", split_digits=True, vocab_size=64256, byte_fallback=True,
normalization_rule_name="nfkc",
user_defined_symbols=["[INST]", "[/INST]", "<<SYS>>", "<</SYS>>"],
required_chars="abcdefghijklmnopqrstuvwxyzåäöABCDEFGHIJKLMNOPQRSTUVWXYZÅÄÖ",
train_extremely_large_corpus=True,
input_sentence_size=500000000, shuffle_input_sentence=True,
num_threads=96)