# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. import torch from .fmha import ( AttentionBias, AttentionOp, AttentionOpBase, AttentionOpDispatch, LowerTriangularMask, MemoryEfficientAttentionCutlassFwdFlashBwOp, MemoryEfficientAttentionCutlassOp, MemoryEfficientAttentionFlashAttentionOp, MemoryEfficientAttentionOp, MemoryEfficientAttentionTritonFwdFlashBwOp, TritonFlashAttentionOp, memory_efficient_attention, memory_efficient_attention_backward, memory_efficient_attention_forward, memory_efficient_attention_forward_requires_grad, ) from .indexing import index_select_cat, scaled_index_add from .rmsnorm import RMSNorm from .rope_padded import rope_padded from .swiglu_op import ( SwiGLU, SwiGLUEagerOp, SwiGLUFusedOp, SwiGLUOp, SwiGLUOpDispatch, SwiGLUPackedFusedOp, swiglu, ) from .unbind import get_stack_strides, stack_or_none, unbind # BW compatibility AttentionMask = AttentionBias def masked_matmul(a, b, mask=None): if torch.overrides.has_torch_function((a, b, mask)): return torch.overrides.handle_torch_function( masked_matmul, (a, b, mask), a, b, mask ) att = a @ b if mask is None: return att if mask.dtype == torch.bool: if mask.ndim == 2: mask = mask.unsqueeze(0).expand(att.shape[0], -1, -1) # mask is presumed false == ignore att[~mask] = float("-inf") else: # mask is presumed additive att += mask return att __all__ = [ "memory_efficient_attention", "AttentionBias", "AttentionMask", "AttentionOp", "AttentionOpBase", "AttentionOpDispatch", "LowerTriangularMask", "MemoryEfficientAttentionCutlassFwdFlashBwOp", "MemoryEfficientAttentionCutlassOp", "MemoryEfficientAttentionFlashAttentionOp", "MemoryEfficientAttentionOp", "MemoryEfficientAttentionTritonFwdFlashBwOp", "memory_efficient_attention_backward", "memory_efficient_attention_forward", "memory_efficient_attention_forward_requires_grad", "RMSNorm", "SwiGLU", "SwiGLUEagerOp", "SwiGLUFusedOp", "SwiGLUOp", "SwiGLUOpDispatch", "SwiGLUPackedFusedOp", "swiglu", "TritonFlashAttentionOp", "unbind", "stack_or_none", "get_stack_strides", "masked_matmul", "scaled_index_add", "index_select_cat", "rope_padded", ]