122 lines
3.1 KiB
Python
122 lines
3.1 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
import torch
|
|
|
|
from xformers.ops import masked_matmul
|
|
from xformers.sparse import SparseCSRTensor
|
|
|
|
# TODO: this is here for BC
|
|
from xformers.sparse.utils import _csr_to_coo, _dense_to_sparse # noqa: F401
|
|
|
|
|
|
class SparseCS:
|
|
def __init__(self, matrix, device=None):
|
|
if device is None:
|
|
device = torch.device("cpu")
|
|
if matrix.ndim == 2:
|
|
matrix = matrix[None]
|
|
assert matrix.ndim == 3
|
|
self._mat = SparseCSRTensor.from_dense(matrix).to(device)
|
|
|
|
@property
|
|
def device(self):
|
|
return self._mat.device
|
|
|
|
@property
|
|
def ndim(self):
|
|
return self._mat.ndim
|
|
|
|
@property
|
|
def dtype(self):
|
|
return self._mat.dtype
|
|
|
|
@property
|
|
def is_sparse(self):
|
|
return True
|
|
|
|
@property
|
|
def shape(self):
|
|
return self._mat.shape[1:]
|
|
|
|
@property
|
|
def values(self):
|
|
return self._mat.values()
|
|
|
|
@property
|
|
def row_indices(self):
|
|
return self._mat._csr_row_indices
|
|
|
|
@property
|
|
def column_indices(self):
|
|
return self._mat._csr_column_indices
|
|
|
|
@property
|
|
def row_offsets(self):
|
|
return self._mat._csr_row_offsets
|
|
|
|
@property
|
|
def _transp_info(self):
|
|
return self._mat._csr_transp_info
|
|
|
|
@classmethod
|
|
def wrap(
|
|
cls, shape, values, row_indices, row_offsets, column_indices, _transp_info
|
|
):
|
|
matrix = cls.__new__(cls)
|
|
_shape = (values.shape[0],) + shape
|
|
csr_matrix = SparseCSRTensor._wrap(
|
|
_shape, values, row_indices, row_offsets, column_indices, _transp_info
|
|
)
|
|
matrix._mat = csr_matrix
|
|
return matrix
|
|
|
|
@classmethod
|
|
def _wrap(cls, csr_matrix):
|
|
assert isinstance(csr_matrix, SparseCSRTensor)
|
|
matrix = cls.__new__(cls)
|
|
matrix._mat = csr_matrix
|
|
return matrix
|
|
|
|
def __mul__(self, other):
|
|
assert isinstance(other, (int, float))
|
|
return type(self)._wrap(self._mat * other)
|
|
|
|
def __add__(self, other):
|
|
assert isinstance(other, type(self))
|
|
return type(self)._wrap(self._mat + other._mat)
|
|
|
|
def matmul_with_mask(self, a, b):
|
|
return type(self)._wrap(masked_matmul(a, b, self._mat))
|
|
|
|
def softmax(self):
|
|
out = torch.nn.functional.softmax(self._mat, -1)
|
|
return type(self)._wrap(out)
|
|
|
|
def spmm(self, b):
|
|
out = torch.bmm(self._mat, b)
|
|
return out
|
|
|
|
def transpose(self):
|
|
out = torch.transpose(self._mat, -2, -1)
|
|
return type(self)._wrap(out)
|
|
|
|
def to(self, device):
|
|
assert isinstance(device, torch.device)
|
|
out = self._mat.to(device)
|
|
return type(self)._wrap(out)
|
|
|
|
def to_dense(self):
|
|
return self._mat.to_dense()
|
|
|
|
def logical_and(self, other: torch.Tensor):
|
|
assert not isinstance(other, SparseCS)
|
|
out = torch.logical_and(self._mat, other)
|
|
return type(self)._wrap(out)
|
|
|
|
def __and__(self, other):
|
|
return self.logical_and(other)
|