Fix sgl-kernel benchmark dead code (#11022)

This commit is contained in:
Xiaoyu Zhang
2025-09-29 15:06:40 +08:00
committed by GitHub
parent 71959545df
commit 11965b0daf
25 changed files with 1019 additions and 260 deletions

View File

@@ -2,6 +2,7 @@
# (kernel, dtype, batch_size, seq_len, dim) and prints speed-up.
import argparse
import itertools
import os
import re
from typing import List, Tuple
@@ -11,7 +12,21 @@ import torch.nn.functional as F
import triton
import triton.testing
from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
from vllm import _custom_ops as vllm_ops
# Optional vLLM import
try:
from vllm import _custom_ops as vllm_ops
VLLM_AVAILABLE = True
except ImportError:
vllm_ops = None
VLLM_AVAILABLE = False
# CI environment detection
IS_CI = (
os.getenv("CI", "false").lower() == "true"
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)
# gelu_quick is only available on HIP/ROCm platforms
try:
@@ -22,7 +37,7 @@ except ImportError:
GELU_QUICK_AVAILABLE = False
gelu_quick = None
if not hasattr(vllm_ops, "silu_and_mul"):
if VLLM_AVAILABLE and not hasattr(vllm_ops, "silu_and_mul"):
vllm_ops = torch.ops._C
@@ -40,6 +55,13 @@ def calculate_diff(
"""Compare vLLM with SGLang for one shape."""
device = torch.device("cuda")
if not VLLM_AVAILABLE:
print(
f"[{kernel:14s} | {str(dtype):9s} | B={batch_size:3d} | "
f"L={seq_len:3d} | D={dim:5d}] ⚠️ vLLM not available, skipping comparison"
)
return True
# activation-only quick GELU
if kernel == "gelu_quick":
if not GELU_QUICK_AVAILABLE:
@@ -68,19 +90,30 @@ def calculate_diff(
return ok
kernels = ["silu_and_mul", "gelu_and_mul", "gelu_tanh_and_mul"]
if GELU_QUICK_AVAILABLE:
kernels.append("gelu_quick")
dtypes = [torch.float16, torch.bfloat16]
# CI environment uses simplified parameters for kernels and dtypes too
if IS_CI:
kernels = ["silu_and_mul"] # Only test one kernel in CI
dtypes = [torch.float16] # Only test one dtype in CI
else:
kernels = ["silu_and_mul", "gelu_and_mul", "gelu_tanh_and_mul"]
if GELU_QUICK_AVAILABLE:
kernels.append("gelu_quick")
dtypes = [torch.float16, torch.bfloat16]
def make_configs(bsizes: List[int], slens: List[int], dims_: List[int]) -> List[Tuple]:
return list(itertools.product(kernels, dtypes, bsizes, slens, dims_))
default_batch_sizes = [2**i for i in range(0, 5, 2)] # 1,4,16
default_seq_lens = [2**i for i in range(0, 8, 2)] # 1,4,16,64
default_dims = [2**i for i in range(10, 15)] # 1024...16384
# CI environment uses simplified parameters
if IS_CI:
default_batch_sizes = [1] # Single batch size for CI
default_seq_lens = [1] # Single sequence length for CI
default_dims = [1024] # Single dimension for CI
else:
default_batch_sizes = [2**i for i in range(0, 5, 2)] # 1,4,16
default_seq_lens = [2**i for i in range(0, 8, 2)] # 1,4,16,64
default_dims = [2**i for i in range(10, 15)] # 1024...16384
@triton.testing.perf_report(
@@ -102,16 +135,24 @@ def benchmark(kernel, dtype, batch_size, seq_len, dim, provider):
x = torch.randn(batch_size, seq_len, in_mult * dim, dtype=dtype, device=device)
y0 = torch.zeros(batch_size, seq_len, dim, dtype=dtype, device=device)
vllm_kernel = getattr(vllm_ops, kernel)
if not VLLM_AVAILABLE and provider in ["vllm", "speedup"]:
# Skip vLLM-related benchmarks if vLLM is not available
return (0, 0, 0)
if VLLM_AVAILABLE:
vllm_kernel = getattr(vllm_ops, kernel)
if kernel == "gelu_quick" and not GELU_QUICK_AVAILABLE:
# Skip benchmark for gelu_quick if not available
return (0, 0, 0)
sglang_kernel = getattr(sgl_kernel, kernel)
def baseline():
tmp = y0.clone()
vllm_kernel(tmp, x)
return tmp
if VLLM_AVAILABLE:
tmp = y0.clone()
vllm_kernel(tmp, x)
return tmp
else:
return torch.zeros_like(y0)
def sglang():
return sglang_kernel(x)
@@ -134,7 +175,7 @@ def benchmark(kernel, dtype, batch_size, seq_len, dim, provider):
# provider == "speedup"
t_ref, _, _ = timed(baseline)
t_sgl, _, _ = timed(sglang)
spd = t_ref / t_sgl
spd = t_ref / t_sgl if t_ref > 0 else 1.0
return (spd, spd, spd)