初始化项目,由ModelHub XC社区提供模型

Model: Finnish-NLP/Ahma-7B
Source: Original Platform
This commit is contained in:
ModelHub XC
2026-06-01 02:08:18 +08:00
commit be39ad8722
45 changed files with 297486 additions and 0 deletions

View File

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,396 @@
import pprint
from functools import partial
import numpy as np
import mlxu
import jax
import jax.numpy as jnp
from jax.experimental.pjit import pjit
from jax.sharding import PartitionSpec as PS
import flax
from flax import linen as nn
from flax.jax_utils import prefetch_to_device
from flax.training.train_state import TrainState
import optax
from transformers import GenerationConfig, FlaxLogitsProcessorList
from EasyLM.checkpoint import StreamingCheckpointer
from EasyLM.serving import LMServer
from EasyLM.jax_utils import (
JaxRNG, JaxDistributedConfig, next_rng, match_partition_rules, tree_apply,
set_random_seed, get_float_dtype_by_name, make_shard_and_gather_fns,
with_sharding_constraint, FlaxTemperatureLogitsWarper
)
from EasyLM.models.gptj.gptj_model import (
GPTJConfig, FlaxGPTJForCausalLMModule, FlaxGPTJForCausalLM
)
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
seed=42,
initialize_jax_distributed=False,
mesh_dim='1,-1,1',
dtype='bf16',
input_length=1024,
seq_length=2048,
top_k=50,
top_p=1.0,
do_sample=True,
num_beams=1,
add_bos_token=False,
load_gptj_config='',
load_checkpoint='',
tokenizer=GPTJConfig.get_tokenizer_config(),
lm_server=LMServer.get_default_config(),
jax_distributed=JaxDistributedConfig.get_default_config(),
)
def main(argv):
JaxDistributedConfig.initialize(FLAGS.jax_distributed)
set_random_seed(FLAGS.seed)
prefix_tokenizer = GPTJConfig.get_tokenizer(
FLAGS.tokenizer, truncation_side='left', padding_side='left'
)
tokenizer = GPTJConfig.get_tokenizer(
FLAGS.tokenizer, truncation_side='right', padding_side='right'
)
with jax.default_device(jax.devices("cpu")[0]):
gptj_config = GPTJConfig.load_config(FLAGS.load_gptj_config)
load_type, load_path = FLAGS.load_checkpoint.split('::', 1)
if load_type == 'huggingface':
params = gptj_config.load_pretrained(load_path)
else:
_, params = StreamingCheckpointer.load_trainstate_checkpoint(
FLAGS.load_checkpoint, disallow_trainstate=True
)
hf_model = FlaxGPTJForCausalLM(
gptj_config,
input_shape=(1, FLAGS.seq_length),
seed=FLAGS.seed,
_do_init=False
)
model_ps = match_partition_rules(
GPTJConfig.get_partition_rules(), params
)
shard_fns, _ = make_shard_and_gather_fns(
model_ps, get_float_dtype_by_name(FLAGS.dtype)
)
@partial(
pjit,
in_shardings=(model_ps, PS(), PS()),
out_shardings=(PS(), PS(), PS())
)
def forward_loglikelihood(params, rng, batch):
batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
rng_generator = JaxRNG(rng)
input_tokens = batch['input_tokens']
output_tokens = batch['output_tokens']
input_mask = batch['input_mask']
output_mask = batch['output_mask']
logits = hf_model.module.apply(
params, input_tokens, attention_mask=input_mask,
deterministic=True, rngs=rng_generator(gptj_config.rng_keys()),
).logits
if gptj_config.n_real_tokens is not None:
logits = logits.at[:, :, gptj_config.n_real_tokens:].set(-1e8)
loglikelihood = -optax.softmax_cross_entropy_with_integer_labels(
logits, output_tokens
)
loglikelihood = jnp.sum(loglikelihood * output_mask, axis=-1)
match_count = jnp.sum(
(jnp.argmax(logits, axis=-1) == output_tokens) * output_mask,
axis=-1
)
total = jnp.sum(output_mask, axis=-1)
is_greedy = match_count == total
return loglikelihood, is_greedy, rng_generator()
@partial(
pjit,
in_shardings=(model_ps, PS(), PS(), PS()),
out_shardings=(PS(), PS())
)
def forward_generate(params, rng, batch, temperature):
batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
rng_generator = JaxRNG(rng)
output = hf_model.generate(
batch['input_tokens'],
attention_mask=batch['attention_mask'],
params=params['params'],
prng_key=rng_generator(),
logits_processor=FlaxLogitsProcessorList(
[FlaxTemperatureLogitsWarper(temperature)]
),
generation_config=GenerationConfig(
max_new_tokens=FLAGS.seq_length - FLAGS.input_length,
pad_token_id=tokenizer.eos_token_id,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
do_sample=FLAGS.do_sample,
num_beams=FLAGS.num_beams,
top_k=FLAGS.top_k,
top_p=FLAGS.top_p,
)
).sequences[:, batch['input_tokens'].shape[1]:]
return output, rng_generator()
@partial(
pjit,
in_shardings=(model_ps, PS(), PS()),
out_shardings=(PS(), PS())
)
def forward_greedy_generate(params, rng, batch):
batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
rng_generator = JaxRNG(rng)
output = hf_model.generate(
batch['input_tokens'],
attention_mask=batch['attention_mask'],
params=params['params'],
prng_key=rng_generator(),
generation_config=GenerationConfig(
max_new_tokens=FLAGS.seq_length - FLAGS.input_length,
pad_token_id=tokenizer.eos_token_id,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
do_sample=False,
num_beams=1,
)
).sequences[:, batch['input_tokens'].shape[1]:]
return output, rng_generator()
mesh = GPTJConfig.get_jax_mesh(FLAGS.mesh_dim)
with mesh:
params = tree_apply(shard_fns, params)
sharded_rng = next_rng()
class ModelServer(LMServer):
@staticmethod
def loglikelihood(prefix_text, text):
nonlocal sharded_rng
prefix = prefix_tokenizer(
prefix_text,
padding='max_length',
truncation=True,
max_length=FLAGS.input_length,
return_tensors='np',
)
inputs = tokenizer(
text,
padding='max_length',
truncation=True,
max_length=FLAGS.seq_length - FLAGS.input_length,
return_tensors='np',
)
output_tokens = np.concatenate([prefix.input_ids, inputs.input_ids], axis=1)
bos_tokens = np.full(
(output_tokens.shape[0], 1), tokenizer.bos_token_id, dtype=np.int32
)
input_tokens = np.concatenate([bos_tokens, output_tokens[:, :-1]], axis=-1)
input_mask = np.concatenate(
[prefix.attention_mask, inputs.attention_mask], axis=1
)
if FLAGS.add_bos_token:
bos_mask = np.ones_like(input_mask[:, :1])
else:
bos_mask = np.zeros_like(input_mask[:, :1])
input_mask = np.concatenate([bos_mask, input_mask[:, :-1]], axis=1)
output_mask = np.concatenate(
[np.zeros_like(prefix.attention_mask), inputs.attention_mask], axis=1
)
batch = dict(
input_tokens=input_tokens,
output_tokens=output_tokens,
input_mask=input_mask,
output_mask=output_mask,
)
with mesh:
loglikelihood, is_greedy, sharded_rng = forward_loglikelihood(
params, sharded_rng, batch
)
loglikelihood, is_greedy = jax.device_get((loglikelihood, is_greedy))
return loglikelihood, is_greedy
@staticmethod
def loglikelihood_rolling(text):
nonlocal sharded_rng
inputs = tokenizer(
text,
padding='longest',
truncation=False,
max_length=np.iinfo(np.int32).max,
return_tensors='np',
)
batch_size = inputs.input_ids.shape[0]
output_tokens = inputs.input_ids
attention_mask = inputs.attention_mask
if output_tokens.shape[1] < FLAGS.seq_length:
padding_length = FLAGS.seq_length - output_tokens.shape[1]
pad_tokens = np.full(
(batch_size, padding_length), tokenizer.pad_token_id, dtype=np.int32
)
output_tokens = np.concatenate([output_tokens, pad_tokens], axis=-1)
pad_mask = np.zeros(
(batch_size, padding_length), dtype=inputs.attention_mask.dtype
)
attention_mask = np.concatenate([attention_mask, pad_mask], axis=-1)
bos_tokens = np.full(
(batch_size, 1), tokenizer.bos_token_id, dtype=np.int32
)
input_tokens = np.concatenate([bos_tokens, output_tokens[:, :-1]], axis=-1)
bos_mask = np.ones((batch_size, 1), dtype=inputs.attention_mask.dtype)
total_seq_length = output_tokens.shape[1]
total_loglikelihood = 0.0
total_is_greedy = True
# Sliding window
for i in range(0, total_seq_length, FLAGS.seq_length):
# Last window
if i + FLAGS.seq_length > total_seq_length:
last_output_mask = np.copy(attention_mask[:, -FLAGS.seq_length:])
last_output_mask[:, :i - total_seq_length] = 0.0
batch = dict(
input_tokens=input_tokens[:, -FLAGS.seq_length:],
output_tokens=output_tokens[:, -FLAGS.seq_length:],
input_mask=attention_mask[:, -FLAGS.seq_length:],
output_mask=last_output_mask,
)
# Normal window
else:
batch = dict(
input_tokens=input_tokens[:, i:i + FLAGS.seq_length],
output_tokens=output_tokens[:, i:i + FLAGS.seq_length],
input_mask=attention_mask[:, i:i + FLAGS.seq_length],
output_mask=attention_mask[:, i:i + FLAGS.seq_length],
)
with mesh:
loglikelihood, is_greedy, sharded_rng = forward_loglikelihood(
params, sharded_rng, batch
)
loglikelihood, is_greedy = jax.device_get((loglikelihood, is_greedy))
total_loglikelihood += loglikelihood
total_is_greedy = np.logical_and(is_greedy, total_is_greedy)
return total_loglikelihood, total_is_greedy
@staticmethod
def generate(text, temperature):
nonlocal sharded_rng
inputs = prefix_tokenizer(
text,
padding='max_length',
truncation=True,
max_length=FLAGS.input_length,
return_tensors='np',
)
input_tokens = inputs.input_ids
input_mask = inputs.attention_mask
if FLAGS.add_bos_token:
input_tokens[:, 0] = tokenizer.bos_token_id
input_mask[:, 0] = 1
batch = dict(
input_tokens=input_tokens,
attention_mask=input_mask,
)
with mesh:
output, sharded_rng = forward_generate(
params, sharded_rng, batch, temperature
)
output = jax.device_get(output)
output_text = []
for text in list(tokenizer.batch_decode(output)):
if tokenizer.eos_token in text:
text = text.split(tokenizer.eos_token, maxsplit=1)[0]
output_text.append(text)
return output_text
@staticmethod
def greedy_until(prefix_text, until, max_length):
nonlocal sharded_rng
all_outputs = []
for pf, ut in zip(prefix_text, until):
if isinstance(ut, str):
ut = [ut]
total_length = 0
total_generated = ''
while total_length < max_length:
pf_tokens = tokenizer(
pf,
padding=False,
truncation=False,
max_length=np.iinfo(np.int32).max,
return_tensors='np',
)
input_tokens = pf_tokens.input_ids
attention_mask = pf_tokens.attention_mask
if input_tokens.shape[1] < FLAGS.input_length:
extra = FLAGS.input_length - input_tokens.shape[1]
pad_tokens = np.full(
(1, extra), tokenizer.pad_token_id, dtype=np.int32
)
input_tokens = np.concatenate(
[pad_tokens, input_tokens], axis=1
)
pad_attention = np.zeros((1, extra), dtype=attention_mask.dtype)
attention_mask = np.concatenate(
[pad_attention, attention_mask], axis=1
)
elif input_tokens.shape[1] > FLAGS.input_length:
input_tokens = input_tokens[:, -FLAGS.input_length:]
attention_mask = attention_mask[:, -FLAGS.input_length:]
if FLAGS.add_bos_token:
input_tokens[:, 0] = tokenizer.bos_token_id
attention_mask[:, 0] = 1
batch = dict(input_tokens=input_tokens, attention_mask=attention_mask)
with mesh:
output, sharded_rng = forward_greedy_generate(
params, sharded_rng, batch
)
output = jax.device_get(output)
total_length += output.shape[1]
output_text = tokenizer.batch_decode(output)[0]
total_generated = total_generated + output_text
pf = pf + output_text
done = False
for s in ut:
if s in total_generated:
total_generated = total_generated.split(s, maxsplit=1)[0]
done = True
if done:
break
all_outputs.append(total_generated)
return all_outputs
server = ModelServer(FLAGS.lm_server)
server.run()
if __name__ == "__main__":
mlxu.run(main)

