First commit
This commit is contained in:
105
pkgs/xformers/benchmarks/benchmark_multi_head_dispatch.py
Normal file
105
pkgs/xformers/benchmarks/benchmark_multi_head_dispatch.py
Normal file
@@ -0,0 +1,105 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import triton
|
||||
|
||||
from xformers.benchmarks.utils import TestCase, pretty_plot, pretty_print
|
||||
from xformers.components import MultiHeadDispatch
|
||||
from xformers.components.attention import ScaledDotProduct
|
||||
|
||||
SHAPES = [
|
||||
(8, 384, 128),
|
||||
(8, 784, 512),
|
||||
(4, 1024, 768),
|
||||
(4, 2048, 1024),
|
||||
(2, 2048, 2048),
|
||||
(2, 2048, 4096),
|
||||
(2, 4096, 4096),
|
||||
(1, 2048, 12288),
|
||||
]
|
||||
|
||||
N_HEADS = [4]
|
||||
|
||||
|
||||
def bench_multihead_dispatch(backward: bool, self_attention: bool):
|
||||
device = torch.device("cuda")
|
||||
bw = "+bw" if backward else ""
|
||||
sa = " (self_attn)" if self_attention else ""
|
||||
|
||||
for dtype in [torch.float16, torch.float32]:
|
||||
results: Dict[str, Any] = {}
|
||||
|
||||
for B, M, K in SHAPES:
|
||||
for heads in N_HEADS:
|
||||
xf_multi_head = MultiHeadDispatch(
|
||||
dim_model=K,
|
||||
residual_dropout=0.0,
|
||||
num_heads=heads,
|
||||
attention=ScaledDotProduct(),
|
||||
bias=(True, True, True, True),
|
||||
).to(device=device, dtype=dtype)
|
||||
torch_multi_head = nn.MultiheadAttention(
|
||||
embed_dim=K, num_heads=heads, batch_first=True
|
||||
).to(device=device, dtype=dtype)
|
||||
|
||||
q = torch.randn(
|
||||
(B, M, K), requires_grad=backward, device=device, dtype=dtype
|
||||
)
|
||||
|
||||
if self_attention:
|
||||
k = q
|
||||
v = q
|
||||
else:
|
||||
k = torch.randn(
|
||||
(B, M, K), requires_grad=backward, device=device, dtype=dtype
|
||||
)
|
||||
v = torch.randn(
|
||||
(B, M, K), requires_grad=backward, device=device, dtype=dtype
|
||||
)
|
||||
|
||||
def torch_mha():
|
||||
y, _ = torch_multi_head(query=q, key=k, value=v)
|
||||
if backward:
|
||||
torch.norm(y).backward()
|
||||
return y
|
||||
|
||||
def xformers_mha():
|
||||
y = xf_multi_head(query=q, key=k, value=v)
|
||||
if backward:
|
||||
torch.norm(y).backward()
|
||||
return y
|
||||
|
||||
for testcase in [
|
||||
TestCase(torch_mha, f"torch - fw{bw}{sa}"),
|
||||
TestCase(xformers_mha, f"xf - fw{bw}{sa}"),
|
||||
]:
|
||||
time = triton.testing.do_bench(testcase.function)[0]
|
||||
key = f"B={B}, M={M}, K={K}, N_HEADS={heads}"
|
||||
if key not in results:
|
||||
results[key] = {}
|
||||
|
||||
results[key][testcase.name] = f"{time:.2f}"
|
||||
|
||||
pretty_print(
|
||||
results,
|
||||
title=f"\n --- Type: {dtype} --- ",
|
||||
units="runtime in ms, lower is better",
|
||||
)
|
||||
pretty_plot(
|
||||
results,
|
||||
title=f"MHA-FW{bw}-{dtype}",
|
||||
units="runtime in ms, lower is better",
|
||||
dash_key="torch",
|
||||
)
|
||||
|
||||
|
||||
for bw in [False, True]:
|
||||
for self_attention in [False, True]:
|
||||
bench_multihead_dispatch(bw, self_attention)
|
||||
Reference in New Issue
Block a user