First commit
This commit is contained in:
83
pkgs/xformers/benchmarks/benchmark_revnet.py
Normal file
83
pkgs/xformers/benchmarks/benchmark_revnet.py
Normal file
@@ -0,0 +1,83 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
import triton
|
||||
|
||||
from xformers.benchmarks.utils import TestCase, pretty_plot, pretty_print
|
||||
from xformers.components.reversible import ReversibleSequence
|
||||
|
||||
SHAPES = [(16384, 32), (2048, 256), (128, 4096)]
|
||||
|
||||
DEPTH = [4, 32, 256]
|
||||
|
||||
|
||||
def bench_revnet(backward: bool):
|
||||
device = torch.device("cuda")
|
||||
bw = "+bw" if backward else ""
|
||||
|
||||
for dtype in [torch.float16, torch.float32]:
|
||||
results: Dict[str, Any] = {}
|
||||
|
||||
for B, K in SHAPES:
|
||||
for depth in DEPTH:
|
||||
f = torch.nn.Linear(K, K).to(device=device, dtype=dtype)
|
||||
g = torch.nn.Linear(K, K).to(device=device, dtype=dtype)
|
||||
revseq = ReversibleSequence(
|
||||
torch.nn.ModuleList([torch.nn.ModuleList([f, g])] * depth)
|
||||
)
|
||||
revseq = revseq.to(device=device, dtype=dtype)
|
||||
|
||||
a = torch.rand(
|
||||
1, B, K, device=device, dtype=dtype, requires_grad=backward
|
||||
)
|
||||
b = torch.rand(
|
||||
1, B, K * 2, device=device, dtype=dtype, requires_grad=backward
|
||||
)
|
||||
|
||||
def normal_step():
|
||||
y = a
|
||||
for _ in range(depth):
|
||||
y = y + f(y)
|
||||
y = y + g(y)
|
||||
if backward:
|
||||
torch.norm(y).backward()
|
||||
return y
|
||||
|
||||
def reversible_step():
|
||||
y = revseq(b)
|
||||
if backward:
|
||||
torch.norm(y).backward()
|
||||
return y
|
||||
|
||||
for testcase in [
|
||||
TestCase(normal_step, f"residual - fw{bw}"),
|
||||
TestCase(reversible_step, f"reversible - fw{bw}"),
|
||||
]:
|
||||
time = triton.testing.do_bench(testcase.function)[0]
|
||||
key = f"Batch={B}, Features={K}, Depth={depth}"
|
||||
if key not in results:
|
||||
results[key] = {}
|
||||
|
||||
results[key][testcase.name] = f"{time:.2f}"
|
||||
|
||||
pretty_print(
|
||||
results,
|
||||
title=f"\n --- Type: {dtype} --- ",
|
||||
units="runtime in ms, lower is better",
|
||||
)
|
||||
pretty_plot(
|
||||
results,
|
||||
title=f"RevNet-FW{bw}-{dtype}",
|
||||
units="runtime in ms, lower is better",
|
||||
dash_key="torch",
|
||||
)
|
||||
|
||||
|
||||
for bw in [False, True]:
|
||||
bench_revnet(bw)
|
||||
Reference in New Issue
Block a user