初始化项目,由ModelHub XC社区提供模型

Model: Finnish-NLP/Ahma-7B
Source: Original Platform
This commit is contained in:
ModelHub XC
2026-06-01 02:08:18 +08:00
commit be39ad8722
45 changed files with 297486 additions and 0 deletions

View File

View File

@@ -0,0 +1,150 @@
from functools import partial
from time import time
import os
import numpy as np
import jax
import jax.flatten_util
import jax.numpy as jnp
import mlxu
from EasyLM.bpt import blockwise_attn
from EasyLM.jax_utils import (
get_float_dtype_by_name, set_random_seed, next_rng, JaxRNG
)
FLAGS, _ = mlxu.define_flags_with_default(
seed=42,
dtype='fp32',
embed_dim=2048,
n_heads=16,
ref_attn_seq_len=2048,
eff_attn_seq_len=16384,
batch_size=1,
query_chunk_size=2048,
key_chunk_size=2048,
warmup_steps=40,
steps=200,
)
def main(argv):
def random_kqv(rng_key, seq_len):
rng_generator = JaxRNG(rng_key)
kqv = []
for i in range(3):
kqv.append(
jax.random.normal(
rng_generator(),
(FLAGS.batch_size, seq_len, FLAGS.n_heads, FLAGS.embed_dim // FLAGS.n_heads),
dtype=get_float_dtype_by_name(FLAGS.dtype)
)
)
return tuple(kqv)
def reference_attn(query, key, value):
dtype = get_float_dtype_by_name(FLAGS.dtype)
query = query / jnp.sqrt(query.shape[-1]).astype(dtype)
logits = jnp.einsum("bqhc,bkhc->bhqk", query, key)
mask_value = jnp.finfo(logits.dtype).min
_, q_seq_len, _, _ = query.shape
_, kv_seq_len, _, _ = key.shape
mask_shape = (q_seq_len, kv_seq_len)
row_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 0)
col_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 1)
causal_mask = (row_ids < col_ids)[None, None, :, :]
logits = logits + jnp.where(causal_mask, mask_value, 0.0)
weights = jax.nn.softmax(logits, axis=-1)
out = jnp.einsum("bhqk,bkhc->bqhc", weights, value)
return out
def efficient_attention(query, key, value):
dtype = get_float_dtype_by_name(FLAGS.dtype)
return blockwise_attn(
query, key, value,
bias=None,
deterministic=True,
dropout_rng=None,
attn_pdrop=0.0,
causal=True,
query_chunk_size=FLAGS.query_chunk_size,
key_chunk_size=FLAGS.key_chunk_size,
dtype=get_float_dtype_by_name(FLAGS.dtype),
policy=jax.checkpoint_policies.nothing_saveable(),
precision=None,
float32_logits=True,
prevent_cse=True,
)
@partial(jax.jit, static_argnums=(1,))
def reference_attn_forward_backward(rng_key, seq_len):
@partial(jax.grad, argnums=(0, 1, 2))
@partial(jax.checkpoint, policy=jax.checkpoint_policies.nothing_saveable())
def grad_fn(query, key, value):
out = reference_attn(query, key, value)
return jnp.mean(out)
query, key, value = random_kqv(rng_key, seq_len)
return jax.flatten_util.ravel_pytree(
grad_fn(query, key, value)[1]
)[0].mean()
@partial(jax.jit, static_argnums=(1,))
def efficient_attn_forward_backward(rng_key, seq_len):
@partial(jax.grad, argnums=(0, 1, 2))
def grad_fn(query, key, value):
out = efficient_attention(query, key, value)
return jnp.mean(out)
query, key, value = random_kqv(rng_key, seq_len)
return jax.flatten_util.ravel_pytree(
grad_fn(query, key, value)[1]
)[0].mean()
set_random_seed(FLAGS.seed)
jax.block_until_ready(reference_attn_forward_backward(next_rng(), FLAGS.ref_attn_seq_len))
jax.block_until_ready(efficient_attn_forward_backward(next_rng(), FLAGS.eff_attn_seq_len))
all_results = []
for i in range(FLAGS.warmup_steps):
all_results.append(reference_attn_forward_backward(next_rng(), FLAGS.ref_attn_seq_len))
jax.block_until_ready(all_results)
start_time = time()
all_results = []
for i in range(FLAGS.steps):
all_results.append(reference_attn_forward_backward(next_rng(), FLAGS.ref_attn_seq_len))
jax.block_until_ready(all_results)
elapsed_time_ref_attn = time() - start_time
print(f'Reference attention: {elapsed_time_ref_attn:.3f} seconds')
all_results = []
for i in range(FLAGS.warmup_steps):
all_results.append(efficient_attn_forward_backward(next_rng(), FLAGS.eff_attn_seq_len))
jax.block_until_ready(all_results)
start_time = time()
all_results = []
for i in range(FLAGS.steps):
all_results.append(efficient_attn_forward_backward(next_rng(), FLAGS.eff_attn_seq_len))
jax.block_until_ready(all_results)
elapsed_time_efficient_attn = time() - start_time
print(f'Efficient attention: {elapsed_time_efficient_attn:.3f} seconds')
flops_ratio = (FLAGS.eff_attn_seq_len / FLAGS.ref_attn_seq_len) ** 2
efficiency = elapsed_time_ref_attn / elapsed_time_efficient_attn * flops_ratio
print(f'Efficiency: {efficiency:.3f}')
if __name__ == '__main__':
mlxu.run(main)

