Files
2025-08-05 19:02:46 +08:00

100 lines
3.0 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# CREDITS: Inspired by https://github.com/pytorch/text/blob/master/torchtext/nn/modules/multiheadattention.py
# and the MultiHeadAttention implementation from PyTorch
import logging
from dataclasses import dataclass
from typing import Optional, Tuple
import torch
from torch import nn
logger = logging.getLogger("xformers")
@dataclass
class InputProjectionConfig:
in_features: int
out_features: int
bias: bool
class InputProjection(nn.Module):
"""
Handle all the input projections in one go, opportunistically fuse some operations.
"""
def __init__(
self,
query_proj_params: InputProjectionConfig,
key_proj_params: Optional[InputProjectionConfig],
value_proj_params: Optional[InputProjectionConfig],
use_separate_proj_weight: bool = True,
):
super().__init__()
self.out_features = query_proj_params.out_features
# Each input gets a separate projection
self.q_proj = nn.Linear(
query_proj_params.in_features,
query_proj_params.out_features,
query_proj_params.bias,
)
if key_proj_params is not None:
self.k_proj = nn.Linear(
key_proj_params.in_features,
key_proj_params.out_features,
key_proj_params.bias,
)
else:
logger.info(
"No Key projection parameters were passed, assuming that the weights"
+ " are shared with the query projection"
)
self.k_proj = self.q_proj
if value_proj_params is not None:
self.v_proj = nn.Linear(
value_proj_params.in_features,
value_proj_params.out_features,
value_proj_params.bias,
)
else:
logger.info(
"No Value projection parameters were passed, assuming that the weights"
+ " are shared with the query projection"
)
self.v_proj = self.q_proj
if not use_separate_proj_weight:
# Compute optimization used at times, share the parameters in between Q/K/V
with torch.no_grad():
self.k_proj.weight = self.q_proj.weight
self.v_proj.weight = self.q_proj.weight
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# One projection per input tensor
# NOTE: Would it make sense to catch self attention + shared weights, to skip a projection step ?
q, k, v = map(
lambda fn, x: fn(x),
[self.q_proj, self.k_proj, self.v_proj],
[query, key, value],
)
return q, k, v