161 lines
4.1 KiB
Python
161 lines
4.1 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
import itertools
|
|
from contextlib import nullcontext
|
|
from functools import partial
|
|
from typing import Any
|
|
|
|
import torch
|
|
from torch.utils import benchmark
|
|
from utils import benchmark_main_helper
|
|
|
|
import xformers.ops.swiglu_op as xsw
|
|
|
|
min_run_time = 0.5
|
|
device = torch.device("cuda")
|
|
|
|
SHAPES = [
|
|
# Format: [inp.shape[0], inp.shape[1], hidden.shape[1]]
|
|
# ViT-Giant
|
|
(9456, 1536, 2736),
|
|
(4440, 1536, 2736),
|
|
(4728, 1536, 2736),
|
|
# Some smaller shapes as well
|
|
(4728, 1536, 1024),
|
|
# GPT-3 (small)
|
|
(32768, 2048, 5632),
|
|
# Chinchilla
|
|
(32768, 8192, 22016),
|
|
]
|
|
|
|
|
|
# OP = xsw._SwiGLUDecomposedOp
|
|
# OP = xsw.SwiGLUFusedOp
|
|
OP = xsw.SwiGLUPackedFusedOp
|
|
|
|
|
|
def product_dict(**kwargs):
|
|
keys = kwargs.keys()
|
|
vals = kwargs.values()
|
|
for instance in itertools.product(*vals):
|
|
yield dict(zip(keys, instance))
|
|
|
|
|
|
CASES = list(
|
|
product_dict(
|
|
shape=SHAPES,
|
|
dtype=[torch.bfloat16, torch.half, "autocast_half"],
|
|
bias=[True, False],
|
|
)
|
|
)
|
|
|
|
DTYPE2STR = {
|
|
torch.bfloat16: "b16 ",
|
|
torch.half: "f16 ",
|
|
"autocast_half": "f16.ac",
|
|
}
|
|
|
|
|
|
def benchmark_swiglu(shape, dtype, bias: bool):
|
|
if dtype == "autocast_half":
|
|
inp_dtype, model_dtype, autocast = torch.float, torch.float, True
|
|
else:
|
|
inp_dtype, model_dtype, autocast = dtype, dtype, False
|
|
|
|
x = torch.randn(shape[:2], device=device, dtype=inp_dtype)
|
|
module = (
|
|
xsw.SwiGLU(in_features=shape[1], hidden_features=shape[2], bias=bias)
|
|
.to(device)
|
|
.to(model_dtype)
|
|
)
|
|
|
|
dtype_str = DTYPE2STR.get(dtype, dtype)
|
|
bstr = "bias" if bias else "nobi"
|
|
sub_label = f"{dtype_str} B={shape[0]}, I={shape[1]}, H={shape[2]} {bstr}"
|
|
|
|
params = module._ordered_params()
|
|
|
|
PREFIX = 'with torch.autocast("cuda", dtype=torch.half):\n ' if autocast else ""
|
|
yield benchmark.Timer(
|
|
stmt=f"{PREFIX}fn(x, *args)",
|
|
globals={
|
|
"x": x,
|
|
"args": params,
|
|
"fn": partial(xsw.swiglu, op=OP),
|
|
},
|
|
label="swiglu_fw",
|
|
description=OP.NAME,
|
|
sub_label=sub_label,
|
|
)
|
|
yield benchmark.Timer(
|
|
stmt=f"{PREFIX}fn(x, *args)",
|
|
globals={
|
|
"x": x,
|
|
"args": params,
|
|
"fn": partial(xsw.swiglu, op=xsw.SwiGLUEagerOp),
|
|
},
|
|
label="swiglu_fw",
|
|
description="eager",
|
|
sub_label=sub_label,
|
|
)
|
|
|
|
|
|
def benchmark_swiglu_bw(shape, dtype, bias: bool):
|
|
if dtype == "autocast_half":
|
|
inp_dtype, model_dtype = torch.float, torch.float
|
|
cm: Any = partial(torch.cuda.amp.autocast, enabled=True, dtype=torch.float16)
|
|
else:
|
|
inp_dtype, model_dtype = dtype, dtype
|
|
cm = nullcontext
|
|
|
|
x = torch.randn(shape[:2], device=device, dtype=inp_dtype)
|
|
x.requires_grad_()
|
|
module = (
|
|
xsw.SwiGLU(in_features=shape[1], hidden_features=shape[2], bias=bias)
|
|
.to(device)
|
|
.to(model_dtype)
|
|
)
|
|
|
|
dtype_str = DTYPE2STR.get(dtype, dtype)
|
|
bstr = "bias" if bias else "nobi"
|
|
sub_label = f"{dtype_str} B={shape[0]}, I={shape[1]}, H={shape[2]} {bstr}"
|
|
|
|
params = module._ordered_params()
|
|
with cm():
|
|
out = xsw.swiglu(x, *params, op=OP)
|
|
grad = torch.zeros_like(out)
|
|
|
|
yield benchmark.Timer(
|
|
stmt="out.backward(grad, retain_graph=True)",
|
|
globals={
|
|
"out": out,
|
|
"grad": grad,
|
|
},
|
|
label="swiglu_bw",
|
|
description=OP.NAME,
|
|
sub_label=sub_label,
|
|
)
|
|
del out
|
|
|
|
with cm():
|
|
out = xsw.swiglu(x, *params, op=xsw.SwiGLUEagerOp)
|
|
|
|
yield benchmark.Timer(
|
|
stmt="out.backward(grad, retain_graph=True)",
|
|
globals={
|
|
"out": out,
|
|
"grad": grad,
|
|
},
|
|
label="swiglu_bw",
|
|
description="eager",
|
|
sub_label=sub_label,
|
|
)
|
|
|
|
|
|
benchmark_main_helper(benchmark_swiglu, CASES, min_run_time=min_run_time)
|
|
benchmark_main_helper(benchmark_swiglu_bw, CASES, min_run_time=min_run_time)
|