View File

@@ -0,0 +1,272 @@
import pprint
from functools import partial
from tqdm import tqdm, trange
import numpy as np
import mlxu
import jax
import jax.numpy as jnp
from jax.experimental.pjit import pjit, with_sharding_constraint
from jax.sharding import PartitionSpec as PS
from flax.training.train_state import TrainState
from EasyLM.data import DatasetFactory
from EasyLM.checkpoint import StreamingCheckpointer
from EasyLM.optimizers import OptimizerFactory
from EasyLM.jax_utils import (
JaxRNG, JaxDistributedConfig, next_rng, match_partition_rules,
cross_entropy_loss_and_accuracy, global_norm, get_float_dtype_by_name,
set_random_seed, average_metrics, get_weight_decay_mask,
make_shard_and_gather_fns, tree_apply
)
from EasyLM.models.gptj.gptj_model import GPTJConfig, FlaxGPTJForCausalLMModule
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
seed=42,
mesh_dim='1,-1,1',
dtype='fp32',
total_steps=10000,
load_gptj_config='',
update_gptj_config='',
load_checkpoint='',
load_dataset_state='',
log_freq=50,
save_model_freq=0,
save_milestone_freq=0,
eval_steps=0,
tokenizer=GPTJConfig.get_tokenizer_config(),
train_dataset=DatasetFactory.get_default_config(),
eval_dataset=DatasetFactory.get_default_config(),
optimizer=OptimizerFactory.get_default_config(),
checkpointer=StreamingCheckpointer.get_default_config(),
gptj=GPTJConfig.get_default_config(),
logger=mlxu.WandBLogger.get_default_config(),
log_all_worker=False,
jax_distributed=JaxDistributedConfig.get_default_config(),
)
def main(argv):
JaxDistributedConfig.initialize(FLAGS.jax_distributed)
variant = mlxu.get_user_flags(FLAGS, FLAGS_DEF)
flags_config_dict = mlxu.user_flags_to_config_dict(FLAGS, FLAGS_DEF)
logger = mlxu.WandBLogger(
config=FLAGS.logger,
variant=variant,
enable=FLAGS.log_all_worker or (jax.process_index() == 0),
)
set_random_seed(FLAGS.seed)
tokenizer = GPTJConfig.get_tokenizer(FLAGS.tokenizer)
dataset = DatasetFactory.load_dataset(FLAGS.train_dataset, tokenizer)
if FLAGS.load_dataset_state != '':
dataset.load_state_dict(mlxu.load_pickle(FLAGS.load_dataset_state))
if FLAGS.eval_steps > 0:
eval_dataset = DatasetFactory.load_dataset(
FLAGS.eval_dataset, dataset.tokenizer
)
eval_iterator = iter(eval_dataset)
seq_length = dataset.seq_length
if FLAGS.load_gptj_config != '':
gptj_config = GPTJConfig.load_config(FLAGS.load_gptj_config)
else:
gptj_config = GPTJConfig(**FLAGS.gptj)
if FLAGS.update_gptj_config != '':
gptj_config.update(dict(eval(FLAGS.update_gptj_config)))
gptj_config.update(dict(
bos_token_id=dataset.tokenizer.bos_token_id,
eos_token_id=dataset.tokenizer.eos_token_id,
))
if gptj_config.vocab_size < dataset.vocab_size:
gptj_config.update(dict(vocab_size=dataset.vocab_size))
model = FlaxGPTJForCausalLMModule(
gptj_config, dtype=get_float_dtype_by_name(FLAGS.dtype)
)
optimizer, optimizer_info = OptimizerFactory.get_optimizer(
FLAGS.optimizer,
get_weight_decay_mask(GPTJConfig.get_weight_decay_exclusions()),
)
def create_trainstate_from_params(params):
return TrainState.create(params=params, tx=optimizer, apply_fn=None)
def init_fn(rng):
rng_generator = JaxRNG(rng)
params = model.init(
input_ids=jnp.zeros((4, seq_length), dtype=jnp.int32),
position_ids=jnp.zeros((4, seq_length), dtype=jnp.int32),
attention_mask=jnp.ones((4, seq_length), dtype=jnp.int32),
rngs=rng_generator(gptj_config.rng_keys()),
)
return TrainState.create(params=params, tx=optimizer, apply_fn=None)
def train_step(train_state, rng, batch):
rng_generator = JaxRNG(rng)
batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
def loss_and_accuracy(params):
logits = model.apply(
params, batch['input_tokens'], deterministic=False,
rngs=rng_generator(gptj_config.rng_keys()),
).logits
return cross_entropy_loss_and_accuracy(
logits, batch['target_tokens'], batch['loss_masks']
)
grad_fn = jax.value_and_grad(loss_and_accuracy, has_aux=True)
(loss, accuracy), grads = grad_fn(train_state.params)
train_state = train_state.apply_gradients(grads=grads)
metrics = dict(
loss=loss,
accuracy=accuracy,
learning_rate=optimizer_info['learning_rate_schedule'](train_state.step),
gradient_norm=global_norm(grads),
param_norm=global_norm(train_state.params),
)
return train_state, rng_generator(), metrics
def eval_step(train_state, rng, batch):
rng_generator = JaxRNG(rng)
batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
logits = model.apply(
train_state.params, batch['input_tokens'], deterministic=True,
rngs=rng_generator(gptj_config.rng_keys()),
).logits
loss, accuracy = cross_entropy_loss_and_accuracy(
logits, batch['target_tokens'], batch['loss_masks']
)
metrics = dict(
eval_loss=loss,
eval_accuracy=accuracy,
)
return rng_generator(), metrics
train_state_shapes = jax.eval_shape(init_fn, next_rng())
train_state_partition = match_partition_rules(
GPTJConfig.get_partition_rules(), train_state_shapes
)
shard_fns, gather_fns = make_shard_and_gather_fns(
train_state_partition, train_state_shapes
)
checkpointer = StreamingCheckpointer(
FLAGS.checkpointer, logger.output_dir,
enable=jax.process_index() == 0,
)
sharded_init_fn = pjit(
init_fn,
in_shardings=PS(),
out_shardings=train_state_partition
)
sharded_create_trainstate_from_params = pjit(
create_trainstate_from_params,
in_shardings=(train_state_partition.params, ),
out_shardings=train_state_partition,
donate_argnums=(0, ),
)
sharded_train_step = pjit(
train_step,
in_shardings=(train_state_partition, PS(), PS()),
out_shardings=(train_state_partition, PS(), PS()),
donate_argnums=(0, 1),
)
sharded_eval_step = pjit(
eval_step,
in_shardings=(train_state_partition, PS(), PS()),
out_shardings=(PS(), PS()),
donate_argnums=(1,),
)
def save_checkpoint(train_state, milestone=False):
step = int(jax.device_get(train_state.step))
metadata = dict(
step=step,
variant=variant,
flags=flags_config_dict,
gptj_config=gptj_config.to_dict(),
)
checkpointer.save_all(
train_state=train_state,
gather_fns=gather_fns,
metadata=metadata,
dataset=dataset.get_state_dict(),
milestone=milestone,
)
mesh = GPTJConfig.get_jax_mesh(FLAGS.mesh_dim)
with mesh:
train_state, restored_params = None, None
if FLAGS.load_checkpoint != '':
load_type, load_path = FLAGS.load_checkpoint.split('::', 1)
if load_type == 'huggingface':
restored_params = tree_apply(
shard_fns.params, gptj_config.load_pretrained(load_path)
)
train_state = None
else:
train_state, restored_params = checkpointer.load_trainstate_checkpoint(
FLAGS.load_checkpoint, train_state_shapes, shard_fns
)
if train_state is None and restored_params is None:
# Initialize from scratch
train_state = sharded_init_fn(next_rng())
elif train_state is None and restored_params is not None:
# Restore from params but initialize train_state
train_state = sharded_create_trainstate_from_params(restored_params)
del restored_params
start_step = int(jax.device_get(train_state.step))
if FLAGS.save_model_freq > 0:
save_checkpoint(train_state)
sharded_rng = next_rng()
step_counter = trange(start_step, FLAGS.total_steps, ncols=0)
for step, (batch, dataset_metrics) in zip(step_counter, dataset):
train_state, sharded_rng, metrics = sharded_train_step(
train_state, sharded_rng, batch
)
if step % FLAGS.log_freq == 0:
if FLAGS.eval_steps > 0:
eval_metric_list = []
for _ in range(FLAGS.eval_steps):
eval_batch, _ = next(eval_iterator)
sharded_rng, eval_metrics = sharded_eval_step(
train_state, sharded_rng, eval_batch
)
eval_metric_list.append(eval_metrics)
metrics.update(average_metrics(eval_metric_list))
log_metrics = {"step": step}
log_metrics.update(metrics)
log_metrics.update(dataset_metrics)
log_metrics = jax.device_get(log_metrics)
logger.log(log_metrics)
tqdm.write("\n" + pprint.pformat(log_metrics) + "\n")
if FLAGS.save_milestone_freq > 0 and (step + 1) % FLAGS.save_milestone_freq == 0:
save_checkpoint(train_state, milestone=True)
elif FLAGS.save_model_freq > 0 and (step + 1) % FLAGS.save_model_freq == 0:
save_checkpoint(train_state)
if FLAGS.save_model_freq > 0:
save_checkpoint(train_state)
if __name__ == "__main__":
mlxu.run(main)