468 lines
15 KiB
Python
468 lines
15 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
from dataclasses import dataclass
|
|
from typing import Dict, Optional, Sequence, Tuple, Union
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from torch import nn
|
|
|
|
from .common import BaseOperator, get_xformers_operator, register_operator
|
|
from .unbind import stack_or_none, unbind
|
|
|
|
|
|
@register_operator
|
|
class DualGemmSiluOp(BaseOperator):
|
|
OPERATOR = get_xformers_operator("dual_gemm_silu_identity_mul")
|
|
OPERATOR_CATEGORY = "swiglu"
|
|
NAME = "dual_gemm_silu"
|
|
|
|
@classmethod
|
|
# type: ignore
|
|
def operator_flop(
|
|
cls, x: torch.Tensor, w1: torch.Tensor, b1, w2: torch.Tensor, b2
|
|
) -> int:
|
|
"""NOTE: we neglect the impact of biases / pointwises"""
|
|
M, N, K = x.shape[0], w1.shape[0], w1.shape[1]
|
|
return M * N * K * 2 * 2
|
|
|
|
|
|
@register_operator
|
|
class GemmFusedSumOp(BaseOperator):
|
|
OPERATOR = get_xformers_operator("gemm_fused_operand_sum")
|
|
OPERATOR_CATEGORY = "swiglu"
|
|
NAME = "gemm_fused_operand_sum"
|
|
|
|
@classmethod
|
|
# type: ignore
|
|
def operator_flop(cls, a: torch.Tensor, b: torch.Tensor, out1, out2) -> int:
|
|
M, N, K = a.shape[0], b.shape[1], a.shape[1]
|
|
return M * N * K * 2
|
|
|
|
|
|
class _SwiGLUDecomposedFunc(torch.autograd.Function):
|
|
"""
|
|
This is just an example implementation with all
|
|
operations explicited. This implementation is worse
|
|
than pytorch, because pytorch is able to fuse some operations
|
|
(eg the linear forward ...) that are decomposed here.
|
|
|
|
The time measurements were made on the ViT-Giant setting:
|
|
- A100/f16
|
|
- input: [4440, 1536]
|
|
- hidden: [4440, 4096]
|
|
"""
|
|
|
|
NAME = "decomposed"
|
|
FORCE_BW_F32 = False
|
|
|
|
def _silu_backward(dy, x):
|
|
# https://github.com/pytorch/pytorch/blob/563b065f5a4b4055fa6b025c2514b566d5fd9439/aten/src/ATen/native/Activation.cpp#L483
|
|
sigm = 1 / (1 + torch.exp(-x.float()))
|
|
return (dy.float() * sigm * (1 + x.float() * (1 - sigm))).to(x.dtype)
|
|
|
|
# 952us
|
|
@classmethod
|
|
def forward(cls, ctx, x, w1, b1, w2, b2, w3, b3):
|
|
x1 = x @ w1.transpose(-2, -1) + b1 # 275us
|
|
x2 = x @ w2.transpose(-2, -1) + b2 # 275us
|
|
x3 = F.silu(x1) # 62us
|
|
x4 = x3 * x2 # 90us
|
|
x5 = x4 @ w3.transpose(-2, -1) + b3 # 250us
|
|
|
|
ctx.save_for_backward(x, w1, b1, w2, b2, w3, b3, x1, x2, x3, x4, x5)
|
|
return x5
|
|
|
|
# 1900us
|
|
@classmethod
|
|
def backward(cls, ctx, dx5):
|
|
saved_tensors = ctx.saved_tensors
|
|
if cls.FORCE_BW_F32:
|
|
dx5 = dx5.float()
|
|
saved_tensors = [t.float() for t in ctx.saved_tensors]
|
|
x, w1, b1, w2, b2, w3, b3, x1, x2, x3, x4, x5 = saved_tensors
|
|
dx4 = dx5 @ w3 # 255us (nn)
|
|
dw3 = dx5.transpose(-2, -1) @ x4 # 247us (nt)
|
|
db3 = dx5.sum(0) # 25us
|
|
dx3 = dx4 * x2 # 88us
|
|
dx2 = dx4 * x3 # 88us
|
|
dx1 = cls._silu_backward(dx3, x1) # 90us
|
|
dx = dx2 @ w2 # 260us (nn)
|
|
dw2 = dx2.transpose(-2, -1) @ x # 245us (nt)
|
|
db2 = dx2.sum(0) # 50us
|
|
dx += dx1 @ w1 # 260us (nn)
|
|
dw1 = dx1.transpose(-2, -1) @ x # 245us (nt)
|
|
db1 = dx1.sum(0) # 50us
|
|
return (dx, dw1, db1, dw2, db2, dw3, db3)
|
|
|
|
|
|
class _SwiGLUFusedFunc(torch.autograd.Function):
|
|
NAME = "fused.py"
|
|
|
|
@classmethod
|
|
@torch.cuda.amp.custom_fwd
|
|
def forward(cls, ctx, x, w1, b1, w2, b2, w3, b3):
|
|
x1, x2, x4 = DualGemmSiluOp.OPERATOR(x, w1, b1, w2, b2)
|
|
|
|
x5 = F.linear(x4, w3, b3)
|
|
ctx.save_for_backward(x, w1, w2, w3, x1, x2)
|
|
ctx.bias = [b1 is not None, b2 is not None, b3 is not None]
|
|
return x5
|
|
|
|
@staticmethod
|
|
def _linear_bw(
|
|
dy: torch.Tensor, x: torch.Tensor, bias: bool
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
|
if not bias:
|
|
return (dy.transpose(-2, -1) @ x), None
|
|
db = torch.empty([dy.shape[1]], dtype=dy.dtype, device=dy.device)
|
|
dw = torch.empty([dy.shape[1], x.shape[1]], dtype=dy.dtype, device=dy.device)
|
|
GemmFusedSumOp.OPERATOR(dy.transpose(-2, -1), x, dw, db)
|
|
return dw, db
|
|
|
|
@classmethod
|
|
@torch.cuda.amp.custom_bwd
|
|
def backward(cls, ctx, dx5):
|
|
x, w1, w2, w3, x1, x2 = ctx.saved_tensors
|
|
w1w2 = stack_or_none([w1, w2], dim=0)
|
|
|
|
dx4 = dx5 @ w3 # 255us (nn)
|
|
dx1dx2, x4 = torch.ops.xformers.silu_bw_fused(x1, x2, dx4)
|
|
dx1, dx2 = dx1dx2.unbind(1)
|
|
del x1, x2, dx4
|
|
|
|
dw3, db3 = cls._linear_bw(dx5, x4, bias=ctx.bias[2])
|
|
del x4, dx5
|
|
if w1w2 is not None:
|
|
assert dx1dx2.is_contiguous()
|
|
assert w1w2.is_contiguous()
|
|
w1w2 = w1w2.view([w1.shape[0] * 2, w1.shape[1]])
|
|
dx = dx1dx2.view([dx1.shape[0], 2 * dx1.shape[1]]) @ w1w2
|
|
|
|
# backward of linear1 + linear2 - packed
|
|
dw1dw2 = dx1dx2.view([dx1.shape[0], 2 * dx1.shape[1]]).transpose(-2, -1) @ x
|
|
dw1dw2, db1db2 = cls._linear_bw(
|
|
dx1dx2.view([dx1.shape[0], 2 * dx1.shape[1]]), x, bias=ctx.bias[0]
|
|
)
|
|
dw1, dw2 = dw1dw2.view([2, *w1.shape]).unbind(0)
|
|
if ctx.bias[0]:
|
|
db1db2 = db1db2.view([2, dx1.shape[1]])
|
|
db1, db2 = torch.unbind(db1db2, dim=0)
|
|
else:
|
|
db1 = db2 = None
|
|
else:
|
|
dx = dx2 @ w2 # 260us (nn)
|
|
torch.addmm(
|
|
dx, dx1, w1.to(dx1.dtype), beta=1, alpha=1, out=dx
|
|
) # dx += dx1 @ w1
|
|
dw2, db2 = cls._linear_bw(dx2, x, bias=ctx.bias[1])
|
|
dw1, db1 = cls._linear_bw(dx1, x, bias=ctx.bias[0])
|
|
return (dx, dw1, db1, dw2, db2, dw3, db3)
|
|
|
|
|
|
class SwiGLUOp:
|
|
"""Base class for any swiglu operator in :attr:`xformers.ops.swiglu`"""
|
|
|
|
def __init__(self, op, packed_weights: bool, name: str, constraints):
|
|
self.NAME = name
|
|
self.PACKED_WEIGHTS = packed_weights
|
|
self.op = op
|
|
self.constraints = constraints
|
|
|
|
def supports(self, op: "SwiGLUOpDispatch") -> bool:
|
|
if self.PACKED_WEIGHTS and not op.packed_weights:
|
|
return False
|
|
return all(c(op) for c in self.constraints)
|
|
|
|
def __call__(self, *args: Optional[torch.Tensor]) -> torch.Tensor:
|
|
pass
|
|
|
|
def __str__(self) -> str:
|
|
return f"SwiGLUOp:{self.NAME}"
|
|
|
|
|
|
class _ForwardToPythonAutogradFunc(SwiGLUOp):
|
|
def supports(self, op: "SwiGLUOpDispatch") -> bool:
|
|
# Let's disable autocast in bf16 until this issue is fixed
|
|
# https://github.com/pytorch/pytorch/issues/87979
|
|
if op.dtype_autocast_gpu == torch.bfloat16:
|
|
return False
|
|
return super().supports(op)
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
return self.op.apply(*args, **kwargs)
|
|
|
|
|
|
class _ForwardToFunc(SwiGLUOp):
|
|
def __call__(self, *args, **kwargs):
|
|
return self.op(*args, **kwargs)
|
|
|
|
def info(self):
|
|
if self.op.__name__ == "no_such_operator":
|
|
return "not built"
|
|
return "available"
|
|
|
|
|
|
def _eager_functional_swiglu(
|
|
x: torch.Tensor,
|
|
w1: torch.Tensor,
|
|
b1: torch.Tensor,
|
|
w2: torch.Tensor,
|
|
b2: torch.Tensor,
|
|
w3: torch.Tensor,
|
|
b3: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
x1 = F.linear(x, w1, b1)
|
|
x2 = F.linear(x, w2, b2)
|
|
hidden = F.silu(x1) * x2
|
|
return F.linear(hidden, w3, b3)
|
|
|
|
|
|
@dataclass
|
|
class SwiGLUOpDispatch:
|
|
"""Dispatcher to automatically select
|
|
the best operator in :attr:`xformers.ops.swiglu`
|
|
"""
|
|
|
|
device: Union[torch.device, str]
|
|
dtype: torch.dtype
|
|
dtype_autocast_gpu: Optional[torch.dtype]
|
|
packed_weights: bool
|
|
bias_enabled: bool
|
|
|
|
@property
|
|
def op(self) -> SwiGLUOp:
|
|
"""Computes the best operator
|
|
|
|
Returns:
|
|
SwiGLUOp: The best operator for the configuration
|
|
"""
|
|
priorities: Sequence[SwiGLUOp] = [
|
|
SwiGLUPackedFusedOp,
|
|
SwiGLUFusedOp,
|
|
]
|
|
for op in priorities:
|
|
if op.supports(self):
|
|
return op
|
|
return SwiGLUEagerOp
|
|
|
|
@staticmethod
|
|
def from_arguments(
|
|
x: torch.Tensor,
|
|
w1: torch.Tensor,
|
|
b1: Optional[torch.Tensor],
|
|
w2: torch.Tensor,
|
|
b2: Optional[torch.Tensor],
|
|
w3: torch.Tensor,
|
|
b3: Optional[torch.Tensor],
|
|
) -> "SwiGLUOpDispatch":
|
|
return SwiGLUOpDispatch(
|
|
device=x.device,
|
|
dtype=x.dtype,
|
|
packed_weights=stack_or_none((w1, w2), dim=0) is not None,
|
|
dtype_autocast_gpu=torch.get_autocast_gpu_dtype()
|
|
if torch.is_autocast_enabled()
|
|
else w1.dtype,
|
|
bias_enabled=b1 is not None and b2 is not None and b3 is not None,
|
|
)
|
|
|
|
|
|
def _only_sm80(op: SwiGLUOpDispatch) -> bool:
|
|
device_type = op.device if isinstance(op.device, str) else op.device.type
|
|
return device_type == "cuda" and torch.cuda.get_device_capability(op.device)[0] >= 8
|
|
|
|
|
|
def _only_half_or_autocast(op: SwiGLUOpDispatch) -> bool:
|
|
HALF_DTYPES = [torch.half, torch.bfloat16]
|
|
return op.dtype in HALF_DTYPES or (
|
|
op.dtype_autocast_gpu is not None and op.dtype_autocast_gpu in HALF_DTYPES
|
|
)
|
|
|
|
|
|
def _bias_enabled(op: SwiGLUOpDispatch) -> bool:
|
|
return op.bias_enabled
|
|
|
|
|
|
_SwiGLUDecomposedOp = _ForwardToPythonAutogradFunc(
|
|
_SwiGLUDecomposedFunc, False, "decomposed", constraints=[_bias_enabled]
|
|
)
|
|
SwiGLUFusedOp = _ForwardToPythonAutogradFunc(
|
|
_SwiGLUFusedFunc, False, "fused", constraints=[_only_sm80, _only_half_or_autocast]
|
|
)
|
|
SwiGLUPackedFusedOp = _ForwardToFunc(
|
|
get_xformers_operator("swiglu_packedw"),
|
|
True,
|
|
"fused.p.cpp",
|
|
constraints=[_only_sm80, _only_half_or_autocast],
|
|
)
|
|
SwiGLUEagerOp = _ForwardToFunc(
|
|
_eager_functional_swiglu,
|
|
False,
|
|
"eager",
|
|
constraints=[],
|
|
)
|
|
|
|
|
|
def _info() -> Dict[str, str]:
|
|
return {op.NAME: op.info() for op in [SwiGLUPackedFusedOp]}
|
|
|
|
|
|
def swiglu(
|
|
x: torch.Tensor,
|
|
w1: torch.Tensor,
|
|
b1: Optional[torch.Tensor],
|
|
w2: torch.Tensor,
|
|
b2: Optional[torch.Tensor],
|
|
w3: torch.Tensor,
|
|
b3: Optional[torch.Tensor],
|
|
*,
|
|
op: SwiGLUOp = None,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Computes a SwiGLU block given the weights/bias of the 3
|
|
linear layers.
|
|
|
|
- It is recommended to keep ``op=None`` so the best implementation \
|
|
available for the inputs will be used.
|
|
|
|
|
|
:Equivalent pytorch code:
|
|
|
|
.. code-block:: python
|
|
|
|
x1 = F.linear(x, w1, b1)
|
|
x2 = F.linear(x, w2, b2)
|
|
hidden = F.silu(x1) * x2
|
|
return F.linear(hidden, w3, b3)
|
|
|
|
:Packing weights:
|
|
|
|
To allow faster implementations, it's recommended to have w1/w2 come from the same storage, as in:
|
|
.. code-block:: python
|
|
|
|
w1, w2 = xformers.ops.unbind(w12, 0)
|
|
|
|
:Supported hardware:
|
|
|
|
This operator is only optimized on A100+ on ``torch.half`` or ``torch.bfloat16`` \
|
|
(autocast is supported), and will fallback to a functional pytorch \
|
|
implementation otherwise.
|
|
"""
|
|
|
|
batch_shape = x.shape[:-1]
|
|
x = x.reshape([-1, x.shape[-1]])
|
|
if w1.ndim != 2 or w1.shape != w2.shape:
|
|
raise ValueError(f"Invalid shapes for w1: {w1.shape} / w2: {w2.shape}")
|
|
if b1 is not None:
|
|
if b1.ndim != 1 or b1.shape[0] != w1.shape[0]:
|
|
raise ValueError(f"Invalid shapes for b1: {b1.shape}")
|
|
if b2 is not None:
|
|
if b2.ndim != 1 or b2.shape[0] != w2.shape[0]:
|
|
raise ValueError(f"Invalid shapes for b2: {b2.shape}")
|
|
if w3.ndim != 2 or w3.shape[1] != w2.shape[0]:
|
|
raise ValueError(f"Invalid shape for w3: {w3.shape}")
|
|
if b3 is not None:
|
|
if b3.ndim != 1 or b3.shape[0] != w3.shape[0]:
|
|
raise ValueError(f"Invalid shapes for w3: {w3.shape} / b3: {b3.shape}")
|
|
|
|
if op is None:
|
|
op = SwiGLUOpDispatch.from_arguments(x, w1, b1, w2, b2, w3, b3).op
|
|
|
|
if not op.PACKED_WEIGHTS:
|
|
return op(x, w1, b1, w2, b2, w3, b3).reshape([*batch_shape, -1])
|
|
w1w2 = stack_or_none((w1, w2), dim=0)
|
|
if b1 is not None and b2 is not None:
|
|
b1b2: Optional[torch.Tensor] = stack_or_none((b1, b2), dim=0)
|
|
if b1b2 is None:
|
|
raise NotImplementedError("b1/b2 needs to be properly packed")
|
|
else:
|
|
b1b2 = None
|
|
assert b1 is None and b2 is None
|
|
|
|
if w1w2 is None:
|
|
raise NotImplementedError("w1/w2 needs to be properly packed")
|
|
return op(x, w1w2, b1b2, w3, b3).reshape([*batch_shape, -1])
|
|
|
|
|
|
class SwiGLU(nn.Module):
|
|
"""
|
|
A Module that encapsulates the call to :attr:`xformers.ops.swiglu`,
|
|
and holds the weights for the 3 linear layers
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
in_features: int,
|
|
hidden_features: int,
|
|
out_features: Optional[int] = None,
|
|
bias: bool = True,
|
|
*,
|
|
_pack_weights: bool = True,
|
|
) -> None:
|
|
"""Create a SwiGLU module
|
|
|
|
Args:
|
|
in_features (int): Number of features of the input
|
|
hidden_features (int): Number of hidden features
|
|
out_features (Optional[int], optional): Number of features of the input. Defaults to None.
|
|
bias (bool, optional): Whether linear layers also include a bias. Defaults to True.
|
|
"""
|
|
super().__init__()
|
|
out_features = out_features or in_features
|
|
hidden_features = hidden_features or in_features
|
|
|
|
self.w12: Optional[nn.Linear]
|
|
if _pack_weights:
|
|
self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
|
|
else:
|
|
self.w12 = None
|
|
self.w1 = nn.Linear(in_features, hidden_features, bias=bias)
|
|
self.w2 = nn.Linear(in_features, hidden_features, bias=bias)
|
|
self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
|
|
|
|
self.hidden_features = hidden_features
|
|
self.out_features = out_features
|
|
self.in_features = in_features
|
|
self.op = None
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
"""Computes :attr:`swiglu` with the module's weights
|
|
|
|
Args:
|
|
x (torch.Tensor): A Tensor of shape ``[..., in_features]``
|
|
|
|
Returns:
|
|
torch.Tensor: A Tensor of shape ``[..., out_features]``
|
|
"""
|
|
return swiglu(x, *self._ordered_params(), op=self.op)
|
|
|
|
def _ordered_params(self):
|
|
"""Used for testing - returns ordered arguments for operators"""
|
|
b1: Optional[torch.Tensor]
|
|
b2: Optional[torch.Tensor]
|
|
if self.w12 is not None:
|
|
w1w2 = self.w12.weight
|
|
b1b2 = self.w12.bias
|
|
w1, w2 = unbind(
|
|
w1w2.view([2, w1w2.shape[0] // 2, w1w2.shape[1]]),
|
|
dim=0,
|
|
)
|
|
if b1b2 is not None:
|
|
b1, b2 = unbind(b1b2.view([2, b1b2.shape[0] // 2]), dim=0)
|
|
else:
|
|
b1, b2 = None, None
|
|
else:
|
|
w1, w2 = self.w1.weight, self.w2.weight
|
|
b1, b2 = self.w1.bias, self.w2.bias
|
|
return [
|
|
w1,
|
|
b1,
|
|
w2,
|
|
b2,
|
|
self.w3.weight,
|
|
self.w3.bias,
|
|
]
|