Fix sgl-kernel benchmark dead code (#11022)
This commit is contained in:
@@ -1,11 +1,18 @@
|
||||
import itertools
|
||||
import math
|
||||
import os
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from sgl_kernel import lightning_attention_decode
|
||||
|
||||
# CI environment detection
|
||||
IS_CI = (
|
||||
os.getenv("CI", "false").lower() == "true"
|
||||
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
|
||||
)
|
||||
|
||||
|
||||
def next_power_of_2(n):
|
||||
return 2 ** (int(math.ceil(math.log(n, 2))))
|
||||
@@ -207,7 +214,12 @@ def calculate_diff(batch_size):
|
||||
print("❌ Implementations differ")
|
||||
|
||||
|
||||
batch_size_range = [i for i in range(1, 65)] # 1 to 128
|
||||
# Simplified for CI environment
|
||||
if IS_CI:
|
||||
batch_size_range = [1] # Single batch size for CI
|
||||
else:
|
||||
batch_size_range = [i for i in range(1, 65)] # 1 to 64
|
||||
|
||||
configs = [(bs,) for bs in batch_size_range]
|
||||
|
||||
|
||||
@@ -292,8 +304,9 @@ if __name__ == "__main__":
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Run correctness test
|
||||
calculate_diff(batch_size=4)
|
||||
# Run correctness test - simplified for CI
|
||||
test_batch_size = 1 if IS_CI else 4
|
||||
calculate_diff(batch_size=test_batch_size)
|
||||
|
||||
# Run performance benchmark
|
||||
benchmark.run(print_data=True)
|
||||
|
||||
Reference in New Issue
Block a user