Files
enginex-bi_series-vllm/pkgs/xformers/triton/k_dropout.py

212 lines
5.8 KiB
Python
Raw Normal View History

2025-08-05 19:02:46 +08:00
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# CREDITS: This is heavily inspired by the Triton dropout tutorial
# https://raw.githubusercontent.com/openai/triton/master/python/tutorials/04-low-memory-dropout.py
import triton
import triton.language as tl
from xformers.triton.k_activations import (
gelu,
gelu_grad,
leaky_relu,
leaky_relu_grad,
relu,
relu_grad,
smelu,
smelu_grad,
squared_relu,
squared_relu_grad,
)
_configs = [
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
triton.Config({}, num_warps=16),
]
# fmt: off
@triton.heuristics({"SIZE_RAND_BLOCK": lambda args: args["BLOCK_N"] * args["BLOCK_M"]})
@triton.autotune(
configs=_configs,
key=["M", "N", "is_fp16"],
)
@triton.jit
def k_dropout_fw(
Y, X, BIAS, SEEDS,
stride,
M, N,
p: tl.constexpr,
is_fp16: tl.constexpr, # autotune
ACTIVATION: tl.constexpr,
# Meta-parameters
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
SIZE_RAND_BLOCK: tl.constexpr,
USE_BIAS: tl.constexpr,
):
"""
Apply dropout on an input tensor
Y : Output (M, N)
X : Input (M, N)
BIAS (N,)
SEEDS (M,)
p : dropout probability
"""
# fmt: on
row_id = tl.program_id(axis=0)
rows = row_id * BLOCK_M + tl.arange(0, BLOCK_M)
col_id = tl.program_id(axis=1)
cols = col_id * BLOCK_N + tl.arange(0, BLOCK_N)
# pointers starting point
x_ptrs = X + rows[:, None] * stride + cols[None, :]
y_ptrs = Y + rows[:, None] * stride + cols[None, :]
# good to go, start the layer computations
col_mask = cols[None, :] < N
p_scale = 1. / (1. - p)
if USE_BIAS:
b_ptrs = BIAS + cols[None, :]
bias = tl.load(b_ptrs, mask=cols[None, :] < N, other=0.)
else:
bias = x_ptrs # will not be used
block_mask = (rows[:, None] < M) & col_mask
x = tl.load(x_ptrs, mask=block_mask, other=0.0)
# optionally apply a fused bias
if USE_BIAS:
x += bias
# optional: fused activation (while the data is in shared memory)
if ACTIVATION == 1:
x = relu(x)
elif ACTIVATION == 2:
x = leaky_relu(x)
elif ACTIVATION == 3:
x = gelu(x)
elif ACTIVATION == 4:
x = squared_relu(x)
elif ACTIVATION == 5:
x = smelu(x)
# get the random keep mask
rand_offsets = tl.arange(0, SIZE_RAND_BLOCK)
seed_int = tl.load(SEEDS + col_id)
r = tl.rand(seed_int, rand_offsets)
keep_mask = r > p
# prune and normalize in one go
keep = tl.view(keep_mask, x.shape)
output = tl.where(keep, (x * p_scale).to(x.dtype), 0.)
tl.store(y_ptrs, output, mask=block_mask) # output
# fmt: off
@triton.heuristics({"SIZE_RAND_BLOCK": lambda args: args["BLOCK_N"] * args["BLOCK_M"]})
@triton.autotune(
configs=_configs,
key=["M", "N", "is_fp16"],
)
@triton.jit
def k_dropout_bw(
GRAD_IN, GRAD_BIAS, GRAD_OUT,
INPUTS, BIAS, SEEDS,
stride_grad, stride_inputs,
M, N,
p: tl.constexpr,
is_fp16: tl.constexpr, # autotune
ACTIVATION: tl.constexpr,
# Meta-parameters
BLOCK_M: tl.constexpr, # heuristics
BLOCK_N: tl.constexpr,
SIZE_RAND_BLOCK: tl.constexpr,
TRAINABLE_BIAS: tl.constexpr,
USE_BIAS: tl.constexpr,
):
"""
Apply dropout on an input tensor
GRAD_OUT (M, N)
GRAD_BIAS (N,)
GRAD_IN (M, N)
BIAS (N,)
SEEDS (N,)
p : dropout probability
"""
# fmt: on
row_id = tl.program_id(axis=0)
rows = row_id * BLOCK_M + tl.arange(0, BLOCK_M)
col_id = tl.program_id(axis=1)
cols = col_id * BLOCK_N + tl.arange(0, BLOCK_N)
# pointers starting point
grad_out_ptrs = GRAD_OUT + rows[:, None] * stride_grad + cols[None, :]
grad_in_ptrs = GRAD_IN + rows[:, None] * stride_grad + cols[None, :]
input_ptrs = INPUTS + rows[:, None] * stride_inputs + cols[None, :]
# now go over the tiles
grad_bias = tl.zeros((BLOCK_N,), dtype=tl.float32)
col_mask = cols[None, :] < N
p_scale = 1. / (1. - p)
if USE_BIAS:
b_ptrs = BIAS + cols[None, :]
bias = tl.load(b_ptrs, mask=col_mask, other=0.)
block_mask = (rows[:, None] < M) & col_mask
grad_out = tl.load(grad_out_ptrs, mask=block_mask, other=0.)
# optional: fused activation (while the data is in shared memory)
if ACTIVATION:
inputs = tl.load(input_ptrs, mask=block_mask, other=0.)
# optionally apply a fused bias
if USE_BIAS:
inputs += bias
if ACTIVATION == 1:
act_grad = relu_grad(inputs)
elif ACTIVATION == 2:
act_grad = leaky_relu_grad(inputs)
elif ACTIVATION == 3:
act_grad = gelu_grad(inputs)
elif ACTIVATION == 4:
act_grad = squared_relu_grad(inputs)
elif ACTIVATION == 5:
act_grad = smelu_grad(inputs)
grad_out *= act_grad
# randomly prune (and scale) the resulting buffer, possibly a no-op
# note that even if we did not save the mask from the FW pass, it is generated
# from the same seeds, so the same drop mask is applied here
rand_offsets = tl.arange(0, SIZE_RAND_BLOCK)
seed_int = tl.load(SEEDS + col_id)
r = tl.rand(seed_int, rand_offsets)
r = tl.view(r, grad_out.shape)
output = tl.where(r > p, (grad_out * p_scale).to(grad_out.dtype), 0.)
# write-back
tl.store(grad_in_ptrs, output, mask=block_mask)
# optionally accumulate the bias gradient
if TRAINABLE_BIAS:
grad_bias += tl.sum(output, axis=0)
if TRAINABLE_BIAS:
grad_bias_ptr = GRAD_BIAS + row_id * N + cols
tl.store(grad_bias_ptr, grad_bias, mask=cols < N)