Fix sgl-kernel benchmark dead code (#11022)
This commit is contained in:
@@ -1,10 +1,17 @@
|
||||
import itertools
|
||||
import os
|
||||
|
||||
import sgl_kernel
|
||||
import torch
|
||||
import triton
|
||||
import triton.testing
|
||||
|
||||
# CI environment detection
|
||||
IS_CI = (
|
||||
os.getenv("CI", "false").lower() == "true"
|
||||
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
|
||||
)
|
||||
|
||||
|
||||
def torch_top_k_top_p_joint_sampling_from_probs(
|
||||
normalized_prob, top_k, top_p, eps=1e-4
|
||||
@@ -67,10 +74,16 @@ def calculate_diff(batch_size, vocab_size, p):
|
||||
)
|
||||
|
||||
|
||||
# parameter space
|
||||
batch_size_range = [16, 64, 128]
|
||||
vocab_size_range = [111, 32000]
|
||||
p_range = [0.1, 0.5]
|
||||
# parameter space - simplified for CI
|
||||
if IS_CI:
|
||||
batch_size_range = [16] # Single batch size for CI
|
||||
vocab_size_range = [111] # Single vocab size for CI
|
||||
p_range = [0.1] # Single p value for CI
|
||||
else:
|
||||
batch_size_range = [16, 64, 128]
|
||||
vocab_size_range = [111, 32000]
|
||||
p_range = [0.1, 0.5]
|
||||
|
||||
configs = list(itertools.product(batch_size_range, vocab_size_range, p_range))
|
||||
|
||||
|
||||
@@ -114,15 +127,19 @@ def benchmark_sampling(batch_size, vocab_size, p, provider):
|
||||
filter_apply_order="joint",
|
||||
)
|
||||
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
fn, quantiles=[0.5, 0.2, 0.8]
|
||||
)
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=[0.5, 0.2, 0.8])
|
||||
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Correctness check
|
||||
for cfg in configs:
|
||||
# Correctness check - simplified for CI
|
||||
if IS_CI:
|
||||
# Only test one configuration in CI
|
||||
test_configs = [configs[0]] if configs else [(16, 111, 0.1)]
|
||||
else:
|
||||
test_configs = configs
|
||||
|
||||
for cfg in test_configs:
|
||||
calculate_diff(*cfg)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
|
||||
Reference in New Issue
Block a user