First commit
This commit is contained in:
159
pkgs/xformers/benchmarks/benchmark_attn_decoding.py
Normal file
159
pkgs/xformers/benchmarks/benchmark_attn_decoding.py
Normal file
@@ -0,0 +1,159 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import itertools
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch.utils import benchmark
|
||||
from utils import benchmark_main_helper
|
||||
|
||||
import xformers.ops as xops
|
||||
|
||||
min_run_time = 0.5
|
||||
device = torch.device("cuda")
|
||||
|
||||
|
||||
def product_dict(**kwargs):
|
||||
keys = kwargs.keys()
|
||||
vals = kwargs.values()
|
||||
for instance in itertools.product(*vals):
|
||||
yield dict(zip(keys, instance))
|
||||
|
||||
|
||||
CASES = [
|
||||
dict(B=max(1, 2 ** (16 - i)), Mq=1, Mkv=2**i, Hq=16, Hkv=1, K=128)
|
||||
for i in range(8, 18)
|
||||
] + [
|
||||
dict(B=max(1, 2 ** (16 - i)), Mq=1, Mkv=2**i, Hq=16, Hkv=2, K=128)
|
||||
for i in range(8, 18)
|
||||
]
|
||||
|
||||
|
||||
def _setup_test(
|
||||
functions, fw: bool = False, bw: bool = False, cuda_graph: bool = True, **kwargs
|
||||
):
|
||||
for k, benchmark_cls in functions.items():
|
||||
benchmark_object = benchmark_cls(**kwargs, bw=bw)
|
||||
label = benchmark_object.label
|
||||
label += "fw" if fw else ""
|
||||
label += "bw" if bw else ""
|
||||
|
||||
def run_one():
|
||||
if fw:
|
||||
benchmark_object.fw()
|
||||
if bw:
|
||||
benchmark_object.bw()
|
||||
|
||||
if cuda_graph:
|
||||
run_one()
|
||||
g = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(g):
|
||||
run_one()
|
||||
|
||||
def run_one():
|
||||
g.replay()
|
||||
|
||||
yield benchmark.Timer(
|
||||
stmt="fn()",
|
||||
globals={
|
||||
"fn": run_one,
|
||||
},
|
||||
label=label,
|
||||
description=k,
|
||||
sub_label=benchmark_object.sub_label,
|
||||
)
|
||||
|
||||
|
||||
class AttentionDecodingFlashDecoding:
|
||||
OP: Any = xops.fmha.flash.FwOp
|
||||
|
||||
def __init__(
|
||||
self, B: int, Mq: int, Mkv: int, Hq: int, Hkv: int, K: int, bw: bool
|
||||
) -> None:
|
||||
dtype = torch.float16
|
||||
self.sub_label = f"B={B} Mq={Mq} Mkv={Mkv} Hq={Hq} Hkv={Hkv} K={K}"
|
||||
self.label = "attn_decoding"
|
||||
self.shapes = (B, Mq, Mkv, Hq, Hkv, K)
|
||||
|
||||
assert Hkv <= Hq
|
||||
assert Hq % Hkv == 0
|
||||
self.q = torch.randn(
|
||||
[B, Mq, Hkv, Hq // Hkv, K], device="cuda", dtype=dtype, requires_grad=bw
|
||||
)
|
||||
self.k = torch.randn(
|
||||
[B, Mkv, Hkv, 1, K], device="cuda", dtype=dtype, requires_grad=bw
|
||||
).expand(-1, -1, -1, Hq // Hkv, -1)
|
||||
self.v = torch.randn(
|
||||
[B, Mkv, Hkv, 1, K], device="cuda", dtype=dtype, requires_grad=bw
|
||||
).expand(-1, -1, -1, Hq // Hkv, -1)
|
||||
|
||||
if Hq == Hkv:
|
||||
self.q = self.q[:, :, :, 0]
|
||||
self.k = self.k[:, :, :, 0]
|
||||
self.v = self.v[:, :, :, 0]
|
||||
if Hkv == 1:
|
||||
self.q = self.q[:, :, 0]
|
||||
self.k = self.k[:, :, 0]
|
||||
self.v = self.v[:, :, 0]
|
||||
|
||||
def fw(self) -> None:
|
||||
xops.memory_efficient_attention_forward(self.q, self.k, self.v, op=self.OP)
|
||||
|
||||
|
||||
class AttentionDecodingSplitKV(AttentionDecodingFlashDecoding):
|
||||
OP = xops.fmha.triton_splitk.FwOp
|
||||
|
||||
|
||||
class AttentionDecodingPyTorchRepeat(AttentionDecodingFlashDecoding):
|
||||
def fw(self) -> None:
|
||||
B, Mq, Mkv, Hq, Hkv, K = self.shapes
|
||||
scale = 1 / K**0.5
|
||||
q = self.q.reshape([B, Mq, -1, K]).permute(0, 2, 1, 3)
|
||||
k = self.k.reshape([B, Mkv, -1, K]).permute(0, 2, 1, 3)
|
||||
v = self.v.reshape([B, Mkv, -1, K]).permute(0, 2, 1, 3)
|
||||
attn = (q @ k.transpose(-1, -2)).softmax(-1) * scale
|
||||
return attn @ v
|
||||
|
||||
|
||||
BENCHMARKS = {
|
||||
"pytorch": AttentionDecodingPyTorchRepeat,
|
||||
"flash-decoding": AttentionDecodingFlashDecoding,
|
||||
"triton_splitK": AttentionDecodingSplitKV,
|
||||
}
|
||||
|
||||
|
||||
try:
|
||||
import flash_attn
|
||||
|
||||
class AttentionDecodingFlashAttention(AttentionDecodingFlashDecoding):
|
||||
def fw(self) -> None:
|
||||
q, k, v = self.q, self.k, self.v
|
||||
if q.ndim == 5:
|
||||
B, Mq, H1, H2, K = q.shape
|
||||
B, Mkv, H1, H2, K = k.shape
|
||||
q = q.reshape([B, Mq, H1 * H2, K])
|
||||
k = k[:, :, :, 0]
|
||||
v = v[:, :, :, 0]
|
||||
return flash_attn.flash_attn_func(q, k, v)
|
||||
|
||||
BENCHMARKS[
|
||||
f"flash-attention@{flash_attn.__version__}"
|
||||
] = AttentionDecodingFlashAttention
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
def attn_decoding(**kwargs):
|
||||
yield from _setup_test(
|
||||
**kwargs,
|
||||
fw=True,
|
||||
cuda_graph=True,
|
||||
functions=BENCHMARKS,
|
||||
)
|
||||
|
||||
|
||||
benchmark_main_helper(attn_decoding, CASES, min_run_time=min_run_time)
|
||||
Reference in New Issue
Block a user