160 lines
4.4 KiB
Python
160 lines
4.4 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
import itertools
|
|
from typing import Any
|
|
|
|
import torch
|
|
from torch.utils import benchmark
|
|
from utils import benchmark_main_helper
|
|
|
|
import xformers.ops as xops
|
|
|
|
min_run_time = 0.5
|
|
device = torch.device("cuda")
|
|
|
|
|
|
def product_dict(**kwargs):
|
|
keys = kwargs.keys()
|
|
vals = kwargs.values()
|
|
for instance in itertools.product(*vals):
|
|
yield dict(zip(keys, instance))
|
|
|
|
|
|
CASES = [
|
|
dict(B=max(1, 2 ** (16 - i)), Mq=1, Mkv=2**i, Hq=16, Hkv=1, K=128)
|
|
for i in range(8, 18)
|
|
] + [
|
|
dict(B=max(1, 2 ** (16 - i)), Mq=1, Mkv=2**i, Hq=16, Hkv=2, K=128)
|
|
for i in range(8, 18)
|
|
]
|
|
|
|
|
|
def _setup_test(
|
|
functions, fw: bool = False, bw: bool = False, cuda_graph: bool = True, **kwargs
|
|
):
|
|
for k, benchmark_cls in functions.items():
|
|
benchmark_object = benchmark_cls(**kwargs, bw=bw)
|
|
label = benchmark_object.label
|
|
label += "fw" if fw else ""
|
|
label += "bw" if bw else ""
|
|
|
|
def run_one():
|
|
if fw:
|
|
benchmark_object.fw()
|
|
if bw:
|
|
benchmark_object.bw()
|
|
|
|
if cuda_graph:
|
|
run_one()
|
|
g = torch.cuda.CUDAGraph()
|
|
with torch.cuda.graph(g):
|
|
run_one()
|
|
|
|
def run_one():
|
|
g.replay()
|
|
|
|
yield benchmark.Timer(
|
|
stmt="fn()",
|
|
globals={
|
|
"fn": run_one,
|
|
},
|
|
label=label,
|
|
description=k,
|
|
sub_label=benchmark_object.sub_label,
|
|
)
|
|
|
|
|
|
class AttentionDecodingFlashDecoding:
|
|
OP: Any = xops.fmha.flash.FwOp
|
|
|
|
def __init__(
|
|
self, B: int, Mq: int, Mkv: int, Hq: int, Hkv: int, K: int, bw: bool
|
|
) -> None:
|
|
dtype = torch.float16
|
|
self.sub_label = f"B={B} Mq={Mq} Mkv={Mkv} Hq={Hq} Hkv={Hkv} K={K}"
|
|
self.label = "attn_decoding"
|
|
self.shapes = (B, Mq, Mkv, Hq, Hkv, K)
|
|
|
|
assert Hkv <= Hq
|
|
assert Hq % Hkv == 0
|
|
self.q = torch.randn(
|
|
[B, Mq, Hkv, Hq // Hkv, K], device="cuda", dtype=dtype, requires_grad=bw
|
|
)
|
|
self.k = torch.randn(
|
|
[B, Mkv, Hkv, 1, K], device="cuda", dtype=dtype, requires_grad=bw
|
|
).expand(-1, -1, -1, Hq // Hkv, -1)
|
|
self.v = torch.randn(
|
|
[B, Mkv, Hkv, 1, K], device="cuda", dtype=dtype, requires_grad=bw
|
|
).expand(-1, -1, -1, Hq // Hkv, -1)
|
|
|
|
if Hq == Hkv:
|
|
self.q = self.q[:, :, :, 0]
|
|
self.k = self.k[:, :, :, 0]
|
|
self.v = self.v[:, :, :, 0]
|
|
if Hkv == 1:
|
|
self.q = self.q[:, :, 0]
|
|
self.k = self.k[:, :, 0]
|
|
self.v = self.v[:, :, 0]
|
|
|
|
def fw(self) -> None:
|
|
xops.memory_efficient_attention_forward(self.q, self.k, self.v, op=self.OP)
|
|
|
|
|
|
class AttentionDecodingSplitKV(AttentionDecodingFlashDecoding):
|
|
OP = xops.fmha.triton_splitk.FwOp
|
|
|
|
|
|
class AttentionDecodingPyTorchRepeat(AttentionDecodingFlashDecoding):
|
|
def fw(self) -> None:
|
|
B, Mq, Mkv, Hq, Hkv, K = self.shapes
|
|
scale = 1 / K**0.5
|
|
q = self.q.reshape([B, Mq, -1, K]).permute(0, 2, 1, 3)
|
|
k = self.k.reshape([B, Mkv, -1, K]).permute(0, 2, 1, 3)
|
|
v = self.v.reshape([B, Mkv, -1, K]).permute(0, 2, 1, 3)
|
|
attn = (q @ k.transpose(-1, -2)).softmax(-1) * scale
|
|
return attn @ v
|
|
|
|
|
|
BENCHMARKS = {
|
|
"pytorch": AttentionDecodingPyTorchRepeat,
|
|
"flash-decoding": AttentionDecodingFlashDecoding,
|
|
"triton_splitK": AttentionDecodingSplitKV,
|
|
}
|
|
|
|
|
|
try:
|
|
import flash_attn
|
|
|
|
class AttentionDecodingFlashAttention(AttentionDecodingFlashDecoding):
|
|
def fw(self) -> None:
|
|
q, k, v = self.q, self.k, self.v
|
|
if q.ndim == 5:
|
|
B, Mq, H1, H2, K = q.shape
|
|
B, Mkv, H1, H2, K = k.shape
|
|
q = q.reshape([B, Mq, H1 * H2, K])
|
|
k = k[:, :, :, 0]
|
|
v = v[:, :, :, 0]
|
|
return flash_attn.flash_attn_func(q, k, v)
|
|
|
|
BENCHMARKS[
|
|
f"flash-attention@{flash_attn.__version__}"
|
|
] = AttentionDecodingFlashAttention
|
|
except ImportError:
|
|
pass
|
|
|
|
|
|
def attn_decoding(**kwargs):
|
|
yield from _setup_test(
|
|
**kwargs,
|
|
fw=True,
|
|
cuda_graph=True,
|
|
functions=BENCHMARKS,
|
|
)
|
|
|
|
|
|
benchmark_main_helper(attn_decoding, CASES, min_run_time=min_run_time)
|