# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. # CREDITS: Reusing a lot of code from the Timm repo # main difference is probably the handling of deepnorm init, and adapting to some xformers specificities # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py import logging import math from enum import Enum from typing import Callable import torch import torch.nn as nn from torch.nn.init import ( _calculate_fan_in_and_fan_out, _no_grad_trunc_normal_, _no_grad_uniform_, ) logger = logging.getLogger("xformers") _assert_if_not_initialized = False class xFormerWeightInit(str, Enum): Timm = "timm" ViT = "vit" Moco = "moco" Small = "small" def get_weight_init_fn(init_choice: xFormerWeightInit): """ Provide the xFormers factory with weight init routines. Supported initializations are: - Small: follow the method outlined in `Transformer Without Tears`_ - ViT: follow the initialization in the reference ViT_ codebase - Timm: follow the initialization in the reference Timm_ codebase - Moco: follow the initialization in the reference MocoV3_ codebase .. _ViT: https://github.com/google-research/vision_transformer .. _Timm: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py .. _MocoV3: https://github.com/facebookresearch/moco-v3 """ return { xFormerWeightInit.Timm: _init_weights_vit_timm, xFormerWeightInit.ViT: _init_weights_vit_jax, xFormerWeightInit.Moco: _init_weights_vit_moco, xFormerWeightInit.Small: _init_weights_small, }[init_choice] # Define pattern matches def is_ffn(n): return "feedforward" in n or ("wrap_ff" in n and not n.endswith("norm")) def is_mha_input_projection(n): return "q_proj" in n or "k_proj" in n or "v_proj" in n # Define distribution helpers def _small_init_(tensor: torch.Tensor, gain: float = 1.0) -> torch.Tensor: r"""Fills the input `Tensor` with values according to the method described in `Transformer Without Tears`_, using a uniform distribution. This is a variation of the Xavier init. The resulting tensor will have values sampled from :math:`\mathcal{U}(-a, a)` where .. math:: a = \text{gain} \times \sqrt{\frac{6}{\text{fan\_in} + 4 * \text{fan\_out}}} Also known as Glorot initialization. Args: tensor: an n-dimensional `torch.Tensor` gain: an optional scaling factor .. _`Transformer Without Tears`: https://arxiv.org/abs/1910.05895 """ fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) std = gain * math.sqrt(2.0 / float(fan_in + 4 * fan_out)) a = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation return _no_grad_uniform_(tensor, -a, a) def _lecun_normal(tensor, gain=1.0): fan_in, _ = _calculate_fan_in_and_fan_out(tensor) denom = fan_in variance = gain / denom # constant is stddev of standard normal truncated to (-2, 2) _no_grad_trunc_normal_( tensor, mean=0.0, std=math.sqrt(variance) / 0.87962566103423978, a=-2.0, b=2.0, ) # Helpers to keep all the functions typesafe, and handle corner cases and common behaviours in one place def _maybe_init_tensor(module: nn.Module, attr: str, distribution_: Callable, **kwargs): # Small helper to catch all the corner cases, while staying type safe if hasattr(module, attr): maybe_tensor = getattr(module, attr) if maybe_tensor is not None and isinstance(maybe_tensor, torch.Tensor): distribution_(maybe_tensor, **kwargs) def _maybe_report_no_init(module, name): if len(list(module.named_children())) == 0 and ( hasattr(module, "weight") or hasattr(module, "bias") ): # Skip layer norm, this is ok if isinstance(module, torch.nn.LayerNorm): return # Skip nn.Embedding, we typically initialize it one level up, else Pytorch has a valid default if isinstance(module, torch.nn.Embedding): return # This is unexpected, warn about a possible unhandled weight logger.warning( f"Not initializing weights in {name}, this could be a mistake.\nModule {module}" ) if _assert_if_not_initialized: assert False, ( f"Uninitialized weight found in {module}." + " If you have a custom module, please provide a `init_weights()` method" ) # Define the different initialization schemes def _init_weights_vit_jax( module: nn.Module, name: str = "", head_bias: float = 0.0, gain: float = 1.0, deepnorm_style: bool = False, **kwargs, ): """ViT weight initialization, matching JAX (Flax) impl""" if is_ffn(name): _maybe_init_tensor(module, "bias", nn.init.normal_, std=1e-6) _maybe_init_tensor(module, "weight", torch.nn.init.xavier_uniform_, gain=gain) elif is_mha_input_projection(name) or isinstance(module, nn.Linear): if deepnorm_style and ( "q_proj" in name.split(".") or "k_proj" in name.split(".") ): gain = 1.0 _maybe_init_tensor(module, "weight", torch.nn.init.xavier_uniform_, gain=gain) _maybe_init_tensor(module, "bias", nn.init.zeros_) elif isinstance(module, nn.Conv2d): _maybe_init_tensor(module, "weight", _lecun_normal, gain=gain) _maybe_init_tensor(module, "bias", nn.init.zeros_) elif hasattr(module, "init_weights"): module.init_weights() # type: ignore else: _maybe_report_no_init(module, name) # Recurse over the children, if the weight init is being handled here if not hasattr(module, "init_weights"): for child_name, child_module in module.named_children(): _init_weights_vit_jax(child_module, f"{name}.{child_name}", head_bias, gain) def _init_weights_vit_moco( module: nn.Module, name: str = "", gain: float = 1.0, **kwargs, ): """ViT weight initialization, matching moco-v3 impl minus fixed PatchEmbed""" assert ( "deepnorm_style" not in kwargs.keys() ), "This initialization method does not support deepnorm" if is_ffn(name): _maybe_init_tensor(module, "weight", torch.nn.init.xavier_uniform_, gain=gain) _maybe_init_tensor(module, "bias", nn.init.zeros_) elif is_mha_input_projection(name) or isinstance(module, nn.Linear): if isinstance(module.weight, torch.Tensor): val = ( math.sqrt(6.0 / float(module.weight.shape[0] + module.weight.shape[1])) * gain ) _maybe_init_tensor(module, "weight", nn.init.uniform_, a=-val, b=val) _maybe_init_tensor(module, "bias", nn.init.zeros_) elif hasattr(module, "init_weights"): module.init_weights(gain=gain) # type: ignore else: _maybe_report_no_init(module, name) # Recurse over the children, if the weight init is being handled here if not hasattr(module, "init_weights"): for child_name, child_module in module.named_children(): _init_weights_vit_moco(child_module, child_name, gain) def _init_weights_small( module: nn.Module, name: str = "", head_bias: float = 0.0, gain: float = 1.0, deepnorm_style: bool = False, **kwargs, ): """Follow the `Transformer Without Tears`_ initialization for self-attention""" if is_ffn(name): _maybe_init_tensor(module, "weight", torch.nn.init.xavier_uniform_, gain=gain) _maybe_init_tensor(module, "bias", nn.init.normal_, std=1e-6) elif is_mha_input_projection(name) or isinstance(module, nn.Linear): # "small init" only scales the attention layers init, not the FFN if deepnorm_style and ( "q_proj" in name.split(".") or "k_proj" in name.split(".") ): gain = 1.0 _maybe_init_tensor(module, "weight", _small_init_, gain=gain) _maybe_init_tensor(module, "bias", nn.init.zeros_) elif isinstance(module, nn.Conv2d): _maybe_init_tensor(module, "weight", _lecun_normal) _maybe_init_tensor(module, "bias", nn.init.zeros_) elif hasattr(module, "init_weights"): module.init_weights() # type: ignore else: _maybe_report_no_init(module, name) # Recurse over the children, if the weight init is being handled here if not hasattr(module, "init_weights"): for child_name, child_module in module.named_children(): _init_weights_small(child_module, f"{name}.{child_name}", head_bias, gain) def _init_weights_vit_timm( module: nn.Module, name: str = "", gain: float = 1.0, deepnorm_style: bool = False, **kwargs, ): """ ViT weight initialization, original timm impl (for reproducibility). See DeepNet_ for all the DeepNorm specific codepaths """ if isinstance(module, nn.Linear): if deepnorm_style and ( "q_proj" in name.split(".") or "k_proj" in name.split(".") ): gain = 1 std = 0.02 * gain a = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation _maybe_init_tensor( module, "weight", _no_grad_trunc_normal_, mean=0.0, std=std, a=-a, b=a ) _maybe_init_tensor(module, "bias", nn.init.zeros_) elif hasattr(module, "init_weights"): module.init_weights(gain=gain) # type: ignore else: _maybe_report_no_init(module, name) # Recurse over the children, if the weight init is being handled here if not hasattr(module, "init_weights"): for child_name, child_module in module.named_children(): _init_weights_vit_timm(child_module, child_name, gain)