Fix sgl-kernel benchmark dead code (#11022)
This commit is contained in:
@@ -2,6 +2,7 @@ import argparse
|
||||
import copy
|
||||
import csv
|
||||
import itertools
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@@ -9,6 +10,14 @@ import triton
|
||||
from flashinfer import mm_fp4
|
||||
from sgl_kernel import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
||||
|
||||
from sglang.srt.utils import get_device_capability
|
||||
|
||||
# CI environment detection
|
||||
IS_CI = (
|
||||
os.getenv("CI", "false").lower() == "true"
|
||||
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
|
||||
)
|
||||
|
||||
FLOAT4_E2M1_MAX = 6.0
|
||||
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
|
||||
|
||||
@@ -33,27 +42,34 @@ def get_weight_shapes(args):
|
||||
]
|
||||
|
||||
|
||||
# CI environment uses simplified parameters
|
||||
if IS_CI:
|
||||
batch_sizes = [1, 8] # Simplified for CI
|
||||
else:
|
||||
batch_sizes = [
|
||||
1,
|
||||
2,
|
||||
4,
|
||||
8,
|
||||
16,
|
||||
32,
|
||||
64,
|
||||
128,
|
||||
256,
|
||||
512,
|
||||
1024,
|
||||
2048,
|
||||
3072,
|
||||
4096,
|
||||
8192,
|
||||
16384,
|
||||
]
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size"],
|
||||
x_vals=[
|
||||
1,
|
||||
2,
|
||||
4,
|
||||
8,
|
||||
16,
|
||||
32,
|
||||
64,
|
||||
128,
|
||||
256,
|
||||
512,
|
||||
1024,
|
||||
2048,
|
||||
3072,
|
||||
4096,
|
||||
8192,
|
||||
16384,
|
||||
],
|
||||
x_vals=batch_sizes,
|
||||
# x_vals = [64],
|
||||
x_log=False,
|
||||
line_arg="provider",
|
||||
@@ -188,21 +204,38 @@ if __name__ == "__main__":
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Simplify for CI environment
|
||||
if IS_CI:
|
||||
args.tp_sizes = [args.tp_sizes[0]] # Use only first TP size
|
||||
|
||||
if args.csv:
|
||||
with open(args.csv, "w", newline="") as f:
|
||||
writer = csv.writer(f)
|
||||
writer.writerow(["provider", "m", "n", "k", "time_ms"])
|
||||
|
||||
NKs = get_weight_shapes(args)
|
||||
for N, K in NKs:
|
||||
print(f"DeepSeek-R1-0528-FP4 N={N} K={K}: ")
|
||||
benchmark.run(
|
||||
print_data=True,
|
||||
N=N,
|
||||
K=K,
|
||||
dtype=args.dtype,
|
||||
correctness=args.correctness,
|
||||
csv_file=args.csv,
|
||||
)
|
||||
# Check architecture compatibility - FP4 operations require sm100a/sm103a
|
||||
major, minor = get_device_capability()
|
||||
if major is None or major < 10: # Requires compute capability 10.0+ (sm100a/sm103a)
|
||||
print("Skipping FP4 GEMM benchmark")
|
||||
if major is not None:
|
||||
print(f"FP4 operations require sm100a/sm103a, but found sm{major}{minor}")
|
||||
else:
|
||||
print("Could not determine device capability")
|
||||
else:
|
||||
NKs = get_weight_shapes(args)
|
||||
|
||||
print("Benchmark finished!")
|
||||
# Limit iterations in CI
|
||||
if IS_CI:
|
||||
NKs = NKs[:2] # Only test first 2 shapes in CI
|
||||
|
||||
for N, K in NKs:
|
||||
print(f"DeepSeek-R1-0528-FP4 N={N} K={K}: ")
|
||||
benchmark.run(
|
||||
print_data=True,
|
||||
N=N,
|
||||
K=K,
|
||||
dtype=args.dtype,
|
||||
correctness=args.correctness,
|
||||
csv_file=args.csv,
|
||||
)
|
||||
print("Benchmark finished!")
|
||||
|
||||
Reference in New Issue
Block a user