Files
enginex-bi_series-vllm/pkgs/xformers/benchmarks/benchmark_mlp.py
2025-08-05 19:02:46 +08:00

128 lines
4.1 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import argparse
from typing import Any, Dict
import torch
import triton
from xformers.benchmarks.utils import TestCase, pretty_plot, pretty_print
from xformers.components import Activation
from xformers.components.feedforward import MLP, FusedMLP
SHAPES = [
(8, 256, 512),
(8, 512, 1024),
(4, 1024, 1024),
(2, 2048, 2048),
(1, 2048, 4096),
(1, 1024, 12288),
]
HIDDEN_LAYER_MULTIPLIER = [4]
def bench_MLP(backward: bool, bias: bool, dropout: float, activation: Activation):
device = torch.device("cuda")
bw = "+bw" if backward else ""
for dtype in [torch.float16, torch.float32]:
results: Dict[str, Any] = {}
for B, M, K in SHAPES:
for hlm in HIDDEN_LAYER_MULTIPLIER:
fused_mlp = FusedMLP(
dim_model=K,
dropout=dropout,
activation=activation,
hidden_layer_multiplier=hlm,
bias=bias,
).to(device=device, dtype=dtype)
standard_mlp = MLP(
dim_model=K,
dropout=dropout,
activation=activation,
hidden_layer_multiplier=hlm,
bias=bias,
).to(device=device, dtype=dtype)
a = torch.randn(
(B, M, K), requires_grad=backward, device=device, dtype=dtype
)
def mlp_standard():
y = standard_mlp(a)
if backward:
torch.norm(y).backward()
return y
def mlp_fused():
y = fused_mlp(a)
if backward:
torch.norm(y).backward()
return y
for testcase in [
TestCase(
mlp_standard,
"standard - {} - {} bias - {} drop - fw{}".format(
activation,
"no" if not bias else "",
dropout,
"+bw" if backward else "",
),
),
TestCase(
mlp_fused,
"fused - {} - {} bias - {} drop - fw{}".format(
activation,
"no" if not bias else "",
dropout,
"+bw" if backward else "",
),
),
]:
time = triton.testing.do_bench(testcase.function)[0]
key = f"{B} x {M} x {K} - {hlm}"
if key not in results:
results[key] = {}
results[key][testcase.name] = f"{time:.2f}"
pretty_print(
results,
title=f"\n --- Type: {dtype} --- ",
units="runtime in ms, lower is better. BMK - mul: ",
)
pretty_plot(
results,
title=f"MLP-{activation}-FW{bw}-{dtype}",
units="runtime in ms, lower is better",
dash_key="torch",
)
if __name__ == "__main__":
# Get the user requests
parser = argparse.ArgumentParser("Benchmark MLP")
parser.add_argument("-act", "--activations", nargs="+", default=[Activation.GeLU])
parser.add_argument("-bias", "--bias", nargs="+", default=[False, True])
parser.add_argument("-dropout", "--dropout", nargs="+", default=[0.0, 0.1])
args = parser.parse_args()
for bw in [False, True]:
for bias in args.bias:
for dropout in args.dropout:
for activation in args.activations:
bench_MLP(
backward=bw,
bias=bias,
dropout=float(dropout),
activation=activation,
)