Files
enginex-bi_series-vllm/pkgs/xformers/triton/softmax.py
2025-08-05 19:02:46 +08:00

204 lines
6.2 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import logging
from typing import Optional
import torch
import triton
from torch.cuda.amp import custom_bwd, custom_fwd
from xformers.triton.k_softmax import _softmax, _softmax_backward
# CREDITS: This is adapted from the vanilla Triton example. See https://openai.com/blog/triton/
# and https://triton-lang.org/getting-started/tutorials/02-fused-softmax.html
logger = logging.getLogger("xformers")
_triton_softmax_fp16_enabled = False # NOTE: PyTorch keeps softmax as fp32
_triton_registered_warnings = False
# Helper to handle the SPMD launch grid and error cases
class _softmax_triton(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float16 if _triton_softmax_fp16_enabled else None)
def forward(ctx, x, mask, log_outputs, causal):
"""
Fused softmax implementation, using the Triton programming model.
This only supports a reduction over the last dimension for now
"""
# Handle 2D/3D tensors
x_ = x.unsqueeze(0) if x.ndim == 2 else x
x_ = x_.flatten(0, -3)
if not x_.is_contiguous():
x_ = x_.contiguous()
y = torch.empty_like(x_)
assert (
y.stride(2) == 1 and x_.stride(2) == 1
), f"{x.shape} - {x_.shape} - {x_.stride()}"
# SPMD launch grid
grid_2d = (
x_.shape[0],
x_.shape[1],
)
# enqueue GPU kernel
use_mask = True
if mask is None:
# placeholder, will not be used
mask = x_
use_mask = False
else:
# Make sure that the mask is binary
assert mask.dtype == x.dtype, "An additive mask is requested"
_softmax[grid_2d](
y,
x_,
mask,
y.stride(0),
y.stride(1),
x_.stride(0),
x_.stride(1),
mask.stride(0),
x_.shape[2],
log=log_outputs,
use_mask=use_mask,
causal=causal,
)
ctx.save_for_backward(y)
ctx.log_outputs = log_outputs
ctx.causal = causal
return y.reshape_as(x)
@staticmethod
@custom_bwd
def backward(
ctx, grad_out
): # pragma: no cover # this is covered, but called directly from C++
(out,) = ctx.saved_tensors
# Handle 2D/3D tensors
grad_out_ = grad_out.unsqueeze(0) if grad_out.ndim == 2 else grad_out
grad_out_ = grad_out_.flatten(0, -3)
# SPMD launch grid
grid_2d = (
grad_out_.shape[0],
grad_out_.shape[1],
)
depth = triton.next_power_of_2(grad_out_.shape[2])
grad_in = torch.empty_like(
out
) # torch.zeros is measurably slower, we'll zero out in the kernel
# Make sure that the tensor are contiguous
grad_in, grad_out, out = map(lambda x: x.contiguous(), [grad_in, grad_out, out])
# fmt: off
_softmax_backward[grid_2d](
grad_in, grad_out_, out,
grad_in.stride(0), grad_in.stride(1),
grad_out_.stride(0), grad_out_.stride(1),
out.stride(0), out.stride(1),
out.shape[2],
depth=depth,
log=ctx.log_outputs,
causal=ctx.causal
)
# fmt: on
return grad_in.reshape_as(grad_out), None, None, None
def softmax(
x: torch.Tensor, mask: Optional[torch.Tensor] = None, causal: bool = False
) -> torch.Tensor:
r"""Applies the Softmax function to an 3-dimensional input Tensor
rescaling them so that the elements of the n-dimensional output Tensor
lie in the range [0,1] and sum to 1.
Softmax is defined as:
.. math::
\text{Softmax}(x_{i}) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}
.. warning: softmax is computed on the last dimension of the input tensor.
Args:
x: input tensor.
mask: optional mask, its application will be fused to the softmax computation if triton is used
causal: optional performance optimization, if triton is used and the attention is causal
Returns:
a Tensor of the same dimension and shape as the input with
values in the range [0, 1] and sum to 1
"""
return _softmax_dispatch(x, log=False, mask=mask, causal=causal)
def log_softmax(
x: torch.Tensor, mask: Optional[torch.Tensor] = None, causal: bool = False
) -> torch.Tensor:
r"""Applies the :math:`\log(\text{Softmax}(x))` function to an 3-dimensional
input Tensor. The LogSoftmax formulation can be simplified as:
.. math::
\text{LogSoftmax}(x_{i}) = \log\left(\frac{\exp(x_i) }{ \sum_j \exp(x_j)} \right)
Args:
x: input tensor.
Returns:
a Tensor of the same dimension and shape as the input with
values in the range [-inf, 0)
"""
return _softmax_dispatch(x, log=True, mask=mask, causal=causal)
def _softmax_dispatch(
x: torch.Tensor, log: bool, mask: Optional[torch.Tensor], causal: bool = False
) -> torch.Tensor:
# Triton is used if
# - CUDA
# - there's enough data to make it faster than pytorch. This could change over time, Triton is improving
# - there was no previous failure
global _triton_registered_warnings
try:
if torch.cuda.is_available() and x.is_cuda and not _triton_registered_warnings:
return _softmax_triton.apply(x, mask, log, causal)
except RuntimeError as e:
# Catch cases where the current GPU does not have enough registers to hold a full tensor line
# fallback to PyTorch's implementation, which streams the tensor in and out
_triton_registered_warnings = True
logger.warning(
"Triton softmax kernel register spillover or invalid image caught."
"Deactivating this kernel, please file an issue int the xFormers repository"
)
logger.warning(e)
if mask is not None:
x = x + mask
if causal:
x = x + torch.triu(torch.full_like(x, float("-inf")), diagonal=1)
if log:
return torch.log_softmax(x, dim=-1)
else:
return torch.softmax(x, dim=-1)