初始化项目,由ModelHub XC社区提供模型
Model: Finnish-NLP/Ahma-7B Source: Original Platform
This commit is contained in:
0
EasyLM/__init__.py
Normal file
0
EasyLM/__init__.py
Normal file
228
EasyLM/bpt.py
Normal file
228
EasyLM/bpt.py
Normal file
@@ -0,0 +1,228 @@
|
||||
"""
|
||||
An implementation of Blockwise parallel transformer https://arxiv.org/abs/2305.19370
|
||||
Also include a reference implementation of memory-efficient transformer https://arxiv.org/abs/2112.05682
|
||||
"""
|
||||
|
||||
import functools
|
||||
from typing import NamedTuple
|
||||
|
||||
import flax.linen as nn
|
||||
import jax
|
||||
import jax.lax as lax
|
||||
import jax.numpy as jnp
|
||||
from einops import rearrange
|
||||
|
||||
"""
|
||||
Computing ffn blockwise without materializing the large hidden tensor, training
|
||||
4x longer sequences than the memory-efficient transformer.
|
||||
Blockwise parallel transformer https://arxiv.org/abs/2305.19370 Liu et al. 2023
|
||||
"""
|
||||
def blockwise_ffn(remat_ffn, inputs, chunk_size=2048, deterministic=True):
|
||||
# remat_ffn: a rematerialized ffn with policy jax.checkpoint_policies.nothing_saveable()
|
||||
# inputs: (batch, seq_len, dim)
|
||||
# chunk_size: the chunk size to split the sequence
|
||||
inputs = rearrange(inputs, 'b (c n) d -> b c n d', c=chunk_size)
|
||||
def scan_ffn(remat_ffn, carry, hidden_states):
|
||||
outputs = remat_ffn(hidden_states, deterministic=deterministic)
|
||||
return carry, outputs
|
||||
scan_axis = inputs.ndim - 2
|
||||
_, res = nn.scan(
|
||||
scan_ffn,
|
||||
variable_broadcast="params",
|
||||
split_rngs={"params": False, "dropout": True},
|
||||
in_axes=scan_axis,
|
||||
out_axes=scan_axis,
|
||||
)(remat_ffn, None, inputs)
|
||||
res = rearrange(res, 'b c n d -> b (c n) d')
|
||||
return res
|
||||
|
||||
|
||||
"""
|
||||
Compute attention blockwise without materializing the full attention matrix,
|
||||
initially proposed in memory-efficient transformer https://arxiv.org/abs/2112.05682 Rabe et al. 2021;
|
||||
flash attention https://arxiv.org/abs/2205.14135 Dao et al. 2022 proposes a CUDA
|
||||
efficient implementation; blockwise parallel transformer https://arxiv.org/abs/2305.19370
|
||||
Liu et al. 2023 proposes blockwise computing both attention and FFN, enabling 4x
|
||||
longer sequences than memory-efficient/flash-attention and fusion of attention and FFN.
|
||||
"""
|
||||
def blockwise_attn(
|
||||
query, key, value,
|
||||
bias=None,
|
||||
deterministic=True,
|
||||
dropout_rng=None,
|
||||
attn_pdrop=0.0,
|
||||
causal=True,
|
||||
query_chunk_size=2048,
|
||||
key_chunk_size=2048,
|
||||
dtype=jnp.float32,
|
||||
policy=jax.checkpoint_policies.nothing_saveable(),
|
||||
precision=None,
|
||||
float32_logits=True,
|
||||
prevent_cse=True,
|
||||
):
|
||||
# query, key, value: (batch, seq_len, num_heads, dim_per_head)
|
||||
# bias: (batch, seq_len) can be used to mask out attention (e.g. padding)
|
||||
# causal: whether to use causal mask
|
||||
# policy: one of jax.checkpoint_policies
|
||||
query = query / jnp.sqrt(query.shape[-1]).astype(dtype)
|
||||
if float32_logits:
|
||||
query = query.astype(jnp.float32)
|
||||
key = key.astype(jnp.float32)
|
||||
|
||||
batch, q_len, num_heads, dim_per_head = query.shape
|
||||
batch, kv_len, num_heads, dim_per_head = key.shape
|
||||
batch, kv_len, num_heads, dim_per_head = value.shape
|
||||
|
||||
num_q = q_len // query_chunk_size
|
||||
num_kv = kv_len // key_chunk_size
|
||||
query = query.reshape((batch, num_q, query_chunk_size, num_heads, dim_per_head))
|
||||
key = key.reshape((batch, num_kv, key_chunk_size, num_heads, dim_per_head))
|
||||
value = value.reshape((batch, num_kv, key_chunk_size, num_heads, dim_per_head))
|
||||
|
||||
query = jnp.moveaxis(query, 1, 0)
|
||||
key = jnp.moveaxis(key, 1, 0)
|
||||
value = jnp.moveaxis(value, 1, 0)
|
||||
|
||||
if bias is not None:
|
||||
for bias_dim, broadcast_dim in zip(bias.shape, (batch, num_heads, q_len, kv_len)):
|
||||
assert bias_dim == 1 or bias_dim == broadcast_dim
|
||||
if not deterministic and attn_pdrop > 0.0:
|
||||
attn_dropout_rng, dropout_rng = jax.random.split(dropout_rng)
|
||||
attn_dropout = jax.random.bernoulli(attn_dropout_rng, attn_pdrop, (batch, num_heads, q_len, kv_len))
|
||||
else:
|
||||
attn_dropout = None
|
||||
|
||||
_chunk_bias_fn = functools.partial(
|
||||
_chunk_attention_bias,
|
||||
query_chunk_size, key_chunk_size, bias, deterministic,
|
||||
attn_dropout, attn_pdrop, causal, dtype)
|
||||
|
||||
def scan_attention(args):
|
||||
query_chunk, query_chunk_idx = args
|
||||
|
||||
@functools.partial(jax.checkpoint, prevent_cse=prevent_cse, policy=policy)
|
||||
def scan_kv_block(carry, args):
|
||||
key_chunk, value_chunk, key_chunk_idx = args
|
||||
(numerator, denominator, prev_max_score) = carry
|
||||
attn_weights = jnp.einsum('bqhd,bkhd->bqhk', query_chunk, key_chunk, precision=precision)
|
||||
bias_chunk = _chunk_bias_fn(query_chunk_idx, key_chunk_idx)
|
||||
bias_chunk = jnp.moveaxis(bias_chunk, 1, 2)
|
||||
attn_weights = attn_weights + bias_chunk
|
||||
|
||||
max_score = jnp.max(attn_weights, axis=-1, keepdims=True)
|
||||
max_score = jnp.maximum(prev_max_score, max_score)
|
||||
max_score = jax.lax.stop_gradient(max_score)
|
||||
exp_weights = jnp.exp(attn_weights - max_score)
|
||||
exp_values = jnp.einsum(
|
||||
'bqhv,bvhd->bqhd', exp_weights, value_chunk, precision=precision
|
||||
)
|
||||
correction = jnp.exp(prev_max_score - max_score)
|
||||
numerator = numerator * correction + exp_values
|
||||
denominator = denominator * correction + exp_weights.sum(axis=-1, keepdims=True)
|
||||
return Carry(numerator, denominator, max_score), None
|
||||
|
||||
def skip_upper_half(carry, args):
|
||||
key_chunk, value_chunk, key_chunk_idx = args
|
||||
skip_block = jnp.array(False)
|
||||
if causal:
|
||||
skip_block = query_chunk_idx < key_chunk_idx
|
||||
return jax.lax.cond(
|
||||
skip_block,
|
||||
lambda carry, args: (carry, None),
|
||||
scan_kv_block,
|
||||
carry,
|
||||
args,
|
||||
)
|
||||
|
||||
init_carry = Carry(
|
||||
jnp.zeros((batch, query_chunk_size, num_heads, dim_per_head), dtype=query.dtype),
|
||||
jnp.zeros((batch, query_chunk_size, num_heads, dim_per_head), dtype=query.dtype),
|
||||
(-jnp.inf) * jnp.ones((batch, query_chunk_size, num_heads, 1), dtype=query.dtype),
|
||||
)
|
||||
(numerator, denominator, max_score), _ = lax.scan(
|
||||
skip_upper_half, init_carry, xs=(key, value, jnp.arange(0, num_kv))
|
||||
)
|
||||
outputs = (numerator / denominator).astype(dtype)
|
||||
return outputs
|
||||
|
||||
_, res = lax.scan(
|
||||
lambda _, x: ((), scan_attention(x)),
|
||||
(), xs=(query, jnp.arange(0, num_q))
|
||||
)
|
||||
res = rearrange(res, 'n b c h d -> b (n c) h d')
|
||||
return res
|
||||
|
||||
|
||||
class Carry(NamedTuple):
|
||||
numerator: jax.Array
|
||||
denominator: jax.Array
|
||||
max_so_far: jax.Array
|
||||
|
||||
|
||||
def _chunk_attention_bias(query_chunk_size, key_chunk_size,
|
||||
bias, deterministic, attn_dropout, attn_pdrop, causal,
|
||||
dtype, query_chunk_idx, key_chunk_idx):
|
||||
query_offset = query_chunk_idx * query_chunk_size
|
||||
key_offset = key_chunk_idx * key_chunk_size
|
||||
chunk_bias = jnp.zeros((1, 1, 1, 1), dtype=dtype)
|
||||
if bias is not None:
|
||||
chunk_bias = lax.dynamic_slice(
|
||||
bias,
|
||||
start_indices=(0, 0, query_offset, key_offset),
|
||||
slice_sizes=(*bias.shape[:2], min(bias.shape[-2], query_chunk_size), min(bias.shape[-1], key_chunk_size)),
|
||||
)
|
||||
|
||||
if causal:
|
||||
query_idx = lax.broadcasted_iota(dtype=jnp.int32, shape=(query_chunk_size, 1), dimension=0)
|
||||
key_idx = lax.broadcasted_iota(dtype=jnp.int32, shape=(1, key_chunk_size), dimension=1)
|
||||
offset = query_offset - key_offset
|
||||
query_idx += offset
|
||||
causal_mask_value = (query_idx < key_idx) * jnp.finfo(dtype).min
|
||||
chunk_bias += causal_mask_value.reshape(1, 1, *causal_mask_value.shape)
|
||||
|
||||
if not deterministic and attn_pdrop > 0.0:
|
||||
attn_dropout_slice = lax.dynamic_slice(
|
||||
attn_dropout,
|
||||
start_indices=(0, 0, query_offset, key_offset),
|
||||
slice_sizes=(
|
||||
*attn_dropout.shape[:2],
|
||||
min(attn_dropout.shape[-2], query_chunk_size),
|
||||
min(attn_dropout.shape[-1], key_chunk_size),
|
||||
),
|
||||
)
|
||||
chunk_bias += attn_dropout_slice * jnp.finfo(dtype).min
|
||||
return chunk_bias.astype(dtype)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# test
|
||||
def reference_attn(query, key, value, causal, dtype):
|
||||
query = query / jnp.sqrt(query.shape[-1]).astype(dtype)
|
||||
logits = jnp.einsum("bqhc,bkhc->bhqk", query, key)
|
||||
if causal:
|
||||
mask_value = jnp.finfo(logits.dtype).min
|
||||
_, q_seq_len, _, _ = query.shape
|
||||
_, kv_seq_len, _, _ = key.shape
|
||||
mask_shape = (q_seq_len, kv_seq_len)
|
||||
row_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 0)
|
||||
col_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 1)
|
||||
causal_mask = (row_ids < col_ids)[None, None, :, :]
|
||||
logits = logits + jnp.where(causal_mask, mask_value, 0.0)
|
||||
weights = jax.nn.softmax(logits, axis=-1)
|
||||
out = jnp.einsum("bhqk,bkhc->bqhc", weights, value)
|
||||
return out
|
||||
|
||||
# random inputs
|
||||
shape = (1, 32, 8, 64)
|
||||
query = jax.random.normal(jax.random.PRNGKey(0), shape)
|
||||
key = jax.random.normal(jax.random.PRNGKey(1), shape)
|
||||
value = jax.random.normal(jax.random.PRNGKey(2), shape)
|
||||
|
||||
causal = True
|
||||
chunk_size = 4
|
||||
policy = jax.checkpoint_policies.nothing_saveable()
|
||||
|
||||
blockwise = blockwise_attn(query, key, value, None, False, None, 0.0, causal, chunk_size, chunk_size, jnp.float32, policy, 'float32', True, False)
|
||||
reference = reference_attn(query, key, value, causal, 'float32')
|
||||
|
||||
assert jnp.allclose(reference, blockwise, atol=1e-6)
|
||||
212
EasyLM/checkpoint.py
Normal file
212
EasyLM/checkpoint.py
Normal file
@@ -0,0 +1,212 @@
|
||||
import os
|
||||
import numpy as np
|
||||
from ml_collections import ConfigDict
|
||||
import mlxu
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import flax
|
||||
from flax.serialization import (
|
||||
from_bytes, to_bytes, to_state_dict, from_state_dict
|
||||
)
|
||||
from flax.traverse_util import flatten_dict, unflatten_dict, empty_node
|
||||
import msgpack
|
||||
|
||||
from EasyLM.jax_utils import tree_apply, float_tensor_to_dtype
|
||||
|
||||
|
||||
class StreamingCheckpointer(object):
|
||||
""" Custom msgpack checkpointer that saves large train states by serializing
|
||||
and saving tensors one by one in a streaming fashion. Avoids running
|
||||
out of memory or local TPU disk with default flax checkpointer.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def get_default_config(updates=None):
|
||||
config = ConfigDict()
|
||||
config.float_dtype = 'bf16'
|
||||
config.save_optimizer_state = False
|
||||
|
||||
if updates is not None:
|
||||
config.update(ConfigDict(updates).copy_and_resolve_references())
|
||||
return config
|
||||
|
||||
def __init__(self, config, checkpoint_dir, enable=True):
|
||||
self.config = self.get_default_config(config)
|
||||
self.checkpoint_dir = checkpoint_dir
|
||||
self.enable = enable
|
||||
|
||||
def save_checkpoint(self, train_state, filename, gather_fns=None):
|
||||
if self.enable:
|
||||
path = os.path.join(self.checkpoint_dir, filename)
|
||||
else:
|
||||
path = '/dev/null'
|
||||
self.save_train_state_to_file(
|
||||
train_state, path, gather_fns, self.config.float_dtype
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def save_train_state_to_file(train_state, path, gather_fns=None, float_dtype=None):
|
||||
train_state = to_state_dict(train_state)
|
||||
packer = msgpack.Packer()
|
||||
flattend_train_state = flatten_dict(train_state)
|
||||
if gather_fns is not None:
|
||||
gather_fns = flatten_dict(to_state_dict(gather_fns))
|
||||
|
||||
with mlxu.open_file(path, "wb") as fout:
|
||||
for key, value in flattend_train_state.items():
|
||||
if gather_fns is not None:
|
||||
value = gather_fns[key](value)
|
||||
value = float_tensor_to_dtype(value, float_dtype)
|
||||
fout.write(packer.pack((key, to_bytes(value))))
|
||||
|
||||
def save_pickle(self, obj, filename):
|
||||
if self.enable:
|
||||
path = os.path.join(self.checkpoint_dir, filename)
|
||||
else:
|
||||
path = '/dev/null'
|
||||
mlxu.save_pickle(obj, path)
|
||||
|
||||
def save_all(self, train_state, gather_fns, metadata=None, dataset=None, milestone=False):
|
||||
step = int(jax.device_get(train_state.step))
|
||||
if self.config.save_optimizer_state:
|
||||
checkpoint_state = train_state
|
||||
checkpoint_name = 'streaming_train_state'
|
||||
checkpoint_gather_fns = gather_fns
|
||||
else:
|
||||
checkpoint_state = train_state.params['params']
|
||||
checkpoint_name = 'streaming_params'
|
||||
checkpoint_gather_fns = gather_fns.params['params']
|
||||
|
||||
if milestone:
|
||||
# Save a milestone checkpoint that will not be overwritten
|
||||
self.save_pickle(metadata, f'metadata_{step}.pkl')
|
||||
self.save_pickle(dataset, f'dataset_{step}.pkl')
|
||||
self.save_checkpoint(
|
||||
checkpoint_state, f'{checkpoint_name}_{step}', checkpoint_gather_fns
|
||||
)
|
||||
else:
|
||||
# Save a normal checkpoint that can be overwritten
|
||||
self.save_pickle(metadata, 'metadata.pkl')
|
||||
self.save_pickle(dataset, 'dataset.pkl')
|
||||
self.save_checkpoint(
|
||||
checkpoint_state, f'{checkpoint_name}', checkpoint_gather_fns
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def load_checkpoint(path, target=None, shard_fns=None, remove_dict_prefix=None):
|
||||
if shard_fns is not None:
|
||||
shard_fns = flatten_dict(
|
||||
to_state_dict(shard_fns)
|
||||
)
|
||||
if remove_dict_prefix is not None:
|
||||
remove_dict_prefix = tuple(remove_dict_prefix)
|
||||
flattend_train_state = {}
|
||||
with mlxu.open_file(path) as fin:
|
||||
# 83886080 bytes = 80 MB, which is 16 blocks on GCS
|
||||
unpacker = msgpack.Unpacker(fin, read_size=83886080, max_buffer_size=0)
|
||||
for key, value in unpacker:
|
||||
key = tuple(key)
|
||||
if remove_dict_prefix is not None:
|
||||
if key[:len(remove_dict_prefix)] == remove_dict_prefix:
|
||||
key = key[len(remove_dict_prefix):]
|
||||
else:
|
||||
continue
|
||||
|
||||
tensor = from_bytes(None, value)
|
||||
if shard_fns is not None:
|
||||
tensor = shard_fns[key](tensor)
|
||||
flattend_train_state[key] = tensor
|
||||
|
||||
if target is not None:
|
||||
flattened_target = flatten_dict(
|
||||
to_state_dict(target), keep_empty_nodes=True
|
||||
)
|
||||
for key, value in flattened_target.items():
|
||||
if key not in flattend_train_state and value == empty_node:
|
||||
flattend_train_state[key] = value
|
||||
|
||||
train_state = unflatten_dict(flattend_train_state)
|
||||
if target is None:
|
||||
return train_state
|
||||
|
||||
return from_state_dict(target, train_state)
|
||||
|
||||
@staticmethod
|
||||
def load_flax_checkpoint(path, target=None, shard_fns=None):
|
||||
""" Load a standard flax checkpoint that's not saved with the
|
||||
msgpack streaming format.
|
||||
"""
|
||||
with mlxu.open_file(path, "rb") as fin:
|
||||
encoded_bytes = fin.read()
|
||||
|
||||
state_dict = flax.serialization.msgpack_restore(encoded_bytes)
|
||||
if shard_fns is not None:
|
||||
shard_fns = to_state_dict(shard_fns)
|
||||
state_dict = tree_apply(shard_fns, state_dict)
|
||||
|
||||
if target is None:
|
||||
return state_dict
|
||||
return from_state_dict(target, state_dict)
|
||||
|
||||
@classmethod
|
||||
def load_trainstate_checkpoint(cls, load_from, trainstate_target=None,
|
||||
trainstate_shard_fns=None,
|
||||
disallow_trainstate=False):
|
||||
if trainstate_target is not None:
|
||||
params_target = trainstate_target.params['params']
|
||||
else:
|
||||
params_target = None
|
||||
|
||||
if trainstate_shard_fns is not None:
|
||||
params_shard_fns = trainstate_shard_fns.params['params']
|
||||
else:
|
||||
params_shard_fns = None
|
||||
|
||||
load_type, load_path = load_from.split('::', 1)
|
||||
if disallow_trainstate:
|
||||
assert load_type != 'trainstate', 'Loading full trainstate is not allowed!'
|
||||
train_state = None
|
||||
restored_params = None
|
||||
if load_type == 'trainstate':
|
||||
# Load the entire train state in the streaming format
|
||||
train_state = cls.load_checkpoint(
|
||||
path=load_path,
|
||||
target=trainstate_target,
|
||||
shard_fns=trainstate_shard_fns,
|
||||
)
|
||||
elif load_type == 'trainstate_params':
|
||||
# Load the params part of the train state in the streaming format
|
||||
restored_params = cls.load_checkpoint(
|
||||
path=load_path,
|
||||
target=params_target,
|
||||
shard_fns=params_shard_fns,
|
||||
remove_dict_prefix=('params', 'params'),
|
||||
)
|
||||
restored_params = flax.core.frozen_dict.freeze(
|
||||
{'params': restored_params}
|
||||
)
|
||||
elif load_type == 'params':
|
||||
# Load the params in the streaming format
|
||||
restored_params = cls.load_checkpoint(
|
||||
path=load_path,
|
||||
target=params_target,
|
||||
shard_fns=params_shard_fns,
|
||||
)
|
||||
restored_params = flax.core.frozen_dict.freeze(
|
||||
{'params': restored_params}
|
||||
)
|
||||
elif load_type == 'flax_params':
|
||||
# Load the params in the standard flax format (non-streaming)
|
||||
# This requires the entire params to fit in memory
|
||||
restored_params = cls.load_flax_checkpoint(
|
||||
path=load_path,
|
||||
target=params_target,
|
||||
shard_fns=params_shard_fns
|
||||
)
|
||||
restored_params = flax.core.frozen_dict.freeze(
|
||||
{'params': restored_params}
|
||||
)
|
||||
else:
|
||||
raise ValueError(f'Invalid load_from type: {load_type}')
|
||||
|
||||
return train_state, restored_params
|
||||
436
EasyLM/data.py
Normal file
436
EasyLM/data.py
Normal file
@@ -0,0 +1,436 @@
|
||||
import dataclasses
|
||||
import pprint
|
||||
import time
|
||||
from functools import partial
|
||||
import json
|
||||
import base64
|
||||
from multiprocessing import Pool
|
||||
|
||||
import h5py
|
||||
import mlxu
|
||||
from ml_collections.config_dict import config_dict
|
||||
from ml_collections import ConfigDict
|
||||
from tqdm import tqdm, trange
|
||||
import numpy as np
|
||||
|
||||
from datasets import load_dataset, load_from_disk
|
||||
|
||||
|
||||
class DatasetFactory(object):
|
||||
""" Datset builder class. """
|
||||
|
||||
@staticmethod
|
||||
def get_default_config(updates=None):
|
||||
config = ConfigDict()
|
||||
config.type = 'huggingface'
|
||||
config.text_processor = TextProcessor.get_default_config()
|
||||
config.huggingface_dataset = HuggingfaceDataset.get_default_config()
|
||||
config.json_dataset = JsonDataset.get_default_config()
|
||||
|
||||
if updates is not None:
|
||||
config.update(ConfigDict(updates).copy_and_resolve_references())
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def load_dataset(cls, config, tokenizer, **kwargs):
|
||||
config = cls.get_default_config(config)
|
||||
text_processor = TextProcessor(config.text_processor, tokenizer)
|
||||
if config.type == 'huggingface':
|
||||
return HuggingfaceDataset(
|
||||
config.huggingface_dataset, tokenizer, text_processor, **kwargs
|
||||
)
|
||||
elif config.type == 'json':
|
||||
return JsonDataset(config.json_dataset, tokenizer, text_processor, **kwargs)
|
||||
else:
|
||||
raise ValueError(f'Unknown dataset type: {config.type}')
|
||||
|
||||
def __init__(self):
|
||||
raise ValueError('DatasetFactory is a static class and should not be instantiated.')
|
||||
|
||||
|
||||
class TextProcessor(object):
|
||||
""" Example processor that converts a dictionary of texts into tokens. """
|
||||
|
||||
@staticmethod
|
||||
def get_default_config(updates=None):
|
||||
config = ConfigDict()
|
||||
config.fields_from_example = ''
|
||||
config.fields = ''
|
||||
config.subfield_separator = ' '
|
||||
config.add_bos_token = True
|
||||
config.add_eos_token = True
|
||||
config.prepend_text = ''
|
||||
config.base64_token_dtype = 'i4'
|
||||
if updates is not None:
|
||||
config.update(ConfigDict(updates).copy_and_resolve_references())
|
||||
return config
|
||||
|
||||
def __init__(self, config, tokenizer):
|
||||
self.config = self.get_default_config(config)
|
||||
assert self.config.fields != '' or self.config.fields_from_example != '', (
|
||||
'Either fields or fields_from_example must be specified.'
|
||||
)
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
def __call__(self, example, has_aux=False):
|
||||
if has_aux:
|
||||
example, *aux = example
|
||||
else:
|
||||
aux = tuple()
|
||||
token_buffer = []
|
||||
loss_mask_buffer = []
|
||||
|
||||
if self.config.add_bos_token:
|
||||
token_buffer.append(self.tokenizer.bos_token_id)
|
||||
loss_mask_buffer.append(0.0)
|
||||
|
||||
if self.config.fields_from_example != '':
|
||||
fields = example[self.config.fields_from_example].split(',')
|
||||
else:
|
||||
fields = self.config.fields.split(',')
|
||||
|
||||
for i, field in enumerate(fields):
|
||||
if field.startswith('[') and field.endswith(']'):
|
||||
# No loss for this field.
|
||||
field = field[1:-1]
|
||||
mask = 0.0
|
||||
else:
|
||||
mask = 1.0
|
||||
|
||||
if field.startswith('<|') and field.endswith('|>'):
|
||||
# Special tokens.
|
||||
field = field[2:-2]
|
||||
if field == 'bos':
|
||||
token_buffer.append(self.tokenizer.bos_token_id)
|
||||
elif field == 'eos':
|
||||
token_buffer.append(self.tokenizer.eos_token_id)
|
||||
else:
|
||||
# Token ID specified directly.
|
||||
token_buffer.append(int(field))
|
||||
loss_mask_buffer.append(mask)
|
||||
elif field.startswith('{') and field.endswith('}'):
|
||||
field = field[1:-1]
|
||||
# Base64 encoded raw tokens.
|
||||
tokens = np.frombuffer(
|
||||
base64.b64decode(example[field]),
|
||||
dtype=self.config.base64_token_dtype
|
||||
).tolist()
|
||||
token_buffer.extend(tokens)
|
||||
loss_mask_buffer.extend([mask for _ in range(len(tokens))])
|
||||
else:
|
||||
subfields = field.split('+')
|
||||
text = self.config.subfield_separator.join(
|
||||
[example[subfield] for subfield in subfields]
|
||||
)
|
||||
if i == 0:
|
||||
text = self.config.prepend_text + text
|
||||
tokens = self.tokenizer.encode(text)
|
||||
token_buffer.extend(tokens)
|
||||
loss_mask_buffer.extend([mask for _ in range(len(tokens))])
|
||||
|
||||
if self.config.add_eos_token:
|
||||
token_buffer.append(self.tokenizer.eos_token_id)
|
||||
loss_mask_buffer.append(1.0)
|
||||
|
||||
return token_buffer, loss_mask_buffer, *aux
|
||||
|
||||
|
||||
class HuggingfaceDataset(object):
|
||||
""" Huggingface dataset, where the dataset is loaded using the huggingface
|
||||
datasets.load_dataset() function.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def get_default_config(updates=None):
|
||||
config = ConfigDict()
|
||||
config.path = 'c4'
|
||||
config.name = 'en'
|
||||
config.split = 'train'
|
||||
config.streaming = False
|
||||
config.seq_length = 1024
|
||||
config.batch_size = 8
|
||||
config.always_start_with_bos = False
|
||||
config.start_seek_loc = 0
|
||||
config.tokens_count_at_start = 0
|
||||
config.batch_token_dtype = 'i4'
|
||||
config.reset_dataset_loc = False
|
||||
|
||||
if updates is not None:
|
||||
config.update(ConfigDict(updates).copy_and_resolve_references())
|
||||
return config
|
||||
|
||||
def __init__(self, config, tokenizer, text_processor, eval_dataset=False):
|
||||
self.config = self.get_default_config(config)
|
||||
name = self.config.name if self.config.name != '' else None
|
||||
split = self.config.split if self.config.split != '' else None
|
||||
self._tokenizer = tokenizer
|
||||
self._text_processor = text_processor
|
||||
self._dataset = load_from_disk(
|
||||
self.config.path
|
||||
)[split]
|
||||
self._dataset = self._dataset.to_iterable_dataset(num_shards=128 if len(self._dataset) > 128 else len(self._dataset))
|
||||
self._eval_dataset = eval_dataset
|
||||
self._train_epochs = 0
|
||||
self._dataset_loc = self.config.start_seek_loc
|
||||
self._total_tokens = self.config.tokens_count_at_start
|
||||
self._index = 0
|
||||
self.reset_dataset_loc = self.config.reset_dataset_loc
|
||||
|
||||
|
||||
def __iter__(self):
|
||||
if not self._eval_dataset and self._train_epochs > 0:
|
||||
self._dataset = self._dataset.shuffle(seed=42, buffer_size=10000)
|
||||
chunk_size = self.config.batch_size * self.config.seq_length
|
||||
while True:
|
||||
token_buffer = []
|
||||
loss_mask_buffer = []
|
||||
if not self._eval_dataset and self._train_epochs > 0:
|
||||
self._dataset.set_epoch(self._train_epochs)
|
||||
for index, example in enumerate(self._dataset):
|
||||
self._index = index
|
||||
if not self._eval_dataset and self._dataset_loc > index:
|
||||
continue
|
||||
tokens, loss_masks = self.text_processor(example)
|
||||
token_buffer.extend(tokens)
|
||||
loss_mask_buffer.extend(loss_masks)
|
||||
while len(token_buffer) > chunk_size + 1:
|
||||
self._total_tokens += chunk_size
|
||||
metrics = {
|
||||
'dataset_example_index': index,
|
||||
'dataset_total_tokens': self._total_tokens,
|
||||
'epoch': self._train_epochs,
|
||||
}
|
||||
batch = {
|
||||
'input_tokens': np.array(token_buffer[:chunk_size], dtype=self.config.batch_token_dtype).reshape(
|
||||
self.config.batch_size, -1
|
||||
),
|
||||
'target_tokens': np.array(token_buffer[1:chunk_size + 1], dtype=self.config.batch_token_dtype).reshape(
|
||||
self.config.batch_size, -1
|
||||
),
|
||||
'loss_masks': np.array(loss_mask_buffer[1:chunk_size + 1], dtype=np.float32).reshape(
|
||||
self.config.batch_size, -1
|
||||
),
|
||||
}
|
||||
if self.config.always_start_with_bos:
|
||||
batch['input_tokens'][:, 0] = self.tokenizer.bos_token_id
|
||||
yield batch, metrics
|
||||
token_buffer = token_buffer[chunk_size:]
|
||||
loss_mask_buffer = loss_mask_buffer[chunk_size:]
|
||||
|
||||
if self._eval_dataset:
|
||||
break
|
||||
else:
|
||||
if self._train_epochs == 0:
|
||||
self._dataset = self._dataset.shuffle(seed=42, buffer_size=10000)
|
||||
self._dataset_loc = 0
|
||||
self._train_epochs += 1
|
||||
|
||||
def get_state_dict(self):
|
||||
return dict(
|
||||
config=self.config,
|
||||
dataset_loc=self._index,
|
||||
total_tokens=self._total_tokens,
|
||||
epochs=self._train_epochs,
|
||||
)
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
if 'config' in state_dict:
|
||||
self.config.update(ConfigDict(state_dict['config']))
|
||||
self._dataset_loc = state_dict.get('dataset_loc', self.config.start_seek_loc)
|
||||
self._total_tokens = state_dict.get('total_tokens', self.config.tokens_count_at_start)
|
||||
self._train_epochs = state_dict.get('epochs', 0)
|
||||
if self.reset_dataset_loc:
|
||||
self._dataset_loc = 0
|
||||
self._train_epochs = 0
|
||||
|
||||
|
||||
@property
|
||||
def seq_length(self):
|
||||
return self.config.seq_length
|
||||
|
||||
@property
|
||||
def tokenizer(self):
|
||||
return self._tokenizer
|
||||
|
||||
@property
|
||||
def text_processor(self):
|
||||
return self._text_processor
|
||||
|
||||
@property
|
||||
def dataset(self):
|
||||
return self._dataset
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
return len(self._tokenizer)
|
||||
|
||||
|
||||
class JsonDataset(object):
|
||||
""" JSON dataset, where each line of the data file contains a JSON
|
||||
dictionary with text fields.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def get_default_config(updates=None):
|
||||
config = ConfigDict()
|
||||
config.path = ''
|
||||
config.seq_length = 1024
|
||||
config.batch_size = 8
|
||||
config.always_start_with_bos = False
|
||||
config.start_seek_loc = 0
|
||||
config.example_index_at_start = 0
|
||||
config.tokens_count_at_start = 0
|
||||
config.tokenizer_processes = 1
|
||||
config.tokenizer_parallel_chunk_size = 32
|
||||
config.tokenizer_parallel_batch_size = 1024
|
||||
config.throughput_average_window_size = 200
|
||||
|
||||
if updates is not None:
|
||||
config.update(ConfigDict(updates).copy_and_resolve_references())
|
||||
return config
|
||||
|
||||
def __init__(self, config, tokenizer, text_processor):
|
||||
self.config = self.get_default_config(config)
|
||||
assert self.config.path != ''
|
||||
self._tokenizer = tokenizer
|
||||
self._text_processor = text_processor
|
||||
self._index = self.config.example_index_at_start
|
||||
self._file_loc = self.config.start_seek_loc
|
||||
self._total_tokens = self.config.tokens_count_at_start
|
||||
|
||||
def parse_json(self, line):
|
||||
if not line or line == '\n':
|
||||
return None
|
||||
try:
|
||||
data = json.loads(line)
|
||||
except json.decoder.JSONDecodeError:
|
||||
print(f'Error parsing json line:\n{line}')
|
||||
return None
|
||||
return data
|
||||
|
||||
def json_iterator(self):
|
||||
with mlxu.open_file(self.config.path, 'r') as fin:
|
||||
fin.seek(self._file_loc)
|
||||
while True:
|
||||
line = fin.readline()
|
||||
self._file_loc = fin.tell()
|
||||
if not line: # Reached EOF
|
||||
self._index = 0
|
||||
fin.seek(0)
|
||||
continue
|
||||
|
||||
data = self.parse_json(line)
|
||||
if data is not None:
|
||||
# JSON parsing succeeded
|
||||
yield data, self._file_loc, self._index
|
||||
self._index += 1
|
||||
|
||||
def batched(self, iterator, batch_size):
|
||||
batch = []
|
||||
for example in iterator:
|
||||
batch.append(example)
|
||||
if len(batch) == batch_size:
|
||||
yield batch
|
||||
batch = []
|
||||
if len(batch) > 0:
|
||||
yield batch
|
||||
|
||||
def parallel_example_iterator(self):
|
||||
if self.config.tokenizer_processes == 1:
|
||||
for example, loc, index in self.json_iterator():
|
||||
yield self.text_processor((example, loc, index), has_aux=True)
|
||||
else:
|
||||
process_pool = Pool(self.config.tokenizer_processes)
|
||||
batched_iterator = self.batched(
|
||||
self.json_iterator(), self.config.tokenizer_parallel_batch_size
|
||||
)
|
||||
with process_pool as pool:
|
||||
map_fn = partial(self.text_processor, has_aux=True)
|
||||
next_batch = pool.map_async(
|
||||
map_fn, next(batched_iterator),
|
||||
chunksize=self.config.tokenizer_parallel_chunk_size
|
||||
)
|
||||
while True:
|
||||
current_batch = next_batch
|
||||
next_batch = pool.map_async(
|
||||
map_fn, next(batched_iterator),
|
||||
chunksize=self.config.tokenizer_parallel_chunk_size
|
||||
)
|
||||
for example in current_batch.get():
|
||||
yield example
|
||||
|
||||
def __iter__(self):
|
||||
chunk_size = self.config.batch_size * self.config.seq_length
|
||||
token_buffer = []
|
||||
loss_mask_buffer = []
|
||||
last_time = 0.0
|
||||
step_times = []
|
||||
start_time = time.time()
|
||||
start_tokens = self._total_tokens
|
||||
for tokens, loss_masks, loc, index in self.parallel_example_iterator():
|
||||
token_buffer.extend(tokens)
|
||||
loss_mask_buffer.extend(loss_masks)
|
||||
while len(token_buffer) > chunk_size + 1:
|
||||
self._total_tokens += chunk_size
|
||||
step_times.append(time.time() - last_time)
|
||||
last_time = time.time()
|
||||
if len(step_times) > self.config.throughput_average_window_size:
|
||||
step_times = step_times[-self.config.throughput_average_window_size:]
|
||||
average_throughput = chunk_size / np.mean(step_times)
|
||||
accumulated_throughput = (
|
||||
(self._total_tokens - start_tokens) / (time.time() - start_time)
|
||||
)
|
||||
metrics = {
|
||||
'dataset_file_loc': loc,
|
||||
'dataset_example_index': index,
|
||||
'dataset_total_tokens': self._total_tokens,
|
||||
'dataset_accumulated_tps': accumulated_throughput,
|
||||
'dataset_average_tps': average_throughput,
|
||||
}
|
||||
batch = {
|
||||
'input_tokens': np.array(token_buffer[:chunk_size], dtype=np.int32).reshape(
|
||||
self.config.batch_size, -1
|
||||
),
|
||||
'target_tokens': np.array(token_buffer[1:chunk_size + 1], dtype=np.int32).reshape(
|
||||
self.config.batch_size, -1
|
||||
),
|
||||
'loss_masks': np.array(loss_mask_buffer[1:chunk_size + 1], dtype=np.float32).reshape(
|
||||
self.config.batch_size, -1
|
||||
),
|
||||
}
|
||||
if self.config.always_start_with_bos:
|
||||
batch['input_tokens'][:, 0] = self.tokenizer.bos_token_id
|
||||
yield batch, metrics
|
||||
token_buffer = token_buffer[chunk_size:]
|
||||
loss_mask_buffer = loss_mask_buffer[chunk_size:]
|
||||
|
||||
def get_state_dict(self):
|
||||
return dict(
|
||||
config=self.config,
|
||||
index=self._index,
|
||||
file_loc=self._file_loc,
|
||||
total_tokens=self._total_tokens,
|
||||
)
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
if 'config' in state_dict:
|
||||
self.config.update(ConfigDict(state_dict['config']))
|
||||
self._index = state_dict.get('index', self.config.example_index_at_start)
|
||||
self._file_loc = state_dict.get('file_loc', self.config.start_seek_loc)
|
||||
self._total_tokens = state_dict.get('total_tokens', self.config.tokens_count_at_start)
|
||||
|
||||
@property
|
||||
def seq_length(self):
|
||||
return self.config.seq_length
|
||||
|
||||
@property
|
||||
def tokenizer(self):
|
||||
return self._tokenizer
|
||||
|
||||
@property
|
||||
def text_processor(self):
|
||||
return self._text_processor
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
return len(self.tokenizer)
|
||||
403
EasyLM/jax_utils.py
Normal file
403
EasyLM/jax_utils.py
Normal file
@@ -0,0 +1,403 @@
|
||||
import os
|
||||
import math
|
||||
from typing import Any, Mapping, Text, Tuple, Union, NamedTuple
|
||||
from functools import partial
|
||||
import re
|
||||
import dataclasses
|
||||
import random
|
||||
from ml_collections import ConfigDict
|
||||
from ml_collections.config_dict.config_dict import placeholder
|
||||
|
||||
import flax
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax.sharding import PartitionSpec as PS
|
||||
from jax.sharding import Mesh
|
||||
from jax.experimental import mesh_utils
|
||||
from jax.experimental.pjit import with_sharding_constraint as _with_sharding_constraint
|
||||
from jax.experimental.pjit import pjit
|
||||
from jax.interpreters import pxla
|
||||
import numpy as np
|
||||
from transformers import FlaxLogitsWarper
|
||||
|
||||
|
||||
class JaxRNG(object):
|
||||
""" A convenient stateful Jax RNG wrapper. Can be used to wrap RNG inside
|
||||
pure function.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def from_seed(cls, seed):
|
||||
return cls(jax.random.PRNGKey(seed))
|
||||
|
||||
def __init__(self, rng):
|
||||
self.rng = rng
|
||||
|
||||
def __call__(self, keys=None):
|
||||
if keys is None:
|
||||
self.rng, split_rng = jax.random.split(self.rng)
|
||||
return split_rng
|
||||
elif isinstance(keys, int):
|
||||
split_rngs = jax.random.split(self.rng, num=keys + 1)
|
||||
self.rng = split_rngs[0]
|
||||
return tuple(split_rngs[1:])
|
||||
else:
|
||||
split_rngs = jax.random.split(self.rng, num=len(keys) + 1)
|
||||
self.rng = split_rngs[0]
|
||||
return {key: val for key, val in zip(keys, split_rngs[1:])}
|
||||
|
||||
|
||||
class JaxDistributedConfig(object):
|
||||
""" Utility class for initializing JAX distributed. """
|
||||
|
||||
@staticmethod
|
||||
def get_default_config(updates=None):
|
||||
config = ConfigDict()
|
||||
config.initialize_jax_distributed = False
|
||||
config.coordinator_address = placeholder(str)
|
||||
config.num_processes = placeholder(int)
|
||||
config.process_id = placeholder(int)
|
||||
config.local_device_ids = placeholder(str)
|
||||
|
||||
if updates is not None:
|
||||
config.update(ConfigDict(updates).copy_and_resolve_references())
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def initialize(cls, config):
|
||||
config = cls.get_default_config(config)
|
||||
if config.initialize_jax_distributed:
|
||||
if config.local_device_ids is not None:
|
||||
local_device_ids = [int(x) for x in config.local_device_ids.split(',')]
|
||||
else:
|
||||
local_device_ids = None
|
||||
|
||||
jax.distributed.initialize(
|
||||
coordinator_address=config.coordinator_address,
|
||||
num_processes=config.num_processes,
|
||||
process_id=config.process_id,
|
||||
local_device_ids=local_device_ids,
|
||||
)
|
||||
|
||||
|
||||
class FlaxTemperatureLogitsWarper(FlaxLogitsWarper):
|
||||
""" JIT traceable version of FlaxLogitsWarper that performs temperature scaling."""
|
||||
def __init__(self, temperature):
|
||||
self.temperature = temperature
|
||||
|
||||
def __call__(self, input_ids, scores, cur_len):
|
||||
return scores / jnp.clip(self.temperature, a_min=1e-8)
|
||||
|
||||
|
||||
def make_shard_and_gather_fns(partition_specs, dtype_specs=None):
|
||||
""" Create pytree of sharding and gathering functions from pytree of
|
||||
partition specs.
|
||||
"""
|
||||
float_dtypes = (jnp.bfloat16, jnp.float16, jnp.float32, jnp.float64)
|
||||
|
||||
def make_to_dtype_fn(dtype_spec):
|
||||
def to_dtype(tensor):
|
||||
if dtype_specs in float_dtypes and getattr(tensor, 'dtype', None) in float_dtypes:
|
||||
# Convert all float tensors to the same dtype
|
||||
return tensor.astype(dtype_specs)
|
||||
elif hasattr(dtype_spec, 'dtype') and hasattr(tensor, 'dtype'):
|
||||
return tensor.astype(dtype_spec.dtype)
|
||||
return tensor
|
||||
return to_dtype
|
||||
|
||||
def make_shard_fn(partition_spec, dtype_spec=None):
|
||||
jax_shard_function = pjit(
|
||||
make_to_dtype_fn(dtype_spec),
|
||||
in_shardings=None,
|
||||
out_shardings=partition_spec
|
||||
)
|
||||
def shard_fn(tensor):
|
||||
return jax_shard_function(tensor).block_until_ready()
|
||||
return shard_fn
|
||||
|
||||
def make_gather_fn(partition_spec, dtype_spec=None):
|
||||
jax_gather_fn = pjit(
|
||||
make_to_dtype_fn(dtype_spec),
|
||||
in_shardings=partition_spec,
|
||||
out_shardings=None
|
||||
)
|
||||
def gather_fn(tensor):
|
||||
return jax.device_get(jax_gather_fn(tensor))
|
||||
return gather_fn
|
||||
|
||||
if dtype_specs is None or dtype_specs in float_dtypes:
|
||||
shard_fns = jax.tree_util.tree_map(make_shard_fn, partition_specs)
|
||||
gather_fns = jax.tree_util.tree_map(make_gather_fn, partition_specs)
|
||||
else:
|
||||
shard_fns = jax.tree_util.tree_map(
|
||||
make_shard_fn, partition_specs, dtype_specs
|
||||
)
|
||||
gather_fns = jax.tree_util.tree_map(
|
||||
make_gather_fn, partition_specs, dtype_specs
|
||||
)
|
||||
return shard_fns, gather_fns
|
||||
|
||||
|
||||
def set_random_seed(seed):
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
init_rng(seed)
|
||||
|
||||
|
||||
def get_jax_mesh(axis_dims, names):
|
||||
if axis_dims.startswith('!'):
|
||||
# Allow splitting a physical mesh axis if needed
|
||||
mesh_axis_splitting = True
|
||||
axis_dims = axis_dims[1:]
|
||||
else:
|
||||
mesh_axis_splitting = False
|
||||
|
||||
if ':' in axis_dims:
|
||||
dims = []
|
||||
dim_names = []
|
||||
for axis in axis_dims.split(','):
|
||||
name, dim = axis.split(':')
|
||||
assert name in names
|
||||
dims.append(int(dim))
|
||||
dim_names.append(name)
|
||||
assert(set(dim_names) == set(names))
|
||||
else:
|
||||
dims = [int(x) for x in axis_dims.split(',')]
|
||||
dim_names = names
|
||||
assert len(dims) == len(names)
|
||||
mesh_shape = np.arange(jax.device_count()).reshape(dims).shape
|
||||
if mesh_axis_splitting:
|
||||
physical_mesh = np.array(jax.devices()).reshape(mesh_shape)
|
||||
else:
|
||||
physical_mesh = mesh_utils.create_device_mesh(mesh_shape)
|
||||
return Mesh(physical_mesh, dim_names)
|
||||
|
||||
|
||||
def names_in_current_mesh(*names):
|
||||
""" Check if current mesh axes contain these names. """
|
||||
mesh_axis_names = pxla.thread_resources.env.physical_mesh.axis_names
|
||||
return set(names) <= set(mesh_axis_names)
|
||||
|
||||
|
||||
def get_names_from_parition_spec(partition_specs):
|
||||
""" Return axis names from partition specs. """
|
||||
names = set()
|
||||
if isinstance(partition_specs, dict):
|
||||
partition_specs = partition_specs.values()
|
||||
for item in partition_specs:
|
||||
if item is None:
|
||||
continue
|
||||
elif isinstance(item, str):
|
||||
names.add(item)
|
||||
else:
|
||||
names.update(get_names_from_parition_spec(item))
|
||||
|
||||
return list(names)
|
||||
|
||||
|
||||
def with_sharding_constraint(x, partition_specs):
|
||||
""" A smarter version of with_sharding_constraint that only applies the
|
||||
constraint if the current mesh contains the axes in the partition specs.
|
||||
"""
|
||||
axis_names = get_names_from_parition_spec(partition_specs)
|
||||
if names_in_current_mesh(*axis_names):
|
||||
x = _with_sharding_constraint(x, partition_specs)
|
||||
return x
|
||||
|
||||
|
||||
def wrap_function_with_rng(rng):
|
||||
""" To be used as decorator, automatically bookkeep a RNG for the wrapped function. """
|
||||
def wrap_function(function):
|
||||
def wrapped(*args, **kwargs):
|
||||
nonlocal rng
|
||||
rng, split_rng = jax.random.split(rng)
|
||||
return function(split_rng, *args, **kwargs)
|
||||
return wrapped
|
||||
return wrap_function
|
||||
|
||||
|
||||
def init_rng(seed):
|
||||
global jax_utils_rng
|
||||
jax_utils_rng = JaxRNG.from_seed(seed)
|
||||
|
||||
|
||||
def next_rng(*args, **kwargs):
|
||||
global jax_utils_rng
|
||||
return jax_utils_rng(*args, **kwargs)
|
||||
|
||||
|
||||
def get_metrics(metrics, unreplicate=False, stack=False):
|
||||
if unreplicate:
|
||||
metrics = flax.jax_utils.unreplicate(metrics)
|
||||
metrics = jax.device_get(metrics)
|
||||
if stack:
|
||||
return jax.tree_map(lambda *args: np.stack(args), *metrics)
|
||||
else:
|
||||
return {key: float(val) for key, val in metrics.items()}
|
||||
|
||||
|
||||
def mse_loss(val, target, valid=None):
|
||||
if valid is None:
|
||||
valid = jnp.ones((*target.shape[:2], 1))
|
||||
valid = valid.astype(jnp.float32)
|
||||
loss = jnp.mean(
|
||||
jnp.where(
|
||||
valid > 0.0,
|
||||
jnp.square(val - target),
|
||||
0.0
|
||||
)
|
||||
)
|
||||
return loss
|
||||
|
||||
|
||||
def cross_entropy_loss_and_accuracy(logits, tokens, valid=None):
|
||||
if valid is None:
|
||||
valid = jnp.ones(tokens.shape[:2])
|
||||
valid = valid.astype(jnp.float32)
|
||||
valid_text_length = jnp.maximum(jnp.sum(valid, axis=-1), 1e-10)
|
||||
logits = logits.astype(jnp.float32) # for numerical stability
|
||||
token_log_prob = jnp.squeeze(
|
||||
jnp.take_along_axis(
|
||||
jax.nn.log_softmax(logits, axis=-1),
|
||||
jnp.expand_dims(tokens, -1),
|
||||
axis=-1,
|
||||
),
|
||||
-1,
|
||||
)
|
||||
token_log_prob = jnp.where(valid > 0.0, token_log_prob, jnp.array(0.0))
|
||||
loss = -jnp.mean(jnp.sum(token_log_prob, axis=-1) / valid_text_length)
|
||||
correct = jnp.where(
|
||||
valid > 0.0,
|
||||
jnp.argmax(logits, axis=-1) == tokens,
|
||||
jnp.array(False)
|
||||
)
|
||||
accuracy = jnp.mean(jnp.sum(correct, axis=-1) / valid_text_length)
|
||||
return loss, accuracy
|
||||
|
||||
|
||||
def global_norm(tree):
|
||||
""" Return the global L2 norm of a pytree. """
|
||||
squared = jax.tree_util.tree_map(lambda x: jnp.sum(jnp.square(x)), tree)
|
||||
flattened, _ = jax.flatten_util.ravel_pytree(squared)
|
||||
return jnp.sqrt(jnp.sum(flattened))
|
||||
|
||||
|
||||
def average_metrics(metrics):
|
||||
with jax.spmd_mode("allow_all"):
|
||||
return jax.tree_map(
|
||||
lambda *args: jnp.mean(jnp.stack(args)),
|
||||
*metrics
|
||||
)
|
||||
|
||||
|
||||
def get_float_dtype_by_name(dtype):
|
||||
return {
|
||||
'bf16': jnp.bfloat16,
|
||||
'bfloat16': jnp.bfloat16,
|
||||
'fp16': jnp.float16,
|
||||
'float16': jnp.float16,
|
||||
'fp32': jnp.float32,
|
||||
'float32': jnp.float32,
|
||||
'fp64': jnp.float64,
|
||||
'float64': jnp.float64,
|
||||
}[dtype]
|
||||
|
||||
|
||||
def float_tensor_to_dtype(tensor, dtype):
|
||||
if dtype is None or dtype == '':
|
||||
return tensor
|
||||
if isinstance(dtype, str):
|
||||
dtype = get_float_dtype_by_name(dtype)
|
||||
float_dtypes = (jnp.bfloat16, jnp.float16, jnp.float32, jnp.float64)
|
||||
if getattr(tensor, 'dtype', None) in float_dtypes:
|
||||
tensor = tensor.astype(dtype)
|
||||
return tensor
|
||||
|
||||
|
||||
def float_to_dtype(tree, dtype):
|
||||
return jax.tree_util.tree_map(
|
||||
partial(float_tensor_to_dtype, dtype=dtype), tree
|
||||
)
|
||||
|
||||
|
||||
def get_gradient_checkpoint_policy(name):
|
||||
return {
|
||||
'everything_saveable': jax.checkpoint_policies.everything_saveable,
|
||||
'nothing_saveable': jax.checkpoint_policies.nothing_saveable,
|
||||
'checkpoint_dots': jax.checkpoint_policies.checkpoint_dots,
|
||||
'checkpoint_dots_with_no_batch_dims': jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims,
|
||||
}[name]
|
||||
|
||||
|
||||
def tree_path_to_string(path, sep=None):
|
||||
keys = []
|
||||
for key in path:
|
||||
if isinstance(key, jax.tree_util.SequenceKey):
|
||||
keys.append(str(key.idx))
|
||||
elif isinstance(key, jax.tree_util.DictKey):
|
||||
keys.append(str(key.key))
|
||||
elif isinstance(key, jax.tree_util.GetAttrKey):
|
||||
keys.append(str(key.name))
|
||||
elif isinstance(key, jax.tree_util.FlattenedIndexKey):
|
||||
keys.append(str(key.key))
|
||||
else:
|
||||
keys.append(str(key))
|
||||
if sep is None:
|
||||
return tuple(keys)
|
||||
return sep.join(keys)
|
||||
|
||||
|
||||
def flatten_tree(xs, is_leaf=None, sep=None):
|
||||
flattened, _ = jax.tree_util.tree_flatten_with_path(xs, is_leaf=is_leaf)
|
||||
output = {}
|
||||
for key, val in flattened:
|
||||
output[tree_path_to_string(key, sep=sep)] = val
|
||||
return output
|
||||
|
||||
|
||||
def named_tree_map(f, tree, *rest, is_leaf=None, sep=None):
|
||||
""" An extended version of jax.tree_util.tree_map, where the mapped function
|
||||
f takes both the name (path) and the tree leaf as input.
|
||||
"""
|
||||
return jax.tree_util.tree_map_with_path(
|
||||
lambda path, x, *r: f(tree_path_to_string(path, sep=sep), x, *r),
|
||||
tree, *rest,
|
||||
is_leaf=is_leaf
|
||||
)
|
||||
|
||||
|
||||
def match_partition_rules(rules, params):
|
||||
""" Returns a pytree of PartitionSpec according to rules. Supports handling
|
||||
Flax TrainState and Optax optimizer state.
|
||||
"""
|
||||
def get_partition_spec(name, leaf):
|
||||
if len(leaf.shape) == 0 or np.prod(leaf.shape) == 1:
|
||||
""" Don't partition scalar values. """
|
||||
return PS()
|
||||
for rule, ps in rules:
|
||||
if re.search(rule, name) is not None:
|
||||
return ps
|
||||
raise ValueError(f'Partition rule not found for param: {name}')
|
||||
return named_tree_map(get_partition_spec, params, sep='/')
|
||||
|
||||
|
||||
def get_weight_decay_mask(exclusions):
|
||||
""" Return a weight decay mask function that computes the pytree masks
|
||||
according to the given exclusion rules.
|
||||
"""
|
||||
def decay(name, _):
|
||||
for rule in exclusions:
|
||||
if re.search(rule, name) is not None:
|
||||
return False
|
||||
return True
|
||||
|
||||
def weight_decay_mask(params):
|
||||
return named_tree_map(decay, params, sep='/')
|
||||
|
||||
return weight_decay_mask
|
||||
|
||||
|
||||
def tree_apply(fns, tree):
|
||||
""" Apply a pytree of functions to the pytree. """
|
||||
return jax.tree_util.tree_map(lambda fn, x: fn(x), fns, tree)
|
||||
|
||||
0
EasyLM/models/__init__.py
Normal file
0
EasyLM/models/__init__.py
Normal file
0
EasyLM/models/gptj/__init__.py
Normal file
0
EasyLM/models/gptj/__init__.py
Normal file
1054
EasyLM/models/gptj/gptj_model.py
Normal file
1054
EasyLM/models/gptj/gptj_model.py
Normal file
File diff suppressed because it is too large
Load Diff
396
EasyLM/models/gptj/gptj_serve.py
Normal file
396
EasyLM/models/gptj/gptj_serve.py
Normal file
@@ -0,0 +1,396 @@
|
||||
import pprint
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import mlxu
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax.experimental.pjit import pjit
|
||||
from jax.sharding import PartitionSpec as PS
|
||||
import flax
|
||||
from flax import linen as nn
|
||||
from flax.jax_utils import prefetch_to_device
|
||||
from flax.training.train_state import TrainState
|
||||
import optax
|
||||
from transformers import GenerationConfig, FlaxLogitsProcessorList
|
||||
|
||||
from EasyLM.checkpoint import StreamingCheckpointer
|
||||
from EasyLM.serving import LMServer
|
||||
from EasyLM.jax_utils import (
|
||||
JaxRNG, JaxDistributedConfig, next_rng, match_partition_rules, tree_apply,
|
||||
set_random_seed, get_float_dtype_by_name, make_shard_and_gather_fns,
|
||||
with_sharding_constraint, FlaxTemperatureLogitsWarper
|
||||
)
|
||||
from EasyLM.models.gptj.gptj_model import (
|
||||
GPTJConfig, FlaxGPTJForCausalLMModule, FlaxGPTJForCausalLM
|
||||
)
|
||||
|
||||
|
||||
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
|
||||
seed=42,
|
||||
initialize_jax_distributed=False,
|
||||
mesh_dim='1,-1,1',
|
||||
dtype='bf16',
|
||||
input_length=1024,
|
||||
seq_length=2048,
|
||||
top_k=50,
|
||||
top_p=1.0,
|
||||
do_sample=True,
|
||||
num_beams=1,
|
||||
add_bos_token=False,
|
||||
load_gptj_config='',
|
||||
load_checkpoint='',
|
||||
tokenizer=GPTJConfig.get_tokenizer_config(),
|
||||
lm_server=LMServer.get_default_config(),
|
||||
jax_distributed=JaxDistributedConfig.get_default_config(),
|
||||
)
|
||||
|
||||
|
||||
def main(argv):
|
||||
JaxDistributedConfig.initialize(FLAGS.jax_distributed)
|
||||
set_random_seed(FLAGS.seed)
|
||||
|
||||
prefix_tokenizer = GPTJConfig.get_tokenizer(
|
||||
FLAGS.tokenizer, truncation_side='left', padding_side='left'
|
||||
)
|
||||
tokenizer = GPTJConfig.get_tokenizer(
|
||||
FLAGS.tokenizer, truncation_side='right', padding_side='right'
|
||||
)
|
||||
|
||||
with jax.default_device(jax.devices("cpu")[0]):
|
||||
gptj_config = GPTJConfig.load_config(FLAGS.load_gptj_config)
|
||||
load_type, load_path = FLAGS.load_checkpoint.split('::', 1)
|
||||
if load_type == 'huggingface':
|
||||
params = gptj_config.load_pretrained(load_path)
|
||||
else:
|
||||
_, params = StreamingCheckpointer.load_trainstate_checkpoint(
|
||||
FLAGS.load_checkpoint, disallow_trainstate=True
|
||||
)
|
||||
|
||||
hf_model = FlaxGPTJForCausalLM(
|
||||
gptj_config,
|
||||
input_shape=(1, FLAGS.seq_length),
|
||||
seed=FLAGS.seed,
|
||||
_do_init=False
|
||||
)
|
||||
|
||||
model_ps = match_partition_rules(
|
||||
GPTJConfig.get_partition_rules(), params
|
||||
)
|
||||
shard_fns, _ = make_shard_and_gather_fns(
|
||||
model_ps, get_float_dtype_by_name(FLAGS.dtype)
|
||||
)
|
||||
|
||||
@partial(
|
||||
pjit,
|
||||
in_shardings=(model_ps, PS(), PS()),
|
||||
out_shardings=(PS(), PS(), PS())
|
||||
)
|
||||
def forward_loglikelihood(params, rng, batch):
|
||||
batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
|
||||
rng_generator = JaxRNG(rng)
|
||||
input_tokens = batch['input_tokens']
|
||||
output_tokens = batch['output_tokens']
|
||||
input_mask = batch['input_mask']
|
||||
output_mask = batch['output_mask']
|
||||
|
||||
logits = hf_model.module.apply(
|
||||
params, input_tokens, attention_mask=input_mask,
|
||||
deterministic=True, rngs=rng_generator(gptj_config.rng_keys()),
|
||||
).logits
|
||||
if gptj_config.n_real_tokens is not None:
|
||||
logits = logits.at[:, :, gptj_config.n_real_tokens:].set(-1e8)
|
||||
loglikelihood = -optax.softmax_cross_entropy_with_integer_labels(
|
||||
logits, output_tokens
|
||||
)
|
||||
loglikelihood = jnp.sum(loglikelihood * output_mask, axis=-1)
|
||||
match_count = jnp.sum(
|
||||
(jnp.argmax(logits, axis=-1) == output_tokens) * output_mask,
|
||||
axis=-1
|
||||
)
|
||||
total = jnp.sum(output_mask, axis=-1)
|
||||
is_greedy = match_count == total
|
||||
return loglikelihood, is_greedy, rng_generator()
|
||||
|
||||
|
||||
@partial(
|
||||
pjit,
|
||||
in_shardings=(model_ps, PS(), PS(), PS()),
|
||||
out_shardings=(PS(), PS())
|
||||
)
|
||||
def forward_generate(params, rng, batch, temperature):
|
||||
batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
|
||||
rng_generator = JaxRNG(rng)
|
||||
output = hf_model.generate(
|
||||
batch['input_tokens'],
|
||||
attention_mask=batch['attention_mask'],
|
||||
params=params['params'],
|
||||
prng_key=rng_generator(),
|
||||
logits_processor=FlaxLogitsProcessorList(
|
||||
[FlaxTemperatureLogitsWarper(temperature)]
|
||||
),
|
||||
generation_config=GenerationConfig(
|
||||
max_new_tokens=FLAGS.seq_length - FLAGS.input_length,
|
||||
pad_token_id=tokenizer.eos_token_id,
|
||||
bos_token_id=tokenizer.bos_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
do_sample=FLAGS.do_sample,
|
||||
num_beams=FLAGS.num_beams,
|
||||
top_k=FLAGS.top_k,
|
||||
top_p=FLAGS.top_p,
|
||||
)
|
||||
).sequences[:, batch['input_tokens'].shape[1]:]
|
||||
return output, rng_generator()
|
||||
|
||||
@partial(
|
||||
pjit,
|
||||
in_shardings=(model_ps, PS(), PS()),
|
||||
out_shardings=(PS(), PS())
|
||||
)
|
||||
def forward_greedy_generate(params, rng, batch):
|
||||
batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
|
||||
rng_generator = JaxRNG(rng)
|
||||
output = hf_model.generate(
|
||||
batch['input_tokens'],
|
||||
attention_mask=batch['attention_mask'],
|
||||
params=params['params'],
|
||||
prng_key=rng_generator(),
|
||||
generation_config=GenerationConfig(
|
||||
max_new_tokens=FLAGS.seq_length - FLAGS.input_length,
|
||||
pad_token_id=tokenizer.eos_token_id,
|
||||
bos_token_id=tokenizer.bos_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
do_sample=False,
|
||||
num_beams=1,
|
||||
)
|
||||
).sequences[:, batch['input_tokens'].shape[1]:]
|
||||
return output, rng_generator()
|
||||
|
||||
mesh = GPTJConfig.get_jax_mesh(FLAGS.mesh_dim)
|
||||
with mesh:
|
||||
params = tree_apply(shard_fns, params)
|
||||
sharded_rng = next_rng()
|
||||
|
||||
class ModelServer(LMServer):
|
||||
|
||||
@staticmethod
|
||||
def loglikelihood(prefix_text, text):
|
||||
nonlocal sharded_rng
|
||||
prefix = prefix_tokenizer(
|
||||
prefix_text,
|
||||
padding='max_length',
|
||||
truncation=True,
|
||||
max_length=FLAGS.input_length,
|
||||
return_tensors='np',
|
||||
)
|
||||
inputs = tokenizer(
|
||||
text,
|
||||
padding='max_length',
|
||||
truncation=True,
|
||||
max_length=FLAGS.seq_length - FLAGS.input_length,
|
||||
return_tensors='np',
|
||||
)
|
||||
output_tokens = np.concatenate([prefix.input_ids, inputs.input_ids], axis=1)
|
||||
bos_tokens = np.full(
|
||||
(output_tokens.shape[0], 1), tokenizer.bos_token_id, dtype=np.int32
|
||||
)
|
||||
input_tokens = np.concatenate([bos_tokens, output_tokens[:, :-1]], axis=-1)
|
||||
input_mask = np.concatenate(
|
||||
[prefix.attention_mask, inputs.attention_mask], axis=1
|
||||
)
|
||||
if FLAGS.add_bos_token:
|
||||
bos_mask = np.ones_like(input_mask[:, :1])
|
||||
else:
|
||||
bos_mask = np.zeros_like(input_mask[:, :1])
|
||||
|
||||
input_mask = np.concatenate([bos_mask, input_mask[:, :-1]], axis=1)
|
||||
output_mask = np.concatenate(
|
||||
[np.zeros_like(prefix.attention_mask), inputs.attention_mask], axis=1
|
||||
)
|
||||
batch = dict(
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
input_mask=input_mask,
|
||||
output_mask=output_mask,
|
||||
)
|
||||
with mesh:
|
||||
loglikelihood, is_greedy, sharded_rng = forward_loglikelihood(
|
||||
params, sharded_rng, batch
|
||||
)
|
||||
loglikelihood, is_greedy = jax.device_get((loglikelihood, is_greedy))
|
||||
return loglikelihood, is_greedy
|
||||
|
||||
@staticmethod
|
||||
def loglikelihood_rolling(text):
|
||||
nonlocal sharded_rng
|
||||
inputs = tokenizer(
|
||||
text,
|
||||
padding='longest',
|
||||
truncation=False,
|
||||
max_length=np.iinfo(np.int32).max,
|
||||
return_tensors='np',
|
||||
)
|
||||
batch_size = inputs.input_ids.shape[0]
|
||||
output_tokens = inputs.input_ids
|
||||
attention_mask = inputs.attention_mask
|
||||
|
||||
if output_tokens.shape[1] < FLAGS.seq_length:
|
||||
padding_length = FLAGS.seq_length - output_tokens.shape[1]
|
||||
pad_tokens = np.full(
|
||||
(batch_size, padding_length), tokenizer.pad_token_id, dtype=np.int32
|
||||
)
|
||||
output_tokens = np.concatenate([output_tokens, pad_tokens], axis=-1)
|
||||
pad_mask = np.zeros(
|
||||
(batch_size, padding_length), dtype=inputs.attention_mask.dtype
|
||||
)
|
||||
attention_mask = np.concatenate([attention_mask, pad_mask], axis=-1)
|
||||
|
||||
bos_tokens = np.full(
|
||||
(batch_size, 1), tokenizer.bos_token_id, dtype=np.int32
|
||||
)
|
||||
input_tokens = np.concatenate([bos_tokens, output_tokens[:, :-1]], axis=-1)
|
||||
bos_mask = np.ones((batch_size, 1), dtype=inputs.attention_mask.dtype)
|
||||
total_seq_length = output_tokens.shape[1]
|
||||
|
||||
total_loglikelihood = 0.0
|
||||
total_is_greedy = True
|
||||
# Sliding window
|
||||
for i in range(0, total_seq_length, FLAGS.seq_length):
|
||||
# Last window
|
||||
if i + FLAGS.seq_length > total_seq_length:
|
||||
last_output_mask = np.copy(attention_mask[:, -FLAGS.seq_length:])
|
||||
last_output_mask[:, :i - total_seq_length] = 0.0
|
||||
|
||||
batch = dict(
|
||||
input_tokens=input_tokens[:, -FLAGS.seq_length:],
|
||||
output_tokens=output_tokens[:, -FLAGS.seq_length:],
|
||||
input_mask=attention_mask[:, -FLAGS.seq_length:],
|
||||
output_mask=last_output_mask,
|
||||
)
|
||||
|
||||
# Normal window
|
||||
else:
|
||||
batch = dict(
|
||||
input_tokens=input_tokens[:, i:i + FLAGS.seq_length],
|
||||
output_tokens=output_tokens[:, i:i + FLAGS.seq_length],
|
||||
input_mask=attention_mask[:, i:i + FLAGS.seq_length],
|
||||
output_mask=attention_mask[:, i:i + FLAGS.seq_length],
|
||||
)
|
||||
|
||||
with mesh:
|
||||
loglikelihood, is_greedy, sharded_rng = forward_loglikelihood(
|
||||
params, sharded_rng, batch
|
||||
)
|
||||
loglikelihood, is_greedy = jax.device_get((loglikelihood, is_greedy))
|
||||
|
||||
total_loglikelihood += loglikelihood
|
||||
total_is_greedy = np.logical_and(is_greedy, total_is_greedy)
|
||||
|
||||
return total_loglikelihood, total_is_greedy
|
||||
|
||||
@staticmethod
|
||||
def generate(text, temperature):
|
||||
nonlocal sharded_rng
|
||||
inputs = prefix_tokenizer(
|
||||
text,
|
||||
padding='max_length',
|
||||
truncation=True,
|
||||
max_length=FLAGS.input_length,
|
||||
return_tensors='np',
|
||||
)
|
||||
input_tokens = inputs.input_ids
|
||||
input_mask = inputs.attention_mask
|
||||
if FLAGS.add_bos_token:
|
||||
input_tokens[:, 0] = tokenizer.bos_token_id
|
||||
input_mask[:, 0] = 1
|
||||
batch = dict(
|
||||
input_tokens=input_tokens,
|
||||
attention_mask=input_mask,
|
||||
)
|
||||
with mesh:
|
||||
output, sharded_rng = forward_generate(
|
||||
params, sharded_rng, batch, temperature
|
||||
)
|
||||
output = jax.device_get(output)
|
||||
output_text = []
|
||||
for text in list(tokenizer.batch_decode(output)):
|
||||
if tokenizer.eos_token in text:
|
||||
text = text.split(tokenizer.eos_token, maxsplit=1)[0]
|
||||
output_text.append(text)
|
||||
|
||||
return output_text
|
||||
|
||||
@staticmethod
|
||||
def greedy_until(prefix_text, until, max_length):
|
||||
nonlocal sharded_rng
|
||||
all_outputs = []
|
||||
for pf, ut in zip(prefix_text, until):
|
||||
if isinstance(ut, str):
|
||||
ut = [ut]
|
||||
total_length = 0
|
||||
total_generated = ''
|
||||
|
||||
while total_length < max_length:
|
||||
pf_tokens = tokenizer(
|
||||
pf,
|
||||
padding=False,
|
||||
truncation=False,
|
||||
max_length=np.iinfo(np.int32).max,
|
||||
return_tensors='np',
|
||||
)
|
||||
input_tokens = pf_tokens.input_ids
|
||||
attention_mask = pf_tokens.attention_mask
|
||||
|
||||
if input_tokens.shape[1] < FLAGS.input_length:
|
||||
extra = FLAGS.input_length - input_tokens.shape[1]
|
||||
pad_tokens = np.full(
|
||||
(1, extra), tokenizer.pad_token_id, dtype=np.int32
|
||||
)
|
||||
input_tokens = np.concatenate(
|
||||
[pad_tokens, input_tokens], axis=1
|
||||
)
|
||||
pad_attention = np.zeros((1, extra), dtype=attention_mask.dtype)
|
||||
attention_mask = np.concatenate(
|
||||
[pad_attention, attention_mask], axis=1
|
||||
)
|
||||
elif input_tokens.shape[1] > FLAGS.input_length:
|
||||
input_tokens = input_tokens[:, -FLAGS.input_length:]
|
||||
attention_mask = attention_mask[:, -FLAGS.input_length:]
|
||||
|
||||
if FLAGS.add_bos_token:
|
||||
input_tokens[:, 0] = tokenizer.bos_token_id
|
||||
attention_mask[:, 0] = 1
|
||||
|
||||
batch = dict(input_tokens=input_tokens, attention_mask=attention_mask)
|
||||
|
||||
with mesh:
|
||||
output, sharded_rng = forward_greedy_generate(
|
||||
params, sharded_rng, batch
|
||||
)
|
||||
output = jax.device_get(output)
|
||||
|
||||
total_length += output.shape[1]
|
||||
output_text = tokenizer.batch_decode(output)[0]
|
||||
total_generated = total_generated + output_text
|
||||
pf = pf + output_text
|
||||
|
||||
done = False
|
||||
for s in ut:
|
||||
if s in total_generated:
|
||||
total_generated = total_generated.split(s, maxsplit=1)[0]
|
||||
done = True
|
||||
if done:
|
||||
break
|
||||
|
||||
all_outputs.append(total_generated)
|
||||
|
||||
return all_outputs
|
||||
|
||||
|
||||
server = ModelServer(FLAGS.lm_server)
|
||||
server.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mlxu.run(main)
|
||||
272
EasyLM/models/gptj/gptj_train.py
Normal file
272
EasyLM/models/gptj/gptj_train.py
Normal file
@@ -0,0 +1,272 @@
|
||||
import pprint
|
||||
from functools import partial
|
||||
|
||||
from tqdm import tqdm, trange
|
||||
import numpy as np
|
||||
import mlxu
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax.experimental.pjit import pjit, with_sharding_constraint
|
||||
from jax.sharding import PartitionSpec as PS
|
||||
from flax.training.train_state import TrainState
|
||||
|
||||
from EasyLM.data import DatasetFactory
|
||||
from EasyLM.checkpoint import StreamingCheckpointer
|
||||
from EasyLM.optimizers import OptimizerFactory
|
||||
from EasyLM.jax_utils import (
|
||||
JaxRNG, JaxDistributedConfig, next_rng, match_partition_rules,
|
||||
cross_entropy_loss_and_accuracy, global_norm, get_float_dtype_by_name,
|
||||
set_random_seed, average_metrics, get_weight_decay_mask,
|
||||
make_shard_and_gather_fns, tree_apply
|
||||
)
|
||||
from EasyLM.models.gptj.gptj_model import GPTJConfig, FlaxGPTJForCausalLMModule
|
||||
|
||||
|
||||
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
|
||||
seed=42,
|
||||
mesh_dim='1,-1,1',
|
||||
dtype='fp32',
|
||||
total_steps=10000,
|
||||
load_gptj_config='',
|
||||
update_gptj_config='',
|
||||
load_checkpoint='',
|
||||
load_dataset_state='',
|
||||
log_freq=50,
|
||||
save_model_freq=0,
|
||||
save_milestone_freq=0,
|
||||
eval_steps=0,
|
||||
tokenizer=GPTJConfig.get_tokenizer_config(),
|
||||
train_dataset=DatasetFactory.get_default_config(),
|
||||
eval_dataset=DatasetFactory.get_default_config(),
|
||||
optimizer=OptimizerFactory.get_default_config(),
|
||||
checkpointer=StreamingCheckpointer.get_default_config(),
|
||||
gptj=GPTJConfig.get_default_config(),
|
||||
logger=mlxu.WandBLogger.get_default_config(),
|
||||
log_all_worker=False,
|
||||
jax_distributed=JaxDistributedConfig.get_default_config(),
|
||||
)
|
||||
|
||||
|
||||
def main(argv):
|
||||
JaxDistributedConfig.initialize(FLAGS.jax_distributed)
|
||||
variant = mlxu.get_user_flags(FLAGS, FLAGS_DEF)
|
||||
flags_config_dict = mlxu.user_flags_to_config_dict(FLAGS, FLAGS_DEF)
|
||||
logger = mlxu.WandBLogger(
|
||||
config=FLAGS.logger,
|
||||
variant=variant,
|
||||
enable=FLAGS.log_all_worker or (jax.process_index() == 0),
|
||||
)
|
||||
set_random_seed(FLAGS.seed)
|
||||
|
||||
tokenizer = GPTJConfig.get_tokenizer(FLAGS.tokenizer)
|
||||
dataset = DatasetFactory.load_dataset(FLAGS.train_dataset, tokenizer)
|
||||
if FLAGS.load_dataset_state != '':
|
||||
dataset.load_state_dict(mlxu.load_pickle(FLAGS.load_dataset_state))
|
||||
|
||||
if FLAGS.eval_steps > 0:
|
||||
eval_dataset = DatasetFactory.load_dataset(
|
||||
FLAGS.eval_dataset, dataset.tokenizer
|
||||
)
|
||||
eval_iterator = iter(eval_dataset)
|
||||
|
||||
seq_length = dataset.seq_length
|
||||
|
||||
if FLAGS.load_gptj_config != '':
|
||||
gptj_config = GPTJConfig.load_config(FLAGS.load_gptj_config)
|
||||
else:
|
||||
gptj_config = GPTJConfig(**FLAGS.gptj)
|
||||
|
||||
if FLAGS.update_gptj_config != '':
|
||||
gptj_config.update(dict(eval(FLAGS.update_gptj_config)))
|
||||
|
||||
gptj_config.update(dict(
|
||||
bos_token_id=dataset.tokenizer.bos_token_id,
|
||||
eos_token_id=dataset.tokenizer.eos_token_id,
|
||||
))
|
||||
if gptj_config.vocab_size < dataset.vocab_size:
|
||||
gptj_config.update(dict(vocab_size=dataset.vocab_size))
|
||||
|
||||
model = FlaxGPTJForCausalLMModule(
|
||||
gptj_config, dtype=get_float_dtype_by_name(FLAGS.dtype)
|
||||
)
|
||||
|
||||
optimizer, optimizer_info = OptimizerFactory.get_optimizer(
|
||||
FLAGS.optimizer,
|
||||
get_weight_decay_mask(GPTJConfig.get_weight_decay_exclusions()),
|
||||
)
|
||||
|
||||
def create_trainstate_from_params(params):
|
||||
return TrainState.create(params=params, tx=optimizer, apply_fn=None)
|
||||
|
||||
def init_fn(rng):
|
||||
rng_generator = JaxRNG(rng)
|
||||
params = model.init(
|
||||
input_ids=jnp.zeros((4, seq_length), dtype=jnp.int32),
|
||||
position_ids=jnp.zeros((4, seq_length), dtype=jnp.int32),
|
||||
attention_mask=jnp.ones((4, seq_length), dtype=jnp.int32),
|
||||
rngs=rng_generator(gptj_config.rng_keys()),
|
||||
)
|
||||
return TrainState.create(params=params, tx=optimizer, apply_fn=None)
|
||||
|
||||
def train_step(train_state, rng, batch):
|
||||
rng_generator = JaxRNG(rng)
|
||||
batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
|
||||
def loss_and_accuracy(params):
|
||||
logits = model.apply(
|
||||
params, batch['input_tokens'], deterministic=False,
|
||||
rngs=rng_generator(gptj_config.rng_keys()),
|
||||
).logits
|
||||
return cross_entropy_loss_and_accuracy(
|
||||
logits, batch['target_tokens'], batch['loss_masks']
|
||||
)
|
||||
grad_fn = jax.value_and_grad(loss_and_accuracy, has_aux=True)
|
||||
(loss, accuracy), grads = grad_fn(train_state.params)
|
||||
train_state = train_state.apply_gradients(grads=grads)
|
||||
metrics = dict(
|
||||
loss=loss,
|
||||
accuracy=accuracy,
|
||||
learning_rate=optimizer_info['learning_rate_schedule'](train_state.step),
|
||||
gradient_norm=global_norm(grads),
|
||||
param_norm=global_norm(train_state.params),
|
||||
)
|
||||
return train_state, rng_generator(), metrics
|
||||
|
||||
def eval_step(train_state, rng, batch):
|
||||
rng_generator = JaxRNG(rng)
|
||||
batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
|
||||
logits = model.apply(
|
||||
train_state.params, batch['input_tokens'], deterministic=True,
|
||||
rngs=rng_generator(gptj_config.rng_keys()),
|
||||
).logits
|
||||
loss, accuracy = cross_entropy_loss_and_accuracy(
|
||||
logits, batch['target_tokens'], batch['loss_masks']
|
||||
)
|
||||
metrics = dict(
|
||||
eval_loss=loss,
|
||||
eval_accuracy=accuracy,
|
||||
)
|
||||
return rng_generator(), metrics
|
||||
|
||||
train_state_shapes = jax.eval_shape(init_fn, next_rng())
|
||||
train_state_partition = match_partition_rules(
|
||||
GPTJConfig.get_partition_rules(), train_state_shapes
|
||||
)
|
||||
|
||||
shard_fns, gather_fns = make_shard_and_gather_fns(
|
||||
train_state_partition, train_state_shapes
|
||||
)
|
||||
checkpointer = StreamingCheckpointer(
|
||||
FLAGS.checkpointer, logger.output_dir,
|
||||
enable=jax.process_index() == 0,
|
||||
)
|
||||
|
||||
sharded_init_fn = pjit(
|
||||
init_fn,
|
||||
in_shardings=PS(),
|
||||
out_shardings=train_state_partition
|
||||
)
|
||||
|
||||
sharded_create_trainstate_from_params = pjit(
|
||||
create_trainstate_from_params,
|
||||
in_shardings=(train_state_partition.params, ),
|
||||
out_shardings=train_state_partition,
|
||||
donate_argnums=(0, ),
|
||||
)
|
||||
|
||||
sharded_train_step = pjit(
|
||||
train_step,
|
||||
in_shardings=(train_state_partition, PS(), PS()),
|
||||
out_shardings=(train_state_partition, PS(), PS()),
|
||||
donate_argnums=(0, 1),
|
||||
)
|
||||
|
||||
sharded_eval_step = pjit(
|
||||
eval_step,
|
||||
in_shardings=(train_state_partition, PS(), PS()),
|
||||
out_shardings=(PS(), PS()),
|
||||
donate_argnums=(1,),
|
||||
)
|
||||
|
||||
def save_checkpoint(train_state, milestone=False):
|
||||
step = int(jax.device_get(train_state.step))
|
||||
metadata = dict(
|
||||
step=step,
|
||||
variant=variant,
|
||||
flags=flags_config_dict,
|
||||
gptj_config=gptj_config.to_dict(),
|
||||
)
|
||||
checkpointer.save_all(
|
||||
train_state=train_state,
|
||||
gather_fns=gather_fns,
|
||||
metadata=metadata,
|
||||
dataset=dataset.get_state_dict(),
|
||||
milestone=milestone,
|
||||
)
|
||||
|
||||
mesh = GPTJConfig.get_jax_mesh(FLAGS.mesh_dim)
|
||||
with mesh:
|
||||
train_state, restored_params = None, None
|
||||
if FLAGS.load_checkpoint != '':
|
||||
load_type, load_path = FLAGS.load_checkpoint.split('::', 1)
|
||||
if load_type == 'huggingface':
|
||||
restored_params = tree_apply(
|
||||
shard_fns.params, gptj_config.load_pretrained(load_path)
|
||||
)
|
||||
train_state = None
|
||||
else:
|
||||
train_state, restored_params = checkpointer.load_trainstate_checkpoint(
|
||||
FLAGS.load_checkpoint, train_state_shapes, shard_fns
|
||||
)
|
||||
|
||||
if train_state is None and restored_params is None:
|
||||
# Initialize from scratch
|
||||
train_state = sharded_init_fn(next_rng())
|
||||
elif train_state is None and restored_params is not None:
|
||||
# Restore from params but initialize train_state
|
||||
train_state = sharded_create_trainstate_from_params(restored_params)
|
||||
del restored_params
|
||||
|
||||
start_step = int(jax.device_get(train_state.step))
|
||||
|
||||
if FLAGS.save_model_freq > 0:
|
||||
save_checkpoint(train_state)
|
||||
|
||||
sharded_rng = next_rng()
|
||||
|
||||
step_counter = trange(start_step, FLAGS.total_steps, ncols=0)
|
||||
|
||||
for step, (batch, dataset_metrics) in zip(step_counter, dataset):
|
||||
train_state, sharded_rng, metrics = sharded_train_step(
|
||||
train_state, sharded_rng, batch
|
||||
)
|
||||
|
||||
if step % FLAGS.log_freq == 0:
|
||||
if FLAGS.eval_steps > 0:
|
||||
eval_metric_list = []
|
||||
for _ in range(FLAGS.eval_steps):
|
||||
eval_batch, _ = next(eval_iterator)
|
||||
sharded_rng, eval_metrics = sharded_eval_step(
|
||||
train_state, sharded_rng, eval_batch
|
||||
)
|
||||
eval_metric_list.append(eval_metrics)
|
||||
metrics.update(average_metrics(eval_metric_list))
|
||||
|
||||
log_metrics = {"step": step}
|
||||
log_metrics.update(metrics)
|
||||
log_metrics.update(dataset_metrics)
|
||||
log_metrics = jax.device_get(log_metrics)
|
||||
logger.log(log_metrics)
|
||||
tqdm.write("\n" + pprint.pformat(log_metrics) + "\n")
|
||||
|
||||
if FLAGS.save_milestone_freq > 0 and (step + 1) % FLAGS.save_milestone_freq == 0:
|
||||
save_checkpoint(train_state, milestone=True)
|
||||
elif FLAGS.save_model_freq > 0 and (step + 1) % FLAGS.save_model_freq == 0:
|
||||
save_checkpoint(train_state)
|
||||
|
||||
if FLAGS.save_model_freq > 0:
|
||||
save_checkpoint(train_state)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mlxu.run(main)
|
||||
338
EasyLM/models/llama/convert_easylm_to_hf.py
Normal file
338
EasyLM/models/llama/convert_easylm_to_hf.py
Normal file
@@ -0,0 +1,338 @@
|
||||
# Copyright 2022 EleutherAI and The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2023 Xinyang Geng
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# This script converts LLaMA model checkpoint trained by EsayLM to the
|
||||
# HuggingFace transformers LLaMA PyTorch format, which can then be loaded
|
||||
# by HuggingFace transformers.
|
||||
|
||||
import gc
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import shutil
|
||||
|
||||
import numpy as np
|
||||
import mlxu
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import flax
|
||||
from flax.traverse_util import flatten_dict
|
||||
import torch
|
||||
from transformers import LlamaConfig, LlamaForCausalLM
|
||||
|
||||
from EasyLM.checkpoint import StreamingCheckpointer
|
||||
from EasyLM.jax_utils import float_tensor_to_dtype
|
||||
|
||||
|
||||
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
|
||||
load_checkpoint='',
|
||||
tokenizer_path='',
|
||||
model_size='13b',
|
||||
output_dir='',
|
||||
)
|
||||
|
||||
|
||||
LLAMA_STANDARD_CONFIGS = {
|
||||
'small': {
|
||||
'vocab_size': 64256,
|
||||
'dim': 768,
|
||||
'intermediate_size': 3072,
|
||||
'n_layers': 12,
|
||||
'n_heads': 12,
|
||||
'norm_eps': 1e-6,
|
||||
},
|
||||
'medium': {
|
||||
'vocab_size': 64256,
|
||||
'dim': 1024,
|
||||
'intermediate_size': 4096,
|
||||
'n_layers': 24,
|
||||
'n_heads': 16,
|
||||
'norm_eps': 1e-6,
|
||||
},
|
||||
'large': {
|
||||
'vocab_size': 64256,
|
||||
'dim': 1536,
|
||||
'intermediate_size': 6144,
|
||||
'n_layers': 24,
|
||||
'n_heads': 16,
|
||||
'norm_eps': 1e-6,
|
||||
},
|
||||
'xlarge': {
|
||||
'vocab_size': 64256,
|
||||
'dim': 2048,
|
||||
'intermediate_size': 8192,
|
||||
'n_layers': 24,
|
||||
'n_heads': 32,
|
||||
'norm_eps': 1e-6,
|
||||
},
|
||||
'1b': {
|
||||
'vocab_size': 64256,
|
||||
'dim': 2048,
|
||||
'intermediate_size': 5504,
|
||||
'n_layers': 22,
|
||||
'n_heads': 16,
|
||||
'norm_eps': 1e-6,
|
||||
},
|
||||
'3b': {
|
||||
'vocab_size': 64256,
|
||||
'dim': 3200,
|
||||
'intermediate_size': 8640,
|
||||
'n_layers': 26,
|
||||
'n_heads': 32,
|
||||
'norm_eps': 1e-6,
|
||||
},
|
||||
'7b': {
|
||||
'vocab_size': 64256,
|
||||
'dim': 4096,
|
||||
'intermediate_size': 11008,
|
||||
'n_layers': 32,
|
||||
'n_heads': 32,
|
||||
'norm_eps': 1e-6,
|
||||
},
|
||||
'13b': {
|
||||
'vocab_size': 64256,
|
||||
'dim': 5120,
|
||||
'intermediate_size': 13824,
|
||||
'n_layers': 40,
|
||||
'n_heads': 40,
|
||||
'norm_eps': 1e-6,
|
||||
},
|
||||
'30b': {
|
||||
'vocab_size': 64256,
|
||||
'dim': 6656,
|
||||
'intermediate_size': 17920,
|
||||
'n_layers': 60,
|
||||
'n_heads': 52,
|
||||
'norm_eps': 1e-6,
|
||||
},
|
||||
'65b': {
|
||||
'vocab_size': 64256,
|
||||
'dim': 8192,
|
||||
'intermediate_size': 22016,
|
||||
'n_layers': 80,
|
||||
'n_heads': 64,
|
||||
'norm_eps': 1e-5,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def match_keywords(string, positives, negatives):
|
||||
for positive in positives:
|
||||
if positive not in string:
|
||||
return False
|
||||
for negative in negatives:
|
||||
if negative in string:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def load_and_convert_checkpoint(path):
|
||||
_, flax_params = StreamingCheckpointer.load_trainstate_checkpoint(path)
|
||||
flax_params = flatten_dict(flax_params['params'], sep='.')
|
||||
torch_params = {}
|
||||
for key, tensor in flax_params.items():
|
||||
if match_keywords(key, ["kernel"], ["norm", 'ln_f']):
|
||||
tensor = tensor.T
|
||||
torch_params[key] = torch.tensor(
|
||||
float_tensor_to_dtype(tensor, 'fp32'), dtype=torch.float16
|
||||
)
|
||||
return torch_params
|
||||
|
||||
|
||||
def read_json(path):
|
||||
with open(path, "r") as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def write_json(text, path):
|
||||
with open(path, "w") as f:
|
||||
json.dump(text, f)
|
||||
|
||||
|
||||
def write_model(loaded, model_path, model_size):
|
||||
os.makedirs(model_path, exist_ok=True)
|
||||
tmp_model_path = os.path.join(model_path, "tmp")
|
||||
os.makedirs(tmp_model_path, exist_ok=True)
|
||||
|
||||
params = LLAMA_STANDARD_CONFIGS[model_size]
|
||||
|
||||
n_layers = params["n_layers"]
|
||||
n_heads = params["n_heads"]
|
||||
dim = params["dim"]
|
||||
dims_per_head = dim // n_heads
|
||||
base = 10000.0
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head))
|
||||
|
||||
# permute for sliced rotary
|
||||
def permute(w):
|
||||
return w.view(n_heads, dim // n_heads // 2, 2, dim).transpose(1, 2).reshape(dim, dim)
|
||||
|
||||
|
||||
param_count = 0
|
||||
index_dict = {"weight_map": {}}
|
||||
for layer_i in range(n_layers):
|
||||
filename = f"pytorch_model-{layer_i + 1}-of-{n_layers + 1}.bin"
|
||||
state_dict = {
|
||||
f"model.layers.{layer_i}.self_attn.q_proj.weight": permute(
|
||||
loaded[f"transformer.h.{layer_i}.attention.wq.kernel"]
|
||||
),
|
||||
f"model.layers.{layer_i}.self_attn.k_proj.weight": permute(
|
||||
loaded[f"transformer.h.{layer_i}.attention.wk.kernel"]
|
||||
),
|
||||
f"model.layers.{layer_i}.self_attn.v_proj.weight": loaded[f"transformer.h.{layer_i}.attention.wv.kernel"],
|
||||
f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded[f"transformer.h.{layer_i}.attention.wo.kernel"],
|
||||
|
||||
f"model.layers.{layer_i}.mlp.gate_proj.weight": loaded[f"transformer.h.{layer_i}.feed_forward.w1.kernel"],
|
||||
f"model.layers.{layer_i}.mlp.down_proj.weight": loaded[f"transformer.h.{layer_i}.feed_forward.w2.kernel"],
|
||||
f"model.layers.{layer_i}.mlp.up_proj.weight": loaded[f"transformer.h.{layer_i}.feed_forward.w3.kernel"],
|
||||
|
||||
f"model.layers.{layer_i}.input_layernorm.weight": loaded[f"transformer.h.{layer_i}.attention_norm.kernel"],
|
||||
f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[f"transformer.h.{layer_i}.ffn_norm.kernel"],
|
||||
|
||||
}
|
||||
|
||||
state_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq
|
||||
for k, v in state_dict.items():
|
||||
index_dict["weight_map"][k] = filename
|
||||
param_count += v.numel()
|
||||
torch.save(state_dict, os.path.join(tmp_model_path, filename))
|
||||
|
||||
filename = f"pytorch_model-{n_layers + 1}-of-{n_layers + 1}.bin"
|
||||
# Unsharded
|
||||
state_dict = {
|
||||
"model.embed_tokens.weight": loaded["transformer.wte.embedding"],
|
||||
"model.norm.weight": loaded["transformer.ln_f.kernel"],
|
||||
"lm_head.weight": loaded["lm_head.kernel"],
|
||||
}
|
||||
|
||||
for k, v in state_dict.items():
|
||||
index_dict["weight_map"][k] = filename
|
||||
param_count += v.numel()
|
||||
torch.save(state_dict, os.path.join(tmp_model_path, filename))
|
||||
|
||||
# Write configs
|
||||
index_dict["metadata"] = {"total_size": param_count * 2}
|
||||
write_json(index_dict, os.path.join(tmp_model_path, "pytorch_model.bin.index.json"))
|
||||
|
||||
config = LlamaConfig(
|
||||
vocab_size=params["vocab_size"],
|
||||
hidden_size=dim,
|
||||
intermediate_size=params["intermediate_size"],
|
||||
num_attention_heads=params["n_heads"],
|
||||
num_hidden_layers=params["n_layers"],
|
||||
rms_norm_eps=params["norm_eps"],
|
||||
)
|
||||
config.save_pretrained(tmp_model_path)
|
||||
|
||||
# Make space so we can load the model properly now.
|
||||
del state_dict
|
||||
del loaded
|
||||
gc.collect()
|
||||
|
||||
print("Loading the checkpoint in a Llama model.")
|
||||
model = LlamaForCausalLM.from_pretrained(tmp_model_path, torch_dtype=torch.float16)
|
||||
# Avoid saving this as part of the config.
|
||||
print("Model parameter count", model.num_parameters())
|
||||
del model.config._name_or_path
|
||||
|
||||
print("Saving in the Transformers format.")
|
||||
model.save_pretrained(model_path, safe_serialization=True)
|
||||
shutil.rmtree(tmp_model_path)
|
||||
|
||||
|
||||
def write_tokenizer(tokenizer_path, input_tokenizer_path):
|
||||
print(f"Fetching the tokenizer from {input_tokenizer_path}.")
|
||||
os.makedirs(tokenizer_path, exist_ok=True)
|
||||
write_json(
|
||||
{
|
||||
"bos_token": {
|
||||
"content": "<s>",
|
||||
"lstrip": False,
|
||||
"normalized": True,
|
||||
"rstrip": False,
|
||||
"single_word": False
|
||||
},
|
||||
"eos_token": {
|
||||
"content": "</s>",
|
||||
"lstrip": False,
|
||||
"normalized": True,
|
||||
"rstrip": False,
|
||||
"single_word": False
|
||||
},
|
||||
"unk_token": {
|
||||
"content": "<unk>",
|
||||
"lstrip": False,
|
||||
"normalized": True,
|
||||
"rstrip": False,
|
||||
"single_word": False
|
||||
},
|
||||
},
|
||||
os.path.join(tokenizer_path, "special_tokens_map.json")
|
||||
)
|
||||
write_json(
|
||||
{
|
||||
"add_bos_token": True,
|
||||
"add_eos_token": False,
|
||||
"model_max_length": 2048,
|
||||
"pad_token": None,
|
||||
"sp_model_kwargs": {},
|
||||
"tokenizer_class": "LlamaTokenizer",
|
||||
"clean_up_tokenization_spaces": False,
|
||||
"bos_token": {
|
||||
"__type": "AddedToken",
|
||||
"content": "<s>",
|
||||
"lstrip": False,
|
||||
"normalized": True,
|
||||
"rstrip": False,
|
||||
"single_word": False
|
||||
},
|
||||
"eos_token": {
|
||||
"__type": "AddedToken",
|
||||
"content": "</s>",
|
||||
"lstrip": False,
|
||||
"normalized": True,
|
||||
"rstrip": False,
|
||||
"single_word": False
|
||||
},
|
||||
"unk_token": {
|
||||
"__type": "AddedToken",
|
||||
"content": "<unk>",
|
||||
"lstrip": False,
|
||||
"normalized": True,
|
||||
"rstrip": False,
|
||||
"single_word": False
|
||||
},
|
||||
},
|
||||
os.path.join(tokenizer_path, "tokenizer_config.json"),
|
||||
)
|
||||
shutil.copyfile(input_tokenizer_path, os.path.join(tokenizer_path, "tokenizer.model"))
|
||||
|
||||
|
||||
def main(argv):
|
||||
assert FLAGS.load_checkpoint != "" and FLAGS.output_dir != ""# and FLAGS.tokenizer_path != ""
|
||||
assert FLAGS.model_size in LLAMA_STANDARD_CONFIGS
|
||||
# write_tokenizer(
|
||||
# tokenizer_path=FLAGS.output_dir,
|
||||
# input_tokenizer_path=FLAGS.tokenizer_path,
|
||||
# )
|
||||
write_model(
|
||||
load_and_convert_checkpoint(FLAGS.load_checkpoint),
|
||||
model_path=FLAGS.output_dir,
|
||||
model_size=FLAGS.model_size,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mlxu.run(main)
|
||||
196
EasyLM/models/llama/convert_hf_to_easylm.py
Normal file
196
EasyLM/models/llama/convert_hf_to_easylm.py
Normal file
@@ -0,0 +1,196 @@
|
||||
"""
|
||||
Usage:
|
||||
python convert_hf_to_easylm.py \
|
||||
--checkpoint_dir /path/hf_format_dir/ \
|
||||
--output_file /path/easylm_format.stream \
|
||||
--model_size 7b \
|
||||
--streaming
|
||||
"""
|
||||
import time
|
||||
from pathlib import Path
|
||||
import argparse
|
||||
|
||||
import mlxu
|
||||
import torch
|
||||
import flax
|
||||
|
||||
from EasyLM.checkpoint import StreamingCheckpointer
|
||||
|
||||
LLAMA_STANDARD_CONFIGS = {
|
||||
'1b': {
|
||||
'dim': 2048,
|
||||
'intermediate_size': 5504,
|
||||
'n_layers': 22,
|
||||
'n_heads': 16,
|
||||
'norm_eps': 1e-6,
|
||||
},
|
||||
'3b': {
|
||||
'dim': 3200,
|
||||
'intermediate_size': 8640,
|
||||
'n_layers': 26,
|
||||
'n_heads': 32,
|
||||
'norm_eps': 1e-6,
|
||||
},
|
||||
"7b": {
|
||||
"dim": 4096,
|
||||
"intermediate_size": 11008,
|
||||
"n_layers": 32,
|
||||
"n_heads": 32,
|
||||
"norm_eps": 1e-6,
|
||||
},
|
||||
"13b": {
|
||||
"dim": 5120,
|
||||
"intermediate_size": 13824,
|
||||
"n_layers": 40,
|
||||
"n_heads": 40,
|
||||
"norm_eps": 1e-6,
|
||||
},
|
||||
"30b": {
|
||||
"dim": 6656,
|
||||
"intermediate_size": 17920,
|
||||
"n_layers": 60,
|
||||
"n_heads": 52,
|
||||
"norm_eps": 1e-6,
|
||||
},
|
||||
"65b": {
|
||||
"dim": 8192,
|
||||
"intermediate_size": 22016,
|
||||
"n_layers": 80,
|
||||
"n_heads": 64,
|
||||
"norm_eps": 1e-5,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def inverse_permute(params, w):
|
||||
n_layers = params["n_layers"]
|
||||
n_heads = params["n_heads"]
|
||||
dim = params["dim"]
|
||||
reshaped_w = w.reshape(n_heads, 2, dim // n_heads // 2, dim)
|
||||
transposed_w = reshaped_w.transpose(0, 2, 1, 3)
|
||||
inverted_w = transposed_w.reshape(dim, dim)
|
||||
return inverted_w
|
||||
|
||||
|
||||
def main(args):
|
||||
start = time.time()
|
||||
params = LLAMA_STANDARD_CONFIGS[args.model_size]
|
||||
|
||||
ckpt_paths = sorted(Path(args.checkpoint_dir).glob("*.bin"))
|
||||
ckpt = {}
|
||||
for i, ckpt_path in enumerate(ckpt_paths):
|
||||
checkpoint = torch.load(ckpt_path, map_location="cpu")
|
||||
for k, v in checkpoint.items():
|
||||
if k.startswith("model."):
|
||||
k = k[6:]
|
||||
ckpt[k] = v
|
||||
print(f"Start convert weight to easylm format...")
|
||||
jax_weights = {
|
||||
"transformer": {
|
||||
"wte": {"embedding": ckpt["embed_tokens.weight"].numpy()},
|
||||
"ln_f": {"kernel": ckpt["norm.weight"].numpy()},
|
||||
"h": {
|
||||
"%d"
|
||||
% (layer): {
|
||||
"attention": {
|
||||
"wq": {
|
||||
"kernel": inverse_permute(
|
||||
params,
|
||||
ckpt[f"layers.{layer}.self_attn.q_proj.weight"].numpy(),
|
||||
).transpose()
|
||||
},
|
||||
"wk": {
|
||||
"kernel": inverse_permute(
|
||||
params,
|
||||
ckpt[f"layers.{layer}.self_attn.k_proj.weight"].numpy(),
|
||||
).transpose()
|
||||
},
|
||||
"wv": {
|
||||
"kernel": ckpt[f"layers.{layer}.self_attn.v_proj.weight"]
|
||||
.numpy()
|
||||
.transpose()
|
||||
},
|
||||
"wo": {
|
||||
"kernel": ckpt[f"layers.{layer}.self_attn.o_proj.weight"]
|
||||
.numpy()
|
||||
.transpose()
|
||||
},
|
||||
},
|
||||
"feed_forward": {
|
||||
"w1": {
|
||||
"kernel": ckpt[f"layers.{layer}.mlp.gate_proj.weight"]
|
||||
.numpy()
|
||||
.transpose()
|
||||
},
|
||||
"w2": {
|
||||
"kernel": ckpt[f"layers.{layer}.mlp.down_proj.weight"]
|
||||
.numpy()
|
||||
.transpose()
|
||||
},
|
||||
"w3": {
|
||||
"kernel": ckpt[f"layers.{layer}.mlp.up_proj.weight"]
|
||||
.numpy()
|
||||
.transpose()
|
||||
},
|
||||
},
|
||||
"attention_norm": {
|
||||
"kernel": ckpt[f"layers.{layer}.input_layernorm.weight"].numpy()
|
||||
},
|
||||
"ffn_norm": {
|
||||
"kernel": ckpt[
|
||||
f"layers.{layer}.post_attention_layernorm.weight"
|
||||
].numpy()
|
||||
},
|
||||
}
|
||||
for layer in range(params["n_layers"])
|
||||
},
|
||||
},
|
||||
"lm_head": {"kernel": ckpt["lm_head.weight"].numpy().transpose()},
|
||||
}
|
||||
print(f"Convert weight to easylm format finished...")
|
||||
print(f"Start to save...")
|
||||
|
||||
if args.streaming:
|
||||
StreamingCheckpointer.save_train_state_to_file(jax_weights, args.output_file)
|
||||
else:
|
||||
with mlxu.open_file(args.output_file, "wb") as fout:
|
||||
fout.write(flax.serialization.msgpack_serialize(jax_weights, in_place=True))
|
||||
|
||||
print(
|
||||
f"Save finished!!! take time: {time.time() - start} save path: {args.output_file}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="hf to easylm format script")
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint_dir",
|
||||
type=str,
|
||||
help="Need to be converted model weight dir. it is a dir",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_file", type=str, help="Save model weight file path, it is a file."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_size",
|
||||
type=str,
|
||||
default="7b",
|
||||
choices=["7b", "13b", "30b", "65b"],
|
||||
help="model size",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--streaming",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="whether is model weight saved stream format",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
print(f"checkpoint_dir: {args.checkpoint_dir}")
|
||||
print(f"output_file: {args.output_file}")
|
||||
print(f"model_size: {args.model_size}")
|
||||
print(f"streaming: {args.streaming}")
|
||||
|
||||
main(args)
|
||||
68
EasyLM/models/llama/convert_torch_to_easylm.py
Normal file
68
EasyLM/models/llama/convert_torch_to_easylm.py
Normal file
@@ -0,0 +1,68 @@
|
||||
# This script converts the standrd LLaMA PyTorch checkpoint released by Meta
|
||||
# to the EasyLM checkpoint format. The converted checkpoint can then be loaded
|
||||
# by EasyLM for fine-tuning or inference.
|
||||
|
||||
# This script is largely borrow from https://github.com/Sea-Snell/JAX_llama
|
||||
|
||||
from pathlib import Path
|
||||
import json
|
||||
import numpy as np
|
||||
import torch
|
||||
import flax
|
||||
import mlxu
|
||||
|
||||
from EasyLM.checkpoint import StreamingCheckpointer
|
||||
|
||||
|
||||
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
|
||||
checkpoint_dir='',
|
||||
output_file='',
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
|
||||
def main(argv):
|
||||
ckpt_paths = sorted(Path(FLAGS.checkpoint_dir).glob("*.pth"))
|
||||
ckpts = {}
|
||||
for i, ckpt_path in enumerate(ckpt_paths):
|
||||
checkpoint = torch.load(ckpt_path, map_location="cpu")
|
||||
ckpts[int(ckpt_path.name.split('.', maxsplit=2)[1])] = checkpoint
|
||||
ckpts = [ckpts[i] for i in sorted(list(ckpts.keys()))]
|
||||
with open(Path(FLAGS.checkpoint_dir) / "params.json", "r") as f:
|
||||
params = json.loads(f.read())
|
||||
|
||||
jax_weights = {
|
||||
'transformer': {
|
||||
'wte': {'embedding': np.concatenate([ckpt['tok_embeddings.weight'].numpy() for ckpt in ckpts], axis=1)},
|
||||
'ln_f': {'kernel': ckpts[0]['norm.weight'].numpy()},
|
||||
'h': {
|
||||
'%d' % (layer): {
|
||||
'attention': {
|
||||
'wq': {'kernel': np.concatenate([ckpt['layers.%d.attention.wq.weight' % (layer)].numpy() for ckpt in ckpts], axis=0).transpose()},
|
||||
'wk': {'kernel': np.concatenate([ckpt['layers.%d.attention.wk.weight' % (layer)].numpy() for ckpt in ckpts], axis=0).transpose()},
|
||||
'wv': {'kernel': np.concatenate([ckpt['layers.%d.attention.wv.weight' % (layer)].numpy() for ckpt in ckpts], axis=0).transpose()},
|
||||
'wo': {'kernel': np.concatenate([ckpt['layers.%d.attention.wo.weight' % (layer)].numpy() for ckpt in ckpts], axis=1).transpose()},
|
||||
},
|
||||
'feed_forward': {
|
||||
'w1': {'kernel': np.concatenate([ckpt['layers.%d.feed_forward.w1.weight' % (layer)].numpy() for ckpt in ckpts], axis=0).transpose()},
|
||||
'w2': {'kernel': np.concatenate([ckpt['layers.%d.feed_forward.w2.weight' % (layer)].numpy() for ckpt in ckpts], axis=1).transpose()},
|
||||
'w3': {'kernel': np.concatenate([ckpt['layers.%d.feed_forward.w3.weight' % (layer)].numpy() for ckpt in ckpts], axis=0).transpose()},
|
||||
},
|
||||
'attention_norm': {'kernel': ckpts[0]['layers.%d.attention_norm.weight' % (layer)].numpy()},
|
||||
'ffn_norm': {'kernel': ckpts[0]['layers.%d.ffn_norm.weight' % (layer)].numpy()},
|
||||
}
|
||||
for layer in range(params['n_layers'])},
|
||||
},
|
||||
'lm_head': {'kernel': np.concatenate([ckpt['output.weight'].numpy() for ckpt in ckpts], axis=0).transpose()},
|
||||
}
|
||||
if FLAGS.streaming:
|
||||
StreamingCheckpointer.save_train_state_to_file(
|
||||
jax_weights, FLAGS.output_file
|
||||
)
|
||||
else:
|
||||
with mlxu.open_file(FLAGS.output_file, 'wb') as fout:
|
||||
fout.write(flax.serialization.msgpack_serialize(jax_weights, in_place=True))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
mlxu.run(main)
|
||||
1530
EasyLM/models/llama/llama_model.py
Normal file
1530
EasyLM/models/llama/llama_model.py
Normal file
File diff suppressed because it is too large
Load Diff
386
EasyLM/models/llama/llama_serve.py
Normal file
386
EasyLM/models/llama/llama_serve.py
Normal file
@@ -0,0 +1,386 @@
|
||||
import pprint
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import mlxu
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax.experimental.pjit import pjit
|
||||
from jax.sharding import PartitionSpec as PS
|
||||
import optax
|
||||
from transformers import GenerationConfig, FlaxLogitsProcessorList
|
||||
|
||||
from EasyLM.checkpoint import StreamingCheckpointer
|
||||
from EasyLM.serving import LMServer
|
||||
from EasyLM.jax_utils import (
|
||||
JaxRNG, JaxDistributedConfig, next_rng, match_partition_rules, tree_apply,
|
||||
set_random_seed, get_float_dtype_by_name, make_shard_and_gather_fns,
|
||||
with_sharding_constraint, FlaxTemperatureLogitsWarper
|
||||
)
|
||||
from EasyLM.models.llama.llama_model import LLaMAConfig, FlaxLLaMAForCausalLM
|
||||
|
||||
|
||||
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
|
||||
seed=42,
|
||||
initialize_jax_distributed=False,
|
||||
mesh_dim='1,-1,1',
|
||||
dtype='bf16',
|
||||
input_length=1024,
|
||||
seq_length=2048,
|
||||
top_k=50,
|
||||
top_p=1.0,
|
||||
do_sample=True,
|
||||
num_beams=1,
|
||||
add_bos_token=True,
|
||||
load_llama_config='',
|
||||
load_checkpoint='',
|
||||
tokenizer=LLaMAConfig.get_tokenizer_config(),
|
||||
lm_server=LMServer.get_default_config(),
|
||||
jax_distributed=JaxDistributedConfig.get_default_config(),
|
||||
)
|
||||
|
||||
|
||||
def main(argv):
|
||||
JaxDistributedConfig.initialize(FLAGS.jax_distributed)
|
||||
set_random_seed(FLAGS.seed)
|
||||
|
||||
prefix_tokenizer = LLaMAConfig.get_tokenizer(
|
||||
FLAGS.tokenizer, truncation_side='left', padding_side='left'
|
||||
)
|
||||
tokenizer = LLaMAConfig.get_tokenizer(
|
||||
FLAGS.tokenizer, truncation_side='right', padding_side='right'
|
||||
)
|
||||
|
||||
with jax.default_device(jax.devices("cpu")[0]):
|
||||
llama_config = LLaMAConfig.load_config(FLAGS.load_llama_config)
|
||||
_, params = StreamingCheckpointer.load_trainstate_checkpoint(
|
||||
FLAGS.load_checkpoint, disallow_trainstate=True
|
||||
)
|
||||
|
||||
hf_model = FlaxLLaMAForCausalLM(
|
||||
llama_config,
|
||||
input_shape=(1, FLAGS.seq_length),
|
||||
seed=FLAGS.seed,
|
||||
_do_init=False
|
||||
)
|
||||
|
||||
model_ps = match_partition_rules(
|
||||
LLaMAConfig.get_partition_rules(), params
|
||||
)
|
||||
shard_fns, _ = make_shard_and_gather_fns(
|
||||
model_ps, get_float_dtype_by_name(FLAGS.dtype)
|
||||
)
|
||||
|
||||
@partial(
|
||||
pjit,
|
||||
in_shardings=(model_ps, PS(), PS()),
|
||||
out_shardings=(PS(), PS(), PS())
|
||||
)
|
||||
def forward_loglikelihood(params, rng, batch):
|
||||
batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
|
||||
rng_generator = JaxRNG(rng)
|
||||
input_tokens = batch['input_tokens']
|
||||
output_tokens = batch['output_tokens']
|
||||
input_mask = batch['input_mask']
|
||||
output_mask = batch['output_mask']
|
||||
|
||||
logits = hf_model.module.apply(
|
||||
params, input_tokens, attention_mask=input_mask,
|
||||
deterministic=True, rngs=rng_generator(llama_config.rng_keys()),
|
||||
).logits
|
||||
# if llama_config.n_real_tokens is not None:
|
||||
# logits = logits.at[:, :, llama_config.n_real_tokens:].set(-1e8)
|
||||
loglikelihood = -optax.softmax_cross_entropy_with_integer_labels(
|
||||
logits, output_tokens
|
||||
)
|
||||
loglikelihood = jnp.sum(loglikelihood * output_mask, axis=-1)
|
||||
match_count = jnp.sum(
|
||||
(jnp.argmax(logits, axis=-1) == output_tokens) * output_mask,
|
||||
axis=-1
|
||||
)
|
||||
total = jnp.sum(output_mask, axis=-1)
|
||||
is_greedy = match_count == total
|
||||
return loglikelihood, is_greedy, rng_generator()
|
||||
|
||||
|
||||
@partial(
|
||||
pjit,
|
||||
in_shardings=(model_ps, PS(), PS(), PS()),
|
||||
out_shardings=(PS(), PS())
|
||||
)
|
||||
def forward_generate(params, rng, batch, temperature):
|
||||
batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
|
||||
rng_generator = JaxRNG(rng)
|
||||
output = hf_model.generate(
|
||||
batch['input_tokens'],
|
||||
attention_mask=batch['attention_mask'],
|
||||
params=params['params'],
|
||||
prng_key=rng_generator(),
|
||||
logits_processor=FlaxLogitsProcessorList(
|
||||
[FlaxTemperatureLogitsWarper(temperature)]
|
||||
),
|
||||
generation_config=GenerationConfig(
|
||||
max_new_tokens=FLAGS.seq_length - FLAGS.input_length,
|
||||
pad_token_id=tokenizer.eos_token_id,
|
||||
bos_token_id=tokenizer.bos_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
do_sample=FLAGS.do_sample,
|
||||
num_beams=FLAGS.num_beams,
|
||||
top_k=FLAGS.top_k,
|
||||
top_p=FLAGS.top_p,
|
||||
)
|
||||
).sequences[:, batch['input_tokens'].shape[1]:]
|
||||
return output, rng_generator()
|
||||
|
||||
@partial(
|
||||
pjit,
|
||||
in_shardings=(model_ps, PS(), PS()),
|
||||
out_shardings=(PS(), PS())
|
||||
)
|
||||
def forward_greedy_generate(params, rng, batch):
|
||||
batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
|
||||
rng_generator = JaxRNG(rng)
|
||||
output = hf_model.generate(
|
||||
batch['input_tokens'],
|
||||
attention_mask=batch['attention_mask'],
|
||||
params=params['params'],
|
||||
prng_key=rng_generator(),
|
||||
generation_config=GenerationConfig(
|
||||
max_new_tokens=FLAGS.seq_length - FLAGS.input_length,
|
||||
pad_token_id=tokenizer.eos_token_id,
|
||||
bos_token_id=tokenizer.bos_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
do_sample=False,
|
||||
num_beams=1,
|
||||
)
|
||||
).sequences[:, batch['input_tokens'].shape[1]:]
|
||||
return output, rng_generator()
|
||||
|
||||
mesh = LLaMAConfig.get_jax_mesh(FLAGS.mesh_dim)
|
||||
with mesh:
|
||||
params = tree_apply(shard_fns, params)
|
||||
sharded_rng = next_rng()
|
||||
|
||||
class ModelServer(LMServer):
|
||||
|
||||
@staticmethod
|
||||
def loglikelihood(prefix_text, text):
|
||||
nonlocal sharded_rng
|
||||
prefix = prefix_tokenizer(
|
||||
prefix_text,
|
||||
padding='max_length',
|
||||
truncation=True,
|
||||
max_length=FLAGS.input_length,
|
||||
return_tensors='np',
|
||||
)
|
||||
inputs = tokenizer(
|
||||
text,
|
||||
padding='max_length',
|
||||
truncation=True,
|
||||
max_length=FLAGS.seq_length - FLAGS.input_length,
|
||||
return_tensors='np',
|
||||
)
|
||||
output_tokens = np.concatenate([prefix.input_ids, inputs.input_ids], axis=1)
|
||||
bos_tokens = np.full(
|
||||
(output_tokens.shape[0], 1), tokenizer.bos_token_id, dtype=np.int32
|
||||
)
|
||||
input_tokens = np.concatenate([bos_tokens, output_tokens[:, :-1]], axis=-1)
|
||||
input_mask = np.concatenate(
|
||||
[prefix.attention_mask, inputs.attention_mask], axis=1
|
||||
)
|
||||
if FLAGS.add_bos_token:
|
||||
bos_mask = np.ones_like(input_mask[:, :1])
|
||||
else:
|
||||
bos_mask = np.zeros_like(input_mask[:, :1])
|
||||
|
||||
input_mask = np.concatenate([bos_mask, input_mask[:, :-1]], axis=1)
|
||||
output_mask = np.concatenate(
|
||||
[np.zeros_like(prefix.attention_mask), inputs.attention_mask], axis=1
|
||||
)
|
||||
batch = dict(
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
input_mask=input_mask,
|
||||
output_mask=output_mask,
|
||||
)
|
||||
with mesh:
|
||||
loglikelihood, is_greedy, sharded_rng = forward_loglikelihood(
|
||||
params, sharded_rng, batch
|
||||
)
|
||||
loglikelihood, is_greedy = jax.device_get((loglikelihood, is_greedy))
|
||||
return loglikelihood, is_greedy
|
||||
|
||||
@staticmethod
|
||||
def loglikelihood_rolling(text):
|
||||
nonlocal sharded_rng
|
||||
inputs = tokenizer(
|
||||
text,
|
||||
padding='longest',
|
||||
truncation=False,
|
||||
max_length=np.iinfo(np.int32).max,
|
||||
return_tensors='np',
|
||||
)
|
||||
batch_size = inputs.input_ids.shape[0]
|
||||
output_tokens = inputs.input_ids
|
||||
attention_mask = inputs.attention_mask
|
||||
|
||||
if output_tokens.shape[1] < FLAGS.seq_length:
|
||||
padding_length = FLAGS.seq_length - output_tokens.shape[1]
|
||||
pad_tokens = np.full(
|
||||
(batch_size, padding_length), tokenizer.pad_token_id, dtype=np.int32
|
||||
)
|
||||
output_tokens = np.concatenate([output_tokens, pad_tokens], axis=-1)
|
||||
pad_mask = np.zeros(
|
||||
(batch_size, padding_length), dtype=inputs.attention_mask.dtype
|
||||
)
|
||||
attention_mask = np.concatenate([attention_mask, pad_mask], axis=-1)
|
||||
|
||||
bos_tokens = np.full(
|
||||
(batch_size, 1), tokenizer.bos_token_id, dtype=np.int32
|
||||
)
|
||||
input_tokens = np.concatenate([bos_tokens, output_tokens[:, :-1]], axis=-1)
|
||||
bos_mask = np.ones((batch_size, 1), dtype=inputs.attention_mask.dtype)
|
||||
total_seq_length = output_tokens.shape[1]
|
||||
|
||||
total_loglikelihood = 0.0
|
||||
total_is_greedy = True
|
||||
# Sliding window
|
||||
for i in range(0, total_seq_length, FLAGS.seq_length):
|
||||
# Last window
|
||||
if i + FLAGS.seq_length > total_seq_length:
|
||||
last_output_mask = np.copy(attention_mask[:, -FLAGS.seq_length:])
|
||||
last_output_mask[:, :i - total_seq_length] = 0.0
|
||||
|
||||
batch = dict(
|
||||
input_tokens=input_tokens[:, -FLAGS.seq_length:],
|
||||
output_tokens=output_tokens[:, -FLAGS.seq_length:],
|
||||
input_mask=attention_mask[:, -FLAGS.seq_length:],
|
||||
output_mask=last_output_mask,
|
||||
)
|
||||
|
||||
# Normal window
|
||||
else:
|
||||
batch = dict(
|
||||
input_tokens=input_tokens[:, i:i + FLAGS.seq_length],
|
||||
output_tokens=output_tokens[:, i:i + FLAGS.seq_length],
|
||||
input_mask=attention_mask[:, i:i + FLAGS.seq_length],
|
||||
output_mask=attention_mask[:, i:i + FLAGS.seq_length],
|
||||
)
|
||||
|
||||
with mesh:
|
||||
loglikelihood, is_greedy, sharded_rng = forward_loglikelihood(
|
||||
params, sharded_rng, batch
|
||||
)
|
||||
loglikelihood, is_greedy = jax.device_get((loglikelihood, is_greedy))
|
||||
|
||||
total_loglikelihood += loglikelihood
|
||||
total_is_greedy = np.logical_and(is_greedy, total_is_greedy)
|
||||
|
||||
return total_loglikelihood, total_is_greedy
|
||||
|
||||
@staticmethod
|
||||
def generate(text, temperature):
|
||||
nonlocal sharded_rng
|
||||
inputs = prefix_tokenizer(
|
||||
text,
|
||||
padding='max_length',
|
||||
truncation=True,
|
||||
max_length=FLAGS.input_length,
|
||||
return_tensors='np',
|
||||
)
|
||||
input_tokens = inputs.input_ids
|
||||
input_mask = inputs.attention_mask
|
||||
if FLAGS.add_bos_token:
|
||||
input_tokens[:, 0] = tokenizer.bos_token_id
|
||||
input_mask[:, 0] = 1
|
||||
batch = dict(
|
||||
input_tokens=input_tokens,
|
||||
attention_mask=input_mask,
|
||||
)
|
||||
with mesh:
|
||||
output, sharded_rng = forward_generate(
|
||||
params, sharded_rng, batch, temperature
|
||||
)
|
||||
output = jax.device_get(output)
|
||||
output_text = []
|
||||
for text in list(tokenizer.batch_decode(output)):
|
||||
if tokenizer.eos_token in text:
|
||||
text = text.split(tokenizer.eos_token, maxsplit=1)[0]
|
||||
output_text.append(text)
|
||||
|
||||
return output_text
|
||||
|
||||
@staticmethod
|
||||
def greedy_until(prefix_text, until, max_length):
|
||||
nonlocal sharded_rng
|
||||
all_outputs = []
|
||||
for pf, ut in zip(prefix_text, until):
|
||||
if isinstance(ut, str):
|
||||
ut = [ut]
|
||||
total_length = 0
|
||||
total_generated = ''
|
||||
|
||||
while total_length < max_length:
|
||||
pf_tokens = tokenizer(
|
||||
pf,
|
||||
padding=False,
|
||||
truncation=False,
|
||||
max_length=np.iinfo(np.int32).max,
|
||||
return_tensors='np',
|
||||
)
|
||||
input_tokens = pf_tokens.input_ids
|
||||
attention_mask = pf_tokens.attention_mask
|
||||
|
||||
if input_tokens.shape[1] < FLAGS.input_length:
|
||||
extra = FLAGS.input_length - input_tokens.shape[1]
|
||||
pad_tokens = np.full(
|
||||
(1, extra), tokenizer.pad_token_id, dtype=np.int32
|
||||
)
|
||||
input_tokens = np.concatenate(
|
||||
[pad_tokens, input_tokens], axis=1
|
||||
)
|
||||
pad_attention = np.zeros((1, extra), dtype=attention_mask.dtype)
|
||||
attention_mask = np.concatenate(
|
||||
[pad_attention, attention_mask], axis=1
|
||||
)
|
||||
elif input_tokens.shape[1] > FLAGS.input_length:
|
||||
input_tokens = input_tokens[:, -FLAGS.input_length:]
|
||||
attention_mask = attention_mask[:, -FLAGS.input_length:]
|
||||
|
||||
if FLAGS.add_bos_token:
|
||||
input_tokens[:, 0] = tokenizer.bos_token_id
|
||||
attention_mask[:, 0] = 1
|
||||
|
||||
batch = dict(input_tokens=input_tokens, attention_mask=attention_mask)
|
||||
|
||||
with mesh:
|
||||
output, sharded_rng = forward_greedy_generate(
|
||||
params, sharded_rng, batch
|
||||
)
|
||||
output = jax.device_get(output)
|
||||
|
||||
total_length += output.shape[1]
|
||||
output_text = tokenizer.batch_decode(output)[0]
|
||||
total_generated = total_generated + output_text
|
||||
pf = pf + output_text
|
||||
|
||||
done = False
|
||||
for s in ut:
|
||||
if s in total_generated:
|
||||
total_generated = total_generated.split(s, maxsplit=1)[0]
|
||||
done = True
|
||||
if done:
|
||||
break
|
||||
|
||||
all_outputs.append(total_generated)
|
||||
|
||||
return all_outputs
|
||||
|
||||
|
||||
server = ModelServer(FLAGS.lm_server)
|
||||
server.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mlxu.run(main)
|
||||
268
EasyLM/models/llama/llama_train.py
Normal file
268
EasyLM/models/llama/llama_train.py
Normal file
@@ -0,0 +1,268 @@
|
||||
import pprint
|
||||
from functools import partial
|
||||
|
||||
from tqdm import tqdm, trange
|
||||
import numpy as np
|
||||
import mlxu
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax.experimental.pjit import pjit
|
||||
from jax.sharding import PartitionSpec as PS
|
||||
from flax.training.train_state import TrainState
|
||||
|
||||
from EasyLM.data import DatasetFactory
|
||||
from EasyLM.checkpoint import StreamingCheckpointer
|
||||
from EasyLM.optimizers import OptimizerFactory
|
||||
from EasyLM.jax_utils import (
|
||||
JaxRNG, JaxDistributedConfig, next_rng, match_partition_rules,
|
||||
cross_entropy_loss_and_accuracy, global_norm, get_float_dtype_by_name,
|
||||
set_random_seed, average_metrics, get_weight_decay_mask,
|
||||
make_shard_and_gather_fns, with_sharding_constraint,
|
||||
)
|
||||
from EasyLM.models.llama.llama_model import (
|
||||
LLaMAConfig, FlaxLLaMAForCausalLMModule
|
||||
)
|
||||
|
||||
|
||||
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
|
||||
seed=42,
|
||||
mesh_dim='1,-1,1',
|
||||
dtype='fp32',
|
||||
param_dtype='fp32',
|
||||
total_steps=10000,
|
||||
load_llama_config='',
|
||||
update_llama_config='',
|
||||
load_checkpoint='',
|
||||
load_dataset_state='',
|
||||
log_freq=50,
|
||||
save_model_freq=0,
|
||||
save_milestone_freq=0,
|
||||
eval_freq=0,
|
||||
tokenizer=LLaMAConfig.get_tokenizer_config(),
|
||||
train_dataset=DatasetFactory.get_default_config(),
|
||||
eval_dataset=DatasetFactory.get_default_config(),
|
||||
optimizer=OptimizerFactory.get_default_config(),
|
||||
checkpointer=StreamingCheckpointer.get_default_config(),
|
||||
llama=LLaMAConfig.get_default_config(),
|
||||
logger=mlxu.WandBLogger.get_default_config(),
|
||||
log_all_worker=False,
|
||||
jax_distributed=JaxDistributedConfig.get_default_config(),
|
||||
)
|
||||
|
||||
|
||||
def main(argv):
|
||||
JaxDistributedConfig.initialize(FLAGS.jax_distributed)
|
||||
variant = mlxu.get_user_flags(FLAGS, FLAGS_DEF)
|
||||
flags_config_dict = mlxu.user_flags_to_config_dict(FLAGS, FLAGS_DEF)
|
||||
logger = mlxu.WandBLogger(
|
||||
config=FLAGS.logger,
|
||||
variant=variant,
|
||||
enable=FLAGS.log_all_worker or (jax.process_index() == 0),
|
||||
)
|
||||
set_random_seed(FLAGS.seed)
|
||||
|
||||
tokenizer = LLaMAConfig.get_tokenizer(FLAGS.tokenizer)
|
||||
dataset = DatasetFactory.load_dataset(FLAGS.train_dataset, tokenizer)
|
||||
if FLAGS.load_dataset_state != '':
|
||||
dataset.load_state_dict(mlxu.load_pickle(FLAGS.load_dataset_state))
|
||||
|
||||
if FLAGS.eval_freq > 0:
|
||||
eval_dataset = DatasetFactory.load_dataset(
|
||||
FLAGS.eval_dataset, dataset.tokenizer, eval_dataset=True
|
||||
)
|
||||
|
||||
seq_length = dataset.seq_length
|
||||
|
||||
if FLAGS.load_llama_config != '':
|
||||
llama_config = LLaMAConfig.load_config(FLAGS.load_llama_config)
|
||||
else:
|
||||
llama_config = LLaMAConfig(**FLAGS.llama)
|
||||
|
||||
if FLAGS.update_llama_config != '':
|
||||
llama_config.update(dict(eval(FLAGS.update_llama_config)))
|
||||
|
||||
llama_config.update(dict(
|
||||
bos_token_id=dataset.tokenizer.bos_token_id,
|
||||
eos_token_id=dataset.tokenizer.eos_token_id,
|
||||
))
|
||||
if llama_config.vocab_size < dataset.vocab_size:
|
||||
print("Updating model config vocab size from", llama_config.vocab_size, "to", dataset.vocab_size)
|
||||
llama_config.update(dict(vocab_size=dataset.vocab_size))
|
||||
|
||||
model = FlaxLLaMAForCausalLMModule(
|
||||
llama_config, dtype=get_float_dtype_by_name(FLAGS.dtype), param_dtype=get_float_dtype_by_name(FLAGS.param_dtype)
|
||||
)
|
||||
|
||||
optimizer, optimizer_info = OptimizerFactory.get_optimizer(
|
||||
FLAGS.optimizer,
|
||||
get_weight_decay_mask(LLaMAConfig.get_weight_decay_exclusions())
|
||||
)
|
||||
|
||||
def create_trainstate_from_params(params):
|
||||
return TrainState.create(params=params, tx=optimizer, apply_fn=None)
|
||||
|
||||
def init_fn(rng):
|
||||
rng_generator = JaxRNG(rng)
|
||||
params = model.init(
|
||||
input_ids=jnp.zeros((4, seq_length), dtype=jnp.int32),
|
||||
position_ids=jnp.zeros((4, seq_length), dtype=jnp.int32),
|
||||
attention_mask=jnp.ones((4, seq_length), dtype=jnp.int32),
|
||||
rngs=rng_generator(llama_config.rng_keys()),
|
||||
)
|
||||
return TrainState.create(params=params, tx=optimizer, apply_fn=None)
|
||||
|
||||
def train_step(train_state, rng, batch):
|
||||
rng_generator = JaxRNG(rng)
|
||||
batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
|
||||
def loss_and_accuracy(params):
|
||||
logits = model.apply(
|
||||
params, batch['input_tokens'], deterministic=False,
|
||||
rngs=rng_generator(llama_config.rng_keys()),
|
||||
).logits
|
||||
return cross_entropy_loss_and_accuracy(
|
||||
logits, batch['target_tokens'], batch['loss_masks']
|
||||
)
|
||||
grad_fn = jax.value_and_grad(loss_and_accuracy, has_aux=True)
|
||||
(loss, accuracy), grads = grad_fn(train_state.params)
|
||||
train_state = train_state.apply_gradients(grads=grads)
|
||||
metrics = dict(
|
||||
loss=loss,
|
||||
accuracy=accuracy,
|
||||
learning_rate=optimizer_info['learning_rate_schedule'](train_state.step),
|
||||
gradient_norm=global_norm(grads),
|
||||
param_norm=global_norm(train_state.params),
|
||||
)
|
||||
return train_state, rng_generator(), metrics
|
||||
|
||||
def eval_step(train_state, rng, batch):
|
||||
rng_generator = JaxRNG(rng)
|
||||
batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
|
||||
logits = model.apply(
|
||||
train_state.params, batch['input_tokens'], deterministic=True,
|
||||
rngs=rng_generator(llama_config.rng_keys()),
|
||||
).logits
|
||||
loss, accuracy = cross_entropy_loss_and_accuracy(
|
||||
logits, batch['target_tokens'], batch['loss_masks']
|
||||
)
|
||||
metrics = dict(
|
||||
eval_loss=loss,
|
||||
eval_accuracy=accuracy,
|
||||
)
|
||||
return rng_generator(), metrics
|
||||
|
||||
train_state_shapes = jax.eval_shape(init_fn, next_rng())
|
||||
train_state_partition = match_partition_rules(
|
||||
LLaMAConfig.get_partition_rules(), train_state_shapes
|
||||
)
|
||||
|
||||
shard_fns, gather_fns = make_shard_and_gather_fns(
|
||||
train_state_partition, train_state_shapes
|
||||
)
|
||||
checkpointer = StreamingCheckpointer(
|
||||
FLAGS.checkpointer, logger.output_dir,
|
||||
enable=jax.process_index() == 0,
|
||||
)
|
||||
|
||||
sharded_init_fn = pjit(
|
||||
init_fn,
|
||||
in_shardings=PS(),
|
||||
out_shardings=train_state_partition
|
||||
)
|
||||
|
||||
sharded_create_trainstate_from_params = pjit(
|
||||
create_trainstate_from_params,
|
||||
in_shardings=(train_state_partition.params, ),
|
||||
out_shardings=train_state_partition,
|
||||
donate_argnums=(0, ),
|
||||
)
|
||||
|
||||
sharded_train_step = pjit(
|
||||
train_step,
|
||||
in_shardings=(train_state_partition, PS(), PS()),
|
||||
out_shardings=(train_state_partition, PS(), PS()),
|
||||
donate_argnums=(0, 1),
|
||||
)
|
||||
|
||||
sharded_eval_step = pjit(
|
||||
eval_step,
|
||||
in_shardings=(train_state_partition, PS(), PS()),
|
||||
out_shardings=(PS(), PS()),
|
||||
donate_argnums=(1,),
|
||||
)
|
||||
|
||||
def save_checkpoint(train_state, milestone=False):
|
||||
step = int(jax.device_get(train_state.step))
|
||||
metadata = dict(
|
||||
step=step,
|
||||
variant=variant,
|
||||
flags=flags_config_dict,
|
||||
llama_config=llama_config.to_dict(),
|
||||
)
|
||||
checkpointer.save_all(
|
||||
train_state=train_state,
|
||||
gather_fns=gather_fns,
|
||||
metadata=metadata,
|
||||
dataset=dataset.get_state_dict(),
|
||||
milestone=milestone,
|
||||
)
|
||||
|
||||
mesh = LLaMAConfig.get_jax_mesh(FLAGS.mesh_dim)
|
||||
with mesh:
|
||||
train_state, restored_params = None, None
|
||||
if FLAGS.load_checkpoint != '':
|
||||
train_state, restored_params = checkpointer.load_trainstate_checkpoint(
|
||||
FLAGS.load_checkpoint, train_state_shapes, shard_fns
|
||||
)
|
||||
|
||||
if train_state is None and restored_params is None:
|
||||
# Initialize from scratch
|
||||
train_state = sharded_init_fn(next_rng())
|
||||
elif train_state is None and restored_params is not None:
|
||||
# Restore from params but initialize train_state
|
||||
train_state = sharded_create_trainstate_from_params(restored_params)
|
||||
del restored_params
|
||||
|
||||
start_step = int(jax.device_get(train_state.step))
|
||||
|
||||
if FLAGS.save_model_freq > 0:
|
||||
save_checkpoint(train_state)
|
||||
|
||||
sharded_rng = next_rng()
|
||||
|
||||
step_counter = trange(start_step, FLAGS.total_steps, ncols=0)
|
||||
|
||||
for step, (batch, dataset_metrics) in zip(step_counter, dataset):
|
||||
train_state, sharded_rng, metrics = sharded_train_step(
|
||||
train_state, sharded_rng, batch
|
||||
)
|
||||
|
||||
if FLAGS.eval_freq > 0 and (step + 1) % FLAGS.eval_freq == 0:
|
||||
eval_metric_list = []
|
||||
eval_iterator = iter(eval_dataset)
|
||||
for eval_batch, _ in eval_iterator:
|
||||
sharded_rng, eval_metrics = sharded_eval_step(
|
||||
train_state, sharded_rng, eval_batch
|
||||
)
|
||||
eval_metric_list.append(eval_metrics)
|
||||
metrics.update(average_metrics(eval_metric_list))
|
||||
|
||||
if FLAGS.log_freq > 0 and (step + 1) % FLAGS.log_freq == 0:
|
||||
log_metrics = {"step": step + 1}
|
||||
log_metrics.update(metrics)
|
||||
log_metrics.update(dataset_metrics)
|
||||
log_metrics = jax.device_get(log_metrics)
|
||||
logger.log(log_metrics)
|
||||
tqdm.write("\n" + pprint.pformat(log_metrics) + "\n")
|
||||
|
||||
if FLAGS.save_milestone_freq > 0 and (step + 1) % FLAGS.save_milestone_freq == 0:
|
||||
save_checkpoint(train_state, milestone=True)
|
||||
elif FLAGS.save_model_freq > 0 and (step + 1) % FLAGS.save_model_freq == 0:
|
||||
save_checkpoint(train_state)
|
||||
|
||||
if FLAGS.save_model_freq > 0:
|
||||
save_checkpoint(train_state)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mlxu.run(main)
|
||||
0
EasyLM/models/roberta/__init__.py
Normal file
0
EasyLM/models/roberta/__init__.py
Normal file
1694
EasyLM/models/roberta/roberta_model.py
Normal file
1694
EasyLM/models/roberta/roberta_model.py
Normal file
File diff suppressed because it is too large
Load Diff
307
EasyLM/models/roberta/roberta_train.py
Normal file
307
EasyLM/models/roberta/roberta_train.py
Normal file
@@ -0,0 +1,307 @@
|
||||
import dataclasses
|
||||
import pprint
|
||||
from functools import partial
|
||||
import re
|
||||
|
||||
from tqdm import tqdm, trange
|
||||
import numpy as np
|
||||
import mlxu
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax.experimental.pjit import pjit, with_sharding_constraint
|
||||
from jax.sharding import PartitionSpec as PS
|
||||
from flax.training.train_state import TrainState
|
||||
|
||||
from EasyLM.data import DatasetFactory
|
||||
from EasyLM.checkpoint import StreamingCheckpointer
|
||||
from EasyLM.optimizers import OptimizerFactory
|
||||
from EasyLM.jax_utils import (
|
||||
JaxRNG, JaxDistributedConfig, next_rng, match_partition_rules, get_float_dtype_by_name,
|
||||
cross_entropy_loss_and_accuracy, named_tree_map, global_norm,
|
||||
set_random_seed, average_metrics, get_weight_decay_mask,
|
||||
make_shard_and_gather_fns, tree_apply
|
||||
)
|
||||
from EasyLM.models.roberta.roberta_model import (
|
||||
RobertaConfig, FlaxRobertaForMaskedLMModule
|
||||
)
|
||||
|
||||
|
||||
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
|
||||
seed=42,
|
||||
mesh_dim='-1,1,1',
|
||||
dtype='fp32',
|
||||
mask_token_probability=0.15,
|
||||
total_steps=10000,
|
||||
load_roberta_config='',
|
||||
update_roberta_config='',
|
||||
load_checkpoint='',
|
||||
load_dataset_state='',
|
||||
log_freq=50,
|
||||
save_model_freq=0,
|
||||
save_milestone_freq=0,
|
||||
eval_steps=0,
|
||||
tokenizer=RobertaConfig.get_tokenizer_config(),
|
||||
train_dataset=DatasetFactory.get_default_config(),
|
||||
eval_dataset=DatasetFactory.get_default_config(),
|
||||
optimizer=OptimizerFactory.get_default_config(),
|
||||
checkpointer=StreamingCheckpointer.get_default_config(),
|
||||
roberta=RobertaConfig.get_default_config(),
|
||||
logger=mlxu.WandBLogger.get_default_config(),
|
||||
log_all_worker=False,
|
||||
jax_distributed=JaxDistributedConfig.get_default_config(),
|
||||
)
|
||||
|
||||
|
||||
def main(argv):
|
||||
JaxDistributedConfig.initialize(FLAGS.jax_distributed)
|
||||
variant = mlxu.get_user_flags(FLAGS, FLAGS_DEF)
|
||||
flags_config_dict = mlxu.user_flags_to_config_dict(FLAGS, FLAGS_DEF)
|
||||
logger = mlxu.WandBLogger(
|
||||
config=FLAGS.logger,
|
||||
variant=variant,
|
||||
enable=FLAGS.log_all_worker or (jax.process_index() == 0),
|
||||
)
|
||||
set_random_seed(FLAGS.seed)
|
||||
|
||||
tokenizer = RobertaConfig.get_tokenizer(FLAGS.tokenizer)
|
||||
dataset = DatasetFactory.load_dataset(FLAGS.train_dataset, tokenizer)
|
||||
if FLAGS.load_dataset_state != '':
|
||||
dataset.load_state_dict(mlxu.load_pickle(FLAGS.load_dataset_state))
|
||||
|
||||
if FLAGS.eval_steps > 0:
|
||||
eval_dataset = DatasetFactory.load_dataset(
|
||||
FLAGS.eval_dataset, dataset.tokenizer
|
||||
)
|
||||
eval_iterator = iter(eval_dataset)
|
||||
|
||||
seq_length = dataset.seq_length
|
||||
|
||||
if FLAGS.load_roberta_config != '':
|
||||
roberta_config = RobertaConfig.load_config(FLAGS.load_roberta_config)
|
||||
else:
|
||||
roberta_config = RobertaConfig(**FLAGS.roberta)
|
||||
|
||||
if FLAGS.update_roberta_config != '':
|
||||
roberta_config.update(dict(eval(FLAGS.update_roberta_config)))
|
||||
|
||||
roberta_config.update(dict(
|
||||
bos_token_id=dataset.tokenizer.bos_token_id,
|
||||
eos_token_id=dataset.tokenizer.eos_token_id,
|
||||
pad_token_id=dataset.tokenizer.pad_token_id,
|
||||
vocab_size=dataset.vocab_size,
|
||||
))
|
||||
|
||||
model = FlaxRobertaForMaskedLMModule(
|
||||
roberta_config, dtype=get_float_dtype_by_name(FLAGS.dtype)
|
||||
)
|
||||
|
||||
optimizer, optimizer_info = OptimizerFactory.get_optimizer(
|
||||
FLAGS.optimizer,
|
||||
get_weight_decay_mask(RobertaConfig.get_weight_decay_exclusions()),
|
||||
)
|
||||
|
||||
def create_trainstate_from_params(params):
|
||||
return TrainState.create(params=params, tx=optimizer, apply_fn=None)
|
||||
|
||||
def init_fn(rng):
|
||||
rng_generator = JaxRNG(rng)
|
||||
params = model.init(
|
||||
input_ids=jnp.zeros((4, seq_length), dtype=jnp.int32),
|
||||
position_ids=jnp.zeros((4, seq_length), dtype=jnp.int32),
|
||||
attention_mask=jnp.ones((4, seq_length), dtype=jnp.int32),
|
||||
token_type_ids=None,
|
||||
head_mask=None,
|
||||
rngs=rng_generator(roberta_config.rng_keys()),
|
||||
)
|
||||
return TrainState.create(params=params, tx=optimizer, apply_fn=None)
|
||||
|
||||
def train_step(train_state, rng, batch):
|
||||
rng_generator = JaxRNG(rng)
|
||||
tokens = with_sharding_constraint(batch['target_tokens'], PS(('dp', 'fsdp')))
|
||||
def loss_and_accuracy(params):
|
||||
altered_tokens = jax.random.uniform(
|
||||
rng_generator(), shape=tokens.shape
|
||||
) < FLAGS.mask_token_probability
|
||||
random_uniform = jax.random.uniform(rng_generator(), shape=tokens.shape)
|
||||
altered_by_mask = altered_tokens & (random_uniform < 0.8)
|
||||
altered_by_random = altered_tokens & (random_uniform >= 0.8) & (random_uniform < 0.9)
|
||||
inputs = jnp.where(altered_by_mask, dataset.tokenizer.mask_token_id, tokens)
|
||||
random_tokens = jax.random.randint(
|
||||
rng_generator(), shape=tokens.shape, minval=0, maxval=dataset.vocab_size
|
||||
)
|
||||
inputs = jnp.where(altered_by_random, random_tokens, inputs)
|
||||
logits = model.apply(
|
||||
params, inputs,
|
||||
attention_mask=jnp.ones_like(inputs),
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
deterministic=False,
|
||||
rngs=rng_generator(roberta_config.rng_keys()),
|
||||
).logits
|
||||
return cross_entropy_loss_and_accuracy(logits, tokens, valid=altered_tokens)
|
||||
grad_fn = jax.value_and_grad(loss_and_accuracy, has_aux=True)
|
||||
(loss, accuracy), grads = grad_fn(train_state.params)
|
||||
train_state = train_state.apply_gradients(grads=grads)
|
||||
metrics = dict(
|
||||
loss=loss,
|
||||
accuracy=accuracy,
|
||||
learning_rate=optimizer_info['learning_rate_schedule'](train_state.step),
|
||||
gradient_norm=global_norm(grads),
|
||||
param_norm=global_norm(train_state.params),
|
||||
)
|
||||
return train_state, rng_generator(), metrics
|
||||
|
||||
def eval_step(train_state, rng, batch):
|
||||
rng_generator = JaxRNG(rng)
|
||||
tokens = with_sharding_constraint(batch['target_tokens'], PS(('dp', 'fsdp')))
|
||||
altered_tokens = jax.random.uniform(
|
||||
rng_generator(), shape=tokens.shape
|
||||
) < FLAGS.mask_token_probability
|
||||
random_uniform = jax.random.uniform(rng_generator(), shape=tokens.shape)
|
||||
altered_by_mask = altered_tokens & (random_uniform < 0.8)
|
||||
altered_by_random = altered_tokens & (random_uniform >= 0.8) & (random_uniform < 0.9)
|
||||
inputs = jnp.where(altered_by_mask, dataset.tokenizer.mask_token_id, tokens)
|
||||
random_tokens = jax.random.randint(
|
||||
rng_generator(), shape=tokens.shape, minval=0, maxval=dataset.vocab_size
|
||||
)
|
||||
inputs = jnp.where(altered_by_random, random_tokens, inputs)
|
||||
logits = model.apply(
|
||||
train_state.params, inputs,
|
||||
attention_mask=jnp.ones_like(inputs),
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
deterministic=False,
|
||||
rngs=rng_generator(roberta_config.rng_keys()),
|
||||
).logits
|
||||
loss, accuracy = cross_entropy_loss_and_accuracy(logits, tokens, valid=altered_tokens)
|
||||
metrics = dict(
|
||||
eval_loss=loss,
|
||||
eval_accuracy=accuracy,
|
||||
)
|
||||
return rng_generator(), metrics
|
||||
|
||||
train_state_shapes = jax.eval_shape(init_fn, next_rng())
|
||||
train_state_partition = match_partition_rules(
|
||||
RobertaConfig.get_partition_rules(), train_state_shapes
|
||||
)
|
||||
|
||||
shard_fns, gather_fns = make_shard_and_gather_fns(
|
||||
train_state_partition, train_state_shapes
|
||||
)
|
||||
checkpointer = StreamingCheckpointer(
|
||||
FLAGS.checkpointer, logger.output_dir,
|
||||
enable=jax.process_index() == 0
|
||||
)
|
||||
|
||||
sharded_init_fn = pjit(
|
||||
init_fn,
|
||||
in_shardings=PS(),
|
||||
out_shardings=train_state_partition
|
||||
)
|
||||
|
||||
sharded_create_trainstate_from_params = pjit(
|
||||
create_trainstate_from_params,
|
||||
in_shardings=(train_state_partition.params, ),
|
||||
out_shardings=train_state_partition,
|
||||
donate_argnums=(0, ),
|
||||
)
|
||||
|
||||
sharded_train_step = pjit(
|
||||
train_step,
|
||||
in_shardings=(train_state_partition, PS(), PS()),
|
||||
out_shardings=(train_state_partition, PS(), PS()),
|
||||
donate_argnums=(0, 1),
|
||||
)
|
||||
|
||||
sharded_eval_step = pjit(
|
||||
eval_step,
|
||||
in_shardings=(train_state_partition, PS(), PS()),
|
||||
out_shardings=(PS(), PS()),
|
||||
donate_argnums=(1,),
|
||||
)
|
||||
|
||||
def save_checkpoint(train_state, milestone=False):
|
||||
step = int(jax.device_get(train_state.step))
|
||||
metadata = dict(
|
||||
step=step,
|
||||
variant=variant,
|
||||
flags=flags_config_dict,
|
||||
roberta_config=roberta_config.to_dict(),
|
||||
)
|
||||
checkpointer.save_all(
|
||||
train_state=train_state,
|
||||
gather_fns=gather_fns,
|
||||
metadata=metadata,
|
||||
dataset=dataset.get_state_dict(),
|
||||
milestone=milestone,
|
||||
)
|
||||
|
||||
mesh = RobertaConfig.get_jax_mesh(FLAGS.mesh_dim)
|
||||
with mesh:
|
||||
train_state, restored_params = None, None
|
||||
if FLAGS.load_checkpoint != '':
|
||||
load_type, load_path = FLAGS.load_checkpoint.split('::', 1)
|
||||
if load_type == 'huggingface':
|
||||
restored_params = tree_apply(
|
||||
shard_fns.params, roberta_config.load_pretrained(load_path)
|
||||
)
|
||||
train_state = None
|
||||
else:
|
||||
train_state, restored_params = checkpointer.load_trainstate_checkpoint(
|
||||
FLAGS.load_checkpoint, train_state_shapes, shard_fns
|
||||
)
|
||||
|
||||
if train_state is None and restored_params is None:
|
||||
# Initialize from scratch
|
||||
train_state = sharded_init_fn(next_rng())
|
||||
elif train_state is None and restored_params is not None:
|
||||
# Restore from params but initialize train_state
|
||||
train_state = sharded_create_trainstate_from_params(restored_params)
|
||||
del restored_params
|
||||
|
||||
start_step = int(jax.device_get(train_state.step))
|
||||
|
||||
if FLAGS.save_model_freq > 0:
|
||||
save_checkpoint(train_state)
|
||||
|
||||
sharded_rng = next_rng()
|
||||
|
||||
step_counter = trange(start_step, FLAGS.total_steps, ncols=0)
|
||||
|
||||
for step, (batch, dataset_metrics) in zip(step_counter, dataset):
|
||||
train_state, sharded_rng, metrics = sharded_train_step(
|
||||
train_state, sharded_rng, batch
|
||||
)
|
||||
|
||||
if step % FLAGS.log_freq == 0:
|
||||
if FLAGS.eval_steps > 0:
|
||||
eval_metric_list = []
|
||||
for _ in range(FLAGS.eval_steps):
|
||||
eval_batch, _ = next(eval_iterator)
|
||||
sharded_rng, eval_metrics = sharded_eval_step(
|
||||
train_state, sharded_rng, eval_batch
|
||||
)
|
||||
eval_metric_list.append(eval_metrics)
|
||||
metrics.update(average_metrics(eval_metric_list))
|
||||
|
||||
log_metrics = {"step": step}
|
||||
log_metrics.update(metrics)
|
||||
log_metrics.update(dataset_metrics)
|
||||
log_metrics = jax.device_get(log_metrics)
|
||||
logger.log(log_metrics)
|
||||
tqdm.write("\n" + pprint.pformat(log_metrics) + "\n")
|
||||
|
||||
if FLAGS.save_milestone_freq > 0 and (step + 1) % FLAGS.save_milestone_freq == 0:
|
||||
save_checkpoint(train_state, milestone=True)
|
||||
elif FLAGS.save_model_freq > 0 and (step + 1) % FLAGS.save_model_freq == 0:
|
||||
save_checkpoint(train_state)
|
||||
|
||||
if FLAGS.save_model_freq > 0:
|
||||
save_checkpoint(train_state)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mlxu.run(main)
|
||||
346
EasyLM/optimizers.py
Normal file
346
EasyLM/optimizers.py
Normal file
@@ -0,0 +1,346 @@
|
||||
import os
|
||||
import time
|
||||
from typing import Any, Mapping, Text, Tuple, Union, NamedTuple
|
||||
from functools import partial
|
||||
import re
|
||||
import dataclasses
|
||||
import random
|
||||
|
||||
from ml_collections.config_dict import config_dict
|
||||
from ml_collections import ConfigDict
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
from absl import logging
|
||||
import optax
|
||||
|
||||
from EasyLM.jax_utils import float_to_dtype
|
||||
|
||||
|
||||
class OptimizerFactory(object):
|
||||
""" Configurable optax optimizer factory. """
|
||||
|
||||
def __init__(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def get_default_config(updates=None):
|
||||
config = ConfigDict()
|
||||
config.accumulate_gradient_steps = 1
|
||||
config.type = 'adamw'
|
||||
config.palm_optimizer = PalmOptimizerFactory.get_default_config()
|
||||
config.adamw_optimizer = AdamWOptimizerFactory.get_default_config()
|
||||
config.lion_optimizer = LionOptimizerFactory.get_default_config()
|
||||
|
||||
if updates is not None:
|
||||
config.update(ConfigDict(updates).copy_and_resolve_references())
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def get_optimizer(cls, config, weight_decay_mask=None):
|
||||
config = cls.get_default_config(config)
|
||||
if config.type == 'palm':
|
||||
optimizer, optimizer_info = PalmOptimizerFactory.get_optimizer(
|
||||
config.palm_optimizer, weight_decay_mask
|
||||
)
|
||||
elif config.type == 'adamw':
|
||||
optimizer, optimizer_info = AdamWOptimizerFactory.get_optimizer(
|
||||
config.adamw_optimizer, weight_decay_mask
|
||||
)
|
||||
elif config.type == 'lion':
|
||||
optimizer, optimizer_info = LionOptimizerFactory.get_optimizer(
|
||||
config.lion_optimizer, weight_decay_mask
|
||||
)
|
||||
else:
|
||||
raise ValueError(f'Unknown optimizer type: {config.type}')
|
||||
|
||||
if config.accumulate_gradient_steps > 1:
|
||||
optimizer = optax.MultiSteps(
|
||||
optimizer, config.accumulate_gradient_steps
|
||||
)
|
||||
|
||||
return optimizer, optimizer_info
|
||||
|
||||
|
||||
class PalmOptimizerFactory(object):
|
||||
""" PaLM optimizer factory. This optimizer implements the optimizer
|
||||
described in the PaLM paper: https://arxiv.org/abs/2204.02311
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def get_default_config(updates=None):
|
||||
config = ConfigDict()
|
||||
config.lr = 0.01
|
||||
config.lr_warmup_steps = 10000
|
||||
config.b1 = 0.9
|
||||
config.b2 = 0.99
|
||||
config.clip_gradient = 1.0
|
||||
config.weight_decay = 1e-4
|
||||
config.bf16_momentum = False
|
||||
|
||||
if updates is not None:
|
||||
config.update(ConfigDict(updates).copy_and_resolve_references())
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def get_optimizer(cls, config, weight_decay_mask=None):
|
||||
config = cls.get_default_config(config)
|
||||
|
||||
def learning_rate_schedule(step):
|
||||
multiplier = config.lr / 0.01
|
||||
return multiplier / jnp.sqrt(jnp.maximum(step, config.lr_warmup_steps))
|
||||
|
||||
def weight_decay_schedule(step):
|
||||
multiplier = config.weight_decay / 1e-4
|
||||
return -multiplier * jnp.square(learning_rate_schedule(step))
|
||||
|
||||
optimizer_info = dict(
|
||||
learning_rate_schedule=learning_rate_schedule,
|
||||
weight_decay_schedule=weight_decay_schedule,
|
||||
)
|
||||
|
||||
optimizer = optax.chain(
|
||||
optax.clip_by_global_norm(config.clip_gradient),
|
||||
optax.adafactor(
|
||||
learning_rate=learning_rate_schedule,
|
||||
multiply_by_parameter_scale=True,
|
||||
momentum=config.b1,
|
||||
decay_rate=config.b2,
|
||||
factored=False,
|
||||
clipping_threshold=None,
|
||||
dtype_momentum=jnp.bfloat16 if config.bf16_momentum else jnp.float32,
|
||||
),
|
||||
optax_add_scheduled_weight_decay(
|
||||
weight_decay_schedule, weight_decay_mask
|
||||
)
|
||||
)
|
||||
return optimizer, optimizer_info
|
||||
|
||||
|
||||
class AdamWOptimizerFactory(object):
|
||||
""" AdamW optimizer with cosine schedule. """
|
||||
|
||||
def __init__(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def get_default_config(updates=None):
|
||||
config = ConfigDict()
|
||||
config.init_lr = 0.0
|
||||
config.end_lr = 0.001
|
||||
config.lr = 0.01
|
||||
config.lr_warmup_steps = 2000
|
||||
config.lr_decay_steps = 500000
|
||||
config.b1 = 0.9
|
||||
config.b2 = 0.95
|
||||
config.clip_gradient = 1.0
|
||||
config.weight_decay = 1e-4
|
||||
config.bf16_momentum = False
|
||||
config.multiply_by_parameter_scale = False
|
||||
|
||||
if updates is not None:
|
||||
config.update(ConfigDict(updates).copy_and_resolve_references())
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def get_optimizer(cls, config, weight_decay_mask=None):
|
||||
config = cls.get_default_config(config)
|
||||
|
||||
learning_rate_schedule = optax.warmup_cosine_decay_schedule(
|
||||
init_value=config.init_lr,
|
||||
peak_value=config.lr,
|
||||
warmup_steps=config.lr_warmup_steps,
|
||||
decay_steps=config.lr_decay_steps,
|
||||
end_value=config.end_lr,
|
||||
)
|
||||
|
||||
optimizer_info = dict(
|
||||
learning_rate_schedule=learning_rate_schedule,
|
||||
)
|
||||
|
||||
if config.multiply_by_parameter_scale:
|
||||
optimizer = optax.chain(
|
||||
optax.clip_by_global_norm(config.clip_gradient),
|
||||
optax.adafactor(
|
||||
learning_rate=learning_rate_schedule,
|
||||
multiply_by_parameter_scale=True,
|
||||
momentum=config.b1,
|
||||
decay_rate=config.b2,
|
||||
factored=False,
|
||||
clipping_threshold=None,
|
||||
dtype_momentum=jnp.bfloat16 if config.bf16_momentum else jnp.float32,
|
||||
),
|
||||
optax_add_scheduled_weight_decay(
|
||||
lambda step: -learning_rate_schedule(step) * config.weight_decay,
|
||||
weight_decay_mask
|
||||
)
|
||||
)
|
||||
else:
|
||||
optimizer = optax.chain(
|
||||
optax.clip_by_global_norm(config.clip_gradient),
|
||||
optax.adamw(
|
||||
learning_rate=learning_rate_schedule,
|
||||
weight_decay=config.weight_decay,
|
||||
b1=config.b1,
|
||||
b2=config.b2,
|
||||
mask=weight_decay_mask,
|
||||
mu_dtype=jnp.bfloat16 if config.bf16_momentum else jnp.float32,
|
||||
),
|
||||
)
|
||||
|
||||
return optimizer, optimizer_info
|
||||
|
||||
class LionOptimizerFactory(object):
|
||||
""" Lion optimizer with cosine schedule. """
|
||||
|
||||
def __init__(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def get_default_config(updates=None):
|
||||
config = ConfigDict()
|
||||
config.init_lr = 0.0
|
||||
config.end_lr = 0.0001
|
||||
config.lr = 0.001
|
||||
config.lr_warmup_steps = 60000
|
||||
config.lr_constant_steps = 840000
|
||||
config.lr_decay_steps = 100000
|
||||
config.b1 = 0.9
|
||||
config.b2 = 0.98
|
||||
config.clip_gradient = 1.0
|
||||
config.weight_decay = 1e-3
|
||||
config.bf16_momentum = False
|
||||
config.lr_schedule_type = "warmup_cosine_decay_schedule"
|
||||
config.lr_decay_rate = 0.98
|
||||
|
||||
if updates is not None:
|
||||
config.update(ConfigDict(updates).copy_and_resolve_references())
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def get_optimizer(cls, config, weight_decay_mask=None):
|
||||
config = cls.get_default_config(config)
|
||||
|
||||
if config.lr_schedule_type == "warmup_cosine_decay_schedule":
|
||||
learning_rate_schedule = optax.warmup_cosine_decay_schedule(
|
||||
init_value=config.init_lr,
|
||||
peak_value=config.lr,
|
||||
warmup_steps=config.lr_warmup_steps,
|
||||
decay_steps=config.lr_decay_steps,
|
||||
end_value=config.end_lr,
|
||||
)
|
||||
elif config.lr_schedule_type == "warmup_constant":
|
||||
learning_rate_schedule = optax.join_schedules(
|
||||
[
|
||||
optax.linear_schedule(
|
||||
init_value=config.init_lr,
|
||||
end_value=config.lr,
|
||||
transition_steps=config.lr_warmup_steps,
|
||||
),
|
||||
optax.constant_schedule(config.lr),
|
||||
],
|
||||
[config.lr_warmup_steps],
|
||||
)
|
||||
elif config.lr_schedule_type == "warmup_constant_linear_decay":
|
||||
learning_rate_schedule = optax.join_schedules(
|
||||
[
|
||||
optax.linear_schedule(
|
||||
init_value=config.init_lr,
|
||||
end_value=config.lr,
|
||||
transition_steps=config.lr_warmup_steps,
|
||||
),
|
||||
optax.constant_schedule(config.lr),
|
||||
optax.linear_schedule(
|
||||
init_value=config.lr,
|
||||
end_value=config.end_lr,
|
||||
transition_steps=config.lr_decay_steps,
|
||||
)
|
||||
],
|
||||
[config.lr_warmup_steps, config.lr_constant_steps],
|
||||
)
|
||||
elif config.lr_schedule_type == "warmup_constant_exponential_decay":
|
||||
learning_rate_schedule = optax.join_schedules(
|
||||
[
|
||||
optax.linear_schedule(
|
||||
init_value=config.init_lr,
|
||||
end_value=config.lr,
|
||||
transition_steps=config.lr_warmup_steps,
|
||||
),
|
||||
optax.constant_schedule(config.lr),
|
||||
optax.exponential_decay(
|
||||
init_value=config.lr,
|
||||
transition_steps=config.lr_decay_steps,
|
||||
decay_rate=config.lr_decay_rate,
|
||||
transition_begin=0,
|
||||
staircase=False,
|
||||
end_value=config.end_lr,
|
||||
)
|
||||
],
|
||||
[config.lr_warmup_steps, config.lr_constant_steps],
|
||||
)
|
||||
elif config.lr_schedule_type == "exponential_decay":
|
||||
learning_rate_schedule = optax.exponential_decay(
|
||||
init_value=config.lr,
|
||||
transition_steps=config.lr_decay_steps,
|
||||
decay_rate=config.lr_decay_rate,
|
||||
transition_begin=0,
|
||||
staircase=False,
|
||||
end_value=config.end_lr,
|
||||
)
|
||||
elif config.lr_schedule_type == "linear_decay":
|
||||
learning_rate_schedule = optax.linear_schedule(
|
||||
init_value=config.lr,
|
||||
end_value=config.end_lr,
|
||||
transition_steps=config.lr_decay_steps,
|
||||
)
|
||||
else:
|
||||
raise ValueError('config.lr_schedule_type must be "warmup_cosine_decay_schedule", "warmup_constant", "warmup_constant_linear_decay", "warmup_constant_exponential_decay", "exponential_decay" or "linear_decay"')
|
||||
|
||||
optimizer_info = dict(
|
||||
learning_rate_schedule=learning_rate_schedule,
|
||||
)
|
||||
|
||||
optimizer = optax.chain(
|
||||
optax.clip_by_global_norm(config.clip_gradient),
|
||||
optax.lion(
|
||||
learning_rate=learning_rate_schedule,
|
||||
weight_decay=config.weight_decay,
|
||||
b1=config.b1,
|
||||
b2=config.b2,
|
||||
mask=weight_decay_mask,
|
||||
mu_dtype=jnp.bfloat16 if config.bf16_momentum else jnp.float32,
|
||||
),
|
||||
)
|
||||
|
||||
return optimizer, optimizer_info
|
||||
|
||||
|
||||
class OptaxScheduledWeightDecayState(NamedTuple):
|
||||
count: jax.Array
|
||||
|
||||
|
||||
def optax_add_scheduled_weight_decay(schedule_fn, mask=None):
|
||||
""" Apply weight decay with schedule. """
|
||||
|
||||
def init_fn(params):
|
||||
del params
|
||||
return OptaxScheduledWeightDecayState(count=jnp.zeros([], jnp.int32))
|
||||
|
||||
def update_fn(updates, state, params):
|
||||
if params is None:
|
||||
raise ValueError('Params cannot be None for weight decay!')
|
||||
|
||||
weight_decay = schedule_fn(state.count)
|
||||
updates = jax.tree_util.tree_map(
|
||||
lambda g, p: g + weight_decay * p, updates, params
|
||||
)
|
||||
return updates, OptaxScheduledWeightDecayState(
|
||||
count=optax.safe_int32_increment(state.count)
|
||||
)
|
||||
|
||||
if mask is not None:
|
||||
return optax.masked(optax.GradientTransformation(init_fn, update_fn), mask)
|
||||
return optax.GradientTransformation(init_fn, update_fn)
|
||||
0
EasyLM/scripts/__init__.py
Normal file
0
EasyLM/scripts/__init__.py
Normal file
150
EasyLM/scripts/benchmark_attention.py
Normal file
150
EasyLM/scripts/benchmark_attention.py
Normal file
@@ -0,0 +1,150 @@
|
||||
from functools import partial
|
||||
from time import time
|
||||
import os
|
||||
import numpy as np
|
||||
import jax
|
||||
import jax.flatten_util
|
||||
import jax.numpy as jnp
|
||||
import mlxu
|
||||
from EasyLM.bpt import blockwise_attn
|
||||
from EasyLM.jax_utils import (
|
||||
get_float_dtype_by_name, set_random_seed, next_rng, JaxRNG
|
||||
)
|
||||
|
||||
|
||||
FLAGS, _ = mlxu.define_flags_with_default(
|
||||
seed=42,
|
||||
dtype='fp32',
|
||||
embed_dim=2048,
|
||||
n_heads=16,
|
||||
ref_attn_seq_len=2048,
|
||||
eff_attn_seq_len=16384,
|
||||
batch_size=1,
|
||||
query_chunk_size=2048,
|
||||
key_chunk_size=2048,
|
||||
warmup_steps=40,
|
||||
steps=200,
|
||||
)
|
||||
|
||||
|
||||
def main(argv):
|
||||
|
||||
def random_kqv(rng_key, seq_len):
|
||||
rng_generator = JaxRNG(rng_key)
|
||||
kqv = []
|
||||
for i in range(3):
|
||||
kqv.append(
|
||||
jax.random.normal(
|
||||
rng_generator(),
|
||||
(FLAGS.batch_size, seq_len, FLAGS.n_heads, FLAGS.embed_dim // FLAGS.n_heads),
|
||||
dtype=get_float_dtype_by_name(FLAGS.dtype)
|
||||
)
|
||||
)
|
||||
return tuple(kqv)
|
||||
|
||||
def reference_attn(query, key, value):
|
||||
dtype = get_float_dtype_by_name(FLAGS.dtype)
|
||||
query = query / jnp.sqrt(query.shape[-1]).astype(dtype)
|
||||
logits = jnp.einsum("bqhc,bkhc->bhqk", query, key)
|
||||
mask_value = jnp.finfo(logits.dtype).min
|
||||
_, q_seq_len, _, _ = query.shape
|
||||
_, kv_seq_len, _, _ = key.shape
|
||||
mask_shape = (q_seq_len, kv_seq_len)
|
||||
row_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 0)
|
||||
col_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 1)
|
||||
causal_mask = (row_ids < col_ids)[None, None, :, :]
|
||||
logits = logits + jnp.where(causal_mask, mask_value, 0.0)
|
||||
weights = jax.nn.softmax(logits, axis=-1)
|
||||
out = jnp.einsum("bhqk,bkhc->bqhc", weights, value)
|
||||
return out
|
||||
|
||||
def efficient_attention(query, key, value):
|
||||
dtype = get_float_dtype_by_name(FLAGS.dtype)
|
||||
return blockwise_attn(
|
||||
query, key, value,
|
||||
bias=None,
|
||||
deterministic=True,
|
||||
dropout_rng=None,
|
||||
attn_pdrop=0.0,
|
||||
causal=True,
|
||||
query_chunk_size=FLAGS.query_chunk_size,
|
||||
key_chunk_size=FLAGS.key_chunk_size,
|
||||
dtype=get_float_dtype_by_name(FLAGS.dtype),
|
||||
policy=jax.checkpoint_policies.nothing_saveable(),
|
||||
precision=None,
|
||||
float32_logits=True,
|
||||
prevent_cse=True,
|
||||
)
|
||||
|
||||
|
||||
@partial(jax.jit, static_argnums=(1,))
|
||||
def reference_attn_forward_backward(rng_key, seq_len):
|
||||
@partial(jax.grad, argnums=(0, 1, 2))
|
||||
@partial(jax.checkpoint, policy=jax.checkpoint_policies.nothing_saveable())
|
||||
def grad_fn(query, key, value):
|
||||
out = reference_attn(query, key, value)
|
||||
return jnp.mean(out)
|
||||
|
||||
query, key, value = random_kqv(rng_key, seq_len)
|
||||
return jax.flatten_util.ravel_pytree(
|
||||
grad_fn(query, key, value)[1]
|
||||
)[0].mean()
|
||||
|
||||
@partial(jax.jit, static_argnums=(1,))
|
||||
def efficient_attn_forward_backward(rng_key, seq_len):
|
||||
@partial(jax.grad, argnums=(0, 1, 2))
|
||||
def grad_fn(query, key, value):
|
||||
out = efficient_attention(query, key, value)
|
||||
return jnp.mean(out)
|
||||
|
||||
query, key, value = random_kqv(rng_key, seq_len)
|
||||
return jax.flatten_util.ravel_pytree(
|
||||
grad_fn(query, key, value)[1]
|
||||
)[0].mean()
|
||||
|
||||
|
||||
set_random_seed(FLAGS.seed)
|
||||
|
||||
jax.block_until_ready(reference_attn_forward_backward(next_rng(), FLAGS.ref_attn_seq_len))
|
||||
jax.block_until_ready(efficient_attn_forward_backward(next_rng(), FLAGS.eff_attn_seq_len))
|
||||
|
||||
all_results = []
|
||||
for i in range(FLAGS.warmup_steps):
|
||||
all_results.append(reference_attn_forward_backward(next_rng(), FLAGS.ref_attn_seq_len))
|
||||
jax.block_until_ready(all_results)
|
||||
|
||||
start_time = time()
|
||||
all_results = []
|
||||
for i in range(FLAGS.steps):
|
||||
all_results.append(reference_attn_forward_backward(next_rng(), FLAGS.ref_attn_seq_len))
|
||||
|
||||
jax.block_until_ready(all_results)
|
||||
elapsed_time_ref_attn = time() - start_time
|
||||
print(f'Reference attention: {elapsed_time_ref_attn:.3f} seconds')
|
||||
|
||||
|
||||
all_results = []
|
||||
for i in range(FLAGS.warmup_steps):
|
||||
all_results.append(efficient_attn_forward_backward(next_rng(), FLAGS.eff_attn_seq_len))
|
||||
jax.block_until_ready(all_results)
|
||||
|
||||
|
||||
start_time = time()
|
||||
all_results = []
|
||||
for i in range(FLAGS.steps):
|
||||
all_results.append(efficient_attn_forward_backward(next_rng(), FLAGS.eff_attn_seq_len))
|
||||
|
||||
jax.block_until_ready(all_results)
|
||||
elapsed_time_efficient_attn = time() - start_time
|
||||
print(f'Efficient attention: {elapsed_time_efficient_attn:.3f} seconds')
|
||||
|
||||
flops_ratio = (FLAGS.eff_attn_seq_len / FLAGS.ref_attn_seq_len) ** 2
|
||||
efficiency = elapsed_time_ref_attn / elapsed_time_efficient_attn * flops_ratio
|
||||
print(f'Efficiency: {efficiency:.3f}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
mlxu.run(main)
|
||||
|
||||
|
||||
|
||||
42
EasyLM/scripts/convert_checkpoint.py
Normal file
42
EasyLM/scripts/convert_checkpoint.py
Normal file
@@ -0,0 +1,42 @@
|
||||
# This script converts model checkpoint trained by EsayLM to a standard
|
||||
# mspack checkpoint that can be loaded by huggingface transformers or
|
||||
# flax.serialization.msgpack_restore. Such conversion allows models to be
|
||||
# used by other frameworks that integrate with huggingface transformers.
|
||||
|
||||
import pprint
|
||||
from functools import partial
|
||||
import os
|
||||
import numpy as np
|
||||
import mlxu
|
||||
import jax.numpy as jnp
|
||||
import flax.serialization
|
||||
from EasyLM.checkpoint import StreamingCheckpointer
|
||||
from EasyLM.jax_utils import float_to_dtype
|
||||
|
||||
|
||||
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
|
||||
load_checkpoint='',
|
||||
output_file='',
|
||||
streaming=False,
|
||||
float_dtype='bf16',
|
||||
)
|
||||
|
||||
|
||||
def main(argv):
|
||||
assert FLAGS.load_checkpoint != '' and FLAGS.output_file != '', 'input and output must be specified'
|
||||
params = StreamingCheckpointer.load_trainstate_checkpoint(
|
||||
FLAGS.load_checkpoint, disallow_trainstate=True
|
||||
)[1]['params']
|
||||
|
||||
if FLAGS.streaming:
|
||||
StreamingCheckpointer.save_train_state_to_file(
|
||||
params, FLAGS.output_file, float_dtype=FLAGS.float_dtype
|
||||
)
|
||||
else:
|
||||
params = float_to_dtype(params, FLAGS.float_dtype)
|
||||
with mlxu.open_file(FLAGS.output, 'wb') as fout:
|
||||
fout.write(flax.serialization.msgpack_serialize(params, in_place=True))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mlxu.run(main)
|
||||
59
EasyLM/scripts/diff_checkpoint.py
Normal file
59
EasyLM/scripts/diff_checkpoint.py
Normal file
@@ -0,0 +1,59 @@
|
||||
# This script converts model checkpoint trained by EsayLM to a standard
|
||||
# mspack checkpoint that can be loaded by huggingface transformers or
|
||||
# flax.serialization.msgpack_restore. Such conversion allows models to be
|
||||
# used by other frameworks that integrate with huggingface transformers.
|
||||
|
||||
import pprint
|
||||
from functools import partial
|
||||
import os
|
||||
import numpy as np
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import flax.serialization
|
||||
import mlxu
|
||||
from EasyLM.checkpoint import StreamingCheckpointer
|
||||
from EasyLM.jax_utils import float_to_dtype
|
||||
|
||||
|
||||
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
|
||||
recover_diff=False,
|
||||
load_base_checkpoint='',
|
||||
load_target_checkpoint='',
|
||||
output_file='',
|
||||
streaming=True,
|
||||
float_dtype='bf16',
|
||||
)
|
||||
|
||||
|
||||
def main(argv):
|
||||
assert FLAGS.load_base_checkpoint != '' and FLAGS.load_target_checkpoint != ''
|
||||
assert FLAGS.output_file != ''
|
||||
base_params = StreamingCheckpointer.load_trainstate_checkpoint(
|
||||
FLAGS.load_base_checkpoint, disallow_trainstate=True
|
||||
)[1]['params']
|
||||
|
||||
target_params = StreamingCheckpointer.load_trainstate_checkpoint(
|
||||
FLAGS.load_target_checkpoint, disallow_trainstate=True
|
||||
)[1]['params']
|
||||
|
||||
if FLAGS.recover_diff:
|
||||
params = jax.tree_util.tree_map(
|
||||
lambda b, t: b + t, base_params, target_params
|
||||
)
|
||||
else:
|
||||
params = jax.tree_util.tree_map(
|
||||
lambda b, t: t - b, base_params, target_params
|
||||
)
|
||||
|
||||
if FLAGS.streaming:
|
||||
StreamingCheckpointer.save_train_state_to_file(
|
||||
params, FLAGS.output_file, float_dtype=FLAGS.float_dtype
|
||||
)
|
||||
else:
|
||||
params = float_to_dtype(params, FLAGS.float_dtype)
|
||||
with mlxu.open_file(FLAGS.output, 'wb') as fout:
|
||||
fout.write(flax.serialization.msgpack_serialize(params, in_place=True))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mlxu.run(main)
|
||||
65
EasyLM/scripts/lm_eval_harness.py
Normal file
65
EasyLM/scripts/lm_eval_harness.py
Normal file
@@ -0,0 +1,65 @@
|
||||
# This script runs lm_eval_harness evaluations against a served language model.
|
||||
# Typically, you need to run a language model server first, e.g.:
|
||||
# python -m EasyLM.models.gptj.gptj_serve ...
|
||||
|
||||
import dataclasses
|
||||
import pprint
|
||||
from functools import partial
|
||||
import os
|
||||
from tqdm import tqdm, trange
|
||||
import numpy as np
|
||||
import mlxu
|
||||
|
||||
from flax.traverse_util import flatten_dict
|
||||
from lm_eval import evaluator, tasks
|
||||
from lm_eval.base import LM
|
||||
|
||||
from EasyLM.serving import LMClient
|
||||
|
||||
|
||||
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
|
||||
tasks='wsc,piqa,winogrande,openbookqa,logiqa',
|
||||
shots=0,
|
||||
limit=0,
|
||||
write_out=False,
|
||||
lm_client=LMClient.get_default_config(),
|
||||
logger=mlxu.WandBLogger.get_default_config(),
|
||||
)
|
||||
|
||||
|
||||
class LMEvalHarnessInterface(LM):
|
||||
|
||||
def __init__(self, lm_client):
|
||||
self.lm_client = lm_client
|
||||
|
||||
def greedy_until(self, inputs):
|
||||
prefix, until = zip(*inputs)
|
||||
return self.lm_client.greedy_until(prefix, until)
|
||||
|
||||
def loglikelihood_rolling(self, inputs):
|
||||
loglikelihood, is_greedy = self.lm_client.loglikelihood_rolling(inputs)
|
||||
return list(zip(loglikelihood, is_greedy))
|
||||
|
||||
def loglikelihood(self, inputs):
|
||||
prefix, text = zip(*inputs)
|
||||
loglikelihood, is_greedy = self.lm_client.loglikelihood(prefix, text)
|
||||
return list(zip(loglikelihood, is_greedy))
|
||||
|
||||
|
||||
def main(argv):
|
||||
logger = mlxu.WandBLogger(
|
||||
config=FLAGS.logger, variant=mlxu.get_user_flags(FLAGS, FLAGS_DEF)
|
||||
)
|
||||
model = LMEvalHarnessInterface(LMClient(FLAGS.lm_client))
|
||||
task_list = FLAGS.tasks.split(',')
|
||||
results = evaluator.evaluate(
|
||||
model, tasks.get_task_dict(task_list), False, FLAGS.shots,
|
||||
limit=None if FLAGS.limit <= 0 else FLAGS.limit,
|
||||
write_out=FLAGS.write_out,
|
||||
)
|
||||
logger.log(flatten_dict(results['results'], sep='/'))
|
||||
pprint.pprint(results)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mlxu.run(main)
|
||||
52
EasyLM/scripts/lm_eval_json.py
Normal file
52
EasyLM/scripts/lm_eval_json.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import json
|
||||
import mlxu
|
||||
from EasyLM.serving import LMClient
|
||||
|
||||
|
||||
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
|
||||
input_file='',
|
||||
output_file='',
|
||||
prefix_field='prefix',
|
||||
text_field='text',
|
||||
until_field='until',
|
||||
eval_type='loglikelihood',
|
||||
lm_client=LMClient.get_default_config(),
|
||||
)
|
||||
|
||||
|
||||
def main(argv):
|
||||
lm_client = LMClient(FLAGS.lm_client)
|
||||
with mlxu.open_file(FLAGS.input_file, 'r') as fin:
|
||||
input_data = json.load(fin)
|
||||
|
||||
if FLAGS.eval_type == 'loglikelihood':
|
||||
prefix = input_data[FLAGS.prefix_field]
|
||||
text = input_data[FLAGS.text_field]
|
||||
loglikelihoods, is_greedys = lm_client.loglikelihood(prefix, text)
|
||||
output_data = {
|
||||
'loglikelihood': loglikelihoods,
|
||||
'is_greedy': is_greedys,
|
||||
}
|
||||
elif FLAGS.eval_type == 'loglikelihood_rolling':
|
||||
text = input_data[FLAGS.text_field]
|
||||
loglikelihoods, is_greedys = lm_client.loglikelihood_rolling(text)
|
||||
output_data = {
|
||||
'loglikelihood': loglikelihoods,
|
||||
'is_greedy': is_greedys,
|
||||
}
|
||||
elif FLAGS.eval_type == 'greedy_until':
|
||||
prefix = input_data[FLAGS.prefix_field]
|
||||
until = input_data[FLAGS.until_field]
|
||||
output_data = {'output_text': lm_client.greedy_until(prefix, until)}
|
||||
elif FLAGS.eval_type == 'generate':
|
||||
prefix = input_data[FLAGS.prefix_field]
|
||||
output_data = {'output_text': lm_client.generate(prefix)}
|
||||
else:
|
||||
raise ValueError(f'Unknown eval_type: {FLAGS.eval_type}')
|
||||
|
||||
with mlxu.open_file(FLAGS.output_file, 'w') as fout:
|
||||
json.dump(output_data, fout)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mlxu.run(main)
|
||||
566
EasyLM/serving.py
Normal file
566
EasyLM/serving.py
Normal file
@@ -0,0 +1,566 @@
|
||||
import dataclasses
|
||||
import pprint
|
||||
from functools import partial
|
||||
import re
|
||||
import os
|
||||
from threading import Lock
|
||||
import urllib
|
||||
import time
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
import absl.logging
|
||||
from tqdm import tqdm, trange
|
||||
import numpy as np
|
||||
import mlxu
|
||||
from ml_collections import ConfigDict
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
import gradio as gr
|
||||
import requests
|
||||
from requests.exceptions import Timeout, ConnectionError
|
||||
|
||||
|
||||
class InferenceRequest(BaseModel):
|
||||
prefix_text: Optional[List[str]] = None
|
||||
text: Optional[List[str]] = None
|
||||
until: Optional[Union[List[str], List[List[str]]]] = None
|
||||
temperature: Optional[float] = None
|
||||
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
prompt: str
|
||||
context: str = ''
|
||||
temperature: Optional[float] = None
|
||||
|
||||
|
||||
class LMServer(object):
|
||||
""" HTTP server for serving langauge models. """
|
||||
|
||||
@staticmethod
|
||||
def get_default_config(updates=None):
|
||||
config = ConfigDict()
|
||||
config.host = '0.0.0.0'
|
||||
config.port = 5007
|
||||
config.batch_size = 1
|
||||
config.logging = False
|
||||
config.pre_compile = 'loglikelihood'
|
||||
config.default_temperature = 1.0
|
||||
config.greedy_until_max_length = 5000
|
||||
config.prepend_to_prefix = ''
|
||||
config.append_to_prefix = ''
|
||||
config.prepend_to_text = ''
|
||||
config.append_to_text = ''
|
||||
config.chat_prepend_text = ''
|
||||
config.chat_user_prefix = ''
|
||||
config.chat_user_suffix = ''
|
||||
config.chat_lm_prefix = ''
|
||||
config.chat_lm_suffix = ''
|
||||
config.notes = ''
|
||||
|
||||
if updates is not None:
|
||||
config.update(ConfigDict(updates).copy_and_resolve_references())
|
||||
return config
|
||||
|
||||
def __init__(self, config):
|
||||
self.config = self.get_default_config(config)
|
||||
self.lock = Lock()
|
||||
self.app = FastAPI()
|
||||
self.app.post('/loglikelihood')(self.serve_loglikelihood)
|
||||
self.app.post('/loglikelihood-rolling')(self.serve_loglikelihood_rolling)
|
||||
self.app.post('/generate')(self.serve_generate)
|
||||
self.app.post('/greedy-until')(self.serve_greedy_until)
|
||||
self.app.post('/chat')(self.serve_chat)
|
||||
self.app.get('/ready')(self.serve_ready)
|
||||
self.app = gr.mount_gradio_app(self.app, self.create_chat_app(), '/')
|
||||
|
||||
@staticmethod
|
||||
def loglikelihood(prefix_text, text):
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
def loglikelihood_rolling(text):
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
def generate(text, temperature):
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
def greedy_until(prefix_text, until, max_length):
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
def to_list(x):
|
||||
if isinstance(x, np.ndarray):
|
||||
return x.tolist()
|
||||
return x
|
||||
|
||||
def serve_ready(self):
|
||||
return 'Ready!\n'
|
||||
|
||||
def serve_loglikelihood(self, data: InferenceRequest):
|
||||
with self.lock:
|
||||
if self.config.logging:
|
||||
absl.logging.info(
|
||||
'\n========= Serving Log Likelihood Request ========= \n'
|
||||
+ pprint.pformat(data) + '\n'
|
||||
)
|
||||
|
||||
if data.prefix_text is None:
|
||||
data.prefix_text = ['' for _ in data.text]
|
||||
|
||||
prefix_text = [
|
||||
self.config.prepend_to_prefix + p + self.config.append_to_prefix
|
||||
for p in data.prefix_text
|
||||
]
|
||||
text = [
|
||||
self.config.prepend_to_text + t + self.config.append_to_text
|
||||
for t in data.text
|
||||
]
|
||||
|
||||
log_likelihood = []
|
||||
is_greedy = []
|
||||
for i in trange(0, len(text), self.config.batch_size, ncols=0):
|
||||
batch_prefix_text = prefix_text[i:i + self.config.batch_size]
|
||||
batch_text = text[i:i + self.config.batch_size]
|
||||
batch_size = len(batch_text)
|
||||
|
||||
if batch_size < self.config.batch_size:
|
||||
extra = self.config.batch_size - batch_size
|
||||
batch_prefix_text.extend(['a' for _ in range(extra)])
|
||||
batch_text.extend(['a' for _ in range(extra)])
|
||||
|
||||
batch_log_likelihood, batch_is_greedy = self.loglikelihood(
|
||||
batch_prefix_text, batch_text
|
||||
)
|
||||
batch_log_likelihood = self.to_list(batch_log_likelihood)
|
||||
batch_is_greedy = self.to_list(batch_is_greedy)
|
||||
log_likelihood.extend(batch_log_likelihood[:batch_size])
|
||||
is_greedy.extend(batch_is_greedy[:batch_size])
|
||||
|
||||
output = {
|
||||
'prefix_text': data.prefix_text,
|
||||
'text': data.text,
|
||||
'log_likelihood': log_likelihood,
|
||||
'is_greedy': is_greedy,
|
||||
}
|
||||
if self.config.logging:
|
||||
absl.logging.info(
|
||||
'\n========= Output ========= \n'
|
||||
+ pprint.pformat(output) + '\n'
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
def serve_loglikelihood_rolling(self, data: InferenceRequest):
|
||||
with self.lock:
|
||||
if self.config.logging:
|
||||
absl.logging.info(
|
||||
'\n========= Serving Log Likelihood Request ========= \n'
|
||||
+ pprint.pformat(data) + '\n'
|
||||
)
|
||||
|
||||
text = [
|
||||
self.config.prepend_to_text + t + self.config.append_to_text
|
||||
for t in data.text
|
||||
]
|
||||
log_likelihood = []
|
||||
is_greedy = []
|
||||
for i in trange(0, len(text), self.config.batch_size, ncols=0):
|
||||
batch_text = text[i:i + self.config.batch_size]
|
||||
batch_size = len(batch_text)
|
||||
|
||||
if batch_size < self.config.batch_size:
|
||||
extra = self.config.batch_size - batch_size
|
||||
batch_text.extend(['a' for _ in range(extra)])
|
||||
|
||||
batch_log_likelihood, batch_is_greedy = self.loglikelihood_rolling(
|
||||
batch_text
|
||||
)
|
||||
batch_log_likelihood = self.to_list(batch_log_likelihood)
|
||||
batch_is_greedy = self.to_list(batch_is_greedy)
|
||||
log_likelihood.extend(batch_log_likelihood[:batch_size])
|
||||
is_greedy.extend(batch_is_greedy[:batch_size])
|
||||
|
||||
output = {
|
||||
'text': data.text,
|
||||
'log_likelihood': log_likelihood,
|
||||
'is_greedy': is_greedy,
|
||||
}
|
||||
if self.config.logging:
|
||||
absl.logging.info(
|
||||
'\n========= Output ========= \n'
|
||||
+ pprint.pformat(output) + '\n'
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
def serve_generate(self, data: InferenceRequest):
|
||||
with self.lock:
|
||||
if self.config.logging:
|
||||
absl.logging.info(
|
||||
'\n========= Serving Generate Request ========= \n'
|
||||
+ pprint.pformat(data) + '\n'
|
||||
)
|
||||
prefix_text = [
|
||||
self.config.prepend_to_prefix + p + self.config.append_to_prefix
|
||||
for p in data.prefix_text
|
||||
]
|
||||
|
||||
if data.temperature is None:
|
||||
data.temperature = self.config.default_temperature
|
||||
|
||||
output_text = []
|
||||
for i in trange(0, len(prefix_text), self.config.batch_size, ncols=0):
|
||||
batch_prefix_text = prefix_text[i:i + self.config.batch_size]
|
||||
batch_size = len(batch_prefix_text)
|
||||
|
||||
if batch_size < self.config.batch_size:
|
||||
extra = self.config.batch_size - batch_size
|
||||
batch_prefix_text.extend(['a' for _ in range(extra)])
|
||||
|
||||
batch_output_text = self.generate(
|
||||
batch_prefix_text,
|
||||
temperature=data.temperature,
|
||||
)
|
||||
output_text.extend(self.to_list(batch_output_text)[:batch_size])
|
||||
|
||||
output = {
|
||||
'prefix_text': data.prefix_text,
|
||||
'output_text': output_text,
|
||||
'temperature': data.temperature,
|
||||
}
|
||||
if self.config.logging:
|
||||
absl.logging.info(
|
||||
'\n========= Output ========= \n'
|
||||
+ pprint.pformat(output) + '\n'
|
||||
)
|
||||
return output
|
||||
|
||||
def serve_greedy_until(self, data: InferenceRequest):
|
||||
with self.lock:
|
||||
if self.config.logging:
|
||||
absl.logging.info(
|
||||
'\n========= Serving Greedy Until Request ========= \n'
|
||||
+ pprint.pformat(data) + '\n'
|
||||
)
|
||||
prefix_text = [
|
||||
self.config.prepend_to_prefix + p + self.config.append_to_prefix
|
||||
for p in data.prefix_text
|
||||
]
|
||||
until = data.until
|
||||
max_length = self.config.greedy_until_max_length
|
||||
|
||||
output_text = []
|
||||
for i in range(0, len(prefix_text), self.config.batch_size):
|
||||
batch_prefix_text = prefix_text[i:i + self.config.batch_size]
|
||||
batch_until = until[i:i + self.config.batch_size]
|
||||
batch_size = len(batch_prefix_text)
|
||||
|
||||
batch_output_text = self.greedy_until(batch_prefix_text, batch_until, max_length)
|
||||
output_text.extend(self.to_list(batch_output_text)[:batch_size])
|
||||
|
||||
output = {
|
||||
'prefix_text': data.prefix_text,
|
||||
'until': data.until,
|
||||
'max_length': max_length,
|
||||
'output_text': output_text,
|
||||
}
|
||||
if self.config.logging:
|
||||
absl.logging.info(
|
||||
'\n========= Output ========= \n'
|
||||
+ pprint.pformat(output) + '\n'
|
||||
)
|
||||
return output
|
||||
|
||||
def process_chat(self, prompt, context, temperature):
|
||||
context = (
|
||||
context + self.config.chat_user_prefix
|
||||
+ prompt + self.config.chat_user_suffix
|
||||
+ self.config.chat_lm_prefix
|
||||
)
|
||||
response = self.generate(
|
||||
[self.config.chat_prepend_text + context],
|
||||
temperature=float(temperature),
|
||||
)[0]
|
||||
context = context + response + self.config.chat_lm_suffix
|
||||
return response, context
|
||||
|
||||
def serve_chat(self, data: ChatRequest):
|
||||
if data.temperature is None:
|
||||
data.temperature = self.config.default_temperature
|
||||
response, context = self.process_chat(
|
||||
data.prompt, data.context,
|
||||
temperature=data.temperature,
|
||||
)
|
||||
return {
|
||||
'response': response,
|
||||
'context': context,
|
||||
'temperature': data.temperature,
|
||||
}
|
||||
|
||||
def create_chat_app(self):
|
||||
with gr.Blocks(analytics_enabled=False, title='EasyLM Chat') as gradio_chatbot:
|
||||
gr.Markdown('# Chatbot Powered by [EasyLM](https://github.com/young-geng/EasyLM)')
|
||||
gr.Markdown(self.config.notes)
|
||||
chatbot = gr.Chatbot(label='Chat history')
|
||||
msg = gr.Textbox(
|
||||
placeholder='Type your message here...',
|
||||
show_label=False
|
||||
)
|
||||
with gr.Row():
|
||||
send = gr.Button('Send')
|
||||
regenerate = gr.Button('Regenerate', interactive=False)
|
||||
clear = gr.Button('Reset')
|
||||
|
||||
temp_slider = gr.Slider(
|
||||
label='Temperature', minimum=0, maximum=2.0,
|
||||
value=self.config.default_temperature
|
||||
)
|
||||
|
||||
context_state = gr.State(['', ''])
|
||||
|
||||
def user_fn(user_message, history, context):
|
||||
return {
|
||||
msg: gr.update(value='', interactive=False),
|
||||
clear: gr.update(interactive=False),
|
||||
send: gr.update(interactive=False),
|
||||
regenerate: gr.update(interactive=False),
|
||||
chatbot: history + [[user_message, None]],
|
||||
context_state: [context[1], context[1]],
|
||||
}
|
||||
|
||||
def model_fn(history, context, temperature):
|
||||
history[-1][1], new_context = self.process_chat(
|
||||
history[-1][0], context[0], temperature
|
||||
)
|
||||
return {
|
||||
msg: gr.update(value='', interactive=True),
|
||||
clear: gr.update(interactive=True),
|
||||
send: gr.update(interactive=True),
|
||||
chatbot: history,
|
||||
context_state: [context[0], new_context],
|
||||
regenerate: gr.update(interactive=True),
|
||||
}
|
||||
|
||||
def regenerate_fn():
|
||||
return {
|
||||
msg: gr.update(value='', interactive=False),
|
||||
clear: gr.update(interactive=False),
|
||||
send: gr.update(interactive=False),
|
||||
regenerate: gr.update(interactive=False),
|
||||
}
|
||||
|
||||
def clear_fn():
|
||||
return {
|
||||
chatbot: None,
|
||||
msg: '',
|
||||
context_state: ['', ''],
|
||||
regenerate: gr.update(interactive=False),
|
||||
}
|
||||
|
||||
msg.submit(
|
||||
user_fn,
|
||||
inputs=[msg, chatbot, context_state],
|
||||
outputs=[msg, clear, send, chatbot, context_state, regenerate],
|
||||
queue=False
|
||||
).then(
|
||||
model_fn,
|
||||
inputs=[chatbot, context_state, temp_slider],
|
||||
outputs=[msg, clear, send, chatbot, context_state, regenerate],
|
||||
queue=True
|
||||
)
|
||||
send.click(
|
||||
user_fn,
|
||||
inputs=[msg, chatbot, context_state],
|
||||
outputs=[msg, clear, send, chatbot, context_state, regenerate],
|
||||
queue=False
|
||||
).then(
|
||||
model_fn,
|
||||
inputs=[chatbot, context_state, temp_slider],
|
||||
outputs=[msg, clear, send, chatbot, context_state, regenerate],
|
||||
queue=True
|
||||
)
|
||||
regenerate.click(
|
||||
regenerate_fn,
|
||||
inputs=None,
|
||||
outputs=[msg, clear, send, regenerate],
|
||||
queue=False
|
||||
).then(
|
||||
model_fn,
|
||||
inputs=[chatbot, context_state, temp_slider],
|
||||
outputs=[msg, clear, send, chatbot, context_state, regenerate],
|
||||
queue=True
|
||||
)
|
||||
clear.click(
|
||||
clear_fn,
|
||||
inputs=None,
|
||||
outputs=[chatbot, msg, context_state, regenerate],
|
||||
queue=False
|
||||
)
|
||||
|
||||
gradio_chatbot.queue(concurrency_count=1)
|
||||
return gradio_chatbot
|
||||
|
||||
def run(self):
|
||||
if self.config.pre_compile != '':
|
||||
if self.config.pre_compile == 'all':
|
||||
pre_compile = ['loglikelihood', 'generate', 'greedy_until', 'chat']
|
||||
else:
|
||||
pre_compile = self.config.pre_compile.split(',')
|
||||
|
||||
pre_compile_data = ['a' for _ in range(self.config.batch_size)]
|
||||
for task in pre_compile:
|
||||
if task == 'loglikelihood':
|
||||
self.loglikelihood(pre_compile_data, pre_compile_data)
|
||||
self.loglikelihood_rolling(pre_compile_data)
|
||||
elif task == 'generate':
|
||||
self.generate(pre_compile_data, 1.0)
|
||||
elif task == 'greedy_until':
|
||||
self.greedy_until(
|
||||
pre_compile_data, pre_compile_data,
|
||||
self.config.greedy_until_max_length
|
||||
)
|
||||
elif task == 'chat':
|
||||
self.process_chat('a', 'a', 1.0)
|
||||
else:
|
||||
raise ValueError(f'Invalid precompile task: {task}!')
|
||||
|
||||
uvicorn.run(self.app, host=self.config.host, port=self.config.port)
|
||||
|
||||
|
||||
class LMClient(object):
|
||||
""" A simple client for the LM server. """
|
||||
|
||||
@staticmethod
|
||||
def get_default_config(updates=None):
|
||||
config = ConfigDict()
|
||||
config.url = 'http://localhost:5007'
|
||||
config.batch_size = 1
|
||||
config.wait_for_ready = True
|
||||
config.dummy = False
|
||||
|
||||
if updates is not None:
|
||||
config.update(ConfigDict(updates).copy_and_resolve_references())
|
||||
return config
|
||||
|
||||
def __init__(self, config=None):
|
||||
self.config = self.get_default_config(config)
|
||||
if self.config.wait_for_ready:
|
||||
self.wait_for_ready()
|
||||
|
||||
def wait_for_ready(self):
|
||||
if self.config.dummy:
|
||||
return
|
||||
while True:
|
||||
try:
|
||||
requests.get(urllib.parse.urljoin(self.config.url, 'ready'))
|
||||
return
|
||||
except (Timeout, ConnectionError) as e:
|
||||
time.sleep(10)
|
||||
|
||||
@staticmethod
|
||||
def batched(iterator, batch_size):
|
||||
batch = []
|
||||
for example in iterator:
|
||||
batch.append(example)
|
||||
if len(batch) == batch_size:
|
||||
yield batch
|
||||
batch = []
|
||||
if len(batch) > 0:
|
||||
yield batch
|
||||
|
||||
def loglikelihood(self, prefix, text):
|
||||
prefix, text = list(prefix), list(text)
|
||||
if self.config.dummy:
|
||||
return [-1.0 for _ in text], [False for _ in text]
|
||||
|
||||
log_likelihood = []
|
||||
is_greedy = []
|
||||
|
||||
batched_iterator = list(zip(
|
||||
self.batched(prefix, self.config.batch_size),
|
||||
self.batched(text, self.config.batch_size)
|
||||
))
|
||||
for batch_prefix, batch_text in tqdm(batched_iterator, ncols=0):
|
||||
response = requests.post(
|
||||
urllib.parse.urljoin(self.config.url, 'loglikelihood'),
|
||||
json={'prefix_text': batch_prefix, 'text': batch_text}
|
||||
).json()
|
||||
log_likelihood.extend(response['log_likelihood'])
|
||||
is_greedy.extend(response['is_greedy'])
|
||||
|
||||
return log_likelihood, is_greedy
|
||||
|
||||
def loglikelihood_rolling(self, text):
|
||||
text = list(text)
|
||||
if self.config.dummy:
|
||||
return [-1.0 for _ in text], [False for _ in text]
|
||||
|
||||
log_likelihood = []
|
||||
is_greedy = []
|
||||
batched_iterator = list(self.batched(text, self.config.batch_size))
|
||||
for batch_text in tqdm(batched_iterator, ncols=0):
|
||||
response = requests.post(
|
||||
urllib.parse.urljoin(self.config.url, 'loglikelihood-rolling'),
|
||||
json={'text': batch_text}
|
||||
).json()
|
||||
log_likelihood.extend(response['log_likelihood'])
|
||||
is_greedy.extend(response['is_greedy'])
|
||||
return log_likelihood, is_greedy
|
||||
|
||||
def greedy_until(self, prefix, until):
|
||||
prefix, until = list(prefix), list(until)
|
||||
if self.config.dummy:
|
||||
results = []
|
||||
for u in until:
|
||||
if isinstance(u, str):
|
||||
results.append('dummy text ' + u)
|
||||
else:
|
||||
results.append('dummy text ' + u[0])
|
||||
return results
|
||||
|
||||
batched_iterator = list(zip(
|
||||
self.batched(prefix, self.config.batch_size),
|
||||
self.batched(until, self.config.batch_size),
|
||||
))
|
||||
output_text = []
|
||||
for batch_prefix, batch_until in tqdm(batched_iterator, ncols=0):
|
||||
response = requests.post(
|
||||
urllib.parse.urljoin(self.config.url, 'greedy-until'),
|
||||
json={'prefix_text': batch_prefix, 'until': batch_until}
|
||||
).json()
|
||||
output_text.extend(response['output_text'])
|
||||
return output_text
|
||||
|
||||
def generate(self, prefix, temperature=None):
|
||||
prefix = list(prefix)
|
||||
if self.config.dummy:
|
||||
return ['' for _ in prefix]
|
||||
|
||||
output_text = []
|
||||
batched_iterator = list(self.batched(prefix, self.config.batch_size))
|
||||
for batch_prefix in tqdm(batched_iterator, ncols=0):
|
||||
response = requests.post(
|
||||
urllib.parse.urljoin(self.config.url, 'generate'),
|
||||
json={
|
||||
'prefix_text': batch_prefix,
|
||||
'temperature': temperature,
|
||||
}
|
||||
).json()
|
||||
output_text.extend(response['output_text'])
|
||||
return output_text
|
||||
|
||||
def chat(self, prompt, context, temperature=None):
|
||||
if self.config.dummy:
|
||||
return ''
|
||||
response = requests.post(
|
||||
urllib.parse.urljoin(self.config.url, 'chat'),
|
||||
json={
|
||||
'prompt': prompt,
|
||||
'context': context,
|
||||
'temperature': temperature,
|
||||
}
|
||||
).json()
|
||||
return response['response'], response['context']
|
||||
Reference in New Issue
Block a user