First commit
This commit is contained in:
93
pkgs/xformers/components/attention/base.py
Normal file
93
pkgs/xformers/components/attention/base.py
Normal file
@@ -0,0 +1,93 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Optional, Type, TypeVar
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from xformers.components.attention import AttentionMask
|
||||
|
||||
|
||||
@dataclass
|
||||
class AttentionConfig:
|
||||
"""Parameters required for all Attentions.
|
||||
Can accept and store extra parameters.
|
||||
"""
|
||||
|
||||
name: str # the registered name for this attention mechanism
|
||||
dropout: float # dropout probability
|
||||
|
||||
|
||||
Self = TypeVar("Self", bound="Attention")
|
||||
|
||||
|
||||
# Define the common interface, every attention block needs to derive from it
|
||||
class Attention(nn.Module, metaclass=ABCMeta):
|
||||
r"""The base Attention mechanism, which is typically a sub-part of the multi-head attention"""
|
||||
|
||||
_causal_mask: Optional[AttentionMask] = None
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, dropout: Optional[float] = None, *args, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
# Requires the inputs to be projected
|
||||
self.requires_input_projection = True
|
||||
|
||||
# Whether the head dimension needs to be present (if not it can be folded into the batch dimension)
|
||||
self.requires_head_dimension = False
|
||||
|
||||
# key padding mask and attention mask must be passed in as separate arguments instead of a merged attention mask
|
||||
self.requires_separate_masks = False
|
||||
|
||||
# Requires that K and Q have the same sequence length
|
||||
self.requires_same_k_q_dimensions = False
|
||||
|
||||
# Whether the attention owns the single head/multihead mechanism
|
||||
# so that the MHA wrapper should skip it
|
||||
self.requires_skip_multi_head = False
|
||||
|
||||
# This attention requires a context length which is squared, often due to 2D pooling
|
||||
self.requires_squared_context = False
|
||||
|
||||
# Whether this attention mechanism supports attention masks
|
||||
self.supports_attention_mask = True
|
||||
self.supports_key_padding_mask = False
|
||||
|
||||
@classmethod
|
||||
def from_config(cls: Type[Self], config: AttentionConfig) -> Self:
|
||||
# Generate the class inputs from the config
|
||||
fields = asdict(config)
|
||||
|
||||
# Skip all Nones so that default values are used
|
||||
fields = {k: v for k, v in fields.items() if v is not None}
|
||||
|
||||
return cls(**fields)
|
||||
|
||||
@abstractmethod
|
||||
def forward(
|
||||
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, *args, **kwargs
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def _maybe_pad_sequence(x: torch.Tensor, mask: torch.Tensor):
|
||||
"""
|
||||
If the sequence is shorter than the mask, return a padded view
|
||||
"""
|
||||
if x.shape[-2] != mask.shape[-1]:
|
||||
assert x.shape[-2] < mask.shape[-1], (
|
||||
"Sequence is bigger than the provided mask, cannot infer what to do with it."
|
||||
" Please update your attention mask"
|
||||
)
|
||||
|
||||
pad_size = (0, 0, 0, mask.shape[-1] - x.shape[-2], 0, 0)
|
||||
return torch.nn.functional.pad(x, pad_size, mode="constant", value=0.0)
|
||||
|
||||
return x
|
||||
Reference in New Issue
Block a user