242 lines
6.8 KiB
Python
242 lines
6.8 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
import itertools
|
|
import random
|
|
|
|
import torch
|
|
from torch.utils import benchmark
|
|
from utils import benchmark_main_helper
|
|
|
|
import xformers.ops as xops
|
|
|
|
min_run_time = 0.5
|
|
device = torch.device("cuda")
|
|
|
|
|
|
def product_dict(**kwargs):
|
|
keys = kwargs.keys()
|
|
vals = kwargs.values()
|
|
for instance in itertools.product(*vals):
|
|
yield dict(zip(keys, instance))
|
|
|
|
|
|
CASES_IADD = list(
|
|
product_dict(
|
|
shape=[
|
|
(int(48 * 0.6), 48, 1, 257 * 1536),
|
|
(int(48 * 0.6), 48, 257, 1536),
|
|
],
|
|
scaling=[False, True],
|
|
dtype=[torch.half],
|
|
)
|
|
) + list(
|
|
product_dict(
|
|
shape=[
|
|
# Format: [B_src, B_inp, M, D]
|
|
(int(192 * 0.6), 192, 50, 1536),
|
|
(int(48 * 257 * 0.6), 257 * 48, 1, 1536),
|
|
(int(192 * 50 * 0.6), 192 * 50, 1, 1536),
|
|
(int(16 * 257 * 0.6), 48 * 257, 1, 1536),
|
|
],
|
|
scaling=[False],
|
|
dtype=[torch.half],
|
|
)
|
|
)
|
|
|
|
CASES_ISELECT = list(
|
|
product_dict(
|
|
batches=[((48, 257), (50, 192))],
|
|
D=[1536],
|
|
keep_ratio=[0.6],
|
|
dtype=[torch.half],
|
|
)
|
|
)
|
|
|
|
DTYPE2STR = {
|
|
torch.bfloat16: "b16",
|
|
torch.half: "f16",
|
|
torch.float32: "f32",
|
|
}
|
|
|
|
|
|
def _setup_test(functions, fw: bool = False, bw: bool = False, **kwargs):
|
|
for k, benchmark_cls in functions.items():
|
|
benchmark_object = benchmark_cls(**kwargs, bw=bw)
|
|
label = benchmark_object.label
|
|
label += "fw" if fw else ""
|
|
label += "bw" if bw else ""
|
|
|
|
def run_one():
|
|
if fw:
|
|
benchmark_object.fw()
|
|
if bw:
|
|
benchmark_object.bw()
|
|
|
|
yield benchmark.Timer(
|
|
stmt="fn()",
|
|
globals={
|
|
"fn": run_one,
|
|
},
|
|
label=label,
|
|
description=k,
|
|
sub_label=benchmark_object.sub_label,
|
|
)
|
|
|
|
|
|
class ScaledIndexAddBenchmark:
|
|
def __init__(self, dtype, scaling: bool, shape, bw: bool) -> None:
|
|
B_src, B_out, M, D = shape
|
|
torch.manual_seed(B_out + B_src)
|
|
dtype_str = DTYPE2STR.get(dtype, dtype)
|
|
self.sub_label = f"{dtype_str} B_src={B_src}, B_out={B_out}, M={M}, D={D} s={'Y' if scaling else 'N'}"
|
|
self.label = "scaled_index_add"
|
|
self.alpha = 0.73
|
|
|
|
self.inp = torch.randn(
|
|
[B_out, M, D], device="cuda", dtype=dtype, requires_grad=bw
|
|
)
|
|
self.src = torch.randn(
|
|
[B_src, M, D], device="cuda", dtype=dtype, requires_grad=bw
|
|
)
|
|
self.scaling = (
|
|
torch.randn([D], device="cuda", dtype=dtype, requires_grad=bw)
|
|
if scaling
|
|
else None
|
|
)
|
|
self.index = torch.tensor(
|
|
[i for i in range(self.src.shape[0])], dtype=torch.int64, device="cuda"
|
|
)
|
|
self.grad = torch.randn([B_out, M, D], device="cuda", dtype=dtype)
|
|
self.out = torch.Tensor()
|
|
|
|
def fw(self) -> None:
|
|
self.out = xops.scaled_index_add(
|
|
input=self.inp.clone(),
|
|
index=self.index,
|
|
source=self.src,
|
|
scaling=self.scaling,
|
|
alpha=self.alpha,
|
|
)
|
|
|
|
def bw(self):
|
|
self.inp.grad = None
|
|
self.src.grad = None
|
|
if self.scaling is not None:
|
|
self.scaling.grad = None
|
|
self.out.backward(self.grad, retain_graph=True)
|
|
|
|
|
|
class ScaledIndexAddBenchmarkBaseline(ScaledIndexAddBenchmark):
|
|
def fw(self) -> None:
|
|
src_scaled = self.src
|
|
if self.scaling is not None:
|
|
src_scaled * self.scaling.unsqueeze(0).unsqueeze(0)
|
|
self.out = self.inp.index_add(
|
|
dim=0,
|
|
source=src_scaled,
|
|
index=self.index,
|
|
alpha=self.alpha,
|
|
)
|
|
|
|
|
|
def scaled_index_add_fw(**kwargs):
|
|
yield from _setup_test(
|
|
**kwargs,
|
|
fw=True,
|
|
functions={
|
|
"xformers": ScaledIndexAddBenchmark,
|
|
"pytorch": ScaledIndexAddBenchmarkBaseline,
|
|
},
|
|
)
|
|
|
|
|
|
def scaled_index_add_fwbw(**kwargs):
|
|
yield from _setup_test(
|
|
**kwargs,
|
|
fw=True,
|
|
bw=True,
|
|
functions={
|
|
"xformers": ScaledIndexAddBenchmark,
|
|
"pytorch": ScaledIndexAddBenchmarkBaseline,
|
|
},
|
|
)
|
|
|
|
|
|
class IndexSelectBenchmark:
|
|
def __init__(self, dtype, batches, D, keep_ratio, bw: bool) -> None:
|
|
dtype_str = DTYPE2STR.get(dtype, dtype)
|
|
self.sub_label = f"{dtype_str} D={D} batches={batches} keep={keep_ratio}"
|
|
self.label = "index_select"
|
|
srcs = [torch.randn([B, seqlen * D]) for (B, seqlen) in batches]
|
|
src = torch.cat([s.view([-1, D]) for s in srcs], dim=0).cuda().to(dtype)
|
|
src.requires_grad_(True)
|
|
|
|
indices = []
|
|
sources = []
|
|
elements_i = 0
|
|
for source_i in srcs:
|
|
index = [i for i in range(source_i.shape[0])]
|
|
random.Random(source_i.shape[0]).shuffle(index)
|
|
indices.append(
|
|
torch.tensor(
|
|
index[: int(keep_ratio * source_i.shape[0])],
|
|
dtype=torch.int64,
|
|
device="cuda",
|
|
)
|
|
)
|
|
sources.append(
|
|
src[
|
|
elements_i : elements_i + source_i.shape[0] * source_i.shape[1] // D
|
|
].reshape(source_i.shape)
|
|
)
|
|
elements_i += source_i.shape[0] * source_i.shape[1] // D
|
|
self.indices, self.sources, self.src = indices, sources, src
|
|
self.out = torch.Tensor()
|
|
|
|
def fw(self) -> None:
|
|
self.out = xops.index_select_cat(self.sources, self.indices)
|
|
|
|
def bw(self):
|
|
self.src.grad = None
|
|
self.out.backward(self.out, retain_graph=True)
|
|
|
|
|
|
class IndexSelectBenchmarkBaseline(IndexSelectBenchmark):
|
|
def fw(self) -> None:
|
|
self.out = torch.cat(
|
|
[s[i].flatten() for s, i in zip(self.sources, self.indices)], dim=0
|
|
)
|
|
|
|
|
|
def index_select_fw(**kwargs):
|
|
yield from _setup_test(
|
|
**kwargs,
|
|
fw=True,
|
|
functions={
|
|
"xformers": IndexSelectBenchmark,
|
|
"pytorch": IndexSelectBenchmarkBaseline,
|
|
},
|
|
)
|
|
|
|
|
|
def index_select_fwbw(**kwargs):
|
|
yield from _setup_test(
|
|
**kwargs,
|
|
fw=True,
|
|
bw=True,
|
|
functions={
|
|
"xformers": IndexSelectBenchmark,
|
|
"pytorch": IndexSelectBenchmarkBaseline,
|
|
},
|
|
)
|
|
|
|
|
|
benchmark_main_helper(scaled_index_add_fw, CASES_IADD, min_run_time=min_run_time)
|
|
benchmark_main_helper(scaled_index_add_fwbw, CASES_IADD, min_run_time=min_run_time)
|
|
benchmark_main_helper(index_select_fw, CASES_ISELECT, min_run_time=min_run_time)
|
|
benchmark_main_helper(index_select_fwbw, CASES_ISELECT, min_run_time=min_run_time)
|