初始化项目,由ModelHub XC社区提供模型
Model: Finnish-NLP/Ahma-7B Source: Original Platform
This commit is contained in:
0
EasyLM/models/roberta/__init__.py
Normal file
0
EasyLM/models/roberta/__init__.py
Normal file
1694
EasyLM/models/roberta/roberta_model.py
Normal file
1694
EasyLM/models/roberta/roberta_model.py
Normal file
File diff suppressed because it is too large
Load Diff
307
EasyLM/models/roberta/roberta_train.py
Normal file
307
EasyLM/models/roberta/roberta_train.py
Normal file
@@ -0,0 +1,307 @@
|
||||
import dataclasses
|
||||
import pprint
|
||||
from functools import partial
|
||||
import re
|
||||
|
||||
from tqdm import tqdm, trange
|
||||
import numpy as np
|
||||
import mlxu
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax.experimental.pjit import pjit, with_sharding_constraint
|
||||
from jax.sharding import PartitionSpec as PS
|
||||
from flax.training.train_state import TrainState
|
||||
|
||||
from EasyLM.data import DatasetFactory
|
||||
from EasyLM.checkpoint import StreamingCheckpointer
|
||||
from EasyLM.optimizers import OptimizerFactory
|
||||
from EasyLM.jax_utils import (
|
||||
JaxRNG, JaxDistributedConfig, next_rng, match_partition_rules, get_float_dtype_by_name,
|
||||
cross_entropy_loss_and_accuracy, named_tree_map, global_norm,
|
||||
set_random_seed, average_metrics, get_weight_decay_mask,
|
||||
make_shard_and_gather_fns, tree_apply
|
||||
)
|
||||
from EasyLM.models.roberta.roberta_model import (
|
||||
RobertaConfig, FlaxRobertaForMaskedLMModule
|
||||
)
|
||||
|
||||
|
||||
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
|
||||
seed=42,
|
||||
mesh_dim='-1,1,1',
|
||||
dtype='fp32',
|
||||
mask_token_probability=0.15,
|
||||
total_steps=10000,
|
||||
load_roberta_config='',
|
||||
update_roberta_config='',
|
||||
load_checkpoint='',
|
||||
load_dataset_state='',
|
||||
log_freq=50,
|
||||
save_model_freq=0,
|
||||
save_milestone_freq=0,
|
||||
eval_steps=0,
|
||||
tokenizer=RobertaConfig.get_tokenizer_config(),
|
||||
train_dataset=DatasetFactory.get_default_config(),
|
||||
eval_dataset=DatasetFactory.get_default_config(),
|
||||
optimizer=OptimizerFactory.get_default_config(),
|
||||
checkpointer=StreamingCheckpointer.get_default_config(),
|
||||
roberta=RobertaConfig.get_default_config(),
|
||||
logger=mlxu.WandBLogger.get_default_config(),
|
||||
log_all_worker=False,
|
||||
jax_distributed=JaxDistributedConfig.get_default_config(),
|
||||
)
|
||||
|
||||
|
||||
def main(argv):
|
||||
JaxDistributedConfig.initialize(FLAGS.jax_distributed)
|
||||
variant = mlxu.get_user_flags(FLAGS, FLAGS_DEF)
|
||||
flags_config_dict = mlxu.user_flags_to_config_dict(FLAGS, FLAGS_DEF)
|
||||
logger = mlxu.WandBLogger(
|
||||
config=FLAGS.logger,
|
||||
variant=variant,
|
||||
enable=FLAGS.log_all_worker or (jax.process_index() == 0),
|
||||
)
|
||||
set_random_seed(FLAGS.seed)
|
||||
|
||||
tokenizer = RobertaConfig.get_tokenizer(FLAGS.tokenizer)
|
||||
dataset = DatasetFactory.load_dataset(FLAGS.train_dataset, tokenizer)
|
||||
if FLAGS.load_dataset_state != '':
|
||||
dataset.load_state_dict(mlxu.load_pickle(FLAGS.load_dataset_state))
|
||||
|
||||
if FLAGS.eval_steps > 0:
|
||||
eval_dataset = DatasetFactory.load_dataset(
|
||||
FLAGS.eval_dataset, dataset.tokenizer
|
||||
)
|
||||
eval_iterator = iter(eval_dataset)
|
||||
|
||||
seq_length = dataset.seq_length
|
||||
|
||||
if FLAGS.load_roberta_config != '':
|
||||
roberta_config = RobertaConfig.load_config(FLAGS.load_roberta_config)
|
||||
else:
|
||||
roberta_config = RobertaConfig(**FLAGS.roberta)
|
||||
|
||||
if FLAGS.update_roberta_config != '':
|
||||
roberta_config.update(dict(eval(FLAGS.update_roberta_config)))
|
||||
|
||||
roberta_config.update(dict(
|
||||
bos_token_id=dataset.tokenizer.bos_token_id,
|
||||
eos_token_id=dataset.tokenizer.eos_token_id,
|
||||
pad_token_id=dataset.tokenizer.pad_token_id,
|
||||
vocab_size=dataset.vocab_size,
|
||||
))
|
||||
|
||||
model = FlaxRobertaForMaskedLMModule(
|
||||
roberta_config, dtype=get_float_dtype_by_name(FLAGS.dtype)
|
||||
)
|
||||
|
||||
optimizer, optimizer_info = OptimizerFactory.get_optimizer(
|
||||
FLAGS.optimizer,
|
||||
get_weight_decay_mask(RobertaConfig.get_weight_decay_exclusions()),
|
||||
)
|
||||
|
||||
def create_trainstate_from_params(params):
|
||||
return TrainState.create(params=params, tx=optimizer, apply_fn=None)
|
||||
|
||||
def init_fn(rng):
|
||||
rng_generator = JaxRNG(rng)
|
||||
params = model.init(
|
||||
input_ids=jnp.zeros((4, seq_length), dtype=jnp.int32),
|
||||
position_ids=jnp.zeros((4, seq_length), dtype=jnp.int32),
|
||||
attention_mask=jnp.ones((4, seq_length), dtype=jnp.int32),
|
||||
token_type_ids=None,
|
||||
head_mask=None,
|
||||
rngs=rng_generator(roberta_config.rng_keys()),
|
||||
)
|
||||
return TrainState.create(params=params, tx=optimizer, apply_fn=None)
|
||||
|
||||
def train_step(train_state, rng, batch):
|
||||
rng_generator = JaxRNG(rng)
|
||||
tokens = with_sharding_constraint(batch['target_tokens'], PS(('dp', 'fsdp')))
|
||||
def loss_and_accuracy(params):
|
||||
altered_tokens = jax.random.uniform(
|
||||
rng_generator(), shape=tokens.shape
|
||||
) < FLAGS.mask_token_probability
|
||||
random_uniform = jax.random.uniform(rng_generator(), shape=tokens.shape)
|
||||
altered_by_mask = altered_tokens & (random_uniform < 0.8)
|
||||
altered_by_random = altered_tokens & (random_uniform >= 0.8) & (random_uniform < 0.9)
|
||||
inputs = jnp.where(altered_by_mask, dataset.tokenizer.mask_token_id, tokens)
|
||||
random_tokens = jax.random.randint(
|
||||
rng_generator(), shape=tokens.shape, minval=0, maxval=dataset.vocab_size
|
||||
)
|
||||
inputs = jnp.where(altered_by_random, random_tokens, inputs)
|
||||
logits = model.apply(
|
||||
params, inputs,
|
||||
attention_mask=jnp.ones_like(inputs),
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
deterministic=False,
|
||||
rngs=rng_generator(roberta_config.rng_keys()),
|
||||
).logits
|
||||
return cross_entropy_loss_and_accuracy(logits, tokens, valid=altered_tokens)
|
||||
grad_fn = jax.value_and_grad(loss_and_accuracy, has_aux=True)
|
||||
(loss, accuracy), grads = grad_fn(train_state.params)
|
||||
train_state = train_state.apply_gradients(grads=grads)
|
||||
metrics = dict(
|
||||
loss=loss,
|
||||
accuracy=accuracy,
|
||||
learning_rate=optimizer_info['learning_rate_schedule'](train_state.step),
|
||||
gradient_norm=global_norm(grads),
|
||||
param_norm=global_norm(train_state.params),
|
||||
)
|
||||
return train_state, rng_generator(), metrics
|
||||
|
||||
def eval_step(train_state, rng, batch):
|
||||
rng_generator = JaxRNG(rng)
|
||||
tokens = with_sharding_constraint(batch['target_tokens'], PS(('dp', 'fsdp')))
|
||||
altered_tokens = jax.random.uniform(
|
||||
rng_generator(), shape=tokens.shape
|
||||
) < FLAGS.mask_token_probability
|
||||
random_uniform = jax.random.uniform(rng_generator(), shape=tokens.shape)
|
||||
altered_by_mask = altered_tokens & (random_uniform < 0.8)
|
||||
altered_by_random = altered_tokens & (random_uniform >= 0.8) & (random_uniform < 0.9)
|
||||
inputs = jnp.where(altered_by_mask, dataset.tokenizer.mask_token_id, tokens)
|
||||
random_tokens = jax.random.randint(
|
||||
rng_generator(), shape=tokens.shape, minval=0, maxval=dataset.vocab_size
|
||||
)
|
||||
inputs = jnp.where(altered_by_random, random_tokens, inputs)
|
||||
logits = model.apply(
|
||||
train_state.params, inputs,
|
||||
attention_mask=jnp.ones_like(inputs),
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
deterministic=False,
|
||||
rngs=rng_generator(roberta_config.rng_keys()),
|
||||
).logits
|
||||
loss, accuracy = cross_entropy_loss_and_accuracy(logits, tokens, valid=altered_tokens)
|
||||
metrics = dict(
|
||||
eval_loss=loss,
|
||||
eval_accuracy=accuracy,
|
||||
)
|
||||
return rng_generator(), metrics
|
||||
|
||||
train_state_shapes = jax.eval_shape(init_fn, next_rng())
|
||||
train_state_partition = match_partition_rules(
|
||||
RobertaConfig.get_partition_rules(), train_state_shapes
|
||||
)
|
||||
|
||||
shard_fns, gather_fns = make_shard_and_gather_fns(
|
||||
train_state_partition, train_state_shapes
|
||||
)
|
||||
checkpointer = StreamingCheckpointer(
|
||||
FLAGS.checkpointer, logger.output_dir,
|
||||
enable=jax.process_index() == 0
|
||||
)
|
||||
|
||||
sharded_init_fn = pjit(
|
||||
init_fn,
|
||||
in_shardings=PS(),
|
||||
out_shardings=train_state_partition
|
||||
)
|
||||
|
||||
sharded_create_trainstate_from_params = pjit(
|
||||
create_trainstate_from_params,
|
||||
in_shardings=(train_state_partition.params, ),
|
||||
out_shardings=train_state_partition,
|
||||
donate_argnums=(0, ),
|
||||
)
|
||||
|
||||
sharded_train_step = pjit(
|
||||
train_step,
|
||||
in_shardings=(train_state_partition, PS(), PS()),
|
||||
out_shardings=(train_state_partition, PS(), PS()),
|
||||
donate_argnums=(0, 1),
|
||||
)
|
||||
|
||||
sharded_eval_step = pjit(
|
||||
eval_step,
|
||||
in_shardings=(train_state_partition, PS(), PS()),
|
||||
out_shardings=(PS(), PS()),
|
||||
donate_argnums=(1,),
|
||||
)
|
||||
|
||||
def save_checkpoint(train_state, milestone=False):
|
||||
step = int(jax.device_get(train_state.step))
|
||||
metadata = dict(
|
||||
step=step,
|
||||
variant=variant,
|
||||
flags=flags_config_dict,
|
||||
roberta_config=roberta_config.to_dict(),
|
||||
)
|
||||
checkpointer.save_all(
|
||||
train_state=train_state,
|
||||
gather_fns=gather_fns,
|
||||
metadata=metadata,
|
||||
dataset=dataset.get_state_dict(),
|
||||
milestone=milestone,
|
||||
)
|
||||
|
||||
mesh = RobertaConfig.get_jax_mesh(FLAGS.mesh_dim)
|
||||
with mesh:
|
||||
train_state, restored_params = None, None
|
||||
if FLAGS.load_checkpoint != '':
|
||||
load_type, load_path = FLAGS.load_checkpoint.split('::', 1)
|
||||
if load_type == 'huggingface':
|
||||
restored_params = tree_apply(
|
||||
shard_fns.params, roberta_config.load_pretrained(load_path)
|
||||
)
|
||||
train_state = None
|
||||
else:
|
||||
train_state, restored_params = checkpointer.load_trainstate_checkpoint(
|
||||
FLAGS.load_checkpoint, train_state_shapes, shard_fns
|
||||
)
|
||||
|
||||
if train_state is None and restored_params is None:
|
||||
# Initialize from scratch
|
||||
train_state = sharded_init_fn(next_rng())
|
||||
elif train_state is None and restored_params is not None:
|
||||
# Restore from params but initialize train_state
|
||||
train_state = sharded_create_trainstate_from_params(restored_params)
|
||||
del restored_params
|
||||
|
||||
start_step = int(jax.device_get(train_state.step))
|
||||
|
||||
if FLAGS.save_model_freq > 0:
|
||||
save_checkpoint(train_state)
|
||||
|
||||
sharded_rng = next_rng()
|
||||
|
||||
step_counter = trange(start_step, FLAGS.total_steps, ncols=0)
|
||||
|
||||
for step, (batch, dataset_metrics) in zip(step_counter, dataset):
|
||||
train_state, sharded_rng, metrics = sharded_train_step(
|
||||
train_state, sharded_rng, batch
|
||||
)
|
||||
|
||||
if step % FLAGS.log_freq == 0:
|
||||
if FLAGS.eval_steps > 0:
|
||||
eval_metric_list = []
|
||||
for _ in range(FLAGS.eval_steps):
|
||||
eval_batch, _ = next(eval_iterator)
|
||||
sharded_rng, eval_metrics = sharded_eval_step(
|
||||
train_state, sharded_rng, eval_batch
|
||||
)
|
||||
eval_metric_list.append(eval_metrics)
|
||||
metrics.update(average_metrics(eval_metric_list))
|
||||
|
||||
log_metrics = {"step": step}
|
||||
log_metrics.update(metrics)
|
||||
log_metrics.update(dataset_metrics)
|
||||
log_metrics = jax.device_get(log_metrics)
|
||||
logger.log(log_metrics)
|
||||
tqdm.write("\n" + pprint.pformat(log_metrics) + "\n")
|
||||
|
||||
if FLAGS.save_milestone_freq > 0 and (step + 1) % FLAGS.save_milestone_freq == 0:
|
||||
save_checkpoint(train_state, milestone=True)
|
||||
elif FLAGS.save_model_freq > 0 and (step + 1) % FLAGS.save_model_freq == 0:
|
||||
save_checkpoint(train_state)
|
||||
|
||||
if FLAGS.save_model_freq > 0:
|
||||
save_checkpoint(train_state)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mlxu.run(main)
|
||||
Reference in New Issue
Block a user