First commit
This commit is contained in:
97
pkgs/xformers/ops/__init__.py
Normal file
97
pkgs/xformers/ops/__init__.py
Normal file
@@ -0,0 +1,97 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
|
||||
from .fmha import (
|
||||
AttentionBias,
|
||||
AttentionOp,
|
||||
AttentionOpBase,
|
||||
AttentionOpDispatch,
|
||||
LowerTriangularMask,
|
||||
MemoryEfficientAttentionCutlassFwdFlashBwOp,
|
||||
MemoryEfficientAttentionCutlassOp,
|
||||
MemoryEfficientAttentionFlashAttentionOp,
|
||||
MemoryEfficientAttentionOp,
|
||||
MemoryEfficientAttentionTritonFwdFlashBwOp,
|
||||
TritonFlashAttentionOp,
|
||||
memory_efficient_attention,
|
||||
memory_efficient_attention_backward,
|
||||
memory_efficient_attention_forward,
|
||||
memory_efficient_attention_forward_requires_grad,
|
||||
)
|
||||
from .indexing import index_select_cat, scaled_index_add
|
||||
from .rmsnorm import RMSNorm
|
||||
from .rope_padded import rope_padded
|
||||
from .swiglu_op import (
|
||||
SwiGLU,
|
||||
SwiGLUEagerOp,
|
||||
SwiGLUFusedOp,
|
||||
SwiGLUOp,
|
||||
SwiGLUOpDispatch,
|
||||
SwiGLUPackedFusedOp,
|
||||
swiglu,
|
||||
)
|
||||
from .unbind import get_stack_strides, stack_or_none, unbind
|
||||
|
||||
# BW compatibility
|
||||
AttentionMask = AttentionBias
|
||||
|
||||
|
||||
def masked_matmul(a, b, mask=None):
|
||||
if torch.overrides.has_torch_function((a, b, mask)):
|
||||
return torch.overrides.handle_torch_function(
|
||||
masked_matmul, (a, b, mask), a, b, mask
|
||||
)
|
||||
|
||||
att = a @ b
|
||||
|
||||
if mask is None:
|
||||
return att
|
||||
|
||||
if mask.dtype == torch.bool:
|
||||
if mask.ndim == 2:
|
||||
mask = mask.unsqueeze(0).expand(att.shape[0], -1, -1)
|
||||
# mask is presumed false == ignore
|
||||
att[~mask] = float("-inf")
|
||||
else:
|
||||
# mask is presumed additive
|
||||
att += mask
|
||||
return att
|
||||
|
||||
|
||||
__all__ = [
|
||||
"memory_efficient_attention",
|
||||
"AttentionBias",
|
||||
"AttentionMask",
|
||||
"AttentionOp",
|
||||
"AttentionOpBase",
|
||||
"AttentionOpDispatch",
|
||||
"LowerTriangularMask",
|
||||
"MemoryEfficientAttentionCutlassFwdFlashBwOp",
|
||||
"MemoryEfficientAttentionCutlassOp",
|
||||
"MemoryEfficientAttentionFlashAttentionOp",
|
||||
"MemoryEfficientAttentionOp",
|
||||
"MemoryEfficientAttentionTritonFwdFlashBwOp",
|
||||
"memory_efficient_attention_backward",
|
||||
"memory_efficient_attention_forward",
|
||||
"memory_efficient_attention_forward_requires_grad",
|
||||
"RMSNorm",
|
||||
"SwiGLU",
|
||||
"SwiGLUEagerOp",
|
||||
"SwiGLUFusedOp",
|
||||
"SwiGLUOp",
|
||||
"SwiGLUOpDispatch",
|
||||
"SwiGLUPackedFusedOp",
|
||||
"swiglu",
|
||||
"TritonFlashAttentionOp",
|
||||
"unbind",
|
||||
"stack_or_none",
|
||||
"get_stack_strides",
|
||||
"masked_matmul",
|
||||
"scaled_index_add",
|
||||
"index_select_cat",
|
||||
"rope_padded",
|
||||
]
|
||||
Reference in New Issue
Block a user