134 lines
3.7 KiB
Python
134 lines
3.7 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import inspect
|
|
import os
|
|
from typing import Any, Dict, List, Type, TypeVar
|
|
|
|
import torch
|
|
from torch.torch_version import TorchVersion
|
|
|
|
|
|
def get_operator(library: str, name: str):
|
|
def no_such_operator(*args, **kwargs):
|
|
raise RuntimeError(
|
|
f"No such operator {library}::{name} - did you forget to build xformers with `python setup.py develop`?"
|
|
)
|
|
|
|
try:
|
|
return getattr(getattr(torch.ops, library), name)
|
|
except (RuntimeError, AttributeError):
|
|
return no_such_operator
|
|
|
|
|
|
def get_xformers_operator(name: str):
|
|
return get_operator("xformers", name)
|
|
|
|
|
|
class BaseOperator:
|
|
OPERATOR: Any
|
|
NAME: str
|
|
OPERATOR_CATEGORY: str
|
|
|
|
@classmethod
|
|
def is_available(cls) -> bool:
|
|
if cls.OPERATOR is None or cls.OPERATOR.__name__ == "no_such_operator":
|
|
return False
|
|
return True
|
|
|
|
@classmethod
|
|
def operator_flop(cls, *inputs) -> int:
|
|
"""Calculate number of FLOP given inputs to `OPERATOR`"""
|
|
return -1
|
|
|
|
|
|
OPERATORS_REGISTRY: List[Type[BaseOperator]] = []
|
|
FUNC_TO_XFORMERS_OPERATOR: Dict[Any, Type[BaseOperator]] = {}
|
|
|
|
ClsT = TypeVar("ClsT")
|
|
|
|
|
|
def register_operator(cls: ClsT) -> ClsT:
|
|
global OPERATORS_REGISTRY, FUNC_TO_XFORMERS_OPERATOR
|
|
OPERATORS_REGISTRY.append(cls) # type: ignore
|
|
FUNC_TO_XFORMERS_OPERATOR[cls.OPERATOR] = cls # type: ignore
|
|
return cls
|
|
|
|
|
|
# post-2.0, avoids a warning
|
|
# (`torch.Tensor.storage` will also be deleted in the future)
|
|
_GET_TENSOR_STORAGE = getattr(torch.Tensor, "untyped_storage", None)
|
|
if _GET_TENSOR_STORAGE is None: # pre-2.0, `untyped_storage` didn't exist
|
|
_GET_TENSOR_STORAGE = torch.Tensor.storage
|
|
|
|
|
|
def _get_storage_base(x: torch.Tensor) -> int:
|
|
return _GET_TENSOR_STORAGE(x).data_ptr() # type: ignore
|
|
|
|
|
|
def make_pytorch_cuda_operator(fn: ClsT) -> ClsT:
|
|
from .. import get_python_lib
|
|
|
|
def render_arg_type(annotation) -> str:
|
|
if annotation is torch.Tensor:
|
|
return "Tensor"
|
|
if annotation is bool:
|
|
return "bool"
|
|
if annotation is int:
|
|
return "int"
|
|
if annotation is List[int]:
|
|
return "int[]"
|
|
if annotation is List[torch.Tensor]:
|
|
return "Tensor[]"
|
|
assert False, f"Unable to parse annotation: `{annotation}`"
|
|
|
|
sign = inspect.signature(fn) # type: ignore
|
|
arguments = [
|
|
f"{render_arg_type(arg.annotation)} {arg.name}"
|
|
for arg in sign.parameters.values()
|
|
]
|
|
op_name = fn.__name__ # type: ignore
|
|
definition = f"{op_name}({', '.join(arguments)}) -> {render_arg_type(sign.return_annotation)}"
|
|
|
|
xformers_lib = get_python_lib()
|
|
xformers_lib.define(definition)
|
|
xformers_lib.impl(op_name, fn, "CUDA")
|
|
dispatcher_impl = getattr(getattr(torch.ops, xformers_lib.ns), op_name)
|
|
|
|
def wrapper(*args, **kwargs):
|
|
return dispatcher_impl(*args, **kwargs)
|
|
|
|
return wrapper # type: ignore
|
|
|
|
|
|
def _has_a_version_of_triton():
|
|
if os.environ.get("XFORMERS_FORCE_DISABLE_TRITON", "0") == "1":
|
|
return False
|
|
if not torch.cuda.is_available():
|
|
return False
|
|
try:
|
|
import triton # noqa: F401
|
|
except ImportError:
|
|
return False
|
|
return True
|
|
|
|
|
|
def _has_triton2():
|
|
if not _has_a_version_of_triton():
|
|
return False
|
|
import triton
|
|
|
|
tv = TorchVersion(triton.__version__)
|
|
return tv >= (2, 1) or tv == (2, 0)
|
|
|
|
|
|
def _has_triton21():
|
|
if not _has_a_version_of_triton():
|
|
return False
|
|
import triton
|
|
|
|
tv = TorchVersion(triton.__version__)
|
|
return tv >= (2, 1)
|