244 lines
7.4 KiB
Python
244 lines
7.4 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
# CREDITS: This is heavily inspired by the Triton dropout tutorial
|
|
# https://raw.githubusercontent.com/openai/triton/master/python/tutorials/04-low-memory-dropout.py
|
|
|
|
from typing import Optional
|
|
|
|
import torch
|
|
import triton
|
|
from torch.cuda.amp import custom_bwd, custom_fwd
|
|
|
|
from xformers.components.activations import Activation, build_activation
|
|
from xformers.triton.k_activations import get_triton_activation_index
|
|
from xformers.triton.k_dropout import k_dropout_bw, k_dropout_fw
|
|
|
|
BLOCK_M = 32
|
|
BLOCK_N = 64 # NOTE: This should ideally be GPU dependent, big impact on perf
|
|
|
|
|
|
# Helper to handle the SPMD launch grid and error cases
|
|
class _dropout(torch.autograd.Function):
|
|
@staticmethod
|
|
@custom_fwd(cast_inputs=torch.float16)
|
|
def forward(ctx, x, p, bias, activation, trainable_bias):
|
|
# Soft-flatten an hypothetical 3rd dimension
|
|
x_ = x.reshape(-1, x.shape[-1]).contiguous()
|
|
y = torch.empty_like(x_)
|
|
M, N = x_.shape
|
|
|
|
assert bias is None or (bias.dtype == x.dtype and bias.shape[0] == N)
|
|
assert p > 0.0
|
|
|
|
def grid(meta):
|
|
return (
|
|
triton.cdiv(M, meta["BLOCK_M"]),
|
|
triton.cdiv(N, meta["BLOCK_N"]),
|
|
)
|
|
|
|
N_BLOCK_N = triton.cdiv(N, BLOCK_N)
|
|
|
|
# Generate one seed per sample
|
|
# seed max is int32 max for positive numbers: 2**16
|
|
seeds = torch.randint(65536, (N_BLOCK_N,), device=x.device, dtype=torch.int32)
|
|
|
|
# fmt: off
|
|
bias_ptr = bias if bias is not None else x_ # Possibly not being used
|
|
|
|
k_dropout_fw[grid](
|
|
y, x_,
|
|
bias_ptr,
|
|
seeds,
|
|
y.stride(0),
|
|
M, N,
|
|
p,
|
|
x.dtype == torch.float16,
|
|
USE_BIAS=bias is not None,
|
|
ACTIVATION=activation,
|
|
BLOCK_M=BLOCK_M,
|
|
BLOCK_N=BLOCK_N,
|
|
)
|
|
# fmt: on
|
|
|
|
if activation is not None:
|
|
ctx.save_for_backward(seeds, bias, x)
|
|
else:
|
|
ctx.save_for_backward(seeds, bias, None)
|
|
|
|
ctx.trainable_bias = bias is not None and trainable_bias
|
|
ctx.activation = activation
|
|
ctx.p = p
|
|
|
|
return y.reshape_as(x)
|
|
|
|
@staticmethod
|
|
@custom_bwd
|
|
def backward(
|
|
ctx, grad_out
|
|
): # pragma: no cover # This is covered, but called from C++ and not tracked
|
|
(seeds, bias, inputs) = ctx.saved_tensors
|
|
|
|
# Soft-flatten an hypothetical 3rd dimension
|
|
grad_out_ = grad_out.reshape(-1, grad_out.shape[-1]).contiguous()
|
|
grad_in = torch.empty_like(grad_out_)
|
|
|
|
M, N = grad_out_.shape
|
|
|
|
# Optional inputs to compute the activation contribution to the gradient
|
|
assert inputs is not None or ctx.activation is None
|
|
|
|
if inputs is None:
|
|
inputs = grad_out_
|
|
elif inputs.ndim > 2:
|
|
inputs = inputs.reshape(-1, N)
|
|
|
|
# We split the problem in tiles:
|
|
# - over M there will be a follow up reduction
|
|
# - over N we compromise in between trying to use as much memory paralellism as possible,
|
|
# (fill in the warps, there are 32 threads per warps, and 4 warps default), and not being too
|
|
# big because of register spilling
|
|
N_BLOCKS_M = triton.cdiv(M, BLOCK_M)
|
|
|
|
if ctx.trainable_bias:
|
|
grad_bias = torch.empty(
|
|
(
|
|
N_BLOCKS_M,
|
|
N,
|
|
),
|
|
device=grad_in.device,
|
|
dtype=grad_in.dtype,
|
|
)
|
|
|
|
else:
|
|
grad_bias = grad_in # will not be used
|
|
|
|
def grid(meta):
|
|
# NOTE: We use Triton Philox random number generator, which optimally generates 4 blocks for
|
|
# a given seed and offsets. "BLOCK_M" here describes the size of one of these blocks
|
|
# but we need to take this factor of 4 into account when scheduling all the kernels
|
|
return (
|
|
N_BLOCKS_M,
|
|
triton.cdiv(N, meta["BLOCK_N"]),
|
|
)
|
|
|
|
# fmt: off
|
|
k_dropout_bw[grid](
|
|
grad_in, grad_bias, grad_out_,
|
|
inputs, bias if bias is not None else inputs,
|
|
seeds,
|
|
grad_out_.stride(0), inputs.stride(0),
|
|
M, N,
|
|
ctx.p,
|
|
grad_in.dtype == torch.float16,
|
|
USE_BIAS=bias is not None,
|
|
ACTIVATION=ctx.activation,
|
|
TRAINABLE_BIAS=ctx.trainable_bias,
|
|
BLOCK_M=BLOCK_M,
|
|
BLOCK_N=BLOCK_N,
|
|
)
|
|
# fmt: on
|
|
|
|
return (
|
|
grad_in.reshape_as(grad_out),
|
|
None,
|
|
torch.sum(grad_bias, dim=0) if ctx.trainable_bias else None,
|
|
None,
|
|
None,
|
|
None,
|
|
)
|
|
|
|
|
|
def dropout(
|
|
x: torch.Tensor,
|
|
p: float,
|
|
bias: Optional[torch.Tensor] = None,
|
|
activation: Optional[Activation] = None,
|
|
):
|
|
"""
|
|
Apply dropout on the input tensor.
|
|
Optionally add a bias, the computation will be fused.
|
|
"""
|
|
|
|
assert p <= 1.0 and p >= 0.0
|
|
|
|
if p == 1.0:
|
|
return torch.zeros_like(x)
|
|
|
|
# Micro optim, skip dropout
|
|
if p == 0.0:
|
|
x = x + bias if bias is not None else x
|
|
if activation is not None:
|
|
activation_fn = build_activation(activation)
|
|
return activation_fn(x)
|
|
return x
|
|
|
|
# The normal triton enabled codepath
|
|
activation_index = get_triton_activation_index(activation)
|
|
return _dropout.apply(
|
|
x,
|
|
float(p),
|
|
bias,
|
|
activation_index,
|
|
bias is not None and bias.requires_grad,
|
|
)
|
|
|
|
|
|
class FusedDropoutBias(torch.nn.Module):
|
|
"""
|
|
A layer which fuses the computation of Dropout(Activation(x))
|
|
in a single GPU kernel
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
p: float,
|
|
bias_shape: Optional[int],
|
|
activation: Optional[Activation] = None,
|
|
) -> None:
|
|
super().__init__()
|
|
|
|
self.p = float(p)
|
|
|
|
assert (
|
|
self.p < 1.0
|
|
), f"We don't want to drop all the values, most probably p={p} is not properly set"
|
|
|
|
self.activation_type = activation
|
|
self.bias = (
|
|
torch.zeros(bias_shape, requires_grad=True)
|
|
if bias_shape is not None
|
|
else None
|
|
)
|
|
|
|
self.activation = get_triton_activation_index(self.activation_type)
|
|
self.activation_pytorch = build_activation(self.activation_type)
|
|
|
|
def init_weights(self, *args, **kwargs):
|
|
with torch.no_grad():
|
|
if self.bias is not None:
|
|
self.bias.fill_(0.0)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
# Convenience, catch a possible type or device mismatch
|
|
if self.bias is not None:
|
|
self.bias = self.bias.to(dtype=x.dtype, device=x.device) # type: ignore
|
|
|
|
# Train/inference
|
|
p = self.p if self.training else 0.0
|
|
|
|
# This kernel is slower than pytorch for small buffers, bypassing it in that case
|
|
perf_check = x.shape[-1] > 512
|
|
|
|
# Catch a non-cuda setup, fallback to pytorch
|
|
if not x.is_cuda or not perf_check or p == 0.0:
|
|
x = x + self.bias if self.bias is not None else x
|
|
x = self.activation_pytorch(x)
|
|
return torch.nn.functional.dropout(x, p) if p > 0.0 else x
|
|
|
|
# The normal, Triton-backed path
|
|
return _dropout.apply(x, p, self.bias, self.activation, True)
|