100 lines
3.1 KiB
Python
100 lines
3.1 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
from typing import Callable
|
|
|
|
import torch
|
|
from torch.utils import benchmark
|
|
|
|
from xformers.components.attention.utils import iterative_pinv
|
|
|
|
MIN_RUN_TIME = 1
|
|
SHAPES = [[8, 8], [256, 1024], [128, 256]]
|
|
SPARSITIES = [0.5, 0.8, 0.9, 0.95, 0.99]
|
|
|
|
|
|
def bench_inverse(inverse_fn: Callable[[torch.Tensor], torch.Tensor]):
|
|
min_run_time = MIN_RUN_TIME
|
|
prob = 0.9
|
|
device = torch.device("cuda")
|
|
results = []
|
|
|
|
for B, M, K in zip(*SHAPES):
|
|
a = torch.rand(B, M, M, device=device)
|
|
a[a < prob] = 0
|
|
a = torch.softmax(a, dim=-1)
|
|
|
|
results.extend(
|
|
[
|
|
benchmark.Timer(
|
|
stmt=f"{inverse_fn.__name__}(a)",
|
|
globals={
|
|
"a": a,
|
|
f"{inverse_fn.__name__}": inverse_fn,
|
|
},
|
|
label=f"{inverse_fn.__name__}",
|
|
sub_label="dense",
|
|
description=f"B={B}, M={M}, K={K}",
|
|
).blocked_autorange(min_run_time=min_run_time),
|
|
]
|
|
)
|
|
for prob in SPARSITIES:
|
|
a = torch.rand(B, M, M, device=device)
|
|
a[a < prob] = 0
|
|
a = a.to_sparse()
|
|
results.append(
|
|
benchmark.Timer(
|
|
stmt=f"{inverse_fn.__name__}(a)",
|
|
globals={
|
|
"a": a,
|
|
f"{inverse_fn.__name__}": inverse_fn,
|
|
},
|
|
label=f"{inverse_fn.__name__}",
|
|
sub_label=f"sparsity: {prob:0.2f}",
|
|
description=f"B={B}, M={M}, K={K}",
|
|
).blocked_autorange(min_run_time=min_run_time)
|
|
)
|
|
|
|
compare = benchmark.Compare(results)
|
|
compare.print()
|
|
|
|
|
|
def iterative_pinv_analysis(
|
|
identity_tolerance: float = 1e-1,
|
|
pinv_tolerance: float = 5e-1,
|
|
max_iters: int = 30,
|
|
plot: bool = True,
|
|
):
|
|
|
|
for i in range(1, 10):
|
|
B, M = 1, 2**i
|
|
a = torch.rand(B, M, M)
|
|
a = torch.softmax(a, dim=-1)
|
|
|
|
for n_iter in range(1, max_iters + 1):
|
|
result = iterative_pinv(a, n_iter=n_iter)
|
|
expected = torch.linalg.pinv(a)
|
|
|
|
result_identity = torch.matmul(a, result)
|
|
identity = torch.eye(M)
|
|
|
|
# Default is frobenius norm.
|
|
identity_error = torch.linalg.norm(identity - result_identity, dim=(-2, -1))
|
|
inverse_error = torch.linalg.norm(expected - result, dim=(-2, -1))
|
|
|
|
if (identity_error < identity_tolerance).all() or n_iter == max_iters:
|
|
print(
|
|
f"Size {M}, n_iters {n_iter}: \n\t \
|
|
Final Error from Identity: {identity_error.item()} \n\t \
|
|
Final Error from linalg.pinv {inverse_error.item()}"
|
|
)
|
|
break
|
|
|
|
|
|
if __name__ == "__main__":
|
|
iterative_pinv_analysis()
|
|
bench_inverse(iterative_pinv)
|
|
bench_inverse(torch.linalg.pinv)
|