First commit
This commit is contained in:
115
pkgs/xformers/benchmarks/benchmark_triton_dropout.py
Normal file
115
pkgs/xformers/benchmarks/benchmark_triton_dropout.py
Normal file
@@ -0,0 +1,115 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
import triton
|
||||
|
||||
from xformers.benchmarks.utils import TestCase, pretty_plot, pretty_print
|
||||
from xformers.components import Activation, build_activation
|
||||
from xformers.triton import FusedDropoutBias
|
||||
|
||||
SHAPES = [
|
||||
(8, 256, 512),
|
||||
(8, 512, 1024),
|
||||
(4, 1024, 1024),
|
||||
(2, 2048, 2048),
|
||||
(1, 2048, 12288),
|
||||
(2, 4096, 4096),
|
||||
]
|
||||
|
||||
P = 0.1
|
||||
|
||||
|
||||
def to_gbs_fw(a, ms, bias):
|
||||
# Read and write the full array
|
||||
total = 2 * a.numel() * a.element_size()
|
||||
|
||||
if bias:
|
||||
# Read the bias, ideally only once
|
||||
total += a.shape[-1] * a.element_size()
|
||||
|
||||
return total * 1e-9 / (ms * 1e-3)
|
||||
|
||||
|
||||
def bench_dropout(bias: bool, backward: bool, activation: Optional[Activation]):
|
||||
device = torch.device("cuda")
|
||||
|
||||
for dtype in [
|
||||
torch.float16,
|
||||
torch.float32,
|
||||
]:
|
||||
results: Dict[str, Any] = {}
|
||||
|
||||
for B, M, K in SHAPES:
|
||||
a = torch.rand(
|
||||
(B, M, K), device=device, dtype=dtype, requires_grad=backward
|
||||
)
|
||||
b = torch.rand(K, device=device, dtype=dtype, requires_grad=backward)
|
||||
torch_act = build_activation(activation)
|
||||
triton_dropout = FusedDropoutBias(
|
||||
P, bias_shape=K if bias else None, activation=activation
|
||||
)
|
||||
|
||||
def torch_step(x):
|
||||
x_ = x + b if bias else x
|
||||
y = torch.nn.functional.dropout(x_, P)
|
||||
if activation:
|
||||
y = torch_act(y)
|
||||
|
||||
if backward:
|
||||
y.grad = None
|
||||
torch.norm(y).backward()
|
||||
return y
|
||||
|
||||
def triton_step(x):
|
||||
y = triton_dropout(x)
|
||||
if backward:
|
||||
y.grad = None
|
||||
torch.norm(y).backward()
|
||||
return y
|
||||
|
||||
for testcase in [
|
||||
TestCase(
|
||||
torch_step,
|
||||
"pytorch - bias: {} - fw{} - act: {}".format(
|
||||
bias, "+bw" if backward else "", activation
|
||||
),
|
||||
),
|
||||
TestCase(
|
||||
triton_step,
|
||||
"triton - bias: {} - fw{} - act: {}".format(
|
||||
bias, "+bw" if backward else "", activation
|
||||
),
|
||||
),
|
||||
]:
|
||||
time = triton.testing.do_bench(
|
||||
lambda: testcase.function(a), grad_to_none=[a, b]
|
||||
)[0]
|
||||
key = f"B={B}, M={M}, K={K}"
|
||||
if key not in results:
|
||||
results[key] = {}
|
||||
|
||||
# Record BW
|
||||
bandwidth = to_gbs_fw(a, time, bias)
|
||||
results[key][testcase.name] = f"{bandwidth:.1f}"
|
||||
|
||||
pretty_print(results, title="\n --- Type: {} --- ".format(dtype), units="GB/s")
|
||||
pretty_plot(
|
||||
results,
|
||||
title="Dropout-Bias-{}-FW{}-{}-Act: {}".format(
|
||||
bias, "+BW" if backward else "", dtype, activation
|
||||
),
|
||||
units="GB/s",
|
||||
dash_key="pytorch",
|
||||
)
|
||||
|
||||
|
||||
for activation in [Activation.GeLU, None, Activation.SquaredReLU]:
|
||||
for bw in [True, False]:
|
||||
for bias in [True, False]:
|
||||
bench_dropout(bias, bw, activation)
|
||||
Reference in New Issue
Block a user