Add CPU optimized kernels for topk and rope fusions (#6456)
This commit is contained in:
@@ -63,10 +63,24 @@ class TestNorm(CustomTestCase):
|
||||
self.assertTrue(torch.allclose(x, ref_x, atol=atol, rtol=rtol))
|
||||
self.assertTrue(torch.allclose(residual, ref_residual, atol=atol, rtol=rtol))
|
||||
|
||||
def _l2norm_test(self, m, n, dtype):
|
||||
|
||||
x = torch.randn([m, n], dtype=dtype)
|
||||
hidden_size = x.size(-1)
|
||||
fake_ones_weight = torch.ones(hidden_size, dtype=dtype)
|
||||
variance_epsilon = 1e-6
|
||||
|
||||
out = torch.ops.sgl_kernel.l2norm_cpu(x, variance_epsilon)
|
||||
ref_out = self._forward_native(x, fake_ones_weight, variance_epsilon)
|
||||
|
||||
atol = rtol = precision[ref_out.dtype]
|
||||
self.assertTrue(torch.allclose(ref_out, out, atol=atol, rtol=rtol))
|
||||
|
||||
def test_norm(self):
|
||||
for params in itertools.product(self.M, self.N, self.dtype):
|
||||
with self.subTest(m=params[0], n=params[1], dtype=params[2]):
|
||||
self._norm_test(*params)
|
||||
self._l2norm_test(*params)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -4,7 +4,10 @@ import sgl_kernel
|
||||
import torch
|
||||
from utils import precision
|
||||
|
||||
from sglang.srt.layers.rotary_embedding import DeepseekScalingRotaryEmbedding
|
||||
from sglang.srt.layers.rotary_embedding import (
|
||||
DeepseekScalingRotaryEmbedding,
|
||||
RotaryEmbedding,
|
||||
)
|
||||
from sglang.test.test_utils import CustomTestCase
|
||||
|
||||
|
||||
@@ -62,10 +65,13 @@ class TestROPE(CustomTestCase):
|
||||
)
|
||||
|
||||
# fused rope kernel
|
||||
q_pe_clone, k_pe_clone = (
|
||||
torch.ops.sgl_kernel.rotary_position_embedding_cpu(
|
||||
positions, q_pe_clone, k_pe_clone, cos_sin_cache
|
||||
)
|
||||
q_pe_clone, k_pe_clone = torch.ops.sgl_kernel.rotary_embedding_cpu(
|
||||
positions,
|
||||
q_pe_clone,
|
||||
k_pe_clone,
|
||||
rope.head_size,
|
||||
cos_sin_cache,
|
||||
False,
|
||||
)
|
||||
|
||||
atol = rtol = precision[q_pe.dtype]
|
||||
@@ -73,6 +79,98 @@ class TestROPE(CustomTestCase):
|
||||
self.assertTrue(torch.allclose(k_pe, k_pe_clone, atol=atol, rtol=rtol))
|
||||
torch.testing.assert_close(k_pe, k_pe_clone)
|
||||
|
||||
def test_origin_rope(self):
|
||||
def single_test(
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
base: int,
|
||||
is_neox_style: bool,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
num_q_heads: int,
|
||||
num_kv_heads: int,
|
||||
):
|
||||
torch.manual_seed(100)
|
||||
rope_ref = RotaryEmbedding(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position_embeddings,
|
||||
base,
|
||||
is_neox_style,
|
||||
dtype,
|
||||
).to(device)
|
||||
pos_ids = torch.arange(seq_len, device=device).repeat(batch_size)
|
||||
query = torch.randn(
|
||||
batch_size * seq_len,
|
||||
num_q_heads * head_size,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
key = torch.randn(
|
||||
batch_size * seq_len,
|
||||
num_kv_heads * head_size,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
query_ref, key_ref = query.clone(), key.clone()
|
||||
query_cpu, key_cpu = query.clone(), key.clone()
|
||||
|
||||
query_ref_out, key_ref_out = rope_ref.forward_native(
|
||||
pos_ids, query_ref, key_ref
|
||||
)
|
||||
query_cpu_out, key_cpu_out = torch.ops.sgl_kernel.rotary_embedding_cpu(
|
||||
pos_ids,
|
||||
query_cpu,
|
||||
key_cpu,
|
||||
rope_ref.head_size,
|
||||
rope_ref.cos_sin_cache.to(query.dtype),
|
||||
rope_ref.is_neox_style,
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
query_ref_out, query_cpu_out, atol=1e-2, rtol=1e-2
|
||||
)
|
||||
torch.testing.assert_close(key_ref_out, key_cpu_out, atol=1e-2, rtol=1e-2)
|
||||
|
||||
test_config = [
|
||||
(64, 64, 32, 8000, True, torch.bfloat16, "cpu", 32, 32, 1, 1),
|
||||
(256, 128, 4096, 10000, True, torch.bfloat16, "cpu", 2, 512, 32, 8),
|
||||
(512, 128, 311, 10000, True, torch.bfloat16, "cpu", 3, 39, 4, 2),
|
||||
(128, 128, 2048, 10000, False, torch.bfloat16, "cpu", 2, 512, 32, 8),
|
||||
(128, 128, 2048, 10000, False, torch.bfloat16, "cpu", 2, 512, 16, 4),
|
||||
(512, 128, 311, 10000, False, torch.bfloat16, "cpu", 3, 39, 4, 2),
|
||||
]
|
||||
|
||||
for (
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position_embeddings,
|
||||
base,
|
||||
is_neox_style,
|
||||
dtype,
|
||||
device,
|
||||
batch_size,
|
||||
seq_len,
|
||||
num_q_heads,
|
||||
num_kv_heads,
|
||||
) in test_config:
|
||||
single_test(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position_embeddings,
|
||||
base,
|
||||
is_neox_style,
|
||||
dtype,
|
||||
device,
|
||||
batch_size,
|
||||
seq_len,
|
||||
num_q_heads,
|
||||
num_kv_heads,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -8,7 +8,9 @@ from utils import precision
|
||||
from sglang.srt.layers.moe.topk import (
|
||||
biased_grouped_topk_impl as native_biased_grouped_topk,
|
||||
)
|
||||
from sglang.srt.layers.moe.topk import fused_topk_native as native_fused_topk
|
||||
from sglang.srt.layers.moe.topk import grouped_topk as native_grouped_topk
|
||||
from sglang.srt.models.llama4 import Llama4MoE
|
||||
from sglang.test.test_utils import CustomTestCase
|
||||
|
||||
|
||||
@@ -94,5 +96,86 @@ class TestBiasedGroupedTopK(CustomTestCase):
|
||||
self._run_single_test(122, 256, 8, 8, 2, renormalize, torch.bfloat16)
|
||||
|
||||
|
||||
class TestTopK(CustomTestCase):
|
||||
def _run_single_test(self, M, E, topk, renormalize, dtype):
|
||||
torch.manual_seed(1998)
|
||||
|
||||
# expand gating_output by M, otherwise bfloat16 fall into same value aftering truncating
|
||||
hidden_states = torch.randn(M, 100, dtype=dtype)
|
||||
gating_output = torch.randn(M, E, dtype=dtype) * 2 * M
|
||||
|
||||
ref_topk_weights, ref_topk_ids = native_fused_topk(
|
||||
hidden_states.float(),
|
||||
gating_output.float(),
|
||||
topk,
|
||||
renormalize,
|
||||
)
|
||||
|
||||
# fused version
|
||||
topk_weights, topk_ids = torch.ops.sgl_kernel.topk_softmax_cpu(
|
||||
hidden_states, gating_output, topk, renormalize
|
||||
)
|
||||
|
||||
res = torch.zeros(M, E, dtype=torch.float)
|
||||
ref = torch.zeros(M, E, dtype=torch.float)
|
||||
res.scatter_(1, topk_ids.long(), topk_weights)
|
||||
ref.scatter_(1, ref_topk_ids.long(), ref_topk_weights)
|
||||
torch.testing.assert_close(res, ref)
|
||||
|
||||
def test_topk(self):
|
||||
for renormalize in [True, False]:
|
||||
self._run_single_test(123, 8, 2, renormalize, torch.bfloat16)
|
||||
self._run_single_test(123, 16, 3, renormalize, torch.bfloat16)
|
||||
self._run_single_test(123, 32, 3, renormalize, torch.bfloat16)
|
||||
self._run_single_test(123, 32, 3, renormalize, torch.bfloat16)
|
||||
self._run_single_test(123, 64, 6, renormalize, torch.bfloat16)
|
||||
self._run_single_test(123, 256, 4, renormalize, torch.bfloat16)
|
||||
self._run_single_test(123, 160, 6, renormalize, torch.bfloat16)
|
||||
|
||||
|
||||
class TestCustomTopK(CustomTestCase):
|
||||
def _run_single_test(
|
||||
self, M, E, topk, renormalize, dtype, native_custom_f, fused_custom_f
|
||||
):
|
||||
torch.manual_seed(16)
|
||||
|
||||
# expand gating_output by M, otherwise bfloat16 fall into same value aftering truncating
|
||||
hidden_states = torch.randn(M, 100, dtype=dtype)
|
||||
gating_output = torch.randn(M, E, dtype=dtype) * 2 * M
|
||||
|
||||
ref_topk_weights, ref_topk_ids = native_custom_f(
|
||||
hidden_states.float(),
|
||||
gating_output.float(),
|
||||
topk,
|
||||
renormalize,
|
||||
)
|
||||
|
||||
# fused version
|
||||
topk_weights, topk_ids = fused_custom_f(
|
||||
hidden_states, gating_output, topk, renormalize
|
||||
)
|
||||
|
||||
res = torch.zeros(M, E, dtype=torch.float)
|
||||
ref = torch.zeros(M, E, dtype=torch.float)
|
||||
res.scatter_(1, topk_ids.long(), topk_weights)
|
||||
ref.scatter_(1, ref_topk_ids.long(), ref_topk_weights)
|
||||
torch.testing.assert_close(res, ref)
|
||||
|
||||
def test_custom_topk(self):
|
||||
test_custom_functions = [
|
||||
(Llama4MoE.custom_routing_function, torch.ops.sgl_kernel.topk_sigmoid_cpu)
|
||||
]
|
||||
for native_custom_f, fused_custom_f in test_custom_functions:
|
||||
self._run_single_test(
|
||||
123, 8, 1, False, torch.bfloat16, native_custom_f, fused_custom_f
|
||||
)
|
||||
self._run_single_test(
|
||||
123, 16, 1, False, torch.bfloat16, native_custom_f, fused_custom_f
|
||||
)
|
||||
self._run_single_test(
|
||||
123, 32, 1, False, torch.bfloat16, native_custom_f, fused_custom_f
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user