32 lines
929 B
Python
32 lines
929 B
Python
import pytest
|
|
import torch
|
|
from sgl_kernel import rmsnorm
|
|
|
|
|
|
def llama_rms_norm(x, w, eps=1e-6):
|
|
orig_dtype = x.dtype
|
|
x = x.float()
|
|
variance = x.pow(2).mean(dim=-1, keepdim=True)
|
|
x = x * torch.rsqrt(variance + eps)
|
|
x = x * w.float()
|
|
x = x.to(orig_dtype)
|
|
return x
|
|
|
|
|
|
@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
|
|
@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192, 16384])
|
|
@pytest.mark.parametrize("dtype", [torch.float16])
|
|
@pytest.mark.parametrize("specify_out", [True, False])
|
|
def test_norm(batch_size, hidden_size, dtype, specify_out):
|
|
x = torch.randn(batch_size, hidden_size).to(0).to(dtype)
|
|
w = torch.randn(hidden_size).to(0).to(dtype)
|
|
|
|
y_ref = llama_rms_norm(x, w)
|
|
if specify_out:
|
|
y = torch.empty_like(x)
|
|
rmsnorm(x, w, out=y)
|
|
else:
|
|
y = rmsnorm(x, w)
|
|
|
|
torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3)
|