Restruct sgl-kernel benchmark (#10861)

This commit is contained in:
Xiaoyu Zhang
2025-09-25 07:45:25 +08:00
committed by GitHub
parent 7a06ef984d
commit c4e314f986
27 changed files with 425 additions and 319 deletions

View File

@@ -246,7 +246,7 @@ def benchmark(batch_size, provider):
quantiles = [0.5, 0.2, 0.8]
if provider == "naive":
ms, min_ms, max_ms = triton.testing.do_bench(
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: lightning_attention_decode_naive(
q.clone(), k.clone(), v.clone(), past_kv.clone(), slope.clone()
),
@@ -257,7 +257,7 @@ def benchmark(batch_size, provider):
batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype
)
new_kv = torch.empty(batch_size, num_heads, head_dim, head_dim, device=device)
ms, min_ms, max_ms = triton.testing.do_bench(
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: lightning_attention_decode_kernel(
q.clone(),
k.clone(),
@@ -270,7 +270,7 @@ def benchmark(batch_size, provider):
quantiles=quantiles,
)
elif provider == "triton":
ms, min_ms, max_ms = triton.testing.do_bench(
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: triton_lightning_attn_decode(
q.clone(), k.clone(), v.clone(), past_kv.clone(), slope.clone()
),