初始化项目,由ModelHub XC社区提供模型

Model: Finnish-NLP/Ahma-7B
Source: Original Platform
This commit is contained in:
ModelHub XC
2026-06-01 02:08:18 +08:00
commit be39ad8722
45 changed files with 297486 additions and 0 deletions

0
EasyLM/__init__.py Normal file
View File

228
EasyLM/bpt.py Normal file
View File

@@ -0,0 +1,228 @@
"""
An implementation of Blockwise parallel transformer https://arxiv.org/abs/2305.19370
Also include a reference implementation of memory-efficient transformer https://arxiv.org/abs/2112.05682
"""
import functools
from typing import NamedTuple
import flax.linen as nn
import jax
import jax.lax as lax
import jax.numpy as jnp
from einops import rearrange
"""
Computing ffn blockwise without materializing the large hidden tensor, training
4x longer sequences than the memory-efficient transformer.
Blockwise parallel transformer https://arxiv.org/abs/2305.19370 Liu et al. 2023
"""
def blockwise_ffn(remat_ffn, inputs, chunk_size=2048, deterministic=True):
# remat_ffn: a rematerialized ffn with policy jax.checkpoint_policies.nothing_saveable()
# inputs: (batch, seq_len, dim)
# chunk_size: the chunk size to split the sequence
inputs = rearrange(inputs, 'b (c n) d -> b c n d', c=chunk_size)
def scan_ffn(remat_ffn, carry, hidden_states):
outputs = remat_ffn(hidden_states, deterministic=deterministic)
return carry, outputs
scan_axis = inputs.ndim - 2
_, res = nn.scan(
scan_ffn,
variable_broadcast="params",
split_rngs={"params": False, "dropout": True},
in_axes=scan_axis,
out_axes=scan_axis,
)(remat_ffn, None, inputs)
res = rearrange(res, 'b c n d -> b (c n) d')
return res
"""
Compute attention blockwise without materializing the full attention matrix,
initially proposed in memory-efficient transformer https://arxiv.org/abs/2112.05682 Rabe et al. 2021;
flash attention https://arxiv.org/abs/2205.14135 Dao et al. 2022 proposes a CUDA
efficient implementation; blockwise parallel transformer https://arxiv.org/abs/2305.19370
Liu et al. 2023 proposes blockwise computing both attention and FFN, enabling 4x
longer sequences than memory-efficient/flash-attention and fusion of attention and FFN.
"""
def blockwise_attn(
query, key, value,
bias=None,
deterministic=True,
dropout_rng=None,
attn_pdrop=0.0,
causal=True,
query_chunk_size=2048,
key_chunk_size=2048,
dtype=jnp.float32,
policy=jax.checkpoint_policies.nothing_saveable(),
precision=None,
float32_logits=True,
prevent_cse=True,
):
# query, key, value: (batch, seq_len, num_heads, dim_per_head)
# bias: (batch, seq_len) can be used to mask out attention (e.g. padding)
# causal: whether to use causal mask
# policy: one of jax.checkpoint_policies
query = query / jnp.sqrt(query.shape[-1]).astype(dtype)
if float32_logits:
query = query.astype(jnp.float32)
key = key.astype(jnp.float32)
batch, q_len, num_heads, dim_per_head = query.shape
batch, kv_len, num_heads, dim_per_head = key.shape
batch, kv_len, num_heads, dim_per_head = value.shape
num_q = q_len // query_chunk_size
num_kv = kv_len // key_chunk_size
query = query.reshape((batch, num_q, query_chunk_size, num_heads, dim_per_head))
key = key.reshape((batch, num_kv, key_chunk_size, num_heads, dim_per_head))
value = value.reshape((batch, num_kv, key_chunk_size, num_heads, dim_per_head))
query = jnp.moveaxis(query, 1, 0)
key = jnp.moveaxis(key, 1, 0)
value = jnp.moveaxis(value, 1, 0)
if bias is not None:
for bias_dim, broadcast_dim in zip(bias.shape, (batch, num_heads, q_len, kv_len)):
assert bias_dim == 1 or bias_dim == broadcast_dim
if not deterministic and attn_pdrop > 0.0:
attn_dropout_rng, dropout_rng = jax.random.split(dropout_rng)
attn_dropout = jax.random.bernoulli(attn_dropout_rng, attn_pdrop, (batch, num_heads, q_len, kv_len))
else:
attn_dropout = None
_chunk_bias_fn = functools.partial(
_chunk_attention_bias,
query_chunk_size, key_chunk_size, bias, deterministic,
attn_dropout, attn_pdrop, causal, dtype)
def scan_attention(args):
query_chunk, query_chunk_idx = args
@functools.partial(jax.checkpoint, prevent_cse=prevent_cse, policy=policy)
def scan_kv_block(carry, args):
key_chunk, value_chunk, key_chunk_idx = args
(numerator, denominator, prev_max_score) = carry
attn_weights = jnp.einsum('bqhd,bkhd->bqhk', query_chunk, key_chunk, precision=precision)
bias_chunk = _chunk_bias_fn(query_chunk_idx, key_chunk_idx)
bias_chunk = jnp.moveaxis(bias_chunk, 1, 2)
attn_weights = attn_weights + bias_chunk
max_score = jnp.max(attn_weights, axis=-1, keepdims=True)
max_score = jnp.maximum(prev_max_score, max_score)
max_score = jax.lax.stop_gradient(max_score)
exp_weights = jnp.exp(attn_weights - max_score)
exp_values = jnp.einsum(
'bqhv,bvhd->bqhd', exp_weights, value_chunk, precision=precision
)
correction = jnp.exp(prev_max_score - max_score)
numerator = numerator * correction + exp_values
denominator = denominator * correction + exp_weights.sum(axis=-1, keepdims=True)
return Carry(numerator, denominator, max_score), None
def skip_upper_half(carry, args):
key_chunk, value_chunk, key_chunk_idx = args
skip_block = jnp.array(False)
if causal:
skip_block = query_chunk_idx < key_chunk_idx
return jax.lax.cond(
skip_block,
lambda carry, args: (carry, None),
scan_kv_block,
carry,
args,
)
init_carry = Carry(
jnp.zeros((batch, query_chunk_size, num_heads, dim_per_head), dtype=query.dtype),
jnp.zeros((batch, query_chunk_size, num_heads, dim_per_head), dtype=query.dtype),
(-jnp.inf) * jnp.ones((batch, query_chunk_size, num_heads, 1), dtype=query.dtype),
)
(numerator, denominator, max_score), _ = lax.scan(
skip_upper_half, init_carry, xs=(key, value, jnp.arange(0, num_kv))
)
outputs = (numerator / denominator).astype(dtype)
return outputs
_, res = lax.scan(
lambda _, x: ((), scan_attention(x)),
(), xs=(query, jnp.arange(0, num_q))
)
res = rearrange(res, 'n b c h d -> b (n c) h d')
return res
class Carry(NamedTuple):
numerator: jax.Array
denominator: jax.Array
max_so_far: jax.Array
def _chunk_attention_bias(query_chunk_size, key_chunk_size,
bias, deterministic, attn_dropout, attn_pdrop, causal,
dtype, query_chunk_idx, key_chunk_idx):
query_offset = query_chunk_idx * query_chunk_size
key_offset = key_chunk_idx * key_chunk_size
chunk_bias = jnp.zeros((1, 1, 1, 1), dtype=dtype)
if bias is not None:
chunk_bias = lax.dynamic_slice(
bias,
start_indices=(0, 0, query_offset, key_offset),
slice_sizes=(*bias.shape[:2], min(bias.shape[-2], query_chunk_size), min(bias.shape[-1], key_chunk_size)),
)
if causal:
query_idx = lax.broadcasted_iota(dtype=jnp.int32, shape=(query_chunk_size, 1), dimension=0)
key_idx = lax.broadcasted_iota(dtype=jnp.int32, shape=(1, key_chunk_size), dimension=1)
offset = query_offset - key_offset
query_idx += offset
causal_mask_value = (query_idx < key_idx) * jnp.finfo(dtype).min
chunk_bias += causal_mask_value.reshape(1, 1, *causal_mask_value.shape)
if not deterministic and attn_pdrop > 0.0:
attn_dropout_slice = lax.dynamic_slice(
attn_dropout,
start_indices=(0, 0, query_offset, key_offset),
slice_sizes=(
*attn_dropout.shape[:2],
min(attn_dropout.shape[-2], query_chunk_size),
min(attn_dropout.shape[-1], key_chunk_size),
),
)
chunk_bias += attn_dropout_slice * jnp.finfo(dtype).min
return chunk_bias.astype(dtype)
if __name__ == '__main__':
# test
def reference_attn(query, key, value, causal, dtype):
query = query / jnp.sqrt(query.shape[-1]).astype(dtype)
logits = jnp.einsum("bqhc,bkhc->bhqk", query, key)
if causal:
mask_value = jnp.finfo(logits.dtype).min
_, q_seq_len, _, _ = query.shape
_, kv_seq_len, _, _ = key.shape
mask_shape = (q_seq_len, kv_seq_len)
row_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 0)
col_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 1)
causal_mask = (row_ids < col_ids)[None, None, :, :]
logits = logits + jnp.where(causal_mask, mask_value, 0.0)
weights = jax.nn.softmax(logits, axis=-1)
out = jnp.einsum("bhqk,bkhc->bqhc", weights, value)
return out
# random inputs
shape = (1, 32, 8, 64)
query = jax.random.normal(jax.random.PRNGKey(0), shape)
key = jax.random.normal(jax.random.PRNGKey(1), shape)
value = jax.random.normal(jax.random.PRNGKey(2), shape)
causal = True
chunk_size = 4
policy = jax.checkpoint_policies.nothing_saveable()
blockwise = blockwise_attn(query, key, value, None, False, None, 0.0, causal, chunk_size, chunk_size, jnp.float32, policy, 'float32', True, False)
reference = reference_attn(query, key, value, causal, 'float32')
assert jnp.allclose(reference, blockwise, atol=1e-6)

212
EasyLM/checkpoint.py Normal file
View File

@@ -0,0 +1,212 @@
import os
import numpy as np
from ml_collections import ConfigDict
import mlxu
import jax
import jax.numpy as jnp
import flax
from flax.serialization import (
from_bytes, to_bytes, to_state_dict, from_state_dict
)
from flax.traverse_util import flatten_dict, unflatten_dict, empty_node
import msgpack
from EasyLM.jax_utils import tree_apply, float_tensor_to_dtype
class StreamingCheckpointer(object):
""" Custom msgpack checkpointer that saves large train states by serializing
and saving tensors one by one in a streaming fashion. Avoids running
out of memory or local TPU disk with default flax checkpointer.
"""
@staticmethod
def get_default_config(updates=None):
config = ConfigDict()
config.float_dtype = 'bf16'
config.save_optimizer_state = False
if updates is not None:
config.update(ConfigDict(updates).copy_and_resolve_references())
return config
def __init__(self, config, checkpoint_dir, enable=True):
self.config = self.get_default_config(config)
self.checkpoint_dir = checkpoint_dir
self.enable = enable
def save_checkpoint(self, train_state, filename, gather_fns=None):
if self.enable:
path = os.path.join(self.checkpoint_dir, filename)
else:
path = '/dev/null'
self.save_train_state_to_file(
train_state, path, gather_fns, self.config.float_dtype
)
@staticmethod
def save_train_state_to_file(train_state, path, gather_fns=None, float_dtype=None):
train_state = to_state_dict(train_state)
packer = msgpack.Packer()
flattend_train_state = flatten_dict(train_state)
if gather_fns is not None:
gather_fns = flatten_dict(to_state_dict(gather_fns))
with mlxu.open_file(path, "wb") as fout:
for key, value in flattend_train_state.items():
if gather_fns is not None:
value = gather_fns[key](value)
value = float_tensor_to_dtype(value, float_dtype)
fout.write(packer.pack((key, to_bytes(value))))
def save_pickle(self, obj, filename):
if self.enable:
path = os.path.join(self.checkpoint_dir, filename)
else:
path = '/dev/null'
mlxu.save_pickle(obj, path)
def save_all(self, train_state, gather_fns, metadata=None, dataset=None, milestone=False):
step = int(jax.device_get(train_state.step))
if self.config.save_optimizer_state:
checkpoint_state = train_state
checkpoint_name = 'streaming_train_state'
checkpoint_gather_fns = gather_fns
else:
checkpoint_state = train_state.params['params']
checkpoint_name = 'streaming_params'
checkpoint_gather_fns = gather_fns.params['params']
if milestone:
# Save a milestone checkpoint that will not be overwritten
self.save_pickle(metadata, f'metadata_{step}.pkl')
self.save_pickle(dataset, f'dataset_{step}.pkl')
self.save_checkpoint(
checkpoint_state, f'{checkpoint_name}_{step}', checkpoint_gather_fns
)
else:
# Save a normal checkpoint that can be overwritten
self.save_pickle(metadata, 'metadata.pkl')
self.save_pickle(dataset, 'dataset.pkl')
self.save_checkpoint(
checkpoint_state, f'{checkpoint_name}', checkpoint_gather_fns
)
@staticmethod
def load_checkpoint(path, target=None, shard_fns=None, remove_dict_prefix=None):
if shard_fns is not None:
shard_fns = flatten_dict(
to_state_dict(shard_fns)
)
if remove_dict_prefix is not None:
remove_dict_prefix = tuple(remove_dict_prefix)
flattend_train_state = {}
with mlxu.open_file(path) as fin:
# 83886080 bytes = 80 MB, which is 16 blocks on GCS
unpacker = msgpack.Unpacker(fin, read_size=83886080, max_buffer_size=0)
for key, value in unpacker:
key = tuple(key)
if remove_dict_prefix is not None:
if key[:len(remove_dict_prefix)] == remove_dict_prefix:
key = key[len(remove_dict_prefix):]
else:
continue
tensor = from_bytes(None, value)
if shard_fns is not None:
tensor = shard_fns[key](tensor)
flattend_train_state[key] = tensor
if target is not None:
flattened_target = flatten_dict(
to_state_dict(target), keep_empty_nodes=True
)
for key, value in flattened_target.items():
if key not in flattend_train_state and value == empty_node:
flattend_train_state[key] = value
train_state = unflatten_dict(flattend_train_state)
if target is None:
return train_state
return from_state_dict(target, train_state)
@staticmethod
def load_flax_checkpoint(path, target=None, shard_fns=None):
""" Load a standard flax checkpoint that's not saved with the
msgpack streaming format.
"""
with mlxu.open_file(path, "rb") as fin:
encoded_bytes = fin.read()
state_dict = flax.serialization.msgpack_restore(encoded_bytes)
if shard_fns is not None:
shard_fns = to_state_dict(shard_fns)
state_dict = tree_apply(shard_fns, state_dict)
if target is None:
return state_dict
return from_state_dict(target, state_dict)
@classmethod
def load_trainstate_checkpoint(cls, load_from, trainstate_target=None,
trainstate_shard_fns=None,
disallow_trainstate=False):
if trainstate_target is not None:
params_target = trainstate_target.params['params']
else:
params_target = None
if trainstate_shard_fns is not None:
params_shard_fns = trainstate_shard_fns.params['params']
else:
params_shard_fns = None
load_type, load_path = load_from.split('::', 1)
if disallow_trainstate:
assert load_type != 'trainstate', 'Loading full trainstate is not allowed!'
train_state = None
restored_params = None
if load_type == 'trainstate':
# Load the entire train state in the streaming format
train_state = cls.load_checkpoint(
path=load_path,
target=trainstate_target,
shard_fns=trainstate_shard_fns,
)
elif load_type == 'trainstate_params':
# Load the params part of the train state in the streaming format
restored_params = cls.load_checkpoint(
path=load_path,
target=params_target,
shard_fns=params_shard_fns,
remove_dict_prefix=('params', 'params'),
)
restored_params = flax.core.frozen_dict.freeze(
{'params': restored_params}
)
elif load_type == 'params':
# Load the params in the streaming format
restored_params = cls.load_checkpoint(
path=load_path,
target=params_target,
shard_fns=params_shard_fns,
)
restored_params = flax.core.frozen_dict.freeze(
{'params': restored_params}
)
elif load_type == 'flax_params':
# Load the params in the standard flax format (non-streaming)
# This requires the entire params to fit in memory
restored_params = cls.load_flax_checkpoint(
path=load_path,
target=params_target,
shard_fns=params_shard_fns
)
restored_params = flax.core.frozen_dict.freeze(
{'params': restored_params}
)
else:
raise ValueError(f'Invalid load_from type: {load_type}')
return train_state, restored_params

436
EasyLM/data.py Normal file
View File

