First commit
This commit is contained in:
201
pkgs/xformers/ops/fmha/triton.py
Normal file
201
pkgs/xformers/ops/fmha/triton.py
Normal file
@@ -0,0 +1,201 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
from dataclasses import replace
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from ... import _is_triton_available
|
||||
from ..common import register_operator
|
||||
|
||||
# This implementation needs pre-MLIR triton
|
||||
# The BW pass is not stable/well tested
|
||||
# And also does not have the latest improvements
|
||||
if TYPE_CHECKING or (False and _is_triton_available()):
|
||||
try:
|
||||
from flash_attn.flash_attn_triton import (
|
||||
_flash_attn_backward,
|
||||
_flash_attn_forward,
|
||||
)
|
||||
except ImportError:
|
||||
import importlib
|
||||
import pathlib
|
||||
import sys
|
||||
import types
|
||||
|
||||
def import_module_from_path(path: str) -> types.ModuleType:
|
||||
"""Import a module from the given path, w/o __init__.py"""
|
||||
module_path = pathlib.Path(path).resolve()
|
||||
module_name = module_path.stem # 'path/x.py' -> 'x'
|
||||
spec = importlib.util.spec_from_file_location(module_name, module_path) # type: ignore
|
||||
assert isinstance(spec, importlib.machinery.ModuleSpec)
|
||||
module = importlib.util.module_from_spec(spec) # type: ignore
|
||||
sys.modules[module_name] = module
|
||||
assert isinstance(spec.loader, importlib.abc.Loader)
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
flash_attn = import_module_from_path(
|
||||
"third_party/flash-attention/flash_attn/flash_attn_triton.py"
|
||||
)
|
||||
_flash_attn_backward = flash_attn._flash_attn_backward
|
||||
_flash_attn_forward = flash_attn._flash_attn_forward
|
||||
|
||||
triton_flash_backward = _flash_attn_backward
|
||||
triton_flash_forward = _flash_attn_forward
|
||||
else:
|
||||
triton_flash_backward = None
|
||||
triton_flash_forward = None
|
||||
|
||||
from .attn_bias import LowerTriangularMask
|
||||
from .common import (
|
||||
AttentionBwOpBase,
|
||||
AttentionFwOpBase,
|
||||
Context,
|
||||
Gradients,
|
||||
Inputs,
|
||||
check_lastdim_alignment_stride1,
|
||||
)
|
||||
|
||||
|
||||
def _prepare_inputs(inp: Inputs) -> Inputs:
|
||||
attn_bias = inp.attn_bias
|
||||
if isinstance(attn_bias, torch.Tensor) and attn_bias.ndim == 3:
|
||||
B = inp.query.shape[0]
|
||||
h = attn_bias.shape[0] // B
|
||||
attn_bias = attn_bias.reshape(B, h, attn_bias.shape[1], attn_bias.shape[2])
|
||||
|
||||
# Make sure that the last dimension is contiguous
|
||||
query, key, value = [
|
||||
x if x.stride(-1) == 1 else x.contiguous()
|
||||
for x in [inp.query, inp.key, inp.value]
|
||||
]
|
||||
return replace(inp, attn_bias=attn_bias, query=query, key=key, value=value)
|
||||
|
||||
|
||||
@register_operator
|
||||
class FwOp(AttentionFwOpBase):
|
||||
"""Operator that computes memory-efficient attention using \
|
||||
`Tri Dao's <https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attn_triton.py>`_ \
|
||||
implementation, based on
|
||||
`Phil Tillet's code <https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py>`_
|
||||
"""
|
||||
|
||||
OPERATOR = triton_flash_forward
|
||||
SUPPORTED_DEVICES = {"cuda"}
|
||||
CUDA_MINIMUM_COMPUTE_CAPABILITY = (8, 0)
|
||||
SUPPORTED_DTYPES = {torch.half, torch.bfloat16}
|
||||
SUPPORTED_MAX_K = 128
|
||||
SUPPORTED_ATTN_BIAS_TYPES: Set[Any] = {
|
||||
type(None),
|
||||
LowerTriangularMask,
|
||||
# TODO: backwards accuracy is failing for a few cases, perhaps we want to disable this for now.
|
||||
# torch.Tensor,
|
||||
}
|
||||
SUPPORTS_DROPOUT = False
|
||||
SUPPORTS_CUSTOM_SCALE = True
|
||||
NAME = "tritonflashattF"
|
||||
|
||||
@classmethod
|
||||
def not_supported_reasons(cls, d: Inputs) -> List[str]:
|
||||
reasons = super(FwOp, cls).not_supported_reasons(d)
|
||||
check_lastdim_alignment_stride1(reasons, "query", d.query, 8)
|
||||
check_lastdim_alignment_stride1(reasons, "key", d.key, 8)
|
||||
check_lastdim_alignment_stride1(reasons, "value", d.value, 8)
|
||||
if cls.OPERATOR is None:
|
||||
reasons.append("triton is not available")
|
||||
if d.device.type == "cuda":
|
||||
# Has only been tested on 8.0 / 9.0.
|
||||
# Fails on 7.5 with illegal memory access
|
||||
if torch.cuda.get_device_capability(d.device) < (8, 0):
|
||||
reasons.append(
|
||||
"requires GPU with sm80 minimum compute capacity, e.g., A100/H100/L4"
|
||||
)
|
||||
if _is_triton_available():
|
||||
import triton
|
||||
|
||||
if triton.__version__ > "2.0.0":
|
||||
reasons.append("Only work on pre-MLIR triton for now")
|
||||
return reasons
|
||||
|
||||
@classmethod
|
||||
def apply(
|
||||
cls, inp: Inputs, needs_gradient: bool
|
||||
) -> Tuple[torch.Tensor, Optional[Context]]:
|
||||
inp = _prepare_inputs(inp)
|
||||
|
||||
out, lse, softmax_scale = triton_flash_forward(
|
||||
q=inp.query,
|
||||
k=inp.key,
|
||||
v=inp.value,
|
||||
bias=inp.attn_bias if isinstance(inp.attn_bias, torch.Tensor) else None,
|
||||
softmax_scale=inp.scale_float,
|
||||
causal=isinstance(inp.attn_bias, LowerTriangularMask),
|
||||
)
|
||||
return out, Context(lse=lse, out=out)
|
||||
|
||||
|
||||
@register_operator
|
||||
class BwOp(AttentionBwOpBase):
|
||||
__doc__ = FwOp.__doc__
|
||||
|
||||
OPERATOR = triton_flash_backward
|
||||
SUPPORTED_DEVICES = FwOp.SUPPORTED_DEVICES
|
||||
CUDA_MINIMUM_COMPUTE_CAPABILITY = FwOp.CUDA_MINIMUM_COMPUTE_CAPABILITY
|
||||
SUPPORTED_DTYPES = FwOp.SUPPORTED_DTYPES
|
||||
SUPPORTED_MAX_K = FwOp.SUPPORTED_MAX_K
|
||||
SUPPORTED_ATTN_BIAS_TYPES = FwOp.SUPPORTED_ATTN_BIAS_TYPES
|
||||
SUPPORTS_DROPOUT = FwOp.SUPPORTS_DROPOUT
|
||||
SUPPORTS_CUSTOM_SCALE = FwOp.SUPPORTS_CUSTOM_SCALE
|
||||
SUPPORTS_DIFFERENT_VALUE_EMBED = FwOp.SUPPORTS_DIFFERENT_VALUE_EMBED
|
||||
NAME = "tritonflashattB"
|
||||
|
||||
@classmethod
|
||||
def not_supported_reasons(cls, d: Inputs) -> List[str]:
|
||||
reasons = super(BwOp, cls).not_supported_reasons(d)
|
||||
check_lastdim_alignment_stride1(reasons, "query", d.query, 8)
|
||||
check_lastdim_alignment_stride1(reasons, "key", d.key, 8)
|
||||
check_lastdim_alignment_stride1(reasons, "value", d.value, 8)
|
||||
if cls.OPERATOR is None:
|
||||
reasons.append("triton is not available")
|
||||
if d.device.type == "cuda":
|
||||
if torch.cuda.get_device_capability(d.device) != (8, 0):
|
||||
reasons.append("requires A100 GPU")
|
||||
if _is_triton_available():
|
||||
import triton
|
||||
|
||||
if triton.__version__ > "2.0.0":
|
||||
reasons.append("Only work on pre-MLIR triton for now")
|
||||
return reasons
|
||||
|
||||
@classmethod
|
||||
def apply(cls, ctx: Context, inp: Inputs, grad: torch.Tensor) -> Gradients:
|
||||
inp = _prepare_inputs(inp)
|
||||
|
||||
# Triton's autotune causes the Tensor._version to change, and so Pytorch autograd
|
||||
# does a memcpy. To avoid this we run in inference_mode, which doesn't track the version.
|
||||
with torch.inference_mode():
|
||||
grads = Gradients(
|
||||
dq=torch.empty_like(inp.query),
|
||||
dk=torch.empty_like(inp.key),
|
||||
dv=torch.empty_like(inp.value),
|
||||
)
|
||||
cls.OPERATOR(
|
||||
grad,
|
||||
inp.query,
|
||||
inp.key,
|
||||
inp.value,
|
||||
ctx.out,
|
||||
ctx.get_padded_lse(128),
|
||||
grads.dq,
|
||||
grads.dk,
|
||||
grads.dv,
|
||||
bias=inp.attn_bias if isinstance(inp.attn_bias, torch.Tensor) else None,
|
||||
softmax_scale=inp.scale_float,
|
||||
causal=isinstance(inp.attn_bias, LowerTriangularMask),
|
||||
)
|
||||
return grads
|
||||
Reference in New Issue
Block a user