feat: add flashinfer as 3rdparty and use rmsnorm as example (#3033)
This commit is contained in:
31
sgl-kernel/tests/test_rmsnorm.py
Normal file
31
sgl-kernel/tests/test_rmsnorm.py
Normal file
@@ -0,0 +1,31 @@
|
||||
import pytest
|
||||
import torch
|
||||
from sgl_kernel import rmsnorm
|
||||
|
||||
|
||||
def llama_rms_norm(x, w, eps=1e-6):
|
||||
orig_dtype = x.dtype
|
||||
x = x.float()
|
||||
variance = x.pow(2).mean(dim=-1, keepdim=True)
|
||||
x = x * torch.rsqrt(variance + eps)
|
||||
x = x * w.float()
|
||||
x = x.to(orig_dtype)
|
||||
return x
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
|
||||
@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192, 16384])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16])
|
||||
@pytest.mark.parametrize("specify_out", [True, False])
|
||||
def test_norm(batch_size, hidden_size, dtype, specify_out):
|
||||
x = torch.randn(batch_size, hidden_size).to(0).to(dtype)
|
||||
w = torch.randn(hidden_size).to(0).to(dtype)
|
||||
|
||||
y_ref = llama_rms_norm(x, w)
|
||||
if specify_out:
|
||||
y = torch.empty_like(x)
|
||||
rmsnorm(x, w, out=y)
|
||||
else:
|
||||
y = rmsnorm(x, w)
|
||||
|
||||
torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3)
|
||||
Reference in New Issue
Block a user