1066 lines
37 KiB
Python
1066 lines
37 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
import gc
|
|
import math
|
|
from collections import namedtuple
|
|
from dataclasses import dataclass
|
|
|
|
import matplotlib.pyplot as plt
|
|
import torch
|
|
import triton
|
|
from triton.ops.blocksparse import matmul as blocksparse_matmul
|
|
|
|
from xformers.benchmarks.utils import pretty_barplot
|
|
from xformers.components.attention.attention_patterns import (
|
|
axial_2d_pattern,
|
|
causal_1d_pattern,
|
|
global_token_pattern,
|
|
local_1d_pattern,
|
|
local_2d_pattern,
|
|
)
|
|
from xformers.components.attention.core import SparseCS, _matmul_with_mask
|
|
|
|
device = "cuda"
|
|
TestCase = namedtuple("TestCase", ["prepare_callable", "mask", "config", "name"])
|
|
|
|
##############################################
|
|
# Plotting utilities
|
|
##############################################
|
|
|
|
|
|
def plot_mask(mask, config, filename):
|
|
sparsity = get_sparsity(mask)
|
|
batch_size = config.batch_size
|
|
num_heads = config.num_heads
|
|
seq_len = config.seq_length
|
|
|
|
proxy = torch.ones(batch_size, num_heads, seq_len, seq_len, dtype=torch.bool)
|
|
proxy = triton.testing.mask_tensor(proxy, mask, config.block_size, False)
|
|
proxy = proxy[0][0]
|
|
|
|
f = plt.figure()
|
|
plt.imshow(proxy.logical_not(), cmap="gray")
|
|
plt.suptitle("Sparsity = " + str(sparsity) + "%")
|
|
plt.savefig(filename)
|
|
plt.close(f)
|
|
|
|
|
|
##############################################
|
|
# Mask and testing utilities
|
|
##############################################
|
|
|
|
|
|
def get_mask(MaskGenType, config, config_setter=[]):
|
|
mask_config = Configuration()
|
|
mask_config.init(config)
|
|
|
|
# Get the mask
|
|
mask_generator = MaskGenType(mask_config)
|
|
for (key, value) in config_setter:
|
|
mask_generator.set_config_attr(key, value)
|
|
if not mask_generator.is_valid_config():
|
|
return None
|
|
|
|
return mask_generator()
|
|
|
|
|
|
def densify_mask(mask, config):
|
|
num_heads = config.num_heads
|
|
seq_length = config.seq_length
|
|
block_size = config.block_size
|
|
dense_mask = torch.zeros(num_heads, seq_length, seq_length)
|
|
for (h, i, j) in zip(*mask.nonzero(as_tuple=True)):
|
|
dense_mask[
|
|
h,
|
|
i * block_size : (i + 1) * block_size,
|
|
j * block_size : (j + 1) * block_size,
|
|
] = mask[h, i, j]
|
|
return dense_mask
|
|
|
|
|
|
def mask_tensor(a, mask, config):
|
|
return triton.testing.mask_tensor(a, mask, config.block_size, 0.0)
|
|
|
|
|
|
def sparsify_tensor(a, mask, config):
|
|
return triton.testing.sparsify_tensor(a, mask, config.block_size)
|
|
|
|
|
|
def get_sparsity(mask):
|
|
return round((1.0 - mask.sum().item() / mask.numel()) * 100)
|
|
|
|
|
|
##############################################
|
|
# Mask Generation
|
|
##############################################
|
|
|
|
|
|
@dataclass
|
|
class Configuration:
|
|
batch_size: int = 32
|
|
num_heads: int = 12
|
|
seq_length: int = 2048
|
|
hidden_size: int = 768 # hidden_size = n_heads * projection_hidden_dimension
|
|
block_size: int = 64
|
|
|
|
@property
|
|
def blocked_seq_length(self):
|
|
return int(self.seq_length / self.block_size)
|
|
|
|
def init(self, kwargs):
|
|
for key, value in kwargs.items():
|
|
setattr(self, key, value)
|
|
|
|
def __str__(self):
|
|
desc = [
|
|
f"bs={self.batch_size}",
|
|
f"h={self.num_heads}",
|
|
f"k={self.hidden_size}",
|
|
f"seq={self.seq_length}",
|
|
f"bl={self.block_size}",
|
|
]
|
|
return ",".join(desc)
|
|
|
|
|
|
class AttentionMask:
|
|
def __init__(self, config=None):
|
|
super().__init__()
|
|
if config is None:
|
|
config = Configuration()
|
|
self.config = config
|
|
|
|
def is_blocked(self):
|
|
return self.config.block_size != 1
|
|
|
|
def is_valid_config(self, keep_blocked=True):
|
|
return True
|
|
|
|
def expand(self, mask):
|
|
if mask.ndim == 2:
|
|
return mask.unsqueeze(0).expand(self.config.num_heads, -1, -1)
|
|
|
|
def gen_mask(self, keep_blocked=True):
|
|
raise NotImplementedError("Abstract data class")
|
|
|
|
def set_config_attr(self, key, value):
|
|
setattr(self.config, key, value)
|
|
|
|
def __str__(self):
|
|
raise NotImplementedError("Abstract data type")
|
|
|
|
def __call__(self):
|
|
mask = self.gen_mask()
|
|
return mask, self.config, str(self)
|
|
|
|
|
|
class RandomAttentionMask(AttentionMask):
|
|
"""
|
|
This is a Random mask. Useful for performance and memory analysis.
|
|
"""
|
|
|
|
def __init__(self, config=None):
|
|
super(RandomAttentionMask, self).__init__(config)
|
|
self.set_config_attr("mask_prob", 0.5)
|
|
|
|
def gen_mask(self, keep_blocked=True):
|
|
seq_length = self.config.seq_length
|
|
if keep_blocked:
|
|
seq_length = self.config.blocked_seq_length
|
|
|
|
mask = torch.rand(seq_length, seq_length) > self.config.mask_prob
|
|
return self.expand(mask)
|
|
|
|
def __str__(self):
|
|
return "random"
|
|
|
|
|
|
class LowerTriangularAttentionMask(AttentionMask):
|
|
"""
|
|
This is a lower triangular mask. This is common in decoder only models.
|
|
This should reduce the computation and memory to roughly half as close to
|
|
half of the mask is zero.
|
|
|
|
The mask stays same for each head and each input.
|
|
|
|
Nit pick (TODO) - While blocking, we need to ensure that the blocks along
|
|
the diagonals are themselves lower triangular blocks. But, for performance
|
|
measurement, this is ok to ignore as we treat the whole block as useful
|
|
values.
|
|
"""
|
|
|
|
def __init__(self, config=None):
|
|
super(LowerTriangularAttentionMask, self).__init__(config)
|
|
|
|
def gen_mask(self, keep_blocked=True):
|
|
seq_length = self.config.seq_length
|
|
if keep_blocked:
|
|
seq_length = self.config.blocked_seq_length
|
|
return self.expand(causal_1d_pattern(seq_length))
|
|
|
|
def __str__(self):
|
|
return "lower_triangular"
|
|
|
|
|
|
class BigBirdAttentionMask(AttentionMask):
|
|
"""
|
|
BigBird mask are composed of three types of masks - random, global and window.
|
|
For more details, refer to https://arxiv.org/pdf/2007.14062.pdf
|
|
|
|
One point to note is that mask is per head here. So, mask is 3D tensor.
|
|
(num_heads, seq_length, seq_length).
|
|
"""
|
|
|
|
def __init__(self, config=None):
|
|
super(BigBirdAttentionMask, self).__init__(config)
|
|
self.mask_per_head = True
|
|
self.set_config_attr("num_global_tokens", 2 * self.config.block_size)
|
|
self.set_config_attr("num_random_tokens", 3 * self.config.block_size)
|
|
self.set_config_attr("num_window_tokens", 3 * self.config.block_size)
|
|
|
|
def gen_global_mask(self, seq_length):
|
|
# Global tokens are tokens that attend to all tokens and to whom all tokens attend to in the sequence
|
|
num_global_blocks = self.config.num_global_tokens // self.config.block_size
|
|
mask_indices = torch.randint(0, seq_length - 1, size=(num_global_blocks,))
|
|
mask_indices = torch.unique(mask_indices)
|
|
query_mask = torch.zeros(seq_length).to(dtype=torch.bool)
|
|
query_mask.scatter_(0, mask_indices, True)
|
|
return global_token_pattern(query_mask)
|
|
|
|
def gen_random_mask(self, seq_length):
|
|
# Each query token attends over r random number of tokens
|
|
num_random_blocks = self.config.num_random_tokens // self.config.block_size
|
|
mask_indices = torch.randint(
|
|
0, seq_length - 1, size=(seq_length, num_random_blocks)
|
|
)
|
|
random_mask = torch.zeros(seq_length, seq_length).to(dtype=torch.bool)
|
|
random_mask.scatter_(1, mask_indices, True)
|
|
return random_mask
|
|
|
|
def gen_window_mask(self, seq_length):
|
|
num_window_blocks = self.config.num_window_tokens // self.config.block_size
|
|
if num_window_blocks % 2 == 0:
|
|
num_window_blocks += 1
|
|
return local_1d_pattern(seq_length, num_window_blocks)
|
|
|
|
def gen_mask(self, keep_blocked=True):
|
|
seq_length = self.config.seq_length
|
|
if keep_blocked:
|
|
seq_length = self.config.blocked_seq_length
|
|
assert keep_blocked, "Not implemented, call to_dense later to get full tensor"
|
|
|
|
if self.mask_per_head:
|
|
head_masks = []
|
|
|
|
for _ in range(self.config.num_heads):
|
|
global_mask = self.gen_global_mask(seq_length)
|
|
random_mask = self.gen_random_mask(seq_length)
|
|
window_mask = self.gen_window_mask(seq_length)
|
|
|
|
mask = global_mask + random_mask + window_mask
|
|
head_masks.append(mask)
|
|
|
|
mask = torch.stack(head_masks)
|
|
else:
|
|
global_mask = self.gen_global_mask(seq_length)
|
|
random_mask = self.gen_random_mask(seq_length)
|
|
window_mask = self.gen_window_mask(seq_length)
|
|
|
|
mask = global_mask + random_mask + window_mask
|
|
mask = self.expand(mask)
|
|
return mask
|
|
|
|
def __str__(self):
|
|
return "bigbird"
|
|
|
|
|
|
class AxialAttentionMask(AttentionMask):
|
|
"""
|
|
BigBird mask are composed of three types of masks - random, global and window.
|
|
For more details, refer to https://arxiv.org/pdf/2007.14062.pdf
|
|
|
|
One point to note is that mask is per head here. So, mask is 3D tensor.
|
|
(num_heads, seq_length, seq_length).
|
|
"""
|
|
|
|
def __init__(self, config=None):
|
|
super(AxialAttentionMask, self).__init__(config)
|
|
if config is None:
|
|
self.set_config_attr("seq_length", 1024)
|
|
|
|
def is_valid_config(self, keep_blocked=True):
|
|
seq_length = self.config.seq_length
|
|
if keep_blocked:
|
|
seq_length = self.config.blocked_seq_length
|
|
H = int(math.sqrt(seq_length))
|
|
if H * H == seq_length:
|
|
return True
|
|
return False
|
|
|
|
def gen_mask(self, keep_blocked=True):
|
|
seq_length = self.config.seq_length
|
|
if keep_blocked:
|
|
seq_length = self.config.blocked_seq_length
|
|
H = int(math.sqrt(seq_length))
|
|
assert H * H == seq_length, f"H={H}, seq_length={seq_length}"
|
|
return self.expand(axial_2d_pattern(H, H))
|
|
|
|
def __str__(self):
|
|
return "axial"
|
|
|
|
|
|
class LocalAttentionMask(AttentionMask):
|
|
"""
|
|
BigBird mask are composed of three types of masks - random, global and window.
|
|
For more details, refer to https://arxiv.org/pdf/2007.14062.pdf
|
|
|
|
One point to note is that mask is per head here. So, mask is 3D tensor.
|
|
(num_heads, seq_length, seq_length).
|
|
"""
|
|
|
|
def __init__(self, config=None):
|
|
super(LocalAttentionMask, self).__init__(config)
|
|
self.set_config_attr("num_local_blocks", 3)
|
|
if config is None:
|
|
self.set_config_attr("seq_length", 1024)
|
|
|
|
def is_valid_config(self, keep_blocked=True):
|
|
seq_length = self.config.seq_length
|
|
if keep_blocked:
|
|
seq_length = self.config.blocked_seq_length
|
|
H = int(math.sqrt(seq_length))
|
|
if H * H == seq_length:
|
|
return True
|
|
return False
|
|
|
|
def gen_mask(self, keep_blocked=True):
|
|
seq_length = self.config.seq_length
|
|
if keep_blocked:
|
|
seq_length = self.config.blocked_seq_length
|
|
H = int(math.sqrt(seq_length))
|
|
assert H * H == seq_length, f"H={H}, seq_length={seq_length}"
|
|
return self.expand(local_2d_pattern(H, H, self.config.num_local_blocks))
|
|
|
|
def __str__(self):
|
|
return "local"
|
|
|
|
|
|
##############################################
|
|
# Class to organize the experiments
|
|
##############################################
|
|
|
|
|
|
class Experiment:
|
|
def __init__(self, mode, dtype, do_accuracy_check, profile_sputnik):
|
|
self.mode = mode
|
|
self.dtype = dtype
|
|
self.do_accuracy_check = do_accuracy_check
|
|
self.profile_sputnik = profile_sputnik
|
|
|
|
def reset_results(self):
|
|
self.results = {}
|
|
self.results["flops"] = {}
|
|
self.results["time"] = {}
|
|
self.results["memory"] = {}
|
|
self.results["speedup"] = {}
|
|
self.results["memory_savings"] = {}
|
|
|
|
def do_mem(sel, fn):
|
|
# bookeeping
|
|
torch.cuda.empty_cache()
|
|
torch.cuda.reset_peak_memory_stats()
|
|
# actually run the function
|
|
fn()
|
|
fn()
|
|
torch.cuda.synchronize()
|
|
return torch.cuda.max_memory_allocated() // 2**20
|
|
|
|
def gen_config(self):
|
|
raise NotImplementedError("Not setup")
|
|
|
|
def plot(self, sparsity, pattern_name):
|
|
raise NotImplementedError("Not setup")
|
|
|
|
def run(self):
|
|
raise NotImplementedError("Not setup")
|
|
|
|
def add_kv(self, d, d_key, d_value, testcase):
|
|
d_value = max(0, d_value)
|
|
if d_key not in d:
|
|
d[d_key] = {}
|
|
d[d_key][testcase.name] = d_value
|
|
|
|
def bench_all(
|
|
self, a, b, tests, mask_config, sparsity, baseline_name, op_flops, dict_key
|
|
):
|
|
|
|
if self.do_accuracy_check:
|
|
self.check_all(tests, a, b)
|
|
|
|
for testcase in tests:
|
|
gc.collect()
|
|
torch.cuda.empty_cache()
|
|
torch.cuda.reset_peak_memory_stats()
|
|
try:
|
|
fn = testcase.prepare_callable(a, b, testcase.mask, testcase.config)
|
|
ms = triton.testing.do_bench(fn)[0]
|
|
flops = op_flops / ms * 1e3 # TFlop per second
|
|
mem = self.do_mem(fn)
|
|
except Exception:
|
|
# raise
|
|
ms = -1
|
|
flops = -1
|
|
mem = -1
|
|
|
|
# Write into results
|
|
# dict_key = f"sp={sparsity}%,{mask_config}"
|
|
self.add_kv(self.results["time"], dict_key, ms, testcase)
|
|
self.add_kv(self.results["flops"], dict_key, flops, testcase)
|
|
self.add_kv(self.results["memory"], dict_key, mem, testcase)
|
|
|
|
speedup = self.results["time"][dict_key][baseline_name] / ms
|
|
memory_savings = self.results["memory"][dict_key][baseline_name] / mem
|
|
self.add_kv(self.results["speedup"], dict_key, speedup, testcase)
|
|
self.add_kv(self.results["flops"], dict_key, flops, testcase)
|
|
self.add_kv(
|
|
self.results["memory_savings"], dict_key, memory_savings, testcase
|
|
)
|
|
|
|
desc = f"sparsity={sparsity}, ops={op_flops}, time={ms}, tflops={flops}, mem={mem}"
|
|
print(f"{testcase.name} --> {mask_config}, {desc}")
|
|
|
|
def get_inputs(self, config, device="cuda"):
|
|
# if mode = sddmm, a, b = query, key
|
|
# if mode = spmm, a, b = attn, value
|
|
if self.mode == "sddmm":
|
|
return [
|
|
torch.randn(
|
|
config.batch_size,
|
|
config.num_heads,
|
|
config.seq_length,
|
|
config.hidden_size // config.num_heads,
|
|
device=device,
|
|
dtype=self.dtype,
|
|
)
|
|
for _ in range(2)
|
|
]
|
|
else:
|
|
assert self.mode == "spmm"
|
|
attn = torch.randn(
|
|
config.batch_size,
|
|
config.num_heads,
|
|
config.seq_length,
|
|
config.seq_length,
|
|
device=device,
|
|
dtype=self.dtype,
|
|
)
|
|
value = torch.randn(
|
|
config.batch_size,
|
|
config.num_heads,
|
|
config.seq_length,
|
|
config.hidden_size // config.num_heads,
|
|
device=device,
|
|
dtype=self.dtype,
|
|
)
|
|
return [attn, value]
|
|
|
|
def torch_matmul_callable(self, a, b, mask, config):
|
|
input_a = mask_tensor(a, mask, config) if self.mode == "spmm" else a
|
|
input_b = b.transpose(-1, -2) if self.mode == "sddmm" else b
|
|
|
|
def torch_fn():
|
|
return torch.matmul(input_a, input_b)
|
|
|
|
return torch_fn
|
|
|
|
def get_triton_fn(self, mask, config, mode="sddmm"):
|
|
if mode == "sddmm":
|
|
return blocksparse_matmul(
|
|
layout=mask,
|
|
block=config.block_size,
|
|
mode="sdd",
|
|
device="cuda",
|
|
trans_a=False,
|
|
trans_b=True,
|
|
)
|
|
else:
|
|
assert mode == "spmm"
|
|
return blocksparse_matmul(
|
|
layout=mask,
|
|
block=config.block_size,
|
|
mode="dsd",
|
|
device="cuda",
|
|
trans_a=False,
|
|
trans_b=False,
|
|
)
|
|
|
|
def triton_callable(self, a, b, mask, config):
|
|
triton_kernel = self.get_triton_fn(mask, config, self.mode)
|
|
input_a = sparsify_tensor(a, mask, config) if self.mode == "spmm" else a
|
|
input_b = b
|
|
|
|
def triton_fn():
|
|
return triton_kernel(input_a, input_b)
|
|
|
|
return triton_fn
|
|
|
|
def prepare_sputnik_inputs(self, query, key, config, mask):
|
|
# - sparse / sputnik
|
|
mask_cs = torch.ones(
|
|
[config.batch_size, config.num_heads, config.seq_length, config.seq_length],
|
|
dtype=torch.bool,
|
|
device="cuda",
|
|
)
|
|
mask_cs = triton.testing.mask_tensor(
|
|
mask_cs, mask, config.block_size, value=False
|
|
)
|
|
# Sputnik kernels only handle fp32
|
|
query_cs = query.flatten(start_dim=0, end_dim=1).to(torch.float32)
|
|
key_cs = key.flatten(start_dim=0, end_dim=1).to(torch.float32)
|
|
|
|
query_cs = query_cs.contiguous()
|
|
key_cs = key_cs.transpose(-2, -1)
|
|
|
|
sparse_mask_cs = SparseCS(
|
|
mask_cs.flatten(start_dim=0, end_dim=1).contiguous(),
|
|
device=torch.device("cuda"),
|
|
)
|
|
return query_cs, key_cs, sparse_mask_cs
|
|
|
|
def sputnik_callable(self, a, b, mask, config):
|
|
assert self.mode == "sddmm"
|
|
a_cs, b_cs, sparse_mask_cs = self.prepare_sputnik_inputs(a, b, config, mask)
|
|
|
|
def sputnik_fn():
|
|
return _matmul_with_mask(a_cs, b_cs, sparse_mask_cs)
|
|
|
|
return sputnik_fn
|
|
|
|
def get_op_flops(self, mask, config):
|
|
# Measure total compute ops
|
|
op_flops = (
|
|
2 # FMA
|
|
* config.batch_size # batched matmul
|
|
* (config.hidden_size // config.num_heads) # Reduce dimension
|
|
* float(mask.sum())
|
|
* config.block_size
|
|
* config.block_size # Effective seq length * seq_length
|
|
* 1e-12 # TFlops
|
|
)
|
|
return op_flops
|
|
|
|
def check_all(self, tests, a, b):
|
|
ref_test = tests[0]
|
|
ref_out = ref_test.prepare_callable(a, b, ref_test.mask, ref_test.config)()
|
|
|
|
res_test = tests[1]
|
|
res_out = res_test.prepare_callable(a, b, res_test.mask, res_test.config)()
|
|
|
|
self.check_accuracy(ref_out, res_out, ref_test.mask, ref_test.config)
|
|
|
|
def check_accuracy(self, ref_full, res_bsr, mask, config):
|
|
if self.mode == "sddmm":
|
|
# Get the dense representation of the bsr tensor
|
|
# Use triton sparse * dense multiplication to get the dense tensor back
|
|
sparse_dot_dsd = blocksparse_matmul(
|
|
layout=mask,
|
|
block=config.block_size,
|
|
mode="dsd",
|
|
device="cuda",
|
|
trans_a=False,
|
|
trans_b=False,
|
|
)
|
|
identity = torch.eye(
|
|
config.seq_length, config.seq_length, device=device, dtype=self.dtype
|
|
)
|
|
|
|
identity = identity.expand(config.batch_size, config.num_heads, -1, -1)
|
|
res = sparse_dot_dsd(res_bsr, identity)
|
|
|
|
# Get the res where values are masked. Expand the blocked mask
|
|
# ref = triton.testing.mask_tensor(ref_full, mask, config.block_size)
|
|
full_mask = densify_mask(mask, config)
|
|
ref = ref_full * full_mask.to(dtype=self.dtype, device=device)
|
|
try:
|
|
assert torch.allclose(ref, res, atol=1e-3, rtol=1e-3)
|
|
except RuntimeError:
|
|
pass
|
|
except AssertionError:
|
|
raise
|
|
else:
|
|
assert self.mode == "spmm"
|
|
# Both are dense outputs
|
|
try:
|
|
assert torch.allclose(ref_full, res_bsr, atol=1e-3, rtol=1e-3)
|
|
except RuntimeError:
|
|
pass
|
|
except AssertionError:
|
|
import pdb
|
|
|
|
pdb.set_trace()
|
|
raise
|
|
|
|
|
|
class DifferentPatternExperiment(Experiment):
|
|
"""
|
|
In this experiment, we check if sparsity pattern (like bigbird, lower triangular
|
|
etc) changes the performance of different kernels. The idea is to check if
|
|
changing sparsity pattern, while keeping total sparsity ratio same, leads to any
|
|
perforamnce differences.
|
|
|
|
We will perform two experiments
|
|
|
|
1) LowerTraingularMask vs RandomMask - Both have ~50% sparsity.
|
|
2) BigBird Mask vs RandomMask - Both have same sparsity.
|
|
"""
|
|
|
|
def __init__(self, mode, dtype, do_accuracy_check, profile_sputnik=False):
|
|
super(DifferentPatternExperiment, self).__init__(
|
|
mode, dtype, do_accuracy_check, profile_sputnik
|
|
)
|
|
|
|
def gen_config(self):
|
|
batch_sizes = [32]
|
|
heads = [16]
|
|
seq_lengths = [1024, 2048]
|
|
block_sizes = [64]
|
|
hidden_sizes = [1024, 4096, 8192]
|
|
for batch in batch_sizes:
|
|
for hidden_size in hidden_sizes:
|
|
for head in heads:
|
|
for seq in seq_lengths:
|
|
for block in block_sizes:
|
|
entry = {
|
|
"batch_size": batch,
|
|
"num_heads": head,
|
|
"seq_length": seq,
|
|
"block_size": block,
|
|
"hidden_size": hidden_size,
|
|
}
|
|
yield entry
|
|
|
|
def plot(self, sparsity, config, pattern_name):
|
|
desc = [
|
|
f"bs={config.batch_size}",
|
|
f"nheads={config.num_heads}",
|
|
f"block={config.block_size}",
|
|
f"dtype={self.dtype}",
|
|
]
|
|
title_suffix = ",".join(desc)
|
|
pretty_barplot(
|
|
self.results["speedup"],
|
|
title=f"{self.mode} - Pattern experiment ({sparsity}%) - speedup\n"
|
|
+ title_suffix,
|
|
filename=f"same_sparsity_{self.mode}_{self.dtype}_{pattern_name}_time.svg",
|
|
dash_key="pytorch",
|
|
units="Speedup normalized to torch_matmul",
|
|
)
|
|
|
|
pretty_barplot(
|
|
self.results["flops"],
|
|
title=f"{self.mode} - Pattern experiment ({sparsity}%) - throughput\n"
|
|
+ title_suffix,
|
|
filename=f"same_sparsity_{self.mode}_{self.dtype}_{pattern_name}_flops.svg",
|
|
dash_key="pytorch",
|
|
units="TFlops/s",
|
|
)
|
|
|
|
pretty_barplot(
|
|
self.results["memory_savings"],
|
|
title=f"{self.mode} - Pattern experiment ({sparsity}%) - memory savings\n"
|
|
+ title_suffix,
|
|
filename=f"same_sparsity_{self.mode}_{self.dtype}_{pattern_name}_memory.svg",
|
|
dash_key="pytorch",
|
|
units="Memory savings normalized to torch_matmul",
|
|
)
|
|
|
|
def run(self):
|
|
for MaskGenType in [LowerTriangularAttentionMask, BigBirdAttentionMask]:
|
|
self.reset_results()
|
|
for config in self.gen_config():
|
|
# Get pattern mask
|
|
pattern_mask, pattern_config, pattern_name = get_mask(
|
|
MaskGenType, config
|
|
)
|
|
sparsity = get_sparsity(pattern_mask)
|
|
mask_prob = sparsity / 100
|
|
|
|
# Get random mask
|
|
random_mask, random_config, _ = get_mask(
|
|
RandomAttentionMask,
|
|
config,
|
|
[("mask_prob", mask_prob)],
|
|
)
|
|
print(f"{pattern_name} sparsity", get_sparsity(pattern_mask))
|
|
print("Random sparsity", get_sparsity(random_mask))
|
|
|
|
# Create input tensors
|
|
a, b = self.get_inputs(random_config)
|
|
|
|
tests = []
|
|
baseline_name = "torch-matmul"
|
|
tests.append(
|
|
TestCase(
|
|
self.torch_matmul_callable,
|
|
random_mask,
|
|
random_config,
|
|
f"{baseline_name}",
|
|
)
|
|
)
|
|
tests.append(
|
|
TestCase(
|
|
self.triton_callable,
|
|
random_mask,
|
|
random_config,
|
|
"triton-random",
|
|
)
|
|
)
|
|
tests.append(
|
|
TestCase(
|
|
self.triton_callable,
|
|
pattern_mask,
|
|
pattern_config,
|
|
f"triton-{pattern_name}",
|
|
)
|
|
)
|
|
if self.profile_sputnik and self.mode == "sddmm":
|
|
tests.append(
|
|
TestCase(
|
|
self.sputnik_callable,
|
|
random_mask,
|
|
random_config,
|
|
"sputnik-random",
|
|
)
|
|
)
|
|
tests.append(
|
|
TestCase(
|
|
self.sputnik_callable,
|
|
pattern_mask,
|
|
pattern_config,
|
|
f"sputnik-{pattern_name}",
|
|
)
|
|
)
|
|
|
|
dict_key = f"hidden={random_config.hidden_size},seq_len={random_config.seq_length}"
|
|
self.bench_all(
|
|
a,
|
|
b,
|
|
tests,
|
|
random_config,
|
|
sparsity,
|
|
baseline_name,
|
|
self.get_op_flops(random_mask, random_config),
|
|
dict_key,
|
|
)
|
|
|
|
ideal_testcase = TestCase(None, None, None, "Ideal")
|
|
ideal_speedup = round(100 / (100 - sparsity), 1)
|
|
self.add_kv(
|
|
self.results["speedup"], dict_key, ideal_speedup, ideal_testcase
|
|
)
|
|
self.add_kv(
|
|
self.results["memory_savings"],
|
|
dict_key,
|
|
ideal_speedup,
|
|
ideal_testcase,
|
|
)
|
|
self.plot(sparsity, random_config, pattern_name)
|
|
|
|
|
|
class VarySparsityExperiment(Experiment):
|
|
"""
|
|
In this experiment, we check how sparsity ration affects the performance.
|
|
"""
|
|
|
|
def __init__(self, mode, dtype, do_accuracy_check, profile_sputnik=False):
|
|
super(VarySparsityExperiment, self).__init__(
|
|
mode, dtype, do_accuracy_check, profile_sputnik
|
|
)
|
|
|
|
def gen_config(self):
|
|
batch_sizes = [32]
|
|
heads = [16]
|
|
seq_lengths = [2048]
|
|
hidden_sizes = [1024, 8192]
|
|
block_sizes = [64]
|
|
for batch in batch_sizes:
|
|
for seq in seq_lengths:
|
|
for head in heads:
|
|
for block in block_sizes:
|
|
for hidden_size in hidden_sizes:
|
|
entry = {
|
|
"batch_size": batch,
|
|
"num_heads": head,
|
|
"seq_length": seq,
|
|
"block_size": block,
|
|
"hidden_size": hidden_size,
|
|
}
|
|
yield entry
|
|
|
|
def plot(self, sparsity, config, pattern_name):
|
|
desc = [
|
|
f"bs={config.batch_size}",
|
|
f"nheads={config.num_heads}",
|
|
f"block={config.block_size}",
|
|
f"dtype={self.dtype}",
|
|
f"seq_len={config.seq_length}",
|
|
]
|
|
title_suffix = ",".join(desc)
|
|
pretty_barplot(
|
|
self.results["speedup"],
|
|
title=f"{self.mode} - SparsityRatio experiment speedup\n" + title_suffix,
|
|
filename=f"vary_sparsity_{self.mode}_{self.dtype}_{pattern_name}_time.svg",
|
|
dash_key="pytorch",
|
|
units="Speedup normalized to torch_matmul",
|
|
)
|
|
|
|
pretty_barplot(
|
|
self.results["flops"],
|
|
title=f"{self.mode} - SparsityRatio experiment throughput\n" + title_suffix,
|
|
filename=f"vary_sparsity_{self.mode}_{self.dtype}_{pattern_name}_flops.svg",
|
|
dash_key="pytorch",
|
|
units="TFlops/s",
|
|
)
|
|
|
|
pretty_barplot(
|
|
self.results["memory_savings"],
|
|
title=f"{self.mode} - SparsityRatio experiment memory savings\n"
|
|
+ title_suffix,
|
|
filename=f"vary_sparsity_{self.mode}_{self.dtype}_{pattern_name}_memory.svg",
|
|
dash_key="pytorch",
|
|
units="Memory savings normalized to torch_matmul",
|
|
)
|
|
|
|
def run(self):
|
|
self.reset_results()
|
|
random_config = None
|
|
for config in self.gen_config():
|
|
for x in range(10, 100, 10):
|
|
mask_prob = x / 100.0
|
|
# Get random mask
|
|
random_mask, random_config, _ = get_mask(
|
|
RandomAttentionMask,
|
|
config,
|
|
[
|
|
("mask_prob", mask_prob),
|
|
],
|
|
)
|
|
sparsity = get_sparsity(random_mask)
|
|
print("Random sparsity", get_sparsity(random_mask))
|
|
|
|
# Create input tensors
|
|
a, b = self.get_inputs(random_config)
|
|
|
|
tests = []
|
|
baseline_name = "torch-matmul"
|
|
tests.append(
|
|
TestCase(
|
|
self.torch_matmul_callable,
|
|
random_mask,
|
|
random_config,
|
|
f"{baseline_name}",
|
|
)
|
|
)
|
|
tests.append(
|
|
TestCase(
|
|
self.triton_callable,
|
|
random_mask,
|
|
random_config,
|
|
"triton-random",
|
|
)
|
|
)
|
|
if self.profile_sputnik and self.mode == "sddmm":
|
|
tests.append(
|
|
TestCase(
|
|
self.sputnik_callable,
|
|
random_mask,
|
|
random_config,
|
|
"sputnik-random",
|
|
)
|
|
)
|
|
dict_key = f"sp={mask_prob},hidden={random_config.hidden_size}"
|
|
self.bench_all(
|
|
a,
|
|
b,
|
|
tests,
|
|
random_config,
|
|
sparsity,
|
|
baseline_name,
|
|
self.get_op_flops(random_mask, random_config),
|
|
dict_key,
|
|
)
|
|
|
|
ideal_testcase = TestCase(None, None, None, "Ideal")
|
|
ideal_speedup = round(100 / (100 - mask_prob * 100), 1)
|
|
self.add_kv(
|
|
self.results["speedup"], dict_key, ideal_speedup, ideal_testcase
|
|
)
|
|
self.add_kv(
|
|
self.results["memory_savings"],
|
|
dict_key,
|
|
ideal_speedup,
|
|
ideal_testcase,
|
|
)
|
|
self.plot(None, random_config, "random")
|
|
|
|
|
|
class BlockSizeExperiment(Experiment):
|
|
"""
|
|
In this experiment, we analyze how increasing the block size affects
|
|
performance. We will take the lower triangular pattern. As we increase the
|
|
batch size, the blocks near the diagonal have to do more unnecessary
|
|
computation (the effective sparsity starts decreasing).
|
|
"""
|
|
|
|
def __init__(self, mode, dtype, do_accuracy_check, profile_sputnik=False):
|
|
super(BlockSizeExperiment, self).__init__(
|
|
mode, dtype, do_accuracy_check, profile_sputnik
|
|
)
|
|
|
|
def gen_config(self):
|
|
batch_sizes = [32]
|
|
heads = [16]
|
|
seq_lengths = [2048]
|
|
block_sizes = [32, 64, 128, 256]
|
|
hidden_sizes = [1024, 8192]
|
|
for batch in batch_sizes:
|
|
for seq in seq_lengths:
|
|
for hidden_size in hidden_sizes:
|
|
for block in block_sizes:
|
|
for head in heads:
|
|
entry = {
|
|
"batch_size": batch,
|
|
"num_heads": head,
|
|
"seq_length": seq,
|
|
"block_size": block,
|
|
"hidden_size": hidden_size,
|
|
}
|
|
yield entry
|
|
|
|
def plot(self, sparsity, config, pattern_name):
|
|
pretty_barplot(
|
|
self.results["speedup"],
|
|
title=f"{self.mode} - BlockSize experiment speedup\n"
|
|
f"bs={config.batch_size}, nheads={config.num_heads}, seq_len={config.seq_length}, dtype={self.dtype}",
|
|
filename=f"vary_block_size_{self.mode}_{self.dtype}_{pattern_name}_time.svg",
|
|
dash_key="pytorch",
|
|
units="Speedup normalized to torch matmul",
|
|
)
|
|
|
|
pretty_barplot(
|
|
self.results["flops"],
|
|
title=f"{self.mode} - BlockSize experiment throughput\n"
|
|
f"bs={config.batch_size}, nheads={config.num_heads}, seq_len={config.seq_length}, dtype={self.dtype}",
|
|
filename=f"vary_block_size_{self.mode}_{self.dtype}_{pattern_name}_flops.svg",
|
|
dash_key="pytorch",
|
|
units="TFlops/s",
|
|
)
|
|
|
|
pretty_barplot(
|
|
self.results["memory_savings"],
|
|
title=f"{self.mode} - BlockSize experiment memory savings\n"
|
|
f"bs={config.batch_size}, nheads={config.num_heads}, seq_len={config.seq_length}, dtype={self.dtype}",
|
|
filename=f"vary_block_size_{self.mode}_{self.dtype}_{pattern_name}_memory.svg",
|
|
dash_key="pytorch",
|
|
units="Memory savings normalized to torch matmul",
|
|
)
|
|
|
|
def get_op_flops(self, mask, config):
|
|
# Op flops here refer to the original non blocked attention mask, where
|
|
# no unnecessary elements are unmasked. We can compute this by computing
|
|
# the total flops of batch matmul and then multiply by (n+1)/2n.
|
|
num_masked_elems = (config.seq_length + 1) / (2.0 * config.seq_length)
|
|
op_flops = (
|
|
2 # FMA
|
|
* config.batch_size # batched matmul
|
|
* config.num_heads
|
|
* (config.hidden_size // config.num_heads) # Reduce dimension
|
|
* config.seq_length
|
|
* config.seq_length
|
|
* num_masked_elems
|
|
* 1e-12 # TFlops
|
|
)
|
|
return op_flops
|
|
|
|
def run(self):
|
|
self.reset_results()
|
|
lt_config = None
|
|
for config in self.gen_config():
|
|
lt_mask, lt_config, lt_name = get_mask(
|
|
LowerTriangularAttentionMask,
|
|
config,
|
|
)
|
|
sparsity = get_sparsity(lt_mask)
|
|
print("Effective sparsity", sparsity)
|
|
if lt_config.seq_length == 2048:
|
|
plot_mask(lt_mask, lt_config, f"lt_mask_{lt_config.block_size}.svg")
|
|
|
|
# Create input tensors
|
|
a, b = self.get_inputs(lt_config)
|
|
|
|
tests = []
|
|
baseline_name = "torch-matmul"
|
|
tests.append(
|
|
TestCase(
|
|
self.torch_matmul_callable, lt_mask, lt_config, f"{baseline_name}"
|
|
)
|
|
)
|
|
tests.append(
|
|
TestCase(self.triton_callable, lt_mask, lt_config, "triton-random")
|
|
)
|
|
if self.profile_sputnik and self.mode == "sddmm":
|
|
tests.append(
|
|
TestCase(
|
|
self.sputnik_callable, lt_mask, lt_config, "sputnik-random"
|
|
)
|
|
)
|
|
dict_key = f"hidden={lt_config.hidden_size}, block={lt_config.block_size}"
|
|
self.bench_all(
|
|
a,
|
|
b,
|
|
tests,
|
|
lt_config,
|
|
sparsity,
|
|
baseline_name,
|
|
self.get_op_flops(lt_mask, lt_config),
|
|
dict_key,
|
|
)
|
|
|
|
ideal_testcase = TestCase(None, None, None, "Ideal")
|
|
seq_len = lt_config.seq_length
|
|
total_elems = seq_len * seq_len
|
|
nnz = seq_len * (seq_len + 1) / 2
|
|
ideal_speedup = (1.0 * total_elems) / nnz
|
|
|
|
self.add_kv(
|
|
self.results["speedup"], dict_key, ideal_speedup, ideal_testcase
|
|
)
|
|
self.add_kv(
|
|
self.results["memory_savings"],
|
|
dict_key,
|
|
ideal_speedup,
|
|
ideal_testcase,
|
|
)
|
|
self.plot(None, lt_config, lt_name)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
for MaskGen in [
|
|
RandomAttentionMask,
|
|
LowerTriangularAttentionMask,
|
|
BigBirdAttentionMask,
|
|
AxialAttentionMask,
|
|
LocalAttentionMask,
|
|
]:
|
|
mask_gen = MaskGen()
|
|
mask, config, name = mask_gen()
|
|
plot_mask(mask, config, f"{name}.svg")
|
|
|
|
for mode in ["sddmm", "spmm"]:
|
|
DifferentPatternExperiment(mode, torch.float16, True).run()
|
|
VarySparsityExperiment(mode, torch.float16, True).run()
|
|
BlockSizeExperiment(mode, torch.float16, True).run()
|