Restruct sgl-kernel benchmark (#10861)

This commit is contained in:
Xiaoyu Zhang
2025-09-25 07:45:25 +08:00
committed by GitHub
parent 7a06ef984d
commit c4e314f986
27 changed files with 425 additions and 319 deletions

View File

@@ -52,7 +52,7 @@ def benchmark_bf16_output(num_tokens, impl):
def runner():
dsv3_router_gemm(mat_a, mat_b, out_dtype=torch.bfloat16)
ms, min_ms, max_ms = triton.testing.do_bench(runner, quantiles=quantiles)
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(runner, quantiles=quantiles)
def tflops(t_ms):
flops = 2 * M * K * N
@@ -106,7 +106,7 @@ def benchmark_float_output(num_tokens, impl):
def runner():
dsv3_router_gemm(mat_a, mat_b, out_dtype=torch.float32)
ms, min_ms, max_ms = triton.testing.do_bench(runner, quantiles=quantiles)
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(runner, quantiles=quantiles)
def tflops(t_ms):
flops = 2 * M * K * N
@@ -119,9 +119,5 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser()
args = parser.parse_args()
benchmark_bf16_output.run(
print_data=True, show_plots=True, save_path="bench_dsv3_router_gemm"
)
benchmark_float_output.run(
print_data=True, show_plots=True, save_path="bench_dsv3_router_gemm"
)
benchmark_bf16_output.run(print_data=True)
benchmark_float_output.run(print_data=True)