# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. # Credits: this is heavily inspired by the official implementation, present in # https://github.com/sarthmit/Compositional-Attention # Original author: Sarthak Mittal # This is a simplified version, for the sake of clarity, and because some features could be exposed later # via the library directly. # In particular, code paths for TPUs, quantization and gumbel softmax have been removed # We're also following the same dimension ordering as in the rest of the xformers library # which is to say [Batch, Sequence, Embedding] wherever possible import math from dataclasses import dataclass from typing import Optional import torch import torch.nn.functional as F from torch import Tensor, nn from xformers.components.attention import ( Attention, AttentionConfig, AttentionMask, register_attention, ) from xformers.components.attention.core import _softmax from xformers.components.input_projection import InputProjection, InputProjectionConfig def _either_or(a: Optional[int], b: int) -> int: return a if a is not None else b @dataclass class CompositionalAttentionConfig(AttentionConfig): dim_model: int num_heads: int dim_attn: Optional[int] = None num_rules: Optional[int] = None dim_key: Optional[int] = None dim_value: Optional[int] = None dim_selection: Optional[int] = None dropout: float qk_rule: bool = False nonlinear: bool = False q_compose: bool = False bias: bool = True causal: Optional[bool] = False in_proj_container: Optional[InputProjection] = None use_separate_proj_weight: Optional[bool] = False @register_attention("compositional", CompositionalAttentionConfig) class CompositionalAttention(Attention): """Compositional Attention, as proposed in "Compositional Attention: Disentangling search and retrieval"_, S. Mittal et al. A key insight from this proposal is that the attention mechanism can be conceived as two steps: a search and a retrieval operation. When queried, the model can search for the most relevant information (Softmax(QKt)), then retrieve information given the Value. Contrary to the original attention proposal, which does not consider interactions in between heads, the compositional attention will consider all possible interactions and softmax over that dimension, so that the information retrieved covers the most relevant dimensions. The number of heads and rules to use is thus typically smaller than for a comparable traditional Transformer, and asking for the same number of heads may not fit in memory. Args: dim_model: dimension of the incoming latent space num_heads: number of heads *for the search operation* dim_attn: dimension (embedding) of the attention num_rules: number of rules to consider *for the retrieval operation* dim_selection: dimension of the scoring/selection space for the retrievals dim_key, dim_value: dimensions of K and V, if different from Q dropout: attention dropout probability qk_rule: QK product will drive the retrieval process nonlinear: use a non linear method to score the retrievals bias: use bias in the initial projection step causal: causal computations (attend to the past only) _"Compositional Attention: Disentangling search and retrieval": https://arxiv.org/pdf/2110.09419v1.pdf """ def __init__( self, dim_model: int, num_heads: int, dim_attn: Optional[int] = None, num_rules: Optional[int] = None, dim_selection: Optional[int] = None, dim_key: Optional[int] = None, dim_value: Optional[int] = None, dropout=0.0, qk_rule=False, nonlinear=False, q_compose=False, in_proj_container: Optional[InputProjection] = None, use_separate_proj_weight: Optional[bool] = False, bias=True, causal=False, *_, **__, ): super().__init__() # Define the inherited flags self.requires_skip_multi_head = ( True # This attention owns the multi-head mechanism ) # Handle defaults / undefined values self.dim_model = dim_model num_rules = _either_or(num_rules, num_heads) dim_selection = _either_or(dim_selection, dim_model // num_heads) # All the initial definition plumbing dim_attn = _either_or(dim_attn, dim_model) dim_key = _either_or(dim_key, dim_model) dim_value = _either_or(dim_value, dim_model) self.in_proj_container = ( in_proj_container if in_proj_container is not None else InputProjection( query_proj_params=InputProjectionConfig(dim_model, dim_key, bias=bias), key_proj_params=InputProjectionConfig(dim_model, dim_key, bias=bias) if use_separate_proj_weight else None, value_proj_params=InputProjectionConfig(dim_model, dim_value, bias=bias) if use_separate_proj_weight else None, ) ) self.num_heads = num_heads self.num_rules = num_rules self.qk_rule = qk_rule self.dim_selection = dim_selection self.nonlinear = nonlinear self.q_compose = q_compose self.dropout_module = nn.Dropout(dropout) self.dim_head = dim_model // num_heads self.value_dim = dim_attn // num_rules assert ( self.value_dim * num_rules == dim_attn ), "value_dim must be divisible by num_rules" self.scaling = self.dim_head**-0.5 self.scaling_values = self.dim_selection**-0.5 self.out_proj = nn.Linear(self.num_heads * self.value_dim, dim_model, bias=bias) if self.qk_rule: self.value_k = nn.Linear(self.value_dim, self.dim_selection, bias=bias) if self.q_compose: self.value_q = nn.Linear(self.dim_head, self.dim_selection, bias=bias) else: self.value_q = nn.Linear( dim_model, self.dim_selection * self.num_heads, bias=bias ) else: if self.q_compose: self.value_q = nn.Linear(self.dim_head, self.dim_selection, bias=bias) else: self.value_q = nn.Linear( dim_model, self.dim_selection * self.num_heads, bias=bias ) if self.nonlinear: self.score_network: nn.Module = nn.Sequential( nn.Linear( self.dim_selection + self.value_dim, self.dim_selection, bias=bias, ), nn.ReLU(), nn.Linear(self.dim_selection, 1, bias=bias), ) else: self.score_network = nn.Linear( self.dim_selection + self.value_dim, 1, bias=bias ) self.causal = causal # Properties specific to this attention mechanism self.supports_attention_mask = True self.supports_key_padding_mask = False self._reset_parameters() def _reset_parameters(self): # NOTE: in_proj_container is already initialized if self.qk_rule: nn.init.xavier_uniform_(self.value_k.weight, gain=1 / math.sqrt(2)) nn.init.xavier_uniform_(self.value_q.weight, gain=1 / math.sqrt(2)) else: nn.init.xavier_uniform_(self.value_q.weight) if self.nonlinear: nn.init.xavier_uniform_(self.score_network[0].weight) nn.init.xavier_uniform_(self.score_network[2].weight) else: nn.init.xavier_uniform_(self.score_network.weight) nn.init.xavier_uniform_(self.out_proj.weight) if self.out_proj.bias is not None: nn.init.constant_(self.out_proj.bias, 0.0) def forward( self, q: Tensor, k: Tensor, v: Tensor, att_mask: Optional[Tensor] = None, *args, **kwargs, ) -> Tensor: """ Input shape: Time x Batch x Channel Args: att_mask (ByteTensor, optional): typically used to implement causal attention, where the mask prevents the attention from looking forward in time (default: None). """ B, Sq, E = q.shape _, Sk, _ = k.shape assert E == self.dim_model # First define projected query/key/values # We keep the projected and original tensors in flight, # depending on the options the original values could be reused q_unprojected = q q, k, v = self.in_proj_container(query=q, key=k, value=v) q *= self.scaling # Init causal mask if needed, now that we know the context length if self.causal and ( self._causal_mask is None or self._causal_mask.shape[0] != Sk ): self._causal_mask = AttentionMask.make_causal(Sq, Sq, device=q.device) # Convenience, create an attention mask if a tensor was passed # This sanitizes different mask types being passed, from now on it's additive if isinstance(att_mask, torch.Tensor): # By default we don't know of the causality, and a check would be expensive att_mask_additive: Optional[AttentionMask] = ( AttentionMask.from_bool(att_mask) if att_mask.dtype == torch.bool else AttentionMask(att_mask, is_causal=False) ) else: att_mask_additive = None # Handle the attention and key padding masks if self._causal_mask is not None: # Optionally add the causal mask if att_mask_additive is not None: att_mask_additive += self._causal_mask else: att_mask_additive = self._causal_mask # Flatten the heads or the rules q = ( q.view(B, Sq, self.num_heads, self.dim_head) .movedim(2, 1) .flatten(0, 1) # [B * num_heads, Sq, dim_head] ) k = ( k.view(B, Sk, self.num_heads, self.dim_head).movedim(2, 1).flatten(0, 1) ) # [B * num_heads, Sk, dim_head] v = v.view(B, -1, self.num_rules, self.value_dim).movedim(2, 1).flatten(0, 1) # Compute the search: Softmax(QKt) attn_weights = torch.bmm(q, k.transpose(1, 2)) # [B * self.num_heads, Sq, Sk] if att_mask_additive is not None: attn_weights += att_mask_additive.values attn_weights = _softmax(attn_weights, causal=self.causal) attn_weights = attn_weights.view(B, self.num_heads, Sq, Sk) attn_probs = self.dropout_module(attn_weights) # Now compute the information retrieval # keep all the heads in flight, we'll score the different possibilities # - compute all the possible retrievals v = v.view(B, 1, self.num_rules, Sk, self.value_dim) attn_probs = attn_probs.unsqueeze(2) attn = torch.matmul(attn_probs, v).view( B, self.num_heads, self.num_rules, Sq, self.value_dim ) attn = attn.movedim(3, 1) # [B, Sq, H, Rules, Values] # - search the most appropriate retrieval among all the values if self.q_compose: v_q = self.value_q(q.transpose(0, 1)).view( B, Sq, self.num_heads, 1, self.dim_selection ) else: v_q = self.value_q(q_unprojected).view( B, Sq, self.num_heads, 1, self.dim_selection ) if self.qk_rule: v_q *= self.scaling_values v_k = ( self.value_k(attn) .view(B, Sq, self.num_heads, self.num_rules, self.dim_selection) .transpose(4, 3) .contiguous() ) v_score = torch.matmul(v_q, v_k).view( B, Sq, self.num_heads, self.num_rules, 1 ) else: v_q = v_q.expand(-1, -1, -1, self.num_rules, -1) v_in = torch.cat([attn, v_q], dim=-1) v_score = self.score_network(v_in).view( B, Sq, self.num_heads, self.num_rules, 1 ) v_score = F.softmax(v_score, dim=3) # - extracted values are the original attention (inc. all the values) weighted by value score attn = (attn * v_score).sum(dim=3).view(B, Sq, self.num_heads * self.value_dim) # Final attention projection, same as other mechanisms attn = self.out_proj(attn) return attn