109 lines
3.7 KiB
Python
109 lines
3.7 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
from typing import Optional
|
|
|
|
import torch
|
|
|
|
|
|
# Reshapes key padding mask from (batch_size, src_len) -> (batch_size * num_heads 1, src_len)
|
|
def reshape_key_padding_mask(
|
|
key_padding_mask: torch.Tensor, batched_dim: int
|
|
) -> torch.Tensor:
|
|
assert key_padding_mask.ndim == 2
|
|
batch_size, src_len = key_padding_mask.size()
|
|
num_heads = batched_dim // batch_size
|
|
return _reshape_key_padding_mask(key_padding_mask, batch_size, src_len, num_heads)
|
|
|
|
|
|
def _reshape_key_padding_mask(
|
|
key_padding_mask: torch.Tensor, batch_size: int, src_len: int, num_heads: int
|
|
) -> torch.Tensor:
|
|
assert key_padding_mask.shape == (batch_size, src_len)
|
|
key_padding_mask = (
|
|
key_padding_mask.view(batch_size, 1, 1, src_len)
|
|
.expand(-1, num_heads, -1, -1)
|
|
.reshape(batch_size * num_heads, 1, src_len)
|
|
)
|
|
return key_padding_mask
|
|
|
|
|
|
# Combine the attention mask and key padding mask into a single mask
|
|
# Taken from https://github.com/pytorch/pytorch/blob/master/torch/nn/functional.py
|
|
# Additive masking not yet supported
|
|
def maybe_merge_masks(
|
|
att_mask: Optional[torch.Tensor],
|
|
key_padding_mask: Optional[torch.Tensor],
|
|
batch_size: int,
|
|
src_len: int,
|
|
num_heads: int,
|
|
tgt_len: Optional[int] = None,
|
|
) -> Optional[torch.Tensor]:
|
|
if tgt_len is None:
|
|
tgt_len = src_len
|
|
if key_padding_mask is not None:
|
|
assert key_padding_mask.shape == (batch_size, src_len)
|
|
key_padding_mask = _reshape_key_padding_mask(
|
|
key_padding_mask, batch_size, src_len, num_heads
|
|
)
|
|
if att_mask is None:
|
|
# make sure dimensions of key padding mask are the same as those expected for att_mask
|
|
att_mask = key_padding_mask.expand(-1, tgt_len, -1)
|
|
# Assumption is that False means to mask.
|
|
elif att_mask.dtype == torch.bool:
|
|
att_mask = att_mask.logical_and(key_padding_mask)
|
|
else:
|
|
att_mask = att_mask.masked_fill(~key_padding_mask, float("-inf"))
|
|
|
|
return att_mask
|
|
|
|
|
|
# Assumes that matrix passed in has had softmax applied to it.
|
|
def iterative_pinv(softmax_mat: torch.Tensor, n_iter=6, pinverse_original_init=False):
|
|
"""
|
|
Computing the Moore-Penrose inverse.
|
|
Use an iterative method from (Razavi et al. 2014) to approximate the Moore-Penrose inverse via efficient
|
|
matrix-matrix multiplications.
|
|
"""
|
|
|
|
i = torch.eye(
|
|
softmax_mat.size(-1), device=softmax_mat.device, dtype=softmax_mat.dtype
|
|
)
|
|
k = softmax_mat
|
|
|
|
# The entries of K are positive and ||K||_{\infty} = 1 due to softmax
|
|
if pinverse_original_init:
|
|
# This original implementation is more conservative to compute coefficient of Z_0.
|
|
v = 1 / torch.max(torch.sum(k, dim=-2)) * k.transpose(-1, -2)
|
|
else:
|
|
# This is the exact coefficient computation, 1 / ||K||_1, of initialization of Z_0, leading to faster
|
|
# convergence.
|
|
v = (
|
|
1
|
|
/ torch.max(torch.sum(k, dim=-2), dim=-1).values[:, None, None]
|
|
* k.transpose(-1, -2)
|
|
)
|
|
|
|
for _ in range(n_iter):
|
|
kv = torch.matmul(k, v)
|
|
v = torch.matmul(
|
|
0.25 * v,
|
|
13 * i - torch.matmul(kv, 15 * i - torch.matmul(kv, 7 * i - kv)),
|
|
)
|
|
return v
|
|
|
|
|
|
def bool_mask_to_additive(
|
|
mask: torch.Tensor, dtype: Optional[torch.dtype] = torch.float32
|
|
) -> torch.Tensor:
|
|
assert (
|
|
mask.dtype == torch.bool
|
|
), "This util is meant to convert in between bool masks and additive ones"
|
|
|
|
mask_ = torch.zeros_like(mask, dtype=dtype)
|
|
mask_[~mask] = float("-inf")
|
|
return mask_
|