[feat]pd disaggregated support cross-machine (#5008)

### What this PR does / why we need it?
pd disaggregated support cross-machine.
We send the primary and secondary node information of node p to node d.
When node d pulls the KV data, it retrieves the corresponding primary or
secondary node information from the mapping.

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: weiguihua2 <weiguihua2@huawei.com>
This commit is contained in:
weiguihua2
2025-12-17 09:28:03 +08:00
committed by GitHub
parent 153eeaa621
commit bf97048bce
2 changed files with 92 additions and 7 deletions

View File

@@ -24,7 +24,8 @@ from mooncake.engine import TransferEngine # type: ignore
from vllm import envs
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
KVConnectorBase_V1, KVConnectorHandshakeMetadata, KVConnectorMetadata,
KVConnectorRole)
from vllm.distributed.parallel_state import (
get_decode_context_model_parallel_rank,
get_decode_context_model_parallel_world_size, get_pp_group,
@@ -64,6 +65,7 @@ class MooncakeAgentMetadata(msgspec.Struct, omit_defaults=True, dict=True):
te_rpc_port: int
kv_caches_base_addr: list[int]
num_blocks: int
local_ip: str = ""
@dataclass
@@ -75,6 +77,7 @@ class ReqMeta:
remote_engine_id: str
remote_pcp_size: int
remote_dcp_size: int
remote_multi_nodes_meta_mapping: dict[str, dict[str, Any]]
@dataclass
@@ -685,6 +688,8 @@ class MooncakeConnectorMetadata(KVConnectorMetadata):
remote_port=kv_transfer_params["remote_port"],
remote_pcp_size=kv_transfer_params.get("remote_pcp_size", 1),
remote_dcp_size=kv_transfer_params.get("remote_dcp_size", 1),
remote_multi_nodes_meta_mapping=kv_transfer_params.get(
"remote_multi_nodes_meta_mapping", {}),
)
@@ -772,6 +777,30 @@ class MooncakeConnector(KVConnectorBase_V1):
"""MooncakeConnector does not save explicitly."""
pass
def get_handshake_metadata(self) -> KVConnectorHandshakeMetadata | None:
"""
Get the KVConnector handshake metadata for this connector.
This metadata is used for out-of-band connector handshake
between P/D workers.
Returns:
KVConnectorHandshakeMetadata: the handshake metadata.
None if no handshake metadata is available.
"""
assert self.connector_worker is not None
return self.connector_worker.xfer_handshake_metadata
def set_xfer_handshake_metadata(
self, metadata: dict[int, KVConnectorHandshakeMetadata]) -> None:
"""
Set the KV connector handshake metadata for this connector.
Args:
metadata (dict): the handshake metadata to set.
"""
assert self.connector_scheduler is not None
self.connector_scheduler.set_xfer_handshake_metadata(metadata)
class MooncakeConnectorScheduler:
"""Implementation of Scheduler side methods"""
@@ -805,6 +834,9 @@ class MooncakeConnectorScheduler:
self._reqs_need_recv: dict[str, tuple[Request, list[int]]] = {}
self._reqs_need_send: dict[str, float] = {}
# master-slave meta information for cross-nodes
self.multi_nodes_meta_mapping: dict[str, dict[str, Any]] = {}
def get_num_new_matched_tokens(
self, request: "Request",
num_computed_tokens: int) -> tuple[int, bool]:
@@ -928,8 +960,23 @@ class MooncakeConnectorScheduler:
remote_pcp_size=self.pcp_size,
remote_dcp_size=self.dcp_size,
last_token_id=request.output_token_ids[-1],
remote_multi_nodes_meta_mapping=self.multi_nodes_meta_mapping,
)
def set_xfer_handshake_metadata(
self, metadata: dict[int, KVConnectorHandshakeMetadata]) -> None:
"""
Set the KV connector handshake metadata for this connector.
Args:
metadata (dict): the handshake metadata to set.
"""
for local_rank, rank_metadata in metadata.items():
self.multi_nodes_meta_mapping[str(local_rank)] = {
"host": rank_metadata.local_ip,
"engine_id": rank_metadata.engine_id,
}
class MooncakeConnectorWorker:
"""Implementation of Worker side methods"""
@@ -989,6 +1036,9 @@ class MooncakeConnectorWorker:
self.kv_send_thread: Optional[KVCacheSendingThread] = None
self.kv_recv_thread: Optional[KVCacheRecvingThread] = None
# Handshake metadata of this worker
self.xfer_handshake_metadata: MooncakeAgentMetadata | None = None
# kv_transfer variables
self.vllm_config = vllm_config
self.block_size = vllm_config.cache_config.block_size
@@ -1118,7 +1168,9 @@ class MooncakeConnectorWorker:
te_rpc_port=self.te_rpc_port,
kv_caches_base_addr=kv_caches_base_addr,
num_blocks=self.num_blocks,
local_ip=get_ip(),
)
self.xfer_handshake_metadata = metadata
ready_event = threading.Event()
if self.kv_role == 'kv_producer':
@@ -1266,13 +1318,18 @@ class MooncakeConnectorWorker:
continue
for i in range(self.tp_num_need_pulls):
assert self.kv_recv_thread is not None
remote_host, remote_engine_id = self._get_remote_host_info_by_port(
meta.remote_port,
remote_handshake_port_list[pcp_dcp_rank][i],
meta.remote_host, meta.remote_engine_id,
meta.remote_multi_nodes_meta_mapping)
self.kv_recv_thread.add_request(
request_id=req_id,
local_block_ids=local_block_ids_list[pcp_dcp_rank],
remote_block_ids=remote_block_ids_list[
pcp_dcp_rank],
remote_engine_id=meta.remote_engine_id,
remote_host=meta.remote_host,
remote_engine_id=remote_engine_id,
remote_host=remote_host,
remote_handshake_port=remote_handshake_port_list[
pcp_dcp_rank][i],
offset=i,
@@ -1287,12 +1344,16 @@ class MooncakeConnectorWorker:
for x in choosen_rank_list]
for i in range(self.tp_num_need_pulls * self._prefill_pp_size):
assert self.kv_recv_thread is not None
remote_host, remote_engine_id = self._get_remote_host_info_by_port(
meta.remote_port, remote_handshake_port_list[i][0],
meta.remote_host, meta.remote_engine_id,
meta.remote_multi_nodes_meta_mapping)
self.kv_recv_thread.add_request(
request_id=req_id,
local_block_ids=meta.local_block_ids,
remote_block_ids=meta.remote_block_ids,
remote_engine_id=meta.remote_engine_id,
remote_host=meta.remote_host,
remote_engine_id=remote_engine_id,
remote_host=remote_host,
remote_handshake_port=remote_handshake_port_list[i][0],
offset=i,
tp_num_need_pulls=self.tp_num_need_pulls,
@@ -1307,6 +1368,18 @@ class MooncakeConnectorWorker:
else:
self.kv_send_thread.add_not_transfer_request(req_id)
def _get_remote_host_info_by_port(self, base_port: int,
remote_handshake_port: int,
remote_host: str, remote_engine_id: str,
remote_multi_nodes_meta_mapping: dict):
rank = str(remote_handshake_port - base_port)
if remote_multi_nodes_meta_mapping is None or remote_multi_nodes_meta_mapping.get(
rank, None) is None:
return remote_host, remote_engine_id
info = remote_multi_nodes_meta_mapping[rank]
return info.get("host", remote_host), info.get("engine_id",
remote_engine_id)
def _prefill_get_remote_rank(self, req_id: str) -> List[int]:
return sum(self._get_remote_ranks_for_req(req_id), [])

View File

@@ -31,7 +31,9 @@ from vllm.config import VllmConfig
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.distributed.ec_transfer import ensure_ec_transfer_initialized
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized,
get_kv_transfer_group,
has_kv_transfer_group)
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
from vllm.logger import logger
from vllm.lora.request import LoRARequest
@@ -374,7 +376,17 @@ class NPUWorker(WorkerBase):
return self.model_runner.get_model()
def get_kv_connector_handshake_metadata(self) -> Optional[dict]:
return None
"""Get KV connector metadata from this worker if available."""
if not has_kv_transfer_group():
return None
connector = get_kv_transfer_group()
# Return None for connectors that don't need to exchange handshake
# metadata across workers.
if (metadata := connector.get_handshake_metadata()) is None:
return None
return {self.rank: metadata}
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
return self.model_runner.get_kv_cache_spec()