Restruct sgl-kernel benchmark (#10861)
This commit is contained in:
@@ -87,7 +87,7 @@ def benchmark(batch_size, seq_len, provider, block_size, num_kv_splits):
|
||||
workspace = torch.empty(workspace_size, device="cuda", dtype=torch.uint8)
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
lambda: cutlass_mla_decode(
|
||||
qn.transpose(0, 1),
|
||||
qr,
|
||||
@@ -136,8 +136,6 @@ if __name__ == "__main__":
|
||||
print(f"block_size={block_size}, num_kv_splits={kv_split}: ")
|
||||
benchmark.run(
|
||||
print_data=True,
|
||||
show_plots=True,
|
||||
save_path="bench_blackwell_mla_res",
|
||||
block_size=block_size,
|
||||
num_kv_splits=kv_split,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user