First commit
This commit is contained in:
91
pkgs/xformers/benchmarks/benchmark_triton_softmax.py
Normal file
91
pkgs/xformers/benchmarks/benchmark_triton_softmax.py
Normal file
@@ -0,0 +1,91 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
|
||||
from xformers.benchmarks.utils import TestCase, bench_functions
|
||||
from xformers.triton.softmax import log_softmax as triton_log_softmax
|
||||
from xformers.triton.softmax import softmax as triton_softmax
|
||||
|
||||
SHAPES = [
|
||||
(8, 384, 128),
|
||||
(8, 784, 512),
|
||||
(4, 1024, 768),
|
||||
(4, 2048, 1024),
|
||||
(2, 2048, 2048),
|
||||
(2, 2048, 4096),
|
||||
(2, 4096, 4096),
|
||||
(1, 2048, 12288),
|
||||
]
|
||||
|
||||
|
||||
def pytorch_fw_bw(x):
|
||||
y = torch.norm(torch.softmax(x, dim=-1))
|
||||
y.backward()
|
||||
|
||||
|
||||
def triton_causal_fw(x):
|
||||
_ = triton_softmax(x, causal=True)
|
||||
|
||||
|
||||
def triton_fw_bw(x):
|
||||
y = torch.norm(triton_softmax(x))
|
||||
y.backward()
|
||||
|
||||
|
||||
def triton_causal_fw_bw(x):
|
||||
y = torch.norm(triton_softmax(x, causal=True))
|
||||
y.backward()
|
||||
|
||||
|
||||
def pytorch_log_fw_bw(x):
|
||||
y = torch.norm(torch.log_softmax(x, dim=-1))
|
||||
y.backward()
|
||||
|
||||
|
||||
def triton_log_fw_bw(x):
|
||||
y = torch.norm(triton_log_softmax(x))
|
||||
y.backward()
|
||||
|
||||
|
||||
# Test FW
|
||||
def to_gbs_fw(a, ms):
|
||||
# Read and write the full array
|
||||
return (2 * a.numel() * a.element_size() * 1e-9) / (ms * 1e-3)
|
||||
|
||||
|
||||
def to_gbs_fwbw(a, ms):
|
||||
# same as above, but we do it twice (FW and then gradient)
|
||||
return 2 * to_gbs_fw(a, ms)
|
||||
|
||||
|
||||
bench_functions(
|
||||
[
|
||||
TestCase(lambda x: torch.softmax(x, dim=-1), "pytorch - fw"),
|
||||
TestCase(triton_softmax, "triton - fw"),
|
||||
TestCase(triton_causal_fw, "triton - causal - fw"),
|
||||
TestCase(lambda x: torch.log_softmax(x, dim=-1), "pytorch - log - fw"),
|
||||
TestCase(triton_log_softmax, "triton - log - fw"),
|
||||
],
|
||||
SHAPES,
|
||||
to_gbs_fw,
|
||||
"GB/s",
|
||||
"Softmax_Bandwidth_FW_",
|
||||
)
|
||||
|
||||
# Test FW+BW
|
||||
bench_functions(
|
||||
[
|
||||
TestCase(pytorch_fw_bw, "pytorch - fw+bw"),
|
||||
TestCase(triton_fw_bw, "triton - fw+bw"),
|
||||
TestCase(triton_causal_fw_bw, "triton - causal - fw+bw"),
|
||||
TestCase(pytorch_log_fw_bw, "pytorch - log - fw+bw"),
|
||||
TestCase(triton_log_fw_bw, "triton - log - fw+bw"),
|
||||
],
|
||||
SHAPES,
|
||||
to_gbs_fwbw,
|
||||
"GB/s",
|
||||
"Softmax_Bandwidth_FW_BW_",
|
||||
)
|
||||
Reference in New Issue
Block a user