View File

@@ -0,0 +1,42 @@
# This script converts model checkpoint trained by EsayLM to a standard
# mspack checkpoint that can be loaded by huggingface transformers or
# flax.serialization.msgpack_restore. Such conversion allows models to be
# used by other frameworks that integrate with huggingface transformers.
import pprint
from functools import partial
import os
import numpy as np
import mlxu
import jax.numpy as jnp
import flax.serialization
from EasyLM.checkpoint import StreamingCheckpointer
from EasyLM.jax_utils import float_to_dtype
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
load_checkpoint='',
output_file='',
streaming=False,
float_dtype='bf16',
)
def main(argv):
assert FLAGS.load_checkpoint != '' and FLAGS.output_file != '', 'input and output must be specified'
params = StreamingCheckpointer.load_trainstate_checkpoint(
FLAGS.load_checkpoint, disallow_trainstate=True
)[1]['params']
if FLAGS.streaming:
StreamingCheckpointer.save_train_state_to_file(
params, FLAGS.output_file, float_dtype=FLAGS.float_dtype
)
else:
params = float_to_dtype(params, FLAGS.float_dtype)
with mlxu.open_file(FLAGS.output, 'wb') as fout:
fout.write(flax.serialization.msgpack_serialize(params, in_place=True))
if __name__ == "__main__":
mlxu.run(main)

View File

@@ -0,0 +1,59 @@
# This script converts model checkpoint trained by EsayLM to a standard
# mspack checkpoint that can be loaded by huggingface transformers or
# flax.serialization.msgpack_restore. Such conversion allows models to be
# used by other frameworks that integrate with huggingface transformers.
import pprint
from functools import partial
import os
import numpy as np
import jax
import jax.numpy as jnp
import flax.serialization
import mlxu
from EasyLM.checkpoint import StreamingCheckpointer
from EasyLM.jax_utils import float_to_dtype
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
recover_diff=False,
load_base_checkpoint='',
load_target_checkpoint='',
output_file='',
streaming=True,
float_dtype='bf16',
)
def main(argv):
assert FLAGS.load_base_checkpoint != '' and FLAGS.load_target_checkpoint != ''
assert FLAGS.output_file != ''
base_params = StreamingCheckpointer.load_trainstate_checkpoint(
FLAGS.load_base_checkpoint, disallow_trainstate=True
)[1]['params']
target_params = StreamingCheckpointer.load_trainstate_checkpoint(
FLAGS.load_target_checkpoint, disallow_trainstate=True
)[1]['params']
if FLAGS.recover_diff:
params = jax.tree_util.tree_map(
lambda b, t: b + t, base_params, target_params
)
else:
params = jax.tree_util.tree_map(
lambda b, t: t - b, base_params, target_params
)
if FLAGS.streaming:
StreamingCheckpointer.save_train_state_to_file(
params, FLAGS.output_file, float_dtype=FLAGS.float_dtype
)
else:
params = float_to_dtype(params, FLAGS.float_dtype)
with mlxu.open_file(FLAGS.output, 'wb') as fout:
fout.write(flax.serialization.msgpack_serialize(params, in_place=True))
if __name__ == "__main__":
mlxu.run(main)

