Files
enginex-bi_series-vllm/pkgs/xformers/triton/dropout.py
2025-08-05 19:02:46 +08:00

244 lines
7.4 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# CREDITS: This is heavily inspired by the Triton dropout tutorial
# https://raw.githubusercontent.com/openai/triton/master/python/tutorials/04-low-memory-dropout.py
from typing import Optional
import torch
import triton
from torch.cuda.amp import custom_bwd, custom_fwd
from xformers.components.activations import Activation, build_activation
from xformers.triton.k_activations import get_triton_activation_index
from xformers.triton.k_dropout import k_dropout_bw, k_dropout_fw
BLOCK_M = 32
BLOCK_N = 64 # NOTE: This should ideally be GPU dependent, big impact on perf
# Helper to handle the SPMD launch grid and error cases
class _dropout(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx, x, p, bias, activation, trainable_bias):
# Soft-flatten an hypothetical 3rd dimension
x_ = x.reshape(-1, x.shape[-1]).contiguous()
y = torch.empty_like(x_)
M, N = x_.shape
assert bias is None or (bias.dtype == x.dtype and bias.shape[0] == N)
assert p > 0.0
def grid(meta):
return (
triton.cdiv(M, meta["BLOCK_M"]),
triton.cdiv(N, meta["BLOCK_N"]),
)
N_BLOCK_N = triton.cdiv(N, BLOCK_N)
# Generate one seed per sample
# seed max is int32 max for positive numbers: 2**16
seeds = torch.randint(65536, (N_BLOCK_N,), device=x.device, dtype=torch.int32)
# fmt: off
bias_ptr = bias if bias is not None else x_ # Possibly not being used
k_dropout_fw[grid](
y, x_,
bias_ptr,
seeds,
y.stride(0),
M, N,
p,
x.dtype == torch.float16,
USE_BIAS=bias is not None,
ACTIVATION=activation,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
)
# fmt: on
if activation is not None:
ctx.save_for_backward(seeds, bias, x)
else:
ctx.save_for_backward(seeds, bias, None)
ctx.trainable_bias = bias is not None and trainable_bias
ctx.activation = activation
ctx.p = p
return y.reshape_as(x)
@staticmethod
@custom_bwd
def backward(
ctx, grad_out
): # pragma: no cover # This is covered, but called from C++ and not tracked
(seeds, bias, inputs) = ctx.saved_tensors
# Soft-flatten an hypothetical 3rd dimension
grad_out_ = grad_out.reshape(-1, grad_out.shape[-1]).contiguous()
grad_in = torch.empty_like(grad_out_)
M, N = grad_out_.shape
# Optional inputs to compute the activation contribution to the gradient
assert inputs is not None or ctx.activation is None
if inputs is None:
inputs = grad_out_
elif inputs.ndim > 2:
inputs = inputs.reshape(-1, N)
# We split the problem in tiles:
# - over M there will be a follow up reduction
# - over N we compromise in between trying to use as much memory paralellism as possible,
# (fill in the warps, there are 32 threads per warps, and 4 warps default), and not being too
# big because of register spilling
N_BLOCKS_M = triton.cdiv(M, BLOCK_M)
if ctx.trainable_bias:
grad_bias = torch.empty(
(
N_BLOCKS_M,
N,
),
device=grad_in.device,
dtype=grad_in.dtype,
)
else:
grad_bias = grad_in # will not be used
def grid(meta):
# NOTE: We use Triton Philox random number generator, which optimally generates 4 blocks for
# a given seed and offsets. "BLOCK_M" here describes the size of one of these blocks
# but we need to take this factor of 4 into account when scheduling all the kernels
return (
N_BLOCKS_M,
triton.cdiv(N, meta["BLOCK_N"]),
)
# fmt: off
k_dropout_bw[grid](
grad_in, grad_bias, grad_out_,
inputs, bias if bias is not None else inputs,
seeds,
grad_out_.stride(0), inputs.stride(0),
M, N,
ctx.p,
grad_in.dtype == torch.float16,
USE_BIAS=bias is not None,
ACTIVATION=ctx.activation,
TRAINABLE_BIAS=ctx.trainable_bias,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
)
# fmt: on
return (
grad_in.reshape_as(grad_out),
None,
torch.sum(grad_bias, dim=0) if ctx.trainable_bias else None,
None,
None,
None,
)
def dropout(
x: torch.Tensor,
p: float,
bias: Optional[torch.Tensor] = None,
activation: Optional[Activation] = None,
):
"""
Apply dropout on the input tensor.
Optionally add a bias, the computation will be fused.
"""
assert p <= 1.0 and p >= 0.0
if p == 1.0:
return torch.zeros_like(x)
# Micro optim, skip dropout
if p == 0.0:
x = x + bias if bias is not None else x
if activation is not None:
activation_fn = build_activation(activation)
return activation_fn(x)
return x
# The normal triton enabled codepath
activation_index = get_triton_activation_index(activation)
return _dropout.apply(
x,
float(p),
bias,
activation_index,
bias is not None and bias.requires_grad,
)
class FusedDropoutBias(torch.nn.Module):
"""
A layer which fuses the computation of Dropout(Activation(x))
in a single GPU kernel
"""
def __init__(
self,
p: float,
bias_shape: Optional[int],
activation: Optional[Activation] = None,
) -> None:
super().__init__()
self.p = float(p)
assert (
self.p < 1.0
), f"We don't want to drop all the values, most probably p={p} is not properly set"
self.activation_type = activation
self.bias = (
torch.zeros(bias_shape, requires_grad=True)
if bias_shape is not None
else None
)
self.activation = get_triton_activation_index(self.activation_type)
self.activation_pytorch = build_activation(self.activation_type)
def init_weights(self, *args, **kwargs):
with torch.no_grad():
if self.bias is not None:
self.bias.fill_(0.0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Convenience, catch a possible type or device mismatch
if self.bias is not None:
self.bias = self.bias.to(dtype=x.dtype, device=x.device) # type: ignore
# Train/inference
p = self.p if self.training else 0.0
# This kernel is slower than pytorch for small buffers, bypassing it in that case
perf_check = x.shape[-1] > 512
# Catch a non-cuda setup, fallback to pytorch
if not x.is_cuda or not perf_check or p == 0.0:
x = x + self.bias if self.bias is not None else x
x = self.activation_pytorch(x)
return torch.nn.functional.dropout(x, p) if p > 0.0 else x
# The normal, Triton-backed path
return _dropout.apply(x, p, self.bias, self.activation, True)