Files
enginex-bi_series-vllm/pkgs/xformers/benchmarks/benchmark_indexing.py
2025-08-05 19:02:46 +08:00

242 lines
6.8 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import itertools
import random
import torch
from torch.utils import benchmark
from utils import benchmark_main_helper
import xformers.ops as xops
min_run_time = 0.5
device = torch.device("cuda")
def product_dict(**kwargs):
keys = kwargs.keys()
vals = kwargs.values()
for instance in itertools.product(*vals):
yield dict(zip(keys, instance))
CASES_IADD = list(
product_dict(
shape=[
(int(48 * 0.6), 48, 1, 257 * 1536),
(int(48 * 0.6), 48, 257, 1536),
],
scaling=[False, True],
dtype=[torch.half],
)
) + list(
product_dict(
shape=[
# Format: [B_src, B_inp, M, D]
(int(192 * 0.6), 192, 50, 1536),
(int(48 * 257 * 0.6), 257 * 48, 1, 1536),
(int(192 * 50 * 0.6), 192 * 50, 1, 1536),
(int(16 * 257 * 0.6), 48 * 257, 1, 1536),
],
scaling=[False],
dtype=[torch.half],
)
)
CASES_ISELECT = list(
product_dict(
batches=[((48, 257), (50, 192))],
D=[1536],
keep_ratio=[0.6],
dtype=[torch.half],
)
)
DTYPE2STR = {
torch.bfloat16: "b16",
torch.half: "f16",
torch.float32: "f32",
}
def _setup_test(functions, fw: bool = False, bw: bool = False, **kwargs):
for k, benchmark_cls in functions.items():
benchmark_object = benchmark_cls(**kwargs, bw=bw)
label = benchmark_object.label
label += "fw" if fw else ""
label += "bw" if bw else ""
def run_one():
if fw:
benchmark_object.fw()
if bw:
benchmark_object.bw()
yield benchmark.Timer(
stmt="fn()",
globals={
"fn": run_one,
},
label=label,
description=k,
sub_label=benchmark_object.sub_label,
)
class ScaledIndexAddBenchmark:
def __init__(self, dtype, scaling: bool, shape, bw: bool) -> None:
B_src, B_out, M, D = shape
torch.manual_seed(B_out + B_src)
dtype_str = DTYPE2STR.get(dtype, dtype)
self.sub_label = f"{dtype_str} B_src={B_src}, B_out={B_out}, M={M}, D={D} s={'Y' if scaling else 'N'}"
self.label = "scaled_index_add"
self.alpha = 0.73
self.inp = torch.randn(
[B_out, M, D], device="cuda", dtype=dtype, requires_grad=bw
)
self.src = torch.randn(
[B_src, M, D], device="cuda", dtype=dtype, requires_grad=bw
)
self.scaling = (
torch.randn([D], device="cuda", dtype=dtype, requires_grad=bw)
if scaling
else None
)
self.index = torch.tensor(
[i for i in range(self.src.shape[0])], dtype=torch.int64, device="cuda"
)
self.grad = torch.randn([B_out, M, D], device="cuda", dtype=dtype)
self.out = torch.Tensor()
def fw(self) -> None:
self.out = xops.scaled_index_add(
input=self.inp.clone(),
index=self.index,
source=self.src,
scaling=self.scaling,
alpha=self.alpha,
)
def bw(self):
self.inp.grad = None
self.src.grad = None
if self.scaling is not None:
self.scaling.grad = None
self.out.backward(self.grad, retain_graph=True)
class ScaledIndexAddBenchmarkBaseline(ScaledIndexAddBenchmark):
def fw(self) -> None:
src_scaled = self.src
if self.scaling is not None:
src_scaled * self.scaling.unsqueeze(0).unsqueeze(0)
self.out = self.inp.index_add(
dim=0,
source=src_scaled,
index=self.index,
alpha=self.alpha,
)
def scaled_index_add_fw(**kwargs):
yield from _setup_test(
**kwargs,
fw=True,
functions={
"xformers": ScaledIndexAddBenchmark,
"pytorch": ScaledIndexAddBenchmarkBaseline,
},
)
def scaled_index_add_fwbw(**kwargs):
yield from _setup_test(
**kwargs,
fw=True,
bw=True,
functions={
"xformers": ScaledIndexAddBenchmark,
"pytorch": ScaledIndexAddBenchmarkBaseline,
},
)
class IndexSelectBenchmark:
def __init__(self, dtype, batches, D, keep_ratio, bw: bool) -> None:
dtype_str = DTYPE2STR.get(dtype, dtype)
self.sub_label = f"{dtype_str} D={D} batches={batches} keep={keep_ratio}"
self.label = "index_select"
srcs = [torch.randn([B, seqlen * D]) for (B, seqlen) in batches]
src = torch.cat([s.view([-1, D]) for s in srcs], dim=0).cuda().to(dtype)
src.requires_grad_(True)
indices = []
sources = []
elements_i = 0
for source_i in srcs:
index = [i for i in range(source_i.shape[0])]
random.Random(source_i.shape[0]).shuffle(index)
indices.append(
torch.tensor(
index[: int(keep_ratio * source_i.shape[0])],
dtype=torch.int64,
device="cuda",
)
)
sources.append(
src[
elements_i : elements_i + source_i.shape[0] * source_i.shape[1] // D
].reshape(source_i.shape)
)
elements_i += source_i.shape[0] * source_i.shape[1] // D
self.indices, self.sources, self.src = indices, sources, src
self.out = torch.Tensor()
def fw(self) -> None:
self.out = xops.index_select_cat(self.sources, self.indices)
def bw(self):
self.src.grad = None
self.out.backward(self.out, retain_graph=True)
class IndexSelectBenchmarkBaseline(IndexSelectBenchmark):
def fw(self) -> None:
self.out = torch.cat(
[s[i].flatten() for s, i in zip(self.sources, self.indices)], dim=0
)
def index_select_fw(**kwargs):
yield from _setup_test(
**kwargs,
fw=True,
functions={
"xformers": IndexSelectBenchmark,
"pytorch": IndexSelectBenchmarkBaseline,
},
)
def index_select_fwbw(**kwargs):
yield from _setup_test(
**kwargs,
fw=True,
bw=True,
functions={
"xformers": IndexSelectBenchmark,
"pytorch": IndexSelectBenchmarkBaseline,
},
)
benchmark_main_helper(scaled_index_add_fw, CASES_IADD, min_run_time=min_run_time)
benchmark_main_helper(scaled_index_add_fwbw, CASES_IADD, min_run_time=min_run_time)
benchmark_main_helper(index_select_fw, CASES_ISELECT, min_run_time=min_run_time)
benchmark_main_helper(index_select_fwbw, CASES_ISELECT, min_run_time=min_run_time)