First commit
This commit is contained in:
99
pkgs/xformers/components/input_projection.py
Normal file
99
pkgs/xformers/components/input_projection.py
Normal file
@@ -0,0 +1,99 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
# CREDITS: Inspired by https://github.com/pytorch/text/blob/master/torchtext/nn/modules/multiheadattention.py
|
||||
# and the MultiHeadAttention implementation from PyTorch
|
||||
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
logger = logging.getLogger("xformers")
|
||||
|
||||
|
||||
@dataclass
|
||||
class InputProjectionConfig:
|
||||
in_features: int
|
||||
out_features: int
|
||||
bias: bool
|
||||
|
||||
|
||||
class InputProjection(nn.Module):
|
||||
"""
|
||||
Handle all the input projections in one go, opportunistically fuse some operations.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
query_proj_params: InputProjectionConfig,
|
||||
key_proj_params: Optional[InputProjectionConfig],
|
||||
value_proj_params: Optional[InputProjectionConfig],
|
||||
use_separate_proj_weight: bool = True,
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.out_features = query_proj_params.out_features
|
||||
|
||||
# Each input gets a separate projection
|
||||
self.q_proj = nn.Linear(
|
||||
query_proj_params.in_features,
|
||||
query_proj_params.out_features,
|
||||
query_proj_params.bias,
|
||||
)
|
||||
|
||||
if key_proj_params is not None:
|
||||
self.k_proj = nn.Linear(
|
||||
key_proj_params.in_features,
|
||||
key_proj_params.out_features,
|
||||
key_proj_params.bias,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"No Key projection parameters were passed, assuming that the weights"
|
||||
+ " are shared with the query projection"
|
||||
)
|
||||
self.k_proj = self.q_proj
|
||||
|
||||
if value_proj_params is not None:
|
||||
self.v_proj = nn.Linear(
|
||||
value_proj_params.in_features,
|
||||
value_proj_params.out_features,
|
||||
value_proj_params.bias,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"No Value projection parameters were passed, assuming that the weights"
|
||||
+ " are shared with the query projection"
|
||||
)
|
||||
self.v_proj = self.q_proj
|
||||
|
||||
if not use_separate_proj_weight:
|
||||
# Compute optimization used at times, share the parameters in between Q/K/V
|
||||
with torch.no_grad():
|
||||
self.k_proj.weight = self.q_proj.weight
|
||||
self.v_proj.weight = self.q_proj.weight
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
# One projection per input tensor
|
||||
|
||||
# NOTE: Would it make sense to catch self attention + shared weights, to skip a projection step ?
|
||||
|
||||
q, k, v = map(
|
||||
lambda fn, x: fn(x),
|
||||
[self.q_proj, self.k_proj, self.v_proj],
|
||||
[query, key, value],
|
||||
)
|
||||
|
||||
return q, k, v
|
||||
Reference in New Issue
Block a user