Files
2025-08-05 19:02:46 +08:00

161 lines
4.1 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import itertools
from contextlib import nullcontext
from functools import partial
from typing import Any
import torch
from torch.utils import benchmark
from utils import benchmark_main_helper
import xformers.ops.swiglu_op as xsw
min_run_time = 0.5
device = torch.device("cuda")
SHAPES = [
# Format: [inp.shape[0], inp.shape[1], hidden.shape[1]]
# ViT-Giant
(9456, 1536, 2736),
(4440, 1536, 2736),
(4728, 1536, 2736),
# Some smaller shapes as well
(4728, 1536, 1024),
# GPT-3 (small)
(32768, 2048, 5632),
# Chinchilla
(32768, 8192, 22016),
]
# OP = xsw._SwiGLUDecomposedOp
# OP = xsw.SwiGLUFusedOp
OP = xsw.SwiGLUPackedFusedOp
def product_dict(**kwargs):
keys = kwargs.keys()
vals = kwargs.values()
for instance in itertools.product(*vals):
yield dict(zip(keys, instance))
CASES = list(
product_dict(
shape=SHAPES,
dtype=[torch.bfloat16, torch.half, "autocast_half"],
bias=[True, False],
)
)
DTYPE2STR = {
torch.bfloat16: "b16 ",
torch.half: "f16 ",
"autocast_half": "f16.ac",
}
def benchmark_swiglu(shape, dtype, bias: bool):
if dtype == "autocast_half":
inp_dtype, model_dtype, autocast = torch.float, torch.float, True
else:
inp_dtype, model_dtype, autocast = dtype, dtype, False
x = torch.randn(shape[:2], device=device, dtype=inp_dtype)
module = (
xsw.SwiGLU(in_features=shape[1], hidden_features=shape[2], bias=bias)
.to(device)
.to(model_dtype)
)
dtype_str = DTYPE2STR.get(dtype, dtype)
bstr = "bias" if bias else "nobi"
sub_label = f"{dtype_str} B={shape[0]}, I={shape[1]}, H={shape[2]} {bstr}"
params = module._ordered_params()
PREFIX = 'with torch.autocast("cuda", dtype=torch.half):\n ' if autocast else ""
yield benchmark.Timer(
stmt=f"{PREFIX}fn(x, *args)",
globals={
"x": x,
"args": params,
"fn": partial(xsw.swiglu, op=OP),
},
label="swiglu_fw",
description=OP.NAME,
sub_label=sub_label,
)
yield benchmark.Timer(
stmt=f"{PREFIX}fn(x, *args)",
globals={
"x": x,
"args": params,
"fn": partial(xsw.swiglu, op=xsw.SwiGLUEagerOp),
},
label="swiglu_fw",
description="eager",
sub_label=sub_label,
)
def benchmark_swiglu_bw(shape, dtype, bias: bool):
if dtype == "autocast_half":
inp_dtype, model_dtype = torch.float, torch.float
cm: Any = partial(torch.cuda.amp.autocast, enabled=True, dtype=torch.float16)
else:
inp_dtype, model_dtype = dtype, dtype
cm = nullcontext
x = torch.randn(shape[:2], device=device, dtype=inp_dtype)
x.requires_grad_()
module = (
xsw.SwiGLU(in_features=shape[1], hidden_features=shape[2], bias=bias)
.to(device)
.to(model_dtype)
)
dtype_str = DTYPE2STR.get(dtype, dtype)
bstr = "bias" if bias else "nobi"
sub_label = f"{dtype_str} B={shape[0]}, I={shape[1]}, H={shape[2]} {bstr}"
params = module._ordered_params()
with cm():
out = xsw.swiglu(x, *params, op=OP)
grad = torch.zeros_like(out)
yield benchmark.Timer(
stmt="out.backward(grad, retain_graph=True)",
globals={
"out": out,
"grad": grad,
},
label="swiglu_bw",
description=OP.NAME,
sub_label=sub_label,
)
del out
with cm():
out = xsw.swiglu(x, *params, op=xsw.SwiGLUEagerOp)
yield benchmark.Timer(
stmt="out.backward(grad, retain_graph=True)",
globals={
"out": out,
"grad": grad,
},
label="swiglu_bw",
description="eager",
sub_label=sub_label,
)
benchmark_main_helper(benchmark_swiglu, CASES, min_run_time=min_run_time)
benchmark_main_helper(benchmark_swiglu_bw, CASES, min_run_time=min_run_time)