212 lines
5.8 KiB
Python
212 lines
5.8 KiB
Python
|
|
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||
|
|
#
|
||
|
|
# This source code is licensed under the BSD license found in the
|
||
|
|
# LICENSE file in the root directory of this source tree.
|
||
|
|
|
||
|
|
|
||
|
|
# CREDITS: This is heavily inspired by the Triton dropout tutorial
|
||
|
|
# https://raw.githubusercontent.com/openai/triton/master/python/tutorials/04-low-memory-dropout.py
|
||
|
|
|
||
|
|
import triton
|
||
|
|
import triton.language as tl
|
||
|
|
|
||
|
|
from xformers.triton.k_activations import (
|
||
|
|
gelu,
|
||
|
|
gelu_grad,
|
||
|
|
leaky_relu,
|
||
|
|
leaky_relu_grad,
|
||
|
|
relu,
|
||
|
|
relu_grad,
|
||
|
|
smelu,
|
||
|
|
smelu_grad,
|
||
|
|
squared_relu,
|
||
|
|
squared_relu_grad,
|
||
|
|
)
|
||
|
|
|
||
|
|
_configs = [
|
||
|
|
triton.Config({}, num_warps=1),
|
||
|
|
triton.Config({}, num_warps=2),
|
||
|
|
triton.Config({}, num_warps=4),
|
||
|
|
triton.Config({}, num_warps=8),
|
||
|
|
triton.Config({}, num_warps=16),
|
||
|
|
]
|
||
|
|
|
||
|
|
|
||
|
|
# fmt: off
|
||
|
|
@triton.heuristics({"SIZE_RAND_BLOCK": lambda args: args["BLOCK_N"] * args["BLOCK_M"]})
|
||
|
|
@triton.autotune(
|
||
|
|
configs=_configs,
|
||
|
|
key=["M", "N", "is_fp16"],
|
||
|
|
)
|
||
|
|
@triton.jit
|
||
|
|
def k_dropout_fw(
|
||
|
|
Y, X, BIAS, SEEDS,
|
||
|
|
stride,
|
||
|
|
M, N,
|
||
|
|
p: tl.constexpr,
|
||
|
|
is_fp16: tl.constexpr, # autotune
|
||
|
|
ACTIVATION: tl.constexpr,
|
||
|
|
# Meta-parameters
|
||
|
|
BLOCK_M: tl.constexpr,
|
||
|
|
BLOCK_N: tl.constexpr,
|
||
|
|
SIZE_RAND_BLOCK: tl.constexpr,
|
||
|
|
USE_BIAS: tl.constexpr,
|
||
|
|
):
|
||
|
|
"""
|
||
|
|
Apply dropout on an input tensor
|
||
|
|
Y : Output (M, N)
|
||
|
|
X : Input (M, N)
|
||
|
|
BIAS (N,)
|
||
|
|
SEEDS (M,)
|
||
|
|
p : dropout probability
|
||
|
|
"""
|
||
|
|
# fmt: on
|
||
|
|
|
||
|
|
row_id = tl.program_id(axis=0)
|
||
|
|
rows = row_id * BLOCK_M + tl.arange(0, BLOCK_M)
|
||
|
|
|
||
|
|
col_id = tl.program_id(axis=1)
|
||
|
|
cols = col_id * BLOCK_N + tl.arange(0, BLOCK_N)
|
||
|
|
|
||
|
|
# pointers starting point
|
||
|
|
x_ptrs = X + rows[:, None] * stride + cols[None, :]
|
||
|
|
y_ptrs = Y + rows[:, None] * stride + cols[None, :]
|
||
|
|
|
||
|
|
# good to go, start the layer computations
|
||
|
|
col_mask = cols[None, :] < N
|
||
|
|
p_scale = 1. / (1. - p)
|
||
|
|
if USE_BIAS:
|
||
|
|
b_ptrs = BIAS + cols[None, :]
|
||
|
|
bias = tl.load(b_ptrs, mask=cols[None, :] < N, other=0.)
|
||
|
|
else:
|
||
|
|
bias = x_ptrs # will not be used
|
||
|
|
|
||
|
|
block_mask = (rows[:, None] < M) & col_mask
|
||
|
|
x = tl.load(x_ptrs, mask=block_mask, other=0.0)
|
||
|
|
|
||
|
|
# optionally apply a fused bias
|
||
|
|
if USE_BIAS:
|
||
|
|
x += bias
|
||
|
|
|
||
|
|
# optional: fused activation (while the data is in shared memory)
|
||
|
|
if ACTIVATION == 1:
|
||
|
|
x = relu(x)
|
||
|
|
elif ACTIVATION == 2:
|
||
|
|
x = leaky_relu(x)
|
||
|
|
elif ACTIVATION == 3:
|
||
|
|
x = gelu(x)
|
||
|
|
elif ACTIVATION == 4:
|
||
|
|
x = squared_relu(x)
|
||
|
|
elif ACTIVATION == 5:
|
||
|
|
x = smelu(x)
|
||
|
|
|
||
|
|
# get the random keep mask
|
||
|
|
rand_offsets = tl.arange(0, SIZE_RAND_BLOCK)
|
||
|
|
seed_int = tl.load(SEEDS + col_id)
|
||
|
|
r = tl.rand(seed_int, rand_offsets)
|
||
|
|
keep_mask = r > p
|
||
|
|
|
||
|
|
# prune and normalize in one go
|
||
|
|
keep = tl.view(keep_mask, x.shape)
|
||
|
|
output = tl.where(keep, (x * p_scale).to(x.dtype), 0.)
|
||
|
|
|
||
|
|
tl.store(y_ptrs, output, mask=block_mask) # output
|
||
|
|
|
||
|
|
|
||
|
|
# fmt: off
|
||
|
|
@triton.heuristics({"SIZE_RAND_BLOCK": lambda args: args["BLOCK_N"] * args["BLOCK_M"]})
|
||
|
|
@triton.autotune(
|
||
|
|
configs=_configs,
|
||
|
|
key=["M", "N", "is_fp16"],
|
||
|
|
)
|
||
|
|
@triton.jit
|
||
|
|
def k_dropout_bw(
|
||
|
|
GRAD_IN, GRAD_BIAS, GRAD_OUT,
|
||
|
|
INPUTS, BIAS, SEEDS,
|
||
|
|
stride_grad, stride_inputs,
|
||
|
|
M, N,
|
||
|
|
p: tl.constexpr,
|
||
|
|
is_fp16: tl.constexpr, # autotune
|
||
|
|
ACTIVATION: tl.constexpr,
|
||
|
|
# Meta-parameters
|
||
|
|
BLOCK_M: tl.constexpr, # heuristics
|
||
|
|
BLOCK_N: tl.constexpr,
|
||
|
|
SIZE_RAND_BLOCK: tl.constexpr,
|
||
|
|
TRAINABLE_BIAS: tl.constexpr,
|
||
|
|
USE_BIAS: tl.constexpr,
|
||
|
|
):
|
||
|
|
"""
|
||
|
|
Apply dropout on an input tensor
|
||
|
|
GRAD_OUT (M, N)
|
||
|
|
GRAD_BIAS (N,)
|
||
|
|
GRAD_IN (M, N)
|
||
|
|
BIAS (N,)
|
||
|
|
SEEDS (N,)
|
||
|
|
p : dropout probability
|
||
|
|
"""
|
||
|
|
# fmt: on
|
||
|
|
|
||
|
|
row_id = tl.program_id(axis=0)
|
||
|
|
rows = row_id * BLOCK_M + tl.arange(0, BLOCK_M)
|
||
|
|
|
||
|
|
col_id = tl.program_id(axis=1)
|
||
|
|
cols = col_id * BLOCK_N + tl.arange(0, BLOCK_N)
|
||
|
|
|
||
|
|
# pointers starting point
|
||
|
|
grad_out_ptrs = GRAD_OUT + rows[:, None] * stride_grad + cols[None, :]
|
||
|
|
grad_in_ptrs = GRAD_IN + rows[:, None] * stride_grad + cols[None, :]
|
||
|
|
input_ptrs = INPUTS + rows[:, None] * stride_inputs + cols[None, :]
|
||
|
|
|
||
|
|
# now go over the tiles
|
||
|
|
grad_bias = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
||
|
|
col_mask = cols[None, :] < N
|
||
|
|
p_scale = 1. / (1. - p)
|
||
|
|
|
||
|
|
if USE_BIAS:
|
||
|
|
b_ptrs = BIAS + cols[None, :]
|
||
|
|
bias = tl.load(b_ptrs, mask=col_mask, other=0.)
|
||
|
|
|
||
|
|
block_mask = (rows[:, None] < M) & col_mask
|
||
|
|
grad_out = tl.load(grad_out_ptrs, mask=block_mask, other=0.)
|
||
|
|
|
||
|
|
# optional: fused activation (while the data is in shared memory)
|
||
|
|
if ACTIVATION:
|
||
|
|
inputs = tl.load(input_ptrs, mask=block_mask, other=0.)
|
||
|
|
|
||
|
|
# optionally apply a fused bias
|
||
|
|
if USE_BIAS:
|
||
|
|
inputs += bias
|
||
|
|
|
||
|
|
if ACTIVATION == 1:
|
||
|
|
act_grad = relu_grad(inputs)
|
||
|
|
elif ACTIVATION == 2:
|
||
|
|
act_grad = leaky_relu_grad(inputs)
|
||
|
|
elif ACTIVATION == 3:
|
||
|
|
act_grad = gelu_grad(inputs)
|
||
|
|
elif ACTIVATION == 4:
|
||
|
|
act_grad = squared_relu_grad(inputs)
|
||
|
|
elif ACTIVATION == 5:
|
||
|
|
act_grad = smelu_grad(inputs)
|
||
|
|
|
||
|
|
grad_out *= act_grad
|
||
|
|
|
||
|
|
# randomly prune (and scale) the resulting buffer, possibly a no-op
|
||
|
|
# note that even if we did not save the mask from the FW pass, it is generated
|
||
|
|
# from the same seeds, so the same drop mask is applied here
|
||
|
|
rand_offsets = tl.arange(0, SIZE_RAND_BLOCK)
|
||
|
|
seed_int = tl.load(SEEDS + col_id)
|
||
|
|
r = tl.rand(seed_int, rand_offsets)
|
||
|
|
r = tl.view(r, grad_out.shape)
|
||
|
|
output = tl.where(r > p, (grad_out * p_scale).to(grad_out.dtype), 0.)
|
||
|
|
|
||
|
|
# write-back
|
||
|
|
tl.store(grad_in_ptrs, output, mask=block_mask)
|
||
|
|
|
||
|
|
# optionally accumulate the bias gradient
|
||
|
|
if TRAINABLE_BIAS:
|
||
|
|
grad_bias += tl.sum(output, axis=0)
|
||
|
|
|
||
|
|
if TRAINABLE_BIAS:
|
||
|
|
grad_bias_ptr = GRAD_BIAS + row_id * N + cols
|
||
|
|
tl.store(grad_bias_ptr, grad_bias, mask=cols < N)
|