Restruct sgl-kernel benchmark (#10861)

This commit is contained in:
Xiaoyu Zhang
2025-09-25 07:45:25 +08:00
committed by GitHub
parent 7a06ef984d
commit c4e314f986
27 changed files with 425 additions and 319 deletions

View File

@@ -5,7 +5,8 @@ import tilelang
import tilelang.language as T
import torch
import triton
from deep_gemm import ceil_div, get_col_major_tma_aligned_tensor
from deep_gemm import ceil_div
from deep_gemm.utils.layout import get_mn_major_tma_aligned_tensor
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
w8a8_block_fp8_matmul as vllm_w8a8_block_fp8_matmul,
)
@@ -131,7 +132,7 @@ def fp8_gemm_deepgemm(
out = torch.empty((m, n), device="cuda", dtype=torch.bfloat16)
# Run DeepGEMM kernel
deep_gemm.gemm_fp8_fp8_bf16_nt((x_fp8, x_scale), (y_fp8, y_scale), out)
deep_gemm.fp8_gemm_nt((x_fp8, x_scale), (y_fp8, y_scale), out)
return out
@@ -179,7 +180,7 @@ def calculate_diff(m: int, n: int, k: int):
x_fp8, x_scale = per_token_cast_to_fp8(x.clone())
y_fp8, y_scale = per_block_cast_to_fp8(y.clone())
x_scale_col_major = get_col_major_tma_aligned_tensor(x_scale.clone())
x_scale_col_major = get_mn_major_tma_aligned_tensor(x_scale.clone())
out_deepgemm = fp8_gemm_deepgemm(
x_fp8.clone(),
@@ -300,7 +301,7 @@ def get_benchmark(tp_size):
# Preprocess data before benchmarking
x_fp8, x_scale = per_token_cast_to_fp8(x)
y_fp8, y_scale = per_block_cast_to_fp8(y)
x_scale_col_major = get_col_major_tma_aligned_tensor(x_scale.clone())
x_scale_col_major = get_mn_major_tma_aligned_tensor(x_scale.clone())
quantiles = [0.5, 0.2, 0.8]

View File

@@ -4,7 +4,8 @@ import deep_gemm
import torch
import triton
import triton.language as tl
from deep_gemm import calc_diff, get_col_major_tma_aligned_tensor
from deep_gemm import calc_diff
from deep_gemm.utils.layout import get_mn_major_tma_aligned_tensor
# Import shared functionality from the regular GEMM benchmark
from sglang.benchmark.kernels.deepseek.benchmark_deepgemm_fp8_gemm import (
@@ -71,9 +72,9 @@ def construct_grouped_and_flat_fp8(
# Transpose earlier for testing
x_fp8_grouped = (
x_fp8_grouped[0],
get_col_major_tma_aligned_tensor(x_fp8_grouped[1]),
get_mn_major_tma_aligned_tensor(x_fp8_grouped[1]),
)
x_fp8_flat = (x_fp8_flat[0], get_col_major_tma_aligned_tensor(x_fp8_flat[1]))
x_fp8_flat = (x_fp8_flat[0], get_mn_major_tma_aligned_tensor(x_fp8_flat[1]))
return x_fp8_grouped, y_fp8_grouped, x_fp8_flat, y_fp8_flat, out, ref_out
@@ -240,7 +241,7 @@ def fp8_gemm_group_triton(a_tuple, b_tuple, c, num_groups):
def fp8_gemm_group_deepgemm(x_fp8_grouped, y_fp8_grouped, out, m_indices):
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
deep_gemm.m_grouped_fp8_gemm_nt_contiguous(
x_fp8_grouped,
y_fp8_grouped,
out,

View File

@@ -1,230 +0,0 @@
import itertools
from typing import Optional, Tuple, Union
import torch
import triton
from flashinfer.norm import fused_add_rmsnorm, rmsnorm
from torch import nn
from vllm import _custom_ops as vllm_ops
class HuggingFaceRMSNorm(nn.Module):
def __init__(self, hidden_size: int, eps: float = 1e-6) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
orig_dtype = x.dtype
x = x.to(torch.float32)
if residual is not None:
x = x + residual.to(torch.float32)
residual = x.to(orig_dtype)
variance = x.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + self.variance_epsilon)
x = x.to(orig_dtype) * self.weight
if residual is None:
return x
else:
return x, residual
def rmsnorm_naive(
x: torch.Tensor,
weight: torch.Tensor,
residual: Optional[torch.Tensor] = None,
eps: float = 1e-6,
):
naive_norm = HuggingFaceRMSNorm(x.shape[-1], eps=eps)
naive_norm.weight = nn.Parameter(weight)
naive_norm = naive_norm.to(x.device)
orig_shape = x.shape
x = x.view(-1, x.shape[-1])
if residual is not None:
residual = residual.view(-1, residual.shape[-1])
output = naive_norm(x, residual)
if isinstance(output, tuple):
output = (output[0].view(orig_shape), output[1].view(orig_shape))
else:
output = output.view(orig_shape)
return output
def rmsnorm_flashinfer(
x: torch.Tensor,
weight: torch.Tensor,
residual: Optional[torch.Tensor] = None,
eps: float = 1e-6,
):
orig_shape = x.shape
x = x.view(-1, x.shape[-1])
if residual is not None:
residual = residual.view(-1, residual.shape[-1])
if residual is not None:
fused_add_rmsnorm(x, residual, weight, eps)
output = (x, residual)
else:
output = rmsnorm(x, weight, eps)
if isinstance(output, tuple):
output = (output[0].view(orig_shape), output[1].view(orig_shape))
else:
output = output.view(orig_shape)
return output
def rmsnorm_vllm(
x: torch.Tensor,
weight: torch.Tensor,
residual: Optional[torch.Tensor] = None,
eps: float = 1e-6,
):
orig_shape = x.shape
x = x.view(-1, x.shape[-1])
if residual is not None:
residual = residual.view(-1, residual.shape[-1])
if residual is not None:
vllm_ops.fused_add_rms_norm(x, residual, weight, eps)
output = (x, residual)
else:
out = torch.empty_like(x)
vllm_ops.rms_norm(out, x, weight, eps)
output = out
if isinstance(output, tuple):
output = (output[0].view(orig_shape), output[1].view(orig_shape))
else:
output = output.view(orig_shape)
return output
def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True):
dtype = torch.bfloat16
x = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device="cuda")
weight = torch.ones(hidden_size, dtype=dtype, device="cuda")
residual = torch.randn_like(x) if use_residual else None
output_naive = rmsnorm_naive(
x.clone(), weight, residual.clone() if residual is not None else None
)
output_flashinfer = rmsnorm_flashinfer(
x.clone(), weight, residual.clone() if residual is not None else None
)
output_vllm = rmsnorm_vllm(
x.clone(), weight, residual.clone() if residual is not None else None
)
if use_residual:
output_naive = output_naive[0]
output_flashinfer = output_flashinfer[0]
output_vllm = output_vllm[0]
print(f"Naive output={output_naive}")
print(f"FlashInfer output={output_flashinfer}")
print(f"VLLM output={output_vllm}")
if torch.allclose(
output_naive, output_flashinfer, atol=1e-2, rtol=1e-2
) and torch.allclose(output_naive, output_vllm, atol=1e-2, rtol=1e-2):
print("✅ All implementations match")
else:
print("❌ Implementations differ")
batch_size_range = [2**i for i in range(0, 7, 2)]
seq_length_range = [2**i for i in range(6, 11, 1)]
head_num_range = [32, 48]
configs = list(itertools.product(head_num_range, batch_size_range, seq_length_range))
def get_benchmark(use_residual):
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["head_num", "batch_size", "seq_len"],
x_vals=[list(_) for _ in configs],
line_arg="provider",
line_vals=["huggingface", "flashinfer", "vllm"],
line_names=["HuggingFace", "FlashInfer", "vLLM"],
styles=[("blue", "-"), ("green", "-"), ("red", "-")],
ylabel="us",
plot_name=f"rmsnorm-performance-{'with' if use_residual else 'without'}-residual",
args={},
)
)
def benchmark(head_num, batch_size, seq_len, provider):
dtype = torch.bfloat16
hidden_size = head_num * 128 # assuming head_dim = 128
x = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device="cuda")
weight = torch.ones(hidden_size, dtype=dtype, device="cuda")
residual = torch.randn_like(x) if use_residual else None
quantiles = [0.5, 0.2, 0.8]
if provider == "huggingface":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: rmsnorm_naive(
x.clone(),
weight,
residual.clone() if residual is not None else None,
),
quantiles=quantiles,
)
elif provider == "flashinfer":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: rmsnorm_flashinfer(
x.clone(),
weight,
residual.clone() if residual is not None else None,
),
quantiles=quantiles,
)
else:
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: rmsnorm_vllm(
x.clone(),
weight,
residual.clone() if residual is not None else None,
),
quantiles=quantiles,
)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
return benchmark
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--use_residual", action="store_true", help="Whether to use residual connection"
)
parser.add_argument(
"--save_path",
type=str,
default="./configs/benchmark_ops/rmsnorm/",
help="Path to save rmsnorm benchmark results",
)
args = parser.parse_args()
# Run correctness test
calculate_diff(
batch_size=4, seq_len=128, hidden_size=4096, use_residual=args.use_residual
)
# Get the benchmark function with proper use_residual setting
benchmark = get_benchmark(args.use_residual)
# Run performance benchmark
benchmark.run(print_data=True, save_path=args.save_path)