[UT] Add mooncake ut test (#5080)
### What this PR does / why we need it?
Add UT for mooncake
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
Signed-off-by: tongyuzhou <tongyuzhou1@huawei.com>
Signed-off-by: wangxiaochao <w00642655@china.huawei.com>
Co-authored-by: tongyuzhou <tongyuzhou1@huawei.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -1260,6 +1260,82 @@ class TestMooncakeConnectorWorker(unittest.TestCase):
|
||||
get_tp_rank(4, 4, 4, 1, 1, True),
|
||||
[[[0, 4, 8, 12], [1, 5, 9, 13], [2, 6, 10, 14], [3, 7, 11, 15]]])
|
||||
|
||||
def test_get_kv_split_metadata(self):
|
||||
|
||||
def get_kv_split_metadata(use_mla, pcp_size, dcp_size, tp_size,
|
||||
tp_rank, pcp_rank, _prefill_tp_size,
|
||||
remote_pcp_size, remote_dcp_size,
|
||||
remote_port, remote_block_ids,
|
||||
local_block_ids):
|
||||
|
||||
worker = MooncakeConnectorWorker(self.vllm_config, self.engine_id)
|
||||
|
||||
worker.use_mla = use_mla
|
||||
worker.pcp_size = pcp_size
|
||||
worker.dcp_size = dcp_size
|
||||
worker.tp_size = tp_size
|
||||
worker.tp_rank = tp_rank
|
||||
worker.pcp_rank = pcp_rank
|
||||
worker._prefill_tp_size = _prefill_tp_size
|
||||
worker.local_remote_block_port_mapping = None
|
||||
|
||||
meta = types.SimpleNamespace()
|
||||
|
||||
meta.remote_pcp_size = remote_pcp_size
|
||||
meta.remote_dcp_size = remote_dcp_size
|
||||
meta.remote_port = remote_port
|
||||
meta.remote_block_ids = remote_block_ids
|
||||
meta.local_block_ids = local_block_ids
|
||||
|
||||
remote_handshake_port_list, local_block_ids_list, remote_block_ids_list = worker._get_kv_split_metadata(
|
||||
'0', meta)
|
||||
|
||||
return remote_handshake_port_list, local_block_ids_list, remote_block_ids_list
|
||||
|
||||
self.assertEqual(
|
||||
get_kv_split_metadata(True, 1, 1, 8, 1, 0, 8, 1, 8, 30000, [1],
|
||||
[1]),
|
||||
([[30001], [30002], [30003], [30004], [30005], [30006], [30007],
|
||||
[30000]], [[], [], [], [], [], [], [], [1]], [[], [], [], [], [],
|
||||
[], [], [1]]))
|
||||
|
||||
self.assertEqual(
|
||||
get_kv_split_metadata(False, 1, 1, 8, 1, 0, 8, 2, 8, 30000, [1],
|
||||
[1]),
|
||||
([[30001], [30002], [30003], [30004], [30005], [30006], [30007],
|
||||
[30008], [30009], [30010], [30011], [30012], [30013], [30014],
|
||||
[30015], [30000]
|
||||
], [[], [], [], [], [], [], [], [], [], [], [], [], [], [], [],
|
||||
[1]], [[], [], [], [], [], [], [], [], [], [], [], [], [],
|
||||
[], [], [1]]))
|
||||
|
||||
self.assertEqual(
|
||||
get_kv_split_metadata(True, 1, 1, 8, 1, 0, 8, 2, 2, 30000, [1],
|
||||
[1]),
|
||||
([[30001], [30008], [30009], [30000]], [[], [], [], [1]
|
||||
], [[], [], [], [1]]))
|
||||
|
||||
self.assertEqual(
|
||||
get_kv_split_metadata(False, 1, 1, 8, 1, 0, 8, 2, 2, 30000, [1],
|
||||
[1]),
|
||||
([[30001], [30008], [30009], [30000]], [[], [], [], [1]
|
||||
], [[], [], [], [1]]))
|
||||
|
||||
self.assertEqual(
|
||||
get_kv_split_metadata(True, 1, 2, 8, 1, 0, 8, 2, 2, 30000, [1],
|
||||
[1]),
|
||||
([[30009], [30001]], [[], [1]], [[], [1]]))
|
||||
|
||||
self.assertEqual(
|
||||
get_kv_split_metadata(False, 1, 2, 8, 1, 0, 8, 2, 2, 30000, [1],
|
||||
[1]),
|
||||
([[30009], [30001]], [[], [1]], [[], [1]]))
|
||||
|
||||
self.assertEqual(
|
||||
get_kv_split_metadata(True, 1, 2, 8, 0, 0, 8, 2, 2, 30000,
|
||||
[1, 2, 3], [1, 2, 3, 4, 5]),
|
||||
([[30008], [30000]], [[1, 2], [3, 4, 5]], [[1, 2], [1, 2, 3]]))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user