初始化项目,由ModelHub XC社区提供模型

Model: Finnish-NLP/Ahma-7B
Source: Original Platform
This commit is contained in:
ModelHub XC
2026-06-01 02:08:18 +08:00
commit be39ad8722
45 changed files with 297486 additions and 0 deletions

View File

@@ -0,0 +1,272 @@
import pprint
from functools import partial
from tqdm import tqdm, trange
import numpy as np
import mlxu
import jax
import jax.numpy as jnp
from jax.experimental.pjit import pjit, with_sharding_constraint
from jax.sharding import PartitionSpec as PS
from flax.training.train_state import TrainState
from EasyLM.data import DatasetFactory
from EasyLM.checkpoint import StreamingCheckpointer
from EasyLM.optimizers import OptimizerFactory
from EasyLM.jax_utils import (
JaxRNG, JaxDistributedConfig, next_rng, match_partition_rules,
cross_entropy_loss_and_accuracy, global_norm, get_float_dtype_by_name,
set_random_seed, average_metrics, get_weight_decay_mask,
make_shard_and_gather_fns, tree_apply
)
from EasyLM.models.gptj.gptj_model import GPTJConfig, FlaxGPTJForCausalLMModule
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
seed=42,
mesh_dim='1,-1,1',
dtype='fp32',
total_steps=10000,
load_gptj_config='',
update_gptj_config='',
load_checkpoint='',
load_dataset_state='',
log_freq=50,
save_model_freq=0,
save_milestone_freq=0,
eval_steps=0,
tokenizer=GPTJConfig.get_tokenizer_config(),
train_dataset=DatasetFactory.get_default_config(),
eval_dataset=DatasetFactory.get_default_config(),
optimizer=OptimizerFactory.get_default_config(),
checkpointer=StreamingCheckpointer.get_default_config(),
gptj=GPTJConfig.get_default_config(),
logger=mlxu.WandBLogger.get_default_config(),
log_all_worker=False,
jax_distributed=JaxDistributedConfig.get_default_config(),
)
def main(argv):
JaxDistributedConfig.initialize(FLAGS.jax_distributed)
variant = mlxu.get_user_flags(FLAGS, FLAGS_DEF)
flags_config_dict = mlxu.user_flags_to_config_dict(FLAGS, FLAGS_DEF)
logger = mlxu.WandBLogger(
config=FLAGS.logger,
variant=variant,
enable=FLAGS.log_all_worker or (jax.process_index() == 0),
)
set_random_seed(FLAGS.seed)
tokenizer = GPTJConfig.get_tokenizer(FLAGS.tokenizer)
dataset = DatasetFactory.load_dataset(FLAGS.train_dataset, tokenizer)
if FLAGS.load_dataset_state != '':
dataset.load_state_dict(mlxu.load_pickle(FLAGS.load_dataset_state))
if FLAGS.eval_steps > 0:
eval_dataset = DatasetFactory.load_dataset(
FLAGS.eval_dataset, dataset.tokenizer
)
eval_iterator = iter(eval_dataset)
seq_length = dataset.seq_length
if FLAGS.load_gptj_config != '':
gptj_config = GPTJConfig.load_config(FLAGS.load_gptj_config)
else:
gptj_config = GPTJConfig(**FLAGS.gptj)
if FLAGS.update_gptj_config != '':
gptj_config.update(dict(eval(FLAGS.update_gptj_config)))
gptj_config.update(dict(
bos_token_id=dataset.tokenizer.bos_token_id,
eos_token_id=dataset.tokenizer.eos_token_id,
))
if gptj_config.vocab_size < dataset.vocab_size:
gptj_config.update(dict(vocab_size=dataset.vocab_size))
model = FlaxGPTJForCausalLMModule(
gptj_config, dtype=get_float_dtype_by_name(FLAGS.dtype)
)
optimizer, optimizer_info = OptimizerFactory.get_optimizer(
FLAGS.optimizer,
get_weight_decay_mask(GPTJConfig.get_weight_decay_exclusions()),
)
def create_trainstate_from_params(params):
return TrainState.create(params=params, tx=optimizer, apply_fn=None)
def init_fn(rng):
rng_generator = JaxRNG(rng)
params = model.init(
input_ids=jnp.zeros((4, seq_length), dtype=jnp.int32),
position_ids=jnp.zeros((4, seq_length), dtype=jnp.int32),
attention_mask=jnp.ones((4, seq_length), dtype=jnp.int32),
rngs=rng_generator(gptj_config.rng_keys()),
)
return TrainState.create(params=params, tx=optimizer, apply_fn=None)
def train_step(train_state, rng, batch):
rng_generator = JaxRNG(rng)
batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
def loss_and_accuracy(params):
logits = model.apply(
params, batch['input_tokens'], deterministic=False,
rngs=rng_generator(gptj_config.rng_keys()),
).logits
return cross_entropy_loss_and_accuracy(
logits, batch['target_tokens'], batch['loss_masks']
)
grad_fn = jax.value_and_grad(loss_and_accuracy, has_aux=True)
(loss, accuracy), grads = grad_fn(train_state.params)
train_state = train_state.apply_gradients(grads=grads)
metrics = dict(
loss=loss,
accuracy=accuracy,
learning_rate=optimizer_info['learning_rate_schedule'](train_state.step),
gradient_norm=global_norm(grads),
param_norm=global_norm(train_state.params),
)
return train_state, rng_generator(), metrics
def eval_step(train_state, rng, batch):
rng_generator = JaxRNG(rng)
batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
logits = model.apply(
train_state.params, batch['input_tokens'], deterministic=True,
rngs=rng_generator(gptj_config.rng_keys()),
).logits
loss, accuracy = cross_entropy_loss_and_accuracy(
logits, batch['target_tokens'], batch['loss_masks']
)
metrics = dict(
eval_loss=loss,
eval_accuracy=accuracy,
)
return rng_generator(), metrics
train_state_shapes = jax.eval_shape(init_fn, next_rng())
train_state_partition = match_partition_rules(
GPTJConfig.get_partition_rules(), train_state_shapes
)
shard_fns, gather_fns = make_shard_and_gather_fns(
train_state_partition, train_state_shapes
)
checkpointer = StreamingCheckpointer(
FLAGS.checkpointer, logger.output_dir,
enable=jax.process_index() == 0,
)
sharded_init_fn = pjit(
init_fn,
in_shardings=PS(),
out_shardings=train_state_partition
)
sharded_create_trainstate_from_params = pjit(
create_trainstate_from_params,
in_shardings=(train_state_partition.params, ),
out_shardings=train_state_partition,
donate_argnums=(0, ),
)
sharded_train_step = pjit(
train_step,
in_shardings=(train_state_partition, PS(), PS()),
out_shardings=(train_state_partition, PS(), PS()),
donate_argnums=(0, 1),
)
sharded_eval_step = pjit(
eval_step,
in_shardings=(train_state_partition, PS(), PS()),
out_shardings=(PS(), PS()),
donate_argnums=(1,),
)
def save_checkpoint(train_state, milestone=False):
step = int(jax.device_get(train_state.step))
metadata = dict(
step=step,
variant=variant,
flags=flags_config_dict,
gptj_config=gptj_config.to_dict(),
)
checkpointer.save_all(
train_state=train_state,
gather_fns=gather_fns,
metadata=metadata,
dataset=dataset.get_state_dict(),
milestone=milestone,
)
mesh = GPTJConfig.get_jax_mesh(FLAGS.mesh_dim)
with mesh:
train_state, restored_params = None, None
if FLAGS.load_checkpoint != '':
load_type, load_path = FLAGS.load_checkpoint.split('::', 1)
if load_type == 'huggingface':
restored_params = tree_apply(
shard_fns.params, gptj_config.load_pretrained(load_path)
)
train_state = None
else:
train_state, restored_params = checkpointer.load_trainstate_checkpoint(
FLAGS.load_checkpoint, train_state_shapes, shard_fns
)
if train_state is None and restored_params is None:
# Initialize from scratch
train_state = sharded_init_fn(next_rng())
elif train_state is None and restored_params is not None:
# Restore from params but initialize train_state
train_state = sharded_create_trainstate_from_params(restored_params)
del restored_params
start_step = int(jax.device_get(train_state.step))
if FLAGS.save_model_freq > 0:
save_checkpoint(train_state)
sharded_rng = next_rng()
step_counter = trange(start_step, FLAGS.total_steps, ncols=0)
for step, (batch, dataset_metrics) in zip(step_counter, dataset):
train_state, sharded_rng, metrics = sharded_train_step(
train_state, sharded_rng, batch
)
if step % FLAGS.log_freq == 0:
if FLAGS.eval_steps > 0:
eval_metric_list = []
for _ in range(FLAGS.eval_steps):
eval_batch, _ = next(eval_iterator)
sharded_rng, eval_metrics = sharded_eval_step(
train_state, sharded_rng, eval_batch
)
eval_metric_list.append(eval_metrics)
metrics.update(average_metrics(eval_metric_list))
log_metrics = {"step": step}
log_metrics.update(metrics)
log_metrics.update(dataset_metrics)
log_metrics = jax.device_get(log_metrics)
logger.log(log_metrics)
tqdm.write("\n" + pprint.pformat(log_metrics) + "\n")
if FLAGS.save_milestone_freq > 0 and (step + 1) % FLAGS.save_milestone_freq == 0:
save_checkpoint(train_state, milestone=True)
elif FLAGS.save_model_freq > 0 and (step + 1) % FLAGS.save_model_freq == 0:
save_checkpoint(train_state)
if FLAGS.save_model_freq > 0:
save_checkpoint(train_state)
if __name__ == "__main__":
mlxu.run(main)