Add CPU optimized kernels for topk and rope fusions (#6456)
This commit is contained in:
@@ -8,7 +8,9 @@ from utils import precision
|
||||
from sglang.srt.layers.moe.topk import (
|
||||
biased_grouped_topk_impl as native_biased_grouped_topk,
|
||||
)
|
||||
from sglang.srt.layers.moe.topk import fused_topk_native as native_fused_topk
|
||||
from sglang.srt.layers.moe.topk import grouped_topk as native_grouped_topk
|
||||
from sglang.srt.models.llama4 import Llama4MoE
|
||||
from sglang.test.test_utils import CustomTestCase
|
||||
|
||||
|
||||
@@ -94,5 +96,86 @@ class TestBiasedGroupedTopK(CustomTestCase):
|
||||
self._run_single_test(122, 256, 8, 8, 2, renormalize, torch.bfloat16)
|
||||
|
||||
|
||||
class TestTopK(CustomTestCase):
|
||||
def _run_single_test(self, M, E, topk, renormalize, dtype):
|
||||
torch.manual_seed(1998)
|
||||
|
||||
# expand gating_output by M, otherwise bfloat16 fall into same value aftering truncating
|
||||
hidden_states = torch.randn(M, 100, dtype=dtype)
|
||||
gating_output = torch.randn(M, E, dtype=dtype) * 2 * M
|
||||
|
||||
ref_topk_weights, ref_topk_ids = native_fused_topk(
|
||||
hidden_states.float(),
|
||||
gating_output.float(),
|
||||
topk,
|
||||
renormalize,
|
||||
)
|
||||
|
||||
# fused version
|
||||
topk_weights, topk_ids = torch.ops.sgl_kernel.topk_softmax_cpu(
|
||||
hidden_states, gating_output, topk, renormalize
|
||||
)
|
||||
|
||||
res = torch.zeros(M, E, dtype=torch.float)
|
||||
ref = torch.zeros(M, E, dtype=torch.float)
|
||||
res.scatter_(1, topk_ids.long(), topk_weights)
|
||||
ref.scatter_(1, ref_topk_ids.long(), ref_topk_weights)
|
||||
torch.testing.assert_close(res, ref)
|
||||
|
||||
def test_topk(self):
|
||||
for renormalize in [True, False]:
|
||||
self._run_single_test(123, 8, 2, renormalize, torch.bfloat16)
|
||||
self._run_single_test(123, 16, 3, renormalize, torch.bfloat16)
|
||||
self._run_single_test(123, 32, 3, renormalize, torch.bfloat16)
|
||||
self._run_single_test(123, 32, 3, renormalize, torch.bfloat16)
|
||||
self._run_single_test(123, 64, 6, renormalize, torch.bfloat16)
|
||||
self._run_single_test(123, 256, 4, renormalize, torch.bfloat16)
|
||||
self._run_single_test(123, 160, 6, renormalize, torch.bfloat16)
|
||||
|
||||
|
||||
class TestCustomTopK(CustomTestCase):
|
||||
def _run_single_test(
|
||||
self, M, E, topk, renormalize, dtype, native_custom_f, fused_custom_f
|
||||
):
|
||||
torch.manual_seed(16)
|
||||
|
||||
# expand gating_output by M, otherwise bfloat16 fall into same value aftering truncating
|
||||
hidden_states = torch.randn(M, 100, dtype=dtype)
|
||||
gating_output = torch.randn(M, E, dtype=dtype) * 2 * M
|
||||
|
||||
ref_topk_weights, ref_topk_ids = native_custom_f(
|
||||
hidden_states.float(),
|
||||
gating_output.float(),
|
||||
topk,
|
||||
renormalize,
|
||||
)
|
||||
|
||||
# fused version
|
||||
topk_weights, topk_ids = fused_custom_f(
|
||||
hidden_states, gating_output, topk, renormalize
|
||||
)
|
||||
|
||||
res = torch.zeros(M, E, dtype=torch.float)
|
||||
ref = torch.zeros(M, E, dtype=torch.float)
|
||||
res.scatter_(1, topk_ids.long(), topk_weights)
|
||||
ref.scatter_(1, ref_topk_ids.long(), ref_topk_weights)
|
||||
torch.testing.assert_close(res, ref)
|
||||
|
||||
def test_custom_topk(self):
|
||||
test_custom_functions = [
|
||||
(Llama4MoE.custom_routing_function, torch.ops.sgl_kernel.topk_sigmoid_cpu)
|
||||
]
|
||||
for native_custom_f, fused_custom_f in test_custom_functions:
|
||||
self._run_single_test(
|
||||
123, 8, 1, False, torch.bfloat16, native_custom_f, fused_custom_f
|
||||
)
|
||||
self._run_single_test(
|
||||
123, 16, 1, False, torch.bfloat16, native_custom_f, fused_custom_f
|
||||
)
|
||||
self._run_single_test(
|
||||
123, 32, 1, False, torch.bfloat16, native_custom_f, fused_custom_f
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user