From bf97048bce9393a6b79228f585d3d97f177b89ed Mon Sep 17 00:00:00 2001 From: weiguihua2 Date: Wed, 17 Dec 2025 09:28:03 +0800 Subject: [PATCH] [feat]pd disaggregated support cross-machine (#5008) ### What this PR does / why we need it? pd disaggregated support cross-machine. We send the primary and secondary node information of node p to node d. When node d pulls the KV data, it retrieves the corresponding primary or secondary node information from the mapping. - vLLM version: v0.12.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 --------- Signed-off-by: weiguihua2 --- vllm_ascend/distributed/mooncake_connector.py | 83 +++++++++++++++++-- vllm_ascend/worker/worker_v1.py | 16 +++- 2 files changed, 92 insertions(+), 7 deletions(-) diff --git a/vllm_ascend/distributed/mooncake_connector.py b/vllm_ascend/distributed/mooncake_connector.py index b8b66b4c..81c80145 100644 --- a/vllm_ascend/distributed/mooncake_connector.py +++ b/vllm_ascend/distributed/mooncake_connector.py @@ -24,7 +24,8 @@ from mooncake.engine import TransferEngine # type: ignore from vllm import envs from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.v1.base import ( - KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) + KVConnectorBase_V1, KVConnectorHandshakeMetadata, KVConnectorMetadata, + KVConnectorRole) from vllm.distributed.parallel_state import ( get_decode_context_model_parallel_rank, get_decode_context_model_parallel_world_size, get_pp_group, @@ -64,6 +65,7 @@ class MooncakeAgentMetadata(msgspec.Struct, omit_defaults=True, dict=True): te_rpc_port: int kv_caches_base_addr: list[int] num_blocks: int + local_ip: str = "" @dataclass @@ -75,6 +77,7 @@ class ReqMeta: remote_engine_id: str remote_pcp_size: int remote_dcp_size: int + remote_multi_nodes_meta_mapping: dict[str, dict[str, Any]] @dataclass @@ -685,6 +688,8 @@ class MooncakeConnectorMetadata(KVConnectorMetadata): remote_port=kv_transfer_params["remote_port"], remote_pcp_size=kv_transfer_params.get("remote_pcp_size", 1), remote_dcp_size=kv_transfer_params.get("remote_dcp_size", 1), + remote_multi_nodes_meta_mapping=kv_transfer_params.get( + "remote_multi_nodes_meta_mapping", {}), ) @@ -772,6 +777,30 @@ class MooncakeConnector(KVConnectorBase_V1): """MooncakeConnector does not save explicitly.""" pass + def get_handshake_metadata(self) -> KVConnectorHandshakeMetadata | None: + """ + Get the KVConnector handshake metadata for this connector. + This metadata is used for out-of-band connector handshake + between P/D workers. + + Returns: + KVConnectorHandshakeMetadata: the handshake metadata. + None if no handshake metadata is available. + """ + assert self.connector_worker is not None + return self.connector_worker.xfer_handshake_metadata + + def set_xfer_handshake_metadata( + self, metadata: dict[int, KVConnectorHandshakeMetadata]) -> None: + """ + Set the KV connector handshake metadata for this connector. + + Args: + metadata (dict): the handshake metadata to set. + """ + assert self.connector_scheduler is not None + self.connector_scheduler.set_xfer_handshake_metadata(metadata) + class MooncakeConnectorScheduler: """Implementation of Scheduler side methods""" @@ -805,6 +834,9 @@ class MooncakeConnectorScheduler: self._reqs_need_recv: dict[str, tuple[Request, list[int]]] = {} self._reqs_need_send: dict[str, float] = {} + # master-slave meta information for cross-nodes + self.multi_nodes_meta_mapping: dict[str, dict[str, Any]] = {} + def get_num_new_matched_tokens( self, request: "Request", num_computed_tokens: int) -> tuple[int, bool]: @@ -928,8 +960,23 @@ class MooncakeConnectorScheduler: remote_pcp_size=self.pcp_size, remote_dcp_size=self.dcp_size, last_token_id=request.output_token_ids[-1], + remote_multi_nodes_meta_mapping=self.multi_nodes_meta_mapping, ) + def set_xfer_handshake_metadata( + self, metadata: dict[int, KVConnectorHandshakeMetadata]) -> None: + """ + Set the KV connector handshake metadata for this connector. + + Args: + metadata (dict): the handshake metadata to set. + """ + for local_rank, rank_metadata in metadata.items(): + self.multi_nodes_meta_mapping[str(local_rank)] = { + "host": rank_metadata.local_ip, + "engine_id": rank_metadata.engine_id, + } + class MooncakeConnectorWorker: """Implementation of Worker side methods""" @@ -989,6 +1036,9 @@ class MooncakeConnectorWorker: self.kv_send_thread: Optional[KVCacheSendingThread] = None self.kv_recv_thread: Optional[KVCacheRecvingThread] = None + # Handshake metadata of this worker + self.xfer_handshake_metadata: MooncakeAgentMetadata | None = None + # kv_transfer variables self.vllm_config = vllm_config self.block_size = vllm_config.cache_config.block_size @@ -1118,7 +1168,9 @@ class MooncakeConnectorWorker: te_rpc_port=self.te_rpc_port, kv_caches_base_addr=kv_caches_base_addr, num_blocks=self.num_blocks, + local_ip=get_ip(), ) + self.xfer_handshake_metadata = metadata ready_event = threading.Event() if self.kv_role == 'kv_producer': @@ -1266,13 +1318,18 @@ class MooncakeConnectorWorker: continue for i in range(self.tp_num_need_pulls): assert self.kv_recv_thread is not None + remote_host, remote_engine_id = self._get_remote_host_info_by_port( + meta.remote_port, + remote_handshake_port_list[pcp_dcp_rank][i], + meta.remote_host, meta.remote_engine_id, + meta.remote_multi_nodes_meta_mapping) self.kv_recv_thread.add_request( request_id=req_id, local_block_ids=local_block_ids_list[pcp_dcp_rank], remote_block_ids=remote_block_ids_list[ pcp_dcp_rank], - remote_engine_id=meta.remote_engine_id, - remote_host=meta.remote_host, + remote_engine_id=remote_engine_id, + remote_host=remote_host, remote_handshake_port=remote_handshake_port_list[ pcp_dcp_rank][i], offset=i, @@ -1287,12 +1344,16 @@ class MooncakeConnectorWorker: for x in choosen_rank_list] for i in range(self.tp_num_need_pulls * self._prefill_pp_size): assert self.kv_recv_thread is not None + remote_host, remote_engine_id = self._get_remote_host_info_by_port( + meta.remote_port, remote_handshake_port_list[i][0], + meta.remote_host, meta.remote_engine_id, + meta.remote_multi_nodes_meta_mapping) self.kv_recv_thread.add_request( request_id=req_id, local_block_ids=meta.local_block_ids, remote_block_ids=meta.remote_block_ids, - remote_engine_id=meta.remote_engine_id, - remote_host=meta.remote_host, + remote_engine_id=remote_engine_id, + remote_host=remote_host, remote_handshake_port=remote_handshake_port_list[i][0], offset=i, tp_num_need_pulls=self.tp_num_need_pulls, @@ -1307,6 +1368,18 @@ class MooncakeConnectorWorker: else: self.kv_send_thread.add_not_transfer_request(req_id) + def _get_remote_host_info_by_port(self, base_port: int, + remote_handshake_port: int, + remote_host: str, remote_engine_id: str, + remote_multi_nodes_meta_mapping: dict): + rank = str(remote_handshake_port - base_port) + if remote_multi_nodes_meta_mapping is None or remote_multi_nodes_meta_mapping.get( + rank, None) is None: + return remote_host, remote_engine_id + info = remote_multi_nodes_meta_mapping[rank] + return info.get("host", remote_host), info.get("engine_id", + remote_engine_id) + def _prefill_get_remote_rank(self, req_id: str) -> List[int]: return sum(self._get_remote_ranks_for_req(req_id), []) diff --git a/vllm_ascend/worker/worker_v1.py b/vllm_ascend/worker/worker_v1.py index bd7bebfa..29e2fb85 100644 --- a/vllm_ascend/worker/worker_v1.py +++ b/vllm_ascend/worker/worker_v1.py @@ -31,7 +31,9 @@ from vllm.config import VllmConfig from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment) from vllm.distributed.ec_transfer import ensure_ec_transfer_initialized -from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized +from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized, + get_kv_transfer_group, + has_kv_transfer_group) from vllm.distributed.parallel_state import get_pp_group, get_tp_group from vllm.logger import logger from vllm.lora.request import LoRARequest @@ -374,7 +376,17 @@ class NPUWorker(WorkerBase): return self.model_runner.get_model() def get_kv_connector_handshake_metadata(self) -> Optional[dict]: - return None + """Get KV connector metadata from this worker if available.""" + if not has_kv_transfer_group(): + return None + + connector = get_kv_transfer_group() + + # Return None for connectors that don't need to exchange handshake + # metadata across workers. + if (metadata := connector.get_handshake_metadata()) is None: + return None + return {self.rank: metadata} def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: return self.model_runner.get_kv_cache_spec()