Restruct sgl-kernel benchmark (#10861)

This commit is contained in:
Xiaoyu Zhang
2025-09-25 07:45:25 +08:00
committed by GitHub
parent 7a06ef984d
commit c4e314f986
27 changed files with 425 additions and 319 deletions

View File

@@ -117,17 +117,17 @@ def benchmark(batch_size, provider, N, K):
quantiles = [0.5, 0.2, 0.8]
if provider == "FP16":
ms, min_ms, max_ms = triton.testing.do_bench(
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: torch.matmul(a_fp16, b_fp16),
quantiles=quantiles,
)
if provider == "W8A8":
ms, min_ms, max_ms = triton.testing.do_bench(
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: int8_scaled_mm(a, b, scale_a, scale_b, torch.float16),
quantiles=quantiles,
)
if provider == "Qserve_W4A8_Per_Channel":
ms, min_ms, max_ms = triton.testing.do_bench(
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: qserve_w4a8_per_chn_gemm(
a_qserve_chn,
b_qserve_chn,
@@ -139,7 +139,7 @@ def benchmark(batch_size, provider, N, K):
quantiles=quantiles,
)
if provider == "Qserve_W4A8_Per_Group":
ms, min_ms, max_ms = triton.testing.do_bench(
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: qserve_w4a8_per_group_gemm(
a_qserve_group,
b_qserve_group,
@@ -189,8 +189,6 @@ if __name__ == "__main__":
print(f"{model_name} N={N} K={K}: ")
benchmark.run(
print_data=True,
show_plots=True,
save_path="bench_qserve_w4a8_gemm_res",
N=N,
K=K,
)