初始化项目,由ModelHub XC社区提供模型
Model: Finnish-NLP/Ahma-7B Source: Original Platform
This commit is contained in:
346
EasyLM/optimizers.py
Normal file
346
EasyLM/optimizers.py
Normal file
@@ -0,0 +1,346 @@
|
||||
import os
|
||||
import time
|
||||
from typing import Any, Mapping, Text, Tuple, Union, NamedTuple
|
||||
from functools import partial
|
||||
import re
|
||||
import dataclasses
|
||||
import random
|
||||
|
||||
from ml_collections.config_dict import config_dict
|
||||
from ml_collections import ConfigDict
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
from absl import logging
|
||||
import optax
|
||||
|
||||
from EasyLM.jax_utils import float_to_dtype
|
||||
|
||||
|
||||
class OptimizerFactory(object):
|
||||
""" Configurable optax optimizer factory. """
|
||||
|
||||
def __init__(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def get_default_config(updates=None):
|
||||
config = ConfigDict()
|
||||
config.accumulate_gradient_steps = 1
|
||||
config.type = 'adamw'
|
||||
config.palm_optimizer = PalmOptimizerFactory.get_default_config()
|
||||
config.adamw_optimizer = AdamWOptimizerFactory.get_default_config()
|
||||
config.lion_optimizer = LionOptimizerFactory.get_default_config()
|
||||
|
||||
if updates is not None:
|
||||
config.update(ConfigDict(updates).copy_and_resolve_references())
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def get_optimizer(cls, config, weight_decay_mask=None):
|
||||
config = cls.get_default_config(config)
|
||||
if config.type == 'palm':
|
||||
optimizer, optimizer_info = PalmOptimizerFactory.get_optimizer(
|
||||
config.palm_optimizer, weight_decay_mask
|
||||
)
|
||||
elif config.type == 'adamw':
|
||||
optimizer, optimizer_info = AdamWOptimizerFactory.get_optimizer(
|
||||
config.adamw_optimizer, weight_decay_mask
|
||||
)
|
||||
elif config.type == 'lion':
|
||||
optimizer, optimizer_info = LionOptimizerFactory.get_optimizer(
|
||||
config.lion_optimizer, weight_decay_mask
|
||||
)
|
||||
else:
|
||||
raise ValueError(f'Unknown optimizer type: {config.type}')
|
||||
|
||||
if config.accumulate_gradient_steps > 1:
|
||||
optimizer = optax.MultiSteps(
|
||||
optimizer, config.accumulate_gradient_steps
|
||||
)
|
||||
|
||||
return optimizer, optimizer_info
|
||||
|
||||
|
||||
class PalmOptimizerFactory(object):
|
||||
""" PaLM optimizer factory. This optimizer implements the optimizer
|
||||
described in the PaLM paper: https://arxiv.org/abs/2204.02311
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def get_default_config(updates=None):
|
||||
config = ConfigDict()
|
||||
config.lr = 0.01
|
||||
config.lr_warmup_steps = 10000
|
||||
config.b1 = 0.9
|
||||
config.b2 = 0.99
|
||||
config.clip_gradient = 1.0
|
||||
config.weight_decay = 1e-4
|
||||
config.bf16_momentum = False
|
||||
|
||||
if updates is not None:
|
||||
config.update(ConfigDict(updates).copy_and_resolve_references())
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def get_optimizer(cls, config, weight_decay_mask=None):
|
||||
config = cls.get_default_config(config)
|
||||
|
||||
def learning_rate_schedule(step):
|
||||
multiplier = config.lr / 0.01
|
||||
return multiplier / jnp.sqrt(jnp.maximum(step, config.lr_warmup_steps))
|
||||
|
||||
def weight_decay_schedule(step):
|
||||
multiplier = config.weight_decay / 1e-4
|
||||
return -multiplier * jnp.square(learning_rate_schedule(step))
|
||||
|
||||
optimizer_info = dict(
|
||||
learning_rate_schedule=learning_rate_schedule,
|
||||
weight_decay_schedule=weight_decay_schedule,
|
||||
)
|
||||
|
||||
optimizer = optax.chain(
|
||||
optax.clip_by_global_norm(config.clip_gradient),
|
||||
optax.adafactor(
|
||||
learning_rate=learning_rate_schedule,
|
||||
multiply_by_parameter_scale=True,
|
||||
momentum=config.b1,
|
||||
decay_rate=config.b2,
|
||||
factored=False,
|
||||
clipping_threshold=None,
|
||||
dtype_momentum=jnp.bfloat16 if config.bf16_momentum else jnp.float32,
|
||||
),
|
||||
optax_add_scheduled_weight_decay(
|
||||
weight_decay_schedule, weight_decay_mask
|
||||
)
|
||||
)
|
||||
return optimizer, optimizer_info
|
||||
|
||||
|
||||
class AdamWOptimizerFactory(object):
|
||||
""" AdamW optimizer with cosine schedule. """
|
||||
|
||||
def __init__(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def get_default_config(updates=None):
|
||||
config = ConfigDict()
|
||||
config.init_lr = 0.0
|
||||
config.end_lr = 0.001
|
||||
config.lr = 0.01
|
||||
config.lr_warmup_steps = 2000
|
||||
config.lr_decay_steps = 500000
|
||||
config.b1 = 0.9
|
||||
config.b2 = 0.95
|
||||
config.clip_gradient = 1.0
|
||||
config.weight_decay = 1e-4
|
||||
config.bf16_momentum = False
|
||||
config.multiply_by_parameter_scale = False
|
||||
|
||||
if updates is not None:
|
||||
config.update(ConfigDict(updates).copy_and_resolve_references())
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def get_optimizer(cls, config, weight_decay_mask=None):
|
||||
config = cls.get_default_config(config)
|
||||
|
||||
learning_rate_schedule = optax.warmup_cosine_decay_schedule(
|
||||
init_value=config.init_lr,
|
||||
peak_value=config.lr,
|
||||
warmup_steps=config.lr_warmup_steps,
|
||||
decay_steps=config.lr_decay_steps,
|
||||
end_value=config.end_lr,
|
||||
)
|
||||
|
||||
optimizer_info = dict(
|
||||
learning_rate_schedule=learning_rate_schedule,
|
||||
)
|
||||
|
||||
if config.multiply_by_parameter_scale:
|
||||
optimizer = optax.chain(
|
||||
optax.clip_by_global_norm(config.clip_gradient),
|
||||
optax.adafactor(
|
||||
learning_rate=learning_rate_schedule,
|
||||
multiply_by_parameter_scale=True,
|
||||
momentum=config.b1,
|
||||
decay_rate=config.b2,
|
||||
factored=False,
|
||||
clipping_threshold=None,
|
||||
dtype_momentum=jnp.bfloat16 if config.bf16_momentum else jnp.float32,
|
||||
),
|
||||
optax_add_scheduled_weight_decay(
|
||||
lambda step: -learning_rate_schedule(step) * config.weight_decay,
|
||||
weight_decay_mask
|
||||
)
|
||||
)
|
||||
else:
|
||||
optimizer = optax.chain(
|
||||
optax.clip_by_global_norm(config.clip_gradient),
|
||||
optax.adamw(
|
||||
learning_rate=learning_rate_schedule,
|
||||
weight_decay=config.weight_decay,
|
||||
b1=config.b1,
|
||||
b2=config.b2,
|
||||
mask=weight_decay_mask,
|
||||
mu_dtype=jnp.bfloat16 if config.bf16_momentum else jnp.float32,
|
||||
),
|
||||
)
|
||||
|
||||
return optimizer, optimizer_info
|
||||
|
||||
class LionOptimizerFactory(object):
|
||||
""" Lion optimizer with cosine schedule. """
|
||||
|
||||
def __init__(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def get_default_config(updates=None):
|
||||
config = ConfigDict()
|
||||
config.init_lr = 0.0
|
||||
config.end_lr = 0.0001
|
||||
config.lr = 0.001
|
||||
config.lr_warmup_steps = 60000
|
||||
config.lr_constant_steps = 840000
|
||||
config.lr_decay_steps = 100000
|
||||
config.b1 = 0.9
|
||||
config.b2 = 0.98
|
||||
config.clip_gradient = 1.0
|
||||
config.weight_decay = 1e-3
|
||||
config.bf16_momentum = False
|
||||
config.lr_schedule_type = "warmup_cosine_decay_schedule"
|
||||
config.lr_decay_rate = 0.98
|
||||
|
||||
if updates is not None:
|
||||
config.update(ConfigDict(updates).copy_and_resolve_references())
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def get_optimizer(cls, config, weight_decay_mask=None):
|
||||
config = cls.get_default_config(config)
|
||||
|
||||
if config.lr_schedule_type == "warmup_cosine_decay_schedule":
|
||||
learning_rate_schedule = optax.warmup_cosine_decay_schedule(
|
||||
init_value=config.init_lr,
|
||||
peak_value=config.lr,
|
||||
warmup_steps=config.lr_warmup_steps,
|
||||
decay_steps=config.lr_decay_steps,
|
||||
end_value=config.end_lr,
|
||||
)
|
||||
elif config.lr_schedule_type == "warmup_constant":
|
||||
learning_rate_schedule = optax.join_schedules(
|
||||
[
|
||||
optax.linear_schedule(
|
||||
init_value=config.init_lr,
|
||||
end_value=config.lr,
|
||||
transition_steps=config.lr_warmup_steps,
|
||||
),
|
||||
optax.constant_schedule(config.lr),
|
||||
],
|
||||
[config.lr_warmup_steps],
|
||||
)
|
||||
elif config.lr_schedule_type == "warmup_constant_linear_decay":
|
||||
learning_rate_schedule = optax.join_schedules(
|
||||
[
|
||||
optax.linear_schedule(
|
||||
init_value=config.init_lr,
|
||||
end_value=config.lr,
|
||||
transition_steps=config.lr_warmup_steps,
|
||||
),
|
||||
optax.constant_schedule(config.lr),
|
||||
optax.linear_schedule(
|
||||
init_value=config.lr,
|
||||
end_value=config.end_lr,
|
||||
transition_steps=config.lr_decay_steps,
|
||||
)
|
||||
],
|
||||
[config.lr_warmup_steps, config.lr_constant_steps],
|
||||
)
|
||||
elif config.lr_schedule_type == "warmup_constant_exponential_decay":
|
||||
learning_rate_schedule = optax.join_schedules(
|
||||
[
|
||||
optax.linear_schedule(
|
||||
init_value=config.init_lr,
|
||||
end_value=config.lr,
|
||||
transition_steps=config.lr_warmup_steps,
|
||||
),
|
||||
optax.constant_schedule(config.lr),
|
||||
optax.exponential_decay(
|
||||
init_value=config.lr,
|
||||
transition_steps=config.lr_decay_steps,
|
||||
decay_rate=config.lr_decay_rate,
|
||||
transition_begin=0,
|
||||
staircase=False,
|
||||
end_value=config.end_lr,
|
||||
)
|
||||
],
|
||||
[config.lr_warmup_steps, config.lr_constant_steps],
|
||||
)
|
||||
elif config.lr_schedule_type == "exponential_decay":
|
||||
learning_rate_schedule = optax.exponential_decay(
|
||||
init_value=config.lr,
|
||||
transition_steps=config.lr_decay_steps,
|
||||
decay_rate=config.lr_decay_rate,
|
||||
transition_begin=0,
|
||||
staircase=False,
|
||||
end_value=config.end_lr,
|
||||
)
|
||||
elif config.lr_schedule_type == "linear_decay":
|
||||
learning_rate_schedule = optax.linear_schedule(
|
||||
init_value=config.lr,
|
||||
end_value=config.end_lr,
|
||||
transition_steps=config.lr_decay_steps,
|
||||
)
|
||||
else:
|
||||
raise ValueError('config.lr_schedule_type must be "warmup_cosine_decay_schedule", "warmup_constant", "warmup_constant_linear_decay", "warmup_constant_exponential_decay", "exponential_decay" or "linear_decay"')
|
||||
|
||||
optimizer_info = dict(
|
||||
learning_rate_schedule=learning_rate_schedule,
|
||||
)
|
||||
|
||||
optimizer = optax.chain(
|
||||
optax.clip_by_global_norm(config.clip_gradient),
|
||||
optax.lion(
|
||||
learning_rate=learning_rate_schedule,
|
||||
weight_decay=config.weight_decay,
|
||||
b1=config.b1,
|
||||
b2=config.b2,
|
||||
mask=weight_decay_mask,
|
||||
mu_dtype=jnp.bfloat16 if config.bf16_momentum else jnp.float32,
|
||||
),
|
||||
)
|
||||
|
||||
return optimizer, optimizer_info
|
||||
|
||||
|
||||
class OptaxScheduledWeightDecayState(NamedTuple):
|
||||
count: jax.Array
|
||||
|
||||
|
||||
def optax_add_scheduled_weight_decay(schedule_fn, mask=None):
|
||||
""" Apply weight decay with schedule. """
|
||||
|
||||
def init_fn(params):
|
||||
del params
|
||||
return OptaxScheduledWeightDecayState(count=jnp.zeros([], jnp.int32))
|
||||
|
||||
def update_fn(updates, state, params):
|
||||
if params is None:
|
||||
raise ValueError('Params cannot be None for weight decay!')
|
||||
|
||||
weight_decay = schedule_fn(state.count)
|
||||
updates = jax.tree_util.tree_map(
|
||||
lambda g, p: g + weight_decay * p, updates, params
|
||||
)
|
||||
return updates, OptaxScheduledWeightDecayState(
|
||||
count=optax.safe_int32_increment(state.count)
|
||||
)
|
||||
|
||||
if mask is not None:
|
||||
return optax.masked(optax.GradientTransformation(init_fn, update_fn), mask)
|
||||
return optax.GradientTransformation(init_fn, update_fn)
|
||||
Reference in New Issue
Block a user