Files
enginex-bi_series-vllm/pkgs/xformers/factory/weight_init.py

294 lines
9.8 KiB
Python
Raw Normal View History

2025-08-05 19:02:46 +08:00
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# CREDITS: Reusing a lot of code from the Timm repo
# main difference is probably the handling of deepnorm init, and adapting to some xformers specificities
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
import logging
import math
from enum import Enum
from typing import Callable
import torch
import torch.nn as nn
from torch.nn.init import (
_calculate_fan_in_and_fan_out,
_no_grad_trunc_normal_,
_no_grad_uniform_,
)
logger = logging.getLogger("xformers")
_assert_if_not_initialized = False
class xFormerWeightInit(str, Enum):
Timm = "timm"
ViT = "vit"
Moco = "moco"
Small = "small"
def get_weight_init_fn(init_choice: xFormerWeightInit):
"""
Provide the xFormers factory with weight init routines.
Supported initializations are:
- Small: follow the method outlined in `Transformer Without Tears`_
- ViT: follow the initialization in the reference ViT_ codebase
- Timm: follow the initialization in the reference Timm_ codebase
- Moco: follow the initialization in the reference MocoV3_ codebase
.. _ViT: https://github.com/google-research/vision_transformer
.. _Timm: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
.. _MocoV3: https://github.com/facebookresearch/moco-v3
"""
return {
xFormerWeightInit.Timm: _init_weights_vit_timm,
xFormerWeightInit.ViT: _init_weights_vit_jax,
xFormerWeightInit.Moco: _init_weights_vit_moco,
xFormerWeightInit.Small: _init_weights_small,
}[init_choice]
# Define pattern matches
def is_ffn(n):
return "feedforward" in n or ("wrap_ff" in n and not n.endswith("norm"))
def is_mha_input_projection(n):
return "q_proj" in n or "k_proj" in n or "v_proj" in n
# Define distribution helpers
def _small_init_(tensor: torch.Tensor, gain: float = 1.0) -> torch.Tensor:
r"""Fills the input `Tensor` with values according to the method
described in `Transformer Without Tears`_, using a uniform distribution.
This is a variation of the Xavier init. The resulting tensor will have values sampled from
:math:`\mathcal{U}(-a, a)` where
.. math::
a = \text{gain} \times \sqrt{\frac{6}{\text{fan\_in} + 4 * \text{fan\_out}}}
Also known as Glorot initialization.
Args:
tensor: an n-dimensional `torch.Tensor`
gain: an optional scaling factor
.. _`Transformer Without Tears`: https://arxiv.org/abs/1910.05895
"""
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
std = gain * math.sqrt(2.0 / float(fan_in + 4 * fan_out))
a = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
return _no_grad_uniform_(tensor, -a, a)
def _lecun_normal(tensor, gain=1.0):
fan_in, _ = _calculate_fan_in_and_fan_out(tensor)
denom = fan_in
variance = gain / denom
# constant is stddev of standard normal truncated to (-2, 2)
_no_grad_trunc_normal_(
tensor,
mean=0.0,
std=math.sqrt(variance) / 0.87962566103423978,
a=-2.0,
b=2.0,
)
# Helpers to keep all the functions typesafe, and handle corner cases and common behaviours in one place
def _maybe_init_tensor(module: nn.Module, attr: str, distribution_: Callable, **kwargs):
# Small helper to catch all the corner cases, while staying type safe
if hasattr(module, attr):
maybe_tensor = getattr(module, attr)
if maybe_tensor is not None and isinstance(maybe_tensor, torch.Tensor):
distribution_(maybe_tensor, **kwargs)
def _maybe_report_no_init(module, name):
if len(list(module.named_children())) == 0 and (
hasattr(module, "weight") or hasattr(module, "bias")
):
# Skip layer norm, this is ok
if isinstance(module, torch.nn.LayerNorm):
return
# Skip nn.Embedding, we typically initialize it one level up, else Pytorch has a valid default
if isinstance(module, torch.nn.Embedding):
return
# This is unexpected, warn about a possible unhandled weight
logger.warning(
f"Not initializing weights in {name}, this could be a mistake.\nModule {module}"
)
if _assert_if_not_initialized:
assert False, (
f"Uninitialized weight found in {module}."
+ " If you have a custom module, please provide a `init_weights()` method"
)
# Define the different initialization schemes
def _init_weights_vit_jax(
module: nn.Module,
name: str = "",
head_bias: float = 0.0,
gain: float = 1.0,
deepnorm_style: bool = False,
**kwargs,
):
"""ViT weight initialization, matching JAX (Flax) impl"""
if is_ffn(name):
_maybe_init_tensor(module, "bias", nn.init.normal_, std=1e-6)
_maybe_init_tensor(module, "weight", torch.nn.init.xavier_uniform_, gain=gain)
elif is_mha_input_projection(name) or isinstance(module, nn.Linear):
if deepnorm_style and (
"q_proj" in name.split(".") or "k_proj" in name.split(".")
):
gain = 1.0
_maybe_init_tensor(module, "weight", torch.nn.init.xavier_uniform_, gain=gain)
_maybe_init_tensor(module, "bias", nn.init.zeros_)
elif isinstance(module, nn.Conv2d):
_maybe_init_tensor(module, "weight", _lecun_normal, gain=gain)
_maybe_init_tensor(module, "bias", nn.init.zeros_)
elif hasattr(module, "init_weights"):
module.init_weights() # type: ignore
else:
_maybe_report_no_init(module, name)
# Recurse over the children, if the weight init is being handled here
if not hasattr(module, "init_weights"):
for child_name, child_module in module.named_children():
_init_weights_vit_jax(child_module, f"{name}.{child_name}", head_bias, gain)
def _init_weights_vit_moco(
module: nn.Module,
name: str = "",
gain: float = 1.0,
**kwargs,
):
"""ViT weight initialization, matching moco-v3 impl minus fixed PatchEmbed"""
assert (
"deepnorm_style" not in kwargs.keys()
), "This initialization method does not support deepnorm"
if is_ffn(name):
_maybe_init_tensor(module, "weight", torch.nn.init.xavier_uniform_, gain=gain)
_maybe_init_tensor(module, "bias", nn.init.zeros_)
elif is_mha_input_projection(name) or isinstance(module, nn.Linear):
if isinstance(module.weight, torch.Tensor):
val = (
math.sqrt(6.0 / float(module.weight.shape[0] + module.weight.shape[1]))
* gain
)
_maybe_init_tensor(module, "weight", nn.init.uniform_, a=-val, b=val)
_maybe_init_tensor(module, "bias", nn.init.zeros_)
elif hasattr(module, "init_weights"):
module.init_weights(gain=gain) # type: ignore
else:
_maybe_report_no_init(module, name)
# Recurse over the children, if the weight init is being handled here
if not hasattr(module, "init_weights"):
for child_name, child_module in module.named_children():
_init_weights_vit_moco(child_module, child_name, gain)
def _init_weights_small(
module: nn.Module,
name: str = "",
head_bias: float = 0.0,
gain: float = 1.0,
deepnorm_style: bool = False,
**kwargs,
):
"""Follow the `Transformer Without Tears`_ initialization for self-attention"""
if is_ffn(name):
_maybe_init_tensor(module, "weight", torch.nn.init.xavier_uniform_, gain=gain)
_maybe_init_tensor(module, "bias", nn.init.normal_, std=1e-6)
elif is_mha_input_projection(name) or isinstance(module, nn.Linear):
# "small init" only scales the attention layers init, not the FFN
if deepnorm_style and (
"q_proj" in name.split(".") or "k_proj" in name.split(".")
):
gain = 1.0
_maybe_init_tensor(module, "weight", _small_init_, gain=gain)
_maybe_init_tensor(module, "bias", nn.init.zeros_)
elif isinstance(module, nn.Conv2d):
_maybe_init_tensor(module, "weight", _lecun_normal)
_maybe_init_tensor(module, "bias", nn.init.zeros_)
elif hasattr(module, "init_weights"):
module.init_weights() # type: ignore
else:
_maybe_report_no_init(module, name)
# Recurse over the children, if the weight init is being handled here
if not hasattr(module, "init_weights"):
for child_name, child_module in module.named_children():
_init_weights_small(child_module, f"{name}.{child_name}", head_bias, gain)
def _init_weights_vit_timm(
module: nn.Module,
name: str = "",
gain: float = 1.0,
deepnorm_style: bool = False,
**kwargs,
):
"""
ViT weight initialization, original timm impl (for reproducibility).
See DeepNet_ for all the DeepNorm specific codepaths
"""
if isinstance(module, nn.Linear):
if deepnorm_style and (
"q_proj" in name.split(".") or "k_proj" in name.split(".")
):
gain = 1
std = 0.02 * gain
a = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
_maybe_init_tensor(
module, "weight", _no_grad_trunc_normal_, mean=0.0, std=std, a=-a, b=a
)
_maybe_init_tensor(module, "bias", nn.init.zeros_)
elif hasattr(module, "init_weights"):
module.init_weights(gain=gain) # type: ignore
else:
_maybe_report_no_init(module, name)
# Recurse over the children, if the weight init is being handled here
if not hasattr(module, "init_weights"):
for child_name, child_module in module.named_children():
_init_weights_vit_timm(child_module, child_name, gain)