First commit

This commit is contained in:
2025-08-05 19:02:46 +08:00
parent 9efe891f99
commit 99fb9f5cb0
1412 changed files with 203615 additions and 0 deletions

View File

@@ -0,0 +1,7 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from .blocksparse_tensor import BlockSparseTensor # noqa: F401
from .csr_tensor import SparseCSRTensor # noqa: F401

Binary file not shown.

View File

@@ -0,0 +1,166 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import torch
from .utils import _csr_to_coo, _transpose_with_info
def _should_use_coo(a, sparsity):
if not a.is_cuda:
return False
B, M, K = a.shape
# amortize overhead of converting from csr to coo
if B < 32 and M < 4096:
return False
if sparsity > 0.995:
return False
if sparsity < 0.9:
return False
if K > 64:
return False
# let's be overly cautious here for now
return sparsity > 0.97
def _should_use_csr_ge(a, sparsity):
if not a.is_cuda:
return False
return sparsity > 0.99
def _sddmm_func(a, b, row_indices, row_offsets, column_indices):
sparsity = 1 - column_indices.shape[0] / (a.shape[1] * b.shape[1])
if _should_use_coo(a, sparsity):
m = a.shape[-2]
n = b.shape[-2]
# converting from csr to coo has a constant overhead of ~150us
# so only dispatch to it for reasonably large problem sizes
ro, ci = _csr_to_coo(m, n, row_offsets, column_indices)
return torch.ops.xformers.coo_sddmm(a, b, row_indices, ro, ci)
elif _should_use_csr_ge(a, sparsity):
return torch.ops.xformers.csr_sddmm(
a, b, row_indices, row_offsets, column_indices
)
return torch.ops.xformers.sddmm_sputnik(
a, b, row_indices, row_offsets, column_indices
)
class _SparseSoftmax(torch.autograd.Function):
@staticmethod
def forward(ctx, m, n, row_indices, values, row_offsets, column_indices):
out = torch.ops.xformers.sparse_softmax_sputnik(
m, n, row_indices, values, row_offsets, column_indices
)
# note: save out and not values, as an optimization step
ctx.save_for_backward(row_indices, out, row_offsets, column_indices)
ctx.size = (m, n)
return out
@staticmethod
def backward(ctx, grad):
row_indices, out, row_offsets, column_indices = ctx.saved_tensors
m, n = ctx.size
# gradients w.r.t. values
grad = grad.contiguous()
ga = torch.ops.xformers.sparse_softmax_backward_sputnik(
m, n, row_indices, out, grad, row_offsets, column_indices
)
return None, None, None, ga, None, None
class _sddmm(torch.autograd.Function):
@staticmethod
def forward(ctx, a, b, row_indices, row_offsets, column_indices, _transp_info):
out = _sddmm_func(a, b, row_indices, row_offsets, column_indices)
ctx.save_for_backward(
a, b, row_indices, row_offsets, column_indices, *_transp_info
)
return out
@staticmethod
def backward(ctx, grad):
(
a,
b,
row_indices,
row_offsets,
column_indices,
*_transp_info,
) = ctx.saved_tensors
m, n = a.shape[1], b.shape[1]
# gradients w.r.t. values
grad = grad.contiguous()
a = a.contiguous()
b = b.contiguous()
a_grad = torch.ops.xformers.spmm_sputnik(
b, row_indices, grad, row_offsets, column_indices, m
)
(
row_indices_t,
grad_t,
row_offsets_t,
column_indices_t,
) = _transpose_with_info(grad, _transp_info)
b_grad = torch.ops.xformers.spmm_sputnik(
a, row_indices_t, grad_t, row_offsets_t, column_indices_t, n
)
return a_grad, b_grad, None, None, None, None
class _spmm(torch.autograd.Function):
@staticmethod
def forward(
ctx, b, row_indices, values, row_offsets, column_indices, m, _transp_info
):
b = b.contiguous()
out = torch.ops.xformers.spmm_sputnik(
b, row_indices, values, row_offsets, column_indices, m
)
ctx.save_for_backward(
b, row_indices, values, row_offsets, column_indices, *_transp_info
)
return out
@staticmethod
def backward(ctx, grad):
(
b,
row_indices,
values,
row_offsets,
column_indices,
*_transp_info,
) = ctx.saved_tensors
k = b.shape[1]
# gradients w.r.t. values
grad = grad.contiguous()
grad_sparse = _sddmm_func(grad, b, row_indices, row_offsets, column_indices)
(
row_indices_t,
values_t,
row_offsets_t,
column_indices_t,
) = _transpose_with_info(values, _transp_info)
grad_dense = torch.ops.xformers.spmm_sputnik(
grad, row_indices_t, values_t, row_offsets_t, column_indices_t, k
)
return grad_dense, None, grad_sparse, None, None, None, None

