diff --git a/tests/ut/kv_connector/test_mooncake_connector.py b/tests/ut/kv_connector/test_mooncake_connector.py index dc34d06e..fabb67b9 100644 --- a/tests/ut/kv_connector/test_mooncake_connector.py +++ b/tests/ut/kv_connector/test_mooncake_connector.py @@ -1260,6 +1260,82 @@ class TestMooncakeConnectorWorker(unittest.TestCase): get_tp_rank(4, 4, 4, 1, 1, True), [[[0, 4, 8, 12], [1, 5, 9, 13], [2, 6, 10, 14], [3, 7, 11, 15]]]) + def test_get_kv_split_metadata(self): + + def get_kv_split_metadata(use_mla, pcp_size, dcp_size, tp_size, + tp_rank, pcp_rank, _prefill_tp_size, + remote_pcp_size, remote_dcp_size, + remote_port, remote_block_ids, + local_block_ids): + + worker = MooncakeConnectorWorker(self.vllm_config, self.engine_id) + + worker.use_mla = use_mla + worker.pcp_size = pcp_size + worker.dcp_size = dcp_size + worker.tp_size = tp_size + worker.tp_rank = tp_rank + worker.pcp_rank = pcp_rank + worker._prefill_tp_size = _prefill_tp_size + worker.local_remote_block_port_mapping = None + + meta = types.SimpleNamespace() + + meta.remote_pcp_size = remote_pcp_size + meta.remote_dcp_size = remote_dcp_size + meta.remote_port = remote_port + meta.remote_block_ids = remote_block_ids + meta.local_block_ids = local_block_ids + + remote_handshake_port_list, local_block_ids_list, remote_block_ids_list = worker._get_kv_split_metadata( + '0', meta) + + return remote_handshake_port_list, local_block_ids_list, remote_block_ids_list + + self.assertEqual( + get_kv_split_metadata(True, 1, 1, 8, 1, 0, 8, 1, 8, 30000, [1], + [1]), + ([[30001], [30002], [30003], [30004], [30005], [30006], [30007], + [30000]], [[], [], [], [], [], [], [], [1]], [[], [], [], [], [], + [], [], [1]])) + + self.assertEqual( + get_kv_split_metadata(False, 1, 1, 8, 1, 0, 8, 2, 8, 30000, [1], + [1]), + ([[30001], [30002], [30003], [30004], [30005], [30006], [30007], + [30008], [30009], [30010], [30011], [30012], [30013], [30014], + [30015], [30000] + ], [[], [], [], [], [], [], [], [], [], [], [], [], [], [], [], + [1]], [[], [], [], [], [], [], [], [], [], [], [], [], [], + [], [], [1]])) + + self.assertEqual( + get_kv_split_metadata(True, 1, 1, 8, 1, 0, 8, 2, 2, 30000, [1], + [1]), + ([[30001], [30008], [30009], [30000]], [[], [], [], [1] + ], [[], [], [], [1]])) + + self.assertEqual( + get_kv_split_metadata(False, 1, 1, 8, 1, 0, 8, 2, 2, 30000, [1], + [1]), + ([[30001], [30008], [30009], [30000]], [[], [], [], [1] + ], [[], [], [], [1]])) + + self.assertEqual( + get_kv_split_metadata(True, 1, 2, 8, 1, 0, 8, 2, 2, 30000, [1], + [1]), + ([[30009], [30001]], [[], [1]], [[], [1]])) + + self.assertEqual( + get_kv_split_metadata(False, 1, 2, 8, 1, 0, 8, 2, 2, 30000, [1], + [1]), + ([[30009], [30001]], [[], [1]], [[], [1]])) + + self.assertEqual( + get_kv_split_metadata(True, 1, 2, 8, 0, 0, 8, 2, 2, 30000, + [1, 2, 3], [1, 2, 3, 4, 5]), + ([[30008], [30000]], [[1, 2], [3, 4, 5]], [[1, 2], [1, 2, 3]])) + if __name__ == '__main__': unittest.main()