Files
2025-08-05 19:02:46 +08:00

98 lines
2.5 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import torch
from .fmha import (
AttentionBias,
AttentionOp,
AttentionOpBase,
AttentionOpDispatch,
LowerTriangularMask,
MemoryEfficientAttentionCutlassFwdFlashBwOp,
MemoryEfficientAttentionCutlassOp,
MemoryEfficientAttentionFlashAttentionOp,
MemoryEfficientAttentionOp,
MemoryEfficientAttentionTritonFwdFlashBwOp,
TritonFlashAttentionOp,
memory_efficient_attention,
memory_efficient_attention_backward,
memory_efficient_attention_forward,
memory_efficient_attention_forward_requires_grad,
)
from .indexing import index_select_cat, scaled_index_add
from .rmsnorm import RMSNorm
from .rope_padded import rope_padded
from .swiglu_op import (
SwiGLU,
SwiGLUEagerOp,
SwiGLUFusedOp,
SwiGLUOp,
SwiGLUOpDispatch,
SwiGLUPackedFusedOp,
swiglu,
)
from .unbind import get_stack_strides, stack_or_none, unbind
# BW compatibility
AttentionMask = AttentionBias
def masked_matmul(a, b, mask=None):
if torch.overrides.has_torch_function((a, b, mask)):
return torch.overrides.handle_torch_function(
masked_matmul, (a, b, mask), a, b, mask
)
att = a @ b
if mask is None:
return att
if mask.dtype == torch.bool:
if mask.ndim == 2:
mask = mask.unsqueeze(0).expand(att.shape[0], -1, -1)
# mask is presumed false == ignore
att[~mask] = float("-inf")
else:
# mask is presumed additive
att += mask
return att
__all__ = [
"memory_efficient_attention",
"AttentionBias",
"AttentionMask",
"AttentionOp",
"AttentionOpBase",
"AttentionOpDispatch",
"LowerTriangularMask",
"MemoryEfficientAttentionCutlassFwdFlashBwOp",
"MemoryEfficientAttentionCutlassOp",
"MemoryEfficientAttentionFlashAttentionOp",
"MemoryEfficientAttentionOp",
"MemoryEfficientAttentionTritonFwdFlashBwOp",
"memory_efficient_attention_backward",
"memory_efficient_attention_forward",
"memory_efficient_attention_forward_requires_grad",
"RMSNorm",
"SwiGLU",
"SwiGLUEagerOp",
"SwiGLUFusedOp",
"SwiGLUOp",
"SwiGLUOpDispatch",
"SwiGLUPackedFusedOp",
"swiglu",
"TritonFlashAttentionOp",
"unbind",
"stack_or_none",
"get_stack_strides",
"masked_matmul",
"scaled_index_add",
"index_select_cat",
"rope_padded",
]