Files
enginex-bi_series-vllm/pkgs/xformers/factory/model_factory.py

314 lines
12 KiB
Python
Raw Normal View History

2025-08-05 19:02:46 +08:00
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import logging
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union
import torch
from xformers._deprecation_warning import deprecated_function
from xformers.components import reversible as rv
from xformers.components.residual import ResidualNormStyle, get_deepnorm_coefficients
from xformers.factory.block_configs import (
xFormerBlockConfig,
xFormerDecoderConfig,
xFormerEncoderConfig,
)
from xformers.factory.block_factory import xFormerDecoderBlock, xFormerEncoderBlock
from xformers.factory.weight_init import get_weight_init_fn, xFormerWeightInit
logger = logging.getLogger("xformers")
@dataclass(init=False)
class xFormerConfig:
"""
The configuration structure to define a full Transformer.
This can include a stack of encoder layers, and a stack of decoder layers.
It is optionally possible to share the embedding weights in between
the encoder and decoder positional encoding, as proposed for instance by
`Using the Output Embedding to Improve Language Models`, Press et al.
A full config example is for instance as follows:
::
xformer_config = [
{
"reversible": False, # Turn on to test the effect of using reversible layers
"block_type": "encoder",
"num_layers": LAYERS,
"dim_model": EMB,
"residual_norm_style": "pre",
"position_encoding_config": {
"name": "vocab",
"seq_len": CONTEXT,
"vocab_size": VOCAB_SIZE,
},
"multi_head_config": {
"num_heads": NUM_HEADS,
"residual_dropout": RES_DROP,
"use_rotary_embeddings": True,
"attention": {
"name": ATTENTION_MECHANISM_STR,
"dropout": ATTN_DROP,
"causal": True,
"seq_len": CONTEXT,
},
},
"feedforward_config": {
"name": "FusedMLP", # Use MLP if Triton is not available
"dropout": MLP_DROP,
"activation": "gelu",
"hidden_layer_multiplier": MLP_MULTIPLIER,
},
}
]
.. _`Using the Output Embedding to Improve Language Models`: https://arxiv.org/pdf/1608.05859.pdf
"""
stack_configs: Union[List[xFormerBlockConfig], Dict[str, xFormerBlockConfig]]
tie_embedding_weights: bool = False
weight_init: xFormerWeightInit = xFormerWeightInit.ViT
def __init__(
self,
stack_configs: Union[List[Dict[str, Any]], Dict[str, Dict[str, Any]]],
tie_embedding_weights: bool = False,
weight_init: xFormerWeightInit = xFormerWeightInit.ViT,
):
# Type all the configurations. Possible typos are caught here
if isinstance(stack_configs, dict):
self.stack_configs = {}
for k, config in stack_configs.items():
if config["block_type"] == "encoder":
self.stack_configs[k] = xFormerEncoderConfig(**config)
else:
self.stack_configs[k] = xFormerDecoderConfig(**config)
else:
self.stack_configs = []
for config in stack_configs:
if config["block_type"] == "encoder":
self.stack_configs.append(xFormerEncoderConfig(**config))
else:
self.stack_configs.append(xFormerDecoderConfig(**config))
self.tie_embedding_weights = tie_embedding_weights
self.weight_init = weight_init
deprecated_function(self)
class xFormer(torch.nn.Module):
def __init__(
self,
stack_configs: Union[
xFormerBlockConfig, List[xFormerBlockConfig], Dict[str, xFormerBlockConfig]
],
tie_embedding_weights: bool = False,
weight_init: xFormerWeightInit = xFormerWeightInit.ViT,
):
"""
Given a serialized configuration, generate the corresponding model.
This is only a helper and can easily be bypassed
"""
super().__init__()
deprecated_function(self)
if isinstance(stack_configs, Dict):
stack_configs = list(stack_configs.values())
# Convenience, users can pass either a list of configs or a single one
if not isinstance(stack_configs, List):
stack_configs = [stack_configs]
# Sanity checks, some config combinations do not make sense
self._verify_reversible(stack_configs)
self._verify_deepnorm(stack_configs)
encoders: List[torch.nn.Module] = []
decoders: List[torch.nn.Module] = []
self.reversible_encoder = False
self.rev_enc_pose_encoding = None
# Unroll the configs and build the model
for config in stack_configs:
# Handle either Encoder or Decoder stacks
builder = (
xFormerEncoderBlock.from_config
if isinstance(config, xFormerEncoderConfig)
else xFormerDecoderBlock.from_config
)
recipient = (
encoders if isinstance(config, xFormerEncoderConfig) else decoders
)
# Build up the stack
for i in range(config.num_layers):
# Label where this layer is in the stack
# (for instance useful for the positional encoding, or late layer norm)
if len(recipient) > 0:
config.layer_position.mark_not_first()
if config != stack_configs[-1] or i < config.num_layers - 1:
config.layer_position.mark_not_last()
block = builder(config) # type: ignore
# If reversible: extract the reversible sub-parts, else append the block as-is
if config.reversible:
# WARNING: only one pose encoding is saved here (not Focal Transformer compatible for instance)
assert isinstance(config, xFormerEncoderConfig)
if block.pose_encoding is not None:
self.rev_enc_pose_encoding = block.pose_encoding
self.reversible_encoder = True
f, g = xFormerEncoderBlock.get_reversible_layer(config)
recipient.append(torch.nn.ModuleList([f, g]))
else:
recipient.append(block) # type: ignore
# Tie embedding weights, if requested and possible
assert (
not tie_embedding_weights or not self.reversible_encoder
), "Reversible layers and tied embeddings is not supported for now"
if (
tie_embedding_weights
and encoders
and encoders[0].pose_encoding
and decoders
and decoders[0].pose_encoding
and not config.reversible
):
logger.info("Tying encoder and decoder embeddings, as requested")
encoders[0].pose_encoding = decoders[0].pose_encoding
self.encoders: torch.nn.Module = (
rv.ReversibleSequence(torch.nn.ModuleList(encoders))
if self.reversible_encoder
else torch.nn.ModuleList(encoders)
)
self.decoders = torch.nn.ModuleList(decoders)
use_deepnorm = (
stack_configs[0].residual_norm_style == ResidualNormStyle.DeepNorm
)
assert (
not use_deepnorm or not self.reversible_encoder
), "Reversible layers and deepnorm is not supported for now"
self.init_weights(weight_init=weight_init, use_deep_norm=use_deepnorm)
@classmethod
def from_config(cls, config: xFormerConfig):
return cls(
config.stack_configs, config.tie_embedding_weights, config.weight_init
)
def _verify_reversible(self, stack_configs: List[xFormerBlockConfig]):
reversible = [
c.reversible
for c in filter(lambda x: x.block_type == "encoder", stack_configs)
]
assert all(reversible) or not any(reversible), (
"All layers need to have the same reversibility setting. "
+ f"Currently {reversible}"
)
def _verify_deepnorm(self, stack_configs: List[xFormerBlockConfig]):
deepnorm = [
c.residual_norm_style == ResidualNormStyle.DeepNorm for c in stack_configs
]
assert all(deepnorm) or not any(deepnorm), (
"All layers need to have the same deepnorm setting. "
+ f"Currently {deepnorm}"
)
def init_weights(self, weight_init: xFormerWeightInit, use_deep_norm: bool):
# The deepnorm weight initialization method requires different gain factors for the encoder
# and decoder, depending on the general model structure (number of respective layers)
if use_deep_norm:
encoder_coefficients, decoder_coefficients = get_deepnorm_coefficients(
encoder_layers=len(self.encoders), decoder_layers=len(self.decoders) # type: ignore
)
else:
encoder_coefficients, decoder_coefficients = None, None
encoder_gain = (
encoder_coefficients.beta if encoder_coefficients is not None else 1.0
)
decoder_gain = (
decoder_coefficients.beta if decoder_coefficients is not None else 1.0
)
# Pick the desired init function
init_fn = get_weight_init_fn(weight_init)
# Initialize all the encoder weights
for name, module in self.encoders.named_children():
init_fn(module=module, name=name, gain=encoder_gain)
for name, module in self.decoders.named_children():
init_fn(module=module, name=name, gain=decoder_gain)
def forward(
self,
src: torch.Tensor,
tgt: Optional[torch.Tensor] = None,
encoder_input_mask: Optional[torch.Tensor] = None,
decoder_input_mask: Optional[torch.Tensor] = None,
) -> Optional[torch.Tensor]:
# Encode to latent space if encoder is present
if len(list(self.encoders.parameters())) > 0:
encoders = self.encoders
memory = src.clone()
if isinstance(encoders, torch.nn.ModuleList):
for encoder in encoders:
memory = encoder(memory, input_mask=encoder_input_mask)
else:
if self.rev_enc_pose_encoding:
memory = self.rev_enc_pose_encoding(src)
# Reversible Encoder
x = torch.cat([memory, memory], dim=-1)
# Apply the optional input masking
if encoder_input_mask is not None:
if x.dim() - encoder_input_mask.dim() > 1:
encoder_input_mask.unsqueeze(0)
x += encoder_input_mask.unsqueeze(-1)
x = encoders(x)
memory = torch.stack(x.chunk(2, dim=-1)).mean(dim=0)
if not self.decoders:
return memory
# If decoder: either use the encoder ouput, or just decode, both options are possible
if len(self.decoders) > 0:
tgt = src.clone() if tgt is None else tgt
for decoder in self.decoders:
tgt = decoder(
target=tgt,
# pyre-fixme[61]: `memory` is not always initialized here.
memory=memory,
input_mask=decoder_input_mask,
)
return tgt
return None