Restruct sgl-kernel benchmark (#10861)
This commit is contained in:
@@ -5,7 +5,8 @@ import tilelang
|
|||||||
import tilelang.language as T
|
import tilelang.language as T
|
||||||
import torch
|
import torch
|
||||||
import triton
|
import triton
|
||||||
from deep_gemm import ceil_div, get_col_major_tma_aligned_tensor
|
from deep_gemm import ceil_div
|
||||||
|
from deep_gemm.utils.layout import get_mn_major_tma_aligned_tensor
|
||||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||||
w8a8_block_fp8_matmul as vllm_w8a8_block_fp8_matmul,
|
w8a8_block_fp8_matmul as vllm_w8a8_block_fp8_matmul,
|
||||||
)
|
)
|
||||||
@@ -131,7 +132,7 @@ def fp8_gemm_deepgemm(
|
|||||||
out = torch.empty((m, n), device="cuda", dtype=torch.bfloat16)
|
out = torch.empty((m, n), device="cuda", dtype=torch.bfloat16)
|
||||||
|
|
||||||
# Run DeepGEMM kernel
|
# Run DeepGEMM kernel
|
||||||
deep_gemm.gemm_fp8_fp8_bf16_nt((x_fp8, x_scale), (y_fp8, y_scale), out)
|
deep_gemm.fp8_gemm_nt((x_fp8, x_scale), (y_fp8, y_scale), out)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
@@ -179,7 +180,7 @@ def calculate_diff(m: int, n: int, k: int):
|
|||||||
|
|
||||||
x_fp8, x_scale = per_token_cast_to_fp8(x.clone())
|
x_fp8, x_scale = per_token_cast_to_fp8(x.clone())
|
||||||
y_fp8, y_scale = per_block_cast_to_fp8(y.clone())
|
y_fp8, y_scale = per_block_cast_to_fp8(y.clone())
|
||||||
x_scale_col_major = get_col_major_tma_aligned_tensor(x_scale.clone())
|
x_scale_col_major = get_mn_major_tma_aligned_tensor(x_scale.clone())
|
||||||
|
|
||||||
out_deepgemm = fp8_gemm_deepgemm(
|
out_deepgemm = fp8_gemm_deepgemm(
|
||||||
x_fp8.clone(),
|
x_fp8.clone(),
|
||||||
@@ -300,7 +301,7 @@ def get_benchmark(tp_size):
|
|||||||
# Preprocess data before benchmarking
|
# Preprocess data before benchmarking
|
||||||
x_fp8, x_scale = per_token_cast_to_fp8(x)
|
x_fp8, x_scale = per_token_cast_to_fp8(x)
|
||||||
y_fp8, y_scale = per_block_cast_to_fp8(y)
|
y_fp8, y_scale = per_block_cast_to_fp8(y)
|
||||||
x_scale_col_major = get_col_major_tma_aligned_tensor(x_scale.clone())
|
x_scale_col_major = get_mn_major_tma_aligned_tensor(x_scale.clone())
|
||||||
|
|
||||||
quantiles = [0.5, 0.2, 0.8]
|
quantiles = [0.5, 0.2, 0.8]
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,8 @@ import deep_gemm
|
|||||||
import torch
|
import torch
|
||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
from deep_gemm import calc_diff, get_col_major_tma_aligned_tensor
|
from deep_gemm import calc_diff
|
||||||
|
from deep_gemm.utils.layout import get_mn_major_tma_aligned_tensor
|
||||||
|
|
||||||
# Import shared functionality from the regular GEMM benchmark
|
# Import shared functionality from the regular GEMM benchmark
|
||||||
from sglang.benchmark.kernels.deepseek.benchmark_deepgemm_fp8_gemm import (
|
from sglang.benchmark.kernels.deepseek.benchmark_deepgemm_fp8_gemm import (
|
||||||
@@ -71,9 +72,9 @@ def construct_grouped_and_flat_fp8(
|
|||||||
# Transpose earlier for testing
|
# Transpose earlier for testing
|
||||||
x_fp8_grouped = (
|
x_fp8_grouped = (
|
||||||
x_fp8_grouped[0],
|
x_fp8_grouped[0],
|
||||||
get_col_major_tma_aligned_tensor(x_fp8_grouped[1]),
|
get_mn_major_tma_aligned_tensor(x_fp8_grouped[1]),
|
||||||
)
|
)
|
||||||
x_fp8_flat = (x_fp8_flat[0], get_col_major_tma_aligned_tensor(x_fp8_flat[1]))
|
x_fp8_flat = (x_fp8_flat[0], get_mn_major_tma_aligned_tensor(x_fp8_flat[1]))
|
||||||
|
|
||||||
return x_fp8_grouped, y_fp8_grouped, x_fp8_flat, y_fp8_flat, out, ref_out
|
return x_fp8_grouped, y_fp8_grouped, x_fp8_flat, y_fp8_flat, out, ref_out
|
||||||
|
|
||||||
@@ -240,7 +241,7 @@ def fp8_gemm_group_triton(a_tuple, b_tuple, c, num_groups):
|
|||||||
|
|
||||||
|
|
||||||
def fp8_gemm_group_deepgemm(x_fp8_grouped, y_fp8_grouped, out, m_indices):
|
def fp8_gemm_group_deepgemm(x_fp8_grouped, y_fp8_grouped, out, m_indices):
|
||||||
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
deep_gemm.m_grouped_fp8_gemm_nt_contiguous(
|
||||||
x_fp8_grouped,
|
x_fp8_grouped,
|
||||||
y_fp8_grouped,
|
y_fp8_grouped,
|
||||||
out,
|
out,
|
||||||
|
|||||||
@@ -1,230 +0,0 @@
|
|||||||
import itertools
|
|
||||||
from typing import Optional, Tuple, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import triton
|
|
||||||
from flashinfer.norm import fused_add_rmsnorm, rmsnorm
|
|
||||||
from torch import nn
|
|
||||||
from vllm import _custom_ops as vllm_ops
|
|
||||||
|
|
||||||
|
|
||||||
class HuggingFaceRMSNorm(nn.Module):
|
|
||||||
def __init__(self, hidden_size: int, eps: float = 1e-6) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
|
||||||
self.variance_epsilon = eps
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
x: torch.Tensor,
|
|
||||||
residual: Optional[torch.Tensor] = None,
|
|
||||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
|
||||||
orig_dtype = x.dtype
|
|
||||||
x = x.to(torch.float32)
|
|
||||||
if residual is not None:
|
|
||||||
x = x + residual.to(torch.float32)
|
|
||||||
residual = x.to(orig_dtype)
|
|
||||||
|
|
||||||
variance = x.pow(2).mean(dim=-1, keepdim=True)
|
|
||||||
x = x * torch.rsqrt(variance + self.variance_epsilon)
|
|
||||||
x = x.to(orig_dtype) * self.weight
|
|
||||||
if residual is None:
|
|
||||||
return x
|
|
||||||
else:
|
|
||||||
return x, residual
|
|
||||||
|
|
||||||
|
|
||||||
def rmsnorm_naive(
|
|
||||||
x: torch.Tensor,
|
|
||||||
weight: torch.Tensor,
|
|
||||||
residual: Optional[torch.Tensor] = None,
|
|
||||||
eps: float = 1e-6,
|
|
||||||
):
|
|
||||||
naive_norm = HuggingFaceRMSNorm(x.shape[-1], eps=eps)
|
|
||||||
naive_norm.weight = nn.Parameter(weight)
|
|
||||||
naive_norm = naive_norm.to(x.device)
|
|
||||||
|
|
||||||
orig_shape = x.shape
|
|
||||||
x = x.view(-1, x.shape[-1])
|
|
||||||
if residual is not None:
|
|
||||||
residual = residual.view(-1, residual.shape[-1])
|
|
||||||
|
|
||||||
output = naive_norm(x, residual)
|
|
||||||
|
|
||||||
if isinstance(output, tuple):
|
|
||||||
output = (output[0].view(orig_shape), output[1].view(orig_shape))
|
|
||||||
else:
|
|
||||||
output = output.view(orig_shape)
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
def rmsnorm_flashinfer(
|
|
||||||
x: torch.Tensor,
|
|
||||||
weight: torch.Tensor,
|
|
||||||
residual: Optional[torch.Tensor] = None,
|
|
||||||
eps: float = 1e-6,
|
|
||||||
):
|
|
||||||
orig_shape = x.shape
|
|
||||||
x = x.view(-1, x.shape[-1])
|
|
||||||
if residual is not None:
|
|
||||||
residual = residual.view(-1, residual.shape[-1])
|
|
||||||
|
|
||||||
if residual is not None:
|
|
||||||
fused_add_rmsnorm(x, residual, weight, eps)
|
|
||||||
output = (x, residual)
|
|
||||||
else:
|
|
||||||
output = rmsnorm(x, weight, eps)
|
|
||||||
|
|
||||||
if isinstance(output, tuple):
|
|
||||||
output = (output[0].view(orig_shape), output[1].view(orig_shape))
|
|
||||||
else:
|
|
||||||
output = output.view(orig_shape)
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
def rmsnorm_vllm(
|
|
||||||
x: torch.Tensor,
|
|
||||||
weight: torch.Tensor,
|
|
||||||
residual: Optional[torch.Tensor] = None,
|
|
||||||
eps: float = 1e-6,
|
|
||||||
):
|
|
||||||
orig_shape = x.shape
|
|
||||||
x = x.view(-1, x.shape[-1])
|
|
||||||
if residual is not None:
|
|
||||||
residual = residual.view(-1, residual.shape[-1])
|
|
||||||
|
|
||||||
if residual is not None:
|
|
||||||
vllm_ops.fused_add_rms_norm(x, residual, weight, eps)
|
|
||||||
output = (x, residual)
|
|
||||||
else:
|
|
||||||
out = torch.empty_like(x)
|
|
||||||
vllm_ops.rms_norm(out, x, weight, eps)
|
|
||||||
output = out
|
|
||||||
|
|
||||||
if isinstance(output, tuple):
|
|
||||||
output = (output[0].view(orig_shape), output[1].view(orig_shape))
|
|
||||||
else:
|
|
||||||
output = output.view(orig_shape)
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True):
|
|
||||||
dtype = torch.bfloat16
|
|
||||||
x = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device="cuda")
|
|
||||||
weight = torch.ones(hidden_size, dtype=dtype, device="cuda")
|
|
||||||
residual = torch.randn_like(x) if use_residual else None
|
|
||||||
|
|
||||||
output_naive = rmsnorm_naive(
|
|
||||||
x.clone(), weight, residual.clone() if residual is not None else None
|
|
||||||
)
|
|
||||||
output_flashinfer = rmsnorm_flashinfer(
|
|
||||||
x.clone(), weight, residual.clone() if residual is not None else None
|
|
||||||
)
|
|
||||||
output_vllm = rmsnorm_vllm(
|
|
||||||
x.clone(), weight, residual.clone() if residual is not None else None
|
|
||||||
)
|
|
||||||
|
|
||||||
if use_residual:
|
|
||||||
output_naive = output_naive[0]
|
|
||||||
output_flashinfer = output_flashinfer[0]
|
|
||||||
output_vllm = output_vllm[0]
|
|
||||||
|
|
||||||
print(f"Naive output={output_naive}")
|
|
||||||
print(f"FlashInfer output={output_flashinfer}")
|
|
||||||
print(f"VLLM output={output_vllm}")
|
|
||||||
|
|
||||||
if torch.allclose(
|
|
||||||
output_naive, output_flashinfer, atol=1e-2, rtol=1e-2
|
|
||||||
) and torch.allclose(output_naive, output_vllm, atol=1e-2, rtol=1e-2):
|
|
||||||
print("✅ All implementations match")
|
|
||||||
else:
|
|
||||||
print("❌ Implementations differ")
|
|
||||||
|
|
||||||
|
|
||||||
batch_size_range = [2**i for i in range(0, 7, 2)]
|
|
||||||
seq_length_range = [2**i for i in range(6, 11, 1)]
|
|
||||||
head_num_range = [32, 48]
|
|
||||||
configs = list(itertools.product(head_num_range, batch_size_range, seq_length_range))
|
|
||||||
|
|
||||||
|
|
||||||
def get_benchmark(use_residual):
|
|
||||||
@triton.testing.perf_report(
|
|
||||||
triton.testing.Benchmark(
|
|
||||||
x_names=["head_num", "batch_size", "seq_len"],
|
|
||||||
x_vals=[list(_) for _ in configs],
|
|
||||||
line_arg="provider",
|
|
||||||
line_vals=["huggingface", "flashinfer", "vllm"],
|
|
||||||
line_names=["HuggingFace", "FlashInfer", "vLLM"],
|
|
||||||
styles=[("blue", "-"), ("green", "-"), ("red", "-")],
|
|
||||||
ylabel="us",
|
|
||||||
plot_name=f"rmsnorm-performance-{'with' if use_residual else 'without'}-residual",
|
|
||||||
args={},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
def benchmark(head_num, batch_size, seq_len, provider):
|
|
||||||
dtype = torch.bfloat16
|
|
||||||
hidden_size = head_num * 128 # assuming head_dim = 128
|
|
||||||
|
|
||||||
x = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device="cuda")
|
|
||||||
weight = torch.ones(hidden_size, dtype=dtype, device="cuda")
|
|
||||||
residual = torch.randn_like(x) if use_residual else None
|
|
||||||
|
|
||||||
quantiles = [0.5, 0.2, 0.8]
|
|
||||||
|
|
||||||
if provider == "huggingface":
|
|
||||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
|
||||||
lambda: rmsnorm_naive(
|
|
||||||
x.clone(),
|
|
||||||
weight,
|
|
||||||
residual.clone() if residual is not None else None,
|
|
||||||
),
|
|
||||||
quantiles=quantiles,
|
|
||||||
)
|
|
||||||
elif provider == "flashinfer":
|
|
||||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
|
||||||
lambda: rmsnorm_flashinfer(
|
|
||||||
x.clone(),
|
|
||||||
weight,
|
|
||||||
residual.clone() if residual is not None else None,
|
|
||||||
),
|
|
||||||
quantiles=quantiles,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
|
||||||
lambda: rmsnorm_vllm(
|
|
||||||
x.clone(),
|
|
||||||
weight,
|
|
||||||
residual.clone() if residual is not None else None,
|
|
||||||
),
|
|
||||||
quantiles=quantiles,
|
|
||||||
)
|
|
||||||
|
|
||||||
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
|
||||||
|
|
||||||
return benchmark
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import argparse
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument(
|
|
||||||
"--use_residual", action="store_true", help="Whether to use residual connection"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--save_path",
|
|
||||||
type=str,
|
|
||||||
default="./configs/benchmark_ops/rmsnorm/",
|
|
||||||
help="Path to save rmsnorm benchmark results",
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
# Run correctness test
|
|
||||||
calculate_diff(
|
|
||||||
batch_size=4, seq_len=128, hidden_size=4096, use_residual=args.use_residual
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get the benchmark function with proper use_residual setting
|
|
||||||
benchmark = get_benchmark(args.use_residual)
|
|
||||||
# Run performance benchmark
|
|
||||||
benchmark.run(print_data=True, save_path=args.save_path)
|
|
||||||
@@ -19,10 +19,6 @@ from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
|||||||
from sglang.srt.layers.quantization.utils import is_layer_skipped
|
from sglang.srt.layers.quantization.utils import is_layer_skipped
|
||||||
from sglang.srt.utils import is_npu, set_weight_attrs
|
from sglang.srt.utils import is_npu, set_weight_attrs
|
||||||
|
|
||||||
_is_npu = is_npu()
|
|
||||||
if not _is_npu:
|
|
||||||
from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.layers.moe import MoeRunnerConfig
|
from sglang.srt.layers.moe import MoeRunnerConfig
|
||||||
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
||||||
|
|||||||
@@ -621,11 +621,11 @@ class TestW8A8BlockFP8BatchedDeepGemm(CustomTestCase):
|
|||||||
w_s,
|
w_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
from deep_gemm import m_grouped_gemm_fp8_fp8_bf16_nt_masked
|
from deep_gemm import fp8_m_grouped_gemm_nt_masked
|
||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
ref_out = torch_w8a8_block_fp8_bmm(a, a_s, w, w_s, block_size, dtype)
|
ref_out = torch_w8a8_block_fp8_bmm(a, a_s, w, w_s, block_size, dtype)
|
||||||
m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs, rhs, oe, masked_m, expected_m)
|
fp8_m_grouped_gemm_nt_masked(lhs, rhs, oe, masked_m, expected_m)
|
||||||
out = oe[:, :M, :]
|
out = oe[:, :M, :]
|
||||||
|
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
|
|||||||
@@ -251,6 +251,14 @@ To use this with your library functions, simply wrap them with make_pytorch_shim
|
|||||||
```
|
```
|
||||||
|
|
||||||
2. Add benchmarks using [triton benchmark](https://triton-lang.org/main/python-api/generated/triton.testing.Benchmark.html) in [benchmark/](https://github.com/sgl-project/sglang/tree/main/sgl-kernel/benchmark)
|
2. Add benchmarks using [triton benchmark](https://triton-lang.org/main/python-api/generated/triton.testing.Benchmark.html) in [benchmark/](https://github.com/sgl-project/sglang/tree/main/sgl-kernel/benchmark)
|
||||||
|
|
||||||
|
**We recommend using `triton.testing.do_bench_cudagraph` for kernel benchmarking**:
|
||||||
|
|
||||||
|
Compared to `triton.testing.do_bench`, `do_bench_cudagraph` provides:
|
||||||
|
- Reduced CPU overhead impact for more accurate kernel performance measurements
|
||||||
|
- Incorporation of PDL (Programmatic Dependent Launch) effects into individual kernel results
|
||||||
|
- More realistic performance data on PDL-supported architectures (SM >= 90)
|
||||||
|
|
||||||
3. Run test suite
|
3. Run test suite
|
||||||
|
|
||||||
### FAQ
|
### FAQ
|
||||||
|
|||||||
@@ -10,10 +10,18 @@ import torch
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import triton
|
import triton
|
||||||
import triton.testing
|
import triton.testing
|
||||||
from sgl_kernel import gelu_quick # activation-only kernel
|
|
||||||
from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
|
from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
|
||||||
from vllm import _custom_ops as vllm_ops
|
from vllm import _custom_ops as vllm_ops
|
||||||
|
|
||||||
|
# gelu_quick is only available on HIP/ROCm platforms
|
||||||
|
try:
|
||||||
|
from sgl_kernel import gelu_quick
|
||||||
|
|
||||||
|
GELU_QUICK_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
GELU_QUICK_AVAILABLE = False
|
||||||
|
gelu_quick = None
|
||||||
|
|
||||||
if not hasattr(vllm_ops, "silu_and_mul"):
|
if not hasattr(vllm_ops, "silu_and_mul"):
|
||||||
vllm_ops = torch.ops._C
|
vllm_ops = torch.ops._C
|
||||||
|
|
||||||
@@ -34,6 +42,12 @@ def calculate_diff(
|
|||||||
|
|
||||||
# activation-only quick GELU
|
# activation-only quick GELU
|
||||||
if kernel == "gelu_quick":
|
if kernel == "gelu_quick":
|
||||||
|
if not GELU_QUICK_AVAILABLE:
|
||||||
|
print(
|
||||||
|
f"[{kernel:14s} | {str(dtype):9s} | B={batch_size:3d} | "
|
||||||
|
f"L={seq_len:3d} | D={dim:5d}] ⚠️ not available on this platform"
|
||||||
|
)
|
||||||
|
return True
|
||||||
x = torch.randn(batch_size, seq_len, dim, dtype=dtype, device=device)
|
x = torch.randn(batch_size, seq_len, dim, dtype=dtype, device=device)
|
||||||
ref_out = torch.zeros_like(x)
|
ref_out = torch.zeros_like(x)
|
||||||
getattr(vllm_ops, kernel)(ref_out, x)
|
getattr(vllm_ops, kernel)(ref_out, x)
|
||||||
@@ -54,7 +68,9 @@ def calculate_diff(
|
|||||||
return ok
|
return ok
|
||||||
|
|
||||||
|
|
||||||
kernels = ["silu_and_mul", "gelu_and_mul", "gelu_tanh_and_mul", "gelu_quick"]
|
kernels = ["silu_and_mul", "gelu_and_mul", "gelu_tanh_and_mul"]
|
||||||
|
if GELU_QUICK_AVAILABLE:
|
||||||
|
kernels.append("gelu_quick")
|
||||||
dtypes = [torch.float16, torch.bfloat16]
|
dtypes = [torch.float16, torch.bfloat16]
|
||||||
|
|
||||||
|
|
||||||
@@ -64,7 +80,7 @@ def make_configs(bsizes: List[int], slens: List[int], dims_: List[int]) -> List[
|
|||||||
|
|
||||||
default_batch_sizes = [2**i for i in range(0, 5, 2)] # 1,4,16
|
default_batch_sizes = [2**i for i in range(0, 5, 2)] # 1,4,16
|
||||||
default_seq_lens = [2**i for i in range(0, 8, 2)] # 1,4,16,64
|
default_seq_lens = [2**i for i in range(0, 8, 2)] # 1,4,16,64
|
||||||
default_dims = [2**i for i in range(7, 15)] # 128...16384
|
default_dims = [2**i for i in range(10, 15)] # 1024...16384
|
||||||
|
|
||||||
|
|
||||||
@triton.testing.perf_report(
|
@triton.testing.perf_report(
|
||||||
@@ -87,6 +103,9 @@ def benchmark(kernel, dtype, batch_size, seq_len, dim, provider):
|
|||||||
y0 = torch.zeros(batch_size, seq_len, dim, dtype=dtype, device=device)
|
y0 = torch.zeros(batch_size, seq_len, dim, dtype=dtype, device=device)
|
||||||
|
|
||||||
vllm_kernel = getattr(vllm_ops, kernel)
|
vllm_kernel = getattr(vllm_ops, kernel)
|
||||||
|
if kernel == "gelu_quick" and not GELU_QUICK_AVAILABLE:
|
||||||
|
# Skip benchmark for gelu_quick if not available
|
||||||
|
return (0, 0, 0)
|
||||||
sglang_kernel = getattr(sgl_kernel, kernel)
|
sglang_kernel = getattr(sgl_kernel, kernel)
|
||||||
|
|
||||||
def baseline():
|
def baseline():
|
||||||
@@ -97,18 +116,14 @@ def benchmark(kernel, dtype, batch_size, seq_len, dim, provider):
|
|||||||
def sglang():
|
def sglang():
|
||||||
return sglang_kernel(x)
|
return sglang_kernel(x)
|
||||||
|
|
||||||
# one-time correctness check
|
|
||||||
if provider == "vllm" and not calculate_diff(
|
|
||||||
kernel, dtype, batch_size, seq_len, dim
|
|
||||||
):
|
|
||||||
raise ValueError("Mismatch – abort benchmark")
|
|
||||||
|
|
||||||
# timing helper
|
# timing helper
|
||||||
def timed(fn):
|
def timed(fn):
|
||||||
for _ in range(5):
|
for _ in range(5):
|
||||||
fn()
|
fn()
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
ms, qmin, qmax = triton.testing.do_bench(fn, quantiles=[0.5, 0.2, 0.8])
|
ms, qmin, qmax = triton.testing.do_bench_cudagraph(
|
||||||
|
fn, quantiles=[0.5, 0.2, 0.8]
|
||||||
|
)
|
||||||
return 1000 * ms, 1000 * qmax, 1000 * qmin
|
return 1000 * ms, 1000 * qmax, 1000 * qmin
|
||||||
|
|
||||||
if provider == "vllm":
|
if provider == "vllm":
|
||||||
@@ -147,7 +162,9 @@ if __name__ == "__main__":
|
|||||||
benchmark.benchmark.x_vals = benchmark_grid
|
benchmark.benchmark.x_vals = benchmark_grid
|
||||||
|
|
||||||
if args.verify_only:
|
if args.verify_only:
|
||||||
ok = calculate_diff("gelu_quick", torch.float16, 1, 1, args.dims[0])
|
# Test with the first available kernel
|
||||||
|
test_kernel = kernels[0]
|
||||||
|
ok = calculate_diff(test_kernel, torch.float16, 1, 1, args.dims[0])
|
||||||
print("✅ sanity pass" if ok else "❌ mismatch")
|
print("✅ sanity pass" if ok else "❌ mismatch")
|
||||||
else:
|
else:
|
||||||
benchmark.run(print_data=True)
|
benchmark.run(print_data=True)
|
||||||
|
|||||||
@@ -108,7 +108,7 @@ def benchmark(qweight_row, qweight_col, provider):
|
|||||||
qweight.clone(), scales.clone(), qzeros.clone()
|
qweight.clone(), scales.clone(), qzeros.clone()
|
||||||
)
|
)
|
||||||
|
|
||||||
ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=quantiles)
|
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=quantiles)
|
||||||
|
|
||||||
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||||
|
|
||||||
|
|||||||
@@ -87,7 +87,7 @@ def benchmark(batch_size, seq_len, provider, block_size, num_kv_splits):
|
|||||||
workspace = torch.empty(workspace_size, device="cuda", dtype=torch.uint8)
|
workspace = torch.empty(workspace_size, device="cuda", dtype=torch.uint8)
|
||||||
|
|
||||||
quantiles = [0.5, 0.2, 0.8]
|
quantiles = [0.5, 0.2, 0.8]
|
||||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||||
lambda: cutlass_mla_decode(
|
lambda: cutlass_mla_decode(
|
||||||
qn.transpose(0, 1),
|
qn.transpose(0, 1),
|
||||||
qr,
|
qr,
|
||||||
@@ -136,8 +136,6 @@ if __name__ == "__main__":
|
|||||||
print(f"block_size={block_size}, num_kv_splits={kv_split}: ")
|
print(f"block_size={block_size}, num_kv_splits={kv_split}: ")
|
||||||
benchmark.run(
|
benchmark.run(
|
||||||
print_data=True,
|
print_data=True,
|
||||||
show_plots=True,
|
|
||||||
save_path="bench_blackwell_mla_res",
|
|
||||||
block_size=block_size,
|
block_size=block_size,
|
||||||
num_kv_splits=kv_split,
|
num_kv_splits=kv_split,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -41,7 +41,7 @@ def benchmark(num_tokens, impl):
|
|||||||
def runner():
|
def runner():
|
||||||
dsv3_fused_a_gemm(mat_a, mat_b)
|
dsv3_fused_a_gemm(mat_a, mat_b)
|
||||||
|
|
||||||
ms, min_ms, max_ms = triton.testing.do_bench(runner, quantiles=quantiles)
|
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(runner, quantiles=quantiles)
|
||||||
|
|
||||||
def tflops(t_ms):
|
def tflops(t_ms):
|
||||||
flops = 2 * M * K * N
|
flops = 2 * M * K * N
|
||||||
@@ -54,4 +54,4 @@ if __name__ == "__main__":
|
|||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
benchmark.run(print_data=True, show_plots=True, save_path="bench_dsv3_gemm")
|
benchmark.run(print_data=True)
|
||||||
|
|||||||
@@ -52,7 +52,7 @@ def benchmark_bf16_output(num_tokens, impl):
|
|||||||
def runner():
|
def runner():
|
||||||
dsv3_router_gemm(mat_a, mat_b, out_dtype=torch.bfloat16)
|
dsv3_router_gemm(mat_a, mat_b, out_dtype=torch.bfloat16)
|
||||||
|
|
||||||
ms, min_ms, max_ms = triton.testing.do_bench(runner, quantiles=quantiles)
|
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(runner, quantiles=quantiles)
|
||||||
|
|
||||||
def tflops(t_ms):
|
def tflops(t_ms):
|
||||||
flops = 2 * M * K * N
|
flops = 2 * M * K * N
|
||||||
@@ -106,7 +106,7 @@ def benchmark_float_output(num_tokens, impl):
|
|||||||
def runner():
|
def runner():
|
||||||
dsv3_router_gemm(mat_a, mat_b, out_dtype=torch.float32)
|
dsv3_router_gemm(mat_a, mat_b, out_dtype=torch.float32)
|
||||||
|
|
||||||
ms, min_ms, max_ms = triton.testing.do_bench(runner, quantiles=quantiles)
|
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(runner, quantiles=quantiles)
|
||||||
|
|
||||||
def tflops(t_ms):
|
def tflops(t_ms):
|
||||||
flops = 2 * M * K * N
|
flops = 2 * M * K * N
|
||||||
@@ -119,9 +119,5 @@ if __name__ == "__main__":
|
|||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
benchmark_bf16_output.run(
|
benchmark_bf16_output.run(print_data=True)
|
||||||
print_data=True, show_plots=True, save_path="bench_dsv3_router_gemm"
|
benchmark_float_output.run(print_data=True)
|
||||||
)
|
|
||||||
benchmark_float_output.run(
|
|
||||||
print_data=True, show_plots=True, save_path="bench_dsv3_router_gemm"
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -198,8 +198,6 @@ if __name__ == "__main__":
|
|||||||
print(f"DeepSeek-R1-0528-FP4 N={N} K={K}: ")
|
print(f"DeepSeek-R1-0528-FP4 N={N} K={K}: ")
|
||||||
benchmark.run(
|
benchmark.run(
|
||||||
print_data=True,
|
print_data=True,
|
||||||
show_plots=True,
|
|
||||||
save_path="bench_fp4_res",
|
|
||||||
N=N,
|
N=N,
|
||||||
K=K,
|
K=K,
|
||||||
dtype=args.dtype,
|
dtype=args.dtype,
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import itertools
|
|||||||
import deep_gemm
|
import deep_gemm
|
||||||
import torch
|
import torch
|
||||||
import triton
|
import triton
|
||||||
from deep_gemm import get_col_major_tma_aligned_tensor
|
from deep_gemm.utils.layout import get_mn_major_tma_aligned_tensor
|
||||||
from sgl_kernel import fp8_blockwise_scaled_mm
|
from sgl_kernel import fp8_blockwise_scaled_mm
|
||||||
from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
|
from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
|
||||||
|
|
||||||
@@ -71,7 +71,7 @@ def fp8_gemm_deepgemm(
|
|||||||
out = torch.empty((m, n), device="cuda", dtype=torch.bfloat16)
|
out = torch.empty((m, n), device="cuda", dtype=torch.bfloat16)
|
||||||
|
|
||||||
# Run DeepGEMM kernel
|
# Run DeepGEMM kernel
|
||||||
deep_gemm.gemm_fp8_fp8_bf16_nt((x_fp8, x_scale), (y_fp8, y_scale), out)
|
deep_gemm.fp8_gemm_nt((x_fp8, x_scale), (y_fp8, y_scale), out)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
@@ -117,7 +117,7 @@ def benchmark(batch_size, provider, N, K):
|
|||||||
if provider == "sgl-kernel":
|
if provider == "sgl-kernel":
|
||||||
scale_a = scale_a.t().contiguous().t()
|
scale_a = scale_a.t().contiguous().t()
|
||||||
b_fp8, scale_b = b_fp8.t(), scale_b.t()
|
b_fp8, scale_b = b_fp8.t(), scale_b.t()
|
||||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||||
lambda: fp8_blockwise_scaled_mm(
|
lambda: fp8_blockwise_scaled_mm(
|
||||||
a_fp8, b_fp8, scale_a, scale_b, torch.float16
|
a_fp8, b_fp8, scale_a, scale_b, torch.float16
|
||||||
),
|
),
|
||||||
@@ -126,20 +126,20 @@ def benchmark(batch_size, provider, N, K):
|
|||||||
if provider == "vllm":
|
if provider == "vllm":
|
||||||
scale_a = scale_a.t().contiguous().t()
|
scale_a = scale_a.t().contiguous().t()
|
||||||
b_fp8, scale_b = b_fp8.t(), scale_b.t()
|
b_fp8, scale_b = b_fp8.t(), scale_b.t()
|
||||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||||
lambda: vllm_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, torch.float16),
|
lambda: vllm_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, torch.float16),
|
||||||
quantiles=quantiles,
|
quantiles=quantiles,
|
||||||
)
|
)
|
||||||
if provider == "triton":
|
if provider == "triton":
|
||||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||||
lambda: w8a8_block_fp8_matmul(
|
lambda: w8a8_block_fp8_matmul(
|
||||||
a_fp8, b_fp8, scale_a, scale_b, [128, 128], torch.float16
|
a_fp8, b_fp8, scale_a, scale_b, [128, 128], torch.float16
|
||||||
),
|
),
|
||||||
quantiles=quantiles,
|
quantiles=quantiles,
|
||||||
)
|
)
|
||||||
if provider == "deepgemm":
|
if provider == "deepgemm":
|
||||||
scale_a_col_major = get_col_major_tma_aligned_tensor(scale_a.clone())
|
scale_a_col_major = get_mn_major_tma_aligned_tensor(scale_a.clone())
|
||||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||||
lambda: fp8_gemm_deepgemm(
|
lambda: fp8_gemm_deepgemm(
|
||||||
a_fp8, scale_a_col_major, b_fp8, scale_b, M, N, K
|
a_fp8, scale_a_col_major, b_fp8, scale_b, M, N, K
|
||||||
),
|
),
|
||||||
@@ -174,8 +174,6 @@ if __name__ == "__main__":
|
|||||||
print(f"{model_name} N={N} K={K}: ")
|
print(f"{model_name} N={N} K={K}: ")
|
||||||
benchmark.run(
|
benchmark.run(
|
||||||
print_data=True,
|
print_data=True,
|
||||||
show_plots=True,
|
|
||||||
save_path="bench_fp8_blockwise_res",
|
|
||||||
N=N,
|
N=N,
|
||||||
K=K,
|
K=K,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -125,7 +125,7 @@ def benchmark(batch_size, provider, N, K):
|
|||||||
a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a)
|
a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a)
|
||||||
b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b)
|
b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b)
|
||||||
b_fp8 = b_fp8.t()
|
b_fp8 = b_fp8.t()
|
||||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||||
lambda: vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype),
|
lambda: vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype),
|
||||||
quantiles=quantiles,
|
quantiles=quantiles,
|
||||||
)
|
)
|
||||||
@@ -133,7 +133,7 @@ def benchmark(batch_size, provider, N, K):
|
|||||||
a_fp8, scale_a_fp8 = sglang_scaled_fp8_quant(a, scale_a)
|
a_fp8, scale_a_fp8 = sglang_scaled_fp8_quant(a, scale_a)
|
||||||
b_fp8, scale_b_fp8 = sglang_scaled_fp8_quant(b, scale_b)
|
b_fp8, scale_b_fp8 = sglang_scaled_fp8_quant(b, scale_b)
|
||||||
b_fp8 = b_fp8.t()
|
b_fp8 = b_fp8.t()
|
||||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||||
lambda: sgl_scaled_mm(
|
lambda: sgl_scaled_mm(
|
||||||
a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype, bias=None
|
a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype, bias=None
|
||||||
),
|
),
|
||||||
@@ -177,8 +177,6 @@ if __name__ == "__main__":
|
|||||||
KN_model_names = prepare_shapes(args)
|
KN_model_names = prepare_shapes(args)
|
||||||
for K, N, model_name in KN_model_names:
|
for K, N, model_name in KN_model_names:
|
||||||
print(f"{model_name} N={N} K={K}: ")
|
print(f"{model_name} N={N} K={K}: ")
|
||||||
benchmark.run(
|
benchmark.run(print_data=True, N=N, K=K)
|
||||||
print_data=True, show_plots=True, save_path="bench_fp8_res", N=N, K=K
|
|
||||||
)
|
|
||||||
|
|
||||||
print("Benchmark finished!")
|
print("Benchmark finished!")
|
||||||
|
|||||||
@@ -86,12 +86,12 @@ def benchmark(batch_size, provider, N, K):
|
|||||||
|
|
||||||
quantiles = [0.5, 0.2, 0.8]
|
quantiles = [0.5, 0.2, 0.8]
|
||||||
if provider == "sgl-kernel":
|
if provider == "sgl-kernel":
|
||||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||||
lambda: int8_scaled_mm(a, b, scale_a, scale_b, torch.float16, bias),
|
lambda: int8_scaled_mm(a, b, scale_a, scale_b, torch.float16, bias),
|
||||||
quantiles=quantiles,
|
quantiles=quantiles,
|
||||||
)
|
)
|
||||||
if provider == "vllm":
|
if provider == "vllm":
|
||||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||||
lambda: vllm_scaled_mm(a, b, scale_a, scale_b, torch.float16, bias),
|
lambda: vllm_scaled_mm(a, b, scale_a, scale_b, torch.float16, bias),
|
||||||
quantiles=quantiles,
|
quantiles=quantiles,
|
||||||
)
|
)
|
||||||
@@ -139,8 +139,6 @@ if __name__ == "__main__":
|
|||||||
KN_model_names = prepare_shapes(args)
|
KN_model_names = prepare_shapes(args)
|
||||||
for K, N, model_name in KN_model_names:
|
for K, N, model_name in KN_model_names:
|
||||||
print(f"{model_name} N={N} K={K}: ")
|
print(f"{model_name} N={N} K={K}: ")
|
||||||
benchmark.run(
|
benchmark.run(print_data=True, N=N, K=K)
|
||||||
print_data=True, show_plots=True, save_path="bench_int8_res", N=N, K=K
|
|
||||||
)
|
|
||||||
|
|
||||||
print("Benchmark finished!")
|
print("Benchmark finished!")
|
||||||
|
|||||||
@@ -246,7 +246,7 @@ def benchmark(batch_size, provider):
|
|||||||
quantiles = [0.5, 0.2, 0.8]
|
quantiles = [0.5, 0.2, 0.8]
|
||||||
|
|
||||||
if provider == "naive":
|
if provider == "naive":
|
||||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||||
lambda: lightning_attention_decode_naive(
|
lambda: lightning_attention_decode_naive(
|
||||||
q.clone(), k.clone(), v.clone(), past_kv.clone(), slope.clone()
|
q.clone(), k.clone(), v.clone(), past_kv.clone(), slope.clone()
|
||||||
),
|
),
|
||||||
@@ -257,7 +257,7 @@ def benchmark(batch_size, provider):
|
|||||||
batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype
|
batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype
|
||||||
)
|
)
|
||||||
new_kv = torch.empty(batch_size, num_heads, head_dim, head_dim, device=device)
|
new_kv = torch.empty(batch_size, num_heads, head_dim, head_dim, device=device)
|
||||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||||
lambda: lightning_attention_decode_kernel(
|
lambda: lightning_attention_decode_kernel(
|
||||||
q.clone(),
|
q.clone(),
|
||||||
k.clone(),
|
k.clone(),
|
||||||
@@ -270,7 +270,7 @@ def benchmark(batch_size, provider):
|
|||||||
quantiles=quantiles,
|
quantiles=quantiles,
|
||||||
)
|
)
|
||||||
elif provider == "triton":
|
elif provider == "triton":
|
||||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||||
lambda: triton_lightning_attn_decode(
|
lambda: triton_lightning_attn_decode(
|
||||||
q.clone(), k.clone(), v.clone(), past_kv.clone(), slope.clone()
|
q.clone(), k.clone(), v.clone(), past_kv.clone(), slope.clone()
|
||||||
),
|
),
|
||||||
|
|||||||
@@ -324,7 +324,7 @@ def benchmark(num_tokens, num_experts, topk, provider):
|
|||||||
|
|
||||||
quantiles = [0.5, 0.2, 0.8]
|
quantiles = [0.5, 0.2, 0.8]
|
||||||
if provider == "sgl":
|
if provider == "sgl":
|
||||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||||
lambda: sgl_moe_align_block_size_with_empty(
|
lambda: sgl_moe_align_block_size_with_empty(
|
||||||
topk_ids,
|
topk_ids,
|
||||||
num_experts,
|
num_experts,
|
||||||
@@ -336,7 +336,7 @@ def benchmark(num_tokens, num_experts, topk, provider):
|
|||||||
quantiles=quantiles,
|
quantiles=quantiles,
|
||||||
)
|
)
|
||||||
elif provider == "sgl_fusion":
|
elif provider == "sgl_fusion":
|
||||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||||
lambda: sgl_moe_align_block_size_with_empty(
|
lambda: sgl_moe_align_block_size_with_empty(
|
||||||
topk_ids,
|
topk_ids,
|
||||||
num_experts,
|
num_experts,
|
||||||
@@ -350,7 +350,7 @@ def benchmark(num_tokens, num_experts, topk, provider):
|
|||||||
)
|
)
|
||||||
elif provider == "triton":
|
elif provider == "triton":
|
||||||
sorted_ids.fill_(topk_ids.numel())
|
sorted_ids.fill_(topk_ids.numel())
|
||||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||||
lambda: moe_align_block_size_triton(
|
lambda: moe_align_block_size_triton(
|
||||||
topk_ids,
|
topk_ids,
|
||||||
num_experts,
|
num_experts,
|
||||||
|
|||||||
@@ -63,7 +63,9 @@ def benchmark(batch_size, provider):
|
|||||||
block_size,
|
block_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
ms, min_ms, max_ms = triton.testing.do_bench(run_triton, quantiles=quantiles)
|
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||||
|
run_triton, quantiles=quantiles
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown provider: {provider}")
|
raise ValueError(f"Unknown provider: {provider}")
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ configs = [(sq,) for sq in seq_length_range]
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
def benchmark(seq_length, provider):
|
def benchmark(seq_length, provider):
|
||||||
dtype = torch.bfloat16
|
dtype = torch.float32
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
num_experts, num_expert_group, topk_group, topk = 256, 8, 4, 8
|
num_experts, num_expert_group, topk_group, topk = 256, 8, 4, 8
|
||||||
|
|
||||||
@@ -56,14 +56,14 @@ def benchmark(seq_length, provider):
|
|||||||
quantiles = [0.5, 0.2, 0.8]
|
quantiles = [0.5, 0.2, 0.8]
|
||||||
|
|
||||||
if provider == "original":
|
if provider == "original":
|
||||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||||
lambda: biased_grouped_topk_org(
|
lambda: biased_grouped_topk_org(
|
||||||
scores.clone(), bias.clone(), num_expert_group, topk_group, topk
|
scores.clone(), bias.clone(), num_expert_group, topk_group, topk
|
||||||
),
|
),
|
||||||
quantiles=quantiles,
|
quantiles=quantiles,
|
||||||
)
|
)
|
||||||
elif provider == "kernel":
|
elif provider == "kernel":
|
||||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||||
lambda: biased_grouped_topk_org_fuse_kernel(
|
lambda: biased_grouped_topk_org_fuse_kernel(
|
||||||
scores.clone(), bias.clone(), num_expert_group, topk_group, topk
|
scores.clone(), bias.clone(), num_expert_group, topk_group, topk
|
||||||
),
|
),
|
||||||
|
|||||||
@@ -97,7 +97,7 @@ def benchmark(num_tokens, num_experts, topk, provider):
|
|||||||
fn = lambda: sglang_topk_softmax(gating_output, topk)
|
fn = lambda: sglang_topk_softmax(gating_output, topk)
|
||||||
|
|
||||||
quantiles = [0.5, 0.2, 0.8]
|
quantiles = [0.5, 0.2, 0.8]
|
||||||
ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=quantiles)
|
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=quantiles)
|
||||||
|
|
||||||
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||||
|
|
||||||
|
|||||||
@@ -165,8 +165,6 @@ if __name__ == "__main__":
|
|||||||
KN_model_names = prepare_shapes(args)
|
KN_model_names = prepare_shapes(args)
|
||||||
for K, N, model_name in KN_model_names:
|
for K, N, model_name in KN_model_names:
|
||||||
print(f"{model_name} N={N} K={K}: ")
|
print(f"{model_name} N={N} K={K}: ")
|
||||||
benchmark.run(
|
benchmark.run(print_data=True, N=N, K=K)
|
||||||
print_data=True, show_plots=True, save_path="bench_fp4_res", N=N, K=K
|
|
||||||
)
|
|
||||||
|
|
||||||
print("Benchmark finished!")
|
print("Benchmark finished!")
|
||||||
|
|||||||
@@ -88,7 +88,7 @@ def benchmark(batch_size, seq_len, provider):
|
|||||||
elif provider == "sglang":
|
elif provider == "sglang":
|
||||||
fn = lambda: sglang_scaled_fp8_quant(x.clone())
|
fn = lambda: sglang_scaled_fp8_quant(x.clone())
|
||||||
|
|
||||||
ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=quantiles)
|
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=quantiles)
|
||||||
|
|
||||||
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||||
|
|
||||||
|
|||||||
@@ -160,7 +160,7 @@ def benchmark_quantization(batch_size, seq_len, hidden_dim, provider):
|
|||||||
elif provider == "sglang":
|
elif provider == "sglang":
|
||||||
fn = lambda: sglang_per_token_quant_fp8(x.clone())
|
fn = lambda: sglang_per_token_quant_fp8(x.clone())
|
||||||
|
|
||||||
ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=quantiles)
|
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=quantiles)
|
||||||
|
|
||||||
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||||
|
|
||||||
|
|||||||
@@ -117,17 +117,17 @@ def benchmark(batch_size, provider, N, K):
|
|||||||
|
|
||||||
quantiles = [0.5, 0.2, 0.8]
|
quantiles = [0.5, 0.2, 0.8]
|
||||||
if provider == "FP16":
|
if provider == "FP16":
|
||||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||||
lambda: torch.matmul(a_fp16, b_fp16),
|
lambda: torch.matmul(a_fp16, b_fp16),
|
||||||
quantiles=quantiles,
|
quantiles=quantiles,
|
||||||
)
|
)
|
||||||
if provider == "W8A8":
|
if provider == "W8A8":
|
||||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||||
lambda: int8_scaled_mm(a, b, scale_a, scale_b, torch.float16),
|
lambda: int8_scaled_mm(a, b, scale_a, scale_b, torch.float16),
|
||||||
quantiles=quantiles,
|
quantiles=quantiles,
|
||||||
)
|
)
|
||||||
if provider == "Qserve_W4A8_Per_Channel":
|
if provider == "Qserve_W4A8_Per_Channel":
|
||||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||||
lambda: qserve_w4a8_per_chn_gemm(
|
lambda: qserve_w4a8_per_chn_gemm(
|
||||||
a_qserve_chn,
|
a_qserve_chn,
|
||||||
b_qserve_chn,
|
b_qserve_chn,
|
||||||
@@ -139,7 +139,7 @@ def benchmark(batch_size, provider, N, K):
|
|||||||
quantiles=quantiles,
|
quantiles=quantiles,
|
||||||
)
|
)
|
||||||
if provider == "Qserve_W4A8_Per_Group":
|
if provider == "Qserve_W4A8_Per_Group":
|
||||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||||
lambda: qserve_w4a8_per_group_gemm(
|
lambda: qserve_w4a8_per_group_gemm(
|
||||||
a_qserve_group,
|
a_qserve_group,
|
||||||
b_qserve_group,
|
b_qserve_group,
|
||||||
@@ -189,8 +189,6 @@ if __name__ == "__main__":
|
|||||||
print(f"{model_name} N={N} K={K}: ")
|
print(f"{model_name} N={N} K={K}: ")
|
||||||
benchmark.run(
|
benchmark.run(
|
||||||
print_data=True,
|
print_data=True,
|
||||||
show_plots=True,
|
|
||||||
save_path="bench_qserve_w4a8_gemm_res",
|
|
||||||
N=N,
|
N=N,
|
||||||
K=K,
|
K=K,
|
||||||
)
|
)
|
||||||
|
|||||||
318
sgl-kernel/benchmark/bench_rmsnorm.py
Normal file
318
sgl-kernel/benchmark/bench_rmsnorm.py
Normal file
@@ -0,0 +1,318 @@
|
|||||||
|
# Benchmarks SGLang RMSNorm kernels versus vLLM and FlashInfer across
|
||||||
|
# (batch_size, seq_len, hidden_size) and prints speed-up.
|
||||||
|
import argparse
|
||||||
|
import itertools
|
||||||
|
import re
|
||||||
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import sgl_kernel
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import triton
|
||||||
|
import triton.testing
|
||||||
|
from flashinfer.norm import fused_add_rmsnorm, rmsnorm
|
||||||
|
from sgl_kernel.utils import is_arch_support_pdl
|
||||||
|
from vllm import _custom_ops as vllm_ops
|
||||||
|
|
||||||
|
|
||||||
|
def str2int_list(arg: str) -> List[int]:
|
||||||
|
if arg in ("", None):
|
||||||
|
return []
|
||||||
|
if re.fullmatch(r"\d+(,\d+)*", arg.strip()) is None:
|
||||||
|
raise argparse.ArgumentTypeError(f"Bad int list: {arg}")
|
||||||
|
return [int(x) for x in arg.split(",")]
|
||||||
|
|
||||||
|
|
||||||
|
class HuggingFaceRMSNorm(nn.Module):
|
||||||
|
def __init__(self, hidden_size: int, eps: float = 1e-6) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||||
|
self.variance_epsilon = eps
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
residual: Optional[torch.Tensor] = None,
|
||||||
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||||
|
orig_dtype = x.dtype
|
||||||
|
x = x.to(torch.float32)
|
||||||
|
if residual is not None:
|
||||||
|
x = x + residual.to(torch.float32)
|
||||||
|
residual = x.to(orig_dtype)
|
||||||
|
|
||||||
|
variance = x.pow(2).mean(dim=-1, keepdim=True)
|
||||||
|
x = x * torch.rsqrt(variance + self.variance_epsilon)
|
||||||
|
x = x.to(orig_dtype) * self.weight
|
||||||
|
if residual is None:
|
||||||
|
return x
|
||||||
|
else:
|
||||||
|
return x, residual
|
||||||
|
|
||||||
|
|
||||||
|
def rmsnorm_naive(
|
||||||
|
x: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
residual: Optional[torch.Tensor] = None,
|
||||||
|
eps: float = 1e-6,
|
||||||
|
):
|
||||||
|
naive_norm = HuggingFaceRMSNorm(x.shape[-1], eps=eps)
|
||||||
|
naive_norm.weight = nn.Parameter(weight)
|
||||||
|
naive_norm = naive_norm.to(x.device)
|
||||||
|
|
||||||
|
orig_shape = x.shape
|
||||||
|
x = x.view(-1, x.shape[-1])
|
||||||
|
if residual is not None:
|
||||||
|
residual = residual.view(-1, residual.shape[-1])
|
||||||
|
|
||||||
|
output = naive_norm(x, residual)
|
||||||
|
|
||||||
|
if isinstance(output, tuple):
|
||||||
|
output = (output[0].view(orig_shape), output[1].view(orig_shape))
|
||||||
|
else:
|
||||||
|
output = output.view(orig_shape)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def rmsnorm_flashinfer(
|
||||||
|
x: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
residual: Optional[torch.Tensor] = None,
|
||||||
|
eps: float = 1e-6,
|
||||||
|
):
|
||||||
|
orig_shape = x.shape
|
||||||
|
x = x.view(-1, x.shape[-1])
|
||||||
|
if residual is not None:
|
||||||
|
residual = residual.view(-1, residual.shape[-1])
|
||||||
|
|
||||||
|
if residual is not None:
|
||||||
|
fused_add_rmsnorm(x, residual, weight, eps)
|
||||||
|
output = (x, residual)
|
||||||
|
else:
|
||||||
|
output = rmsnorm(x, weight, eps)
|
||||||
|
|
||||||
|
if isinstance(output, tuple):
|
||||||
|
output = (output[0].view(orig_shape), output[1].view(orig_shape))
|
||||||
|
else:
|
||||||
|
output = output.view(orig_shape)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def rmsnorm_vllm(
|
||||||
|
x: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
residual: Optional[torch.Tensor] = None,
|
||||||
|
eps: float = 1e-6,
|
||||||
|
):
|
||||||
|
orig_shape = x.shape
|
||||||
|
x = x.view(-1, x.shape[-1])
|
||||||
|
if residual is not None:
|
||||||
|
residual = residual.view(-1, residual.shape[-1])
|
||||||
|
|
||||||
|
if residual is not None:
|
||||||
|
vllm_ops.fused_add_rms_norm(x, residual, weight, eps)
|
||||||
|
output = (x, residual)
|
||||||
|
else:
|
||||||
|
out = torch.empty_like(x)
|
||||||
|
vllm_ops.rms_norm(out, x, weight, eps)
|
||||||
|
output = out
|
||||||
|
|
||||||
|
if isinstance(output, tuple):
|
||||||
|
output = (output[0].view(orig_shape), output[1].view(orig_shape))
|
||||||
|
else:
|
||||||
|
output = output.view(orig_shape)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def rmsnorm_sglang(
|
||||||
|
x: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
residual: Optional[torch.Tensor] = None,
|
||||||
|
eps: float = 1e-6,
|
||||||
|
enable_pdl: Optional[bool] = None,
|
||||||
|
):
|
||||||
|
orig_shape = x.shape
|
||||||
|
x = x.view(-1, x.shape[-1])
|
||||||
|
if residual is not None:
|
||||||
|
residual = residual.view(-1, residual.shape[-1])
|
||||||
|
|
||||||
|
if enable_pdl is None:
|
||||||
|
enable_pdl = is_arch_support_pdl()
|
||||||
|
|
||||||
|
if residual is not None:
|
||||||
|
sgl_kernel.fused_add_rmsnorm(x, residual, weight, eps, enable_pdl=enable_pdl)
|
||||||
|
output = (x, residual)
|
||||||
|
else:
|
||||||
|
out = torch.empty_like(x)
|
||||||
|
sgl_kernel.rmsnorm(x, weight, eps, out=out, enable_pdl=enable_pdl)
|
||||||
|
output = out
|
||||||
|
|
||||||
|
if isinstance(output, tuple):
|
||||||
|
output = (output[0].view(orig_shape), output[1].view(orig_shape))
|
||||||
|
else:
|
||||||
|
output = output.view(orig_shape)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True):
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
x = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device="cuda")
|
||||||
|
weight = torch.ones(hidden_size, dtype=dtype, device="cuda")
|
||||||
|
residual = torch.randn_like(x) if use_residual else None
|
||||||
|
|
||||||
|
output_naive = rmsnorm_naive(
|
||||||
|
x.clone(), weight, residual.clone() if residual is not None else None
|
||||||
|
)
|
||||||
|
output_flashinfer = rmsnorm_flashinfer(
|
||||||
|
x.clone(), weight, residual.clone() if residual is not None else None
|
||||||
|
)
|
||||||
|
output_vllm = rmsnorm_vllm(
|
||||||
|
x.clone(), weight, residual.clone() if residual is not None else None
|
||||||
|
)
|
||||||
|
output_sglang = rmsnorm_sglang(
|
||||||
|
x.clone(), weight, residual.clone() if residual is not None else None
|
||||||
|
)
|
||||||
|
|
||||||
|
if use_residual:
|
||||||
|
output_naive = output_naive[0]
|
||||||
|
output_flashinfer = output_flashinfer[0]
|
||||||
|
output_vllm = output_vllm[0]
|
||||||
|
output_sglang = output_sglang[0]
|
||||||
|
|
||||||
|
print(f"Naive output={output_naive}")
|
||||||
|
print(f"FlashInfer output={output_flashinfer}")
|
||||||
|
print(f"VLLM output={output_vllm}")
|
||||||
|
print(f"SGLang output={output_sglang}")
|
||||||
|
|
||||||
|
if (
|
||||||
|
torch.allclose(output_naive, output_flashinfer, atol=1e-2, rtol=1e-2)
|
||||||
|
and torch.allclose(output_naive, output_vllm, atol=1e-2, rtol=1e-2)
|
||||||
|
and torch.allclose(output_naive, output_sglang, atol=1e-2, rtol=1e-2)
|
||||||
|
):
|
||||||
|
print("✅ All implementations match")
|
||||||
|
else:
|
||||||
|
print("❌ Implementations differ")
|
||||||
|
|
||||||
|
|
||||||
|
default_batch_sizes = [2**i for i in range(0, 7, 2)] # 1, 4, 16, 64
|
||||||
|
default_seq_lens = [2**i for i in range(6, 11, 1)] # 64, 128, 256, 512, 1024
|
||||||
|
default_hidden_sizes = [32 * 128, 48 * 128] # 4096, 6144
|
||||||
|
|
||||||
|
|
||||||
|
def make_configs(bsizes: List[int], slens: List[int], hsizes: List[int]) -> List[Tuple]:
|
||||||
|
return list(itertools.product(bsizes, slens, hsizes))
|
||||||
|
|
||||||
|
|
||||||
|
@triton.testing.perf_report(
|
||||||
|
triton.testing.Benchmark(
|
||||||
|
x_names=["batch_size", "seq_len", "hidden_size"],
|
||||||
|
x_vals=[],
|
||||||
|
line_arg="provider",
|
||||||
|
line_vals=["huggingface", "flashinfer", "vllm", "sglang"],
|
||||||
|
line_names=["HuggingFace", "FlashInfer", "vLLM", "SGL Kernel"],
|
||||||
|
styles=[("blue", "-"), ("green", "-"), ("red", "-"), ("orange", "-")],
|
||||||
|
ylabel="µs (median) or × (speed-up)",
|
||||||
|
plot_name="rmsnorm-performance",
|
||||||
|
args={},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
def benchmark(batch_size, seq_len, hidden_size, provider, use_residual):
|
||||||
|
device = torch.device("cuda")
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
|
||||||
|
x = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device=device)
|
||||||
|
weight = torch.ones(hidden_size, dtype=dtype, device=device)
|
||||||
|
residual = torch.randn_like(x) if use_residual else None
|
||||||
|
|
||||||
|
# timing helper
|
||||||
|
def timed(fn):
|
||||||
|
for _ in range(5):
|
||||||
|
fn()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
ms, qmin, qmax = triton.testing.do_bench_cudagraph(
|
||||||
|
fn, quantiles=[0.5, 0.2, 0.8]
|
||||||
|
)
|
||||||
|
return 1000 * ms, 1000 * qmax, 1000 * qmin
|
||||||
|
|
||||||
|
if provider == "huggingface":
|
||||||
|
return timed(
|
||||||
|
lambda: rmsnorm_naive(
|
||||||
|
x.clone(),
|
||||||
|
weight,
|
||||||
|
residual.clone() if residual is not None else None,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif provider == "flashinfer":
|
||||||
|
return timed(
|
||||||
|
lambda: rmsnorm_flashinfer(
|
||||||
|
x.clone(),
|
||||||
|
weight,
|
||||||
|
residual.clone() if residual is not None else None,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif provider == "vllm":
|
||||||
|
return timed(
|
||||||
|
lambda: rmsnorm_vllm(
|
||||||
|
x.clone(),
|
||||||
|
weight,
|
||||||
|
residual.clone() if residual is not None else None,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif provider == "sglang":
|
||||||
|
return timed(
|
||||||
|
lambda: rmsnorm_sglang(
|
||||||
|
x.clone(),
|
||||||
|
weight,
|
||||||
|
residual.clone() if residual is not None else None,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# provider == "speedup"
|
||||||
|
t_ref, _, _ = timed(
|
||||||
|
lambda: rmsnorm_vllm(
|
||||||
|
x.clone(),
|
||||||
|
weight,
|
||||||
|
residual.clone() if residual is not None else None,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
t_sgl, _, _ = timed(
|
||||||
|
lambda: rmsnorm_sglang(
|
||||||
|
x.clone(),
|
||||||
|
weight,
|
||||||
|
residual.clone() if residual is not None else None,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
spd = t_ref / t_sgl
|
||||||
|
return (spd, spd, spd)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
p = argparse.ArgumentParser("RMSNorm kernel benchmark")
|
||||||
|
p.add_argument("--batch_sizes", type=str2int_list, default=default_batch_sizes)
|
||||||
|
p.add_argument("--seq_lens", type=str2int_list, default=default_seq_lens)
|
||||||
|
p.add_argument("--hidden_sizes", type=str2int_list, default=default_hidden_sizes)
|
||||||
|
p.add_argument(
|
||||||
|
"--use_residual", action="store_true", help="Whether to use residual connection"
|
||||||
|
)
|
||||||
|
p.add_argument("--verify_only", action="store_true")
|
||||||
|
args = p.parse_args()
|
||||||
|
|
||||||
|
# coerce lists
|
||||||
|
if isinstance(args.batch_sizes, str):
|
||||||
|
args.batch_sizes = str2int_list(args.batch_sizes)
|
||||||
|
if isinstance(args.seq_lens, str):
|
||||||
|
args.seq_lens = str2int_list(args.seq_lens)
|
||||||
|
if isinstance(args.hidden_sizes, str):
|
||||||
|
args.hidden_sizes = str2int_list(args.hidden_sizes)
|
||||||
|
|
||||||
|
# patch perf_report grid
|
||||||
|
benchmark_grid = make_configs(args.batch_sizes, args.seq_lens, args.hidden_sizes)
|
||||||
|
if hasattr(benchmark, "benchmarks"):
|
||||||
|
benchmark.benchmarks.x_vals = benchmark_grid
|
||||||
|
else:
|
||||||
|
benchmark.benchmark.x_vals = benchmark_grid
|
||||||
|
|
||||||
|
if args.verify_only:
|
||||||
|
ok = calculate_diff(4, 128, args.hidden_sizes[0], args.use_residual)
|
||||||
|
print("✅ sanity pass" if ok else "❌ mismatch")
|
||||||
|
else:
|
||||||
|
benchmark.run(print_data=True, use_residual=args.use_residual)
|
||||||
@@ -114,7 +114,9 @@ def benchmark_sampling(batch_size, vocab_size, p, provider):
|
|||||||
filter_apply_order="joint",
|
filter_apply_order="joint",
|
||||||
)
|
)
|
||||||
|
|
||||||
ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=[0.5, 0.2, 0.8])
|
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||||
|
fn, quantiles=[0.5, 0.2, 0.8]
|
||||||
|
)
|
||||||
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import sgl_kernel
|
import sgl_kernel
|
||||||
import torch
|
import torch
|
||||||
|
from sgl_kernel.utils import is_arch_support_pdl
|
||||||
|
|
||||||
|
|
||||||
def llama_rms_norm(x, w, eps=1e-6):
|
def llama_rms_norm(x, w, eps=1e-6):
|
||||||
@@ -58,11 +59,12 @@ def test_norm(batch_size, hidden_size, dtype, specify_out):
|
|||||||
w = torch.randn(hidden_size).to(0).to(dtype)
|
w = torch.randn(hidden_size).to(0).to(dtype)
|
||||||
|
|
||||||
y_ref = llama_rms_norm(x, w)
|
y_ref = llama_rms_norm(x, w)
|
||||||
|
enable_pdl = is_arch_support_pdl()
|
||||||
if specify_out:
|
if specify_out:
|
||||||
y = torch.empty_like(x)
|
y = torch.empty_like(x)
|
||||||
sgl_kernel.rmsnorm(x, w, out=y)
|
sgl_kernel.rmsnorm(x, w, out=y, enable_pdl=enable_pdl)
|
||||||
else:
|
else:
|
||||||
y = sgl_kernel.rmsnorm(x, w)
|
y = sgl_kernel.rmsnorm(x, w, enable_pdl=enable_pdl)
|
||||||
|
|
||||||
torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3)
|
torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3)
|
||||||
|
|
||||||
@@ -83,7 +85,10 @@ def test_fused_add_rmsnorm(batch_size, hidden_size, dtype):
|
|||||||
|
|
||||||
x_fused = x.clone()
|
x_fused = x.clone()
|
||||||
residual_fused = residual.clone()
|
residual_fused = residual.clone()
|
||||||
sgl_kernel.fused_add_rmsnorm(x_fused, residual_fused, weight, eps)
|
enable_pdl = is_arch_support_pdl()
|
||||||
|
sgl_kernel.fused_add_rmsnorm(
|
||||||
|
x_fused, residual_fused, weight, eps, enable_pdl=enable_pdl
|
||||||
|
)
|
||||||
|
|
||||||
torch.testing.assert_close(x_fused, x_native, rtol=1e-3, atol=1e-3)
|
torch.testing.assert_close(x_fused, x_native, rtol=1e-3, atol=1e-3)
|
||||||
torch.testing.assert_close(residual_fused, residual_native, rtol=1e-3, atol=1e-3)
|
torch.testing.assert_close(residual_fused, residual_native, rtol=1e-3, atol=1e-3)
|
||||||
@@ -98,11 +103,12 @@ def test_gemma_norm(batch_size, hidden_size, dtype, specify_out):
|
|||||||
w = torch.randn(hidden_size).to(0).to(dtype)
|
w = torch.randn(hidden_size).to(0).to(dtype)
|
||||||
|
|
||||||
y_ref = gemma_rms_norm(x, w)
|
y_ref = gemma_rms_norm(x, w)
|
||||||
|
enable_pdl = is_arch_support_pdl()
|
||||||
if specify_out:
|
if specify_out:
|
||||||
y = torch.empty_like(x)
|
y = torch.empty_like(x)
|
||||||
sgl_kernel.gemma_rmsnorm(x, w, out=y)
|
sgl_kernel.gemma_rmsnorm(x, w, out=y, enable_pdl=enable_pdl)
|
||||||
else:
|
else:
|
||||||
y = sgl_kernel.gemma_rmsnorm(x, w)
|
y = sgl_kernel.gemma_rmsnorm(x, w, enable_pdl=enable_pdl)
|
||||||
|
|
||||||
torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3)
|
torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3)
|
||||||
|
|
||||||
@@ -123,7 +129,10 @@ def test_gemma_fused_add_rmsnorm(batch_size, hidden_size, dtype):
|
|||||||
|
|
||||||
x_fused = x.clone()
|
x_fused = x.clone()
|
||||||
residual_fused = residual.clone()
|
residual_fused = residual.clone()
|
||||||
sgl_kernel.gemma_fused_add_rmsnorm(x_fused, residual_fused, weight, eps)
|
enable_pdl = is_arch_support_pdl()
|
||||||
|
sgl_kernel.gemma_fused_add_rmsnorm(
|
||||||
|
x_fused, residual_fused, weight, eps, enable_pdl=enable_pdl
|
||||||
|
)
|
||||||
|
|
||||||
torch.testing.assert_close(x_fused, x_native, rtol=1e-3, atol=1e-3)
|
torch.testing.assert_close(x_fused, x_native, rtol=1e-3, atol=1e-3)
|
||||||
torch.testing.assert_close(residual_fused, residual_native, rtol=1e-3, atol=1e-3)
|
torch.testing.assert_close(residual_fused, residual_native, rtol=1e-3, atol=1e-3)
|
||||||
|
|||||||
Reference in New Issue
Block a user