Files
enginex-bi_series-vllm/pkgs/xformers/benchmarks/benchmark_transformer.py
2025-08-05 19:02:46 +08:00

156 lines
4.3 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import itertools
from functools import partial, reduce
import timm
import torch
import torch.nn as nn
from timm.models.layers import Mlp as TimmMlp
from timm.models.vision_transformer import Attention as TimmAttention
from timm.models.vision_transformer import Block as TimmBlock
from torch.utils import benchmark
from utils import benchmark_main_helper
import xformers.ops as xops
def replace_module(module: nn.Module, replace_class, factory):
if isinstance(module, replace_class):
return factory(module)
module_output = module
for name, child in module.named_children():
module_output.add_module(name, replace_module(child, replace_class, factory))
del module
return module_output
class TimmMemEffAttention(nn.Module):
def __init__(self, attn: TimmAttention, op=None):
super().__init__()
self.op = None
self.num_heads = attn.num_heads
self.scale = attn.scale
self.qkv = attn.qkv
self.attn_drop = attn.attn_drop
self.proj = attn.proj
self.proj_drop = attn.proj_drop
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
q, k, v = xops.unbind(qkv, dim=2)
x = xops.memory_efficient_attention(q, k, v, op=self.op).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class TimmSwiGLU(nn.Module):
def __init__(self, mlp: TimmMlp, op=None) -> None:
super().__init__()
self.fc1 = mlp.fc1
self.swiglu = xops.SwiGLU(
in_features=mlp.fc1.in_features,
hidden_features=mlp.fc1.out_features,
bias=True,
)
self.op = op
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.swiglu(x)
def mod_memeff_attn(model: nn.Module, op=None) -> nn.Module:
return replace_module(model, TimmAttention, partial(TimmMemEffAttention, op=op))
def mod_mlp_to_swiglu(model: nn.Module, op=None) -> nn.Module:
def _mlp_to_swiglu(block: TimmBlock):
block.mlp = TimmSwiGLU(block.mlp, op=op)
return block
return replace_module(model, TimmBlock, _mlp_to_swiglu)
mod_mlp_to_eagr_swiglu = partial(mod_mlp_to_swiglu, op=xops.SwiGLUEagerOp)
mod_mlp_to_fast_swiglu = partial(mod_mlp_to_swiglu, op=None)
def compose(*fns):
def compose2(f, g):
return lambda *a, **kw: f(g(*a, **kw))
return reduce(compose2, fns)
MODELS = [
# model_name, model_factory, input_shape
("ViT-B/16", timm.models.vit_base_patch16_224, [512, 3, 224, 224]),
("ViT-B/8", timm.models.vit_base_patch8_224, [64, 3, 224, 224]),
("ViT-L/16", timm.models.vit_large_patch16_224, [128, 3, 224, 224]),
("ViT-g/14", timm.models.vit_giant_patch14_224, [32, 3, 224, 224]),
]
MODIFIERS = [
["mlp", lambda x: x],
["mlp+memeff", compose(mod_mlp_to_fast_swiglu, mod_memeff_attn)],
["swiglu", mod_mlp_to_eagr_swiglu],
["swiglu+fast_swiglu", mod_mlp_to_fast_swiglu],
["swiglu+fast_swiglu+memeff", compose(mod_mlp_to_fast_swiglu, mod_memeff_attn)],
]
def product_dict(**kwargs):
keys = kwargs.keys()
vals = kwargs.values()
for instance in itertools.product(*vals):
yield dict(zip(keys, instance))
CASES = list(
product_dict(
model_info=MODELS,
dtype=[torch.half],
)
)
def benchmark_transformer(model_info, dtype):
device = "cuda"
model_name, model_factory, input_shape = model_info
inp = torch.randn(input_shape, dtype=dtype, device=device)
for mod_name, mod_apply in MODIFIERS:
model: nn.Module = model_factory()
model = mod_apply(model).to(device).to(dtype)
# Make sure we don't have errors
out = model(inp)
grad = out.clone()
out.backward(grad)
yield benchmark.Timer(
stmt="model(inp).backward(grad)",
globals={
"model": model,
"inp": inp,
"grad": grad,
},
label="fw+bw",
description=mod_name,
sub_label=model_name,
)
benchmark_main_helper(benchmark_transformer, CASES)