Files
enginex-bi_series-vllm/pkgs/xformers/benchmarks/benchmark_blocksparse_transformers.py
2025-08-05 19:02:46 +08:00

1066 lines
37 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import gc
import math
from collections import namedtuple
from dataclasses import dataclass
import matplotlib.pyplot as plt
import torch
import triton
from triton.ops.blocksparse import matmul as blocksparse_matmul
from xformers.benchmarks.utils import pretty_barplot
from xformers.components.attention.attention_patterns import (
axial_2d_pattern,
causal_1d_pattern,
global_token_pattern,
local_1d_pattern,
local_2d_pattern,
)
from xformers.components.attention.core import SparseCS, _matmul_with_mask
device = "cuda"
TestCase = namedtuple("TestCase", ["prepare_callable", "mask", "config", "name"])
##############################################
# Plotting utilities
##############################################
def plot_mask(mask, config, filename):
sparsity = get_sparsity(mask)
batch_size = config.batch_size
num_heads = config.num_heads
seq_len = config.seq_length
proxy = torch.ones(batch_size, num_heads, seq_len, seq_len, dtype=torch.bool)
proxy = triton.testing.mask_tensor(proxy, mask, config.block_size, False)
proxy = proxy[0][0]
f = plt.figure()
plt.imshow(proxy.logical_not(), cmap="gray")
plt.suptitle("Sparsity = " + str(sparsity) + "%")
plt.savefig(filename)
plt.close(f)
##############################################
# Mask and testing utilities
##############################################
def get_mask(MaskGenType, config, config_setter=[]):
mask_config = Configuration()
mask_config.init(config)
# Get the mask
mask_generator = MaskGenType(mask_config)
for (key, value) in config_setter:
mask_generator.set_config_attr(key, value)
if not mask_generator.is_valid_config():
return None
return mask_generator()
def densify_mask(mask, config):
num_heads = config.num_heads
seq_length = config.seq_length
block_size = config.block_size
dense_mask = torch.zeros(num_heads, seq_length, seq_length)
for (h, i, j) in zip(*mask.nonzero(as_tuple=True)):
dense_mask[
h,
i * block_size : (i + 1) * block_size,
j * block_size : (j + 1) * block_size,
] = mask[h, i, j]
return dense_mask
def mask_tensor(a, mask, config):
return triton.testing.mask_tensor(a, mask, config.block_size, 0.0)
def sparsify_tensor(a, mask, config):
return triton.testing.sparsify_tensor(a, mask, config.block_size)
def get_sparsity(mask):
return round((1.0 - mask.sum().item() / mask.numel()) * 100)
##############################################
# Mask Generation
##############################################
@dataclass
class Configuration:
batch_size: int = 32
num_heads: int = 12
seq_length: int = 2048
hidden_size: int = 768 # hidden_size = n_heads * projection_hidden_dimension
block_size: int = 64
@property
def blocked_seq_length(self):
return int(self.seq_length / self.block_size)
def init(self, kwargs):
for key, value in kwargs.items():
setattr(self, key, value)
def __str__(self):
desc = [
f"bs={self.batch_size}",
f"h={self.num_heads}",
f"k={self.hidden_size}",
f"seq={self.seq_length}",
f"bl={self.block_size}",
]
return ",".join(desc)
class AttentionMask:
def __init__(self, config=None):
super().__init__()
if config is None:
config = Configuration()
self.config = config
def is_blocked(self):
return self.config.block_size != 1
def is_valid_config(self, keep_blocked=True):
return True
def expand(self, mask):
if mask.ndim == 2:
return mask.unsqueeze(0).expand(self.config.num_heads, -1, -1)
def gen_mask(self, keep_blocked=True):
raise NotImplementedError("Abstract data class")
def set_config_attr(self, key, value):
setattr(self.config, key, value)
def __str__(self):
raise NotImplementedError("Abstract data type")
def __call__(self):
mask = self.gen_mask()
return mask, self.config, str(self)
class RandomAttentionMask(AttentionMask):
"""
This is a Random mask. Useful for performance and memory analysis.
"""
def __init__(self, config=None):
super(RandomAttentionMask, self).__init__(config)
self.set_config_attr("mask_prob", 0.5)
def gen_mask(self, keep_blocked=True):
seq_length = self.config.seq_length
if keep_blocked:
seq_length = self.config.blocked_seq_length
mask = torch.rand(seq_length, seq_length) > self.config.mask_prob
return self.expand(mask)
def __str__(self):
return "random"
class LowerTriangularAttentionMask(AttentionMask):
"""
This is a lower triangular mask. This is common in decoder only models.
This should reduce the computation and memory to roughly half as close to
half of the mask is zero.
The mask stays same for each head and each input.
Nit pick (TODO) - While blocking, we need to ensure that the blocks along
the diagonals are themselves lower triangular blocks. But, for performance
measurement, this is ok to ignore as we treat the whole block as useful
values.
"""
def __init__(self, config=None):
super(LowerTriangularAttentionMask, self).__init__(config)
def gen_mask(self, keep_blocked=True):
seq_length = self.config.seq_length
if keep_blocked:
seq_length = self.config.blocked_seq_length
return self.expand(causal_1d_pattern(seq_length))
def __str__(self):
return "lower_triangular"
class BigBirdAttentionMask(AttentionMask):
"""
BigBird mask are composed of three types of masks - random, global and window.
For more details, refer to https://arxiv.org/pdf/2007.14062.pdf
One point to note is that mask is per head here. So, mask is 3D tensor.
(num_heads, seq_length, seq_length).
"""
def __init__(self, config=None):
super(BigBirdAttentionMask, self).__init__(config)
self.mask_per_head = True
self.set_config_attr("num_global_tokens", 2 * self.config.block_size)
self.set_config_attr("num_random_tokens", 3 * self.config.block_size)
self.set_config_attr("num_window_tokens", 3 * self.config.block_size)
def gen_global_mask(self, seq_length):
# Global tokens are tokens that attend to all tokens and to whom all tokens attend to in the sequence
num_global_blocks = self.config.num_global_tokens // self.config.block_size
mask_indices = torch.randint(0, seq_length - 1, size=(num_global_blocks,))
mask_indices = torch.unique(mask_indices)
query_mask = torch.zeros(seq_length).to(dtype=torch.bool)
query_mask.scatter_(0, mask_indices, True)
return global_token_pattern(query_mask)
def gen_random_mask(self, seq_length):
# Each query token attends over r random number of tokens
num_random_blocks = self.config.num_random_tokens // self.config.block_size
mask_indices = torch.randint(
0, seq_length - 1, size=(seq_length, num_random_blocks)
)
random_mask = torch.zeros(seq_length, seq_length).to(dtype=torch.bool)
random_mask.scatter_(1, mask_indices, True)
return random_mask
def gen_window_mask(self, seq_length):
num_window_blocks = self.config.num_window_tokens // self.config.block_size
if num_window_blocks % 2 == 0:
num_window_blocks += 1
return local_1d_pattern(seq_length, num_window_blocks)
def gen_mask(self, keep_blocked=True):
seq_length = self.config.seq_length
if keep_blocked:
seq_length = self.config.blocked_seq_length
assert keep_blocked, "Not implemented, call to_dense later to get full tensor"
if self.mask_per_head:
head_masks = []
for _ in range(self.config.num_heads):
global_mask = self.gen_global_mask(seq_length)
random_mask = self.gen_random_mask(seq_length)
window_mask = self.gen_window_mask(seq_length)
mask = global_mask + random_mask + window_mask
head_masks.append(mask)
mask = torch.stack(head_masks)
else:
global_mask = self.gen_global_mask(seq_length)
random_mask = self.gen_random_mask(seq_length)
window_mask = self.gen_window_mask(seq_length)
mask = global_mask + random_mask + window_mask
mask = self.expand(mask)
return mask
def __str__(self):
return "bigbird"
class AxialAttentionMask(AttentionMask):
"""
BigBird mask are composed of three types of masks - random, global and window.
For more details, refer to https://arxiv.org/pdf/2007.14062.pdf
One point to note is that mask is per head here. So, mask is 3D tensor.
(num_heads, seq_length, seq_length).
"""
def __init__(self, config=None):
super(AxialAttentionMask, self).__init__(config)
if config is None:
self.set_config_attr("seq_length", 1024)
def is_valid_config(self, keep_blocked=True):
seq_length = self.config.seq_length
if keep_blocked:
seq_length = self.config.blocked_seq_length
H = int(math.sqrt(seq_length))
if H * H == seq_length:
return True
return False
def gen_mask(self, keep_blocked=True):
seq_length = self.config.seq_length
if keep_blocked:
seq_length = self.config.blocked_seq_length
H = int(math.sqrt(seq_length))
assert H * H == seq_length, f"H={H}, seq_length={seq_length}"
return self.expand(axial_2d_pattern(H, H))
def __str__(self):
return "axial"
class LocalAttentionMask(AttentionMask):
"""
BigBird mask are composed of three types of masks - random, global and window.
For more details, refer to https://arxiv.org/pdf/2007.14062.pdf
One point to note is that mask is per head here. So, mask is 3D tensor.
(num_heads, seq_length, seq_length).
"""
def __init__(self, config=None):
super(LocalAttentionMask, self).__init__(config)
self.set_config_attr("num_local_blocks", 3)
if config is None:
self.set_config_attr("seq_length", 1024)
def is_valid_config(self, keep_blocked=True):
seq_length = self.config.seq_length
if keep_blocked:
seq_length = self.config.blocked_seq_length
H = int(math.sqrt(seq_length))
if H * H == seq_length:
return True
return False
def gen_mask(self, keep_blocked=True):
seq_length = self.config.seq_length
if keep_blocked:
seq_length = self.config.blocked_seq_length
H = int(math.sqrt(seq_length))
assert H * H == seq_length, f"H={H}, seq_length={seq_length}"
return self.expand(local_2d_pattern(H, H, self.config.num_local_blocks))
def __str__(self):
return "local"
##############################################
# Class to organize the experiments
##############################################
class Experiment:
def __init__(self, mode, dtype, do_accuracy_check, profile_sputnik):
self.mode = mode
self.dtype = dtype
self.do_accuracy_check = do_accuracy_check
self.profile_sputnik = profile_sputnik
def reset_results(self):
self.results = {}
self.results["flops"] = {}
self.results["time"] = {}
self.results["memory"] = {}
self.results["speedup"] = {}
self.results["memory_savings"] = {}
def do_mem(sel, fn):
# bookeeping
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
# actually run the function
fn()
fn()
torch.cuda.synchronize()
return torch.cuda.max_memory_allocated() // 2**20
def gen_config(self):
raise NotImplementedError("Not setup")
def plot(self, sparsity, pattern_name):
raise NotImplementedError("Not setup")
def run(self):
raise NotImplementedError("Not setup")
def add_kv(self, d, d_key, d_value, testcase):
d_value = max(0, d_value)
if d_key not in d:
d[d_key] = {}
d[d_key][testcase.name] = d_value
def bench_all(
self, a, b, tests, mask_config, sparsity, baseline_name, op_flops, dict_key
):
if self.do_accuracy_check:
self.check_all(tests, a, b)
for testcase in tests:
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
try:
fn = testcase.prepare_callable(a, b, testcase.mask, testcase.config)
ms = triton.testing.do_bench(fn)[0]
flops = op_flops / ms * 1e3 # TFlop per second
mem = self.do_mem(fn)
except Exception:
# raise
ms = -1
flops = -1
mem = -1
# Write into results
# dict_key = f"sp={sparsity}%,{mask_config}"
self.add_kv(self.results["time"], dict_key, ms, testcase)
self.add_kv(self.results["flops"], dict_key, flops, testcase)
self.add_kv(self.results["memory"], dict_key, mem, testcase)
speedup = self.results["time"][dict_key][baseline_name] / ms
memory_savings = self.results["memory"][dict_key][baseline_name] / mem
self.add_kv(self.results["speedup"], dict_key, speedup, testcase)
self.add_kv(self.results["flops"], dict_key, flops, testcase)
self.add_kv(
self.results["memory_savings"], dict_key, memory_savings, testcase
)
desc = f"sparsity={sparsity}, ops={op_flops}, time={ms}, tflops={flops}, mem={mem}"
print(f"{testcase.name} --> {mask_config}, {desc}")
def get_inputs(self, config, device="cuda"):
# if mode = sddmm, a, b = query, key
# if mode = spmm, a, b = attn, value
if self.mode == "sddmm":
return [
torch.randn(
config.batch_size,
config.num_heads,
config.seq_length,
config.hidden_size // config.num_heads,
device=device,
dtype=self.dtype,
)
for _ in range(2)
]
else:
assert self.mode == "spmm"
attn = torch.randn(
config.batch_size,
config.num_heads,
config.seq_length,
config.seq_length,
device=device,
dtype=self.dtype,
)
value = torch.randn(
config.batch_size,
config.num_heads,
config.seq_length,
config.hidden_size // config.num_heads,
device=device,
dtype=self.dtype,
)
return [attn, value]
def torch_matmul_callable(self, a, b, mask, config):
input_a = mask_tensor(a, mask, config) if self.mode == "spmm" else a
input_b = b.transpose(-1, -2) if self.mode == "sddmm" else b
def torch_fn():
return torch.matmul(input_a, input_b)
return torch_fn
def get_triton_fn(self, mask, config, mode="sddmm"):
if mode == "sddmm":
return blocksparse_matmul(
layout=mask,
block=config.block_size,
mode="sdd",
device="cuda",
trans_a=False,
trans_b=True,
)
else:
assert mode == "spmm"
return blocksparse_matmul(
layout=mask,
block=config.block_size,
mode="dsd",
device="cuda",
trans_a=False,
trans_b=False,
)
def triton_callable(self, a, b, mask, config):
triton_kernel = self.get_triton_fn(mask, config, self.mode)
input_a = sparsify_tensor(a, mask, config) if self.mode == "spmm" else a
input_b = b
def triton_fn():
return triton_kernel(input_a, input_b)
return triton_fn
def prepare_sputnik_inputs(self, query, key, config, mask):
# - sparse / sputnik
mask_cs = torch.ones(
[config.batch_size, config.num_heads, config.seq_length, config.seq_length],
dtype=torch.bool,
device="cuda",
)
mask_cs = triton.testing.mask_tensor(
mask_cs, mask, config.block_size, value=False
)
# Sputnik kernels only handle fp32
query_cs = query.flatten(start_dim=0, end_dim=1).to(torch.float32)
key_cs = key.flatten(start_dim=0, end_dim=1).to(torch.float32)
query_cs = query_cs.contiguous()
key_cs = key_cs.transpose(-2, -1)
sparse_mask_cs = SparseCS(
mask_cs.flatten(start_dim=0, end_dim=1).contiguous(),
device=torch.device("cuda"),
)
return query_cs, key_cs, sparse_mask_cs
def sputnik_callable(self, a, b, mask, config):
assert self.mode == "sddmm"
a_cs, b_cs, sparse_mask_cs = self.prepare_sputnik_inputs(a, b, config, mask)
def sputnik_fn():
return _matmul_with_mask(a_cs, b_cs, sparse_mask_cs)
return sputnik_fn
def get_op_flops(self, mask, config):
# Measure total compute ops
op_flops = (
2 # FMA
* config.batch_size # batched matmul
* (config.hidden_size // config.num_heads) # Reduce dimension
* float(mask.sum())
* config.block_size
* config.block_size # Effective seq length * seq_length
* 1e-12 # TFlops
)
return op_flops
def check_all(self, tests, a, b):
ref_test = tests[0]
ref_out = ref_test.prepare_callable(a, b, ref_test.mask, ref_test.config)()
res_test = tests[1]
res_out = res_test.prepare_callable(a, b, res_test.mask, res_test.config)()
self.check_accuracy(ref_out, res_out, ref_test.mask, ref_test.config)
def check_accuracy(self, ref_full, res_bsr, mask, config):
if self.mode == "sddmm":
# Get the dense representation of the bsr tensor
# Use triton sparse * dense multiplication to get the dense tensor back
sparse_dot_dsd = blocksparse_matmul(
layout=mask,
block=config.block_size,
mode="dsd",
device="cuda",
trans_a=False,
trans_b=False,
)
identity = torch.eye(
config.seq_length, config.seq_length, device=device, dtype=self.dtype
)
identity = identity.expand(config.batch_size, config.num_heads, -1, -1)
res = sparse_dot_dsd(res_bsr, identity)
# Get the res where values are masked. Expand the blocked mask
# ref = triton.testing.mask_tensor(ref_full, mask, config.block_size)
full_mask = densify_mask(mask, config)
ref = ref_full * full_mask.to(dtype=self.dtype, device=device)
try:
assert torch.allclose(ref, res, atol=1e-3, rtol=1e-3)
except RuntimeError:
pass
except AssertionError:
raise
else:
assert self.mode == "spmm"
# Both are dense outputs
try:
assert torch.allclose(ref_full, res_bsr, atol=1e-3, rtol=1e-3)
except RuntimeError:
pass
except AssertionError:
import pdb
pdb.set_trace()
raise
class DifferentPatternExperiment(Experiment):
"""
In this experiment, we check if sparsity pattern (like bigbird, lower triangular
etc) changes the performance of different kernels. The idea is to check if
changing sparsity pattern, while keeping total sparsity ratio same, leads to any
perforamnce differences.
We will perform two experiments
1) LowerTraingularMask vs RandomMask - Both have ~50% sparsity.
2) BigBird Mask vs RandomMask - Both have same sparsity.
"""
def __init__(self, mode, dtype, do_accuracy_check, profile_sputnik=False):
super(DifferentPatternExperiment, self).__init__(
mode, dtype, do_accuracy_check, profile_sputnik
)
def gen_config(self):
batch_sizes = [32]
heads = [16]
seq_lengths = [1024, 2048]
block_sizes = [64]
hidden_sizes = [1024, 4096, 8192]
for batch in batch_sizes:
for hidden_size in hidden_sizes:
for head in heads:
for seq in seq_lengths:
for block in block_sizes:
entry = {
"batch_size": batch,
"num_heads": head,
"seq_length": seq,
"block_size": block,
"hidden_size": hidden_size,
}
yield entry
def plot(self, sparsity, config, pattern_name):
desc = [
f"bs={config.batch_size}",
f"nheads={config.num_heads}",
f"block={config.block_size}",
f"dtype={self.dtype}",
]
title_suffix = ",".join(desc)
pretty_barplot(
self.results["speedup"],
title=f"{self.mode} - Pattern experiment ({sparsity}%) - speedup\n"
+ title_suffix,
filename=f"same_sparsity_{self.mode}_{self.dtype}_{pattern_name}_time.svg",
dash_key="pytorch",
units="Speedup normalized to torch_matmul",
)
pretty_barplot(
self.results["flops"],
title=f"{self.mode} - Pattern experiment ({sparsity}%) - throughput\n"
+ title_suffix,
filename=f"same_sparsity_{self.mode}_{self.dtype}_{pattern_name}_flops.svg",
dash_key="pytorch",
units="TFlops/s",
)
pretty_barplot(
self.results["memory_savings"],
title=f"{self.mode} - Pattern experiment ({sparsity}%) - memory savings\n"
+ title_suffix,
filename=f"same_sparsity_{self.mode}_{self.dtype}_{pattern_name}_memory.svg",
dash_key="pytorch",
units="Memory savings normalized to torch_matmul",
)
def run(self):
for MaskGenType in [LowerTriangularAttentionMask, BigBirdAttentionMask]:
self.reset_results()
for config in self.gen_config():
# Get pattern mask
pattern_mask, pattern_config, pattern_name = get_mask(
MaskGenType, config
)
sparsity = get_sparsity(pattern_mask)
mask_prob = sparsity / 100
# Get random mask
random_mask, random_config, _ = get_mask(
RandomAttentionMask,
config,
[("mask_prob", mask_prob)],
)
print(f"{pattern_name} sparsity", get_sparsity(pattern_mask))
print("Random sparsity", get_sparsity(random_mask))
# Create input tensors
a, b = self.get_inputs(random_config)
tests = []
baseline_name = "torch-matmul"
tests.append(
TestCase(
self.torch_matmul_callable,
random_mask,
random_config,
f"{baseline_name}",
)
)
tests.append(
TestCase(
self.triton_callable,
random_mask,
random_config,
"triton-random",
)
)
tests.append(
TestCase(
self.triton_callable,
pattern_mask,
pattern_config,
f"triton-{pattern_name}",
)
)
if self.profile_sputnik and self.mode == "sddmm":
tests.append(
TestCase(
self.sputnik_callable,
random_mask,
random_config,
"sputnik-random",
)
)
tests.append(
TestCase(
self.sputnik_callable,
pattern_mask,
pattern_config,
f"sputnik-{pattern_name}",
)
)
dict_key = f"hidden={random_config.hidden_size},seq_len={random_config.seq_length}"
self.bench_all(
a,
b,
tests,
random_config,
sparsity,
baseline_name,
self.get_op_flops(random_mask, random_config),
dict_key,
)
ideal_testcase = TestCase(None, None, None, "Ideal")
ideal_speedup = round(100 / (100 - sparsity), 1)
self.add_kv(
self.results["speedup"], dict_key, ideal_speedup, ideal_testcase
)
self.add_kv(
self.results["memory_savings"],
dict_key,
ideal_speedup,
ideal_testcase,
)
self.plot(sparsity, random_config, pattern_name)
class VarySparsityExperiment(Experiment):
"""
In this experiment, we check how sparsity ration affects the performance.
"""
def __init__(self, mode, dtype, do_accuracy_check, profile_sputnik=False):
super(VarySparsityExperiment, self).__init__(
mode, dtype, do_accuracy_check, profile_sputnik
)
def gen_config(self):
batch_sizes = [32]
heads = [16]
seq_lengths = [2048]
hidden_sizes = [1024, 8192]
block_sizes = [64]
for batch in batch_sizes:
for seq in seq_lengths:
for head in heads:
for block in block_sizes:
for hidden_size in hidden_sizes:
entry = {
"batch_size": batch,
"num_heads": head,
"seq_length": seq,
"block_size": block,
"hidden_size": hidden_size,
}
yield entry
def plot(self, sparsity, config, pattern_name):
desc = [
f"bs={config.batch_size}",
f"nheads={config.num_heads}",
f"block={config.block_size}",
f"dtype={self.dtype}",
f"seq_len={config.seq_length}",
]
title_suffix = ",".join(desc)
pretty_barplot(
self.results["speedup"],
title=f"{self.mode} - SparsityRatio experiment speedup\n" + title_suffix,
filename=f"vary_sparsity_{self.mode}_{self.dtype}_{pattern_name}_time.svg",
dash_key="pytorch",
units="Speedup normalized to torch_matmul",
)
pretty_barplot(
self.results["flops"],
title=f"{self.mode} - SparsityRatio experiment throughput\n" + title_suffix,
filename=f"vary_sparsity_{self.mode}_{self.dtype}_{pattern_name}_flops.svg",
dash_key="pytorch",
units="TFlops/s",
)
pretty_barplot(
self.results["memory_savings"],
title=f"{self.mode} - SparsityRatio experiment memory savings\n"
+ title_suffix,
filename=f"vary_sparsity_{self.mode}_{self.dtype}_{pattern_name}_memory.svg",
dash_key="pytorch",
units="Memory savings normalized to torch_matmul",
)
def run(self):
self.reset_results()
random_config = None
for config in self.gen_config():
for x in range(10, 100, 10):
mask_prob = x / 100.0
# Get random mask
random_mask, random_config, _ = get_mask(
RandomAttentionMask,
config,
[
("mask_prob", mask_prob),
],
)
sparsity = get_sparsity(random_mask)
print("Random sparsity", get_sparsity(random_mask))
# Create input tensors
a, b = self.get_inputs(random_config)
tests = []
baseline_name = "torch-matmul"
tests.append(
TestCase(
self.torch_matmul_callable,
random_mask,
random_config,
f"{baseline_name}",
)
)
tests.append(
TestCase(
self.triton_callable,
random_mask,
random_config,
"triton-random",
)
)
if self.profile_sputnik and self.mode == "sddmm":
tests.append(
TestCase(
self.sputnik_callable,
random_mask,
random_config,
"sputnik-random",
)
)
dict_key = f"sp={mask_prob},hidden={random_config.hidden_size}"
self.bench_all(
a,
b,
tests,
random_config,
sparsity,
baseline_name,
self.get_op_flops(random_mask, random_config),
dict_key,
)
ideal_testcase = TestCase(None, None, None, "Ideal")
ideal_speedup = round(100 / (100 - mask_prob * 100), 1)
self.add_kv(
self.results["speedup"], dict_key, ideal_speedup, ideal_testcase
)
self.add_kv(
self.results["memory_savings"],
dict_key,
ideal_speedup,
ideal_testcase,
)
self.plot(None, random_config, "random")
class BlockSizeExperiment(Experiment):
"""
In this experiment, we analyze how increasing the block size affects
performance. We will take the lower triangular pattern. As we increase the
batch size, the blocks near the diagonal have to do more unnecessary
computation (the effective sparsity starts decreasing).
"""
def __init__(self, mode, dtype, do_accuracy_check, profile_sputnik=False):
super(BlockSizeExperiment, self).__init__(
mode, dtype, do_accuracy_check, profile_sputnik
)
def gen_config(self):
batch_sizes = [32]
heads = [16]
seq_lengths = [2048]
block_sizes = [32, 64, 128, 256]
hidden_sizes = [1024, 8192]
for batch in batch_sizes:
for seq in seq_lengths:
for hidden_size in hidden_sizes:
for block in block_sizes:
for head in heads:
entry = {
"batch_size": batch,
"num_heads": head,
"seq_length": seq,
"block_size": block,
"hidden_size": hidden_size,
}
yield entry
def plot(self, sparsity, config, pattern_name):
pretty_barplot(
self.results["speedup"],
title=f"{self.mode} - BlockSize experiment speedup\n"
f"bs={config.batch_size}, nheads={config.num_heads}, seq_len={config.seq_length}, dtype={self.dtype}",
filename=f"vary_block_size_{self.mode}_{self.dtype}_{pattern_name}_time.svg",
dash_key="pytorch",
units="Speedup normalized to torch matmul",
)
pretty_barplot(
self.results["flops"],
title=f"{self.mode} - BlockSize experiment throughput\n"
f"bs={config.batch_size}, nheads={config.num_heads}, seq_len={config.seq_length}, dtype={self.dtype}",
filename=f"vary_block_size_{self.mode}_{self.dtype}_{pattern_name}_flops.svg",
dash_key="pytorch",
units="TFlops/s",
)
pretty_barplot(
self.results["memory_savings"],
title=f"{self.mode} - BlockSize experiment memory savings\n"
f"bs={config.batch_size}, nheads={config.num_heads}, seq_len={config.seq_length}, dtype={self.dtype}",
filename=f"vary_block_size_{self.mode}_{self.dtype}_{pattern_name}_memory.svg",
dash_key="pytorch",
units="Memory savings normalized to torch matmul",
)
def get_op_flops(self, mask, config):
# Op flops here refer to the original non blocked attention mask, where
# no unnecessary elements are unmasked. We can compute this by computing
# the total flops of batch matmul and then multiply by (n+1)/2n.
num_masked_elems = (config.seq_length + 1) / (2.0 * config.seq_length)
op_flops = (
2 # FMA
* config.batch_size # batched matmul
* config.num_heads
* (config.hidden_size // config.num_heads) # Reduce dimension
* config.seq_length
* config.seq_length
* num_masked_elems
* 1e-12 # TFlops
)
return op_flops
def run(self):
self.reset_results()
lt_config = None
for config in self.gen_config():
lt_mask, lt_config, lt_name = get_mask(
LowerTriangularAttentionMask,
config,
)
sparsity = get_sparsity(lt_mask)
print("Effective sparsity", sparsity)
if lt_config.seq_length == 2048:
plot_mask(lt_mask, lt_config, f"lt_mask_{lt_config.block_size}.svg")
# Create input tensors
a, b = self.get_inputs(lt_config)
tests = []
baseline_name = "torch-matmul"
tests.append(
TestCase(
self.torch_matmul_callable, lt_mask, lt_config, f"{baseline_name}"
)
)
tests.append(
TestCase(self.triton_callable, lt_mask, lt_config, "triton-random")
)
if self.profile_sputnik and self.mode == "sddmm":
tests.append(
TestCase(
self.sputnik_callable, lt_mask, lt_config, "sputnik-random"
)
)
dict_key = f"hidden={lt_config.hidden_size}, block={lt_config.block_size}"
self.bench_all(
a,
b,
tests,
lt_config,
sparsity,
baseline_name,
self.get_op_flops(lt_mask, lt_config),
dict_key,
)
ideal_testcase = TestCase(None, None, None, "Ideal")
seq_len = lt_config.seq_length
total_elems = seq_len * seq_len
nnz = seq_len * (seq_len + 1) / 2
ideal_speedup = (1.0 * total_elems) / nnz
self.add_kv(
self.results["speedup"], dict_key, ideal_speedup, ideal_testcase
)
self.add_kv(
self.results["memory_savings"],
dict_key,
ideal_speedup,
ideal_testcase,
)
self.plot(None, lt_config, lt_name)
if __name__ == "__main__":
for MaskGen in [
RandomAttentionMask,
LowerTriangularAttentionMask,
BigBirdAttentionMask,
AxialAttentionMask,
LocalAttentionMask,
]:
mask_gen = MaskGen()
mask, config, name = mask_gen()
plot_mask(mask, config, f"{name}.svg")
for mode in ["sddmm", "spmm"]:
DifferentPatternExperiment(mode, torch.float16, True).run()
VarySparsityExperiment(mode, torch.float16, True).run()
BlockSizeExperiment(mode, torch.float16, True).run()