adapt to sglang v0.5.2rc1 on dcu
This commit is contained in:
153
sgl-kernel/benchmark/bench_activation.py
Normal file
153
sgl-kernel/benchmark/bench_activation.py
Normal file
@@ -0,0 +1,153 @@
|
||||
# Benchmarks SGLang kernels versus vLLM across
|
||||
# (kernel, dtype, batch_size, seq_len, dim) and prints speed-up.
|
||||
import argparse
|
||||
import itertools
|
||||
import re
|
||||
from typing import List, Tuple
|
||||
|
||||
import sgl_kernel
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import triton
|
||||
import triton.testing
|
||||
from sgl_kernel import gelu_quick # activation-only kernel
|
||||
from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
|
||||
from vllm import _custom_ops as vllm_ops
|
||||
|
||||
if not hasattr(vllm_ops, "silu_and_mul"):
|
||||
vllm_ops = torch.ops._C
|
||||
|
||||
|
||||
def str2int_list(arg: str) -> List[int]:
|
||||
if arg in ("", None):
|
||||
return []
|
||||
if re.fullmatch(r"\d+(,\d+)*", arg.strip()) is None:
|
||||
raise argparse.ArgumentTypeError(f"Bad int list: {arg}")
|
||||
return [int(x) for x in arg.split(",")]
|
||||
|
||||
|
||||
def calculate_diff(
|
||||
kernel: str, dtype: torch.dtype, batch_size: int, seq_len: int, dim: int
|
||||
) -> bool:
|
||||
"""Compare vLLM with SGLang for one shape."""
|
||||
device = torch.device("cuda")
|
||||
|
||||
# activation-only quick GELU
|
||||
if kernel == "gelu_quick":
|
||||
x = torch.randn(batch_size, seq_len, dim, dtype=dtype, device=device)
|
||||
ref_out = torch.zeros_like(x)
|
||||
getattr(vllm_ops, kernel)(ref_out, x)
|
||||
test_out = getattr(sgl_kernel, kernel)(x)
|
||||
# fused activation x mul kernels
|
||||
else:
|
||||
x = torch.randn(batch_size, seq_len, 2 * dim, dtype=dtype, device=device)
|
||||
ref_out = torch.zeros(batch_size, seq_len, dim, dtype=dtype, device=device)
|
||||
getattr(vllm_ops, kernel)(ref_out, x)
|
||||
test_out = getattr(sgl_kernel, kernel)(x)
|
||||
|
||||
ok = torch.allclose(ref_out, test_out, rtol=1e-3, atol=1e-5)
|
||||
tag = "✅ match" if ok else "❌ mismatch"
|
||||
print(
|
||||
f"[{kernel:14s} | {str(dtype):9s} | B={batch_size:3d} | "
|
||||
f"L={seq_len:3d} | D={dim:5d}] {tag}"
|
||||
)
|
||||
return ok
|
||||
|
||||
|
||||
kernels = ["silu_and_mul", "gelu_and_mul", "gelu_tanh_and_mul", "gelu_quick"]
|
||||
dtypes = [torch.float16, torch.bfloat16]
|
||||
|
||||
|
||||
def make_configs(bsizes: List[int], slens: List[int], dims_: List[int]) -> List[Tuple]:
|
||||
return list(itertools.product(kernels, dtypes, bsizes, slens, dims_))
|
||||
|
||||
|
||||
default_batch_sizes = [2**i for i in range(0, 5, 2)] # 1,4,16
|
||||
default_seq_lens = [2**i for i in range(0, 8, 2)] # 1,4,16,64
|
||||
default_dims = [2**i for i in range(7, 15)] # 128...16384
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["kernel", "dtype", "batch_size", "seq_len", "dim"],
|
||||
x_vals=[],
|
||||
line_arg="provider",
|
||||
line_vals=["vllm", "sglang", "speedup"],
|
||||
line_names=["vLLM", "SGL Kernel", "Speed-up (x)"],
|
||||
styles=[("blue", "-"), ("green", "-"), ("red", "--")],
|
||||
ylabel="µs (median) or × (speed-up)",
|
||||
plot_name="activation-performance",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(kernel, dtype, batch_size, seq_len, dim, provider):
|
||||
device = torch.device("cuda")
|
||||
in_mult = 1 if kernel == "gelu_quick" else 2
|
||||
x = torch.randn(batch_size, seq_len, in_mult * dim, dtype=dtype, device=device)
|
||||
y0 = torch.zeros(batch_size, seq_len, dim, dtype=dtype, device=device)
|
||||
|
||||
vllm_kernel = getattr(vllm_ops, kernel)
|
||||
sglang_kernel = getattr(sgl_kernel, kernel)
|
||||
|
||||
def baseline():
|
||||
tmp = y0.clone()
|
||||
vllm_kernel(tmp, x)
|
||||
return tmp
|
||||
|
||||
def sglang():
|
||||
return sglang_kernel(x)
|
||||
|
||||
# one-time correctness check
|
||||
if provider == "vllm" and not calculate_diff(
|
||||
kernel, dtype, batch_size, seq_len, dim
|
||||
):
|
||||
raise ValueError("Mismatch – abort benchmark")
|
||||
|
||||
# timing helper
|
||||
def timed(fn):
|
||||
for _ in range(5):
|
||||
fn()
|
||||
torch.cuda.synchronize()
|
||||
ms, qmin, qmax = triton.testing.do_bench(fn, quantiles=[0.5, 0.2, 0.8])
|
||||
return 1000 * ms, 1000 * qmax, 1000 * qmin
|
||||
|
||||
if provider == "vllm":
|
||||
return timed(baseline)
|
||||
if provider == "sglang":
|
||||
return timed(sglang)
|
||||
|
||||
# provider == "speedup"
|
||||
t_ref, _, _ = timed(baseline)
|
||||
t_sgl, _, _ = timed(sglang)
|
||||
spd = t_ref / t_sgl
|
||||
return (spd, spd, spd)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
p = argparse.ArgumentParser("Activation kernel benchmark")
|
||||
p.add_argument("--batch_sizes", type=str2int_list, default=default_batch_sizes)
|
||||
p.add_argument("--seq_lens", type=str2int_list, default=default_seq_lens)
|
||||
p.add_argument("--dims", type=str2int_list, default=default_dims)
|
||||
p.add_argument("--verify_only", action="store_true")
|
||||
args = p.parse_args()
|
||||
|
||||
# coerce lists
|
||||
if isinstance(args.batch_sizes, str):
|
||||
args.batch_sizes = str2int_list(args.batch_sizes)
|
||||
if isinstance(args.seq_lens, str):
|
||||
args.seq_lens = str2int_list(args.seq_lens)
|
||||
if isinstance(args.dims, str):
|
||||
args.dims = str2int_list(args.dims)
|
||||
|
||||
# patch perf_report grid
|
||||
benchmark_grid = make_configs(args.batch_sizes, args.seq_lens, args.dims)
|
||||
if hasattr(benchmark, "benchmarks"):
|
||||
benchmark.benchmarks.x_vals = benchmark_grid
|
||||
else:
|
||||
benchmark.benchmark.x_vals = benchmark_grid
|
||||
|
||||
if args.verify_only:
|
||||
ok = calculate_diff("gelu_quick", torch.float16, 1, 1, args.dims[0])
|
||||
print("✅ sanity pass" if ok else "❌ mismatch")
|
||||
else:
|
||||
benchmark.run(print_data=True)
|
||||
Reference in New Issue
Block a user