Restruct sgl-kernel benchmark (#10861)
This commit is contained in:
@@ -5,7 +5,8 @@ import tilelang
|
||||
import tilelang.language as T
|
||||
import torch
|
||||
import triton
|
||||
from deep_gemm import ceil_div, get_col_major_tma_aligned_tensor
|
||||
from deep_gemm import ceil_div
|
||||
from deep_gemm.utils.layout import get_mn_major_tma_aligned_tensor
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
w8a8_block_fp8_matmul as vllm_w8a8_block_fp8_matmul,
|
||||
)
|
||||
@@ -131,7 +132,7 @@ def fp8_gemm_deepgemm(
|
||||
out = torch.empty((m, n), device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
# Run DeepGEMM kernel
|
||||
deep_gemm.gemm_fp8_fp8_bf16_nt((x_fp8, x_scale), (y_fp8, y_scale), out)
|
||||
deep_gemm.fp8_gemm_nt((x_fp8, x_scale), (y_fp8, y_scale), out)
|
||||
return out
|
||||
|
||||
|
||||
@@ -179,7 +180,7 @@ def calculate_diff(m: int, n: int, k: int):
|
||||
|
||||
x_fp8, x_scale = per_token_cast_to_fp8(x.clone())
|
||||
y_fp8, y_scale = per_block_cast_to_fp8(y.clone())
|
||||
x_scale_col_major = get_col_major_tma_aligned_tensor(x_scale.clone())
|
||||
x_scale_col_major = get_mn_major_tma_aligned_tensor(x_scale.clone())
|
||||
|
||||
out_deepgemm = fp8_gemm_deepgemm(
|
||||
x_fp8.clone(),
|
||||
@@ -300,7 +301,7 @@ def get_benchmark(tp_size):
|
||||
# Preprocess data before benchmarking
|
||||
x_fp8, x_scale = per_token_cast_to_fp8(x)
|
||||
y_fp8, y_scale = per_block_cast_to_fp8(y)
|
||||
x_scale_col_major = get_col_major_tma_aligned_tensor(x_scale.clone())
|
||||
x_scale_col_major = get_mn_major_tma_aligned_tensor(x_scale.clone())
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
|
||||
@@ -4,7 +4,8 @@ import deep_gemm
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from deep_gemm import calc_diff, get_col_major_tma_aligned_tensor
|
||||
from deep_gemm import calc_diff
|
||||
from deep_gemm.utils.layout import get_mn_major_tma_aligned_tensor
|
||||
|
||||
# Import shared functionality from the regular GEMM benchmark
|
||||
from sglang.benchmark.kernels.deepseek.benchmark_deepgemm_fp8_gemm import (
|
||||
@@ -71,9 +72,9 @@ def construct_grouped_and_flat_fp8(
|
||||
# Transpose earlier for testing
|
||||
x_fp8_grouped = (
|
||||
x_fp8_grouped[0],
|
||||
get_col_major_tma_aligned_tensor(x_fp8_grouped[1]),
|
||||
get_mn_major_tma_aligned_tensor(x_fp8_grouped[1]),
|
||||
)
|
||||
x_fp8_flat = (x_fp8_flat[0], get_col_major_tma_aligned_tensor(x_fp8_flat[1]))
|
||||
x_fp8_flat = (x_fp8_flat[0], get_mn_major_tma_aligned_tensor(x_fp8_flat[1]))
|
||||
|
||||
return x_fp8_grouped, y_fp8_grouped, x_fp8_flat, y_fp8_flat, out, ref_out
|
||||
|
||||
@@ -240,7 +241,7 @@ def fp8_gemm_group_triton(a_tuple, b_tuple, c, num_groups):
|
||||
|
||||
|
||||
def fp8_gemm_group_deepgemm(x_fp8_grouped, y_fp8_grouped, out, m_indices):
|
||||
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
||||
deep_gemm.m_grouped_fp8_gemm_nt_contiguous(
|
||||
x_fp8_grouped,
|
||||
y_fp8_grouped,
|
||||
out,
|
||||
|
||||
@@ -1,230 +0,0 @@
|
||||
import itertools
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import triton
|
||||
from flashinfer.norm import fused_add_rmsnorm, rmsnorm
|
||||
from torch import nn
|
||||
from vllm import _custom_ops as vllm_ops
|
||||
|
||||
|
||||
class HuggingFaceRMSNorm(nn.Module):
|
||||
def __init__(self, hidden_size: int, eps: float = 1e-6) -> None:
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
orig_dtype = x.dtype
|
||||
x = x.to(torch.float32)
|
||||
if residual is not None:
|
||||
x = x + residual.to(torch.float32)
|
||||
residual = x.to(orig_dtype)
|
||||
|
||||
variance = x.pow(2).mean(dim=-1, keepdim=True)
|
||||
x = x * torch.rsqrt(variance + self.variance_epsilon)
|
||||
x = x.to(orig_dtype) * self.weight
|
||||
if residual is None:
|
||||
return x
|
||||
else:
|
||||
return x, residual
|
||||
|
||||
|
||||
def rmsnorm_naive(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
eps: float = 1e-6,
|
||||
):
|
||||
naive_norm = HuggingFaceRMSNorm(x.shape[-1], eps=eps)
|
||||
naive_norm.weight = nn.Parameter(weight)
|
||||
naive_norm = naive_norm.to(x.device)
|
||||
|
||||
orig_shape = x.shape
|
||||
x = x.view(-1, x.shape[-1])
|
||||
if residual is not None:
|
||||
residual = residual.view(-1, residual.shape[-1])
|
||||
|
||||
output = naive_norm(x, residual)
|
||||
|
||||
if isinstance(output, tuple):
|
||||
output = (output[0].view(orig_shape), output[1].view(orig_shape))
|
||||
else:
|
||||
output = output.view(orig_shape)
|
||||
return output
|
||||
|
||||
|
||||
def rmsnorm_flashinfer(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
eps: float = 1e-6,
|
||||
):
|
||||
orig_shape = x.shape
|
||||
x = x.view(-1, x.shape[-1])
|
||||
if residual is not None:
|
||||
residual = residual.view(-1, residual.shape[-1])
|
||||
|
||||
if residual is not None:
|
||||
fused_add_rmsnorm(x, residual, weight, eps)
|
||||
output = (x, residual)
|
||||
else:
|
||||
output = rmsnorm(x, weight, eps)
|
||||
|
||||
if isinstance(output, tuple):
|
||||
output = (output[0].view(orig_shape), output[1].view(orig_shape))
|
||||
else:
|
||||
output = output.view(orig_shape)
|
||||
return output
|
||||
|
||||
|
||||
def rmsnorm_vllm(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
eps: float = 1e-6,
|
||||
):
|
||||
orig_shape = x.shape
|
||||
x = x.view(-1, x.shape[-1])
|
||||
if residual is not None:
|
||||
residual = residual.view(-1, residual.shape[-1])
|
||||
|
||||
if residual is not None:
|
||||
vllm_ops.fused_add_rms_norm(x, residual, weight, eps)
|
||||
output = (x, residual)
|
||||
else:
|
||||
out = torch.empty_like(x)
|
||||
vllm_ops.rms_norm(out, x, weight, eps)
|
||||
output = out
|
||||
|
||||
if isinstance(output, tuple):
|
||||
output = (output[0].view(orig_shape), output[1].view(orig_shape))
|
||||
else:
|
||||
output = output.view(orig_shape)
|
||||
return output
|
||||
|
||||
|
||||
def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True):
|
||||
dtype = torch.bfloat16
|
||||
x = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device="cuda")
|
||||
weight = torch.ones(hidden_size, dtype=dtype, device="cuda")
|
||||
residual = torch.randn_like(x) if use_residual else None
|
||||
|
||||
output_naive = rmsnorm_naive(
|
||||
x.clone(), weight, residual.clone() if residual is not None else None
|
||||
)
|
||||
output_flashinfer = rmsnorm_flashinfer(
|
||||
x.clone(), weight, residual.clone() if residual is not None else None
|
||||
)
|
||||
output_vllm = rmsnorm_vllm(
|
||||
x.clone(), weight, residual.clone() if residual is not None else None
|
||||
)
|
||||
|
||||
if use_residual:
|
||||
output_naive = output_naive[0]
|
||||
output_flashinfer = output_flashinfer[0]
|
||||
output_vllm = output_vllm[0]
|
||||
|
||||
print(f"Naive output={output_naive}")
|
||||
print(f"FlashInfer output={output_flashinfer}")
|
||||
print(f"VLLM output={output_vllm}")
|
||||
|
||||
if torch.allclose(
|
||||
output_naive, output_flashinfer, atol=1e-2, rtol=1e-2
|
||||
) and torch.allclose(output_naive, output_vllm, atol=1e-2, rtol=1e-2):
|
||||
print("✅ All implementations match")
|
||||
else:
|
||||
print("❌ Implementations differ")
|
||||
|
||||
|
||||
batch_size_range = [2**i for i in range(0, 7, 2)]
|
||||
seq_length_range = [2**i for i in range(6, 11, 1)]
|
||||
head_num_range = [32, 48]
|
||||
configs = list(itertools.product(head_num_range, batch_size_range, seq_length_range))
|
||||
|
||||
|
||||
def get_benchmark(use_residual):
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["head_num", "batch_size", "seq_len"],
|
||||
x_vals=[list(_) for _ in configs],
|
||||
line_arg="provider",
|
||||
line_vals=["huggingface", "flashinfer", "vllm"],
|
||||
line_names=["HuggingFace", "FlashInfer", "vLLM"],
|
||||
styles=[("blue", "-"), ("green", "-"), ("red", "-")],
|
||||
ylabel="us",
|
||||
plot_name=f"rmsnorm-performance-{'with' if use_residual else 'without'}-residual",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(head_num, batch_size, seq_len, provider):
|
||||
dtype = torch.bfloat16
|
||||
hidden_size = head_num * 128 # assuming head_dim = 128
|
||||
|
||||
x = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device="cuda")
|
||||
weight = torch.ones(hidden_size, dtype=dtype, device="cuda")
|
||||
residual = torch.randn_like(x) if use_residual else None
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
if provider == "huggingface":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: rmsnorm_naive(
|
||||
x.clone(),
|
||||
weight,
|
||||
residual.clone() if residual is not None else None,
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
elif provider == "flashinfer":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: rmsnorm_flashinfer(
|
||||
x.clone(),
|
||||
weight,
|
||||
residual.clone() if residual is not None else None,
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
else:
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: rmsnorm_vllm(
|
||||
x.clone(),
|
||||
weight,
|
||||
residual.clone() if residual is not None else None,
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
|
||||
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||
|
||||
return benchmark
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--use_residual", action="store_true", help="Whether to use residual connection"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_path",
|
||||
type=str,
|
||||
default="./configs/benchmark_ops/rmsnorm/",
|
||||
help="Path to save rmsnorm benchmark results",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Run correctness test
|
||||
calculate_diff(
|
||||
batch_size=4, seq_len=128, hidden_size=4096, use_residual=args.use_residual
|
||||
)
|
||||
|
||||
# Get the benchmark function with proper use_residual setting
|
||||
benchmark = get_benchmark(args.use_residual)
|
||||
# Run performance benchmark
|
||||
benchmark.run(print_data=True, save_path=args.save_path)
|
||||
@@ -19,10 +19,6 @@ from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
||||
from sglang.srt.layers.quantization.utils import is_layer_skipped
|
||||
from sglang.srt.utils import is_npu, set_weight_attrs
|
||||
|
||||
_is_npu = is_npu()
|
||||
if not _is_npu:
|
||||
from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe import MoeRunnerConfig
|
||||
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
||||
|
||||
@@ -621,11 +621,11 @@ class TestW8A8BlockFP8BatchedDeepGemm(CustomTestCase):
|
||||
w_s,
|
||||
)
|
||||
|
||||
from deep_gemm import m_grouped_gemm_fp8_fp8_bf16_nt_masked
|
||||
from deep_gemm import fp8_m_grouped_gemm_nt_masked
|
||||
|
||||
with torch.inference_mode():
|
||||
ref_out = torch_w8a8_block_fp8_bmm(a, a_s, w, w_s, block_size, dtype)
|
||||
m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs, rhs, oe, masked_m, expected_m)
|
||||
fp8_m_grouped_gemm_nt_masked(lhs, rhs, oe, masked_m, expected_m)
|
||||
out = oe[:, :M, :]
|
||||
|
||||
self.assertTrue(
|
||||
|
||||
@@ -251,6 +251,14 @@ To use this with your library functions, simply wrap them with make_pytorch_shim
|
||||
```
|
||||
|
||||
2. Add benchmarks using [triton benchmark](https://triton-lang.org/main/python-api/generated/triton.testing.Benchmark.html) in [benchmark/](https://github.com/sgl-project/sglang/tree/main/sgl-kernel/benchmark)
|
||||
|
||||
**We recommend using `triton.testing.do_bench_cudagraph` for kernel benchmarking**:
|
||||
|
||||
Compared to `triton.testing.do_bench`, `do_bench_cudagraph` provides:
|
||||
- Reduced CPU overhead impact for more accurate kernel performance measurements
|
||||
- Incorporation of PDL (Programmatic Dependent Launch) effects into individual kernel results
|
||||
- More realistic performance data on PDL-supported architectures (SM >= 90)
|
||||
|
||||
3. Run test suite
|
||||
|
||||
### FAQ
|
||||
|
||||
@@ -10,10 +10,18 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
import triton
|
||||
import triton.testing
|
||||
from sgl_kernel import gelu_quick # activation-only kernel
|
||||
from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
|
||||
from vllm import _custom_ops as vllm_ops
|
||||
|
||||
# gelu_quick is only available on HIP/ROCm platforms
|
||||
try:
|
||||
from sgl_kernel import gelu_quick
|
||||
|
||||
GELU_QUICK_AVAILABLE = True
|
||||
except ImportError:
|
||||
GELU_QUICK_AVAILABLE = False
|
||||
gelu_quick = None
|
||||
|
||||
if not hasattr(vllm_ops, "silu_and_mul"):
|
||||
vllm_ops = torch.ops._C
|
||||
|
||||
@@ -34,6 +42,12 @@ def calculate_diff(
|
||||
|
||||
# activation-only quick GELU
|
||||
if kernel == "gelu_quick":
|
||||
if not GELU_QUICK_AVAILABLE:
|
||||
print(
|
||||
f"[{kernel:14s} | {str(dtype):9s} | B={batch_size:3d} | "
|
||||
f"L={seq_len:3d} | D={dim:5d}] ⚠️ not available on this platform"
|
||||
)
|
||||
return True
|
||||
x = torch.randn(batch_size, seq_len, dim, dtype=dtype, device=device)
|
||||
ref_out = torch.zeros_like(x)
|
||||
getattr(vllm_ops, kernel)(ref_out, x)
|
||||
@@ -54,7 +68,9 @@ def calculate_diff(
|
||||
return ok
|
||||
|
||||
|
||||
kernels = ["silu_and_mul", "gelu_and_mul", "gelu_tanh_and_mul", "gelu_quick"]
|
||||
kernels = ["silu_and_mul", "gelu_and_mul", "gelu_tanh_and_mul"]
|
||||
if GELU_QUICK_AVAILABLE:
|
||||
kernels.append("gelu_quick")
|
||||
dtypes = [torch.float16, torch.bfloat16]
|
||||
|
||||
|
||||
@@ -64,7 +80,7 @@ def make_configs(bsizes: List[int], slens: List[int], dims_: List[int]) -> List[
|
||||
|
||||
default_batch_sizes = [2**i for i in range(0, 5, 2)] # 1,4,16
|
||||
default_seq_lens = [2**i for i in range(0, 8, 2)] # 1,4,16,64
|
||||
default_dims = [2**i for i in range(7, 15)] # 128...16384
|
||||
default_dims = [2**i for i in range(10, 15)] # 1024...16384
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
@@ -87,6 +103,9 @@ def benchmark(kernel, dtype, batch_size, seq_len, dim, provider):
|
||||
y0 = torch.zeros(batch_size, seq_len, dim, dtype=dtype, device=device)
|
||||
|
||||
vllm_kernel = getattr(vllm_ops, kernel)
|
||||
if kernel == "gelu_quick" and not GELU_QUICK_AVAILABLE:
|
||||
# Skip benchmark for gelu_quick if not available
|
||||
return (0, 0, 0)
|
||||
sglang_kernel = getattr(sgl_kernel, kernel)
|
||||
|
||||
def baseline():
|
||||
@@ -97,18 +116,14 @@ def benchmark(kernel, dtype, batch_size, seq_len, dim, provider):
|
||||
def sglang():
|
||||
return sglang_kernel(x)
|
||||
|
||||
# one-time correctness check
|
||||
if provider == "vllm" and not calculate_diff(
|
||||
kernel, dtype, batch_size, seq_len, dim
|
||||
):
|
||||
raise ValueError("Mismatch – abort benchmark")
|
||||
|
||||
# timing helper
|
||||
def timed(fn):
|
||||
for _ in range(5):
|
||||
fn()
|
||||
torch.cuda.synchronize()
|
||||
ms, qmin, qmax = triton.testing.do_bench(fn, quantiles=[0.5, 0.2, 0.8])
|
||||
ms, qmin, qmax = triton.testing.do_bench_cudagraph(
|
||||
fn, quantiles=[0.5, 0.2, 0.8]
|
||||
)
|
||||
return 1000 * ms, 1000 * qmax, 1000 * qmin
|
||||
|
||||
if provider == "vllm":
|
||||
@@ -147,7 +162,9 @@ if __name__ == "__main__":
|
||||
benchmark.benchmark.x_vals = benchmark_grid
|
||||
|
||||
if args.verify_only:
|
||||
ok = calculate_diff("gelu_quick", torch.float16, 1, 1, args.dims[0])
|
||||
# Test with the first available kernel
|
||||
test_kernel = kernels[0]
|
||||
ok = calculate_diff(test_kernel, torch.float16, 1, 1, args.dims[0])
|
||||
print("✅ sanity pass" if ok else "❌ mismatch")
|
||||
else:
|
||||
benchmark.run(print_data=True)
|
||||
|
||||
@@ -108,7 +108,7 @@ def benchmark(qweight_row, qweight_col, provider):
|
||||
qweight.clone(), scales.clone(), qzeros.clone()
|
||||
)
|
||||
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=quantiles)
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=quantiles)
|
||||
|
||||
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||
|
||||
|
||||
@@ -87,7 +87,7 @@ def benchmark(batch_size, seq_len, provider, block_size, num_kv_splits):
|
||||
workspace = torch.empty(workspace_size, device="cuda", dtype=torch.uint8)
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
lambda: cutlass_mla_decode(
|
||||
qn.transpose(0, 1),
|
||||
qr,
|
||||
@@ -136,8 +136,6 @@ if __name__ == "__main__":
|
||||
print(f"block_size={block_size}, num_kv_splits={kv_split}: ")
|
||||
benchmark.run(
|
||||
print_data=True,
|
||||
show_plots=True,
|
||||
save_path="bench_blackwell_mla_res",
|
||||
block_size=block_size,
|
||||
num_kv_splits=kv_split,
|
||||
)
|
||||
|
||||
@@ -41,7 +41,7 @@ def benchmark(num_tokens, impl):
|
||||
def runner():
|
||||
dsv3_fused_a_gemm(mat_a, mat_b)
|
||||
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(runner, quantiles=quantiles)
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(runner, quantiles=quantiles)
|
||||
|
||||
def tflops(t_ms):
|
||||
flops = 2 * M * K * N
|
||||
@@ -54,4 +54,4 @@ if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
args = parser.parse_args()
|
||||
|
||||
benchmark.run(print_data=True, show_plots=True, save_path="bench_dsv3_gemm")
|
||||
benchmark.run(print_data=True)
|
||||
|
||||
@@ -52,7 +52,7 @@ def benchmark_bf16_output(num_tokens, impl):
|
||||
def runner():
|
||||
dsv3_router_gemm(mat_a, mat_b, out_dtype=torch.bfloat16)
|
||||
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(runner, quantiles=quantiles)
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(runner, quantiles=quantiles)
|
||||
|
||||
def tflops(t_ms):
|
||||
flops = 2 * M * K * N
|
||||
@@ -106,7 +106,7 @@ def benchmark_float_output(num_tokens, impl):
|
||||
def runner():
|
||||
dsv3_router_gemm(mat_a, mat_b, out_dtype=torch.float32)
|
||||
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(runner, quantiles=quantiles)
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(runner, quantiles=quantiles)
|
||||
|
||||
def tflops(t_ms):
|
||||
flops = 2 * M * K * N
|
||||
@@ -119,9 +119,5 @@ if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
args = parser.parse_args()
|
||||
|
||||
benchmark_bf16_output.run(
|
||||
print_data=True, show_plots=True, save_path="bench_dsv3_router_gemm"
|
||||
)
|
||||
benchmark_float_output.run(
|
||||
print_data=True, show_plots=True, save_path="bench_dsv3_router_gemm"
|
||||
)
|
||||
benchmark_bf16_output.run(print_data=True)
|
||||
benchmark_float_output.run(print_data=True)
|
||||
|
||||
@@ -198,8 +198,6 @@ if __name__ == "__main__":
|
||||
print(f"DeepSeek-R1-0528-FP4 N={N} K={K}: ")
|
||||
benchmark.run(
|
||||
print_data=True,
|
||||
show_plots=True,
|
||||
save_path="bench_fp4_res",
|
||||
N=N,
|
||||
K=K,
|
||||
dtype=args.dtype,
|
||||
|
||||
@@ -5,7 +5,7 @@ import itertools
|
||||
import deep_gemm
|
||||
import torch
|
||||
import triton
|
||||
from deep_gemm import get_col_major_tma_aligned_tensor
|
||||
from deep_gemm.utils.layout import get_mn_major_tma_aligned_tensor
|
||||
from sgl_kernel import fp8_blockwise_scaled_mm
|
||||
from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
|
||||
|
||||
@@ -71,7 +71,7 @@ def fp8_gemm_deepgemm(
|
||||
out = torch.empty((m, n), device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
# Run DeepGEMM kernel
|
||||
deep_gemm.gemm_fp8_fp8_bf16_nt((x_fp8, x_scale), (y_fp8, y_scale), out)
|
||||
deep_gemm.fp8_gemm_nt((x_fp8, x_scale), (y_fp8, y_scale), out)
|
||||
return out
|
||||
|
||||
|
||||
@@ -117,7 +117,7 @@ def benchmark(batch_size, provider, N, K):
|
||||
if provider == "sgl-kernel":
|
||||
scale_a = scale_a.t().contiguous().t()
|
||||
b_fp8, scale_b = b_fp8.t(), scale_b.t()
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
lambda: fp8_blockwise_scaled_mm(
|
||||
a_fp8, b_fp8, scale_a, scale_b, torch.float16
|
||||
),
|
||||
@@ -126,20 +126,20 @@ def benchmark(batch_size, provider, N, K):
|
||||
if provider == "vllm":
|
||||
scale_a = scale_a.t().contiguous().t()
|
||||
b_fp8, scale_b = b_fp8.t(), scale_b.t()
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
lambda: vllm_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, torch.float16),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
if provider == "triton":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
lambda: w8a8_block_fp8_matmul(
|
||||
a_fp8, b_fp8, scale_a, scale_b, [128, 128], torch.float16
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
if provider == "deepgemm":
|
||||
scale_a_col_major = get_col_major_tma_aligned_tensor(scale_a.clone())
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
scale_a_col_major = get_mn_major_tma_aligned_tensor(scale_a.clone())
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
lambda: fp8_gemm_deepgemm(
|
||||
a_fp8, scale_a_col_major, b_fp8, scale_b, M, N, K
|
||||
),
|
||||
@@ -174,8 +174,6 @@ if __name__ == "__main__":
|
||||
print(f"{model_name} N={N} K={K}: ")
|
||||
benchmark.run(
|
||||
print_data=True,
|
||||
show_plots=True,
|
||||
save_path="bench_fp8_blockwise_res",
|
||||
N=N,
|
||||
K=K,
|
||||
)
|
||||
|
||||
@@ -125,7 +125,7 @@ def benchmark(batch_size, provider, N, K):
|
||||
a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a)
|
||||
b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b)
|
||||
b_fp8 = b_fp8.t()
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
lambda: vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
@@ -133,7 +133,7 @@ def benchmark(batch_size, provider, N, K):
|
||||
a_fp8, scale_a_fp8 = sglang_scaled_fp8_quant(a, scale_a)
|
||||
b_fp8, scale_b_fp8 = sglang_scaled_fp8_quant(b, scale_b)
|
||||
b_fp8 = b_fp8.t()
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
lambda: sgl_scaled_mm(
|
||||
a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype, bias=None
|
||||
),
|
||||
@@ -177,8 +177,6 @@ if __name__ == "__main__":
|
||||
KN_model_names = prepare_shapes(args)
|
||||
for K, N, model_name in KN_model_names:
|
||||
print(f"{model_name} N={N} K={K}: ")
|
||||
benchmark.run(
|
||||
print_data=True, show_plots=True, save_path="bench_fp8_res", N=N, K=K
|
||||
)
|
||||
benchmark.run(print_data=True, N=N, K=K)
|
||||
|
||||
print("Benchmark finished!")
|
||||
|
||||
@@ -86,12 +86,12 @@ def benchmark(batch_size, provider, N, K):
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
if provider == "sgl-kernel":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
lambda: int8_scaled_mm(a, b, scale_a, scale_b, torch.float16, bias),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
if provider == "vllm":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
lambda: vllm_scaled_mm(a, b, scale_a, scale_b, torch.float16, bias),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
@@ -139,8 +139,6 @@ if __name__ == "__main__":
|
||||
KN_model_names = prepare_shapes(args)
|
||||
for K, N, model_name in KN_model_names:
|
||||
print(f"{model_name} N={N} K={K}: ")
|
||||
benchmark.run(
|
||||
print_data=True, show_plots=True, save_path="bench_int8_res", N=N, K=K
|
||||
)
|
||||
benchmark.run(print_data=True, N=N, K=K)
|
||||
|
||||
print("Benchmark finished!")
|
||||
|
||||
@@ -246,7 +246,7 @@ def benchmark(batch_size, provider):
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
if provider == "naive":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
lambda: lightning_attention_decode_naive(
|
||||
q.clone(), k.clone(), v.clone(), past_kv.clone(), slope.clone()
|
||||
),
|
||||
@@ -257,7 +257,7 @@ def benchmark(batch_size, provider):
|
||||
batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype
|
||||
)
|
||||
new_kv = torch.empty(batch_size, num_heads, head_dim, head_dim, device=device)
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
lambda: lightning_attention_decode_kernel(
|
||||
q.clone(),
|
||||
k.clone(),
|
||||
@@ -270,7 +270,7 @@ def benchmark(batch_size, provider):
|
||||
quantiles=quantiles,
|
||||
)
|
||||
elif provider == "triton":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
lambda: triton_lightning_attn_decode(
|
||||
q.clone(), k.clone(), v.clone(), past_kv.clone(), slope.clone()
|
||||
),
|
||||
|
||||
@@ -324,7 +324,7 @@ def benchmark(num_tokens, num_experts, topk, provider):
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
if provider == "sgl":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
lambda: sgl_moe_align_block_size_with_empty(
|
||||
topk_ids,
|
||||
num_experts,
|
||||
@@ -336,7 +336,7 @@ def benchmark(num_tokens, num_experts, topk, provider):
|
||||
quantiles=quantiles,
|
||||
)
|
||||
elif provider == "sgl_fusion":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
lambda: sgl_moe_align_block_size_with_empty(
|
||||
topk_ids,
|
||||
num_experts,
|
||||
@@ -350,7 +350,7 @@ def benchmark(num_tokens, num_experts, topk, provider):
|
||||
)
|
||||
elif provider == "triton":
|
||||
sorted_ids.fill_(topk_ids.numel())
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
lambda: moe_align_block_size_triton(
|
||||
topk_ids,
|
||||
num_experts,
|
||||
|
||||
@@ -63,7 +63,9 @@ def benchmark(batch_size, provider):
|
||||
block_size,
|
||||
)
|
||||
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(run_triton, quantiles=quantiles)
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
run_triton, quantiles=quantiles
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown provider: {provider}")
|
||||
|
||||
@@ -46,7 +46,7 @@ configs = [(sq,) for sq in seq_length_range]
|
||||
)
|
||||
)
|
||||
def benchmark(seq_length, provider):
|
||||
dtype = torch.bfloat16
|
||||
dtype = torch.float32
|
||||
device = torch.device("cuda")
|
||||
num_experts, num_expert_group, topk_group, topk = 256, 8, 4, 8
|
||||
|
||||
@@ -56,14 +56,14 @@ def benchmark(seq_length, provider):
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
if provider == "original":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
lambda: biased_grouped_topk_org(
|
||||
scores.clone(), bias.clone(), num_expert_group, topk_group, topk
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
elif provider == "kernel":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
lambda: biased_grouped_topk_org_fuse_kernel(
|
||||
scores.clone(), bias.clone(), num_expert_group, topk_group, topk
|
||||
),
|
||||
|
||||
@@ -97,7 +97,7 @@ def benchmark(num_tokens, num_experts, topk, provider):
|
||||
fn = lambda: sglang_topk_softmax(gating_output, topk)
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=quantiles)
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=quantiles)
|
||||
|
||||
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||
|
||||
|
||||
@@ -165,8 +165,6 @@ if __name__ == "__main__":
|
||||
KN_model_names = prepare_shapes(args)
|
||||
for K, N, model_name in KN_model_names:
|
||||
print(f"{model_name} N={N} K={K}: ")
|
||||
benchmark.run(
|
||||
print_data=True, show_plots=True, save_path="bench_fp4_res", N=N, K=K
|
||||
)
|
||||
benchmark.run(print_data=True, N=N, K=K)
|
||||
|
||||
print("Benchmark finished!")
|
||||
|
||||
@@ -88,7 +88,7 @@ def benchmark(batch_size, seq_len, provider):
|
||||
elif provider == "sglang":
|
||||
fn = lambda: sglang_scaled_fp8_quant(x.clone())
|
||||
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=quantiles)
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=quantiles)
|
||||
|
||||
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||
|
||||
|
||||
@@ -160,7 +160,7 @@ def benchmark_quantization(batch_size, seq_len, hidden_dim, provider):
|
||||
elif provider == "sglang":
|
||||
fn = lambda: sglang_per_token_quant_fp8(x.clone())
|
||||
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=quantiles)
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=quantiles)
|
||||
|
||||
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||
|
||||
|
||||
@@ -117,17 +117,17 @@ def benchmark(batch_size, provider, N, K):
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
if provider == "FP16":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
lambda: torch.matmul(a_fp16, b_fp16),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
if provider == "W8A8":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
lambda: int8_scaled_mm(a, b, scale_a, scale_b, torch.float16),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
if provider == "Qserve_W4A8_Per_Channel":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
lambda: qserve_w4a8_per_chn_gemm(
|
||||
a_qserve_chn,
|
||||
b_qserve_chn,
|
||||
@@ -139,7 +139,7 @@ def benchmark(batch_size, provider, N, K):
|
||||
quantiles=quantiles,
|
||||
)
|
||||
if provider == "Qserve_W4A8_Per_Group":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
lambda: qserve_w4a8_per_group_gemm(
|
||||
a_qserve_group,
|
||||
b_qserve_group,
|
||||
@@ -189,8 +189,6 @@ if __name__ == "__main__":
|
||||
print(f"{model_name} N={N} K={K}: ")
|
||||
benchmark.run(
|
||||
print_data=True,
|
||||
show_plots=True,
|
||||
save_path="bench_qserve_w4a8_gemm_res",
|
||||
N=N,
|
||||
K=K,
|
||||
)
|
||||
|
||||
318
sgl-kernel/benchmark/bench_rmsnorm.py
Normal file
318
sgl-kernel/benchmark/bench_rmsnorm.py
Normal file
@@ -0,0 +1,318 @@
|
||||
# Benchmarks SGLang RMSNorm kernels versus vLLM and FlashInfer across
|
||||
# (batch_size, seq_len, hidden_size) and prints speed-up.
|
||||
import argparse
|
||||
import itertools
|
||||
import re
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import sgl_kernel
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import triton
|
||||
import triton.testing
|
||||
from flashinfer.norm import fused_add_rmsnorm, rmsnorm
|
||||
from sgl_kernel.utils import is_arch_support_pdl
|
||||
from vllm import _custom_ops as vllm_ops
|
||||
|
||||
|
||||
def str2int_list(arg: str) -> List[int]:
|
||||
if arg in ("", None):
|
||||
return []
|
||||
if re.fullmatch(r"\d+(,\d+)*", arg.strip()) is None:
|
||||
raise argparse.ArgumentTypeError(f"Bad int list: {arg}")
|
||||
return [int(x) for x in arg.split(",")]
|
||||
|
||||
|
||||
class HuggingFaceRMSNorm(nn.Module):
|
||||
def __init__(self, hidden_size: int, eps: float = 1e-6) -> None:
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
orig_dtype = x.dtype
|
||||
x = x.to(torch.float32)
|
||||
if residual is not None:
|
||||
x = x + residual.to(torch.float32)
|
||||
residual = x.to(orig_dtype)
|
||||
|
||||
variance = x.pow(2).mean(dim=-1, keepdim=True)
|
||||
x = x * torch.rsqrt(variance + self.variance_epsilon)
|
||||
x = x.to(orig_dtype) * self.weight
|
||||
if residual is None:
|
||||
return x
|
||||
else:
|
||||
return x, residual
|
||||
|
||||
|
||||
def rmsnorm_naive(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
eps: float = 1e-6,
|
||||
):
|
||||
naive_norm = HuggingFaceRMSNorm(x.shape[-1], eps=eps)
|
||||
naive_norm.weight = nn.Parameter(weight)
|
||||
naive_norm = naive_norm.to(x.device)
|
||||
|
||||
orig_shape = x.shape
|
||||
x = x.view(-1, x.shape[-1])
|
||||
if residual is not None:
|
||||
residual = residual.view(-1, residual.shape[-1])
|
||||
|
||||
output = naive_norm(x, residual)
|
||||
|
||||
if isinstance(output, tuple):
|
||||
output = (output[0].view(orig_shape), output[1].view(orig_shape))
|
||||
else:
|
||||
output = output.view(orig_shape)
|
||||
return output
|
||||
|
||||
|
||||
def rmsnorm_flashinfer(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
eps: float = 1e-6,
|
||||
):
|
||||
orig_shape = x.shape
|
||||
x = x.view(-1, x.shape[-1])
|
||||
if residual is not None:
|
||||
residual = residual.view(-1, residual.shape[-1])
|
||||
|
||||
if residual is not None:
|
||||
fused_add_rmsnorm(x, residual, weight, eps)
|
||||
output = (x, residual)
|
||||
else:
|
||||
output = rmsnorm(x, weight, eps)
|
||||
|
||||
if isinstance(output, tuple):
|
||||
output = (output[0].view(orig_shape), output[1].view(orig_shape))
|
||||
else:
|
||||
output = output.view(orig_shape)
|
||||
return output
|
||||
|
||||
|
||||
def rmsnorm_vllm(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
eps: float = 1e-6,
|
||||
):
|
||||
orig_shape = x.shape
|
||||
x = x.view(-1, x.shape[-1])
|
||||
if residual is not None:
|
||||
residual = residual.view(-1, residual.shape[-1])
|
||||
|
||||
if residual is not None:
|
||||
vllm_ops.fused_add_rms_norm(x, residual, weight, eps)
|
||||
output = (x, residual)
|
||||
else:
|
||||
out = torch.empty_like(x)
|
||||
vllm_ops.rms_norm(out, x, weight, eps)
|
||||
output = out
|
||||
|
||||
if isinstance(output, tuple):
|
||||
output = (output[0].view(orig_shape), output[1].view(orig_shape))
|
||||
else:
|
||||
output = output.view(orig_shape)
|
||||
return output
|
||||
|
||||
|
||||
def rmsnorm_sglang(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
eps: float = 1e-6,
|
||||
enable_pdl: Optional[bool] = None,
|
||||
):
|
||||
orig_shape = x.shape
|
||||
x = x.view(-1, x.shape[-1])
|
||||
if residual is not None:
|
||||
residual = residual.view(-1, residual.shape[-1])
|
||||
|
||||
if enable_pdl is None:
|
||||
enable_pdl = is_arch_support_pdl()
|
||||
|
||||
if residual is not None:
|
||||
sgl_kernel.fused_add_rmsnorm(x, residual, weight, eps, enable_pdl=enable_pdl)
|
||||
output = (x, residual)
|
||||
else:
|
||||
out = torch.empty_like(x)
|
||||
sgl_kernel.rmsnorm(x, weight, eps, out=out, enable_pdl=enable_pdl)
|
||||
output = out
|
||||
|
||||
if isinstance(output, tuple):
|
||||
output = (output[0].view(orig_shape), output[1].view(orig_shape))
|
||||
else:
|
||||
output = output.view(orig_shape)
|
||||
return output
|
||||
|
||||
|
||||
def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True):
|
||||
dtype = torch.bfloat16
|
||||
x = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device="cuda")
|
||||
weight = torch.ones(hidden_size, dtype=dtype, device="cuda")
|
||||
residual = torch.randn_like(x) if use_residual else None
|
||||
|
||||
output_naive = rmsnorm_naive(
|
||||
x.clone(), weight, residual.clone() if residual is not None else None
|
||||
)
|
||||
output_flashinfer = rmsnorm_flashinfer(
|
||||
x.clone(), weight, residual.clone() if residual is not None else None
|
||||
)
|
||||
output_vllm = rmsnorm_vllm(
|
||||
x.clone(), weight, residual.clone() if residual is not None else None
|
||||
)
|
||||
output_sglang = rmsnorm_sglang(
|
||||
x.clone(), weight, residual.clone() if residual is not None else None
|
||||
)
|
||||
|
||||
if use_residual:
|
||||
output_naive = output_naive[0]
|
||||
output_flashinfer = output_flashinfer[0]
|
||||
output_vllm = output_vllm[0]
|
||||
output_sglang = output_sglang[0]
|
||||
|
||||
print(f"Naive output={output_naive}")
|
||||
print(f"FlashInfer output={output_flashinfer}")
|
||||
print(f"VLLM output={output_vllm}")
|
||||
print(f"SGLang output={output_sglang}")
|
||||
|
||||
if (
|
||||
torch.allclose(output_naive, output_flashinfer, atol=1e-2, rtol=1e-2)
|
||||
and torch.allclose(output_naive, output_vllm, atol=1e-2, rtol=1e-2)
|
||||
and torch.allclose(output_naive, output_sglang, atol=1e-2, rtol=1e-2)
|
||||
):
|
||||
print("✅ All implementations match")
|
||||
else:
|
||||
print("❌ Implementations differ")
|
||||
|
||||
|
||||
default_batch_sizes = [2**i for i in range(0, 7, 2)] # 1, 4, 16, 64
|
||||
default_seq_lens = [2**i for i in range(6, 11, 1)] # 64, 128, 256, 512, 1024
|
||||
default_hidden_sizes = [32 * 128, 48 * 128] # 4096, 6144
|
||||
|
||||
|
||||
def make_configs(bsizes: List[int], slens: List[int], hsizes: List[int]) -> List[Tuple]:
|
||||
return list(itertools.product(bsizes, slens, hsizes))
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size", "seq_len", "hidden_size"],
|
||||
x_vals=[],
|
||||
line_arg="provider",
|
||||
line_vals=["huggingface", "flashinfer", "vllm", "sglang"],
|
||||
line_names=["HuggingFace", "FlashInfer", "vLLM", "SGL Kernel"],
|
||||
styles=[("blue", "-"), ("green", "-"), ("red", "-"), ("orange", "-")],
|
||||
ylabel="µs (median) or × (speed-up)",
|
||||
plot_name="rmsnorm-performance",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(batch_size, seq_len, hidden_size, provider, use_residual):
|
||||
device = torch.device("cuda")
|
||||
dtype = torch.bfloat16
|
||||
|
||||
x = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device=device)
|
||||
weight = torch.ones(hidden_size, dtype=dtype, device=device)
|
||||
residual = torch.randn_like(x) if use_residual else None
|
||||
|
||||
# timing helper
|
||||
def timed(fn):
|
||||
for _ in range(5):
|
||||
fn()
|
||||
torch.cuda.synchronize()
|
||||
ms, qmin, qmax = triton.testing.do_bench_cudagraph(
|
||||
fn, quantiles=[0.5, 0.2, 0.8]
|
||||
)
|
||||
return 1000 * ms, 1000 * qmax, 1000 * qmin
|
||||
|
||||
if provider == "huggingface":
|
||||
return timed(
|
||||
lambda: rmsnorm_naive(
|
||||
x.clone(),
|
||||
weight,
|
||||
residual.clone() if residual is not None else None,
|
||||
)
|
||||
)
|
||||
elif provider == "flashinfer":
|
||||
return timed(
|
||||
lambda: rmsnorm_flashinfer(
|
||||
x.clone(),
|
||||
weight,
|
||||
residual.clone() if residual is not None else None,
|
||||
)
|
||||
)
|
||||
elif provider == "vllm":
|
||||
return timed(
|
||||
lambda: rmsnorm_vllm(
|
||||
x.clone(),
|
||||
weight,
|
||||
residual.clone() if residual is not None else None,
|
||||
)
|
||||
)
|
||||
elif provider == "sglang":
|
||||
return timed(
|
||||
lambda: rmsnorm_sglang(
|
||||
x.clone(),
|
||||
weight,
|
||||
residual.clone() if residual is not None else None,
|
||||
)
|
||||
)
|
||||
|
||||
# provider == "speedup"
|
||||
t_ref, _, _ = timed(
|
||||
lambda: rmsnorm_vllm(
|
||||
x.clone(),
|
||||
weight,
|
||||
residual.clone() if residual is not None else None,
|
||||
)
|
||||
)
|
||||
t_sgl, _, _ = timed(
|
||||
lambda: rmsnorm_sglang(
|
||||
x.clone(),
|
||||
weight,
|
||||
residual.clone() if residual is not None else None,
|
||||
)
|
||||
)
|
||||
spd = t_ref / t_sgl
|
||||
return (spd, spd, spd)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
p = argparse.ArgumentParser("RMSNorm kernel benchmark")
|
||||
p.add_argument("--batch_sizes", type=str2int_list, default=default_batch_sizes)
|
||||
p.add_argument("--seq_lens", type=str2int_list, default=default_seq_lens)
|
||||
p.add_argument("--hidden_sizes", type=str2int_list, default=default_hidden_sizes)
|
||||
p.add_argument(
|
||||
"--use_residual", action="store_true", help="Whether to use residual connection"
|
||||
)
|
||||
p.add_argument("--verify_only", action="store_true")
|
||||
args = p.parse_args()
|
||||
|
||||
# coerce lists
|
||||
if isinstance(args.batch_sizes, str):
|
||||
args.batch_sizes = str2int_list(args.batch_sizes)
|
||||
if isinstance(args.seq_lens, str):
|
||||
args.seq_lens = str2int_list(args.seq_lens)
|
||||
if isinstance(args.hidden_sizes, str):
|
||||
args.hidden_sizes = str2int_list(args.hidden_sizes)
|
||||
|
||||
# patch perf_report grid
|
||||
benchmark_grid = make_configs(args.batch_sizes, args.seq_lens, args.hidden_sizes)
|
||||
if hasattr(benchmark, "benchmarks"):
|
||||
benchmark.benchmarks.x_vals = benchmark_grid
|
||||
else:
|
||||
benchmark.benchmark.x_vals = benchmark_grid
|
||||
|
||||
if args.verify_only:
|
||||
ok = calculate_diff(4, 128, args.hidden_sizes[0], args.use_residual)
|
||||
print("✅ sanity pass" if ok else "❌ mismatch")
|
||||
else:
|
||||
benchmark.run(print_data=True, use_residual=args.use_residual)
|
||||
@@ -114,7 +114,9 @@ def benchmark_sampling(batch_size, vocab_size, p, provider):
|
||||
filter_apply_order="joint",
|
||||
)
|
||||
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=[0.5, 0.2, 0.8])
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
fn, quantiles=[0.5, 0.2, 0.8]
|
||||
)
|
||||
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
import pytest
|
||||
import sgl_kernel
|
||||
import torch
|
||||
from sgl_kernel.utils import is_arch_support_pdl
|
||||
|
||||
|
||||
def llama_rms_norm(x, w, eps=1e-6):
|
||||
@@ -58,11 +59,12 @@ def test_norm(batch_size, hidden_size, dtype, specify_out):
|
||||
w = torch.randn(hidden_size).to(0).to(dtype)
|
||||
|
||||
y_ref = llama_rms_norm(x, w)
|
||||
enable_pdl = is_arch_support_pdl()
|
||||
if specify_out:
|
||||
y = torch.empty_like(x)
|
||||
sgl_kernel.rmsnorm(x, w, out=y)
|
||||
sgl_kernel.rmsnorm(x, w, out=y, enable_pdl=enable_pdl)
|
||||
else:
|
||||
y = sgl_kernel.rmsnorm(x, w)
|
||||
y = sgl_kernel.rmsnorm(x, w, enable_pdl=enable_pdl)
|
||||
|
||||
torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3)
|
||||
|
||||
@@ -83,7 +85,10 @@ def test_fused_add_rmsnorm(batch_size, hidden_size, dtype):
|
||||
|
||||
x_fused = x.clone()
|
||||
residual_fused = residual.clone()
|
||||
sgl_kernel.fused_add_rmsnorm(x_fused, residual_fused, weight, eps)
|
||||
enable_pdl = is_arch_support_pdl()
|
||||
sgl_kernel.fused_add_rmsnorm(
|
||||
x_fused, residual_fused, weight, eps, enable_pdl=enable_pdl
|
||||
)
|
||||
|
||||
torch.testing.assert_close(x_fused, x_native, rtol=1e-3, atol=1e-3)
|
||||
torch.testing.assert_close(residual_fused, residual_native, rtol=1e-3, atol=1e-3)
|
||||
@@ -98,11 +103,12 @@ def test_gemma_norm(batch_size, hidden_size, dtype, specify_out):
|
||||
w = torch.randn(hidden_size).to(0).to(dtype)
|
||||
|
||||
y_ref = gemma_rms_norm(x, w)
|
||||
enable_pdl = is_arch_support_pdl()
|
||||
if specify_out:
|
||||
y = torch.empty_like(x)
|
||||
sgl_kernel.gemma_rmsnorm(x, w, out=y)
|
||||
sgl_kernel.gemma_rmsnorm(x, w, out=y, enable_pdl=enable_pdl)
|
||||
else:
|
||||
y = sgl_kernel.gemma_rmsnorm(x, w)
|
||||
y = sgl_kernel.gemma_rmsnorm(x, w, enable_pdl=enable_pdl)
|
||||
|
||||
torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3)
|
||||
|
||||
@@ -123,7 +129,10 @@ def test_gemma_fused_add_rmsnorm(batch_size, hidden_size, dtype):
|
||||
|
||||
x_fused = x.clone()
|
||||
residual_fused = residual.clone()
|
||||
sgl_kernel.gemma_fused_add_rmsnorm(x_fused, residual_fused, weight, eps)
|
||||
enable_pdl = is_arch_support_pdl()
|
||||
sgl_kernel.gemma_fused_add_rmsnorm(
|
||||
x_fused, residual_fused, weight, eps, enable_pdl=enable_pdl
|
||||
)
|
||||
|
||||
torch.testing.assert_close(x_fused, x_native, rtol=1e-3, atol=1e-3)
|
||||
torch.testing.assert_close(residual_fused, residual_native, rtol=1e-3, atol=1e-3)
|
||||
|
||||
Reference in New Issue
Block a user