[UT] Add mooncake ut test (#5080)

### What this PR does / why we need it?

Add UT for mooncake

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

Signed-off-by: tongyuzhou <tongyuzhou1@huawei.com>
Signed-off-by: wangxiaochao <w00642655@china.huawei.com>
Co-authored-by: tongyuzhou <tongyuzhou1@huawei.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
Yuzhou Tong
2025-12-18 15:07:14 +08:00
committed by GitHub
parent 9045843c90
commit 78602eab4f

View File

@@ -1260,6 +1260,82 @@ class TestMooncakeConnectorWorker(unittest.TestCase):
get_tp_rank(4, 4, 4, 1, 1, True),
[[[0, 4, 8, 12], [1, 5, 9, 13], [2, 6, 10, 14], [3, 7, 11, 15]]])
def test_get_kv_split_metadata(self):
def get_kv_split_metadata(use_mla, pcp_size, dcp_size, tp_size,
tp_rank, pcp_rank, _prefill_tp_size,
remote_pcp_size, remote_dcp_size,
remote_port, remote_block_ids,
local_block_ids):
worker = MooncakeConnectorWorker(self.vllm_config, self.engine_id)
worker.use_mla = use_mla
worker.pcp_size = pcp_size
worker.dcp_size = dcp_size
worker.tp_size = tp_size
worker.tp_rank = tp_rank
worker.pcp_rank = pcp_rank
worker._prefill_tp_size = _prefill_tp_size
worker.local_remote_block_port_mapping = None
meta = types.SimpleNamespace()
meta.remote_pcp_size = remote_pcp_size
meta.remote_dcp_size = remote_dcp_size
meta.remote_port = remote_port
meta.remote_block_ids = remote_block_ids
meta.local_block_ids = local_block_ids
remote_handshake_port_list, local_block_ids_list, remote_block_ids_list = worker._get_kv_split_metadata(
'0', meta)
return remote_handshake_port_list, local_block_ids_list, remote_block_ids_list
self.assertEqual(
get_kv_split_metadata(True, 1, 1, 8, 1, 0, 8, 1, 8, 30000, [1],
[1]),
([[30001], [30002], [30003], [30004], [30005], [30006], [30007],
[30000]], [[], [], [], [], [], [], [], [1]], [[], [], [], [], [],
[], [], [1]]))
self.assertEqual(
get_kv_split_metadata(False, 1, 1, 8, 1, 0, 8, 2, 8, 30000, [1],
[1]),
([[30001], [30002], [30003], [30004], [30005], [30006], [30007],
[30008], [30009], [30010], [30011], [30012], [30013], [30014],
[30015], [30000]
], [[], [], [], [], [], [], [], [], [], [], [], [], [], [], [],
[1]], [[], [], [], [], [], [], [], [], [], [], [], [], [],
[], [], [1]]))
self.assertEqual(
get_kv_split_metadata(True, 1, 1, 8, 1, 0, 8, 2, 2, 30000, [1],
[1]),
([[30001], [30008], [30009], [30000]], [[], [], [], [1]
], [[], [], [], [1]]))
self.assertEqual(
get_kv_split_metadata(False, 1, 1, 8, 1, 0, 8, 2, 2, 30000, [1],
[1]),
([[30001], [30008], [30009], [30000]], [[], [], [], [1]
], [[], [], [], [1]]))
self.assertEqual(
get_kv_split_metadata(True, 1, 2, 8, 1, 0, 8, 2, 2, 30000, [1],
[1]),
([[30009], [30001]], [[], [1]], [[], [1]]))
self.assertEqual(
get_kv_split_metadata(False, 1, 2, 8, 1, 0, 8, 2, 2, 30000, [1],
[1]),
([[30009], [30001]], [[], [1]], [[], [1]]))
self.assertEqual(
get_kv_split_metadata(True, 1, 2, 8, 0, 0, 8, 2, 2, 30000,
[1, 2, 3], [1, 2, 3, 4, 5]),
([[30008], [30000]], [[1, 2], [3, 4, 5]], [[1, 2], [1, 2, 3]]))
if __name__ == '__main__':
unittest.main()