# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. import torch from xformers.ops import masked_matmul from xformers.sparse import _csr_ops from xformers.sparse.utils import ( _csr_to_coo, _dense3d_to_sparse, _diffsort, _get_transpose_info, _transpose_with_info, ) class SparseCSRTensor(torch.Tensor): @staticmethod def __new__(cls, row_offsets, column_indices, values, shape): kwargs = {} kwargs["device"] = values.device kwargs["dtype"] = values.dtype kwargs["layout"] = values.layout kwargs["requires_grad"] = values.requires_grad assert len(shape) == 3 assert torch.__version__ > (1, 10), "SparseCSRTensor requires PyTorch 1.11+" return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) def __init__(self, row_offsets, column_indices, values, shape): assert row_offsets.ndim == 1 assert column_indices.ndim == 1 assert values.ndim == 2 self.__row_offsets = row_offsets.contiguous() self.__row_indices = _diffsort(row_offsets).to(row_offsets.dtype) self.__column_indices = column_indices.contiguous() self.__values = values.contiguous() self.__transp_info = _get_transpose_info( self.shape[1], self.shape[2], self.__row_indices, self.__row_offsets, self.__column_indices, ) def __repr__(self): return f"sparse_csr_tensor(shape={self.shape}, values={self.__values})" @classmethod def from_dense(cls, matrix): values, row_indices, row_offsets, column_indices = _dense3d_to_sparse( matrix, matrix.device ) return cls(row_offsets, column_indices, values, matrix.shape) @classmethod def from_sparse_coo(cls, arg0): """ assert arg0.is_sparse x = arg0.coalesce() rows, cols = x.indices().unbind(0) vals = x.values() _coo_to_csr() """ pass @classmethod def _wrap( cls, shape, values, row_indices, row_offsets, column_indices, _transp_info ): matrix = cls.__new__(cls, row_offsets, column_indices, values, shape) matrix.__values = values matrix.__row_indices = row_indices matrix.__row_offsets = row_offsets matrix.__column_indices = column_indices matrix.__transp_info = _transp_info return matrix def values(self): return self.__values @property def _csr_row_indices(self): return self.__row_indices @property def _csr_row_offsets(self): return self.__row_offsets @property def _csr_column_indices(self): return self.__column_indices @property def _csr_transp_info(self): return self.__transp_info @classmethod def _bmm(cls, arg0, arg1): if not (isinstance(arg0, cls) and type(arg1) == torch.Tensor): return NotImplemented assert arg0.ndim == 3 assert arg1.ndim == 3 self = arg0 b = arg1 _, m, n = self.shape row_indices = self.__row_indices values = self.__values row_offsets = self.__row_offsets column_indices = self.__column_indices out = _csr_ops._spmm.apply( b, row_indices, values, row_offsets, column_indices, m, self.__transp_info ) return out @classmethod def _softmax(cls, arg0, dim): if not (dim == -1 or dim == 2): return NotImplemented self = arg0 _, m, n = self.shape row_indices = self.__row_indices values = self.__values row_offsets = self.__row_offsets column_indices = self.__column_indices out = _csr_ops._SparseSoftmax.apply( m, n, row_indices, values, row_offsets, column_indices ) return cls._wrap( self.shape, out, row_indices, row_offsets, column_indices, self.__transp_info, ) @classmethod def _transpose(cls, arg0, dim0, dim1): # TODO: check if need to return this or not if not (dim0 == 1 or dim0 == -2): return NotImplemented if not (dim1 == 2 or dim1 == -1): return NotImplemented B, m, n = arg0.shape values = arg0.__values ( output_row_indices, output_values, output_row_offsets, output_column_indices, ) = _transpose_with_info(values, arg0.__transp_info) new_transp_info = _get_transpose_info( n, m, output_row_indices, output_row_offsets, output_column_indices ) return cls._wrap( (B, n, m), output_values, output_row_indices, output_row_offsets, output_column_indices, new_transp_info, ) @classmethod def _masked_matmul(cls, a, b, mask): if not (type(a) == torch.Tensor and type(b) == torch.Tensor): return NotImplemented assert mask.shape[1] == a.shape[1] assert mask.shape[2] == b.shape[2] row_indices = mask.__row_indices row_offsets = mask.__row_offsets column_indices = mask.__column_indices a = a.contiguous() out = _csr_ops._sddmm.apply( a, b.transpose(-2, -1).contiguous(), row_indices, row_offsets, column_indices, mask.__transp_info, ) # TODO add bias here return cls._wrap( mask.shape, out, row_indices, row_offsets, column_indices, mask.__transp_info, ) @classmethod def _to(cls, arg0, device): if isinstance(device, str): device = torch.device(device) assert isinstance(device, torch.device) return cls._wrap( arg0.shape, arg0.__values.to(device=device), arg0.__row_indices.to(device=device), arg0.__row_offsets.to(device=device), arg0.__column_indices.to(device=device), tuple(t.to(device=device) for t in arg0.__transp_info), ) @classmethod def _copy(cls, arg0, arg1): if not (isinstance(arg0, cls) and isinstance(arg1, cls)): return NotImplemented assert arg0.shape == arg1.shape av0, av1 = arg0.__values, arg1.__values av0.resize_as_(av1).copy_(av1) av0, av1 = arg0.__row_indices, arg1.__row_indices av0.resize_as_(av1).copy_(av1) av0, av1 = arg0.__row_offsets, arg1.__row_offsets av0.resize_as_(av1).copy_(av1) av0, av1 = arg0.__column_indices, arg1.__column_indices av0.resize_as_(av1).copy_(av1) for v0, v1 in zip(arg0.__transp_info, arg1.__transp_info): v0.resize_as_(v1).copy_(v1) return arg0 @classmethod def _equal(cls, arg0, arg1): if not (isinstance(arg0, cls) and isinstance(arg1, cls)): return NotImplemented if arg0.shape != arg1.shape: return False if not torch.equal(arg0.__values, arg1.__values): return False if not torch.equal(arg0.__row_offsets, arg1.__row_offsets): return False if not torch.equal(arg0.__column_indices, arg1.__column_indices): return False return True @classmethod def _to_dense(cls, arg0): _, m, n = arg0.shape shape = arg0.shape matrix = torch.zeros(shape, dtype=arg0.dtype, device=arg0.device) row_offsets = arg0.__row_offsets.long() column_indices = arg0.__column_indices.long() row_coo, _ = _csr_to_coo(m, n, row_offsets, column_indices) b_idxs = torch.arange(len(arg0.__values), device=arg0.device)[:, None] matrix[b_idxs, row_coo, column_indices] = arg0.__values return matrix @classmethod def _binary_op(cls, func, arg0, arg1): if not ( isinstance(arg0, (cls, int, float)) and isinstance(arg1, (cls, int, float)) ): return NotImplemented v0, v1 = arg0, arg1 if isinstance(arg0, cls): v0 = arg0.__values if isinstance(arg1, cls): v1 = arg1.__values # assert arg0.shape == arg1.shape if isinstance(arg0, cls) and isinstance(arg1, cls): msg = f"arg0 and arg1 need to have the same sparsity pattern in {func} (for now)" if not arg0.__row_offsets.shape == arg1.__row_offsets.shape: raise NotImplementedError(msg) if not arg0.__column_indices.shape == arg1.__column_indices.shape: raise NotImplementedError(msg) if not arg0.__values.shape == arg1.__values.shape: raise NotImplementedError(msg) # TODO this is not always true, but is a fast approximation for now if arg0.__row_offsets is not arg1.__row_offsets: raise NotImplementedError(msg) if arg0.__column_indices is not arg1.__column_indices: raise NotImplementedError(msg) out = func(v0, v1) return cls._wrap( arg0.shape, out, arg0.__row_indices, arg0.__row_offsets, arg0.__column_indices, arg0.__transp_info, ) @classmethod def _binary_op_slow(cls, func, arg0, arg1): # assert arg0.shape == arg1.shape v0, v1 = arg0, arg1 if isinstance(arg0, cls): v0 = arg0.to_dense() if isinstance(arg1, cls): v1 = arg1.to_dense() out = func(v0, v1) return cls.from_dense(out) @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} if func in [ torch.Tensor.bmm, torch.bmm, torch.Tensor.__matmul__, torch.matmul, torch.Tensor.matmul, ]: assert len(args) == 2 return cls._bmm(args[0], args[1]) if func in [torch.Tensor.softmax, torch.nn.functional.softmax, torch.softmax]: return cls._softmax(args[0], kwargs["dim"]) if func in [torch.Tensor.transpose, torch.transpose]: assert len(kwargs) == 0 return cls._transpose(args[0], args[1], args[2]) if func == masked_matmul: assert len(args) == 3 return cls._masked_matmul(args[0], args[1], args[2]) if func in [ torch.Tensor.add, torch.add, torch.Tensor.__add__, ]: assert len(args) == 2 if not (isinstance(args[0], cls) and isinstance(args[1], cls)): raise NotImplementedError( f"{func} with {type(args[0])} and {type(args[1])} not implemented" ) return cls._binary_op(func, args[0], args[1]) if func in [ torch.Tensor.mul, torch.mul, torch.Tensor.__mul__, ]: assert len(args) == 2 return cls._binary_op(func, args[0], args[1]) if func in [torch.Tensor.logical_and, torch.logical_and, torch.Tensor.__and__]: assert len(args) == 2 return cls._binary_op_slow(func, args[0], args[1]) if func in [torch.nn.functional.dropout, torch.dropout, torch.dropout_]: x = args[0] values = x.__values.clone() values = func(values, *args[1:], **kwargs) return cls._wrap( x.shape, values, x.__row_indices, x.__row_offsets, x.__column_indices, x.__transp_info, ) if func == torch.Tensor.to: # print(args, kwargs) assert len(args) >= 2 return cls._to(args[0], args[1]) # return cls._to(args[0], kwargs["device"]) if func in [torch.Tensor.copy_]: assert len(args) == 2 return cls._copy(args[0], args[1]) if func in [torch.Tensor.equal, torch.equal]: assert len(args) == 2 return cls._equal(args[0], args[1]) if func == torch.Tensor.to_dense: assert len(args) == 1 return cls._to_dense(args[0]) if func == torch.Tensor.detach: x = args[0] return cls._wrap( x.shape, x.__values.detach(), x.__row_indices, x.__row_offsets, x.__column_indices, x.__transp_info, ) if func == torch.Tensor.__deepcopy__: x = args[0] memo = args[1] return cls._wrap( x.shape, x.__values.__deepcopy__(memo), x.__row_indices.__deepcopy__(memo), x.__row_offsets.__deepcopy__(memo), x.__column_indices.__deepcopy__(memo), tuple(v.__deepcopy__(memo) for v in x.__transp_info), ) if func in [torch.Tensor.grad.__get__, torch.Tensor._grad.__get__]: assert len(args) == 1 assert len(kwargs) == 0 x = args[0] return cls._wrap( x.shape, x.__values.grad, x.__row_indices, x.__row_offsets, x.__column_indices, x.__transp_info, ) if func == torch.Tensor.requires_grad_: func(args[0].__values) with torch._C.DisableTorchFunction(): ret = func(*args, **kwargs) # TODO: check this if func in torch.overrides.get_default_nowrap_functions(): return ret return torch._tensor._convert(ret, cls) return NotImplemented @classmethod def __torch_dispatch__(cls, func, types, args, kwargs): return NotImplemented