100 lines
3.0 KiB
Python
100 lines
3.0 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
# CREDITS: Inspired by https://github.com/pytorch/text/blob/master/torchtext/nn/modules/multiheadattention.py
|
|
# and the MultiHeadAttention implementation from PyTorch
|
|
|
|
|
|
import logging
|
|
from dataclasses import dataclass
|
|
from typing import Optional, Tuple
|
|
|
|
import torch
|
|
from torch import nn
|
|
|
|
logger = logging.getLogger("xformers")
|
|
|
|
|
|
@dataclass
|
|
class InputProjectionConfig:
|
|
in_features: int
|
|
out_features: int
|
|
bias: bool
|
|
|
|
|
|
class InputProjection(nn.Module):
|
|
"""
|
|
Handle all the input projections in one go, opportunistically fuse some operations.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
query_proj_params: InputProjectionConfig,
|
|
key_proj_params: Optional[InputProjectionConfig],
|
|
value_proj_params: Optional[InputProjectionConfig],
|
|
use_separate_proj_weight: bool = True,
|
|
):
|
|
|
|
super().__init__()
|
|
|
|
self.out_features = query_proj_params.out_features
|
|
|
|
# Each input gets a separate projection
|
|
self.q_proj = nn.Linear(
|
|
query_proj_params.in_features,
|
|
query_proj_params.out_features,
|
|
query_proj_params.bias,
|
|
)
|
|
|
|
if key_proj_params is not None:
|
|
self.k_proj = nn.Linear(
|
|
key_proj_params.in_features,
|
|
key_proj_params.out_features,
|
|
key_proj_params.bias,
|
|
)
|
|
else:
|
|
logger.info(
|
|
"No Key projection parameters were passed, assuming that the weights"
|
|
+ " are shared with the query projection"
|
|
)
|
|
self.k_proj = self.q_proj
|
|
|
|
if value_proj_params is not None:
|
|
self.v_proj = nn.Linear(
|
|
value_proj_params.in_features,
|
|
value_proj_params.out_features,
|
|
value_proj_params.bias,
|
|
)
|
|
else:
|
|
logger.info(
|
|
"No Value projection parameters were passed, assuming that the weights"
|
|
+ " are shared with the query projection"
|
|
)
|
|
self.v_proj = self.q_proj
|
|
|
|
if not use_separate_proj_weight:
|
|
# Compute optimization used at times, share the parameters in between Q/K/V
|
|
with torch.no_grad():
|
|
self.k_proj.weight = self.q_proj.weight
|
|
self.v_proj.weight = self.q_proj.weight
|
|
|
|
def forward(
|
|
self,
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
# One projection per input tensor
|
|
|
|
# NOTE: Would it make sense to catch self attention + shared weights, to skip a projection step ?
|
|
|
|
q, k, v = map(
|
|
lambda fn, x: fn(x),
|
|
[self.q_proj, self.k_proj, self.v_proj],
|
|
[query, key, value],
|
|
)
|
|
|
|
return q, k, v
|