397 lines
15 KiB
Python
397 lines
15 KiB
Python
import pprint
|
|
from functools import partial
|
|
|
|
import numpy as np
|
|
import mlxu
|
|
|
|
import jax
|
|
import jax.numpy as jnp
|
|
from jax.experimental.pjit import pjit
|
|
from jax.sharding import PartitionSpec as PS
|
|
import flax
|
|
from flax import linen as nn
|
|
from flax.jax_utils import prefetch_to_device
|
|
from flax.training.train_state import TrainState
|
|
import optax
|
|
from transformers import GenerationConfig, FlaxLogitsProcessorList
|
|
|
|
from EasyLM.checkpoint import StreamingCheckpointer
|
|
from EasyLM.serving import LMServer
|
|
from EasyLM.jax_utils import (
|
|
JaxRNG, JaxDistributedConfig, next_rng, match_partition_rules, tree_apply,
|
|
set_random_seed, get_float_dtype_by_name, make_shard_and_gather_fns,
|
|
with_sharding_constraint, FlaxTemperatureLogitsWarper
|
|
)
|
|
from EasyLM.models.gptj.gptj_model import (
|
|
GPTJConfig, FlaxGPTJForCausalLMModule, FlaxGPTJForCausalLM
|
|
)
|
|
|
|
|
|
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
|
|
seed=42,
|
|
initialize_jax_distributed=False,
|
|
mesh_dim='1,-1,1',
|
|
dtype='bf16',
|
|
input_length=1024,
|
|
seq_length=2048,
|
|
top_k=50,
|
|
top_p=1.0,
|
|
do_sample=True,
|
|
num_beams=1,
|
|
add_bos_token=False,
|
|
load_gptj_config='',
|
|
load_checkpoint='',
|
|
tokenizer=GPTJConfig.get_tokenizer_config(),
|
|
lm_server=LMServer.get_default_config(),
|
|
jax_distributed=JaxDistributedConfig.get_default_config(),
|
|
)
|
|
|
|
|
|
def main(argv):
|
|
JaxDistributedConfig.initialize(FLAGS.jax_distributed)
|
|
set_random_seed(FLAGS.seed)
|
|
|
|
prefix_tokenizer = GPTJConfig.get_tokenizer(
|
|
FLAGS.tokenizer, truncation_side='left', padding_side='left'
|
|
)
|
|
tokenizer = GPTJConfig.get_tokenizer(
|
|
FLAGS.tokenizer, truncation_side='right', padding_side='right'
|
|
)
|
|
|
|
with jax.default_device(jax.devices("cpu")[0]):
|
|
gptj_config = GPTJConfig.load_config(FLAGS.load_gptj_config)
|
|
load_type, load_path = FLAGS.load_checkpoint.split('::', 1)
|
|
if load_type == 'huggingface':
|
|
params = gptj_config.load_pretrained(load_path)
|
|
else:
|
|
_, params = StreamingCheckpointer.load_trainstate_checkpoint(
|
|
FLAGS.load_checkpoint, disallow_trainstate=True
|
|
)
|
|
|
|
hf_model = FlaxGPTJForCausalLM(
|
|
gptj_config,
|
|
input_shape=(1, FLAGS.seq_length),
|
|
seed=FLAGS.seed,
|
|
_do_init=False
|
|
)
|
|
|
|
model_ps = match_partition_rules(
|
|
GPTJConfig.get_partition_rules(), params
|
|
)
|
|
shard_fns, _ = make_shard_and_gather_fns(
|
|
model_ps, get_float_dtype_by_name(FLAGS.dtype)
|
|
)
|
|
|
|
@partial(
|
|
pjit,
|
|
in_shardings=(model_ps, PS(), PS()),
|
|
out_shardings=(PS(), PS(), PS())
|
|
)
|
|
def forward_loglikelihood(params, rng, batch):
|
|
batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
|
|
rng_generator = JaxRNG(rng)
|
|
input_tokens = batch['input_tokens']
|
|
output_tokens = batch['output_tokens']
|
|
input_mask = batch['input_mask']
|
|
output_mask = batch['output_mask']
|
|
|
|
logits = hf_model.module.apply(
|
|
params, input_tokens, attention_mask=input_mask,
|
|
deterministic=True, rngs=rng_generator(gptj_config.rng_keys()),
|
|
).logits
|
|
if gptj_config.n_real_tokens is not None:
|
|
logits = logits.at[:, :, gptj_config.n_real_tokens:].set(-1e8)
|
|
loglikelihood = -optax.softmax_cross_entropy_with_integer_labels(
|
|
logits, output_tokens
|
|
)
|
|
loglikelihood = jnp.sum(loglikelihood * output_mask, axis=-1)
|
|
match_count = jnp.sum(
|
|
(jnp.argmax(logits, axis=-1) == output_tokens) * output_mask,
|
|
axis=-1
|
|
)
|
|
total = jnp.sum(output_mask, axis=-1)
|
|
is_greedy = match_count == total
|
|
return loglikelihood, is_greedy, rng_generator()
|
|
|
|
|
|
@partial(
|
|
pjit,
|
|
in_shardings=(model_ps, PS(), PS(), PS()),
|
|
out_shardings=(PS(), PS())
|
|
)
|
|
def forward_generate(params, rng, batch, temperature):
|
|
batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
|
|
rng_generator = JaxRNG(rng)
|
|
output = hf_model.generate(
|
|
batch['input_tokens'],
|
|
attention_mask=batch['attention_mask'],
|
|
params=params['params'],
|
|
prng_key=rng_generator(),
|
|
logits_processor=FlaxLogitsProcessorList(
|
|
[FlaxTemperatureLogitsWarper(temperature)]
|
|
),
|
|
generation_config=GenerationConfig(
|
|
max_new_tokens=FLAGS.seq_length - FLAGS.input_length,
|
|
pad_token_id=tokenizer.eos_token_id,
|
|
bos_token_id=tokenizer.bos_token_id,
|
|
eos_token_id=tokenizer.eos_token_id,
|
|
do_sample=FLAGS.do_sample,
|
|
num_beams=FLAGS.num_beams,
|
|
top_k=FLAGS.top_k,
|
|
top_p=FLAGS.top_p,
|
|
)
|
|
).sequences[:, batch['input_tokens'].shape[1]:]
|
|
return output, rng_generator()
|
|
|
|
@partial(
|
|
pjit,
|
|
in_shardings=(model_ps, PS(), PS()),
|
|
out_shardings=(PS(), PS())
|
|
)
|
|
def forward_greedy_generate(params, rng, batch):
|
|
batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
|
|
rng_generator = JaxRNG(rng)
|
|
output = hf_model.generate(
|
|
batch['input_tokens'],
|
|
attention_mask=batch['attention_mask'],
|
|
params=params['params'],
|
|
prng_key=rng_generator(),
|
|
generation_config=GenerationConfig(
|
|
max_new_tokens=FLAGS.seq_length - FLAGS.input_length,
|
|
pad_token_id=tokenizer.eos_token_id,
|
|
bos_token_id=tokenizer.bos_token_id,
|
|
eos_token_id=tokenizer.eos_token_id,
|
|
do_sample=False,
|
|
num_beams=1,
|
|
)
|
|
).sequences[:, batch['input_tokens'].shape[1]:]
|
|
return output, rng_generator()
|
|
|
|
mesh = GPTJConfig.get_jax_mesh(FLAGS.mesh_dim)
|
|
with mesh:
|
|
params = tree_apply(shard_fns, params)
|
|
sharded_rng = next_rng()
|
|
|
|
class ModelServer(LMServer):
|
|
|
|
@staticmethod
|
|
def loglikelihood(prefix_text, text):
|
|
nonlocal sharded_rng
|
|
prefix = prefix_tokenizer(
|
|
prefix_text,
|
|
padding='max_length',
|
|
truncation=True,
|
|
max_length=FLAGS.input_length,
|
|
return_tensors='np',
|
|
)
|
|
inputs = tokenizer(
|
|
text,
|
|
padding='max_length',
|
|
truncation=True,
|
|
max_length=FLAGS.seq_length - FLAGS.input_length,
|
|
return_tensors='np',
|
|
)
|
|
output_tokens = np.concatenate([prefix.input_ids, inputs.input_ids], axis=1)
|
|
bos_tokens = np.full(
|
|
(output_tokens.shape[0], 1), tokenizer.bos_token_id, dtype=np.int32
|
|
)
|
|
input_tokens = np.concatenate([bos_tokens, output_tokens[:, :-1]], axis=-1)
|
|
input_mask = np.concatenate(
|
|
[prefix.attention_mask, inputs.attention_mask], axis=1
|
|
)
|
|
if FLAGS.add_bos_token:
|
|
bos_mask = np.ones_like(input_mask[:, :1])
|
|
else:
|
|
bos_mask = np.zeros_like(input_mask[:, :1])
|
|
|
|
input_mask = np.concatenate([bos_mask, input_mask[:, :-1]], axis=1)
|
|
output_mask = np.concatenate(
|
|
[np.zeros_like(prefix.attention_mask), inputs.attention_mask], axis=1
|
|
)
|
|
batch = dict(
|
|
input_tokens=input_tokens,
|
|
output_tokens=output_tokens,
|
|
input_mask=input_mask,
|
|
output_mask=output_mask,
|
|
)
|
|
with mesh:
|
|
loglikelihood, is_greedy, sharded_rng = forward_loglikelihood(
|
|
params, sharded_rng, batch
|
|
)
|
|
loglikelihood, is_greedy = jax.device_get((loglikelihood, is_greedy))
|
|
return loglikelihood, is_greedy
|
|
|
|
@staticmethod
|
|
def loglikelihood_rolling(text):
|
|
nonlocal sharded_rng
|
|
inputs = tokenizer(
|
|
text,
|
|
padding='longest',
|
|
truncation=False,
|
|
max_length=np.iinfo(np.int32).max,
|
|
return_tensors='np',
|
|
)
|
|
batch_size = inputs.input_ids.shape[0]
|
|
output_tokens = inputs.input_ids
|
|
attention_mask = inputs.attention_mask
|
|
|
|
if output_tokens.shape[1] < FLAGS.seq_length:
|
|
padding_length = FLAGS.seq_length - output_tokens.shape[1]
|
|
pad_tokens = np.full(
|
|
(batch_size, padding_length), tokenizer.pad_token_id, dtype=np.int32
|
|
)
|
|
output_tokens = np.concatenate([output_tokens, pad_tokens], axis=-1)
|
|
pad_mask = np.zeros(
|
|
(batch_size, padding_length), dtype=inputs.attention_mask.dtype
|
|
)
|
|
attention_mask = np.concatenate([attention_mask, pad_mask], axis=-1)
|
|
|
|
bos_tokens = np.full(
|
|
(batch_size, 1), tokenizer.bos_token_id, dtype=np.int32
|
|
)
|
|
input_tokens = np.concatenate([bos_tokens, output_tokens[:, :-1]], axis=-1)
|
|
bos_mask = np.ones((batch_size, 1), dtype=inputs.attention_mask.dtype)
|
|
total_seq_length = output_tokens.shape[1]
|
|
|
|
total_loglikelihood = 0.0
|
|
total_is_greedy = True
|
|
# Sliding window
|
|
for i in range(0, total_seq_length, FLAGS.seq_length):
|
|
# Last window
|
|
if i + FLAGS.seq_length > total_seq_length:
|
|
last_output_mask = np.copy(attention_mask[:, -FLAGS.seq_length:])
|
|
last_output_mask[:, :i - total_seq_length] = 0.0
|
|
|
|
batch = dict(
|
|
input_tokens=input_tokens[:, -FLAGS.seq_length:],
|
|
output_tokens=output_tokens[:, -FLAGS.seq_length:],
|
|
input_mask=attention_mask[:, -FLAGS.seq_length:],
|
|
output_mask=last_output_mask,
|
|
)
|
|
|
|
# Normal window
|
|
else:
|
|
batch = dict(
|
|
input_tokens=input_tokens[:, i:i + FLAGS.seq_length],
|
|
output_tokens=output_tokens[:, i:i + FLAGS.seq_length],
|
|
input_mask=attention_mask[:, i:i + FLAGS.seq_length],
|
|
output_mask=attention_mask[:, i:i + FLAGS.seq_length],
|
|
)
|
|
|
|
with mesh:
|
|
loglikelihood, is_greedy, sharded_rng = forward_loglikelihood(
|
|
params, sharded_rng, batch
|
|
)
|
|
loglikelihood, is_greedy = jax.device_get((loglikelihood, is_greedy))
|
|
|
|
total_loglikelihood += loglikelihood
|
|
total_is_greedy = np.logical_and(is_greedy, total_is_greedy)
|
|
|
|
return total_loglikelihood, total_is_greedy
|
|
|
|
@staticmethod
|
|
def generate(text, temperature):
|
|
nonlocal sharded_rng
|
|
inputs = prefix_tokenizer(
|
|
text,
|
|
padding='max_length',
|
|
truncation=True,
|
|
max_length=FLAGS.input_length,
|
|
return_tensors='np',
|
|
)
|
|
input_tokens = inputs.input_ids
|
|
input_mask = inputs.attention_mask
|
|
if FLAGS.add_bos_token:
|
|
input_tokens[:, 0] = tokenizer.bos_token_id
|
|
input_mask[:, 0] = 1
|
|
batch = dict(
|
|
input_tokens=input_tokens,
|
|
attention_mask=input_mask,
|
|
)
|
|
with mesh:
|
|
output, sharded_rng = forward_generate(
|
|
params, sharded_rng, batch, temperature
|
|
)
|
|
output = jax.device_get(output)
|
|
output_text = []
|
|
for text in list(tokenizer.batch_decode(output)):
|
|
if tokenizer.eos_token in text:
|
|
text = text.split(tokenizer.eos_token, maxsplit=1)[0]
|
|
output_text.append(text)
|
|
|
|
return output_text
|
|
|
|
@staticmethod
|
|
def greedy_until(prefix_text, until, max_length):
|
|
nonlocal sharded_rng
|
|
all_outputs = []
|
|
for pf, ut in zip(prefix_text, until):
|
|
if isinstance(ut, str):
|
|
ut = [ut]
|
|
total_length = 0
|
|
total_generated = ''
|
|
|
|
while total_length < max_length:
|
|
pf_tokens = tokenizer(
|
|
pf,
|
|
padding=False,
|
|
truncation=False,
|
|
max_length=np.iinfo(np.int32).max,
|
|
return_tensors='np',
|
|
)
|
|
input_tokens = pf_tokens.input_ids
|
|
attention_mask = pf_tokens.attention_mask
|
|
|
|
if input_tokens.shape[1] < FLAGS.input_length:
|
|
extra = FLAGS.input_length - input_tokens.shape[1]
|
|
pad_tokens = np.full(
|
|
(1, extra), tokenizer.pad_token_id, dtype=np.int32
|
|
)
|
|
input_tokens = np.concatenate(
|
|
[pad_tokens, input_tokens], axis=1
|
|
)
|
|
pad_attention = np.zeros((1, extra), dtype=attention_mask.dtype)
|
|
attention_mask = np.concatenate(
|
|
[pad_attention, attention_mask], axis=1
|
|
)
|
|
elif input_tokens.shape[1] > FLAGS.input_length:
|
|
input_tokens = input_tokens[:, -FLAGS.input_length:]
|
|
attention_mask = attention_mask[:, -FLAGS.input_length:]
|
|
|
|
if FLAGS.add_bos_token:
|
|
input_tokens[:, 0] = tokenizer.bos_token_id
|
|
attention_mask[:, 0] = 1
|
|
|
|
batch = dict(input_tokens=input_tokens, attention_mask=attention_mask)
|
|
|
|
with mesh:
|
|
output, sharded_rng = forward_greedy_generate(
|
|
params, sharded_rng, batch
|
|
)
|
|
output = jax.device_get(output)
|
|
|
|
total_length += output.shape[1]
|
|
output_text = tokenizer.batch_decode(output)[0]
|
|
total_generated = total_generated + output_text
|
|
pf = pf + output_text
|
|
|
|
done = False
|
|
for s in ut:
|
|
if s in total_generated:
|
|
total_generated = total_generated.split(s, maxsplit=1)[0]
|
|
done = True
|
|
if done:
|
|
break
|
|
|
|
all_outputs.append(total_generated)
|
|
|
|
return all_outputs
|
|
|
|
|
|
server = ModelServer(FLAGS.lm_server)
|
|
server.run()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
mlxu.run(main)
|