Files
enginex-bi_series-vllm/pkgs/xformers/benchmarks/benchmark_causal_blocksparse.py
2025-08-05 19:02:46 +08:00

138 lines
4.8 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import os
from typing import Any, Dict
import torch
import triton
from xformers.benchmarks.utils import TestCase, pretty_plot, pretty_print
from xformers.components.attention.attention_mask import AttentionMask
from xformers.components.attention.core import scaled_dot_product_attention
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
SHAPES = [
(8, 128, 2096),
(8, 1024, 256),
(12, 512, 1024),
(128, 128, 512),
(8, 2048, 4096),
(16, 1024, 5120),
(512, 128, 2560),
]
BLOCK_SIZES = [128]
N_HEADS = [8, 32]
def bench_blocksparse_compare(backward: bool):
device = torch.device("cuda")
bw = "+bw" if backward else ""
use_amp = True
_use_cuda = True
for dtype in [torch.float16, torch.float32]:
datatype = "fp16" if dtype == torch.float16 else "fp32"
results: Dict[str, Any] = {}
results_mem: Dict[str, Any] = {}
for BS in BLOCK_SIZES:
for heads in N_HEADS:
for B, M, K in SHAPES:
q = torch.randn(
(B, heads, M, K // heads),
requires_grad=backward,
device=device,
dtype=dtype,
)
k = q
v = q
# Mask with causal flag
m_att_mask = AttentionMask.make_causal(
M, M, device=device, dtype=dtype
)
# Custom causal tensor mask
m_custom = torch.triu(
torch.ones(M, M, device=device, dtype=dtype) * float("-inf"),
diagonal=1,
)
def blocksparse_attention():
with torch.cuda.amp.autocast(enabled=use_amp):
y = scaled_dot_product_attention(
q=q, k=k, v=v, att_mask=m_att_mask, block_size=BS
)
if backward:
torch.norm(y).backward()
return y
def sdp_attention():
with torch.cuda.amp.autocast(enabled=use_amp):
y = scaled_dot_product_attention(
q=q, k=k, v=v, att_mask=m_custom, block_size=BS
)
if backward:
torch.norm(y).backward()
return y
for testcase in [
TestCase(blocksparse_attention, f"blocksparse - fw{bw}"),
TestCase(sdp_attention, f"standard sdp - fw{bw}"),
]:
if _use_cuda:
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
torch.cuda.synchronize()
time = triton.testing.do_bench(testcase.function)[0]
if _use_cuda:
torch.cuda.synchronize()
max_memory = torch.cuda.max_memory_allocated() / 2**20
else:
max_memory = -1
key = f"B={B},M={M},K={K},NH={heads}"
if key not in results_mem:
results_mem[key] = {}
results_mem[key][testcase.name] = f"{max_memory:.1f}"
if key not in results:
results[key] = {}
results[key][testcase.name] = f"{time:.2f}"
pretty_print(
results,
title=f"\n --- Type: {datatype} Block Size: {BS} --- ",
units="runtime in ms",
)
pretty_print(
results_mem,
title=f"\n --- Type: {datatype} Block Size: {BS} --- ",
units="peak memory usage in MB",
)
pretty_plot(
results,
title=f"Causal Blocksparse Runtime FW{bw.upper()} {datatype} Blocksize:{BS}",
units="runtime in ms",
dash_key="torch",
legend_loc="upper left",
)
pretty_plot(
results_mem,
title=f"Causal Blocksparse Memory FW{bw.upper()} {datatype} Blocksize:{BS}",
units="peak memory usage in MB",
dash_key="torch",
legend_loc="upper left",
)
for bw in [False, True]:
bench_blocksparse_compare(bw)