120 lines
3.8 KiB
Python
120 lines
3.8 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import math
|
|
from typing import Any, Optional
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.cuda.amp import custom_bwd, custom_fwd
|
|
|
|
from xformers.components.activations import Activation
|
|
from xformers.triton.k_activations import get_triton_activation_index
|
|
from xformers.triton.k_fused_matmul_bw import fused_matmul_backward
|
|
from xformers.triton.k_fused_matmul_fw import fused_matmul
|
|
|
|
|
|
class _fused_linear_triton(torch.autograd.Function):
|
|
@staticmethod
|
|
@custom_fwd(cast_inputs=torch.float16)
|
|
def forward(
|
|
ctx,
|
|
x,
|
|
weight,
|
|
bias,
|
|
activation,
|
|
trainable_weight,
|
|
trainable_bias,
|
|
save_activation_inputs,
|
|
):
|
|
|
|
# Kick the fused Triton kernel, handling bias and activation in one go
|
|
y, activation_inputs = fused_matmul(
|
|
x, weight, bias, activation, save_activation_inputs
|
|
)
|
|
|
|
ctx.activation = activation
|
|
ctx.trainable_weight = trainable_weight
|
|
ctx.trainable_bias = trainable_bias
|
|
|
|
# Micro-optimization: saving these is not always needed (?)
|
|
if x.requires_grad or ctx.trainable_weight or ctx.trainable_bias:
|
|
ctx.save_for_backward(weight, activation_inputs, x)
|
|
|
|
return y
|
|
|
|
@staticmethod
|
|
@custom_bwd
|
|
def backward(
|
|
ctx: Any, grad_out: torch.Tensor
|
|
) -> Any: # pragma: no cover # this is covered, but called directly from C++
|
|
"""
|
|
Compute the derivative with respect to x, other tensors were not trainable inputs.
|
|
"""
|
|
(weight, activation_inputs, x) = ctx.saved_tensors
|
|
|
|
grad_input, grad_weight, grad_bias = fused_matmul_backward(
|
|
grad_out=grad_out,
|
|
inputs=x,
|
|
act_in=activation_inputs,
|
|
weight=weight,
|
|
trainable_weight=ctx.trainable_weight,
|
|
trainable_bias=ctx.trainable_bias,
|
|
activation_grad=ctx.activation,
|
|
)
|
|
|
|
return grad_input, grad_weight, grad_bias, None, None, None, None, None, None
|
|
|
|
|
|
class FusedLinear(nn.Module):
|
|
"""
|
|
Handle a linear transform, like torch.nn.Linear_, and a given activation, in a single kernel.
|
|
The whole transform: is :math:`y = activation(xA^T + b)`.
|
|
|
|
This is typically significantly faster than PyTorch while using fp16 and non-sigmoid activations,
|
|
as of September 2021.
|
|
|
|
.. _torch.nn.Linear: https://pytorch.org/docs/stable/generated/torch.nn.Linear.html
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
in_features: int,
|
|
out_features: int,
|
|
bias: bool = False,
|
|
activation: Optional[Activation] = None,
|
|
**_,
|
|
):
|
|
super().__init__()
|
|
self.weight = nn.Parameter(
|
|
torch.empty(out_features, in_features), requires_grad=True
|
|
)
|
|
self.bias = (
|
|
nn.Parameter(torch.empty(out_features), requires_grad=True)
|
|
if bias
|
|
else None
|
|
)
|
|
|
|
self._activation_index = get_triton_activation_index(activation)
|
|
self.reset_parameters()
|
|
|
|
def reset_parameters(self) -> None:
|
|
torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
|
if self.bias is not None:
|
|
fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight)
|
|
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
|
torch.nn.init.uniform_(self.bias, -bound, bound)
|
|
|
|
def forward(self, x):
|
|
return _fused_linear_triton.apply(
|
|
x,
|
|
self.weight,
|
|
self.bias,
|
|
self._activation_index,
|
|
self.weight.requires_grad,
|
|
self.bias.requires_grad if self.bias is not None else False,
|
|
self.training and x.requires_grad and self._activation_index > 0,
|
|
)
|