Restruct sgl-kernel benchmark (#10861)

This commit is contained in:
Xiaoyu Zhang
2025-09-25 07:45:25 +08:00
committed by GitHub
parent 7a06ef984d
commit c4e314f986
27 changed files with 425 additions and 319 deletions

View File

@@ -251,6 +251,14 @@ To use this with your library functions, simply wrap them with make_pytorch_shim
```
2. Add benchmarks using [triton benchmark](https://triton-lang.org/main/python-api/generated/triton.testing.Benchmark.html) in [benchmark/](https://github.com/sgl-project/sglang/tree/main/sgl-kernel/benchmark)
**We recommend using `triton.testing.do_bench_cudagraph` for kernel benchmarking**:
Compared to `triton.testing.do_bench`, `do_bench_cudagraph` provides:
- Reduced CPU overhead impact for more accurate kernel performance measurements
- Incorporation of PDL (Programmatic Dependent Launch) effects into individual kernel results
- More realistic performance data on PDL-supported architectures (SM >= 90)
3. Run test suite
### FAQ

View File

@@ -10,10 +10,18 @@ import torch
import torch.nn.functional as F
import triton
import triton.testing
from sgl_kernel import gelu_quick # activation-only kernel
from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
from vllm import _custom_ops as vllm_ops
# gelu_quick is only available on HIP/ROCm platforms
try:
from sgl_kernel import gelu_quick
GELU_QUICK_AVAILABLE = True
except ImportError:
GELU_QUICK_AVAILABLE = False
gelu_quick = None
if not hasattr(vllm_ops, "silu_and_mul"):
vllm_ops = torch.ops._C
@@ -34,6 +42,12 @@ def calculate_diff(
# activation-only quick GELU
if kernel == "gelu_quick":
if not GELU_QUICK_AVAILABLE:
print(
f"[{kernel:14s} | {str(dtype):9s} | B={batch_size:3d} | "
f"L={seq_len:3d} | D={dim:5d}] ⚠️ not available on this platform"
)
return True
x = torch.randn(batch_size, seq_len, dim, dtype=dtype, device=device)
ref_out = torch.zeros_like(x)
getattr(vllm_ops, kernel)(ref_out, x)
@@ -54,7 +68,9 @@ def calculate_diff(
return ok
kernels = ["silu_and_mul", "gelu_and_mul", "gelu_tanh_and_mul", "gelu_quick"]
kernels = ["silu_and_mul", "gelu_and_mul", "gelu_tanh_and_mul"]
if GELU_QUICK_AVAILABLE:
kernels.append("gelu_quick")
dtypes = [torch.float16, torch.bfloat16]
@@ -64,7 +80,7 @@ def make_configs(bsizes: List[int], slens: List[int], dims_: List[int]) -> List[
default_batch_sizes = [2**i for i in range(0, 5, 2)] # 1,4,16
default_seq_lens = [2**i for i in range(0, 8, 2)] # 1,4,16,64
default_dims = [2**i for i in range(7, 15)] # 128...16384
default_dims = [2**i for i in range(10, 15)] # 1024...16384
@triton.testing.perf_report(
@@ -87,6 +103,9 @@ def benchmark(kernel, dtype, batch_size, seq_len, dim, provider):
y0 = torch.zeros(batch_size, seq_len, dim, dtype=dtype, device=device)
vllm_kernel = getattr(vllm_ops, kernel)
if kernel == "gelu_quick" and not GELU_QUICK_AVAILABLE:
# Skip benchmark for gelu_quick if not available
return (0, 0, 0)
sglang_kernel = getattr(sgl_kernel, kernel)
def baseline():
@@ -97,18 +116,14 @@ def benchmark(kernel, dtype, batch_size, seq_len, dim, provider):
def sglang():
return sglang_kernel(x)
# one-time correctness check
if provider == "vllm" and not calculate_diff(
kernel, dtype, batch_size, seq_len, dim
):
raise ValueError("Mismatch abort benchmark")
# timing helper
def timed(fn):
for _ in range(5):
fn()
torch.cuda.synchronize()
ms, qmin, qmax = triton.testing.do_bench(fn, quantiles=[0.5, 0.2, 0.8])
ms, qmin, qmax = triton.testing.do_bench_cudagraph(
fn, quantiles=[0.5, 0.2, 0.8]
)
return 1000 * ms, 1000 * qmax, 1000 * qmin
if provider == "vllm":
@@ -147,7 +162,9 @@ if __name__ == "__main__":
benchmark.benchmark.x_vals = benchmark_grid
if args.verify_only:
ok = calculate_diff("gelu_quick", torch.float16, 1, 1, args.dims[0])
# Test with the first available kernel
test_kernel = kernels[0]
ok = calculate_diff(test_kernel, torch.float16, 1, 1, args.dims[0])
print("✅ sanity pass" if ok else "❌ mismatch")
else:
benchmark.run(print_data=True)

View File

@@ -108,7 +108,7 @@ def benchmark(qweight_row, qweight_col, provider):
qweight.clone(), scales.clone(), qzeros.clone()
)
ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=quantiles)
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=quantiles)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms

View File

@@ -87,7 +87,7 @@ def benchmark(batch_size, seq_len, provider, block_size, num_kv_splits):
workspace = torch.empty(workspace_size, device="cuda", dtype=torch.uint8)
quantiles = [0.5, 0.2, 0.8]
ms, min_ms, max_ms = triton.testing.do_bench(
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: cutlass_mla_decode(
qn.transpose(0, 1),
qr,
@@ -136,8 +136,6 @@ if __name__ == "__main__":
print(f"block_size={block_size}, num_kv_splits={kv_split}: ")
benchmark.run(
print_data=True,
show_plots=True,
save_path="bench_blackwell_mla_res",
block_size=block_size,
num_kv_splits=kv_split,
)

View File

@@ -41,7 +41,7 @@ def benchmark(num_tokens, impl):
def runner():
dsv3_fused_a_gemm(mat_a, mat_b)
ms, min_ms, max_ms = triton.testing.do_bench(runner, quantiles=quantiles)
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(runner, quantiles=quantiles)
def tflops(t_ms):
flops = 2 * M * K * N
@@ -54,4 +54,4 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser()
args = parser.parse_args()
benchmark.run(print_data=True, show_plots=True, save_path="bench_dsv3_gemm")
benchmark.run(print_data=True)

View File

@@ -52,7 +52,7 @@ def benchmark_bf16_output(num_tokens, impl):
def runner():
dsv3_router_gemm(mat_a, mat_b, out_dtype=torch.bfloat16)
ms, min_ms, max_ms = triton.testing.do_bench(runner, quantiles=quantiles)
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(runner, quantiles=quantiles)
def tflops(t_ms):
flops = 2 * M * K * N
@@ -106,7 +106,7 @@ def benchmark_float_output(num_tokens, impl):
def runner():
dsv3_router_gemm(mat_a, mat_b, out_dtype=torch.float32)
ms, min_ms, max_ms = triton.testing.do_bench(runner, quantiles=quantiles)
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(runner, quantiles=quantiles)
def tflops(t_ms):
flops = 2 * M * K * N
@@ -119,9 +119,5 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser()
args = parser.parse_args()
benchmark_bf16_output.run(
print_data=True, show_plots=True, save_path="bench_dsv3_router_gemm"
)
benchmark_float_output.run(
print_data=True, show_plots=True, save_path="bench_dsv3_router_gemm"
)
benchmark_bf16_output.run(print_data=True)
benchmark_float_output.run(print_data=True)

View File

@@ -198,8 +198,6 @@ if __name__ == "__main__":
print(f"DeepSeek-R1-0528-FP4 N={N} K={K}: ")
benchmark.run(
print_data=True,
show_plots=True,
save_path="bench_fp4_res",
N=N,
K=K,
dtype=args.dtype,

View File

@@ -5,7 +5,7 @@ import itertools
import deep_gemm
import torch
import triton
from deep_gemm import get_col_major_tma_aligned_tensor
from deep_gemm.utils.layout import get_mn_major_tma_aligned_tensor
from sgl_kernel import fp8_blockwise_scaled_mm
from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
@@ -71,7 +71,7 @@ def fp8_gemm_deepgemm(
out = torch.empty((m, n), device="cuda", dtype=torch.bfloat16)
# Run DeepGEMM kernel
deep_gemm.gemm_fp8_fp8_bf16_nt((x_fp8, x_scale), (y_fp8, y_scale), out)
deep_gemm.fp8_gemm_nt((x_fp8, x_scale), (y_fp8, y_scale), out)
return out
@@ -117,7 +117,7 @@ def benchmark(batch_size, provider, N, K):
if provider == "sgl-kernel":
scale_a = scale_a.t().contiguous().t()
b_fp8, scale_b = b_fp8.t(), scale_b.t()
ms, min_ms, max_ms = triton.testing.do_bench(
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: fp8_blockwise_scaled_mm(
a_fp8, b_fp8, scale_a, scale_b, torch.float16
),
@@ -126,20 +126,20 @@ def benchmark(batch_size, provider, N, K):
if provider == "vllm":
scale_a = scale_a.t().contiguous().t()
b_fp8, scale_b = b_fp8.t(), scale_b.t()
ms, min_ms, max_ms = triton.testing.do_bench(
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: vllm_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, torch.float16),
quantiles=quantiles,
)
if provider == "triton":
ms, min_ms, max_ms = triton.testing.do_bench(
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: w8a8_block_fp8_matmul(
a_fp8, b_fp8, scale_a, scale_b, [128, 128], torch.float16
),
quantiles=quantiles,
)
if provider == "deepgemm":
scale_a_col_major = get_col_major_tma_aligned_tensor(scale_a.clone())
ms, min_ms, max_ms = triton.testing.do_bench(
scale_a_col_major = get_mn_major_tma_aligned_tensor(scale_a.clone())
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: fp8_gemm_deepgemm(
a_fp8, scale_a_col_major, b_fp8, scale_b, M, N, K
),
@@ -174,8 +174,6 @@ if __name__ == "__main__":
print(f"{model_name} N={N} K={K}: ")
benchmark.run(
print_data=True,
show_plots=True,
save_path="bench_fp8_blockwise_res",
N=N,
K=K,
)

View File

@@ -125,7 +125,7 @@ def benchmark(batch_size, provider, N, K):
a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a)
b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b)
b_fp8 = b_fp8.t()
ms, min_ms, max_ms = triton.testing.do_bench(
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype),
quantiles=quantiles,
)
@@ -133,7 +133,7 @@ def benchmark(batch_size, provider, N, K):
a_fp8, scale_a_fp8 = sglang_scaled_fp8_quant(a, scale_a)
b_fp8, scale_b_fp8 = sglang_scaled_fp8_quant(b, scale_b)
b_fp8 = b_fp8.t()
ms, min_ms, max_ms = triton.testing.do_bench(
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: sgl_scaled_mm(
a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype, bias=None
),
@@ -177,8 +177,6 @@ if __name__ == "__main__":
KN_model_names = prepare_shapes(args)
for K, N, model_name in KN_model_names:
print(f"{model_name} N={N} K={K}: ")
benchmark.run(
print_data=True, show_plots=True, save_path="bench_fp8_res", N=N, K=K
)
benchmark.run(print_data=True, N=N, K=K)
print("Benchmark finished!")

View File

@@ -86,12 +86,12 @@ def benchmark(batch_size, provider, N, K):
quantiles = [0.5, 0.2, 0.8]
if provider == "sgl-kernel":
ms, min_ms, max_ms = triton.testing.do_bench(
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: int8_scaled_mm(a, b, scale_a, scale_b, torch.float16, bias),
quantiles=quantiles,
)
if provider == "vllm":
ms, min_ms, max_ms = triton.testing.do_bench(
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: vllm_scaled_mm(a, b, scale_a, scale_b, torch.float16, bias),
quantiles=quantiles,
)
@@ -139,8 +139,6 @@ if __name__ == "__main__":
KN_model_names = prepare_shapes(args)
for K, N, model_name in KN_model_names:
print(f"{model_name} N={N} K={K}: ")
benchmark.run(
print_data=True, show_plots=True, save_path="bench_int8_res", N=N, K=K
)
benchmark.run(print_data=True, N=N, K=K)
print("Benchmark finished!")

View File

@@ -246,7 +246,7 @@ def benchmark(batch_size, provider):
quantiles = [0.5, 0.2, 0.8]
if provider == "naive":
ms, min_ms, max_ms = triton.testing.do_bench(
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: lightning_attention_decode_naive(
q.clone(), k.clone(), v.clone(), past_kv.clone(), slope.clone()
),
@@ -257,7 +257,7 @@ def benchmark(batch_size, provider):
batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype
)
new_kv = torch.empty(batch_size, num_heads, head_dim, head_dim, device=device)
ms, min_ms, max_ms = triton.testing.do_bench(
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: lightning_attention_decode_kernel(
q.clone(),
k.clone(),
@@ -270,7 +270,7 @@ def benchmark(batch_size, provider):
quantiles=quantiles,
)
elif provider == "triton":
ms, min_ms, max_ms = triton.testing.do_bench(
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: triton_lightning_attn_decode(
q.clone(), k.clone(), v.clone(), past_kv.clone(), slope.clone()
),

View File

@@ -324,7 +324,7 @@ def benchmark(num_tokens, num_experts, topk, provider):
quantiles = [0.5, 0.2, 0.8]
if provider == "sgl":
ms, min_ms, max_ms = triton.testing.do_bench(
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: sgl_moe_align_block_size_with_empty(
topk_ids,
num_experts,
@@ -336,7 +336,7 @@ def benchmark(num_tokens, num_experts, topk, provider):
quantiles=quantiles,
)
elif provider == "sgl_fusion":
ms, min_ms, max_ms = triton.testing.do_bench(
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: sgl_moe_align_block_size_with_empty(
topk_ids,
num_experts,
@@ -350,7 +350,7 @@ def benchmark(num_tokens, num_experts, topk, provider):
)
elif provider == "triton":
sorted_ids.fill_(topk_ids.numel())
ms, min_ms, max_ms = triton.testing.do_bench(
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: moe_align_block_size_triton(
topk_ids,
num_experts,

View File

@@ -63,7 +63,9 @@ def benchmark(batch_size, provider):
block_size,
)
ms, min_ms, max_ms = triton.testing.do_bench(run_triton, quantiles=quantiles)
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
run_triton, quantiles=quantiles
)
else:
raise ValueError(f"Unknown provider: {provider}")

View File

@@ -46,7 +46,7 @@ configs = [(sq,) for sq in seq_length_range]
)
)
def benchmark(seq_length, provider):
dtype = torch.bfloat16
dtype = torch.float32
device = torch.device("cuda")
num_experts, num_expert_group, topk_group, topk = 256, 8, 4, 8
@@ -56,14 +56,14 @@ def benchmark(seq_length, provider):
quantiles = [0.5, 0.2, 0.8]
if provider == "original":
ms, min_ms, max_ms = triton.testing.do_bench(
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: biased_grouped_topk_org(
scores.clone(), bias.clone(), num_expert_group, topk_group, topk
),
quantiles=quantiles,
)
elif provider == "kernel":
ms, min_ms, max_ms = triton.testing.do_bench(
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: biased_grouped_topk_org_fuse_kernel(
scores.clone(), bias.clone(), num_expert_group, topk_group, topk
),

View File

@@ -97,7 +97,7 @@ def benchmark(num_tokens, num_experts, topk, provider):
fn = lambda: sglang_topk_softmax(gating_output, topk)
quantiles = [0.5, 0.2, 0.8]
ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=quantiles)
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=quantiles)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms

View File

@@ -165,8 +165,6 @@ if __name__ == "__main__":
KN_model_names = prepare_shapes(args)
for K, N, model_name in KN_model_names:
print(f"{model_name} N={N} K={K}: ")
benchmark.run(
print_data=True, show_plots=True, save_path="bench_fp4_res", N=N, K=K
)
benchmark.run(print_data=True, N=N, K=K)
print("Benchmark finished!")

View File

@@ -88,7 +88,7 @@ def benchmark(batch_size, seq_len, provider):
elif provider == "sglang":
fn = lambda: sglang_scaled_fp8_quant(x.clone())
ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=quantiles)
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=quantiles)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms

View File

@@ -160,7 +160,7 @@ def benchmark_quantization(batch_size, seq_len, hidden_dim, provider):
elif provider == "sglang":
fn = lambda: sglang_per_token_quant_fp8(x.clone())
ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=quantiles)
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=quantiles)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms

View File

@@ -117,17 +117,17 @@ def benchmark(batch_size, provider, N, K):
quantiles = [0.5, 0.2, 0.8]
if provider == "FP16":
ms, min_ms, max_ms = triton.testing.do_bench(
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: torch.matmul(a_fp16, b_fp16),
quantiles=quantiles,
)
if provider == "W8A8":
ms, min_ms, max_ms = triton.testing.do_bench(
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: int8_scaled_mm(a, b, scale_a, scale_b, torch.float16),
quantiles=quantiles,
)
if provider == "Qserve_W4A8_Per_Channel":
ms, min_ms, max_ms = triton.testing.do_bench(
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: qserve_w4a8_per_chn_gemm(
a_qserve_chn,
b_qserve_chn,
@@ -139,7 +139,7 @@ def benchmark(batch_size, provider, N, K):
quantiles=quantiles,
)
if provider == "Qserve_W4A8_Per_Group":
ms, min_ms, max_ms = triton.testing.do_bench(
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: qserve_w4a8_per_group_gemm(
a_qserve_group,
b_qserve_group,
@@ -189,8 +189,6 @@ if __name__ == "__main__":
print(f"{model_name} N={N} K={K}: ")
benchmark.run(
print_data=True,
show_plots=True,
save_path="bench_qserve_w4a8_gemm_res",
N=N,
K=K,
)

View File

@@ -0,0 +1,318 @@
# Benchmarks SGLang RMSNorm kernels versus vLLM and FlashInfer across
# (batch_size, seq_len, hidden_size) and prints speed-up.
import argparse
import itertools
import re
from typing import List, Optional, Tuple, Union
import sgl_kernel
import torch
import torch.nn as nn
import triton
import triton.testing
from flashinfer.norm import fused_add_rmsnorm, rmsnorm
from sgl_kernel.utils import is_arch_support_pdl
from vllm import _custom_ops as vllm_ops
def str2int_list(arg: str) -> List[int]:
if arg in ("", None):
return []
if re.fullmatch(r"\d+(,\d+)*", arg.strip()) is None:
raise argparse.ArgumentTypeError(f"Bad int list: {arg}")
return [int(x) for x in arg.split(",")]
class HuggingFaceRMSNorm(nn.Module):
def __init__(self, hidden_size: int, eps: float = 1e-6) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
orig_dtype = x.dtype
x = x.to(torch.float32)
if residual is not None:
x = x + residual.to(torch.float32)
residual = x.to(orig_dtype)
variance = x.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + self.variance_epsilon)
x = x.to(orig_dtype) * self.weight
if residual is None:
return x
else:
return x, residual
def rmsnorm_naive(
x: torch.Tensor,
weight: torch.Tensor,
residual: Optional[torch.Tensor] = None,
eps: float = 1e-6,
):
naive_norm = HuggingFaceRMSNorm(x.shape[-1], eps=eps)
naive_norm.weight = nn.Parameter(weight)
naive_norm = naive_norm.to(x.device)
orig_shape = x.shape
x = x.view(-1, x.shape[-1])
if residual is not None:
residual = residual.view(-1, residual.shape[-1])
output = naive_norm(x, residual)
if isinstance(output, tuple):
output = (output[0].view(orig_shape), output[1].view(orig_shape))
else:
output = output.view(orig_shape)
return output
def rmsnorm_flashinfer(
x: torch.Tensor,
weight: torch.Tensor,
residual: Optional[torch.Tensor] = None,
eps: float = 1e-6,
):
orig_shape = x.shape
x = x.view(-1, x.shape[-1])
if residual is not None:
residual = residual.view(-1, residual.shape[-1])
if residual is not None:
fused_add_rmsnorm(x, residual, weight, eps)
output = (x, residual)
else:
output = rmsnorm(x, weight, eps)
if isinstance(output, tuple):
output = (output[0].view(orig_shape), output[1].view(orig_shape))
else:
output = output.view(orig_shape)
return output
def rmsnorm_vllm(
x: torch.Tensor,
weight: torch.Tensor,
residual: Optional[torch.Tensor] = None,
eps: float = 1e-6,
):
orig_shape = x.shape
x = x.view(-1, x.shape[-1])
if residual is not None:
residual = residual.view(-1, residual.shape[-1])
if residual is not None:
vllm_ops.fused_add_rms_norm(x, residual, weight, eps)
output = (x, residual)
else:
out = torch.empty_like(x)
vllm_ops.rms_norm(out, x, weight, eps)
output = out
if isinstance(output, tuple):
output = (output[0].view(orig_shape), output[1].view(orig_shape))
else:
output = output.view(orig_shape)
return output
def rmsnorm_sglang(
x: torch.Tensor,
weight: torch.Tensor,
residual: Optional[torch.Tensor] = None,
eps: float = 1e-6,
enable_pdl: Optional[bool] = None,
):
orig_shape = x.shape
x = x.view(-1, x.shape[-1])
if residual is not None:
residual = residual.view(-1, residual.shape[-1])
if enable_pdl is None:
enable_pdl = is_arch_support_pdl()
if residual is not None:
sgl_kernel.fused_add_rmsnorm(x, residual, weight, eps, enable_pdl=enable_pdl)
output = (x, residual)
else:
out = torch.empty_like(x)
sgl_kernel.rmsnorm(x, weight, eps, out=out, enable_pdl=enable_pdl)
output = out
if isinstance(output, tuple):
output = (output[0].view(orig_shape), output[1].view(orig_shape))
else:
output = output.view(orig_shape)
return output
def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True):
dtype = torch.bfloat16
x = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device="cuda")
weight = torch.ones(hidden_size, dtype=dtype, device="cuda")
residual = torch.randn_like(x) if use_residual else None
output_naive = rmsnorm_naive(
x.clone(), weight, residual.clone() if residual is not None else None
)
output_flashinfer = rmsnorm_flashinfer(
x.clone(), weight, residual.clone() if residual is not None else None
)
output_vllm = rmsnorm_vllm(
x.clone(), weight, residual.clone() if residual is not None else None
)
output_sglang = rmsnorm_sglang(
x.clone(), weight, residual.clone() if residual is not None else None
)
if use_residual:
output_naive = output_naive[0]
output_flashinfer = output_flashinfer[0]
output_vllm = output_vllm[0]
output_sglang = output_sglang[0]
print(f"Naive output={output_naive}")
print(f"FlashInfer output={output_flashinfer}")
print(f"VLLM output={output_vllm}")
print(f"SGLang output={output_sglang}")
if (
torch.allclose(output_naive, output_flashinfer, atol=1e-2, rtol=1e-2)
and torch.allclose(output_naive, output_vllm, atol=1e-2, rtol=1e-2)
and torch.allclose(output_naive, output_sglang, atol=1e-2, rtol=1e-2)
):
print("✅ All implementations match")
else:
print("❌ Implementations differ")
default_batch_sizes = [2**i for i in range(0, 7, 2)] # 1, 4, 16, 64
default_seq_lens = [2**i for i in range(6, 11, 1)] # 64, 128, 256, 512, 1024
default_hidden_sizes = [32 * 128, 48 * 128] # 4096, 6144
def make_configs(bsizes: List[int], slens: List[int], hsizes: List[int]) -> List[Tuple]:
return list(itertools.product(bsizes, slens, hsizes))
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size", "seq_len", "hidden_size"],
x_vals=[],
line_arg="provider",
line_vals=["huggingface", "flashinfer", "vllm", "sglang"],
line_names=["HuggingFace", "FlashInfer", "vLLM", "SGL Kernel"],
styles=[("blue", "-"), ("green", "-"), ("red", "-"), ("orange", "-")],
ylabel="µs (median) or × (speed-up)",
plot_name="rmsnorm-performance",
args={},
)
)
def benchmark(batch_size, seq_len, hidden_size, provider, use_residual):
device = torch.device("cuda")
dtype = torch.bfloat16
x = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device=device)
weight = torch.ones(hidden_size, dtype=dtype, device=device)
residual = torch.randn_like(x) if use_residual else None
# timing helper
def timed(fn):
for _ in range(5):
fn()
torch.cuda.synchronize()
ms, qmin, qmax = triton.testing.do_bench_cudagraph(
fn, quantiles=[0.5, 0.2, 0.8]
)
return 1000 * ms, 1000 * qmax, 1000 * qmin
if provider == "huggingface":
return timed(
lambda: rmsnorm_naive(
x.clone(),
weight,
residual.clone() if residual is not None else None,
)
)
elif provider == "flashinfer":
return timed(
lambda: rmsnorm_flashinfer(
x.clone(),
weight,
residual.clone() if residual is not None else None,
)
)
elif provider == "vllm":
return timed(
lambda: rmsnorm_vllm(
x.clone(),
weight,
residual.clone() if residual is not None else None,
)
)
elif provider == "sglang":
return timed(
lambda: rmsnorm_sglang(
x.clone(),
weight,
residual.clone() if residual is not None else None,
)
)
# provider == "speedup"
t_ref, _, _ = timed(
lambda: rmsnorm_vllm(
x.clone(),
weight,
residual.clone() if residual is not None else None,
)
)
t_sgl, _, _ = timed(
lambda: rmsnorm_sglang(
x.clone(),
weight,
residual.clone() if residual is not None else None,
)
)
spd = t_ref / t_sgl
return (spd, spd, spd)
if __name__ == "__main__":
p = argparse.ArgumentParser("RMSNorm kernel benchmark")
p.add_argument("--batch_sizes", type=str2int_list, default=default_batch_sizes)
p.add_argument("--seq_lens", type=str2int_list, default=default_seq_lens)
p.add_argument("--hidden_sizes", type=str2int_list, default=default_hidden_sizes)
p.add_argument(
"--use_residual", action="store_true", help="Whether to use residual connection"
)
p.add_argument("--verify_only", action="store_true")
args = p.parse_args()
# coerce lists
if isinstance(args.batch_sizes, str):
args.batch_sizes = str2int_list(args.batch_sizes)
if isinstance(args.seq_lens, str):
args.seq_lens = str2int_list(args.seq_lens)
if isinstance(args.hidden_sizes, str):
args.hidden_sizes = str2int_list(args.hidden_sizes)
# patch perf_report grid
benchmark_grid = make_configs(args.batch_sizes, args.seq_lens, args.hidden_sizes)
if hasattr(benchmark, "benchmarks"):
benchmark.benchmarks.x_vals = benchmark_grid
else:
benchmark.benchmark.x_vals = benchmark_grid
if args.verify_only:
ok = calculate_diff(4, 128, args.hidden_sizes[0], args.use_residual)
print("✅ sanity pass" if ok else "❌ mismatch")
else:
benchmark.run(print_data=True, use_residual=args.use_residual)

View File

@@ -114,7 +114,9 @@ def benchmark_sampling(batch_size, vocab_size, p, provider):
filter_apply_order="joint",
)
ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=[0.5, 0.2, 0.8])
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
fn, quantiles=[0.5, 0.2, 0.8]
)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms

View File

@@ -3,6 +3,7 @@
import pytest
import sgl_kernel
import torch
from sgl_kernel.utils import is_arch_support_pdl
def llama_rms_norm(x, w, eps=1e-6):
@@ -58,11 +59,12 @@ def test_norm(batch_size, hidden_size, dtype, specify_out):
w = torch.randn(hidden_size).to(0).to(dtype)
y_ref = llama_rms_norm(x, w)
enable_pdl = is_arch_support_pdl()
if specify_out:
y = torch.empty_like(x)
sgl_kernel.rmsnorm(x, w, out=y)
sgl_kernel.rmsnorm(x, w, out=y, enable_pdl=enable_pdl)
else:
y = sgl_kernel.rmsnorm(x, w)
y = sgl_kernel.rmsnorm(x, w, enable_pdl=enable_pdl)
torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3)
@@ -83,7 +85,10 @@ def test_fused_add_rmsnorm(batch_size, hidden_size, dtype):
x_fused = x.clone()
residual_fused = residual.clone()
sgl_kernel.fused_add_rmsnorm(x_fused, residual_fused, weight, eps)
enable_pdl = is_arch_support_pdl()
sgl_kernel.fused_add_rmsnorm(
x_fused, residual_fused, weight, eps, enable_pdl=enable_pdl
)
torch.testing.assert_close(x_fused, x_native, rtol=1e-3, atol=1e-3)
torch.testing.assert_close(residual_fused, residual_native, rtol=1e-3, atol=1e-3)
@@ -98,11 +103,12 @@ def test_gemma_norm(batch_size, hidden_size, dtype, specify_out):
w = torch.randn(hidden_size).to(0).to(dtype)
y_ref = gemma_rms_norm(x, w)
enable_pdl = is_arch_support_pdl()
if specify_out:
y = torch.empty_like(x)
sgl_kernel.gemma_rmsnorm(x, w, out=y)
sgl_kernel.gemma_rmsnorm(x, w, out=y, enable_pdl=enable_pdl)
else:
y = sgl_kernel.gemma_rmsnorm(x, w)
y = sgl_kernel.gemma_rmsnorm(x, w, enable_pdl=enable_pdl)
torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3)
@@ -123,7 +129,10 @@ def test_gemma_fused_add_rmsnorm(batch_size, hidden_size, dtype):
x_fused = x.clone()
residual_fused = residual.clone()
sgl_kernel.gemma_fused_add_rmsnorm(x_fused, residual_fused, weight, eps)
enable_pdl = is_arch_support_pdl()
sgl_kernel.gemma_fused_add_rmsnorm(
x_fused, residual_fused, weight, eps, enable_pdl=enable_pdl
)
torch.testing.assert_close(x_fused, x_native, rtol=1e-3, atol=1e-3)
torch.testing.assert_close(residual_fused, residual_native, rtol=1e-3, atol=1e-3)