初始化项目,由ModelHub XC社区提供模型

Model: Finnish-NLP/Ahma-7B
Source: Original Platform
This commit is contained in:
ModelHub XC
2026-06-01 02:08:18 +08:00
commit be39ad8722
45 changed files with 297486 additions and 0 deletions

View File

@@ -0,0 +1,42 @@
# This script converts model checkpoint trained by EsayLM to a standard
# mspack checkpoint that can be loaded by huggingface transformers or
# flax.serialization.msgpack_restore. Such conversion allows models to be
# used by other frameworks that integrate with huggingface transformers.
import pprint
from functools import partial
import os
import numpy as np
import mlxu
import jax.numpy as jnp
import flax.serialization
from EasyLM.checkpoint import StreamingCheckpointer
from EasyLM.jax_utils import float_to_dtype
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
load_checkpoint='',
output_file='',
streaming=False,
float_dtype='bf16',
)
def main(argv):
assert FLAGS.load_checkpoint != '' and FLAGS.output_file != '', 'input and output must be specified'
params = StreamingCheckpointer.load_trainstate_checkpoint(
FLAGS.load_checkpoint, disallow_trainstate=True
)[1]['params']
if FLAGS.streaming:
StreamingCheckpointer.save_train_state_to_file(
params, FLAGS.output_file, float_dtype=FLAGS.float_dtype
)
else:
params = float_to_dtype(params, FLAGS.float_dtype)
with mlxu.open_file(FLAGS.output, 'wb') as fout:
fout.write(flax.serialization.msgpack_serialize(params, in_place=True))
if __name__ == "__main__":
mlxu.run(main)