First commit
This commit is contained in:
121
pkgs/xformers/components/attention/_sputnik_sparse.py
Normal file
121
pkgs/xformers/components/attention/_sputnik_sparse.py
Normal file
@@ -0,0 +1,121 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
from xformers.ops import masked_matmul
|
||||
from xformers.sparse import SparseCSRTensor
|
||||
|
||||
# TODO: this is here for BC
|
||||
from xformers.sparse.utils import _csr_to_coo, _dense_to_sparse # noqa: F401
|
||||
|
||||
|
||||
class SparseCS:
|
||||
def __init__(self, matrix, device=None):
|
||||
if device is None:
|
||||
device = torch.device("cpu")
|
||||
if matrix.ndim == 2:
|
||||
matrix = matrix[None]
|
||||
assert matrix.ndim == 3
|
||||
self._mat = SparseCSRTensor.from_dense(matrix).to(device)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self._mat.device
|
||||
|
||||
@property
|
||||
def ndim(self):
|
||||
return self._mat.ndim
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self._mat.dtype
|
||||
|
||||
@property
|
||||
def is_sparse(self):
|
||||
return True
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return self._mat.shape[1:]
|
||||
|
||||
@property
|
||||
def values(self):
|
||||
return self._mat.values()
|
||||
|
||||
@property
|
||||
def row_indices(self):
|
||||
return self._mat._csr_row_indices
|
||||
|
||||
@property
|
||||
def column_indices(self):
|
||||
return self._mat._csr_column_indices
|
||||
|
||||
@property
|
||||
def row_offsets(self):
|
||||
return self._mat._csr_row_offsets
|
||||
|
||||
@property
|
||||
def _transp_info(self):
|
||||
return self._mat._csr_transp_info
|
||||
|
||||
@classmethod
|
||||
def wrap(
|
||||
cls, shape, values, row_indices, row_offsets, column_indices, _transp_info
|
||||
):
|
||||
matrix = cls.__new__(cls)
|
||||
_shape = (values.shape[0],) + shape
|
||||
csr_matrix = SparseCSRTensor._wrap(
|
||||
_shape, values, row_indices, row_offsets, column_indices, _transp_info
|
||||
)
|
||||
matrix._mat = csr_matrix
|
||||
return matrix
|
||||
|
||||
@classmethod
|
||||
def _wrap(cls, csr_matrix):
|
||||
assert isinstance(csr_matrix, SparseCSRTensor)
|
||||
matrix = cls.__new__(cls)
|
||||
matrix._mat = csr_matrix
|
||||
return matrix
|
||||
|
||||
def __mul__(self, other):
|
||||
assert isinstance(other, (int, float))
|
||||
return type(self)._wrap(self._mat * other)
|
||||
|
||||
def __add__(self, other):
|
||||
assert isinstance(other, type(self))
|
||||
return type(self)._wrap(self._mat + other._mat)
|
||||
|
||||
def matmul_with_mask(self, a, b):
|
||||
return type(self)._wrap(masked_matmul(a, b, self._mat))
|
||||
|
||||
def softmax(self):
|
||||
out = torch.nn.functional.softmax(self._mat, -1)
|
||||
return type(self)._wrap(out)
|
||||
|
||||
def spmm(self, b):
|
||||
out = torch.bmm(self._mat, b)
|
||||
return out
|
||||
|
||||
def transpose(self):
|
||||
out = torch.transpose(self._mat, -2, -1)
|
||||
return type(self)._wrap(out)
|
||||
|
||||
def to(self, device):
|
||||
assert isinstance(device, torch.device)
|
||||
out = self._mat.to(device)
|
||||
return type(self)._wrap(out)
|
||||
|
||||
def to_dense(self):
|
||||
return self._mat.to_dense()
|
||||
|
||||
def logical_and(self, other: torch.Tensor):
|
||||
assert not isinstance(other, SparseCS)
|
||||
out = torch.logical_and(self._mat, other)
|
||||
return type(self)._wrap(out)
|
||||
|
||||
def __and__(self, other):
|
||||
return self.logical_and(other)
|
||||
Reference in New Issue
Block a user