156 lines
4.3 KiB
Python
156 lines
4.3 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
import itertools
|
|
from functools import partial, reduce
|
|
|
|
import timm
|
|
import torch
|
|
import torch.nn as nn
|
|
from timm.models.layers import Mlp as TimmMlp
|
|
from timm.models.vision_transformer import Attention as TimmAttention
|
|
from timm.models.vision_transformer import Block as TimmBlock
|
|
from torch.utils import benchmark
|
|
from utils import benchmark_main_helper
|
|
|
|
import xformers.ops as xops
|
|
|
|
|
|
def replace_module(module: nn.Module, replace_class, factory):
|
|
if isinstance(module, replace_class):
|
|
return factory(module)
|
|
module_output = module
|
|
for name, child in module.named_children():
|
|
module_output.add_module(name, replace_module(child, replace_class, factory))
|
|
del module
|
|
return module_output
|
|
|
|
|
|
class TimmMemEffAttention(nn.Module):
|
|
def __init__(self, attn: TimmAttention, op=None):
|
|
super().__init__()
|
|
self.op = None
|
|
self.num_heads = attn.num_heads
|
|
self.scale = attn.scale
|
|
|
|
self.qkv = attn.qkv
|
|
self.attn_drop = attn.attn_drop
|
|
self.proj = attn.proj
|
|
self.proj_drop = attn.proj_drop
|
|
|
|
def forward(self, x):
|
|
B, N, C = x.shape
|
|
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
|
q, k, v = xops.unbind(qkv, dim=2)
|
|
|
|
x = xops.memory_efficient_attention(q, k, v, op=self.op).reshape(B, N, C)
|
|
|
|
x = self.proj(x)
|
|
x = self.proj_drop(x)
|
|
return x
|
|
|
|
|
|
class TimmSwiGLU(nn.Module):
|
|
def __init__(self, mlp: TimmMlp, op=None) -> None:
|
|
super().__init__()
|
|
self.fc1 = mlp.fc1
|
|
self.swiglu = xops.SwiGLU(
|
|
in_features=mlp.fc1.in_features,
|
|
hidden_features=mlp.fc1.out_features,
|
|
bias=True,
|
|
)
|
|
self.op = op
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
return self.swiglu(x)
|
|
|
|
|
|
def mod_memeff_attn(model: nn.Module, op=None) -> nn.Module:
|
|
return replace_module(model, TimmAttention, partial(TimmMemEffAttention, op=op))
|
|
|
|
|
|
def mod_mlp_to_swiglu(model: nn.Module, op=None) -> nn.Module:
|
|
def _mlp_to_swiglu(block: TimmBlock):
|
|
block.mlp = TimmSwiGLU(block.mlp, op=op)
|
|
return block
|
|
|
|
return replace_module(model, TimmBlock, _mlp_to_swiglu)
|
|
|
|
|
|
mod_mlp_to_eagr_swiglu = partial(mod_mlp_to_swiglu, op=xops.SwiGLUEagerOp)
|
|
mod_mlp_to_fast_swiglu = partial(mod_mlp_to_swiglu, op=None)
|
|
|
|
|
|
def compose(*fns):
|
|
def compose2(f, g):
|
|
return lambda *a, **kw: f(g(*a, **kw))
|
|
|
|
return reduce(compose2, fns)
|
|
|
|
|
|
MODELS = [
|
|
# model_name, model_factory, input_shape
|
|
("ViT-B/16", timm.models.vit_base_patch16_224, [512, 3, 224, 224]),
|
|
("ViT-B/8", timm.models.vit_base_patch8_224, [64, 3, 224, 224]),
|
|
("ViT-L/16", timm.models.vit_large_patch16_224, [128, 3, 224, 224]),
|
|
("ViT-g/14", timm.models.vit_giant_patch14_224, [32, 3, 224, 224]),
|
|
]
|
|
|
|
MODIFIERS = [
|
|
["mlp", lambda x: x],
|
|
["mlp+memeff", compose(mod_mlp_to_fast_swiglu, mod_memeff_attn)],
|
|
["swiglu", mod_mlp_to_eagr_swiglu],
|
|
["swiglu+fast_swiglu", mod_mlp_to_fast_swiglu],
|
|
["swiglu+fast_swiglu+memeff", compose(mod_mlp_to_fast_swiglu, mod_memeff_attn)],
|
|
]
|
|
|
|
|
|
def product_dict(**kwargs):
|
|
keys = kwargs.keys()
|
|
vals = kwargs.values()
|
|
for instance in itertools.product(*vals):
|
|
yield dict(zip(keys, instance))
|
|
|
|
|
|
CASES = list(
|
|
product_dict(
|
|
model_info=MODELS,
|
|
dtype=[torch.half],
|
|
)
|
|
)
|
|
|
|
|
|
def benchmark_transformer(model_info, dtype):
|
|
device = "cuda"
|
|
|
|
model_name, model_factory, input_shape = model_info
|
|
|
|
inp = torch.randn(input_shape, dtype=dtype, device=device)
|
|
|
|
for mod_name, mod_apply in MODIFIERS:
|
|
model: nn.Module = model_factory()
|
|
model = mod_apply(model).to(device).to(dtype)
|
|
|
|
# Make sure we don't have errors
|
|
out = model(inp)
|
|
grad = out.clone()
|
|
out.backward(grad)
|
|
|
|
yield benchmark.Timer(
|
|
stmt="model(inp).backward(grad)",
|
|
globals={
|
|
"model": model,
|
|
"inp": inp,
|
|
"grad": grad,
|
|
},
|
|
label="fw+bw",
|
|
description=mod_name,
|
|
sub_label=model_name,
|
|
)
|
|
|
|
|
|
benchmark_main_helper(benchmark_transformer, CASES)
|