Files
2025-08-05 19:02:46 +08:00

439 lines
14 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import torch
from xformers.ops import masked_matmul
from xformers.sparse import _csr_ops
from xformers.sparse.utils import (
_csr_to_coo,
_dense3d_to_sparse,
_diffsort,
_get_transpose_info,
_transpose_with_info,
)
class SparseCSRTensor(torch.Tensor):
@staticmethod
def __new__(cls, row_offsets, column_indices, values, shape):
kwargs = {}
kwargs["device"] = values.device
kwargs["dtype"] = values.dtype
kwargs["layout"] = values.layout
kwargs["requires_grad"] = values.requires_grad
assert len(shape) == 3
assert torch.__version__ > (1, 10), "SparseCSRTensor requires PyTorch 1.11+"
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs)
def __init__(self, row_offsets, column_indices, values, shape):
assert row_offsets.ndim == 1
assert column_indices.ndim == 1
assert values.ndim == 2
self.__row_offsets = row_offsets.contiguous()
self.__row_indices = _diffsort(row_offsets).to(row_offsets.dtype)
self.__column_indices = column_indices.contiguous()
self.__values = values.contiguous()
self.__transp_info = _get_transpose_info(
self.shape[1],
self.shape[2],
self.__row_indices,
self.__row_offsets,
self.__column_indices,
)
def __repr__(self):
return f"sparse_csr_tensor(shape={self.shape}, values={self.__values})"
@classmethod
def from_dense(cls, matrix):
values, row_indices, row_offsets, column_indices = _dense3d_to_sparse(
matrix, matrix.device
)
return cls(row_offsets, column_indices, values, matrix.shape)
@classmethod
def from_sparse_coo(cls, arg0):
"""
assert arg0.is_sparse
x = arg0.coalesce()
rows, cols = x.indices().unbind(0)
vals = x.values()
_coo_to_csr()
"""
pass
@classmethod
def _wrap(
cls, shape, values, row_indices, row_offsets, column_indices, _transp_info
):
matrix = cls.__new__(cls, row_offsets, column_indices, values, shape)
matrix.__values = values
matrix.__row_indices = row_indices
matrix.__row_offsets = row_offsets
matrix.__column_indices = column_indices
matrix.__transp_info = _transp_info
return matrix
def values(self):
return self.__values
@property
def _csr_row_indices(self):
return self.__row_indices
@property
def _csr_row_offsets(self):
return self.__row_offsets
@property
def _csr_column_indices(self):
return self.__column_indices
@property
def _csr_transp_info(self):
return self.__transp_info
@classmethod
def _bmm(cls, arg0, arg1):
if not (isinstance(arg0, cls) and type(arg1) == torch.Tensor):
return NotImplemented
assert arg0.ndim == 3
assert arg1.ndim == 3
self = arg0
b = arg1
_, m, n = self.shape
row_indices = self.__row_indices
values = self.__values
row_offsets = self.__row_offsets
column_indices = self.__column_indices
out = _csr_ops._spmm.apply(
b, row_indices, values, row_offsets, column_indices, m, self.__transp_info
)
return out
@classmethod
def _softmax(cls, arg0, dim):
if not (dim == -1 or dim == 2):
return NotImplemented
self = arg0
_, m, n = self.shape
row_indices = self.__row_indices
values = self.__values
row_offsets = self.__row_offsets
column_indices = self.__column_indices
out = _csr_ops._SparseSoftmax.apply(
m, n, row_indices, values, row_offsets, column_indices
)
return cls._wrap(
self.shape,
out,
row_indices,
row_offsets,
column_indices,
self.__transp_info,
)
@classmethod
def _transpose(cls, arg0, dim0, dim1):
# TODO: check if need to return this or not
if not (dim0 == 1 or dim0 == -2):
return NotImplemented
if not (dim1 == 2 or dim1 == -1):
return NotImplemented
B, m, n = arg0.shape
values = arg0.__values
(
output_row_indices,
output_values,
output_row_offsets,
output_column_indices,
) = _transpose_with_info(values, arg0.__transp_info)
new_transp_info = _get_transpose_info(
n, m, output_row_indices, output_row_offsets, output_column_indices
)
return cls._wrap(
(B, n, m),
output_values,
output_row_indices,
output_row_offsets,
output_column_indices,
new_transp_info,
)
@classmethod
def _masked_matmul(cls, a, b, mask):
if not (type(a) == torch.Tensor and type(b) == torch.Tensor):
return NotImplemented
assert mask.shape[1] == a.shape[1]
assert mask.shape[2] == b.shape[2]
row_indices = mask.__row_indices
row_offsets = mask.__row_offsets
column_indices = mask.__column_indices
a = a.contiguous()
out = _csr_ops._sddmm.apply(
a,
b.transpose(-2, -1).contiguous(),
row_indices,
row_offsets,
column_indices,
mask.__transp_info,
)
# TODO add bias here
return cls._wrap(
mask.shape,
out,
row_indices,
row_offsets,
column_indices,
mask.__transp_info,
)
@classmethod
def _to(cls, arg0, device):
if isinstance(device, str):
device = torch.device(device)
assert isinstance(device, torch.device)
return cls._wrap(
arg0.shape,
arg0.__values.to(device=device),
arg0.__row_indices.to(device=device),
arg0.__row_offsets.to(device=device),
arg0.__column_indices.to(device=device),
tuple(t.to(device=device) for t in arg0.__transp_info),
)
@classmethod
def _copy(cls, arg0, arg1):
if not (isinstance(arg0, cls) and isinstance(arg1, cls)):
return NotImplemented
assert arg0.shape == arg1.shape
av0, av1 = arg0.__values, arg1.__values
av0.resize_as_(av1).copy_(av1)
av0, av1 = arg0.__row_indices, arg1.__row_indices
av0.resize_as_(av1).copy_(av1)
av0, av1 = arg0.__row_offsets, arg1.__row_offsets
av0.resize_as_(av1).copy_(av1)
av0, av1 = arg0.__column_indices, arg1.__column_indices
av0.resize_as_(av1).copy_(av1)
for v0, v1 in zip(arg0.__transp_info, arg1.__transp_info):
v0.resize_as_(v1).copy_(v1)
return arg0
@classmethod
def _equal(cls, arg0, arg1):
if not (isinstance(arg0, cls) and isinstance(arg1, cls)):
return NotImplemented
if arg0.shape != arg1.shape:
return False
if not torch.equal(arg0.__values, arg1.__values):
return False
if not torch.equal(arg0.__row_offsets, arg1.__row_offsets):
return False
if not torch.equal(arg0.__column_indices, arg1.__column_indices):
return False
return True
@classmethod
def _to_dense(cls, arg0):
_, m, n = arg0.shape
shape = arg0.shape
matrix = torch.zeros(shape, dtype=arg0.dtype, device=arg0.device)
row_offsets = arg0.__row_offsets.long()
column_indices = arg0.__column_indices.long()
row_coo, _ = _csr_to_coo(m, n, row_offsets, column_indices)
b_idxs = torch.arange(len(arg0.__values), device=arg0.device)[:, None]
matrix[b_idxs, row_coo, column_indices] = arg0.__values
return matrix
@classmethod
def _binary_op(cls, func, arg0, arg1):
if not (
isinstance(arg0, (cls, int, float)) and isinstance(arg1, (cls, int, float))
):
return NotImplemented
v0, v1 = arg0, arg1
if isinstance(arg0, cls):
v0 = arg0.__values
if isinstance(arg1, cls):
v1 = arg1.__values
# assert arg0.shape == arg1.shape
if isinstance(arg0, cls) and isinstance(arg1, cls):
msg = f"arg0 and arg1 need to have the same sparsity pattern in {func} (for now)"
if not arg0.__row_offsets.shape == arg1.__row_offsets.shape:
raise NotImplementedError(msg)
if not arg0.__column_indices.shape == arg1.__column_indices.shape:
raise NotImplementedError(msg)
if not arg0.__values.shape == arg1.__values.shape:
raise NotImplementedError(msg)
# TODO this is not always true, but is a fast approximation for now
if arg0.__row_offsets is not arg1.__row_offsets:
raise NotImplementedError(msg)
if arg0.__column_indices is not arg1.__column_indices:
raise NotImplementedError(msg)
out = func(v0, v1)
return cls._wrap(
arg0.shape,
out,
arg0.__row_indices,
arg0.__row_offsets,
arg0.__column_indices,
arg0.__transp_info,
)
@classmethod
def _binary_op_slow(cls, func, arg0, arg1):
# assert arg0.shape == arg1.shape
v0, v1 = arg0, arg1
if isinstance(arg0, cls):
v0 = arg0.to_dense()
if isinstance(arg1, cls):
v1 = arg1.to_dense()
out = func(v0, v1)
return cls.from_dense(out)
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
if func in [
torch.Tensor.bmm,
torch.bmm,
torch.Tensor.__matmul__,
torch.matmul,
torch.Tensor.matmul,
]:
assert len(args) == 2
return cls._bmm(args[0], args[1])
if func in [torch.Tensor.softmax, torch.nn.functional.softmax, torch.softmax]:
return cls._softmax(args[0], kwargs["dim"])
if func in [torch.Tensor.transpose, torch.transpose]:
assert len(kwargs) == 0
return cls._transpose(args[0], args[1], args[2])
if func == masked_matmul:
assert len(args) == 3
return cls._masked_matmul(args[0], args[1], args[2])
if func in [
torch.Tensor.add,
torch.add,
torch.Tensor.__add__,
]:
assert len(args) == 2
if not (isinstance(args[0], cls) and isinstance(args[1], cls)):
raise NotImplementedError(
f"{func} with {type(args[0])} and {type(args[1])} not implemented"
)
return cls._binary_op(func, args[0], args[1])
if func in [
torch.Tensor.mul,
torch.mul,
torch.Tensor.__mul__,
]:
assert len(args) == 2
return cls._binary_op(func, args[0], args[1])
if func in [torch.Tensor.logical_and, torch.logical_and, torch.Tensor.__and__]:
assert len(args) == 2
return cls._binary_op_slow(func, args[0], args[1])
if func in [torch.nn.functional.dropout, torch.dropout, torch.dropout_]:
x = args[0]
values = x.__values.clone()
values = func(values, *args[1:], **kwargs)
return cls._wrap(
x.shape,
values,
x.__row_indices,
x.__row_offsets,
x.__column_indices,
x.__transp_info,
)
if func == torch.Tensor.to:
# print(args, kwargs)
assert len(args) >= 2
return cls._to(args[0], args[1])
# return cls._to(args[0], kwargs["device"])
if func in [torch.Tensor.copy_]:
assert len(args) == 2
return cls._copy(args[0], args[1])
if func in [torch.Tensor.equal, torch.equal]:
assert len(args) == 2
return cls._equal(args[0], args[1])
if func == torch.Tensor.to_dense:
assert len(args) == 1
return cls._to_dense(args[0])
if func == torch.Tensor.detach:
x = args[0]
return cls._wrap(
x.shape,
x.__values.detach(),
x.__row_indices,
x.__row_offsets,
x.__column_indices,
x.__transp_info,
)
if func == torch.Tensor.__deepcopy__:
x = args[0]
memo = args[1]
return cls._wrap(
x.shape,
x.__values.__deepcopy__(memo),
x.__row_indices.__deepcopy__(memo),
x.__row_offsets.__deepcopy__(memo),
x.__column_indices.__deepcopy__(memo),
tuple(v.__deepcopy__(memo) for v in x.__transp_info),
)
if func in [torch.Tensor.grad.__get__, torch.Tensor._grad.__get__]:
assert len(args) == 1
assert len(kwargs) == 0
x = args[0]
return cls._wrap(
x.shape,
x.__values.grad,
x.__row_indices,
x.__row_offsets,
x.__column_indices,
x.__transp_info,
)
if func == torch.Tensor.requires_grad_:
func(args[0].__values)
with torch._C.DisableTorchFunction():
ret = func(*args, **kwargs)
# TODO: check this
if func in torch.overrides.get_default_nowrap_functions():
return ret
return torch._tensor._convert(ret, cls)
return NotImplemented
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
return NotImplemented