Restruct sgl-kernel benchmark (#10861)
This commit is contained in:
@@ -52,7 +52,7 @@ def benchmark_bf16_output(num_tokens, impl):
|
||||
def runner():
|
||||
dsv3_router_gemm(mat_a, mat_b, out_dtype=torch.bfloat16)
|
||||
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(runner, quantiles=quantiles)
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(runner, quantiles=quantiles)
|
||||
|
||||
def tflops(t_ms):
|
||||
flops = 2 * M * K * N
|
||||
@@ -106,7 +106,7 @@ def benchmark_float_output(num_tokens, impl):
|
||||
def runner():
|
||||
dsv3_router_gemm(mat_a, mat_b, out_dtype=torch.float32)
|
||||
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(runner, quantiles=quantiles)
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(runner, quantiles=quantiles)
|
||||
|
||||
def tflops(t_ms):
|
||||
flops = 2 * M * K * N
|
||||
@@ -119,9 +119,5 @@ if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
args = parser.parse_args()
|
||||
|
||||
benchmark_bf16_output.run(
|
||||
print_data=True, show_plots=True, save_path="bench_dsv3_router_gemm"
|
||||
)
|
||||
benchmark_float_output.run(
|
||||
print_data=True, show_plots=True, save_path="bench_dsv3_router_gemm"
|
||||
)
|
||||
benchmark_bf16_output.run(print_data=True)
|
||||
benchmark_float_output.run(print_data=True)
|
||||
|
||||
Reference in New Issue
Block a user