View File

@@ -0,0 +1,357 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import logging
import torch
from xformers import _is_triton_available
from xformers.ops import masked_matmul
logger = logging.getLogger("xformers")
try:
if not _is_triton_available():
raise ImportError("triton is not available")
from triton.ops.blocksparse import matmul as blocksparse_matmul
from triton.ops.blocksparse import softmax as blocksparse_softmax
except ImportError as e:
logger.warning(
"Triton is not available, some optimizations will not be enabled.\n"
+ f"This is just a warning: {e}"
)
blocksparse_matmul = None
blocksparse_softmax = None
def _can_use_triton(a):
if a.device.type == "cpu":
return False
if blocksparse_matmul is None:
return False
return True
def _spmm(b, layout, values):
N, nnz, _, block_size = values.shape
br = b.reshape(
b.shape[0], b.shape[1], b.shape[2] // block_size, block_size, b.shape[3]
)
# perform matmul on blocks
h, r, c = layout.nonzero(as_tuple=True)
temp = values @ br[:, h, c, :]
linear_idx = h * (b.shape[2] // block_size) + r
out = torch.zeros(
N,
b.shape[1] * layout.shape[-2],
block_size,
b.shape[3],
dtype=b.dtype,
device=b.device,
)
# now aggregate the results of the different blocks
out.index_add_(1, linear_idx.to(b.device), temp)
out = out.reshape(N, b.shape[1], -1, b.shape[3])
return out
def _softmax(layout, values):
h, r, c = layout.nonzero(as_tuple=True)
norms = torch.logsumexp(values, dim=-1, keepdim=True)
linear_idx = h * layout.shape[1] + r
out_t = torch.zeros(
norms.shape[0],
layout.shape[0] * layout.shape[1],
norms.shape[2],
norms.shape[3],
dtype=norms.dtype,
device=norms.device,
)
max_val = norms.max()
out_t.index_add_(
1, linear_idx.to(values.device), (norms - max_val).exp()
).clamp_min_(1e-24).log_().add_(max_val)
out = torch.exp(values - out_t[:, linear_idx])
return out
def _sddmm(a, b, layout):
block_size = a.shape[-2] // layout.shape[-2]
a = a.reshape(
a.shape[0], a.shape[1], a.shape[2] // block_size, block_size, a.shape[3]
)
b = b.reshape(
b.shape[0], b.shape[1], b.shape[2] // block_size, block_size, b.shape[3]
)
h, r, c = layout.nonzero(as_tuple=True)
out = torch.einsum("nhik,nhjk->nhij", a[:, h, r, :, :], b[:, h, c, :, :])
return out
class BlockSparseTensor(torch.Tensor):
@staticmethod
def __new__(cls, values, layout):
kwargs = {}
kwargs["device"] = values.device
kwargs["dtype"] = values.dtype
kwargs["layout"] = values.layout
kwargs["requires_grad"] = values.requires_grad
assert values.ndim == 4
B, _, block_size, _ = values.shape
C, h, w = layout.shape
# TODO validate shape of layout vs values
shape = (B, C, block_size * h, block_size * w)
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs)
def __init__(self, values, layout):
assert values.shape[-2] == values.shape[-1]
assert (
values.device == layout.device
), "Both values and layout need to reside on the same device"
block_size = values.shape[-1]
# TODO: make this check conditioned on the use of Triton
assert block_size >= 16, "Minimum block size is 16, for now at least"
# Pure blocksparse data
self.__values = values
self.__layout = layout
# blocksparse operators for triton
if blocksparse_matmul:
self._initialize_triton_ops()
else:
self.__sparse_dot_sdd = None
self.__sparse_dot_dsd = None
self.__sparse_softmax = None
def _initialize_triton_ops(self):
block_size = self.__values.shape[-1]
self.__sparse_dot_sdd = blocksparse_matmul(
self.__layout,
block_size,
"sdd",
trans_a=False,
trans_b=True,
device=self.__layout.device,
)
self.__sparse_dot_dsd = blocksparse_matmul(
self.__layout,
block_size,
"dsd",
trans_a=False,
trans_b=False,
device=self.__layout.device,
)
self.__sparse_softmax = blocksparse_softmax(
self.__layout, block_size, device=self.__layout.device
)
def __repr__(self):
return f"block_sparse_tensor(shape={self.shape}, values={self.__values})"
def values(self):
return self.__values
@classmethod
def _raw_wrap(cls, values, layout, sparse_dot_sdd, sparse_dot_dsd, sparse_softmax):
matrix = cls.__new__(cls, values, layout)
matrix.__values = values
matrix.__layout = layout
matrix.__sparse_dot_sdd = sparse_dot_sdd
matrix.__sparse_dot_dsd = sparse_dot_dsd
matrix.__sparse_softmax = sparse_softmax
return matrix
@classmethod
def _wrap(cls, values, bmat):
matrix = cls.__new__(cls, values, bmat.__layout)
matrix.__values = values
matrix.__layout = bmat.__layout
matrix.__sparse_dot_sdd = bmat.__sparse_dot_sdd
matrix.__sparse_dot_dsd = bmat.__sparse_dot_dsd
matrix.__sparse_softmax = bmat.__sparse_softmax
return matrix
@classmethod
def _bmm(cls, arg0, arg1):
if not (isinstance(arg0, cls) and type(arg1) == torch.Tensor):
return NotImplemented
if _can_use_triton(arg1):
res = arg0.__sparse_dot_dsd(arg0.__values, arg1)
else:
res = _spmm(arg1, arg0.__layout, arg0.__values)
return res
@classmethod
def _masked_matmul(cls, a, b, mask):
if not (type(a) == torch.Tensor and type(b) == torch.Tensor):
return NotImplemented
b = b.transpose(-2, -1)
assert b.is_contiguous()
if _can_use_triton(a):
res = mask.__sparse_dot_sdd(a, b)
else:
res = _sddmm(a, b, mask.__layout)
return cls._wrap(res, mask)
@classmethod
def _softmax(cls, arg0, dim):
if not (dim == -1 or dim == 2):
return NotImplemented
if _can_use_triton(arg0):
res = arg0.__sparse_softmax(arg0.__values)
else:
res = _softmax(arg0.__layout, arg0.__values)
return cls._wrap(res, arg0)
@classmethod
def _to(cls, arg0, device):
if isinstance(device, str):
device = torch.device(device)
assert isinstance(device, torch.device)
return cls(
arg0.__values.to(device=device),
arg0.__layout,
)
@classmethod
def _copy(cls, arg0, arg1):
if not (isinstance(arg0, cls) and isinstance(arg1, cls)):
return NotImplemented
assert arg0.shape == arg1.shape
av0, av1 = arg0.__values, arg1.__values
av0.resize_as_(av1).copy_(av1)
av0, av1 = arg0.__layout, arg1.__layout
av0.resize_as_(av1).copy_(av1)
out = cls(arg0.__values, arg0.__layout)
arg0.__sparse_dot_sdd = out.__sparse_dot_sdd
arg0.__sparse_dot_dsd = out.__sparse_dot_dsd
arg0.__sparse_softmax = out.__sparse_softmax
return arg0
@classmethod
def _equal(cls, arg0, arg1):
if not (isinstance(arg0, cls) and isinstance(arg1, cls)):
return NotImplemented
if arg0.shape != arg1.shape:
return False
if not torch.equal(arg0.__values, arg1.__values):
return False
if not torch.equal(arg0.__layout, arg1.__layout):
return False
return True
@classmethod
def _to_dense(cls, arg0):
# out = torch.zeros(arg0.shape, dtype=arg0.dtype, device=arg0.device, requires_grad=arg0.requires_grad)
out = torch.zeros(arg0.shape, dtype=arg0.dtype, device=arg0.device)
values = arg0.__values
layout = arg0.__layout
block_size = values.shape[-1]
blocks_i = layout.shape[-2]
blocks_j = layout.shape[-1]
out_r = out.reshape(
arg0.shape[0], arg0.shape[1], blocks_i, block_size, blocks_j, block_size
)
for idx, (h, i, j) in enumerate(zip(*layout.nonzero(as_tuple=True))):
out_r[:, h, i, :, j, :] = values[:, idx, :, :]
return out
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
if func in [
torch.Tensor.bmm,
torch.bmm,
torch.Tensor.__matmul__,
torch.matmul,
torch.Tensor.matmul,
]:
assert len(args) == 2
return cls._bmm(args[0], args[1])
if func in [torch.Tensor.softmax, torch.nn.functional.softmax, torch.softmax]:
return cls._softmax(args[0], kwargs["dim"])
if func == masked_matmul:
assert len(args) == 3
return cls._masked_matmul(args[0], args[1], args[2])
if func in [torch.nn.functional.dropout, torch.dropout, torch.dropout_]:
x = args[0]
values = x.__values.clone()
values = func(values, *args[1:], **kwargs)
return cls._wrap(values, x)
if func == torch.Tensor.to:
# print(args, kwargs)
assert len(args) >= 2
return cls._to(args[0], args[1])
# return cls._to(args[0], kwargs["device"])
if func in [torch.Tensor.copy_]:
assert len(args) == 2
return cls._copy(args[0], args[1])
if func in [torch.Tensor.equal, torch.equal]:
assert len(args) == 2
return cls._equal(args[0], args[1])
if func == torch.Tensor.to_dense:
assert len(args) == 1
return cls._to_dense(args[0])
if func == torch.Tensor.detach:
x = args[0]
values = x.__values.clone()
values = func(values, *args[1:], **kwargs)
return cls._wrap(values, x)
if func == torch.Tensor.__deepcopy__:
x = args[0]
memo = args[1]
return cls._raw_wrap(
x.__values.__deepcopy__(memo),
x.__layout.__deepcopy__(memo),
# x.__sparse_dot_sdd.__deepcopy__(memo),
# x.__sparse_dot_dsd.__deepcopy__(memo),
# x.__sparse_softmax.__deepcopy__(memo),
x.__sparse_dot_sdd,
x.__sparse_dot_dsd,
x.__sparse_softmax,
)
if func in [torch.Tensor.grad.__get__, torch.Tensor._grad.__get__]:
assert len(args) == 1
assert len(kwargs) == 0
x = args[0]
return cls._wrap(x.__values.grad, x)
if func == torch.Tensor.requires_grad_:
func(args[0].__values)
with torch._C.DisableTorchFunction():
ret = func(*args, **kwargs)
# TODO: check this
if func in torch.overrides.get_default_nowrap_functions():
return ret
return torch._tensor._convert(ret, cls)
return NotImplemented
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
return NotImplemented

