初始化项目,由ModelHub XC社区提供模型
Model: Finnish-NLP/Ahma-7B Source: Original Platform
This commit is contained in:
35
.gitattributes
vendored
Normal file
35
.gitattributes
vendored
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
*.7z filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.arrow filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.bin filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.ftz filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.gz filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.h5 filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.joblib filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.model filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.npy filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.npz filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.onnx filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.ot filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.parquet filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.pb filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.pickle filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.pkl filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.pt filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.pth filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.rar filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
||||||
|
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.tar filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.tflite filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.tgz filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.wasm filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.xz filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.zip filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.zst filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
||||||
1
.gitignore
vendored
Normal file
1
.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
|||||||
|
__pycache__
|
||||||
0
EasyLM/__init__.py
Normal file
0
EasyLM/__init__.py
Normal file
228
EasyLM/bpt.py
Normal file
228
EasyLM/bpt.py
Normal file
@@ -0,0 +1,228 @@
|
|||||||
|
"""
|
||||||
|
An implementation of Blockwise parallel transformer https://arxiv.org/abs/2305.19370
|
||||||
|
Also include a reference implementation of memory-efficient transformer https://arxiv.org/abs/2112.05682
|
||||||
|
"""
|
||||||
|
|
||||||
|
import functools
|
||||||
|
from typing import NamedTuple
|
||||||
|
|
||||||
|
import flax.linen as nn
|
||||||
|
import jax
|
||||||
|
import jax.lax as lax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
|
"""
|
||||||
|
Computing ffn blockwise without materializing the large hidden tensor, training
|
||||||
|
4x longer sequences than the memory-efficient transformer.
|
||||||
|
Blockwise parallel transformer https://arxiv.org/abs/2305.19370 Liu et al. 2023
|
||||||
|
"""
|
||||||
|
def blockwise_ffn(remat_ffn, inputs, chunk_size=2048, deterministic=True):
|
||||||
|
# remat_ffn: a rematerialized ffn with policy jax.checkpoint_policies.nothing_saveable()
|
||||||
|
# inputs: (batch, seq_len, dim)
|
||||||
|
# chunk_size: the chunk size to split the sequence
|
||||||
|
inputs = rearrange(inputs, 'b (c n) d -> b c n d', c=chunk_size)
|
||||||
|
def scan_ffn(remat_ffn, carry, hidden_states):
|
||||||
|
outputs = remat_ffn(hidden_states, deterministic=deterministic)
|
||||||
|
return carry, outputs
|
||||||
|
scan_axis = inputs.ndim - 2
|
||||||
|
_, res = nn.scan(
|
||||||
|
scan_ffn,
|
||||||
|
variable_broadcast="params",
|
||||||
|
split_rngs={"params": False, "dropout": True},
|
||||||
|
in_axes=scan_axis,
|
||||||
|
out_axes=scan_axis,
|
||||||
|
)(remat_ffn, None, inputs)
|
||||||
|
res = rearrange(res, 'b c n d -> b (c n) d')
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
Compute attention blockwise without materializing the full attention matrix,
|
||||||
|
initially proposed in memory-efficient transformer https://arxiv.org/abs/2112.05682 Rabe et al. 2021;
|
||||||
|
flash attention https://arxiv.org/abs/2205.14135 Dao et al. 2022 proposes a CUDA
|
||||||
|
efficient implementation; blockwise parallel transformer https://arxiv.org/abs/2305.19370
|
||||||
|
Liu et al. 2023 proposes blockwise computing both attention and FFN, enabling 4x
|
||||||
|
longer sequences than memory-efficient/flash-attention and fusion of attention and FFN.
|
||||||
|
"""
|
||||||
|
def blockwise_attn(
|
||||||
|
query, key, value,
|
||||||
|
bias=None,
|
||||||
|
deterministic=True,
|
||||||
|
dropout_rng=None,
|
||||||
|
attn_pdrop=0.0,
|
||||||
|
causal=True,
|
||||||
|
query_chunk_size=2048,
|
||||||
|
key_chunk_size=2048,
|
||||||
|
dtype=jnp.float32,
|
||||||
|
policy=jax.checkpoint_policies.nothing_saveable(),
|
||||||
|
precision=None,
|
||||||
|
float32_logits=True,
|
||||||
|
prevent_cse=True,
|
||||||
|
):
|
||||||
|
# query, key, value: (batch, seq_len, num_heads, dim_per_head)
|
||||||
|
# bias: (batch, seq_len) can be used to mask out attention (e.g. padding)
|
||||||
|
# causal: whether to use causal mask
|
||||||
|
# policy: one of jax.checkpoint_policies
|
||||||
|
query = query / jnp.sqrt(query.shape[-1]).astype(dtype)
|
||||||
|
if float32_logits:
|
||||||
|
query = query.astype(jnp.float32)
|
||||||
|
key = key.astype(jnp.float32)
|
||||||
|
|
||||||
|
batch, q_len, num_heads, dim_per_head = query.shape
|
||||||
|
batch, kv_len, num_heads, dim_per_head = key.shape
|
||||||
|
batch, kv_len, num_heads, dim_per_head = value.shape
|
||||||
|
|
||||||
|
num_q = q_len // query_chunk_size
|
||||||
|
num_kv = kv_len // key_chunk_size
|
||||||
|
query = query.reshape((batch, num_q, query_chunk_size, num_heads, dim_per_head))
|
||||||
|
key = key.reshape((batch, num_kv, key_chunk_size, num_heads, dim_per_head))
|
||||||
|
value = value.reshape((batch, num_kv, key_chunk_size, num_heads, dim_per_head))
|
||||||
|
|
||||||
|
query = jnp.moveaxis(query, 1, 0)
|
||||||
|
key = jnp.moveaxis(key, 1, 0)
|
||||||
|
value = jnp.moveaxis(value, 1, 0)
|
||||||
|
|
||||||
|
if bias is not None:
|
||||||
|
for bias_dim, broadcast_dim in zip(bias.shape, (batch, num_heads, q_len, kv_len)):
|
||||||
|
assert bias_dim == 1 or bias_dim == broadcast_dim
|
||||||
|
if not deterministic and attn_pdrop > 0.0:
|
||||||
|
attn_dropout_rng, dropout_rng = jax.random.split(dropout_rng)
|
||||||
|
attn_dropout = jax.random.bernoulli(attn_dropout_rng, attn_pdrop, (batch, num_heads, q_len, kv_len))
|
||||||
|
else:
|
||||||
|
attn_dropout = None
|
||||||
|
|
||||||
|
_chunk_bias_fn = functools.partial(
|
||||||
|
_chunk_attention_bias,
|
||||||
|
query_chunk_size, key_chunk_size, bias, deterministic,
|
||||||
|
attn_dropout, attn_pdrop, causal, dtype)
|
||||||
|
|
||||||
|
def scan_attention(args):
|
||||||
|
query_chunk, query_chunk_idx = args
|
||||||
|
|
||||||
|
@functools.partial(jax.checkpoint, prevent_cse=prevent_cse, policy=policy)
|
||||||
|
def scan_kv_block(carry, args):
|
||||||
|
key_chunk, value_chunk, key_chunk_idx = args
|
||||||
|
(numerator, denominator, prev_max_score) = carry
|
||||||
|
attn_weights = jnp.einsum('bqhd,bkhd->bqhk', query_chunk, key_chunk, precision=precision)
|
||||||
|
bias_chunk = _chunk_bias_fn(query_chunk_idx, key_chunk_idx)
|
||||||
|
bias_chunk = jnp.moveaxis(bias_chunk, 1, 2)
|
||||||
|
attn_weights = attn_weights + bias_chunk
|
||||||
|
|
||||||
|
max_score = jnp.max(attn_weights, axis=-1, keepdims=True)
|
||||||
|
max_score = jnp.maximum(prev_max_score, max_score)
|
||||||
|
max_score = jax.lax.stop_gradient(max_score)
|
||||||
|
exp_weights = jnp.exp(attn_weights - max_score)
|
||||||
|
exp_values = jnp.einsum(
|
||||||
|
'bqhv,bvhd->bqhd', exp_weights, value_chunk, precision=precision
|
||||||
|
)
|
||||||
|
correction = jnp.exp(prev_max_score - max_score)
|
||||||
|
numerator = numerator * correction + exp_values
|
||||||
|
denominator = denominator * correction + exp_weights.sum(axis=-1, keepdims=True)
|
||||||
|
return Carry(numerator, denominator, max_score), None
|
||||||
|
|
||||||
|
def skip_upper_half(carry, args):
|
||||||
|
key_chunk, value_chunk, key_chunk_idx = args
|
||||||
|
skip_block = jnp.array(False)
|
||||||
|
if causal:
|
||||||
|
skip_block = query_chunk_idx < key_chunk_idx
|
||||||
|
return jax.lax.cond(
|
||||||
|
skip_block,
|
||||||
|
lambda carry, args: (carry, None),
|
||||||
|
scan_kv_block,
|
||||||
|
carry,
|
||||||
|
args,
|
||||||
|
)
|
||||||
|
|
||||||
|
init_carry = Carry(
|
||||||
|
jnp.zeros((batch, query_chunk_size, num_heads, dim_per_head), dtype=query.dtype),
|
||||||
|
jnp.zeros((batch, query_chunk_size, num_heads, dim_per_head), dtype=query.dtype),
|
||||||
|
(-jnp.inf) * jnp.ones((batch, query_chunk_size, num_heads, 1), dtype=query.dtype),
|
||||||
|
)
|
||||||
|
(numerator, denominator, max_score), _ = lax.scan(
|
||||||
|
skip_upper_half, init_carry, xs=(key, value, jnp.arange(0, num_kv))
|
||||||
|
)
|
||||||
|
outputs = (numerator / denominator).astype(dtype)
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
_, res = lax.scan(
|
||||||
|
lambda _, x: ((), scan_attention(x)),
|
||||||
|
(), xs=(query, jnp.arange(0, num_q))
|
||||||
|
)
|
||||||
|
res = rearrange(res, 'n b c h d -> b (n c) h d')
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
class Carry(NamedTuple):
|
||||||
|
numerator: jax.Array
|
||||||
|
denominator: jax.Array
|
||||||
|
max_so_far: jax.Array
|
||||||
|
|
||||||
|
|
||||||
|
def _chunk_attention_bias(query_chunk_size, key_chunk_size,
|
||||||
|
bias, deterministic, attn_dropout, attn_pdrop, causal,
|
||||||
|
dtype, query_chunk_idx, key_chunk_idx):
|
||||||
|
query_offset = query_chunk_idx * query_chunk_size
|
||||||
|
key_offset = key_chunk_idx * key_chunk_size
|
||||||
|
chunk_bias = jnp.zeros((1, 1, 1, 1), dtype=dtype)
|
||||||
|
if bias is not None:
|
||||||
|
chunk_bias = lax.dynamic_slice(
|
||||||
|
bias,
|
||||||
|
start_indices=(0, 0, query_offset, key_offset),
|
||||||
|
slice_sizes=(*bias.shape[:2], min(bias.shape[-2], query_chunk_size), min(bias.shape[-1], key_chunk_size)),
|
||||||
|
)
|
||||||
|
|
||||||
|
if causal:
|
||||||
|
query_idx = lax.broadcasted_iota(dtype=jnp.int32, shape=(query_chunk_size, 1), dimension=0)
|
||||||
|
key_idx = lax.broadcasted_iota(dtype=jnp.int32, shape=(1, key_chunk_size), dimension=1)
|
||||||
|
offset = query_offset - key_offset
|
||||||
|
query_idx += offset
|
||||||
|
causal_mask_value = (query_idx < key_idx) * jnp.finfo(dtype).min
|
||||||
|
chunk_bias += causal_mask_value.reshape(1, 1, *causal_mask_value.shape)
|
||||||
|
|
||||||
|
if not deterministic and attn_pdrop > 0.0:
|
||||||
|
attn_dropout_slice = lax.dynamic_slice(
|
||||||
|
attn_dropout,
|
||||||
|
start_indices=(0, 0, query_offset, key_offset),
|
||||||
|
slice_sizes=(
|
||||||
|
*attn_dropout.shape[:2],
|
||||||
|
min(attn_dropout.shape[-2], query_chunk_size),
|
||||||
|
min(attn_dropout.shape[-1], key_chunk_size),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
chunk_bias += attn_dropout_slice * jnp.finfo(dtype).min
|
||||||
|
return chunk_bias.astype(dtype)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# test
|
||||||
|
def reference_attn(query, key, value, causal, dtype):
|
||||||
|
query = query / jnp.sqrt(query.shape[-1]).astype(dtype)
|
||||||
|
logits = jnp.einsum("bqhc,bkhc->bhqk", query, key)
|
||||||
|
if causal:
|
||||||
|
mask_value = jnp.finfo(logits.dtype).min
|
||||||
|
_, q_seq_len, _, _ = query.shape
|
||||||
|
_, kv_seq_len, _, _ = key.shape
|
||||||
|
mask_shape = (q_seq_len, kv_seq_len)
|
||||||
|
row_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 0)
|
||||||
|
col_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 1)
|
||||||
|
causal_mask = (row_ids < col_ids)[None, None, :, :]
|
||||||
|
logits = logits + jnp.where(causal_mask, mask_value, 0.0)
|
||||||
|
weights = jax.nn.softmax(logits, axis=-1)
|
||||||
|
out = jnp.einsum("bhqk,bkhc->bqhc", weights, value)
|
||||||
|
return out
|
||||||
|
|
||||||
|
# random inputs
|
||||||
|
shape = (1, 32, 8, 64)
|
||||||
|
query = jax.random.normal(jax.random.PRNGKey(0), shape)
|
||||||
|
key = jax.random.normal(jax.random.PRNGKey(1), shape)
|
||||||
|
value = jax.random.normal(jax.random.PRNGKey(2), shape)
|
||||||
|
|
||||||
|
causal = True
|
||||||
|
chunk_size = 4
|
||||||
|
policy = jax.checkpoint_policies.nothing_saveable()
|
||||||
|
|
||||||
|
blockwise = blockwise_attn(query, key, value, None, False, None, 0.0, causal, chunk_size, chunk_size, jnp.float32, policy, 'float32', True, False)
|
||||||
|
reference = reference_attn(query, key, value, causal, 'float32')
|
||||||
|
|
||||||
|
assert jnp.allclose(reference, blockwise, atol=1e-6)
|
||||||
212
EasyLM/checkpoint.py
Normal file
212
EasyLM/checkpoint.py
Normal file
@@ -0,0 +1,212 @@
|
|||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
from ml_collections import ConfigDict
|
||||||
|
import mlxu
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import flax
|
||||||
|
from flax.serialization import (
|
||||||
|
from_bytes, to_bytes, to_state_dict, from_state_dict
|
||||||
|
)
|
||||||
|
from flax.traverse_util import flatten_dict, unflatten_dict, empty_node
|
||||||
|
import msgpack
|
||||||
|
|
||||||
|
from EasyLM.jax_utils import tree_apply, float_tensor_to_dtype
|
||||||
|
|
||||||
|
|
||||||
|
class StreamingCheckpointer(object):
|
||||||
|
""" Custom msgpack checkpointer that saves large train states by serializing
|
||||||
|
and saving tensors one by one in a streaming fashion. Avoids running
|
||||||
|
out of memory or local TPU disk with default flax checkpointer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_default_config(updates=None):
|
||||||
|
config = ConfigDict()
|
||||||
|
config.float_dtype = 'bf16'
|
||||||
|
config.save_optimizer_state = False
|
||||||
|
|
||||||
|
if updates is not None:
|
||||||
|
config.update(ConfigDict(updates).copy_and_resolve_references())
|
||||||
|
return config
|
||||||
|
|
||||||
|
def __init__(self, config, checkpoint_dir, enable=True):
|
||||||
|
self.config = self.get_default_config(config)
|
||||||
|
self.checkpoint_dir = checkpoint_dir
|
||||||
|
self.enable = enable
|
||||||
|
|
||||||
|
def save_checkpoint(self, train_state, filename, gather_fns=None):
|
||||||
|
if self.enable:
|
||||||
|
path = os.path.join(self.checkpoint_dir, filename)
|
||||||
|
else:
|
||||||
|
path = '/dev/null'
|
||||||
|
self.save_train_state_to_file(
|
||||||
|
train_state, path, gather_fns, self.config.float_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def save_train_state_to_file(train_state, path, gather_fns=None, float_dtype=None):
|
||||||
|
train_state = to_state_dict(train_state)
|
||||||
|
packer = msgpack.Packer()
|
||||||
|
flattend_train_state = flatten_dict(train_state)
|
||||||
|
if gather_fns is not None:
|
||||||
|
gather_fns = flatten_dict(to_state_dict(gather_fns))
|
||||||
|
|
||||||
|
with mlxu.open_file(path, "wb") as fout:
|
||||||
|
for key, value in flattend_train_state.items():
|
||||||
|
if gather_fns is not None:
|
||||||
|
value = gather_fns[key](value)
|
||||||
|
value = float_tensor_to_dtype(value, float_dtype)
|
||||||
|
fout.write(packer.pack((key, to_bytes(value))))
|
||||||
|
|
||||||
|
def save_pickle(self, obj, filename):
|
||||||
|
if self.enable:
|
||||||
|
path = os.path.join(self.checkpoint_dir, filename)
|
||||||
|
else:
|
||||||
|
path = '/dev/null'
|
||||||
|
mlxu.save_pickle(obj, path)
|
||||||
|
|
||||||
|
def save_all(self, train_state, gather_fns, metadata=None, dataset=None, milestone=False):
|
||||||
|
step = int(jax.device_get(train_state.step))
|
||||||
|
if self.config.save_optimizer_state:
|
||||||
|
checkpoint_state = train_state
|
||||||
|
checkpoint_name = 'streaming_train_state'
|
||||||
|
checkpoint_gather_fns = gather_fns
|
||||||
|
else:
|
||||||
|
checkpoint_state = train_state.params['params']
|
||||||
|
checkpoint_name = 'streaming_params'
|
||||||
|
checkpoint_gather_fns = gather_fns.params['params']
|
||||||
|
|
||||||
|
if milestone:
|
||||||
|
# Save a milestone checkpoint that will not be overwritten
|
||||||
|
self.save_pickle(metadata, f'metadata_{step}.pkl')
|
||||||
|
self.save_pickle(dataset, f'dataset_{step}.pkl')
|
||||||
|
self.save_checkpoint(
|
||||||
|
checkpoint_state, f'{checkpoint_name}_{step}', checkpoint_gather_fns
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Save a normal checkpoint that can be overwritten
|
||||||
|
self.save_pickle(metadata, 'metadata.pkl')
|
||||||
|
self.save_pickle(dataset, 'dataset.pkl')
|
||||||
|
self.save_checkpoint(
|
||||||
|
checkpoint_state, f'{checkpoint_name}', checkpoint_gather_fns
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load_checkpoint(path, target=None, shard_fns=None, remove_dict_prefix=None):
|
||||||
|
if shard_fns is not None:
|
||||||
|
shard_fns = flatten_dict(
|
||||||
|
to_state_dict(shard_fns)
|
||||||
|
)
|
||||||
|
if remove_dict_prefix is not None:
|
||||||
|
remove_dict_prefix = tuple(remove_dict_prefix)
|
||||||
|
flattend_train_state = {}
|
||||||
|
with mlxu.open_file(path) as fin:
|
||||||
|
# 83886080 bytes = 80 MB, which is 16 blocks on GCS
|
||||||
|
unpacker = msgpack.Unpacker(fin, read_size=83886080, max_buffer_size=0)
|
||||||
|
for key, value in unpacker:
|
||||||
|
key = tuple(key)
|
||||||
|
if remove_dict_prefix is not None:
|
||||||
|
if key[:len(remove_dict_prefix)] == remove_dict_prefix:
|
||||||
|
key = key[len(remove_dict_prefix):]
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
|
tensor = from_bytes(None, value)
|
||||||
|
if shard_fns is not None:
|
||||||
|
tensor = shard_fns[key](tensor)
|
||||||
|
flattend_train_state[key] = tensor
|
||||||
|
|
||||||
|
if target is not None:
|
||||||
|
flattened_target = flatten_dict(
|
||||||
|
to_state_dict(target), keep_empty_nodes=True
|
||||||
|
)
|
||||||
|
for key, value in flattened_target.items():
|
||||||
|
if key not in flattend_train_state and value == empty_node:
|
||||||
|
flattend_train_state[key] = value
|
||||||
|
|
||||||
|
train_state = unflatten_dict(flattend_train_state)
|
||||||
|
if target is None:
|
||||||
|
return train_state
|
||||||
|
|
||||||
|
return from_state_dict(target, train_state)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load_flax_checkpoint(path, target=None, shard_fns=None):
|
||||||
|
""" Load a standard flax checkpoint that's not saved with the
|
||||||
|
msgpack streaming format.
|
||||||
|
"""
|
||||||
|
with mlxu.open_file(path, "rb") as fin:
|
||||||
|
encoded_bytes = fin.read()
|
||||||
|
|
||||||
|
state_dict = flax.serialization.msgpack_restore(encoded_bytes)
|
||||||
|
if shard_fns is not None:
|
||||||
|
shard_fns = to_state_dict(shard_fns)
|
||||||
|
state_dict = tree_apply(shard_fns, state_dict)
|
||||||
|
|
||||||
|
if target is None:
|
||||||
|
return state_dict
|
||||||
|
return from_state_dict(target, state_dict)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load_trainstate_checkpoint(cls, load_from, trainstate_target=None,
|
||||||
|
trainstate_shard_fns=None,
|
||||||
|
disallow_trainstate=False):
|
||||||
|
if trainstate_target is not None:
|
||||||
|
params_target = trainstate_target.params['params']
|
||||||
|
else:
|
||||||
|
params_target = None
|
||||||
|
|
||||||
|
if trainstate_shard_fns is not None:
|
||||||
|
params_shard_fns = trainstate_shard_fns.params['params']
|
||||||
|
else:
|
||||||
|
params_shard_fns = None
|
||||||
|
|
||||||
|
load_type, load_path = load_from.split('::', 1)
|
||||||
|
if disallow_trainstate:
|
||||||
|
assert load_type != 'trainstate', 'Loading full trainstate is not allowed!'
|
||||||
|
train_state = None
|
||||||
|
restored_params = None
|
||||||
|
if load_type == 'trainstate':
|
||||||
|
# Load the entire train state in the streaming format
|
||||||
|
train_state = cls.load_checkpoint(
|
||||||
|
path=load_path,
|
||||||
|
target=trainstate_target,
|
||||||
|
shard_fns=trainstate_shard_fns,
|
||||||
|
)
|
||||||
|
elif load_type == 'trainstate_params':
|
||||||
|
# Load the params part of the train state in the streaming format
|
||||||
|
restored_params = cls.load_checkpoint(
|
||||||
|
path=load_path,
|
||||||
|
target=params_target,
|
||||||
|
shard_fns=params_shard_fns,
|
||||||
|
remove_dict_prefix=('params', 'params'),
|
||||||
|
)
|
||||||
|
restored_params = flax.core.frozen_dict.freeze(
|
||||||
|
{'params': restored_params}
|
||||||
|
)
|
||||||
|
elif load_type == 'params':
|
||||||
|
# Load the params in the streaming format
|
||||||
|
restored_params = cls.load_checkpoint(
|
||||||
|
path=load_path,
|
||||||
|
target=params_target,
|
||||||
|
shard_fns=params_shard_fns,
|
||||||
|
)
|
||||||
|
restored_params = flax.core.frozen_dict.freeze(
|
||||||
|
{'params': restored_params}
|
||||||
|
)
|
||||||
|
elif load_type == 'flax_params':
|
||||||
|
# Load the params in the standard flax format (non-streaming)
|
||||||
|
# This requires the entire params to fit in memory
|
||||||
|
restored_params = cls.load_flax_checkpoint(
|
||||||
|
path=load_path,
|
||||||
|
target=params_target,
|
||||||
|
shard_fns=params_shard_fns
|
||||||
|
)
|
||||||
|
restored_params = flax.core.frozen_dict.freeze(
|
||||||
|
{'params': restored_params}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f'Invalid load_from type: {load_type}')
|
||||||
|
|
||||||
|
return train_state, restored_params
|
||||||
436
EasyLM/data.py
Normal file
436
EasyLM/data.py
Normal file
@@ -0,0 +1,436 @@
|
|||||||
|
import dataclasses
|
||||||
|
import pprint
|
||||||
|
import time
|
||||||
|
from functools import partial
|
||||||
|
import json
|
||||||
|
import base64
|
||||||
|
from multiprocessing import Pool
|
||||||
|
|
||||||
|
import h5py
|
||||||
|
import mlxu
|
||||||
|
from ml_collections.config_dict import config_dict
|
||||||
|
from ml_collections import ConfigDict
|
||||||
|
from tqdm import tqdm, trange
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from datasets import load_dataset, load_from_disk
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetFactory(object):
|
||||||
|
""" Datset builder class. """
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_default_config(updates=None):
|
||||||
|
config = ConfigDict()
|
||||||
|
config.type = 'huggingface'
|
||||||
|
config.text_processor = TextProcessor.get_default_config()
|
||||||
|
config.huggingface_dataset = HuggingfaceDataset.get_default_config()
|
||||||
|
config.json_dataset = JsonDataset.get_default_config()
|
||||||
|
|
||||||
|
if updates is not None:
|
||||||
|
config.update(ConfigDict(updates).copy_and_resolve_references())
|
||||||
|
return config
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load_dataset(cls, config, tokenizer, **kwargs):
|
||||||
|
config = cls.get_default_config(config)
|
||||||
|
text_processor = TextProcessor(config.text_processor, tokenizer)
|
||||||
|
if config.type == 'huggingface':
|
||||||
|
return HuggingfaceDataset(
|
||||||
|
config.huggingface_dataset, tokenizer, text_processor, **kwargs
|
||||||
|
)
|
||||||
|
elif config.type == 'json':
|
||||||
|
return JsonDataset(config.json_dataset, tokenizer, text_processor, **kwargs)
|
||||||
|
else:
|
||||||
|
raise ValueError(f'Unknown dataset type: {config.type}')
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
raise ValueError('DatasetFactory is a static class and should not be instantiated.')
|
||||||
|
|
||||||
|
|
||||||
|
class TextProcessor(object):
|
||||||
|
""" Example processor that converts a dictionary of texts into tokens. """
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_default_config(updates=None):
|
||||||
|
config = ConfigDict()
|
||||||
|
config.fields_from_example = ''
|
||||||
|
config.fields = ''
|
||||||
|
config.subfield_separator = ' '
|
||||||
|
config.add_bos_token = True
|
||||||
|
config.add_eos_token = True
|
||||||
|
config.prepend_text = ''
|
||||||
|
config.base64_token_dtype = 'i4'
|
||||||
|
if updates is not None:
|
||||||
|
config.update(ConfigDict(updates).copy_and_resolve_references())
|
||||||
|
return config
|
||||||
|
|
||||||
|
def __init__(self, config, tokenizer):
|
||||||
|
self.config = self.get_default_config(config)
|
||||||
|
assert self.config.fields != '' or self.config.fields_from_example != '', (
|
||||||
|
'Either fields or fields_from_example must be specified.'
|
||||||
|
)
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
|
||||||
|
def __call__(self, example, has_aux=False):
|
||||||
|
if has_aux:
|
||||||
|
example, *aux = example
|
||||||
|
else:
|
||||||
|
aux = tuple()
|
||||||
|
token_buffer = []
|
||||||
|
loss_mask_buffer = []
|
||||||
|
|
||||||
|
if self.config.add_bos_token:
|
||||||
|
token_buffer.append(self.tokenizer.bos_token_id)
|
||||||
|
loss_mask_buffer.append(0.0)
|
||||||
|
|
||||||
|
if self.config.fields_from_example != '':
|
||||||
|
fields = example[self.config.fields_from_example].split(',')
|
||||||
|
else:
|
||||||
|
fields = self.config.fields.split(',')
|
||||||
|
|
||||||
|
for i, field in enumerate(fields):
|
||||||
|
if field.startswith('[') and field.endswith(']'):
|
||||||
|
# No loss for this field.
|
||||||
|
field = field[1:-1]
|
||||||
|
mask = 0.0
|
||||||
|
else:
|
||||||
|
mask = 1.0
|
||||||
|
|
||||||
|
if field.startswith('<|') and field.endswith('|>'):
|
||||||
|
# Special tokens.
|
||||||
|
field = field[2:-2]
|
||||||
|
if field == 'bos':
|
||||||
|
token_buffer.append(self.tokenizer.bos_token_id)
|
||||||
|
elif field == 'eos':
|
||||||
|
token_buffer.append(self.tokenizer.eos_token_id)
|
||||||
|
else:
|
||||||
|
# Token ID specified directly.
|
||||||
|
token_buffer.append(int(field))
|
||||||
|
loss_mask_buffer.append(mask)
|
||||||
|
elif field.startswith('{') and field.endswith('}'):
|
||||||
|
field = field[1:-1]
|
||||||
|
# Base64 encoded raw tokens.
|
||||||
|
tokens = np.frombuffer(
|
||||||
|
base64.b64decode(example[field]),
|
||||||
|
dtype=self.config.base64_token_dtype
|
||||||
|
).tolist()
|
||||||
|
token_buffer.extend(tokens)
|
||||||
|
loss_mask_buffer.extend([mask for _ in range(len(tokens))])
|
||||||
|
else:
|
||||||
|
subfields = field.split('+')
|
||||||
|
text = self.config.subfield_separator.join(
|
||||||
|
[example[subfield] for subfield in subfields]
|
||||||
|
)
|
||||||
|
if i == 0:
|
||||||
|
text = self.config.prepend_text + text
|
||||||
|
tokens = self.tokenizer.encode(text)
|
||||||
|
token_buffer.extend(tokens)
|
||||||
|
loss_mask_buffer.extend([mask for _ in range(len(tokens))])
|
||||||
|
|
||||||
|
if self.config.add_eos_token:
|
||||||
|
token_buffer.append(self.tokenizer.eos_token_id)
|
||||||
|
loss_mask_buffer.append(1.0)
|
||||||
|
|
||||||
|
return token_buffer, loss_mask_buffer, *aux
|
||||||
|
|
||||||
|
|
||||||
|
class HuggingfaceDataset(object):
|
||||||
|
""" Huggingface dataset, where the dataset is loaded using the huggingface
|
||||||
|
datasets.load_dataset() function.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_default_config(updates=None):
|
||||||
|
config = ConfigDict()
|
||||||
|
config.path = 'c4'
|
||||||
|
config.name = 'en'
|
||||||
|
config.split = 'train'
|
||||||
|
config.streaming = False
|
||||||
|
config.seq_length = 1024
|
||||||
|
config.batch_size = 8
|
||||||
|
config.always_start_with_bos = False
|
||||||
|
config.start_seek_loc = 0
|
||||||
|
config.tokens_count_at_start = 0
|
||||||
|
config.batch_token_dtype = 'i4'
|
||||||
|
config.reset_dataset_loc = False
|
||||||
|
|
||||||
|
if updates is not None:
|
||||||
|
config.update(ConfigDict(updates).copy_and_resolve_references())
|
||||||
|
return config
|
||||||
|
|
||||||
|
def __init__(self, config, tokenizer, text_processor, eval_dataset=False):
|
||||||
|
self.config = self.get_default_config(config)
|
||||||
|
name = self.config.name if self.config.name != '' else None
|
||||||
|
split = self.config.split if self.config.split != '' else None
|
||||||
|
self._tokenizer = tokenizer
|
||||||
|
self._text_processor = text_processor
|
||||||
|
self._dataset = load_from_disk(
|
||||||
|
self.config.path
|
||||||
|
)[split]
|
||||||
|
self._dataset = self._dataset.to_iterable_dataset(num_shards=128 if len(self._dataset) > 128 else len(self._dataset))
|
||||||
|
self._eval_dataset = eval_dataset
|
||||||
|
self._train_epochs = 0
|
||||||
|
self._dataset_loc = self.config.start_seek_loc
|
||||||
|
self._total_tokens = self.config.tokens_count_at_start
|
||||||
|
self._index = 0
|
||||||
|
self.reset_dataset_loc = self.config.reset_dataset_loc
|
||||||
|
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
if not self._eval_dataset and self._train_epochs > 0:
|
||||||
|
self._dataset = self._dataset.shuffle(seed=42, buffer_size=10000)
|
||||||
|
chunk_size = self.config.batch_size * self.config.seq_length
|
||||||
|
while True:
|
||||||
|
token_buffer = []
|
||||||
|
loss_mask_buffer = []
|
||||||
|
if not self._eval_dataset and self._train_epochs > 0:
|
||||||
|
self._dataset.set_epoch(self._train_epochs)
|
||||||
|
for index, example in enumerate(self._dataset):
|
||||||
|
self._index = index
|
||||||
|
if not self._eval_dataset and self._dataset_loc > index:
|
||||||
|
continue
|
||||||
|
tokens, loss_masks = self.text_processor(example)
|
||||||
|
token_buffer.extend(tokens)
|
||||||
|
loss_mask_buffer.extend(loss_masks)
|
||||||
|
while len(token_buffer) > chunk_size + 1:
|
||||||
|
self._total_tokens += chunk_size
|
||||||
|
metrics = {
|
||||||
|
'dataset_example_index': index,
|
||||||
|
'dataset_total_tokens': self._total_tokens,
|
||||||
|
'epoch': self._train_epochs,
|
||||||
|
}
|
||||||
|
batch = {
|
||||||
|
'input_tokens': np.array(token_buffer[:chunk_size], dtype=self.config.batch_token_dtype).reshape(
|
||||||
|
self.config.batch_size, -1
|
||||||
|
),
|
||||||
|
'target_tokens': np.array(token_buffer[1:chunk_size + 1], dtype=self.config.batch_token_dtype).reshape(
|
||||||
|
self.config.batch_size, -1
|
||||||
|
),
|
||||||
|
'loss_masks': np.array(loss_mask_buffer[1:chunk_size + 1], dtype=np.float32).reshape(
|
||||||
|
self.config.batch_size, -1
|
||||||
|
),
|
||||||
|
}
|
||||||
|
if self.config.always_start_with_bos:
|
||||||
|
batch['input_tokens'][:, 0] = self.tokenizer.bos_token_id
|
||||||
|
yield batch, metrics
|
||||||
|
token_buffer = token_buffer[chunk_size:]
|
||||||
|
loss_mask_buffer = loss_mask_buffer[chunk_size:]
|
||||||
|
|
||||||
|
if self._eval_dataset:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
if self._train_epochs == 0:
|
||||||
|
self._dataset = self._dataset.shuffle(seed=42, buffer_size=10000)
|
||||||
|
self._dataset_loc = 0
|
||||||
|
self._train_epochs += 1
|
||||||
|
|
||||||
|
def get_state_dict(self):
|
||||||
|
return dict(
|
||||||
|
config=self.config,
|
||||||
|
dataset_loc=self._index,
|
||||||
|
total_tokens=self._total_tokens,
|
||||||
|
epochs=self._train_epochs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def load_state_dict(self, state_dict):
|
||||||
|
if 'config' in state_dict:
|
||||||
|
self.config.update(ConfigDict(state_dict['config']))
|
||||||
|
self._dataset_loc = state_dict.get('dataset_loc', self.config.start_seek_loc)
|
||||||
|
self._total_tokens = state_dict.get('total_tokens', self.config.tokens_count_at_start)
|
||||||
|
self._train_epochs = state_dict.get('epochs', 0)
|
||||||
|
if self.reset_dataset_loc:
|
||||||
|
self._dataset_loc = 0
|
||||||
|
self._train_epochs = 0
|
||||||
|
|
||||||
|
|
||||||
|
@property
|
||||||
|
def seq_length(self):
|
||||||
|
return self.config.seq_length
|
||||||
|
|
||||||
|
@property
|
||||||
|
def tokenizer(self):
|
||||||
|
return self._tokenizer
|
||||||
|
|
||||||
|
@property
|
||||||
|
def text_processor(self):
|
||||||
|
return self._text_processor
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dataset(self):
|
||||||
|
return self._dataset
|
||||||
|
|
||||||
|
@property
|
||||||
|
def vocab_size(self):
|
||||||
|
return len(self._tokenizer)
|
||||||
|
|
||||||
|
|
||||||
|
class JsonDataset(object):
|
||||||
|
""" JSON dataset, where each line of the data file contains a JSON
|
||||||
|
dictionary with text fields.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_default_config(updates=None):
|
||||||
|
config = ConfigDict()
|
||||||
|
config.path = ''
|
||||||
|
config.seq_length = 1024
|
||||||
|
config.batch_size = 8
|
||||||
|
config.always_start_with_bos = False
|
||||||
|
config.start_seek_loc = 0
|
||||||
|
config.example_index_at_start = 0
|
||||||
|
config.tokens_count_at_start = 0
|
||||||
|
config.tokenizer_processes = 1
|
||||||
|
config.tokenizer_parallel_chunk_size = 32
|
||||||
|
config.tokenizer_parallel_batch_size = 1024
|
||||||
|
config.throughput_average_window_size = 200
|
||||||
|
|
||||||
|
if updates is not None:
|
||||||
|
config.update(ConfigDict(updates).copy_and_resolve_references())
|
||||||
|
return config
|
||||||
|
|
||||||
|
def __init__(self, config, tokenizer, text_processor):
|
||||||
|
self.config = self.get_default_config(config)
|
||||||
|
assert self.config.path != ''
|
||||||
|
self._tokenizer = tokenizer
|
||||||
|
self._text_processor = text_processor
|
||||||
|
self._index = self.config.example_index_at_start
|
||||||
|
self._file_loc = self.config.start_seek_loc
|
||||||
|
self._total_tokens = self.config.tokens_count_at_start
|
||||||
|
|
||||||
|
def parse_json(self, line):
|
||||||
|
if not line or line == '\n':
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
data = json.loads(line)
|
||||||
|
except json.decoder.JSONDecodeError:
|
||||||
|
print(f'Error parsing json line:\n{line}')
|
||||||
|
return None
|
||||||
|
return data
|
||||||
|
|
||||||
|
def json_iterator(self):
|
||||||
|
with mlxu.open_file(self.config.path, 'r') as fin:
|
||||||
|
fin.seek(self._file_loc)
|
||||||
|
while True:
|
||||||
|
line = fin.readline()
|
||||||
|
self._file_loc = fin.tell()
|
||||||
|
if not line: # Reached EOF
|
||||||
|
self._index = 0
|
||||||
|
fin.seek(0)
|
||||||
|
continue
|
||||||
|
|
||||||
|
data = self.parse_json(line)
|
||||||
|
if data is not None:
|
||||||
|
# JSON parsing succeeded
|
||||||
|
yield data, self._file_loc, self._index
|
||||||
|
self._index += 1
|
||||||
|
|
||||||
|
def batched(self, iterator, batch_size):
|
||||||
|
batch = []
|
||||||
|
for example in iterator:
|
||||||
|
batch.append(example)
|
||||||
|
if len(batch) == batch_size:
|
||||||
|
yield batch
|
||||||
|
batch = []
|
||||||
|
if len(batch) > 0:
|
||||||
|
yield batch
|
||||||
|
|
||||||
|
def parallel_example_iterator(self):
|
||||||
|
if self.config.tokenizer_processes == 1:
|
||||||
|
for example, loc, index in self.json_iterator():
|
||||||
|
yield self.text_processor((example, loc, index), has_aux=True)
|
||||||
|
else:
|
||||||
|
process_pool = Pool(self.config.tokenizer_processes)
|
||||||
|
batched_iterator = self.batched(
|
||||||
|
self.json_iterator(), self.config.tokenizer_parallel_batch_size
|
||||||
|
)
|
||||||
|
with process_pool as pool:
|
||||||
|
map_fn = partial(self.text_processor, has_aux=True)
|
||||||
|
next_batch = pool.map_async(
|
||||||
|
map_fn, next(batched_iterator),
|
||||||
|
chunksize=self.config.tokenizer_parallel_chunk_size
|
||||||
|
)
|
||||||
|
while True:
|
||||||
|
current_batch = next_batch
|
||||||
|
next_batch = pool.map_async(
|
||||||
|
map_fn, next(batched_iterator),
|
||||||
|
chunksize=self.config.tokenizer_parallel_chunk_size
|
||||||
|
)
|
||||||
|
for example in current_batch.get():
|
||||||
|
yield example
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
chunk_size = self.config.batch_size * self.config.seq_length
|
||||||
|
token_buffer = []
|
||||||
|
loss_mask_buffer = []
|
||||||
|
last_time = 0.0
|
||||||
|
step_times = []
|
||||||
|
start_time = time.time()
|
||||||
|
start_tokens = self._total_tokens
|
||||||
|
for tokens, loss_masks, loc, index in self.parallel_example_iterator():
|
||||||
|
token_buffer.extend(tokens)
|
||||||
|
loss_mask_buffer.extend(loss_masks)
|
||||||
|
while len(token_buffer) > chunk_size + 1:
|
||||||
|
self._total_tokens += chunk_size
|
||||||
|
step_times.append(time.time() - last_time)
|
||||||
|
last_time = time.time()
|
||||||
|
if len(step_times) > self.config.throughput_average_window_size:
|
||||||
|
step_times = step_times[-self.config.throughput_average_window_size:]
|
||||||
|
average_throughput = chunk_size / np.mean(step_times)
|
||||||
|
accumulated_throughput = (
|
||||||
|
(self._total_tokens - start_tokens) / (time.time() - start_time)
|
||||||
|
)
|
||||||
|
metrics = {
|
||||||
|
'dataset_file_loc': loc,
|
||||||
|
'dataset_example_index': index,
|
||||||
|
'dataset_total_tokens': self._total_tokens,
|
||||||
|
'dataset_accumulated_tps': accumulated_throughput,
|
||||||
|
'dataset_average_tps': average_throughput,
|
||||||
|
}
|
||||||
|
batch = {
|
||||||
|
'input_tokens': np.array(token_buffer[:chunk_size], dtype=np.int32).reshape(
|
||||||
|
self.config.batch_size, -1
|
||||||
|
),
|
||||||
|
'target_tokens': np.array(token_buffer[1:chunk_size + 1], dtype=np.int32).reshape(
|
||||||
|
self.config.batch_size, -1
|
||||||
|
),
|
||||||
|
'loss_masks': np.array(loss_mask_buffer[1:chunk_size + 1], dtype=np.float32).reshape(
|
||||||
|
self.config.batch_size, -1
|
||||||
|
),
|
||||||
|
}
|
||||||
|
if self.config.always_start_with_bos:
|
||||||
|
batch['input_tokens'][:, 0] = self.tokenizer.bos_token_id
|
||||||
|
yield batch, metrics
|
||||||
|
token_buffer = token_buffer[chunk_size:]
|
||||||
|
loss_mask_buffer = loss_mask_buffer[chunk_size:]
|
||||||
|
|
||||||
|
def get_state_dict(self):
|
||||||
|
return dict(
|
||||||
|
config=self.config,
|
||||||
|
index=self._index,
|
||||||
|
file_loc=self._file_loc,
|
||||||
|
total_tokens=self._total_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
def load_state_dict(self, state_dict):
|
||||||
|
if 'config' in state_dict:
|
||||||
|
self.config.update(ConfigDict(state_dict['config']))
|
||||||
|
self._index = state_dict.get('index', self.config.example_index_at_start)
|
||||||
|
self._file_loc = state_dict.get('file_loc', self.config.start_seek_loc)
|
||||||
|
self._total_tokens = state_dict.get('total_tokens', self.config.tokens_count_at_start)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def seq_length(self):
|
||||||
|
return self.config.seq_length
|
||||||
|
|
||||||
|
@property
|
||||||
|
def tokenizer(self):
|
||||||
|
return self._tokenizer
|
||||||
|
|
||||||
|
@property
|
||||||
|
def text_processor(self):
|
||||||
|
return self._text_processor
|
||||||
|
|
||||||
|
@property
|
||||||
|
def vocab_size(self):
|
||||||
|
return len(self.tokenizer)
|
||||||
403
EasyLM/jax_utils.py
Normal file
403
EasyLM/jax_utils.py
Normal file
@@ -0,0 +1,403 @@
|
|||||||
|
import os
|
||||||
|
import math
|
||||||
|
from typing import Any, Mapping, Text, Tuple, Union, NamedTuple
|
||||||
|
from functools import partial
|
||||||
|
import re
|
||||||
|
import dataclasses
|
||||||
|
import random
|
||||||
|
from ml_collections import ConfigDict
|
||||||
|
from ml_collections.config_dict.config_dict import placeholder
|
||||||
|
|
||||||
|
import flax
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
from jax.sharding import PartitionSpec as PS
|
||||||
|
from jax.sharding import Mesh
|
||||||
|
from jax.experimental import mesh_utils
|
||||||
|
from jax.experimental.pjit import with_sharding_constraint as _with_sharding_constraint
|
||||||
|
from jax.experimental.pjit import pjit
|
||||||
|
from jax.interpreters import pxla
|
||||||
|
import numpy as np
|
||||||
|
from transformers import FlaxLogitsWarper
|
||||||
|
|
||||||
|
|
||||||
|
class JaxRNG(object):
|
||||||
|
""" A convenient stateful Jax RNG wrapper. Can be used to wrap RNG inside
|
||||||
|
pure function.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_seed(cls, seed):
|
||||||
|
return cls(jax.random.PRNGKey(seed))
|
||||||
|
|
||||||
|
def __init__(self, rng):
|
||||||
|
self.rng = rng
|
||||||
|
|
||||||
|
def __call__(self, keys=None):
|
||||||
|
if keys is None:
|
||||||
|
self.rng, split_rng = jax.random.split(self.rng)
|
||||||
|
return split_rng
|
||||||
|
elif isinstance(keys, int):
|
||||||
|
split_rngs = jax.random.split(self.rng, num=keys + 1)
|
||||||
|
self.rng = split_rngs[0]
|
||||||
|
return tuple(split_rngs[1:])
|
||||||
|
else:
|
||||||
|
split_rngs = jax.random.split(self.rng, num=len(keys) + 1)
|
||||||
|
self.rng = split_rngs[0]
|
||||||
|
return {key: val for key, val in zip(keys, split_rngs[1:])}
|
||||||
|
|
||||||
|
|
||||||
|
class JaxDistributedConfig(object):
|
||||||
|
""" Utility class for initializing JAX distributed. """
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_default_config(updates=None):
|
||||||
|
config = ConfigDict()
|
||||||
|
config.initialize_jax_distributed = False
|
||||||
|
config.coordinator_address = placeholder(str)
|
||||||
|
config.num_processes = placeholder(int)
|
||||||
|
config.process_id = placeholder(int)
|
||||||
|
config.local_device_ids = placeholder(str)
|
||||||
|
|
||||||
|
if updates is not None:
|
||||||
|
config.update(ConfigDict(updates).copy_and_resolve_references())
|
||||||
|
return config
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def initialize(cls, config):
|
||||||
|
config = cls.get_default_config(config)
|
||||||
|
if config.initialize_jax_distributed:
|
||||||
|
if config.local_device_ids is not None:
|
||||||
|
local_device_ids = [int(x) for x in config.local_device_ids.split(',')]
|
||||||
|
else:
|
||||||
|
local_device_ids = None
|
||||||
|
|
||||||
|
jax.distributed.initialize(
|
||||||
|
coordinator_address=config.coordinator_address,
|
||||||
|
num_processes=config.num_processes,
|
||||||
|
process_id=config.process_id,
|
||||||
|
local_device_ids=local_device_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxTemperatureLogitsWarper(FlaxLogitsWarper):
|
||||||
|
""" JIT traceable version of FlaxLogitsWarper that performs temperature scaling."""
|
||||||
|
def __init__(self, temperature):
|
||||||
|
self.temperature = temperature
|
||||||
|
|
||||||
|
def __call__(self, input_ids, scores, cur_len):
|
||||||
|
return scores / jnp.clip(self.temperature, a_min=1e-8)
|
||||||
|
|
||||||
|
|
||||||
|
def make_shard_and_gather_fns(partition_specs, dtype_specs=None):
|
||||||
|
""" Create pytree of sharding and gathering functions from pytree of
|
||||||
|
partition specs.
|
||||||
|
"""
|
||||||
|
float_dtypes = (jnp.bfloat16, jnp.float16, jnp.float32, jnp.float64)
|
||||||
|
|
||||||
|
def make_to_dtype_fn(dtype_spec):
|
||||||
|
def to_dtype(tensor):
|
||||||
|
if dtype_specs in float_dtypes and getattr(tensor, 'dtype', None) in float_dtypes:
|
||||||
|
# Convert all float tensors to the same dtype
|
||||||
|
return tensor.astype(dtype_specs)
|
||||||
|
elif hasattr(dtype_spec, 'dtype') and hasattr(tensor, 'dtype'):
|
||||||
|
return tensor.astype(dtype_spec.dtype)
|
||||||
|
return tensor
|
||||||
|
return to_dtype
|
||||||
|
|
||||||
|
def make_shard_fn(partition_spec, dtype_spec=None):
|
||||||
|
jax_shard_function = pjit(
|
||||||
|
make_to_dtype_fn(dtype_spec),
|
||||||
|
in_shardings=None,
|
||||||
|
out_shardings=partition_spec
|
||||||
|
)
|
||||||
|
def shard_fn(tensor):
|
||||||
|
return jax_shard_function(tensor).block_until_ready()
|
||||||
|
return shard_fn
|
||||||
|
|
||||||
|
def make_gather_fn(partition_spec, dtype_spec=None):
|
||||||
|
jax_gather_fn = pjit(
|
||||||
|
make_to_dtype_fn(dtype_spec),
|
||||||
|
in_shardings=partition_spec,
|
||||||
|
out_shardings=None
|
||||||
|
)
|
||||||
|
def gather_fn(tensor):
|
||||||
|
return jax.device_get(jax_gather_fn(tensor))
|
||||||
|
return gather_fn
|
||||||
|
|
||||||
|
if dtype_specs is None or dtype_specs in float_dtypes:
|
||||||
|
shard_fns = jax.tree_util.tree_map(make_shard_fn, partition_specs)
|
||||||
|
gather_fns = jax.tree_util.tree_map(make_gather_fn, partition_specs)
|
||||||
|
else:
|
||||||
|
shard_fns = jax.tree_util.tree_map(
|
||||||
|
make_shard_fn, partition_specs, dtype_specs
|
||||||
|
)
|
||||||
|
gather_fns = jax.tree_util.tree_map(
|
||||||
|
make_gather_fn, partition_specs, dtype_specs
|
||||||
|
)
|
||||||
|
return shard_fns, gather_fns
|
||||||
|
|
||||||
|
|
||||||
|
def set_random_seed(seed):
|
||||||
|
np.random.seed(seed)
|
||||||
|
random.seed(seed)
|
||||||
|
init_rng(seed)
|
||||||
|
|
||||||
|
|
||||||
|
def get_jax_mesh(axis_dims, names):
|
||||||
|
if axis_dims.startswith('!'):
|
||||||
|
# Allow splitting a physical mesh axis if needed
|
||||||
|
mesh_axis_splitting = True
|
||||||
|
axis_dims = axis_dims[1:]
|
||||||
|
else:
|
||||||
|
mesh_axis_splitting = False
|
||||||
|
|
||||||
|
if ':' in axis_dims:
|
||||||
|
dims = []
|
||||||
|
dim_names = []
|
||||||
|
for axis in axis_dims.split(','):
|
||||||
|
name, dim = axis.split(':')
|
||||||
|
assert name in names
|
||||||
|
dims.append(int(dim))
|
||||||
|
dim_names.append(name)
|
||||||
|
assert(set(dim_names) == set(names))
|
||||||
|
else:
|
||||||
|
dims = [int(x) for x in axis_dims.split(',')]
|
||||||
|
dim_names = names
|
||||||
|
assert len(dims) == len(names)
|
||||||
|
mesh_shape = np.arange(jax.device_count()).reshape(dims).shape
|
||||||
|
if mesh_axis_splitting:
|
||||||
|
physical_mesh = np.array(jax.devices()).reshape(mesh_shape)
|
||||||
|
else:
|
||||||
|
physical_mesh = mesh_utils.create_device_mesh(mesh_shape)
|
||||||
|
return Mesh(physical_mesh, dim_names)
|
||||||
|
|
||||||
|
|
||||||
|
def names_in_current_mesh(*names):
|
||||||
|
""" Check if current mesh axes contain these names. """
|
||||||
|
mesh_axis_names = pxla.thread_resources.env.physical_mesh.axis_names
|
||||||
|
return set(names) <= set(mesh_axis_names)
|
||||||
|
|
||||||
|
|
||||||
|
def get_names_from_parition_spec(partition_specs):
|
||||||
|
""" Return axis names from partition specs. """
|
||||||
|
names = set()
|
||||||
|
if isinstance(partition_specs, dict):
|
||||||
|
partition_specs = partition_specs.values()
|
||||||
|
for item in partition_specs:
|
||||||
|
if item is None:
|
||||||
|
continue
|
||||||
|
elif isinstance(item, str):
|
||||||
|
names.add(item)
|
||||||
|
else:
|
||||||
|
names.update(get_names_from_parition_spec(item))
|
||||||
|
|
||||||
|
return list(names)
|
||||||
|
|
||||||
|
|
||||||
|
def with_sharding_constraint(x, partition_specs):
|
||||||
|
""" A smarter version of with_sharding_constraint that only applies the
|
||||||
|
constraint if the current mesh contains the axes in the partition specs.
|
||||||
|
"""
|
||||||
|
axis_names = get_names_from_parition_spec(partition_specs)
|
||||||
|
if names_in_current_mesh(*axis_names):
|
||||||
|
x = _with_sharding_constraint(x, partition_specs)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def wrap_function_with_rng(rng):
|
||||||
|
""" To be used as decorator, automatically bookkeep a RNG for the wrapped function. """
|
||||||
|
def wrap_function(function):
|
||||||
|
def wrapped(*args, **kwargs):
|
||||||
|
nonlocal rng
|
||||||
|
rng, split_rng = jax.random.split(rng)
|
||||||
|
return function(split_rng, *args, **kwargs)
|
||||||
|
return wrapped
|
||||||
|
return wrap_function
|
||||||
|
|
||||||
|
|
||||||
|
def init_rng(seed):
|
||||||
|
global jax_utils_rng
|
||||||
|
jax_utils_rng = JaxRNG.from_seed(seed)
|
||||||
|
|
||||||
|
|
||||||
|
def next_rng(*args, **kwargs):
|
||||||
|
global jax_utils_rng
|
||||||
|
return jax_utils_rng(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def get_metrics(metrics, unreplicate=False, stack=False):
|
||||||
|
if unreplicate:
|
||||||
|
metrics = flax.jax_utils.unreplicate(metrics)
|
||||||
|
metrics = jax.device_get(metrics)
|
||||||
|
if stack:
|
||||||
|
return jax.tree_map(lambda *args: np.stack(args), *metrics)
|
||||||
|
else:
|
||||||
|
return {key: float(val) for key, val in metrics.items()}
|
||||||
|
|
||||||
|
|
||||||
|
def mse_loss(val, target, valid=None):
|
||||||
|
if valid is None:
|
||||||
|
valid = jnp.ones((*target.shape[:2], 1))
|
||||||
|
valid = valid.astype(jnp.float32)
|
||||||
|
loss = jnp.mean(
|
||||||
|
jnp.where(
|
||||||
|
valid > 0.0,
|
||||||
|
jnp.square(val - target),
|
||||||
|
0.0
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
def cross_entropy_loss_and_accuracy(logits, tokens, valid=None):
|
||||||
|
if valid is None:
|
||||||
|
valid = jnp.ones(tokens.shape[:2])
|
||||||
|
valid = valid.astype(jnp.float32)
|
||||||
|
valid_text_length = jnp.maximum(jnp.sum(valid, axis=-1), 1e-10)
|
||||||
|
logits = logits.astype(jnp.float32) # for numerical stability
|
||||||
|
token_log_prob = jnp.squeeze(
|
||||||
|
jnp.take_along_axis(
|
||||||
|
jax.nn.log_softmax(logits, axis=-1),
|
||||||
|
jnp.expand_dims(tokens, -1),
|
||||||
|
axis=-1,
|
||||||
|
),
|
||||||
|
-1,
|
||||||
|
)
|
||||||
|
token_log_prob = jnp.where(valid > 0.0, token_log_prob, jnp.array(0.0))
|
||||||
|
loss = -jnp.mean(jnp.sum(token_log_prob, axis=-1) / valid_text_length)
|
||||||
|
correct = jnp.where(
|
||||||
|
valid > 0.0,
|
||||||
|
jnp.argmax(logits, axis=-1) == tokens,
|
||||||
|
jnp.array(False)
|
||||||
|
)
|
||||||
|
accuracy = jnp.mean(jnp.sum(correct, axis=-1) / valid_text_length)
|
||||||
|
return loss, accuracy
|
||||||
|
|
||||||
|
|
||||||
|
def global_norm(tree):
|
||||||
|
""" Return the global L2 norm of a pytree. """
|
||||||
|
squared = jax.tree_util.tree_map(lambda x: jnp.sum(jnp.square(x)), tree)
|
||||||
|
flattened, _ = jax.flatten_util.ravel_pytree(squared)
|
||||||
|
return jnp.sqrt(jnp.sum(flattened))
|
||||||
|
|
||||||
|
|
||||||
|
def average_metrics(metrics):
|
||||||
|
with jax.spmd_mode("allow_all"):
|
||||||
|
return jax.tree_map(
|
||||||
|
lambda *args: jnp.mean(jnp.stack(args)),
|
||||||
|
*metrics
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_float_dtype_by_name(dtype):
|
||||||
|
return {
|
||||||
|
'bf16': jnp.bfloat16,
|
||||||
|
'bfloat16': jnp.bfloat16,
|
||||||
|
'fp16': jnp.float16,
|
||||||
|
'float16': jnp.float16,
|
||||||
|
'fp32': jnp.float32,
|
||||||
|
'float32': jnp.float32,
|
||||||
|
'fp64': jnp.float64,
|
||||||
|
'float64': jnp.float64,
|
||||||
|
}[dtype]
|
||||||
|
|
||||||
|
|
||||||
|
def float_tensor_to_dtype(tensor, dtype):
|
||||||
|
if dtype is None or dtype == '':
|
||||||
|
return tensor
|
||||||
|
if isinstance(dtype, str):
|
||||||
|
dtype = get_float_dtype_by_name(dtype)
|
||||||
|
float_dtypes = (jnp.bfloat16, jnp.float16, jnp.float32, jnp.float64)
|
||||||
|
if getattr(tensor, 'dtype', None) in float_dtypes:
|
||||||
|
tensor = tensor.astype(dtype)
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
def float_to_dtype(tree, dtype):
|
||||||
|
return jax.tree_util.tree_map(
|
||||||
|
partial(float_tensor_to_dtype, dtype=dtype), tree
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_gradient_checkpoint_policy(name):
|
||||||
|
return {
|
||||||
|
'everything_saveable': jax.checkpoint_policies.everything_saveable,
|
||||||
|
'nothing_saveable': jax.checkpoint_policies.nothing_saveable,
|
||||||
|
'checkpoint_dots': jax.checkpoint_policies.checkpoint_dots,
|
||||||
|
'checkpoint_dots_with_no_batch_dims': jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims,
|
||||||
|
}[name]
|
||||||
|
|
||||||
|
|
||||||
|
def tree_path_to_string(path, sep=None):
|
||||||
|
keys = []
|
||||||
|
for key in path:
|
||||||
|
if isinstance(key, jax.tree_util.SequenceKey):
|
||||||
|
keys.append(str(key.idx))
|
||||||
|
elif isinstance(key, jax.tree_util.DictKey):
|
||||||
|
keys.append(str(key.key))
|
||||||
|
elif isinstance(key, jax.tree_util.GetAttrKey):
|
||||||
|
keys.append(str(key.name))
|
||||||
|
elif isinstance(key, jax.tree_util.FlattenedIndexKey):
|
||||||
|
keys.append(str(key.key))
|
||||||
|
else:
|
||||||
|
keys.append(str(key))
|
||||||
|
if sep is None:
|
||||||
|
return tuple(keys)
|
||||||
|
return sep.join(keys)
|
||||||
|
|
||||||
|
|
||||||
|
def flatten_tree(xs, is_leaf=None, sep=None):
|
||||||
|
flattened, _ = jax.tree_util.tree_flatten_with_path(xs, is_leaf=is_leaf)
|
||||||
|
output = {}
|
||||||
|
for key, val in flattened:
|
||||||
|
output[tree_path_to_string(key, sep=sep)] = val
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def named_tree_map(f, tree, *rest, is_leaf=None, sep=None):
|
||||||
|
""" An extended version of jax.tree_util.tree_map, where the mapped function
|
||||||
|
f takes both the name (path) and the tree leaf as input.
|
||||||
|
"""
|
||||||
|
return jax.tree_util.tree_map_with_path(
|
||||||
|
lambda path, x, *r: f(tree_path_to_string(path, sep=sep), x, *r),
|
||||||
|
tree, *rest,
|
||||||
|
is_leaf=is_leaf
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def match_partition_rules(rules, params):
|
||||||
|
""" Returns a pytree of PartitionSpec according to rules. Supports handling
|
||||||
|
Flax TrainState and Optax optimizer state.
|
||||||
|
"""
|
||||||
|
def get_partition_spec(name, leaf):
|
||||||
|
if len(leaf.shape) == 0 or np.prod(leaf.shape) == 1:
|
||||||
|
""" Don't partition scalar values. """
|
||||||
|
return PS()
|
||||||
|
for rule, ps in rules:
|
||||||
|
if re.search(rule, name) is not None:
|
||||||
|
return ps
|
||||||
|
raise ValueError(f'Partition rule not found for param: {name}')
|
||||||
|
return named_tree_map(get_partition_spec, params, sep='/')
|
||||||
|
|
||||||
|
|
||||||
|
def get_weight_decay_mask(exclusions):
|
||||||
|
""" Return a weight decay mask function that computes the pytree masks
|
||||||
|
according to the given exclusion rules.
|
||||||
|
"""
|
||||||
|
def decay(name, _):
|
||||||
|
for rule in exclusions:
|
||||||
|
if re.search(rule, name) is not None:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def weight_decay_mask(params):
|
||||||
|
return named_tree_map(decay, params, sep='/')
|
||||||
|
|
||||||
|
return weight_decay_mask
|
||||||
|
|
||||||
|
|
||||||
|
def tree_apply(fns, tree):
|
||||||
|
""" Apply a pytree of functions to the pytree. """
|
||||||
|
return jax.tree_util.tree_map(lambda fn, x: fn(x), fns, tree)
|
||||||
|
|
||||||
0
EasyLM/models/__init__.py
Normal file
0
EasyLM/models/__init__.py
Normal file
0
EasyLM/models/gptj/__init__.py
Normal file
0
EasyLM/models/gptj/__init__.py
Normal file
1054
EasyLM/models/gptj/gptj_model.py
Normal file
1054
EasyLM/models/gptj/gptj_model.py
Normal file
File diff suppressed because it is too large
Load Diff
396
EasyLM/models/gptj/gptj_serve.py
Normal file
396
EasyLM/models/gptj/gptj_serve.py
Normal file
@@ -0,0 +1,396 @@
|
|||||||
|
import pprint
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import mlxu
|
||||||
|
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
from jax.experimental.pjit import pjit
|
||||||
|
from jax.sharding import PartitionSpec as PS
|
||||||
|
import flax
|
||||||
|
from flax import linen as nn
|
||||||
|
from flax.jax_utils import prefetch_to_device
|
||||||
|
from flax.training.train_state import TrainState
|
||||||
|
import optax
|
||||||
|
from transformers import GenerationConfig, FlaxLogitsProcessorList
|
||||||
|
|
||||||
|
from EasyLM.checkpoint import StreamingCheckpointer
|
||||||
|
from EasyLM.serving import LMServer
|
||||||
|
from EasyLM.jax_utils import (
|
||||||
|
JaxRNG, JaxDistributedConfig, next_rng, match_partition_rules, tree_apply,
|
||||||
|
set_random_seed, get_float_dtype_by_name, make_shard_and_gather_fns,
|
||||||
|
with_sharding_constraint, FlaxTemperatureLogitsWarper
|
||||||
|
)
|
||||||
|
from EasyLM.models.gptj.gptj_model import (
|
||||||
|
GPTJConfig, FlaxGPTJForCausalLMModule, FlaxGPTJForCausalLM
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
|
||||||
|
seed=42,
|
||||||
|
initialize_jax_distributed=False,
|
||||||
|
mesh_dim='1,-1,1',
|
||||||
|
dtype='bf16',
|
||||||
|
input_length=1024,
|
||||||
|
seq_length=2048,
|
||||||
|
top_k=50,
|
||||||
|
top_p=1.0,
|
||||||
|
do_sample=True,
|
||||||
|
num_beams=1,
|
||||||
|
add_bos_token=False,
|
||||||
|
load_gptj_config='',
|
||||||
|
load_checkpoint='',
|
||||||
|
tokenizer=GPTJConfig.get_tokenizer_config(),
|
||||||
|
lm_server=LMServer.get_default_config(),
|
||||||
|
jax_distributed=JaxDistributedConfig.get_default_config(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main(argv):
|
||||||
|
JaxDistributedConfig.initialize(FLAGS.jax_distributed)
|
||||||
|
set_random_seed(FLAGS.seed)
|
||||||
|
|
||||||
|
prefix_tokenizer = GPTJConfig.get_tokenizer(
|
||||||
|
FLAGS.tokenizer, truncation_side='left', padding_side='left'
|
||||||
|
)
|
||||||
|
tokenizer = GPTJConfig.get_tokenizer(
|
||||||
|
FLAGS.tokenizer, truncation_side='right', padding_side='right'
|
||||||
|
)
|
||||||
|
|
||||||
|
with jax.default_device(jax.devices("cpu")[0]):
|
||||||
|
gptj_config = GPTJConfig.load_config(FLAGS.load_gptj_config)
|
||||||
|
load_type, load_path = FLAGS.load_checkpoint.split('::', 1)
|
||||||
|
if load_type == 'huggingface':
|
||||||
|
params = gptj_config.load_pretrained(load_path)
|
||||||
|
else:
|
||||||
|
_, params = StreamingCheckpointer.load_trainstate_checkpoint(
|
||||||
|
FLAGS.load_checkpoint, disallow_trainstate=True
|
||||||
|
)
|
||||||
|
|
||||||
|
hf_model = FlaxGPTJForCausalLM(
|
||||||
|
gptj_config,
|
||||||
|
input_shape=(1, FLAGS.seq_length),
|
||||||
|
seed=FLAGS.seed,
|
||||||
|
_do_init=False
|
||||||
|
)
|
||||||
|
|
||||||
|
model_ps = match_partition_rules(
|
||||||
|
GPTJConfig.get_partition_rules(), params
|
||||||
|
)
|
||||||
|
shard_fns, _ = make_shard_and_gather_fns(
|
||||||
|
model_ps, get_float_dtype_by_name(FLAGS.dtype)
|
||||||
|
)
|
||||||
|
|
||||||
|
@partial(
|
||||||
|
pjit,
|
||||||
|
in_shardings=(model_ps, PS(), PS()),
|
||||||
|
out_shardings=(PS(), PS(), PS())
|
||||||
|
)
|
||||||
|
def forward_loglikelihood(params, rng, batch):
|
||||||
|
batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
|
||||||
|
rng_generator = JaxRNG(rng)
|
||||||
|
input_tokens = batch['input_tokens']
|
||||||
|
output_tokens = batch['output_tokens']
|
||||||
|
input_mask = batch['input_mask']
|
||||||
|
output_mask = batch['output_mask']
|
||||||
|
|
||||||
|
logits = hf_model.module.apply(
|
||||||
|
params, input_tokens, attention_mask=input_mask,
|
||||||
|
deterministic=True, rngs=rng_generator(gptj_config.rng_keys()),
|
||||||
|
).logits
|
||||||
|
if gptj_config.n_real_tokens is not None:
|
||||||
|
logits = logits.at[:, :, gptj_config.n_real_tokens:].set(-1e8)
|
||||||
|
loglikelihood = -optax.softmax_cross_entropy_with_integer_labels(
|
||||||
|
logits, output_tokens
|
||||||
|
)
|
||||||
|
loglikelihood = jnp.sum(loglikelihood * output_mask, axis=-1)
|
||||||
|
match_count = jnp.sum(
|
||||||
|
(jnp.argmax(logits, axis=-1) == output_tokens) * output_mask,
|
||||||
|
axis=-1
|
||||||
|
)
|
||||||
|
total = jnp.sum(output_mask, axis=-1)
|
||||||
|
is_greedy = match_count == total
|
||||||
|
return loglikelihood, is_greedy, rng_generator()
|
||||||
|
|
||||||
|
|
||||||
|
@partial(
|
||||||
|
pjit,
|
||||||
|
in_shardings=(model_ps, PS(), PS(), PS()),
|
||||||
|
out_shardings=(PS(), PS())
|
||||||
|
)
|
||||||
|
def forward_generate(params, rng, batch, temperature):
|
||||||
|
batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
|
||||||
|
rng_generator = JaxRNG(rng)
|
||||||
|
output = hf_model.generate(
|
||||||
|
batch['input_tokens'],
|
||||||
|
attention_mask=batch['attention_mask'],
|
||||||
|
params=params['params'],
|
||||||
|
prng_key=rng_generator(),
|
||||||
|
logits_processor=FlaxLogitsProcessorList(
|
||||||
|
[FlaxTemperatureLogitsWarper(temperature)]
|
||||||
|
),
|
||||||
|
generation_config=GenerationConfig(
|
||||||
|
max_new_tokens=FLAGS.seq_length - FLAGS.input_length,
|
||||||
|
pad_token_id=tokenizer.eos_token_id,
|
||||||
|
bos_token_id=tokenizer.bos_token_id,
|
||||||
|
eos_token_id=tokenizer.eos_token_id,
|
||||||
|
do_sample=FLAGS.do_sample,
|
||||||
|
num_beams=FLAGS.num_beams,
|
||||||
|
top_k=FLAGS.top_k,
|
||||||
|
top_p=FLAGS.top_p,
|
||||||
|
)
|
||||||
|
).sequences[:, batch['input_tokens'].shape[1]:]
|
||||||
|
return output, rng_generator()
|
||||||
|
|
||||||
|
@partial(
|
||||||
|
pjit,
|
||||||
|
in_shardings=(model_ps, PS(), PS()),
|
||||||
|
out_shardings=(PS(), PS())
|
||||||
|
)
|
||||||
|
def forward_greedy_generate(params, rng, batch):
|
||||||
|
batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
|
||||||
|
rng_generator = JaxRNG(rng)
|
||||||
|
output = hf_model.generate(
|
||||||
|
batch['input_tokens'],
|
||||||
|
attention_mask=batch['attention_mask'],
|
||||||
|
params=params['params'],
|
||||||
|
prng_key=rng_generator(),
|
||||||
|
generation_config=GenerationConfig(
|
||||||
|
max_new_tokens=FLAGS.seq_length - FLAGS.input_length,
|
||||||
|
pad_token_id=tokenizer.eos_token_id,
|
||||||
|
bos_token_id=tokenizer.bos_token_id,
|
||||||
|
eos_token_id=tokenizer.eos_token_id,
|
||||||
|
do_sample=False,
|
||||||
|
num_beams=1,
|
||||||
|
)
|
||||||
|
).sequences[:, batch['input_tokens'].shape[1]:]
|
||||||
|
return output, rng_generator()
|
||||||
|
|
||||||
|
mesh = GPTJConfig.get_jax_mesh(FLAGS.mesh_dim)
|
||||||
|
with mesh:
|
||||||
|
params = tree_apply(shard_fns, params)
|
||||||
|
sharded_rng = next_rng()
|
||||||
|
|
||||||
|
class ModelServer(LMServer):
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def loglikelihood(prefix_text, text):
|
||||||
|
nonlocal sharded_rng
|
||||||
|
prefix = prefix_tokenizer(
|
||||||
|
prefix_text,
|
||||||
|
padding='max_length',
|
||||||
|
truncation=True,
|
||||||
|
max_length=FLAGS.input_length,
|
||||||
|
return_tensors='np',
|
||||||
|
)
|
||||||
|
inputs = tokenizer(
|
||||||
|
text,
|
||||||
|
padding='max_length',
|
||||||
|
truncation=True,
|
||||||
|
max_length=FLAGS.seq_length - FLAGS.input_length,
|
||||||
|
return_tensors='np',
|
||||||
|
)
|
||||||
|
output_tokens = np.concatenate([prefix.input_ids, inputs.input_ids], axis=1)
|
||||||
|
bos_tokens = np.full(
|
||||||
|
(output_tokens.shape[0], 1), tokenizer.bos_token_id, dtype=np.int32
|
||||||
|
)
|
||||||
|
input_tokens = np.concatenate([bos_tokens, output_tokens[:, :-1]], axis=-1)
|
||||||
|
input_mask = np.concatenate(
|
||||||
|
[prefix.attention_mask, inputs.attention_mask], axis=1
|
||||||
|
)
|
||||||
|
if FLAGS.add_bos_token:
|
||||||
|
bos_mask = np.ones_like(input_mask[:, :1])
|
||||||
|
else:
|
||||||
|
bos_mask = np.zeros_like(input_mask[:, :1])
|
||||||
|
|
||||||
|
input_mask = np.concatenate([bos_mask, input_mask[:, :-1]], axis=1)
|
||||||
|
output_mask = np.concatenate(
|
||||||
|
[np.zeros_like(prefix.attention_mask), inputs.attention_mask], axis=1
|
||||||
|
)
|
||||||
|
batch = dict(
|
||||||
|
input_tokens=input_tokens,
|
||||||
|
output_tokens=output_tokens,
|
||||||
|
input_mask=input_mask,
|
||||||
|
output_mask=output_mask,
|
||||||
|
)
|
||||||
|
with mesh:
|
||||||
|
loglikelihood, is_greedy, sharded_rng = forward_loglikelihood(
|
||||||
|
params, sharded_rng, batch
|
||||||
|
)
|
||||||
|
loglikelihood, is_greedy = jax.device_get((loglikelihood, is_greedy))
|
||||||
|
return loglikelihood, is_greedy
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def loglikelihood_rolling(text):
|
||||||
|
nonlocal sharded_rng
|
||||||
|
inputs = tokenizer(
|
||||||
|
text,
|
||||||
|
padding='longest',
|
||||||
|
truncation=False,
|
||||||
|
max_length=np.iinfo(np.int32).max,
|
||||||
|
return_tensors='np',
|
||||||
|
)
|
||||||
|
batch_size = inputs.input_ids.shape[0]
|
||||||
|
output_tokens = inputs.input_ids
|
||||||
|
attention_mask = inputs.attention_mask
|
||||||
|
|
||||||
|
if output_tokens.shape[1] < FLAGS.seq_length:
|
||||||
|
padding_length = FLAGS.seq_length - output_tokens.shape[1]
|
||||||
|
pad_tokens = np.full(
|
||||||
|
(batch_size, padding_length), tokenizer.pad_token_id, dtype=np.int32
|
||||||
|
)
|
||||||
|
output_tokens = np.concatenate([output_tokens, pad_tokens], axis=-1)
|
||||||
|
pad_mask = np.zeros(
|
||||||
|
(batch_size, padding_length), dtype=inputs.attention_mask.dtype
|
||||||
|
)
|
||||||
|
attention_mask = np.concatenate([attention_mask, pad_mask], axis=-1)
|
||||||
|
|
||||||
|
bos_tokens = np.full(
|
||||||
|
(batch_size, 1), tokenizer.bos_token_id, dtype=np.int32
|
||||||
|
)
|
||||||
|
input_tokens = np.concatenate([bos_tokens, output_tokens[:, :-1]], axis=-1)
|
||||||
|
bos_mask = np.ones((batch_size, 1), dtype=inputs.attention_mask.dtype)
|
||||||
|
total_seq_length = output_tokens.shape[1]
|
||||||
|
|
||||||
|
total_loglikelihood = 0.0
|
||||||
|
total_is_greedy = True
|
||||||
|
# Sliding window
|
||||||
|
for i in range(0, total_seq_length, FLAGS.seq_length):
|
||||||
|
# Last window
|
||||||
|
if i + FLAGS.seq_length > total_seq_length:
|
||||||
|
last_output_mask = np.copy(attention_mask[:, -FLAGS.seq_length:])
|
||||||
|
last_output_mask[:, :i - total_seq_length] = 0.0
|
||||||
|
|
||||||
|
batch = dict(
|
||||||
|
input_tokens=input_tokens[:, -FLAGS.seq_length:],
|
||||||
|
output_tokens=output_tokens[:, -FLAGS.seq_length:],
|
||||||
|
input_mask=attention_mask[:, -FLAGS.seq_length:],
|
||||||
|
output_mask=last_output_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Normal window
|
||||||
|
else:
|
||||||
|
batch = dict(
|
||||||
|
input_tokens=input_tokens[:, i:i + FLAGS.seq_length],
|
||||||
|
output_tokens=output_tokens[:, i:i + FLAGS.seq_length],
|
||||||
|
input_mask=attention_mask[:, i:i + FLAGS.seq_length],
|
||||||
|
output_mask=attention_mask[:, i:i + FLAGS.seq_length],
|
||||||
|
)
|
||||||
|
|
||||||
|
with mesh:
|
||||||
|
loglikelihood, is_greedy, sharded_rng = forward_loglikelihood(
|
||||||
|
params, sharded_rng, batch
|
||||||
|
)
|
||||||
|
loglikelihood, is_greedy = jax.device_get((loglikelihood, is_greedy))
|
||||||
|
|
||||||
|
total_loglikelihood += loglikelihood
|
||||||
|
total_is_greedy = np.logical_and(is_greedy, total_is_greedy)
|
||||||
|
|
||||||
|
return total_loglikelihood, total_is_greedy
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def generate(text, temperature):
|
||||||
|
nonlocal sharded_rng
|
||||||
|
inputs = prefix_tokenizer(
|
||||||
|
text,
|
||||||
|
padding='max_length',
|
||||||
|
truncation=True,
|
||||||
|
max_length=FLAGS.input_length,
|
||||||
|
return_tensors='np',
|
||||||
|
)
|
||||||
|
input_tokens = inputs.input_ids
|
||||||
|
input_mask = inputs.attention_mask
|
||||||
|
if FLAGS.add_bos_token:
|
||||||
|
input_tokens[:, 0] = tokenizer.bos_token_id
|
||||||
|
input_mask[:, 0] = 1
|
||||||
|
batch = dict(
|
||||||
|
input_tokens=input_tokens,
|
||||||
|
attention_mask=input_mask,
|
||||||
|
)
|
||||||
|
with mesh:
|
||||||
|
output, sharded_rng = forward_generate(
|
||||||
|
params, sharded_rng, batch, temperature
|
||||||
|
)
|
||||||
|
output = jax.device_get(output)
|
||||||
|
output_text = []
|
||||||
|
for text in list(tokenizer.batch_decode(output)):
|
||||||
|
if tokenizer.eos_token in text:
|
||||||
|
text = text.split(tokenizer.eos_token, maxsplit=1)[0]
|
||||||
|
output_text.append(text)
|
||||||
|
|
||||||
|
return output_text
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def greedy_until(prefix_text, until, max_length):
|
||||||
|
nonlocal sharded_rng
|
||||||
|
all_outputs = []
|
||||||
|
for pf, ut in zip(prefix_text, until):
|
||||||
|
if isinstance(ut, str):
|
||||||
|
ut = [ut]
|
||||||
|
total_length = 0
|
||||||
|
total_generated = ''
|
||||||
|
|
||||||
|
while total_length < max_length:
|
||||||
|
pf_tokens = tokenizer(
|
||||||
|
pf,
|
||||||
|
padding=False,
|
||||||
|
truncation=False,
|
||||||
|
max_length=np.iinfo(np.int32).max,
|
||||||
|
return_tensors='np',
|
||||||
|
)
|
||||||
|
input_tokens = pf_tokens.input_ids
|
||||||
|
attention_mask = pf_tokens.attention_mask
|
||||||
|
|
||||||
|
if input_tokens.shape[1] < FLAGS.input_length:
|
||||||
|
extra = FLAGS.input_length - input_tokens.shape[1]
|
||||||
|
pad_tokens = np.full(
|
||||||
|
(1, extra), tokenizer.pad_token_id, dtype=np.int32
|
||||||
|
)
|
||||||
|
input_tokens = np.concatenate(
|
||||||
|
[pad_tokens, input_tokens], axis=1
|
||||||
|
)
|
||||||
|
pad_attention = np.zeros((1, extra), dtype=attention_mask.dtype)
|
||||||
|
attention_mask = np.concatenate(
|
||||||
|
[pad_attention, attention_mask], axis=1
|
||||||
|
)
|
||||||
|
elif input_tokens.shape[1] > FLAGS.input_length:
|
||||||
|
input_tokens = input_tokens[:, -FLAGS.input_length:]
|
||||||
|
attention_mask = attention_mask[:, -FLAGS.input_length:]
|
||||||
|
|
||||||
|
if FLAGS.add_bos_token:
|
||||||
|
input_tokens[:, 0] = tokenizer.bos_token_id
|
||||||
|
attention_mask[:, 0] = 1
|
||||||
|
|
||||||
|
batch = dict(input_tokens=input_tokens, attention_mask=attention_mask)
|
||||||
|
|
||||||
|
with mesh:
|
||||||
|
output, sharded_rng = forward_greedy_generate(
|
||||||
|
params, sharded_rng, batch
|
||||||
|
)
|
||||||
|
output = jax.device_get(output)
|
||||||
|
|
||||||
|
total_length += output.shape[1]
|
||||||
|
output_text = tokenizer.batch_decode(output)[0]
|
||||||
|
total_generated = total_generated + output_text
|
||||||
|
pf = pf + output_text
|
||||||
|
|
||||||
|
done = False
|
||||||
|
for s in ut:
|
||||||
|
if s in total_generated:
|
||||||
|
total_generated = total_generated.split(s, maxsplit=1)[0]
|
||||||
|
done = True
|
||||||
|
if done:
|
||||||
|
break
|
||||||
|
|
||||||
|
all_outputs.append(total_generated)
|
||||||
|
|
||||||
|
return all_outputs
|
||||||
|
|
||||||
|
|
||||||
|
server = ModelServer(FLAGS.lm_server)
|
||||||
|
server.run()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
mlxu.run(main)
|
||||||
272
EasyLM/models/gptj/gptj_train.py
Normal file
272
EasyLM/models/gptj/gptj_train.py
Normal file
@@ -0,0 +1,272 @@
|
|||||||
|
import pprint
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
from tqdm import tqdm, trange
|
||||||
|
import numpy as np
|
||||||
|
import mlxu
|
||||||
|
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
from jax.experimental.pjit import pjit, with_sharding_constraint
|
||||||
|
from jax.sharding import PartitionSpec as PS
|
||||||
|
from flax.training.train_state import TrainState
|
||||||
|
|
||||||
|
from EasyLM.data import DatasetFactory
|
||||||
|
from EasyLM.checkpoint import StreamingCheckpointer
|
||||||
|
from EasyLM.optimizers import OptimizerFactory
|
||||||
|
from EasyLM.jax_utils import (
|
||||||
|
JaxRNG, JaxDistributedConfig, next_rng, match_partition_rules,
|
||||||
|
cross_entropy_loss_and_accuracy, global_norm, get_float_dtype_by_name,
|
||||||
|
set_random_seed, average_metrics, get_weight_decay_mask,
|
||||||
|
make_shard_and_gather_fns, tree_apply
|
||||||
|
)
|
||||||
|
from EasyLM.models.gptj.gptj_model import GPTJConfig, FlaxGPTJForCausalLMModule
|
||||||
|
|
||||||
|
|
||||||
|
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
|
||||||
|
seed=42,
|
||||||
|
mesh_dim='1,-1,1',
|
||||||
|
dtype='fp32',
|
||||||
|
total_steps=10000,
|
||||||
|
load_gptj_config='',
|
||||||
|
update_gptj_config='',
|
||||||
|
load_checkpoint='',
|
||||||
|
load_dataset_state='',
|
||||||
|
log_freq=50,
|
||||||
|
save_model_freq=0,
|
||||||
|
save_milestone_freq=0,
|
||||||
|
eval_steps=0,
|
||||||
|
tokenizer=GPTJConfig.get_tokenizer_config(),
|
||||||
|
train_dataset=DatasetFactory.get_default_config(),
|
||||||
|
eval_dataset=DatasetFactory.get_default_config(),
|
||||||
|
optimizer=OptimizerFactory.get_default_config(),
|
||||||
|
checkpointer=StreamingCheckpointer.get_default_config(),
|
||||||
|
gptj=GPTJConfig.get_default_config(),
|
||||||
|
logger=mlxu.WandBLogger.get_default_config(),
|
||||||
|
log_all_worker=False,
|
||||||
|
jax_distributed=JaxDistributedConfig.get_default_config(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main(argv):
|
||||||
|
JaxDistributedConfig.initialize(FLAGS.jax_distributed)
|
||||||
|
variant = mlxu.get_user_flags(FLAGS, FLAGS_DEF)
|
||||||
|
flags_config_dict = mlxu.user_flags_to_config_dict(FLAGS, FLAGS_DEF)
|
||||||
|
logger = mlxu.WandBLogger(
|
||||||
|
config=FLAGS.logger,
|
||||||
|
variant=variant,
|
||||||
|
enable=FLAGS.log_all_worker or (jax.process_index() == 0),
|
||||||
|
)
|
||||||
|
set_random_seed(FLAGS.seed)
|
||||||
|
|
||||||
|
tokenizer = GPTJConfig.get_tokenizer(FLAGS.tokenizer)
|
||||||
|
dataset = DatasetFactory.load_dataset(FLAGS.train_dataset, tokenizer)
|
||||||
|
if FLAGS.load_dataset_state != '':
|
||||||
|
dataset.load_state_dict(mlxu.load_pickle(FLAGS.load_dataset_state))
|
||||||
|
|
||||||
|
if FLAGS.eval_steps > 0:
|
||||||
|
eval_dataset = DatasetFactory.load_dataset(
|
||||||
|
FLAGS.eval_dataset, dataset.tokenizer
|
||||||
|
)
|
||||||
|
eval_iterator = iter(eval_dataset)
|
||||||
|
|
||||||
|
seq_length = dataset.seq_length
|
||||||
|
|
||||||
|
if FLAGS.load_gptj_config != '':
|
||||||
|
gptj_config = GPTJConfig.load_config(FLAGS.load_gptj_config)
|
||||||
|
else:
|
||||||
|
gptj_config = GPTJConfig(**FLAGS.gptj)
|
||||||
|
|
||||||
|
if FLAGS.update_gptj_config != '':
|
||||||
|
gptj_config.update(dict(eval(FLAGS.update_gptj_config)))
|
||||||
|
|
||||||
|
gptj_config.update(dict(
|
||||||
|
bos_token_id=dataset.tokenizer.bos_token_id,
|
||||||
|
eos_token_id=dataset.tokenizer.eos_token_id,
|
||||||
|
))
|
||||||
|
if gptj_config.vocab_size < dataset.vocab_size:
|
||||||
|
gptj_config.update(dict(vocab_size=dataset.vocab_size))
|
||||||
|
|
||||||
|
model = FlaxGPTJForCausalLMModule(
|
||||||
|
gptj_config, dtype=get_float_dtype_by_name(FLAGS.dtype)
|
||||||
|
)
|
||||||
|
|
||||||
|
optimizer, optimizer_info = OptimizerFactory.get_optimizer(
|
||||||
|
FLAGS.optimizer,
|
||||||
|
get_weight_decay_mask(GPTJConfig.get_weight_decay_exclusions()),
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_trainstate_from_params(params):
|
||||||
|
return TrainState.create(params=params, tx=optimizer, apply_fn=None)
|
||||||
|
|
||||||
|
def init_fn(rng):
|
||||||
|
rng_generator = JaxRNG(rng)
|
||||||
|
params = model.init(
|
||||||
|
input_ids=jnp.zeros((4, seq_length), dtype=jnp.int32),
|
||||||
|
position_ids=jnp.zeros((4, seq_length), dtype=jnp.int32),
|
||||||
|
attention_mask=jnp.ones((4, seq_length), dtype=jnp.int32),
|
||||||
|
rngs=rng_generator(gptj_config.rng_keys()),
|
||||||
|
)
|
||||||
|
return TrainState.create(params=params, tx=optimizer, apply_fn=None)
|
||||||
|
|
||||||
|
def train_step(train_state, rng, batch):
|
||||||
|
rng_generator = JaxRNG(rng)
|
||||||
|
batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
|
||||||
|
def loss_and_accuracy(params):
|
||||||
|
logits = model.apply(
|
||||||
|
params, batch['input_tokens'], deterministic=False,
|
||||||
|
rngs=rng_generator(gptj_config.rng_keys()),
|
||||||
|
).logits
|
||||||
|
return cross_entropy_loss_and_accuracy(
|
||||||
|
logits, batch['target_tokens'], batch['loss_masks']
|
||||||
|
)
|
||||||
|
grad_fn = jax.value_and_grad(loss_and_accuracy, has_aux=True)
|
||||||
|
(loss, accuracy), grads = grad_fn(train_state.params)
|
||||||
|
train_state = train_state.apply_gradients(grads=grads)
|
||||||
|
metrics = dict(
|
||||||
|
loss=loss,
|
||||||
|
accuracy=accuracy,
|
||||||
|
learning_rate=optimizer_info['learning_rate_schedule'](train_state.step),
|
||||||
|
gradient_norm=global_norm(grads),
|
||||||
|
param_norm=global_norm(train_state.params),
|
||||||
|
)
|
||||||
|
return train_state, rng_generator(), metrics
|
||||||
|
|
||||||
|
def eval_step(train_state, rng, batch):
|
||||||
|
rng_generator = JaxRNG(rng)
|
||||||
|
batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
|
||||||
|
logits = model.apply(
|
||||||
|
train_state.params, batch['input_tokens'], deterministic=True,
|
||||||
|
rngs=rng_generator(gptj_config.rng_keys()),
|
||||||
|
).logits
|
||||||
|
loss, accuracy = cross_entropy_loss_and_accuracy(
|
||||||
|
logits, batch['target_tokens'], batch['loss_masks']
|
||||||
|
)
|
||||||
|
metrics = dict(
|
||||||
|
eval_loss=loss,
|
||||||
|
eval_accuracy=accuracy,
|
||||||
|
)
|
||||||
|
return rng_generator(), metrics
|
||||||
|
|
||||||
|
train_state_shapes = jax.eval_shape(init_fn, next_rng())
|
||||||
|
train_state_partition = match_partition_rules(
|
||||||
|
GPTJConfig.get_partition_rules(), train_state_shapes
|
||||||
|
)
|
||||||
|
|
||||||
|
shard_fns, gather_fns = make_shard_and_gather_fns(
|
||||||
|
train_state_partition, train_state_shapes
|
||||||
|
)
|
||||||
|
checkpointer = StreamingCheckpointer(
|
||||||
|
FLAGS.checkpointer, logger.output_dir,
|
||||||
|
enable=jax.process_index() == 0,
|
||||||
|
)
|
||||||
|
|
||||||
|
sharded_init_fn = pjit(
|
||||||
|
init_fn,
|
||||||
|
in_shardings=PS(),
|
||||||
|
out_shardings=train_state_partition
|
||||||
|
)
|
||||||
|
|
||||||
|
sharded_create_trainstate_from_params = pjit(
|
||||||
|
create_trainstate_from_params,
|
||||||
|
in_shardings=(train_state_partition.params, ),
|
||||||
|
out_shardings=train_state_partition,
|
||||||
|
donate_argnums=(0, ),
|
||||||
|
)
|
||||||
|
|
||||||
|
sharded_train_step = pjit(
|
||||||
|
train_step,
|
||||||
|
in_shardings=(train_state_partition, PS(), PS()),
|
||||||
|
out_shardings=(train_state_partition, PS(), PS()),
|
||||||
|
donate_argnums=(0, 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
sharded_eval_step = pjit(
|
||||||
|
eval_step,
|
||||||
|
in_shardings=(train_state_partition, PS(), PS()),
|
||||||
|
out_shardings=(PS(), PS()),
|
||||||
|
donate_argnums=(1,),
|
||||||
|
)
|
||||||
|
|
||||||
|
def save_checkpoint(train_state, milestone=False):
|
||||||
|
step = int(jax.device_get(train_state.step))
|
||||||
|
metadata = dict(
|
||||||
|
step=step,
|
||||||
|
variant=variant,
|
||||||
|
flags=flags_config_dict,
|
||||||
|
gptj_config=gptj_config.to_dict(),
|
||||||
|
)
|
||||||
|
checkpointer.save_all(
|
||||||
|
train_state=train_state,
|
||||||
|
gather_fns=gather_fns,
|
||||||
|
metadata=metadata,
|
||||||
|
dataset=dataset.get_state_dict(),
|
||||||
|
milestone=milestone,
|
||||||
|
)
|
||||||
|
|
||||||
|
mesh = GPTJConfig.get_jax_mesh(FLAGS.mesh_dim)
|
||||||
|
with mesh:
|
||||||
|
train_state, restored_params = None, None
|
||||||
|
if FLAGS.load_checkpoint != '':
|
||||||
|
load_type, load_path = FLAGS.load_checkpoint.split('::', 1)
|
||||||
|
if load_type == 'huggingface':
|
||||||
|
restored_params = tree_apply(
|
||||||
|
shard_fns.params, gptj_config.load_pretrained(load_path)
|
||||||
|
)
|
||||||
|
train_state = None
|
||||||
|
else:
|
||||||
|
train_state, restored_params = checkpointer.load_trainstate_checkpoint(
|
||||||
|
FLAGS.load_checkpoint, train_state_shapes, shard_fns
|
||||||
|
)
|
||||||
|
|
||||||
|
if train_state is None and restored_params is None:
|
||||||
|
# Initialize from scratch
|
||||||
|
train_state = sharded_init_fn(next_rng())
|
||||||
|
elif train_state is None and restored_params is not None:
|
||||||
|
# Restore from params but initialize train_state
|
||||||
|
train_state = sharded_create_trainstate_from_params(restored_params)
|
||||||
|
del restored_params
|
||||||
|
|
||||||
|
start_step = int(jax.device_get(train_state.step))
|
||||||
|
|
||||||
|
if FLAGS.save_model_freq > 0:
|
||||||
|
save_checkpoint(train_state)
|
||||||
|
|
||||||
|
sharded_rng = next_rng()
|
||||||
|
|
||||||
|
step_counter = trange(start_step, FLAGS.total_steps, ncols=0)
|
||||||
|
|
||||||
|
for step, (batch, dataset_metrics) in zip(step_counter, dataset):
|
||||||
|
train_state, sharded_rng, metrics = sharded_train_step(
|
||||||
|
train_state, sharded_rng, batch
|
||||||
|
)
|
||||||
|
|
||||||
|
if step % FLAGS.log_freq == 0:
|
||||||
|
if FLAGS.eval_steps > 0:
|
||||||
|
eval_metric_list = []
|
||||||
|
for _ in range(FLAGS.eval_steps):
|
||||||
|
eval_batch, _ = next(eval_iterator)
|
||||||
|
sharded_rng, eval_metrics = sharded_eval_step(
|
||||||
|
train_state, sharded_rng, eval_batch
|
||||||
|
)
|
||||||
|
eval_metric_list.append(eval_metrics)
|
||||||
|
metrics.update(average_metrics(eval_metric_list))
|
||||||
|
|
||||||
|
log_metrics = {"step": step}
|
||||||
|
log_metrics.update(metrics)
|
||||||
|
log_metrics.update(dataset_metrics)
|
||||||
|
log_metrics = jax.device_get(log_metrics)
|
||||||
|
logger.log(log_metrics)
|
||||||
|
tqdm.write("\n" + pprint.pformat(log_metrics) + "\n")
|
||||||
|
|
||||||
|
if FLAGS.save_milestone_freq > 0 and (step + 1) % FLAGS.save_milestone_freq == 0:
|
||||||
|
save_checkpoint(train_state, milestone=True)
|
||||||
|
elif FLAGS.save_model_freq > 0 and (step + 1) % FLAGS.save_model_freq == 0:
|
||||||
|
save_checkpoint(train_state)
|
||||||
|
|
||||||
|
if FLAGS.save_model_freq > 0:
|
||||||
|
save_checkpoint(train_state)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
mlxu.run(main)
|
||||||
338
EasyLM/models/llama/convert_easylm_to_hf.py
Normal file
338
EasyLM/models/llama/convert_easylm_to_hf.py
Normal file
@@ -0,0 +1,338 @@
|
|||||||
|
# Copyright 2022 EleutherAI and The HuggingFace Inc. team. All rights reserved.
|
||||||
|
# Copyright 2023 Xinyang Geng
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
# This script converts LLaMA model checkpoint trained by EsayLM to the
|
||||||
|
# HuggingFace transformers LLaMA PyTorch format, which can then be loaded
|
||||||
|
# by HuggingFace transformers.
|
||||||
|
|
||||||
|
import gc
|
||||||
|
import json
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import mlxu
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import flax
|
||||||
|
from flax.traverse_util import flatten_dict
|
||||||
|
import torch
|
||||||
|
from transformers import LlamaConfig, LlamaForCausalLM
|
||||||
|
|
||||||
|
from EasyLM.checkpoint import StreamingCheckpointer
|
||||||
|
from EasyLM.jax_utils import float_tensor_to_dtype
|
||||||
|
|
||||||
|
|
||||||
|
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
|
||||||
|
load_checkpoint='',
|
||||||
|
tokenizer_path='',
|
||||||
|
model_size='13b',
|
||||||
|
output_dir='',
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
LLAMA_STANDARD_CONFIGS = {
|
||||||
|
'small': {
|
||||||
|
'vocab_size': 64256,
|
||||||
|
'dim': 768,
|
||||||
|
'intermediate_size': 3072,
|
||||||
|
'n_layers': 12,
|
||||||
|
'n_heads': 12,
|
||||||
|
'norm_eps': 1e-6,
|
||||||
|
},
|
||||||
|
'medium': {
|
||||||
|
'vocab_size': 64256,
|
||||||
|
'dim': 1024,
|
||||||
|
'intermediate_size': 4096,
|
||||||
|
'n_layers': 24,
|
||||||
|
'n_heads': 16,
|
||||||
|
'norm_eps': 1e-6,
|
||||||
|
},
|
||||||
|
'large': {
|
||||||
|
'vocab_size': 64256,
|
||||||
|
'dim': 1536,
|
||||||
|
'intermediate_size': 6144,
|
||||||
|
'n_layers': 24,
|
||||||
|
'n_heads': 16,
|
||||||
|
'norm_eps': 1e-6,
|
||||||
|
},
|
||||||
|
'xlarge': {
|
||||||
|
'vocab_size': 64256,
|
||||||
|
'dim': 2048,
|
||||||
|
'intermediate_size': 8192,
|
||||||
|
'n_layers': 24,
|
||||||
|
'n_heads': 32,
|
||||||
|
'norm_eps': 1e-6,
|
||||||
|
},
|
||||||
|
'1b': {
|
||||||
|
'vocab_size': 64256,
|
||||||
|
'dim': 2048,
|
||||||
|
'intermediate_size': 5504,
|
||||||
|
'n_layers': 22,
|
||||||
|
'n_heads': 16,
|
||||||
|
'norm_eps': 1e-6,
|
||||||
|
},
|
||||||
|
'3b': {
|
||||||
|
'vocab_size': 64256,
|
||||||
|
'dim': 3200,
|
||||||
|
'intermediate_size': 8640,
|
||||||
|
'n_layers': 26,
|
||||||
|
'n_heads': 32,
|
||||||
|
'norm_eps': 1e-6,
|
||||||
|
},
|
||||||
|
'7b': {
|
||||||
|
'vocab_size': 64256,
|
||||||
|
'dim': 4096,
|
||||||
|
'intermediate_size': 11008,
|
||||||
|
'n_layers': 32,
|
||||||
|
'n_heads': 32,
|
||||||
|
'norm_eps': 1e-6,
|
||||||
|
},
|
||||||
|
'13b': {
|
||||||
|
'vocab_size': 64256,
|
||||||
|
'dim': 5120,
|
||||||
|
'intermediate_size': 13824,
|
||||||
|
'n_layers': 40,
|
||||||
|
'n_heads': 40,
|
||||||
|
'norm_eps': 1e-6,
|
||||||
|
},
|
||||||
|
'30b': {
|
||||||
|
'vocab_size': 64256,
|
||||||
|
'dim': 6656,
|
||||||
|
'intermediate_size': 17920,
|
||||||
|
'n_layers': 60,
|
||||||
|
'n_heads': 52,
|
||||||
|
'norm_eps': 1e-6,
|
||||||
|
},
|
||||||
|
'65b': {
|
||||||
|
'vocab_size': 64256,
|
||||||
|
'dim': 8192,
|
||||||
|
'intermediate_size': 22016,
|
||||||
|
'n_layers': 80,
|
||||||
|
'n_heads': 64,
|
||||||
|
'norm_eps': 1e-5,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def match_keywords(string, positives, negatives):
|
||||||
|
for positive in positives:
|
||||||
|
if positive not in string:
|
||||||
|
return False
|
||||||
|
for negative in negatives:
|
||||||
|
if negative in string:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def load_and_convert_checkpoint(path):
|
||||||
|
_, flax_params = StreamingCheckpointer.load_trainstate_checkpoint(path)
|
||||||
|
flax_params = flatten_dict(flax_params['params'], sep='.')
|
||||||
|
torch_params = {}
|
||||||
|
for key, tensor in flax_params.items():
|
||||||
|
if match_keywords(key, ["kernel"], ["norm", 'ln_f']):
|
||||||
|
tensor = tensor.T
|
||||||
|
torch_params[key] = torch.tensor(
|
||||||
|
float_tensor_to_dtype(tensor, 'fp32'), dtype=torch.float16
|
||||||
|
)
|
||||||
|
return torch_params
|
||||||
|
|
||||||
|
|
||||||
|
def read_json(path):
|
||||||
|
with open(path, "r") as f:
|
||||||
|
return json.load(f)
|
||||||
|
|
||||||
|
|
||||||
|
def write_json(text, path):
|
||||||
|
with open(path, "w") as f:
|
||||||
|
json.dump(text, f)
|
||||||
|
|
||||||
|
|
||||||
|
def write_model(loaded, model_path, model_size):
|
||||||
|
os.makedirs(model_path, exist_ok=True)
|
||||||
|
tmp_model_path = os.path.join(model_path, "tmp")
|
||||||
|
os.makedirs(tmp_model_path, exist_ok=True)
|
||||||
|
|
||||||
|
params = LLAMA_STANDARD_CONFIGS[model_size]
|
||||||
|
|
||||||
|
n_layers = params["n_layers"]
|
||||||
|
n_heads = params["n_heads"]
|
||||||
|
dim = params["dim"]
|
||||||
|
dims_per_head = dim // n_heads
|
||||||
|
base = 10000.0
|
||||||
|
inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head))
|
||||||
|
|
||||||
|
# permute for sliced rotary
|
||||||
|
def permute(w):
|
||||||
|
return w.view(n_heads, dim // n_heads // 2, 2, dim).transpose(1, 2).reshape(dim, dim)
|
||||||
|
|
||||||
|
|
||||||
|
param_count = 0
|
||||||
|
index_dict = {"weight_map": {}}
|
||||||
|
for layer_i in range(n_layers):
|
||||||
|
filename = f"pytorch_model-{layer_i + 1}-of-{n_layers + 1}.bin"
|
||||||
|
state_dict = {
|
||||||
|
f"model.layers.{layer_i}.self_attn.q_proj.weight": permute(
|
||||||
|
loaded[f"transformer.h.{layer_i}.attention.wq.kernel"]
|
||||||
|
),
|
||||||
|
f"model.layers.{layer_i}.self_attn.k_proj.weight": permute(
|
||||||
|
loaded[f"transformer.h.{layer_i}.attention.wk.kernel"]
|
||||||
|
),
|
||||||
|
f"model.layers.{layer_i}.self_attn.v_proj.weight": loaded[f"transformer.h.{layer_i}.attention.wv.kernel"],
|
||||||
|
f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded[f"transformer.h.{layer_i}.attention.wo.kernel"],
|
||||||
|
|
||||||
|
f"model.layers.{layer_i}.mlp.gate_proj.weight": loaded[f"transformer.h.{layer_i}.feed_forward.w1.kernel"],
|
||||||
|
f"model.layers.{layer_i}.mlp.down_proj.weight": loaded[f"transformer.h.{layer_i}.feed_forward.w2.kernel"],
|
||||||
|
f"model.layers.{layer_i}.mlp.up_proj.weight": loaded[f"transformer.h.{layer_i}.feed_forward.w3.kernel"],
|
||||||
|
|
||||||
|
f"model.layers.{layer_i}.input_layernorm.weight": loaded[f"transformer.h.{layer_i}.attention_norm.kernel"],
|
||||||
|
f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[f"transformer.h.{layer_i}.ffn_norm.kernel"],
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
state_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq
|
||||||
|
for k, v in state_dict.items():
|
||||||
|
index_dict["weight_map"][k] = filename
|
||||||
|
param_count += v.numel()
|
||||||
|
torch.save(state_dict, os.path.join(tmp_model_path, filename))
|
||||||
|
|
||||||
|
filename = f"pytorch_model-{n_layers + 1}-of-{n_layers + 1}.bin"
|
||||||
|
# Unsharded
|
||||||
|
state_dict = {
|
||||||
|
"model.embed_tokens.weight": loaded["transformer.wte.embedding"],
|
||||||
|
"model.norm.weight": loaded["transformer.ln_f.kernel"],
|
||||||
|
"lm_head.weight": loaded["lm_head.kernel"],
|
||||||
|
}
|
||||||
|
|
||||||
|
for k, v in state_dict.items():
|
||||||
|
index_dict["weight_map"][k] = filename
|
||||||
|
param_count += v.numel()
|
||||||
|
torch.save(state_dict, os.path.join(tmp_model_path, filename))
|
||||||
|
|
||||||
|
# Write configs
|
||||||
|
index_dict["metadata"] = {"total_size": param_count * 2}
|
||||||
|
write_json(index_dict, os.path.join(tmp_model_path, "pytorch_model.bin.index.json"))
|
||||||
|
|
||||||
|
config = LlamaConfig(
|
||||||
|
vocab_size=params["vocab_size"],
|
||||||
|
hidden_size=dim,
|
||||||
|
intermediate_size=params["intermediate_size"],
|
||||||
|
num_attention_heads=params["n_heads"],
|
||||||
|
num_hidden_layers=params["n_layers"],
|
||||||
|
rms_norm_eps=params["norm_eps"],
|
||||||
|
)
|
||||||
|
config.save_pretrained(tmp_model_path)
|
||||||
|
|
||||||
|
# Make space so we can load the model properly now.
|
||||||
|
del state_dict
|
||||||
|
del loaded
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
print("Loading the checkpoint in a Llama model.")
|
||||||
|
model = LlamaForCausalLM.from_pretrained(tmp_model_path, torch_dtype=torch.float16)
|
||||||
|
# Avoid saving this as part of the config.
|
||||||
|
print("Model parameter count", model.num_parameters())
|
||||||
|
del model.config._name_or_path
|
||||||
|
|
||||||
|
print("Saving in the Transformers format.")
|
||||||
|
model.save_pretrained(model_path, safe_serialization=True)
|
||||||
|
shutil.rmtree(tmp_model_path)
|
||||||
|
|
||||||
|
|
||||||
|
def write_tokenizer(tokenizer_path, input_tokenizer_path):
|
||||||
|
print(f"Fetching the tokenizer from {input_tokenizer_path}.")
|
||||||
|
os.makedirs(tokenizer_path, exist_ok=True)
|
||||||
|
write_json(
|
||||||
|
{
|
||||||
|
"bos_token": {
|
||||||
|
"content": "<s>",
|
||||||
|
"lstrip": False,
|
||||||
|
"normalized": True,
|
||||||
|
"rstrip": False,
|
||||||
|
"single_word": False
|
||||||
|
},
|
||||||
|
"eos_token": {
|
||||||
|
"content": "</s>",
|
||||||
|
"lstrip": False,
|
||||||
|
"normalized": True,
|
||||||
|
"rstrip": False,
|
||||||
|
"single_word": False
|
||||||
|
},
|
||||||
|
"unk_token": {
|
||||||
|
"content": "<unk>",
|
||||||
|
"lstrip": False,
|
||||||
|
"normalized": True,
|
||||||
|
"rstrip": False,
|
||||||
|
"single_word": False
|
||||||
|
},
|
||||||
|
},
|
||||||
|
os.path.join(tokenizer_path, "special_tokens_map.json")
|
||||||
|
)
|
||||||
|
write_json(
|
||||||
|
{
|
||||||
|
"add_bos_token": True,
|
||||||
|
"add_eos_token": False,
|
||||||
|
"model_max_length": 2048,
|
||||||
|
"pad_token": None,
|
||||||
|
"sp_model_kwargs": {},
|
||||||
|
"tokenizer_class": "LlamaTokenizer",
|
||||||
|
"clean_up_tokenization_spaces": False,
|
||||||
|
"bos_token": {
|
||||||
|
"__type": "AddedToken",
|
||||||
|
"content": "<s>",
|
||||||
|
"lstrip": False,
|
||||||
|
"normalized": True,
|
||||||
|
"rstrip": False,
|
||||||
|
"single_word": False
|
||||||
|
},
|
||||||
|
"eos_token": {
|
||||||
|
"__type": "AddedToken",
|
||||||
|
"content": "</s>",
|
||||||
|
"lstrip": False,
|
||||||
|
"normalized": True,
|
||||||
|
"rstrip": False,
|
||||||
|
"single_word": False
|
||||||
|
},
|
||||||
|
"unk_token": {
|
||||||
|
"__type": "AddedToken",
|
||||||
|
"content": "<unk>",
|
||||||
|
"lstrip": False,
|
||||||
|
"normalized": True,
|
||||||
|
"rstrip": False,
|
||||||
|
"single_word": False
|
||||||
|
},
|
||||||
|
},
|
||||||
|
os.path.join(tokenizer_path, "tokenizer_config.json"),
|
||||||
|
)
|
||||||
|
shutil.copyfile(input_tokenizer_path, os.path.join(tokenizer_path, "tokenizer.model"))
|
||||||
|
|
||||||
|
|
||||||
|
def main(argv):
|
||||||
|
assert FLAGS.load_checkpoint != "" and FLAGS.output_dir != ""# and FLAGS.tokenizer_path != ""
|
||||||
|
assert FLAGS.model_size in LLAMA_STANDARD_CONFIGS
|
||||||
|
# write_tokenizer(
|
||||||
|
# tokenizer_path=FLAGS.output_dir,
|
||||||
|
# input_tokenizer_path=FLAGS.tokenizer_path,
|
||||||
|
# )
|
||||||
|
write_model(
|
||||||
|
load_and_convert_checkpoint(FLAGS.load_checkpoint),
|
||||||
|
model_path=FLAGS.output_dir,
|
||||||
|
model_size=FLAGS.model_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
mlxu.run(main)
|
||||||
196
EasyLM/models/llama/convert_hf_to_easylm.py
Normal file
196
EasyLM/models/llama/convert_hf_to_easylm.py
Normal file
@@ -0,0 +1,196 @@
|
|||||||
|
"""
|
||||||
|
Usage:
|
||||||
|
python convert_hf_to_easylm.py \
|
||||||
|
--checkpoint_dir /path/hf_format_dir/ \
|
||||||
|
--output_file /path/easylm_format.stream \
|
||||||
|
--model_size 7b \
|
||||||
|
--streaming
|
||||||
|
"""
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
import mlxu
|
||||||
|
import torch
|
||||||
|
import flax
|
||||||
|
|
||||||
|
from EasyLM.checkpoint import StreamingCheckpointer
|
||||||
|
|
||||||
|
LLAMA_STANDARD_CONFIGS = {
|
||||||
|
'1b': {
|
||||||
|
'dim': 2048,
|
||||||
|
'intermediate_size': 5504,
|
||||||
|
'n_layers': 22,
|
||||||
|
'n_heads': 16,
|
||||||
|
'norm_eps': 1e-6,
|
||||||
|
},
|
||||||
|
'3b': {
|
||||||
|
'dim': 3200,
|
||||||
|
'intermediate_size': 8640,
|
||||||
|
'n_layers': 26,
|
||||||
|
'n_heads': 32,
|
||||||
|
'norm_eps': 1e-6,
|
||||||
|
},
|
||||||
|
"7b": {
|
||||||
|
"dim": 4096,
|
||||||
|
"intermediate_size": 11008,
|
||||||
|
"n_layers": 32,
|
||||||
|
"n_heads": 32,
|
||||||
|
"norm_eps": 1e-6,
|
||||||
|
},
|
||||||
|
"13b": {
|
||||||
|
"dim": 5120,
|
||||||
|
"intermediate_size": 13824,
|
||||||
|
"n_layers": 40,
|
||||||
|
"n_heads": 40,
|
||||||
|
"norm_eps": 1e-6,
|
||||||
|
},
|
||||||
|
"30b": {
|
||||||
|
"dim": 6656,
|
||||||
|
"intermediate_size": 17920,
|
||||||
|
"n_layers": 60,
|
||||||
|
"n_heads": 52,
|
||||||
|
"norm_eps": 1e-6,
|
||||||
|
},
|
||||||
|
"65b": {
|
||||||
|
"dim": 8192,
|
||||||
|
"intermediate_size": 22016,
|
||||||
|
"n_layers": 80,
|
||||||
|
"n_heads": 64,
|
||||||
|
"norm_eps": 1e-5,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def inverse_permute(params, w):
|
||||||
|
n_layers = params["n_layers"]
|
||||||
|
n_heads = params["n_heads"]
|
||||||
|
dim = params["dim"]
|
||||||
|
reshaped_w = w.reshape(n_heads, 2, dim // n_heads // 2, dim)
|
||||||
|
transposed_w = reshaped_w.transpose(0, 2, 1, 3)
|
||||||
|
inverted_w = transposed_w.reshape(dim, dim)
|
||||||
|
return inverted_w
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
start = time.time()
|
||||||
|
params = LLAMA_STANDARD_CONFIGS[args.model_size]
|
||||||
|
|
||||||
|
ckpt_paths = sorted(Path(args.checkpoint_dir).glob("*.bin"))
|
||||||
|
ckpt = {}
|
||||||
|
for i, ckpt_path in enumerate(ckpt_paths):
|
||||||
|
checkpoint = torch.load(ckpt_path, map_location="cpu")
|
||||||
|
for k, v in checkpoint.items():
|
||||||
|
if k.startswith("model."):
|
||||||
|
k = k[6:]
|
||||||
|
ckpt[k] = v
|
||||||
|
print(f"Start convert weight to easylm format...")
|
||||||
|
jax_weights = {
|
||||||
|
"transformer": {
|
||||||
|
"wte": {"embedding": ckpt["embed_tokens.weight"].numpy()},
|
||||||
|
"ln_f": {"kernel": ckpt["norm.weight"].numpy()},
|
||||||
|
"h": {
|
||||||
|
"%d"
|
||||||
|
% (layer): {
|
||||||
|
"attention": {
|
||||||
|
"wq": {
|
||||||
|
"kernel": inverse_permute(
|
||||||
|
params,
|
||||||
|
ckpt[f"layers.{layer}.self_attn.q_proj.weight"].numpy(),
|
||||||
|
).transpose()
|
||||||
|
},
|
||||||
|
"wk": {
|
||||||
|
"kernel": inverse_permute(
|
||||||
|
params,
|
||||||
|
ckpt[f"layers.{layer}.self_attn.k_proj.weight"].numpy(),
|
||||||
|
).transpose()
|
||||||
|
},
|
||||||
|
"wv": {
|
||||||
|
"kernel": ckpt[f"layers.{layer}.self_attn.v_proj.weight"]
|
||||||
|
.numpy()
|
||||||
|
.transpose()
|
||||||
|
},
|
||||||
|
"wo": {
|
||||||
|
"kernel": ckpt[f"layers.{layer}.self_attn.o_proj.weight"]
|
||||||
|
.numpy()
|
||||||
|
.transpose()
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"feed_forward": {
|
||||||
|
"w1": {
|
||||||
|
"kernel": ckpt[f"layers.{layer}.mlp.gate_proj.weight"]
|
||||||
|
.numpy()
|
||||||
|
.transpose()
|
||||||
|
},
|
||||||
|
"w2": {
|
||||||
|
"kernel": ckpt[f"layers.{layer}.mlp.down_proj.weight"]
|
||||||
|
.numpy()
|
||||||
|
.transpose()
|
||||||
|
},
|
||||||
|
"w3": {
|
||||||
|
"kernel": ckpt[f"layers.{layer}.mlp.up_proj.weight"]
|
||||||
|
.numpy()
|
||||||
|
.transpose()
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"attention_norm": {
|
||||||
|
"kernel": ckpt[f"layers.{layer}.input_layernorm.weight"].numpy()
|
||||||
|
},
|
||||||
|
"ffn_norm": {
|
||||||
|
"kernel": ckpt[
|
||||||
|
f"layers.{layer}.post_attention_layernorm.weight"
|
||||||
|
].numpy()
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for layer in range(params["n_layers"])
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"lm_head": {"kernel": ckpt["lm_head.weight"].numpy().transpose()},
|
||||||
|
}
|
||||||
|
print(f"Convert weight to easylm format finished...")
|
||||||
|
print(f"Start to save...")
|
||||||
|
|
||||||
|
if args.streaming:
|
||||||
|
StreamingCheckpointer.save_train_state_to_file(jax_weights, args.output_file)
|
||||||
|
else:
|
||||||
|
with mlxu.open_file(args.output_file, "wb") as fout:
|
||||||
|
fout.write(flax.serialization.msgpack_serialize(jax_weights, in_place=True))
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Save finished!!! take time: {time.time() - start} save path: {args.output_file}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="hf to easylm format script")
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--checkpoint_dir",
|
||||||
|
type=str,
|
||||||
|
help="Need to be converted model weight dir. it is a dir",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output_file", type=str, help="Save model weight file path, it is a file."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model_size",
|
||||||
|
type=str,
|
||||||
|
default="7b",
|
||||||
|
choices=["7b", "13b", "30b", "65b"],
|
||||||
|
help="model size",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--streaming",
|
||||||
|
action="store_true",
|
||||||
|
default=True,
|
||||||
|
help="whether is model weight saved stream format",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
print(f"checkpoint_dir: {args.checkpoint_dir}")
|
||||||
|
print(f"output_file: {args.output_file}")
|
||||||
|
print(f"model_size: {args.model_size}")
|
||||||
|
print(f"streaming: {args.streaming}")
|
||||||
|
|
||||||
|
main(args)
|
||||||
68
EasyLM/models/llama/convert_torch_to_easylm.py
Normal file
68
EasyLM/models/llama/convert_torch_to_easylm.py
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
# This script converts the standrd LLaMA PyTorch checkpoint released by Meta
|
||||||
|
# to the EasyLM checkpoint format. The converted checkpoint can then be loaded
|
||||||
|
# by EasyLM for fine-tuning or inference.
|
||||||
|
|
||||||
|
# This script is largely borrow from https://github.com/Sea-Snell/JAX_llama
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
import json
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import flax
|
||||||
|
import mlxu
|
||||||
|
|
||||||
|
from EasyLM.checkpoint import StreamingCheckpointer
|
||||||
|
|
||||||
|
|
||||||
|
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
|
||||||
|
checkpoint_dir='',
|
||||||
|
output_file='',
|
||||||
|
streaming=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main(argv):
|
||||||
|
ckpt_paths = sorted(Path(FLAGS.checkpoint_dir).glob("*.pth"))
|
||||||
|
ckpts = {}
|
||||||
|
for i, ckpt_path in enumerate(ckpt_paths):
|
||||||
|
checkpoint = torch.load(ckpt_path, map_location="cpu")
|
||||||
|
ckpts[int(ckpt_path.name.split('.', maxsplit=2)[1])] = checkpoint
|
||||||
|
ckpts = [ckpts[i] for i in sorted(list(ckpts.keys()))]
|
||||||
|
with open(Path(FLAGS.checkpoint_dir) / "params.json", "r") as f:
|
||||||
|
params = json.loads(f.read())
|
||||||
|
|
||||||
|
jax_weights = {
|
||||||
|
'transformer': {
|
||||||
|
'wte': {'embedding': np.concatenate([ckpt['tok_embeddings.weight'].numpy() for ckpt in ckpts], axis=1)},
|
||||||
|
'ln_f': {'kernel': ckpts[0]['norm.weight'].numpy()},
|
||||||
|
'h': {
|
||||||
|
'%d' % (layer): {
|
||||||
|
'attention': {
|
||||||
|
'wq': {'kernel': np.concatenate([ckpt['layers.%d.attention.wq.weight' % (layer)].numpy() for ckpt in ckpts], axis=0).transpose()},
|
||||||
|
'wk': {'kernel': np.concatenate([ckpt['layers.%d.attention.wk.weight' % (layer)].numpy() for ckpt in ckpts], axis=0).transpose()},
|
||||||
|
'wv': {'kernel': np.concatenate([ckpt['layers.%d.attention.wv.weight' % (layer)].numpy() for ckpt in ckpts], axis=0).transpose()},
|
||||||
|
'wo': {'kernel': np.concatenate([ckpt['layers.%d.attention.wo.weight' % (layer)].numpy() for ckpt in ckpts], axis=1).transpose()},
|
||||||
|
},
|
||||||
|
'feed_forward': {
|
||||||
|
'w1': {'kernel': np.concatenate([ckpt['layers.%d.feed_forward.w1.weight' % (layer)].numpy() for ckpt in ckpts], axis=0).transpose()},
|
||||||
|
'w2': {'kernel': np.concatenate([ckpt['layers.%d.feed_forward.w2.weight' % (layer)].numpy() for ckpt in ckpts], axis=1).transpose()},
|
||||||
|
'w3': {'kernel': np.concatenate([ckpt['layers.%d.feed_forward.w3.weight' % (layer)].numpy() for ckpt in ckpts], axis=0).transpose()},
|
||||||
|
},
|
||||||
|
'attention_norm': {'kernel': ckpts[0]['layers.%d.attention_norm.weight' % (layer)].numpy()},
|
||||||
|
'ffn_norm': {'kernel': ckpts[0]['layers.%d.ffn_norm.weight' % (layer)].numpy()},
|
||||||
|
}
|
||||||
|
for layer in range(params['n_layers'])},
|
||||||
|
},
|
||||||
|
'lm_head': {'kernel': np.concatenate([ckpt['output.weight'].numpy() for ckpt in ckpts], axis=0).transpose()},
|
||||||
|
}
|
||||||
|
if FLAGS.streaming:
|
||||||
|
StreamingCheckpointer.save_train_state_to_file(
|
||||||
|
jax_weights, FLAGS.output_file
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
with mlxu.open_file(FLAGS.output_file, 'wb') as fout:
|
||||||
|
fout.write(flax.serialization.msgpack_serialize(jax_weights, in_place=True))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
mlxu.run(main)
|
||||||
1530
EasyLM/models/llama/llama_model.py
Normal file
1530
EasyLM/models/llama/llama_model.py
Normal file
File diff suppressed because it is too large
Load Diff
386
EasyLM/models/llama/llama_serve.py
Normal file
386
EasyLM/models/llama/llama_serve.py
Normal file
@@ -0,0 +1,386 @@
|
|||||||
|
import pprint
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import mlxu
|
||||||
|
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
from jax.experimental.pjit import pjit
|
||||||
|
from jax.sharding import PartitionSpec as PS
|
||||||
|
import optax
|
||||||
|
from transformers import GenerationConfig, FlaxLogitsProcessorList
|
||||||
|
|
||||||
|
from EasyLM.checkpoint import StreamingCheckpointer
|
||||||
|
from EasyLM.serving import LMServer
|
||||||
|
from EasyLM.jax_utils import (
|
||||||
|
JaxRNG, JaxDistributedConfig, next_rng, match_partition_rules, tree_apply,
|
||||||
|
set_random_seed, get_float_dtype_by_name, make_shard_and_gather_fns,
|
||||||
|
with_sharding_constraint, FlaxTemperatureLogitsWarper
|
||||||
|
)
|
||||||
|
from EasyLM.models.llama.llama_model import LLaMAConfig, FlaxLLaMAForCausalLM
|
||||||
|
|
||||||
|
|
||||||
|
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
|
||||||
|
seed=42,
|
||||||
|
initialize_jax_distributed=False,
|
||||||
|
mesh_dim='1,-1,1',
|
||||||
|
dtype='bf16',
|
||||||
|
input_length=1024,
|
||||||
|
seq_length=2048,
|
||||||
|
top_k=50,
|
||||||
|
top_p=1.0,
|
||||||
|
do_sample=True,
|
||||||
|
num_beams=1,
|
||||||
|
add_bos_token=True,
|
||||||
|
load_llama_config='',
|
||||||
|
load_checkpoint='',
|
||||||
|
tokenizer=LLaMAConfig.get_tokenizer_config(),
|
||||||
|
lm_server=LMServer.get_default_config(),
|
||||||
|
jax_distributed=JaxDistributedConfig.get_default_config(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main(argv):
|
||||||
|
JaxDistributedConfig.initialize(FLAGS.jax_distributed)
|
||||||
|
set_random_seed(FLAGS.seed)
|
||||||
|
|
||||||
|
prefix_tokenizer = LLaMAConfig.get_tokenizer(
|
||||||
|
FLAGS.tokenizer, truncation_side='left', padding_side='left'
|
||||||
|
)
|
||||||
|
tokenizer = LLaMAConfig.get_tokenizer(
|
||||||
|
FLAGS.tokenizer, truncation_side='right', padding_side='right'
|
||||||
|
)
|
||||||
|
|
||||||
|
with jax.default_device(jax.devices("cpu")[0]):
|
||||||
|
llama_config = LLaMAConfig.load_config(FLAGS.load_llama_config)
|
||||||
|
_, params = StreamingCheckpointer.load_trainstate_checkpoint(
|
||||||
|
FLAGS.load_checkpoint, disallow_trainstate=True
|
||||||
|
)
|
||||||
|
|
||||||
|
hf_model = FlaxLLaMAForCausalLM(
|
||||||
|
llama_config,
|
||||||
|
input_shape=(1, FLAGS.seq_length),
|
||||||
|
seed=FLAGS.seed,
|
||||||
|
_do_init=False
|
||||||
|
)
|
||||||
|
|
||||||
|
model_ps = match_partition_rules(
|
||||||
|
LLaMAConfig.get_partition_rules(), params
|
||||||
|
)
|
||||||
|
shard_fns, _ = make_shard_and_gather_fns(
|
||||||
|
model_ps, get_float_dtype_by_name(FLAGS.dtype)
|
||||||
|
)
|
||||||
|
|
||||||
|
@partial(
|
||||||
|
pjit,
|
||||||
|
in_shardings=(model_ps, PS(), PS()),
|
||||||
|
out_shardings=(PS(), PS(), PS())
|
||||||
|
)
|
||||||
|
def forward_loglikelihood(params, rng, batch):
|
||||||
|
batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
|
||||||
|
rng_generator = JaxRNG(rng)
|
||||||
|
input_tokens = batch['input_tokens']
|
||||||
|
output_tokens = batch['output_tokens']
|
||||||
|
input_mask = batch['input_mask']
|
||||||
|
output_mask = batch['output_mask']
|
||||||
|
|
||||||
|
logits = hf_model.module.apply(
|
||||||
|
params, input_tokens, attention_mask=input_mask,
|
||||||
|
deterministic=True, rngs=rng_generator(llama_config.rng_keys()),
|
||||||
|
).logits
|
||||||
|
# if llama_config.n_real_tokens is not None:
|
||||||
|
# logits = logits.at[:, :, llama_config.n_real_tokens:].set(-1e8)
|
||||||
|
loglikelihood = -optax.softmax_cross_entropy_with_integer_labels(
|
||||||
|
logits, output_tokens
|
||||||
|
)
|
||||||
|
loglikelihood = jnp.sum(loglikelihood * output_mask, axis=-1)
|
||||||
|
match_count = jnp.sum(
|
||||||
|
(jnp.argmax(logits, axis=-1) == output_tokens) * output_mask,
|
||||||
|
axis=-1
|
||||||
|
)
|
||||||
|
total = jnp.sum(output_mask, axis=-1)
|
||||||
|
is_greedy = match_count == total
|
||||||
|
return loglikelihood, is_greedy, rng_generator()
|
||||||
|
|
||||||
|
|
||||||
|
@partial(
|
||||||
|
pjit,
|
||||||
|
in_shardings=(model_ps, PS(), PS(), PS()),
|
||||||
|
out_shardings=(PS(), PS())
|
||||||
|
)
|
||||||
|
def forward_generate(params, rng, batch, temperature):
|
||||||
|
batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
|
||||||
|
rng_generator = JaxRNG(rng)
|
||||||
|
output = hf_model.generate(
|
||||||
|
batch['input_tokens'],
|
||||||
|
attention_mask=batch['attention_mask'],
|
||||||
|
params=params['params'],
|
||||||
|
prng_key=rng_generator(),
|
||||||
|
logits_processor=FlaxLogitsProcessorList(
|
||||||
|
[FlaxTemperatureLogitsWarper(temperature)]
|
||||||
|
),
|
||||||
|
generation_config=GenerationConfig(
|
||||||
|
max_new_tokens=FLAGS.seq_length - FLAGS.input_length,
|
||||||
|
pad_token_id=tokenizer.eos_token_id,
|
||||||
|
bos_token_id=tokenizer.bos_token_id,
|
||||||
|
eos_token_id=tokenizer.eos_token_id,
|
||||||
|
do_sample=FLAGS.do_sample,
|
||||||
|
num_beams=FLAGS.num_beams,
|
||||||
|
top_k=FLAGS.top_k,
|
||||||
|
top_p=FLAGS.top_p,
|
||||||
|
)
|
||||||
|
).sequences[:, batch['input_tokens'].shape[1]:]
|
||||||
|
return output, rng_generator()
|
||||||
|
|
||||||
|
@partial(
|
||||||
|
pjit,
|
||||||
|
in_shardings=(model_ps, PS(), PS()),
|
||||||
|
out_shardings=(PS(), PS())
|
||||||
|
)
|
||||||
|
def forward_greedy_generate(params, rng, batch):
|
||||||
|
batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
|
||||||
|
rng_generator = JaxRNG(rng)
|
||||||
|
output = hf_model.generate(
|
||||||
|
batch['input_tokens'],
|
||||||
|
attention_mask=batch['attention_mask'],
|
||||||
|
params=params['params'],
|
||||||
|
prng_key=rng_generator(),
|
||||||
|
generation_config=GenerationConfig(
|
||||||
|
max_new_tokens=FLAGS.seq_length - FLAGS.input_length,
|
||||||
|
pad_token_id=tokenizer.eos_token_id,
|
||||||
|
bos_token_id=tokenizer.bos_token_id,
|
||||||
|
eos_token_id=tokenizer.eos_token_id,
|
||||||
|
do_sample=False,
|
||||||
|
num_beams=1,
|
||||||
|
)
|
||||||
|
).sequences[:, batch['input_tokens'].shape[1]:]
|
||||||
|
return output, rng_generator()
|
||||||
|
|
||||||
|
mesh = LLaMAConfig.get_jax_mesh(FLAGS.mesh_dim)
|
||||||
|
with mesh:
|
||||||
|
params = tree_apply(shard_fns, params)
|
||||||
|
sharded_rng = next_rng()
|
||||||
|
|
||||||
|
class ModelServer(LMServer):
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def loglikelihood(prefix_text, text):
|
||||||
|
nonlocal sharded_rng
|
||||||
|
prefix = prefix_tokenizer(
|
||||||
|
prefix_text,
|
||||||
|
padding='max_length',
|
||||||
|
truncation=True,
|
||||||
|
max_length=FLAGS.input_length,
|
||||||
|
return_tensors='np',
|
||||||
|
)
|
||||||
|
inputs = tokenizer(
|
||||||
|
text,
|
||||||
|
padding='max_length',
|
||||||
|
truncation=True,
|
||||||
|
max_length=FLAGS.seq_length - FLAGS.input_length,
|
||||||
|
return_tensors='np',
|
||||||
|
)
|
||||||
|
output_tokens = np.concatenate([prefix.input_ids, inputs.input_ids], axis=1)
|
||||||
|
bos_tokens = np.full(
|
||||||
|
(output_tokens.shape[0], 1), tokenizer.bos_token_id, dtype=np.int32
|
||||||
|
)
|
||||||
|
input_tokens = np.concatenate([bos_tokens, output_tokens[:, :-1]], axis=-1)
|
||||||
|
input_mask = np.concatenate(
|
||||||
|
[prefix.attention_mask, inputs.attention_mask], axis=1
|
||||||
|
)
|
||||||
|
if FLAGS.add_bos_token:
|
||||||
|
bos_mask = np.ones_like(input_mask[:, :1])
|
||||||
|
else:
|
||||||
|
bos_mask = np.zeros_like(input_mask[:, :1])
|
||||||
|
|
||||||
|
input_mask = np.concatenate([bos_mask, input_mask[:, :-1]], axis=1)
|
||||||
|
output_mask = np.concatenate(
|
||||||
|
[np.zeros_like(prefix.attention_mask), inputs.attention_mask], axis=1
|
||||||
|
)
|
||||||
|
batch = dict(
|
||||||
|
input_tokens=input_tokens,
|
||||||
|
output_tokens=output_tokens,
|
||||||
|
input_mask=input_mask,
|
||||||
|
output_mask=output_mask,
|
||||||
|
)
|
||||||
|
with mesh:
|
||||||
|
loglikelihood, is_greedy, sharded_rng = forward_loglikelihood(
|
||||||
|
params, sharded_rng, batch
|
||||||
|
)
|
||||||
|
loglikelihood, is_greedy = jax.device_get((loglikelihood, is_greedy))
|
||||||
|
return loglikelihood, is_greedy
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def loglikelihood_rolling(text):
|
||||||
|
nonlocal sharded_rng
|
||||||
|
inputs = tokenizer(
|
||||||
|
text,
|
||||||
|
padding='longest',
|
||||||
|
truncation=False,
|
||||||
|
max_length=np.iinfo(np.int32).max,
|
||||||
|
return_tensors='np',
|
||||||
|
)
|
||||||
|
batch_size = inputs.input_ids.shape[0]
|
||||||
|
output_tokens = inputs.input_ids
|
||||||
|
attention_mask = inputs.attention_mask
|
||||||
|
|
||||||
|
if output_tokens.shape[1] < FLAGS.seq_length:
|
||||||
|
padding_length = FLAGS.seq_length - output_tokens.shape[1]
|
||||||
|
pad_tokens = np.full(
|
||||||
|
(batch_size, padding_length), tokenizer.pad_token_id, dtype=np.int32
|
||||||
|
)
|
||||||
|
output_tokens = np.concatenate([output_tokens, pad_tokens], axis=-1)
|
||||||
|
pad_mask = np.zeros(
|
||||||
|
(batch_size, padding_length), dtype=inputs.attention_mask.dtype
|
||||||
|
)
|
||||||
|
attention_mask = np.concatenate([attention_mask, pad_mask], axis=-1)
|
||||||
|
|
||||||
|
bos_tokens = np.full(
|
||||||
|
(batch_size, 1), tokenizer.bos_token_id, dtype=np.int32
|
||||||
|
)
|
||||||
|
input_tokens = np.concatenate([bos_tokens, output_tokens[:, :-1]], axis=-1)
|
||||||
|
bos_mask = np.ones((batch_size, 1), dtype=inputs.attention_mask.dtype)
|
||||||
|
total_seq_length = output_tokens.shape[1]
|
||||||
|
|
||||||
|
total_loglikelihood = 0.0
|
||||||
|
total_is_greedy = True
|
||||||
|
# Sliding window
|
||||||
|
for i in range(0, total_seq_length, FLAGS.seq_length):
|
||||||
|
# Last window
|
||||||
|
if i + FLAGS.seq_length > total_seq_length:
|
||||||
|
last_output_mask = np.copy(attention_mask[:, -FLAGS.seq_length:])
|
||||||
|
last_output_mask[:, :i - total_seq_length] = 0.0
|
||||||
|
|
||||||
|
batch = dict(
|
||||||
|
input_tokens=input_tokens[:, -FLAGS.seq_length:],
|
||||||
|
output_tokens=output_tokens[:, -FLAGS.seq_length:],
|
||||||
|
input_mask=attention_mask[:, -FLAGS.seq_length:],
|
||||||
|
output_mask=last_output_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Normal window
|
||||||
|
else:
|
||||||
|
batch = dict(
|
||||||
|
input_tokens=input_tokens[:, i:i + FLAGS.seq_length],
|
||||||
|
output_tokens=output_tokens[:, i:i + FLAGS.seq_length],
|
||||||
|
input_mask=attention_mask[:, i:i + FLAGS.seq_length],
|
||||||
|
output_mask=attention_mask[:, i:i + FLAGS.seq_length],
|
||||||
|
)
|
||||||
|
|
||||||
|
with mesh:
|
||||||
|
loglikelihood, is_greedy, sharded_rng = forward_loglikelihood(
|
||||||
|
params, sharded_rng, batch
|
||||||
|
)
|
||||||
|
loglikelihood, is_greedy = jax.device_get((loglikelihood, is_greedy))
|
||||||
|
|
||||||
|
total_loglikelihood += loglikelihood
|
||||||
|
total_is_greedy = np.logical_and(is_greedy, total_is_greedy)
|
||||||
|
|
||||||
|
return total_loglikelihood, total_is_greedy
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def generate(text, temperature):
|
||||||
|
nonlocal sharded_rng
|
||||||
|
inputs = prefix_tokenizer(
|
||||||
|
text,
|
||||||
|
padding='max_length',
|
||||||
|
truncation=True,
|
||||||
|
max_length=FLAGS.input_length,
|
||||||
|
return_tensors='np',
|
||||||
|
)
|
||||||
|
input_tokens = inputs.input_ids
|
||||||
|
input_mask = inputs.attention_mask
|
||||||
|
if FLAGS.add_bos_token:
|
||||||
|
input_tokens[:, 0] = tokenizer.bos_token_id
|
||||||
|
input_mask[:, 0] = 1
|
||||||
|
batch = dict(
|
||||||
|
input_tokens=input_tokens,
|
||||||
|
attention_mask=input_mask,
|
||||||
|
)
|
||||||
|
with mesh:
|
||||||
|
output, sharded_rng = forward_generate(
|
||||||
|
params, sharded_rng, batch, temperature
|
||||||
|
)
|
||||||
|
output = jax.device_get(output)
|
||||||
|
output_text = []
|
||||||
|
for text in list(tokenizer.batch_decode(output)):
|
||||||
|
if tokenizer.eos_token in text:
|
||||||
|
text = text.split(tokenizer.eos_token, maxsplit=1)[0]
|
||||||
|
output_text.append(text)
|
||||||
|
|
||||||
|
return output_text
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def greedy_until(prefix_text, until, max_length):
|
||||||
|
nonlocal sharded_rng
|
||||||
|
all_outputs = []
|
||||||
|
for pf, ut in zip(prefix_text, until):
|
||||||
|
if isinstance(ut, str):
|
||||||
|
ut = [ut]
|
||||||
|
total_length = 0
|
||||||
|
total_generated = ''
|
||||||
|
|
||||||
|
while total_length < max_length:
|
||||||
|
pf_tokens = tokenizer(
|
||||||
|
pf,
|
||||||
|
padding=False,
|
||||||
|
truncation=False,
|
||||||
|
max_length=np.iinfo(np.int32).max,
|
||||||
|
return_tensors='np',
|
||||||
|
)
|
||||||
|
input_tokens = pf_tokens.input_ids
|
||||||
|
attention_mask = pf_tokens.attention_mask
|
||||||
|
|
||||||
|
if input_tokens.shape[1] < FLAGS.input_length:
|
||||||
|
extra = FLAGS.input_length - input_tokens.shape[1]
|
||||||
|
pad_tokens = np.full(
|
||||||
|
(1, extra), tokenizer.pad_token_id, dtype=np.int32
|
||||||
|
)
|
||||||
|
input_tokens = np.concatenate(
|
||||||
|
[pad_tokens, input_tokens], axis=1
|
||||||
|
)
|
||||||
|
pad_attention = np.zeros((1, extra), dtype=attention_mask.dtype)
|
||||||
|
attention_mask = np.concatenate(
|
||||||
|
[pad_attention, attention_mask], axis=1
|
||||||
|
)
|
||||||
|
elif input_tokens.shape[1] > FLAGS.input_length:
|
||||||
|
input_tokens = input_tokens[:, -FLAGS.input_length:]
|
||||||
|
attention_mask = attention_mask[:, -FLAGS.input_length:]
|
||||||
|
|
||||||
|
if FLAGS.add_bos_token:
|
||||||
|
input_tokens[:, 0] = tokenizer.bos_token_id
|
||||||
|
attention_mask[:, 0] = 1
|
||||||
|
|
||||||
|
batch = dict(input_tokens=input_tokens, attention_mask=attention_mask)
|
||||||
|
|
||||||
|
with mesh:
|
||||||
|
output, sharded_rng = forward_greedy_generate(
|
||||||
|
params, sharded_rng, batch
|
||||||
|
)
|
||||||
|
output = jax.device_get(output)
|
||||||
|
|
||||||
|
total_length += output.shape[1]
|
||||||
|
output_text = tokenizer.batch_decode(output)[0]
|
||||||
|
total_generated = total_generated + output_text
|
||||||
|
pf = pf + output_text
|
||||||
|
|
||||||
|
done = False
|
||||||
|
for s in ut:
|
||||||
|
if s in total_generated:
|
||||||
|
total_generated = total_generated.split(s, maxsplit=1)[0]
|
||||||
|
done = True
|
||||||
|
if done:
|
||||||
|
break
|
||||||
|
|
||||||
|
all_outputs.append(total_generated)
|
||||||
|
|
||||||
|
return all_outputs
|
||||||
|
|
||||||
|
|
||||||
|
server = ModelServer(FLAGS.lm_server)
|
||||||
|
server.run()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
mlxu.run(main)
|
||||||
268
EasyLM/models/llama/llama_train.py
Normal file
268
EasyLM/models/llama/llama_train.py
Normal file
@@ -0,0 +1,268 @@
|
|||||||
|
import pprint
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
from tqdm import tqdm, trange
|
||||||
|
import numpy as np
|
||||||
|
import mlxu
|
||||||
|
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
from jax.experimental.pjit import pjit
|
||||||
|
from jax.sharding import PartitionSpec as PS
|
||||||
|
from flax.training.train_state import TrainState
|
||||||
|
|
||||||
|
from EasyLM.data import DatasetFactory
|
||||||
|
from EasyLM.checkpoint import StreamingCheckpointer
|
||||||
|
from EasyLM.optimizers import OptimizerFactory
|
||||||
|
from EasyLM.jax_utils import (
|
||||||
|
JaxRNG, JaxDistributedConfig, next_rng, match_partition_rules,
|
||||||
|
cross_entropy_loss_and_accuracy, global_norm, get_float_dtype_by_name,
|
||||||
|
set_random_seed, average_metrics, get_weight_decay_mask,
|
||||||
|
make_shard_and_gather_fns, with_sharding_constraint,
|
||||||
|
)
|
||||||
|
from EasyLM.models.llama.llama_model import (
|
||||||
|
LLaMAConfig, FlaxLLaMAForCausalLMModule
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
|
||||||
|
seed=42,
|
||||||
|
mesh_dim='1,-1,1',
|
||||||
|
dtype='fp32',
|
||||||
|
param_dtype='fp32',
|
||||||
|
total_steps=10000,
|
||||||
|
load_llama_config='',
|
||||||
|
update_llama_config='',
|
||||||
|
load_checkpoint='',
|
||||||
|
load_dataset_state='',
|
||||||
|
log_freq=50,
|
||||||
|
save_model_freq=0,
|
||||||
|
save_milestone_freq=0,
|
||||||
|
eval_freq=0,
|
||||||
|
tokenizer=LLaMAConfig.get_tokenizer_config(),
|
||||||
|
train_dataset=DatasetFactory.get_default_config(),
|
||||||
|
eval_dataset=DatasetFactory.get_default_config(),
|
||||||
|
optimizer=OptimizerFactory.get_default_config(),
|
||||||
|
checkpointer=StreamingCheckpointer.get_default_config(),
|
||||||
|
llama=LLaMAConfig.get_default_config(),
|
||||||
|
logger=mlxu.WandBLogger.get_default_config(),
|
||||||
|
log_all_worker=False,
|
||||||
|
jax_distributed=JaxDistributedConfig.get_default_config(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main(argv):
|
||||||
|
JaxDistributedConfig.initialize(FLAGS.jax_distributed)
|
||||||
|
variant = mlxu.get_user_flags(FLAGS, FLAGS_DEF)
|
||||||
|
flags_config_dict = mlxu.user_flags_to_config_dict(FLAGS, FLAGS_DEF)
|
||||||
|
logger = mlxu.WandBLogger(
|
||||||
|
config=FLAGS.logger,
|
||||||
|
variant=variant,
|
||||||
|
enable=FLAGS.log_all_worker or (jax.process_index() == 0),
|
||||||
|
)
|
||||||
|
set_random_seed(FLAGS.seed)
|
||||||
|
|
||||||
|
tokenizer = LLaMAConfig.get_tokenizer(FLAGS.tokenizer)
|
||||||
|
dataset = DatasetFactory.load_dataset(FLAGS.train_dataset, tokenizer)
|
||||||
|
if FLAGS.load_dataset_state != '':
|
||||||
|
dataset.load_state_dict(mlxu.load_pickle(FLAGS.load_dataset_state))
|
||||||
|
|
||||||
|
if FLAGS.eval_freq > 0:
|
||||||
|
eval_dataset = DatasetFactory.load_dataset(
|
||||||
|
FLAGS.eval_dataset, dataset.tokenizer, eval_dataset=True
|
||||||
|
)
|
||||||
|
|
||||||
|
seq_length = dataset.seq_length
|
||||||
|
|
||||||
|
if FLAGS.load_llama_config != '':
|
||||||
|
llama_config = LLaMAConfig.load_config(FLAGS.load_llama_config)
|
||||||
|
else:
|
||||||
|
llama_config = LLaMAConfig(**FLAGS.llama)
|
||||||
|
|
||||||
|
if FLAGS.update_llama_config != '':
|
||||||
|
llama_config.update(dict(eval(FLAGS.update_llama_config)))
|
||||||
|
|
||||||
|
llama_config.update(dict(
|
||||||
|
bos_token_id=dataset.tokenizer.bos_token_id,
|
||||||
|
eos_token_id=dataset.tokenizer.eos_token_id,
|
||||||
|
))
|
||||||
|
if llama_config.vocab_size < dataset.vocab_size:
|
||||||
|
print("Updating model config vocab size from", llama_config.vocab_size, "to", dataset.vocab_size)
|
||||||
|
llama_config.update(dict(vocab_size=dataset.vocab_size))
|
||||||
|
|
||||||
|
model = FlaxLLaMAForCausalLMModule(
|
||||||
|
llama_config, dtype=get_float_dtype_by_name(FLAGS.dtype), param_dtype=get_float_dtype_by_name(FLAGS.param_dtype)
|
||||||
|
)
|
||||||
|
|
||||||
|
optimizer, optimizer_info = OptimizerFactory.get_optimizer(
|
||||||
|
FLAGS.optimizer,
|
||||||
|
get_weight_decay_mask(LLaMAConfig.get_weight_decay_exclusions())
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_trainstate_from_params(params):
|
||||||
|
return TrainState.create(params=params, tx=optimizer, apply_fn=None)
|
||||||
|
|
||||||
|
def init_fn(rng):
|
||||||
|
rng_generator = JaxRNG(rng)
|
||||||
|
params = model.init(
|
||||||
|
input_ids=jnp.zeros((4, seq_length), dtype=jnp.int32),
|
||||||
|
position_ids=jnp.zeros((4, seq_length), dtype=jnp.int32),
|
||||||
|
attention_mask=jnp.ones((4, seq_length), dtype=jnp.int32),
|
||||||
|
rngs=rng_generator(llama_config.rng_keys()),
|
||||||
|
)
|
||||||
|
return TrainState.create(params=params, tx=optimizer, apply_fn=None)
|
||||||
|
|
||||||
|
def train_step(train_state, rng, batch):
|
||||||
|
rng_generator = JaxRNG(rng)
|
||||||
|
batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
|
||||||
|
def loss_and_accuracy(params):
|
||||||
|
logits = model.apply(
|
||||||
|
params, batch['input_tokens'], deterministic=False,
|
||||||
|
rngs=rng_generator(llama_config.rng_keys()),
|
||||||
|
).logits
|
||||||
|
return cross_entropy_loss_and_accuracy(
|
||||||
|
logits, batch['target_tokens'], batch['loss_masks']
|
||||||
|
)
|
||||||
|
grad_fn = jax.value_and_grad(loss_and_accuracy, has_aux=True)
|
||||||
|
(loss, accuracy), grads = grad_fn(train_state.params)
|
||||||
|
train_state = train_state.apply_gradients(grads=grads)
|
||||||
|
metrics = dict(
|
||||||
|
loss=loss,
|
||||||
|
accuracy=accuracy,
|
||||||
|
learning_rate=optimizer_info['learning_rate_schedule'](train_state.step),
|
||||||
|
gradient_norm=global_norm(grads),
|
||||||
|
param_norm=global_norm(train_state.params),
|
||||||
|
)
|
||||||
|
return train_state, rng_generator(), metrics
|
||||||
|
|
||||||
|
def eval_step(train_state, rng, batch):
|
||||||
|
rng_generator = JaxRNG(rng)
|
||||||
|
batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
|
||||||
|
logits = model.apply(
|
||||||
|
train_state.params, batch['input_tokens'], deterministic=True,
|
||||||
|
rngs=rng_generator(llama_config.rng_keys()),
|
||||||
|
).logits
|
||||||
|
loss, accuracy = cross_entropy_loss_and_accuracy(
|
||||||
|
logits, batch['target_tokens'], batch['loss_masks']
|
||||||
|
)
|
||||||
|
metrics = dict(
|
||||||
|
eval_loss=loss,
|
||||||
|
eval_accuracy=accuracy,
|
||||||
|
)
|
||||||
|
return rng_generator(), metrics
|
||||||
|
|
||||||
|
train_state_shapes = jax.eval_shape(init_fn, next_rng())
|
||||||
|
train_state_partition = match_partition_rules(
|
||||||
|
LLaMAConfig.get_partition_rules(), train_state_shapes
|
||||||
|
)
|
||||||
|
|
||||||
|
shard_fns, gather_fns = make_shard_and_gather_fns(
|
||||||
|
train_state_partition, train_state_shapes
|
||||||
|
)
|
||||||
|
checkpointer = StreamingCheckpointer(
|
||||||
|
FLAGS.checkpointer, logger.output_dir,
|
||||||
|
enable=jax.process_index() == 0,
|
||||||
|
)
|
||||||
|
|
||||||
|
sharded_init_fn = pjit(
|
||||||
|
init_fn,
|
||||||
|
in_shardings=PS(),
|
||||||
|
out_shardings=train_state_partition
|
||||||
|
)
|
||||||
|
|
||||||
|
sharded_create_trainstate_from_params = pjit(
|
||||||
|
create_trainstate_from_params,
|
||||||
|
in_shardings=(train_state_partition.params, ),
|
||||||
|
out_shardings=train_state_partition,
|
||||||
|
donate_argnums=(0, ),
|
||||||
|
)
|
||||||
|
|
||||||
|
sharded_train_step = pjit(
|
||||||
|
train_step,
|
||||||
|
in_shardings=(train_state_partition, PS(), PS()),
|
||||||
|
out_shardings=(train_state_partition, PS(), PS()),
|
||||||
|
donate_argnums=(0, 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
sharded_eval_step = pjit(
|
||||||
|
eval_step,
|
||||||
|
in_shardings=(train_state_partition, PS(), PS()),
|
||||||
|
out_shardings=(PS(), PS()),
|
||||||
|
donate_argnums=(1,),
|
||||||
|
)
|
||||||
|
|
||||||
|
def save_checkpoint(train_state, milestone=False):
|
||||||
|
step = int(jax.device_get(train_state.step))
|
||||||
|
metadata = dict(
|
||||||
|
step=step,
|
||||||
|
variant=variant,
|
||||||
|
flags=flags_config_dict,
|
||||||
|
llama_config=llama_config.to_dict(),
|
||||||
|
)
|
||||||
|
checkpointer.save_all(
|
||||||
|
train_state=train_state,
|
||||||
|
gather_fns=gather_fns,
|
||||||
|
metadata=metadata,
|
||||||
|
dataset=dataset.get_state_dict(),
|
||||||
|
milestone=milestone,
|
||||||
|
)
|
||||||
|
|
||||||
|
mesh = LLaMAConfig.get_jax_mesh(FLAGS.mesh_dim)
|
||||||
|
with mesh:
|
||||||
|
train_state, restored_params = None, None
|
||||||
|
if FLAGS.load_checkpoint != '':
|
||||||
|
train_state, restored_params = checkpointer.load_trainstate_checkpoint(
|
||||||
|
FLAGS.load_checkpoint, train_state_shapes, shard_fns
|
||||||
|
)
|
||||||
|
|
||||||
|
if train_state is None and restored_params is None:
|
||||||
|
# Initialize from scratch
|
||||||
|
train_state = sharded_init_fn(next_rng())
|
||||||
|
elif train_state is None and restored_params is not None:
|
||||||
|
# Restore from params but initialize train_state
|
||||||
|
train_state = sharded_create_trainstate_from_params(restored_params)
|
||||||
|
del restored_params
|
||||||
|
|
||||||
|
start_step = int(jax.device_get(train_state.step))
|
||||||
|
|
||||||
|
if FLAGS.save_model_freq > 0:
|
||||||
|
save_checkpoint(train_state)
|
||||||
|
|
||||||
|
sharded_rng = next_rng()
|
||||||
|
|
||||||
|
step_counter = trange(start_step, FLAGS.total_steps, ncols=0)
|
||||||
|
|
||||||
|
for step, (batch, dataset_metrics) in zip(step_counter, dataset):
|
||||||
|
train_state, sharded_rng, metrics = sharded_train_step(
|
||||||
|
train_state, sharded_rng, batch
|
||||||
|
)
|
||||||
|
|
||||||
|
if FLAGS.eval_freq > 0 and (step + 1) % FLAGS.eval_freq == 0:
|
||||||
|
eval_metric_list = []
|
||||||
|
eval_iterator = iter(eval_dataset)
|
||||||
|
for eval_batch, _ in eval_iterator:
|
||||||
|
sharded_rng, eval_metrics = sharded_eval_step(
|
||||||
|
train_state, sharded_rng, eval_batch
|
||||||
|
)
|
||||||
|
eval_metric_list.append(eval_metrics)
|
||||||
|
metrics.update(average_metrics(eval_metric_list))
|
||||||
|
|
||||||
|
if FLAGS.log_freq > 0 and (step + 1) % FLAGS.log_freq == 0:
|
||||||
|
log_metrics = {"step": step + 1}
|
||||||
|
log_metrics.update(metrics)
|
||||||
|
log_metrics.update(dataset_metrics)
|
||||||
|
log_metrics = jax.device_get(log_metrics)
|
||||||
|
logger.log(log_metrics)
|
||||||
|
tqdm.write("\n" + pprint.pformat(log_metrics) + "\n")
|
||||||
|
|
||||||
|
if FLAGS.save_milestone_freq > 0 and (step + 1) % FLAGS.save_milestone_freq == 0:
|
||||||
|
save_checkpoint(train_state, milestone=True)
|
||||||
|
elif FLAGS.save_model_freq > 0 and (step + 1) % FLAGS.save_model_freq == 0:
|
||||||
|
save_checkpoint(train_state)
|
||||||
|
|
||||||
|
if FLAGS.save_model_freq > 0:
|
||||||
|
save_checkpoint(train_state)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
mlxu.run(main)
|
||||||
0
EasyLM/models/roberta/__init__.py
Normal file
0
EasyLM/models/roberta/__init__.py
Normal file
1694
EasyLM/models/roberta/roberta_model.py
Normal file
1694
EasyLM/models/roberta/roberta_model.py
Normal file
File diff suppressed because it is too large
Load Diff
307
EasyLM/models/roberta/roberta_train.py
Normal file
307
EasyLM/models/roberta/roberta_train.py
Normal file
@@ -0,0 +1,307 @@
|
|||||||
|
import dataclasses
|
||||||
|
import pprint
|
||||||
|
from functools import partial
|
||||||
|
import re
|
||||||
|
|
||||||
|
from tqdm import tqdm, trange
|
||||||
|
import numpy as np
|
||||||
|
import mlxu
|
||||||
|
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
from jax.experimental.pjit import pjit, with_sharding_constraint
|
||||||
|
from jax.sharding import PartitionSpec as PS
|
||||||
|
from flax.training.train_state import TrainState
|
||||||
|
|
||||||
|
from EasyLM.data import DatasetFactory
|
||||||
|
from EasyLM.checkpoint import StreamingCheckpointer
|
||||||
|
from EasyLM.optimizers import OptimizerFactory
|
||||||
|
from EasyLM.jax_utils import (
|
||||||
|
JaxRNG, JaxDistributedConfig, next_rng, match_partition_rules, get_float_dtype_by_name,
|
||||||
|
cross_entropy_loss_and_accuracy, named_tree_map, global_norm,
|
||||||
|
set_random_seed, average_metrics, get_weight_decay_mask,
|
||||||
|
make_shard_and_gather_fns, tree_apply
|
||||||
|
)
|
||||||
|
from EasyLM.models.roberta.roberta_model import (
|
||||||
|
RobertaConfig, FlaxRobertaForMaskedLMModule
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
|
||||||
|
seed=42,
|
||||||
|
mesh_dim='-1,1,1',
|
||||||
|
dtype='fp32',
|
||||||
|
mask_token_probability=0.15,
|
||||||
|
total_steps=10000,
|
||||||
|
load_roberta_config='',
|
||||||
|
update_roberta_config='',
|
||||||
|
load_checkpoint='',
|
||||||
|
load_dataset_state='',
|
||||||
|
log_freq=50,
|
||||||
|
save_model_freq=0,
|
||||||
|
save_milestone_freq=0,
|
||||||
|
eval_steps=0,
|
||||||
|
tokenizer=RobertaConfig.get_tokenizer_config(),
|
||||||
|
train_dataset=DatasetFactory.get_default_config(),
|
||||||
|
eval_dataset=DatasetFactory.get_default_config(),
|
||||||
|
optimizer=OptimizerFactory.get_default_config(),
|
||||||
|
checkpointer=StreamingCheckpointer.get_default_config(),
|
||||||
|
roberta=RobertaConfig.get_default_config(),
|
||||||
|
logger=mlxu.WandBLogger.get_default_config(),
|
||||||
|
log_all_worker=False,
|
||||||
|
jax_distributed=JaxDistributedConfig.get_default_config(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main(argv):
|
||||||
|
JaxDistributedConfig.initialize(FLAGS.jax_distributed)
|
||||||
|
variant = mlxu.get_user_flags(FLAGS, FLAGS_DEF)
|
||||||
|
flags_config_dict = mlxu.user_flags_to_config_dict(FLAGS, FLAGS_DEF)
|
||||||
|
logger = mlxu.WandBLogger(
|
||||||
|
config=FLAGS.logger,
|
||||||
|
variant=variant,
|
||||||
|
enable=FLAGS.log_all_worker or (jax.process_index() == 0),
|
||||||
|
)
|
||||||
|
set_random_seed(FLAGS.seed)
|
||||||
|
|
||||||
|
tokenizer = RobertaConfig.get_tokenizer(FLAGS.tokenizer)
|
||||||
|
dataset = DatasetFactory.load_dataset(FLAGS.train_dataset, tokenizer)
|
||||||
|
if FLAGS.load_dataset_state != '':
|
||||||
|
dataset.load_state_dict(mlxu.load_pickle(FLAGS.load_dataset_state))
|
||||||
|
|
||||||
|
if FLAGS.eval_steps > 0:
|
||||||
|
eval_dataset = DatasetFactory.load_dataset(
|
||||||
|
FLAGS.eval_dataset, dataset.tokenizer
|
||||||
|
)
|
||||||
|
eval_iterator = iter(eval_dataset)
|
||||||
|
|
||||||
|
seq_length = dataset.seq_length
|
||||||
|
|
||||||
|
if FLAGS.load_roberta_config != '':
|
||||||
|
roberta_config = RobertaConfig.load_config(FLAGS.load_roberta_config)
|
||||||
|
else:
|
||||||
|
roberta_config = RobertaConfig(**FLAGS.roberta)
|
||||||
|
|
||||||
|
if FLAGS.update_roberta_config != '':
|
||||||
|
roberta_config.update(dict(eval(FLAGS.update_roberta_config)))
|
||||||
|
|
||||||
|
roberta_config.update(dict(
|
||||||
|
bos_token_id=dataset.tokenizer.bos_token_id,
|
||||||
|
eos_token_id=dataset.tokenizer.eos_token_id,
|
||||||
|
pad_token_id=dataset.tokenizer.pad_token_id,
|
||||||
|
vocab_size=dataset.vocab_size,
|
||||||
|
))
|
||||||
|
|
||||||
|
model = FlaxRobertaForMaskedLMModule(
|
||||||
|
roberta_config, dtype=get_float_dtype_by_name(FLAGS.dtype)
|
||||||
|
)
|
||||||
|
|
||||||
|
optimizer, optimizer_info = OptimizerFactory.get_optimizer(
|
||||||
|
FLAGS.optimizer,
|
||||||
|
get_weight_decay_mask(RobertaConfig.get_weight_decay_exclusions()),
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_trainstate_from_params(params):
|
||||||
|
return TrainState.create(params=params, tx=optimizer, apply_fn=None)
|
||||||
|
|
||||||
|
def init_fn(rng):
|
||||||
|
rng_generator = JaxRNG(rng)
|
||||||
|
params = model.init(
|
||||||
|
input_ids=jnp.zeros((4, seq_length), dtype=jnp.int32),
|
||||||
|
position_ids=jnp.zeros((4, seq_length), dtype=jnp.int32),
|
||||||
|
attention_mask=jnp.ones((4, seq_length), dtype=jnp.int32),
|
||||||
|
token_type_ids=None,
|
||||||
|
head_mask=None,
|
||||||
|
rngs=rng_generator(roberta_config.rng_keys()),
|
||||||
|
)
|
||||||
|
return TrainState.create(params=params, tx=optimizer, apply_fn=None)
|
||||||
|
|
||||||
|
def train_step(train_state, rng, batch):
|
||||||
|
rng_generator = JaxRNG(rng)
|
||||||
|
tokens = with_sharding_constraint(batch['target_tokens'], PS(('dp', 'fsdp')))
|
||||||
|
def loss_and_accuracy(params):
|
||||||
|
altered_tokens = jax.random.uniform(
|
||||||
|
rng_generator(), shape=tokens.shape
|
||||||
|
) < FLAGS.mask_token_probability
|
||||||
|
random_uniform = jax.random.uniform(rng_generator(), shape=tokens.shape)
|
||||||
|
altered_by_mask = altered_tokens & (random_uniform < 0.8)
|
||||||
|
altered_by_random = altered_tokens & (random_uniform >= 0.8) & (random_uniform < 0.9)
|
||||||
|
inputs = jnp.where(altered_by_mask, dataset.tokenizer.mask_token_id, tokens)
|
||||||
|
random_tokens = jax.random.randint(
|
||||||
|
rng_generator(), shape=tokens.shape, minval=0, maxval=dataset.vocab_size
|
||||||
|
)
|
||||||
|
inputs = jnp.where(altered_by_random, random_tokens, inputs)
|
||||||
|
logits = model.apply(
|
||||||
|
params, inputs,
|
||||||
|
attention_mask=jnp.ones_like(inputs),
|
||||||
|
token_type_ids=None,
|
||||||
|
position_ids=None,
|
||||||
|
head_mask=None,
|
||||||
|
deterministic=False,
|
||||||
|
rngs=rng_generator(roberta_config.rng_keys()),
|
||||||
|
).logits
|
||||||
|
return cross_entropy_loss_and_accuracy(logits, tokens, valid=altered_tokens)
|
||||||
|
grad_fn = jax.value_and_grad(loss_and_accuracy, has_aux=True)
|
||||||
|
(loss, accuracy), grads = grad_fn(train_state.params)
|
||||||
|
train_state = train_state.apply_gradients(grads=grads)
|
||||||
|
metrics = dict(
|
||||||
|
loss=loss,
|
||||||
|
accuracy=accuracy,
|
||||||
|
learning_rate=optimizer_info['learning_rate_schedule'](train_state.step),
|
||||||
|
gradient_norm=global_norm(grads),
|
||||||
|
param_norm=global_norm(train_state.params),
|
||||||
|
)
|
||||||
|
return train_state, rng_generator(), metrics
|
||||||
|
|
||||||
|
def eval_step(train_state, rng, batch):
|
||||||
|
rng_generator = JaxRNG(rng)
|
||||||
|
tokens = with_sharding_constraint(batch['target_tokens'], PS(('dp', 'fsdp')))
|
||||||
|
altered_tokens = jax.random.uniform(
|
||||||
|
rng_generator(), shape=tokens.shape
|
||||||
|
) < FLAGS.mask_token_probability
|
||||||
|
random_uniform = jax.random.uniform(rng_generator(), shape=tokens.shape)
|
||||||
|
altered_by_mask = altered_tokens & (random_uniform < 0.8)
|
||||||
|
altered_by_random = altered_tokens & (random_uniform >= 0.8) & (random_uniform < 0.9)
|
||||||
|
inputs = jnp.where(altered_by_mask, dataset.tokenizer.mask_token_id, tokens)
|
||||||
|
random_tokens = jax.random.randint(
|
||||||
|
rng_generator(), shape=tokens.shape, minval=0, maxval=dataset.vocab_size
|
||||||
|
)
|
||||||
|
inputs = jnp.where(altered_by_random, random_tokens, inputs)
|
||||||
|
logits = model.apply(
|
||||||
|
train_state.params, inputs,
|
||||||
|
attention_mask=jnp.ones_like(inputs),
|
||||||
|
token_type_ids=None,
|
||||||
|
position_ids=None,
|
||||||
|
head_mask=None,
|
||||||
|
deterministic=False,
|
||||||
|
rngs=rng_generator(roberta_config.rng_keys()),
|
||||||
|
).logits
|
||||||
|
loss, accuracy = cross_entropy_loss_and_accuracy(logits, tokens, valid=altered_tokens)
|
||||||
|
metrics = dict(
|
||||||
|
eval_loss=loss,
|
||||||
|
eval_accuracy=accuracy,
|
||||||
|
)
|
||||||
|
return rng_generator(), metrics
|
||||||
|
|
||||||
|
train_state_shapes = jax.eval_shape(init_fn, next_rng())
|
||||||
|
train_state_partition = match_partition_rules(
|
||||||
|
RobertaConfig.get_partition_rules(), train_state_shapes
|
||||||
|
)
|
||||||
|
|
||||||
|
shard_fns, gather_fns = make_shard_and_gather_fns(
|
||||||
|
train_state_partition, train_state_shapes
|
||||||
|
)
|
||||||
|
checkpointer = StreamingCheckpointer(
|
||||||
|
FLAGS.checkpointer, logger.output_dir,
|
||||||
|
enable=jax.process_index() == 0
|
||||||
|
)
|
||||||
|
|
||||||
|
sharded_init_fn = pjit(
|
||||||
|
init_fn,
|
||||||
|
in_shardings=PS(),
|
||||||
|
out_shardings=train_state_partition
|
||||||
|
)
|
||||||
|
|
||||||
|
sharded_create_trainstate_from_params = pjit(
|
||||||
|
create_trainstate_from_params,
|
||||||
|
in_shardings=(train_state_partition.params, ),
|
||||||
|
out_shardings=train_state_partition,
|
||||||
|
donate_argnums=(0, ),
|
||||||
|
)
|
||||||
|
|
||||||
|
sharded_train_step = pjit(
|
||||||
|
train_step,
|
||||||
|
in_shardings=(train_state_partition, PS(), PS()),
|
||||||
|
out_shardings=(train_state_partition, PS(), PS()),
|
||||||
|
donate_argnums=(0, 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
sharded_eval_step = pjit(
|
||||||
|
eval_step,
|
||||||
|
in_shardings=(train_state_partition, PS(), PS()),
|
||||||
|
out_shardings=(PS(), PS()),
|
||||||
|
donate_argnums=(1,),
|
||||||
|
)
|
||||||
|
|
||||||
|
def save_checkpoint(train_state, milestone=False):
|
||||||
|
step = int(jax.device_get(train_state.step))
|
||||||
|
metadata = dict(
|
||||||
|
step=step,
|
||||||
|
variant=variant,
|
||||||
|
flags=flags_config_dict,
|
||||||
|
roberta_config=roberta_config.to_dict(),
|
||||||
|
)
|
||||||
|
checkpointer.save_all(
|
||||||
|
train_state=train_state,
|
||||||
|
gather_fns=gather_fns,
|
||||||
|
metadata=metadata,
|
||||||
|
dataset=dataset.get_state_dict(),
|
||||||
|
milestone=milestone,
|
||||||
|
)
|
||||||
|
|
||||||
|
mesh = RobertaConfig.get_jax_mesh(FLAGS.mesh_dim)
|
||||||
|
with mesh:
|
||||||
|
train_state, restored_params = None, None
|
||||||
|
if FLAGS.load_checkpoint != '':
|
||||||
|
load_type, load_path = FLAGS.load_checkpoint.split('::', 1)
|
||||||
|
if load_type == 'huggingface':
|
||||||
|
restored_params = tree_apply(
|
||||||
|
shard_fns.params, roberta_config.load_pretrained(load_path)
|
||||||
|
)
|
||||||
|
train_state = None
|
||||||
|
else:
|
||||||
|
train_state, restored_params = checkpointer.load_trainstate_checkpoint(
|
||||||
|
FLAGS.load_checkpoint, train_state_shapes, shard_fns
|
||||||
|
)
|
||||||
|
|
||||||
|
if train_state is None and restored_params is None:
|
||||||
|
# Initialize from scratch
|
||||||
|
train_state = sharded_init_fn(next_rng())
|
||||||
|
elif train_state is None and restored_params is not None:
|
||||||
|
# Restore from params but initialize train_state
|
||||||
|
train_state = sharded_create_trainstate_from_params(restored_params)
|
||||||
|
del restored_params
|
||||||
|
|
||||||
|
start_step = int(jax.device_get(train_state.step))
|
||||||
|
|
||||||
|
if FLAGS.save_model_freq > 0:
|
||||||
|
save_checkpoint(train_state)
|
||||||
|
|
||||||
|
sharded_rng = next_rng()
|
||||||
|
|
||||||
|
step_counter = trange(start_step, FLAGS.total_steps, ncols=0)
|
||||||
|
|
||||||
|
for step, (batch, dataset_metrics) in zip(step_counter, dataset):
|
||||||
|
train_state, sharded_rng, metrics = sharded_train_step(
|
||||||
|
train_state, sharded_rng, batch
|
||||||
|
)
|
||||||
|
|
||||||
|
if step % FLAGS.log_freq == 0:
|
||||||
|
if FLAGS.eval_steps > 0:
|
||||||
|
eval_metric_list = []
|
||||||
|
for _ in range(FLAGS.eval_steps):
|
||||||
|
eval_batch, _ = next(eval_iterator)
|
||||||
|
sharded_rng, eval_metrics = sharded_eval_step(
|
||||||
|
train_state, sharded_rng, eval_batch
|
||||||
|
)
|
||||||
|
eval_metric_list.append(eval_metrics)
|
||||||
|
metrics.update(average_metrics(eval_metric_list))
|
||||||
|
|
||||||
|
log_metrics = {"step": step}
|
||||||
|
log_metrics.update(metrics)
|
||||||
|
log_metrics.update(dataset_metrics)
|
||||||
|
log_metrics = jax.device_get(log_metrics)
|
||||||
|
logger.log(log_metrics)
|
||||||
|
tqdm.write("\n" + pprint.pformat(log_metrics) + "\n")
|
||||||
|
|
||||||
|
if FLAGS.save_milestone_freq > 0 and (step + 1) % FLAGS.save_milestone_freq == 0:
|
||||||
|
save_checkpoint(train_state, milestone=True)
|
||||||
|
elif FLAGS.save_model_freq > 0 and (step + 1) % FLAGS.save_model_freq == 0:
|
||||||
|
save_checkpoint(train_state)
|
||||||
|
|
||||||
|
if FLAGS.save_model_freq > 0:
|
||||||
|
save_checkpoint(train_state)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
mlxu.run(main)
|
||||||
346
EasyLM/optimizers.py
Normal file
346
EasyLM/optimizers.py
Normal file
@@ -0,0 +1,346 @@
|
|||||||
|
import os
|
||||||
|
import time
|
||||||
|
from typing import Any, Mapping, Text, Tuple, Union, NamedTuple
|
||||||
|
from functools import partial
|
||||||
|
import re
|
||||||
|
import dataclasses
|
||||||
|
import random
|
||||||
|
|
||||||
|
from ml_collections.config_dict import config_dict
|
||||||
|
from ml_collections import ConfigDict
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import numpy as np
|
||||||
|
from absl import logging
|
||||||
|
import optax
|
||||||
|
|
||||||
|
from EasyLM.jax_utils import float_to_dtype
|
||||||
|
|
||||||
|
|
||||||
|
class OptimizerFactory(object):
|
||||||
|
""" Configurable optax optimizer factory. """
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_default_config(updates=None):
|
||||||
|
config = ConfigDict()
|
||||||
|
config.accumulate_gradient_steps = 1
|
||||||
|
config.type = 'adamw'
|
||||||
|
config.palm_optimizer = PalmOptimizerFactory.get_default_config()
|
||||||
|
config.adamw_optimizer = AdamWOptimizerFactory.get_default_config()
|
||||||
|
config.lion_optimizer = LionOptimizerFactory.get_default_config()
|
||||||
|
|
||||||
|
if updates is not None:
|
||||||
|
config.update(ConfigDict(updates).copy_and_resolve_references())
|
||||||
|
return config
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_optimizer(cls, config, weight_decay_mask=None):
|
||||||
|
config = cls.get_default_config(config)
|
||||||
|
if config.type == 'palm':
|
||||||
|
optimizer, optimizer_info = PalmOptimizerFactory.get_optimizer(
|
||||||
|
config.palm_optimizer, weight_decay_mask
|
||||||
|
)
|
||||||
|
elif config.type == 'adamw':
|
||||||
|
optimizer, optimizer_info = AdamWOptimizerFactory.get_optimizer(
|
||||||
|
config.adamw_optimizer, weight_decay_mask
|
||||||
|
)
|
||||||
|
elif config.type == 'lion':
|
||||||
|
optimizer, optimizer_info = LionOptimizerFactory.get_optimizer(
|
||||||
|
config.lion_optimizer, weight_decay_mask
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f'Unknown optimizer type: {config.type}')
|
||||||
|
|
||||||
|
if config.accumulate_gradient_steps > 1:
|
||||||
|
optimizer = optax.MultiSteps(
|
||||||
|
optimizer, config.accumulate_gradient_steps
|
||||||
|
)
|
||||||
|
|
||||||
|
return optimizer, optimizer_info
|
||||||
|
|
||||||
|
|
||||||
|
class PalmOptimizerFactory(object):
|
||||||
|
""" PaLM optimizer factory. This optimizer implements the optimizer
|
||||||
|
described in the PaLM paper: https://arxiv.org/abs/2204.02311
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_default_config(updates=None):
|
||||||
|
config = ConfigDict()
|
||||||
|
config.lr = 0.01
|
||||||
|
config.lr_warmup_steps = 10000
|
||||||
|
config.b1 = 0.9
|
||||||
|
config.b2 = 0.99
|
||||||
|
config.clip_gradient = 1.0
|
||||||
|
config.weight_decay = 1e-4
|
||||||
|
config.bf16_momentum = False
|
||||||
|
|
||||||
|
if updates is not None:
|
||||||
|
config.update(ConfigDict(updates).copy_and_resolve_references())
|
||||||
|
return config
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_optimizer(cls, config, weight_decay_mask=None):
|
||||||
|
config = cls.get_default_config(config)
|
||||||
|
|
||||||
|
def learning_rate_schedule(step):
|
||||||
|
multiplier = config.lr / 0.01
|
||||||
|
return multiplier / jnp.sqrt(jnp.maximum(step, config.lr_warmup_steps))
|
||||||
|
|
||||||
|
def weight_decay_schedule(step):
|
||||||
|
multiplier = config.weight_decay / 1e-4
|
||||||
|
return -multiplier * jnp.square(learning_rate_schedule(step))
|
||||||
|
|
||||||
|
optimizer_info = dict(
|
||||||
|
learning_rate_schedule=learning_rate_schedule,
|
||||||
|
weight_decay_schedule=weight_decay_schedule,
|
||||||
|
)
|
||||||
|
|
||||||
|
optimizer = optax.chain(
|
||||||
|
optax.clip_by_global_norm(config.clip_gradient),
|
||||||
|
optax.adafactor(
|
||||||
|
learning_rate=learning_rate_schedule,
|
||||||
|
multiply_by_parameter_scale=True,
|
||||||
|
momentum=config.b1,
|
||||||
|
decay_rate=config.b2,
|
||||||
|
factored=False,
|
||||||
|
clipping_threshold=None,
|
||||||
|
dtype_momentum=jnp.bfloat16 if config.bf16_momentum else jnp.float32,
|
||||||
|
),
|
||||||
|
optax_add_scheduled_weight_decay(
|
||||||
|
weight_decay_schedule, weight_decay_mask
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return optimizer, optimizer_info
|
||||||
|
|
||||||
|
|
||||||
|
class AdamWOptimizerFactory(object):
|
||||||
|
""" AdamW optimizer with cosine schedule. """
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_default_config(updates=None):
|
||||||
|
config = ConfigDict()
|
||||||
|
config.init_lr = 0.0
|
||||||
|
config.end_lr = 0.001
|
||||||
|
config.lr = 0.01
|
||||||
|
config.lr_warmup_steps = 2000
|
||||||
|
config.lr_decay_steps = 500000
|
||||||
|
config.b1 = 0.9
|
||||||
|
config.b2 = 0.95
|
||||||
|
config.clip_gradient = 1.0
|
||||||
|
config.weight_decay = 1e-4
|
||||||
|
config.bf16_momentum = False
|
||||||
|
config.multiply_by_parameter_scale = False
|
||||||
|
|
||||||
|
if updates is not None:
|
||||||
|
config.update(ConfigDict(updates).copy_and_resolve_references())
|
||||||
|
return config
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_optimizer(cls, config, weight_decay_mask=None):
|
||||||
|
config = cls.get_default_config(config)
|
||||||
|
|
||||||
|
learning_rate_schedule = optax.warmup_cosine_decay_schedule(
|
||||||
|
init_value=config.init_lr,
|
||||||
|
peak_value=config.lr,
|
||||||
|
warmup_steps=config.lr_warmup_steps,
|
||||||
|
decay_steps=config.lr_decay_steps,
|
||||||
|
end_value=config.end_lr,
|
||||||
|
)
|
||||||
|
|
||||||
|
optimizer_info = dict(
|
||||||
|
learning_rate_schedule=learning_rate_schedule,
|
||||||
|
)
|
||||||
|
|
||||||
|
if config.multiply_by_parameter_scale:
|
||||||
|
optimizer = optax.chain(
|
||||||
|
optax.clip_by_global_norm(config.clip_gradient),
|
||||||
|
optax.adafactor(
|
||||||
|
learning_rate=learning_rate_schedule,
|
||||||
|
multiply_by_parameter_scale=True,
|
||||||
|
momentum=config.b1,
|
||||||
|
decay_rate=config.b2,
|
||||||
|
factored=False,
|
||||||
|
clipping_threshold=None,
|
||||||
|
dtype_momentum=jnp.bfloat16 if config.bf16_momentum else jnp.float32,
|
||||||
|
),
|
||||||
|
optax_add_scheduled_weight_decay(
|
||||||
|
lambda step: -learning_rate_schedule(step) * config.weight_decay,
|
||||||
|
weight_decay_mask
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
optimizer = optax.chain(
|
||||||
|
optax.clip_by_global_norm(config.clip_gradient),
|
||||||
|
optax.adamw(
|
||||||
|
learning_rate=learning_rate_schedule,
|
||||||
|
weight_decay=config.weight_decay,
|
||||||
|
b1=config.b1,
|
||||||
|
b2=config.b2,
|
||||||
|
mask=weight_decay_mask,
|
||||||
|
mu_dtype=jnp.bfloat16 if config.bf16_momentum else jnp.float32,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
return optimizer, optimizer_info
|
||||||
|
|
||||||
|
class LionOptimizerFactory(object):
|
||||||
|
""" Lion optimizer with cosine schedule. """
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_default_config(updates=None):
|
||||||
|
config = ConfigDict()
|
||||||
|
config.init_lr = 0.0
|
||||||
|
config.end_lr = 0.0001
|
||||||
|
config.lr = 0.001
|
||||||
|
config.lr_warmup_steps = 60000
|
||||||
|
config.lr_constant_steps = 840000
|
||||||
|
config.lr_decay_steps = 100000
|
||||||
|
config.b1 = 0.9
|
||||||
|
config.b2 = 0.98
|
||||||
|
config.clip_gradient = 1.0
|
||||||
|
config.weight_decay = 1e-3
|
||||||
|
config.bf16_momentum = False
|
||||||
|
config.lr_schedule_type = "warmup_cosine_decay_schedule"
|
||||||
|
config.lr_decay_rate = 0.98
|
||||||
|
|
||||||
|
if updates is not None:
|
||||||
|
config.update(ConfigDict(updates).copy_and_resolve_references())
|
||||||
|
return config
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_optimizer(cls, config, weight_decay_mask=None):
|
||||||
|
config = cls.get_default_config(config)
|
||||||
|
|
||||||
|
if config.lr_schedule_type == "warmup_cosine_decay_schedule":
|
||||||
|
learning_rate_schedule = optax.warmup_cosine_decay_schedule(
|
||||||
|
init_value=config.init_lr,
|
||||||
|
peak_value=config.lr,
|
||||||
|
warmup_steps=config.lr_warmup_steps,
|
||||||
|
decay_steps=config.lr_decay_steps,
|
||||||
|
end_value=config.end_lr,
|
||||||
|
)
|
||||||
|
elif config.lr_schedule_type == "warmup_constant":
|
||||||
|
learning_rate_schedule = optax.join_schedules(
|
||||||
|
[
|
||||||
|
optax.linear_schedule(
|
||||||
|
init_value=config.init_lr,
|
||||||
|
end_value=config.lr,
|
||||||
|
transition_steps=config.lr_warmup_steps,
|
||||||
|
),
|
||||||
|
optax.constant_schedule(config.lr),
|
||||||
|
],
|
||||||
|
[config.lr_warmup_steps],
|
||||||
|
)
|
||||||
|
elif config.lr_schedule_type == "warmup_constant_linear_decay":
|
||||||
|
learning_rate_schedule = optax.join_schedules(
|
||||||
|
[
|
||||||
|
optax.linear_schedule(
|
||||||
|
init_value=config.init_lr,
|
||||||
|
end_value=config.lr,
|
||||||
|
transition_steps=config.lr_warmup_steps,
|
||||||
|
),
|
||||||
|
optax.constant_schedule(config.lr),
|
||||||
|
optax.linear_schedule(
|
||||||
|
init_value=config.lr,
|
||||||
|
end_value=config.end_lr,
|
||||||
|
transition_steps=config.lr_decay_steps,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
[config.lr_warmup_steps, config.lr_constant_steps],
|
||||||
|
)
|
||||||
|
elif config.lr_schedule_type == "warmup_constant_exponential_decay":
|
||||||
|
learning_rate_schedule = optax.join_schedules(
|
||||||
|
[
|
||||||
|
optax.linear_schedule(
|
||||||
|
init_value=config.init_lr,
|
||||||
|
end_value=config.lr,
|
||||||
|
transition_steps=config.lr_warmup_steps,
|
||||||
|
),
|
||||||
|
optax.constant_schedule(config.lr),
|
||||||
|
optax.exponential_decay(
|
||||||
|
init_value=config.lr,
|
||||||
|
transition_steps=config.lr_decay_steps,
|
||||||
|
decay_rate=config.lr_decay_rate,
|
||||||
|
transition_begin=0,
|
||||||
|
staircase=False,
|
||||||
|
end_value=config.end_lr,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
[config.lr_warmup_steps, config.lr_constant_steps],
|
||||||
|
)
|
||||||
|
elif config.lr_schedule_type == "exponential_decay":
|
||||||
|
learning_rate_schedule = optax.exponential_decay(
|
||||||
|
init_value=config.lr,
|
||||||
|
transition_steps=config.lr_decay_steps,
|
||||||
|
decay_rate=config.lr_decay_rate,
|
||||||
|
transition_begin=0,
|
||||||
|
staircase=False,
|
||||||
|
end_value=config.end_lr,
|
||||||
|
)
|
||||||
|
elif config.lr_schedule_type == "linear_decay":
|
||||||
|
learning_rate_schedule = optax.linear_schedule(
|
||||||
|
init_value=config.lr,
|
||||||
|
end_value=config.end_lr,
|
||||||
|
transition_steps=config.lr_decay_steps,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError('config.lr_schedule_type must be "warmup_cosine_decay_schedule", "warmup_constant", "warmup_constant_linear_decay", "warmup_constant_exponential_decay", "exponential_decay" or "linear_decay"')
|
||||||
|
|
||||||
|
optimizer_info = dict(
|
||||||
|
learning_rate_schedule=learning_rate_schedule,
|
||||||
|
)
|
||||||
|
|
||||||
|
optimizer = optax.chain(
|
||||||
|
optax.clip_by_global_norm(config.clip_gradient),
|
||||||
|
optax.lion(
|
||||||
|
learning_rate=learning_rate_schedule,
|
||||||
|
weight_decay=config.weight_decay,
|
||||||
|
b1=config.b1,
|
||||||
|
b2=config.b2,
|
||||||
|
mask=weight_decay_mask,
|
||||||
|
mu_dtype=jnp.bfloat16 if config.bf16_momentum else jnp.float32,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
return optimizer, optimizer_info
|
||||||
|
|
||||||
|
|
||||||
|
class OptaxScheduledWeightDecayState(NamedTuple):
|
||||||
|
count: jax.Array
|
||||||
|
|
||||||
|
|
||||||
|
def optax_add_scheduled_weight_decay(schedule_fn, mask=None):
|
||||||
|
""" Apply weight decay with schedule. """
|
||||||
|
|
||||||
|
def init_fn(params):
|
||||||
|
del params
|
||||||
|
return OptaxScheduledWeightDecayState(count=jnp.zeros([], jnp.int32))
|
||||||
|
|
||||||
|
def update_fn(updates, state, params):
|
||||||
|
if params is None:
|
||||||
|
raise ValueError('Params cannot be None for weight decay!')
|
||||||
|
|
||||||
|
weight_decay = schedule_fn(state.count)
|
||||||
|
updates = jax.tree_util.tree_map(
|
||||||
|
lambda g, p: g + weight_decay * p, updates, params
|
||||||
|
)
|
||||||
|
return updates, OptaxScheduledWeightDecayState(
|
||||||
|
count=optax.safe_int32_increment(state.count)
|
||||||
|
)
|
||||||
|
|
||||||
|
if mask is not None:
|
||||||
|
return optax.masked(optax.GradientTransformation(init_fn, update_fn), mask)
|
||||||
|
return optax.GradientTransformation(init_fn, update_fn)
|
||||||
0
EasyLM/scripts/__init__.py
Normal file
0
EasyLM/scripts/__init__.py
Normal file
150
EasyLM/scripts/benchmark_attention.py
Normal file
150
EasyLM/scripts/benchmark_attention.py
Normal file
@@ -0,0 +1,150 @@
|
|||||||
|
from functools import partial
|
||||||
|
from time import time
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
import jax
|
||||||
|
import jax.flatten_util
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import mlxu
|
||||||
|
from EasyLM.bpt import blockwise_attn
|
||||||
|
from EasyLM.jax_utils import (
|
||||||
|
get_float_dtype_by_name, set_random_seed, next_rng, JaxRNG
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
FLAGS, _ = mlxu.define_flags_with_default(
|
||||||
|
seed=42,
|
||||||
|
dtype='fp32',
|
||||||
|
embed_dim=2048,
|
||||||
|
n_heads=16,
|
||||||
|
ref_attn_seq_len=2048,
|
||||||
|
eff_attn_seq_len=16384,
|
||||||
|
batch_size=1,
|
||||||
|
query_chunk_size=2048,
|
||||||
|
key_chunk_size=2048,
|
||||||
|
warmup_steps=40,
|
||||||
|
steps=200,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main(argv):
|
||||||
|
|
||||||
|
def random_kqv(rng_key, seq_len):
|
||||||
|
rng_generator = JaxRNG(rng_key)
|
||||||
|
kqv = []
|
||||||
|
for i in range(3):
|
||||||
|
kqv.append(
|
||||||
|
jax.random.normal(
|
||||||
|
rng_generator(),
|
||||||
|
(FLAGS.batch_size, seq_len, FLAGS.n_heads, FLAGS.embed_dim // FLAGS.n_heads),
|
||||||
|
dtype=get_float_dtype_by_name(FLAGS.dtype)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return tuple(kqv)
|
||||||
|
|
||||||
|
def reference_attn(query, key, value):
|
||||||
|
dtype = get_float_dtype_by_name(FLAGS.dtype)
|
||||||
|
query = query / jnp.sqrt(query.shape[-1]).astype(dtype)
|
||||||
|
logits = jnp.einsum("bqhc,bkhc->bhqk", query, key)
|
||||||
|
mask_value = jnp.finfo(logits.dtype).min
|
||||||
|
_, q_seq_len, _, _ = query.shape
|
||||||
|
_, kv_seq_len, _, _ = key.shape
|
||||||
|
mask_shape = (q_seq_len, kv_seq_len)
|
||||||
|
row_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 0)
|
||||||
|
col_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 1)
|
||||||
|
causal_mask = (row_ids < col_ids)[None, None, :, :]
|
||||||
|
logits = logits + jnp.where(causal_mask, mask_value, 0.0)
|
||||||
|
weights = jax.nn.softmax(logits, axis=-1)
|
||||||
|
out = jnp.einsum("bhqk,bkhc->bqhc", weights, value)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def efficient_attention(query, key, value):
|
||||||
|
dtype = get_float_dtype_by_name(FLAGS.dtype)
|
||||||
|
return blockwise_attn(
|
||||||
|
query, key, value,
|
||||||
|
bias=None,
|
||||||
|
deterministic=True,
|
||||||
|
dropout_rng=None,
|
||||||
|
attn_pdrop=0.0,
|
||||||
|
causal=True,
|
||||||
|
query_chunk_size=FLAGS.query_chunk_size,
|
||||||
|
key_chunk_size=FLAGS.key_chunk_size,
|
||||||
|
dtype=get_float_dtype_by_name(FLAGS.dtype),
|
||||||
|
policy=jax.checkpoint_policies.nothing_saveable(),
|
||||||
|
precision=None,
|
||||||
|
float32_logits=True,
|
||||||
|
prevent_cse=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@partial(jax.jit, static_argnums=(1,))
|
||||||
|
def reference_attn_forward_backward(rng_key, seq_len):
|
||||||
|
@partial(jax.grad, argnums=(0, 1, 2))
|
||||||
|
@partial(jax.checkpoint, policy=jax.checkpoint_policies.nothing_saveable())
|
||||||
|
def grad_fn(query, key, value):
|
||||||
|
out = reference_attn(query, key, value)
|
||||||
|
return jnp.mean(out)
|
||||||
|
|
||||||
|
query, key, value = random_kqv(rng_key, seq_len)
|
||||||
|
return jax.flatten_util.ravel_pytree(
|
||||||
|
grad_fn(query, key, value)[1]
|
||||||
|
)[0].mean()
|
||||||
|
|
||||||
|
@partial(jax.jit, static_argnums=(1,))
|
||||||
|
def efficient_attn_forward_backward(rng_key, seq_len):
|
||||||
|
@partial(jax.grad, argnums=(0, 1, 2))
|
||||||
|
def grad_fn(query, key, value):
|
||||||
|
out = efficient_attention(query, key, value)
|
||||||
|
return jnp.mean(out)
|
||||||
|
|
||||||
|
query, key, value = random_kqv(rng_key, seq_len)
|
||||||
|
return jax.flatten_util.ravel_pytree(
|
||||||
|
grad_fn(query, key, value)[1]
|
||||||
|
)[0].mean()
|
||||||
|
|
||||||
|
|
||||||
|
set_random_seed(FLAGS.seed)
|
||||||
|
|
||||||
|
jax.block_until_ready(reference_attn_forward_backward(next_rng(), FLAGS.ref_attn_seq_len))
|
||||||
|
jax.block_until_ready(efficient_attn_forward_backward(next_rng(), FLAGS.eff_attn_seq_len))
|
||||||
|
|
||||||
|
all_results = []
|
||||||
|
for i in range(FLAGS.warmup_steps):
|
||||||
|
all_results.append(reference_attn_forward_backward(next_rng(), FLAGS.ref_attn_seq_len))
|
||||||
|
jax.block_until_ready(all_results)
|
||||||
|
|
||||||
|
start_time = time()
|
||||||
|
all_results = []
|
||||||
|
for i in range(FLAGS.steps):
|
||||||
|
all_results.append(reference_attn_forward_backward(next_rng(), FLAGS.ref_attn_seq_len))
|
||||||
|
|
||||||
|
jax.block_until_ready(all_results)
|
||||||
|
elapsed_time_ref_attn = time() - start_time
|
||||||
|
print(f'Reference attention: {elapsed_time_ref_attn:.3f} seconds')
|
||||||
|
|
||||||
|
|
||||||
|
all_results = []
|
||||||
|
for i in range(FLAGS.warmup_steps):
|
||||||
|
all_results.append(efficient_attn_forward_backward(next_rng(), FLAGS.eff_attn_seq_len))
|
||||||
|
jax.block_until_ready(all_results)
|
||||||
|
|
||||||
|
|
||||||
|
start_time = time()
|
||||||
|
all_results = []
|
||||||
|
for i in range(FLAGS.steps):
|
||||||
|
all_results.append(efficient_attn_forward_backward(next_rng(), FLAGS.eff_attn_seq_len))
|
||||||
|
|
||||||
|
jax.block_until_ready(all_results)
|
||||||
|
elapsed_time_efficient_attn = time() - start_time
|
||||||
|
print(f'Efficient attention: {elapsed_time_efficient_attn:.3f} seconds')
|
||||||
|
|
||||||
|
flops_ratio = (FLAGS.eff_attn_seq_len / FLAGS.ref_attn_seq_len) ** 2
|
||||||
|
efficiency = elapsed_time_ref_attn / elapsed_time_efficient_attn * flops_ratio
|
||||||
|
print(f'Efficiency: {efficiency:.3f}')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
mlxu.run(main)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
42
EasyLM/scripts/convert_checkpoint.py
Normal file
42
EasyLM/scripts/convert_checkpoint.py
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
# This script converts model checkpoint trained by EsayLM to a standard
|
||||||
|
# mspack checkpoint that can be loaded by huggingface transformers or
|
||||||
|
# flax.serialization.msgpack_restore. Such conversion allows models to be
|
||||||
|
# used by other frameworks that integrate with huggingface transformers.
|
||||||
|
|
||||||
|
import pprint
|
||||||
|
from functools import partial
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
import mlxu
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import flax.serialization
|
||||||
|
from EasyLM.checkpoint import StreamingCheckpointer
|
||||||
|
from EasyLM.jax_utils import float_to_dtype
|
||||||
|
|
||||||
|
|
||||||
|
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
|
||||||
|
load_checkpoint='',
|
||||||
|
output_file='',
|
||||||
|
streaming=False,
|
||||||
|
float_dtype='bf16',
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main(argv):
|
||||||
|
assert FLAGS.load_checkpoint != '' and FLAGS.output_file != '', 'input and output must be specified'
|
||||||
|
params = StreamingCheckpointer.load_trainstate_checkpoint(
|
||||||
|
FLAGS.load_checkpoint, disallow_trainstate=True
|
||||||
|
)[1]['params']
|
||||||
|
|
||||||
|
if FLAGS.streaming:
|
||||||
|
StreamingCheckpointer.save_train_state_to_file(
|
||||||
|
params, FLAGS.output_file, float_dtype=FLAGS.float_dtype
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
params = float_to_dtype(params, FLAGS.float_dtype)
|
||||||
|
with mlxu.open_file(FLAGS.output, 'wb') as fout:
|
||||||
|
fout.write(flax.serialization.msgpack_serialize(params, in_place=True))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
mlxu.run(main)
|
||||||
59
EasyLM/scripts/diff_checkpoint.py
Normal file
59
EasyLM/scripts/diff_checkpoint.py
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
# This script converts model checkpoint trained by EsayLM to a standard
|
||||||
|
# mspack checkpoint that can be loaded by huggingface transformers or
|
||||||
|
# flax.serialization.msgpack_restore. Such conversion allows models to be
|
||||||
|
# used by other frameworks that integrate with huggingface transformers.
|
||||||
|
|
||||||
|
import pprint
|
||||||
|
from functools import partial
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import flax.serialization
|
||||||
|
import mlxu
|
||||||
|
from EasyLM.checkpoint import StreamingCheckpointer
|
||||||
|
from EasyLM.jax_utils import float_to_dtype
|
||||||
|
|
||||||
|
|
||||||
|
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
|
||||||
|
recover_diff=False,
|
||||||
|
load_base_checkpoint='',
|
||||||
|
load_target_checkpoint='',
|
||||||
|
output_file='',
|
||||||
|
streaming=True,
|
||||||
|
float_dtype='bf16',
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main(argv):
|
||||||
|
assert FLAGS.load_base_checkpoint != '' and FLAGS.load_target_checkpoint != ''
|
||||||
|
assert FLAGS.output_file != ''
|
||||||
|
base_params = StreamingCheckpointer.load_trainstate_checkpoint(
|
||||||
|
FLAGS.load_base_checkpoint, disallow_trainstate=True
|
||||||
|
)[1]['params']
|
||||||
|
|
||||||
|
target_params = StreamingCheckpointer.load_trainstate_checkpoint(
|
||||||
|
FLAGS.load_target_checkpoint, disallow_trainstate=True
|
||||||
|
)[1]['params']
|
||||||
|
|
||||||
|
if FLAGS.recover_diff:
|
||||||
|
params = jax.tree_util.tree_map(
|
||||||
|
lambda b, t: b + t, base_params, target_params
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
params = jax.tree_util.tree_map(
|
||||||
|
lambda b, t: t - b, base_params, target_params
|
||||||
|
)
|
||||||
|
|
||||||
|
if FLAGS.streaming:
|
||||||
|
StreamingCheckpointer.save_train_state_to_file(
|
||||||
|
params, FLAGS.output_file, float_dtype=FLAGS.float_dtype
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
params = float_to_dtype(params, FLAGS.float_dtype)
|
||||||
|
with mlxu.open_file(FLAGS.output, 'wb') as fout:
|
||||||
|
fout.write(flax.serialization.msgpack_serialize(params, in_place=True))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
mlxu.run(main)
|
||||||
65
EasyLM/scripts/lm_eval_harness.py
Normal file
65
EasyLM/scripts/lm_eval_harness.py
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
# This script runs lm_eval_harness evaluations against a served language model.
|
||||||
|
# Typically, you need to run a language model server first, e.g.:
|
||||||
|
# python -m EasyLM.models.gptj.gptj_serve ...
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
import pprint
|
||||||
|
from functools import partial
|
||||||
|
import os
|
||||||
|
from tqdm import tqdm, trange
|
||||||
|
import numpy as np
|
||||||
|
import mlxu
|
||||||
|
|
||||||
|
from flax.traverse_util import flatten_dict
|
||||||
|
from lm_eval import evaluator, tasks
|
||||||
|
from lm_eval.base import LM
|
||||||
|
|
||||||
|
from EasyLM.serving import LMClient
|
||||||
|
|
||||||
|
|
||||||
|
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
|
||||||
|
tasks='wsc,piqa,winogrande,openbookqa,logiqa',
|
||||||
|
shots=0,
|
||||||
|
limit=0,
|
||||||
|
write_out=False,
|
||||||
|
lm_client=LMClient.get_default_config(),
|
||||||
|
logger=mlxu.WandBLogger.get_default_config(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LMEvalHarnessInterface(LM):
|
||||||
|
|
||||||
|
def __init__(self, lm_client):
|
||||||
|
self.lm_client = lm_client
|
||||||
|
|
||||||
|
def greedy_until(self, inputs):
|
||||||
|
prefix, until = zip(*inputs)
|
||||||
|
return self.lm_client.greedy_until(prefix, until)
|
||||||
|
|
||||||
|
def loglikelihood_rolling(self, inputs):
|
||||||
|
loglikelihood, is_greedy = self.lm_client.loglikelihood_rolling(inputs)
|
||||||
|
return list(zip(loglikelihood, is_greedy))
|
||||||
|
|
||||||
|
def loglikelihood(self, inputs):
|
||||||
|
prefix, text = zip(*inputs)
|
||||||
|
loglikelihood, is_greedy = self.lm_client.loglikelihood(prefix, text)
|
||||||
|
return list(zip(loglikelihood, is_greedy))
|
||||||
|
|
||||||
|
|
||||||
|
def main(argv):
|
||||||
|
logger = mlxu.WandBLogger(
|
||||||
|
config=FLAGS.logger, variant=mlxu.get_user_flags(FLAGS, FLAGS_DEF)
|
||||||
|
)
|
||||||
|
model = LMEvalHarnessInterface(LMClient(FLAGS.lm_client))
|
||||||
|
task_list = FLAGS.tasks.split(',')
|
||||||
|
results = evaluator.evaluate(
|
||||||
|
model, tasks.get_task_dict(task_list), False, FLAGS.shots,
|
||||||
|
limit=None if FLAGS.limit <= 0 else FLAGS.limit,
|
||||||
|
write_out=FLAGS.write_out,
|
||||||
|
)
|
||||||
|
logger.log(flatten_dict(results['results'], sep='/'))
|
||||||
|
pprint.pprint(results)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
mlxu.run(main)
|
||||||
52
EasyLM/scripts/lm_eval_json.py
Normal file
52
EasyLM/scripts/lm_eval_json.py
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
import json
|
||||||
|
import mlxu
|
||||||
|
from EasyLM.serving import LMClient
|
||||||
|
|
||||||
|
|
||||||
|
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
|
||||||
|
input_file='',
|
||||||
|
output_file='',
|
||||||
|
prefix_field='prefix',
|
||||||
|
text_field='text',
|
||||||
|
until_field='until',
|
||||||
|
eval_type='loglikelihood',
|
||||||
|
lm_client=LMClient.get_default_config(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main(argv):
|
||||||
|
lm_client = LMClient(FLAGS.lm_client)
|
||||||
|
with mlxu.open_file(FLAGS.input_file, 'r') as fin:
|
||||||
|
input_data = json.load(fin)
|
||||||
|
|
||||||
|
if FLAGS.eval_type == 'loglikelihood':
|
||||||
|
prefix = input_data[FLAGS.prefix_field]
|
||||||
|
text = input_data[FLAGS.text_field]
|
||||||
|
loglikelihoods, is_greedys = lm_client.loglikelihood(prefix, text)
|
||||||
|
output_data = {
|
||||||
|
'loglikelihood': loglikelihoods,
|
||||||
|
'is_greedy': is_greedys,
|
||||||
|
}
|
||||||
|
elif FLAGS.eval_type == 'loglikelihood_rolling':
|
||||||
|
text = input_data[FLAGS.text_field]
|
||||||
|
loglikelihoods, is_greedys = lm_client.loglikelihood_rolling(text)
|
||||||
|
output_data = {
|
||||||
|
'loglikelihood': loglikelihoods,
|
||||||
|
'is_greedy': is_greedys,
|
||||||
|
}
|
||||||
|
elif FLAGS.eval_type == 'greedy_until':
|
||||||
|
prefix = input_data[FLAGS.prefix_field]
|
||||||
|
until = input_data[FLAGS.until_field]
|
||||||
|
output_data = {'output_text': lm_client.greedy_until(prefix, until)}
|
||||||
|
elif FLAGS.eval_type == 'generate':
|
||||||
|
prefix = input_data[FLAGS.prefix_field]
|
||||||
|
output_data = {'output_text': lm_client.generate(prefix)}
|
||||||
|
else:
|
||||||
|
raise ValueError(f'Unknown eval_type: {FLAGS.eval_type}')
|
||||||
|
|
||||||
|
with mlxu.open_file(FLAGS.output_file, 'w') as fout:
|
||||||
|
json.dump(output_data, fout)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
mlxu.run(main)
|
||||||
566
EasyLM/serving.py
Normal file
566
EasyLM/serving.py
Normal file
@@ -0,0 +1,566 @@
|
|||||||
|
import dataclasses
|
||||||
|
import pprint
|
||||||
|
from functools import partial
|
||||||
|
import re
|
||||||
|
import os
|
||||||
|
from threading import Lock
|
||||||
|
import urllib
|
||||||
|
import time
|
||||||
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
import absl.logging
|
||||||
|
from tqdm import tqdm, trange
|
||||||
|
import numpy as np
|
||||||
|
import mlxu
|
||||||
|
from ml_collections import ConfigDict
|
||||||
|
import uvicorn
|
||||||
|
from fastapi import FastAPI
|
||||||
|
import gradio as gr
|
||||||
|
import requests
|
||||||
|
from requests.exceptions import Timeout, ConnectionError
|
||||||
|
|
||||||
|
|
||||||
|
class InferenceRequest(BaseModel):
|
||||||
|
prefix_text: Optional[List[str]] = None
|
||||||
|
text: Optional[List[str]] = None
|
||||||
|
until: Optional[Union[List[str], List[List[str]]]] = None
|
||||||
|
temperature: Optional[float] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ChatRequest(BaseModel):
|
||||||
|
prompt: str
|
||||||
|
context: str = ''
|
||||||
|
temperature: Optional[float] = None
|
||||||
|
|
||||||
|
|
||||||
|
class LMServer(object):
|
||||||
|
""" HTTP server for serving langauge models. """
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_default_config(updates=None):
|
||||||
|
config = ConfigDict()
|
||||||
|
config.host = '0.0.0.0'
|
||||||
|
config.port = 5007
|
||||||
|
config.batch_size = 1
|
||||||
|
config.logging = False
|
||||||
|
config.pre_compile = 'loglikelihood'
|
||||||
|
config.default_temperature = 1.0
|
||||||
|
config.greedy_until_max_length = 5000
|
||||||
|
config.prepend_to_prefix = ''
|
||||||
|
config.append_to_prefix = ''
|
||||||
|
config.prepend_to_text = ''
|
||||||
|
config.append_to_text = ''
|
||||||
|
config.chat_prepend_text = ''
|
||||||
|
config.chat_user_prefix = ''
|
||||||
|
config.chat_user_suffix = ''
|
||||||
|
config.chat_lm_prefix = ''
|
||||||
|
config.chat_lm_suffix = ''
|
||||||
|
config.notes = ''
|
||||||
|
|
||||||
|
if updates is not None:
|
||||||
|
config.update(ConfigDict(updates).copy_and_resolve_references())
|
||||||
|
return config
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
self.config = self.get_default_config(config)
|
||||||
|
self.lock = Lock()
|
||||||
|
self.app = FastAPI()
|
||||||
|
self.app.post('/loglikelihood')(self.serve_loglikelihood)
|
||||||
|
self.app.post('/loglikelihood-rolling')(self.serve_loglikelihood_rolling)
|
||||||
|
self.app.post('/generate')(self.serve_generate)
|
||||||
|
self.app.post('/greedy-until')(self.serve_greedy_until)
|
||||||
|
self.app.post('/chat')(self.serve_chat)
|
||||||
|
self.app.get('/ready')(self.serve_ready)
|
||||||
|
self.app = gr.mount_gradio_app(self.app, self.create_chat_app(), '/')
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def loglikelihood(prefix_text, text):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def loglikelihood_rolling(text):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def generate(text, temperature):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def greedy_until(prefix_text, until, max_length):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def to_list(x):
|
||||||
|
if isinstance(x, np.ndarray):
|
||||||
|
return x.tolist()
|
||||||
|
return x
|
||||||
|
|
||||||
|
def serve_ready(self):
|
||||||
|
return 'Ready!\n'
|
||||||
|
|
||||||
|
def serve_loglikelihood(self, data: InferenceRequest):
|
||||||
|
with self.lock:
|
||||||
|
if self.config.logging:
|
||||||
|
absl.logging.info(
|
||||||
|
'\n========= Serving Log Likelihood Request ========= \n'
|
||||||
|
+ pprint.pformat(data) + '\n'
|
||||||
|
)
|
||||||
|
|
||||||
|
if data.prefix_text is None:
|
||||||
|
data.prefix_text = ['' for _ in data.text]
|
||||||
|
|
||||||
|
prefix_text = [
|
||||||
|
self.config.prepend_to_prefix + p + self.config.append_to_prefix
|
||||||
|
for p in data.prefix_text
|
||||||
|
]
|
||||||
|
text = [
|
||||||
|
self.config.prepend_to_text + t + self.config.append_to_text
|
||||||
|
for t in data.text
|
||||||
|
]
|
||||||
|
|
||||||
|
log_likelihood = []
|
||||||
|
is_greedy = []
|
||||||
|
for i in trange(0, len(text), self.config.batch_size, ncols=0):
|
||||||
|
batch_prefix_text = prefix_text[i:i + self.config.batch_size]
|
||||||
|
batch_text = text[i:i + self.config.batch_size]
|
||||||
|
batch_size = len(batch_text)
|
||||||
|
|
||||||
|
if batch_size < self.config.batch_size:
|
||||||
|
extra = self.config.batch_size - batch_size
|
||||||
|
batch_prefix_text.extend(['a' for _ in range(extra)])
|
||||||
|
batch_text.extend(['a' for _ in range(extra)])
|
||||||
|
|
||||||
|
batch_log_likelihood, batch_is_greedy = self.loglikelihood(
|
||||||
|
batch_prefix_text, batch_text
|
||||||
|
)
|
||||||
|
batch_log_likelihood = self.to_list(batch_log_likelihood)
|
||||||
|
batch_is_greedy = self.to_list(batch_is_greedy)
|
||||||
|
log_likelihood.extend(batch_log_likelihood[:batch_size])
|
||||||
|
is_greedy.extend(batch_is_greedy[:batch_size])
|
||||||
|
|
||||||
|
output = {
|
||||||
|
'prefix_text': data.prefix_text,
|
||||||
|
'text': data.text,
|
||||||
|
'log_likelihood': log_likelihood,
|
||||||
|
'is_greedy': is_greedy,
|
||||||
|
}
|
||||||
|
if self.config.logging:
|
||||||
|
absl.logging.info(
|
||||||
|
'\n========= Output ========= \n'
|
||||||
|
+ pprint.pformat(output) + '\n'
|
||||||
|
)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
def serve_loglikelihood_rolling(self, data: InferenceRequest):
|
||||||
|
with self.lock:
|
||||||
|
if self.config.logging:
|
||||||
|
absl.logging.info(
|
||||||
|
'\n========= Serving Log Likelihood Request ========= \n'
|
||||||
|
+ pprint.pformat(data) + '\n'
|
||||||
|
)
|
||||||
|
|
||||||
|
text = [
|
||||||
|
self.config.prepend_to_text + t + self.config.append_to_text
|
||||||
|
for t in data.text
|
||||||
|
]
|
||||||
|
log_likelihood = []
|
||||||
|
is_greedy = []
|
||||||
|
for i in trange(0, len(text), self.config.batch_size, ncols=0):
|
||||||
|
batch_text = text[i:i + self.config.batch_size]
|
||||||
|
batch_size = len(batch_text)
|
||||||
|
|
||||||
|
if batch_size < self.config.batch_size:
|
||||||
|
extra = self.config.batch_size - batch_size
|
||||||
|
batch_text.extend(['a' for _ in range(extra)])
|
||||||
|
|
||||||
|
batch_log_likelihood, batch_is_greedy = self.loglikelihood_rolling(
|
||||||
|
batch_text
|
||||||
|
)
|
||||||
|
batch_log_likelihood = self.to_list(batch_log_likelihood)
|
||||||
|
batch_is_greedy = self.to_list(batch_is_greedy)
|
||||||
|
log_likelihood.extend(batch_log_likelihood[:batch_size])
|
||||||
|
is_greedy.extend(batch_is_greedy[:batch_size])
|
||||||
|
|
||||||
|
output = {
|
||||||
|
'text': data.text,
|
||||||
|
'log_likelihood': log_likelihood,
|
||||||
|
'is_greedy': is_greedy,
|
||||||
|
}
|
||||||
|
if self.config.logging:
|
||||||
|
absl.logging.info(
|
||||||
|
'\n========= Output ========= \n'
|
||||||
|
+ pprint.pformat(output) + '\n'
|
||||||
|
)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
def serve_generate(self, data: InferenceRequest):
|
||||||
|
with self.lock:
|
||||||
|
if self.config.logging:
|
||||||
|
absl.logging.info(
|
||||||
|
'\n========= Serving Generate Request ========= \n'
|
||||||
|
+ pprint.pformat(data) + '\n'
|
||||||
|
)
|
||||||
|
prefix_text = [
|
||||||
|
self.config.prepend_to_prefix + p + self.config.append_to_prefix
|
||||||
|
for p in data.prefix_text
|
||||||
|
]
|
||||||
|
|
||||||
|
if data.temperature is None:
|
||||||
|
data.temperature = self.config.default_temperature
|
||||||
|
|
||||||
|
output_text = []
|
||||||
|
for i in trange(0, len(prefix_text), self.config.batch_size, ncols=0):
|
||||||
|
batch_prefix_text = prefix_text[i:i + self.config.batch_size]
|
||||||
|
batch_size = len(batch_prefix_text)
|
||||||
|
|
||||||
|
if batch_size < self.config.batch_size:
|
||||||
|
extra = self.config.batch_size - batch_size
|
||||||
|
batch_prefix_text.extend(['a' for _ in range(extra)])
|
||||||
|
|
||||||
|
batch_output_text = self.generate(
|
||||||
|
batch_prefix_text,
|
||||||
|
temperature=data.temperature,
|
||||||
|
)
|
||||||
|
output_text.extend(self.to_list(batch_output_text)[:batch_size])
|
||||||
|
|
||||||
|
output = {
|
||||||
|
'prefix_text': data.prefix_text,
|
||||||
|
'output_text': output_text,
|
||||||
|
'temperature': data.temperature,
|
||||||
|
}
|
||||||
|
if self.config.logging:
|
||||||
|
absl.logging.info(
|
||||||
|
'\n========= Output ========= \n'
|
||||||
|
+ pprint.pformat(output) + '\n'
|
||||||
|
)
|
||||||
|
return output
|
||||||
|
|
||||||
|
def serve_greedy_until(self, data: InferenceRequest):
|
||||||
|
with self.lock:
|
||||||
|
if self.config.logging:
|
||||||
|
absl.logging.info(
|
||||||
|
'\n========= Serving Greedy Until Request ========= \n'
|
||||||
|
+ pprint.pformat(data) + '\n'
|
||||||
|
)
|
||||||
|
prefix_text = [
|
||||||
|
self.config.prepend_to_prefix + p + self.config.append_to_prefix
|
||||||
|
for p in data.prefix_text
|
||||||
|
]
|
||||||
|
until = data.until
|
||||||
|
max_length = self.config.greedy_until_max_length
|
||||||
|
|
||||||
|
output_text = []
|
||||||
|
for i in range(0, len(prefix_text), self.config.batch_size):
|
||||||
|
batch_prefix_text = prefix_text[i:i + self.config.batch_size]
|
||||||
|
batch_until = until[i:i + self.config.batch_size]
|
||||||
|
batch_size = len(batch_prefix_text)
|
||||||
|
|
||||||
|
batch_output_text = self.greedy_until(batch_prefix_text, batch_until, max_length)
|
||||||
|
output_text.extend(self.to_list(batch_output_text)[:batch_size])
|
||||||
|
|
||||||
|
output = {
|
||||||
|
'prefix_text': data.prefix_text,
|
||||||
|
'until': data.until,
|
||||||
|
'max_length': max_length,
|
||||||
|
'output_text': output_text,
|
||||||
|
}
|
||||||
|
if self.config.logging:
|
||||||
|
absl.logging.info(
|
||||||
|
'\n========= Output ========= \n'
|
||||||
|
+ pprint.pformat(output) + '\n'
|
||||||
|
)
|
||||||
|
return output
|
||||||
|
|
||||||
|
def process_chat(self, prompt, context, temperature):
|
||||||
|
context = (
|
||||||
|
context + self.config.chat_user_prefix
|
||||||
|
+ prompt + self.config.chat_user_suffix
|
||||||
|
+ self.config.chat_lm_prefix
|
||||||
|
)
|
||||||
|
response = self.generate(
|
||||||
|
[self.config.chat_prepend_text + context],
|
||||||
|
temperature=float(temperature),
|
||||||
|
)[0]
|
||||||
|
context = context + response + self.config.chat_lm_suffix
|
||||||
|
return response, context
|
||||||
|
|
||||||
|
def serve_chat(self, data: ChatRequest):
|
||||||
|
if data.temperature is None:
|
||||||
|
data.temperature = self.config.default_temperature
|
||||||
|
response, context = self.process_chat(
|
||||||
|
data.prompt, data.context,
|
||||||
|
temperature=data.temperature,
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
'response': response,
|
||||||
|
'context': context,
|
||||||
|
'temperature': data.temperature,
|
||||||
|
}
|
||||||
|
|
||||||
|
def create_chat_app(self):
|
||||||
|
with gr.Blocks(analytics_enabled=False, title='EasyLM Chat') as gradio_chatbot:
|
||||||
|
gr.Markdown('# Chatbot Powered by [EasyLM](https://github.com/young-geng/EasyLM)')
|
||||||
|
gr.Markdown(self.config.notes)
|
||||||
|
chatbot = gr.Chatbot(label='Chat history')
|
||||||
|
msg = gr.Textbox(
|
||||||
|
placeholder='Type your message here...',
|
||||||
|
show_label=False
|
||||||
|
)
|
||||||
|
with gr.Row():
|
||||||
|
send = gr.Button('Send')
|
||||||
|
regenerate = gr.Button('Regenerate', interactive=False)
|
||||||
|
clear = gr.Button('Reset')
|
||||||
|
|
||||||
|
temp_slider = gr.Slider(
|
||||||
|
label='Temperature', minimum=0, maximum=2.0,
|
||||||
|
value=self.config.default_temperature
|
||||||
|
)
|
||||||
|
|
||||||
|
context_state = gr.State(['', ''])
|
||||||
|
|
||||||
|
def user_fn(user_message, history, context):
|
||||||
|
return {
|
||||||
|
msg: gr.update(value='', interactive=False),
|
||||||
|
clear: gr.update(interactive=False),
|
||||||
|
send: gr.update(interactive=False),
|
||||||
|
regenerate: gr.update(interactive=False),
|
||||||
|
chatbot: history + [[user_message, None]],
|
||||||
|
context_state: [context[1], context[1]],
|
||||||
|
}
|
||||||
|
|
||||||
|
def model_fn(history, context, temperature):
|
||||||
|
history[-1][1], new_context = self.process_chat(
|
||||||
|
history[-1][0], context[0], temperature
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
msg: gr.update(value='', interactive=True),
|
||||||
|
clear: gr.update(interactive=True),
|
||||||
|
send: gr.update(interactive=True),
|
||||||
|
chatbot: history,
|
||||||
|
context_state: [context[0], new_context],
|
||||||
|
regenerate: gr.update(interactive=True),
|
||||||
|
}
|
||||||
|
|
||||||
|
def regenerate_fn():
|
||||||
|
return {
|
||||||
|
msg: gr.update(value='', interactive=False),
|
||||||
|
clear: gr.update(interactive=False),
|
||||||
|
send: gr.update(interactive=False),
|
||||||
|
regenerate: gr.update(interactive=False),
|
||||||
|
}
|
||||||
|
|
||||||
|
def clear_fn():
|
||||||
|
return {
|
||||||
|
chatbot: None,
|
||||||
|
msg: '',
|
||||||
|
context_state: ['', ''],
|
||||||
|
regenerate: gr.update(interactive=False),
|
||||||
|
}
|
||||||
|
|
||||||
|
msg.submit(
|
||||||
|
user_fn,
|
||||||
|
inputs=[msg, chatbot, context_state],
|
||||||
|
outputs=[msg, clear, send, chatbot, context_state, regenerate],
|
||||||
|
queue=False
|
||||||
|
).then(
|
||||||
|
model_fn,
|
||||||
|
inputs=[chatbot, context_state, temp_slider],
|
||||||
|
outputs=[msg, clear, send, chatbot, context_state, regenerate],
|
||||||
|
queue=True
|
||||||
|
)
|
||||||
|
send.click(
|
||||||
|
user_fn,
|
||||||
|
inputs=[msg, chatbot, context_state],
|
||||||
|
outputs=[msg, clear, send, chatbot, context_state, regenerate],
|
||||||
|
queue=False
|
||||||
|
).then(
|
||||||
|
model_fn,
|
||||||
|
inputs=[chatbot, context_state, temp_slider],
|
||||||
|
outputs=[msg, clear, send, chatbot, context_state, regenerate],
|
||||||
|
queue=True
|
||||||
|
)
|
||||||
|
regenerate.click(
|
||||||
|
regenerate_fn,
|
||||||
|
inputs=None,
|
||||||
|
outputs=[msg, clear, send, regenerate],
|
||||||
|
queue=False
|
||||||
|
).then(
|
||||||
|
model_fn,
|
||||||
|
inputs=[chatbot, context_state, temp_slider],
|
||||||
|
outputs=[msg, clear, send, chatbot, context_state, regenerate],
|
||||||
|
queue=True
|
||||||
|
)
|
||||||
|
clear.click(
|
||||||
|
clear_fn,
|
||||||
|
inputs=None,
|
||||||
|
outputs=[chatbot, msg, context_state, regenerate],
|
||||||
|
queue=False
|
||||||
|
)
|
||||||
|
|
||||||
|
gradio_chatbot.queue(concurrency_count=1)
|
||||||
|
return gradio_chatbot
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
if self.config.pre_compile != '':
|
||||||
|
if self.config.pre_compile == 'all':
|
||||||
|
pre_compile = ['loglikelihood', 'generate', 'greedy_until', 'chat']
|
||||||
|
else:
|
||||||
|
pre_compile = self.config.pre_compile.split(',')
|
||||||
|
|
||||||
|
pre_compile_data = ['a' for _ in range(self.config.batch_size)]
|
||||||
|
for task in pre_compile:
|
||||||
|
if task == 'loglikelihood':
|
||||||
|
self.loglikelihood(pre_compile_data, pre_compile_data)
|
||||||
|
self.loglikelihood_rolling(pre_compile_data)
|
||||||
|
elif task == 'generate':
|
||||||
|
self.generate(pre_compile_data, 1.0)
|
||||||
|
elif task == 'greedy_until':
|
||||||
|
self.greedy_until(
|
||||||
|
pre_compile_data, pre_compile_data,
|
||||||
|
self.config.greedy_until_max_length
|
||||||
|
)
|
||||||
|
elif task == 'chat':
|
||||||
|
self.process_chat('a', 'a', 1.0)
|
||||||
|
else:
|
||||||
|
raise ValueError(f'Invalid precompile task: {task}!')
|
||||||
|
|
||||||
|
uvicorn.run(self.app, host=self.config.host, port=self.config.port)
|
||||||
|
|
||||||
|
|
||||||
|
class LMClient(object):
|
||||||
|
""" A simple client for the LM server. """
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_default_config(updates=None):
|
||||||
|
config = ConfigDict()
|
||||||
|
config.url = 'http://localhost:5007'
|
||||||
|
config.batch_size = 1
|
||||||
|
config.wait_for_ready = True
|
||||||
|
config.dummy = False
|
||||||
|
|
||||||
|
if updates is not None:
|
||||||
|
config.update(ConfigDict(updates).copy_and_resolve_references())
|
||||||
|
return config
|
||||||
|
|
||||||
|
def __init__(self, config=None):
|
||||||
|
self.config = self.get_default_config(config)
|
||||||
|
if self.config.wait_for_ready:
|
||||||
|
self.wait_for_ready()
|
||||||
|
|
||||||
|
def wait_for_ready(self):
|
||||||
|
if self.config.dummy:
|
||||||
|
return
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
requests.get(urllib.parse.urljoin(self.config.url, 'ready'))
|
||||||
|
return
|
||||||
|
except (Timeout, ConnectionError) as e:
|
||||||
|
time.sleep(10)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def batched(iterator, batch_size):
|
||||||
|
batch = []
|
||||||
|
for example in iterator:
|
||||||
|
batch.append(example)
|
||||||
|
if len(batch) == batch_size:
|
||||||
|
yield batch
|
||||||
|
batch = []
|
||||||
|
if len(batch) > 0:
|
||||||
|
yield batch
|
||||||
|
|
||||||
|
def loglikelihood(self, prefix, text):
|
||||||
|
prefix, text = list(prefix), list(text)
|
||||||
|
if self.config.dummy:
|
||||||
|
return [-1.0 for _ in text], [False for _ in text]
|
||||||
|
|
||||||
|
log_likelihood = []
|
||||||
|
is_greedy = []
|
||||||
|
|
||||||
|
batched_iterator = list(zip(
|
||||||
|
self.batched(prefix, self.config.batch_size),
|
||||||
|
self.batched(text, self.config.batch_size)
|
||||||
|
))
|
||||||
|
for batch_prefix, batch_text in tqdm(batched_iterator, ncols=0):
|
||||||
|
response = requests.post(
|
||||||
|
urllib.parse.urljoin(self.config.url, 'loglikelihood'),
|
||||||
|
json={'prefix_text': batch_prefix, 'text': batch_text}
|
||||||
|
).json()
|
||||||
|
log_likelihood.extend(response['log_likelihood'])
|
||||||
|
is_greedy.extend(response['is_greedy'])
|
||||||
|
|
||||||
|
return log_likelihood, is_greedy
|
||||||
|
|
||||||
|
def loglikelihood_rolling(self, text):
|
||||||
|
text = list(text)
|
||||||
|
if self.config.dummy:
|
||||||
|
return [-1.0 for _ in text], [False for _ in text]
|
||||||
|
|
||||||
|
log_likelihood = []
|
||||||
|
is_greedy = []
|
||||||
|
batched_iterator = list(self.batched(text, self.config.batch_size))
|
||||||
|
for batch_text in tqdm(batched_iterator, ncols=0):
|
||||||
|
response = requests.post(
|
||||||
|
urllib.parse.urljoin(self.config.url, 'loglikelihood-rolling'),
|
||||||
|
json={'text': batch_text}
|
||||||
|
).json()
|
||||||
|
log_likelihood.extend(response['log_likelihood'])
|
||||||
|
is_greedy.extend(response['is_greedy'])
|
||||||
|
return log_likelihood, is_greedy
|
||||||
|
|
||||||
|
def greedy_until(self, prefix, until):
|
||||||
|
prefix, until = list(prefix), list(until)
|
||||||
|
if self.config.dummy:
|
||||||
|
results = []
|
||||||
|
for u in until:
|
||||||
|
if isinstance(u, str):
|
||||||
|
results.append('dummy text ' + u)
|
||||||
|
else:
|
||||||
|
results.append('dummy text ' + u[0])
|
||||||
|
return results
|
||||||
|
|
||||||
|
batched_iterator = list(zip(
|
||||||
|
self.batched(prefix, self.config.batch_size),
|
||||||
|
self.batched(until, self.config.batch_size),
|
||||||
|
))
|
||||||
|
output_text = []
|
||||||
|
for batch_prefix, batch_until in tqdm(batched_iterator, ncols=0):
|
||||||
|
response = requests.post(
|
||||||
|
urllib.parse.urljoin(self.config.url, 'greedy-until'),
|
||||||
|
json={'prefix_text': batch_prefix, 'until': batch_until}
|
||||||
|
).json()
|
||||||
|
output_text.extend(response['output_text'])
|
||||||
|
return output_text
|
||||||
|
|
||||||
|
def generate(self, prefix, temperature=None):
|
||||||
|
prefix = list(prefix)
|
||||||
|
if self.config.dummy:
|
||||||
|
return ['' for _ in prefix]
|
||||||
|
|
||||||
|
output_text = []
|
||||||
|
batched_iterator = list(self.batched(prefix, self.config.batch_size))
|
||||||
|
for batch_prefix in tqdm(batched_iterator, ncols=0):
|
||||||
|
response = requests.post(
|
||||||
|
urllib.parse.urljoin(self.config.url, 'generate'),
|
||||||
|
json={
|
||||||
|
'prefix_text': batch_prefix,
|
||||||
|
'temperature': temperature,
|
||||||
|
}
|
||||||
|
).json()
|
||||||
|
output_text.extend(response['output_text'])
|
||||||
|
return output_text
|
||||||
|
|
||||||
|
def chat(self, prompt, context, temperature=None):
|
||||||
|
if self.config.dummy:
|
||||||
|
return ''
|
||||||
|
response = requests.post(
|
||||||
|
urllib.parse.urljoin(self.config.url, 'chat'),
|
||||||
|
json={
|
||||||
|
'prompt': prompt,
|
||||||
|
'context': context,
|
||||||
|
'temperature': temperature,
|
||||||
|
}
|
||||||
|
).json()
|
||||||
|
return response['response'], response['context']
|
||||||
302
README.md
Normal file
302
README.md
Normal file
@@ -0,0 +1,302 @@
|
|||||||
|
---
|
||||||
|
datasets:
|
||||||
|
- Finnish-NLP/CulturaX_fi_cleaned
|
||||||
|
- Finnish-NLP/HPLT_1.2_fi_cleaned
|
||||||
|
- Finnish-NLP/wikipedia_20231101_fi_cleaned
|
||||||
|
- Finnish-NLP/Reddit_fi_2006_2022
|
||||||
|
- intfloat/multilingual_cc_news
|
||||||
|
language:
|
||||||
|
- fi
|
||||||
|
license: apache-2.0
|
||||||
|
pipeline_tag: text-generation
|
||||||
|
tags:
|
||||||
|
- finnish
|
||||||
|
- llama
|
||||||
|
library_name: transformers
|
||||||
|
---
|
||||||
|
|
||||||
|
# Ahma-7B for Finnish
|
||||||
|
|
||||||
|
Ahma-7B is a 7B parameter decoder-only transformer model based on Meta's Llama (v1) architecture, pretrained from scratch on the Finnish language. Its development was informed by the research presented in the paper [Scaling Data-Constrained Language Models](https://huggingface.co/papers/2305.16264). The original Llama model architecture was introduced in
|
||||||
|
[this paper](https://arxiv.org/abs/2302.13971)
|
||||||
|
and first released at [this page](https://github.com/facebookresearch/llama).
|
||||||
|
|
||||||
|
What does Ahma mean? Ahma is the Finnish word for wolverine! In the Finnish Lapland, wolverines are the biggest cause of reindeer damage.
|
||||||
|
|
||||||
|
There are two different sized base Ahma models both pretrained from scratch, Ahma-3B for 139B tokens and Ahma-7B for 149B tokens:
|
||||||
|
|
||||||
|
| Model | Context length | Layers | Dim | Heads | Params |
|
||||||
|
|:--------------------------------------------------------------------------------|:---------------|:-------|:-----|:------|:-------|
|
||||||
|
| [Ahma-3B](https://huggingface.co/Finnish-NLP/Ahma-3B) | 2048 | 26 | 3200 | 32 | 3.6B |
|
||||||
|
| [Ahma-7B](https://huggingface.co/Finnish-NLP/Ahma-7B) | 2048 | 32 | 4096 | 32 | 7.0B |
|
||||||
|
|
||||||
|
And two instruct-tuned versions:
|
||||||
|
|
||||||
|
| Model | Context length | Layers | Dim | Heads | Params |
|
||||||
|
|:--------------------------------------------------------------------------------|:---------------|:-------|:-----|:------|:-------|
|
||||||
|
| [Ahma-3B-Instruct](https://huggingface.co/Finnish-NLP/Ahma-3B-Instruct) | 2048 | 26 | 3200 | 32 | 3.6B |
|
||||||
|
| [Ahma-7B-Instruct](https://huggingface.co/Finnish-NLP/Ahma-7B-Instruct) | 2048 | 32 | 4096 | 32 | 7.0B |
|
||||||
|
|
||||||
|
## Paper Abstract
|
||||||
|
|
||||||
|
The current trend of scaling language models involves increasing both parameter count and training dataset size. Extrapolating this trend suggests that training dataset size may soon be limited by the amount of text data available on the internet. Motivated by this limit, we investigate scaling language models in data-constrained regimes. Specifically, we run a large set of experiments varying the extent of data repetition and compute budget, ranging up to 900 billion training tokens and 9 billion parameter models. We find that with constrained data for a fixed compute budget, training with up to 4 epochs of repeated data yields negligible changes to loss compared to having unique data. However, with more repetition, the value of adding compute eventually decays to zero. We propose and empirically validate a scaling law for compute optimality that accounts for the decreasing value of repeated tokens and excess parameters. Finally, we experiment with approaches mitigating data scarcity, including augmenting the training dataset with code data or removing commonly used filters. Models and datasets from our 400 training runs are freely available at this https URL .
|
||||||
|
|
||||||
|
## Intended uses & limitations
|
||||||
|
|
||||||
|
This model was pretrained only in a self-supervised way, without any supervised training. You can use this model for text generation or fine-tune it for a downstream task. This model followed a 2-stage pretraining approach where single-turn instruction-following examples were mixed in with the other training data in the second stage (explained more later in this readme). Thanks to this approach, this pretrained model is already capable of instruction following, but you might get even better results if you specifically fine-tune it for instruction following or other use cases. For instruction-following fine-tuning, you should use the same prompt format showcased below.
|
||||||
|
|
||||||
|
### How to use
|
||||||
|
|
||||||
|
#### Fine-tuning
|
||||||
|
|
||||||
|
We have now added finetuning example notebook along with video! \
|
||||||
|
Notebook: https://huggingface.co/Finnish-NLP/Ahma-3B/blob/main/Finetune_Ahma_3B_example.ipynb \
|
||||||
|
Video: https://www.youtube.com/watch?v=6mbgn9XzpS4
|
||||||
|
|
||||||
|
|
||||||
|
#### Inference
|
||||||
|
|
||||||
|
If you want to use this model for instruction-following, you need to use the same prompt format we used in the second stage of the pretraining (basically the same format what Meta used in their Llama2 models). **Note: do not use "LlamaTokenizer" from transformers library but always use the AutoTokenizer instead, or use the plain sentencepiece tokenizer.** Here is an example using the instruction-following prompt format, with some generation arguments you can modify for your use:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||||
|
|
||||||
|
system_prompt = "Olet tekoälyavustaja. Vastaat aina mahdollisimman avuliaasti. Vastauksesi eivät saa sisältää mitään haitallista, epäeettistä, rasistista, seksististä, vaarallista tai laitonta sisältöä. Jos kysymyksessä ei ole mitään järkeä tai se ei ole asiasisällöltään johdonmukainen, selitä miksi sen sijaan, että vastaisit jotain väärin. Jos et tiedä vastausta kysymykseen, älä kerro väärää tietoa."
|
||||||
|
|
||||||
|
|
||||||
|
def format_prompt(prompt: str) -> str:
|
||||||
|
prompt = f" [INST] <<SYS>>
|
||||||
|
{system_prompt.strip()}
|
||||||
|
<</SYS>>
|
||||||
|
|
||||||
|
{prompt.strip()} [/INST] "
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("Finnish-NLP/Ahma-7B")
|
||||||
|
model = AutoModelForCausalLM.from_pretrained("Finnish-NLP/Ahma-7B")
|
||||||
|
model = model.to("cuda")
|
||||||
|
|
||||||
|
# use the custom prompt format function or the chat template feature in the tokenizer to format your inputs
|
||||||
|
|
||||||
|
# prompt = format_prompt("Listaa kolme hyötyä, joita pienet avoimen lähdekoodin kielimallit tuovat?")
|
||||||
|
# inputs = tokenizer(prompt, return_tensors="pt")
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": system_prompt,
|
||||||
|
},
|
||||||
|
{"role": "user", "content": "Listaa kolme hyötyä, joita pienet avoimen lähdekoodin kielimallit tuovat?"},
|
||||||
|
]
|
||||||
|
inputs = tokenizer.apply_chat_template(
|
||||||
|
messages, tokenize=True, add_generation_prompt=True, return_tensors="pt"
|
||||||
|
)
|
||||||
|
inputs = inputs.to("cuda")
|
||||||
|
|
||||||
|
generated_ids = model.generate(
|
||||||
|
inputs,
|
||||||
|
temperature=0.6,
|
||||||
|
penalty_alpha=0.6,
|
||||||
|
top_k=4,
|
||||||
|
do_sample=True,
|
||||||
|
repetition_penalty=1.2,
|
||||||
|
min_length=5,
|
||||||
|
max_length=2048,
|
||||||
|
)
|
||||||
|
generated_text = tokenizer.batch_decode(
|
||||||
|
generated_ids, skip_special_tokens=False
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
"""
|
||||||
|
1. Parempi luettavuus ja ymmärtäminen: Pienten avoimen lähdekoodin kielimallien avulla voidaan luoda ymmärrettävämpää ja luettavampaa tekstiä, mikä helpottaa ihmisten ymmärtämistä ja tiedon hankkimista.
|
||||||
|
2. Parempi mukautuvuus ja monipuolisuus: Avoimen lähdekoodin mallit antavat kehittäjille mahdollisuuden luoda räätälöityjä ratkaisuja omiin tarpeisiinsa, jolloin he voivat hyödyntää olemassa olevaa tietämystä ja asiantuntemusta.
|
||||||
|
3. Lisääntynyt yhteistyö ja avoimuus: Avoimen lähdekoodin mallien ansiosta kehittäjät voivat tehdä yhteistyötä muiden kanssa, jakaa ideoita ja parantaa koodin laatua jakamalla oivalluksia ja parhaita käytäntöjä. Tämä edistää yhteistyöhön perustuvaa ympäristöä ja kannustaa jatkuvaan parantamiseen.
|
||||||
|
"""
|
||||||
|
```
|
||||||
|
|
||||||
|
You may experiment with different system prompt instructions too if you like.
|
||||||
|
|
||||||
|
### Limitations and bias
|
||||||
|
|
||||||
|
This model was trained only with Finnish texts excluding code so it should not be used for multilingual and code generation use cases.
|
||||||
|
|
||||||
|
The training data used for this model contains a lot of content from the internet, which is far from neutral. Therefore, the model can have biased predictions. This bias will also affect all fine-tuned versions of this model.
|
||||||
|
|
||||||
|
To reduce toxic content, training data was filtered with a toxicity classifier but it cannot truly eliminate all toxic text.
|
||||||
|
|
||||||
|
## Training data
|
||||||
|
|
||||||
|
This model was pretrained on the combination of 14 datasets:
|
||||||
|
- [CulturaX_fi_cleaned](https://huggingface.co/datasets/Finnish-NLP/CulturaX_fi_cleaned), we cleaned Finnish split from the original [CulturaX](https://huggingface.co/datasets/uonlp/CulturaX) dataset
|
||||||
|
- [HPLT_1.2_fi_cleaned](https://huggingface.co/datasets/Finnish-NLP/HPLT_1.2_fi_cleaned), we cleaned Finnish split from the original [HPLT v1.2](https://hplt-project.org/datasets/v1.2) dataset
|
||||||
|
- [wikipedia_20231101_fi_cleaned](https://huggingface.co/datasets/Finnish-NLP/wikipedia_20231101_fi_cleaned), we used the Finnish subset of the wikipedia (November 2023) dataset
|
||||||
|
- [Reddit_fi_2006_2022](https://huggingface.co/datasets/Finnish-NLP/Reddit_fi_2006_2022), filtered and post-processed dataset of Finnish Reddit
|
||||||
|
- [Yle Finnish News Archive 2011-2018](http://urn.fi/urn:nbn:fi:lb-2017070501)
|
||||||
|
- [Yle Finnish News Archive 2019-2020](http://urn.fi/urn:nbn:fi:lb-2021050401)
|
||||||
|
- [Finnish News Agency Archive (STT)](http://urn.fi/urn:nbn:fi:lb-2018121001)
|
||||||
|
- [The Suomi24 Sentences Corpus](http://urn.fi/urn:nbn:fi:lb-2020021803)
|
||||||
|
- [Project Lönnrot](http://www.lonnrot.net/)
|
||||||
|
- [Finnish parliament speeches](https://avoindata.eduskunta.fi)
|
||||||
|
- [multilingual_cc_news](https://huggingface.co/datasets/intfloat/multilingual_cc_news), we used the Finnish subset of the multilingual CC-News dataset
|
||||||
|
- [fi-news-corpus](https://github.com/nkrusch/fi-news-corpus)
|
||||||
|
- Finnish higher education public theses
|
||||||
|
- Finnish single-turn instruction-following datasets, combination of multiple originally openly licensed English datasets translated to Finnish. For example, [Ultrachat, Aya, Capybara, etc](https://huggingface.co/collections/Finnish-NLP/sft-dpo-dataset-65f55dde1139c3cd683ff035)
|
||||||
|
|
||||||
|
|
||||||
|
Raw datasets were automatically cleaned to filter out bad quality and non-Finnish examples. Also, a [perplexity](https://huggingface.co/course/chapter7/3#perplexity-for-language-models) score was calculated for all texts with a KenLM model which was trained with very clean Finnish texts only. This perplexity score can then be used to determine how "clean" Finnish language the text contains. To reduce toxic text, we used Finnish toxicity classifier [TurkuNLP/bert-large-finnish-cased-toxicity](https://huggingface.co/TurkuNLP/bert-large-finnish-cased-toxicity) released by TurkuNLP to classify all text examples. Classified toxicity label scores can then be used to determine how toxic the text is.
|
||||||
|
|
||||||
|
All datasets were concatenated and the whole dataset near deduplicated using MinHashLSH from [text-dedup](https://github.com/ChenghaoMou/text-dedup). Top 95% perplexity score was used as a filtering threshold to filter out the worst quality 5% of texts. To reduce amount of toxic content, the dataset was filtered to include text examples having lower than 80% score for the toxicity labels "label_identity_attack", "label_insult", "label_threat" and "label_severe_toxicity".
|
||||||
|
|
||||||
|
Finally, 20,000 text examples from each of the CulturaX, Wikipedia, Yle, STT, Suomi24, and Reddit datasets were randomly selected for evaluation dataset.
|
||||||
|
|
||||||
|
The final training dataset had 23 billion words (calculated with regex "\w+") and the evaluation dataset had 23 million words. After tokenization, the training dataset had 41 billion tokens and the evaluation dataset had 40 million tokens. For the 2-stage pretraining, training datasets are divided as follows:
|
||||||
|
|
||||||
|
The first stage:
|
||||||
|
|Dataset | Words | Ratio |
|
||||||
|
|:-----------------------------|:------------|:-------------|
|
||||||
|
|CulturaX | 12.820B | 59.88% |
|
||||||
|
|HPLT v1.2 | 5.034B | 23.51% |
|
||||||
|
|Suomi24 | 3.018B | 14.09% |
|
||||||
|
|Reddit | 0.141B | 0.66% |
|
||||||
|
|CC-News | 0.311B | 1.45% |
|
||||||
|
|FI news corpus | 0.004B | 0.02% |
|
||||||
|
|Project Lönnrot | 0.083B | 0.39% |
|
||||||
|
|**TOTAL** | **21.410B** | **100.0%** |
|
||||||
|
|
||||||
|
|
||||||
|
The second stage:
|
||||||
|
|Dataset | Words | Ratio |
|
||||||
|
|:--------------------------------------------------------------|:------------|:------------|
|
||||||
|
|CulturaX (cleaner sample using KenLM perplexity score) | 2.252B | 55.48% |
|
||||||
|
|Wikipedia | 0.095B | 2.34% |
|
||||||
|
|STT | 0.253B | 6.23% |
|
||||||
|
|Yle | 0.212B | 5.22% |
|
||||||
|
|Finnish parliament speeches | 0.021B | 0.52% |
|
||||||
|
|Finnish higher education public theses | 0.855B | 21.07% |
|
||||||
|
|Finnish instruction-following datasets (note: 2X upsampled) | 0.371B | 9.14% |
|
||||||
|
|**TOTAL** | **4.059B** | **100.0%** |
|
||||||
|
|
||||||
|
## Training procedure
|
||||||
|
|
||||||
|
### Preprocessing
|
||||||
|
|
||||||
|
Texts are tokenized using Byte Pair Encoding (BPE) using the implementation from SentencePiece splitting all numbers into individual digits and using bytes to decompose unknown UTF-8 characters. The total
|
||||||
|
vocabulary size is 64k tokens. Inputs are sequences of 2048 consecutive tokens. Texts are not lower cased so this model is case-sensitive: it makes a difference between finnish and Finnish. Both BOS and EOS tokens were used in the pretraining.
|
||||||
|
|
||||||
|
### 2-stage pretraining
|
||||||
|
|
||||||
|
The model was trained on TPUv4-32 VM, sponsored by the [Google TPU Research Cloud](https://sites.research.google/trc/about/). Training was conducted with a slightly modified Jax/Flax based [EasyLM](https://github.com/young-geng/EasyLM) framework, and inspired by the [OpenLLaMA](https://github.com/openlm-research/open_llama) project. The optimizer used was a [Lion](https://arxiv.org/abs/2302.06675).
|
||||||
|
|
||||||
|
The 2-stage pretraining approach was inspired by [MiniCPM](https://shengdinghu.notion.site/MiniCPM-Unveiling-the-Potential-of-End-side-Large-Language-Models-d4d3a8c426424654a4e80e42a711cb20) findings. For the first stage (79% of the entire training), we used noisier web-scraped datasets. For the second stage (21% of the entire training), we primarily used cleaner datasets and instruction-following datasets shuffled together, like in MiniCPM. The learning rate schedule for the 2-stage pretraining was Warmup-Stable-Decay (WSD). During the first stage, the learning rate schedule had a linear warmup for about 8 billion tokens to a peak learning rate of 1e-4 (note: with the Lion optimizer, the learning rate had to be about 10 times smaller than with the commonly used AdamW), followed by a stable phase where the rate of 1e-4 was kept constant. During the second stage, the learning rate schedule had a linear decay from 1e-4 to 6e-6 for the first 7 billion tokens, followed by a stable phase for the remaining tokens.
|
||||||
|
|
||||||
|
In the first stage, the model was trained for 118 billion tokens, which is about three epochs of the first-stage training data, inspired by the findings of [Scaling Data-Constrained Language Models](https://huggingface.co/papers/2305.16264). In the second stage, the model was trained for 31 billion tokens, which is close to five epochs of the second-stage training data.
|
||||||
|
|
||||||
|
Thanks to the WSD learning rate schedule, you can more easily experiment with different first-stage model checkpoints. For example, you could apply the second-stage training on an earlier checkpoint or continue pretraining further before the second stage. Model checkpoints were pushed to this repository every 100,000 training steps (approximately 13 billion tokens).
|
||||||
|
|
||||||
|
- [900K](https://huggingface.co/Finnish-NLP/Ahma-7B/tree/5f6eb9498b17fece810d766f81c711c38a2b2de2)
|
||||||
|
- [800K](https://huggingface.co/Finnish-NLP/Ahma-7B/tree/bc2d607ce302c1b0ff75c229496645cf232c6d98)
|
||||||
|
- [700K](https://huggingface.co/Finnish-NLP/Ahma-7B/tree/69352a497d5953c5290296a1f429a450978c7f7f)
|
||||||
|
- [600K](https://huggingface.co/Finnish-NLP/Ahma-7B/tree/760ab5f865b08d9a512c1df523a5c4deb6874322)
|
||||||
|
- [500K](https://huggingface.co/Finnish-NLP/Ahma-7B/tree/32ea3d35931da8039180e80d67f6c323719ae50a)
|
||||||
|
- [400K](https://huggingface.co/Finnish-NLP/Ahma-7B/tree/d1256a6815983053d0f9934f21f163d764fc5ecd)
|
||||||
|
- [300K](https://huggingface.co/Finnish-NLP/Ahma-7B/tree/1e3094c66e788fe81d2aadad5bf8f0431358bd38)
|
||||||
|
- [200K](https://huggingface.co/Finnish-NLP/Ahma-7B/tree/a4afd130fa0effea047deaaf8bf63b3eba1b323b)
|
||||||
|
- [100K](https://huggingface.co/Finnish-NLP/Ahma-7B/tree/245fad2f5838af1465cb40ad42caef092e875cd9)
|
||||||
|
|
||||||
|
## Evaluation results
|
||||||
|
|
||||||
|
### FIN-bench
|
||||||
|
|
||||||
|
This Ahma 7B base model was primarily evaluated using [FIN-bench by TurkuNLP](https://github.com/TurkuNLP/FIN-bench), and the same evaluation was carried out for other relevant Finnish models for comparison: [FinGPT 8B by TurkuNLP](https://huggingface.co/TurkuNLP/gpt3-finnish-8B), [Viking 7B by TurkuNLP, SiloGen and HPLT](https://huggingface.co/LumiOpen/Viking-7B), and [Poro 34B by SiloGen, TurkuNLP and HPLT](https://huggingface.co/LumiOpen/Poro-34B). Below are the results with 0-shot and 3-shot settings in FIN-bench.
|
||||||
|
|
||||||
|
0-shot results:
|
||||||
|
|
||||||
|
| Benchmark | Ahma 3B base (instruct prompt format) | Ahma 3B Instruct (instruct prompt format) | Ahma 7B base (instruct prompt format) | Ahma 7B Instruct (instruct prompt format) | FinGPT 8B | Viking 7B | Poro 34B (8bit quant) |
|
||||||
|
|:---------------------------|:--------------------------------------|:------------------------------------------|:--------------------------------------|:------------------------------------------|:----------|:----------|:----------------------|
|
||||||
|
| Analogies | 50.77 | 48.46 | 56.92 | 41.54 | 49.23 | 40.00 | 54.62 |
|
||||||
|
| Arithmetic | 27.64 | 22.14 | 11.50 | 14.70 | 33.15 | 30.16 | 30.34 |
|
||||||
|
| Cause and Effect | 59.48 | 58.82 | 59.48 | 53.60 | 66.01 | 58.82 | 62.74 |
|
||||||
|
| Emotions | 36.25 | 28.12 | 36.25 | 27.50 | 22.50 | 26.25 | 35.63 |
|
||||||
|
| Empirical Judgements | 33.33 | 35.35 | 33.33 | 33.33 | 27.27 | 33.33 | 49.49 |
|
||||||
|
| General Knowledge | 44.29 | 48.57 | 51.43 | 37.14 | 40.00 | 24.29 | 51.43 |
|
||||||
|
| HHH Alignment | 42.09 | 41.66 | 44.23 | 43.22 | 41.81 | 42.51 | 42.92 |
|
||||||
|
| Intent Recognition | 24.42 | 26.16 | 43.64 | 56.94 | 17.49 | 22.40 | 68.35 |
|
||||||
|
| Misconceptions | 46.27 | 47.01 | 46.27 | 47.01 | 53.73 | 53.73 | 52.24 |
|
||||||
|
| Paraphrase | 59.50 | 73.00 | 67.00 | 70.50 | 51.00 | 50.00 | 51.00 |
|
||||||
|
| Sentence Ambiguity | 53.33 | 65.00 | 60.00 | 63.33 | 51.67 | 48.33 | 50.00 |
|
||||||
|
| Similarities Abstraction | 65.79 | 68.42 | 71.05 | 61.84 | 60.53 | 65.79 | 60.53 |
|
||||||
|
| **Non-Arithmetic Average** | **47.55** | **48.95** | **51.33** | **48.30** | **46.17** | **44.42** | **52.08** |
|
||||||
|
| **Overall Average** | **36.49** | **34.06** | **29.20** | **29.64** | **38.93** | **36.50** | **40.00** |
|
||||||
|
|
||||||
|
3-shot results:
|
||||||
|
|
||||||
|
| Benchmark | Ahma 3B base (instruct prompt format) | Ahma 3B Instruct (instruct prompt format) | Ahma 7B base (instruct prompt format) | Ahma 7B Instruct (instruct prompt format) | FinGPT 8B | Viking 7B | Poro 34B (8bit quant) |
|
||||||
|
|:---------------------------|:--------------------------------------|:------------------------------------------|:--------------------------------------|:------------------------------------------|:----------|:----------|:----------------------|
|
||||||
|
| Analogies | 50.77 | 49.23 | 49.23 | 43.08 | 40.77 | 54.62 | 76.92 |
|
||||||
|
| Arithmetic | 38.38 | 43.89 | 20.88 | 26.81 | 43.63 | 45.78 | 53.68 |
|
||||||
|
| Cause and Effect | 60.78 | 64.71 | 66.01 | 62.74 | 64.05 | 58.17 | 67.32 |
|
||||||
|
| Emotions | 30.00 | 41.25 | 30.00 | 53.75 | 44.37 | 48.13 | 56.87 |
|
||||||
|
| Empirical Judgements | 46.46 | 44.44 | 39.39 | 39.39 | 32.32 | 43.43 | 63.64 |
|
||||||
|
| General Knowledge | 47.14 | 40.00 | 27.14 | 44.29 | 54.29 | 28.57 | 74.29 |
|
||||||
|
| HHH Alignment | 43.53 | 44.80 | 43.80 | 45.09 | 45.39 | 44.80 | 46.07 |
|
||||||
|
| Intent Recognition | 20.52 | 44.22 | 36.42 | 39.02 | 51.45 | 58.82 | 83.67 |
|
||||||
|
| Misconceptions | 50.75 | 52.24 | 46.27 | 51.49 | 52.99 | 46.27 | 52.99 |
|
||||||
|
| Paraphrase | 50.50 | 58.50 | 57.50 | 65.00 | 53.00 | 54.50 | 55.00 |
|
||||||
|
| Sentence Ambiguity | 53.33 | 48.33 | 53.33 | 51.67 | 51.67 | 53.33 | 66.67 |
|
||||||
|
| Similarities Abstraction | 69.74 | 72.37 | 72.37 | 69.74 | 64.47 | 73.68 | 75.00 |
|
||||||
|
| **Non-Arithmetic Average** | **48.48** | **51.49** | **49.05** | **51.63** | **51.19** | **50.94** | **61.96** |
|
||||||
|
| **Overall Average** | **42.87** | **47.27** | **33.41** | **37.84** | **46.99** | **48.07** | **57.36** |
|
||||||
|
|
||||||
|
As we can see, Ahma 7B base model has bad arithmetic performance but in non-arithmetic tasks it clearly outperforms same sized models like the FinGPT 8B and Viking 7B, especially in 0-shot usage. Ahma 7B base model is even on-par with the 5X larger Poro 34B model, in non-arithmetic tasks in 0-shot usage. This result might be attributed to Ahma's 2-stage pretraining and the inclusion of instruct-following examples during the pretraining phase.
|
||||||
|
|
||||||
|
In a 3-shot setting, the results are more mixed. The poorer performance of Ahma 7B base model in 3-shot settings might be due to the use of the instruct prompt format and having only single-turn instruction-following training examples.
|
||||||
|
|
||||||
|
|
||||||
|
### MTBench Finnish
|
||||||
|
|
||||||
|
This Ahma 7B base model was also evaluated using [MTBench Finnish by LumiOpen](https://github.com/LumiOpen/FastChat/tree/main/fastchat/llm_judge) even though this Ahma model is not fine-tuned for chat. Since the MTBench evaluates also multi-turn chats while Ahma base models were only pretrained with single-turn instruction following examples, we have reported MTBench Finnish results separately for their single-turn and multi-turn evaluation examples. [Poro 34B Chat by SiloGen, TurkuNLP and HPLT](https://huggingface.co/LumiOpen/Poro-34B-chat) model's presumably multi-turn results are copied from their model card for the comparison.
|
||||||
|
|
||||||
|
Single-turn results:
|
||||||
|
|
||||||
|
| Benchmark | Ahma 3B base (instruct prompt format) | Ahma 3B Instruct (instruct prompt format) | Ahma 7B base (instruct prompt format) | Ahma 7B Instruct (instruct prompt format) |
|
||||||
|
|:--------------------|:--------------------------------------|:------------------------------------------|:--------------------------------------|:------------------------------------------|
|
||||||
|
| Coding | 1.00 | 1.00 | 1.70 | 1.10 |
|
||||||
|
| Extraction | 2.00 | 1.30 | 3.10 | 3.00 |
|
||||||
|
| Humanities | 4.05 | 6.20 | 6.60 | 8.00 |
|
||||||
|
| Math | 3.00 | 3.20 | 3.90 | 2.90 |
|
||||||
|
| Reasoning | 2.90 | 4.60 | 3.70 | 5.70 |
|
||||||
|
| Roleplay | 4.80 | 6.50 | 6.60 | 7.20 |
|
||||||
|
| STEM | 5.10 | 5.95 | 6.75 | 7.30 |
|
||||||
|
| Writing | 6.60 | 9.00 | 7.10 | 8.80 |
|
||||||
|
| **Overall Average** | **3.68** | **4.72** | **4.93** | **5.50** |
|
||||||
|
|
||||||
|
Multi-turn results:
|
||||||
|
|
||||||
|
| Benchmark | Ahma 3B base (instruct prompt format) | Ahma 3B Instruct (instruct prompt format) | Ahma 7B base (instruct prompt format) | Ahma 7B Instruct (instruct prompt format) | Poro 34B Chat |
|
||||||
|
|:--------------------|:--------------------------------------|:------------------------------------------|:--------------------------------------|:------------------------------------------|:--------------|
|
||||||
|
| Coding | 1.00 | 1.00 | 1.40 | 1.05 | 3.70 |
|
||||||
|
| Extraction | 1.55 | 1.15 | 2.05 | 2.65 | 6.37 |
|
||||||
|
| Humanities | 3.25 | 6.20 | 4.95 | 7.85 | 9.25 |
|
||||||
|
| Math | 2.20 | 2.70 | 2.50 | 2.40 | 1.20 |
|
||||||
|
| Reasoning | 2.45 | 3.50 | 2.55 | 4.50 | 4.35 |
|
||||||
|
| Roleplay | 4.90 | 6.40 | 6.35 | 6.60 | 7.35 |
|
||||||
|
| STEM | 4.20 | 4.78 | 4.28 | 5.40 | 7.80 |
|
||||||
|
| Writing | 3.80 | 6.65 | 4.10 | 6.25 | 8.50 |
|
||||||
|
| **Overall Average** | **2.92** | **4.05** | **3.52** | **4.59** | **6.06** |
|
||||||
|
|
||||||
|
As we can see, Ahma 7B base model struggles with multi-turn examples, as expected, since it has only been pretrained with single-turn instruction following examples. In addition, coding performance was expectedly poor because the Ahma 7B model is not trained with code data. In single-turn setting, Ahma 7B beats both the Ahma 3B base and Instruct-tuned versions, demonstrating greater base capability to be further improved with Instruct-tuning.
|
||||||
|
|
||||||
|
## Acknowledgements
|
||||||
|
|
||||||
|
This project would not have been possible without compute generously provided by Google through the
|
||||||
|
[TPU Research Cloud](https://sites.research.google/trc/).
|
||||||
|
|
||||||
|
## Team Members
|
||||||
|
|
||||||
|
- Aapo Tanskanen, [Hugging Face profile](https://huggingface.co/aapot), [LinkedIn profile](https://www.linkedin.com/in/aapotanskanen/)
|
||||||
|
- Rasmus Toivanen, [Hugging Face profile](https://huggingface.co/RASMUS), [LinkedIn profile](https://www.linkedin.com/in/rasmustoivanen/)
|
||||||
|
|
||||||
|
Feel free to contact us for more details 🤗
|
||||||
|
|
||||||
|

|
||||||
27
config.json
Normal file
27
config.json
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
{
|
||||||
|
"architectures": [
|
||||||
|
"LlamaForCausalLM"
|
||||||
|
],
|
||||||
|
"attention_bias": false,
|
||||||
|
"attention_dropout": 0.0,
|
||||||
|
"bos_token_id": 1,
|
||||||
|
"eos_token_id": 2,
|
||||||
|
"hidden_act": "silu",
|
||||||
|
"hidden_size": 4096,
|
||||||
|
"initializer_range": 0.02,
|
||||||
|
"intermediate_size": 11008,
|
||||||
|
"max_position_embeddings": 2048,
|
||||||
|
"model_type": "llama",
|
||||||
|
"num_attention_heads": 32,
|
||||||
|
"num_hidden_layers": 32,
|
||||||
|
"num_key_value_heads": 32,
|
||||||
|
"pretraining_tp": 1,
|
||||||
|
"rms_norm_eps": 1e-06,
|
||||||
|
"rope_scaling": null,
|
||||||
|
"rope_theta": 10000.0,
|
||||||
|
"tie_word_embeddings": false,
|
||||||
|
"torch_dtype": "float16",
|
||||||
|
"transformers_version": "4.38.0.dev0",
|
||||||
|
"use_cache": true,
|
||||||
|
"vocab_size": 64256
|
||||||
|
}
|
||||||
4
convert_to_hf_model.sh
Normal file
4
convert_to_hf_model.sh
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
JAX_PLATFORM_NAME=cpu python3 -m EasyLM.models.llama.convert_easylm_to_hf \
|
||||||
|
--load_checkpoint='' \
|
||||||
|
--model_size='7b' \
|
||||||
|
--output_dir='./'
|
||||||
6
generation_config.json
Normal file
6
generation_config.json
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
{
|
||||||
|
"_from_model_config": true,
|
||||||
|
"bos_token_id": 1,
|
||||||
|
"eos_token_id": 2,
|
||||||
|
"transformers_version": "4.38.0.dev0"
|
||||||
|
}
|
||||||
3
model-00001-of-00003.safetensors
Normal file
3
model-00001-of-00003.safetensors
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:625f48801b93273fa419f51704ccc45bba97010337ed52ed8db290767a152c71
|
||||||
|
size 4978830560
|
||||||
3
model-00002-of-00003.safetensors
Normal file
3
model-00002-of-00003.safetensors
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:2d8f06f2f59ecbffe4fce68ab67d5ad530d951fbce594f7c9850d9bce0b739a3
|
||||||
|
size 4991431320
|
||||||
3
model-00003-of-00003.safetensors
Normal file
3
model-00003-of-00003.safetensors
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:c34dac66ee076f27cda65ce37ae74894f30a4e6cc49f898cf1bf67cbf3c1f10e
|
||||||
|
size 4035085208
|
||||||
298
model.safetensors.index.json
Normal file
298
model.safetensors.index.json
Normal file
@@ -0,0 +1,298 @@
|
|||||||
|
{
|
||||||
|
"metadata": {
|
||||||
|
"total_size": 14005313536
|
||||||
|
},
|
||||||
|
"weight_map": {
|
||||||
|
"lm_head.weight": "model-00003-of-00003.safetensors",
|
||||||
|
"model.embed_tokens.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.0.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.0.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.0.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.0.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.0.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.0.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.0.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.0.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.0.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.1.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.1.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.1.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.1.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.1.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.1.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.1.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.1.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.1.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.10.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.10.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.10.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.10.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.10.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.10.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.10.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.10.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.10.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.11.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.11.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.11.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.11.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.11.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.11.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.11.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.11.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.11.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.12.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.12.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.12.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.12.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.12.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.12.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.12.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.12.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.12.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.13.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.13.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.13.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.13.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.13.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.13.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.13.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.13.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.13.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.14.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.14.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.14.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.14.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.14.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.14.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.14.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.14.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.14.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.15.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.15.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.15.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.15.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.15.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.15.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.15.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.15.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.15.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.16.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.16.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.16.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.16.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.16.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.16.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.16.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.16.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.16.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.17.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.17.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.17.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.17.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.17.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.17.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.17.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.17.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.17.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.18.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.18.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.18.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.18.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.18.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.18.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.18.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.18.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.18.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.19.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.19.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.19.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.19.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.19.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.19.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.19.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.19.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.19.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.2.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.2.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.2.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.2.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.2.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.2.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.2.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.2.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.2.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.20.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.20.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.20.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.20.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.20.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.20.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.20.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.20.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.20.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.21.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.21.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.21.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.21.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.21.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.21.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.21.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.21.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.21.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.22.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.22.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.22.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.22.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.22.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.22.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.22.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.22.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.22.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.23.input_layernorm.weight": "model-00003-of-00003.safetensors",
|
||||||
|
"model.layers.23.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
||||||
|
"model.layers.23.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
||||||
|
"model.layers.23.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
||||||
|
"model.layers.23.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
|
||||||
|
"model.layers.23.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.23.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.23.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.23.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
||||||
|
"model.layers.24.input_layernorm.weight": "model-00003-of-00003.safetensors",
|
||||||
|
"model.layers.24.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
||||||
|
"model.layers.24.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
||||||
|
"model.layers.24.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
||||||
|
"model.layers.24.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
|
||||||
|
"model.layers.24.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
|
||||||
|
"model.layers.24.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
|
||||||
|
"model.layers.24.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
|
||||||
|
"model.layers.24.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
|
||||||
|
"model.layers.25.input_layernorm.weight": "model-00003-of-00003.safetensors",
|
||||||
|
"model.layers.25.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
||||||
|
"model.layers.25.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
||||||
|
"model.layers.25.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
||||||
|
"model.layers.25.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
|
||||||
|
"model.layers.25.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
|
||||||
|
"model.layers.25.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
|
||||||
|
"model.layers.25.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
|
||||||
|
"model.layers.25.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
|
||||||
|
"model.layers.26.input_layernorm.weight": "model-00003-of-00003.safetensors",
|
||||||
|
"model.layers.26.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
||||||
|
"model.layers.26.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
||||||
|
"model.layers.26.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
||||||
|
"model.layers.26.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
|
||||||
|
"model.layers.26.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
|
||||||
|
"model.layers.26.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
|
||||||
|
"model.layers.26.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
|
||||||
|
"model.layers.26.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
|
||||||
|
"model.layers.27.input_layernorm.weight": "model-00003-of-00003.safetensors",
|
||||||
|
"model.layers.27.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
||||||
|
"model.layers.27.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
||||||
|
"model.layers.27.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
||||||
|
"model.layers.27.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
|
||||||
|
"model.layers.27.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
|
||||||
|
"model.layers.27.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
|
||||||
|
"model.layers.27.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
|
||||||
|
"model.layers.27.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
|
||||||
|
"model.layers.28.input_layernorm.weight": "model-00003-of-00003.safetensors",
|
||||||
|
"model.layers.28.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
||||||
|
"model.layers.28.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
||||||
|
"model.layers.28.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
||||||
|
"model.layers.28.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
|
||||||
|
"model.layers.28.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
|
||||||
|
"model.layers.28.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
|
||||||
|
"model.layers.28.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
|
||||||
|
"model.layers.28.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
|
||||||
|
"model.layers.29.input_layernorm.weight": "model-00003-of-00003.safetensors",
|
||||||
|
"model.layers.29.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
||||||
|
"model.layers.29.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
||||||
|
"model.layers.29.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
||||||
|
"model.layers.29.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
|
||||||
|
"model.layers.29.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
|
||||||
|
"model.layers.29.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
|
||||||
|
"model.layers.29.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
|
||||||
|
"model.layers.29.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
|
||||||
|
"model.layers.3.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.3.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.3.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.3.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.3.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.3.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.3.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.3.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.3.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.30.input_layernorm.weight": "model-00003-of-00003.safetensors",
|
||||||
|
"model.layers.30.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
||||||
|
"model.layers.30.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
||||||
|
"model.layers.30.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
||||||
|
"model.layers.30.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
|
||||||
|
"model.layers.30.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
|
||||||
|
"model.layers.30.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
|
||||||
|
"model.layers.30.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
|
||||||
|
"model.layers.30.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
|
||||||
|
"model.layers.31.input_layernorm.weight": "model-00003-of-00003.safetensors",
|
||||||
|
"model.layers.31.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
||||||
|
"model.layers.31.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
||||||
|
"model.layers.31.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
||||||
|
"model.layers.31.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
|
||||||
|
"model.layers.31.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
|
||||||
|
"model.layers.31.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
|
||||||
|
"model.layers.31.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
|
||||||
|
"model.layers.31.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
|
||||||
|
"model.layers.4.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.4.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.4.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.4.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.4.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.4.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.4.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.4.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.4.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.5.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.5.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.5.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.5.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.5.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.5.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.5.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.5.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.5.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.6.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.6.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.6.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.6.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.6.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.6.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.6.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.6.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.6.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.7.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.7.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.7.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.7.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.7.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.7.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.7.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.7.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.7.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.8.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.8.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.8.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.8.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.8.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.8.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.8.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.8.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.8.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.9.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.9.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.9.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.9.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.9.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.9.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.9.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.9.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.layers.9.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
||||||
|
"model.norm.weight": "model-00003-of-00003.safetensors"
|
||||||
|
}
|
||||||
|
}
|
||||||
55
pretrain_llama_7b.sh
Executable file
55
pretrain_llama_7b.sh
Executable file
@@ -0,0 +1,55 @@
|
|||||||
|
#! /bin/bash
|
||||||
|
|
||||||
|
# Put your WANDB API key here to enable logging to wandb.
|
||||||
|
export WANDB_API_KEY=''
|
||||||
|
|
||||||
|
# TPU specific flags to improve training throughput
|
||||||
|
export LIBTPU_INIT_ARGS='--xla_jf_spmd_threshold_for_windowed_einsum_mib=0 --xla_tpu_spmd_threshold_for_allgather_cse=10000 --xla_tpu_spmd_rewrite_einsum_with_reshape=true --xla_enable_async_all_gather=true --jax_enable_async_collective_offload=true --xla_tpu_enable_latency_hiding_scheduler=true TPU_MEGACORE=MEGACORE_DENSE'
|
||||||
|
|
||||||
|
|
||||||
|
python3 -m EasyLM.models.llama.llama_train \
|
||||||
|
--jax_distributed.initialize_jax_distributed=True \
|
||||||
|
--mesh_dim='1,-1,4' \
|
||||||
|
--dtype='bf16' \
|
||||||
|
--total_steps=900000 \
|
||||||
|
--eval_freq=50000 \
|
||||||
|
--log_freq=1000 \
|
||||||
|
--save_model_freq=2000 \
|
||||||
|
--save_milestone_freq=50000 \
|
||||||
|
--load_llama_config='7b' \
|
||||||
|
--update_llama_config='' \
|
||||||
|
--load_dataset_state='' \
|
||||||
|
--load_checkpoint='' \
|
||||||
|
--tokenizer.vocab_file='tokenizer.model' \
|
||||||
|
--optimizer.type='lion' \
|
||||||
|
--optimizer.lion_optimizer.weight_decay=1.0 \
|
||||||
|
--optimizer.lion_optimizer.lr_schedule_type='warmup_constant_linear_decay' \
|
||||||
|
--optimizer.lion_optimizer.lr=1e-4 \
|
||||||
|
--optimizer.lion_optimizer.end_lr=1e-5 \
|
||||||
|
--optimizer.lion_optimizer.lr_warmup_steps=60000 \
|
||||||
|
--optimizer.lion_optimizer.lr_constant_steps=900000 \
|
||||||
|
--optimizer.lion_optimizer.lr_decay_steps=100000 \
|
||||||
|
--optimizer.lion_optimizer.bf16_momentum=True \
|
||||||
|
--train_dataset.type='huggingface' \
|
||||||
|
--train_dataset.text_processor.fields='text' \
|
||||||
|
--train_dataset.text_processor.add_eos_token=True \
|
||||||
|
--train_dataset.text_processor.add_bos_token=True \
|
||||||
|
--train_dataset.huggingface_dataset.path='/researchdisk/lm_training_dataset_first_stage' \
|
||||||
|
--train_dataset.huggingface_dataset.split='train' \
|
||||||
|
--train_dataset.huggingface_dataset.seq_length=2048 \
|
||||||
|
--train_dataset.huggingface_dataset.batch_size=64 \
|
||||||
|
--eval_dataset.type='huggingface' \
|
||||||
|
--eval_dataset.text_processor.fields='text' \
|
||||||
|
--eval_dataset.text_processor.add_eos_token=True \
|
||||||
|
--eval_dataset.text_processor.add_bos_token=True \
|
||||||
|
--eval_dataset.huggingface_dataset.path='/researchdisk/lm_training_dataset_first_stage' \
|
||||||
|
--eval_dataset.huggingface_dataset.split='validation' \
|
||||||
|
--eval_dataset.huggingface_dataset.seq_length=2048 \
|
||||||
|
--eval_dataset.huggingface_dataset.batch_size=64 \
|
||||||
|
--checkpointer.save_optimizer_state=True \
|
||||||
|
--logger.online=True \
|
||||||
|
--logger.prefix='EasyLM' \
|
||||||
|
--logger.project="llama-7b-v2" \
|
||||||
|
--logger.output_dir="gs://finnish-nlp-research-us/llama-7b-v2-checkpoint" \
|
||||||
|
--logger.wandb_dir="./"
|
||||||
|
|
||||||
23
special_tokens_map.json
Normal file
23
special_tokens_map.json
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
{
|
||||||
|
"bos_token": {
|
||||||
|
"content": "<s>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false
|
||||||
|
},
|
||||||
|
"eos_token": {
|
||||||
|
"content": "</s>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false
|
||||||
|
},
|
||||||
|
"unk_token": {
|
||||||
|
"content": "<unk>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false
|
||||||
|
}
|
||||||
|
}
|
||||||
223314
tokenizer.json
Normal file
223314
tokenizer.json
Normal file
File diff suppressed because it is too large
Load Diff
3
tokenizer.model
Normal file
3
tokenizer.model
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:1980c00aa3cb5455177a39efa3e60e7b8887ee89c3f7b8950719592a08ad9456
|
||||||
|
size 1400411
|
||||||
64256
tokenizer.vocab
Normal file
64256
tokenizer.vocab
Normal file
File diff suppressed because it is too large
Load Diff
75
tokenizer_config.json
Normal file
75
tokenizer_config.json
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
{
|
||||||
|
"add_bos_token": true,
|
||||||
|
"add_eos_token": false,
|
||||||
|
"add_prefix_space": true,
|
||||||
|
"added_tokens_decoder": {
|
||||||
|
"0": {
|
||||||
|
"content": "<unk>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"1": {
|
||||||
|
"content": "<s>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"2": {
|
||||||
|
"content": "</s>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"3": {
|
||||||
|
"content": "[INST]",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"4": {
|
||||||
|
"content": "[/INST]",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"5": {
|
||||||
|
"content": "<<SYS>>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"6": {
|
||||||
|
"content": "<</SYS>>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"bos_token": "<s>",
|
||||||
|
"chat_template": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = 'Olet tekoälyavustaja. Vastaat aina mahdollisimman avuliaasti. Vastauksesi eivät saa sisältää mitään haitallista, epäeettistä, rasistista, seksististä, vaarallista tai laitonta sisältöä. Jos kysymyksessä ei ole mitään järkeä tai se ei ole asiasisällöltään johdonmukainen, selitä miksi sen sijaan, että vastaisit jotain väärin. Jos et tiedä vastausta kysymykseen, älä kerro väärää tietoa.' %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<<SYS>>\\n' + system_message + '\\n<</SYS>>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + ' [INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + eos_token }}{% endif %}{% endfor %}",
|
||||||
|
"clean_up_tokenization_spaces": false,
|
||||||
|
"eos_token": "</s>",
|
||||||
|
"legacy": false,
|
||||||
|
"model_max_length": 1000000000000000019884624838656,
|
||||||
|
"pad_token": null,
|
||||||
|
"sp_model_kwargs": {},
|
||||||
|
"spaces_between_special_tokens": false,
|
||||||
|
"tokenizer_class": "PreTrainedTokenizerFast",
|
||||||
|
"unk_token": "<unk>",
|
||||||
|
"use_default_system_prompt": false
|
||||||
|
}
|
||||||
10
train_sentencepiece.py
Normal file
10
train_sentencepiece.py
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
import sentencepiece as spm
|
||||||
|
|
||||||
|
spm.SentencePieceTrainer.train(input="/researchdisk/training_dataset_sentences/train.txt", model_prefix="tokenizer",
|
||||||
|
model_type="bpe", split_digits=True, vocab_size=64256, byte_fallback=True,
|
||||||
|
normalization_rule_name="nfkc",
|
||||||
|
user_defined_symbols=["[INST]", "[/INST]", "<<SYS>>", "<</SYS>>"],
|
||||||
|
required_chars="abcdefghijklmnopqrstuvwxyzåäöABCDEFGHIJKLMNOPQRSTUVWXYZÅÄÖ",
|
||||||
|
train_extremely_large_corpus=True,
|
||||||
|
input_sentence_size=500000000, shuffle_input_sentence=True,
|
||||||
|
num_threads=96)
|
||||||
Reference in New Issue
Block a user