# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. import itertools import torch from torch.utils import benchmark from xformers.components.attention._sputnik_sparse import _csr_to_coo from xformers.components.attention.core import SparseCS, _create_random_sparsity MIN_RUN_TIME = 0.2 def _get_fn(backend): if backend == "csr_ge": fn = torch.ops.xformers.csr_sddmm elif backend == "csr_sputnik": fn = torch.ops.xformers.sddmm_sputnik elif backend == "coo_ge": def fn(a, b, row_indices, row_offsets, column_indices): row_coo, _ = _csr_to_coo( a.shape[-2], b.shape[-2], row_offsets, column_indices ) return torch.ops.xformers.coo_sddmm( a, b, row_indices, row_coo, column_indices ) elif backend == "csr_to_coo": def fn(a, b, row_indices, row_offsets, column_indices): row_coo, _ = _csr_to_coo( a.shape[-2], b.shape[-2], row_offsets, column_indices ) return row_coo return fn def bench_sddmm(configs): min_run_time = MIN_RUN_TIME device = torch.device("cuda") results = [] for (B, M, K), prob in configs: a = torch.rand(B, M, K, device=device) b = torch.rand(B, M, K, device=device) mask = _create_random_sparsity( torch.ones(1, M, M, dtype=torch.bool), prob, divisible_by=16 ) aa = a bb = b mask = SparseCS(mask, device) row_indices = mask.row_indices row_offsets = mask.row_offsets column_indices = mask.column_indices for backend in ["csr_sputnik", "csr_ge", "coo_ge", "csr_to_coo"]: fn_str = "fn(a, b, row_indices, row_offsets, column_indices)" fn = _get_fn(backend) results.append( benchmark.Timer( stmt=fn_str, globals={ "a": aa, "b": bb, "mask": mask, "row_indices": row_indices, "row_offsets": row_offsets, "column_indices": column_indices, "fn": fn, }, label="sddmm", sub_label=f"B={B:>4d}, M={M:>4d}, K={K:>3d}, prob={prob:0.4f}", description=backend, ).blocked_autorange(min_run_time=min_run_time) ) compare = benchmark.Compare(results) compare.print() return results # batch size 32, for different layers SWIN_T_SIZES = [(96, 3136, 32), (192, 784, 32), (384, 196, 32), (768, 49, 32)] swin_t_config = list(zip(SWIN_T_SIZES, (0.9844, 0.9375, 0.75, 0.0))) # some random values BASIC_SIZES = [(32, 1024, 32), (32, 1024, 128), (8, 4096, 32), (8, 4096, 128)] SPARSITIES = [0.90, 0.93, 0.95, 0.97, 0.98, 0.99, 0.995, 0.999] basic_config = list(itertools.product(BASIC_SIZES, SPARSITIES)) # batch size 32 here vit_sizes = [ (192, 785, 64), # deit_small_patch8_224 (192, 197, 64), # deit_small_patch16_224 (384, 785, 64), # deit_base_patch8_224 (384, 197, 64), # deit_base_patch16_224 ] SPARSITIES = [0.70, 0.80, 0.85, 0.90, 0.93, 0.95, 0.97] vit_config = list(itertools.product(vit_sizes, SPARSITIES)) results = [] print("Swin Transformer") results += bench_sddmm(swin_t_config) print("ViT") results += bench_sddmm(vit_config) print("Basic cases") results += bench_sddmm(basic_config)