Fix sgl-kernel benchmark dead code (#11022)
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
# (kernel, dtype, batch_size, seq_len, dim) and prints speed-up.
|
||||
import argparse
|
||||
import itertools
|
||||
import os
|
||||
import re
|
||||
from typing import List, Tuple
|
||||
|
||||
@@ -11,7 +12,21 @@ import torch.nn.functional as F
|
||||
import triton
|
||||
import triton.testing
|
||||
from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
|
||||
from vllm import _custom_ops as vllm_ops
|
||||
|
||||
# Optional vLLM import
|
||||
try:
|
||||
from vllm import _custom_ops as vllm_ops
|
||||
|
||||
VLLM_AVAILABLE = True
|
||||
except ImportError:
|
||||
vllm_ops = None
|
||||
VLLM_AVAILABLE = False
|
||||
|
||||
# CI environment detection
|
||||
IS_CI = (
|
||||
os.getenv("CI", "false").lower() == "true"
|
||||
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
|
||||
)
|
||||
|
||||
# gelu_quick is only available on HIP/ROCm platforms
|
||||
try:
|
||||
@@ -22,7 +37,7 @@ except ImportError:
|
||||
GELU_QUICK_AVAILABLE = False
|
||||
gelu_quick = None
|
||||
|
||||
if not hasattr(vllm_ops, "silu_and_mul"):
|
||||
if VLLM_AVAILABLE and not hasattr(vllm_ops, "silu_and_mul"):
|
||||
vllm_ops = torch.ops._C
|
||||
|
||||
|
||||
@@ -40,6 +55,13 @@ def calculate_diff(
|
||||
"""Compare vLLM with SGLang for one shape."""
|
||||
device = torch.device("cuda")
|
||||
|
||||
if not VLLM_AVAILABLE:
|
||||
print(
|
||||
f"[{kernel:14s} | {str(dtype):9s} | B={batch_size:3d} | "
|
||||
f"L={seq_len:3d} | D={dim:5d}] ⚠️ vLLM not available, skipping comparison"
|
||||
)
|
||||
return True
|
||||
|
||||
# activation-only quick GELU
|
||||
if kernel == "gelu_quick":
|
||||
if not GELU_QUICK_AVAILABLE:
|
||||
@@ -68,19 +90,30 @@ def calculate_diff(
|
||||
return ok
|
||||
|
||||
|
||||
kernels = ["silu_and_mul", "gelu_and_mul", "gelu_tanh_and_mul"]
|
||||
if GELU_QUICK_AVAILABLE:
|
||||
kernels.append("gelu_quick")
|
||||
dtypes = [torch.float16, torch.bfloat16]
|
||||
# CI environment uses simplified parameters for kernels and dtypes too
|
||||
if IS_CI:
|
||||
kernels = ["silu_and_mul"] # Only test one kernel in CI
|
||||
dtypes = [torch.float16] # Only test one dtype in CI
|
||||
else:
|
||||
kernels = ["silu_and_mul", "gelu_and_mul", "gelu_tanh_and_mul"]
|
||||
if GELU_QUICK_AVAILABLE:
|
||||
kernels.append("gelu_quick")
|
||||
dtypes = [torch.float16, torch.bfloat16]
|
||||
|
||||
|
||||
def make_configs(bsizes: List[int], slens: List[int], dims_: List[int]) -> List[Tuple]:
|
||||
return list(itertools.product(kernels, dtypes, bsizes, slens, dims_))
|
||||
|
||||
|
||||
default_batch_sizes = [2**i for i in range(0, 5, 2)] # 1,4,16
|
||||
default_seq_lens = [2**i for i in range(0, 8, 2)] # 1,4,16,64
|
||||
default_dims = [2**i for i in range(10, 15)] # 1024...16384
|
||||
# CI environment uses simplified parameters
|
||||
if IS_CI:
|
||||
default_batch_sizes = [1] # Single batch size for CI
|
||||
default_seq_lens = [1] # Single sequence length for CI
|
||||
default_dims = [1024] # Single dimension for CI
|
||||
else:
|
||||
default_batch_sizes = [2**i for i in range(0, 5, 2)] # 1,4,16
|
||||
default_seq_lens = [2**i for i in range(0, 8, 2)] # 1,4,16,64
|
||||
default_dims = [2**i for i in range(10, 15)] # 1024...16384
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
@@ -102,16 +135,24 @@ def benchmark(kernel, dtype, batch_size, seq_len, dim, provider):
|
||||
x = torch.randn(batch_size, seq_len, in_mult * dim, dtype=dtype, device=device)
|
||||
y0 = torch.zeros(batch_size, seq_len, dim, dtype=dtype, device=device)
|
||||
|
||||
vllm_kernel = getattr(vllm_ops, kernel)
|
||||
if not VLLM_AVAILABLE and provider in ["vllm", "speedup"]:
|
||||
# Skip vLLM-related benchmarks if vLLM is not available
|
||||
return (0, 0, 0)
|
||||
|
||||
if VLLM_AVAILABLE:
|
||||
vllm_kernel = getattr(vllm_ops, kernel)
|
||||
if kernel == "gelu_quick" and not GELU_QUICK_AVAILABLE:
|
||||
# Skip benchmark for gelu_quick if not available
|
||||
return (0, 0, 0)
|
||||
sglang_kernel = getattr(sgl_kernel, kernel)
|
||||
|
||||
def baseline():
|
||||
tmp = y0.clone()
|
||||
vllm_kernel(tmp, x)
|
||||
return tmp
|
||||
if VLLM_AVAILABLE:
|
||||
tmp = y0.clone()
|
||||
vllm_kernel(tmp, x)
|
||||
return tmp
|
||||
else:
|
||||
return torch.zeros_like(y0)
|
||||
|
||||
def sglang():
|
||||
return sglang_kernel(x)
|
||||
@@ -134,7 +175,7 @@ def benchmark(kernel, dtype, batch_size, seq_len, dim, provider):
|
||||
# provider == "speedup"
|
||||
t_ref, _, _ = timed(baseline)
|
||||
t_sgl, _, _ = timed(sglang)
|
||||
spd = t_ref / t_sgl
|
||||
spd = t_ref / t_sgl if t_ref > 0 else 1.0
|
||||
return (spd, spd, spd)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user