@@ -0,0 +1,436 @@
import dataclasses
import pprint
import time
from functools import partial
import json
import base64
from multiprocessing import Pool
import h5py
import mlxu
from ml_collections.config_dict import config_dict
from ml_collections import ConfigDict
from tqdm import tqdm, trange
import numpy as np
from datasets import load_dataset, load_from_disk
class DatasetFactory(object):
""" Datset builder class. """
@staticmethod
def get_default_config(updates=None):
config = ConfigDict()
config.type = 'huggingface'
config.text_processor = TextProcessor.get_default_config()
config.huggingface_dataset = HuggingfaceDataset.get_default_config()
config.json_dataset = JsonDataset.get_default_config()
if updates is not None:
config.update(ConfigDict(updates).copy_and_resolve_references())
return config
@classmethod
def load_dataset(cls, config, tokenizer, **kwargs):
config = cls.get_default_config(config)
text_processor = TextProcessor(config.text_processor, tokenizer)
if config.type == 'huggingface':
return HuggingfaceDataset(
config.huggingface_dataset, tokenizer, text_processor, **kwargs
)
elif config.type == 'json':
return JsonDataset(config.json_dataset, tokenizer, text_processor, **kwargs)
else:
raise ValueError(f'Unknown dataset type: {config.type}')
def __init__(self):
raise ValueError('DatasetFactory is a static class and should not be instantiated.')
class TextProcessor(object):
""" Example processor that converts a dictionary of texts into tokens. """
@staticmethod
def get_default_config(updates=None):
config = ConfigDict()
config.fields_from_example = ''
config.fields = ''
config.subfield_separator = ' '
config.add_bos_token = True
config.add_eos_token = True
config.prepend_text = ''
config.base64_token_dtype = 'i4'
if updates is not None:
config.update(ConfigDict(updates).copy_and_resolve_references())
return config
def __init__(self, config, tokenizer):
self.config = self.get_default_config(config)
assert self.config.fields != '' or self.config.fields_from_example != '', (
'Either fields or fields_from_example must be specified.'
)
self.tokenizer = tokenizer
def __call__(self, example, has_aux=False):
if has_aux:
example, *aux = example
else:
aux = tuple()
token_buffer = []
loss_mask_buffer = []
if self.config.add_bos_token:
token_buffer.append(self.tokenizer.bos_token_id)
loss_mask_buffer.append(0.0)
if self.config.fields_from_example != '':
fields = example[self.config.fields_from_example].split(',')
else:
fields = self.config.fields.split(',')
for i, field in enumerate(fields):
if field.startswith('[') and field.endswith(']'):
# No loss for this field.
field = field[1:-1]
mask = 0.0
else:
mask = 1.0
if field.startswith('<|') and field.endswith('|>'):
# Special tokens.
field = field[2:-2]
if field == 'bos':
token_buffer.append(self.tokenizer.bos_token_id)
elif field == 'eos':
token_buffer.append(self.tokenizer.eos_token_id)
else:
# Token ID specified directly.
token_buffer.append(int(field))
loss_mask_buffer.append(mask)
elif field.startswith('{') and field.endswith('}'):
field = field[1:-1]
# Base64 encoded raw tokens.
tokens = np.frombuffer(
base64.b64decode(example[field]),
dtype=self.config.base64_token_dtype
).tolist()
token_buffer.extend(tokens)
loss_mask_buffer.extend([mask for _ in range(len(tokens))])
else:
subfields = field.split('+')
text = self.config.subfield_separator.join(
[example[subfield] for subfield in subfields]
)
if i == 0:
text = self.config.prepend_text + text
tokens = self.tokenizer.encode(text)
token_buffer.extend(tokens)
loss_mask_buffer.extend([mask for _ in range(len(tokens))])
if self.config.add_eos_token:
token_buffer.append(self.tokenizer.eos_token_id)
loss_mask_buffer.append(1.0)
return token_buffer, loss_mask_buffer, *aux
class HuggingfaceDataset(object):
""" Huggingface dataset, where the dataset is loaded using the huggingface
datasets.load_dataset() function.
"""
@staticmethod
def get_default_config(updates=None):
config = ConfigDict()
config.path = 'c4'
config.name = 'en'
config.split = 'train'
config.streaming = False
config.seq_length = 1024
config.batch_size = 8
config.always_start_with_bos = False
config.start_seek_loc = 0
config.tokens_count_at_start = 0
config.batch_token_dtype = 'i4'
config.reset_dataset_loc = False
if updates is not None:
config.update(ConfigDict(updates).copy_and_resolve_references())
return config
def __init__(self, config, tokenizer, text_processor, eval_dataset=False):
self.config = self.get_default_config(config)
name = self.config.name if self.config.name != '' else None
split = self.config.split if self.config.split != '' else None
self._tokenizer = tokenizer
self._text_processor = text_processor
self._dataset = load_from_disk(
self.config.path
)[split]
self._dataset = self._dataset.to_iterable_dataset(num_shards=128 if len(self._dataset) > 128 else len(self._dataset))
self._eval_dataset = eval_dataset
self._train_epochs = 0
self._dataset_loc = self.config.start_seek_loc
self._total_tokens = self.config.tokens_count_at_start
self._index = 0
self.reset_dataset_loc = self.config.reset_dataset_loc
def __iter__(self):
if not self._eval_dataset and self._train_epochs > 0:
self._dataset = self._dataset.shuffle(seed=42, buffer_size=10000)
chunk_size = self.config.batch_size * self.config.seq_length
while True:
token_buffer = []
loss_mask_buffer = []
if not self._eval_dataset and self._train_epochs > 0:
self._dataset.set_epoch(self._train_epochs)
for index, example in enumerate(self._dataset):
self._index = index
if not self._eval_dataset and self._dataset_loc > index:
continue
tokens, loss_masks = self.text_processor(example)
token_buffer.extend(tokens)
loss_mask_buffer.extend(loss_masks)
while len(token_buffer) > chunk_size + 1:
self._total_tokens += chunk_size
metrics = {
'dataset_example_index': index,
'dataset_total_tokens': self._total_tokens,
'epoch': self._train_epochs,
}
batch = {
'input_tokens': np.array(token_buffer[:chunk_size], dtype=self.config.batch_token_dtype).reshape(
self.config.batch_size, -1
),
'target_tokens': np.array(token_buffer[1:chunk_size + 1], dtype=self.config.batch_token_dtype).reshape(
self.config.batch_size, -1
),
'loss_masks': np.array(loss_mask_buffer[1:chunk_size + 1], dtype=np.float32).reshape(
self.config.batch_size, -1
),
}
if self.config.always_start_with_bos:
batch['input_tokens'][:, 0] = self.tokenizer.bos_token_id
yield batch, metrics
token_buffer = token_buffer[chunk_size:]
loss_mask_buffer = loss_mask_buffer[chunk_size:]
if self._eval_dataset:
break
else:
if self._train_epochs == 0:
self._dataset = self._dataset.shuffle(seed=42, buffer_size=10000)
self._dataset_loc = 0
self._train_epochs += 1
def get_state_dict(self):
return dict(
config=self.config,
dataset_loc=self._index,
total_tokens=self._total_tokens,
epochs=self._train_epochs,
)
def load_state_dict(self, state_dict):
if 'config' in state_dict:
self.config.update(ConfigDict(state_dict['config']))
self._dataset_loc = state_dict.get('dataset_loc', self.config.start_seek_loc)
self._total_tokens = state_dict.get('total_tokens', self.config.tokens_count_at_start)
self._train_epochs = state_dict.get('epochs', 0)
if self.reset_dataset_loc:
self._dataset_loc = 0
self._train_epochs = 0
@property
def seq_length(self):
return self.config.seq_length
@property
def tokenizer(self):
return self._tokenizer
@property
def text_processor(self):
return self._text_processor
@property
def dataset(self):
return self._dataset
@property
def vocab_size(self):
return len(self._tokenizer)
class JsonDataset(object):
""" JSON dataset, where each line of the data file contains a JSON
dictionary with text fields.
"""
@staticmethod
def get_default_config(updates=None):
config = ConfigDict()
config.path = ''
config.seq_length = 1024
config.batch_size = 8
config.always_start_with_bos = False
config.start_seek_loc = 0
config.example_index_at_start = 0
config.tokens_count_at_start = 0
config.tokenizer_processes = 1
config.tokenizer_parallel_chunk_size = 32
config.tokenizer_parallel_batch_size = 1024
config.throughput_average_window_size = 200
if updates is not None:
config.update(ConfigDict(updates).copy_and_resolve_references())
return config
def __init__(self, config, tokenizer, text_processor):
self.config = self.get_default_config(config)
assert self.config.path != ''
self._tokenizer = tokenizer
self._text_processor = text_processor
self._index = self.config.example_index_at_start
self._file_loc = self.config.start_seek_loc
self._total_tokens = self.config.tokens_count_at_start
def parse_json(self, line):
if not line or line == '\n':
return None
try:
data = json.loads(line)
except json.decoder.JSONDecodeError:
print(f'Error parsing json line:\n{line}')
return None
return data
def json_iterator(self):
with mlxu.open_file(self.config.path, 'r') as fin:
fin.seek(self._file_loc)
while True:
line = fin.readline()
self._file_loc = fin.tell()
if not line: # Reached EOF
self._index = 0
fin.seek(0)
continue
data = self.parse_json(line)
if data is not None:
# JSON parsing succeeded
yield data, self._file_loc, self._index
self._index += 1
def batched(self, iterator, batch_size):
batch = []
for example in iterator:
batch.append(example)
if len(batch) == batch_size:
yield batch
batch = []
if len(batch) > 0:
yield batch
def parallel_example_iterator(self):
if self.config.tokenizer_processes == 1:
for example, loc, index in self.json_iterator():
yield self.text_processor((example, loc, index), has_aux=True)
else:
process_pool = Pool(self.config.tokenizer_processes)
batched_iterator = self.batched(
self.json_iterator(), self.config.tokenizer_parallel_batch_size
)
with process_pool as pool:
map_fn = partial(self.text_processor, has_aux=True)
next_batch = pool.map_async(
map_fn, next(batched_iterator),
chunksize=self.config.tokenizer_parallel_chunk_size
)
while True:
current_batch = next_batch
next_batch = pool.map_async(
map_fn, next(batched_iterator),
chunksize=self.config.tokenizer_parallel_chunk_size
)
for example in current_batch.get():
yield example
def __iter__(self):
chunk_size = self.config.batch_size * self.config.seq_length
token_buffer = []
loss_mask_buffer = []
last_time = 0.0
step_times = []
start_time = time.time()
start_tokens = self._total_tokens
for tokens, loss_masks, loc, index in self.parallel_example_iterator():
token_buffer.extend(tokens)
loss_mask_buffer.extend(loss_masks)
while len(token_buffer) > chunk_size + 1:
self._total_tokens += chunk_size
step_times.append(time.time() - last_time)
last_time = time.time()
if len(step_times) > self.config.throughput_average_window_size:
step_times = step_times[-self.config.throughput_average_window_size:]
average_throughput = chunk_size / np.mean(step_times)
accumulated_throughput = (
(self._total_tokens - start_tokens) / (time.time() - start_time)
)
metrics = {
'dataset_file_loc': loc,
'dataset_example_index': index,
'dataset_total_tokens': self._total_tokens,
'dataset_accumulated_tps': accumulated_throughput,
'dataset_average_tps': average_throughput,
}
batch = {
'input_tokens': np.array(token_buffer[:chunk_size], dtype=np.int32).reshape(
self.config.batch_size, -1
),
'target_tokens': np.array(token_buffer[1:chunk_size + 1], dtype=np.int32).reshape(
self.config.batch_size, -1
),
'loss_masks': np.array(loss_mask_buffer[1:chunk_size + 1], dtype=np.float32).reshape(
self.config.batch_size, -1
),
}
if self.config.always_start_with_bos:
batch['input_tokens'][:, 0] = self.tokenizer.bos_token_id
yield batch, metrics
token_buffer = token_buffer[chunk_size:]
loss_mask_buffer = loss_mask_buffer[chunk_size:]
def get_state_dict(self):
return dict(
config=self.config,
index=self._index,
file_loc=self._file_loc,
total_tokens=self._total_tokens,
)
def load_state_dict(self, state_dict):
if 'config' in state_dict:
self.config.update(ConfigDict(state_dict['config']))
self._index = state_dict.get('index', self.config.example_index_at_start)
self._file_loc = state_dict.get('file_loc', self.config.start_seek_loc)
self._total_tokens = state_dict.get('total_tokens', self.config.tokens_count_at_start)
@property
def seq_length(self):
return self.config.seq_length
@property
def tokenizer(self):
return self._tokenizer
@property
def text_processor(self):
return self._text_processor
@property
def vocab_size(self):
return len(self.tokenizer)

403
EasyLM/jax_utils.py Normal file
View File

@@ -0,0 +1,403 @@
import os
import math
from typing import Any, Mapping, Text, Tuple, Union, NamedTuple
from functools import partial
import re
import dataclasses
import random
from ml_collections import ConfigDict
from ml_collections.config_dict.config_dict import placeholder
import flax
import jax
import jax.numpy as jnp
from jax.sharding import PartitionSpec as PS
from jax.sharding import Mesh
from jax.experimental import mesh_utils
from jax.experimental.pjit import with_sharding_constraint as _with_sharding_constraint
from jax.experimental.pjit import pjit
from jax.interpreters import pxla
import numpy as np
from transformers import FlaxLogitsWarper
class JaxRNG(object):
""" A convenient stateful Jax RNG wrapper. Can be used to wrap RNG inside
pure function.
"""
@classmethod
def from_seed(cls, seed):
return cls(jax.random.PRNGKey(seed))
def __init__(self, rng):
self.rng = rng
def __call__(self, keys=None):
if keys is None:
self.rng, split_rng = jax.random.split(self.rng)
return split_rng
elif isinstance(keys, int):
split_rngs = jax.random.split(self.rng, num=keys + 1)
self.rng = split_rngs[0]
return tuple(split_rngs[1:])
else:
split_rngs = jax.random.split(self.rng, num=len(keys) + 1)
self.rng = split_rngs[0]
return {key: val for key, val in zip(keys, split_rngs[1:])}
class JaxDistributedConfig(object):
""" Utility class for initializing JAX distributed. """
@staticmethod
def get_default_config(updates=None):
config = ConfigDict()
config.initialize_jax_distributed = False
config.coordinator_address = placeholder(str)
config.num_processes = placeholder(int)
config.process_id = placeholder(int)
config.local_device_ids = placeholder(str)
if updates is not None:
config.update(ConfigDict(updates).copy_and_resolve_references())
return config
@classmethod
def initialize(cls, config):
config = cls.get_default_config(config)
if config.initialize_jax_distributed:
if config.local_device_ids is not None:
local_device_ids = [int(x) for x in config.local_device_ids.split(',')]
else:
local_device_ids = None
jax.distributed.initialize(
coordinator_address=config.coordinator_address,
num_processes=config.num_processes,
process_id=config.process_id,
local_device_ids=local_device_ids,
)
class FlaxTemperatureLogitsWarper(FlaxLogitsWarper):
""" JIT traceable version of FlaxLogitsWarper that performs temperature scaling."""
def __init__(self, temperature):
self.temperature = temperature
def __call__(self, input_ids, scores, cur_len):
return scores / jnp.clip(self.temperature, a_min=1e-8)
def make_shard_and_gather_fns(partition_specs, dtype_specs=None):
""" Create pytree of sharding and gathering functions from pytree of
partition specs.
"""
float_dtypes = (jnp.bfloat16, jnp.float16, jnp.float32, jnp.float64)
def make_to_dtype_fn(dtype_spec):
def to_dtype(tensor):
if dtype_specs in float_dtypes and getattr(tensor, 'dtype', None) in float_dtypes:
# Convert all float tensors to the same dtype
return tensor.astype(dtype_specs)
elif hasattr(dtype_spec, 'dtype') and hasattr(tensor, 'dtype'):
return tensor.astype(dtype_spec.dtype)
return tensor
return to_dtype
def make_shard_fn(partition_spec, dtype_spec=None):
jax_shard_function = pjit(
make_to_dtype_fn(dtype_spec),
in_shardings=None,
out_shardings=partition_spec
)
def shard_fn(tensor):
return jax_shard_function(tensor).block_until_ready()
return shard_fn
def make_gather_fn(partition_spec, dtype_spec=None):
jax_gather_fn = pjit(
make_to_dtype_fn(dtype_spec),
in_shardings=partition_spec,
out_shardings=None
)
def gather_fn(tensor):
return jax.device_get(jax_gather_fn(tensor))
return gather_fn
if dtype_specs is None or dtype_specs in float_dtypes:
shard_fns = jax.tree_util.tree_map(make_shard_fn, partition_specs)
gather_fns = jax.tree_util.tree_map(make_gather_fn, partition_specs)
else:
shard_fns = jax.tree_util.tree_map(
make_shard_fn, partition_specs, dtype_specs
)
gather_fns = jax.tree_util.tree_map(
make_gather_fn, partition_specs, dtype_specs
)
return shard_fns, gather_fns
def set_random_seed(seed):
np.random.seed(seed)
random.seed(seed)
init_rng(seed)
def get_jax_mesh(axis_dims, names):
if axis_dims.startswith('!'):
# Allow splitting a physical mesh axis if needed
mesh_axis_splitting = True
axis_dims = axis_dims[1:]
else:
mesh_axis_splitting = False
if ':' in axis_dims:
dims = []
dim_names = []
for axis in axis_dims.split(','):
name, dim = axis.split(':')
assert name in names
dims.append(int(dim))
dim_names.append(name)
assert(set(dim_names) == set(names))
else:
dims = [int(x) for x in axis_dims.split(',')]
dim_names = names
assert len(dims) == len(names)
mesh_shape = np.arange(jax.device_count()).reshape(dims).shape
if mesh_axis_splitting:
physical_mesh = np.array(jax.devices()).reshape(mesh_shape)
else:
physical_mesh = mesh_utils.create_device_mesh(mesh_shape)
return Mesh(physical_mesh, dim_names)
def names_in_current_mesh(*names):
""" Check if current mesh axes contain these names. """
mesh_axis_names = pxla.thread_resources.env.physical_mesh.axis_names
return set(names) <= set(mesh_axis_names)
def get_names_from_parition_spec(partition_specs):
""" Return axis names from partition specs. """
names = set()
if isinstance(partition_specs, dict):
partition_specs = partition_specs.values()
for item in partition_specs:
if item is None:
continue
elif isinstance(item, str):
names.add(item)
else:
names.update(get_names_from_parition_spec(item))
return list(names)
def with_sharding_constraint(x, partition_specs):
""" A smarter version of with_sharding_constraint that only applies the
constraint if the current mesh contains the axes in the partition specs.
"""
axis_names = get_names_from_parition_spec(partition_specs)
if names_in_current_mesh(*axis_names):
x = _with_sharding_constraint(x, partition_specs)
return x
def wrap_function_with_rng(rng):
""" To be used as decorator, automatically bookkeep a RNG for the wrapped function. """
def wrap_function(function):
def wrapped(*args, **kwargs):
nonlocal rng
rng, split_rng = jax.random.split(rng)
return function(split_rng, *args, **kwargs)
return wrapped
return wrap_function
def init_rng(seed):
global jax_utils_rng
jax_utils_rng = JaxRNG.from_seed(seed)
def next_rng(*args, **kwargs):
global jax_utils_rng
return jax_utils_rng(*args, **kwargs)
def get_metrics(metrics, unreplicate=False, stack=False):
if unreplicate:
metrics = flax.jax_utils.unreplicate(metrics)
metrics = jax.device_get(metrics)
if stack:
return jax.tree_map(lambda *args: np.stack(args), *metrics)
else:
return {key: float(val) for key, val in metrics.items()}
def mse_loss(val, target, valid=None):
if valid is None:
valid = jnp.ones((*target.shape[:2], 1))
valid = valid.astype(jnp.float32)
loss = jnp.mean(
jnp.where(
valid > 0.0,
jnp.square(val - target),
0.0
)
)
return loss
def cross_entropy_loss_and_accuracy(logits, tokens, valid=None):
if valid is None:
valid = jnp.ones(tokens.shape[:2])
valid = valid.astype(jnp.float32)
valid_text_length = jnp.maximum(jnp.sum(valid, axis=-1), 1e-10)
logits = logits.astype(jnp.float32) # for numerical stability
token_log_prob = jnp.squeeze(
jnp.take_along_axis(
jax.nn.log_softmax(logits, axis=-1),
jnp.expand_dims(tokens, -1),
axis=-1,
),
-1,
)
token_log_prob = jnp.where(valid > 0.0, token_log_prob, jnp.array(0.0))
loss = -jnp.mean(jnp.sum(token_log_prob, axis=-1) / valid_text_length)
correct = jnp.where(
valid > 0.0,
jnp.argmax(logits, axis=-1) == tokens,
jnp.array(False)
)
accuracy = jnp.mean(jnp.sum(correct, axis=-1) / valid_text_length)
return loss, accuracy
def global_norm(tree):
""" Return the global L2 norm of a pytree. """
squared = jax.tree_util.tree_map(lambda x: jnp.sum(jnp.square(x)), tree)
flattened, _ = jax.flatten_util.ravel_pytree(squared)
return jnp.sqrt(jnp.sum(flattened))
def average_metrics(metrics):
with jax.spmd_mode("allow_all"):
return jax.tree_map(
lambda *args: jnp.mean(jnp.stack(args)),
*metrics
)
def get_float_dtype_by_name(dtype):
return {
'bf16': jnp.bfloat16,
'bfloat16': jnp.bfloat16,
'fp16': jnp.float16,
'float16': jnp.float16,
'fp32': jnp.float32,
'float32': jnp.float32,
'fp64': jnp.float64,
'float64': jnp.float64,
}[dtype]
def float_tensor_to_dtype(tensor, dtype):
if dtype is None or dtype == '':
return tensor
if isinstance(dtype, str):
dtype = get_float_dtype_by_name(dtype)
float_dtypes = (jnp.bfloat16, jnp.float16, jnp.float32, jnp.float64)
if getattr(tensor, 'dtype', None) in float_dtypes:
tensor = tensor.astype(dtype)
return tensor
def float_to_dtype(tree, dtype):
return jax.tree_util.tree_map(
partial(float_tensor_to_dtype, dtype=dtype), tree
)
def get_gradient_checkpoint_policy(name):
return {
'everything_saveable': jax.checkpoint_policies.everything_saveable,
'nothing_saveable': jax.checkpoint_policies.nothing_saveable,
'checkpoint_dots': jax.checkpoint_policies.checkpoint_dots,
'checkpoint_dots_with_no_batch_dims': jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims,
}[name]
def tree_path_to_string(path, sep=None):
keys = []
for key in path:
if isinstance(key, jax.tree_util.SequenceKey):
keys.append(str(key.idx))
elif isinstance(key, jax.tree_util.DictKey):
keys.append(str(key.key))
elif isinstance(key, jax.tree_util.GetAttrKey):
keys.append(str(key.name))
elif isinstance(key, jax.tree_util.FlattenedIndexKey):
keys.append(str(key.key))
else:
keys.append(str(key))
if sep is None:
return tuple(keys)
return sep.join(keys)
def flatten_tree(xs, is_leaf=None, sep=None):
flattened, _ = jax.tree_util.tree_flatten_with_path(xs, is_leaf=is_leaf)
output = {}
for key, val in flattened:
output[tree_path_to_string(key, sep=sep)] = val
return output
def named_tree_map(f, tree, *rest, is_leaf=None, sep=None):
""" An extended version of jax.tree_util.tree_map, where the mapped function
f takes both the name (path) and the tree leaf as input.
"""
return jax.tree_util.tree_map_with_path(
lambda path, x, *r: f(tree_path_to_string(path, sep=sep), x, *r),
tree, *rest,
is_leaf=is_leaf
)
def match_partition_rules(rules, params):
""" Returns a pytree of PartitionSpec according to rules. Supports handling
Flax TrainState and Optax optimizer state.
"""
def get_partition_spec(name, leaf):
if len(leaf.shape) == 0 or np.prod(leaf.shape) == 1:
""" Don't partition scalar values. """
return PS()
for rule, ps in rules:
if re.search(rule, name) is not None:
return ps
raise ValueError(f'Partition rule not found for param: {name}')
return named_tree_map(get_partition_spec, params, sep='/')
def get_weight_decay_mask(exclusions):
""" Return a weight decay mask function that computes the pytree masks
according to the given exclusion rules.
"""
def decay(name, _):
for rule in exclusions:
if re.search(rule, name) is not None:
return False
return True
def weight_decay_mask(params):
return named_tree_map(decay, params, sep='/')
return weight_decay_mask
def tree_apply(fns, tree):
""" Apply a pytree of functions to the pytree. """
return jax.tree_util.tree_map(lambda fn, x: fn(x), fns, tree)

