Files
2025-08-05 19:02:46 +08:00

167 lines
4.8 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import torch
from .utils import _csr_to_coo, _transpose_with_info
def _should_use_coo(a, sparsity):
if not a.is_cuda:
return False
B, M, K = a.shape
# amortize overhead of converting from csr to coo
if B < 32 and M < 4096:
return False
if sparsity > 0.995:
return False
if sparsity < 0.9:
return False
if K > 64:
return False
# let's be overly cautious here for now
return sparsity > 0.97
def _should_use_csr_ge(a, sparsity):
if not a.is_cuda:
return False
return sparsity > 0.99
def _sddmm_func(a, b, row_indices, row_offsets, column_indices):
sparsity = 1 - column_indices.shape[0] / (a.shape[1] * b.shape[1])
if _should_use_coo(a, sparsity):
m = a.shape[-2]
n = b.shape[-2]
# converting from csr to coo has a constant overhead of ~150us
# so only dispatch to it for reasonably large problem sizes
ro, ci = _csr_to_coo(m, n, row_offsets, column_indices)
return torch.ops.xformers.coo_sddmm(a, b, row_indices, ro, ci)
elif _should_use_csr_ge(a, sparsity):
return torch.ops.xformers.csr_sddmm(
a, b, row_indices, row_offsets, column_indices
)
return torch.ops.xformers.sddmm_sputnik(
a, b, row_indices, row_offsets, column_indices
)
class _SparseSoftmax(torch.autograd.Function):
@staticmethod
def forward(ctx, m, n, row_indices, values, row_offsets, column_indices):
out = torch.ops.xformers.sparse_softmax_sputnik(
m, n, row_indices, values, row_offsets, column_indices
)
# note: save out and not values, as an optimization step
ctx.save_for_backward(row_indices, out, row_offsets, column_indices)
ctx.size = (m, n)
return out
@staticmethod
def backward(ctx, grad):
row_indices, out, row_offsets, column_indices = ctx.saved_tensors
m, n = ctx.size
# gradients w.r.t. values
grad = grad.contiguous()
ga = torch.ops.xformers.sparse_softmax_backward_sputnik(
m, n, row_indices, out, grad, row_offsets, column_indices
)
return None, None, None, ga, None, None
class _sddmm(torch.autograd.Function):
@staticmethod
def forward(ctx, a, b, row_indices, row_offsets, column_indices, _transp_info):
out = _sddmm_func(a, b, row_indices, row_offsets, column_indices)
ctx.save_for_backward(
a, b, row_indices, row_offsets, column_indices, *_transp_info
)
return out
@staticmethod
def backward(ctx, grad):
(
a,
b,
row_indices,
row_offsets,
column_indices,
*_transp_info,
) = ctx.saved_tensors
m, n = a.shape[1], b.shape[1]
# gradients w.r.t. values
grad = grad.contiguous()
a = a.contiguous()
b = b.contiguous()
a_grad = torch.ops.xformers.spmm_sputnik(
b, row_indices, grad, row_offsets, column_indices, m
)
(
row_indices_t,
grad_t,
row_offsets_t,
column_indices_t,
) = _transpose_with_info(grad, _transp_info)
b_grad = torch.ops.xformers.spmm_sputnik(
a, row_indices_t, grad_t, row_offsets_t, column_indices_t, n
)
return a_grad, b_grad, None, None, None, None
class _spmm(torch.autograd.Function):
@staticmethod
def forward(
ctx, b, row_indices, values, row_offsets, column_indices, m, _transp_info
):
b = b.contiguous()
out = torch.ops.xformers.spmm_sputnik(
b, row_indices, values, row_offsets, column_indices, m
)
ctx.save_for_backward(
b, row_indices, values, row_offsets, column_indices, *_transp_info
)
return out
@staticmethod
def backward(ctx, grad):
(
b,
row_indices,
values,
row_offsets,
column_indices,
*_transp_info,
) = ctx.saved_tensors
k = b.shape[1]
# gradients w.r.t. values
grad = grad.contiguous()
grad_sparse = _sddmm_func(grad, b, row_indices, row_offsets, column_indices)
(
row_indices_t,
values_t,
row_offsets_t,
column_indices_t,
) = _transpose_with_info(values, _transp_info)
grad_dense = torch.ops.xformers.spmm_sputnik(
grad, row_indices_t, values_t, row_offsets_t, column_indices_t, k
)
return grad_dense, None, grad_sparse, None, None, None, None