Files
2025-08-05 19:02:46 +08:00

468 lines
15 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
from typing import Dict, Optional, Sequence, Tuple, Union
import torch
import torch.nn.functional as F
from torch import nn
from .common import BaseOperator, get_xformers_operator, register_operator
from .unbind import stack_or_none, unbind
@register_operator
class DualGemmSiluOp(BaseOperator):
OPERATOR = get_xformers_operator("dual_gemm_silu_identity_mul")
OPERATOR_CATEGORY = "swiglu"
NAME = "dual_gemm_silu"
@classmethod
# type: ignore
def operator_flop(
cls, x: torch.Tensor, w1: torch.Tensor, b1, w2: torch.Tensor, b2
) -> int:
"""NOTE: we neglect the impact of biases / pointwises"""
M, N, K = x.shape[0], w1.shape[0], w1.shape[1]
return M * N * K * 2 * 2
@register_operator
class GemmFusedSumOp(BaseOperator):
OPERATOR = get_xformers_operator("gemm_fused_operand_sum")
OPERATOR_CATEGORY = "swiglu"
NAME = "gemm_fused_operand_sum"
@classmethod
# type: ignore
def operator_flop(cls, a: torch.Tensor, b: torch.Tensor, out1, out2) -> int:
M, N, K = a.shape[0], b.shape[1], a.shape[1]
return M * N * K * 2
class _SwiGLUDecomposedFunc(torch.autograd.Function):
"""
This is just an example implementation with all
operations explicited. This implementation is worse
than pytorch, because pytorch is able to fuse some operations
(eg the linear forward ...) that are decomposed here.
The time measurements were made on the ViT-Giant setting:
- A100/f16
- input: [4440, 1536]
- hidden: [4440, 4096]
"""
NAME = "decomposed"
FORCE_BW_F32 = False
def _silu_backward(dy, x):
# https://github.com/pytorch/pytorch/blob/563b065f5a4b4055fa6b025c2514b566d5fd9439/aten/src/ATen/native/Activation.cpp#L483
sigm = 1 / (1 + torch.exp(-x.float()))
return (dy.float() * sigm * (1 + x.float() * (1 - sigm))).to(x.dtype)
# 952us
@classmethod
def forward(cls, ctx, x, w1, b1, w2, b2, w3, b3):
x1 = x @ w1.transpose(-2, -1) + b1 # 275us
x2 = x @ w2.transpose(-2, -1) + b2 # 275us
x3 = F.silu(x1) # 62us
x4 = x3 * x2 # 90us
x5 = x4 @ w3.transpose(-2, -1) + b3 # 250us
ctx.save_for_backward(x, w1, b1, w2, b2, w3, b3, x1, x2, x3, x4, x5)
return x5
# 1900us
@classmethod
def backward(cls, ctx, dx5):
saved_tensors = ctx.saved_tensors
if cls.FORCE_BW_F32:
dx5 = dx5.float()
saved_tensors = [t.float() for t in ctx.saved_tensors]
x, w1, b1, w2, b2, w3, b3, x1, x2, x3, x4, x5 = saved_tensors
dx4 = dx5 @ w3 # 255us (nn)
dw3 = dx5.transpose(-2, -1) @ x4 # 247us (nt)
db3 = dx5.sum(0) # 25us
dx3 = dx4 * x2 # 88us
dx2 = dx4 * x3 # 88us
dx1 = cls._silu_backward(dx3, x1) # 90us
dx = dx2 @ w2 # 260us (nn)
dw2 = dx2.transpose(-2, -1) @ x # 245us (nt)
db2 = dx2.sum(0) # 50us
dx += dx1 @ w1 # 260us (nn)
dw1 = dx1.transpose(-2, -1) @ x # 245us (nt)
db1 = dx1.sum(0) # 50us
return (dx, dw1, db1, dw2, db2, dw3, db3)
class _SwiGLUFusedFunc(torch.autograd.Function):
NAME = "fused.py"
@classmethod
@torch.cuda.amp.custom_fwd
def forward(cls, ctx, x, w1, b1, w2, b2, w3, b3):
x1, x2, x4 = DualGemmSiluOp.OPERATOR(x, w1, b1, w2, b2)
x5 = F.linear(x4, w3, b3)
ctx.save_for_backward(x, w1, w2, w3, x1, x2)
ctx.bias = [b1 is not None, b2 is not None, b3 is not None]
return x5
@staticmethod
def _linear_bw(
dy: torch.Tensor, x: torch.Tensor, bias: bool
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
if not bias:
return (dy.transpose(-2, -1) @ x), None
db = torch.empty([dy.shape[1]], dtype=dy.dtype, device=dy.device)
dw = torch.empty([dy.shape[1], x.shape[1]], dtype=dy.dtype, device=dy.device)
GemmFusedSumOp.OPERATOR(dy.transpose(-2, -1), x, dw, db)
return dw, db
@classmethod
@torch.cuda.amp.custom_bwd
def backward(cls, ctx, dx5):
x, w1, w2, w3, x1, x2 = ctx.saved_tensors
w1w2 = stack_or_none([w1, w2], dim=0)
dx4 = dx5 @ w3 # 255us (nn)
dx1dx2, x4 = torch.ops.xformers.silu_bw_fused(x1, x2, dx4)
dx1, dx2 = dx1dx2.unbind(1)
del x1, x2, dx4
dw3, db3 = cls._linear_bw(dx5, x4, bias=ctx.bias[2])
del x4, dx5
if w1w2 is not None:
assert dx1dx2.is_contiguous()
assert w1w2.is_contiguous()
w1w2 = w1w2.view([w1.shape[0] * 2, w1.shape[1]])
dx = dx1dx2.view([dx1.shape[0], 2 * dx1.shape[1]]) @ w1w2
# backward of linear1 + linear2 - packed
dw1dw2 = dx1dx2.view([dx1.shape[0], 2 * dx1.shape[1]]).transpose(-2, -1) @ x
dw1dw2, db1db2 = cls._linear_bw(
dx1dx2.view([dx1.shape[0], 2 * dx1.shape[1]]), x, bias=ctx.bias[0]
)
dw1, dw2 = dw1dw2.view([2, *w1.shape]).unbind(0)
if ctx.bias[0]:
db1db2 = db1db2.view([2, dx1.shape[1]])
db1, db2 = torch.unbind(db1db2, dim=0)
else:
db1 = db2 = None
else:
dx = dx2 @ w2 # 260us (nn)
torch.addmm(
dx, dx1, w1.to(dx1.dtype), beta=1, alpha=1, out=dx
) # dx += dx1 @ w1
dw2, db2 = cls._linear_bw(dx2, x, bias=ctx.bias[1])
dw1, db1 = cls._linear_bw(dx1, x, bias=ctx.bias[0])
return (dx, dw1, db1, dw2, db2, dw3, db3)
class SwiGLUOp:
"""Base class for any swiglu operator in :attr:`xformers.ops.swiglu`"""
def __init__(self, op, packed_weights: bool, name: str, constraints):
self.NAME = name
self.PACKED_WEIGHTS = packed_weights
self.op = op
self.constraints = constraints
def supports(self, op: "SwiGLUOpDispatch") -> bool:
if self.PACKED_WEIGHTS and not op.packed_weights:
return False
return all(c(op) for c in self.constraints)
def __call__(self, *args: Optional[torch.Tensor]) -> torch.Tensor:
pass
def __str__(self) -> str:
return f"SwiGLUOp:{self.NAME}"
class _ForwardToPythonAutogradFunc(SwiGLUOp):
def supports(self, op: "SwiGLUOpDispatch") -> bool:
# Let's disable autocast in bf16 until this issue is fixed
# https://github.com/pytorch/pytorch/issues/87979
if op.dtype_autocast_gpu == torch.bfloat16:
return False
return super().supports(op)
def __call__(self, *args, **kwargs):
return self.op.apply(*args, **kwargs)
class _ForwardToFunc(SwiGLUOp):
def __call__(self, *args, **kwargs):
return self.op(*args, **kwargs)
def info(self):
if self.op.__name__ == "no_such_operator":
return "not built"
return "available"
def _eager_functional_swiglu(
x: torch.Tensor,
w1: torch.Tensor,
b1: torch.Tensor,
w2: torch.Tensor,
b2: torch.Tensor,
w3: torch.Tensor,
b3: torch.Tensor,
) -> torch.Tensor:
x1 = F.linear(x, w1, b1)
x2 = F.linear(x, w2, b2)
hidden = F.silu(x1) * x2
return F.linear(hidden, w3, b3)
@dataclass
class SwiGLUOpDispatch:
"""Dispatcher to automatically select
the best operator in :attr:`xformers.ops.swiglu`
"""
device: Union[torch.device, str]
dtype: torch.dtype
dtype_autocast_gpu: Optional[torch.dtype]
packed_weights: bool
bias_enabled: bool
@property
def op(self) -> SwiGLUOp:
"""Computes the best operator
Returns:
SwiGLUOp: The best operator for the configuration
"""
priorities: Sequence[SwiGLUOp] = [
SwiGLUPackedFusedOp,
SwiGLUFusedOp,
]
for op in priorities:
if op.supports(self):
return op
return SwiGLUEagerOp
@staticmethod
def from_arguments(
x: torch.Tensor,
w1: torch.Tensor,
b1: Optional[torch.Tensor],
w2: torch.Tensor,
b2: Optional[torch.Tensor],
w3: torch.Tensor,
b3: Optional[torch.Tensor],
) -> "SwiGLUOpDispatch":
return SwiGLUOpDispatch(
device=x.device,
dtype=x.dtype,
packed_weights=stack_or_none((w1, w2), dim=0) is not None,
dtype_autocast_gpu=torch.get_autocast_gpu_dtype()
if torch.is_autocast_enabled()
else w1.dtype,
bias_enabled=b1 is not None and b2 is not None and b3 is not None,
)
def _only_sm80(op: SwiGLUOpDispatch) -> bool:
device_type = op.device if isinstance(op.device, str) else op.device.type
return device_type == "cuda" and torch.cuda.get_device_capability(op.device)[0] >= 8
def _only_half_or_autocast(op: SwiGLUOpDispatch) -> bool:
HALF_DTYPES = [torch.half, torch.bfloat16]
return op.dtype in HALF_DTYPES or (
op.dtype_autocast_gpu is not None and op.dtype_autocast_gpu in HALF_DTYPES
)
def _bias_enabled(op: SwiGLUOpDispatch) -> bool:
return op.bias_enabled
_SwiGLUDecomposedOp = _ForwardToPythonAutogradFunc(
_SwiGLUDecomposedFunc, False, "decomposed", constraints=[_bias_enabled]
)
SwiGLUFusedOp = _ForwardToPythonAutogradFunc(
_SwiGLUFusedFunc, False, "fused", constraints=[_only_sm80, _only_half_or_autocast]
)
SwiGLUPackedFusedOp = _ForwardToFunc(
get_xformers_operator("swiglu_packedw"),
True,
"fused.p.cpp",
constraints=[_only_sm80, _only_half_or_autocast],
)
SwiGLUEagerOp = _ForwardToFunc(
_eager_functional_swiglu,
False,
"eager",
constraints=[],
)
def _info() -> Dict[str, str]:
return {op.NAME: op.info() for op in [SwiGLUPackedFusedOp]}
def swiglu(
x: torch.Tensor,
w1: torch.Tensor,
b1: Optional[torch.Tensor],
w2: torch.Tensor,
b2: Optional[torch.Tensor],
w3: torch.Tensor,
b3: Optional[torch.Tensor],
*,
op: SwiGLUOp = None,
) -> torch.Tensor:
"""
Computes a SwiGLU block given the weights/bias of the 3
linear layers.
- It is recommended to keep ``op=None`` so the best implementation \
available for the inputs will be used.
:Equivalent pytorch code:
.. code-block:: python
x1 = F.linear(x, w1, b1)
x2 = F.linear(x, w2, b2)
hidden = F.silu(x1) * x2
return F.linear(hidden, w3, b3)
:Packing weights:
To allow faster implementations, it's recommended to have w1/w2 come from the same storage, as in:
.. code-block:: python
w1, w2 = xformers.ops.unbind(w12, 0)
:Supported hardware:
This operator is only optimized on A100+ on ``torch.half`` or ``torch.bfloat16`` \
(autocast is supported), and will fallback to a functional pytorch \
implementation otherwise.
"""
batch_shape = x.shape[:-1]
x = x.reshape([-1, x.shape[-1]])
if w1.ndim != 2 or w1.shape != w2.shape:
raise ValueError(f"Invalid shapes for w1: {w1.shape} / w2: {w2.shape}")
if b1 is not None:
if b1.ndim != 1 or b1.shape[0] != w1.shape[0]:
raise ValueError(f"Invalid shapes for b1: {b1.shape}")
if b2 is not None:
if b2.ndim != 1 or b2.shape[0] != w2.shape[0]:
raise ValueError(f"Invalid shapes for b2: {b2.shape}")
if w3.ndim != 2 or w3.shape[1] != w2.shape[0]:
raise ValueError(f"Invalid shape for w3: {w3.shape}")
if b3 is not None:
if b3.ndim != 1 or b3.shape[0] != w3.shape[0]:
raise ValueError(f"Invalid shapes for w3: {w3.shape} / b3: {b3.shape}")
if op is None:
op = SwiGLUOpDispatch.from_arguments(x, w1, b1, w2, b2, w3, b3).op
if not op.PACKED_WEIGHTS:
return op(x, w1, b1, w2, b2, w3, b3).reshape([*batch_shape, -1])
w1w2 = stack_or_none((w1, w2), dim=0)
if b1 is not None and b2 is not None:
b1b2: Optional[torch.Tensor] = stack_or_none((b1, b2), dim=0)
if b1b2 is None:
raise NotImplementedError("b1/b2 needs to be properly packed")
else:
b1b2 = None
assert b1 is None and b2 is None
if w1w2 is None:
raise NotImplementedError("w1/w2 needs to be properly packed")
return op(x, w1w2, b1b2, w3, b3).reshape([*batch_shape, -1])
class SwiGLU(nn.Module):
"""
A Module that encapsulates the call to :attr:`xformers.ops.swiglu`,
and holds the weights for the 3 linear layers
"""
def __init__(
self,
in_features: int,
hidden_features: int,
out_features: Optional[int] = None,
bias: bool = True,
*,
_pack_weights: bool = True,
) -> None:
"""Create a SwiGLU module
Args:
in_features (int): Number of features of the input
hidden_features (int): Number of hidden features
out_features (Optional[int], optional): Number of features of the input. Defaults to None.
bias (bool, optional): Whether linear layers also include a bias. Defaults to True.
"""
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.w12: Optional[nn.Linear]
if _pack_weights:
self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
else:
self.w12 = None
self.w1 = nn.Linear(in_features, hidden_features, bias=bias)
self.w2 = nn.Linear(in_features, hidden_features, bias=bias)
self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
self.hidden_features = hidden_features
self.out_features = out_features
self.in_features = in_features
self.op = None
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Computes :attr:`swiglu` with the module's weights
Args:
x (torch.Tensor): A Tensor of shape ``[..., in_features]``
Returns:
torch.Tensor: A Tensor of shape ``[..., out_features]``
"""
return swiglu(x, *self._ordered_params(), op=self.op)
def _ordered_params(self):
"""Used for testing - returns ordered arguments for operators"""
b1: Optional[torch.Tensor]
b2: Optional[torch.Tensor]
if self.w12 is not None:
w1w2 = self.w12.weight
b1b2 = self.w12.bias
w1, w2 = unbind(
w1w2.view([2, w1w2.shape[0] // 2, w1w2.shape[1]]),
dim=0,
)
if b1b2 is not None:
b1, b2 = unbind(b1b2.view([2, b1b2.shape[0] // 2]), dim=0)
else:
b1, b2 = None, None
else:
w1, w2 = self.w1.weight, self.w2.weight
b1, b2 = self.w1.bias, self.w2.bias
return [
w1,
b1,
w2,
b2,
self.w3.weight,
self.w3.bias,
]