60 lines
1.8 KiB
Python
60 lines
1.8 KiB
Python
# This script converts model checkpoint trained by EsayLM to a standard
|
|
# mspack checkpoint that can be loaded by huggingface transformers or
|
|
# flax.serialization.msgpack_restore. Such conversion allows models to be
|
|
# used by other frameworks that integrate with huggingface transformers.
|
|
|
|
import pprint
|
|
from functools import partial
|
|
import os
|
|
import numpy as np
|
|
import jax
|
|
import jax.numpy as jnp
|
|
import flax.serialization
|
|
import mlxu
|
|
from EasyLM.checkpoint import StreamingCheckpointer
|
|
from EasyLM.jax_utils import float_to_dtype
|
|
|
|
|
|
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
|
|
recover_diff=False,
|
|
load_base_checkpoint='',
|
|
load_target_checkpoint='',
|
|
output_file='',
|
|
streaming=True,
|
|
float_dtype='bf16',
|
|
)
|
|
|
|
|
|
def main(argv):
|
|
assert FLAGS.load_base_checkpoint != '' and FLAGS.load_target_checkpoint != ''
|
|
assert FLAGS.output_file != ''
|
|
base_params = StreamingCheckpointer.load_trainstate_checkpoint(
|
|
FLAGS.load_base_checkpoint, disallow_trainstate=True
|
|
)[1]['params']
|
|
|
|
target_params = StreamingCheckpointer.load_trainstate_checkpoint(
|
|
FLAGS.load_target_checkpoint, disallow_trainstate=True
|
|
)[1]['params']
|
|
|
|
if FLAGS.recover_diff:
|
|
params = jax.tree_util.tree_map(
|
|
lambda b, t: b + t, base_params, target_params
|
|
)
|
|
else:
|
|
params = jax.tree_util.tree_map(
|
|
lambda b, t: t - b, base_params, target_params
|
|
)
|
|
|
|
if FLAGS.streaming:
|
|
StreamingCheckpointer.save_train_state_to_file(
|
|
params, FLAGS.output_file, float_dtype=FLAGS.float_dtype
|
|
)
|
|
else:
|
|
params = float_to_dtype(params, FLAGS.float_dtype)
|
|
with mlxu.open_file(FLAGS.output, 'wb') as fout:
|
|
fout.write(flax.serialization.msgpack_serialize(params, in_place=True))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
mlxu.run(main)
|