Files
enginex-bi_series-vllm/pkgs/xformers/benchmarks/benchmark_revnet.py
2025-08-05 19:02:46 +08:00

84 lines
2.6 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Dict
import torch
import triton
from xformers.benchmarks.utils import TestCase, pretty_plot, pretty_print
from xformers.components.reversible import ReversibleSequence
SHAPES = [(16384, 32), (2048, 256), (128, 4096)]
DEPTH = [4, 32, 256]
def bench_revnet(backward: bool):
device = torch.device("cuda")
bw = "+bw" if backward else ""
for dtype in [torch.float16, torch.float32]:
results: Dict[str, Any] = {}
for B, K in SHAPES:
for depth in DEPTH:
f = torch.nn.Linear(K, K).to(device=device, dtype=dtype)
g = torch.nn.Linear(K, K).to(device=device, dtype=dtype)
revseq = ReversibleSequence(
torch.nn.ModuleList([torch.nn.ModuleList([f, g])] * depth)
)
revseq = revseq.to(device=device, dtype=dtype)
a = torch.rand(
1, B, K, device=device, dtype=dtype, requires_grad=backward
)
b = torch.rand(
1, B, K * 2, device=device, dtype=dtype, requires_grad=backward
)
def normal_step():
y = a
for _ in range(depth):
y = y + f(y)
y = y + g(y)
if backward:
torch.norm(y).backward()
return y
def reversible_step():
y = revseq(b)
if backward:
torch.norm(y).backward()
return y
for testcase in [
TestCase(normal_step, f"residual - fw{bw}"),
TestCase(reversible_step, f"reversible - fw{bw}"),
]:
time = triton.testing.do_bench(testcase.function)[0]
key = f"Batch={B}, Features={K}, Depth={depth}"
if key not in results:
results[key] = {}
results[key][testcase.name] = f"{time:.2f}"
pretty_print(
results,
title=f"\n --- Type: {dtype} --- ",
units="runtime in ms, lower is better",
)
pretty_plot(
results,
title=f"RevNet-FW{bw}-{dtype}",
units="runtime in ms, lower is better",
dash_key="torch",
)
for bw in [False, True]:
bench_revnet(bw)