151 lines
5.0 KiB
Python
151 lines
5.0 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
# Benchmark the blocksparse operations:
|
|
# matrix multiply and softmax
|
|
|
|
# Matmul can be of three types:
|
|
# - Dense x Dense (COO) -> Sparse
|
|
# - Sparse x Dense -> Dense
|
|
# - Dense x Sparse -> Dense
|
|
|
|
from typing import Any, Dict
|
|
|
|
import torch
|
|
import triton
|
|
from triton.ops.blocksparse import matmul as blocksparse_matmul
|
|
|
|
from xformers.benchmarks.utils import TestCase, pretty_plot, pretty_print
|
|
from xformers.components.attention.core import SparseCS, _matmul_with_mask
|
|
|
|
|
|
def bench_matmul(dtype: torch.dtype, shapes):
|
|
results: Dict[str, Any] = {}
|
|
Z, H = 1, 1
|
|
|
|
for M, N, K in shapes:
|
|
|
|
modes = [(mode, block) for mode in ["sdd", "dsd"] for block in [16, 32, 64]]
|
|
|
|
for mode, block in modes:
|
|
# create inputs
|
|
a = torch.randn((Z, H, M, K), dtype=dtype, device="cuda")
|
|
b = torch.randn((Z, H, K, N), dtype=dtype, device="cuda")
|
|
shape = {
|
|
"sdd": (M, N),
|
|
"dsd": (a.shape[2], a.shape[3]),
|
|
"dds": (b.shape[2], b.shape[3]),
|
|
}[mode]
|
|
|
|
# Pre-sparsify everything
|
|
_layout = torch.eye(shape[0] // block, shape[1] // block, dtype=torch.long)
|
|
|
|
# - blocksparse
|
|
layout = _layout.unsqueeze(0).expand(H, -1, -1)
|
|
a_triton = (
|
|
triton.testing.sparsify_tensor(a, layout, block) if mode == "dsd" else a
|
|
)
|
|
b_triton = (
|
|
triton.testing.sparsify_tensor(b, layout, block) if mode == "dds" else b
|
|
)
|
|
bsmm = blocksparse_matmul(
|
|
layout=layout,
|
|
block=block,
|
|
mode=mode,
|
|
device=torch.device("cuda"),
|
|
trans_a=False,
|
|
trans_b=False,
|
|
)
|
|
|
|
# - dense
|
|
ta = triton.testing.mask_tensor(a, layout, block) if mode == "dsd" else a
|
|
tb = triton.testing.mask_tensor(b, layout, block) if mode == "dds" else b
|
|
|
|
# - sparse / sputnik
|
|
mask = torch.ones_like(a, dtype=torch.float, device="cuda")
|
|
mask = triton.testing.mask_tensor(mask, layout, block, value=0.0)
|
|
a_cs = a.flatten(start_dim=0, end_dim=1).to(
|
|
torch.float32
|
|
) # Sputnik kernels only handle fp32
|
|
b_cs = b.flatten(start_dim=0, end_dim=1).to(torch.float32)
|
|
a_cs = a_cs.contiguous()
|
|
b_cs = b_cs.transpose(-2, -1).contiguous()
|
|
|
|
if mode == "sdd":
|
|
b_cs = b_cs.transpose(-2, -1)
|
|
|
|
# pyre-fixme[16]: TODO(T101400990): Pyre did not recognize the
|
|
# `SparseCS` import.
|
|
sparse_cs_mask = SparseCS(
|
|
mask.flatten(start_dim=0, end_dim=1).contiguous(),
|
|
device=torch.device("cuda"),
|
|
)
|
|
|
|
# The raw compute steps
|
|
op_flops = {
|
|
"sdd": 2 * Z * K * float(layout.sum()) * block * block,
|
|
"dsd": 2 * Z * N * float(layout.sum()) * block * block,
|
|
"dds": 2 * Z * M * float(layout.sum()) * block * block,
|
|
}[
|
|
mode
|
|
] * 1e-12 # TFlops
|
|
|
|
def torch_step():
|
|
return torch.matmul(ta, tb)
|
|
|
|
def triton_step():
|
|
return bsmm(a_triton, b_triton)
|
|
|
|
def sparse_step():
|
|
if mode == "sdd":
|
|
return _matmul_with_mask(a_cs, b_cs, sparse_cs_mask)
|
|
else:
|
|
return sparse_cs_mask.spmm(b_cs)
|
|
|
|
# Run and measure, report perf in terms of TFlops
|
|
for testcase in [
|
|
TestCase(
|
|
torch_step,
|
|
f"pytorch - {mode} - {block}: ",
|
|
),
|
|
TestCase(
|
|
sparse_step,
|
|
f"sparse - {mode} - {block}: ",
|
|
),
|
|
TestCase(
|
|
triton_step,
|
|
f"triton - {mode} - {block}: ",
|
|
),
|
|
]:
|
|
ms = triton.testing.do_bench(lambda: testcase.function())[0]
|
|
key = f"M={M}, N={N}, K={K}"
|
|
|
|
if key not in results:
|
|
results[key] = {}
|
|
|
|
num_flops = op_flops / ms * 1e3 # Get to TFlop per second
|
|
results[key][testcase.name] = f"{num_flops:.1f}"
|
|
print(f"{key} - {testcase.name} - {num_flops:.2f}TFlops")
|
|
|
|
pretty_print(
|
|
results,
|
|
title="\n ------------- Type: {} -------------".format(dtype),
|
|
units="TFlops/s",
|
|
)
|
|
|
|
pretty_plot(
|
|
results,
|
|
title=f"Sparse/Blocksparse throughput - {dtype}",
|
|
filename=f"blocksparse_{dtype}.png",
|
|
dash_key="pytorch",
|
|
units="TFlops/s",
|
|
)
|
|
|
|
|
|
shapes = [(k, k, k) for k in [128, 512, 1024, 2048, 4096]]
|
|
bench_matmul(torch.float16, shapes)
|
|
bench_matmul(torch.float32, shapes)
|