First commit

This commit is contained in:
2025-08-05 19:02:46 +08:00
parent 9efe891f99
commit 99fb9f5cb0
1412 changed files with 203615 additions and 0 deletions

View File

@@ -0,0 +1,150 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Benchmark the blocksparse operations:
# matrix multiply and softmax
# Matmul can be of three types:
# - Dense x Dense (COO) -> Sparse
# - Sparse x Dense -> Dense
# - Dense x Sparse -> Dense
from typing import Any, Dict
import torch
import triton
from triton.ops.blocksparse import matmul as blocksparse_matmul
from xformers.benchmarks.utils import TestCase, pretty_plot, pretty_print
from xformers.components.attention.core import SparseCS, _matmul_with_mask
def bench_matmul(dtype: torch.dtype, shapes):
results: Dict[str, Any] = {}
Z, H = 1, 1
for M, N, K in shapes:
modes = [(mode, block) for mode in ["sdd", "dsd"] for block in [16, 32, 64]]
for mode, block in modes:
# create inputs
a = torch.randn((Z, H, M, K), dtype=dtype, device="cuda")
b = torch.randn((Z, H, K, N), dtype=dtype, device="cuda")
shape = {
"sdd": (M, N),
"dsd": (a.shape[2], a.shape[3]),
"dds": (b.shape[2], b.shape[3]),
}[mode]
# Pre-sparsify everything
_layout = torch.eye(shape[0] // block, shape[1] // block, dtype=torch.long)
# - blocksparse
layout = _layout.unsqueeze(0).expand(H, -1, -1)
a_triton = (
triton.testing.sparsify_tensor(a, layout, block) if mode == "dsd" else a
)
b_triton = (
triton.testing.sparsify_tensor(b, layout, block) if mode == "dds" else b
)
bsmm = blocksparse_matmul(
layout=layout,
block=block,
mode=mode,
device=torch.device("cuda"),
trans_a=False,
trans_b=False,
)
# - dense
ta = triton.testing.mask_tensor(a, layout, block) if mode == "dsd" else a
tb = triton.testing.mask_tensor(b, layout, block) if mode == "dds" else b
# - sparse / sputnik
mask = torch.ones_like(a, dtype=torch.float, device="cuda")
mask = triton.testing.mask_tensor(mask, layout, block, value=0.0)
a_cs = a.flatten(start_dim=0, end_dim=1).to(
torch.float32
) # Sputnik kernels only handle fp32
b_cs = b.flatten(start_dim=0, end_dim=1).to(torch.float32)
a_cs = a_cs.contiguous()
b_cs = b_cs.transpose(-2, -1).contiguous()
if mode == "sdd":
b_cs = b_cs.transpose(-2, -1)
# pyre-fixme[16]: TODO(T101400990): Pyre did not recognize the
# `SparseCS` import.
sparse_cs_mask = SparseCS(
mask.flatten(start_dim=0, end_dim=1).contiguous(),
device=torch.device("cuda"),
)
# The raw compute steps
op_flops = {
"sdd": 2 * Z * K * float(layout.sum()) * block * block,
"dsd": 2 * Z * N * float(layout.sum()) * block * block,
"dds": 2 * Z * M * float(layout.sum()) * block * block,
}[
mode
] * 1e-12 # TFlops
def torch_step():
return torch.matmul(ta, tb)
def triton_step():
return bsmm(a_triton, b_triton)
def sparse_step():
if mode == "sdd":
return _matmul_with_mask(a_cs, b_cs, sparse_cs_mask)
else:
return sparse_cs_mask.spmm(b_cs)
# Run and measure, report perf in terms of TFlops
for testcase in [
TestCase(
torch_step,
f"pytorch - {mode} - {block}: ",
),
TestCase(
sparse_step,
f"sparse - {mode} - {block}: ",
),
TestCase(
triton_step,
f"triton - {mode} - {block}: ",
),
]:
ms = triton.testing.do_bench(lambda: testcase.function())[0]
key = f"M={M}, N={N}, K={K}"
if key not in results:
results[key] = {}
num_flops = op_flops / ms * 1e3 # Get to TFlop per second
results[key][testcase.name] = f"{num_flops:.1f}"
print(f"{key} - {testcase.name} - {num_flops:.2f}TFlops")
pretty_print(
results,
title="\n ------------- Type: {} -------------".format(dtype),
units="TFlops/s",
)
pretty_plot(
results,
title=f"Sparse/Blocksparse throughput - {dtype}",
filename=f"blocksparse_{dtype}.png",
dash_key="pytorch",
units="TFlops/s",
)
shapes = [(k, k, k) for k in [128, 512, 1024, 2048, 4096]]
bench_matmul(torch.float16, shapes)
bench_matmul(torch.float32, shapes)