First commit
This commit is contained in:
166
pkgs/xformers/sparse/_csr_ops.py
Normal file
166
pkgs/xformers/sparse/_csr_ops.py
Normal file
@@ -0,0 +1,166 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
from .utils import _csr_to_coo, _transpose_with_info
|
||||
|
||||
|
||||
def _should_use_coo(a, sparsity):
|
||||
if not a.is_cuda:
|
||||
return False
|
||||
B, M, K = a.shape
|
||||
# amortize overhead of converting from csr to coo
|
||||
if B < 32 and M < 4096:
|
||||
return False
|
||||
if sparsity > 0.995:
|
||||
return False
|
||||
if sparsity < 0.9:
|
||||
return False
|
||||
if K > 64:
|
||||
return False
|
||||
# let's be overly cautious here for now
|
||||
return sparsity > 0.97
|
||||
|
||||
|
||||
def _should_use_csr_ge(a, sparsity):
|
||||
if not a.is_cuda:
|
||||
return False
|
||||
return sparsity > 0.99
|
||||
|
||||
|
||||
def _sddmm_func(a, b, row_indices, row_offsets, column_indices):
|
||||
sparsity = 1 - column_indices.shape[0] / (a.shape[1] * b.shape[1])
|
||||
if _should_use_coo(a, sparsity):
|
||||
m = a.shape[-2]
|
||||
n = b.shape[-2]
|
||||
# converting from csr to coo has a constant overhead of ~150us
|
||||
# so only dispatch to it for reasonably large problem sizes
|
||||
ro, ci = _csr_to_coo(m, n, row_offsets, column_indices)
|
||||
return torch.ops.xformers.coo_sddmm(a, b, row_indices, ro, ci)
|
||||
elif _should_use_csr_ge(a, sparsity):
|
||||
return torch.ops.xformers.csr_sddmm(
|
||||
a, b, row_indices, row_offsets, column_indices
|
||||
)
|
||||
return torch.ops.xformers.sddmm_sputnik(
|
||||
a, b, row_indices, row_offsets, column_indices
|
||||
)
|
||||
|
||||
|
||||
class _SparseSoftmax(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, m, n, row_indices, values, row_offsets, column_indices):
|
||||
out = torch.ops.xformers.sparse_softmax_sputnik(
|
||||
m, n, row_indices, values, row_offsets, column_indices
|
||||
)
|
||||
# note: save out and not values, as an optimization step
|
||||
ctx.save_for_backward(row_indices, out, row_offsets, column_indices)
|
||||
ctx.size = (m, n)
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad):
|
||||
row_indices, out, row_offsets, column_indices = ctx.saved_tensors
|
||||
m, n = ctx.size
|
||||
|
||||
# gradients w.r.t. values
|
||||
grad = grad.contiguous()
|
||||
ga = torch.ops.xformers.sparse_softmax_backward_sputnik(
|
||||
m, n, row_indices, out, grad, row_offsets, column_indices
|
||||
)
|
||||
|
||||
return None, None, None, ga, None, None
|
||||
|
||||
|
||||
class _sddmm(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, a, b, row_indices, row_offsets, column_indices, _transp_info):
|
||||
out = _sddmm_func(a, b, row_indices, row_offsets, column_indices)
|
||||
|
||||
ctx.save_for_backward(
|
||||
a, b, row_indices, row_offsets, column_indices, *_transp_info
|
||||
)
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad):
|
||||
(
|
||||
a,
|
||||
b,
|
||||
row_indices,
|
||||
row_offsets,
|
||||
column_indices,
|
||||
*_transp_info,
|
||||
) = ctx.saved_tensors
|
||||
m, n = a.shape[1], b.shape[1]
|
||||
|
||||
# gradients w.r.t. values
|
||||
grad = grad.contiguous()
|
||||
a = a.contiguous()
|
||||
b = b.contiguous()
|
||||
|
||||
a_grad = torch.ops.xformers.spmm_sputnik(
|
||||
b, row_indices, grad, row_offsets, column_indices, m
|
||||
)
|
||||
|
||||
(
|
||||
row_indices_t,
|
||||
grad_t,
|
||||
row_offsets_t,
|
||||
column_indices_t,
|
||||
) = _transpose_with_info(grad, _transp_info)
|
||||
|
||||
b_grad = torch.ops.xformers.spmm_sputnik(
|
||||
a, row_indices_t, grad_t, row_offsets_t, column_indices_t, n
|
||||
)
|
||||
|
||||
return a_grad, b_grad, None, None, None, None
|
||||
|
||||
|
||||
class _spmm(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx, b, row_indices, values, row_offsets, column_indices, m, _transp_info
|
||||
):
|
||||
b = b.contiguous()
|
||||
out = torch.ops.xformers.spmm_sputnik(
|
||||
b, row_indices, values, row_offsets, column_indices, m
|
||||
)
|
||||
|
||||
ctx.save_for_backward(
|
||||
b, row_indices, values, row_offsets, column_indices, *_transp_info
|
||||
)
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad):
|
||||
(
|
||||
b,
|
||||
row_indices,
|
||||
values,
|
||||
row_offsets,
|
||||
column_indices,
|
||||
*_transp_info,
|
||||
) = ctx.saved_tensors
|
||||
k = b.shape[1]
|
||||
|
||||
# gradients w.r.t. values
|
||||
grad = grad.contiguous()
|
||||
|
||||
grad_sparse = _sddmm_func(grad, b, row_indices, row_offsets, column_indices)
|
||||
|
||||
(
|
||||
row_indices_t,
|
||||
values_t,
|
||||
row_offsets_t,
|
||||
column_indices_t,
|
||||
) = _transpose_with_info(values, _transp_info)
|
||||
|
||||
grad_dense = torch.ops.xformers.spmm_sputnik(
|
||||
grad, row_indices_t, values_t, row_offsets_t, column_indices_t, k
|
||||
)
|
||||
|
||||
return grad_dense, None, grad_sparse, None, None, None, None
|
||||
Reference in New Issue
Block a user