First commit
This commit is contained in:
467
pkgs/xformers/ops/swiglu_op.py
Normal file
467
pkgs/xformers/ops/swiglu_op.py
Normal file
@@ -0,0 +1,467 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional, Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from .common import BaseOperator, get_xformers_operator, register_operator
|
||||
from .unbind import stack_or_none, unbind
|
||||
|
||||
|
||||
@register_operator
|
||||
class DualGemmSiluOp(BaseOperator):
|
||||
OPERATOR = get_xformers_operator("dual_gemm_silu_identity_mul")
|
||||
OPERATOR_CATEGORY = "swiglu"
|
||||
NAME = "dual_gemm_silu"
|
||||
|
||||
@classmethod
|
||||
# type: ignore
|
||||
def operator_flop(
|
||||
cls, x: torch.Tensor, w1: torch.Tensor, b1, w2: torch.Tensor, b2
|
||||
) -> int:
|
||||
"""NOTE: we neglect the impact of biases / pointwises"""
|
||||
M, N, K = x.shape[0], w1.shape[0], w1.shape[1]
|
||||
return M * N * K * 2 * 2
|
||||
|
||||
|
||||
@register_operator
|
||||
class GemmFusedSumOp(BaseOperator):
|
||||
OPERATOR = get_xformers_operator("gemm_fused_operand_sum")
|
||||
OPERATOR_CATEGORY = "swiglu"
|
||||
NAME = "gemm_fused_operand_sum"
|
||||
|
||||
@classmethod
|
||||
# type: ignore
|
||||
def operator_flop(cls, a: torch.Tensor, b: torch.Tensor, out1, out2) -> int:
|
||||
M, N, K = a.shape[0], b.shape[1], a.shape[1]
|
||||
return M * N * K * 2
|
||||
|
||||
|
||||
class _SwiGLUDecomposedFunc(torch.autograd.Function):
|
||||
"""
|
||||
This is just an example implementation with all
|
||||
operations explicited. This implementation is worse
|
||||
than pytorch, because pytorch is able to fuse some operations
|
||||
(eg the linear forward ...) that are decomposed here.
|
||||
|
||||
The time measurements were made on the ViT-Giant setting:
|
||||
- A100/f16
|
||||
- input: [4440, 1536]
|
||||
- hidden: [4440, 4096]
|
||||
"""
|
||||
|
||||
NAME = "decomposed"
|
||||
FORCE_BW_F32 = False
|
||||
|
||||
def _silu_backward(dy, x):
|
||||
# https://github.com/pytorch/pytorch/blob/563b065f5a4b4055fa6b025c2514b566d5fd9439/aten/src/ATen/native/Activation.cpp#L483
|
||||
sigm = 1 / (1 + torch.exp(-x.float()))
|
||||
return (dy.float() * sigm * (1 + x.float() * (1 - sigm))).to(x.dtype)
|
||||
|
||||
# 952us
|
||||
@classmethod
|
||||
def forward(cls, ctx, x, w1, b1, w2, b2, w3, b3):
|
||||
x1 = x @ w1.transpose(-2, -1) + b1 # 275us
|
||||
x2 = x @ w2.transpose(-2, -1) + b2 # 275us
|
||||
x3 = F.silu(x1) # 62us
|
||||
x4 = x3 * x2 # 90us
|
||||
x5 = x4 @ w3.transpose(-2, -1) + b3 # 250us
|
||||
|
||||
ctx.save_for_backward(x, w1, b1, w2, b2, w3, b3, x1, x2, x3, x4, x5)
|
||||
return x5
|
||||
|
||||
# 1900us
|
||||
@classmethod
|
||||
def backward(cls, ctx, dx5):
|
||||
saved_tensors = ctx.saved_tensors
|
||||
if cls.FORCE_BW_F32:
|
||||
dx5 = dx5.float()
|
||||
saved_tensors = [t.float() for t in ctx.saved_tensors]
|
||||
x, w1, b1, w2, b2, w3, b3, x1, x2, x3, x4, x5 = saved_tensors
|
||||
dx4 = dx5 @ w3 # 255us (nn)
|
||||
dw3 = dx5.transpose(-2, -1) @ x4 # 247us (nt)
|
||||
db3 = dx5.sum(0) # 25us
|
||||
dx3 = dx4 * x2 # 88us
|
||||
dx2 = dx4 * x3 # 88us
|
||||
dx1 = cls._silu_backward(dx3, x1) # 90us
|
||||
dx = dx2 @ w2 # 260us (nn)
|
||||
dw2 = dx2.transpose(-2, -1) @ x # 245us (nt)
|
||||
db2 = dx2.sum(0) # 50us
|
||||
dx += dx1 @ w1 # 260us (nn)
|
||||
dw1 = dx1.transpose(-2, -1) @ x # 245us (nt)
|
||||
db1 = dx1.sum(0) # 50us
|
||||
return (dx, dw1, db1, dw2, db2, dw3, db3)
|
||||
|
||||
|
||||
class _SwiGLUFusedFunc(torch.autograd.Function):
|
||||
NAME = "fused.py"
|
||||
|
||||
@classmethod
|
||||
@torch.cuda.amp.custom_fwd
|
||||
def forward(cls, ctx, x, w1, b1, w2, b2, w3, b3):
|
||||
x1, x2, x4 = DualGemmSiluOp.OPERATOR(x, w1, b1, w2, b2)
|
||||
|
||||
x5 = F.linear(x4, w3, b3)
|
||||
ctx.save_for_backward(x, w1, w2, w3, x1, x2)
|
||||
ctx.bias = [b1 is not None, b2 is not None, b3 is not None]
|
||||
return x5
|
||||
|
||||
@staticmethod
|
||||
def _linear_bw(
|
||||
dy: torch.Tensor, x: torch.Tensor, bias: bool
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
if not bias:
|
||||
return (dy.transpose(-2, -1) @ x), None
|
||||
db = torch.empty([dy.shape[1]], dtype=dy.dtype, device=dy.device)
|
||||
dw = torch.empty([dy.shape[1], x.shape[1]], dtype=dy.dtype, device=dy.device)
|
||||
GemmFusedSumOp.OPERATOR(dy.transpose(-2, -1), x, dw, db)
|
||||
return dw, db
|
||||
|
||||
@classmethod
|
||||
@torch.cuda.amp.custom_bwd
|
||||
def backward(cls, ctx, dx5):
|
||||
x, w1, w2, w3, x1, x2 = ctx.saved_tensors
|
||||
w1w2 = stack_or_none([w1, w2], dim=0)
|
||||
|
||||
dx4 = dx5 @ w3 # 255us (nn)
|
||||
dx1dx2, x4 = torch.ops.xformers.silu_bw_fused(x1, x2, dx4)
|
||||
dx1, dx2 = dx1dx2.unbind(1)
|
||||
del x1, x2, dx4
|
||||
|
||||
dw3, db3 = cls._linear_bw(dx5, x4, bias=ctx.bias[2])
|
||||
del x4, dx5
|
||||
if w1w2 is not None:
|
||||
assert dx1dx2.is_contiguous()
|
||||
assert w1w2.is_contiguous()
|
||||
w1w2 = w1w2.view([w1.shape[0] * 2, w1.shape[1]])
|
||||
dx = dx1dx2.view([dx1.shape[0], 2 * dx1.shape[1]]) @ w1w2
|
||||
|
||||
# backward of linear1 + linear2 - packed
|
||||
dw1dw2 = dx1dx2.view([dx1.shape[0], 2 * dx1.shape[1]]).transpose(-2, -1) @ x
|
||||
dw1dw2, db1db2 = cls._linear_bw(
|
||||
dx1dx2.view([dx1.shape[0], 2 * dx1.shape[1]]), x, bias=ctx.bias[0]
|
||||
)
|
||||
dw1, dw2 = dw1dw2.view([2, *w1.shape]).unbind(0)
|
||||
if ctx.bias[0]:
|
||||
db1db2 = db1db2.view([2, dx1.shape[1]])
|
||||
db1, db2 = torch.unbind(db1db2, dim=0)
|
||||
else:
|
||||
db1 = db2 = None
|
||||
else:
|
||||
dx = dx2 @ w2 # 260us (nn)
|
||||
torch.addmm(
|
||||
dx, dx1, w1.to(dx1.dtype), beta=1, alpha=1, out=dx
|
||||
) # dx += dx1 @ w1
|
||||
dw2, db2 = cls._linear_bw(dx2, x, bias=ctx.bias[1])
|
||||
dw1, db1 = cls._linear_bw(dx1, x, bias=ctx.bias[0])
|
||||
return (dx, dw1, db1, dw2, db2, dw3, db3)
|
||||
|
||||
|
||||
class SwiGLUOp:
|
||||
"""Base class for any swiglu operator in :attr:`xformers.ops.swiglu`"""
|
||||
|
||||
def __init__(self, op, packed_weights: bool, name: str, constraints):
|
||||
self.NAME = name
|
||||
self.PACKED_WEIGHTS = packed_weights
|
||||
self.op = op
|
||||
self.constraints = constraints
|
||||
|
||||
def supports(self, op: "SwiGLUOpDispatch") -> bool:
|
||||
if self.PACKED_WEIGHTS and not op.packed_weights:
|
||||
return False
|
||||
return all(c(op) for c in self.constraints)
|
||||
|
||||
def __call__(self, *args: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
pass
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"SwiGLUOp:{self.NAME}"
|
||||
|
||||
|
||||
class _ForwardToPythonAutogradFunc(SwiGLUOp):
|
||||
def supports(self, op: "SwiGLUOpDispatch") -> bool:
|
||||
# Let's disable autocast in bf16 until this issue is fixed
|
||||
# https://github.com/pytorch/pytorch/issues/87979
|
||||
if op.dtype_autocast_gpu == torch.bfloat16:
|
||||
return False
|
||||
return super().supports(op)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.op.apply(*args, **kwargs)
|
||||
|
||||
|
||||
class _ForwardToFunc(SwiGLUOp):
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.op(*args, **kwargs)
|
||||
|
||||
def info(self):
|
||||
if self.op.__name__ == "no_such_operator":
|
||||
return "not built"
|
||||
return "available"
|
||||
|
||||
|
||||
def _eager_functional_swiglu(
|
||||
x: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
b1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
b2: torch.Tensor,
|
||||
w3: torch.Tensor,
|
||||
b3: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
x1 = F.linear(x, w1, b1)
|
||||
x2 = F.linear(x, w2, b2)
|
||||
hidden = F.silu(x1) * x2
|
||||
return F.linear(hidden, w3, b3)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SwiGLUOpDispatch:
|
||||
"""Dispatcher to automatically select
|
||||
the best operator in :attr:`xformers.ops.swiglu`
|
||||
"""
|
||||
|
||||
device: Union[torch.device, str]
|
||||
dtype: torch.dtype
|
||||
dtype_autocast_gpu: Optional[torch.dtype]
|
||||
packed_weights: bool
|
||||
bias_enabled: bool
|
||||
|
||||
@property
|
||||
def op(self) -> SwiGLUOp:
|
||||
"""Computes the best operator
|
||||
|
||||
Returns:
|
||||
SwiGLUOp: The best operator for the configuration
|
||||
"""
|
||||
priorities: Sequence[SwiGLUOp] = [
|
||||
SwiGLUPackedFusedOp,
|
||||
SwiGLUFusedOp,
|
||||
]
|
||||
for op in priorities:
|
||||
if op.supports(self):
|
||||
return op
|
||||
return SwiGLUEagerOp
|
||||
|
||||
@staticmethod
|
||||
def from_arguments(
|
||||
x: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
b1: Optional[torch.Tensor],
|
||||
w2: torch.Tensor,
|
||||
b2: Optional[torch.Tensor],
|
||||
w3: torch.Tensor,
|
||||
b3: Optional[torch.Tensor],
|
||||
) -> "SwiGLUOpDispatch":
|
||||
return SwiGLUOpDispatch(
|
||||
device=x.device,
|
||||
dtype=x.dtype,
|
||||
packed_weights=stack_or_none((w1, w2), dim=0) is not None,
|
||||
dtype_autocast_gpu=torch.get_autocast_gpu_dtype()
|
||||
if torch.is_autocast_enabled()
|
||||
else w1.dtype,
|
||||
bias_enabled=b1 is not None and b2 is not None and b3 is not None,
|
||||
)
|
||||
|
||||
|
||||
def _only_sm80(op: SwiGLUOpDispatch) -> bool:
|
||||
device_type = op.device if isinstance(op.device, str) else op.device.type
|
||||
return device_type == "cuda" and torch.cuda.get_device_capability(op.device)[0] >= 8
|
||||
|
||||
|
||||
def _only_half_or_autocast(op: SwiGLUOpDispatch) -> bool:
|
||||
HALF_DTYPES = [torch.half, torch.bfloat16]
|
||||
return op.dtype in HALF_DTYPES or (
|
||||
op.dtype_autocast_gpu is not None and op.dtype_autocast_gpu in HALF_DTYPES
|
||||
)
|
||||
|
||||
|
||||
def _bias_enabled(op: SwiGLUOpDispatch) -> bool:
|
||||
return op.bias_enabled
|
||||
|
||||
|
||||
_SwiGLUDecomposedOp = _ForwardToPythonAutogradFunc(
|
||||
_SwiGLUDecomposedFunc, False, "decomposed", constraints=[_bias_enabled]
|
||||
)
|
||||
SwiGLUFusedOp = _ForwardToPythonAutogradFunc(
|
||||
_SwiGLUFusedFunc, False, "fused", constraints=[_only_sm80, _only_half_or_autocast]
|
||||
)
|
||||
SwiGLUPackedFusedOp = _ForwardToFunc(
|
||||
get_xformers_operator("swiglu_packedw"),
|
||||
True,
|
||||
"fused.p.cpp",
|
||||
constraints=[_only_sm80, _only_half_or_autocast],
|
||||
)
|
||||
SwiGLUEagerOp = _ForwardToFunc(
|
||||
_eager_functional_swiglu,
|
||||
False,
|
||||
"eager",
|
||||
constraints=[],
|
||||
)
|
||||
|
||||
|
||||
def _info() -> Dict[str, str]:
|
||||
return {op.NAME: op.info() for op in [SwiGLUPackedFusedOp]}
|
||||
|
||||
|
||||
def swiglu(
|
||||
x: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
b1: Optional[torch.Tensor],
|
||||
w2: torch.Tensor,
|
||||
b2: Optional[torch.Tensor],
|
||||
w3: torch.Tensor,
|
||||
b3: Optional[torch.Tensor],
|
||||
*,
|
||||
op: SwiGLUOp = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Computes a SwiGLU block given the weights/bias of the 3
|
||||
linear layers.
|
||||
|
||||
- It is recommended to keep ``op=None`` so the best implementation \
|
||||
available for the inputs will be used.
|
||||
|
||||
|
||||
:Equivalent pytorch code:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
x1 = F.linear(x, w1, b1)
|
||||
x2 = F.linear(x, w2, b2)
|
||||
hidden = F.silu(x1) * x2
|
||||
return F.linear(hidden, w3, b3)
|
||||
|
||||
:Packing weights:
|
||||
|
||||
To allow faster implementations, it's recommended to have w1/w2 come from the same storage, as in:
|
||||
.. code-block:: python
|
||||
|
||||
w1, w2 = xformers.ops.unbind(w12, 0)
|
||||
|
||||
:Supported hardware:
|
||||
|
||||
This operator is only optimized on A100+ on ``torch.half`` or ``torch.bfloat16`` \
|
||||
(autocast is supported), and will fallback to a functional pytorch \
|
||||
implementation otherwise.
|
||||
"""
|
||||
|
||||
batch_shape = x.shape[:-1]
|
||||
x = x.reshape([-1, x.shape[-1]])
|
||||
if w1.ndim != 2 or w1.shape != w2.shape:
|
||||
raise ValueError(f"Invalid shapes for w1: {w1.shape} / w2: {w2.shape}")
|
||||
if b1 is not None:
|
||||
if b1.ndim != 1 or b1.shape[0] != w1.shape[0]:
|
||||
raise ValueError(f"Invalid shapes for b1: {b1.shape}")
|
||||
if b2 is not None:
|
||||
if b2.ndim != 1 or b2.shape[0] != w2.shape[0]:
|
||||
raise ValueError(f"Invalid shapes for b2: {b2.shape}")
|
||||
if w3.ndim != 2 or w3.shape[1] != w2.shape[0]:
|
||||
raise ValueError(f"Invalid shape for w3: {w3.shape}")
|
||||
if b3 is not None:
|
||||
if b3.ndim != 1 or b3.shape[0] != w3.shape[0]:
|
||||
raise ValueError(f"Invalid shapes for w3: {w3.shape} / b3: {b3.shape}")
|
||||
|
||||
if op is None:
|
||||
op = SwiGLUOpDispatch.from_arguments(x, w1, b1, w2, b2, w3, b3).op
|
||||
|
||||
if not op.PACKED_WEIGHTS:
|
||||
return op(x, w1, b1, w2, b2, w3, b3).reshape([*batch_shape, -1])
|
||||
w1w2 = stack_or_none((w1, w2), dim=0)
|
||||
if b1 is not None and b2 is not None:
|
||||
b1b2: Optional[torch.Tensor] = stack_or_none((b1, b2), dim=0)
|
||||
if b1b2 is None:
|
||||
raise NotImplementedError("b1/b2 needs to be properly packed")
|
||||
else:
|
||||
b1b2 = None
|
||||
assert b1 is None and b2 is None
|
||||
|
||||
if w1w2 is None:
|
||||
raise NotImplementedError("w1/w2 needs to be properly packed")
|
||||
return op(x, w1w2, b1b2, w3, b3).reshape([*batch_shape, -1])
|
||||
|
||||
|
||||
class SwiGLU(nn.Module):
|
||||
"""
|
||||
A Module that encapsulates the call to :attr:`xformers.ops.swiglu`,
|
||||
and holds the weights for the 3 linear layers
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
hidden_features: int,
|
||||
out_features: Optional[int] = None,
|
||||
bias: bool = True,
|
||||
*,
|
||||
_pack_weights: bool = True,
|
||||
) -> None:
|
||||
"""Create a SwiGLU module
|
||||
|
||||
Args:
|
||||
in_features (int): Number of features of the input
|
||||
hidden_features (int): Number of hidden features
|
||||
out_features (Optional[int], optional): Number of features of the input. Defaults to None.
|
||||
bias (bool, optional): Whether linear layers also include a bias. Defaults to True.
|
||||
"""
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
|
||||
self.w12: Optional[nn.Linear]
|
||||
if _pack_weights:
|
||||
self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
|
||||
else:
|
||||
self.w12 = None
|
||||
self.w1 = nn.Linear(in_features, hidden_features, bias=bias)
|
||||
self.w2 = nn.Linear(in_features, hidden_features, bias=bias)
|
||||
self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
|
||||
|
||||
self.hidden_features = hidden_features
|
||||
self.out_features = out_features
|
||||
self.in_features = in_features
|
||||
self.op = None
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Computes :attr:`swiglu` with the module's weights
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): A Tensor of shape ``[..., in_features]``
|
||||
|
||||
Returns:
|
||||
torch.Tensor: A Tensor of shape ``[..., out_features]``
|
||||
"""
|
||||
return swiglu(x, *self._ordered_params(), op=self.op)
|
||||
|
||||
def _ordered_params(self):
|
||||
"""Used for testing - returns ordered arguments for operators"""
|
||||
b1: Optional[torch.Tensor]
|
||||
b2: Optional[torch.Tensor]
|
||||
if self.w12 is not None:
|
||||
w1w2 = self.w12.weight
|
||||
b1b2 = self.w12.bias
|
||||
w1, w2 = unbind(
|
||||
w1w2.view([2, w1w2.shape[0] // 2, w1w2.shape[1]]),
|
||||
dim=0,
|
||||
)
|
||||
if b1b2 is not None:
|
||||
b1, b2 = unbind(b1b2.view([2, b1b2.shape[0] // 2]), dim=0)
|
||||
else:
|
||||
b1, b2 = None, None
|
||||
else:
|
||||
w1, w2 = self.w1.weight, self.w2.weight
|
||||
b1, b2 = self.w1.bias, self.w2.bias
|
||||
return [
|
||||
w1,
|
||||
b1,
|
||||
w2,
|
||||
b2,
|
||||
self.w3.weight,
|
||||
self.w3.bias,
|
||||
]
|
||||
Reference in New Issue
Block a user