77 lines
2.0 KiB
Python
77 lines
2.0 KiB
Python
import argparse
|
|
import os
|
|
|
|
# CI environment detection
|
|
IS_CI = (
|
|
os.getenv("CI", "false").lower() == "true"
|
|
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
|
|
)
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
import triton
|
|
import triton.testing
|
|
from sgl_kernel import dsv3_fused_a_gemm
|
|
|
|
# CI environment uses simplified parameters
|
|
if IS_CI:
|
|
num_tokens_vals = [1] # Only test 1 value in CI
|
|
line_vals = ["sgl-kernel"] # Only test sgl-kernel implementation in CI
|
|
else:
|
|
num_tokens_vals = [i + 1 for i in range(16)] # Test 1-16 in full mode
|
|
line_vals = ["torch", "sgl-kernel"]
|
|
|
|
|
|
@triton.testing.perf_report(
|
|
triton.testing.Benchmark(
|
|
x_names=["num_tokens"],
|
|
x_vals=num_tokens_vals,
|
|
x_log=False,
|
|
line_arg="impl",
|
|
line_vals=line_vals,
|
|
line_names=(
|
|
["torch (bf16)", "dsv3_fused_a_gemm"]
|
|
if not IS_CI
|
|
else ["dsv3_fused_a_gemm"]
|
|
),
|
|
styles=[("blue", "-"), ("orange", "-")] if not IS_CI else [("orange", "-")],
|
|
ylabel="TFLOPs",
|
|
plot_name="bf16 dsv3 fused a GEMM throughput",
|
|
args={},
|
|
)
|
|
)
|
|
def benchmark(num_tokens, impl):
|
|
kHdIn = 7168
|
|
kHdOut = 2112
|
|
M, K, N = num_tokens, kHdIn, kHdOut
|
|
|
|
mat_a = torch.randn((M, K), dtype=torch.bfloat16, device="cuda").contiguous()
|
|
mat_b = torch.randn((N, K), dtype=torch.bfloat16, device="cuda").transpose(0, 1)
|
|
|
|
quantiles = [0.5, 0.2, 0.8]
|
|
|
|
if impl == "torch":
|
|
|
|
def runner():
|
|
F.linear(mat_a, mat_b.T)
|
|
|
|
elif impl == "sgl-kernel":
|
|
|
|
def runner():
|
|
dsv3_fused_a_gemm(mat_a, mat_b)
|
|
|
|
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(runner, quantiles=quantiles)
|
|
|
|
def tflops(t_ms):
|
|
flops = 2 * M * K * N
|
|
return flops / (t_ms * 1e-3) / 1e12
|
|
|
|
return tflops(ms), tflops(max_ms), tflops(min_ms)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
args = parser.parse_args()
|
|
|
|
benchmark.run(print_data=True)
|