138 lines
4.8 KiB
Python
138 lines
4.8 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
import os
|
|
from typing import Any, Dict
|
|
|
|
import torch
|
|
import triton
|
|
|
|
from xformers.benchmarks.utils import TestCase, pretty_plot, pretty_print
|
|
from xformers.components.attention.attention_mask import AttentionMask
|
|
from xformers.components.attention.core import scaled_dot_product_attention
|
|
|
|
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
|
|
|
|
SHAPES = [
|
|
(8, 128, 2096),
|
|
(8, 1024, 256),
|
|
(12, 512, 1024),
|
|
(128, 128, 512),
|
|
(8, 2048, 4096),
|
|
(16, 1024, 5120),
|
|
(512, 128, 2560),
|
|
]
|
|
|
|
BLOCK_SIZES = [128]
|
|
N_HEADS = [8, 32]
|
|
|
|
|
|
def bench_blocksparse_compare(backward: bool):
|
|
device = torch.device("cuda")
|
|
bw = "+bw" if backward else ""
|
|
use_amp = True
|
|
_use_cuda = True
|
|
|
|
for dtype in [torch.float16, torch.float32]:
|
|
datatype = "fp16" if dtype == torch.float16 else "fp32"
|
|
results: Dict[str, Any] = {}
|
|
results_mem: Dict[str, Any] = {}
|
|
for BS in BLOCK_SIZES:
|
|
for heads in N_HEADS:
|
|
for B, M, K in SHAPES:
|
|
q = torch.randn(
|
|
(B, heads, M, K // heads),
|
|
requires_grad=backward,
|
|
device=device,
|
|
dtype=dtype,
|
|
)
|
|
k = q
|
|
v = q
|
|
|
|
# Mask with causal flag
|
|
m_att_mask = AttentionMask.make_causal(
|
|
M, M, device=device, dtype=dtype
|
|
)
|
|
# Custom causal tensor mask
|
|
m_custom = torch.triu(
|
|
torch.ones(M, M, device=device, dtype=dtype) * float("-inf"),
|
|
diagonal=1,
|
|
)
|
|
|
|
def blocksparse_attention():
|
|
with torch.cuda.amp.autocast(enabled=use_amp):
|
|
y = scaled_dot_product_attention(
|
|
q=q, k=k, v=v, att_mask=m_att_mask, block_size=BS
|
|
)
|
|
if backward:
|
|
torch.norm(y).backward()
|
|
return y
|
|
|
|
def sdp_attention():
|
|
with torch.cuda.amp.autocast(enabled=use_amp):
|
|
y = scaled_dot_product_attention(
|
|
q=q, k=k, v=v, att_mask=m_custom, block_size=BS
|
|
)
|
|
if backward:
|
|
torch.norm(y).backward()
|
|
return y
|
|
|
|
for testcase in [
|
|
TestCase(blocksparse_attention, f"blocksparse - fw{bw}"),
|
|
TestCase(sdp_attention, f"standard sdp - fw{bw}"),
|
|
]:
|
|
if _use_cuda:
|
|
torch.cuda.empty_cache()
|
|
torch.cuda.reset_peak_memory_stats()
|
|
torch.cuda.synchronize()
|
|
time = triton.testing.do_bench(testcase.function)[0]
|
|
|
|
if _use_cuda:
|
|
torch.cuda.synchronize()
|
|
max_memory = torch.cuda.max_memory_allocated() / 2**20
|
|
else:
|
|
max_memory = -1
|
|
|
|
key = f"B={B},M={M},K={K},NH={heads}"
|
|
|
|
if key not in results_mem:
|
|
results_mem[key] = {}
|
|
results_mem[key][testcase.name] = f"{max_memory:.1f}"
|
|
|
|
if key not in results:
|
|
results[key] = {}
|
|
results[key][testcase.name] = f"{time:.2f}"
|
|
|
|
pretty_print(
|
|
results,
|
|
title=f"\n --- Type: {datatype} Block Size: {BS} --- ",
|
|
units="runtime in ms",
|
|
)
|
|
pretty_print(
|
|
results_mem,
|
|
title=f"\n --- Type: {datatype} Block Size: {BS} --- ",
|
|
units="peak memory usage in MB",
|
|
)
|
|
|
|
pretty_plot(
|
|
results,
|
|
title=f"Causal Blocksparse Runtime FW{bw.upper()} {datatype} Blocksize:{BS}",
|
|
units="runtime in ms",
|
|
dash_key="torch",
|
|
legend_loc="upper left",
|
|
)
|
|
pretty_plot(
|
|
results_mem,
|
|
title=f"Causal Blocksparse Memory FW{bw.upper()} {datatype} Blocksize:{BS}",
|
|
units="peak memory usage in MB",
|
|
dash_key="torch",
|
|
legend_loc="upper left",
|
|
)
|
|
|
|
|
|
for bw in [False, True]:
|
|
bench_blocksparse_compare(bw)
|