Fix sgl-kernel benchmark dead code (#11022)

This commit is contained in:
Xiaoyu Zhang
2025-09-29 15:06:40 +08:00
committed by GitHub
parent 71959545df
commit 11965b0daf
25 changed files with 1019 additions and 260 deletions

View File

@@ -1,11 +1,18 @@
import itertools
import math
import os
import torch
import triton
import triton.language as tl
from sgl_kernel import lightning_attention_decode
# CI environment detection
IS_CI = (
os.getenv("CI", "false").lower() == "true"
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)
def next_power_of_2(n):
return 2 ** (int(math.ceil(math.log(n, 2))))
@@ -207,7 +214,12 @@ def calculate_diff(batch_size):
print("❌ Implementations differ")
batch_size_range = [i for i in range(1, 65)] # 1 to 128
# Simplified for CI environment
if IS_CI:
batch_size_range = [1] # Single batch size for CI
else:
batch_size_range = [i for i in range(1, 65)] # 1 to 64
configs = [(bs,) for bs in batch_size_range]
@@ -292,8 +304,9 @@ if __name__ == "__main__":
)
args = parser.parse_args()
# Run correctness test
calculate_diff(batch_size=4)
# Run correctness test - simplified for CI
test_batch_size = 1 if IS_CI else 4
calculate_diff(batch_size=test_batch_size)
# Run performance benchmark
benchmark.run(print_data=True)