Restruct sgl-kernel benchmark (#10861)
This commit is contained in:
@@ -46,7 +46,7 @@ configs = [(sq,) for sq in seq_length_range]
|
||||
)
|
||||
)
|
||||
def benchmark(seq_length, provider):
|
||||
dtype = torch.bfloat16
|
||||
dtype = torch.float32
|
||||
device = torch.device("cuda")
|
||||
num_experts, num_expert_group, topk_group, topk = 256, 8, 4, 8
|
||||
|
||||
@@ -56,14 +56,14 @@ def benchmark(seq_length, provider):
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
if provider == "original":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
lambda: biased_grouped_topk_org(
|
||||
scores.clone(), bias.clone(), num_expert_group, topk_group, topk
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
elif provider == "kernel":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
lambda: biased_grouped_topk_org_fuse_kernel(
|
||||
scores.clone(), bias.clone(), num_expert_group, topk_group, topk
|
||||
),
|
||||
|
||||
Reference in New Issue
Block a user