View File

View File

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,396 @@
import pprint
from functools import partial
import numpy as np
import mlxu
import jax
import jax.numpy as jnp
from jax.experimental.pjit import pjit
from jax.sharding import PartitionSpec as PS
import flax
from flax import linen as nn
from flax.jax_utils import prefetch_to_device
from flax.training.train_state import TrainState
import optax
from transformers import GenerationConfig, FlaxLogitsProcessorList
from EasyLM.checkpoint import StreamingCheckpointer
from EasyLM.serving import LMServer
from EasyLM.jax_utils import (
JaxRNG, JaxDistributedConfig, next_rng, match_partition_rules, tree_apply,
set_random_seed, get_float_dtype_by_name, make_shard_and_gather_fns,
with_sharding_constraint, FlaxTemperatureLogitsWarper
)
from EasyLM.models.gptj.gptj_model import (
GPTJConfig, FlaxGPTJForCausalLMModule, FlaxGPTJForCausalLM
)
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
seed=42,
initialize_jax_distributed=False,
mesh_dim='1,-1,1',
dtype='bf16',
input_length=1024,
seq_length=2048,
top_k=50,
top_p=1.0,
do_sample=True,
num_beams=1,
add_bos_token=False,
load_gptj_config='',
load_checkpoint='',
tokenizer=GPTJConfig.get_tokenizer_config(),
lm_server=LMServer.get_default_config(),
jax_distributed=JaxDistributedConfig.get_default_config(),
)
def main(argv):
JaxDistributedConfig.initialize(FLAGS.jax_distributed)
set_random_seed(FLAGS.seed)
prefix_tokenizer = GPTJConfig.get_tokenizer(
FLAGS.tokenizer, truncation_side='left', padding_side='left'
)
tokenizer = GPTJConfig.get_tokenizer(
FLAGS.tokenizer, truncation_side='right', padding_side='right'
)
with jax.default_device(jax.devices("cpu")[0]):
gptj_config = GPTJConfig.load_config(FLAGS.load_gptj_config)
load_type, load_path = FLAGS.load_checkpoint.split('::', 1)
if load_type == 'huggingface':
params = gptj_config.load_pretrained(load_path)
else:
_, params = StreamingCheckpointer.load_trainstate_checkpoint(
FLAGS.load_checkpoint, disallow_trainstate=True
)
hf_model = FlaxGPTJForCausalLM(
gptj_config,
input_shape=(1, FLAGS.seq_length),
seed=FLAGS.seed,
_do_init=False
)
model_ps = match_partition_rules(
GPTJConfig.get_partition_rules(), params
)
shard_fns, _ = make_shard_and_gather_fns(
model_ps, get_float_dtype_by_name(FLAGS.dtype)
)
@partial(
pjit,
in_shardings=(model_ps, PS(), PS()),
out_shardings=(PS(), PS(), PS())
)
def forward_loglikelihood(params, rng, batch):
batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
rng_generator = JaxRNG(rng)
input_tokens = batch['input_tokens']
output_tokens = batch['output_tokens']
input_mask = batch['input_mask']
output_mask = batch['output_mask']
logits = hf_model.module.apply(
params, input_tokens, attention_mask=input_mask,
deterministic=True, rngs=rng_generator(gptj_config.rng_keys()),
).logits
if gptj_config.n_real_tokens is not None:
logits = logits.at[:, :, gptj_config.n_real_tokens:].set(-1e8)
loglikelihood = -optax.softmax_cross_entropy_with_integer_labels(
logits, output_tokens
)
loglikelihood = jnp.sum(loglikelihood * output_mask, axis=-1)
match_count = jnp.sum(
(jnp.argmax(logits, axis=-1) == output_tokens) * output_mask,
axis=-1
)
total = jnp.sum(output_mask, axis=-1)
is_greedy = match_count == total
return loglikelihood, is_greedy, rng_generator()
@partial(
pjit,
in_shardings=(model_ps, PS(), PS(), PS()),
out_shardings=(PS(), PS())
)
def forward_generate(params, rng, batch, temperature):
batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
rng_generator = JaxRNG(rng)
output = hf_model.generate(
batch['input_tokens'],
attention_mask=batch['attention_mask'],
params=params['params'],
prng_key=rng_generator(),
logits_processor=FlaxLogitsProcessorList(
[FlaxTemperatureLogitsWarper(temperature)]
),
generation_config=GenerationConfig(
max_new_tokens=FLAGS.seq_length - FLAGS.input_length,
pad_token_id=tokenizer.eos_token_id,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
do_sample=FLAGS.do_sample,
num_beams=FLAGS.num_beams,
top_k=FLAGS.top_k,
top_p=FLAGS.top_p,
)
).sequences[:, batch['input_tokens'].shape[1]:]
return output, rng_generator()
@partial(
pjit,
in_shardings=(model_ps, PS(), PS()),
out_shardings=(PS(), PS())
)
def forward_greedy_generate(params, rng, batch):
batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
rng_generator = JaxRNG(rng)
output = hf_model.generate(
batch['input_tokens'],
attention_mask=batch['attention_mask'],
params=params['params'],
prng_key=rng_generator(),
generation_config=GenerationConfig(
max_new_tokens=FLAGS.seq_length - FLAGS.input_length,
pad_token_id=tokenizer.eos_token_id,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
do_sample=False,
num_beams=1,
)
).sequences[:, batch['input_tokens'].shape[1]:]
return output, rng_generator()
mesh = GPTJConfig.get_jax_mesh(FLAGS.mesh_dim)
with mesh:
params = tree_apply(shard_fns, params)
sharded_rng = next_rng()
class ModelServer(LMServer):
@staticmethod
def loglikelihood(prefix_text, text):
nonlocal sharded_rng
prefix = prefix_tokenizer(
prefix_text,
padding='max_length',
truncation=True,
max_length=FLAGS.input_length,
return_tensors='np',
)
inputs = tokenizer(
text,
padding='max_length',
truncation=True,
max_length=FLAGS.seq_length - FLAGS.input_length,
return_tensors='np',
)
output_tokens = np.concatenate([prefix.input_ids, inputs.input_ids], axis=1)
bos_tokens = np.full(
(output_tokens.shape[0], 1), tokenizer.bos_token_id, dtype=np.int32
)
input_tokens = np.concatenate([bos_tokens, output_tokens[:, :-1]], axis=-1)
input_mask = np.concatenate(
[prefix.attention_mask, inputs.attention_mask], axis=1
)
if FLAGS.add_bos_token:
bos_mask = np.ones_like(input_mask[:, :1])
else:
bos_mask = np.zeros_like(input_mask[:, :1])
input_mask = np.concatenate([bos_mask, input_mask[:, :-1]], axis=1)
output_mask = np.concatenate(
[np.zeros_like(prefix.attention_mask), inputs.attention_mask], axis=1
)
batch = dict(
input_tokens=input_tokens,
output_tokens=output_tokens,
input_mask=input_mask,
output_mask=output_mask,
)
with mesh:
loglikelihood, is_greedy, sharded_rng = forward_loglikelihood(
params, sharded_rng, batch
)
loglikelihood, is_greedy = jax.device_get((loglikelihood, is_greedy))
return loglikelihood, is_greedy
@staticmethod
def loglikelihood_rolling(text):
nonlocal sharded_rng
inputs = tokenizer(
text,
padding='longest',
truncation=False,
max_length=np.iinfo(np.int32).max,
return_tensors='np',
)
batch_size = inputs.input_ids.shape[0]
output_tokens = inputs.input_ids
attention_mask = inputs.attention_mask
if output_tokens.shape[1] < FLAGS.seq_length:
padding_length = FLAGS.seq_length - output_tokens.shape[1]
pad_tokens = np.full(
(batch_size, padding_length), tokenizer.pad_token_id, dtype=np.int32
)
output_tokens = np.concatenate([output_tokens, pad_tokens], axis=-1)
pad_mask = np.zeros(
(batch_size, padding_length), dtype=inputs.attention_mask.dtype
)
attention_mask = np.concatenate([attention_mask, pad_mask], axis=-1)
bos_tokens = np.full(
(batch_size, 1), tokenizer.bos_token_id, dtype=np.int32
)
input_tokens = np.concatenate([bos_tokens, output_tokens[:, :-1]], axis=-1)
bos_mask = np.ones((batch_size, 1), dtype=inputs.attention_mask.dtype)
total_seq_length = output_tokens.shape[1]
total_loglikelihood = 0.0
total_is_greedy = True
# Sliding window
for i in range(0, total_seq_length, FLAGS.seq_length):
# Last window
if i + FLAGS.seq_length > total_seq_length:
last_output_mask = np.copy(attention_mask[:, -FLAGS.seq_length:])
last_output_mask[:, :i - total_seq_length] = 0.0
batch = dict(
input_tokens=input_tokens[:, -FLAGS.seq_length:],
output_tokens=output_tokens[:, -FLAGS.seq_length:],
input_mask=attention_mask[:, -FLAGS.seq_length:],
output_mask=last_output_mask,
)
# Normal window
else:
batch = dict(
input_tokens=input_tokens[:, i:i + FLAGS.seq_length],
output_tokens=output_tokens[:, i:i + FLAGS.seq_length],
input_mask=attention_mask[:, i:i + FLAGS.seq_length],
output_mask=attention_mask[:, i:i + FLAGS.seq_length],
)
with mesh:
loglikelihood, is_greedy, sharded_rng = forward_loglikelihood(
params, sharded_rng, batch
)
loglikelihood, is_greedy = jax.device_get((loglikelihood, is_greedy))
total_loglikelihood += loglikelihood
total_is_greedy = np.logical_and(is_greedy, total_is_greedy)
return total_loglikelihood, total_is_greedy
@staticmethod
def generate(text, temperature):
nonlocal sharded_rng
inputs = prefix_tokenizer(
text,
padding='max_length',
truncation=True,
max_length=FLAGS.input_length,
return_tensors='np',
)
input_tokens = inputs.input_ids
input_mask = inputs.attention_mask
if FLAGS.add_bos_token:
input_tokens[:, 0] = tokenizer.bos_token_id
input_mask[:, 0] = 1
batch = dict(
input_tokens=input_tokens,
attention_mask=input_mask,
)
with mesh:
output, sharded_rng = forward_generate(
params, sharded_rng, batch, temperature
)
output = jax.device_get(output)
output_text = []
for text in list(tokenizer.batch_decode(output)):
if tokenizer.eos_token in text:
text = text.split(tokenizer.eos_token, maxsplit=1)[0]
output_text.append(text)
return output_text
@staticmethod
def greedy_until(prefix_text, until, max_length):
nonlocal sharded_rng
all_outputs = []
for pf, ut in zip(prefix_text, until):
if isinstance(ut, str):
ut = [ut]
total_length = 0
total_generated = ''
while total_length < max_length:
pf_tokens = tokenizer(
pf,
padding=False,
truncation=False,
max_length=np.iinfo(np.int32).max,
return_tensors='np',
)
input_tokens = pf_tokens.input_ids
attention_mask = pf_tokens.attention_mask
if input_tokens.shape[1] < FLAGS.input_length:
extra = FLAGS.input_length - input_tokens.shape[1]
pad_tokens = np.full(
(1, extra), tokenizer.pad_token_id, dtype=np.int32
)
input_tokens = np.concatenate(
[pad_tokens, input_tokens], axis=1
)
pad_attention = np.zeros((1, extra), dtype=attention_mask.dtype)
attention_mask = np.concatenate(
[pad_attention, attention_mask], axis=1
)
elif input_tokens.shape[1] > FLAGS.input_length:
input_tokens = input_tokens[:, -FLAGS.input_length:]
attention_mask = attention_mask[:, -FLAGS.input_length:]
if FLAGS.add_bos_token:
input_tokens[:, 0] = tokenizer.bos_token_id
attention_mask[:, 0] = 1
batch = dict(input_tokens=input_tokens, attention_mask=attention_mask)
with mesh:
output, sharded_rng = forward_greedy_generate(
params, sharded_rng, batch
)
output = jax.device_get(output)
total_length += output.shape[1]
output_text = tokenizer.batch_decode(output)[0]
total_generated = total_generated + output_text
pf = pf + output_text
done = False
for s in ut:
if s in total_generated:
total_generated = total_generated.split(s, maxsplit=1)[0]
done = True
if done:
break
all_outputs.append(total_generated)
return all_outputs
server = ModelServer(FLAGS.lm_server)
server.run()
if __name__ == "__main__":
mlxu.run(main)

View File

