Restruct sgl-kernel benchmark (#10861)
This commit is contained in:
@@ -10,10 +10,18 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
import triton
|
||||
import triton.testing
|
||||
from sgl_kernel import gelu_quick # activation-only kernel
|
||||
from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
|
||||
from vllm import _custom_ops as vllm_ops
|
||||
|
||||
# gelu_quick is only available on HIP/ROCm platforms
|
||||
try:
|
||||
from sgl_kernel import gelu_quick
|
||||
|
||||
GELU_QUICK_AVAILABLE = True
|
||||
except ImportError:
|
||||
GELU_QUICK_AVAILABLE = False
|
||||
gelu_quick = None
|
||||
|
||||
if not hasattr(vllm_ops, "silu_and_mul"):
|
||||
vllm_ops = torch.ops._C
|
||||
|
||||
@@ -34,6 +42,12 @@ def calculate_diff(
|
||||
|
||||
# activation-only quick GELU
|
||||
if kernel == "gelu_quick":
|
||||
if not GELU_QUICK_AVAILABLE:
|
||||
print(
|
||||
f"[{kernel:14s} | {str(dtype):9s} | B={batch_size:3d} | "
|
||||
f"L={seq_len:3d} | D={dim:5d}] ⚠️ not available on this platform"
|
||||
)
|
||||
return True
|
||||
x = torch.randn(batch_size, seq_len, dim, dtype=dtype, device=device)
|
||||
ref_out = torch.zeros_like(x)
|
||||
getattr(vllm_ops, kernel)(ref_out, x)
|
||||
@@ -54,7 +68,9 @@ def calculate_diff(
|
||||
return ok
|
||||
|
||||
|
||||
kernels = ["silu_and_mul", "gelu_and_mul", "gelu_tanh_and_mul", "gelu_quick"]
|
||||
kernels = ["silu_and_mul", "gelu_and_mul", "gelu_tanh_and_mul"]
|
||||
if GELU_QUICK_AVAILABLE:
|
||||
kernels.append("gelu_quick")
|
||||
dtypes = [torch.float16, torch.bfloat16]
|
||||
|
||||
|
||||
@@ -64,7 +80,7 @@ def make_configs(bsizes: List[int], slens: List[int], dims_: List[int]) -> List[
|
||||
|
||||
default_batch_sizes = [2**i for i in range(0, 5, 2)] # 1,4,16
|
||||
default_seq_lens = [2**i for i in range(0, 8, 2)] # 1,4,16,64
|
||||
default_dims = [2**i for i in range(7, 15)] # 128...16384
|
||||
default_dims = [2**i for i in range(10, 15)] # 1024...16384
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
@@ -87,6 +103,9 @@ def benchmark(kernel, dtype, batch_size, seq_len, dim, provider):
|
||||
y0 = torch.zeros(batch_size, seq_len, dim, dtype=dtype, device=device)
|
||||
|
||||
vllm_kernel = getattr(vllm_ops, kernel)
|
||||
if kernel == "gelu_quick" and not GELU_QUICK_AVAILABLE:
|
||||
# Skip benchmark for gelu_quick if not available
|
||||
return (0, 0, 0)
|
||||
sglang_kernel = getattr(sgl_kernel, kernel)
|
||||
|
||||
def baseline():
|
||||
@@ -97,18 +116,14 @@ def benchmark(kernel, dtype, batch_size, seq_len, dim, provider):
|
||||
def sglang():
|
||||
return sglang_kernel(x)
|
||||
|
||||
# one-time correctness check
|
||||
if provider == "vllm" and not calculate_diff(
|
||||
kernel, dtype, batch_size, seq_len, dim
|
||||
):
|
||||
raise ValueError("Mismatch – abort benchmark")
|
||||
|
||||
# timing helper
|
||||
def timed(fn):
|
||||
for _ in range(5):
|
||||
fn()
|
||||
torch.cuda.synchronize()
|
||||
ms, qmin, qmax = triton.testing.do_bench(fn, quantiles=[0.5, 0.2, 0.8])
|
||||
ms, qmin, qmax = triton.testing.do_bench_cudagraph(
|
||||
fn, quantiles=[0.5, 0.2, 0.8]
|
||||
)
|
||||
return 1000 * ms, 1000 * qmax, 1000 * qmin
|
||||
|
||||
if provider == "vllm":
|
||||
@@ -147,7 +162,9 @@ if __name__ == "__main__":
|
||||
benchmark.benchmark.x_vals = benchmark_grid
|
||||
|
||||
if args.verify_only:
|
||||
ok = calculate_diff("gelu_quick", torch.float16, 1, 1, args.dims[0])
|
||||
# Test with the first available kernel
|
||||
test_kernel = kernels[0]
|
||||
ok = calculate_diff(test_kernel, torch.float16, 1, 1, args.dims[0])
|
||||
print("✅ sanity pass" if ok else "❌ mismatch")
|
||||
else:
|
||||
benchmark.run(print_data=True)
|
||||
|
||||
Reference in New Issue
Block a user