Restruct sgl-kernel benchmark (#10861)
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
import pytest
|
||||
import sgl_kernel
|
||||
import torch
|
||||
from sgl_kernel.utils import is_arch_support_pdl
|
||||
|
||||
|
||||
def llama_rms_norm(x, w, eps=1e-6):
|
||||
@@ -58,11 +59,12 @@ def test_norm(batch_size, hidden_size, dtype, specify_out):
|
||||
w = torch.randn(hidden_size).to(0).to(dtype)
|
||||
|
||||
y_ref = llama_rms_norm(x, w)
|
||||
enable_pdl = is_arch_support_pdl()
|
||||
if specify_out:
|
||||
y = torch.empty_like(x)
|
||||
sgl_kernel.rmsnorm(x, w, out=y)
|
||||
sgl_kernel.rmsnorm(x, w, out=y, enable_pdl=enable_pdl)
|
||||
else:
|
||||
y = sgl_kernel.rmsnorm(x, w)
|
||||
y = sgl_kernel.rmsnorm(x, w, enable_pdl=enable_pdl)
|
||||
|
||||
torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3)
|
||||
|
||||
@@ -83,7 +85,10 @@ def test_fused_add_rmsnorm(batch_size, hidden_size, dtype):
|
||||
|
||||
x_fused = x.clone()
|
||||
residual_fused = residual.clone()
|
||||
sgl_kernel.fused_add_rmsnorm(x_fused, residual_fused, weight, eps)
|
||||
enable_pdl = is_arch_support_pdl()
|
||||
sgl_kernel.fused_add_rmsnorm(
|
||||
x_fused, residual_fused, weight, eps, enable_pdl=enable_pdl
|
||||
)
|
||||
|
||||
torch.testing.assert_close(x_fused, x_native, rtol=1e-3, atol=1e-3)
|
||||
torch.testing.assert_close(residual_fused, residual_native, rtol=1e-3, atol=1e-3)
|
||||
@@ -98,11 +103,12 @@ def test_gemma_norm(batch_size, hidden_size, dtype, specify_out):
|
||||
w = torch.randn(hidden_size).to(0).to(dtype)
|
||||
|
||||
y_ref = gemma_rms_norm(x, w)
|
||||
enable_pdl = is_arch_support_pdl()
|
||||
if specify_out:
|
||||
y = torch.empty_like(x)
|
||||
sgl_kernel.gemma_rmsnorm(x, w, out=y)
|
||||
sgl_kernel.gemma_rmsnorm(x, w, out=y, enable_pdl=enable_pdl)
|
||||
else:
|
||||
y = sgl_kernel.gemma_rmsnorm(x, w)
|
||||
y = sgl_kernel.gemma_rmsnorm(x, w, enable_pdl=enable_pdl)
|
||||
|
||||
torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3)
|
||||
|
||||
@@ -123,7 +129,10 @@ def test_gemma_fused_add_rmsnorm(batch_size, hidden_size, dtype):
|
||||
|
||||
x_fused = x.clone()
|
||||
residual_fused = residual.clone()
|
||||
sgl_kernel.gemma_fused_add_rmsnorm(x_fused, residual_fused, weight, eps)
|
||||
enable_pdl = is_arch_support_pdl()
|
||||
sgl_kernel.gemma_fused_add_rmsnorm(
|
||||
x_fused, residual_fused, weight, eps, enable_pdl=enable_pdl
|
||||
)
|
||||
|
||||
torch.testing.assert_close(x_fused, x_native, rtol=1e-3, atol=1e-3)
|
||||
torch.testing.assert_close(residual_fused, residual_native, rtol=1e-3, atol=1e-3)
|
||||
|
||||
Reference in New Issue
Block a user