[feat]pd disaggregated support cross-machine (#5008)
### What this PR does / why we need it?
pd disaggregated support cross-machine.
We send the primary and secondary node information of node p to node d.
When node d pulls the KV data, it retrieves the corresponding primary or
secondary node information from the mapping.
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: weiguihua2 <weiguihua2@huawei.com>
This commit is contained in:
@@ -24,7 +24,8 @@ from mooncake.engine import TransferEngine # type: ignore
|
||||
from vllm import envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
|
||||
KVConnectorBase_V1, KVConnectorHandshakeMetadata, KVConnectorMetadata,
|
||||
KVConnectorRole)
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_decode_context_model_parallel_rank,
|
||||
get_decode_context_model_parallel_world_size, get_pp_group,
|
||||
@@ -64,6 +65,7 @@ class MooncakeAgentMetadata(msgspec.Struct, omit_defaults=True, dict=True):
|
||||
te_rpc_port: int
|
||||
kv_caches_base_addr: list[int]
|
||||
num_blocks: int
|
||||
local_ip: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -75,6 +77,7 @@ class ReqMeta:
|
||||
remote_engine_id: str
|
||||
remote_pcp_size: int
|
||||
remote_dcp_size: int
|
||||
remote_multi_nodes_meta_mapping: dict[str, dict[str, Any]]
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -685,6 +688,8 @@ class MooncakeConnectorMetadata(KVConnectorMetadata):
|
||||
remote_port=kv_transfer_params["remote_port"],
|
||||
remote_pcp_size=kv_transfer_params.get("remote_pcp_size", 1),
|
||||
remote_dcp_size=kv_transfer_params.get("remote_dcp_size", 1),
|
||||
remote_multi_nodes_meta_mapping=kv_transfer_params.get(
|
||||
"remote_multi_nodes_meta_mapping", {}),
|
||||
)
|
||||
|
||||
|
||||
@@ -772,6 +777,30 @@ class MooncakeConnector(KVConnectorBase_V1):
|
||||
"""MooncakeConnector does not save explicitly."""
|
||||
pass
|
||||
|
||||
def get_handshake_metadata(self) -> KVConnectorHandshakeMetadata | None:
|
||||
"""
|
||||
Get the KVConnector handshake metadata for this connector.
|
||||
This metadata is used for out-of-band connector handshake
|
||||
between P/D workers.
|
||||
|
||||
Returns:
|
||||
KVConnectorHandshakeMetadata: the handshake metadata.
|
||||
None if no handshake metadata is available.
|
||||
"""
|
||||
assert self.connector_worker is not None
|
||||
return self.connector_worker.xfer_handshake_metadata
|
||||
|
||||
def set_xfer_handshake_metadata(
|
||||
self, metadata: dict[int, KVConnectorHandshakeMetadata]) -> None:
|
||||
"""
|
||||
Set the KV connector handshake metadata for this connector.
|
||||
|
||||
Args:
|
||||
metadata (dict): the handshake metadata to set.
|
||||
"""
|
||||
assert self.connector_scheduler is not None
|
||||
self.connector_scheduler.set_xfer_handshake_metadata(metadata)
|
||||
|
||||
|
||||
class MooncakeConnectorScheduler:
|
||||
"""Implementation of Scheduler side methods"""
|
||||
@@ -805,6 +834,9 @@ class MooncakeConnectorScheduler:
|
||||
self._reqs_need_recv: dict[str, tuple[Request, list[int]]] = {}
|
||||
self._reqs_need_send: dict[str, float] = {}
|
||||
|
||||
# master-slave meta information for cross-nodes
|
||||
self.multi_nodes_meta_mapping: dict[str, dict[str, Any]] = {}
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self, request: "Request",
|
||||
num_computed_tokens: int) -> tuple[int, bool]:
|
||||
@@ -928,8 +960,23 @@ class MooncakeConnectorScheduler:
|
||||
remote_pcp_size=self.pcp_size,
|
||||
remote_dcp_size=self.dcp_size,
|
||||
last_token_id=request.output_token_ids[-1],
|
||||
remote_multi_nodes_meta_mapping=self.multi_nodes_meta_mapping,
|
||||
)
|
||||
|
||||
def set_xfer_handshake_metadata(
|
||||
self, metadata: dict[int, KVConnectorHandshakeMetadata]) -> None:
|
||||
"""
|
||||
Set the KV connector handshake metadata for this connector.
|
||||
|
||||
Args:
|
||||
metadata (dict): the handshake metadata to set.
|
||||
"""
|
||||
for local_rank, rank_metadata in metadata.items():
|
||||
self.multi_nodes_meta_mapping[str(local_rank)] = {
|
||||
"host": rank_metadata.local_ip,
|
||||
"engine_id": rank_metadata.engine_id,
|
||||
}
|
||||
|
||||
|
||||
class MooncakeConnectorWorker:
|
||||
"""Implementation of Worker side methods"""
|
||||
@@ -989,6 +1036,9 @@ class MooncakeConnectorWorker:
|
||||
self.kv_send_thread: Optional[KVCacheSendingThread] = None
|
||||
self.kv_recv_thread: Optional[KVCacheRecvingThread] = None
|
||||
|
||||
# Handshake metadata of this worker
|
||||
self.xfer_handshake_metadata: MooncakeAgentMetadata | None = None
|
||||
|
||||
# kv_transfer variables
|
||||
self.vllm_config = vllm_config
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
@@ -1118,7 +1168,9 @@ class MooncakeConnectorWorker:
|
||||
te_rpc_port=self.te_rpc_port,
|
||||
kv_caches_base_addr=kv_caches_base_addr,
|
||||
num_blocks=self.num_blocks,
|
||||
local_ip=get_ip(),
|
||||
)
|
||||
self.xfer_handshake_metadata = metadata
|
||||
|
||||
ready_event = threading.Event()
|
||||
if self.kv_role == 'kv_producer':
|
||||
@@ -1266,13 +1318,18 @@ class MooncakeConnectorWorker:
|
||||
continue
|
||||
for i in range(self.tp_num_need_pulls):
|
||||
assert self.kv_recv_thread is not None
|
||||
remote_host, remote_engine_id = self._get_remote_host_info_by_port(
|
||||
meta.remote_port,
|
||||
remote_handshake_port_list[pcp_dcp_rank][i],
|
||||
meta.remote_host, meta.remote_engine_id,
|
||||
meta.remote_multi_nodes_meta_mapping)
|
||||
self.kv_recv_thread.add_request(
|
||||
request_id=req_id,
|
||||
local_block_ids=local_block_ids_list[pcp_dcp_rank],
|
||||
remote_block_ids=remote_block_ids_list[
|
||||
pcp_dcp_rank],
|
||||
remote_engine_id=meta.remote_engine_id,
|
||||
remote_host=meta.remote_host,
|
||||
remote_engine_id=remote_engine_id,
|
||||
remote_host=remote_host,
|
||||
remote_handshake_port=remote_handshake_port_list[
|
||||
pcp_dcp_rank][i],
|
||||
offset=i,
|
||||
@@ -1287,12 +1344,16 @@ class MooncakeConnectorWorker:
|
||||
for x in choosen_rank_list]
|
||||
for i in range(self.tp_num_need_pulls * self._prefill_pp_size):
|
||||
assert self.kv_recv_thread is not None
|
||||
remote_host, remote_engine_id = self._get_remote_host_info_by_port(
|
||||
meta.remote_port, remote_handshake_port_list[i][0],
|
||||
meta.remote_host, meta.remote_engine_id,
|
||||
meta.remote_multi_nodes_meta_mapping)
|
||||
self.kv_recv_thread.add_request(
|
||||
request_id=req_id,
|
||||
local_block_ids=meta.local_block_ids,
|
||||
remote_block_ids=meta.remote_block_ids,
|
||||
remote_engine_id=meta.remote_engine_id,
|
||||
remote_host=meta.remote_host,
|
||||
remote_engine_id=remote_engine_id,
|
||||
remote_host=remote_host,
|
||||
remote_handshake_port=remote_handshake_port_list[i][0],
|
||||
offset=i,
|
||||
tp_num_need_pulls=self.tp_num_need_pulls,
|
||||
@@ -1307,6 +1368,18 @@ class MooncakeConnectorWorker:
|
||||
else:
|
||||
self.kv_send_thread.add_not_transfer_request(req_id)
|
||||
|
||||
def _get_remote_host_info_by_port(self, base_port: int,
|
||||
remote_handshake_port: int,
|
||||
remote_host: str, remote_engine_id: str,
|
||||
remote_multi_nodes_meta_mapping: dict):
|
||||
rank = str(remote_handshake_port - base_port)
|
||||
if remote_multi_nodes_meta_mapping is None or remote_multi_nodes_meta_mapping.get(
|
||||
rank, None) is None:
|
||||
return remote_host, remote_engine_id
|
||||
info = remote_multi_nodes_meta_mapping[rank]
|
||||
return info.get("host", remote_host), info.get("engine_id",
|
||||
remote_engine_id)
|
||||
|
||||
def _prefill_get_remote_rank(self, req_id: str) -> List[int]:
|
||||
return sum(self._get_remote_ranks_for_req(req_id), [])
|
||||
|
||||
|
||||
@@ -31,7 +31,9 @@ from vllm.config import VllmConfig
|
||||
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||
init_distributed_environment)
|
||||
from vllm.distributed.ec_transfer import ensure_ec_transfer_initialized
|
||||
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
|
||||
from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized,
|
||||
get_kv_transfer_group,
|
||||
has_kv_transfer_group)
|
||||
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
|
||||
from vllm.logger import logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
@@ -374,7 +376,17 @@ class NPUWorker(WorkerBase):
|
||||
return self.model_runner.get_model()
|
||||
|
||||
def get_kv_connector_handshake_metadata(self) -> Optional[dict]:
|
||||
return None
|
||||
"""Get KV connector metadata from this worker if available."""
|
||||
if not has_kv_transfer_group():
|
||||
return None
|
||||
|
||||
connector = get_kv_transfer_group()
|
||||
|
||||
# Return None for connectors that don't need to exchange handshake
|
||||
# metadata across workers.
|
||||
if (metadata := connector.get_handshake_metadata()) is None:
|
||||
return None
|
||||
return {self.rank: metadata}
|
||||
|
||||
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
|
||||
return self.model_runner.get_kv_cache_spec()
|
||||
|
||||
Reference in New Issue
Block a user