View File

@@ -0,0 +1,65 @@
# This script runs lm_eval_harness evaluations against a served language model.
# Typically, you need to run a language model server first, e.g.:
# python -m EasyLM.models.gptj.gptj_serve ...
import dataclasses
import pprint
from functools import partial
import os
from tqdm import tqdm, trange
import numpy as np
import mlxu
from flax.traverse_util import flatten_dict
from lm_eval import evaluator, tasks
from lm_eval.base import LM
from EasyLM.serving import LMClient
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
tasks='wsc,piqa,winogrande,openbookqa,logiqa',
shots=0,
limit=0,
write_out=False,
lm_client=LMClient.get_default_config(),
logger=mlxu.WandBLogger.get_default_config(),
)
class LMEvalHarnessInterface(LM):
def __init__(self, lm_client):
self.lm_client = lm_client
def greedy_until(self, inputs):
prefix, until = zip(*inputs)
return self.lm_client.greedy_until(prefix, until)
def loglikelihood_rolling(self, inputs):
loglikelihood, is_greedy = self.lm_client.loglikelihood_rolling(inputs)
return list(zip(loglikelihood, is_greedy))
def loglikelihood(self, inputs):
prefix, text = zip(*inputs)
loglikelihood, is_greedy = self.lm_client.loglikelihood(prefix, text)
return list(zip(loglikelihood, is_greedy))
def main(argv):
logger = mlxu.WandBLogger(
config=FLAGS.logger, variant=mlxu.get_user_flags(FLAGS, FLAGS_DEF)
)
model = LMEvalHarnessInterface(LMClient(FLAGS.lm_client))
task_list = FLAGS.tasks.split(',')
results = evaluator.evaluate(
model, tasks.get_task_dict(task_list), False, FLAGS.shots,
limit=None if FLAGS.limit <= 0 else FLAGS.limit,
write_out=FLAGS.write_out,
)
logger.log(flatten_dict(results['results'], sep='/'))
pprint.pprint(results)
if __name__ == "__main__":
mlxu.run(main)

View File

@@ -0,0 +1,52 @@
import json
import mlxu
from EasyLM.serving import LMClient
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
input_file='',
output_file='',
prefix_field='prefix',
text_field='text',
until_field='until',
eval_type='loglikelihood',
lm_client=LMClient.get_default_config(),
)
def main(argv):
lm_client = LMClient(FLAGS.lm_client)
with mlxu.open_file(FLAGS.input_file, 'r') as fin:
input_data = json.load(fin)
if FLAGS.eval_type == 'loglikelihood':
prefix = input_data[FLAGS.prefix_field]
text = input_data[FLAGS.text_field]
loglikelihoods, is_greedys = lm_client.loglikelihood(prefix, text)
output_data = {
'loglikelihood': loglikelihoods,
'is_greedy': is_greedys,
}
elif FLAGS.eval_type == 'loglikelihood_rolling':
text = input_data[FLAGS.text_field]
loglikelihoods, is_greedys = lm_client.loglikelihood_rolling(text)
output_data = {
'loglikelihood': loglikelihoods,
'is_greedy': is_greedys,
}
elif FLAGS.eval_type == 'greedy_until':
prefix = input_data[FLAGS.prefix_field]
until = input_data[FLAGS.until_field]
output_data = {'output_text': lm_client.greedy_until(prefix, until)}
elif FLAGS.eval_type == 'generate':
prefix = input_data[FLAGS.prefix_field]
output_data = {'output_text': lm_client.generate(prefix)}
else:
raise ValueError(f'Unknown eval_type: {FLAGS.eval_type}')
with mlxu.open_file(FLAGS.output_file, 'w') as fout:
json.dump(output_data, fout)
if __name__ == "__main__":
mlxu.run(main)