Files
enginex-bi_series-vllm/pkgs/xformers/benchmarks/benchmark_nystrom_utils.py
2025-08-05 19:02:46 +08:00

100 lines
3.1 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from typing import Callable
import torch
from torch.utils import benchmark
from xformers.components.attention.utils import iterative_pinv
MIN_RUN_TIME = 1
SHAPES = [[8, 8], [256, 1024], [128, 256]]
SPARSITIES = [0.5, 0.8, 0.9, 0.95, 0.99]
def bench_inverse(inverse_fn: Callable[[torch.Tensor], torch.Tensor]):
min_run_time = MIN_RUN_TIME
prob = 0.9
device = torch.device("cuda")
results = []
for B, M, K in zip(*SHAPES):
a = torch.rand(B, M, M, device=device)
a[a < prob] = 0
a = torch.softmax(a, dim=-1)
results.extend(
[
benchmark.Timer(
stmt=f"{inverse_fn.__name__}(a)",
globals={
"a": a,
f"{inverse_fn.__name__}": inverse_fn,
},
label=f"{inverse_fn.__name__}",
sub_label="dense",
description=f"B={B}, M={M}, K={K}",
).blocked_autorange(min_run_time=min_run_time),
]
)
for prob in SPARSITIES:
a = torch.rand(B, M, M, device=device)
a[a < prob] = 0
a = a.to_sparse()
results.append(
benchmark.Timer(
stmt=f"{inverse_fn.__name__}(a)",
globals={
"a": a,
f"{inverse_fn.__name__}": inverse_fn,
},
label=f"{inverse_fn.__name__}",
sub_label=f"sparsity: {prob:0.2f}",
description=f"B={B}, M={M}, K={K}",
).blocked_autorange(min_run_time=min_run_time)
)
compare = benchmark.Compare(results)
compare.print()
def iterative_pinv_analysis(
identity_tolerance: float = 1e-1,
pinv_tolerance: float = 5e-1,
max_iters: int = 30,
plot: bool = True,
):
for i in range(1, 10):
B, M = 1, 2**i
a = torch.rand(B, M, M)
a = torch.softmax(a, dim=-1)
for n_iter in range(1, max_iters + 1):
result = iterative_pinv(a, n_iter=n_iter)
expected = torch.linalg.pinv(a)
result_identity = torch.matmul(a, result)
identity = torch.eye(M)
# Default is frobenius norm.
identity_error = torch.linalg.norm(identity - result_identity, dim=(-2, -1))
inverse_error = torch.linalg.norm(expected - result, dim=(-2, -1))
if (identity_error < identity_tolerance).all() or n_iter == max_iters:
print(
f"Size {M}, n_iters {n_iter}: \n\t \
Final Error from Identity: {identity_error.item()} \n\t \
Final Error from linalg.pinv {inverse_error.item()}"
)
break
if __name__ == "__main__":
iterative_pinv_analysis()
bench_inverse(iterative_pinv)
bench_inverse(torch.linalg.pinv)