404 lines
13 KiB
Python
404 lines
13 KiB
Python
|
|
import os
|
||
|
|
import math
|
||
|
|
from typing import Any, Mapping, Text, Tuple, Union, NamedTuple
|
||
|
|
from functools import partial
|
||
|
|
import re
|
||
|
|
import dataclasses
|
||
|
|
import random
|
||
|
|
from ml_collections import ConfigDict
|
||
|
|
from ml_collections.config_dict.config_dict import placeholder
|
||
|
|
|
||
|
|
import flax
|
||
|
|
import jax
|
||
|
|
import jax.numpy as jnp
|
||
|
|
from jax.sharding import PartitionSpec as PS
|
||
|
|
from jax.sharding import Mesh
|
||
|
|
from jax.experimental import mesh_utils
|
||
|
|
from jax.experimental.pjit import with_sharding_constraint as _with_sharding_constraint
|
||
|
|
from jax.experimental.pjit import pjit
|
||
|
|
from jax.interpreters import pxla
|
||
|
|
import numpy as np
|
||
|
|
from transformers import FlaxLogitsWarper
|
||
|
|
|
||
|
|
|
||
|
|
class JaxRNG(object):
|
||
|
|
""" A convenient stateful Jax RNG wrapper. Can be used to wrap RNG inside
|
||
|
|
pure function.
|
||
|
|
"""
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
def from_seed(cls, seed):
|
||
|
|
return cls(jax.random.PRNGKey(seed))
|
||
|
|
|
||
|
|
def __init__(self, rng):
|
||
|
|
self.rng = rng
|
||
|
|
|
||
|
|
def __call__(self, keys=None):
|
||
|
|
if keys is None:
|
||
|
|
self.rng, split_rng = jax.random.split(self.rng)
|
||
|
|
return split_rng
|
||
|
|
elif isinstance(keys, int):
|
||
|
|
split_rngs = jax.random.split(self.rng, num=keys + 1)
|
||
|
|
self.rng = split_rngs[0]
|
||
|
|
return tuple(split_rngs[1:])
|
||
|
|
else:
|
||
|
|
split_rngs = jax.random.split(self.rng, num=len(keys) + 1)
|
||
|
|
self.rng = split_rngs[0]
|
||
|
|
return {key: val for key, val in zip(keys, split_rngs[1:])}
|
||
|
|
|
||
|
|
|
||
|
|
class JaxDistributedConfig(object):
|
||
|
|
""" Utility class for initializing JAX distributed. """
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def get_default_config(updates=None):
|
||
|
|
config = ConfigDict()
|
||
|
|
config.initialize_jax_distributed = False
|
||
|
|
config.coordinator_address = placeholder(str)
|
||
|
|
config.num_processes = placeholder(int)
|
||
|
|
config.process_id = placeholder(int)
|
||
|
|
config.local_device_ids = placeholder(str)
|
||
|
|
|
||
|
|
if updates is not None:
|
||
|
|
config.update(ConfigDict(updates).copy_and_resolve_references())
|
||
|
|
return config
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
def initialize(cls, config):
|
||
|
|
config = cls.get_default_config(config)
|
||
|
|
if config.initialize_jax_distributed:
|
||
|
|
if config.local_device_ids is not None:
|
||
|
|
local_device_ids = [int(x) for x in config.local_device_ids.split(',')]
|
||
|
|
else:
|
||
|
|
local_device_ids = None
|
||
|
|
|
||
|
|
jax.distributed.initialize(
|
||
|
|
coordinator_address=config.coordinator_address,
|
||
|
|
num_processes=config.num_processes,
|
||
|
|
process_id=config.process_id,
|
||
|
|
local_device_ids=local_device_ids,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
class FlaxTemperatureLogitsWarper(FlaxLogitsWarper):
|
||
|
|
""" JIT traceable version of FlaxLogitsWarper that performs temperature scaling."""
|
||
|
|
def __init__(self, temperature):
|
||
|
|
self.temperature = temperature
|
||
|
|
|
||
|
|
def __call__(self, input_ids, scores, cur_len):
|
||
|
|
return scores / jnp.clip(self.temperature, a_min=1e-8)
|
||
|
|
|
||
|
|
|
||
|
|
def make_shard_and_gather_fns(partition_specs, dtype_specs=None):
|
||
|
|
""" Create pytree of sharding and gathering functions from pytree of
|
||
|
|
partition specs.
|
||
|
|
"""
|
||
|
|
float_dtypes = (jnp.bfloat16, jnp.float16, jnp.float32, jnp.float64)
|
||
|
|
|
||
|
|
def make_to_dtype_fn(dtype_spec):
|
||
|
|
def to_dtype(tensor):
|
||
|
|
if dtype_specs in float_dtypes and getattr(tensor, 'dtype', None) in float_dtypes:
|
||
|
|
# Convert all float tensors to the same dtype
|
||
|
|
return tensor.astype(dtype_specs)
|
||
|
|
elif hasattr(dtype_spec, 'dtype') and hasattr(tensor, 'dtype'):
|
||
|
|
return tensor.astype(dtype_spec.dtype)
|
||
|
|
return tensor
|
||
|
|
return to_dtype
|
||
|
|
|
||
|
|
def make_shard_fn(partition_spec, dtype_spec=None):
|
||
|
|
jax_shard_function = pjit(
|
||
|
|
make_to_dtype_fn(dtype_spec),
|
||
|
|
in_shardings=None,
|
||
|
|
out_shardings=partition_spec
|
||
|
|
)
|
||
|
|
def shard_fn(tensor):
|
||
|
|
return jax_shard_function(tensor).block_until_ready()
|
||
|
|
return shard_fn
|
||
|
|
|
||
|
|
def make_gather_fn(partition_spec, dtype_spec=None):
|
||
|
|
jax_gather_fn = pjit(
|
||
|
|
make_to_dtype_fn(dtype_spec),
|
||
|
|
in_shardings=partition_spec,
|
||
|
|
out_shardings=None
|
||
|
|
)
|
||
|
|
def gather_fn(tensor):
|
||
|
|
return jax.device_get(jax_gather_fn(tensor))
|
||
|
|
return gather_fn
|
||
|
|
|
||
|
|
if dtype_specs is None or dtype_specs in float_dtypes:
|
||
|
|
shard_fns = jax.tree_util.tree_map(make_shard_fn, partition_specs)
|
||
|
|
gather_fns = jax.tree_util.tree_map(make_gather_fn, partition_specs)
|
||
|
|
else:
|
||
|
|
shard_fns = jax.tree_util.tree_map(
|
||
|
|
make_shard_fn, partition_specs, dtype_specs
|
||
|
|
)
|
||
|
|
gather_fns = jax.tree_util.tree_map(
|
||
|
|
make_gather_fn, partition_specs, dtype_specs
|
||
|
|
)
|
||
|
|
return shard_fns, gather_fns
|
||
|
|
|
||
|
|
|
||
|
|
def set_random_seed(seed):
|
||
|
|
np.random.seed(seed)
|
||
|
|
random.seed(seed)
|
||
|
|
init_rng(seed)
|
||
|
|
|
||
|
|
|
||
|
|
def get_jax_mesh(axis_dims, names):
|
||
|
|
if axis_dims.startswith('!'):
|
||
|
|
# Allow splitting a physical mesh axis if needed
|
||
|
|
mesh_axis_splitting = True
|
||
|
|
axis_dims = axis_dims[1:]
|
||
|
|
else:
|
||
|
|
mesh_axis_splitting = False
|
||
|
|
|
||
|
|
if ':' in axis_dims:
|
||
|
|
dims = []
|
||
|
|
dim_names = []
|
||
|
|
for axis in axis_dims.split(','):
|
||
|
|
name, dim = axis.split(':')
|
||
|
|
assert name in names
|
||
|
|
dims.append(int(dim))
|
||
|
|
dim_names.append(name)
|
||
|
|
assert(set(dim_names) == set(names))
|
||
|
|
else:
|
||
|
|
dims = [int(x) for x in axis_dims.split(',')]
|
||
|
|
dim_names = names
|
||
|
|
assert len(dims) == len(names)
|
||
|
|
mesh_shape = np.arange(jax.device_count()).reshape(dims).shape
|
||
|
|
if mesh_axis_splitting:
|
||
|
|
physical_mesh = np.array(jax.devices()).reshape(mesh_shape)
|
||
|
|
else:
|
||
|
|
physical_mesh = mesh_utils.create_device_mesh(mesh_shape)
|
||
|
|
return Mesh(physical_mesh, dim_names)
|
||
|
|
|
||
|
|
|
||
|
|
def names_in_current_mesh(*names):
|
||
|
|
""" Check if current mesh axes contain these names. """
|
||
|
|
mesh_axis_names = pxla.thread_resources.env.physical_mesh.axis_names
|
||
|
|
return set(names) <= set(mesh_axis_names)
|
||
|
|
|
||
|
|
|
||
|
|
def get_names_from_parition_spec(partition_specs):
|
||
|
|
""" Return axis names from partition specs. """
|
||
|
|
names = set()
|
||
|
|
if isinstance(partition_specs, dict):
|
||
|
|
partition_specs = partition_specs.values()
|
||
|
|
for item in partition_specs:
|
||
|
|
if item is None:
|
||
|
|
continue
|
||
|
|
elif isinstance(item, str):
|
||
|
|
names.add(item)
|
||
|
|
else:
|
||
|
|
names.update(get_names_from_parition_spec(item))
|
||
|
|
|
||
|
|
return list(names)
|
||
|
|
|
||
|
|
|
||
|
|
def with_sharding_constraint(x, partition_specs):
|
||
|
|
""" A smarter version of with_sharding_constraint that only applies the
|
||
|
|
constraint if the current mesh contains the axes in the partition specs.
|
||
|
|
"""
|
||
|
|
axis_names = get_names_from_parition_spec(partition_specs)
|
||
|
|
if names_in_current_mesh(*axis_names):
|
||
|
|
x = _with_sharding_constraint(x, partition_specs)
|
||
|
|
return x
|
||
|
|
|
||
|
|
|
||
|
|
def wrap_function_with_rng(rng):
|
||
|
|
""" To be used as decorator, automatically bookkeep a RNG for the wrapped function. """
|
||
|
|
def wrap_function(function):
|
||
|
|
def wrapped(*args, **kwargs):
|
||
|
|
nonlocal rng
|
||
|
|
rng, split_rng = jax.random.split(rng)
|
||
|
|
return function(split_rng, *args, **kwargs)
|
||
|
|
return wrapped
|
||
|
|
return wrap_function
|
||
|
|
|
||
|
|
|
||
|
|
def init_rng(seed):
|
||
|
|
global jax_utils_rng
|
||
|
|
jax_utils_rng = JaxRNG.from_seed(seed)
|
||
|
|
|
||
|
|
|
||
|
|
def next_rng(*args, **kwargs):
|
||
|
|
global jax_utils_rng
|
||
|
|
return jax_utils_rng(*args, **kwargs)
|
||
|
|
|
||
|
|
|
||
|
|
def get_metrics(metrics, unreplicate=False, stack=False):
|
||
|
|
if unreplicate:
|
||
|
|
metrics = flax.jax_utils.unreplicate(metrics)
|
||
|
|
metrics = jax.device_get(metrics)
|
||
|
|
if stack:
|
||
|
|
return jax.tree_map(lambda *args: np.stack(args), *metrics)
|
||
|
|
else:
|
||
|
|
return {key: float(val) for key, val in metrics.items()}
|
||
|
|
|
||
|
|
|
||
|
|
def mse_loss(val, target, valid=None):
|
||
|
|
if valid is None:
|
||
|
|
valid = jnp.ones((*target.shape[:2], 1))
|
||
|
|
valid = valid.astype(jnp.float32)
|
||
|
|
loss = jnp.mean(
|
||
|
|
jnp.where(
|
||
|
|
valid > 0.0,
|
||
|
|
jnp.square(val - target),
|
||
|
|
0.0
|
||
|
|
)
|
||
|
|
)
|
||
|
|
return loss
|
||
|
|
|
||
|
|
|
||
|
|
def cross_entropy_loss_and_accuracy(logits, tokens, valid=None):
|
||
|
|
if valid is None:
|
||
|
|
valid = jnp.ones(tokens.shape[:2])
|
||
|
|
valid = valid.astype(jnp.float32)
|
||
|
|
valid_text_length = jnp.maximum(jnp.sum(valid, axis=-1), 1e-10)
|
||
|
|
logits = logits.astype(jnp.float32) # for numerical stability
|
||
|
|
token_log_prob = jnp.squeeze(
|
||
|
|
jnp.take_along_axis(
|
||
|
|
jax.nn.log_softmax(logits, axis=-1),
|
||
|
|
jnp.expand_dims(tokens, -1),
|
||
|
|
axis=-1,
|
||
|
|
),
|
||
|
|
-1,
|
||
|
|
)
|
||
|
|
token_log_prob = jnp.where(valid > 0.0, token_log_prob, jnp.array(0.0))
|
||
|
|
loss = -jnp.mean(jnp.sum(token_log_prob, axis=-1) / valid_text_length)
|
||
|
|
correct = jnp.where(
|
||
|
|
valid > 0.0,
|
||
|
|
jnp.argmax(logits, axis=-1) == tokens,
|
||
|
|
jnp.array(False)
|
||
|
|
)
|
||
|
|
accuracy = jnp.mean(jnp.sum(correct, axis=-1) / valid_text_length)
|
||
|
|
return loss, accuracy
|
||
|
|
|
||
|
|
|
||
|
|
def global_norm(tree):
|
||
|
|
""" Return the global L2 norm of a pytree. """
|
||
|
|
squared = jax.tree_util.tree_map(lambda x: jnp.sum(jnp.square(x)), tree)
|
||
|
|
flattened, _ = jax.flatten_util.ravel_pytree(squared)
|
||
|
|
return jnp.sqrt(jnp.sum(flattened))
|
||
|
|
|
||
|
|
|
||
|
|
def average_metrics(metrics):
|
||
|
|
with jax.spmd_mode("allow_all"):
|
||
|
|
return jax.tree_map(
|
||
|
|
lambda *args: jnp.mean(jnp.stack(args)),
|
||
|
|
*metrics
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def get_float_dtype_by_name(dtype):
|
||
|
|
return {
|
||
|
|
'bf16': jnp.bfloat16,
|
||
|
|
'bfloat16': jnp.bfloat16,
|
||
|
|
'fp16': jnp.float16,
|
||
|
|
'float16': jnp.float16,
|
||
|
|
'fp32': jnp.float32,
|
||
|
|
'float32': jnp.float32,
|
||
|
|
'fp64': jnp.float64,
|
||
|
|
'float64': jnp.float64,
|
||
|
|
}[dtype]
|
||
|
|
|
||
|
|
|
||
|
|
def float_tensor_to_dtype(tensor, dtype):
|
||
|
|
if dtype is None or dtype == '':
|
||
|
|
return tensor
|
||
|
|
if isinstance(dtype, str):
|
||
|
|
dtype = get_float_dtype_by_name(dtype)
|
||
|
|
float_dtypes = (jnp.bfloat16, jnp.float16, jnp.float32, jnp.float64)
|
||
|
|
if getattr(tensor, 'dtype', None) in float_dtypes:
|
||
|
|
tensor = tensor.astype(dtype)
|
||
|
|
return tensor
|
||
|
|
|
||
|
|
|
||
|
|
def float_to_dtype(tree, dtype):
|
||
|
|
return jax.tree_util.tree_map(
|
||
|
|
partial(float_tensor_to_dtype, dtype=dtype), tree
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def get_gradient_checkpoint_policy(name):
|
||
|
|
return {
|
||
|
|
'everything_saveable': jax.checkpoint_policies.everything_saveable,
|
||
|
|
'nothing_saveable': jax.checkpoint_policies.nothing_saveable,
|
||
|
|
'checkpoint_dots': jax.checkpoint_policies.checkpoint_dots,
|
||
|
|
'checkpoint_dots_with_no_batch_dims': jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims,
|
||
|
|
}[name]
|
||
|
|
|
||
|
|
|
||
|
|
def tree_path_to_string(path, sep=None):
|
||
|
|
keys = []
|
||
|
|
for key in path:
|
||
|
|
if isinstance(key, jax.tree_util.SequenceKey):
|
||
|
|
keys.append(str(key.idx))
|
||
|
|
elif isinstance(key, jax.tree_util.DictKey):
|
||
|
|
keys.append(str(key.key))
|
||
|
|
elif isinstance(key, jax.tree_util.GetAttrKey):
|
||
|
|
keys.append(str(key.name))
|
||
|
|
elif isinstance(key, jax.tree_util.FlattenedIndexKey):
|
||
|
|
keys.append(str(key.key))
|
||
|
|
else:
|
||
|
|
keys.append(str(key))
|
||
|
|
if sep is None:
|
||
|
|
return tuple(keys)
|
||
|
|
return sep.join(keys)
|
||
|
|
|
||
|
|
|
||
|
|
def flatten_tree(xs, is_leaf=None, sep=None):
|
||
|
|
flattened, _ = jax.tree_util.tree_flatten_with_path(xs, is_leaf=is_leaf)
|
||
|
|
output = {}
|
||
|
|
for key, val in flattened:
|
||
|
|
output[tree_path_to_string(key, sep=sep)] = val
|
||
|
|
return output
|
||
|
|
|
||
|
|
|
||
|
|
def named_tree_map(f, tree, *rest, is_leaf=None, sep=None):
|
||
|
|
""" An extended version of jax.tree_util.tree_map, where the mapped function
|
||
|
|
f takes both the name (path) and the tree leaf as input.
|
||
|
|
"""
|
||
|
|
return jax.tree_util.tree_map_with_path(
|
||
|
|
lambda path, x, *r: f(tree_path_to_string(path, sep=sep), x, *r),
|
||
|
|
tree, *rest,
|
||
|
|
is_leaf=is_leaf
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def match_partition_rules(rules, params):
|
||
|
|
""" Returns a pytree of PartitionSpec according to rules. Supports handling
|
||
|
|
Flax TrainState and Optax optimizer state.
|
||
|
|
"""
|
||
|
|
def get_partition_spec(name, leaf):
|
||
|
|
if len(leaf.shape) == 0 or np.prod(leaf.shape) == 1:
|
||
|
|
""" Don't partition scalar values. """
|
||
|
|
return PS()
|
||
|
|
for rule, ps in rules:
|
||
|
|
if re.search(rule, name) is not None:
|
||
|
|
return ps
|
||
|
|
raise ValueError(f'Partition rule not found for param: {name}')
|
||
|
|
return named_tree_map(get_partition_spec, params, sep='/')
|
||
|
|
|
||
|
|
|
||
|
|
def get_weight_decay_mask(exclusions):
|
||
|
|
""" Return a weight decay mask function that computes the pytree masks
|
||
|
|
according to the given exclusion rules.
|
||
|
|
"""
|
||
|
|
def decay(name, _):
|
||
|
|
for rule in exclusions:
|
||
|
|
if re.search(rule, name) is not None:
|
||
|
|
return False
|
||
|
|
return True
|
||
|
|
|
||
|
|
def weight_decay_mask(params):
|
||
|
|
return named_tree_map(decay, params, sep='/')
|
||
|
|
|
||
|
|
return weight_decay_mask
|
||
|
|
|
||
|
|
|
||
|
|
def tree_apply(fns, tree):
|
||
|
|
""" Apply a pytree of functions to the pytree. """
|
||
|
|
return jax.tree_util.tree_map(lambda fn, x: fn(x), fns, tree)
|
||
|
|
|