Files
enginex-bi_series-vllm/pkgs/xformers/factory/block_factory.py
2025-08-05 19:02:46 +08:00

359 lines
13 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import logging
from dataclasses import asdict
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
from xformers._deprecation_warning import deprecated_function
from xformers.components import (
PatchEmbeddingConfig,
PostNorm,
PreNorm,
Residual,
ResidualNormStyle,
build_multi_head_attention,
build_patch_embedding,
)
from xformers.components.attention import AttentionMask
from xformers.components.feedforward import build_feedforward
from xformers.components.positional_embedding import build_positional_embedding
from xformers.components.residual import get_deepnorm_coefficients
from xformers.components.simplicial_embedding import SimplicialEmbedding
from xformers.factory.block_configs import (
NormalizationType,
xFormerDecoderConfig,
xFormerEncoderConfig,
)
logger = logging.getLogger("xformers")
def _get_ln_factory(
d_model: int,
residual_norm_style: Optional[ResidualNormStyle],
use_triton: bool,
residual: bool,
normalization: NormalizationType = NormalizationType.LayerNorm,
residual_scale: float = 1.0,
):
"""
Handle all the supported residual path configurations.
..Note: we return the appropriate constructor, not an actual layer
"""
def get_layer_wrapper(
d_model: int,
sublayer: nn.Module,
residual_norm_style: Optional[ResidualNormStyle],
residual: bool,
residual_scale: float,
):
if residual:
if residual_norm_style == ResidualNormStyle.Pre:
return Residual(
layer=PreNorm(d_model, sublayer, normalization, use_triton),
scale=None,
)
elif residual_norm_style == ResidualNormStyle.Post:
return PostNorm(
d_model,
Residual(layer=sublayer, scale=None),
normalization,
use_triton,
)
elif residual_norm_style == ResidualNormStyle.DeepNorm:
return PostNorm(
d_model,
Residual(layer=sublayer, scale=residual_scale),
normalization,
use_triton=use_triton,
)
else:
raise ValueError
return (
PreNorm(d_model, sublayer, normalization, use_triton)
if residual_norm_style == ResidualNormStyle.Pre
else PostNorm(d_model, sublayer, normalization, use_triton)
)
def ln_factory(sublayer: nn.Module):
return get_layer_wrapper(
d_model, sublayer, residual_norm_style, residual, residual_scale
)
return ln_factory
class xFormerEncoderBlock(torch.nn.Module):
r"""A vanilla Transformer Encoder block"""
def __init__(self, config: xFormerEncoderConfig, **kwargs):
super().__init__()
deprecated_function(self)
self.reversible_f = None
self.reversible_g = None
self.residual_norm_style = config.residual_norm_style
self.dim_model = config.dim_model
# If this layer is the first one, and a pose encoding has been requested
if (
config.position_encoding_config is not None
and config.layer_position.is_first()
):
self.pose_encoding = build_positional_embedding(
asdict(config.position_encoding_config)
)
pos_encoding_dim = config.position_encoding_config.dim_model
mha_dim = config.multi_head_config["dim_model"]
if pos_encoding_dim != mha_dim:
logger.warning(
f"The embedding dim and model dim do not match ({pos_encoding_dim} vs {mha_dim}), adding a projector layer." # noqa
)
self.embedding_projector = nn.Linear(pos_encoding_dim, mha_dim)
else:
self.pose_encoding = None
if config.residual_norm_style == ResidualNormStyle.DeepNorm:
# Just use the layer norm coefficient here,
# the init will be handled at the xformers level (knows about encoder and decoder blocks)
deep_norm_coefficients, _ = get_deepnorm_coefficients(
encoder_layers=config.num_layers, decoder_layers=0
)
assert deep_norm_coefficients is not None
residual_scale = deep_norm_coefficients.alpha
else:
residual_scale = 1.0
# mini helper, builds a normalization layer with the right Pre/Post config, residuals, and the right dimensions
ln_factory = _get_ln_factory(
config.dim_model,
config.residual_norm_style,
use_triton=config.use_triton,
residual=True,
residual_scale=residual_scale,
normalization=config.normalization,
)
mha = build_multi_head_attention(config.multi_head_config)
feedforward = build_feedforward(asdict(config.feedforward_config))
# Expose attention specific capabilities
self.supports_attention_mask = mha.attention.supports_attention_mask
self.requires_same_k_q_dimensions = mha.attention.requires_same_k_q_dimensions
self.causal = (
mha.attention.causal if hasattr(mha.attention, "causal") else False
)
# Wrappers handle the different layer norm styles (pre- and post-) and the residual path
self.wrap_att = ln_factory(mha)
self.wrap_ff: Union[Residual, PostNorm] = ln_factory(feedforward)
if (
config.residual_norm_style == ResidualNormStyle.Pre
and config.layer_position.is_last()
):
self.wrap_ff = PostNorm(
config.dim_model,
self.wrap_ff,
normalization=config.normalization,
use_triton=config.use_triton,
)
# Simplicial embeddings are only used if specified, and on the last layer
self.simplicial_embedding: Optional[SimplicialEmbedding] = None
if config.simplicial_embeddings is not None and config.layer_position.is_last():
self.simplicial_embedding = SimplicialEmbedding(
**config.simplicial_embeddings
)
# Optional patch embedding
self.patch_emb: Optional[nn.Module] = None
if config.patch_embedding_config is not None:
self.patch_emb = build_patch_embedding(
PatchEmbeddingConfig(**config.patch_embedding_config)
)
@classmethod
def from_config(cls, config: xFormerEncoderConfig):
return cls(config)
@staticmethod
def get_reversible_layer(config) -> Tuple[nn.Module, nn.Module]:
ln_factory = _get_ln_factory(
config.dim_model,
config.residual_norm_style,
residual=False,
use_triton=config.use_triton,
normalization=config.normalization,
)
mha = build_multi_head_attention(config.multi_head_config)
feedforward = build_feedforward(asdict(config.feedforward_config))
reversible_f = ln_factory(mha)
reversible_g = ln_factory(feedforward)
return reversible_f, reversible_g
def forward(
self,
x: torch.Tensor,
att_mask: Optional[Union[torch.Tensor, AttentionMask]] = None,
input_mask: Optional[torch.Tensor] = None,
):
if self.patch_emb is not None:
x = self.patch_emb(x)
if self.pose_encoding is not None:
x = self.pose_encoding(x)
if hasattr(self, "embedding_projector"):
x = self.embedding_projector(x)
# Handle the optional input masking, differs on Q, K, V
if input_mask is not None:
q = x
k = x * input_mask.unsqueeze(-1)
v = k
else:
q, k, v = x, x, x
# Pre/Post norms and residual paths are already handled
x = self.wrap_att(inputs=[q, k, v], att_mask=att_mask)
x = self.wrap_ff(inputs=[x])
# Optional simplicial embeddings
if self.simplicial_embedding is not None:
x = self.simplicial_embedding(x)
return x
class xFormerDecoderBlock(torch.nn.Module):
r"""A vanilla Transformer Decoder block
... note: this implementation is not (yet ?) reversible"""
def __init__(self, config: xFormerDecoderConfig, **kwargs):
super().__init__()
deprecated_function(self)
# If this layer is the first one, and a pose encoding as been requested
if (
config.position_encoding_config is not None
and config.layer_position.is_first()
):
self.pose_encoding = build_positional_embedding(
config.position_encoding_config
)
pos_encoding_dim = config.position_encoding_config.dim_model
mha_dim = config.multi_head_config_masked["dim_model"]
if pos_encoding_dim != mha_dim:
logger.warning(
f"The embedding dim and model dim do not match ({pos_encoding_dim} vs {mha_dim}), adding a projector layer." # noqa
)
self.embedding_projector = nn.Linear(pos_encoding_dim, mha_dim)
else:
self.pose_encoding = None
if config.residual_norm_style == ResidualNormStyle.DeepNorm:
# Just use the layer norm coefficient here,
# the init will be handled at the xformers level (knows about encoder and decoder blocks)
_, deep_norm_coefficients = get_deepnorm_coefficients(
encoder_layers=0, decoder_layers=config.num_layers
)
assert deep_norm_coefficients is not None
residual_scale = deep_norm_coefficients.alpha
else:
residual_scale = 1.0
# mini helper, builds a LayerNorm with the right Pre/Post config and the right dimensions
ln_factory = _get_ln_factory(
config.dim_model,
config.residual_norm_style,
use_triton=config.use_triton,
residual=True,
residual_scale=residual_scale,
normalization=config.normalization,
)
mha = build_multi_head_attention(config.multi_head_config_masked)
cross_mha = build_multi_head_attention(config.multi_head_config_cross)
feedforward = build_feedforward(config.feedforward_config)
# Expose attention or feedforward specific capabilities
self.supports_attention_mask = mha.attention.supports_attention_mask
self.requires_same_k_q_dimensions = mha.attention.requires_same_k_q_dimensions
self.requires_squared_context_length = (
feedforward.requires_squared_context
or mha.attention.requires_squared_context
)
self.causal_attention = (
mha.attention.causal if hasattr(mha.attention, "causal") else False
)
# Wrappers handle the different layer norm styles (pre- and post-) and the residual path
self.wrap_att = ln_factory(mha)
self.wrap_cross = ln_factory(cross_mha)
self.wrap_ff: Union[Residual, PostNorm] = ln_factory(feedforward)
if (
config.residual_norm_style == ResidualNormStyle.Pre
and config.layer_position.is_last()
):
self.wrap_ff = PostNorm(
config.dim_model,
self.wrap_ff,
normalization=NormalizationType.LayerNorm,
)
@classmethod
def from_config(cls, config: xFormerDecoderConfig):
return cls(config)
def forward(
self,
target: torch.Tensor,
memory: torch.Tensor,
encoder_att_mask: Optional[Union[torch.Tensor, AttentionMask]] = None,
decoder_att_mask: Optional[Union[torch.Tensor, AttentionMask]] = None,
input_mask: Optional[torch.Tensor] = None,
):
if self.pose_encoding is not None:
target = self.pose_encoding(target)
if hasattr(self, "embedding_projector"):
target = self.embedding_projector(target)
# Handle the optional input masking, differs on Q, K, V
if input_mask is not None:
target_q = target
target_k = target * input_mask.unsqueeze(-1)
target_v = target_k
else:
target_q, target_k, target_v = target, target, target
x = self.wrap_att(
inputs=[target_q, target_k, target_v], att_mask=decoder_att_mask
)
x = self.wrap_cross(inputs=[x, memory, memory], att_mask=encoder_att_mask)
x = self.wrap_ff(inputs=[x])
return x