398 lines
12 KiB
Python
398 lines
12 KiB
Python
# Benchmarks SGLang RMSNorm kernels versus vLLM and FlashInfer across
|
||
# (batch_size, seq_len, hidden_size) and prints speed-up.
|
||
import argparse
|
||
import itertools
|
||
import os
|
||
import re
|
||
from typing import List, Optional, Tuple, Union
|
||
|
||
import sgl_kernel
|
||
import torch
|
||
import torch.nn as nn
|
||
import triton
|
||
import triton.testing
|
||
from sgl_kernel.utils import is_arch_support_pdl
|
||
|
||
# Optional imports
|
||
try:
|
||
from flashinfer.norm import fused_add_rmsnorm, rmsnorm
|
||
|
||
FLASHINFER_AVAILABLE = True
|
||
except ImportError:
|
||
fused_add_rmsnorm = None
|
||
rmsnorm = None
|
||
FLASHINFER_AVAILABLE = False
|
||
|
||
try:
|
||
from vllm import _custom_ops as vllm_ops
|
||
|
||
VLLM_AVAILABLE = True
|
||
except ImportError:
|
||
vllm_ops = None
|
||
VLLM_AVAILABLE = False
|
||
|
||
# CI environment detection
|
||
IS_CI = (
|
||
os.getenv("CI", "false").lower() == "true"
|
||
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
|
||
)
|
||
|
||
|
||
def str2int_list(arg: str) -> List[int]:
|
||
if arg in ("", None):
|
||
return []
|
||
if re.fullmatch(r"\d+(,\d+)*", arg.strip()) is None:
|
||
raise argparse.ArgumentTypeError(f"Bad int list: {arg}")
|
||
return [int(x) for x in arg.split(",")]
|
||
|
||
|
||
class HuggingFaceRMSNorm(nn.Module):
|
||
def __init__(self, hidden_size: int, eps: float = 1e-6) -> None:
|
||
super().__init__()
|
||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||
self.variance_epsilon = eps
|
||
|
||
def forward(
|
||
self,
|
||
x: torch.Tensor,
|
||
residual: Optional[torch.Tensor] = None,
|
||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||
orig_dtype = x.dtype
|
||
x = x.to(torch.float32)
|
||
if residual is not None:
|
||
x = x + residual.to(torch.float32)
|
||
residual = x.to(orig_dtype)
|
||
|
||
variance = x.pow(2).mean(dim=-1, keepdim=True)
|
||
x = x * torch.rsqrt(variance + self.variance_epsilon)
|
||
x = x.to(orig_dtype) * self.weight
|
||
if residual is None:
|
||
return x
|
||
else:
|
||
return x, residual
|
||
|
||
|
||
def rmsnorm_naive(
|
||
x: torch.Tensor,
|
||
weight: torch.Tensor,
|
||
residual: Optional[torch.Tensor] = None,
|
||
eps: float = 1e-6,
|
||
):
|
||
naive_norm = HuggingFaceRMSNorm(x.shape[-1], eps=eps)
|
||
naive_norm.weight = nn.Parameter(weight)
|
||
naive_norm = naive_norm.to(x.device)
|
||
|
||
orig_shape = x.shape
|
||
x = x.view(-1, x.shape[-1])
|
||
if residual is not None:
|
||
residual = residual.view(-1, residual.shape[-1])
|
||
|
||
output = naive_norm(x, residual)
|
||
|
||
if isinstance(output, tuple):
|
||
output = (output[0].view(orig_shape), output[1].view(orig_shape))
|
||
else:
|
||
output = output.view(orig_shape)
|
||
return output
|
||
|
||
|
||
def rmsnorm_flashinfer(
|
||
x: torch.Tensor,
|
||
weight: torch.Tensor,
|
||
residual: Optional[torch.Tensor] = None,
|
||
eps: float = 1e-6,
|
||
):
|
||
if not FLASHINFER_AVAILABLE:
|
||
# Fallback to naive implementation if FlashInfer is not available
|
||
return rmsnorm_naive(x, weight, residual, eps)
|
||
|
||
orig_shape = x.shape
|
||
x = x.view(-1, x.shape[-1])
|
||
if residual is not None:
|
||
residual = residual.view(-1, residual.shape[-1])
|
||
|
||
if residual is not None:
|
||
fused_add_rmsnorm(x, residual, weight, eps)
|
||
output = (x, residual)
|
||
else:
|
||
output = rmsnorm(x, weight, eps)
|
||
|
||
if isinstance(output, tuple):
|
||
output = (output[0].view(orig_shape), output[1].view(orig_shape))
|
||
else:
|
||
output = output.view(orig_shape)
|
||
return output
|
||
|
||
|
||
def rmsnorm_vllm(
|
||
x: torch.Tensor,
|
||
weight: torch.Tensor,
|
||
residual: Optional[torch.Tensor] = None,
|
||
eps: float = 1e-6,
|
||
):
|
||
if not VLLM_AVAILABLE:
|
||
# Fallback to naive implementation if vLLM is not available
|
||
return rmsnorm_naive(x, weight, residual, eps)
|
||
|
||
orig_shape = x.shape
|
||
x = x.view(-1, x.shape[-1])
|
||
if residual is not None:
|
||
residual = residual.view(-1, residual.shape[-1])
|
||
|
||
if residual is not None:
|
||
vllm_ops.fused_add_rms_norm(x, residual, weight, eps)
|
||
output = (x, residual)
|
||
else:
|
||
out = torch.empty_like(x)
|
||
vllm_ops.rms_norm(out, x, weight, eps)
|
||
output = out
|
||
|
||
if isinstance(output, tuple):
|
||
output = (output[0].view(orig_shape), output[1].view(orig_shape))
|
||
else:
|
||
output = output.view(orig_shape)
|
||
return output
|
||
|
||
|
||
def rmsnorm_sglang(
|
||
x: torch.Tensor,
|
||
weight: torch.Tensor,
|
||
residual: Optional[torch.Tensor] = None,
|
||
eps: float = 1e-6,
|
||
enable_pdl: Optional[bool] = None,
|
||
):
|
||
orig_shape = x.shape
|
||
x = x.view(-1, x.shape[-1])
|
||
if residual is not None:
|
||
residual = residual.view(-1, residual.shape[-1])
|
||
|
||
if enable_pdl is None:
|
||
enable_pdl = is_arch_support_pdl()
|
||
|
||
if residual is not None:
|
||
sgl_kernel.fused_add_rmsnorm(x, residual, weight, eps, enable_pdl=enable_pdl)
|
||
output = (x, residual)
|
||
else:
|
||
out = torch.empty_like(x)
|
||
sgl_kernel.rmsnorm(x, weight, eps, out=out, enable_pdl=enable_pdl)
|
||
output = out
|
||
|
||
if isinstance(output, tuple):
|
||
output = (output[0].view(orig_shape), output[1].view(orig_shape))
|
||
else:
|
||
output = output.view(orig_shape)
|
||
return output
|
||
|
||
|
||
def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True):
|
||
dtype = torch.bfloat16
|
||
x = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device="cuda")
|
||
weight = torch.ones(hidden_size, dtype=dtype, device="cuda")
|
||
residual = torch.randn_like(x) if use_residual else None
|
||
|
||
output_naive = rmsnorm_naive(
|
||
x.clone(), weight, residual.clone() if residual is not None else None
|
||
)
|
||
output_flashinfer = rmsnorm_flashinfer(
|
||
x.clone(), weight, residual.clone() if residual is not None else None
|
||
)
|
||
output_vllm = rmsnorm_vllm(
|
||
x.clone(), weight, residual.clone() if residual is not None else None
|
||
)
|
||
output_sglang = rmsnorm_sglang(
|
||
x.clone(), weight, residual.clone() if residual is not None else None
|
||
)
|
||
|
||
if use_residual:
|
||
output_naive = output_naive[0]
|
||
output_flashinfer = output_flashinfer[0]
|
||
output_vllm = output_vllm[0]
|
||
output_sglang = output_sglang[0]
|
||
|
||
print(f"Naive output={output_naive}")
|
||
if FLASHINFER_AVAILABLE:
|
||
print(f"FlashInfer output={output_flashinfer}")
|
||
else:
|
||
print("FlashInfer not available, skipped")
|
||
if VLLM_AVAILABLE:
|
||
print(f"VLLM output={output_vllm}")
|
||
else:
|
||
print("vLLM not available, skipped")
|
||
print(f"SGLang output={output_sglang}")
|
||
|
||
# Only compare available implementations
|
||
all_match = torch.allclose(output_naive, output_sglang, atol=1e-2, rtol=1e-2)
|
||
if FLASHINFER_AVAILABLE:
|
||
all_match = all_match and torch.allclose(
|
||
output_naive, output_flashinfer, atol=1e-2, rtol=1e-2
|
||
)
|
||
if VLLM_AVAILABLE:
|
||
all_match = all_match and torch.allclose(
|
||
output_naive, output_vllm, atol=1e-2, rtol=1e-2
|
||
)
|
||
|
||
if all_match:
|
||
print("✅ All available implementations match")
|
||
else:
|
||
print("❌ Implementations differ")
|
||
|
||
|
||
# CI environment uses simplified parameters
|
||
if IS_CI:
|
||
default_batch_sizes = [1] # Single batch size for CI
|
||
default_seq_lens = [64] # Single sequence length for CI
|
||
default_hidden_sizes = [4096] # Single hidden size for CI
|
||
else:
|
||
default_batch_sizes = [2**i for i in range(0, 7, 2)] # 1, 4, 16, 64
|
||
default_seq_lens = [2**i for i in range(6, 11, 1)] # 64, 128, 256, 512, 1024
|
||
default_hidden_sizes = [32 * 128, 48 * 128] # 4096, 6144
|
||
|
||
|
||
def make_configs(bsizes: List[int], slens: List[int], hsizes: List[int]) -> List[Tuple]:
|
||
return list(itertools.product(bsizes, slens, hsizes))
|
||
|
||
|
||
# Filter providers based on availability
|
||
available_providers = ["huggingface", "sglang"]
|
||
available_names = ["HuggingFace", "SGL Kernel"]
|
||
available_styles = [("blue", "-"), ("orange", "-")]
|
||
|
||
if FLASHINFER_AVAILABLE:
|
||
available_providers.insert(-1, "flashinfer")
|
||
available_names.insert(-1, "FlashInfer")
|
||
available_styles.insert(-1, ("green", "-"))
|
||
|
||
if VLLM_AVAILABLE:
|
||
available_providers.insert(-1, "vllm")
|
||
available_names.insert(-1, "vLLM")
|
||
available_styles.insert(-1, ("red", "-"))
|
||
|
||
|
||
@triton.testing.perf_report(
|
||
triton.testing.Benchmark(
|
||
x_names=["batch_size", "seq_len", "hidden_size"],
|
||
x_vals=[],
|
||
line_arg="provider",
|
||
line_vals=available_providers,
|
||
line_names=available_names,
|
||
styles=available_styles,
|
||
ylabel="µs (median) or × (speed-up)",
|
||
plot_name="rmsnorm-performance",
|
||
args={},
|
||
)
|
||
)
|
||
def benchmark(batch_size, seq_len, hidden_size, provider, use_residual):
|
||
device = torch.device("cuda")
|
||
dtype = torch.bfloat16
|
||
|
||
x = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device=device)
|
||
weight = torch.ones(hidden_size, dtype=dtype, device=device)
|
||
residual = torch.randn_like(x) if use_residual else None
|
||
|
||
# timing helper
|
||
def timed(fn):
|
||
for _ in range(5):
|
||
fn()
|
||
torch.cuda.synchronize()
|
||
ms, qmin, qmax = triton.testing.do_bench_cudagraph(
|
||
fn, quantiles=[0.5, 0.2, 0.8]
|
||
)
|
||
return 1000 * ms, 1000 * qmax, 1000 * qmin
|
||
|
||
if provider == "huggingface":
|
||
return timed(
|
||
lambda: rmsnorm_naive(
|
||
x.clone(),
|
||
weight,
|
||
residual.clone() if residual is not None else None,
|
||
)
|
||
)
|
||
elif provider == "flashinfer":
|
||
if not FLASHINFER_AVAILABLE:
|
||
return (0, 0, 0)
|
||
return timed(
|
||
lambda: rmsnorm_flashinfer(
|
||
x.clone(),
|
||
weight,
|
||
residual.clone() if residual is not None else None,
|
||
)
|
||
)
|
||
elif provider == "vllm":
|
||
if not VLLM_AVAILABLE:
|
||
return (0, 0, 0)
|
||
return timed(
|
||
lambda: rmsnorm_vllm(
|
||
x.clone(),
|
||
weight,
|
||
residual.clone() if residual is not None else None,
|
||
)
|
||
)
|
||
elif provider == "sglang":
|
||
return timed(
|
||
lambda: rmsnorm_sglang(
|
||
x.clone(),
|
||
weight,
|
||
residual.clone() if residual is not None else None,
|
||
)
|
||
)
|
||
|
||
# provider == "speedup"
|
||
if VLLM_AVAILABLE:
|
||
t_ref, _, _ = timed(
|
||
lambda: rmsnorm_vllm(
|
||
x.clone(),
|
||
weight,
|
||
residual.clone() if residual is not None else None,
|
||
)
|
||
)
|
||
else:
|
||
t_ref, _, _ = timed(
|
||
lambda: rmsnorm_naive(
|
||
x.clone(),
|
||
weight,
|
||
residual.clone() if residual is not None else None,
|
||
)
|
||
)
|
||
t_sgl, _, _ = timed(
|
||
lambda: rmsnorm_sglang(
|
||
x.clone(),
|
||
weight,
|
||
residual.clone() if residual is not None else None,
|
||
)
|
||
)
|
||
spd = t_ref / t_sgl if t_ref > 0 else 1.0
|
||
return (spd, spd, spd)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
p = argparse.ArgumentParser("RMSNorm kernel benchmark")
|
||
p.add_argument("--batch_sizes", type=str2int_list, default=default_batch_sizes)
|
||
p.add_argument("--seq_lens", type=str2int_list, default=default_seq_lens)
|
||
p.add_argument("--hidden_sizes", type=str2int_list, default=default_hidden_sizes)
|
||
p.add_argument(
|
||
"--use_residual", action="store_true", help="Whether to use residual connection"
|
||
)
|
||
p.add_argument("--verify_only", action="store_true")
|
||
args = p.parse_args()
|
||
|
||
# coerce lists
|
||
if isinstance(args.batch_sizes, str):
|
||
args.batch_sizes = str2int_list(args.batch_sizes)
|
||
if isinstance(args.seq_lens, str):
|
||
args.seq_lens = str2int_list(args.seq_lens)
|
||
if isinstance(args.hidden_sizes, str):
|
||
args.hidden_sizes = str2int_list(args.hidden_sizes)
|
||
|
||
# patch perf_report grid
|
||
benchmark_grid = make_configs(args.batch_sizes, args.seq_lens, args.hidden_sizes)
|
||
if hasattr(benchmark, "benchmarks"):
|
||
benchmark.benchmarks.x_vals = benchmark_grid
|
||
else:
|
||
benchmark.benchmark.x_vals = benchmark_grid
|
||
|
||
if args.verify_only:
|
||
ok = calculate_diff(4, 128, args.hidden_sizes[0], args.use_residual)
|
||
print("✅ sanity pass" if ok else "❌ mismatch")
|
||
else:
|
||
benchmark.run(print_data=True, use_residual=args.use_residual)
|