@@ -0,0 +1,272 @@
import pprint
from functools import partial
from tqdm import tqdm, trange
import numpy as np
import mlxu
import jax
import jax.numpy as jnp
from jax.experimental.pjit import pjit, with_sharding_constraint
from jax.sharding import PartitionSpec as PS
from flax.training.train_state import TrainState
from EasyLM.data import DatasetFactory
from EasyLM.checkpoint import StreamingCheckpointer
from EasyLM.optimizers import OptimizerFactory
from EasyLM.jax_utils import (
JaxRNG, JaxDistributedConfig, next_rng, match_partition_rules,
cross_entropy_loss_and_accuracy, global_norm, get_float_dtype_by_name,
set_random_seed, average_metrics, get_weight_decay_mask,
make_shard_and_gather_fns, tree_apply
)
from EasyLM.models.gptj.gptj_model import GPTJConfig, FlaxGPTJForCausalLMModule
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
seed=42,
mesh_dim='1,-1,1',
dtype='fp32',
total_steps=10000,
load_gptj_config='',
update_gptj_config='',
load_checkpoint='',
load_dataset_state='',
log_freq=50,
save_model_freq=0,
save_milestone_freq=0,
eval_steps=0,
tokenizer=GPTJConfig.get_tokenizer_config(),
train_dataset=DatasetFactory.get_default_config(),
eval_dataset=DatasetFactory.get_default_config(),
optimizer=OptimizerFactory.get_default_config(),
checkpointer=StreamingCheckpointer.get_default_config(),
gptj=GPTJConfig.get_default_config(),
logger=mlxu.WandBLogger.get_default_config(),
log_all_worker=False,
jax_distributed=JaxDistributedConfig.get_default_config(),
)
def main(argv):
JaxDistributedConfig.initialize(FLAGS.jax_distributed)
variant = mlxu.get_user_flags(FLAGS, FLAGS_DEF)
flags_config_dict = mlxu.user_flags_to_config_dict(FLAGS, FLAGS_DEF)
logger = mlxu.WandBLogger(
config=FLAGS.logger,
variant=variant,
enable=FLAGS.log_all_worker or (jax.process_index() == 0),
)
set_random_seed(FLAGS.seed)
tokenizer = GPTJConfig.get_tokenizer(FLAGS.tokenizer)
dataset = DatasetFactory.load_dataset(FLAGS.train_dataset, tokenizer)
if FLAGS.load_dataset_state != '':
dataset.load_state_dict(mlxu.load_pickle(FLAGS.load_dataset_state))
if FLAGS.eval_steps > 0:
eval_dataset = DatasetFactory.load_dataset(
FLAGS.eval_dataset, dataset.tokenizer
)
eval_iterator = iter(eval_dataset)
seq_length = dataset.seq_length
if FLAGS.load_gptj_config != '':
gptj_config = GPTJConfig.load_config(FLAGS.load_gptj_config)
else:
gptj_config = GPTJConfig(**FLAGS.gptj)
if FLAGS.update_gptj_config != '':
gptj_config.update(dict(eval(FLAGS.update_gptj_config)))
gptj_config.update(dict(
bos_token_id=dataset.tokenizer.bos_token_id,
eos_token_id=dataset.tokenizer.eos_token_id,
))
if gptj_config.vocab_size < dataset.vocab_size:
gptj_config.update(dict(vocab_size=dataset.vocab_size))
model = FlaxGPTJForCausalLMModule(
gptj_config, dtype=get_float_dtype_by_name(FLAGS.dtype)
)
optimizer, optimizer_info = OptimizerFactory.get_optimizer(
FLAGS.optimizer,
get_weight_decay_mask(GPTJConfig.get_weight_decay_exclusions()),
)
def create_trainstate_from_params(params):
return TrainState.create(params=params, tx=optimizer, apply_fn=None)
def init_fn(rng):
rng_generator = JaxRNG(rng)
params = model.init(
input_ids=jnp.zeros((4, seq_length), dtype=jnp.int32),
position_ids=jnp.zeros((4, seq_length), dtype=jnp.int32),
attention_mask=jnp.ones((4, seq_length), dtype=jnp.int32),
rngs=rng_generator(gptj_config.rng_keys()),
)
return TrainState.create(params=params, tx=optimizer, apply_fn=None)
def train_step(train_state, rng, batch):
rng_generator = JaxRNG(rng)
batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
def loss_and_accuracy(params):
logits = model.apply(
params, batch['input_tokens'], deterministic=False,
rngs=rng_generator(gptj_config.rng_keys()),
).logits
return cross_entropy_loss_and_accuracy(
logits, batch['target_tokens'], batch['loss_masks']
)
grad_fn = jax.value_and_grad(loss_and_accuracy, has_aux=True)
(loss, accuracy), grads = grad_fn(train_state.params)
train_state = train_state.apply_gradients(grads=grads)
metrics = dict(
loss=loss,
accuracy=accuracy,
learning_rate=optimizer_info['learning_rate_schedule'](train_state.step),
gradient_norm=global_norm(grads),
param_norm=global_norm(train_state.params),
)
return train_state, rng_generator(), metrics
def eval_step(train_state, rng, batch):
rng_generator = JaxRNG(rng)
batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
logits = model.apply(
train_state.params, batch['input_tokens'], deterministic=True,
rngs=rng_generator(gptj_config.rng_keys()),
).logits
loss, accuracy = cross_entropy_loss_and_accuracy(
logits, batch['target_tokens'], batch['loss_masks']
)
metrics = dict(
eval_loss=loss,
eval_accuracy=accuracy,
)
return rng_generator(), metrics
train_state_shapes = jax.eval_shape(init_fn, next_rng())
train_state_partition = match_partition_rules(
GPTJConfig.get_partition_rules(), train_state_shapes
)
shard_fns, gather_fns = make_shard_and_gather_fns(
train_state_partition, train_state_shapes
)
checkpointer = StreamingCheckpointer(
FLAGS.checkpointer, logger.output_dir,
enable=jax.process_index() == 0,
)
sharded_init_fn = pjit(
init_fn,
in_shardings=PS(),
out_shardings=train_state_partition
)
sharded_create_trainstate_from_params = pjit(
create_trainstate_from_params,
in_shardings=(train_state_partition.params, ),
out_shardings=train_state_partition,
donate_argnums=(0, ),
)
sharded_train_step = pjit(
train_step,
in_shardings=(train_state_partition, PS(), PS()),
out_shardings=(train_state_partition, PS(), PS()),
donate_argnums=(0, 1),
)
sharded_eval_step = pjit(
eval_step,
in_shardings=(train_state_partition, PS(), PS()),
out_shardings=(PS(), PS()),
donate_argnums=(1,),
)
def save_checkpoint(train_state, milestone=False):
step = int(jax.device_get(train_state.step))
metadata = dict(
step=step,
variant=variant,
flags=flags_config_dict,
gptj_config=gptj_config.to_dict(),
)
checkpointer.save_all(
train_state=train_state,
gather_fns=gather_fns,
metadata=metadata,
dataset=dataset.get_state_dict(),
milestone=milestone,
)
mesh = GPTJConfig.get_jax_mesh(FLAGS.mesh_dim)
with mesh:
train_state, restored_params = None, None
if FLAGS.load_checkpoint != '':
load_type, load_path = FLAGS.load_checkpoint.split('::', 1)
if load_type == 'huggingface':
restored_params = tree_apply(
shard_fns.params, gptj_config.load_pretrained(load_path)
)
train_state = None
else:
train_state, restored_params = checkpointer.load_trainstate_checkpoint(
FLAGS.load_checkpoint, train_state_shapes, shard_fns
)
if train_state is None and restored_params is None:
# Initialize from scratch
train_state = sharded_init_fn(next_rng())
elif train_state is None and restored_params is not None:
# Restore from params but initialize train_state
train_state = sharded_create_trainstate_from_params(restored_params)
del restored_params
start_step = int(jax.device_get(train_state.step))
if FLAGS.save_model_freq > 0:
save_checkpoint(train_state)
sharded_rng = next_rng()
step_counter = trange(start_step, FLAGS.total_steps, ncols=0)
for step, (batch, dataset_metrics) in zip(step_counter, dataset):
train_state, sharded_rng, metrics = sharded_train_step(
train_state, sharded_rng, batch
)
if step % FLAGS.log_freq == 0:
if FLAGS.eval_steps > 0:
eval_metric_list = []
for _ in range(FLAGS.eval_steps):
eval_batch, _ = next(eval_iterator)
sharded_rng, eval_metrics = sharded_eval_step(
train_state, sharded_rng, eval_batch
)
eval_metric_list.append(eval_metrics)
metrics.update(average_metrics(eval_metric_list))
log_metrics = {"step": step}
log_metrics.update(metrics)
log_metrics.update(dataset_metrics)
log_metrics = jax.device_get(log_metrics)
logger.log(log_metrics)
tqdm.write("\n" + pprint.pformat(log_metrics) + "\n")
if FLAGS.save_milestone_freq > 0 and (step + 1) % FLAGS.save_milestone_freq == 0:
save_checkpoint(train_state, milestone=True)
elif FLAGS.save_model_freq > 0 and (step + 1) % FLAGS.save_model_freq == 0:
save_checkpoint(train_state)
if FLAGS.save_model_freq > 0:
save_checkpoint(train_state)
if __name__ == "__main__":
mlxu.run(main)

View File

