317 lines
8.8 KiB
Python
317 lines
8.8 KiB
Python
|
|
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||
|
|
#
|
||
|
|
# This source code is licensed under the BSD license found in the
|
||
|
|
# LICENSE file in the root directory of this source tree.
|
||
|
|
|
||
|
|
|
||
|
|
import itertools
|
||
|
|
import random
|
||
|
|
from functools import partial
|
||
|
|
|
||
|
|
import torch
|
||
|
|
from torch.utils import benchmark
|
||
|
|
from utils import benchmark_main_helper
|
||
|
|
|
||
|
|
import xformers.ops
|
||
|
|
import xformers.ops.fmha as fmha
|
||
|
|
|
||
|
|
torch.backends.cuda.matmul.allow_tf32 = False
|
||
|
|
|
||
|
|
|
||
|
|
def create_attn_bias(
|
||
|
|
bias_type,
|
||
|
|
batch_size: int,
|
||
|
|
num_heads: int,
|
||
|
|
q_len: int,
|
||
|
|
kv_len: int,
|
||
|
|
device,
|
||
|
|
dtype,
|
||
|
|
bias_requires_grad: bool = False,
|
||
|
|
):
|
||
|
|
NoneType = type(None)
|
||
|
|
if bias_type is NoneType:
|
||
|
|
return None
|
||
|
|
if bias_type is torch.Tensor:
|
||
|
|
attn_bias = torch.randn((1, 1, q_len, kv_len), device=device, dtype=dtype)
|
||
|
|
return attn_bias.expand(batch_size, num_heads, q_len, kv_len)
|
||
|
|
if bias_type is xformers.ops.LowerTriangularMask:
|
||
|
|
return bias_type()
|
||
|
|
assert False, f"Unsupported bias type: {bias_type}"
|
||
|
|
|
||
|
|
|
||
|
|
def ref_attention_bmk(q, k, v, attn_bias=None, p=0.0):
|
||
|
|
if isinstance(attn_bias, xformers.ops.AttentionMask):
|
||
|
|
attn_bias = (
|
||
|
|
attn_bias.materialize((q.shape[0], 1, q.shape[1], k.shape[1]))
|
||
|
|
.to(q)
|
||
|
|
.squeeze()
|
||
|
|
)
|
||
|
|
q = q * (1.0 / q.shape[-1] ** 0.5)
|
||
|
|
if attn_bias is None:
|
||
|
|
attn = q @ k.transpose(-2, -1)
|
||
|
|
else:
|
||
|
|
# equivalent to (q @ k.transpose(-2, -1) + m).softmax(-1) @ v
|
||
|
|
# but faster, and is what is used in PyTorch now
|
||
|
|
attn = torch.baddbmm(attn_bias, q, k.transpose(-2, -1))
|
||
|
|
attn = attn.softmax(-1)
|
||
|
|
if p > 0:
|
||
|
|
attn = torch.nn.functional.dropout(attn, p=p)
|
||
|
|
return attn @ v
|
||
|
|
|
||
|
|
|
||
|
|
def ref_attention(q, k, v, attn_bias, p=0.0):
|
||
|
|
assert q.ndim == 4
|
||
|
|
B, M, H, K = q.shape
|
||
|
|
|
||
|
|
def T(t):
|
||
|
|
return t.permute((0, 2, 1, 3)).reshape(
|
||
|
|
[t.shape[0] * t.shape[2], t.shape[1], t.shape[3]]
|
||
|
|
)
|
||
|
|
|
||
|
|
if isinstance(attn_bias, torch.Tensor):
|
||
|
|
attn_bias = attn_bias.reshape(B * H, M, M)
|
||
|
|
out = ref_attention_bmk(T(q), T(k), T(v), attn_bias, p)
|
||
|
|
out = out.reshape([q.shape[0], q.shape[2], q.shape[1], v.shape[3]])
|
||
|
|
return out.permute((0, 2, 1, 3))
|
||
|
|
|
||
|
|
|
||
|
|
min_run_time = 0.5
|
||
|
|
device = torch.device("cuda")
|
||
|
|
|
||
|
|
NUM_THREADS = [1] if device.type == "cuda" else [1, 40]
|
||
|
|
SHAPES = [
|
||
|
|
# ViT
|
||
|
|
(384, 197, 1, 88),
|
||
|
|
(384, 197, 1, 80),
|
||
|
|
(384, 197, 1, 64),
|
||
|
|
(1024, 197, 1, 88),
|
||
|
|
(1024, 197, 1, 80),
|
||
|
|
(1024, 197, 1, 64),
|
||
|
|
# ViT-Huge
|
||
|
|
(32 * 16, 197, 1, 80),
|
||
|
|
(32, 197, 16, 80),
|
||
|
|
(32, 197, 16, 64),
|
||
|
|
(32, 197, 16, 128),
|
||
|
|
# ViT-Giant
|
||
|
|
(16 * 16, 197, 1, 88),
|
||
|
|
(16, 197, 16, 88),
|
||
|
|
(16, 197, 16, 64),
|
||
|
|
(16, 197, 16, 128),
|
||
|
|
# FB models
|
||
|
|
(1024, 82, 8, 64),
|
||
|
|
(150, 256, 16, 64),
|
||
|
|
(64, 256, 12, 64),
|
||
|
|
# Stable diffusion (https://github.com/huggingface/diffusers/pull/532)
|
||
|
|
(1, 4096, 16, 40), # 512x512
|
||
|
|
(1, 16384, 16, 40), # 1024x1024
|
||
|
|
(1, 4096, 16, 80),
|
||
|
|
(1, 16384, 16, 80),
|
||
|
|
# + bs4
|
||
|
|
(4, 4096, 16, 40),
|
||
|
|
(4, 16384, 16, 40),
|
||
|
|
(4, 4096, 16, 80),
|
||
|
|
(4, 16384, 16, 80),
|
||
|
|
# ParlAI model
|
||
|
|
(256, 4096, 16, 64),
|
||
|
|
# Zetta B M H K
|
||
|
|
(8, 2048, 20, 128),
|
||
|
|
# LLaMa 70b - mp=8/16
|
||
|
|
*sorted(itertools.product([1, 2], [2048, 4096, 8192], [4, 8], [128])),
|
||
|
|
*sorted(
|
||
|
|
itertools.product([16], [128, 512, 1024], [16], [16, 32, 64, 128, 160, 256])
|
||
|
|
),
|
||
|
|
]
|
||
|
|
|
||
|
|
OPS = [
|
||
|
|
(xformers.ops.fmha.cutlass.FwOp, xformers.ops.fmha.cutlass.BwOp),
|
||
|
|
(xformers.ops.fmha.flash.FwOp, xformers.ops.fmha.flash.BwOp),
|
||
|
|
# TODO: Triton is not stable: it can trigger Illegal Memory Accesses
|
||
|
|
# and its performance varies a lot between runs.
|
||
|
|
# (xformers.ops.fmha.triton.FwOp, xformers.ops.fmha.triton.BwOp),
|
||
|
|
]
|
||
|
|
|
||
|
|
|
||
|
|
def product_dict(**kwargs):
|
||
|
|
keys = kwargs.keys()
|
||
|
|
vals = kwargs.values()
|
||
|
|
for instance in itertools.product(*vals):
|
||
|
|
yield dict(zip(keys, instance))
|
||
|
|
|
||
|
|
|
||
|
|
CASES = list(
|
||
|
|
product_dict(
|
||
|
|
shape=SHAPES,
|
||
|
|
num_threads=NUM_THREADS,
|
||
|
|
dropout_p=[0.0],
|
||
|
|
attn_bias_cfg=[(type(None), False)],
|
||
|
|
dtype=[torch.half],
|
||
|
|
)
|
||
|
|
)
|
||
|
|
|
||
|
|
# Add more cases with some variations
|
||
|
|
for c in CASES.copy():
|
||
|
|
c = c.copy()
|
||
|
|
c.update(
|
||
|
|
random.Random(str(c["shape"])).choice(
|
||
|
|
[
|
||
|
|
{"dropout_p": 0.3},
|
||
|
|
{"attn_bias_cfg": (torch.Tensor, False)},
|
||
|
|
{"attn_bias_cfg": (torch.Tensor, True)},
|
||
|
|
{"attn_bias_cfg": (xformers.ops.LowerTriangularMask, False)},
|
||
|
|
{"dtype": torch.bfloat16},
|
||
|
|
{"dtype": torch.float},
|
||
|
|
]
|
||
|
|
)
|
||
|
|
)
|
||
|
|
CASES.append(c)
|
||
|
|
|
||
|
|
|
||
|
|
def create_tensors(shape, dtype, requires_grad=False):
|
||
|
|
B, M, H, K = shape
|
||
|
|
qkv = torch.rand(
|
||
|
|
[B, M, 3, H, K], device=device, dtype=dtype, requires_grad=requires_grad
|
||
|
|
)
|
||
|
|
q, k, v = xformers.ops.unbind(qkv, 2)
|
||
|
|
return qkv, q, k, v
|
||
|
|
|
||
|
|
|
||
|
|
def mem_eff_attention_fw(shape, num_threads: int, attn_bias_cfg, dropout_p, dtype):
|
||
|
|
B, M, H, K = shape
|
||
|
|
_, q, k, v = create_tensors(shape, dtype)
|
||
|
|
attn_bias_type, attn_bias_requires_grad = attn_bias_cfg
|
||
|
|
if attn_bias_requires_grad:
|
||
|
|
return
|
||
|
|
bias = create_attn_bias(
|
||
|
|
attn_bias_type,
|
||
|
|
batch_size=B,
|
||
|
|
num_heads=H,
|
||
|
|
q_len=M,
|
||
|
|
kv_len=M,
|
||
|
|
device=device,
|
||
|
|
dtype=dtype,
|
||
|
|
bias_requires_grad=attn_bias_requires_grad,
|
||
|
|
)
|
||
|
|
inp = fmha.Inputs(query=q, key=k, value=v, attn_bias=bias, p=dropout_p)
|
||
|
|
|
||
|
|
dtype_str = {
|
||
|
|
torch.bfloat16: "b16",
|
||
|
|
torch.half: "f16",
|
||
|
|
torch.float: "f32",
|
||
|
|
}[dtype]
|
||
|
|
sub_label = (
|
||
|
|
f"{dtype_str} {B}-{M}-{H}-{K}, p={dropout_p}, "
|
||
|
|
f"BiasT={attn_bias_type.__name__}"
|
||
|
|
)
|
||
|
|
|
||
|
|
has_run = False
|
||
|
|
for fw_op, bw_op in OPS:
|
||
|
|
if not fw_op.supports(inp):
|
||
|
|
continue
|
||
|
|
|
||
|
|
yield benchmark.Timer(
|
||
|
|
stmt="fn(q, k, v, attn_bias, p)",
|
||
|
|
globals={
|
||
|
|
"q": q,
|
||
|
|
"k": k,
|
||
|
|
"v": v,
|
||
|
|
"attn_bias": inp.attn_bias,
|
||
|
|
"p": dropout_p,
|
||
|
|
"fn": partial(
|
||
|
|
xformers.ops.memory_efficient_attention, op=(fw_op, bw_op)
|
||
|
|
),
|
||
|
|
},
|
||
|
|
label=f"attention (attn_bias={attn_bias_type})",
|
||
|
|
description=fw_op.NAME,
|
||
|
|
sub_label=sub_label,
|
||
|
|
num_threads=num_threads,
|
||
|
|
)
|
||
|
|
has_run = True
|
||
|
|
|
||
|
|
if not has_run:
|
||
|
|
return
|
||
|
|
|
||
|
|
yield benchmark.Timer(
|
||
|
|
stmt="fn(q, k, v, attn_bias, p)",
|
||
|
|
globals={
|
||
|
|
"q": q,
|
||
|
|
"k": k,
|
||
|
|
"v": v,
|
||
|
|
"attn_bias": inp.attn_bias,
|
||
|
|
"p": dropout_p,
|
||
|
|
"fn": ref_attention,
|
||
|
|
},
|
||
|
|
label=f"attention (attn_bias={attn_bias_type})",
|
||
|
|
description="eager",
|
||
|
|
sub_label=sub_label,
|
||
|
|
num_threads=num_threads,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def mem_eff_attention_bw(shape, num_threads: int, attn_bias_cfg, dropout_p, dtype):
|
||
|
|
B, M, H, K = shape
|
||
|
|
qkv, q, k, v = create_tensors(shape, dtype, requires_grad=True)
|
||
|
|
|
||
|
|
attn_bias_type, attn_bias_requires_grad = attn_bias_cfg
|
||
|
|
bias = create_attn_bias(
|
||
|
|
attn_bias_type,
|
||
|
|
batch_size=B,
|
||
|
|
num_heads=H,
|
||
|
|
q_len=M,
|
||
|
|
kv_len=M,
|
||
|
|
device=device,
|
||
|
|
dtype=dtype,
|
||
|
|
bias_requires_grad=attn_bias_requires_grad,
|
||
|
|
)
|
||
|
|
inp = fmha.Inputs(query=q, key=k, value=v, attn_bias=bias, p=dropout_p)
|
||
|
|
|
||
|
|
dtype_str = {
|
||
|
|
torch.bfloat16: "b16",
|
||
|
|
torch.half: "f16",
|
||
|
|
torch.float: "f32",
|
||
|
|
}[dtype]
|
||
|
|
sub_label = (
|
||
|
|
f"{dtype_str} {B}-{M}-{H}-{K}, p={dropout_p}, "
|
||
|
|
f"BiasT={attn_bias_type.__name__}, BiasGrad={attn_bias_requires_grad}"
|
||
|
|
)
|
||
|
|
|
||
|
|
has_run = False
|
||
|
|
for fw_op, bw_op in OPS:
|
||
|
|
if not fw_op.supports(inp) or not bw_op.supports(inp):
|
||
|
|
continue
|
||
|
|
has_run = True
|
||
|
|
out = xformers.ops.memory_efficient_attention(
|
||
|
|
inp.query, inp.key, inp.value, inp.attn_bias, inp.p, op=(fw_op, bw_op)
|
||
|
|
)
|
||
|
|
grad_benchmark = torch.ones_like(q)
|
||
|
|
|
||
|
|
yield benchmark.Timer(
|
||
|
|
stmt="out.backward(grad, retain_graph=True)",
|
||
|
|
globals={
|
||
|
|
"out": out,
|
||
|
|
"grad": grad_benchmark,
|
||
|
|
},
|
||
|
|
label=f"attention backward (attn_bias={attn_bias_type})",
|
||
|
|
description=bw_op.NAME,
|
||
|
|
sub_label=sub_label,
|
||
|
|
num_threads=num_threads,
|
||
|
|
)
|
||
|
|
del out
|
||
|
|
|
||
|
|
if not has_run:
|
||
|
|
return
|
||
|
|
yield benchmark.Timer(
|
||
|
|
stmt="out.backward(grad, retain_graph=True)",
|
||
|
|
globals={
|
||
|
|
"out": ref_attention(q, k, v, inp.attn_bias, dropout_p),
|
||
|
|
"grad": grad_benchmark,
|
||
|
|
},
|
||
|
|
label=f"attention backward (attn_bias={attn_bias_type})",
|
||
|
|
description="vanilla",
|
||
|
|
sub_label=sub_label,
|
||
|
|
num_threads=num_threads,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
benchmark_main_helper(mem_eff_attention_fw, CASES, min_run_time=min_run_time)
|
||
|
|
benchmark_main_helper(mem_eff_attention_bw, CASES, min_run_time=min_run_time)
|