First commit
This commit is contained in:
155
pkgs/xformers/benchmarks/benchmark_transformer.py
Normal file
155
pkgs/xformers/benchmarks/benchmark_transformer.py
Normal file
@@ -0,0 +1,155 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import itertools
|
||||
from functools import partial, reduce
|
||||
|
||||
import timm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from timm.models.layers import Mlp as TimmMlp
|
||||
from timm.models.vision_transformer import Attention as TimmAttention
|
||||
from timm.models.vision_transformer import Block as TimmBlock
|
||||
from torch.utils import benchmark
|
||||
from utils import benchmark_main_helper
|
||||
|
||||
import xformers.ops as xops
|
||||
|
||||
|
||||
def replace_module(module: nn.Module, replace_class, factory):
|
||||
if isinstance(module, replace_class):
|
||||
return factory(module)
|
||||
module_output = module
|
||||
for name, child in module.named_children():
|
||||
module_output.add_module(name, replace_module(child, replace_class, factory))
|
||||
del module
|
||||
return module_output
|
||||
|
||||
|
||||
class TimmMemEffAttention(nn.Module):
|
||||
def __init__(self, attn: TimmAttention, op=None):
|
||||
super().__init__()
|
||||
self.op = None
|
||||
self.num_heads = attn.num_heads
|
||||
self.scale = attn.scale
|
||||
|
||||
self.qkv = attn.qkv
|
||||
self.attn_drop = attn.attn_drop
|
||||
self.proj = attn.proj
|
||||
self.proj_drop = attn.proj_drop
|
||||
|
||||
def forward(self, x):
|
||||
B, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
||||
q, k, v = xops.unbind(qkv, dim=2)
|
||||
|
||||
x = xops.memory_efficient_attention(q, k, v, op=self.op).reshape(B, N, C)
|
||||
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class TimmSwiGLU(nn.Module):
|
||||
def __init__(self, mlp: TimmMlp, op=None) -> None:
|
||||
super().__init__()
|
||||
self.fc1 = mlp.fc1
|
||||
self.swiglu = xops.SwiGLU(
|
||||
in_features=mlp.fc1.in_features,
|
||||
hidden_features=mlp.fc1.out_features,
|
||||
bias=True,
|
||||
)
|
||||
self.op = op
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.swiglu(x)
|
||||
|
||||
|
||||
def mod_memeff_attn(model: nn.Module, op=None) -> nn.Module:
|
||||
return replace_module(model, TimmAttention, partial(TimmMemEffAttention, op=op))
|
||||
|
||||
|
||||
def mod_mlp_to_swiglu(model: nn.Module, op=None) -> nn.Module:
|
||||
def _mlp_to_swiglu(block: TimmBlock):
|
||||
block.mlp = TimmSwiGLU(block.mlp, op=op)
|
||||
return block
|
||||
|
||||
return replace_module(model, TimmBlock, _mlp_to_swiglu)
|
||||
|
||||
|
||||
mod_mlp_to_eagr_swiglu = partial(mod_mlp_to_swiglu, op=xops.SwiGLUEagerOp)
|
||||
mod_mlp_to_fast_swiglu = partial(mod_mlp_to_swiglu, op=None)
|
||||
|
||||
|
||||
def compose(*fns):
|
||||
def compose2(f, g):
|
||||
return lambda *a, **kw: f(g(*a, **kw))
|
||||
|
||||
return reduce(compose2, fns)
|
||||
|
||||
|
||||
MODELS = [
|
||||
# model_name, model_factory, input_shape
|
||||
("ViT-B/16", timm.models.vit_base_patch16_224, [512, 3, 224, 224]),
|
||||
("ViT-B/8", timm.models.vit_base_patch8_224, [64, 3, 224, 224]),
|
||||
("ViT-L/16", timm.models.vit_large_patch16_224, [128, 3, 224, 224]),
|
||||
("ViT-g/14", timm.models.vit_giant_patch14_224, [32, 3, 224, 224]),
|
||||
]
|
||||
|
||||
MODIFIERS = [
|
||||
["mlp", lambda x: x],
|
||||
["mlp+memeff", compose(mod_mlp_to_fast_swiglu, mod_memeff_attn)],
|
||||
["swiglu", mod_mlp_to_eagr_swiglu],
|
||||
["swiglu+fast_swiglu", mod_mlp_to_fast_swiglu],
|
||||
["swiglu+fast_swiglu+memeff", compose(mod_mlp_to_fast_swiglu, mod_memeff_attn)],
|
||||
]
|
||||
|
||||
|
||||
def product_dict(**kwargs):
|
||||
keys = kwargs.keys()
|
||||
vals = kwargs.values()
|
||||
for instance in itertools.product(*vals):
|
||||
yield dict(zip(keys, instance))
|
||||
|
||||
|
||||
CASES = list(
|
||||
product_dict(
|
||||
model_info=MODELS,
|
||||
dtype=[torch.half],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def benchmark_transformer(model_info, dtype):
|
||||
device = "cuda"
|
||||
|
||||
model_name, model_factory, input_shape = model_info
|
||||
|
||||
inp = torch.randn(input_shape, dtype=dtype, device=device)
|
||||
|
||||
for mod_name, mod_apply in MODIFIERS:
|
||||
model: nn.Module = model_factory()
|
||||
model = mod_apply(model).to(device).to(dtype)
|
||||
|
||||
# Make sure we don't have errors
|
||||
out = model(inp)
|
||||
grad = out.clone()
|
||||
out.backward(grad)
|
||||
|
||||
yield benchmark.Timer(
|
||||
stmt="model(inp).backward(grad)",
|
||||
globals={
|
||||
"model": model,
|
||||
"inp": inp,
|
||||
"grad": grad,
|
||||
},
|
||||
label="fw+bw",
|
||||
description=mod_name,
|
||||
sub_label=model_name,
|
||||
)
|
||||
|
||||
|
||||
benchmark_main_helper(benchmark_transformer, CASES)
|
||||
Reference in New Issue
Block a user