Restruct sgl-kernel benchmark (#10861)
This commit is contained in:
@@ -324,7 +324,7 @@ def benchmark(num_tokens, num_experts, topk, provider):
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
if provider == "sgl":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
lambda: sgl_moe_align_block_size_with_empty(
|
||||
topk_ids,
|
||||
num_experts,
|
||||
@@ -336,7 +336,7 @@ def benchmark(num_tokens, num_experts, topk, provider):
|
||||
quantiles=quantiles,
|
||||
)
|
||||
elif provider == "sgl_fusion":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
lambda: sgl_moe_align_block_size_with_empty(
|
||||
topk_ids,
|
||||
num_experts,
|
||||
@@ -350,7 +350,7 @@ def benchmark(num_tokens, num_experts, topk, provider):
|
||||
)
|
||||
elif provider == "triton":
|
||||
sorted_ids.fill_(topk_ids.numel())
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
lambda: moe_align_block_size_triton(
|
||||
topk_ids,
|
||||
num_experts,
|
||||
|
||||
Reference in New Issue
Block a user