Fix sgl-kernel benchmark dead code (#11022)
This commit is contained in:
@@ -1,18 +1,33 @@
|
||||
import argparse
|
||||
import copy
|
||||
import itertools
|
||||
import os
|
||||
|
||||
import deep_gemm
|
||||
import torch
|
||||
import triton
|
||||
from deep_gemm.utils.layout import get_mn_major_tma_aligned_tensor
|
||||
from sgl_kernel import fp8_blockwise_scaled_mm
|
||||
from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
|
||||
|
||||
# Optional vLLM import
|
||||
try:
|
||||
from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
|
||||
|
||||
VLLM_AVAILABLE = True
|
||||
except ImportError:
|
||||
vllm_scaled_mm = None
|
||||
VLLM_AVAILABLE = False
|
||||
|
||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
w8a8_block_fp8_matmul_triton as w8a8_block_fp8_matmul,
|
||||
)
|
||||
|
||||
# CI environment detection
|
||||
IS_CI = (
|
||||
os.getenv("CI", "false").lower() == "true"
|
||||
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
|
||||
)
|
||||
|
||||
|
||||
def get_weight_shapes(args):
|
||||
models_tps = list(itertools.product(args.models, args.tp_sizes))
|
||||
@@ -80,15 +95,46 @@ def scale_shape(shape, group_shape):
|
||||
return tuple(cdiv(shape[i], group_shape[i]) for i in range(len(group_shape)))
|
||||
|
||||
|
||||
# CI environment uses simplified parameters
|
||||
if IS_CI:
|
||||
batch_sizes = [1, 8] # Simplified for CI
|
||||
else:
|
||||
batch_sizes = [1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096]
|
||||
|
||||
# Filter providers based on availability
|
||||
available_providers = ["sgl-kernel"]
|
||||
available_names = ["sgl-kernel"]
|
||||
available_styles = [("orange", "-")]
|
||||
|
||||
if VLLM_AVAILABLE:
|
||||
available_providers.insert(0, "vllm")
|
||||
available_names.insert(0, "vllm")
|
||||
available_styles.insert(0, ("blue", "-"))
|
||||
|
||||
available_providers.append("triton")
|
||||
available_names.append("sglang triton")
|
||||
available_styles.append(("red", "-"))
|
||||
|
||||
# Add deepgemm if available
|
||||
try:
|
||||
import deep_gemm
|
||||
|
||||
available_providers.append("deepgemm")
|
||||
available_names.append("deepgemm")
|
||||
available_styles.append(("yellow", "-"))
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size"],
|
||||
x_vals=[1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096],
|
||||
x_vals=batch_sizes,
|
||||
x_log=False,
|
||||
line_arg="provider",
|
||||
line_vals=["vllm", "sgl-kernel", "triton", "deepgemm"],
|
||||
line_names=["vllm", "sgl-kernel", "sglang triton", "deepgemm"],
|
||||
styles=[("blue", "-"), ("orange", "-"), ("red", "-"), ("yellow", "-")],
|
||||
line_vals=available_providers,
|
||||
line_names=available_names,
|
||||
styles=available_styles,
|
||||
ylabel="GB/s",
|
||||
plot_name="fp8 blockwise scaled matmul",
|
||||
args={},
|
||||
@@ -123,14 +169,16 @@ def benchmark(batch_size, provider, N, K):
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
if provider == "vllm":
|
||||
elif provider == "vllm":
|
||||
if not VLLM_AVAILABLE:
|
||||
return (0, 0, 0)
|
||||
scale_a = scale_a.t().contiguous().t()
|
||||
b_fp8, scale_b = b_fp8.t(), scale_b.t()
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
lambda: vllm_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, torch.float16),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
if provider == "triton":
|
||||
elif provider == "triton":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
lambda: w8a8_block_fp8_matmul(
|
||||
a_fp8, b_fp8, scale_a, scale_b, [128, 128], torch.float16
|
||||
@@ -166,7 +214,17 @@ if __name__ == "__main__":
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Simplify for CI environment
|
||||
if IS_CI:
|
||||
args.models = [args.models[0]] # Use only first model
|
||||
args.tp_sizes = [args.tp_sizes[0]] # Use only first TP size
|
||||
|
||||
NK_model_names = get_weight_shapes(args)
|
||||
|
||||
# Limit iterations in CI
|
||||
if IS_CI:
|
||||
NK_model_names = NK_model_names[:2] # Only test first 2 shapes in CI
|
||||
|
||||
for N, K, model_name in NK_model_names:
|
||||
if N % 128 != 0 or K % 128 != 0:
|
||||
print(f"Skip {N=}, {K=} now")
|
||||
|
||||
Reference in New Issue
Block a user