# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. import itertools from typing import Any import torch from torch.utils import benchmark from utils import benchmark_main_helper import xformers.ops as xops min_run_time = 0.5 device = torch.device("cuda") def product_dict(**kwargs): keys = kwargs.keys() vals = kwargs.values() for instance in itertools.product(*vals): yield dict(zip(keys, instance)) CASES = [ dict(B=max(1, 2 ** (16 - i)), Mq=1, Mkv=2**i, Hq=16, Hkv=1, K=128) for i in range(8, 18) ] + [ dict(B=max(1, 2 ** (16 - i)), Mq=1, Mkv=2**i, Hq=16, Hkv=2, K=128) for i in range(8, 18) ] def _setup_test( functions, fw: bool = False, bw: bool = False, cuda_graph: bool = True, **kwargs ): for k, benchmark_cls in functions.items(): benchmark_object = benchmark_cls(**kwargs, bw=bw) label = benchmark_object.label label += "fw" if fw else "" label += "bw" if bw else "" def run_one(): if fw: benchmark_object.fw() if bw: benchmark_object.bw() if cuda_graph: run_one() g = torch.cuda.CUDAGraph() with torch.cuda.graph(g): run_one() def run_one(): g.replay() yield benchmark.Timer( stmt="fn()", globals={ "fn": run_one, }, label=label, description=k, sub_label=benchmark_object.sub_label, ) class AttentionDecodingFlashDecoding: OP: Any = xops.fmha.flash.FwOp def __init__( self, B: int, Mq: int, Mkv: int, Hq: int, Hkv: int, K: int, bw: bool ) -> None: dtype = torch.float16 self.sub_label = f"B={B} Mq={Mq} Mkv={Mkv} Hq={Hq} Hkv={Hkv} K={K}" self.label = "attn_decoding" self.shapes = (B, Mq, Mkv, Hq, Hkv, K) assert Hkv <= Hq assert Hq % Hkv == 0 self.q = torch.randn( [B, Mq, Hkv, Hq // Hkv, K], device="cuda", dtype=dtype, requires_grad=bw ) self.k = torch.randn( [B, Mkv, Hkv, 1, K], device="cuda", dtype=dtype, requires_grad=bw ).expand(-1, -1, -1, Hq // Hkv, -1) self.v = torch.randn( [B, Mkv, Hkv, 1, K], device="cuda", dtype=dtype, requires_grad=bw ).expand(-1, -1, -1, Hq // Hkv, -1) if Hq == Hkv: self.q = self.q[:, :, :, 0] self.k = self.k[:, :, :, 0] self.v = self.v[:, :, :, 0] if Hkv == 1: self.q = self.q[:, :, 0] self.k = self.k[:, :, 0] self.v = self.v[:, :, 0] def fw(self) -> None: xops.memory_efficient_attention_forward(self.q, self.k, self.v, op=self.OP) class AttentionDecodingSplitKV(AttentionDecodingFlashDecoding): OP = xops.fmha.triton_splitk.FwOp class AttentionDecodingPyTorchRepeat(AttentionDecodingFlashDecoding): def fw(self) -> None: B, Mq, Mkv, Hq, Hkv, K = self.shapes scale = 1 / K**0.5 q = self.q.reshape([B, Mq, -1, K]).permute(0, 2, 1, 3) k = self.k.reshape([B, Mkv, -1, K]).permute(0, 2, 1, 3) v = self.v.reshape([B, Mkv, -1, K]).permute(0, 2, 1, 3) attn = (q @ k.transpose(-1, -2)).softmax(-1) * scale return attn @ v BENCHMARKS = { "pytorch": AttentionDecodingPyTorchRepeat, "flash-decoding": AttentionDecodingFlashDecoding, "triton_splitK": AttentionDecodingSplitKV, } try: import flash_attn class AttentionDecodingFlashAttention(AttentionDecodingFlashDecoding): def fw(self) -> None: q, k, v = self.q, self.k, self.v if q.ndim == 5: B, Mq, H1, H2, K = q.shape B, Mkv, H1, H2, K = k.shape q = q.reshape([B, Mq, H1 * H2, K]) k = k[:, :, :, 0] v = v[:, :, :, 0] return flash_attn.flash_attn_func(q, k, v) BENCHMARKS[ f"flash-attention@{flash_attn.__version__}" ] = AttentionDecodingFlashAttention except ImportError: pass def attn_decoding(**kwargs): yield from _setup_test( **kwargs, fw=True, cuda_graph=True, functions=BENCHMARKS, ) benchmark_main_helper(attn_decoding, CASES, min_run_time=min_run_time)