First commit
This commit is contained in:
7
pkgs/xformers/sparse/__init__.py
Normal file
7
pkgs/xformers/sparse/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from .blocksparse_tensor import BlockSparseTensor # noqa: F401
|
||||
from .csr_tensor import SparseCSRTensor # noqa: F401
|
||||
BIN
pkgs/xformers/sparse/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
pkgs/xformers/sparse/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/xformers/sparse/__pycache__/_csr_ops.cpython-310.pyc
Normal file
BIN
pkgs/xformers/sparse/__pycache__/_csr_ops.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
pkgs/xformers/sparse/__pycache__/csr_tensor.cpython-310.pyc
Normal file
BIN
pkgs/xformers/sparse/__pycache__/csr_tensor.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/xformers/sparse/__pycache__/utils.cpython-310.pyc
Normal file
BIN
pkgs/xformers/sparse/__pycache__/utils.cpython-310.pyc
Normal file
Binary file not shown.
166
pkgs/xformers/sparse/_csr_ops.py
Normal file
166
pkgs/xformers/sparse/_csr_ops.py
Normal file
@@ -0,0 +1,166 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
from .utils import _csr_to_coo, _transpose_with_info
|
||||
|
||||
|
||||
def _should_use_coo(a, sparsity):
|
||||
if not a.is_cuda:
|
||||
return False
|
||||
B, M, K = a.shape
|
||||
# amortize overhead of converting from csr to coo
|
||||
if B < 32 and M < 4096:
|
||||
return False
|
||||
if sparsity > 0.995:
|
||||
return False
|
||||
if sparsity < 0.9:
|
||||
return False
|
||||
if K > 64:
|
||||
return False
|
||||
# let's be overly cautious here for now
|
||||
return sparsity > 0.97
|
||||
|
||||
|
||||
def _should_use_csr_ge(a, sparsity):
|
||||
if not a.is_cuda:
|
||||
return False
|
||||
return sparsity > 0.99
|
||||
|
||||
|
||||
def _sddmm_func(a, b, row_indices, row_offsets, column_indices):
|
||||
sparsity = 1 - column_indices.shape[0] / (a.shape[1] * b.shape[1])
|
||||
if _should_use_coo(a, sparsity):
|
||||
m = a.shape[-2]
|
||||
n = b.shape[-2]
|
||||
# converting from csr to coo has a constant overhead of ~150us
|
||||
# so only dispatch to it for reasonably large problem sizes
|
||||
ro, ci = _csr_to_coo(m, n, row_offsets, column_indices)
|
||||
return torch.ops.xformers.coo_sddmm(a, b, row_indices, ro, ci)
|
||||
elif _should_use_csr_ge(a, sparsity):
|
||||
return torch.ops.xformers.csr_sddmm(
|
||||
a, b, row_indices, row_offsets, column_indices
|
||||
)
|
||||
return torch.ops.xformers.sddmm_sputnik(
|
||||
a, b, row_indices, row_offsets, column_indices
|
||||
)
|
||||
|
||||
|
||||
class _SparseSoftmax(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, m, n, row_indices, values, row_offsets, column_indices):
|
||||
out = torch.ops.xformers.sparse_softmax_sputnik(
|
||||
m, n, row_indices, values, row_offsets, column_indices
|
||||
)
|
||||
# note: save out and not values, as an optimization step
|
||||
ctx.save_for_backward(row_indices, out, row_offsets, column_indices)
|
||||
ctx.size = (m, n)
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad):
|
||||
row_indices, out, row_offsets, column_indices = ctx.saved_tensors
|
||||
m, n = ctx.size
|
||||
|
||||
# gradients w.r.t. values
|
||||
grad = grad.contiguous()
|
||||
ga = torch.ops.xformers.sparse_softmax_backward_sputnik(
|
||||
m, n, row_indices, out, grad, row_offsets, column_indices
|
||||
)
|
||||
|
||||
return None, None, None, ga, None, None
|
||||
|
||||
|
||||
class _sddmm(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, a, b, row_indices, row_offsets, column_indices, _transp_info):
|
||||
out = _sddmm_func(a, b, row_indices, row_offsets, column_indices)
|
||||
|
||||
ctx.save_for_backward(
|
||||
a, b, row_indices, row_offsets, column_indices, *_transp_info
|
||||
)
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad):
|
||||
(
|
||||
a,
|
||||
b,
|
||||
row_indices,
|
||||
row_offsets,
|
||||
column_indices,
|
||||
*_transp_info,
|
||||
) = ctx.saved_tensors
|
||||
m, n = a.shape[1], b.shape[1]
|
||||
|
||||
# gradients w.r.t. values
|
||||
grad = grad.contiguous()
|
||||
a = a.contiguous()
|
||||
b = b.contiguous()
|
||||
|
||||
a_grad = torch.ops.xformers.spmm_sputnik(
|
||||
b, row_indices, grad, row_offsets, column_indices, m
|
||||
)
|
||||
|
||||
(
|
||||
row_indices_t,
|
||||
grad_t,
|
||||
row_offsets_t,
|
||||
column_indices_t,
|
||||
) = _transpose_with_info(grad, _transp_info)
|
||||
|
||||
b_grad = torch.ops.xformers.spmm_sputnik(
|
||||
a, row_indices_t, grad_t, row_offsets_t, column_indices_t, n
|
||||
)
|
||||
|
||||
return a_grad, b_grad, None, None, None, None
|
||||
|
||||
|
||||
class _spmm(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx, b, row_indices, values, row_offsets, column_indices, m, _transp_info
|
||||
):
|
||||
b = b.contiguous()
|
||||
out = torch.ops.xformers.spmm_sputnik(
|
||||
b, row_indices, values, row_offsets, column_indices, m
|
||||
)
|
||||
|
||||
ctx.save_for_backward(
|
||||
b, row_indices, values, row_offsets, column_indices, *_transp_info
|
||||
)
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad):
|
||||
(
|
||||
b,
|
||||
row_indices,
|
||||
values,
|
||||
row_offsets,
|
||||
column_indices,
|
||||
*_transp_info,
|
||||
) = ctx.saved_tensors
|
||||
k = b.shape[1]
|
||||
|
||||
# gradients w.r.t. values
|
||||
grad = grad.contiguous()
|
||||
|
||||
grad_sparse = _sddmm_func(grad, b, row_indices, row_offsets, column_indices)
|
||||
|
||||
(
|
||||
row_indices_t,
|
||||
values_t,
|
||||
row_offsets_t,
|
||||
column_indices_t,
|
||||
) = _transpose_with_info(values, _transp_info)
|
||||
|
||||
grad_dense = torch.ops.xformers.spmm_sputnik(
|
||||
grad, row_indices_t, values_t, row_offsets_t, column_indices_t, k
|
||||
)
|
||||
|
||||
return grad_dense, None, grad_sparse, None, None, None, None
|
||||
357
pkgs/xformers/sparse/blocksparse_tensor.py
Normal file
357
pkgs/xformers/sparse/blocksparse_tensor.py
Normal file
@@ -0,0 +1,357 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
|
||||
import torch
|
||||
|
||||
from xformers import _is_triton_available
|
||||
from xformers.ops import masked_matmul
|
||||
|
||||
logger = logging.getLogger("xformers")
|
||||
|
||||
|
||||
try:
|
||||
if not _is_triton_available():
|
||||
raise ImportError("triton is not available")
|
||||
from triton.ops.blocksparse import matmul as blocksparse_matmul
|
||||
from triton.ops.blocksparse import softmax as blocksparse_softmax
|
||||
except ImportError as e:
|
||||
logger.warning(
|
||||
"Triton is not available, some optimizations will not be enabled.\n"
|
||||
+ f"This is just a warning: {e}"
|
||||
)
|
||||
blocksparse_matmul = None
|
||||
blocksparse_softmax = None
|
||||
|
||||
|
||||
def _can_use_triton(a):
|
||||
if a.device.type == "cpu":
|
||||
return False
|
||||
|
||||
if blocksparse_matmul is None:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def _spmm(b, layout, values):
|
||||
N, nnz, _, block_size = values.shape
|
||||
br = b.reshape(
|
||||
b.shape[0], b.shape[1], b.shape[2] // block_size, block_size, b.shape[3]
|
||||
)
|
||||
# perform matmul on blocks
|
||||
h, r, c = layout.nonzero(as_tuple=True)
|
||||
temp = values @ br[:, h, c, :]
|
||||
|
||||
linear_idx = h * (b.shape[2] // block_size) + r
|
||||
out = torch.zeros(
|
||||
N,
|
||||
b.shape[1] * layout.shape[-2],
|
||||
block_size,
|
||||
b.shape[3],
|
||||
dtype=b.dtype,
|
||||
device=b.device,
|
||||
)
|
||||
# now aggregate the results of the different blocks
|
||||
out.index_add_(1, linear_idx.to(b.device), temp)
|
||||
out = out.reshape(N, b.shape[1], -1, b.shape[3])
|
||||
return out
|
||||
|
||||
|
||||
def _softmax(layout, values):
|
||||
h, r, c = layout.nonzero(as_tuple=True)
|
||||
norms = torch.logsumexp(values, dim=-1, keepdim=True)
|
||||
linear_idx = h * layout.shape[1] + r
|
||||
|
||||
out_t = torch.zeros(
|
||||
norms.shape[0],
|
||||
layout.shape[0] * layout.shape[1],
|
||||
norms.shape[2],
|
||||
norms.shape[3],
|
||||
dtype=norms.dtype,
|
||||
device=norms.device,
|
||||
)
|
||||
max_val = norms.max()
|
||||
out_t.index_add_(
|
||||
1, linear_idx.to(values.device), (norms - max_val).exp()
|
||||
).clamp_min_(1e-24).log_().add_(max_val)
|
||||
out = torch.exp(values - out_t[:, linear_idx])
|
||||
return out
|
||||
|
||||
|
||||
def _sddmm(a, b, layout):
|
||||
block_size = a.shape[-2] // layout.shape[-2]
|
||||
a = a.reshape(
|
||||
a.shape[0], a.shape[1], a.shape[2] // block_size, block_size, a.shape[3]
|
||||
)
|
||||
b = b.reshape(
|
||||
b.shape[0], b.shape[1], b.shape[2] // block_size, block_size, b.shape[3]
|
||||
)
|
||||
|
||||
h, r, c = layout.nonzero(as_tuple=True)
|
||||
|
||||
out = torch.einsum("nhik,nhjk->nhij", a[:, h, r, :, :], b[:, h, c, :, :])
|
||||
return out
|
||||
|
||||
|
||||
class BlockSparseTensor(torch.Tensor):
|
||||
@staticmethod
|
||||
def __new__(cls, values, layout):
|
||||
kwargs = {}
|
||||
kwargs["device"] = values.device
|
||||
kwargs["dtype"] = values.dtype
|
||||
kwargs["layout"] = values.layout
|
||||
kwargs["requires_grad"] = values.requires_grad
|
||||
assert values.ndim == 4
|
||||
B, _, block_size, _ = values.shape
|
||||
C, h, w = layout.shape
|
||||
# TODO validate shape of layout vs values
|
||||
shape = (B, C, block_size * h, block_size * w)
|
||||
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs)
|
||||
|
||||
def __init__(self, values, layout):
|
||||
assert values.shape[-2] == values.shape[-1]
|
||||
assert (
|
||||
values.device == layout.device
|
||||
), "Both values and layout need to reside on the same device"
|
||||
block_size = values.shape[-1]
|
||||
# TODO: make this check conditioned on the use of Triton
|
||||
assert block_size >= 16, "Minimum block size is 16, for now at least"
|
||||
|
||||
# Pure blocksparse data
|
||||
self.__values = values
|
||||
self.__layout = layout
|
||||
|
||||
# blocksparse operators for triton
|
||||
if blocksparse_matmul:
|
||||
self._initialize_triton_ops()
|
||||
else:
|
||||
self.__sparse_dot_sdd = None
|
||||
self.__sparse_dot_dsd = None
|
||||
self.__sparse_softmax = None
|
||||
|
||||
def _initialize_triton_ops(self):
|
||||
block_size = self.__values.shape[-1]
|
||||
self.__sparse_dot_sdd = blocksparse_matmul(
|
||||
self.__layout,
|
||||
block_size,
|
||||
"sdd",
|
||||
trans_a=False,
|
||||
trans_b=True,
|
||||
device=self.__layout.device,
|
||||
)
|
||||
self.__sparse_dot_dsd = blocksparse_matmul(
|
||||
self.__layout,
|
||||
block_size,
|
||||
"dsd",
|
||||
trans_a=False,
|
||||
trans_b=False,
|
||||
device=self.__layout.device,
|
||||
)
|
||||
self.__sparse_softmax = blocksparse_softmax(
|
||||
self.__layout, block_size, device=self.__layout.device
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"block_sparse_tensor(shape={self.shape}, values={self.__values})"
|
||||
|
||||
def values(self):
|
||||
return self.__values
|
||||
|
||||
@classmethod
|
||||
def _raw_wrap(cls, values, layout, sparse_dot_sdd, sparse_dot_dsd, sparse_softmax):
|
||||
matrix = cls.__new__(cls, values, layout)
|
||||
matrix.__values = values
|
||||
matrix.__layout = layout
|
||||
matrix.__sparse_dot_sdd = sparse_dot_sdd
|
||||
matrix.__sparse_dot_dsd = sparse_dot_dsd
|
||||
matrix.__sparse_softmax = sparse_softmax
|
||||
return matrix
|
||||
|
||||
@classmethod
|
||||
def _wrap(cls, values, bmat):
|
||||
matrix = cls.__new__(cls, values, bmat.__layout)
|
||||
matrix.__values = values
|
||||
matrix.__layout = bmat.__layout
|
||||
matrix.__sparse_dot_sdd = bmat.__sparse_dot_sdd
|
||||
matrix.__sparse_dot_dsd = bmat.__sparse_dot_dsd
|
||||
matrix.__sparse_softmax = bmat.__sparse_softmax
|
||||
return matrix
|
||||
|
||||
@classmethod
|
||||
def _bmm(cls, arg0, arg1):
|
||||
if not (isinstance(arg0, cls) and type(arg1) == torch.Tensor):
|
||||
return NotImplemented
|
||||
if _can_use_triton(arg1):
|
||||
res = arg0.__sparse_dot_dsd(arg0.__values, arg1)
|
||||
else:
|
||||
res = _spmm(arg1, arg0.__layout, arg0.__values)
|
||||
return res
|
||||
|
||||
@classmethod
|
||||
def _masked_matmul(cls, a, b, mask):
|
||||
if not (type(a) == torch.Tensor and type(b) == torch.Tensor):
|
||||
return NotImplemented
|
||||
b = b.transpose(-2, -1)
|
||||
assert b.is_contiguous()
|
||||
if _can_use_triton(a):
|
||||
res = mask.__sparse_dot_sdd(a, b)
|
||||
else:
|
||||
res = _sddmm(a, b, mask.__layout)
|
||||
return cls._wrap(res, mask)
|
||||
|
||||
@classmethod
|
||||
def _softmax(cls, arg0, dim):
|
||||
if not (dim == -1 or dim == 2):
|
||||
return NotImplemented
|
||||
if _can_use_triton(arg0):
|
||||
res = arg0.__sparse_softmax(arg0.__values)
|
||||
else:
|
||||
res = _softmax(arg0.__layout, arg0.__values)
|
||||
return cls._wrap(res, arg0)
|
||||
|
||||
@classmethod
|
||||
def _to(cls, arg0, device):
|
||||
if isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
assert isinstance(device, torch.device)
|
||||
return cls(
|
||||
arg0.__values.to(device=device),
|
||||
arg0.__layout,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _copy(cls, arg0, arg1):
|
||||
if not (isinstance(arg0, cls) and isinstance(arg1, cls)):
|
||||
return NotImplemented
|
||||
assert arg0.shape == arg1.shape
|
||||
av0, av1 = arg0.__values, arg1.__values
|
||||
av0.resize_as_(av1).copy_(av1)
|
||||
av0, av1 = arg0.__layout, arg1.__layout
|
||||
av0.resize_as_(av1).copy_(av1)
|
||||
out = cls(arg0.__values, arg0.__layout)
|
||||
arg0.__sparse_dot_sdd = out.__sparse_dot_sdd
|
||||
arg0.__sparse_dot_dsd = out.__sparse_dot_dsd
|
||||
arg0.__sparse_softmax = out.__sparse_softmax
|
||||
return arg0
|
||||
|
||||
@classmethod
|
||||
def _equal(cls, arg0, arg1):
|
||||
if not (isinstance(arg0, cls) and isinstance(arg1, cls)):
|
||||
return NotImplemented
|
||||
if arg0.shape != arg1.shape:
|
||||
return False
|
||||
if not torch.equal(arg0.__values, arg1.__values):
|
||||
return False
|
||||
if not torch.equal(arg0.__layout, arg1.__layout):
|
||||
return False
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def _to_dense(cls, arg0):
|
||||
# out = torch.zeros(arg0.shape, dtype=arg0.dtype, device=arg0.device, requires_grad=arg0.requires_grad)
|
||||
out = torch.zeros(arg0.shape, dtype=arg0.dtype, device=arg0.device)
|
||||
values = arg0.__values
|
||||
layout = arg0.__layout
|
||||
block_size = values.shape[-1]
|
||||
blocks_i = layout.shape[-2]
|
||||
blocks_j = layout.shape[-1]
|
||||
|
||||
out_r = out.reshape(
|
||||
arg0.shape[0], arg0.shape[1], blocks_i, block_size, blocks_j, block_size
|
||||
)
|
||||
|
||||
for idx, (h, i, j) in enumerate(zip(*layout.nonzero(as_tuple=True))):
|
||||
out_r[:, h, i, :, j, :] = values[:, idx, :, :]
|
||||
|
||||
return out
|
||||
|
||||
@classmethod
|
||||
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
if func in [
|
||||
torch.Tensor.bmm,
|
||||
torch.bmm,
|
||||
torch.Tensor.__matmul__,
|
||||
torch.matmul,
|
||||
torch.Tensor.matmul,
|
||||
]:
|
||||
assert len(args) == 2
|
||||
return cls._bmm(args[0], args[1])
|
||||
|
||||
if func in [torch.Tensor.softmax, torch.nn.functional.softmax, torch.softmax]:
|
||||
return cls._softmax(args[0], kwargs["dim"])
|
||||
|
||||
if func == masked_matmul:
|
||||
assert len(args) == 3
|
||||
return cls._masked_matmul(args[0], args[1], args[2])
|
||||
|
||||
if func in [torch.nn.functional.dropout, torch.dropout, torch.dropout_]:
|
||||
x = args[0]
|
||||
values = x.__values.clone()
|
||||
values = func(values, *args[1:], **kwargs)
|
||||
return cls._wrap(values, x)
|
||||
|
||||
if func == torch.Tensor.to:
|
||||
# print(args, kwargs)
|
||||
assert len(args) >= 2
|
||||
return cls._to(args[0], args[1])
|
||||
# return cls._to(args[0], kwargs["device"])
|
||||
|
||||
if func in [torch.Tensor.copy_]:
|
||||
assert len(args) == 2
|
||||
return cls._copy(args[0], args[1])
|
||||
|
||||
if func in [torch.Tensor.equal, torch.equal]:
|
||||
assert len(args) == 2
|
||||
return cls._equal(args[0], args[1])
|
||||
|
||||
if func == torch.Tensor.to_dense:
|
||||
assert len(args) == 1
|
||||
return cls._to_dense(args[0])
|
||||
|
||||
if func == torch.Tensor.detach:
|
||||
x = args[0]
|
||||
values = x.__values.clone()
|
||||
values = func(values, *args[1:], **kwargs)
|
||||
return cls._wrap(values, x)
|
||||
|
||||
if func == torch.Tensor.__deepcopy__:
|
||||
x = args[0]
|
||||
memo = args[1]
|
||||
return cls._raw_wrap(
|
||||
x.__values.__deepcopy__(memo),
|
||||
x.__layout.__deepcopy__(memo),
|
||||
# x.__sparse_dot_sdd.__deepcopy__(memo),
|
||||
# x.__sparse_dot_dsd.__deepcopy__(memo),
|
||||
# x.__sparse_softmax.__deepcopy__(memo),
|
||||
x.__sparse_dot_sdd,
|
||||
x.__sparse_dot_dsd,
|
||||
x.__sparse_softmax,
|
||||
)
|
||||
|
||||
if func in [torch.Tensor.grad.__get__, torch.Tensor._grad.__get__]:
|
||||
assert len(args) == 1
|
||||
assert len(kwargs) == 0
|
||||
x = args[0]
|
||||
return cls._wrap(x.__values.grad, x)
|
||||
|
||||
if func == torch.Tensor.requires_grad_:
|
||||
func(args[0].__values)
|
||||
|
||||
with torch._C.DisableTorchFunction():
|
||||
ret = func(*args, **kwargs)
|
||||
# TODO: check this
|
||||
if func in torch.overrides.get_default_nowrap_functions():
|
||||
return ret
|
||||
return torch._tensor._convert(ret, cls)
|
||||
|
||||
return NotImplemented
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args, kwargs):
|
||||
return NotImplemented
|
||||
438
pkgs/xformers/sparse/csr_tensor.py
Normal file
438
pkgs/xformers/sparse/csr_tensor.py
Normal file
@@ -0,0 +1,438 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
from xformers.ops import masked_matmul
|
||||
from xformers.sparse import _csr_ops
|
||||
from xformers.sparse.utils import (
|
||||
_csr_to_coo,
|
||||
_dense3d_to_sparse,
|
||||
_diffsort,
|
||||
_get_transpose_info,
|
||||
_transpose_with_info,
|
||||
)
|
||||
|
||||
|
||||
class SparseCSRTensor(torch.Tensor):
|
||||
@staticmethod
|
||||
def __new__(cls, row_offsets, column_indices, values, shape):
|
||||
kwargs = {}
|
||||
kwargs["device"] = values.device
|
||||
kwargs["dtype"] = values.dtype
|
||||
kwargs["layout"] = values.layout
|
||||
kwargs["requires_grad"] = values.requires_grad
|
||||
assert len(shape) == 3
|
||||
assert torch.__version__ > (1, 10), "SparseCSRTensor requires PyTorch 1.11+"
|
||||
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs)
|
||||
|
||||
def __init__(self, row_offsets, column_indices, values, shape):
|
||||
assert row_offsets.ndim == 1
|
||||
assert column_indices.ndim == 1
|
||||
assert values.ndim == 2
|
||||
|
||||
self.__row_offsets = row_offsets.contiguous()
|
||||
self.__row_indices = _diffsort(row_offsets).to(row_offsets.dtype)
|
||||
self.__column_indices = column_indices.contiguous()
|
||||
self.__values = values.contiguous()
|
||||
|
||||
self.__transp_info = _get_transpose_info(
|
||||
self.shape[1],
|
||||
self.shape[2],
|
||||
self.__row_indices,
|
||||
self.__row_offsets,
|
||||
self.__column_indices,
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"sparse_csr_tensor(shape={self.shape}, values={self.__values})"
|
||||
|
||||
@classmethod
|
||||
def from_dense(cls, matrix):
|
||||
values, row_indices, row_offsets, column_indices = _dense3d_to_sparse(
|
||||
matrix, matrix.device
|
||||
)
|
||||
return cls(row_offsets, column_indices, values, matrix.shape)
|
||||
|
||||
@classmethod
|
||||
def from_sparse_coo(cls, arg0):
|
||||
"""
|
||||
assert arg0.is_sparse
|
||||
x = arg0.coalesce()
|
||||
rows, cols = x.indices().unbind(0)
|
||||
vals = x.values()
|
||||
_coo_to_csr()
|
||||
"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def _wrap(
|
||||
cls, shape, values, row_indices, row_offsets, column_indices, _transp_info
|
||||
):
|
||||
matrix = cls.__new__(cls, row_offsets, column_indices, values, shape)
|
||||
matrix.__values = values
|
||||
matrix.__row_indices = row_indices
|
||||
matrix.__row_offsets = row_offsets
|
||||
matrix.__column_indices = column_indices
|
||||
matrix.__transp_info = _transp_info
|
||||
return matrix
|
||||
|
||||
def values(self):
|
||||
return self.__values
|
||||
|
||||
@property
|
||||
def _csr_row_indices(self):
|
||||
return self.__row_indices
|
||||
|
||||
@property
|
||||
def _csr_row_offsets(self):
|
||||
return self.__row_offsets
|
||||
|
||||
@property
|
||||
def _csr_column_indices(self):
|
||||
return self.__column_indices
|
||||
|
||||
@property
|
||||
def _csr_transp_info(self):
|
||||
return self.__transp_info
|
||||
|
||||
@classmethod
|
||||
def _bmm(cls, arg0, arg1):
|
||||
if not (isinstance(arg0, cls) and type(arg1) == torch.Tensor):
|
||||
return NotImplemented
|
||||
|
||||
assert arg0.ndim == 3
|
||||
assert arg1.ndim == 3
|
||||
|
||||
self = arg0
|
||||
b = arg1
|
||||
|
||||
_, m, n = self.shape
|
||||
row_indices = self.__row_indices
|
||||
values = self.__values
|
||||
row_offsets = self.__row_offsets
|
||||
column_indices = self.__column_indices
|
||||
|
||||
out = _csr_ops._spmm.apply(
|
||||
b, row_indices, values, row_offsets, column_indices, m, self.__transp_info
|
||||
)
|
||||
return out
|
||||
|
||||
@classmethod
|
||||
def _softmax(cls, arg0, dim):
|
||||
if not (dim == -1 or dim == 2):
|
||||
return NotImplemented
|
||||
|
||||
self = arg0
|
||||
_, m, n = self.shape
|
||||
row_indices = self.__row_indices
|
||||
values = self.__values
|
||||
row_offsets = self.__row_offsets
|
||||
column_indices = self.__column_indices
|
||||
out = _csr_ops._SparseSoftmax.apply(
|
||||
m, n, row_indices, values, row_offsets, column_indices
|
||||
)
|
||||
return cls._wrap(
|
||||
self.shape,
|
||||
out,
|
||||
row_indices,
|
||||
row_offsets,
|
||||
column_indices,
|
||||
self.__transp_info,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _transpose(cls, arg0, dim0, dim1):
|
||||
# TODO: check if need to return this or not
|
||||
if not (dim0 == 1 or dim0 == -2):
|
||||
return NotImplemented
|
||||
if not (dim1 == 2 or dim1 == -1):
|
||||
return NotImplemented
|
||||
|
||||
B, m, n = arg0.shape
|
||||
values = arg0.__values
|
||||
|
||||
(
|
||||
output_row_indices,
|
||||
output_values,
|
||||
output_row_offsets,
|
||||
output_column_indices,
|
||||
) = _transpose_with_info(values, arg0.__transp_info)
|
||||
new_transp_info = _get_transpose_info(
|
||||
n, m, output_row_indices, output_row_offsets, output_column_indices
|
||||
)
|
||||
|
||||
return cls._wrap(
|
||||
(B, n, m),
|
||||
output_values,
|
||||
output_row_indices,
|
||||
output_row_offsets,
|
||||
output_column_indices,
|
||||
new_transp_info,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _masked_matmul(cls, a, b, mask):
|
||||
if not (type(a) == torch.Tensor and type(b) == torch.Tensor):
|
||||
return NotImplemented
|
||||
assert mask.shape[1] == a.shape[1]
|
||||
assert mask.shape[2] == b.shape[2]
|
||||
row_indices = mask.__row_indices
|
||||
row_offsets = mask.__row_offsets
|
||||
column_indices = mask.__column_indices
|
||||
a = a.contiguous()
|
||||
out = _csr_ops._sddmm.apply(
|
||||
a,
|
||||
b.transpose(-2, -1).contiguous(),
|
||||
row_indices,
|
||||
row_offsets,
|
||||
column_indices,
|
||||
mask.__transp_info,
|
||||
)
|
||||
# TODO add bias here
|
||||
return cls._wrap(
|
||||
mask.shape,
|
||||
out,
|
||||
row_indices,
|
||||
row_offsets,
|
||||
column_indices,
|
||||
mask.__transp_info,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _to(cls, arg0, device):
|
||||
if isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
assert isinstance(device, torch.device)
|
||||
return cls._wrap(
|
||||
arg0.shape,
|
||||
arg0.__values.to(device=device),
|
||||
arg0.__row_indices.to(device=device),
|
||||
arg0.__row_offsets.to(device=device),
|
||||
arg0.__column_indices.to(device=device),
|
||||
tuple(t.to(device=device) for t in arg0.__transp_info),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _copy(cls, arg0, arg1):
|
||||
if not (isinstance(arg0, cls) and isinstance(arg1, cls)):
|
||||
return NotImplemented
|
||||
assert arg0.shape == arg1.shape
|
||||
av0, av1 = arg0.__values, arg1.__values
|
||||
av0.resize_as_(av1).copy_(av1)
|
||||
av0, av1 = arg0.__row_indices, arg1.__row_indices
|
||||
av0.resize_as_(av1).copy_(av1)
|
||||
av0, av1 = arg0.__row_offsets, arg1.__row_offsets
|
||||
av0.resize_as_(av1).copy_(av1)
|
||||
av0, av1 = arg0.__column_indices, arg1.__column_indices
|
||||
av0.resize_as_(av1).copy_(av1)
|
||||
for v0, v1 in zip(arg0.__transp_info, arg1.__transp_info):
|
||||
v0.resize_as_(v1).copy_(v1)
|
||||
return arg0
|
||||
|
||||
@classmethod
|
||||
def _equal(cls, arg0, arg1):
|
||||
if not (isinstance(arg0, cls) and isinstance(arg1, cls)):
|
||||
return NotImplemented
|
||||
if arg0.shape != arg1.shape:
|
||||
return False
|
||||
if not torch.equal(arg0.__values, arg1.__values):
|
||||
return False
|
||||
if not torch.equal(arg0.__row_offsets, arg1.__row_offsets):
|
||||
return False
|
||||
if not torch.equal(arg0.__column_indices, arg1.__column_indices):
|
||||
return False
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def _to_dense(cls, arg0):
|
||||
_, m, n = arg0.shape
|
||||
shape = arg0.shape
|
||||
matrix = torch.zeros(shape, dtype=arg0.dtype, device=arg0.device)
|
||||
row_offsets = arg0.__row_offsets.long()
|
||||
column_indices = arg0.__column_indices.long()
|
||||
row_coo, _ = _csr_to_coo(m, n, row_offsets, column_indices)
|
||||
b_idxs = torch.arange(len(arg0.__values), device=arg0.device)[:, None]
|
||||
matrix[b_idxs, row_coo, column_indices] = arg0.__values
|
||||
return matrix
|
||||
|
||||
@classmethod
|
||||
def _binary_op(cls, func, arg0, arg1):
|
||||
if not (
|
||||
isinstance(arg0, (cls, int, float)) and isinstance(arg1, (cls, int, float))
|
||||
):
|
||||
return NotImplemented
|
||||
v0, v1 = arg0, arg1
|
||||
if isinstance(arg0, cls):
|
||||
v0 = arg0.__values
|
||||
if isinstance(arg1, cls):
|
||||
v1 = arg1.__values
|
||||
# assert arg0.shape == arg1.shape
|
||||
if isinstance(arg0, cls) and isinstance(arg1, cls):
|
||||
msg = f"arg0 and arg1 need to have the same sparsity pattern in {func} (for now)"
|
||||
if not arg0.__row_offsets.shape == arg1.__row_offsets.shape:
|
||||
raise NotImplementedError(msg)
|
||||
if not arg0.__column_indices.shape == arg1.__column_indices.shape:
|
||||
raise NotImplementedError(msg)
|
||||
if not arg0.__values.shape == arg1.__values.shape:
|
||||
raise NotImplementedError(msg)
|
||||
# TODO this is not always true, but is a fast approximation for now
|
||||
if arg0.__row_offsets is not arg1.__row_offsets:
|
||||
raise NotImplementedError(msg)
|
||||
if arg0.__column_indices is not arg1.__column_indices:
|
||||
raise NotImplementedError(msg)
|
||||
out = func(v0, v1)
|
||||
return cls._wrap(
|
||||
arg0.shape,
|
||||
out,
|
||||
arg0.__row_indices,
|
||||
arg0.__row_offsets,
|
||||
arg0.__column_indices,
|
||||
arg0.__transp_info,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _binary_op_slow(cls, func, arg0, arg1):
|
||||
# assert arg0.shape == arg1.shape
|
||||
v0, v1 = arg0, arg1
|
||||
if isinstance(arg0, cls):
|
||||
v0 = arg0.to_dense()
|
||||
if isinstance(arg1, cls):
|
||||
v1 = arg1.to_dense()
|
||||
out = func(v0, v1)
|
||||
return cls.from_dense(out)
|
||||
|
||||
@classmethod
|
||||
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
if func in [
|
||||
torch.Tensor.bmm,
|
||||
torch.bmm,
|
||||
torch.Tensor.__matmul__,
|
||||
torch.matmul,
|
||||
torch.Tensor.matmul,
|
||||
]:
|
||||
assert len(args) == 2
|
||||
return cls._bmm(args[0], args[1])
|
||||
|
||||
if func in [torch.Tensor.softmax, torch.nn.functional.softmax, torch.softmax]:
|
||||
return cls._softmax(args[0], kwargs["dim"])
|
||||
|
||||
if func in [torch.Tensor.transpose, torch.transpose]:
|
||||
assert len(kwargs) == 0
|
||||
return cls._transpose(args[0], args[1], args[2])
|
||||
|
||||
if func == masked_matmul:
|
||||
assert len(args) == 3
|
||||
return cls._masked_matmul(args[0], args[1], args[2])
|
||||
|
||||
if func in [
|
||||
torch.Tensor.add,
|
||||
torch.add,
|
||||
torch.Tensor.__add__,
|
||||
]:
|
||||
assert len(args) == 2
|
||||
if not (isinstance(args[0], cls) and isinstance(args[1], cls)):
|
||||
raise NotImplementedError(
|
||||
f"{func} with {type(args[0])} and {type(args[1])} not implemented"
|
||||
)
|
||||
return cls._binary_op(func, args[0], args[1])
|
||||
|
||||
if func in [
|
||||
torch.Tensor.mul,
|
||||
torch.mul,
|
||||
torch.Tensor.__mul__,
|
||||
]:
|
||||
assert len(args) == 2
|
||||
return cls._binary_op(func, args[0], args[1])
|
||||
|
||||
if func in [torch.Tensor.logical_and, torch.logical_and, torch.Tensor.__and__]:
|
||||
assert len(args) == 2
|
||||
return cls._binary_op_slow(func, args[0], args[1])
|
||||
|
||||
if func in [torch.nn.functional.dropout, torch.dropout, torch.dropout_]:
|
||||
x = args[0]
|
||||
values = x.__values.clone()
|
||||
values = func(values, *args[1:], **kwargs)
|
||||
return cls._wrap(
|
||||
x.shape,
|
||||
values,
|
||||
x.__row_indices,
|
||||
x.__row_offsets,
|
||||
x.__column_indices,
|
||||
x.__transp_info,
|
||||
)
|
||||
|
||||
if func == torch.Tensor.to:
|
||||
# print(args, kwargs)
|
||||
assert len(args) >= 2
|
||||
return cls._to(args[0], args[1])
|
||||
# return cls._to(args[0], kwargs["device"])
|
||||
|
||||
if func in [torch.Tensor.copy_]:
|
||||
assert len(args) == 2
|
||||
return cls._copy(args[0], args[1])
|
||||
|
||||
if func in [torch.Tensor.equal, torch.equal]:
|
||||
assert len(args) == 2
|
||||
return cls._equal(args[0], args[1])
|
||||
|
||||
if func == torch.Tensor.to_dense:
|
||||
assert len(args) == 1
|
||||
return cls._to_dense(args[0])
|
||||
|
||||
if func == torch.Tensor.detach:
|
||||
x = args[0]
|
||||
return cls._wrap(
|
||||
x.shape,
|
||||
x.__values.detach(),
|
||||
x.__row_indices,
|
||||
x.__row_offsets,
|
||||
x.__column_indices,
|
||||
x.__transp_info,
|
||||
)
|
||||
|
||||
if func == torch.Tensor.__deepcopy__:
|
||||
x = args[0]
|
||||
memo = args[1]
|
||||
return cls._wrap(
|
||||
x.shape,
|
||||
x.__values.__deepcopy__(memo),
|
||||
x.__row_indices.__deepcopy__(memo),
|
||||
x.__row_offsets.__deepcopy__(memo),
|
||||
x.__column_indices.__deepcopy__(memo),
|
||||
tuple(v.__deepcopy__(memo) for v in x.__transp_info),
|
||||
)
|
||||
|
||||
if func in [torch.Tensor.grad.__get__, torch.Tensor._grad.__get__]:
|
||||
assert len(args) == 1
|
||||
assert len(kwargs) == 0
|
||||
x = args[0]
|
||||
return cls._wrap(
|
||||
x.shape,
|
||||
x.__values.grad,
|
||||
x.__row_indices,
|
||||
x.__row_offsets,
|
||||
x.__column_indices,
|
||||
x.__transp_info,
|
||||
)
|
||||
|
||||
if func == torch.Tensor.requires_grad_:
|
||||
func(args[0].__values)
|
||||
|
||||
with torch._C.DisableTorchFunction():
|
||||
ret = func(*args, **kwargs)
|
||||
# TODO: check this
|
||||
if func in torch.overrides.get_default_nowrap_functions():
|
||||
return ret
|
||||
return torch._tensor._convert(ret, cls)
|
||||
|
||||
return NotImplemented
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args, kwargs):
|
||||
return NotImplemented
|
||||
123
pkgs/xformers/sparse/utils.py
Normal file
123
pkgs/xformers/sparse/utils.py
Normal file
@@ -0,0 +1,123 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def _coo_to_csr(m, n, row_indices, column_indices):
|
||||
# assumes coalesced coo
|
||||
row_offsets = row_indices.bincount(minlength=n).cumsum(0, dtype=row_indices.dtype)
|
||||
row_offsets = torch.nn.functional.pad(row_offsets, (1, 0))
|
||||
return row_offsets, column_indices
|
||||
|
||||
|
||||
def _csr_to_coo(m, n, row_offsets, column_indices):
|
||||
# convert from compressed rows to uncompressed
|
||||
indices = torch.arange(m, dtype=row_offsets.dtype, device=row_offsets.device)
|
||||
row_sizes = torch.diff(row_offsets)
|
||||
row_coo = torch.repeat_interleave(indices, row_sizes.long())
|
||||
return row_coo, column_indices
|
||||
|
||||
|
||||
def _diffsort(a):
|
||||
return torch.argsort(torch.diff(a), dim=0, descending=True)
|
||||
|
||||
|
||||
def _get_transpose_info(m, n, row_indices, row_offsets, column_indices):
|
||||
# strategy:
|
||||
# - uncompress the rows to have data in COO format
|
||||
# - get permutation for stable sort of the columns to get the rows for the transposed matrix
|
||||
# - compress the new rows and return the permutation to be applied on the values
|
||||
|
||||
# convert from compressed rows to uncompressed
|
||||
row_coo, _ = _csr_to_coo(m, n, row_offsets, column_indices)
|
||||
|
||||
# get the permutation for the stable sort
|
||||
row_offsets_t, perm = column_indices.sort(dim=0, stable=True)
|
||||
column_indices_t = row_coo[perm]
|
||||
|
||||
row_offsets_t, _ = _coo_to_csr(m, n, row_offsets_t, column_indices)
|
||||
row_indices_t = _diffsort(row_offsets_t).int()
|
||||
|
||||
return row_indices_t, row_offsets_t, column_indices_t, perm
|
||||
|
||||
|
||||
def _transpose_with_info(values, _transpose_info):
|
||||
row_indices_t, row_offsets_t, column_indices_t, perm = _transpose_info
|
||||
values_t = values[:, perm]
|
||||
return row_indices_t, values_t, row_offsets_t, column_indices_t
|
||||
|
||||
|
||||
def _transpose(m, n, row_indices, values, row_offsets, column_indices):
|
||||
_transpose_info = _get_transpose_info(
|
||||
m, n, row_indices, row_offsets, column_indices
|
||||
)
|
||||
return _transpose_with_info(values, _transpose_info)
|
||||
|
||||
|
||||
def _nonzero_mask_to_sparse_csr_indices(mask, device):
|
||||
"""Converts dense 2d matrix to a csr sparse matrix."""
|
||||
|
||||
assert len(mask.shape) == 2
|
||||
index_dtype = torch.int32
|
||||
|
||||
# Calculate the offset of each row.
|
||||
row_offsets = mask.sum(dim=-1, dtype=index_dtype).cumsum(dim=-1, dtype=index_dtype)
|
||||
row_offsets = torch.nn.functional.pad(row_offsets, (1, 0))
|
||||
|
||||
# Create the row indices and sort them.
|
||||
row_indices = _diffsort(row_offsets).to(index_dtype)
|
||||
|
||||
# Extract the column indices for the nonzero values.
|
||||
column_indices = torch.where(mask)[1].to(index_dtype).contiguous()
|
||||
|
||||
row_indices = row_indices.to(device)
|
||||
row_offsets = row_offsets.to(device)
|
||||
column_indices = column_indices.to(device)
|
||||
return row_indices, row_offsets, column_indices
|
||||
|
||||
|
||||
def _dense_to_sparse(matrix, device):
|
||||
"""Converts dense 2d matrix to a csr sparse matrix."""
|
||||
|
||||
assert len(matrix.shape) == 2
|
||||
value_dtype = torch.float32
|
||||
|
||||
# Extract the nonzero values.
|
||||
mask = matrix != 0
|
||||
values = matrix[mask].to(dtype=value_dtype, device=device)
|
||||
|
||||
row_indices, row_offsets, column_indices = _nonzero_mask_to_sparse_csr_indices(
|
||||
mask, device
|
||||
)
|
||||
return values, row_indices, row_offsets, column_indices
|
||||
|
||||
|
||||
def _round_nnz(mask, divisible_by=4):
|
||||
nonzero = torch.where(mask)
|
||||
nnz = nonzero[0].shape[0]
|
||||
nonzero = tuple(n[: (nnz - nnz % divisible_by)] for n in nonzero)
|
||||
nm = torch.zeros_like(mask)
|
||||
nm[nonzero] = True
|
||||
return nm
|
||||
|
||||
|
||||
def _dense3d_to_sparse(matrix, device):
|
||||
assert len(matrix.shape) == 3
|
||||
mask = matrix != 0
|
||||
if not torch.all(mask == mask[0]):
|
||||
raise ValueError("Expected the same sparsity pattern over the batch dimension")
|
||||
|
||||
# for now, our kernels assume that we have the number of
|
||||
# nnz to be divisible by 4
|
||||
mask = _round_nnz(mask[0], divisible_by=4)
|
||||
mask = mask[None].expand(matrix.shape)
|
||||
|
||||
values = matrix[mask].reshape(matrix.shape[0], -1).to(device)
|
||||
row_indices, row_offsets, column_indices = _nonzero_mask_to_sparse_csr_indices(
|
||||
mask[0], device
|
||||
)
|
||||
return values, row_indices, row_offsets, column_indices
|
||||
Reference in New Issue
Block a user