First commit
This commit is contained in:
137
pkgs/xformers/benchmarks/benchmark_causal_blocksparse.py
Normal file
137
pkgs/xformers/benchmarks/benchmark_causal_blocksparse.py
Normal file
@@ -0,0 +1,137 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import os
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
import triton
|
||||
|
||||
from xformers.benchmarks.utils import TestCase, pretty_plot, pretty_print
|
||||
from xformers.components.attention.attention_mask import AttentionMask
|
||||
from xformers.components.attention.core import scaled_dot_product_attention
|
||||
|
||||
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
|
||||
|
||||
SHAPES = [
|
||||
(8, 128, 2096),
|
||||
(8, 1024, 256),
|
||||
(12, 512, 1024),
|
||||
(128, 128, 512),
|
||||
(8, 2048, 4096),
|
||||
(16, 1024, 5120),
|
||||
(512, 128, 2560),
|
||||
]
|
||||
|
||||
BLOCK_SIZES = [128]
|
||||
N_HEADS = [8, 32]
|
||||
|
||||
|
||||
def bench_blocksparse_compare(backward: bool):
|
||||
device = torch.device("cuda")
|
||||
bw = "+bw" if backward else ""
|
||||
use_amp = True
|
||||
_use_cuda = True
|
||||
|
||||
for dtype in [torch.float16, torch.float32]:
|
||||
datatype = "fp16" if dtype == torch.float16 else "fp32"
|
||||
results: Dict[str, Any] = {}
|
||||
results_mem: Dict[str, Any] = {}
|
||||
for BS in BLOCK_SIZES:
|
||||
for heads in N_HEADS:
|
||||
for B, M, K in SHAPES:
|
||||
q = torch.randn(
|
||||
(B, heads, M, K // heads),
|
||||
requires_grad=backward,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
k = q
|
||||
v = q
|
||||
|
||||
# Mask with causal flag
|
||||
m_att_mask = AttentionMask.make_causal(
|
||||
M, M, device=device, dtype=dtype
|
||||
)
|
||||
# Custom causal tensor mask
|
||||
m_custom = torch.triu(
|
||||
torch.ones(M, M, device=device, dtype=dtype) * float("-inf"),
|
||||
diagonal=1,
|
||||
)
|
||||
|
||||
def blocksparse_attention():
|
||||
with torch.cuda.amp.autocast(enabled=use_amp):
|
||||
y = scaled_dot_product_attention(
|
||||
q=q, k=k, v=v, att_mask=m_att_mask, block_size=BS
|
||||
)
|
||||
if backward:
|
||||
torch.norm(y).backward()
|
||||
return y
|
||||
|
||||
def sdp_attention():
|
||||
with torch.cuda.amp.autocast(enabled=use_amp):
|
||||
y = scaled_dot_product_attention(
|
||||
q=q, k=k, v=v, att_mask=m_custom, block_size=BS
|
||||
)
|
||||
if backward:
|
||||
torch.norm(y).backward()
|
||||
return y
|
||||
|
||||
for testcase in [
|
||||
TestCase(blocksparse_attention, f"blocksparse - fw{bw}"),
|
||||
TestCase(sdp_attention, f"standard sdp - fw{bw}"),
|
||||
]:
|
||||
if _use_cuda:
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
torch.cuda.synchronize()
|
||||
time = triton.testing.do_bench(testcase.function)[0]
|
||||
|
||||
if _use_cuda:
|
||||
torch.cuda.synchronize()
|
||||
max_memory = torch.cuda.max_memory_allocated() / 2**20
|
||||
else:
|
||||
max_memory = -1
|
||||
|
||||
key = f"B={B},M={M},K={K},NH={heads}"
|
||||
|
||||
if key not in results_mem:
|
||||
results_mem[key] = {}
|
||||
results_mem[key][testcase.name] = f"{max_memory:.1f}"
|
||||
|
||||
if key not in results:
|
||||
results[key] = {}
|
||||
results[key][testcase.name] = f"{time:.2f}"
|
||||
|
||||
pretty_print(
|
||||
results,
|
||||
title=f"\n --- Type: {datatype} Block Size: {BS} --- ",
|
||||
units="runtime in ms",
|
||||
)
|
||||
pretty_print(
|
||||
results_mem,
|
||||
title=f"\n --- Type: {datatype} Block Size: {BS} --- ",
|
||||
units="peak memory usage in MB",
|
||||
)
|
||||
|
||||
pretty_plot(
|
||||
results,
|
||||
title=f"Causal Blocksparse Runtime FW{bw.upper()} {datatype} Blocksize:{BS}",
|
||||
units="runtime in ms",
|
||||
dash_key="torch",
|
||||
legend_loc="upper left",
|
||||
)
|
||||
pretty_plot(
|
||||
results_mem,
|
||||
title=f"Causal Blocksparse Memory FW{bw.upper()} {datatype} Blocksize:{BS}",
|
||||
units="peak memory usage in MB",
|
||||
dash_key="torch",
|
||||
legend_loc="upper left",
|
||||
)
|
||||
|
||||
|
||||
for bw in [False, True]:
|
||||
bench_blocksparse_compare(bw)
|
||||
Reference in New Issue
Block a user