Restruct sgl-kernel benchmark (#10861)

This commit is contained in:
Xiaoyu Zhang
2025-09-25 07:45:25 +08:00
committed by GitHub
parent 7a06ef984d
commit c4e314f986
27 changed files with 425 additions and 319 deletions

View File

@@ -3,6 +3,7 @@
import pytest
import sgl_kernel
import torch
from sgl_kernel.utils import is_arch_support_pdl
def llama_rms_norm(x, w, eps=1e-6):
@@ -58,11 +59,12 @@ def test_norm(batch_size, hidden_size, dtype, specify_out):
w = torch.randn(hidden_size).to(0).to(dtype)
y_ref = llama_rms_norm(x, w)
enable_pdl = is_arch_support_pdl()
if specify_out:
y = torch.empty_like(x)
sgl_kernel.rmsnorm(x, w, out=y)
sgl_kernel.rmsnorm(x, w, out=y, enable_pdl=enable_pdl)
else:
y = sgl_kernel.rmsnorm(x, w)
y = sgl_kernel.rmsnorm(x, w, enable_pdl=enable_pdl)
torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3)
@@ -83,7 +85,10 @@ def test_fused_add_rmsnorm(batch_size, hidden_size, dtype):
x_fused = x.clone()
residual_fused = residual.clone()
sgl_kernel.fused_add_rmsnorm(x_fused, residual_fused, weight, eps)
enable_pdl = is_arch_support_pdl()
sgl_kernel.fused_add_rmsnorm(
x_fused, residual_fused, weight, eps, enable_pdl=enable_pdl
)
torch.testing.assert_close(x_fused, x_native, rtol=1e-3, atol=1e-3)
torch.testing.assert_close(residual_fused, residual_native, rtol=1e-3, atol=1e-3)
@@ -98,11 +103,12 @@ def test_gemma_norm(batch_size, hidden_size, dtype, specify_out):
w = torch.randn(hidden_size).to(0).to(dtype)
y_ref = gemma_rms_norm(x, w)
enable_pdl = is_arch_support_pdl()
if specify_out:
y = torch.empty_like(x)
sgl_kernel.gemma_rmsnorm(x, w, out=y)
sgl_kernel.gemma_rmsnorm(x, w, out=y, enable_pdl=enable_pdl)
else:
y = sgl_kernel.gemma_rmsnorm(x, w)
y = sgl_kernel.gemma_rmsnorm(x, w, enable_pdl=enable_pdl)
torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3)
@@ -123,7 +129,10 @@ def test_gemma_fused_add_rmsnorm(batch_size, hidden_size, dtype):
x_fused = x.clone()
residual_fused = residual.clone()
sgl_kernel.gemma_fused_add_rmsnorm(x_fused, residual_fused, weight, eps)
enable_pdl = is_arch_support_pdl()
sgl_kernel.gemma_fused_add_rmsnorm(
x_fused, residual_fused, weight, eps, enable_pdl=enable_pdl
)
torch.testing.assert_close(x_fused, x_native, rtol=1e-3, atol=1e-3)
torch.testing.assert_close(residual_fused, residual_native, rtol=1e-3, atol=1e-3)