Fix sgl-kernel benchmark dead code (#11022)

This commit is contained in:
Xiaoyu Zhang
2025-09-29 15:06:40 +08:00
committed by GitHub
parent 71959545df
commit 11965b0daf
25 changed files with 1019 additions and 260 deletions

View File

@@ -1,10 +1,17 @@
import itertools
import os
import sgl_kernel
import torch
import triton
import triton.testing
# CI environment detection
IS_CI = (
os.getenv("CI", "false").lower() == "true"
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)
def torch_top_k_top_p_joint_sampling_from_probs(
normalized_prob, top_k, top_p, eps=1e-4
@@ -67,10 +74,16 @@ def calculate_diff(batch_size, vocab_size, p):
)
# parameter space
batch_size_range = [16, 64, 128]
vocab_size_range = [111, 32000]
p_range = [0.1, 0.5]
# parameter space - simplified for CI
if IS_CI:
batch_size_range = [16] # Single batch size for CI
vocab_size_range = [111] # Single vocab size for CI
p_range = [0.1] # Single p value for CI
else:
batch_size_range = [16, 64, 128]
vocab_size_range = [111, 32000]
p_range = [0.1, 0.5]
configs = list(itertools.product(batch_size_range, vocab_size_range, p_range))
@@ -114,15 +127,19 @@ def benchmark_sampling(batch_size, vocab_size, p, provider):
filter_apply_order="joint",
)
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
fn, quantiles=[0.5, 0.2, 0.8]
)
ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=[0.5, 0.2, 0.8])
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
if __name__ == "__main__":
# Correctness check
for cfg in configs:
# Correctness check - simplified for CI
if IS_CI:
# Only test one configuration in CI
test_configs = [configs[0]] if configs else [(16, 111, 0.1)]
else:
test_configs = configs
for cfg in test_configs:
calculate_diff(*cfg)
print("\n" + "=" * 60)