First commit
This commit is contained in:
187
pkgs/xformers/benchmarks/benchmark_mem_eff_attn_decoder.py
Normal file
187
pkgs/xformers/benchmarks/benchmark_mem_eff_attn_decoder.py
Normal file
@@ -0,0 +1,187 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import itertools
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
from torch.utils import benchmark
|
||||
from utils import benchmark_main_helper
|
||||
|
||||
import xformers.ops
|
||||
import xformers.ops.fmha as fmha
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
|
||||
# Run with
|
||||
# python xformers/benchmarks/benchmark_mem_eff_attn_decoder.py --omit-baselines --quiet
|
||||
# The baselines for these benchmarks are really slow because there is
|
||||
# so much padding in the inputs, so there is no point running them.
|
||||
|
||||
|
||||
def ref_attention_bmk(q, k, v, attn_bias=None):
|
||||
if isinstance(attn_bias, xformers.ops.AttentionMask):
|
||||
attn_bias = (
|
||||
attn_bias.materialize((q.shape[0], 1, q.shape[1], k.shape[1]))
|
||||
.to(q)
|
||||
.squeeze()
|
||||
)
|
||||
q = q * (1.0 / q.shape[-1] ** 0.5)
|
||||
if attn_bias is None:
|
||||
attn = q @ k.transpose(-2, -1)
|
||||
else:
|
||||
# equivalent to (q @ k.transpose(-2, -1) + m).softmax(-1) @ v
|
||||
# but faster, and is what is used in PyTorch now
|
||||
attn = torch.baddbmm(attn_bias, q, k.transpose(-2, -1))
|
||||
attn = attn.softmax(-1)
|
||||
return attn @ v
|
||||
|
||||
|
||||
def ref_attention(q, k, v, attn_bias):
|
||||
assert q.ndim == 4
|
||||
|
||||
def T(t):
|
||||
return t.permute((0, 2, 1, 3)).reshape(
|
||||
[t.shape[0] * t.shape[2], t.shape[1], t.shape[3]]
|
||||
)
|
||||
|
||||
out = ref_attention_bmk(T(q), T(k), T(v), attn_bias)
|
||||
out = out.reshape([q.shape[0], q.shape[2], q.shape[1], v.shape[3]])
|
||||
return out.permute((0, 2, 1, 3))
|
||||
|
||||
|
||||
min_run_time = 0.5
|
||||
device = torch.device("cuda")
|
||||
|
||||
NUM_THREADS = [1] if device.type == "cuda" else [1, 40]
|
||||
|
||||
OPS = [
|
||||
xformers.ops.fmha.cutlass.FwOp,
|
||||
xformers.ops.fmha.decoder.FwOp,
|
||||
]
|
||||
|
||||
KV_SHAPES = [
|
||||
# list of n_keys, padding_length, batchsize
|
||||
(2, 64, 3),
|
||||
(32, 1024, 500),
|
||||
(1000, 1024, 2),
|
||||
(8000, 8192, 1),
|
||||
(240, 256, 32),
|
||||
(2048, 2 * 1024, 4),
|
||||
(4096 * 2, 8 * 1024, 1),
|
||||
]
|
||||
|
||||
N_HEADS = [8, 16, 64]
|
||||
|
||||
|
||||
def product_dict(**kwargs):
|
||||
keys = kwargs.keys()
|
||||
vals = kwargs.values()
|
||||
for instance in itertools.product(*vals):
|
||||
yield dict(zip(keys, instance))
|
||||
|
||||
|
||||
CASES = list(
|
||||
product_dict(
|
||||
kv_shape=KV_SHAPES,
|
||||
n_heads=N_HEADS,
|
||||
num_threads=NUM_THREADS,
|
||||
multiquery=[True, False],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def mem_eff_attention_decoder(
|
||||
kv_shape, n_heads: int, num_threads: int, multiquery: bool
|
||||
):
|
||||
n_keys, padding, B = kv_shape
|
||||
torch.manual_seed(42)
|
||||
k_seqlen = torch.randint(1, n_keys + 1, (B,)).tolist()
|
||||
K = 128
|
||||
|
||||
q = torch.rand(1, B, n_heads, K, device=device, dtype=torch.bfloat16)
|
||||
if multiquery:
|
||||
k = torch.rand(
|
||||
1, B * padding, 1, K, device=device, dtype=torch.bfloat16
|
||||
).expand(1, B * padding, n_heads, K)
|
||||
v = torch.rand(
|
||||
1, B * padding, 1, K, device=device, dtype=torch.bfloat16
|
||||
).expand(1, B * padding, n_heads, K)
|
||||
else:
|
||||
k = torch.rand(1, B * padding, n_heads, K, device=device, dtype=torch.bfloat16)
|
||||
v = torch.rand(1, B * padding, n_heads, K, device=device, dtype=torch.bfloat16)
|
||||
|
||||
bias = fmha.attn_bias.BlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens(
|
||||
q_seqlen=[1] * B,
|
||||
kv_seqlen=k_seqlen,
|
||||
kv_padding=padding,
|
||||
)
|
||||
|
||||
sub_label = f"{B}batch-{k_seqlen[0]}keys-{n_heads}heads"
|
||||
if multiquery:
|
||||
sub_label += "-mq"
|
||||
|
||||
has_run = False
|
||||
for fw_op in OPS:
|
||||
inp = fmha.Inputs(q, k, v, attn_bias=bias)
|
||||
if not fw_op.supports(inp):
|
||||
continue
|
||||
|
||||
fn = partial(xformers.ops.memory_efficient_attention_forward, op=fw_op)
|
||||
|
||||
yield benchmark.Timer(
|
||||
stmt="fn(q, k, v, attn_bias)",
|
||||
globals={
|
||||
"q": q,
|
||||
"k": k,
|
||||
"v": v,
|
||||
"attn_bias": bias,
|
||||
"fn": fn,
|
||||
},
|
||||
label="attention",
|
||||
description=fw_op.NAME,
|
||||
sub_label=sub_label,
|
||||
num_threads=num_threads,
|
||||
)
|
||||
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph):
|
||||
fn(q, k, v, bias)
|
||||
yield benchmark.Timer(
|
||||
stmt="graph.replay()",
|
||||
globals={
|
||||
"graph": graph,
|
||||
},
|
||||
label="cuda graphed attention",
|
||||
description=fw_op.NAME,
|
||||
sub_label=sub_label,
|
||||
num_threads=num_threads,
|
||||
)
|
||||
|
||||
has_run = True
|
||||
|
||||
if not has_run:
|
||||
return
|
||||
|
||||
RUN_BASELINES = False
|
||||
if RUN_BASELINES:
|
||||
yield benchmark.Timer(
|
||||
stmt="fn(q, k, v, attn_bias)",
|
||||
globals={
|
||||
"q": q,
|
||||
"k": k,
|
||||
"v": v,
|
||||
"attn_bias": bias,
|
||||
"fn": ref_attention,
|
||||
},
|
||||
label="attention",
|
||||
description="eager",
|
||||
sub_label=sub_label,
|
||||
num_threads=num_threads,
|
||||
)
|
||||
|
||||
|
||||
benchmark_main_helper(mem_eff_attention_decoder, CASES, min_run_time=min_run_time)
|
||||
Reference in New Issue
Block a user