First commit
This commit is contained in:
160
pkgs/xformers/benchmarks/benchmark_swiglu.py
Normal file
160
pkgs/xformers/benchmarks/benchmark_swiglu.py
Normal file
@@ -0,0 +1,160 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import itertools
|
||||
from contextlib import nullcontext
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch.utils import benchmark
|
||||
from utils import benchmark_main_helper
|
||||
|
||||
import xformers.ops.swiglu_op as xsw
|
||||
|
||||
min_run_time = 0.5
|
||||
device = torch.device("cuda")
|
||||
|
||||
SHAPES = [
|
||||
# Format: [inp.shape[0], inp.shape[1], hidden.shape[1]]
|
||||
# ViT-Giant
|
||||
(9456, 1536, 2736),
|
||||
(4440, 1536, 2736),
|
||||
(4728, 1536, 2736),
|
||||
# Some smaller shapes as well
|
||||
(4728, 1536, 1024),
|
||||
# GPT-3 (small)
|
||||
(32768, 2048, 5632),
|
||||
# Chinchilla
|
||||
(32768, 8192, 22016),
|
||||
]
|
||||
|
||||
|
||||
# OP = xsw._SwiGLUDecomposedOp
|
||||
# OP = xsw.SwiGLUFusedOp
|
||||
OP = xsw.SwiGLUPackedFusedOp
|
||||
|
||||
|
||||
def product_dict(**kwargs):
|
||||
keys = kwargs.keys()
|
||||
vals = kwargs.values()
|
||||
for instance in itertools.product(*vals):
|
||||
yield dict(zip(keys, instance))
|
||||
|
||||
|
||||
CASES = list(
|
||||
product_dict(
|
||||
shape=SHAPES,
|
||||
dtype=[torch.bfloat16, torch.half, "autocast_half"],
|
||||
bias=[True, False],
|
||||
)
|
||||
)
|
||||
|
||||
DTYPE2STR = {
|
||||
torch.bfloat16: "b16 ",
|
||||
torch.half: "f16 ",
|
||||
"autocast_half": "f16.ac",
|
||||
}
|
||||
|
||||
|
||||
def benchmark_swiglu(shape, dtype, bias: bool):
|
||||
if dtype == "autocast_half":
|
||||
inp_dtype, model_dtype, autocast = torch.float, torch.float, True
|
||||
else:
|
||||
inp_dtype, model_dtype, autocast = dtype, dtype, False
|
||||
|
||||
x = torch.randn(shape[:2], device=device, dtype=inp_dtype)
|
||||
module = (
|
||||
xsw.SwiGLU(in_features=shape[1], hidden_features=shape[2], bias=bias)
|
||||
.to(device)
|
||||
.to(model_dtype)
|
||||
)
|
||||
|
||||
dtype_str = DTYPE2STR.get(dtype, dtype)
|
||||
bstr = "bias" if bias else "nobi"
|
||||
sub_label = f"{dtype_str} B={shape[0]}, I={shape[1]}, H={shape[2]} {bstr}"
|
||||
|
||||
params = module._ordered_params()
|
||||
|
||||
PREFIX = 'with torch.autocast("cuda", dtype=torch.half):\n ' if autocast else ""
|
||||
yield benchmark.Timer(
|
||||
stmt=f"{PREFIX}fn(x, *args)",
|
||||
globals={
|
||||
"x": x,
|
||||
"args": params,
|
||||
"fn": partial(xsw.swiglu, op=OP),
|
||||
},
|
||||
label="swiglu_fw",
|
||||
description=OP.NAME,
|
||||
sub_label=sub_label,
|
||||
)
|
||||
yield benchmark.Timer(
|
||||
stmt=f"{PREFIX}fn(x, *args)",
|
||||
globals={
|
||||
"x": x,
|
||||
"args": params,
|
||||
"fn": partial(xsw.swiglu, op=xsw.SwiGLUEagerOp),
|
||||
},
|
||||
label="swiglu_fw",
|
||||
description="eager",
|
||||
sub_label=sub_label,
|
||||
)
|
||||
|
||||
|
||||
def benchmark_swiglu_bw(shape, dtype, bias: bool):
|
||||
if dtype == "autocast_half":
|
||||
inp_dtype, model_dtype = torch.float, torch.float
|
||||
cm: Any = partial(torch.cuda.amp.autocast, enabled=True, dtype=torch.float16)
|
||||
else:
|
||||
inp_dtype, model_dtype = dtype, dtype
|
||||
cm = nullcontext
|
||||
|
||||
x = torch.randn(shape[:2], device=device, dtype=inp_dtype)
|
||||
x.requires_grad_()
|
||||
module = (
|
||||
xsw.SwiGLU(in_features=shape[1], hidden_features=shape[2], bias=bias)
|
||||
.to(device)
|
||||
.to(model_dtype)
|
||||
)
|
||||
|
||||
dtype_str = DTYPE2STR.get(dtype, dtype)
|
||||
bstr = "bias" if bias else "nobi"
|
||||
sub_label = f"{dtype_str} B={shape[0]}, I={shape[1]}, H={shape[2]} {bstr}"
|
||||
|
||||
params = module._ordered_params()
|
||||
with cm():
|
||||
out = xsw.swiglu(x, *params, op=OP)
|
||||
grad = torch.zeros_like(out)
|
||||
|
||||
yield benchmark.Timer(
|
||||
stmt="out.backward(grad, retain_graph=True)",
|
||||
globals={
|
||||
"out": out,
|
||||
"grad": grad,
|
||||
},
|
||||
label="swiglu_bw",
|
||||
description=OP.NAME,
|
||||
sub_label=sub_label,
|
||||
)
|
||||
del out
|
||||
|
||||
with cm():
|
||||
out = xsw.swiglu(x, *params, op=xsw.SwiGLUEagerOp)
|
||||
|
||||
yield benchmark.Timer(
|
||||
stmt="out.backward(grad, retain_graph=True)",
|
||||
globals={
|
||||
"out": out,
|
||||
"grad": grad,
|
||||
},
|
||||
label="swiglu_bw",
|
||||
description="eager",
|
||||
sub_label=sub_label,
|
||||
)
|
||||
|
||||
|
||||
benchmark_main_helper(benchmark_swiglu, CASES, min_run_time=min_run_time)
|
||||
benchmark_main_helper(benchmark_swiglu_bw, CASES, min_run_time=min_run_time)
|
||||
Reference in New Issue
Block a user