Add CPU optimized kernels for topk and rope fusions (#6456)

This commit is contained in:
jianan-gu
2025-06-03 08:37:34 +08:00
committed by GitHub
parent ff91474825
commit ff00895c46
7 changed files with 829 additions and 98 deletions

View File

@@ -63,10 +63,24 @@ class TestNorm(CustomTestCase):
self.assertTrue(torch.allclose(x, ref_x, atol=atol, rtol=rtol))
self.assertTrue(torch.allclose(residual, ref_residual, atol=atol, rtol=rtol))
def _l2norm_test(self, m, n, dtype):
x = torch.randn([m, n], dtype=dtype)
hidden_size = x.size(-1)
fake_ones_weight = torch.ones(hidden_size, dtype=dtype)
variance_epsilon = 1e-6
out = torch.ops.sgl_kernel.l2norm_cpu(x, variance_epsilon)
ref_out = self._forward_native(x, fake_ones_weight, variance_epsilon)
atol = rtol = precision[ref_out.dtype]
self.assertTrue(torch.allclose(ref_out, out, atol=atol, rtol=rtol))
def test_norm(self):
for params in itertools.product(self.M, self.N, self.dtype):
with self.subTest(m=params[0], n=params[1], dtype=params[2]):
self._norm_test(*params)
self._l2norm_test(*params)
if __name__ == "__main__":

View File

@@ -4,7 +4,10 @@ import sgl_kernel
import torch
from utils import precision
from sglang.srt.layers.rotary_embedding import DeepseekScalingRotaryEmbedding
from sglang.srt.layers.rotary_embedding import (
DeepseekScalingRotaryEmbedding,
RotaryEmbedding,
)
from sglang.test.test_utils import CustomTestCase
@@ -62,10 +65,13 @@ class TestROPE(CustomTestCase):
)
# fused rope kernel
q_pe_clone, k_pe_clone = (
torch.ops.sgl_kernel.rotary_position_embedding_cpu(
positions, q_pe_clone, k_pe_clone, cos_sin_cache
)
q_pe_clone, k_pe_clone = torch.ops.sgl_kernel.rotary_embedding_cpu(
positions,
q_pe_clone,
k_pe_clone,
rope.head_size,
cos_sin_cache,
False,
)
atol = rtol = precision[q_pe.dtype]
@@ -73,6 +79,98 @@ class TestROPE(CustomTestCase):
self.assertTrue(torch.allclose(k_pe, k_pe_clone, atol=atol, rtol=rtol))
torch.testing.assert_close(k_pe, k_pe_clone)
def test_origin_rope(self):
def single_test(
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
dtype: torch.dtype,
device: str,
batch_size: int,
seq_len: int,
num_q_heads: int,
num_kv_heads: int,
):
torch.manual_seed(100)
rope_ref = RotaryEmbedding(
head_size,
rotary_dim,
max_position_embeddings,
base,
is_neox_style,
dtype,
).to(device)
pos_ids = torch.arange(seq_len, device=device).repeat(batch_size)
query = torch.randn(
batch_size * seq_len,
num_q_heads * head_size,
dtype=dtype,
device=device,
)
key = torch.randn(
batch_size * seq_len,
num_kv_heads * head_size,
dtype=dtype,
device=device,
)
query_ref, key_ref = query.clone(), key.clone()
query_cpu, key_cpu = query.clone(), key.clone()
query_ref_out, key_ref_out = rope_ref.forward_native(
pos_ids, query_ref, key_ref
)
query_cpu_out, key_cpu_out = torch.ops.sgl_kernel.rotary_embedding_cpu(
pos_ids,
query_cpu,
key_cpu,
rope_ref.head_size,
rope_ref.cos_sin_cache.to(query.dtype),
rope_ref.is_neox_style,
)
torch.testing.assert_close(
query_ref_out, query_cpu_out, atol=1e-2, rtol=1e-2
)
torch.testing.assert_close(key_ref_out, key_cpu_out, atol=1e-2, rtol=1e-2)
test_config = [
(64, 64, 32, 8000, True, torch.bfloat16, "cpu", 32, 32, 1, 1),
(256, 128, 4096, 10000, True, torch.bfloat16, "cpu", 2, 512, 32, 8),
(512, 128, 311, 10000, True, torch.bfloat16, "cpu", 3, 39, 4, 2),
(128, 128, 2048, 10000, False, torch.bfloat16, "cpu", 2, 512, 32, 8),
(128, 128, 2048, 10000, False, torch.bfloat16, "cpu", 2, 512, 16, 4),
(512, 128, 311, 10000, False, torch.bfloat16, "cpu", 3, 39, 4, 2),
]
for (
head_size,
rotary_dim,
max_position_embeddings,
base,
is_neox_style,
dtype,
device,
batch_size,
seq_len,
num_q_heads,
num_kv_heads,
) in test_config:
single_test(
head_size,
rotary_dim,
max_position_embeddings,
base,
is_neox_style,
dtype,
device,
batch_size,
seq_len,
num_q_heads,
num_kv_heads,
)
if __name__ == "__main__":
unittest.main()

View File

@@ -8,7 +8,9 @@ from utils import precision
from sglang.srt.layers.moe.topk import (
biased_grouped_topk_impl as native_biased_grouped_topk,
)
from sglang.srt.layers.moe.topk import fused_topk_native as native_fused_topk
from sglang.srt.layers.moe.topk import grouped_topk as native_grouped_topk
from sglang.srt.models.llama4 import Llama4MoE
from sglang.test.test_utils import CustomTestCase
@@ -94,5 +96,86 @@ class TestBiasedGroupedTopK(CustomTestCase):
self._run_single_test(122, 256, 8, 8, 2, renormalize, torch.bfloat16)
class TestTopK(CustomTestCase):
def _run_single_test(self, M, E, topk, renormalize, dtype):
torch.manual_seed(1998)
# expand gating_output by M, otherwise bfloat16 fall into same value aftering truncating
hidden_states = torch.randn(M, 100, dtype=dtype)
gating_output = torch.randn(M, E, dtype=dtype) * 2 * M
ref_topk_weights, ref_topk_ids = native_fused_topk(
hidden_states.float(),
gating_output.float(),
topk,
renormalize,
)
# fused version
topk_weights, topk_ids = torch.ops.sgl_kernel.topk_softmax_cpu(
hidden_states, gating_output, topk, renormalize
)
res = torch.zeros(M, E, dtype=torch.float)
ref = torch.zeros(M, E, dtype=torch.float)
res.scatter_(1, topk_ids.long(), topk_weights)
ref.scatter_(1, ref_topk_ids.long(), ref_topk_weights)
torch.testing.assert_close(res, ref)
def test_topk(self):
for renormalize in [True, False]:
self._run_single_test(123, 8, 2, renormalize, torch.bfloat16)
self._run_single_test(123, 16, 3, renormalize, torch.bfloat16)
self._run_single_test(123, 32, 3, renormalize, torch.bfloat16)
self._run_single_test(123, 32, 3, renormalize, torch.bfloat16)
self._run_single_test(123, 64, 6, renormalize, torch.bfloat16)
self._run_single_test(123, 256, 4, renormalize, torch.bfloat16)
self._run_single_test(123, 160, 6, renormalize, torch.bfloat16)
class TestCustomTopK(CustomTestCase):
def _run_single_test(
self, M, E, topk, renormalize, dtype, native_custom_f, fused_custom_f
):
torch.manual_seed(16)
# expand gating_output by M, otherwise bfloat16 fall into same value aftering truncating
hidden_states = torch.randn(M, 100, dtype=dtype)
gating_output = torch.randn(M, E, dtype=dtype) * 2 * M
ref_topk_weights, ref_topk_ids = native_custom_f(
hidden_states.float(),
gating_output.float(),
topk,
renormalize,
)
# fused version
topk_weights, topk_ids = fused_custom_f(
hidden_states, gating_output, topk, renormalize
)
res = torch.zeros(M, E, dtype=torch.float)
ref = torch.zeros(M, E, dtype=torch.float)
res.scatter_(1, topk_ids.long(), topk_weights)
ref.scatter_(1, ref_topk_ids.long(), ref_topk_weights)
torch.testing.assert_close(res, ref)
def test_custom_topk(self):
test_custom_functions = [
(Llama4MoE.custom_routing_function, torch.ops.sgl_kernel.topk_sigmoid_cpu)
]
for native_custom_f, fused_custom_f in test_custom_functions:
self._run_single_test(
123, 8, 1, False, torch.bfloat16, native_custom_f, fused_custom_f
)
self._run_single_test(
123, 16, 1, False, torch.bfloat16, native_custom_f, fused_custom_f
)
self._run_single_test(
123, 32, 1, False, torch.bfloat16, native_custom_f, fused_custom_f
)
if __name__ == "__main__":
unittest.main()