Update sgl-kernel UTs for activation/topk/norm/rope kernels (#6452)
This commit is contained in:
33
test/srt/cpu/test_activation.py
Normal file
33
test/srt/cpu/test_activation.py
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
import itertools
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import sgl_kernel
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from utils import SiluAndMul, precision
|
||||||
|
|
||||||
|
from sglang.test.test_utils import CustomTestCase
|
||||||
|
|
||||||
|
|
||||||
|
class TestActivation(CustomTestCase):
|
||||||
|
M = [128, 129, 257]
|
||||||
|
N = [22016, 22018]
|
||||||
|
dtype = [torch.float16, torch.bfloat16]
|
||||||
|
|
||||||
|
def _activation_test(self, m, n, dtype):
|
||||||
|
x = torch.randn([m, n], dtype=dtype)
|
||||||
|
|
||||||
|
out = torch.ops.sgl_kernel.silu_and_mul_cpu(x)
|
||||||
|
ref_out = SiluAndMul(x)
|
||||||
|
|
||||||
|
atol = rtol = precision[ref_out.dtype]
|
||||||
|
self.assertTrue(torch.allclose(ref_out, out, atol=atol, rtol=rtol))
|
||||||
|
|
||||||
|
def test_activation(self):
|
||||||
|
for params in itertools.product(self.M, self.N, self.dtype):
|
||||||
|
with self.subTest(m=params[0], n=params[1], dtype=params[2]):
|
||||||
|
self._activation_test(*params)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
73
test/srt/cpu/test_norm.py
Normal file
73
test/srt/cpu/test_norm.py
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
import itertools
|
||||||
|
import unittest
|
||||||
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
|
import sgl_kernel
|
||||||
|
import torch
|
||||||
|
from utils import precision
|
||||||
|
|
||||||
|
from sglang.test.test_utils import CustomTestCase
|
||||||
|
|
||||||
|
|
||||||
|
class TestNorm(CustomTestCase):
|
||||||
|
M = [4096, 1024]
|
||||||
|
N = [4096, 4096 + 13]
|
||||||
|
dtype = [torch.float16, torch.bfloat16]
|
||||||
|
|
||||||
|
def _forward_native(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
variance_epsilon: float = 1e-6,
|
||||||
|
residual: Optional[torch.Tensor] = None,
|
||||||
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||||
|
orig_dtype = x.dtype
|
||||||
|
x = x.to(torch.float32)
|
||||||
|
if residual is not None:
|
||||||
|
x = x + residual.to(torch.float32)
|
||||||
|
residual = x.to(orig_dtype)
|
||||||
|
|
||||||
|
variance = x.pow(2).mean(dim=-1, keepdim=True)
|
||||||
|
x = x * torch.rsqrt(variance + variance_epsilon)
|
||||||
|
x = x.to(orig_dtype) * weight
|
||||||
|
if residual is None:
|
||||||
|
return x
|
||||||
|
else:
|
||||||
|
return x, residual
|
||||||
|
|
||||||
|
def _norm_test(self, m, n, dtype):
|
||||||
|
|
||||||
|
x = torch.randn([m, n], dtype=dtype)
|
||||||
|
hidden_size = x.size(-1)
|
||||||
|
weight = torch.randn(hidden_size, dtype=dtype)
|
||||||
|
variance_epsilon = 1e-6
|
||||||
|
|
||||||
|
out = torch.ops.sgl_kernel.rmsnorm_cpu(x, weight, variance_epsilon)
|
||||||
|
ref_out = self._forward_native(x, weight, variance_epsilon)
|
||||||
|
|
||||||
|
atol = rtol = precision[ref_out.dtype]
|
||||||
|
self.assertTrue(torch.allclose(ref_out, out, atol=atol, rtol=rtol))
|
||||||
|
|
||||||
|
ref_x = x.clone()
|
||||||
|
residual = torch.randn([m, n], dtype=dtype)
|
||||||
|
ref_residual = residual.clone()
|
||||||
|
|
||||||
|
torch.ops.sgl_kernel.fused_add_rmsnorm_cpu(
|
||||||
|
x, residual, weight, variance_epsilon
|
||||||
|
)
|
||||||
|
|
||||||
|
ref_x, ref_residual = self._forward_native(
|
||||||
|
ref_x, weight, variance_epsilon, ref_residual
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertTrue(torch.allclose(x, ref_x, atol=atol, rtol=rtol))
|
||||||
|
self.assertTrue(torch.allclose(residual, ref_residual, atol=atol, rtol=rtol))
|
||||||
|
|
||||||
|
def test_norm(self):
|
||||||
|
for params in itertools.product(self.M, self.N, self.dtype):
|
||||||
|
with self.subTest(m=params[0], n=params[1], dtype=params[2]):
|
||||||
|
self._norm_test(*params)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
78
test/srt/cpu/test_rope.py
Normal file
78
test/srt/cpu/test_rope.py
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
import unittest
|
||||||
|
|
||||||
|
import sgl_kernel
|
||||||
|
import torch
|
||||||
|
from utils import precision
|
||||||
|
|
||||||
|
from sglang.srt.layers.rotary_embedding import DeepseekScalingRotaryEmbedding
|
||||||
|
from sglang.test.test_utils import CustomTestCase
|
||||||
|
|
||||||
|
|
||||||
|
class TestROPE(CustomTestCase):
|
||||||
|
def test_deepseek_v2_rope(self):
|
||||||
|
num_head = 16
|
||||||
|
seq_len = 1024
|
||||||
|
q_head_dim = 192
|
||||||
|
qk_nope_head_dim = 128
|
||||||
|
qk_rope_head_dim = 64
|
||||||
|
max_pos = 256
|
||||||
|
k_dim = 576
|
||||||
|
rotary_dim = 64
|
||||||
|
is_neox_style = False
|
||||||
|
|
||||||
|
# Create cos_sin_cache
|
||||||
|
freqs = torch.rand(max_pos, qk_rope_head_dim // 2)
|
||||||
|
cos = freqs.cos() * 0.7
|
||||||
|
sin = freqs.sin() * 0.7
|
||||||
|
cos_sin_cache = torch.cat((cos, sin), dim=-1).to(torch.bfloat16)
|
||||||
|
positions = torch.randint(0, max_pos, (seq_len,))
|
||||||
|
|
||||||
|
rope = DeepseekScalingRotaryEmbedding(
|
||||||
|
qk_rope_head_dim,
|
||||||
|
rotary_dim,
|
||||||
|
max_pos,
|
||||||
|
16, # not used since cos_sin_cache is provided
|
||||||
|
is_neox_style,
|
||||||
|
1.0,
|
||||||
|
torch.bfloat16,
|
||||||
|
device="cpu",
|
||||||
|
)
|
||||||
|
rope.register_buffer("cos_sin_cache", cos_sin_cache)
|
||||||
|
|
||||||
|
for dtype in [torch.bfloat16]:
|
||||||
|
enable_autocast = True
|
||||||
|
|
||||||
|
with torch.no_grad(), torch.amp.autocast("cpu", enabled=enable_autocast):
|
||||||
|
q = torch.randn(seq_len, num_head, q_head_dim, dtype=dtype)
|
||||||
|
q_clone = q.clone()
|
||||||
|
k = torch.randn(seq_len, 1, k_dim, dtype=dtype)
|
||||||
|
k_clone = k.clone()
|
||||||
|
_, q_pe = q.split([qk_nope_head_dim, qk_rope_head_dim], dim=-1)
|
||||||
|
_, q_pe_clone = q_clone.split(
|
||||||
|
[qk_nope_head_dim, qk_rope_head_dim], dim=-1
|
||||||
|
)
|
||||||
|
k_pe = k[:, :, k_dim - qk_rope_head_dim :]
|
||||||
|
k_pe_clone = k_clone[:, :, k_dim - qk_rope_head_dim :]
|
||||||
|
|
||||||
|
# ref kernel
|
||||||
|
q_pe, k_pe = rope.forward_native(
|
||||||
|
query=q_pe,
|
||||||
|
key=k_pe,
|
||||||
|
positions=positions,
|
||||||
|
)
|
||||||
|
|
||||||
|
# fused rope kernel
|
||||||
|
q_pe_clone, k_pe_clone = (
|
||||||
|
torch.ops.sgl_kernel.rotary_position_embedding_cpu(
|
||||||
|
positions, q_pe_clone, k_pe_clone, cos_sin_cache
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
atol = rtol = precision[q_pe.dtype]
|
||||||
|
self.assertTrue(torch.allclose(q_pe, q_pe_clone, atol=atol, rtol=rtol))
|
||||||
|
self.assertTrue(torch.allclose(k_pe, k_pe_clone, atol=atol, rtol=rtol))
|
||||||
|
torch.testing.assert_close(k_pe, k_pe_clone)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
98
test/srt/cpu/test_topk.py
Normal file
98
test/srt/cpu/test_topk.py
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
import itertools
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import sgl_kernel
|
||||||
|
import torch
|
||||||
|
from utils import precision
|
||||||
|
|
||||||
|
from sglang.srt.layers.moe.topk import (
|
||||||
|
biased_grouped_topk_impl as native_biased_grouped_topk,
|
||||||
|
)
|
||||||
|
from sglang.srt.layers.moe.topk import grouped_topk as native_grouped_topk
|
||||||
|
from sglang.test.test_utils import CustomTestCase
|
||||||
|
|
||||||
|
|
||||||
|
# This is used by the Deepseek-V2 model
|
||||||
|
class TestGroupedTopK(CustomTestCase):
|
||||||
|
def _run_single_test(self, M, E, G, topk, topk_group, renormalize, dtype):
|
||||||
|
torch.manual_seed(1234)
|
||||||
|
|
||||||
|
# expand gating_output by M, otherwise bfloat16 fall into same value aftering truncating
|
||||||
|
hidden_states = torch.randn(M, 100, dtype=dtype)
|
||||||
|
gating_output = torch.randn(M, E, dtype=dtype) * 2 * M
|
||||||
|
|
||||||
|
ref_topk_weights, ref_topk_ids = native_grouped_topk(
|
||||||
|
hidden_states.float(),
|
||||||
|
gating_output.float(),
|
||||||
|
topk,
|
||||||
|
renormalize,
|
||||||
|
G,
|
||||||
|
topk_group,
|
||||||
|
)
|
||||||
|
|
||||||
|
# fused version
|
||||||
|
topk_weights, topk_ids = torch.ops.sgl_kernel.grouped_topk_cpu(
|
||||||
|
hidden_states, gating_output, topk, renormalize, G, topk_group
|
||||||
|
)
|
||||||
|
|
||||||
|
res = torch.zeros(M, E, dtype=torch.float)
|
||||||
|
ref = torch.zeros(M, E, dtype=torch.float)
|
||||||
|
res.scatter_(1, topk_ids.long(), topk_weights)
|
||||||
|
ref.scatter_(1, ref_topk_ids.long(), ref_topk_weights)
|
||||||
|
torch.testing.assert_close(res, ref)
|
||||||
|
|
||||||
|
def test_grouped_topk(self):
|
||||||
|
for renormalize in [True, False]:
|
||||||
|
self._run_single_test(123, 8, 2, 2, 1, renormalize, torch.bfloat16)
|
||||||
|
self._run_single_test(123, 16, 4, 3, 2, renormalize, torch.bfloat16)
|
||||||
|
self._run_single_test(123, 32, 4, 3, 2, renormalize, torch.bfloat16)
|
||||||
|
self._run_single_test(1123, 32, 4, 3, 2, renormalize, torch.bfloat16)
|
||||||
|
self._run_single_test(123, 64, 1, 6, 1, renormalize, torch.bfloat16)
|
||||||
|
self._run_single_test(123, 256, 8, 4, 8, renormalize, torch.bfloat16)
|
||||||
|
self._run_single_test(123, 160, 8, 6, 2, renormalize, torch.bfloat16)
|
||||||
|
|
||||||
|
|
||||||
|
# DeepSeek V2/V3/R1 uses biased_grouped_top
|
||||||
|
class TestBiasedGroupedTopK(CustomTestCase):
|
||||||
|
def _run_single_test(self, M, E, G, topk, topk_group, renormalize, dtype):
|
||||||
|
torch.manual_seed(1234)
|
||||||
|
|
||||||
|
# expand gating_output by M, otherwise bfloat16 fall into same value aftering truncating
|
||||||
|
hidden_states = torch.randn(M, 100, dtype=dtype)
|
||||||
|
gating_output = torch.randn(M, E, dtype=dtype) * 2 * M
|
||||||
|
correction_bias = torch.randn(E, dtype=dtype)
|
||||||
|
|
||||||
|
ref_topk_weights, ref_topk_ids = native_biased_grouped_topk(
|
||||||
|
hidden_states.float(),
|
||||||
|
gating_output.float(),
|
||||||
|
correction_bias.float(),
|
||||||
|
topk,
|
||||||
|
renormalize,
|
||||||
|
G,
|
||||||
|
topk_group,
|
||||||
|
)
|
||||||
|
|
||||||
|
# fused version
|
||||||
|
topk_weights, topk_ids = torch.ops.sgl_kernel.biased_grouped_topk_cpu(
|
||||||
|
hidden_states,
|
||||||
|
gating_output,
|
||||||
|
correction_bias,
|
||||||
|
topk,
|
||||||
|
renormalize,
|
||||||
|
G,
|
||||||
|
topk_group,
|
||||||
|
)
|
||||||
|
|
||||||
|
res = torch.zeros(M, E, dtype=torch.float)
|
||||||
|
ref = torch.zeros(M, E, dtype=torch.float)
|
||||||
|
res.scatter_(1, topk_ids.long(), topk_weights)
|
||||||
|
ref.scatter_(1, ref_topk_ids.long(), ref_topk_weights)
|
||||||
|
torch.testing.assert_close(res, ref)
|
||||||
|
|
||||||
|
def test_biased_grouped_topk(self):
|
||||||
|
for renormalize in [True, False]:
|
||||||
|
self._run_single_test(122, 256, 8, 8, 2, renormalize, torch.bfloat16)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user