347 lines
12 KiB
Python
347 lines
12 KiB
Python
import os
|
|
import time
|
|
from typing import Any, Mapping, Text, Tuple, Union, NamedTuple
|
|
from functools import partial
|
|
import re
|
|
import dataclasses
|
|
import random
|
|
|
|
from ml_collections.config_dict import config_dict
|
|
from ml_collections import ConfigDict
|
|
import jax
|
|
import jax.numpy as jnp
|
|
import numpy as np
|
|
from absl import logging
|
|
import optax
|
|
|
|
from EasyLM.jax_utils import float_to_dtype
|
|
|
|
|
|
class OptimizerFactory(object):
|
|
""" Configurable optax optimizer factory. """
|
|
|
|
def __init__(self):
|
|
raise NotImplementedError
|
|
|
|
@staticmethod
|
|
def get_default_config(updates=None):
|
|
config = ConfigDict()
|
|
config.accumulate_gradient_steps = 1
|
|
config.type = 'adamw'
|
|
config.palm_optimizer = PalmOptimizerFactory.get_default_config()
|
|
config.adamw_optimizer = AdamWOptimizerFactory.get_default_config()
|
|
config.lion_optimizer = LionOptimizerFactory.get_default_config()
|
|
|
|
if updates is not None:
|
|
config.update(ConfigDict(updates).copy_and_resolve_references())
|
|
return config
|
|
|
|
@classmethod
|
|
def get_optimizer(cls, config, weight_decay_mask=None):
|
|
config = cls.get_default_config(config)
|
|
if config.type == 'palm':
|
|
optimizer, optimizer_info = PalmOptimizerFactory.get_optimizer(
|
|
config.palm_optimizer, weight_decay_mask
|
|
)
|
|
elif config.type == 'adamw':
|
|
optimizer, optimizer_info = AdamWOptimizerFactory.get_optimizer(
|
|
config.adamw_optimizer, weight_decay_mask
|
|
)
|
|
elif config.type == 'lion':
|
|
optimizer, optimizer_info = LionOptimizerFactory.get_optimizer(
|
|
config.lion_optimizer, weight_decay_mask
|
|
)
|
|
else:
|
|
raise ValueError(f'Unknown optimizer type: {config.type}')
|
|
|
|
if config.accumulate_gradient_steps > 1:
|
|
optimizer = optax.MultiSteps(
|
|
optimizer, config.accumulate_gradient_steps
|
|
)
|
|
|
|
return optimizer, optimizer_info
|
|
|
|
|
|
class PalmOptimizerFactory(object):
|
|
""" PaLM optimizer factory. This optimizer implements the optimizer
|
|
described in the PaLM paper: https://arxiv.org/abs/2204.02311
|
|
"""
|
|
|
|
def __init__(self):
|
|
raise NotImplementedError
|
|
|
|
@staticmethod
|
|
def get_default_config(updates=None):
|
|
config = ConfigDict()
|
|
config.lr = 0.01
|
|
config.lr_warmup_steps = 10000
|
|
config.b1 = 0.9
|
|
config.b2 = 0.99
|
|
config.clip_gradient = 1.0
|
|
config.weight_decay = 1e-4
|
|
config.bf16_momentum = False
|
|
|
|
if updates is not None:
|
|
config.update(ConfigDict(updates).copy_and_resolve_references())
|
|
return config
|
|
|
|
@classmethod
|
|
def get_optimizer(cls, config, weight_decay_mask=None):
|
|
config = cls.get_default_config(config)
|
|
|
|
def learning_rate_schedule(step):
|
|
multiplier = config.lr / 0.01
|
|
return multiplier / jnp.sqrt(jnp.maximum(step, config.lr_warmup_steps))
|
|
|
|
def weight_decay_schedule(step):
|
|
multiplier = config.weight_decay / 1e-4
|
|
return -multiplier * jnp.square(learning_rate_schedule(step))
|
|
|
|
optimizer_info = dict(
|
|
learning_rate_schedule=learning_rate_schedule,
|
|
weight_decay_schedule=weight_decay_schedule,
|
|
)
|
|
|
|
optimizer = optax.chain(
|
|
optax.clip_by_global_norm(config.clip_gradient),
|
|
optax.adafactor(
|
|
learning_rate=learning_rate_schedule,
|
|
multiply_by_parameter_scale=True,
|
|
momentum=config.b1,
|
|
decay_rate=config.b2,
|
|
factored=False,
|
|
clipping_threshold=None,
|
|
dtype_momentum=jnp.bfloat16 if config.bf16_momentum else jnp.float32,
|
|
),
|
|
optax_add_scheduled_weight_decay(
|
|
weight_decay_schedule, weight_decay_mask
|
|
)
|
|
)
|
|
return optimizer, optimizer_info
|
|
|
|
|
|
class AdamWOptimizerFactory(object):
|
|
""" AdamW optimizer with cosine schedule. """
|
|
|
|
def __init__(self):
|
|
raise NotImplementedError
|
|
|
|
@staticmethod
|
|
def get_default_config(updates=None):
|
|
config = ConfigDict()
|
|
config.init_lr = 0.0
|
|
config.end_lr = 0.001
|
|
config.lr = 0.01
|
|
config.lr_warmup_steps = 2000
|
|
config.lr_decay_steps = 500000
|
|
config.b1 = 0.9
|
|
config.b2 = 0.95
|
|
config.clip_gradient = 1.0
|
|
config.weight_decay = 1e-4
|
|
config.bf16_momentum = False
|
|
config.multiply_by_parameter_scale = False
|
|
|
|
if updates is not None:
|
|
config.update(ConfigDict(updates).copy_and_resolve_references())
|
|
return config
|
|
|
|
@classmethod
|
|
def get_optimizer(cls, config, weight_decay_mask=None):
|
|
config = cls.get_default_config(config)
|
|
|
|
learning_rate_schedule = optax.warmup_cosine_decay_schedule(
|
|
init_value=config.init_lr,
|
|
peak_value=config.lr,
|
|
warmup_steps=config.lr_warmup_steps,
|
|
decay_steps=config.lr_decay_steps,
|
|
end_value=config.end_lr,
|
|
)
|
|
|
|
optimizer_info = dict(
|
|
learning_rate_schedule=learning_rate_schedule,
|
|
)
|
|
|
|
if config.multiply_by_parameter_scale:
|
|
optimizer = optax.chain(
|
|
optax.clip_by_global_norm(config.clip_gradient),
|
|
optax.adafactor(
|
|
learning_rate=learning_rate_schedule,
|
|
multiply_by_parameter_scale=True,
|
|
momentum=config.b1,
|
|
decay_rate=config.b2,
|
|
factored=False,
|
|
clipping_threshold=None,
|
|
dtype_momentum=jnp.bfloat16 if config.bf16_momentum else jnp.float32,
|
|
),
|
|
optax_add_scheduled_weight_decay(
|
|
lambda step: -learning_rate_schedule(step) * config.weight_decay,
|
|
weight_decay_mask
|
|
)
|
|
)
|
|
else:
|
|
optimizer = optax.chain(
|
|
optax.clip_by_global_norm(config.clip_gradient),
|
|
optax.adamw(
|
|
learning_rate=learning_rate_schedule,
|
|
weight_decay=config.weight_decay,
|
|
b1=config.b1,
|
|
b2=config.b2,
|
|
mask=weight_decay_mask,
|
|
mu_dtype=jnp.bfloat16 if config.bf16_momentum else jnp.float32,
|
|
),
|
|
)
|
|
|
|
return optimizer, optimizer_info
|
|
|
|
class LionOptimizerFactory(object):
|
|
""" Lion optimizer with cosine schedule. """
|
|
|
|
def __init__(self):
|
|
raise NotImplementedError
|
|
|
|
@staticmethod
|
|
def get_default_config(updates=None):
|
|
config = ConfigDict()
|
|
config.init_lr = 0.0
|
|
config.end_lr = 0.0001
|
|
config.lr = 0.001
|
|
config.lr_warmup_steps = 60000
|
|
config.lr_constant_steps = 840000
|
|
config.lr_decay_steps = 100000
|
|
config.b1 = 0.9
|
|
config.b2 = 0.98
|
|
config.clip_gradient = 1.0
|
|
config.weight_decay = 1e-3
|
|
config.bf16_momentum = False
|
|
config.lr_schedule_type = "warmup_cosine_decay_schedule"
|
|
config.lr_decay_rate = 0.98
|
|
|
|
if updates is not None:
|
|
config.update(ConfigDict(updates).copy_and_resolve_references())
|
|
return config
|
|
|
|
@classmethod
|
|
def get_optimizer(cls, config, weight_decay_mask=None):
|
|
config = cls.get_default_config(config)
|
|
|
|
if config.lr_schedule_type == "warmup_cosine_decay_schedule":
|
|
learning_rate_schedule = optax.warmup_cosine_decay_schedule(
|
|
init_value=config.init_lr,
|
|
peak_value=config.lr,
|
|
warmup_steps=config.lr_warmup_steps,
|
|
decay_steps=config.lr_decay_steps,
|
|
end_value=config.end_lr,
|
|
)
|
|
elif config.lr_schedule_type == "warmup_constant":
|
|
learning_rate_schedule = optax.join_schedules(
|
|
[
|
|
optax.linear_schedule(
|
|
init_value=config.init_lr,
|
|
end_value=config.lr,
|
|
transition_steps=config.lr_warmup_steps,
|
|
),
|
|
optax.constant_schedule(config.lr),
|
|
],
|
|
[config.lr_warmup_steps],
|
|
)
|
|
elif config.lr_schedule_type == "warmup_constant_linear_decay":
|
|
learning_rate_schedule = optax.join_schedules(
|
|
[
|
|
optax.linear_schedule(
|
|
init_value=config.init_lr,
|
|
end_value=config.lr,
|
|
transition_steps=config.lr_warmup_steps,
|
|
),
|
|
optax.constant_schedule(config.lr),
|
|
optax.linear_schedule(
|
|
init_value=config.lr,
|
|
end_value=config.end_lr,
|
|
transition_steps=config.lr_decay_steps,
|
|
)
|
|
],
|
|
[config.lr_warmup_steps, config.lr_constant_steps],
|
|
)
|
|
elif config.lr_schedule_type == "warmup_constant_exponential_decay":
|
|
learning_rate_schedule = optax.join_schedules(
|
|
[
|
|
optax.linear_schedule(
|
|
init_value=config.init_lr,
|
|
end_value=config.lr,
|
|
transition_steps=config.lr_warmup_steps,
|
|
),
|
|
optax.constant_schedule(config.lr),
|
|
optax.exponential_decay(
|
|
init_value=config.lr,
|
|
transition_steps=config.lr_decay_steps,
|
|
decay_rate=config.lr_decay_rate,
|
|
transition_begin=0,
|
|
staircase=False,
|
|
end_value=config.end_lr,
|
|
)
|
|
],
|
|
[config.lr_warmup_steps, config.lr_constant_steps],
|
|
)
|
|
elif config.lr_schedule_type == "exponential_decay":
|
|
learning_rate_schedule = optax.exponential_decay(
|
|
init_value=config.lr,
|
|
transition_steps=config.lr_decay_steps,
|
|
decay_rate=config.lr_decay_rate,
|
|
transition_begin=0,
|
|
staircase=False,
|
|
end_value=config.end_lr,
|
|
)
|
|
elif config.lr_schedule_type == "linear_decay":
|
|
learning_rate_schedule = optax.linear_schedule(
|
|
init_value=config.lr,
|
|
end_value=config.end_lr,
|
|
transition_steps=config.lr_decay_steps,
|
|
)
|
|
else:
|
|
raise ValueError('config.lr_schedule_type must be "warmup_cosine_decay_schedule", "warmup_constant", "warmup_constant_linear_decay", "warmup_constant_exponential_decay", "exponential_decay" or "linear_decay"')
|
|
|
|
optimizer_info = dict(
|
|
learning_rate_schedule=learning_rate_schedule,
|
|
)
|
|
|
|
optimizer = optax.chain(
|
|
optax.clip_by_global_norm(config.clip_gradient),
|
|
optax.lion(
|
|
learning_rate=learning_rate_schedule,
|
|
weight_decay=config.weight_decay,
|
|
b1=config.b1,
|
|
b2=config.b2,
|
|
mask=weight_decay_mask,
|
|
mu_dtype=jnp.bfloat16 if config.bf16_momentum else jnp.float32,
|
|
),
|
|
)
|
|
|
|
return optimizer, optimizer_info
|
|
|
|
|
|
class OptaxScheduledWeightDecayState(NamedTuple):
|
|
count: jax.Array
|
|
|
|
|
|
def optax_add_scheduled_weight_decay(schedule_fn, mask=None):
|
|
""" Apply weight decay with schedule. """
|
|
|
|
def init_fn(params):
|
|
del params
|
|
return OptaxScheduledWeightDecayState(count=jnp.zeros([], jnp.int32))
|
|
|
|
def update_fn(updates, state, params):
|
|
if params is None:
|
|
raise ValueError('Params cannot be None for weight decay!')
|
|
|
|
weight_decay = schedule_fn(state.count)
|
|
updates = jax.tree_util.tree_map(
|
|
lambda g, p: g + weight_decay * p, updates, params
|
|
)
|
|
return updates, OptaxScheduledWeightDecayState(
|
|
count=optax.safe_int32_increment(state.count)
|
|
)
|
|
|
|
if mask is not None:
|
|
return optax.masked(optax.GradientTransformation(init_fn, update_fn), mask)
|
|
return optax.GradientTransformation(init_fn, update_fn)
|