diff --git a/benchmark/kernels/deepseek/benchmark_deepgemm_fp8_gemm.py b/benchmark/kernels/deepseek/benchmark_deepgemm_fp8_gemm.py index f93732154..bd02e2aee 100644 --- a/benchmark/kernels/deepseek/benchmark_deepgemm_fp8_gemm.py +++ b/benchmark/kernels/deepseek/benchmark_deepgemm_fp8_gemm.py @@ -5,7 +5,8 @@ import tilelang import tilelang.language as T import torch import triton -from deep_gemm import ceil_div, get_col_major_tma_aligned_tensor +from deep_gemm import ceil_div +from deep_gemm.utils.layout import get_mn_major_tma_aligned_tensor from vllm.model_executor.layers.quantization.utils.fp8_utils import ( w8a8_block_fp8_matmul as vllm_w8a8_block_fp8_matmul, ) @@ -131,7 +132,7 @@ def fp8_gemm_deepgemm( out = torch.empty((m, n), device="cuda", dtype=torch.bfloat16) # Run DeepGEMM kernel - deep_gemm.gemm_fp8_fp8_bf16_nt((x_fp8, x_scale), (y_fp8, y_scale), out) + deep_gemm.fp8_gemm_nt((x_fp8, x_scale), (y_fp8, y_scale), out) return out @@ -179,7 +180,7 @@ def calculate_diff(m: int, n: int, k: int): x_fp8, x_scale = per_token_cast_to_fp8(x.clone()) y_fp8, y_scale = per_block_cast_to_fp8(y.clone()) - x_scale_col_major = get_col_major_tma_aligned_tensor(x_scale.clone()) + x_scale_col_major = get_mn_major_tma_aligned_tensor(x_scale.clone()) out_deepgemm = fp8_gemm_deepgemm( x_fp8.clone(), @@ -300,7 +301,7 @@ def get_benchmark(tp_size): # Preprocess data before benchmarking x_fp8, x_scale = per_token_cast_to_fp8(x) y_fp8, y_scale = per_block_cast_to_fp8(y) - x_scale_col_major = get_col_major_tma_aligned_tensor(x_scale.clone()) + x_scale_col_major = get_mn_major_tma_aligned_tensor(x_scale.clone()) quantiles = [0.5, 0.2, 0.8] diff --git a/benchmark/kernels/deepseek/benchmark_deepgemm_fp8_group_gemm.py b/benchmark/kernels/deepseek/benchmark_deepgemm_fp8_group_gemm.py index 2c3e8dfcc..b2cea0705 100644 --- a/benchmark/kernels/deepseek/benchmark_deepgemm_fp8_group_gemm.py +++ b/benchmark/kernels/deepseek/benchmark_deepgemm_fp8_group_gemm.py @@ -4,7 +4,8 @@ import deep_gemm import torch import triton import triton.language as tl -from deep_gemm import calc_diff, get_col_major_tma_aligned_tensor +from deep_gemm import calc_diff +from deep_gemm.utils.layout import get_mn_major_tma_aligned_tensor # Import shared functionality from the regular GEMM benchmark from sglang.benchmark.kernels.deepseek.benchmark_deepgemm_fp8_gemm import ( @@ -71,9 +72,9 @@ def construct_grouped_and_flat_fp8( # Transpose earlier for testing x_fp8_grouped = ( x_fp8_grouped[0], - get_col_major_tma_aligned_tensor(x_fp8_grouped[1]), + get_mn_major_tma_aligned_tensor(x_fp8_grouped[1]), ) - x_fp8_flat = (x_fp8_flat[0], get_col_major_tma_aligned_tensor(x_fp8_flat[1])) + x_fp8_flat = (x_fp8_flat[0], get_mn_major_tma_aligned_tensor(x_fp8_flat[1])) return x_fp8_grouped, y_fp8_grouped, x_fp8_flat, y_fp8_flat, out, ref_out @@ -240,7 +241,7 @@ def fp8_gemm_group_triton(a_tuple, b_tuple, c, num_groups): def fp8_gemm_group_deepgemm(x_fp8_grouped, y_fp8_grouped, out, m_indices): - deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( + deep_gemm.m_grouped_fp8_gemm_nt_contiguous( x_fp8_grouped, y_fp8_grouped, out, diff --git a/benchmark/kernels/rmsnorm/benchmark_rmsnorm.py b/benchmark/kernels/rmsnorm/benchmark_rmsnorm.py deleted file mode 100644 index aeeea62c0..000000000 --- a/benchmark/kernels/rmsnorm/benchmark_rmsnorm.py +++ /dev/null @@ -1,230 +0,0 @@ -import itertools -from typing import Optional, Tuple, Union - -import torch -import triton -from flashinfer.norm import fused_add_rmsnorm, rmsnorm -from torch import nn -from vllm import _custom_ops as vllm_ops - - -class HuggingFaceRMSNorm(nn.Module): - def __init__(self, hidden_size: int, eps: float = 1e-6) -> None: - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward( - self, - x: torch.Tensor, - residual: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - orig_dtype = x.dtype - x = x.to(torch.float32) - if residual is not None: - x = x + residual.to(torch.float32) - residual = x.to(orig_dtype) - - variance = x.pow(2).mean(dim=-1, keepdim=True) - x = x * torch.rsqrt(variance + self.variance_epsilon) - x = x.to(orig_dtype) * self.weight - if residual is None: - return x - else: - return x, residual - - -def rmsnorm_naive( - x: torch.Tensor, - weight: torch.Tensor, - residual: Optional[torch.Tensor] = None, - eps: float = 1e-6, -): - naive_norm = HuggingFaceRMSNorm(x.shape[-1], eps=eps) - naive_norm.weight = nn.Parameter(weight) - naive_norm = naive_norm.to(x.device) - - orig_shape = x.shape - x = x.view(-1, x.shape[-1]) - if residual is not None: - residual = residual.view(-1, residual.shape[-1]) - - output = naive_norm(x, residual) - - if isinstance(output, tuple): - output = (output[0].view(orig_shape), output[1].view(orig_shape)) - else: - output = output.view(orig_shape) - return output - - -def rmsnorm_flashinfer( - x: torch.Tensor, - weight: torch.Tensor, - residual: Optional[torch.Tensor] = None, - eps: float = 1e-6, -): - orig_shape = x.shape - x = x.view(-1, x.shape[-1]) - if residual is not None: - residual = residual.view(-1, residual.shape[-1]) - - if residual is not None: - fused_add_rmsnorm(x, residual, weight, eps) - output = (x, residual) - else: - output = rmsnorm(x, weight, eps) - - if isinstance(output, tuple): - output = (output[0].view(orig_shape), output[1].view(orig_shape)) - else: - output = output.view(orig_shape) - return output - - -def rmsnorm_vllm( - x: torch.Tensor, - weight: torch.Tensor, - residual: Optional[torch.Tensor] = None, - eps: float = 1e-6, -): - orig_shape = x.shape - x = x.view(-1, x.shape[-1]) - if residual is not None: - residual = residual.view(-1, residual.shape[-1]) - - if residual is not None: - vllm_ops.fused_add_rms_norm(x, residual, weight, eps) - output = (x, residual) - else: - out = torch.empty_like(x) - vllm_ops.rms_norm(out, x, weight, eps) - output = out - - if isinstance(output, tuple): - output = (output[0].view(orig_shape), output[1].view(orig_shape)) - else: - output = output.view(orig_shape) - return output - - -def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True): - dtype = torch.bfloat16 - x = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device="cuda") - weight = torch.ones(hidden_size, dtype=dtype, device="cuda") - residual = torch.randn_like(x) if use_residual else None - - output_naive = rmsnorm_naive( - x.clone(), weight, residual.clone() if residual is not None else None - ) - output_flashinfer = rmsnorm_flashinfer( - x.clone(), weight, residual.clone() if residual is not None else None - ) - output_vllm = rmsnorm_vllm( - x.clone(), weight, residual.clone() if residual is not None else None - ) - - if use_residual: - output_naive = output_naive[0] - output_flashinfer = output_flashinfer[0] - output_vllm = output_vllm[0] - - print(f"Naive output={output_naive}") - print(f"FlashInfer output={output_flashinfer}") - print(f"VLLM output={output_vllm}") - - if torch.allclose( - output_naive, output_flashinfer, atol=1e-2, rtol=1e-2 - ) and torch.allclose(output_naive, output_vllm, atol=1e-2, rtol=1e-2): - print("✅ All implementations match") - else: - print("❌ Implementations differ") - - -batch_size_range = [2**i for i in range(0, 7, 2)] -seq_length_range = [2**i for i in range(6, 11, 1)] -head_num_range = [32, 48] -configs = list(itertools.product(head_num_range, batch_size_range, seq_length_range)) - - -def get_benchmark(use_residual): - @triton.testing.perf_report( - triton.testing.Benchmark( - x_names=["head_num", "batch_size", "seq_len"], - x_vals=[list(_) for _ in configs], - line_arg="provider", - line_vals=["huggingface", "flashinfer", "vllm"], - line_names=["HuggingFace", "FlashInfer", "vLLM"], - styles=[("blue", "-"), ("green", "-"), ("red", "-")], - ylabel="us", - plot_name=f"rmsnorm-performance-{'with' if use_residual else 'without'}-residual", - args={}, - ) - ) - def benchmark(head_num, batch_size, seq_len, provider): - dtype = torch.bfloat16 - hidden_size = head_num * 128 # assuming head_dim = 128 - - x = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device="cuda") - weight = torch.ones(hidden_size, dtype=dtype, device="cuda") - residual = torch.randn_like(x) if use_residual else None - - quantiles = [0.5, 0.2, 0.8] - - if provider == "huggingface": - ms, min_ms, max_ms = triton.testing.do_bench( - lambda: rmsnorm_naive( - x.clone(), - weight, - residual.clone() if residual is not None else None, - ), - quantiles=quantiles, - ) - elif provider == "flashinfer": - ms, min_ms, max_ms = triton.testing.do_bench( - lambda: rmsnorm_flashinfer( - x.clone(), - weight, - residual.clone() if residual is not None else None, - ), - quantiles=quantiles, - ) - else: - ms, min_ms, max_ms = triton.testing.do_bench( - lambda: rmsnorm_vllm( - x.clone(), - weight, - residual.clone() if residual is not None else None, - ), - quantiles=quantiles, - ) - - return 1000 * ms, 1000 * max_ms, 1000 * min_ms - - return benchmark - - -if __name__ == "__main__": - import argparse - - parser = argparse.ArgumentParser() - parser.add_argument( - "--use_residual", action="store_true", help="Whether to use residual connection" - ) - parser.add_argument( - "--save_path", - type=str, - default="./configs/benchmark_ops/rmsnorm/", - help="Path to save rmsnorm benchmark results", - ) - args = parser.parse_args() - - # Run correctness test - calculate_diff( - batch_size=4, seq_len=128, hidden_size=4096, use_residual=args.use_residual - ) - - # Get the benchmark function with proper use_residual setting - benchmark = get_benchmark(args.use_residual) - # Run performance benchmark - benchmark.run(print_data=True, save_path=args.save_path) diff --git a/python/sglang/srt/layers/quantization/w4afp8.py b/python/sglang/srt/layers/quantization/w4afp8.py index e95247041..158ae6561 100644 --- a/python/sglang/srt/layers/quantization/w4afp8.py +++ b/python/sglang/srt/layers/quantization/w4afp8.py @@ -19,10 +19,6 @@ from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod from sglang.srt.layers.quantization.utils import is_layer_skipped from sglang.srt.utils import is_npu, set_weight_attrs -_is_npu = is_npu() -if not _is_npu: - from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe - if TYPE_CHECKING: from sglang.srt.layers.moe import MoeRunnerConfig from sglang.srt.layers.moe.ep_moe.layer import EPMoE diff --git a/python/sglang/test/test_block_fp8.py b/python/sglang/test/test_block_fp8.py index 45271e116..80202d15e 100644 --- a/python/sglang/test/test_block_fp8.py +++ b/python/sglang/test/test_block_fp8.py @@ -621,11 +621,11 @@ class TestW8A8BlockFP8BatchedDeepGemm(CustomTestCase): w_s, ) - from deep_gemm import m_grouped_gemm_fp8_fp8_bf16_nt_masked + from deep_gemm import fp8_m_grouped_gemm_nt_masked with torch.inference_mode(): ref_out = torch_w8a8_block_fp8_bmm(a, a_s, w, w_s, block_size, dtype) - m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs, rhs, oe, masked_m, expected_m) + fp8_m_grouped_gemm_nt_masked(lhs, rhs, oe, masked_m, expected_m) out = oe[:, :M, :] self.assertTrue( diff --git a/sgl-kernel/README.md b/sgl-kernel/README.md index 06e285101..f86d5851f 100644 --- a/sgl-kernel/README.md +++ b/sgl-kernel/README.md @@ -251,6 +251,14 @@ To use this with your library functions, simply wrap them with make_pytorch_shim ``` 2. Add benchmarks using [triton benchmark](https://triton-lang.org/main/python-api/generated/triton.testing.Benchmark.html) in [benchmark/](https://github.com/sgl-project/sglang/tree/main/sgl-kernel/benchmark) + + **We recommend using `triton.testing.do_bench_cudagraph` for kernel benchmarking**: + + Compared to `triton.testing.do_bench`, `do_bench_cudagraph` provides: + - Reduced CPU overhead impact for more accurate kernel performance measurements + - Incorporation of PDL (Programmatic Dependent Launch) effects into individual kernel results + - More realistic performance data on PDL-supported architectures (SM >= 90) + 3. Run test suite ### FAQ diff --git a/sgl-kernel/benchmark/bench_activation.py b/sgl-kernel/benchmark/bench_activation.py index cfea78915..0c59cceee 100644 --- a/sgl-kernel/benchmark/bench_activation.py +++ b/sgl-kernel/benchmark/bench_activation.py @@ -10,10 +10,18 @@ import torch import torch.nn.functional as F import triton import triton.testing -from sgl_kernel import gelu_quick # activation-only kernel from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul from vllm import _custom_ops as vllm_ops +# gelu_quick is only available on HIP/ROCm platforms +try: + from sgl_kernel import gelu_quick + + GELU_QUICK_AVAILABLE = True +except ImportError: + GELU_QUICK_AVAILABLE = False + gelu_quick = None + if not hasattr(vllm_ops, "silu_and_mul"): vllm_ops = torch.ops._C @@ -34,6 +42,12 @@ def calculate_diff( # activation-only quick GELU if kernel == "gelu_quick": + if not GELU_QUICK_AVAILABLE: + print( + f"[{kernel:14s} | {str(dtype):9s} | B={batch_size:3d} | " + f"L={seq_len:3d} | D={dim:5d}] ⚠️ not available on this platform" + ) + return True x = torch.randn(batch_size, seq_len, dim, dtype=dtype, device=device) ref_out = torch.zeros_like(x) getattr(vllm_ops, kernel)(ref_out, x) @@ -54,7 +68,9 @@ def calculate_diff( return ok -kernels = ["silu_and_mul", "gelu_and_mul", "gelu_tanh_and_mul", "gelu_quick"] +kernels = ["silu_and_mul", "gelu_and_mul", "gelu_tanh_and_mul"] +if GELU_QUICK_AVAILABLE: + kernels.append("gelu_quick") dtypes = [torch.float16, torch.bfloat16] @@ -64,7 +80,7 @@ def make_configs(bsizes: List[int], slens: List[int], dims_: List[int]) -> List[ default_batch_sizes = [2**i for i in range(0, 5, 2)] # 1,4,16 default_seq_lens = [2**i for i in range(0, 8, 2)] # 1,4,16,64 -default_dims = [2**i for i in range(7, 15)] # 128...16384 +default_dims = [2**i for i in range(10, 15)] # 1024...16384 @triton.testing.perf_report( @@ -87,6 +103,9 @@ def benchmark(kernel, dtype, batch_size, seq_len, dim, provider): y0 = torch.zeros(batch_size, seq_len, dim, dtype=dtype, device=device) vllm_kernel = getattr(vllm_ops, kernel) + if kernel == "gelu_quick" and not GELU_QUICK_AVAILABLE: + # Skip benchmark for gelu_quick if not available + return (0, 0, 0) sglang_kernel = getattr(sgl_kernel, kernel) def baseline(): @@ -97,18 +116,14 @@ def benchmark(kernel, dtype, batch_size, seq_len, dim, provider): def sglang(): return sglang_kernel(x) - # one-time correctness check - if provider == "vllm" and not calculate_diff( - kernel, dtype, batch_size, seq_len, dim - ): - raise ValueError("Mismatch – abort benchmark") - # timing helper def timed(fn): for _ in range(5): fn() torch.cuda.synchronize() - ms, qmin, qmax = triton.testing.do_bench(fn, quantiles=[0.5, 0.2, 0.8]) + ms, qmin, qmax = triton.testing.do_bench_cudagraph( + fn, quantiles=[0.5, 0.2, 0.8] + ) return 1000 * ms, 1000 * qmax, 1000 * qmin if provider == "vllm": @@ -147,7 +162,9 @@ if __name__ == "__main__": benchmark.benchmark.x_vals = benchmark_grid if args.verify_only: - ok = calculate_diff("gelu_quick", torch.float16, 1, 1, args.dims[0]) + # Test with the first available kernel + test_kernel = kernels[0] + ok = calculate_diff(test_kernel, torch.float16, 1, 1, args.dims[0]) print("✅ sanity pass" if ok else "❌ mismatch") else: benchmark.run(print_data=True) diff --git a/sgl-kernel/benchmark/bench_awq_dequant.py b/sgl-kernel/benchmark/bench_awq_dequant.py index 22280c250..a906894c9 100644 --- a/sgl-kernel/benchmark/bench_awq_dequant.py +++ b/sgl-kernel/benchmark/bench_awq_dequant.py @@ -108,7 +108,7 @@ def benchmark(qweight_row, qweight_col, provider): qweight.clone(), scales.clone(), qzeros.clone() ) - ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=quantiles) + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=quantiles) return 1000 * ms, 1000 * max_ms, 1000 * min_ms diff --git a/sgl-kernel/benchmark/bench_cutlass_mla.py b/sgl-kernel/benchmark/bench_cutlass_mla.py index 785e51033..e01fdc110 100644 --- a/sgl-kernel/benchmark/bench_cutlass_mla.py +++ b/sgl-kernel/benchmark/bench_cutlass_mla.py @@ -87,7 +87,7 @@ def benchmark(batch_size, seq_len, provider, block_size, num_kv_splits): workspace = torch.empty(workspace_size, device="cuda", dtype=torch.uint8) quantiles = [0.5, 0.2, 0.8] - ms, min_ms, max_ms = triton.testing.do_bench( + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( lambda: cutlass_mla_decode( qn.transpose(0, 1), qr, @@ -136,8 +136,6 @@ if __name__ == "__main__": print(f"block_size={block_size}, num_kv_splits={kv_split}: ") benchmark.run( print_data=True, - show_plots=True, - save_path="bench_blackwell_mla_res", block_size=block_size, num_kv_splits=kv_split, ) diff --git a/sgl-kernel/benchmark/bench_dsv3_fused_a_gemm.py b/sgl-kernel/benchmark/bench_dsv3_fused_a_gemm.py index 8c1e29980..a00c59ad0 100644 --- a/sgl-kernel/benchmark/bench_dsv3_fused_a_gemm.py +++ b/sgl-kernel/benchmark/bench_dsv3_fused_a_gemm.py @@ -41,7 +41,7 @@ def benchmark(num_tokens, impl): def runner(): dsv3_fused_a_gemm(mat_a, mat_b) - ms, min_ms, max_ms = triton.testing.do_bench(runner, quantiles=quantiles) + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(runner, quantiles=quantiles) def tflops(t_ms): flops = 2 * M * K * N @@ -54,4 +54,4 @@ if __name__ == "__main__": parser = argparse.ArgumentParser() args = parser.parse_args() - benchmark.run(print_data=True, show_plots=True, save_path="bench_dsv3_gemm") + benchmark.run(print_data=True) diff --git a/sgl-kernel/benchmark/bench_dsv3_router_gemm.py b/sgl-kernel/benchmark/bench_dsv3_router_gemm.py index dee090e21..b699d2580 100644 --- a/sgl-kernel/benchmark/bench_dsv3_router_gemm.py +++ b/sgl-kernel/benchmark/bench_dsv3_router_gemm.py @@ -52,7 +52,7 @@ def benchmark_bf16_output(num_tokens, impl): def runner(): dsv3_router_gemm(mat_a, mat_b, out_dtype=torch.bfloat16) - ms, min_ms, max_ms = triton.testing.do_bench(runner, quantiles=quantiles) + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(runner, quantiles=quantiles) def tflops(t_ms): flops = 2 * M * K * N @@ -106,7 +106,7 @@ def benchmark_float_output(num_tokens, impl): def runner(): dsv3_router_gemm(mat_a, mat_b, out_dtype=torch.float32) - ms, min_ms, max_ms = triton.testing.do_bench(runner, quantiles=quantiles) + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(runner, quantiles=quantiles) def tflops(t_ms): flops = 2 * M * K * N @@ -119,9 +119,5 @@ if __name__ == "__main__": parser = argparse.ArgumentParser() args = parser.parse_args() - benchmark_bf16_output.run( - print_data=True, show_plots=True, save_path="bench_dsv3_router_gemm" - ) - benchmark_float_output.run( - print_data=True, show_plots=True, save_path="bench_dsv3_router_gemm" - ) + benchmark_bf16_output.run(print_data=True) + benchmark_float_output.run(print_data=True) diff --git a/sgl-kernel/benchmark/bench_fp4_gemm.py b/sgl-kernel/benchmark/bench_fp4_gemm.py index 80773eb07..e8d467702 100755 --- a/sgl-kernel/benchmark/bench_fp4_gemm.py +++ b/sgl-kernel/benchmark/bench_fp4_gemm.py @@ -198,8 +198,6 @@ if __name__ == "__main__": print(f"DeepSeek-R1-0528-FP4 N={N} K={K}: ") benchmark.run( print_data=True, - show_plots=True, - save_path="bench_fp4_res", N=N, K=K, dtype=args.dtype, diff --git a/sgl-kernel/benchmark/bench_fp8_blockwise_gemm.py b/sgl-kernel/benchmark/bench_fp8_blockwise_gemm.py index ed0410298..8843457f9 100644 --- a/sgl-kernel/benchmark/bench_fp8_blockwise_gemm.py +++ b/sgl-kernel/benchmark/bench_fp8_blockwise_gemm.py @@ -5,7 +5,7 @@ import itertools import deep_gemm import torch import triton -from deep_gemm import get_col_major_tma_aligned_tensor +from deep_gemm.utils.layout import get_mn_major_tma_aligned_tensor from sgl_kernel import fp8_blockwise_scaled_mm from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm @@ -71,7 +71,7 @@ def fp8_gemm_deepgemm( out = torch.empty((m, n), device="cuda", dtype=torch.bfloat16) # Run DeepGEMM kernel - deep_gemm.gemm_fp8_fp8_bf16_nt((x_fp8, x_scale), (y_fp8, y_scale), out) + deep_gemm.fp8_gemm_nt((x_fp8, x_scale), (y_fp8, y_scale), out) return out @@ -117,7 +117,7 @@ def benchmark(batch_size, provider, N, K): if provider == "sgl-kernel": scale_a = scale_a.t().contiguous().t() b_fp8, scale_b = b_fp8.t(), scale_b.t() - ms, min_ms, max_ms = triton.testing.do_bench( + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( lambda: fp8_blockwise_scaled_mm( a_fp8, b_fp8, scale_a, scale_b, torch.float16 ), @@ -126,20 +126,20 @@ def benchmark(batch_size, provider, N, K): if provider == "vllm": scale_a = scale_a.t().contiguous().t() b_fp8, scale_b = b_fp8.t(), scale_b.t() - ms, min_ms, max_ms = triton.testing.do_bench( + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( lambda: vllm_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, torch.float16), quantiles=quantiles, ) if provider == "triton": - ms, min_ms, max_ms = triton.testing.do_bench( + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( lambda: w8a8_block_fp8_matmul( a_fp8, b_fp8, scale_a, scale_b, [128, 128], torch.float16 ), quantiles=quantiles, ) if provider == "deepgemm": - scale_a_col_major = get_col_major_tma_aligned_tensor(scale_a.clone()) - ms, min_ms, max_ms = triton.testing.do_bench( + scale_a_col_major = get_mn_major_tma_aligned_tensor(scale_a.clone()) + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( lambda: fp8_gemm_deepgemm( a_fp8, scale_a_col_major, b_fp8, scale_b, M, N, K ), @@ -174,8 +174,6 @@ if __name__ == "__main__": print(f"{model_name} N={N} K={K}: ") benchmark.run( print_data=True, - show_plots=True, - save_path="bench_fp8_blockwise_res", N=N, K=K, ) diff --git a/sgl-kernel/benchmark/bench_fp8_gemm.py b/sgl-kernel/benchmark/bench_fp8_gemm.py index 5f16ca028..f94c32bec 100644 --- a/sgl-kernel/benchmark/bench_fp8_gemm.py +++ b/sgl-kernel/benchmark/bench_fp8_gemm.py @@ -125,7 +125,7 @@ def benchmark(batch_size, provider, N, K): a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a) b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b) b_fp8 = b_fp8.t() - ms, min_ms, max_ms = triton.testing.do_bench( + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( lambda: vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype), quantiles=quantiles, ) @@ -133,7 +133,7 @@ def benchmark(batch_size, provider, N, K): a_fp8, scale_a_fp8 = sglang_scaled_fp8_quant(a, scale_a) b_fp8, scale_b_fp8 = sglang_scaled_fp8_quant(b, scale_b) b_fp8 = b_fp8.t() - ms, min_ms, max_ms = triton.testing.do_bench( + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( lambda: sgl_scaled_mm( a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype, bias=None ), @@ -177,8 +177,6 @@ if __name__ == "__main__": KN_model_names = prepare_shapes(args) for K, N, model_name in KN_model_names: print(f"{model_name} N={N} K={K}: ") - benchmark.run( - print_data=True, show_plots=True, save_path="bench_fp8_res", N=N, K=K - ) + benchmark.run(print_data=True, N=N, K=K) print("Benchmark finished!") diff --git a/sgl-kernel/benchmark/bench_int8_gemm.py b/sgl-kernel/benchmark/bench_int8_gemm.py index c5a709393..8a3c5d3f3 100644 --- a/sgl-kernel/benchmark/bench_int8_gemm.py +++ b/sgl-kernel/benchmark/bench_int8_gemm.py @@ -86,12 +86,12 @@ def benchmark(batch_size, provider, N, K): quantiles = [0.5, 0.2, 0.8] if provider == "sgl-kernel": - ms, min_ms, max_ms = triton.testing.do_bench( + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( lambda: int8_scaled_mm(a, b, scale_a, scale_b, torch.float16, bias), quantiles=quantiles, ) if provider == "vllm": - ms, min_ms, max_ms = triton.testing.do_bench( + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( lambda: vllm_scaled_mm(a, b, scale_a, scale_b, torch.float16, bias), quantiles=quantiles, ) @@ -139,8 +139,6 @@ if __name__ == "__main__": KN_model_names = prepare_shapes(args) for K, N, model_name in KN_model_names: print(f"{model_name} N={N} K={K}: ") - benchmark.run( - print_data=True, show_plots=True, save_path="bench_int8_res", N=N, K=K - ) + benchmark.run(print_data=True, N=N, K=K) print("Benchmark finished!") diff --git a/sgl-kernel/benchmark/bench_lightning_attention_decode.py b/sgl-kernel/benchmark/bench_lightning_attention_decode.py index 36bdccac0..6097e966d 100644 --- a/sgl-kernel/benchmark/bench_lightning_attention_decode.py +++ b/sgl-kernel/benchmark/bench_lightning_attention_decode.py @@ -246,7 +246,7 @@ def benchmark(batch_size, provider): quantiles = [0.5, 0.2, 0.8] if provider == "naive": - ms, min_ms, max_ms = triton.testing.do_bench( + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( lambda: lightning_attention_decode_naive( q.clone(), k.clone(), v.clone(), past_kv.clone(), slope.clone() ), @@ -257,7 +257,7 @@ def benchmark(batch_size, provider): batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype ) new_kv = torch.empty(batch_size, num_heads, head_dim, head_dim, device=device) - ms, min_ms, max_ms = triton.testing.do_bench( + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( lambda: lightning_attention_decode_kernel( q.clone(), k.clone(), @@ -270,7 +270,7 @@ def benchmark(batch_size, provider): quantiles=quantiles, ) elif provider == "triton": - ms, min_ms, max_ms = triton.testing.do_bench( + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( lambda: triton_lightning_attn_decode( q.clone(), k.clone(), v.clone(), past_kv.clone(), slope.clone() ), diff --git a/sgl-kernel/benchmark/bench_moe_align_block_size.py b/sgl-kernel/benchmark/bench_moe_align_block_size.py index ed8a7b8f3..749d97d8f 100644 --- a/sgl-kernel/benchmark/bench_moe_align_block_size.py +++ b/sgl-kernel/benchmark/bench_moe_align_block_size.py @@ -324,7 +324,7 @@ def benchmark(num_tokens, num_experts, topk, provider): quantiles = [0.5, 0.2, 0.8] if provider == "sgl": - ms, min_ms, max_ms = triton.testing.do_bench( + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( lambda: sgl_moe_align_block_size_with_empty( topk_ids, num_experts, @@ -336,7 +336,7 @@ def benchmark(num_tokens, num_experts, topk, provider): quantiles=quantiles, ) elif provider == "sgl_fusion": - ms, min_ms, max_ms = triton.testing.do_bench( + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( lambda: sgl_moe_align_block_size_with_empty( topk_ids, num_experts, @@ -350,7 +350,7 @@ def benchmark(num_tokens, num_experts, topk, provider): ) elif provider == "triton": sorted_ids.fill_(topk_ids.numel()) - ms, min_ms, max_ms = triton.testing.do_bench( + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( lambda: moe_align_block_size_triton( topk_ids, num_experts, diff --git a/sgl-kernel/benchmark/bench_moe_ep_post_reorder.py b/sgl-kernel/benchmark/bench_moe_ep_post_reorder.py index faadd7698..38b07e83a 100644 --- a/sgl-kernel/benchmark/bench_moe_ep_post_reorder.py +++ b/sgl-kernel/benchmark/bench_moe_ep_post_reorder.py @@ -63,7 +63,9 @@ def benchmark(batch_size, provider): block_size, ) - ms, min_ms, max_ms = triton.testing.do_bench(run_triton, quantiles=quantiles) + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( + run_triton, quantiles=quantiles + ) else: raise ValueError(f"Unknown provider: {provider}") diff --git a/sgl-kernel/benchmark/bench_moe_fused_gate.py b/sgl-kernel/benchmark/bench_moe_fused_gate.py index 36cc9c498..4455a91b6 100644 --- a/sgl-kernel/benchmark/bench_moe_fused_gate.py +++ b/sgl-kernel/benchmark/bench_moe_fused_gate.py @@ -46,7 +46,7 @@ configs = [(sq,) for sq in seq_length_range] ) ) def benchmark(seq_length, provider): - dtype = torch.bfloat16 + dtype = torch.float32 device = torch.device("cuda") num_experts, num_expert_group, topk_group, topk = 256, 8, 4, 8 @@ -56,14 +56,14 @@ def benchmark(seq_length, provider): quantiles = [0.5, 0.2, 0.8] if provider == "original": - ms, min_ms, max_ms = triton.testing.do_bench( + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( lambda: biased_grouped_topk_org( scores.clone(), bias.clone(), num_expert_group, topk_group, topk ), quantiles=quantiles, ) elif provider == "kernel": - ms, min_ms, max_ms = triton.testing.do_bench( + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( lambda: biased_grouped_topk_org_fuse_kernel( scores.clone(), bias.clone(), num_expert_group, topk_group, topk ), diff --git a/sgl-kernel/benchmark/bench_moe_topk_softmax.py b/sgl-kernel/benchmark/bench_moe_topk_softmax.py index 1d3e3e93f..36d89eb82 100644 --- a/sgl-kernel/benchmark/bench_moe_topk_softmax.py +++ b/sgl-kernel/benchmark/bench_moe_topk_softmax.py @@ -97,7 +97,7 @@ def benchmark(num_tokens, num_experts, topk, provider): fn = lambda: sglang_topk_softmax(gating_output, topk) quantiles = [0.5, 0.2, 0.8] - ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=quantiles) + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=quantiles) return 1000 * ms, 1000 * max_ms, 1000 * min_ms diff --git a/sgl-kernel/benchmark/bench_nvfp4_scaled_gemm.py b/sgl-kernel/benchmark/bench_nvfp4_scaled_gemm.py index 44498a3b4..e7fdcc7ae 100644 --- a/sgl-kernel/benchmark/bench_nvfp4_scaled_gemm.py +++ b/sgl-kernel/benchmark/bench_nvfp4_scaled_gemm.py @@ -165,8 +165,6 @@ if __name__ == "__main__": KN_model_names = prepare_shapes(args) for K, N, model_name in KN_model_names: print(f"{model_name} N={N} K={K}: ") - benchmark.run( - print_data=True, show_plots=True, save_path="bench_fp4_res", N=N, K=K - ) + benchmark.run(print_data=True, N=N, K=K) print("Benchmark finished!") diff --git a/sgl-kernel/benchmark/bench_per_tensor_quant_fp8.py b/sgl-kernel/benchmark/bench_per_tensor_quant_fp8.py index 8bc7d1e01..d42007819 100644 --- a/sgl-kernel/benchmark/bench_per_tensor_quant_fp8.py +++ b/sgl-kernel/benchmark/bench_per_tensor_quant_fp8.py @@ -88,7 +88,7 @@ def benchmark(batch_size, seq_len, provider): elif provider == "sglang": fn = lambda: sglang_scaled_fp8_quant(x.clone()) - ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=quantiles) + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=quantiles) return 1000 * ms, 1000 * max_ms, 1000 * min_ms diff --git a/sgl-kernel/benchmark/bench_per_token_quant_fp8.py b/sgl-kernel/benchmark/bench_per_token_quant_fp8.py index a72a1a3d0..e77410f92 100644 --- a/sgl-kernel/benchmark/bench_per_token_quant_fp8.py +++ b/sgl-kernel/benchmark/bench_per_token_quant_fp8.py @@ -160,7 +160,7 @@ def benchmark_quantization(batch_size, seq_len, hidden_dim, provider): elif provider == "sglang": fn = lambda: sglang_per_token_quant_fp8(x.clone()) - ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=quantiles) + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=quantiles) return 1000 * ms, 1000 * max_ms, 1000 * min_ms diff --git a/sgl-kernel/benchmark/bench_qserve_w4a8_gemm.py b/sgl-kernel/benchmark/bench_qserve_w4a8_gemm.py index 18fa4869d..bcf3d7fba 100644 --- a/sgl-kernel/benchmark/bench_qserve_w4a8_gemm.py +++ b/sgl-kernel/benchmark/bench_qserve_w4a8_gemm.py @@ -117,17 +117,17 @@ def benchmark(batch_size, provider, N, K): quantiles = [0.5, 0.2, 0.8] if provider == "FP16": - ms, min_ms, max_ms = triton.testing.do_bench( + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( lambda: torch.matmul(a_fp16, b_fp16), quantiles=quantiles, ) if provider == "W8A8": - ms, min_ms, max_ms = triton.testing.do_bench( + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( lambda: int8_scaled_mm(a, b, scale_a, scale_b, torch.float16), quantiles=quantiles, ) if provider == "Qserve_W4A8_Per_Channel": - ms, min_ms, max_ms = triton.testing.do_bench( + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( lambda: qserve_w4a8_per_chn_gemm( a_qserve_chn, b_qserve_chn, @@ -139,7 +139,7 @@ def benchmark(batch_size, provider, N, K): quantiles=quantiles, ) if provider == "Qserve_W4A8_Per_Group": - ms, min_ms, max_ms = triton.testing.do_bench( + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( lambda: qserve_w4a8_per_group_gemm( a_qserve_group, b_qserve_group, @@ -189,8 +189,6 @@ if __name__ == "__main__": print(f"{model_name} N={N} K={K}: ") benchmark.run( print_data=True, - show_plots=True, - save_path="bench_qserve_w4a8_gemm_res", N=N, K=K, ) diff --git a/sgl-kernel/benchmark/bench_rmsnorm.py b/sgl-kernel/benchmark/bench_rmsnorm.py new file mode 100644 index 000000000..fc3b732c5 --- /dev/null +++ b/sgl-kernel/benchmark/bench_rmsnorm.py @@ -0,0 +1,318 @@ +# Benchmarks SGLang RMSNorm kernels versus vLLM and FlashInfer across +# (batch_size, seq_len, hidden_size) and prints speed-up. +import argparse +import itertools +import re +from typing import List, Optional, Tuple, Union + +import sgl_kernel +import torch +import torch.nn as nn +import triton +import triton.testing +from flashinfer.norm import fused_add_rmsnorm, rmsnorm +from sgl_kernel.utils import is_arch_support_pdl +from vllm import _custom_ops as vllm_ops + + +def str2int_list(arg: str) -> List[int]: + if arg in ("", None): + return [] + if re.fullmatch(r"\d+(,\d+)*", arg.strip()) is None: + raise argparse.ArgumentTypeError(f"Bad int list: {arg}") + return [int(x) for x in arg.split(",")] + + +class HuggingFaceRMSNorm(nn.Module): + def __init__(self, hidden_size: int, eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + orig_dtype = x.dtype + x = x.to(torch.float32) + if residual is not None: + x = x + residual.to(torch.float32) + residual = x.to(orig_dtype) + + variance = x.pow(2).mean(dim=-1, keepdim=True) + x = x * torch.rsqrt(variance + self.variance_epsilon) + x = x.to(orig_dtype) * self.weight + if residual is None: + return x + else: + return x, residual + + +def rmsnorm_naive( + x: torch.Tensor, + weight: torch.Tensor, + residual: Optional[torch.Tensor] = None, + eps: float = 1e-6, +): + naive_norm = HuggingFaceRMSNorm(x.shape[-1], eps=eps) + naive_norm.weight = nn.Parameter(weight) + naive_norm = naive_norm.to(x.device) + + orig_shape = x.shape + x = x.view(-1, x.shape[-1]) + if residual is not None: + residual = residual.view(-1, residual.shape[-1]) + + output = naive_norm(x, residual) + + if isinstance(output, tuple): + output = (output[0].view(orig_shape), output[1].view(orig_shape)) + else: + output = output.view(orig_shape) + return output + + +def rmsnorm_flashinfer( + x: torch.Tensor, + weight: torch.Tensor, + residual: Optional[torch.Tensor] = None, + eps: float = 1e-6, +): + orig_shape = x.shape + x = x.view(-1, x.shape[-1]) + if residual is not None: + residual = residual.view(-1, residual.shape[-1]) + + if residual is not None: + fused_add_rmsnorm(x, residual, weight, eps) + output = (x, residual) + else: + output = rmsnorm(x, weight, eps) + + if isinstance(output, tuple): + output = (output[0].view(orig_shape), output[1].view(orig_shape)) + else: + output = output.view(orig_shape) + return output + + +def rmsnorm_vllm( + x: torch.Tensor, + weight: torch.Tensor, + residual: Optional[torch.Tensor] = None, + eps: float = 1e-6, +): + orig_shape = x.shape + x = x.view(-1, x.shape[-1]) + if residual is not None: + residual = residual.view(-1, residual.shape[-1]) + + if residual is not None: + vllm_ops.fused_add_rms_norm(x, residual, weight, eps) + output = (x, residual) + else: + out = torch.empty_like(x) + vllm_ops.rms_norm(out, x, weight, eps) + output = out + + if isinstance(output, tuple): + output = (output[0].view(orig_shape), output[1].view(orig_shape)) + else: + output = output.view(orig_shape) + return output + + +def rmsnorm_sglang( + x: torch.Tensor, + weight: torch.Tensor, + residual: Optional[torch.Tensor] = None, + eps: float = 1e-6, + enable_pdl: Optional[bool] = None, +): + orig_shape = x.shape + x = x.view(-1, x.shape[-1]) + if residual is not None: + residual = residual.view(-1, residual.shape[-1]) + + if enable_pdl is None: + enable_pdl = is_arch_support_pdl() + + if residual is not None: + sgl_kernel.fused_add_rmsnorm(x, residual, weight, eps, enable_pdl=enable_pdl) + output = (x, residual) + else: + out = torch.empty_like(x) + sgl_kernel.rmsnorm(x, weight, eps, out=out, enable_pdl=enable_pdl) + output = out + + if isinstance(output, tuple): + output = (output[0].view(orig_shape), output[1].view(orig_shape)) + else: + output = output.view(orig_shape) + return output + + +def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True): + dtype = torch.bfloat16 + x = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device="cuda") + weight = torch.ones(hidden_size, dtype=dtype, device="cuda") + residual = torch.randn_like(x) if use_residual else None + + output_naive = rmsnorm_naive( + x.clone(), weight, residual.clone() if residual is not None else None + ) + output_flashinfer = rmsnorm_flashinfer( + x.clone(), weight, residual.clone() if residual is not None else None + ) + output_vllm = rmsnorm_vllm( + x.clone(), weight, residual.clone() if residual is not None else None + ) + output_sglang = rmsnorm_sglang( + x.clone(), weight, residual.clone() if residual is not None else None + ) + + if use_residual: + output_naive = output_naive[0] + output_flashinfer = output_flashinfer[0] + output_vllm = output_vllm[0] + output_sglang = output_sglang[0] + + print(f"Naive output={output_naive}") + print(f"FlashInfer output={output_flashinfer}") + print(f"VLLM output={output_vllm}") + print(f"SGLang output={output_sglang}") + + if ( + torch.allclose(output_naive, output_flashinfer, atol=1e-2, rtol=1e-2) + and torch.allclose(output_naive, output_vllm, atol=1e-2, rtol=1e-2) + and torch.allclose(output_naive, output_sglang, atol=1e-2, rtol=1e-2) + ): + print("✅ All implementations match") + else: + print("❌ Implementations differ") + + +default_batch_sizes = [2**i for i in range(0, 7, 2)] # 1, 4, 16, 64 +default_seq_lens = [2**i for i in range(6, 11, 1)] # 64, 128, 256, 512, 1024 +default_hidden_sizes = [32 * 128, 48 * 128] # 4096, 6144 + + +def make_configs(bsizes: List[int], slens: List[int], hsizes: List[int]) -> List[Tuple]: + return list(itertools.product(bsizes, slens, hsizes)) + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["batch_size", "seq_len", "hidden_size"], + x_vals=[], + line_arg="provider", + line_vals=["huggingface", "flashinfer", "vllm", "sglang"], + line_names=["HuggingFace", "FlashInfer", "vLLM", "SGL Kernel"], + styles=[("blue", "-"), ("green", "-"), ("red", "-"), ("orange", "-")], + ylabel="µs (median) or × (speed-up)", + plot_name="rmsnorm-performance", + args={}, + ) +) +def benchmark(batch_size, seq_len, hidden_size, provider, use_residual): + device = torch.device("cuda") + dtype = torch.bfloat16 + + x = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device=device) + weight = torch.ones(hidden_size, dtype=dtype, device=device) + residual = torch.randn_like(x) if use_residual else None + + # timing helper + def timed(fn): + for _ in range(5): + fn() + torch.cuda.synchronize() + ms, qmin, qmax = triton.testing.do_bench_cudagraph( + fn, quantiles=[0.5, 0.2, 0.8] + ) + return 1000 * ms, 1000 * qmax, 1000 * qmin + + if provider == "huggingface": + return timed( + lambda: rmsnorm_naive( + x.clone(), + weight, + residual.clone() if residual is not None else None, + ) + ) + elif provider == "flashinfer": + return timed( + lambda: rmsnorm_flashinfer( + x.clone(), + weight, + residual.clone() if residual is not None else None, + ) + ) + elif provider == "vllm": + return timed( + lambda: rmsnorm_vllm( + x.clone(), + weight, + residual.clone() if residual is not None else None, + ) + ) + elif provider == "sglang": + return timed( + lambda: rmsnorm_sglang( + x.clone(), + weight, + residual.clone() if residual is not None else None, + ) + ) + + # provider == "speedup" + t_ref, _, _ = timed( + lambda: rmsnorm_vllm( + x.clone(), + weight, + residual.clone() if residual is not None else None, + ) + ) + t_sgl, _, _ = timed( + lambda: rmsnorm_sglang( + x.clone(), + weight, + residual.clone() if residual is not None else None, + ) + ) + spd = t_ref / t_sgl + return (spd, spd, spd) + + +if __name__ == "__main__": + p = argparse.ArgumentParser("RMSNorm kernel benchmark") + p.add_argument("--batch_sizes", type=str2int_list, default=default_batch_sizes) + p.add_argument("--seq_lens", type=str2int_list, default=default_seq_lens) + p.add_argument("--hidden_sizes", type=str2int_list, default=default_hidden_sizes) + p.add_argument( + "--use_residual", action="store_true", help="Whether to use residual connection" + ) + p.add_argument("--verify_only", action="store_true") + args = p.parse_args() + + # coerce lists + if isinstance(args.batch_sizes, str): + args.batch_sizes = str2int_list(args.batch_sizes) + if isinstance(args.seq_lens, str): + args.seq_lens = str2int_list(args.seq_lens) + if isinstance(args.hidden_sizes, str): + args.hidden_sizes = str2int_list(args.hidden_sizes) + + # patch perf_report grid + benchmark_grid = make_configs(args.batch_sizes, args.seq_lens, args.hidden_sizes) + if hasattr(benchmark, "benchmarks"): + benchmark.benchmarks.x_vals = benchmark_grid + else: + benchmark.benchmark.x_vals = benchmark_grid + + if args.verify_only: + ok = calculate_diff(4, 128, args.hidden_sizes[0], args.use_residual) + print("✅ sanity pass" if ok else "❌ mismatch") + else: + benchmark.run(print_data=True, use_residual=args.use_residual) diff --git a/sgl-kernel/benchmark/bench_top_k_top_p_sampling.py b/sgl-kernel/benchmark/bench_top_k_top_p_sampling.py index 3692b5b39..254584837 100644 --- a/sgl-kernel/benchmark/bench_top_k_top_p_sampling.py +++ b/sgl-kernel/benchmark/bench_top_k_top_p_sampling.py @@ -114,7 +114,9 @@ def benchmark_sampling(batch_size, vocab_size, p, provider): filter_apply_order="joint", ) - ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=[0.5, 0.2, 0.8]) + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( + fn, quantiles=[0.5, 0.2, 0.8] + ) return 1000 * ms, 1000 * max_ms, 1000 * min_ms diff --git a/sgl-kernel/tests/test_norm.py b/sgl-kernel/tests/test_norm.py index d22da931f..ed61663ed 100644 --- a/sgl-kernel/tests/test_norm.py +++ b/sgl-kernel/tests/test_norm.py @@ -3,6 +3,7 @@ import pytest import sgl_kernel import torch +from sgl_kernel.utils import is_arch_support_pdl def llama_rms_norm(x, w, eps=1e-6): @@ -58,11 +59,12 @@ def test_norm(batch_size, hidden_size, dtype, specify_out): w = torch.randn(hidden_size).to(0).to(dtype) y_ref = llama_rms_norm(x, w) + enable_pdl = is_arch_support_pdl() if specify_out: y = torch.empty_like(x) - sgl_kernel.rmsnorm(x, w, out=y) + sgl_kernel.rmsnorm(x, w, out=y, enable_pdl=enable_pdl) else: - y = sgl_kernel.rmsnorm(x, w) + y = sgl_kernel.rmsnorm(x, w, enable_pdl=enable_pdl) torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3) @@ -83,7 +85,10 @@ def test_fused_add_rmsnorm(batch_size, hidden_size, dtype): x_fused = x.clone() residual_fused = residual.clone() - sgl_kernel.fused_add_rmsnorm(x_fused, residual_fused, weight, eps) + enable_pdl = is_arch_support_pdl() + sgl_kernel.fused_add_rmsnorm( + x_fused, residual_fused, weight, eps, enable_pdl=enable_pdl + ) torch.testing.assert_close(x_fused, x_native, rtol=1e-3, atol=1e-3) torch.testing.assert_close(residual_fused, residual_native, rtol=1e-3, atol=1e-3) @@ -98,11 +103,12 @@ def test_gemma_norm(batch_size, hidden_size, dtype, specify_out): w = torch.randn(hidden_size).to(0).to(dtype) y_ref = gemma_rms_norm(x, w) + enable_pdl = is_arch_support_pdl() if specify_out: y = torch.empty_like(x) - sgl_kernel.gemma_rmsnorm(x, w, out=y) + sgl_kernel.gemma_rmsnorm(x, w, out=y, enable_pdl=enable_pdl) else: - y = sgl_kernel.gemma_rmsnorm(x, w) + y = sgl_kernel.gemma_rmsnorm(x, w, enable_pdl=enable_pdl) torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3) @@ -123,7 +129,10 @@ def test_gemma_fused_add_rmsnorm(batch_size, hidden_size, dtype): x_fused = x.clone() residual_fused = residual.clone() - sgl_kernel.gemma_fused_add_rmsnorm(x_fused, residual_fused, weight, eps) + enable_pdl = is_arch_support_pdl() + sgl_kernel.gemma_fused_add_rmsnorm( + x_fused, residual_fused, weight, eps, enable_pdl=enable_pdl + ) torch.testing.assert_close(x_fused, x_native, rtol=1e-3, atol=1e-3) torch.testing.assert_close(residual_fused, residual_native, rtol=1e-3, atol=1e-3)