# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. import itertools from functools import partial, reduce import timm import torch import torch.nn as nn from timm.models.layers import Mlp as TimmMlp from timm.models.vision_transformer import Attention as TimmAttention from timm.models.vision_transformer import Block as TimmBlock from torch.utils import benchmark from utils import benchmark_main_helper import xformers.ops as xops def replace_module(module: nn.Module, replace_class, factory): if isinstance(module, replace_class): return factory(module) module_output = module for name, child in module.named_children(): module_output.add_module(name, replace_module(child, replace_class, factory)) del module return module_output class TimmMemEffAttention(nn.Module): def __init__(self, attn: TimmAttention, op=None): super().__init__() self.op = None self.num_heads = attn.num_heads self.scale = attn.scale self.qkv = attn.qkv self.attn_drop = attn.attn_drop self.proj = attn.proj self.proj_drop = attn.proj_drop def forward(self, x): B, N, C = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) q, k, v = xops.unbind(qkv, dim=2) x = xops.memory_efficient_attention(q, k, v, op=self.op).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x class TimmSwiGLU(nn.Module): def __init__(self, mlp: TimmMlp, op=None) -> None: super().__init__() self.fc1 = mlp.fc1 self.swiglu = xops.SwiGLU( in_features=mlp.fc1.in_features, hidden_features=mlp.fc1.out_features, bias=True, ) self.op = op def forward(self, x: torch.Tensor) -> torch.Tensor: return self.swiglu(x) def mod_memeff_attn(model: nn.Module, op=None) -> nn.Module: return replace_module(model, TimmAttention, partial(TimmMemEffAttention, op=op)) def mod_mlp_to_swiglu(model: nn.Module, op=None) -> nn.Module: def _mlp_to_swiglu(block: TimmBlock): block.mlp = TimmSwiGLU(block.mlp, op=op) return block return replace_module(model, TimmBlock, _mlp_to_swiglu) mod_mlp_to_eagr_swiglu = partial(mod_mlp_to_swiglu, op=xops.SwiGLUEagerOp) mod_mlp_to_fast_swiglu = partial(mod_mlp_to_swiglu, op=None) def compose(*fns): def compose2(f, g): return lambda *a, **kw: f(g(*a, **kw)) return reduce(compose2, fns) MODELS = [ # model_name, model_factory, input_shape ("ViT-B/16", timm.models.vit_base_patch16_224, [512, 3, 224, 224]), ("ViT-B/8", timm.models.vit_base_patch8_224, [64, 3, 224, 224]), ("ViT-L/16", timm.models.vit_large_patch16_224, [128, 3, 224, 224]), ("ViT-g/14", timm.models.vit_giant_patch14_224, [32, 3, 224, 224]), ] MODIFIERS = [ ["mlp", lambda x: x], ["mlp+memeff", compose(mod_mlp_to_fast_swiglu, mod_memeff_attn)], ["swiglu", mod_mlp_to_eagr_swiglu], ["swiglu+fast_swiglu", mod_mlp_to_fast_swiglu], ["swiglu+fast_swiglu+memeff", compose(mod_mlp_to_fast_swiglu, mod_memeff_attn)], ] def product_dict(**kwargs): keys = kwargs.keys() vals = kwargs.values() for instance in itertools.product(*vals): yield dict(zip(keys, instance)) CASES = list( product_dict( model_info=MODELS, dtype=[torch.half], ) ) def benchmark_transformer(model_info, dtype): device = "cuda" model_name, model_factory, input_shape = model_info inp = torch.randn(input_shape, dtype=dtype, device=device) for mod_name, mod_apply in MODIFIERS: model: nn.Module = model_factory() model = mod_apply(model).to(device).to(dtype) # Make sure we don't have errors out = model(inp) grad = out.clone() out.backward(grad) yield benchmark.Timer( stmt="model(inp).backward(grad)", globals={ "model": model, "inp": inp, "grad": grad, }, label="fw+bw", description=mod_name, sub_label=model_name, ) benchmark_main_helper(benchmark_transformer, CASES)