Files
2025-08-05 19:02:46 +08:00

124 lines
4.2 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import torch
def _coo_to_csr(m, n, row_indices, column_indices):
# assumes coalesced coo
row_offsets = row_indices.bincount(minlength=n).cumsum(0, dtype=row_indices.dtype)
row_offsets = torch.nn.functional.pad(row_offsets, (1, 0))
return row_offsets, column_indices
def _csr_to_coo(m, n, row_offsets, column_indices):
# convert from compressed rows to uncompressed
indices = torch.arange(m, dtype=row_offsets.dtype, device=row_offsets.device)
row_sizes = torch.diff(row_offsets)
row_coo = torch.repeat_interleave(indices, row_sizes.long())
return row_coo, column_indices
def _diffsort(a):
return torch.argsort(torch.diff(a), dim=0, descending=True)
def _get_transpose_info(m, n, row_indices, row_offsets, column_indices):
# strategy:
# - uncompress the rows to have data in COO format
# - get permutation for stable sort of the columns to get the rows for the transposed matrix
# - compress the new rows and return the permutation to be applied on the values
# convert from compressed rows to uncompressed
row_coo, _ = _csr_to_coo(m, n, row_offsets, column_indices)
# get the permutation for the stable sort
row_offsets_t, perm = column_indices.sort(dim=0, stable=True)
column_indices_t = row_coo[perm]
row_offsets_t, _ = _coo_to_csr(m, n, row_offsets_t, column_indices)
row_indices_t = _diffsort(row_offsets_t).int()
return row_indices_t, row_offsets_t, column_indices_t, perm
def _transpose_with_info(values, _transpose_info):
row_indices_t, row_offsets_t, column_indices_t, perm = _transpose_info
values_t = values[:, perm]
return row_indices_t, values_t, row_offsets_t, column_indices_t
def _transpose(m, n, row_indices, values, row_offsets, column_indices):
_transpose_info = _get_transpose_info(
m, n, row_indices, row_offsets, column_indices
)
return _transpose_with_info(values, _transpose_info)
def _nonzero_mask_to_sparse_csr_indices(mask, device):
"""Converts dense 2d matrix to a csr sparse matrix."""
assert len(mask.shape) == 2
index_dtype = torch.int32
# Calculate the offset of each row.
row_offsets = mask.sum(dim=-1, dtype=index_dtype).cumsum(dim=-1, dtype=index_dtype)
row_offsets = torch.nn.functional.pad(row_offsets, (1, 0))
# Create the row indices and sort them.
row_indices = _diffsort(row_offsets).to(index_dtype)
# Extract the column indices for the nonzero values.
column_indices = torch.where(mask)[1].to(index_dtype).contiguous()
row_indices = row_indices.to(device)
row_offsets = row_offsets.to(device)
column_indices = column_indices.to(device)
return row_indices, row_offsets, column_indices
def _dense_to_sparse(matrix, device):
"""Converts dense 2d matrix to a csr sparse matrix."""
assert len(matrix.shape) == 2
value_dtype = torch.float32
# Extract the nonzero values.
mask = matrix != 0
values = matrix[mask].to(dtype=value_dtype, device=device)
row_indices, row_offsets, column_indices = _nonzero_mask_to_sparse_csr_indices(
mask, device
)
return values, row_indices, row_offsets, column_indices
def _round_nnz(mask, divisible_by=4):
nonzero = torch.where(mask)
nnz = nonzero[0].shape[0]
nonzero = tuple(n[: (nnz - nnz % divisible_by)] for n in nonzero)
nm = torch.zeros_like(mask)
nm[nonzero] = True
return nm
def _dense3d_to_sparse(matrix, device):
assert len(matrix.shape) == 3
mask = matrix != 0
if not torch.all(mask == mask[0]):
raise ValueError("Expected the same sparsity pattern over the batch dimension")
# for now, our kernels assume that we have the number of
# nnz to be divisible by 4
mask = _round_nnz(mask[0], divisible_by=4)
mask = mask[None].expand(matrix.shape)
values = matrix[mask].reshape(matrix.shape[0], -1).to(device)
row_indices, row_offsets, column_indices = _nonzero_mask_to_sparse_csr_indices(
mask[0], device
)
return values, row_indices, row_offsets, column_indices