First commit
This commit is contained in:
150
pkgs/xformers/benchmarks/benchmark_triton_blocksparse.py
Normal file
150
pkgs/xformers/benchmarks/benchmark_triton_blocksparse.py
Normal file
@@ -0,0 +1,150 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
# Benchmark the blocksparse operations:
|
||||
# matrix multiply and softmax
|
||||
|
||||
# Matmul can be of three types:
|
||||
# - Dense x Dense (COO) -> Sparse
|
||||
# - Sparse x Dense -> Dense
|
||||
# - Dense x Sparse -> Dense
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
import triton
|
||||
from triton.ops.blocksparse import matmul as blocksparse_matmul
|
||||
|
||||
from xformers.benchmarks.utils import TestCase, pretty_plot, pretty_print
|
||||
from xformers.components.attention.core import SparseCS, _matmul_with_mask
|
||||
|
||||
|
||||
def bench_matmul(dtype: torch.dtype, shapes):
|
||||
results: Dict[str, Any] = {}
|
||||
Z, H = 1, 1
|
||||
|
||||
for M, N, K in shapes:
|
||||
|
||||
modes = [(mode, block) for mode in ["sdd", "dsd"] for block in [16, 32, 64]]
|
||||
|
||||
for mode, block in modes:
|
||||
# create inputs
|
||||
a = torch.randn((Z, H, M, K), dtype=dtype, device="cuda")
|
||||
b = torch.randn((Z, H, K, N), dtype=dtype, device="cuda")
|
||||
shape = {
|
||||
"sdd": (M, N),
|
||||
"dsd": (a.shape[2], a.shape[3]),
|
||||
"dds": (b.shape[2], b.shape[3]),
|
||||
}[mode]
|
||||
|
||||
# Pre-sparsify everything
|
||||
_layout = torch.eye(shape[0] // block, shape[1] // block, dtype=torch.long)
|
||||
|
||||
# - blocksparse
|
||||
layout = _layout.unsqueeze(0).expand(H, -1, -1)
|
||||
a_triton = (
|
||||
triton.testing.sparsify_tensor(a, layout, block) if mode == "dsd" else a
|
||||
)
|
||||
b_triton = (
|
||||
triton.testing.sparsify_tensor(b, layout, block) if mode == "dds" else b
|
||||
)
|
||||
bsmm = blocksparse_matmul(
|
||||
layout=layout,
|
||||
block=block,
|
||||
mode=mode,
|
||||
device=torch.device("cuda"),
|
||||
trans_a=False,
|
||||
trans_b=False,
|
||||
)
|
||||
|
||||
# - dense
|
||||
ta = triton.testing.mask_tensor(a, layout, block) if mode == "dsd" else a
|
||||
tb = triton.testing.mask_tensor(b, layout, block) if mode == "dds" else b
|
||||
|
||||
# - sparse / sputnik
|
||||
mask = torch.ones_like(a, dtype=torch.float, device="cuda")
|
||||
mask = triton.testing.mask_tensor(mask, layout, block, value=0.0)
|
||||
a_cs = a.flatten(start_dim=0, end_dim=1).to(
|
||||
torch.float32
|
||||
) # Sputnik kernels only handle fp32
|
||||
b_cs = b.flatten(start_dim=0, end_dim=1).to(torch.float32)
|
||||
a_cs = a_cs.contiguous()
|
||||
b_cs = b_cs.transpose(-2, -1).contiguous()
|
||||
|
||||
if mode == "sdd":
|
||||
b_cs = b_cs.transpose(-2, -1)
|
||||
|
||||
# pyre-fixme[16]: TODO(T101400990): Pyre did not recognize the
|
||||
# `SparseCS` import.
|
||||
sparse_cs_mask = SparseCS(
|
||||
mask.flatten(start_dim=0, end_dim=1).contiguous(),
|
||||
device=torch.device("cuda"),
|
||||
)
|
||||
|
||||
# The raw compute steps
|
||||
op_flops = {
|
||||
"sdd": 2 * Z * K * float(layout.sum()) * block * block,
|
||||
"dsd": 2 * Z * N * float(layout.sum()) * block * block,
|
||||
"dds": 2 * Z * M * float(layout.sum()) * block * block,
|
||||
}[
|
||||
mode
|
||||
] * 1e-12 # TFlops
|
||||
|
||||
def torch_step():
|
||||
return torch.matmul(ta, tb)
|
||||
|
||||
def triton_step():
|
||||
return bsmm(a_triton, b_triton)
|
||||
|
||||
def sparse_step():
|
||||
if mode == "sdd":
|
||||
return _matmul_with_mask(a_cs, b_cs, sparse_cs_mask)
|
||||
else:
|
||||
return sparse_cs_mask.spmm(b_cs)
|
||||
|
||||
# Run and measure, report perf in terms of TFlops
|
||||
for testcase in [
|
||||
TestCase(
|
||||
torch_step,
|
||||
f"pytorch - {mode} - {block}: ",
|
||||
),
|
||||
TestCase(
|
||||
sparse_step,
|
||||
f"sparse - {mode} - {block}: ",
|
||||
),
|
||||
TestCase(
|
||||
triton_step,
|
||||
f"triton - {mode} - {block}: ",
|
||||
),
|
||||
]:
|
||||
ms = triton.testing.do_bench(lambda: testcase.function())[0]
|
||||
key = f"M={M}, N={N}, K={K}"
|
||||
|
||||
if key not in results:
|
||||
results[key] = {}
|
||||
|
||||
num_flops = op_flops / ms * 1e3 # Get to TFlop per second
|
||||
results[key][testcase.name] = f"{num_flops:.1f}"
|
||||
print(f"{key} - {testcase.name} - {num_flops:.2f}TFlops")
|
||||
|
||||
pretty_print(
|
||||
results,
|
||||
title="\n ------------- Type: {} -------------".format(dtype),
|
||||
units="TFlops/s",
|
||||
)
|
||||
|
||||
pretty_plot(
|
||||
results,
|
||||
title=f"Sparse/Blocksparse throughput - {dtype}",
|
||||
filename=f"blocksparse_{dtype}.png",
|
||||
dash_key="pytorch",
|
||||
units="TFlops/s",
|
||||
)
|
||||
|
||||
|
||||
shapes = [(k, k, k) for k in [128, 512, 1024, 2048, 4096]]
|
||||
bench_matmul(torch.float16, shapes)
|
||||
bench_matmul(torch.float32, shapes)
|
||||
Reference in New Issue
Block a user