92 lines
2.1 KiB
Python
92 lines
2.1 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import torch
|
|
|
|
from xformers.benchmarks.utils import TestCase, bench_functions
|
|
from xformers.triton.softmax import log_softmax as triton_log_softmax
|
|
from xformers.triton.softmax import softmax as triton_softmax
|
|
|
|
SHAPES = [
|
|
(8, 384, 128),
|
|
(8, 784, 512),
|
|
(4, 1024, 768),
|
|
(4, 2048, 1024),
|
|
(2, 2048, 2048),
|
|
(2, 2048, 4096),
|
|
(2, 4096, 4096),
|
|
(1, 2048, 12288),
|
|
]
|
|
|
|
|
|
def pytorch_fw_bw(x):
|
|
y = torch.norm(torch.softmax(x, dim=-1))
|
|
y.backward()
|
|
|
|
|
|
def triton_causal_fw(x):
|
|
_ = triton_softmax(x, causal=True)
|
|
|
|
|
|
def triton_fw_bw(x):
|
|
y = torch.norm(triton_softmax(x))
|
|
y.backward()
|
|
|
|
|
|
def triton_causal_fw_bw(x):
|
|
y = torch.norm(triton_softmax(x, causal=True))
|
|
y.backward()
|
|
|
|
|
|
def pytorch_log_fw_bw(x):
|
|
y = torch.norm(torch.log_softmax(x, dim=-1))
|
|
y.backward()
|
|
|
|
|
|
def triton_log_fw_bw(x):
|
|
y = torch.norm(triton_log_softmax(x))
|
|
y.backward()
|
|
|
|
|
|
# Test FW
|
|
def to_gbs_fw(a, ms):
|
|
# Read and write the full array
|
|
return (2 * a.numel() * a.element_size() * 1e-9) / (ms * 1e-3)
|
|
|
|
|
|
def to_gbs_fwbw(a, ms):
|
|
# same as above, but we do it twice (FW and then gradient)
|
|
return 2 * to_gbs_fw(a, ms)
|
|
|
|
|
|
bench_functions(
|
|
[
|
|
TestCase(lambda x: torch.softmax(x, dim=-1), "pytorch - fw"),
|
|
TestCase(triton_softmax, "triton - fw"),
|
|
TestCase(triton_causal_fw, "triton - causal - fw"),
|
|
TestCase(lambda x: torch.log_softmax(x, dim=-1), "pytorch - log - fw"),
|
|
TestCase(triton_log_softmax, "triton - log - fw"),
|
|
],
|
|
SHAPES,
|
|
to_gbs_fw,
|
|
"GB/s",
|
|
"Softmax_Bandwidth_FW_",
|
|
)
|
|
|
|
# Test FW+BW
|
|
bench_functions(
|
|
[
|
|
TestCase(pytorch_fw_bw, "pytorch - fw+bw"),
|
|
TestCase(triton_fw_bw, "triton - fw+bw"),
|
|
TestCase(triton_causal_fw_bw, "triton - causal - fw+bw"),
|
|
TestCase(pytorch_log_fw_bw, "pytorch - log - fw+bw"),
|
|
TestCase(triton_log_fw_bw, "triton - log - fw+bw"),
|
|
],
|
|
SHAPES,
|
|
to_gbs_fwbw,
|
|
"GB/s",
|
|
"Softmax_Bandwidth_FW_BW_",
|
|
)
|