初始化项目,由ModelHub XC社区提供模型
Model: Finnish-NLP/Ahma-7B Source: Original Platform
This commit is contained in:
59
EasyLM/scripts/diff_checkpoint.py
Normal file
59
EasyLM/scripts/diff_checkpoint.py
Normal file
@@ -0,0 +1,59 @@
|
||||
# This script converts model checkpoint trained by EsayLM to a standard
|
||||
# mspack checkpoint that can be loaded by huggingface transformers or
|
||||
# flax.serialization.msgpack_restore. Such conversion allows models to be
|
||||
# used by other frameworks that integrate with huggingface transformers.
|
||||
|
||||
import pprint
|
||||
from functools import partial
|
||||
import os
|
||||
import numpy as np
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import flax.serialization
|
||||
import mlxu
|
||||
from EasyLM.checkpoint import StreamingCheckpointer
|
||||
from EasyLM.jax_utils import float_to_dtype
|
||||
|
||||
|
||||
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
|
||||
recover_diff=False,
|
||||
load_base_checkpoint='',
|
||||
load_target_checkpoint='',
|
||||
output_file='',
|
||||
streaming=True,
|
||||
float_dtype='bf16',
|
||||
)
|
||||
|
||||
|
||||
def main(argv):
|
||||
assert FLAGS.load_base_checkpoint != '' and FLAGS.load_target_checkpoint != ''
|
||||
assert FLAGS.output_file != ''
|
||||
base_params = StreamingCheckpointer.load_trainstate_checkpoint(
|
||||
FLAGS.load_base_checkpoint, disallow_trainstate=True
|
||||
)[1]['params']
|
||||
|
||||
target_params = StreamingCheckpointer.load_trainstate_checkpoint(
|
||||
FLAGS.load_target_checkpoint, disallow_trainstate=True
|
||||
)[1]['params']
|
||||
|
||||
if FLAGS.recover_diff:
|
||||
params = jax.tree_util.tree_map(
|
||||
lambda b, t: b + t, base_params, target_params
|
||||
)
|
||||
else:
|
||||
params = jax.tree_util.tree_map(
|
||||
lambda b, t: t - b, base_params, target_params
|
||||
)
|
||||
|
||||
if FLAGS.streaming:
|
||||
StreamingCheckpointer.save_train_state_to_file(
|
||||
params, FLAGS.output_file, float_dtype=FLAGS.float_dtype
|
||||
)
|
||||
else:
|
||||
params = float_to_dtype(params, FLAGS.float_dtype)
|
||||
with mlxu.open_file(FLAGS.output, 'wb') as fout:
|
||||
fout.write(flax.serialization.msgpack_serialize(params, in_place=True))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mlxu.run(main)
|
||||
Reference in New Issue
Block a user