View File

@@ -0,0 +1,438 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import torch
from xformers.ops import masked_matmul
from xformers.sparse import _csr_ops
from xformers.sparse.utils import (
_csr_to_coo,
_dense3d_to_sparse,
_diffsort,
_get_transpose_info,
_transpose_with_info,
)
class SparseCSRTensor(torch.Tensor):
@staticmethod
def __new__(cls, row_offsets, column_indices, values, shape):
kwargs = {}
kwargs["device"] = values.device
kwargs["dtype"] = values.dtype
kwargs["layout"] = values.layout
kwargs["requires_grad"] = values.requires_grad
assert len(shape) == 3
assert torch.__version__ > (1, 10), "SparseCSRTensor requires PyTorch 1.11+"
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs)
def __init__(self, row_offsets, column_indices, values, shape):
assert row_offsets.ndim == 1
assert column_indices.ndim == 1
assert values.ndim == 2
self.__row_offsets = row_offsets.contiguous()
self.__row_indices = _diffsort(row_offsets).to(row_offsets.dtype)
self.__column_indices = column_indices.contiguous()
self.__values = values.contiguous()
self.__transp_info = _get_transpose_info(
self.shape[1],
self.shape[2],
self.__row_indices,
self.__row_offsets,
self.__column_indices,
)
def __repr__(self):
return f"sparse_csr_tensor(shape={self.shape}, values={self.__values})"
@classmethod
def from_dense(cls, matrix):
values, row_indices, row_offsets, column_indices = _dense3d_to_sparse(
matrix, matrix.device
)
return cls(row_offsets, column_indices, values, matrix.shape)
@classmethod
def from_sparse_coo(cls, arg0):
"""
assert arg0.is_sparse
x = arg0.coalesce()
rows, cols = x.indices().unbind(0)
vals = x.values()
_coo_to_csr()
"""
pass
@classmethod
def _wrap(
cls, shape, values, row_indices, row_offsets, column_indices, _transp_info
):
matrix = cls.__new__(cls, row_offsets, column_indices, values, shape)
matrix.__values = values
matrix.__row_indices = row_indices
matrix.__row_offsets = row_offsets
matrix.__column_indices = column_indices
matrix.__transp_info = _transp_info
return matrix
def values(self):
return self.__values
@property
def _csr_row_indices(self):
return self.__row_indices
@property
def _csr_row_offsets(self):
return self.__row_offsets
@property
def _csr_column_indices(self):
return self.__column_indices
@property
def _csr_transp_info(self):
return self.__transp_info
@classmethod
def _bmm(cls, arg0, arg1):
if not (isinstance(arg0, cls) and type(arg1) == torch.Tensor):
return NotImplemented
assert arg0.ndim == 3
assert arg1.ndim == 3
self = arg0
b = arg1
_, m, n = self.shape
row_indices = self.__row_indices
values = self.__values
row_offsets = self.__row_offsets
column_indices = self.__column_indices
out = _csr_ops._spmm.apply(
b, row_indices, values, row_offsets, column_indices, m, self.__transp_info
)
return out
@classmethod
def _softmax(cls, arg0, dim):
if not (dim == -1 or dim == 2):
return NotImplemented
self = arg0
_, m, n = self.shape
row_indices = self.__row_indices
values = self.__values
row_offsets = self.__row_offsets
column_indices = self.__column_indices
out = _csr_ops._SparseSoftmax.apply(
m, n, row_indices, values, row_offsets, column_indices
)
return cls._wrap(
self.shape,
out,
row_indices,
row_offsets,
column_indices,
self.__transp_info,
)
@classmethod
def _transpose(cls, arg0, dim0, dim1):
# TODO: check if need to return this or not
if not (dim0 == 1 or dim0 == -2):
return NotImplemented
if not (dim1 == 2 or dim1 == -1):
return NotImplemented
B, m, n = arg0.shape
values = arg0.__values
(
output_row_indices,
output_values,
output_row_offsets,
output_column_indices,
) = _transpose_with_info(values, arg0.__transp_info)
new_transp_info = _get_transpose_info(
n, m, output_row_indices, output_row_offsets, output_column_indices
)
return cls._wrap(
(B, n, m),
output_values,
output_row_indices,
output_row_offsets,
output_column_indices,
new_transp_info,
)
@classmethod
def _masked_matmul(cls, a, b, mask):
if not (type(a) == torch.Tensor and type(b) == torch.Tensor):
return NotImplemented
assert mask.shape[1] == a.shape[1]
assert mask.shape[2] == b.shape[2]
row_indices = mask.__row_indices
row_offsets = mask.__row_offsets
column_indices = mask.__column_indices
a = a.contiguous()
out = _csr_ops._sddmm.apply(
a,
b.transpose(-2, -1).contiguous(),
row_indices,
row_offsets,
column_indices,
mask.__transp_info,
)
# TODO add bias here
return cls._wrap(
mask.shape,
out,
row_indices,
row_offsets,
column_indices,
mask.__transp_info,
)
@classmethod
def _to(cls, arg0, device):
if isinstance(device, str):
device = torch.device(device)
assert isinstance(device, torch.device)
return cls._wrap(
arg0.shape,
arg0.__values.to(device=device),
arg0.__row_indices.to(device=device),
arg0.__row_offsets.to(device=device),
arg0.__column_indices.to(device=device),
tuple(t.to(device=device) for t in arg0.__transp_info),
)
@classmethod
def _copy(cls, arg0, arg1):
if not (isinstance(arg0, cls) and isinstance(arg1, cls)):
return NotImplemented
assert arg0.shape == arg1.shape
av0, av1 = arg0.__values, arg1.__values
av0.resize_as_(av1).copy_(av1)
av0, av1 = arg0.__row_indices, arg1.__row_indices
av0.resize_as_(av1).copy_(av1)
av0, av1 = arg0.__row_offsets, arg1.__row_offsets
av0.resize_as_(av1).copy_(av1)
av0, av1 = arg0.__column_indices, arg1.__column_indices
av0.resize_as_(av1).copy_(av1)
for v0, v1 in zip(arg0.__transp_info, arg1.__transp_info):
v0.resize_as_(v1).copy_(v1)
return arg0
@classmethod
def _equal(cls, arg0, arg1):
if not (isinstance(arg0, cls) and isinstance(arg1, cls)):
return NotImplemented
if arg0.shape != arg1.shape:
return False
if not torch.equal(arg0.__values, arg1.__values):
return False
if not torch.equal(arg0.__row_offsets, arg1.__row_offsets):
return False
if not torch.equal(arg0.__column_indices, arg1.__column_indices):
return False
return True
@classmethod
def _to_dense(cls, arg0):
_, m, n = arg0.shape
shape = arg0.shape
matrix = torch.zeros(shape, dtype=arg0.dtype, device=arg0.device)
row_offsets = arg0.__row_offsets.long()
column_indices = arg0.__column_indices.long()
row_coo, _ = _csr_to_coo(m, n, row_offsets, column_indices)
b_idxs = torch.arange(len(arg0.__values), device=arg0.device)[:, None]
matrix[b_idxs, row_coo, column_indices] = arg0.__values
return matrix
@classmethod
def _binary_op(cls, func, arg0, arg1):
if not (
isinstance(arg0, (cls, int, float)) and isinstance(arg1, (cls, int, float))
):
return NotImplemented
v0, v1 = arg0, arg1
if isinstance(arg0, cls):
v0 = arg0.__values
if isinstance(arg1, cls):
v1 = arg1.__values
# assert arg0.shape == arg1.shape
if isinstance(arg0, cls) and isinstance(arg1, cls):
msg = f"arg0 and arg1 need to have the same sparsity pattern in {func} (for now)"
if not arg0.__row_offsets.shape == arg1.__row_offsets.shape:
raise NotImplementedError(msg)
if not arg0.__column_indices.shape == arg1.__column_indices.shape:
raise NotImplementedError(msg)
if not arg0.__values.shape == arg1.__values.shape:
raise NotImplementedError(msg)
# TODO this is not always true, but is a fast approximation for now
if arg0.__row_offsets is not arg1.__row_offsets:
raise NotImplementedError(msg)
if arg0.__column_indices is not arg1.__column_indices:
raise NotImplementedError(msg)
out = func(v0, v1)
return cls._wrap(
arg0.shape,
out,
arg0.__row_indices,
arg0.__row_offsets,
arg0.__column_indices,
arg0.__transp_info,
)
@classmethod
def _binary_op_slow(cls, func, arg0, arg1):
# assert arg0.shape == arg1.shape
v0, v1 = arg0, arg1
if isinstance(arg0, cls):
v0 = arg0.to_dense()
if isinstance(arg1, cls):
v1 = arg1.to_dense()
out = func(v0, v1)
return cls.from_dense(out)
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
if func in [
torch.Tensor.bmm,
torch.bmm,
torch.Tensor.__matmul__,
torch.matmul,
torch.Tensor.matmul,
]:
assert len(args) == 2
return cls._bmm(args[0], args[1])
if func in [torch.Tensor.softmax, torch.nn.functional.softmax, torch.softmax]:
return cls._softmax(args[0], kwargs["dim"])
if func in [torch.Tensor.transpose, torch.transpose]:
assert len(kwargs) == 0
return cls._transpose(args[0], args[1], args[2])
if func == masked_matmul:
assert len(args) == 3
return cls._masked_matmul(args[0], args[1], args[2])
if func in [
torch.Tensor.add,
torch.add,
torch.Tensor.__add__,
]:
assert len(args) == 2
if not (isinstance(args[0], cls) and isinstance(args[1], cls)):
raise NotImplementedError(
f"{func} with {type(args[0])} and {type(args[1])} not implemented"
)
return cls._binary_op(func, args[0], args[1])
if func in [
torch.Tensor.mul,
torch.mul,
torch.Tensor.__mul__,
]:
assert len(args) == 2
return cls._binary_op(func, args[0], args[1])
if func in [torch.Tensor.logical_and, torch.logical_and, torch.Tensor.__and__]:
assert len(args) == 2
return cls._binary_op_slow(func, args[0], args[1])
if func in [torch.nn.functional.dropout, torch.dropout, torch.dropout_]:
x = args[0]
values = x.__values.clone()
values = func(values, *args[1:], **kwargs)
return cls._wrap(
x.shape,
values,
x.__row_indices,
x.__row_offsets,
x.__column_indices,
x.__transp_info,
)
if func == torch.Tensor.to:
# print(args, kwargs)
assert len(args) >= 2
return cls._to(args[0], args[1])
# return cls._to(args[0], kwargs["device"])
if func in [torch.Tensor.copy_]:
assert len(args) == 2
return cls._copy(args[0], args[1])
if func in [torch.Tensor.equal, torch.equal]:
assert len(args) == 2
return cls._equal(args[0], args[1])
if func == torch.Tensor.to_dense:
assert len(args) == 1
return cls._to_dense(args[0])
if func == torch.Tensor.detach:
x = args[0]
return cls._wrap(
x.shape,
x.__values.detach(),
x.__row_indices,
x.__row_offsets,
x.__column_indices,
x.__transp_info,
)
if func == torch.Tensor.__deepcopy__:
x = args[0]
memo = args[1]
return cls._wrap(
x.shape,
x.__values.__deepcopy__(memo),
x.__row_indices.__deepcopy__(memo),
x.__row_offsets.__deepcopy__(memo),
x.__column_indices.__deepcopy__(memo),
tuple(v.__deepcopy__(memo) for v in x.__transp_info),
)
if func in [torch.Tensor.grad.__get__, torch.Tensor._grad.__get__]:
assert len(args) == 1
assert len(kwargs) == 0
x = args[0]
return cls._wrap(
x.shape,
x.__values.grad,
x.__row_indices,
x.__row_offsets,
x.__column_indices,
x.__transp_info,
)
if func == torch.Tensor.requires_grad_:
func(args[0].__values)
with torch._C.DisableTorchFunction():
ret = func(*args, **kwargs)
# TODO: check this
if func in torch.overrides.get_default_nowrap_functions():
return ret
return torch._tensor._convert(ret, cls)
return NotImplemented
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
return NotImplemented

