初始化项目,由ModelHub XC社区提供模型
Model: Finnish-NLP/Ahma-7B Source: Original Platform
This commit is contained in:
0
EasyLM/scripts/__init__.py
Normal file
0
EasyLM/scripts/__init__.py
Normal file
150
EasyLM/scripts/benchmark_attention.py
Normal file
150
EasyLM/scripts/benchmark_attention.py
Normal file
@@ -0,0 +1,150 @@
|
||||
from functools import partial
|
||||
from time import time
|
||||
import os
|
||||
import numpy as np
|
||||
import jax
|
||||
import jax.flatten_util
|
||||
import jax.numpy as jnp
|
||||
import mlxu
|
||||
from EasyLM.bpt import blockwise_attn
|
||||
from EasyLM.jax_utils import (
|
||||
get_float_dtype_by_name, set_random_seed, next_rng, JaxRNG
|
||||
)
|
||||
|
||||
|
||||
FLAGS, _ = mlxu.define_flags_with_default(
|
||||
seed=42,
|
||||
dtype='fp32',
|
||||
embed_dim=2048,
|
||||
n_heads=16,
|
||||
ref_attn_seq_len=2048,
|
||||
eff_attn_seq_len=16384,
|
||||
batch_size=1,
|
||||
query_chunk_size=2048,
|
||||
key_chunk_size=2048,
|
||||
warmup_steps=40,
|
||||
steps=200,
|
||||
)
|
||||
|
||||
|
||||
def main(argv):
|
||||
|
||||
def random_kqv(rng_key, seq_len):
|
||||
rng_generator = JaxRNG(rng_key)
|
||||
kqv = []
|
||||
for i in range(3):
|
||||
kqv.append(
|
||||
jax.random.normal(
|
||||
rng_generator(),
|
||||
(FLAGS.batch_size, seq_len, FLAGS.n_heads, FLAGS.embed_dim // FLAGS.n_heads),
|
||||
dtype=get_float_dtype_by_name(FLAGS.dtype)
|
||||
)
|
||||
)
|
||||
return tuple(kqv)
|
||||
|
||||
def reference_attn(query, key, value):
|
||||
dtype = get_float_dtype_by_name(FLAGS.dtype)
|
||||
query = query / jnp.sqrt(query.shape[-1]).astype(dtype)
|
||||
logits = jnp.einsum("bqhc,bkhc->bhqk", query, key)
|
||||
mask_value = jnp.finfo(logits.dtype).min
|
||||
_, q_seq_len, _, _ = query.shape
|
||||
_, kv_seq_len, _, _ = key.shape
|
||||
mask_shape = (q_seq_len, kv_seq_len)
|
||||
row_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 0)
|
||||
col_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 1)
|
||||
causal_mask = (row_ids < col_ids)[None, None, :, :]
|
||||
logits = logits + jnp.where(causal_mask, mask_value, 0.0)
|
||||
weights = jax.nn.softmax(logits, axis=-1)
|
||||
out = jnp.einsum("bhqk,bkhc->bqhc", weights, value)
|
||||
return out
|
||||
|
||||
def efficient_attention(query, key, value):
|
||||
dtype = get_float_dtype_by_name(FLAGS.dtype)
|
||||
return blockwise_attn(
|
||||
query, key, value,
|
||||
bias=None,
|
||||
deterministic=True,
|
||||
dropout_rng=None,
|
||||
attn_pdrop=0.0,
|
||||
causal=True,
|
||||
query_chunk_size=FLAGS.query_chunk_size,
|
||||
key_chunk_size=FLAGS.key_chunk_size,
|
||||
dtype=get_float_dtype_by_name(FLAGS.dtype),
|
||||
policy=jax.checkpoint_policies.nothing_saveable(),
|
||||
precision=None,
|
||||
float32_logits=True,
|
||||
prevent_cse=True,
|
||||
)
|
||||
|
||||
|
||||
@partial(jax.jit, static_argnums=(1,))
|
||||
def reference_attn_forward_backward(rng_key, seq_len):
|
||||
@partial(jax.grad, argnums=(0, 1, 2))
|
||||
@partial(jax.checkpoint, policy=jax.checkpoint_policies.nothing_saveable())
|
||||
def grad_fn(query, key, value):
|
||||
out = reference_attn(query, key, value)
|
||||
return jnp.mean(out)
|
||||
|
||||
query, key, value = random_kqv(rng_key, seq_len)
|
||||
return jax.flatten_util.ravel_pytree(
|
||||
grad_fn(query, key, value)[1]
|
||||
)[0].mean()
|
||||
|
||||
@partial(jax.jit, static_argnums=(1,))
|
||||
def efficient_attn_forward_backward(rng_key, seq_len):
|
||||
@partial(jax.grad, argnums=(0, 1, 2))
|
||||
def grad_fn(query, key, value):
|
||||
out = efficient_attention(query, key, value)
|
||||
return jnp.mean(out)
|
||||
|
||||
query, key, value = random_kqv(rng_key, seq_len)
|
||||
return jax.flatten_util.ravel_pytree(
|
||||
grad_fn(query, key, value)[1]
|
||||
)[0].mean()
|
||||
|
||||
|
||||
set_random_seed(FLAGS.seed)
|
||||
|
||||
jax.block_until_ready(reference_attn_forward_backward(next_rng(), FLAGS.ref_attn_seq_len))
|
||||
jax.block_until_ready(efficient_attn_forward_backward(next_rng(), FLAGS.eff_attn_seq_len))
|
||||
|
||||
all_results = []
|
||||
for i in range(FLAGS.warmup_steps):
|
||||
all_results.append(reference_attn_forward_backward(next_rng(), FLAGS.ref_attn_seq_len))
|
||||
jax.block_until_ready(all_results)
|
||||
|
||||
start_time = time()
|
||||
all_results = []
|
||||
for i in range(FLAGS.steps):
|
||||
all_results.append(reference_attn_forward_backward(next_rng(), FLAGS.ref_attn_seq_len))
|
||||
|
||||
jax.block_until_ready(all_results)
|
||||
elapsed_time_ref_attn = time() - start_time
|
||||
print(f'Reference attention: {elapsed_time_ref_attn:.3f} seconds')
|
||||
|
||||
|
||||
all_results = []
|
||||
for i in range(FLAGS.warmup_steps):
|
||||
all_results.append(efficient_attn_forward_backward(next_rng(), FLAGS.eff_attn_seq_len))
|
||||
jax.block_until_ready(all_results)
|
||||
|
||||
|
||||
start_time = time()
|
||||
all_results = []
|
||||
for i in range(FLAGS.steps):
|
||||
all_results.append(efficient_attn_forward_backward(next_rng(), FLAGS.eff_attn_seq_len))
|
||||
|
||||
jax.block_until_ready(all_results)
|
||||
elapsed_time_efficient_attn = time() - start_time
|
||||
print(f'Efficient attention: {elapsed_time_efficient_attn:.3f} seconds')
|
||||
|
||||
flops_ratio = (FLAGS.eff_attn_seq_len / FLAGS.ref_attn_seq_len) ** 2
|
||||
efficiency = elapsed_time_ref_attn / elapsed_time_efficient_attn * flops_ratio
|
||||
print(f'Efficiency: {efficiency:.3f}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
mlxu.run(main)
|
||||
|
||||
|
||||
|
||||
42
EasyLM/scripts/convert_checkpoint.py
Normal file
42
EasyLM/scripts/convert_checkpoint.py
Normal file
@@ -0,0 +1,42 @@
|
||||
# This script converts model checkpoint trained by EsayLM to a standard
|
||||
# mspack checkpoint that can be loaded by huggingface transformers or
|
||||
# flax.serialization.msgpack_restore. Such conversion allows models to be
|
||||
# used by other frameworks that integrate with huggingface transformers.
|
||||
|
||||
import pprint
|
||||
from functools import partial
|
||||
import os
|
||||
import numpy as np
|
||||
import mlxu
|
||||
import jax.numpy as jnp
|
||||
import flax.serialization
|
||||
from EasyLM.checkpoint import StreamingCheckpointer
|
||||
from EasyLM.jax_utils import float_to_dtype
|
||||
|
||||
|
||||
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
|
||||
load_checkpoint='',
|
||||
output_file='',
|
||||
streaming=False,
|
||||
float_dtype='bf16',
|
||||
)
|
||||
|
||||
|
||||
def main(argv):
|
||||
assert FLAGS.load_checkpoint != '' and FLAGS.output_file != '', 'input and output must be specified'
|
||||
params = StreamingCheckpointer.load_trainstate_checkpoint(
|
||||
FLAGS.load_checkpoint, disallow_trainstate=True
|
||||
)[1]['params']
|
||||
|
||||
if FLAGS.streaming:
|
||||
StreamingCheckpointer.save_train_state_to_file(
|
||||
params, FLAGS.output_file, float_dtype=FLAGS.float_dtype
|
||||
)
|
||||
else:
|
||||
params = float_to_dtype(params, FLAGS.float_dtype)
|
||||
with mlxu.open_file(FLAGS.output, 'wb') as fout:
|
||||
fout.write(flax.serialization.msgpack_serialize(params, in_place=True))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mlxu.run(main)
|
||||
59
EasyLM/scripts/diff_checkpoint.py
Normal file
59
EasyLM/scripts/diff_checkpoint.py
Normal file
@@ -0,0 +1,59 @@
|
||||
# This script converts model checkpoint trained by EsayLM to a standard
|
||||
# mspack checkpoint that can be loaded by huggingface transformers or
|
||||
# flax.serialization.msgpack_restore. Such conversion allows models to be
|
||||
# used by other frameworks that integrate with huggingface transformers.
|
||||
|
||||
import pprint
|
||||
from functools import partial
|
||||
import os
|
||||
import numpy as np
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import flax.serialization
|
||||
import mlxu
|
||||
from EasyLM.checkpoint import StreamingCheckpointer
|
||||
from EasyLM.jax_utils import float_to_dtype
|
||||
|
||||
|
||||
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
|
||||
recover_diff=False,
|
||||
load_base_checkpoint='',
|
||||
load_target_checkpoint='',
|
||||
output_file='',
|
||||
streaming=True,
|
||||
float_dtype='bf16',
|
||||
)
|
||||
|
||||
|
||||
def main(argv):
|
||||
assert FLAGS.load_base_checkpoint != '' and FLAGS.load_target_checkpoint != ''
|
||||
assert FLAGS.output_file != ''
|
||||
base_params = StreamingCheckpointer.load_trainstate_checkpoint(
|
||||
FLAGS.load_base_checkpoint, disallow_trainstate=True
|
||||
)[1]['params']
|
||||
|
||||
target_params = StreamingCheckpointer.load_trainstate_checkpoint(
|
||||
FLAGS.load_target_checkpoint, disallow_trainstate=True
|
||||
)[1]['params']
|
||||
|
||||
if FLAGS.recover_diff:
|
||||
params = jax.tree_util.tree_map(
|
||||
lambda b, t: b + t, base_params, target_params
|
||||
)
|
||||
else:
|
||||
params = jax.tree_util.tree_map(
|
||||
lambda b, t: t - b, base_params, target_params
|
||||
)
|
||||
|
||||
if FLAGS.streaming:
|
||||
StreamingCheckpointer.save_train_state_to_file(
|
||||
params, FLAGS.output_file, float_dtype=FLAGS.float_dtype
|
||||
)
|
||||
else:
|
||||
params = float_to_dtype(params, FLAGS.float_dtype)
|
||||
with mlxu.open_file(FLAGS.output, 'wb') as fout:
|
||||
fout.write(flax.serialization.msgpack_serialize(params, in_place=True))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mlxu.run(main)
|
||||
65
EasyLM/scripts/lm_eval_harness.py
Normal file
65
EasyLM/scripts/lm_eval_harness.py
Normal file
@@ -0,0 +1,65 @@
|
||||
# This script runs lm_eval_harness evaluations against a served language model.
|
||||
# Typically, you need to run a language model server first, e.g.:
|
||||
# python -m EasyLM.models.gptj.gptj_serve ...
|
||||
|
||||
import dataclasses
|
||||
import pprint
|
||||
from functools import partial
|
||||
import os
|
||||
from tqdm import tqdm, trange
|
||||
import numpy as np
|
||||
import mlxu
|
||||
|
||||
from flax.traverse_util import flatten_dict
|
||||
from lm_eval import evaluator, tasks
|
||||
from lm_eval.base import LM
|
||||
|
||||
from EasyLM.serving import LMClient
|
||||
|
||||
|
||||
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
|
||||
tasks='wsc,piqa,winogrande,openbookqa,logiqa',
|
||||
shots=0,
|
||||
limit=0,
|
||||
write_out=False,
|
||||
lm_client=LMClient.get_default_config(),
|
||||
logger=mlxu.WandBLogger.get_default_config(),
|
||||
)
|
||||
|
||||
|
||||
class LMEvalHarnessInterface(LM):
|
||||
|
||||
def __init__(self, lm_client):
|
||||
self.lm_client = lm_client
|
||||
|
||||
def greedy_until(self, inputs):
|
||||
prefix, until = zip(*inputs)
|
||||
return self.lm_client.greedy_until(prefix, until)
|
||||
|
||||
def loglikelihood_rolling(self, inputs):
|
||||
loglikelihood, is_greedy = self.lm_client.loglikelihood_rolling(inputs)
|
||||
return list(zip(loglikelihood, is_greedy))
|
||||
|
||||
def loglikelihood(self, inputs):
|
||||
prefix, text = zip(*inputs)
|
||||
loglikelihood, is_greedy = self.lm_client.loglikelihood(prefix, text)
|
||||
return list(zip(loglikelihood, is_greedy))
|
||||
|
||||
|
||||
def main(argv):
|
||||
logger = mlxu.WandBLogger(
|
||||
config=FLAGS.logger, variant=mlxu.get_user_flags(FLAGS, FLAGS_DEF)
|
||||
)
|
||||
model = LMEvalHarnessInterface(LMClient(FLAGS.lm_client))
|
||||
task_list = FLAGS.tasks.split(',')
|
||||
results = evaluator.evaluate(
|
||||
model, tasks.get_task_dict(task_list), False, FLAGS.shots,
|
||||
limit=None if FLAGS.limit <= 0 else FLAGS.limit,
|
||||
write_out=FLAGS.write_out,
|
||||
)
|
||||
logger.log(flatten_dict(results['results'], sep='/'))
|
||||
pprint.pprint(results)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mlxu.run(main)
|
||||
52
EasyLM/scripts/lm_eval_json.py
Normal file
52
EasyLM/scripts/lm_eval_json.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import json
|
||||
import mlxu
|
||||
from EasyLM.serving import LMClient
|
||||
|
||||
|
||||
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
|
||||
input_file='',
|
||||
output_file='',
|
||||
prefix_field='prefix',
|
||||
text_field='text',
|
||||
until_field='until',
|
||||
eval_type='loglikelihood',
|
||||
lm_client=LMClient.get_default_config(),
|
||||
)
|
||||
|
||||
|
||||
def main(argv):
|
||||
lm_client = LMClient(FLAGS.lm_client)
|
||||
with mlxu.open_file(FLAGS.input_file, 'r') as fin:
|
||||
input_data = json.load(fin)
|
||||
|
||||
if FLAGS.eval_type == 'loglikelihood':
|
||||
prefix = input_data[FLAGS.prefix_field]
|
||||
text = input_data[FLAGS.text_field]
|
||||
loglikelihoods, is_greedys = lm_client.loglikelihood(prefix, text)
|
||||
output_data = {
|
||||
'loglikelihood': loglikelihoods,
|
||||
'is_greedy': is_greedys,
|
||||
}
|
||||
elif FLAGS.eval_type == 'loglikelihood_rolling':
|
||||
text = input_data[FLAGS.text_field]
|
||||
loglikelihoods, is_greedys = lm_client.loglikelihood_rolling(text)
|
||||
output_data = {
|
||||
'loglikelihood': loglikelihoods,
|
||||
'is_greedy': is_greedys,
|
||||
}
|
||||
elif FLAGS.eval_type == 'greedy_until':
|
||||
prefix = input_data[FLAGS.prefix_field]
|
||||
until = input_data[FLAGS.until_field]
|
||||
output_data = {'output_text': lm_client.greedy_until(prefix, until)}
|
||||
elif FLAGS.eval_type == 'generate':
|
||||
prefix = input_data[FLAGS.prefix_field]
|
||||
output_data = {'output_text': lm_client.generate(prefix)}
|
||||
else:
|
||||
raise ValueError(f'Unknown eval_type: {FLAGS.eval_type}')
|
||||
|
||||
with mlxu.open_file(FLAGS.output_file, 'w') as fout:
|
||||
json.dump(output_data, fout)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mlxu.run(main)
|
||||
Reference in New Issue
Block a user