@@ -0,0 +1,338 @@
# Copyright 2022 EleutherAI and The HuggingFace Inc. team. All rights reserved.
# Copyright 2023 Xinyang Geng
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This script converts LLaMA model checkpoint trained by EsayLM to the
# HuggingFace transformers LLaMA PyTorch format, which can then be loaded
# by HuggingFace transformers.
import gc
import json
import math
import os
import shutil
import numpy as np
import mlxu
import jax
import jax.numpy as jnp
import flax
from flax.traverse_util import flatten_dict
import torch
from transformers import LlamaConfig, LlamaForCausalLM
from EasyLM.checkpoint import StreamingCheckpointer
from EasyLM.jax_utils import float_tensor_to_dtype
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
load_checkpoint='',
tokenizer_path='',
model_size='13b',
output_dir='',
)
LLAMA_STANDARD_CONFIGS = {
'small': {
'vocab_size': 64256,
'dim': 768,
'intermediate_size': 3072,
'n_layers': 12,
'n_heads': 12,
'norm_eps': 1e-6,
},
'medium': {
'vocab_size': 64256,
'dim': 1024,
'intermediate_size': 4096,
'n_layers': 24,
'n_heads': 16,
'norm_eps': 1e-6,
},
'large': {
'vocab_size': 64256,
'dim': 1536,
'intermediate_size': 6144,
'n_layers': 24,
'n_heads': 16,
'norm_eps': 1e-6,
},
'xlarge': {
'vocab_size': 64256,
'dim': 2048,
'intermediate_size': 8192,
'n_layers': 24,
'n_heads': 32,
'norm_eps': 1e-6,
},
'1b': {
'vocab_size': 64256,
'dim': 2048,
'intermediate_size': 5504,
'n_layers': 22,
'n_heads': 16,
'norm_eps': 1e-6,
},
'3b': {
'vocab_size': 64256,
'dim': 3200,
'intermediate_size': 8640,
'n_layers': 26,
'n_heads': 32,
'norm_eps': 1e-6,
},
'7b': {
'vocab_size': 64256,
'dim': 4096,
'intermediate_size': 11008,
'n_layers': 32,
'n_heads': 32,
'norm_eps': 1e-6,
},
'13b': {
'vocab_size': 64256,
'dim': 5120,
'intermediate_size': 13824,
'n_layers': 40,
'n_heads': 40,
'norm_eps': 1e-6,
},
'30b': {
'vocab_size': 64256,
'dim': 6656,
'intermediate_size': 17920,
'n_layers': 60,
'n_heads': 52,
'norm_eps': 1e-6,
},
'65b': {
'vocab_size': 64256,
'dim': 8192,
'intermediate_size': 22016,
'n_layers': 80,
'n_heads': 64,
'norm_eps': 1e-5,
},
}
def match_keywords(string, positives, negatives):
for positive in positives:
if positive not in string:
return False
for negative in negatives:
if negative in string:
return False
return True
def load_and_convert_checkpoint(path):
_, flax_params = StreamingCheckpointer.load_trainstate_checkpoint(path)
flax_params = flatten_dict(flax_params['params'], sep='.')
torch_params = {}
for key, tensor in flax_params.items():
if match_keywords(key, ["kernel"], ["norm", 'ln_f']):
tensor = tensor.T
torch_params[key] = torch.tensor(
float_tensor_to_dtype(tensor, 'fp32'), dtype=torch.float16
)
return torch_params
def read_json(path):
with open(path, "r") as f:
return json.load(f)
def write_json(text, path):
with open(path, "w") as f:
json.dump(text, f)
def write_model(loaded, model_path, model_size):
os.makedirs(model_path, exist_ok=True)
tmp_model_path = os.path.join(model_path, "tmp")
os.makedirs(tmp_model_path, exist_ok=True)
params = LLAMA_STANDARD_CONFIGS[model_size]
n_layers = params["n_layers"]
n_heads = params["n_heads"]
dim = params["dim"]
dims_per_head = dim // n_heads
base = 10000.0
inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head))
# permute for sliced rotary
def permute(w):
return w.view(n_heads, dim // n_heads // 2, 2, dim).transpose(1, 2).reshape(dim, dim)
param_count = 0
index_dict = {"weight_map": {}}
for layer_i in range(n_layers):
filename = f"pytorch_model-{layer_i + 1}-of-{n_layers + 1}.bin"
state_dict = {
f"model.layers.{layer_i}.self_attn.q_proj.weight": permute(
loaded[f"transformer.h.{layer_i}.attention.wq.kernel"]
),
f"model.layers.{layer_i}.self_attn.k_proj.weight": permute(
loaded[f"transformer.h.{layer_i}.attention.wk.kernel"]
),
f"model.layers.{layer_i}.self_attn.v_proj.weight": loaded[f"transformer.h.{layer_i}.attention.wv.kernel"],
f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded[f"transformer.h.{layer_i}.attention.wo.kernel"],
f"model.layers.{layer_i}.mlp.gate_proj.weight": loaded[f"transformer.h.{layer_i}.feed_forward.w1.kernel"],
f"model.layers.{layer_i}.mlp.down_proj.weight": loaded[f"transformer.h.{layer_i}.feed_forward.w2.kernel"],
f"model.layers.{layer_i}.mlp.up_proj.weight": loaded[f"transformer.h.{layer_i}.feed_forward.w3.kernel"],
f"model.layers.{layer_i}.input_layernorm.weight": loaded[f"transformer.h.{layer_i}.attention_norm.kernel"],
f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[f"transformer.h.{layer_i}.ffn_norm.kernel"],
}
state_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq
for k, v in state_dict.items():
index_dict["weight_map"][k] = filename
param_count += v.numel()
torch.save(state_dict, os.path.join(tmp_model_path, filename))
filename = f"pytorch_model-{n_layers + 1}-of-{n_layers + 1}.bin"
# Unsharded
state_dict = {
"model.embed_tokens.weight": loaded["transformer.wte.embedding"],
"model.norm.weight": loaded["transformer.ln_f.kernel"],
"lm_head.weight": loaded["lm_head.kernel"],
}
for k, v in state_dict.items():
index_dict["weight_map"][k] = filename
param_count += v.numel()
torch.save(state_dict, os.path.join(tmp_model_path, filename))
# Write configs
index_dict["metadata"] = {"total_size": param_count * 2}
write_json(index_dict, os.path.join(tmp_model_path, "pytorch_model.bin.index.json"))
config = LlamaConfig(
vocab_size=params["vocab_size"],
hidden_size=dim,
intermediate_size=params["intermediate_size"],
num_attention_heads=params["n_heads"],
num_hidden_layers=params["n_layers"],
rms_norm_eps=params["norm_eps"],
)
config.save_pretrained(tmp_model_path)
# Make space so we can load the model properly now.
del state_dict
del loaded
gc.collect()
print("Loading the checkpoint in a Llama model.")
model = LlamaForCausalLM.from_pretrained(tmp_model_path, torch_dtype=torch.float16)
# Avoid saving this as part of the config.
print("Model parameter count", model.num_parameters())
del model.config._name_or_path
print("Saving in the Transformers format.")
model.save_pretrained(model_path, safe_serialization=True)
shutil.rmtree(tmp_model_path)
def write_tokenizer(tokenizer_path, input_tokenizer_path):
print(f"Fetching the tokenizer from {input_tokenizer_path}.")
os.makedirs(tokenizer_path, exist_ok=True)
write_json(
{
"bos_token": {
"content": "<s>",
"lstrip": False,
"normalized": True,
"rstrip": False,
"single_word": False
},
"eos_token": {
"content": "</s>",
"lstrip": False,
"normalized": True,
"rstrip": False,
"single_word": False
},
"unk_token": {
"content": "<unk>",
"lstrip": False,
"normalized": True,
"rstrip": False,
"single_word": False
},
},
os.path.join(tokenizer_path, "special_tokens_map.json")
)
write_json(
{
"add_bos_token": True,
"add_eos_token": False,
"model_max_length": 2048,
"pad_token": None,
"sp_model_kwargs": {},
"tokenizer_class": "LlamaTokenizer",
"clean_up_tokenization_spaces": False,
"bos_token": {
"__type": "AddedToken",
"content": "<s>",
"lstrip": False,
"normalized": True,
"rstrip": False,
"single_word": False
},
"eos_token": {
"__type": "AddedToken",
"content": "</s>",
"lstrip": False,
"normalized": True,
"rstrip": False,
"single_word": False
},
"unk_token": {
"__type": "AddedToken",
"content": "<unk>",
"lstrip": False,
"normalized": True,
"rstrip": False,
"single_word": False
},
},
os.path.join(tokenizer_path, "tokenizer_config.json"),
)
shutil.copyfile(input_tokenizer_path, os.path.join(tokenizer_path, "tokenizer.model"))
def main(argv):
assert FLAGS.load_checkpoint != "" and FLAGS.output_dir != ""# and FLAGS.tokenizer_path != ""
assert FLAGS.model_size in LLAMA_STANDARD_CONFIGS
# write_tokenizer(
# tokenizer_path=FLAGS.output_dir,
# input_tokenizer_path=FLAGS.tokenizer_path,
# )
write_model(
load_and_convert_checkpoint(FLAGS.load_checkpoint),
model_path=FLAGS.output_dir,
model_size=FLAGS.model_size,
)
if __name__ == "__main__":
mlxu.run(main)

View File

@@ -0,0 +1,196 @@
"""
Usage:
python convert_hf_to_easylm.py \
--checkpoint_dir /path/hf_format_dir/ \
--output_file /path/easylm_format.stream \
--model_size 7b \
--streaming
"""
import time
from pathlib import Path
import argparse
import mlxu
import torch
import flax
from EasyLM.checkpoint import StreamingCheckpointer
LLAMA_STANDARD_CONFIGS = {
'1b': {
'dim': 2048,
'intermediate_size': 5504,
'n_layers': 22,
'n_heads': 16,
'norm_eps': 1e-6,
},
'3b': {
'dim': 3200,
'intermediate_size': 8640,
'n_layers': 26,
'n_heads': 32,
'norm_eps': 1e-6,
},
"7b": {
"dim": 4096,
"intermediate_size": 11008,
"n_layers": 32,
"n_heads": 32,
"norm_eps": 1e-6,
},
"13b": {
"dim": 5120,
"intermediate_size": 13824,
"n_layers": 40,
"n_heads": 40,
"norm_eps": 1e-6,
},
"30b": {
"dim": 6656,
"intermediate_size": 17920,
"n_layers": 60,
"n_heads": 52,
"norm_eps": 1e-6,
},
"65b": {
"dim": 8192,
"intermediate_size": 22016,
"n_layers": 80,
"n_heads": 64,
"norm_eps": 1e-5,
},
}
def inverse_permute(params, w):
n_layers = params["n_layers"]
n_heads = params["n_heads"]
dim = params["dim"]
reshaped_w = w.reshape(n_heads, 2, dim // n_heads // 2, dim)
transposed_w = reshaped_w.transpose(0, 2, 1, 3)
inverted_w = transposed_w.reshape(dim, dim)
return inverted_w
def main(args):
start = time.time()
params = LLAMA_STANDARD_CONFIGS[args.model_size]
ckpt_paths = sorted(Path(args.checkpoint_dir).glob("*.bin"))
ckpt = {}
for i, ckpt_path in enumerate(ckpt_paths):
checkpoint = torch.load(ckpt_path, map_location="cpu")
for k, v in checkpoint.items():
if k.startswith("model."):
k = k[6:]
ckpt[k] = v
print(f"Start convert weight to easylm format...")
jax_weights = {
"transformer": {
"wte": {"embedding": ckpt["embed_tokens.weight"].numpy()},
"ln_f": {"kernel": ckpt["norm.weight"].numpy()},
"h": {
"%d"
% (layer): {
"attention": {
"wq": {
"kernel": inverse_permute(
params,
ckpt[f"layers.{layer}.self_attn.q_proj.weight"].numpy(),
).transpose()
},
"wk": {
"kernel": inverse_permute(
params,
ckpt[f"layers.{layer}.self_attn.k_proj.weight"].numpy(),
).transpose()
},
"wv": {
"kernel": ckpt[f"layers.{layer}.self_attn.v_proj.weight"]
.numpy()
.transpose()
},
"wo": {
"kernel": ckpt[f"layers.{layer}.self_attn.o_proj.weight"]
.numpy()
.transpose()
},
},
"feed_forward": {
"w1": {
"kernel": ckpt[f"layers.{layer}.mlp.gate_proj.weight"]
.numpy()
.transpose()
},
"w2": {
"kernel": ckpt[f"layers.{layer}.mlp.down_proj.weight"]
.numpy()
.transpose()
},
"w3": {
"kernel": ckpt[f"layers.{layer}.mlp.up_proj.weight"]
.numpy()
.transpose()
},
},
"attention_norm": {
"kernel": ckpt[f"layers.{layer}.input_layernorm.weight"].numpy()
},
"ffn_norm": {
"kernel": ckpt[
f"layers.{layer}.post_attention_layernorm.weight"
].numpy()
},
}
for layer in range(params["n_layers"])
},
},
"lm_head": {"kernel": ckpt["lm_head.weight"].numpy().transpose()},
}
print(f"Convert weight to easylm format finished...")
print(f"Start to save...")
if args.streaming:
StreamingCheckpointer.save_train_state_to_file(jax_weights, args.output_file)
else:
with mlxu.open_file(args.output_file, "wb") as fout:
fout.write(flax.serialization.msgpack_serialize(jax_weights, in_place=True))
print(
f"Save finished!!! take time: {time.time() - start} save path: {args.output_file}"
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="hf to easylm format script")
parser.add_argument(
"--checkpoint_dir",
type=str,
help="Need to be converted model weight dir. it is a dir",
)
parser.add_argument(
"--output_file", type=str, help="Save model weight file path, it is a file."
)
parser.add_argument(
"--model_size",
type=str,
default="7b",
choices=["7b", "13b", "30b", "65b"],
help="model size",
)
parser.add_argument(
"--streaming",
action="store_true",
default=True,
help="whether is model weight saved stream format",
)
args = parser.parse_args()
print(f"checkpoint_dir: {args.checkpoint_dir}")
print(f"output_file: {args.output_file}")
print(f"model_size: {args.model_size}")
print(f"streaming: {args.streaming}")
main(args)

View File

@@ -0,0 +1,68 @@
# This script converts the standrd LLaMA PyTorch checkpoint released by Meta
# to the EasyLM checkpoint format. The converted checkpoint can then be loaded
# by EasyLM for fine-tuning or inference.
# This script is largely borrow from https://github.com/Sea-Snell/JAX_llama
from pathlib import Path
import json
import numpy as np
import torch
import flax
import mlxu
from EasyLM.checkpoint import StreamingCheckpointer
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
checkpoint_dir='',
output_file='',
streaming=True,
)
def main(argv):
ckpt_paths = sorted(Path(FLAGS.checkpoint_dir).glob("*.pth"))
ckpts = {}
for i, ckpt_path in enumerate(ckpt_paths):
checkpoint = torch.load(ckpt_path, map_location="cpu")
ckpts[int(ckpt_path.name.split('.', maxsplit=2)[1])] = checkpoint
ckpts = [ckpts[i] for i in sorted(list(ckpts.keys()))]
with open(Path(FLAGS.checkpoint_dir) / "params.json", "r") as f:
params = json.loads(f.read())
jax_weights = {
'transformer': {
'wte': {'embedding': np.concatenate([ckpt['tok_embeddings.weight'].numpy() for ckpt in ckpts], axis=1)},
'ln_f': {'kernel': ckpts[0]['norm.weight'].numpy()},
'h': {
'%d' % (layer): {
'attention': {
'wq': {'kernel': np.concatenate([ckpt['layers.%d.attention.wq.weight' % (layer)].numpy() for ckpt in ckpts], axis=0).transpose()},
'wk': {'kernel': np.concatenate([ckpt['layers.%d.attention.wk.weight' % (layer)].numpy() for ckpt in ckpts], axis=0).transpose()},
'wv': {'kernel': np.concatenate([ckpt['layers.%d.attention.wv.weight' % (layer)].numpy() for ckpt in ckpts], axis=0).transpose()},
'wo': {'kernel': np.concatenate([ckpt['layers.%d.attention.wo.weight' % (layer)].numpy() for ckpt in ckpts], axis=1).transpose()},
},
'feed_forward': {
'w1': {'kernel': np.concatenate([ckpt['layers.%d.feed_forward.w1.weight' % (layer)].numpy() for ckpt in ckpts], axis=0).transpose()},
'w2': {'kernel': np.concatenate([ckpt['layers.%d.feed_forward.w2.weight' % (layer)].numpy() for ckpt in ckpts], axis=1).transpose()},
'w3': {'kernel': np.concatenate([ckpt['layers.%d.feed_forward.w3.weight' % (layer)].numpy() for ckpt in ckpts], axis=0).transpose()},
},
'attention_norm': {'kernel': ckpts[0]['layers.%d.attention_norm.weight' % (layer)].numpy()},
'ffn_norm': {'kernel': ckpts[0]['layers.%d.ffn_norm.weight' % (layer)].numpy()},
}
for layer in range(params['n_layers'])},
},
'lm_head': {'kernel': np.concatenate([ckpt['output.weight'].numpy() for ckpt in ckpts], axis=0).transpose()},
}
if FLAGS.streaming:
StreamingCheckpointer.save_train_state_to_file(
jax_weights, FLAGS.output_file
)
else:
with mlxu.open_file(FLAGS.output_file, 'wb') as fout:
fout.write(flax.serialization.msgpack_serialize(jax_weights, in_place=True))
if __name__ == '__main__':
mlxu.run(main)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,386 @@
import pprint
from functools import partial
import numpy as np
import mlxu
import jax
import jax.numpy as jnp
from jax.experimental.pjit import pjit
from jax.sharding import PartitionSpec as PS
import optax
from transformers import GenerationConfig, FlaxLogitsProcessorList
from EasyLM.checkpoint import StreamingCheckpointer
from EasyLM.serving import LMServer
from EasyLM.jax_utils import (
JaxRNG, JaxDistributedConfig, next_rng, match_partition_rules, tree_apply,
set_random_seed, get_float_dtype_by_name, make_shard_and_gather_fns,
with_sharding_constraint, FlaxTemperatureLogitsWarper
)
from EasyLM.models.llama.llama_model import LLaMAConfig, FlaxLLaMAForCausalLM
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
seed=42,
initialize_jax_distributed=False,
mesh_dim='1,-1,1',
dtype='bf16',
input_length=1024,
seq_length=2048,
top_k=50,
top_p=1.0,
do_sample=True,
num_beams=1,
add_bos_token=True,
load_llama_config='',
load_checkpoint='',
tokenizer=LLaMAConfig.get_tokenizer_config(),
lm_server=LMServer.get_default_config(),
jax_distributed=JaxDistributedConfig.get_default_config(),
)
def main(argv):
JaxDistributedConfig.initialize(FLAGS.jax_distributed)
set_random_seed(FLAGS.seed)
prefix_tokenizer = LLaMAConfig.get_tokenizer(
FLAGS.tokenizer, truncation_side='left', padding_side='left'
)
tokenizer = LLaMAConfig.get_tokenizer(
FLAGS.tokenizer, truncation_side='right', padding_side='right'
)
with jax.default_device(jax.devices("cpu")[0]):
llama_config = LLaMAConfig.load_config(FLAGS.load_llama_config)
_, params = StreamingCheckpointer.load_trainstate_checkpoint(
FLAGS.load_checkpoint, disallow_trainstate=True
)
hf_model = FlaxLLaMAForCausalLM(
llama_config,
input_shape=(1, FLAGS.seq_length),
seed=FLAGS.seed,
_do_init=False
)
model_ps = match_partition_rules(
LLaMAConfig.get_partition_rules(), params
)
shard_fns, _ = make_shard_and_gather_fns(
model_ps, get_float_dtype_by_name(FLAGS.dtype)
)
@partial(
pjit,
in_shardings=(model_ps, PS(), PS()),
out_shardings=(PS(), PS(), PS())
)
def forward_loglikelihood(params, rng, batch):
batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
rng_generator = JaxRNG(rng)
input_tokens = batch['input_tokens']
output_tokens = batch['output_tokens']
input_mask = batch['input_mask']
output_mask = batch['output_mask']
logits = hf_model.module.apply(
params, input_tokens, attention_mask=input_mask,
deterministic=True, rngs=rng_generator(llama_config.rng_keys()),
).logits
# if llama_config.n_real_tokens is not None:
# logits = logits.at[:, :, llama_config.n_real_tokens:].set(-1e8)
loglikelihood = -optax.softmax_cross_entropy_with_integer_labels(
logits, output_tokens
)
loglikelihood = jnp.sum(loglikelihood * output_mask, axis=-1)
match_count = jnp.sum(
(jnp.argmax(logits, axis=-1) == output_tokens) * output_mask,
axis=-1
)
total = jnp.sum(output_mask, axis=-1)
is_greedy = match_count == total
return loglikelihood, is_greedy, rng_generator()
@partial(
pjit,
in_shardings=(model_ps, PS(), PS(), PS()),
out_shardings=(PS(), PS())
)
def forward_generate(params, rng, batch, temperature):
batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
rng_generator = JaxRNG(rng)
output = hf_model.generate(
batch['input_tokens'],
attention_mask=batch['attention_mask'],
params=params['params'],
prng_key=rng_generator(),
logits_processor=FlaxLogitsProcessorList(
[FlaxTemperatureLogitsWarper(temperature)]
),
generation_config=GenerationConfig(
max_new_tokens=FLAGS.seq_length - FLAGS.input_length,
pad_token_id=tokenizer.eos_token_id,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
do_sample=FLAGS.do_sample,
num_beams=FLAGS.num_beams,
top_k=FLAGS.top_k,
top_p=FLAGS.top_p,
)
).sequences[:, batch['input_tokens'].shape[1]:]
return output, rng_generator()
@partial(
pjit,
in_shardings=(model_ps, PS(), PS()),
out_shardings=(PS(), PS())
)
def forward_greedy_generate(params, rng, batch):
batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
rng_generator = JaxRNG(rng)
output = hf_model.generate(
batch['input_tokens'],
attention_mask=batch['attention_mask'],
params=params['params'],
prng_key=rng_generator(),
generation_config=GenerationConfig(
max_new_tokens=FLAGS.seq_length - FLAGS.input_length,
pad_token_id=tokenizer.eos_token_id,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
do_sample=False,
num_beams=1,
)
).sequences[:, batch['input_tokens'].shape[1]:]
return output, rng_generator()
mesh = LLaMAConfig.get_jax_mesh(FLAGS.mesh_dim)
with mesh:
params = tree_apply(shard_fns, params)
sharded_rng = next_rng()
class ModelServer(LMServer):
@staticmethod
def loglikelihood(prefix_text, text):
nonlocal sharded_rng
prefix = prefix_tokenizer(
prefix_text,
padding='max_length',
truncation=True,
max_length=FLAGS.input_length,
return_tensors='np',
)
inputs = tokenizer(
text,
padding='max_length',
truncation=True,
max_length=FLAGS.seq_length - FLAGS.input_length,
return_tensors='np',
)
output_tokens = np.concatenate([prefix.input_ids, inputs.input_ids], axis=1)
bos_tokens = np.full(
(output_tokens.shape[0], 1), tokenizer.bos_token_id, dtype=np.int32
)
input_tokens = np.concatenate([bos_tokens, output_tokens[:, :-1]], axis=-1)
input_mask = np.concatenate(
[prefix.attention_mask, inputs.attention_mask], axis=1
)
if FLAGS.add_bos_token:
bos_mask = np.ones_like(input_mask[:, :1])
else:
bos_mask = np.zeros_like(input_mask[:, :1])
input_mask = np.concatenate([bos_mask, input_mask[:, :-1]], axis=1)
output_mask = np.concatenate(
[np.zeros_like(prefix.attention_mask), inputs.attention_mask], axis=1
)
batch = dict(
input_tokens=input_tokens,
output_tokens=output_tokens,
input_mask=input_mask,
output_mask=output_mask,
)
with mesh:
loglikelihood, is_greedy, sharded_rng = forward_loglikelihood(
params, sharded_rng, batch
)
loglikelihood, is_greedy = jax.device_get((loglikelihood, is_greedy))
return loglikelihood, is_greedy
@staticmethod
def loglikelihood_rolling(text):
nonlocal sharded_rng
inputs = tokenizer(
text,
padding='longest',
truncation=False,
max_length=np.iinfo(np.int32).max,
return_tensors='np',
)
batch_size = inputs.input_ids.shape[0]
output_tokens = inputs.input_ids
attention_mask = inputs.attention_mask
if output_tokens.shape[1] < FLAGS.seq_length:
padding_length = FLAGS.seq_length - output_tokens.shape[1]
pad_tokens = np.full(
(batch_size, padding_length), tokenizer.pad_token_id, dtype=np.int32
)
output_tokens = np.concatenate([output_tokens, pad_tokens], axis=-1)
pad_mask = np.zeros(
(batch_size, padding_length), dtype=inputs.attention_mask.dtype
)
attention_mask = np.concatenate([attention_mask, pad_mask], axis=-1)
bos_tokens = np.full(
(batch_size, 1), tokenizer.bos_token_id, dtype=np.int32
)
input_tokens = np.concatenate([bos_tokens, output_tokens[:, :-1]], axis=-1)
bos_mask = np.ones((batch_size, 1), dtype=inputs.attention_mask.dtype)
total_seq_length = output_tokens.shape[1]
total_loglikelihood = 0.0
total_is_greedy = True
# Sliding window
for i in range(0, total_seq_length, FLAGS.seq_length):
# Last window
if i + FLAGS.seq_length > total_seq_length:
last_output_mask = np.copy(attention_mask[:, -FLAGS.seq_length:])
last_output_mask[:, :i - total_seq_length] = 0.0
batch = dict(
input_tokens=input_tokens[:, -FLAGS.seq_length:],
output_tokens=output_tokens[:, -FLAGS.seq_length:],
input_mask=attention_mask[:, -FLAGS.seq_length:],
output_mask=last_output_mask,
)
# Normal window
else:
batch = dict(
input_tokens=input_tokens[:, i:i + FLAGS.seq_length],
output_tokens=output_tokens[:, i:i + FLAGS.seq_length],
input_mask=attention_mask[:, i:i + FLAGS.seq_length],
output_mask=attention_mask[:, i:i + FLAGS.seq_length],
)
with mesh:
loglikelihood, is_greedy, sharded_rng = forward_loglikelihood(
params, sharded_rng, batch
)
loglikelihood, is_greedy = jax.device_get((loglikelihood, is_greedy))
total_loglikelihood += loglikelihood
total_is_greedy = np.logical_and(is_greedy, total_is_greedy)
return total_loglikelihood, total_is_greedy
@staticmethod
def generate(text, temperature):
nonlocal sharded_rng
inputs = prefix_tokenizer(
text,
padding='max_length',
truncation=True,
max_length=FLAGS.input_length,
return_tensors='np',
)
input_tokens = inputs.input_ids
input_mask = inputs.attention_mask
if FLAGS.add_bos_token:
input_tokens[:, 0] = tokenizer.bos_token_id
input_mask[:, 0] = 1
batch = dict(
input_tokens=input_tokens,
attention_mask=input_mask,
)
with mesh:
output, sharded_rng = forward_generate(
params, sharded_rng, batch, temperature
)
output = jax.device_get(output)
output_text = []
for text in list(tokenizer.batch_decode(output)):
if tokenizer.eos_token in text:
text = text.split(tokenizer.eos_token, maxsplit=1)[0]
output_text.append(text)
return output_text
@staticmethod
def greedy_until(prefix_text, until, max_length):
nonlocal sharded_rng
all_outputs = []
for pf, ut in zip(prefix_text, until):
if isinstance(ut, str):
ut = [ut]
total_length = 0
total_generated = ''
while total_length < max_length:
pf_tokens = tokenizer(
pf,
padding=False,
truncation=False,
max_length=np.iinfo(np.int32).max,
return_tensors='np',
)
input_tokens = pf_tokens.input_ids
attention_mask = pf_tokens.attention_mask
if input_tokens.shape[1] < FLAGS.input_length:
extra = FLAGS.input_length - input_tokens.shape[1]
pad_tokens = np.full(
(1, extra), tokenizer.pad_token_id, dtype=np.int32
)
input_tokens = np.concatenate(
[pad_tokens, input_tokens], axis=1
)
pad_attention = np.zeros((1, extra), dtype=attention_mask.dtype)
attention_mask = np.concatenate(
[pad_attention, attention_mask], axis=1
)
elif input_tokens.shape[1] > FLAGS.input_length:
input_tokens = input_tokens[:, -FLAGS.input_length:]
attention_mask = attention_mask[:, -FLAGS.input_length:]
if FLAGS.add_bos_token:
input_tokens[:, 0] = tokenizer.bos_token_id
attention_mask[:, 0] = 1
batch = dict(input_tokens=input_tokens, attention_mask=attention_mask)
with mesh:
output, sharded_rng = forward_greedy_generate(
params, sharded_rng, batch
)
output = jax.device_get(output)
total_length += output.shape[1]
output_text = tokenizer.batch_decode(output)[0]
total_generated = total_generated + output_text
pf = pf + output_text
done = False
for s in ut:
if s in total_generated:
total_generated = total_generated.split(s, maxsplit=1)[0]
done = True
if done:
break
all_outputs.append(total_generated)
return all_outputs
server = ModelServer(FLAGS.lm_server)
server.run()
if __name__ == "__main__":
mlxu.run(main)

View File

