Restruct sgl-kernel benchmark (#10861)
This commit is contained in:
@@ -117,17 +117,17 @@ def benchmark(batch_size, provider, N, K):
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
if provider == "FP16":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
lambda: torch.matmul(a_fp16, b_fp16),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
if provider == "W8A8":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
lambda: int8_scaled_mm(a, b, scale_a, scale_b, torch.float16),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
if provider == "Qserve_W4A8_Per_Channel":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
lambda: qserve_w4a8_per_chn_gemm(
|
||||
a_qserve_chn,
|
||||
b_qserve_chn,
|
||||
@@ -139,7 +139,7 @@ def benchmark(batch_size, provider, N, K):
|
||||
quantiles=quantiles,
|
||||
)
|
||||
if provider == "Qserve_W4A8_Per_Group":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
lambda: qserve_w4a8_per_group_gemm(
|
||||
a_qserve_group,
|
||||
b_qserve_group,
|
||||
@@ -189,8 +189,6 @@ if __name__ == "__main__":
|
||||
print(f"{model_name} N={N} K={K}: ")
|
||||
benchmark.run(
|
||||
print_data=True,
|
||||
show_plots=True,
|
||||
save_path="bench_qserve_w4a8_gemm_res",
|
||||
N=N,
|
||||
K=K,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user