From d8189660a9bbd4b5b5fe2526424d42c8ffcf7195 Mon Sep 17 00:00:00 2001 From: YanbingJiang Date: Fri, 23 May 2025 17:03:15 +0800 Subject: [PATCH] Update sgl-kernel UTs for activation/topk/norm/rope kernels (#6452) --- test/srt/cpu/test_activation.py | 33 +++++++++++ test/srt/cpu/test_norm.py | 73 ++++++++++++++++++++++++ test/srt/cpu/test_rope.py | 78 ++++++++++++++++++++++++++ test/srt/cpu/test_topk.py | 98 +++++++++++++++++++++++++++++++++ 4 files changed, 282 insertions(+) create mode 100644 test/srt/cpu/test_activation.py create mode 100644 test/srt/cpu/test_norm.py create mode 100644 test/srt/cpu/test_rope.py create mode 100644 test/srt/cpu/test_topk.py diff --git a/test/srt/cpu/test_activation.py b/test/srt/cpu/test_activation.py new file mode 100644 index 000000000..7602445dd --- /dev/null +++ b/test/srt/cpu/test_activation.py @@ -0,0 +1,33 @@ +import itertools +import unittest + +import sgl_kernel +import torch +import torch.nn.functional as F +from utils import SiluAndMul, precision + +from sglang.test.test_utils import CustomTestCase + + +class TestActivation(CustomTestCase): + M = [128, 129, 257] + N = [22016, 22018] + dtype = [torch.float16, torch.bfloat16] + + def _activation_test(self, m, n, dtype): + x = torch.randn([m, n], dtype=dtype) + + out = torch.ops.sgl_kernel.silu_and_mul_cpu(x) + ref_out = SiluAndMul(x) + + atol = rtol = precision[ref_out.dtype] + self.assertTrue(torch.allclose(ref_out, out, atol=atol, rtol=rtol)) + + def test_activation(self): + for params in itertools.product(self.M, self.N, self.dtype): + with self.subTest(m=params[0], n=params[1], dtype=params[2]): + self._activation_test(*params) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/cpu/test_norm.py b/test/srt/cpu/test_norm.py new file mode 100644 index 000000000..8af46c6a1 --- /dev/null +++ b/test/srt/cpu/test_norm.py @@ -0,0 +1,73 @@ +import itertools +import unittest +from typing import Optional, Tuple, Union + +import sgl_kernel +import torch +from utils import precision + +from sglang.test.test_utils import CustomTestCase + + +class TestNorm(CustomTestCase): + M = [4096, 1024] + N = [4096, 4096 + 13] + dtype = [torch.float16, torch.bfloat16] + + def _forward_native( + self, + x: torch.Tensor, + weight: torch.Tensor, + variance_epsilon: float = 1e-6, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + orig_dtype = x.dtype + x = x.to(torch.float32) + if residual is not None: + x = x + residual.to(torch.float32) + residual = x.to(orig_dtype) + + variance = x.pow(2).mean(dim=-1, keepdim=True) + x = x * torch.rsqrt(variance + variance_epsilon) + x = x.to(orig_dtype) * weight + if residual is None: + return x + else: + return x, residual + + def _norm_test(self, m, n, dtype): + + x = torch.randn([m, n], dtype=dtype) + hidden_size = x.size(-1) + weight = torch.randn(hidden_size, dtype=dtype) + variance_epsilon = 1e-6 + + out = torch.ops.sgl_kernel.rmsnorm_cpu(x, weight, variance_epsilon) + ref_out = self._forward_native(x, weight, variance_epsilon) + + atol = rtol = precision[ref_out.dtype] + self.assertTrue(torch.allclose(ref_out, out, atol=atol, rtol=rtol)) + + ref_x = x.clone() + residual = torch.randn([m, n], dtype=dtype) + ref_residual = residual.clone() + + torch.ops.sgl_kernel.fused_add_rmsnorm_cpu( + x, residual, weight, variance_epsilon + ) + + ref_x, ref_residual = self._forward_native( + ref_x, weight, variance_epsilon, ref_residual + ) + + self.assertTrue(torch.allclose(x, ref_x, atol=atol, rtol=rtol)) + self.assertTrue(torch.allclose(residual, ref_residual, atol=atol, rtol=rtol)) + + def test_norm(self): + for params in itertools.product(self.M, self.N, self.dtype): + with self.subTest(m=params[0], n=params[1], dtype=params[2]): + self._norm_test(*params) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/cpu/test_rope.py b/test/srt/cpu/test_rope.py new file mode 100644 index 000000000..33b6fc623 --- /dev/null +++ b/test/srt/cpu/test_rope.py @@ -0,0 +1,78 @@ +import unittest + +import sgl_kernel +import torch +from utils import precision + +from sglang.srt.layers.rotary_embedding import DeepseekScalingRotaryEmbedding +from sglang.test.test_utils import CustomTestCase + + +class TestROPE(CustomTestCase): + def test_deepseek_v2_rope(self): + num_head = 16 + seq_len = 1024 + q_head_dim = 192 + qk_nope_head_dim = 128 + qk_rope_head_dim = 64 + max_pos = 256 + k_dim = 576 + rotary_dim = 64 + is_neox_style = False + + # Create cos_sin_cache + freqs = torch.rand(max_pos, qk_rope_head_dim // 2) + cos = freqs.cos() * 0.7 + sin = freqs.sin() * 0.7 + cos_sin_cache = torch.cat((cos, sin), dim=-1).to(torch.bfloat16) + positions = torch.randint(0, max_pos, (seq_len,)) + + rope = DeepseekScalingRotaryEmbedding( + qk_rope_head_dim, + rotary_dim, + max_pos, + 16, # not used since cos_sin_cache is provided + is_neox_style, + 1.0, + torch.bfloat16, + device="cpu", + ) + rope.register_buffer("cos_sin_cache", cos_sin_cache) + + for dtype in [torch.bfloat16]: + enable_autocast = True + + with torch.no_grad(), torch.amp.autocast("cpu", enabled=enable_autocast): + q = torch.randn(seq_len, num_head, q_head_dim, dtype=dtype) + q_clone = q.clone() + k = torch.randn(seq_len, 1, k_dim, dtype=dtype) + k_clone = k.clone() + _, q_pe = q.split([qk_nope_head_dim, qk_rope_head_dim], dim=-1) + _, q_pe_clone = q_clone.split( + [qk_nope_head_dim, qk_rope_head_dim], dim=-1 + ) + k_pe = k[:, :, k_dim - qk_rope_head_dim :] + k_pe_clone = k_clone[:, :, k_dim - qk_rope_head_dim :] + + # ref kernel + q_pe, k_pe = rope.forward_native( + query=q_pe, + key=k_pe, + positions=positions, + ) + + # fused rope kernel + q_pe_clone, k_pe_clone = ( + torch.ops.sgl_kernel.rotary_position_embedding_cpu( + positions, q_pe_clone, k_pe_clone, cos_sin_cache + ) + ) + + atol = rtol = precision[q_pe.dtype] + self.assertTrue(torch.allclose(q_pe, q_pe_clone, atol=atol, rtol=rtol)) + self.assertTrue(torch.allclose(k_pe, k_pe_clone, atol=atol, rtol=rtol)) + torch.testing.assert_close(k_pe, k_pe_clone) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/cpu/test_topk.py b/test/srt/cpu/test_topk.py new file mode 100644 index 000000000..22c9e2784 --- /dev/null +++ b/test/srt/cpu/test_topk.py @@ -0,0 +1,98 @@ +import itertools +import unittest + +import sgl_kernel +import torch +from utils import precision + +from sglang.srt.layers.moe.topk import ( + biased_grouped_topk_impl as native_biased_grouped_topk, +) +from sglang.srt.layers.moe.topk import grouped_topk as native_grouped_topk +from sglang.test.test_utils import CustomTestCase + + +# This is used by the Deepseek-V2 model +class TestGroupedTopK(CustomTestCase): + def _run_single_test(self, M, E, G, topk, topk_group, renormalize, dtype): + torch.manual_seed(1234) + + # expand gating_output by M, otherwise bfloat16 fall into same value aftering truncating + hidden_states = torch.randn(M, 100, dtype=dtype) + gating_output = torch.randn(M, E, dtype=dtype) * 2 * M + + ref_topk_weights, ref_topk_ids = native_grouped_topk( + hidden_states.float(), + gating_output.float(), + topk, + renormalize, + G, + topk_group, + ) + + # fused version + topk_weights, topk_ids = torch.ops.sgl_kernel.grouped_topk_cpu( + hidden_states, gating_output, topk, renormalize, G, topk_group + ) + + res = torch.zeros(M, E, dtype=torch.float) + ref = torch.zeros(M, E, dtype=torch.float) + res.scatter_(1, topk_ids.long(), topk_weights) + ref.scatter_(1, ref_topk_ids.long(), ref_topk_weights) + torch.testing.assert_close(res, ref) + + def test_grouped_topk(self): + for renormalize in [True, False]: + self._run_single_test(123, 8, 2, 2, 1, renormalize, torch.bfloat16) + self._run_single_test(123, 16, 4, 3, 2, renormalize, torch.bfloat16) + self._run_single_test(123, 32, 4, 3, 2, renormalize, torch.bfloat16) + self._run_single_test(1123, 32, 4, 3, 2, renormalize, torch.bfloat16) + self._run_single_test(123, 64, 1, 6, 1, renormalize, torch.bfloat16) + self._run_single_test(123, 256, 8, 4, 8, renormalize, torch.bfloat16) + self._run_single_test(123, 160, 8, 6, 2, renormalize, torch.bfloat16) + + +# DeepSeek V2/V3/R1 uses biased_grouped_top +class TestBiasedGroupedTopK(CustomTestCase): + def _run_single_test(self, M, E, G, topk, topk_group, renormalize, dtype): + torch.manual_seed(1234) + + # expand gating_output by M, otherwise bfloat16 fall into same value aftering truncating + hidden_states = torch.randn(M, 100, dtype=dtype) + gating_output = torch.randn(M, E, dtype=dtype) * 2 * M + correction_bias = torch.randn(E, dtype=dtype) + + ref_topk_weights, ref_topk_ids = native_biased_grouped_topk( + hidden_states.float(), + gating_output.float(), + correction_bias.float(), + topk, + renormalize, + G, + topk_group, + ) + + # fused version + topk_weights, topk_ids = torch.ops.sgl_kernel.biased_grouped_topk_cpu( + hidden_states, + gating_output, + correction_bias, + topk, + renormalize, + G, + topk_group, + ) + + res = torch.zeros(M, E, dtype=torch.float) + ref = torch.zeros(M, E, dtype=torch.float) + res.scatter_(1, topk_ids.long(), topk_weights) + ref.scatter_(1, ref_topk_ids.long(), ref_topk_weights) + torch.testing.assert_close(res, ref) + + def test_biased_grouped_topk(self): + for renormalize in [True, False]: + self._run_single_test(122, 256, 8, 8, 2, renormalize, torch.bfloat16) + + +if __name__ == "__main__": + unittest.main()