@@ -0,0 +1,268 @@
import pprint
from functools import partial
from tqdm import tqdm, trange
import numpy as np
import mlxu
import jax
import jax.numpy as jnp
from jax.experimental.pjit import pjit
from jax.sharding import PartitionSpec as PS
from flax.training.train_state import TrainState
from EasyLM.data import DatasetFactory
from EasyLM.checkpoint import StreamingCheckpointer
from EasyLM.optimizers import OptimizerFactory
from EasyLM.jax_utils import (
JaxRNG, JaxDistributedConfig, next_rng, match_partition_rules,
cross_entropy_loss_and_accuracy, global_norm, get_float_dtype_by_name,
set_random_seed, average_metrics, get_weight_decay_mask,
make_shard_and_gather_fns, with_sharding_constraint,
)
from EasyLM.models.llama.llama_model import (
LLaMAConfig, FlaxLLaMAForCausalLMModule
)
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
seed=42,
mesh_dim='1,-1,1',
dtype='fp32',
param_dtype='fp32',
total_steps=10000,
load_llama_config='',
update_llama_config='',
load_checkpoint='',
load_dataset_state='',
log_freq=50,
save_model_freq=0,
save_milestone_freq=0,
eval_freq=0,
tokenizer=LLaMAConfig.get_tokenizer_config(),
train_dataset=DatasetFactory.get_default_config(),
eval_dataset=DatasetFactory.get_default_config(),
optimizer=OptimizerFactory.get_default_config(),
checkpointer=StreamingCheckpointer.get_default_config(),
llama=LLaMAConfig.get_default_config(),
logger=mlxu.WandBLogger.get_default_config(),
log_all_worker=False,
jax_distributed=JaxDistributedConfig.get_default_config(),
)
def main(argv):
JaxDistributedConfig.initialize(FLAGS.jax_distributed)
variant = mlxu.get_user_flags(FLAGS, FLAGS_DEF)
flags_config_dict = mlxu.user_flags_to_config_dict(FLAGS, FLAGS_DEF)
logger = mlxu.WandBLogger(
config=FLAGS.logger,
variant=variant,
enable=FLAGS.log_all_worker or (jax.process_index() == 0),
)
set_random_seed(FLAGS.seed)
tokenizer = LLaMAConfig.get_tokenizer(FLAGS.tokenizer)
dataset = DatasetFactory.load_dataset(FLAGS.train_dataset, tokenizer)
if FLAGS.load_dataset_state != '':
dataset.load_state_dict(mlxu.load_pickle(FLAGS.load_dataset_state))
if FLAGS.eval_freq > 0:
eval_dataset = DatasetFactory.load_dataset(
FLAGS.eval_dataset, dataset.tokenizer, eval_dataset=True
)
seq_length = dataset.seq_length
if FLAGS.load_llama_config != '':
llama_config = LLaMAConfig.load_config(FLAGS.load_llama_config)
else:
llama_config = LLaMAConfig(**FLAGS.llama)
if FLAGS.update_llama_config != '':
llama_config.update(dict(eval(FLAGS.update_llama_config)))
llama_config.update(dict(
bos_token_id=dataset.tokenizer.bos_token_id,
eos_token_id=dataset.tokenizer.eos_token_id,
))
if llama_config.vocab_size < dataset.vocab_size:
print("Updating model config vocab size from", llama_config.vocab_size, "to", dataset.vocab_size)
llama_config.update(dict(vocab_size=dataset.vocab_size))
model = FlaxLLaMAForCausalLMModule(
llama_config, dtype=get_float_dtype_by_name(FLAGS.dtype), param_dtype=get_float_dtype_by_name(FLAGS.param_dtype)
)
optimizer, optimizer_info = OptimizerFactory.get_optimizer(
FLAGS.optimizer,
get_weight_decay_mask(LLaMAConfig.get_weight_decay_exclusions())
)
def create_trainstate_from_params(params):
return TrainState.create(params=params, tx=optimizer, apply_fn=None)
def init_fn(rng):
rng_generator = JaxRNG(rng)
params = model.init(
input_ids=jnp.zeros((4, seq_length), dtype=jnp.int32),
position_ids=jnp.zeros((4, seq_length), dtype=jnp.int32),
attention_mask=jnp.ones((4, seq_length), dtype=jnp.int32),
rngs=rng_generator(llama_config.rng_keys()),
)
return TrainState.create(params=params, tx=optimizer, apply_fn=None)
def train_step(train_state, rng, batch):
rng_generator = JaxRNG(rng)
batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
def loss_and_accuracy(params):
logits = model.apply(
params, batch['input_tokens'], deterministic=False,
rngs=rng_generator(llama_config.rng_keys()),
).logits
return cross_entropy_loss_and_accuracy(
logits, batch['target_tokens'], batch['loss_masks']
)
grad_fn = jax.value_and_grad(loss_and_accuracy, has_aux=True)
(loss, accuracy), grads = grad_fn(train_state.params)
train_state = train_state.apply_gradients(grads=grads)
metrics = dict(
loss=loss,
accuracy=accuracy,
learning_rate=optimizer_info['learning_rate_schedule'](train_state.step),
gradient_norm=global_norm(grads),
param_norm=global_norm(train_state.params),
)
return train_state, rng_generator(), metrics
def eval_step(train_state, rng, batch):
rng_generator = JaxRNG(rng)
batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
logits = model.apply(
train_state.params, batch['input_tokens'], deterministic=True,
rngs=rng_generator(llama_config.rng_keys()),
).logits
loss, accuracy = cross_entropy_loss_and_accuracy(
logits, batch['target_tokens'], batch['loss_masks']
)
metrics = dict(
eval_loss=loss,
eval_accuracy=accuracy,
)
return rng_generator(), metrics
train_state_shapes = jax.eval_shape(init_fn, next_rng())
train_state_partition = match_partition_rules(
LLaMAConfig.get_partition_rules(), train_state_shapes
)
shard_fns, gather_fns = make_shard_and_gather_fns(
train_state_partition, train_state_shapes
)
checkpointer = StreamingCheckpointer(
FLAGS.checkpointer, logger.output_dir,
enable=jax.process_index() == 0,
)
sharded_init_fn = pjit(
init_fn,
in_shardings=PS(),
out_shardings=train_state_partition
)
sharded_create_trainstate_from_params = pjit(
create_trainstate_from_params,
in_shardings=(train_state_partition.params, ),
out_shardings=train_state_partition,
donate_argnums=(0, ),
)
sharded_train_step = pjit(
train_step,
in_shardings=(train_state_partition, PS(), PS()),
out_shardings=(train_state_partition, PS(), PS()),
donate_argnums=(0, 1),
)
sharded_eval_step = pjit(
eval_step,
in_shardings=(train_state_partition, PS(), PS()),
out_shardings=(PS(), PS()),
donate_argnums=(1,),
)
def save_checkpoint(train_state, milestone=False):
step = int(jax.device_get(train_state.step))
metadata = dict(
step=step,
variant=variant,
flags=flags_config_dict,
llama_config=llama_config.to_dict(),
)
checkpointer.save_all(
train_state=train_state,
gather_fns=gather_fns,
metadata=metadata,
dataset=dataset.get_state_dict(),
milestone=milestone,
)
mesh = LLaMAConfig.get_jax_mesh(FLAGS.mesh_dim)
with mesh:
train_state, restored_params = None, None
if FLAGS.load_checkpoint != '':
train_state, restored_params = checkpointer.load_trainstate_checkpoint(
FLAGS.load_checkpoint, train_state_shapes, shard_fns
)
if train_state is None and restored_params is None:
# Initialize from scratch
train_state = sharded_init_fn(next_rng())
elif train_state is None and restored_params is not None:
# Restore from params but initialize train_state
train_state = sharded_create_trainstate_from_params(restored_params)
del restored_params
start_step = int(jax.device_get(train_state.step))
if FLAGS.save_model_freq > 0:
save_checkpoint(train_state)
sharded_rng = next_rng()
step_counter = trange(start_step, FLAGS.total_steps, ncols=0)
for step, (batch, dataset_metrics) in zip(step_counter, dataset):
train_state, sharded_rng, metrics = sharded_train_step(
train_state, sharded_rng, batch
)
if FLAGS.eval_freq > 0 and (step + 1) % FLAGS.eval_freq == 0:
eval_metric_list = []
eval_iterator = iter(eval_dataset)
for eval_batch, _ in eval_iterator:
sharded_rng, eval_metrics = sharded_eval_step(
train_state, sharded_rng, eval_batch
)
eval_metric_list.append(eval_metrics)
metrics.update(average_metrics(eval_metric_list))
if FLAGS.log_freq > 0 and (step + 1) % FLAGS.log_freq == 0:
log_metrics = {"step": step + 1}
log_metrics.update(metrics)
log_metrics.update(dataset_metrics)
log_metrics = jax.device_get(log_metrics)
logger.log(log_metrics)
tqdm.write("\n" + pprint.pformat(log_metrics) + "\n")
if FLAGS.save_milestone_freq > 0 and (step + 1) % FLAGS.save_milestone_freq == 0:
save_checkpoint(train_state, milestone=True)
elif FLAGS.save_model_freq > 0 and (step + 1) % FLAGS.save_model_freq == 0:
save_checkpoint(train_state)
if FLAGS.save_model_freq > 0:
save_checkpoint(train_state)
if __name__ == "__main__":
mlxu.run(main)

View File

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,307 @@
import dataclasses
import pprint
from functools import partial
import re
from tqdm import tqdm, trange
import numpy as np
import mlxu
import jax
import jax.numpy as jnp
from jax.experimental.pjit import pjit, with_sharding_constraint
from jax.sharding import PartitionSpec as PS
from flax.training.train_state import TrainState
from EasyLM.data import DatasetFactory
from EasyLM.checkpoint import StreamingCheckpointer
from EasyLM.optimizers import OptimizerFactory
from EasyLM.jax_utils import (
JaxRNG, JaxDistributedConfig, next_rng, match_partition_rules, get_float_dtype_by_name,
cross_entropy_loss_and_accuracy, named_tree_map, global_norm,
set_random_seed, average_metrics, get_weight_decay_mask,
make_shard_and_gather_fns, tree_apply
)
from EasyLM.models.roberta.roberta_model import (
RobertaConfig, FlaxRobertaForMaskedLMModule
)
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
seed=42,
mesh_dim='-1,1,1',
dtype='fp32',
mask_token_probability=0.15,
total_steps=10000,
load_roberta_config='',
update_roberta_config='',
load_checkpoint='',
load_dataset_state='',
log_freq=50,
save_model_freq=0,
save_milestone_freq=0,
eval_steps=0,
tokenizer=RobertaConfig.get_tokenizer_config(),
train_dataset=DatasetFactory.get_default_config(),
eval_dataset=DatasetFactory.get_default_config(),
optimizer=OptimizerFactory.get_default_config(),
checkpointer=StreamingCheckpointer.get_default_config(),
roberta=RobertaConfig.get_default_config(),
logger=mlxu.WandBLogger.get_default_config(),
log_all_worker=False,
jax_distributed=JaxDistributedConfig.get_default_config(),
)
def main(argv):
JaxDistributedConfig.initialize(FLAGS.jax_distributed)
variant = mlxu.get_user_flags(FLAGS, FLAGS_DEF)
flags_config_dict = mlxu.user_flags_to_config_dict(FLAGS, FLAGS_DEF)
logger = mlxu.WandBLogger(
config=FLAGS.logger,
variant=variant,
enable=FLAGS.log_all_worker or (jax.process_index() == 0),
)
set_random_seed(FLAGS.seed)
tokenizer = RobertaConfig.get_tokenizer(FLAGS.tokenizer)
dataset = DatasetFactory.load_dataset(FLAGS.train_dataset, tokenizer)
if FLAGS.load_dataset_state != '':
dataset.load_state_dict(mlxu.load_pickle(FLAGS.load_dataset_state))
if FLAGS.eval_steps > 0:
eval_dataset = DatasetFactory.load_dataset(
FLAGS.eval_dataset, dataset.tokenizer
)
eval_iterator = iter(eval_dataset)
seq_length = dataset.seq_length
if FLAGS.load_roberta_config != '':
roberta_config = RobertaConfig.load_config(FLAGS.load_roberta_config)
else:
roberta_config = RobertaConfig(**FLAGS.roberta)
if FLAGS.update_roberta_config != '':
roberta_config.update(dict(eval(FLAGS.update_roberta_config)))
roberta_config.update(dict(
bos_token_id=dataset.tokenizer.bos_token_id,
eos_token_id=dataset.tokenizer.eos_token_id,
pad_token_id=dataset.tokenizer.pad_token_id,
vocab_size=dataset.vocab_size,
))
model = FlaxRobertaForMaskedLMModule(
roberta_config, dtype=get_float_dtype_by_name(FLAGS.dtype)
)
optimizer, optimizer_info = OptimizerFactory.get_optimizer(
FLAGS.optimizer,
get_weight_decay_mask(RobertaConfig.get_weight_decay_exclusions()),
)
def create_trainstate_from_params(params):
return TrainState.create(params=params, tx=optimizer, apply_fn=None)
def init_fn(rng):
rng_generator = JaxRNG(rng)
params = model.init(
input_ids=jnp.zeros((4, seq_length), dtype=jnp.int32),
position_ids=jnp.zeros((4, seq_length), dtype=jnp.int32),
attention_mask=jnp.ones((4, seq_length), dtype=jnp.int32),
token_type_ids=None,
head_mask=None,
rngs=rng_generator(roberta_config.rng_keys()),
)
return TrainState.create(params=params, tx=optimizer, apply_fn=None)
def train_step(train_state, rng, batch):
rng_generator = JaxRNG(rng)
tokens = with_sharding_constraint(batch['target_tokens'], PS(('dp', 'fsdp')))
def loss_and_accuracy(params):
altered_tokens = jax.random.uniform(
rng_generator(), shape=tokens.shape
) < FLAGS.mask_token_probability
random_uniform = jax.random.uniform(rng_generator(), shape=tokens.shape)
altered_by_mask = altered_tokens & (random_uniform < 0.8)
altered_by_random = altered_tokens & (random_uniform >= 0.8) & (random_uniform < 0.9)
inputs = jnp.where(altered_by_mask, dataset.tokenizer.mask_token_id, tokens)
random_tokens = jax.random.randint(
rng_generator(), shape=tokens.shape, minval=0, maxval=dataset.vocab_size
)
inputs = jnp.where(altered_by_random, random_tokens, inputs)
logits = model.apply(
params, inputs,
attention_mask=jnp.ones_like(inputs),
token_type_ids=None,
position_ids=None,
head_mask=None,
deterministic=False,
rngs=rng_generator(roberta_config.rng_keys()),
).logits
return cross_entropy_loss_and_accuracy(logits, tokens, valid=altered_tokens)
grad_fn = jax.value_and_grad(loss_and_accuracy, has_aux=True)
(loss, accuracy), grads = grad_fn(train_state.params)
train_state = train_state.apply_gradients(grads=grads)
metrics = dict(
loss=loss,
accuracy=accuracy,
learning_rate=optimizer_info['learning_rate_schedule'](train_state.step),
gradient_norm=global_norm(grads),
param_norm=global_norm(train_state.params),
)
return train_state, rng_generator(), metrics
def eval_step(train_state, rng, batch):
rng_generator = JaxRNG(rng)
tokens = with_sharding_constraint(batch['target_tokens'], PS(('dp', 'fsdp')))
altered_tokens = jax.random.uniform(
rng_generator(), shape=tokens.shape
) < FLAGS.mask_token_probability
random_uniform = jax.random.uniform(rng_generator(), shape=tokens.shape)
altered_by_mask = altered_tokens & (random_uniform < 0.8)
altered_by_random = altered_tokens & (random_uniform >= 0.8) & (random_uniform < 0.9)
inputs = jnp.where(altered_by_mask, dataset.tokenizer.mask_token_id, tokens)
random_tokens = jax.random.randint(
rng_generator(), shape=tokens.shape, minval=0, maxval=dataset.vocab_size
)
inputs = jnp.where(altered_by_random, random_tokens, inputs)
logits = model.apply(
train_state.params, inputs,
attention_mask=jnp.ones_like(inputs),
token_type_ids=None,
position_ids=None,
head_mask=None,
deterministic=False,
rngs=rng_generator(roberta_config.rng_keys()),
).logits
loss, accuracy = cross_entropy_loss_and_accuracy(logits, tokens, valid=altered_tokens)
metrics = dict(
eval_loss=loss,
eval_accuracy=accuracy,
)
return rng_generator(), metrics
train_state_shapes = jax.eval_shape(init_fn, next_rng())
train_state_partition = match_partition_rules(
RobertaConfig.get_partition_rules(), train_state_shapes
)
shard_fns, gather_fns = make_shard_and_gather_fns(
train_state_partition, train_state_shapes
)
checkpointer = StreamingCheckpointer(
FLAGS.checkpointer, logger.output_dir,
enable=jax.process_index() == 0
)
sharded_init_fn = pjit(
init_fn,
in_shardings=PS(),
out_shardings=train_state_partition
)
sharded_create_trainstate_from_params = pjit(
create_trainstate_from_params,
in_shardings=(train_state_partition.params, ),
out_shardings=train_state_partition,
donate_argnums=(0, ),
)
sharded_train_step = pjit(
train_step,
in_shardings=(train_state_partition, PS(), PS()),
out_shardings=(train_state_partition, PS(), PS()),
donate_argnums=(0, 1),
)
sharded_eval_step = pjit(
eval_step,
in_shardings=(train_state_partition, PS(), PS()),
out_shardings=(PS(), PS()),
donate_argnums=(1,),
)
def save_checkpoint(train_state, milestone=False):
step = int(jax.device_get(train_state.step))
metadata = dict(
step=step,
variant=variant,
flags=flags_config_dict,
roberta_config=roberta_config.to_dict(),
)
checkpointer.save_all(
train_state=train_state,
gather_fns=gather_fns,
metadata=metadata,
dataset=dataset.get_state_dict(),
milestone=milestone,
)
mesh = RobertaConfig.get_jax_mesh(FLAGS.mesh_dim)
with mesh:
train_state, restored_params = None, None
if FLAGS.load_checkpoint != '':
load_type, load_path = FLAGS.load_checkpoint.split('::', 1)
if load_type == 'huggingface':
restored_params = tree_apply(
shard_fns.params, roberta_config.load_pretrained(load_path)
)
train_state = None
else:
train_state, restored_params = checkpointer.load_trainstate_checkpoint(
FLAGS.load_checkpoint, train_state_shapes, shard_fns
)
if train_state is None and restored_params is None:
# Initialize from scratch
train_state = sharded_init_fn(next_rng())
elif train_state is None and restored_params is not None:
# Restore from params but initialize train_state
train_state = sharded_create_trainstate_from_params(restored_params)
del restored_params
start_step = int(jax.device_get(train_state.step))
if FLAGS.save_model_freq > 0:
save_checkpoint(train_state)
sharded_rng = next_rng()
step_counter = trange(start_step, FLAGS.total_steps, ncols=0)
for step, (batch, dataset_metrics) in zip(step_counter, dataset):
train_state, sharded_rng, metrics = sharded_train_step(
train_state, sharded_rng, batch
)
if step % FLAGS.log_freq == 0:
if FLAGS.eval_steps > 0:
eval_metric_list = []
for _ in range(FLAGS.eval_steps):
eval_batch, _ = next(eval_iterator)
sharded_rng, eval_metrics = sharded_eval_step(
train_state, sharded_rng, eval_batch
)
eval_metric_list.append(eval_metrics)
metrics.update(average_metrics(eval_metric_list))
log_metrics = {"step": step}
log_metrics.update(metrics)
log_metrics.update(dataset_metrics)
log_metrics = jax.device_get(log_metrics)
logger.log(log_metrics)
tqdm.write("\n" + pprint.pformat(log_metrics) + "\n")
if FLAGS.save_milestone_freq > 0 and (step + 1) % FLAGS.save_milestone_freq == 0:
save_checkpoint(train_state, milestone=True)
elif FLAGS.save_model_freq > 0 and (step + 1) % FLAGS.save_model_freq == 0:
save_checkpoint(train_state)
if FLAGS.save_model_freq > 0:
save_checkpoint(train_state)
if __name__ == "__main__":
mlxu.run(main)

