sglangv0.5.2 & support Qwen3-Next-80B-A3B-Instruct
This commit is contained in:
90
test/srt/cpu/test_norm.py
Normal file
90
test/srt/cpu/test_norm.py
Normal file
@@ -0,0 +1,90 @@
|
||||
import itertools
|
||||
import unittest
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import sgl_kernel
|
||||
import torch
|
||||
from utils import make_non_contiguous, precision
|
||||
|
||||
from sglang.test.test_utils import CustomTestCase
|
||||
|
||||
torch.manual_seed(1234)
|
||||
|
||||
|
||||
class TestNorm(CustomTestCase):
|
||||
M = [4096, 1024]
|
||||
N = [4096, 4096 + 13]
|
||||
dtype = [torch.float16, torch.bfloat16]
|
||||
|
||||
def _forward_native(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
variance_epsilon: float = 1e-6,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
orig_dtype = x.dtype
|
||||
x = x.to(torch.float32)
|
||||
if residual is not None:
|
||||
x = x + residual.to(torch.float32)
|
||||
residual = x.to(orig_dtype)
|
||||
|
||||
variance = x.pow(2).mean(dim=-1, keepdim=True)
|
||||
x = x * torch.rsqrt(variance + variance_epsilon)
|
||||
x = x.to(orig_dtype) * weight
|
||||
if residual is None:
|
||||
return x
|
||||
else:
|
||||
return x, residual
|
||||
|
||||
def _norm_test(self, m, n, dtype):
|
||||
|
||||
x = torch.randn([m, n], dtype=dtype)
|
||||
x = make_non_contiguous(x)
|
||||
hidden_size = x.size(-1)
|
||||
weight = torch.randn(hidden_size, dtype=dtype)
|
||||
variance_epsilon = 1e-6
|
||||
|
||||
out = torch.ops.sgl_kernel.rmsnorm_cpu(x, weight, variance_epsilon)
|
||||
ref_out = self._forward_native(x, weight, variance_epsilon)
|
||||
|
||||
atol = rtol = precision[ref_out.dtype]
|
||||
torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol)
|
||||
|
||||
ref_x = x.clone()
|
||||
residual = torch.randn([m, hidden_size], dtype=dtype)
|
||||
ref_residual = residual.clone()
|
||||
|
||||
torch.ops.sgl_kernel.fused_add_rmsnorm_cpu(
|
||||
x, residual, weight, variance_epsilon
|
||||
)
|
||||
|
||||
ref_x, ref_residual = self._forward_native(
|
||||
ref_x, weight, variance_epsilon, ref_residual
|
||||
)
|
||||
|
||||
torch.testing.assert_close(x, ref_x, atol=atol, rtol=rtol)
|
||||
torch.testing.assert_close(residual, ref_residual, atol=atol, rtol=rtol)
|
||||
|
||||
def _l2norm_test(self, m, n, dtype):
|
||||
|
||||
x = torch.randn([m, n], dtype=dtype)
|
||||
hidden_size = x.size(-1)
|
||||
fake_ones_weight = torch.ones(hidden_size, dtype=dtype)
|
||||
variance_epsilon = 1e-6
|
||||
|
||||
out = torch.ops.sgl_kernel.l2norm_cpu(x, variance_epsilon)
|
||||
ref_out = self._forward_native(x, fake_ones_weight, variance_epsilon)
|
||||
|
||||
atol = rtol = precision[ref_out.dtype]
|
||||
torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol)
|
||||
|
||||
def test_norm(self):
|
||||
for params in itertools.product(self.M, self.N, self.dtype):
|
||||
with self.subTest(m=params[0], n=params[1], dtype=params[2]):
|
||||
self._norm_test(*params)
|
||||
self._l2norm_test(*params)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user