Fix sgl-kernel benchmark dead code (#11022)

This commit is contained in:
Xiaoyu Zhang
2025-09-29 15:06:40 +08:00
committed by GitHub
parent 71959545df
commit 11965b0daf
25 changed files with 1019 additions and 260 deletions

View File

@@ -2,6 +2,7 @@ import argparse
import copy
import csv
import itertools
import os
import pytest
import torch
@@ -9,6 +10,14 @@ import triton
from flashinfer import mm_fp4
from sgl_kernel import cutlass_scaled_fp4_mm, scaled_fp4_quant
from sglang.srt.utils import get_device_capability
# CI environment detection
IS_CI = (
os.getenv("CI", "false").lower() == "true"
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)
FLOAT4_E2M1_MAX = 6.0
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
@@ -33,27 +42,34 @@ def get_weight_shapes(args):
]
# CI environment uses simplified parameters
if IS_CI:
batch_sizes = [1, 8] # Simplified for CI
else:
batch_sizes = [
1,
2,
4,
8,
16,
32,
64,
128,
256,
512,
1024,
2048,
3072,
4096,
8192,
16384,
]
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size"],
x_vals=[
1,
2,
4,
8,
16,
32,
64,
128,
256,
512,
1024,
2048,
3072,
4096,
8192,
16384,
],
x_vals=batch_sizes,
# x_vals = [64],
x_log=False,
line_arg="provider",
@@ -188,21 +204,38 @@ if __name__ == "__main__":
)
args = parser.parse_args()
# Simplify for CI environment
if IS_CI:
args.tp_sizes = [args.tp_sizes[0]] # Use only first TP size
if args.csv:
with open(args.csv, "w", newline="") as f:
writer = csv.writer(f)
writer.writerow(["provider", "m", "n", "k", "time_ms"])
NKs = get_weight_shapes(args)
for N, K in NKs:
print(f"DeepSeek-R1-0528-FP4 N={N} K={K}: ")
benchmark.run(
print_data=True,
N=N,
K=K,
dtype=args.dtype,
correctness=args.correctness,
csv_file=args.csv,
)
# Check architecture compatibility - FP4 operations require sm100a/sm103a
major, minor = get_device_capability()
if major is None or major < 10: # Requires compute capability 10.0+ (sm100a/sm103a)
print("Skipping FP4 GEMM benchmark")
if major is not None:
print(f"FP4 operations require sm100a/sm103a, but found sm{major}{minor}")
else:
print("Could not determine device capability")
else:
NKs = get_weight_shapes(args)
print("Benchmark finished!")
# Limit iterations in CI
if IS_CI:
NKs = NKs[:2] # Only test first 2 shapes in CI
for N, K in NKs:
print(f"DeepSeek-R1-0528-FP4 N={N} K={K}: ")
benchmark.run(
print_data=True,
N=N,
K=K,
dtype=args.dtype,
correctness=args.correctness,
csv_file=args.csv,
)
print("Benchmark finished!")