346
EasyLM/optimizers.py Normal file
View File

@@ -0,0 +1,346 @@
import os
import time
from typing import Any, Mapping, Text, Tuple, Union, NamedTuple
from functools import partial
import re
import dataclasses
import random
from ml_collections.config_dict import config_dict
from ml_collections import ConfigDict
import jax
import jax.numpy as jnp
import numpy as np
from absl import logging
import optax
from EasyLM.jax_utils import float_to_dtype
class OptimizerFactory(object):
""" Configurable optax optimizer factory. """
def __init__(self):
raise NotImplementedError
@staticmethod
def get_default_config(updates=None):
config = ConfigDict()
config.accumulate_gradient_steps = 1
config.type = 'adamw'
config.palm_optimizer = PalmOptimizerFactory.get_default_config()
config.adamw_optimizer = AdamWOptimizerFactory.get_default_config()
config.lion_optimizer = LionOptimizerFactory.get_default_config()
if updates is not None:
config.update(ConfigDict(updates).copy_and_resolve_references())
return config
@classmethod
def get_optimizer(cls, config, weight_decay_mask=None):
config = cls.get_default_config(config)
if config.type == 'palm':
optimizer, optimizer_info = PalmOptimizerFactory.get_optimizer(
config.palm_optimizer, weight_decay_mask
)
elif config.type == 'adamw':
optimizer, optimizer_info = AdamWOptimizerFactory.get_optimizer(
config.adamw_optimizer, weight_decay_mask
)
elif config.type == 'lion':
optimizer, optimizer_info = LionOptimizerFactory.get_optimizer(
config.lion_optimizer, weight_decay_mask
)
else:
raise ValueError(f'Unknown optimizer type: {config.type}')
if config.accumulate_gradient_steps > 1:
optimizer = optax.MultiSteps(
optimizer, config.accumulate_gradient_steps
)
return optimizer, optimizer_info
class PalmOptimizerFactory(object):
""" PaLM optimizer factory. This optimizer implements the optimizer
described in the PaLM paper: https://arxiv.org/abs/2204.02311
"""
def __init__(self):
raise NotImplementedError
@staticmethod
def get_default_config(updates=None):
config = ConfigDict()
config.lr = 0.01
config.lr_warmup_steps = 10000
config.b1 = 0.9
config.b2 = 0.99
config.clip_gradient = 1.0
config.weight_decay = 1e-4
config.bf16_momentum = False
if updates is not None:
config.update(ConfigDict(updates).copy_and_resolve_references())
return config
@classmethod
def get_optimizer(cls, config, weight_decay_mask=None):
config = cls.get_default_config(config)
def learning_rate_schedule(step):
multiplier = config.lr / 0.01
return multiplier / jnp.sqrt(jnp.maximum(step, config.lr_warmup_steps))
def weight_decay_schedule(step):
multiplier = config.weight_decay / 1e-4
return -multiplier * jnp.square(learning_rate_schedule(step))
optimizer_info = dict(
learning_rate_schedule=learning_rate_schedule,
weight_decay_schedule=weight_decay_schedule,
)
optimizer = optax.chain(
optax.clip_by_global_norm(config.clip_gradient),
optax.adafactor(
learning_rate=learning_rate_schedule,
multiply_by_parameter_scale=True,
momentum=config.b1,
decay_rate=config.b2,
factored=False,
clipping_threshold=None,
dtype_momentum=jnp.bfloat16 if config.bf16_momentum else jnp.float32,
),
optax_add_scheduled_weight_decay(
weight_decay_schedule, weight_decay_mask
)
)
return optimizer, optimizer_info
class AdamWOptimizerFactory(object):
""" AdamW optimizer with cosine schedule. """
def __init__(self):
raise NotImplementedError
@staticmethod
def get_default_config(updates=None):
config = ConfigDict()
config.init_lr = 0.0
config.end_lr = 0.001
config.lr = 0.01
config.lr_warmup_steps = 2000
config.lr_decay_steps = 500000
config.b1 = 0.9
config.b2 = 0.95
config.clip_gradient = 1.0
config.weight_decay = 1e-4
config.bf16_momentum = False
config.multiply_by_parameter_scale = False
if updates is not None:
config.update(ConfigDict(updates).copy_and_resolve_references())
return config
@classmethod
def get_optimizer(cls, config, weight_decay_mask=None):
config = cls.get_default_config(config)
learning_rate_schedule = optax.warmup_cosine_decay_schedule(
init_value=config.init_lr,
peak_value=config.lr,
warmup_steps=config.lr_warmup_steps,
decay_steps=config.lr_decay_steps,
end_value=config.end_lr,
)
optimizer_info = dict(
learning_rate_schedule=learning_rate_schedule,
)
if config.multiply_by_parameter_scale:
optimizer = optax.chain(
optax.clip_by_global_norm(config.clip_gradient),
optax.adafactor(
learning_rate=learning_rate_schedule,
multiply_by_parameter_scale=True,
momentum=config.b1,
decay_rate=config.b2,
factored=False,
clipping_threshold=None,
dtype_momentum=jnp.bfloat16 if config.bf16_momentum else jnp.float32,
),
optax_add_scheduled_weight_decay(
lambda step: -learning_rate_schedule(step) * config.weight_decay,
weight_decay_mask
)
)
else:
optimizer = optax.chain(
optax.clip_by_global_norm(config.clip_gradient),
optax.adamw(
learning_rate=learning_rate_schedule,
weight_decay=config.weight_decay,
b1=config.b1,
b2=config.b2,
mask=weight_decay_mask,
mu_dtype=jnp.bfloat16 if config.bf16_momentum else jnp.float32,
),
)
return optimizer, optimizer_info
class LionOptimizerFactory(object):
""" Lion optimizer with cosine schedule. """
def __init__(self):
raise NotImplementedError
@staticmethod
def get_default_config(updates=None):
config = ConfigDict()
config.init_lr = 0.0
config.end_lr = 0.0001
config.lr = 0.001
config.lr_warmup_steps = 60000
config.lr_constant_steps = 840000
config.lr_decay_steps = 100000
config.b1 = 0.9
config.b2 = 0.98
config.clip_gradient = 1.0
config.weight_decay = 1e-3
config.bf16_momentum = False
config.lr_schedule_type = "warmup_cosine_decay_schedule"
config.lr_decay_rate = 0.98
if updates is not None:
config.update(ConfigDict(updates).copy_and_resolve_references())
return config
@classmethod
def get_optimizer(cls, config, weight_decay_mask=None):
config = cls.get_default_config(config)
if config.lr_schedule_type == "warmup_cosine_decay_schedule":
learning_rate_schedule = optax.warmup_cosine_decay_schedule(
init_value=config.init_lr,
peak_value=config.lr,
warmup_steps=config.lr_warmup_steps,
decay_steps=config.lr_decay_steps,
end_value=config.end_lr,
)
elif config.lr_schedule_type == "warmup_constant":
learning_rate_schedule = optax.join_schedules(
[
optax.linear_schedule(
init_value=config.init_lr,
end_value=config.lr,
transition_steps=config.lr_warmup_steps,
),
optax.constant_schedule(config.lr),
],
[config.lr_warmup_steps],
)
elif config.lr_schedule_type == "warmup_constant_linear_decay":
learning_rate_schedule = optax.join_schedules(
[
optax.linear_schedule(
init_value=config.init_lr,
end_value=config.lr,
transition_steps=config.lr_warmup_steps,
),
optax.constant_schedule(config.lr),
optax.linear_schedule(
init_value=config.lr,
end_value=config.end_lr,
transition_steps=config.lr_decay_steps,
)
],
[config.lr_warmup_steps, config.lr_constant_steps],
)
elif config.lr_schedule_type == "warmup_constant_exponential_decay":
learning_rate_schedule = optax.join_schedules(
[
optax.linear_schedule(
init_value=config.init_lr,
end_value=config.lr,
transition_steps=config.lr_warmup_steps,
),
optax.constant_schedule(config.lr),
optax.exponential_decay(
init_value=config.lr,
transition_steps=config.lr_decay_steps,
decay_rate=config.lr_decay_rate,
transition_begin=0,
staircase=False,
end_value=config.end_lr,
)
],
[config.lr_warmup_steps, config.lr_constant_steps],
)
elif config.lr_schedule_type == "exponential_decay":
learning_rate_schedule = optax.exponential_decay(
init_value=config.lr,
transition_steps=config.lr_decay_steps,
decay_rate=config.lr_decay_rate,
transition_begin=0,
staircase=False,
end_value=config.end_lr,
)
elif config.lr_schedule_type == "linear_decay":
learning_rate_schedule = optax.linear_schedule(
init_value=config.lr,
end_value=config.end_lr,
transition_steps=config.lr_decay_steps,
)
else:
raise ValueError('config.lr_schedule_type must be "warmup_cosine_decay_schedule", "warmup_constant", "warmup_constant_linear_decay", "warmup_constant_exponential_decay", "exponential_decay" or "linear_decay"')
optimizer_info = dict(
learning_rate_schedule=learning_rate_schedule,
)
optimizer = optax.chain(
optax.clip_by_global_norm(config.clip_gradient),
optax.lion(
learning_rate=learning_rate_schedule,
weight_decay=config.weight_decay,
b1=config.b1,
b2=config.b2,
mask=weight_decay_mask,
mu_dtype=jnp.bfloat16 if config.bf16_momentum else jnp.float32,
),
)
return optimizer, optimizer_info
class OptaxScheduledWeightDecayState(NamedTuple):
count: jax.Array
def optax_add_scheduled_weight_decay(schedule_fn, mask=None):
""" Apply weight decay with schedule. """
def init_fn(params):
del params
return OptaxScheduledWeightDecayState(count=jnp.zeros([], jnp.int32))
def update_fn(updates, state, params):
if params is None:
raise ValueError('Params cannot be None for weight decay!')
weight_decay = schedule_fn(state.count)
updates = jax.tree_util.tree_map(
lambda g, p: g + weight_decay * p, updates, params
)
return updates, OptaxScheduledWeightDecayState(
count=optax.safe_int32_increment(state.count)
)
if mask is not None:
return optax.masked(optax.GradientTransformation(init_fn, update_fn), mask)
return optax.GradientTransformation(init_fn, update_fn)

View File

View File

@@ -0,0 +1,150 @@
from functools import partial
from time import time
import os
import numpy as np
import jax
import jax.flatten_util
import jax.numpy as jnp
import mlxu
from EasyLM.bpt import blockwise_attn
from EasyLM.jax_utils import (
get_float_dtype_by_name, set_random_seed, next_rng, JaxRNG
)
FLAGS, _ = mlxu.define_flags_with_default(
seed=42,
dtype='fp32',
embed_dim=2048,
n_heads=16,
ref_attn_seq_len=2048,
eff_attn_seq_len=16384,
batch_size=1,
query_chunk_size=2048,
key_chunk_size=2048,
warmup_steps=40,
steps=200,
)
def main(argv):
def random_kqv(rng_key, seq_len):
rng_generator = JaxRNG(rng_key)
kqv = []
for i in range(3):
kqv.append(
jax.random.normal(
rng_generator(),
(FLAGS.batch_size, seq_len, FLAGS.n_heads, FLAGS.embed_dim // FLAGS.n_heads),
dtype=get_float_dtype_by_name(FLAGS.dtype)
)
)
return tuple(kqv)
def reference_attn(query, key, value):
dtype = get_float_dtype_by_name(FLAGS.dtype)
query = query / jnp.sqrt(query.shape[-1]).astype(dtype)
logits = jnp.einsum("bqhc,bkhc->bhqk", query, key)
mask_value = jnp.finfo(logits.dtype).min
_, q_seq_len, _, _ = query.shape
_, kv_seq_len, _, _ = key.shape
mask_shape = (q_seq_len, kv_seq_len)
row_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 0)
col_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 1)
causal_mask = (row_ids < col_ids)[None, None, :, :]
logits = logits + jnp.where(causal_mask, mask_value, 0.0)
weights = jax.nn.softmax(logits, axis=-1)
out = jnp.einsum("bhqk,bkhc->bqhc", weights, value)
return out
def efficient_attention(query, key, value):
dtype = get_float_dtype_by_name(FLAGS.dtype)
return blockwise_attn(
query, key, value,
bias=None,
deterministic=True,
dropout_rng=None,
attn_pdrop=0.0,
causal=True,
query_chunk_size=FLAGS.query_chunk_size,
key_chunk_size=FLAGS.key_chunk_size,
dtype=get_float_dtype_by_name(FLAGS.dtype),
policy=jax.checkpoint_policies.nothing_saveable(),
precision=None,
float32_logits=True,
prevent_cse=True,
)
@partial(jax.jit, static_argnums=(1,))
def reference_attn_forward_backward(rng_key, seq_len):
@partial(jax.grad, argnums=(0, 1, 2))
@partial(jax.checkpoint, policy=jax.checkpoint_policies.nothing_saveable())
def grad_fn(query, key, value):
out = reference_attn(query, key, value)
return jnp.mean(out)
query, key, value = random_kqv(rng_key, seq_len)
return jax.flatten_util.ravel_pytree(
grad_fn(query, key, value)[1]
)[0].mean()
@partial(jax.jit, static_argnums=(1,))
def efficient_attn_forward_backward(rng_key, seq_len):
@partial(jax.grad, argnums=(0, 1, 2))
def grad_fn(query, key, value):
out = efficient_attention(query, key, value)
return jnp.mean(out)
query, key, value = random_kqv(rng_key, seq_len)
return jax.flatten_util.ravel_pytree(
grad_fn(query, key, value)[1]
)[0].mean()
set_random_seed(FLAGS.seed)
jax.block_until_ready(reference_attn_forward_backward(next_rng(), FLAGS.ref_attn_seq_len))
jax.block_until_ready(efficient_attn_forward_backward(next_rng(), FLAGS.eff_attn_seq_len))
all_results = []
for i in range(FLAGS.warmup_steps):
all_results.append(reference_attn_forward_backward(next_rng(), FLAGS.ref_attn_seq_len))
jax.block_until_ready(all_results)
start_time = time()
all_results = []
for i in range(FLAGS.steps):
all_results.append(reference_attn_forward_backward(next_rng(), FLAGS.ref_attn_seq_len))
jax.block_until_ready(all_results)
elapsed_time_ref_attn = time() - start_time
print(f'Reference attention: {elapsed_time_ref_attn:.3f} seconds')
all_results = []
for i in range(FLAGS.warmup_steps):
all_results.append(efficient_attn_forward_backward(next_rng(), FLAGS.eff_attn_seq_len))
jax.block_until_ready(all_results)
start_time = time()
all_results = []
for i in range(FLAGS.steps):
all_results.append(efficient_attn_forward_backward(next_rng(), FLAGS.eff_attn_seq_len))
jax.block_until_ready(all_results)
elapsed_time_efficient_attn = time() - start_time
print(f'Efficient attention: {elapsed_time_efficient_attn:.3f} seconds')
flops_ratio = (FLAGS.eff_attn_seq_len / FLAGS.ref_attn_seq_len) ** 2
efficiency = elapsed_time_ref_attn / elapsed_time_efficient_attn * flops_ratio
print(f'Efficiency: {efficiency:.3f}')
if __name__ == '__main__':
mlxu.run(main)

View File

@@ -0,0 +1,42 @@
# This script converts model checkpoint trained by EsayLM to a standard
# mspack checkpoint that can be loaded by huggingface transformers or
# flax.serialization.msgpack_restore. Such conversion allows models to be
# used by other frameworks that integrate with huggingface transformers.
import pprint
from functools import partial
import os
import numpy as np
import mlxu
import jax.numpy as jnp
import flax.serialization
from EasyLM.checkpoint import StreamingCheckpointer
from EasyLM.jax_utils import float_to_dtype
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
load_checkpoint='',
output_file='',
streaming=False,
float_dtype='bf16',
)
def main(argv):
assert FLAGS.load_checkpoint != '' and FLAGS.output_file != '', 'input and output must be specified'
params = StreamingCheckpointer.load_trainstate_checkpoint(
FLAGS.load_checkpoint, disallow_trainstate=True
)[1]['params']
if FLAGS.streaming:
StreamingCheckpointer.save_train_state_to_file(
params, FLAGS.output_file, float_dtype=FLAGS.float_dtype
)
else:
params = float_to_dtype(params, FLAGS.float_dtype)
with mlxu.open_file(FLAGS.output, 'wb') as fout:
fout.write(flax.serialization.msgpack_serialize(params, in_place=True))
if __name__ == "__main__":
mlxu.run(main)

View File

@@ -0,0 +1,59 @@
# This script converts model checkpoint trained by EsayLM to a standard
# mspack checkpoint that can be loaded by huggingface transformers or
# flax.serialization.msgpack_restore. Such conversion allows models to be
# used by other frameworks that integrate with huggingface transformers.
import pprint
from functools import partial
import os
import numpy as np
import jax
import jax.numpy as jnp
import flax.serialization
import mlxu
from EasyLM.checkpoint import StreamingCheckpointer
from EasyLM.jax_utils import float_to_dtype
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
recover_diff=False,
load_base_checkpoint='',
load_target_checkpoint='',
output_file='',
streaming=True,
float_dtype='bf16',
)
def main(argv):
assert FLAGS.load_base_checkpoint != '' and FLAGS.load_target_checkpoint != ''
assert FLAGS.output_file != ''
base_params = StreamingCheckpointer.load_trainstate_checkpoint(
FLAGS.load_base_checkpoint, disallow_trainstate=True
)[1]['params']
target_params = StreamingCheckpointer.load_trainstate_checkpoint(
FLAGS.load_target_checkpoint, disallow_trainstate=True
)[1]['params']
if FLAGS.recover_diff:
params = jax.tree_util.tree_map(
lambda b, t: b + t, base_params, target_params
)
else:
params = jax.tree_util.tree_map(
lambda b, t: t - b, base_params, target_params
)
if FLAGS.streaming:
StreamingCheckpointer.save_train_state_to_file(
params, FLAGS.output_file, float_dtype=FLAGS.float_dtype
)
else:
params = float_to_dtype(params, FLAGS.float_dtype)
with mlxu.open_file(FLAGS.output, 'wb') as fout:
fout.write(flax.serialization.msgpack_serialize(params, in_place=True))
if __name__ == "__main__":
mlxu.run(main)

View File

