Fix sgl-kernel benchmark dead code (#11022)
This commit is contained in:
@@ -1,11 +1,20 @@
|
||||
import argparse
|
||||
import copy
|
||||
import itertools
|
||||
import os
|
||||
|
||||
import torch
|
||||
import triton
|
||||
from sgl_kernel import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
||||
|
||||
from sglang.srt.utils import get_device_capability
|
||||
|
||||
# CI environment detection
|
||||
IS_CI = (
|
||||
os.getenv("CI", "false").lower() == "true"
|
||||
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
|
||||
)
|
||||
|
||||
FLOAT4_E2M1_MAX = 6.0
|
||||
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
|
||||
|
||||
@@ -162,9 +171,22 @@ if __name__ == "__main__":
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
KN_model_names = prepare_shapes(args)
|
||||
for K, N, model_name in KN_model_names:
|
||||
print(f"{model_name} N={N} K={K}: ")
|
||||
benchmark.run(print_data=True, N=N, K=K)
|
||||
# Check architecture compatibility - FP4 operations require sm100a/sm103a
|
||||
major, minor = get_device_capability()
|
||||
if major is None or major < 10: # Requires compute capability 10.0+ (sm100a/sm103a)
|
||||
print("Skipping NVIDIA FP4 scaled GEMM benchmark")
|
||||
if major is not None:
|
||||
print(f"FP4 operations require sm100a/sm103a, but found sm{major}{minor}")
|
||||
else:
|
||||
print("Could not determine device capability")
|
||||
else:
|
||||
KN_model_names = prepare_shapes(args)
|
||||
|
||||
print("Benchmark finished!")
|
||||
# Limit iterations in CI
|
||||
if IS_CI:
|
||||
KN_model_names = KN_model_names[:2] # Only test first 2 shapes in CI
|
||||
|
||||
for K, N, model_name in KN_model_names:
|
||||
print(f"{model_name} N={N} K={K}: ")
|
||||
benchmark.run(print_data=True, N=N, K=K)
|
||||
print("Benchmark finished!")
|
||||
|
||||
Reference in New Issue
Block a user