View File

@@ -0,0 +1,123 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import torch
def _coo_to_csr(m, n, row_indices, column_indices):
# assumes coalesced coo
row_offsets = row_indices.bincount(minlength=n).cumsum(0, dtype=row_indices.dtype)
row_offsets = torch.nn.functional.pad(row_offsets, (1, 0))
return row_offsets, column_indices
def _csr_to_coo(m, n, row_offsets, column_indices):
# convert from compressed rows to uncompressed
indices = torch.arange(m, dtype=row_offsets.dtype, device=row_offsets.device)
row_sizes = torch.diff(row_offsets)
row_coo = torch.repeat_interleave(indices, row_sizes.long())
return row_coo, column_indices
def _diffsort(a):
return torch.argsort(torch.diff(a), dim=0, descending=True)
def _get_transpose_info(m, n, row_indices, row_offsets, column_indices):
# strategy:
# - uncompress the rows to have data in COO format
# - get permutation for stable sort of the columns to get the rows for the transposed matrix
# - compress the new rows and return the permutation to be applied on the values
# convert from compressed rows to uncompressed
row_coo, _ = _csr_to_coo(m, n, row_offsets, column_indices)
# get the permutation for the stable sort
row_offsets_t, perm = column_indices.sort(dim=0, stable=True)
column_indices_t = row_coo[perm]
row_offsets_t, _ = _coo_to_csr(m, n, row_offsets_t, column_indices)
row_indices_t = _diffsort(row_offsets_t).int()
return row_indices_t, row_offsets_t, column_indices_t, perm
def _transpose_with_info(values, _transpose_info):
row_indices_t, row_offsets_t, column_indices_t, perm = _transpose_info
values_t = values[:, perm]
return row_indices_t, values_t, row_offsets_t, column_indices_t
def _transpose(m, n, row_indices, values, row_offsets, column_indices):
_transpose_info = _get_transpose_info(
m, n, row_indices, row_offsets, column_indices
)
return _transpose_with_info(values, _transpose_info)
def _nonzero_mask_to_sparse_csr_indices(mask, device):
"""Converts dense 2d matrix to a csr sparse matrix."""
assert len(mask.shape) == 2
index_dtype = torch.int32
# Calculate the offset of each row.
row_offsets = mask.sum(dim=-1, dtype=index_dtype).cumsum(dim=-1, dtype=index_dtype)
row_offsets = torch.nn.functional.pad(row_offsets, (1, 0))
# Create the row indices and sort them.
row_indices = _diffsort(row_offsets).to(index_dtype)
# Extract the column indices for the nonzero values.
column_indices = torch.where(mask)[1].to(index_dtype).contiguous()
row_indices = row_indices.to(device)
row_offsets = row_offsets.to(device)
column_indices = column_indices.to(device)
return row_indices, row_offsets, column_indices
def _dense_to_sparse(matrix, device):
"""Converts dense 2d matrix to a csr sparse matrix."""
assert len(matrix.shape) == 2
value_dtype = torch.float32
# Extract the nonzero values.
mask = matrix != 0
values = matrix[mask].to(dtype=value_dtype, device=device)
row_indices, row_offsets, column_indices = _nonzero_mask_to_sparse_csr_indices(
mask, device
)
return values, row_indices, row_offsets, column_indices
def _round_nnz(mask, divisible_by=4):
nonzero = torch.where(mask)
nnz = nonzero[0].shape[0]
nonzero = tuple(n[: (nnz - nnz % divisible_by)] for n in nonzero)
nm = torch.zeros_like(mask)
nm[nonzero] = True
return nm
def _dense3d_to_sparse(matrix, device):
assert len(matrix.shape) == 3
mask = matrix != 0
if not torch.all(mask == mask[0]):
raise ValueError("Expected the same sparsity pattern over the batch dimension")
# for now, our kernels assume that we have the number of
# nnz to be divisible by 4
mask = _round_nnz(mask[0], divisible_by=4)
mask = mask[None].expand(matrix.shape)
values = matrix[mask].reshape(matrix.shape[0], -1).to(device)
row_indices, row_offsets, column_indices = _nonzero_mask_to_sparse_csr_indices(
mask[0], device
)
return values, row_indices, row_offsets, column_indices