@@ -0,0 +1,65 @@
# This script runs lm_eval_harness evaluations against a served language model.
# Typically, you need to run a language model server first, e.g.:
# python -m EasyLM.models.gptj.gptj_serve ...
import dataclasses
import pprint
from functools import partial
import os
from tqdm import tqdm, trange
import numpy as np
import mlxu
from flax.traverse_util import flatten_dict
from lm_eval import evaluator, tasks
from lm_eval.base import LM
from EasyLM.serving import LMClient
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
tasks='wsc,piqa,winogrande,openbookqa,logiqa',
shots=0,
limit=0,
write_out=False,
lm_client=LMClient.get_default_config(),
logger=mlxu.WandBLogger.get_default_config(),
)
class LMEvalHarnessInterface(LM):
def __init__(self, lm_client):
self.lm_client = lm_client
def greedy_until(self, inputs):
prefix, until = zip(*inputs)
return self.lm_client.greedy_until(prefix, until)
def loglikelihood_rolling(self, inputs):
loglikelihood, is_greedy = self.lm_client.loglikelihood_rolling(inputs)
return list(zip(loglikelihood, is_greedy))
def loglikelihood(self, inputs):
prefix, text = zip(*inputs)
loglikelihood, is_greedy = self.lm_client.loglikelihood(prefix, text)
return list(zip(loglikelihood, is_greedy))
def main(argv):
logger = mlxu.WandBLogger(
config=FLAGS.logger, variant=mlxu.get_user_flags(FLAGS, FLAGS_DEF)
)
model = LMEvalHarnessInterface(LMClient(FLAGS.lm_client))
task_list = FLAGS.tasks.split(',')
results = evaluator.evaluate(
model, tasks.get_task_dict(task_list), False, FLAGS.shots,
limit=None if FLAGS.limit <= 0 else FLAGS.limit,
write_out=FLAGS.write_out,
)
logger.log(flatten_dict(results['results'], sep='/'))
pprint.pprint(results)
if __name__ == "__main__":
mlxu.run(main)

View File

@@ -0,0 +1,52 @@
import json
import mlxu
from EasyLM.serving import LMClient
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
input_file='',
output_file='',
prefix_field='prefix',
text_field='text',
until_field='until',
eval_type='loglikelihood',
lm_client=LMClient.get_default_config(),
)
def main(argv):
lm_client = LMClient(FLAGS.lm_client)
with mlxu.open_file(FLAGS.input_file, 'r') as fin:
input_data = json.load(fin)
if FLAGS.eval_type == 'loglikelihood':
prefix = input_data[FLAGS.prefix_field]
text = input_data[FLAGS.text_field]
loglikelihoods, is_greedys = lm_client.loglikelihood(prefix, text)
output_data = {
'loglikelihood': loglikelihoods,
'is_greedy': is_greedys,
}
elif FLAGS.eval_type == 'loglikelihood_rolling':
text = input_data[FLAGS.text_field]
loglikelihoods, is_greedys = lm_client.loglikelihood_rolling(text)
output_data = {
'loglikelihood': loglikelihoods,
'is_greedy': is_greedys,
}
elif FLAGS.eval_type == 'greedy_until':
prefix = input_data[FLAGS.prefix_field]
until = input_data[FLAGS.until_field]
output_data = {'output_text': lm_client.greedy_until(prefix, until)}
elif FLAGS.eval_type == 'generate':
prefix = input_data[FLAGS.prefix_field]
output_data = {'output_text': lm_client.generate(prefix)}
else:
raise ValueError(f'Unknown eval_type: {FLAGS.eval_type}')
with mlxu.open_file(FLAGS.output_file, 'w') as fout:
json.dump(output_data, fout)
if __name__ == "__main__":
mlxu.run(main)

566
EasyLM/serving.py Normal file
View File

@@ -0,0 +1,566 @@
import dataclasses
import pprint
from functools import partial
import re
import os
from threading import Lock
import urllib
import time
from typing import List, Optional, Union
from pydantic import BaseModel
import absl.logging
from tqdm import tqdm, trange
import numpy as np
import mlxu
from ml_collections import ConfigDict
import uvicorn
from fastapi import FastAPI
import gradio as gr
import requests
from requests.exceptions import Timeout, ConnectionError
class InferenceRequest(BaseModel):
prefix_text: Optional[List[str]] = None
text: Optional[List[str]] = None
until: Optional[Union[List[str], List[List[str]]]] = None
temperature: Optional[float] = None
class ChatRequest(BaseModel):
prompt: str
context: str = ''
temperature: Optional[float] = None
class LMServer(object):
""" HTTP server for serving langauge models. """
@staticmethod
def get_default_config(updates=None):
config = ConfigDict()
config.host = '0.0.0.0'
config.port = 5007
config.batch_size = 1
config.logging = False
config.pre_compile = 'loglikelihood'
config.default_temperature = 1.0
config.greedy_until_max_length = 5000
config.prepend_to_prefix = ''
config.append_to_prefix = ''
config.prepend_to_text = ''
config.append_to_text = ''
config.chat_prepend_text = ''
config.chat_user_prefix = ''
config.chat_user_suffix = ''
config.chat_lm_prefix = ''
config.chat_lm_suffix = ''
config.notes = ''
if updates is not None:
config.update(ConfigDict(updates).copy_and_resolve_references())
return config
def __init__(self, config):
self.config = self.get_default_config(config)
self.lock = Lock()
self.app = FastAPI()
self.app.post('/loglikelihood')(self.serve_loglikelihood)
self.app.post('/loglikelihood-rolling')(self.serve_loglikelihood_rolling)
self.app.post('/generate')(self.serve_generate)
self.app.post('/greedy-until')(self.serve_greedy_until)
self.app.post('/chat')(self.serve_chat)
self.app.get('/ready')(self.serve_ready)
self.app = gr.mount_gradio_app(self.app, self.create_chat_app(), '/')
@staticmethod
def loglikelihood(prefix_text, text):
raise NotImplementedError()
@staticmethod
def loglikelihood_rolling(text):
raise NotImplementedError()
@staticmethod
def generate(text, temperature):
raise NotImplementedError()
@staticmethod
def greedy_until(prefix_text, until, max_length):
raise NotImplementedError()
@staticmethod
def to_list(x):
if isinstance(x, np.ndarray):
return x.tolist()
return x
def serve_ready(self):
return 'Ready!\n'
def serve_loglikelihood(self, data: InferenceRequest):
with self.lock:
if self.config.logging:
absl.logging.info(
'\n========= Serving Log Likelihood Request ========= \n'
+ pprint.pformat(data) + '\n'
)
if data.prefix_text is None:
data.prefix_text = ['' for _ in data.text]
prefix_text = [
self.config.prepend_to_prefix + p + self.config.append_to_prefix
for p in data.prefix_text
]
text = [
self.config.prepend_to_text + t + self.config.append_to_text
for t in data.text
]
log_likelihood = []
is_greedy = []
for i in trange(0, len(text), self.config.batch_size, ncols=0):
batch_prefix_text = prefix_text[i:i + self.config.batch_size]
batch_text = text[i:i + self.config.batch_size]
batch_size = len(batch_text)
if batch_size < self.config.batch_size:
extra = self.config.batch_size - batch_size
batch_prefix_text.extend(['a' for _ in range(extra)])
batch_text.extend(['a' for _ in range(extra)])
batch_log_likelihood, batch_is_greedy = self.loglikelihood(
batch_prefix_text, batch_text
)
batch_log_likelihood = self.to_list(batch_log_likelihood)
batch_is_greedy = self.to_list(batch_is_greedy)
log_likelihood.extend(batch_log_likelihood[:batch_size])
is_greedy.extend(batch_is_greedy[:batch_size])
output = {
'prefix_text': data.prefix_text,
'text': data.text,
'log_likelihood': log_likelihood,
'is_greedy': is_greedy,
}
if self.config.logging:
absl.logging.info(
'\n========= Output ========= \n'
+ pprint.pformat(output) + '\n'
)
return output
def serve_loglikelihood_rolling(self, data: InferenceRequest):
with self.lock:
if self.config.logging:
absl.logging.info(
'\n========= Serving Log Likelihood Request ========= \n'
+ pprint.pformat(data) + '\n'
)
text = [
self.config.prepend_to_text + t + self.config.append_to_text
for t in data.text
]
log_likelihood = []
is_greedy = []
for i in trange(0, len(text), self.config.batch_size, ncols=0):
batch_text = text[i:i + self.config.batch_size]
batch_size = len(batch_text)
if batch_size < self.config.batch_size:
extra = self.config.batch_size - batch_size
batch_text.extend(['a' for _ in range(extra)])
batch_log_likelihood, batch_is_greedy = self.loglikelihood_rolling(
batch_text
)
batch_log_likelihood = self.to_list(batch_log_likelihood)
batch_is_greedy = self.to_list(batch_is_greedy)
log_likelihood.extend(batch_log_likelihood[:batch_size])
is_greedy.extend(batch_is_greedy[:batch_size])
output = {
'text': data.text,
'log_likelihood': log_likelihood,
'is_greedy': is_greedy,
}
if self.config.logging:
absl.logging.info(
'\n========= Output ========= \n'
+ pprint.pformat(output) + '\n'
)
return output
def serve_generate(self, data: InferenceRequest):
with self.lock:
if self.config.logging:
absl.logging.info(
'\n========= Serving Generate Request ========= \n'
+ pprint.pformat(data) + '\n'
)
prefix_text = [
self.config.prepend_to_prefix + p + self.config.append_to_prefix
for p in data.prefix_text
]
if data.temperature is None:
data.temperature = self.config.default_temperature
output_text = []
for i in trange(0, len(prefix_text), self.config.batch_size, ncols=0):
batch_prefix_text = prefix_text[i:i + self.config.batch_size]
batch_size = len(batch_prefix_text)
if batch_size < self.config.batch_size:
extra = self.config.batch_size - batch_size
batch_prefix_text.extend(['a' for _ in range(extra)])
batch_output_text = self.generate(
batch_prefix_text,
temperature=data.temperature,
)
output_text.extend(self.to_list(batch_output_text)[:batch_size])
output = {
'prefix_text': data.prefix_text,
'output_text': output_text,
'temperature': data.temperature,
}
if self.config.logging:
absl.logging.info(
'\n========= Output ========= \n'
+ pprint.pformat(output) + '\n'
)
return output
def serve_greedy_until(self, data: InferenceRequest):
with self.lock:
if self.config.logging:
absl.logging.info(
'\n========= Serving Greedy Until Request ========= \n'
+ pprint.pformat(data) + '\n'
)
prefix_text = [
self.config.prepend_to_prefix + p + self.config.append_to_prefix
for p in data.prefix_text
]
until = data.until
max_length = self.config.greedy_until_max_length
output_text = []
for i in range(0, len(prefix_text), self.config.batch_size):
batch_prefix_text = prefix_text[i:i + self.config.batch_size]
batch_until = until[i:i + self.config.batch_size]
batch_size = len(batch_prefix_text)
batch_output_text = self.greedy_until(batch_prefix_text, batch_until, max_length)
output_text.extend(self.to_list(batch_output_text)[:batch_size])
output = {
'prefix_text': data.prefix_text,
'until': data.until,
'max_length': max_length,
'output_text': output_text,
}
if self.config.logging:
absl.logging.info(
'\n========= Output ========= \n'
+ pprint.pformat(output) + '\n'
)
return output
def process_chat(self, prompt, context, temperature):
context = (
context + self.config.chat_user_prefix
+ prompt + self.config.chat_user_suffix
+ self.config.chat_lm_prefix
)
response = self.generate(
[self.config.chat_prepend_text + context],
temperature=float(temperature),
)[0]
context = context + response + self.config.chat_lm_suffix
return response, context
def serve_chat(self, data: ChatRequest):
if data.temperature is None:
data.temperature = self.config.default_temperature
response, context = self.process_chat(
data.prompt, data.context,
temperature=data.temperature,
)
return {
'response': response,
'context': context,
'temperature': data.temperature,
}
def create_chat_app(self):
with gr.Blocks(analytics_enabled=False, title='EasyLM Chat') as gradio_chatbot:
gr.Markdown('# Chatbot Powered by [EasyLM](https://github.com/young-geng/EasyLM)')
gr.Markdown(self.config.notes)
chatbot = gr.Chatbot(label='Chat history')
msg = gr.Textbox(
placeholder='Type your message here...',
show_label=False
)
with gr.Row():
send = gr.Button('Send')
regenerate = gr.Button('Regenerate', interactive=False)
clear = gr.Button('Reset')
temp_slider = gr.Slider(
label='Temperature', minimum=0, maximum=2.0,
value=self.config.default_temperature
)
context_state = gr.State(['', ''])
def user_fn(user_message, history, context):
return {
msg: gr.update(value='', interactive=False),
clear: gr.update(interactive=False),
send: gr.update(interactive=False),
regenerate: gr.update(interactive=False),
chatbot: history + [[user_message, None]],
context_state: [context[1], context[1]],
}
def model_fn(history, context, temperature):
history[-1][1], new_context = self.process_chat(
history[-1][0], context[0], temperature
)
return {
msg: gr.update(value='', interactive=True),
clear: gr.update(interactive=True),
send: gr.update(interactive=True),
chatbot: history,
context_state: [context[0], new_context],
regenerate: gr.update(interactive=True),
}
def regenerate_fn():
return {
msg: gr.update(value='', interactive=False),
clear: gr.update(interactive=False),
send: gr.update(interactive=False),
regenerate: gr.update(interactive=False),
}
def clear_fn():
return {
chatbot: None,
msg: '',
context_state: ['', ''],
regenerate: gr.update(interactive=False),
}
msg.submit(
user_fn,
inputs=[msg, chatbot, context_state],
outputs=[msg, clear, send, chatbot, context_state, regenerate],
queue=False
).then(
model_fn,
inputs=[chatbot, context_state, temp_slider],
outputs=[msg, clear, send, chatbot, context_state, regenerate],
queue=True
)
send.click(
user_fn,
inputs=[msg, chatbot, context_state],
outputs=[msg, clear, send, chatbot, context_state, regenerate],
queue=False
).then(
model_fn,
inputs=[chatbot, context_state, temp_slider],
outputs=[msg, clear, send, chatbot, context_state, regenerate],
queue=True
)
regenerate.click(
regenerate_fn,
inputs=None,
outputs=[msg, clear, send, regenerate],
queue=False
).then(
model_fn,
inputs=[chatbot, context_state, temp_slider],
outputs=[msg, clear, send, chatbot, context_state, regenerate],
queue=True
)
clear.click(
clear_fn,
inputs=None,
outputs=[chatbot, msg, context_state, regenerate],
queue=False
)
gradio_chatbot.queue(concurrency_count=1)
return gradio_chatbot
def run(self):
if self.config.pre_compile != '':
if self.config.pre_compile == 'all':
pre_compile = ['loglikelihood', 'generate', 'greedy_until', 'chat']
else:
pre_compile = self.config.pre_compile.split(',')
pre_compile_data = ['a' for _ in range(self.config.batch_size)]
for task in pre_compile:
if task == 'loglikelihood':
self.loglikelihood(pre_compile_data, pre_compile_data)
self.loglikelihood_rolling(pre_compile_data)
elif task == 'generate':
self.generate(pre_compile_data, 1.0)
elif task == 'greedy_until':
self.greedy_until(
pre_compile_data, pre_compile_data,
self.config.greedy_until_max_length
)
elif task == 'chat':
self.process_chat('a', 'a', 1.0)
else:
raise ValueError(f'Invalid precompile task: {task}!')
uvicorn.run(self.app, host=self.config.host, port=self.config.port)
class LMClient(object):
""" A simple client for the LM server. """
@staticmethod
def get_default_config(updates=None):
config = ConfigDict()
config.url = 'http://localhost:5007'
config.batch_size = 1
config.wait_for_ready = True
config.dummy = False
if updates is not None:
config.update(ConfigDict(updates).copy_and_resolve_references())
return config
def __init__(self, config=None):
self.config = self.get_default_config(config)
if self.config.wait_for_ready:
self.wait_for_ready()
def wait_for_ready(self):
if self.config.dummy:
return
while True:
try:
requests.get(urllib.parse.urljoin(self.config.url, 'ready'))
return
except (Timeout, ConnectionError) as e:
time.sleep(10)
@staticmethod
def batched(iterator, batch_size):
batch = []
for example in iterator:
batch.append(example)
if len(batch) == batch_size:
yield batch
batch = []
if len(batch) > 0:
yield batch
def loglikelihood(self, prefix, text):
prefix, text = list(prefix), list(text)
if self.config.dummy:
return [-1.0 for _ in text], [False for _ in text]
log_likelihood = []
is_greedy = []
batched_iterator = list(zip(
self.batched(prefix, self.config.batch_size),
self.batched(text, self.config.batch_size)
))
for batch_prefix, batch_text in tqdm(batched_iterator, ncols=0):
response = requests.post(
urllib.parse.urljoin(self.config.url, 'loglikelihood'),
json={'prefix_text': batch_prefix, 'text': batch_text}
).json()
log_likelihood.extend(response['log_likelihood'])
is_greedy.extend(response['is_greedy'])
return log_likelihood, is_greedy
def loglikelihood_rolling(self, text):
text = list(text)
if self.config.dummy:
return [-1.0 for _ in text], [False for _ in text]
log_likelihood = []
is_greedy = []
batched_iterator = list(self.batched(text, self.config.batch_size))
for batch_text in tqdm(batched_iterator, ncols=0):
response = requests.post(
urllib.parse.urljoin(self.config.url, 'loglikelihood-rolling'),
json={'text': batch_text}
).json()
log_likelihood.extend(response['log_likelihood'])
is_greedy.extend(response['is_greedy'])
return log_likelihood, is_greedy
def greedy_until(self, prefix, until):
prefix, until = list(prefix), list(until)
if self.config.dummy:
results = []
for u in until:
if isinstance(u, str):
results.append('dummy text ' + u)
else:
results.append('dummy text ' + u[0])
return results
batched_iterator = list(zip(
self.batched(prefix, self.config.batch_size),
self.batched(until, self.config.batch_size),
))
output_text = []
for batch_prefix, batch_until in tqdm(batched_iterator, ncols=0):
response = requests.post(
urllib.parse.urljoin(self.config.url, 'greedy-until'),
json={'prefix_text': batch_prefix, 'until': batch_until}
).json()
output_text.extend(response['output_text'])
return output_text
def generate(self, prefix, temperature=None):
prefix = list(prefix)
if self.config.dummy:
return ['' for _ in prefix]
output_text = []
batched_iterator = list(self.batched(prefix, self.config.batch_size))
for batch_prefix in tqdm(batched_iterator, ncols=0):
response = requests.post(
urllib.parse.urljoin(self.config.url, 'generate'),
json={
'prefix_text': batch_prefix,
'temperature': temperature,
}
).json()
output_text.extend(response['output_text'])
return output_text
def chat(self, prompt, context, temperature=None):
if self.config.dummy:
return ''
response = requests.post(
urllib.parse.urljoin(self.config.url, 'chat'),
json={
'prompt': prompt,
'context': context,
'temperature': temperature,
}
).json()
return response['response'], response['context']