diff --git a/tests/ut/kv_connector/test_mooncake_connector.py b/tests/ut/kv_connector/test_mooncake_connector.py index 71e6f1a4..b09174a9 100644 --- a/tests/ut/kv_connector/test_mooncake_connector.py +++ b/tests/ut/kv_connector/test_mooncake_connector.py @@ -1264,7 +1264,7 @@ class TestMooncakeConnectorWorker(unittest.TestCase): tp_rank, pcp_rank, _prefill_tp_size, remote_pcp_size, remote_dcp_size, remote_port, remote_block_ids, - local_block_ids): + local_block_ids, remote_engine_id): worker = MooncakeConnectorWorker(self.vllm_config, self.engine_id) @@ -1275,7 +1275,7 @@ class TestMooncakeConnectorWorker(unittest.TestCase): worker.tp_rank = tp_rank worker.pcp_rank = pcp_rank worker._prefill_tp_size = _prefill_tp_size - worker.local_remote_block_port_mapping = None + worker.local_remote_block_port_mapping = {} worker.block_size = 16 worker.num_key_value_heads = 1 @@ -1289,6 +1289,7 @@ class TestMooncakeConnectorWorker(unittest.TestCase): meta.num_external_tokens = pcp_size * dcp_size * len( local_block_ids) * worker.block_size meta.num_prompt_blocks = pcp_size * dcp_size * len(local_block_ids) + meta.remote_engine_id = remote_engine_id remote_handshake_port_list, local_block_ids_list, remote_block_ids_list = worker._get_kv_split_metadata( '0', meta) @@ -1297,14 +1298,14 @@ class TestMooncakeConnectorWorker(unittest.TestCase): self.assertEqual( get_kv_split_metadata(True, 1, 1, 8, 1, 0, 8, 1, 8, 30000, [1], - [1]), + [1], 0), ([[30001], [30002], [30003], [30004], [30005], [30006], [30007], [30000]], [[], [], [], [], [], [], [], [1]], [[], [], [], [], [], [], [], [1]])) self.assertEqual( get_kv_split_metadata(False, 1, 1, 8, 1, 0, 8, 2, 8, 30000, [1], - [1]), + [1], 0), ([[30001], [30002], [30003], [30004], [30005], [30006], [30007], [30008], [30009], [30010], [30011], [30012], [30013], [30014], [30015], [30000] @@ -1314,29 +1315,29 @@ class TestMooncakeConnectorWorker(unittest.TestCase): self.assertEqual( get_kv_split_metadata(True, 1, 1, 8, 1, 0, 8, 2, 2, 30000, [1], - [1]), + [1], 0), ([[30001], [30008], [30009], [30000]], [[], [], [], [1] ], [[], [], [], [1]])) self.assertEqual( get_kv_split_metadata(False, 1, 1, 8, 1, 0, 8, 2, 2, 30000, [1], - [1]), + [1], 0), ([[30001], [30008], [30009], [30000]], [[], [], [], [1] ], [[], [], [], [1]])) self.assertEqual( get_kv_split_metadata(True, 1, 2, 8, 1, 0, 8, 2, 2, 30000, [1], - [1]), + [1], 0), ([[30000], [30008]], [[1], []], [[1], []])) self.assertEqual( get_kv_split_metadata(False, 1, 2, 8, 1, 0, 8, 2, 2, 30000, [1], - [1]), + [1], 0), ([[30000], [30008]], [[1], []], [[1], []])) self.assertEqual( get_kv_split_metadata(True, 1, 2, 8, 0, 0, 8, 2, 2, 30000, - [1, 2, 3], [1, 2, 3, 4, 5]), + [1, 2, 3], [1, 2, 3, 4, 5], 0), ([[30000], [30008]], [[1, 2, 3], [4, 5]], [[1, 2, 3], [1, 2]])) diff --git a/vllm_ascend/distributed/mooncake_connector.py b/vllm_ascend/distributed/mooncake_connector.py index 1d3619ab..d46d64ba 100644 --- a/vllm_ascend/distributed/mooncake_connector.py +++ b/vllm_ascend/distributed/mooncake_connector.py @@ -1154,8 +1154,9 @@ class MooncakeConnectorWorker: num_p_block_heads = max( 1, self.num_key_value_heads // self._prefill_tp_size) self.tp_num_need_pulls = num_d_block_heads // num_p_block_heads - self.local_remote_block_port_mapping = None - self.remote_port_send_num: dict[int, int] = {} + self.local_remote_block_port_mapping: dict[ + str, Optional[List[List[int]]]] = {} + self.remote_port_send_num: dict[str, dict[int, int]] = {} def _get_prefill_decode_size(self, vllm_config: VllmConfig): # get prefill tp and dp size from extra config @@ -1453,16 +1454,20 @@ class MooncakeConnectorWorker: remote_port_send_num[remote_port] += 1 return remote_port_send_num - if self.local_remote_block_port_mapping is None: + if meta.remote_engine_id not in self.local_remote_block_port_mapping: + self.local_remote_block_port_mapping[meta.remote_engine_id] = None + if self.local_remote_block_port_mapping[meta.remote_engine_id] is None: local_remote_block_port_mappings = get_local_remote_block_port_mappings( ) - self.local_remote_block_port_mapping = local_remote_block_port_mappings[ - self.handshake_port] - self.remote_port_send_num = get_remote_port_send_num( - local_remote_block_port_mappings) + self.local_remote_block_port_mapping[ + meta.remote_engine_id] = local_remote_block_port_mappings[ + self.handshake_port] + self.remote_port_send_num[ + meta.remote_engine_id] = get_remote_port_send_num( + local_remote_block_port_mappings) local_remote_block_port_mapping = copy.deepcopy( - self.local_remote_block_port_mapping) + self.local_remote_block_port_mapping[meta.remote_engine_id]) num_external_blocks = math.ceil(meta.num_external_tokens / self.block_size) @@ -1568,7 +1573,8 @@ class MooncakeConnectorWorker: pcp_dcp_rank][i], offset=i, tp_num_need_pulls=self.tp_num_need_pulls, - remote_port_send_num=self.remote_port_send_num, + remote_port_send_num=self.remote_port_send_num[ + meta.remote_engine_id], all_task_done=( pcp_dcp_rank == len(remote_handshake_port_list) - 1