# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. import itertools from functools import partial import torch from torch.utils import benchmark from utils import benchmark_main_helper import xformers.ops import xformers.ops.fmha as fmha torch.backends.cuda.matmul.allow_tf32 = False # Run with # python xformers/benchmarks/benchmark_mem_eff_attn_decoder.py --omit-baselines --quiet # The baselines for these benchmarks are really slow because there is # so much padding in the inputs, so there is no point running them. def ref_attention_bmk(q, k, v, attn_bias=None): if isinstance(attn_bias, xformers.ops.AttentionMask): attn_bias = ( attn_bias.materialize((q.shape[0], 1, q.shape[1], k.shape[1])) .to(q) .squeeze() ) q = q * (1.0 / q.shape[-1] ** 0.5) if attn_bias is None: attn = q @ k.transpose(-2, -1) else: # equivalent to (q @ k.transpose(-2, -1) + m).softmax(-1) @ v # but faster, and is what is used in PyTorch now attn = torch.baddbmm(attn_bias, q, k.transpose(-2, -1)) attn = attn.softmax(-1) return attn @ v def ref_attention(q, k, v, attn_bias): assert q.ndim == 4 def T(t): return t.permute((0, 2, 1, 3)).reshape( [t.shape[0] * t.shape[2], t.shape[1], t.shape[3]] ) out = ref_attention_bmk(T(q), T(k), T(v), attn_bias) out = out.reshape([q.shape[0], q.shape[2], q.shape[1], v.shape[3]]) return out.permute((0, 2, 1, 3)) min_run_time = 0.5 device = torch.device("cuda") NUM_THREADS = [1] if device.type == "cuda" else [1, 40] OPS = [ xformers.ops.fmha.cutlass.FwOp, xformers.ops.fmha.decoder.FwOp, ] KV_SHAPES = [ # list of n_keys, padding_length, batchsize (2, 64, 3), (32, 1024, 500), (1000, 1024, 2), (8000, 8192, 1), (240, 256, 32), (2048, 2 * 1024, 4), (4096 * 2, 8 * 1024, 1), ] N_HEADS = [8, 16, 64] def product_dict(**kwargs): keys = kwargs.keys() vals = kwargs.values() for instance in itertools.product(*vals): yield dict(zip(keys, instance)) CASES = list( product_dict( kv_shape=KV_SHAPES, n_heads=N_HEADS, num_threads=NUM_THREADS, multiquery=[True, False], ) ) def mem_eff_attention_decoder( kv_shape, n_heads: int, num_threads: int, multiquery: bool ): n_keys, padding, B = kv_shape torch.manual_seed(42) k_seqlen = torch.randint(1, n_keys + 1, (B,)).tolist() K = 128 q = torch.rand(1, B, n_heads, K, device=device, dtype=torch.bfloat16) if multiquery: k = torch.rand( 1, B * padding, 1, K, device=device, dtype=torch.bfloat16 ).expand(1, B * padding, n_heads, K) v = torch.rand( 1, B * padding, 1, K, device=device, dtype=torch.bfloat16 ).expand(1, B * padding, n_heads, K) else: k = torch.rand(1, B * padding, n_heads, K, device=device, dtype=torch.bfloat16) v = torch.rand(1, B * padding, n_heads, K, device=device, dtype=torch.bfloat16) bias = fmha.attn_bias.BlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens( q_seqlen=[1] * B, kv_seqlen=k_seqlen, kv_padding=padding, ) sub_label = f"{B}batch-{k_seqlen[0]}keys-{n_heads}heads" if multiquery: sub_label += "-mq" has_run = False for fw_op in OPS: inp = fmha.Inputs(q, k, v, attn_bias=bias) if not fw_op.supports(inp): continue fn = partial(xformers.ops.memory_efficient_attention_forward, op=fw_op) yield benchmark.Timer( stmt="fn(q, k, v, attn_bias)", globals={ "q": q, "k": k, "v": v, "attn_bias": bias, "fn": fn, }, label="attention", description=fw_op.NAME, sub_label=sub_label, num_threads=num_threads, ) graph = torch.cuda.CUDAGraph() with torch.cuda.graph(graph): fn(q, k, v, bias) yield benchmark.Timer( stmt="graph.replay()", globals={ "graph": graph, }, label="cuda graphed attention", description=fw_op.NAME, sub_label=sub_label, num_threads=num_threads, ) has_run = True if not has_run: return RUN_BASELINES = False if RUN_BASELINES: yield benchmark.Timer( stmt="fn(q, k, v, attn_bias)", globals={ "q": q, "k": k, "v": v, "attn_bias": bias, "fn": ref_attention, }, label="attention", description="eager", sub_label=sub_label, num_threads=num_threads, ) benchmark_main_helper(mem_eff_attention_decoder, CASES, min_run_time=min_run_time)