Fix sgl-kernel benchmark dead code (#11022)

This commit is contained in:
Xiaoyu Zhang
2025-09-29 15:06:40 +08:00
committed by GitHub
parent 71959545df
commit 11965b0daf
25 changed files with 1019 additions and 260 deletions

View File

@@ -1,11 +1,20 @@
import argparse
import copy
import itertools
import os
import torch
import triton
from sgl_kernel import cutlass_scaled_fp4_mm, scaled_fp4_quant
from sglang.srt.utils import get_device_capability
# CI environment detection
IS_CI = (
os.getenv("CI", "false").lower() == "true"
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)
FLOAT4_E2M1_MAX = 6.0
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
@@ -162,9 +171,22 @@ if __name__ == "__main__":
)
args = parser.parse_args()
KN_model_names = prepare_shapes(args)
for K, N, model_name in KN_model_names:
print(f"{model_name} N={N} K={K}: ")
benchmark.run(print_data=True, N=N, K=K)
# Check architecture compatibility - FP4 operations require sm100a/sm103a
major, minor = get_device_capability()
if major is None or major < 10: # Requires compute capability 10.0+ (sm100a/sm103a)
print("Skipping NVIDIA FP4 scaled GEMM benchmark")
if major is not None:
print(f"FP4 operations require sm100a/sm103a, but found sm{major}{minor}")
else:
print("Could not determine device capability")
else:
KN_model_names = prepare_shapes(args)
print("Benchmark finished!")
# Limit iterations in CI
if IS_CI:
KN_model_names = KN_model_names[:2] # Only test first 2 shapes in CI
for K, N, model_name in KN_model_names:
print(f"{model_name} N={N} K={K}: ")
benchmark.run(print_data=True, N=N, K=K)
print("Benchmark finished!")