[feat]pd disaggregated support cross-machine (#5008)
### What this PR does / why we need it?
pd disaggregated support cross-machine.
We send the primary and secondary node information of node p to node d.
When node d pulls the KV data, it retrieves the corresponding primary or
secondary node information from the mapping.
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: weiguihua2 <weiguihua2@huawei.com>
This commit is contained in:
@@ -24,7 +24,8 @@ from mooncake.engine import TransferEngine # type: ignore
|
|||||||
from vllm import envs
|
from vllm import envs
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||||
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
|
KVConnectorBase_V1, KVConnectorHandshakeMetadata, KVConnectorMetadata,
|
||||||
|
KVConnectorRole)
|
||||||
from vllm.distributed.parallel_state import (
|
from vllm.distributed.parallel_state import (
|
||||||
get_decode_context_model_parallel_rank,
|
get_decode_context_model_parallel_rank,
|
||||||
get_decode_context_model_parallel_world_size, get_pp_group,
|
get_decode_context_model_parallel_world_size, get_pp_group,
|
||||||
@@ -64,6 +65,7 @@ class MooncakeAgentMetadata(msgspec.Struct, omit_defaults=True, dict=True):
|
|||||||
te_rpc_port: int
|
te_rpc_port: int
|
||||||
kv_caches_base_addr: list[int]
|
kv_caches_base_addr: list[int]
|
||||||
num_blocks: int
|
num_blocks: int
|
||||||
|
local_ip: str = ""
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -75,6 +77,7 @@ class ReqMeta:
|
|||||||
remote_engine_id: str
|
remote_engine_id: str
|
||||||
remote_pcp_size: int
|
remote_pcp_size: int
|
||||||
remote_dcp_size: int
|
remote_dcp_size: int
|
||||||
|
remote_multi_nodes_meta_mapping: dict[str, dict[str, Any]]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -685,6 +688,8 @@ class MooncakeConnectorMetadata(KVConnectorMetadata):
|
|||||||
remote_port=kv_transfer_params["remote_port"],
|
remote_port=kv_transfer_params["remote_port"],
|
||||||
remote_pcp_size=kv_transfer_params.get("remote_pcp_size", 1),
|
remote_pcp_size=kv_transfer_params.get("remote_pcp_size", 1),
|
||||||
remote_dcp_size=kv_transfer_params.get("remote_dcp_size", 1),
|
remote_dcp_size=kv_transfer_params.get("remote_dcp_size", 1),
|
||||||
|
remote_multi_nodes_meta_mapping=kv_transfer_params.get(
|
||||||
|
"remote_multi_nodes_meta_mapping", {}),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -772,6 +777,30 @@ class MooncakeConnector(KVConnectorBase_V1):
|
|||||||
"""MooncakeConnector does not save explicitly."""
|
"""MooncakeConnector does not save explicitly."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def get_handshake_metadata(self) -> KVConnectorHandshakeMetadata | None:
|
||||||
|
"""
|
||||||
|
Get the KVConnector handshake metadata for this connector.
|
||||||
|
This metadata is used for out-of-band connector handshake
|
||||||
|
between P/D workers.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
KVConnectorHandshakeMetadata: the handshake metadata.
|
||||||
|
None if no handshake metadata is available.
|
||||||
|
"""
|
||||||
|
assert self.connector_worker is not None
|
||||||
|
return self.connector_worker.xfer_handshake_metadata
|
||||||
|
|
||||||
|
def set_xfer_handshake_metadata(
|
||||||
|
self, metadata: dict[int, KVConnectorHandshakeMetadata]) -> None:
|
||||||
|
"""
|
||||||
|
Set the KV connector handshake metadata for this connector.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
metadata (dict): the handshake metadata to set.
|
||||||
|
"""
|
||||||
|
assert self.connector_scheduler is not None
|
||||||
|
self.connector_scheduler.set_xfer_handshake_metadata(metadata)
|
||||||
|
|
||||||
|
|
||||||
class MooncakeConnectorScheduler:
|
class MooncakeConnectorScheduler:
|
||||||
"""Implementation of Scheduler side methods"""
|
"""Implementation of Scheduler side methods"""
|
||||||
@@ -805,6 +834,9 @@ class MooncakeConnectorScheduler:
|
|||||||
self._reqs_need_recv: dict[str, tuple[Request, list[int]]] = {}
|
self._reqs_need_recv: dict[str, tuple[Request, list[int]]] = {}
|
||||||
self._reqs_need_send: dict[str, float] = {}
|
self._reqs_need_send: dict[str, float] = {}
|
||||||
|
|
||||||
|
# master-slave meta information for cross-nodes
|
||||||
|
self.multi_nodes_meta_mapping: dict[str, dict[str, Any]] = {}
|
||||||
|
|
||||||
def get_num_new_matched_tokens(
|
def get_num_new_matched_tokens(
|
||||||
self, request: "Request",
|
self, request: "Request",
|
||||||
num_computed_tokens: int) -> tuple[int, bool]:
|
num_computed_tokens: int) -> tuple[int, bool]:
|
||||||
@@ -928,8 +960,23 @@ class MooncakeConnectorScheduler:
|
|||||||
remote_pcp_size=self.pcp_size,
|
remote_pcp_size=self.pcp_size,
|
||||||
remote_dcp_size=self.dcp_size,
|
remote_dcp_size=self.dcp_size,
|
||||||
last_token_id=request.output_token_ids[-1],
|
last_token_id=request.output_token_ids[-1],
|
||||||
|
remote_multi_nodes_meta_mapping=self.multi_nodes_meta_mapping,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def set_xfer_handshake_metadata(
|
||||||
|
self, metadata: dict[int, KVConnectorHandshakeMetadata]) -> None:
|
||||||
|
"""
|
||||||
|
Set the KV connector handshake metadata for this connector.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
metadata (dict): the handshake metadata to set.
|
||||||
|
"""
|
||||||
|
for local_rank, rank_metadata in metadata.items():
|
||||||
|
self.multi_nodes_meta_mapping[str(local_rank)] = {
|
||||||
|
"host": rank_metadata.local_ip,
|
||||||
|
"engine_id": rank_metadata.engine_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class MooncakeConnectorWorker:
|
class MooncakeConnectorWorker:
|
||||||
"""Implementation of Worker side methods"""
|
"""Implementation of Worker side methods"""
|
||||||
@@ -989,6 +1036,9 @@ class MooncakeConnectorWorker:
|
|||||||
self.kv_send_thread: Optional[KVCacheSendingThread] = None
|
self.kv_send_thread: Optional[KVCacheSendingThread] = None
|
||||||
self.kv_recv_thread: Optional[KVCacheRecvingThread] = None
|
self.kv_recv_thread: Optional[KVCacheRecvingThread] = None
|
||||||
|
|
||||||
|
# Handshake metadata of this worker
|
||||||
|
self.xfer_handshake_metadata: MooncakeAgentMetadata | None = None
|
||||||
|
|
||||||
# kv_transfer variables
|
# kv_transfer variables
|
||||||
self.vllm_config = vllm_config
|
self.vllm_config = vllm_config
|
||||||
self.block_size = vllm_config.cache_config.block_size
|
self.block_size = vllm_config.cache_config.block_size
|
||||||
@@ -1118,7 +1168,9 @@ class MooncakeConnectorWorker:
|
|||||||
te_rpc_port=self.te_rpc_port,
|
te_rpc_port=self.te_rpc_port,
|
||||||
kv_caches_base_addr=kv_caches_base_addr,
|
kv_caches_base_addr=kv_caches_base_addr,
|
||||||
num_blocks=self.num_blocks,
|
num_blocks=self.num_blocks,
|
||||||
|
local_ip=get_ip(),
|
||||||
)
|
)
|
||||||
|
self.xfer_handshake_metadata = metadata
|
||||||
|
|
||||||
ready_event = threading.Event()
|
ready_event = threading.Event()
|
||||||
if self.kv_role == 'kv_producer':
|
if self.kv_role == 'kv_producer':
|
||||||
@@ -1266,13 +1318,18 @@ class MooncakeConnectorWorker:
|
|||||||
continue
|
continue
|
||||||
for i in range(self.tp_num_need_pulls):
|
for i in range(self.tp_num_need_pulls):
|
||||||
assert self.kv_recv_thread is not None
|
assert self.kv_recv_thread is not None
|
||||||
|
remote_host, remote_engine_id = self._get_remote_host_info_by_port(
|
||||||
|
meta.remote_port,
|
||||||
|
remote_handshake_port_list[pcp_dcp_rank][i],
|
||||||
|
meta.remote_host, meta.remote_engine_id,
|
||||||
|
meta.remote_multi_nodes_meta_mapping)
|
||||||
self.kv_recv_thread.add_request(
|
self.kv_recv_thread.add_request(
|
||||||
request_id=req_id,
|
request_id=req_id,
|
||||||
local_block_ids=local_block_ids_list[pcp_dcp_rank],
|
local_block_ids=local_block_ids_list[pcp_dcp_rank],
|
||||||
remote_block_ids=remote_block_ids_list[
|
remote_block_ids=remote_block_ids_list[
|
||||||
pcp_dcp_rank],
|
pcp_dcp_rank],
|
||||||
remote_engine_id=meta.remote_engine_id,
|
remote_engine_id=remote_engine_id,
|
||||||
remote_host=meta.remote_host,
|
remote_host=remote_host,
|
||||||
remote_handshake_port=remote_handshake_port_list[
|
remote_handshake_port=remote_handshake_port_list[
|
||||||
pcp_dcp_rank][i],
|
pcp_dcp_rank][i],
|
||||||
offset=i,
|
offset=i,
|
||||||
@@ -1287,12 +1344,16 @@ class MooncakeConnectorWorker:
|
|||||||
for x in choosen_rank_list]
|
for x in choosen_rank_list]
|
||||||
for i in range(self.tp_num_need_pulls * self._prefill_pp_size):
|
for i in range(self.tp_num_need_pulls * self._prefill_pp_size):
|
||||||
assert self.kv_recv_thread is not None
|
assert self.kv_recv_thread is not None
|
||||||
|
remote_host, remote_engine_id = self._get_remote_host_info_by_port(
|
||||||
|
meta.remote_port, remote_handshake_port_list[i][0],
|
||||||
|
meta.remote_host, meta.remote_engine_id,
|
||||||
|
meta.remote_multi_nodes_meta_mapping)
|
||||||
self.kv_recv_thread.add_request(
|
self.kv_recv_thread.add_request(
|
||||||
request_id=req_id,
|
request_id=req_id,
|
||||||
local_block_ids=meta.local_block_ids,
|
local_block_ids=meta.local_block_ids,
|
||||||
remote_block_ids=meta.remote_block_ids,
|
remote_block_ids=meta.remote_block_ids,
|
||||||
remote_engine_id=meta.remote_engine_id,
|
remote_engine_id=remote_engine_id,
|
||||||
remote_host=meta.remote_host,
|
remote_host=remote_host,
|
||||||
remote_handshake_port=remote_handshake_port_list[i][0],
|
remote_handshake_port=remote_handshake_port_list[i][0],
|
||||||
offset=i,
|
offset=i,
|
||||||
tp_num_need_pulls=self.tp_num_need_pulls,
|
tp_num_need_pulls=self.tp_num_need_pulls,
|
||||||
@@ -1307,6 +1368,18 @@ class MooncakeConnectorWorker:
|
|||||||
else:
|
else:
|
||||||
self.kv_send_thread.add_not_transfer_request(req_id)
|
self.kv_send_thread.add_not_transfer_request(req_id)
|
||||||
|
|
||||||
|
def _get_remote_host_info_by_port(self, base_port: int,
|
||||||
|
remote_handshake_port: int,
|
||||||
|
remote_host: str, remote_engine_id: str,
|
||||||
|
remote_multi_nodes_meta_mapping: dict):
|
||||||
|
rank = str(remote_handshake_port - base_port)
|
||||||
|
if remote_multi_nodes_meta_mapping is None or remote_multi_nodes_meta_mapping.get(
|
||||||
|
rank, None) is None:
|
||||||
|
return remote_host, remote_engine_id
|
||||||
|
info = remote_multi_nodes_meta_mapping[rank]
|
||||||
|
return info.get("host", remote_host), info.get("engine_id",
|
||||||
|
remote_engine_id)
|
||||||
|
|
||||||
def _prefill_get_remote_rank(self, req_id: str) -> List[int]:
|
def _prefill_get_remote_rank(self, req_id: str) -> List[int]:
|
||||||
return sum(self._get_remote_ranks_for_req(req_id), [])
|
return sum(self._get_remote_ranks_for_req(req_id), [])
|
||||||
|
|
||||||
|
|||||||
@@ -31,7 +31,9 @@ from vllm.config import VllmConfig
|
|||||||
from vllm.distributed import (ensure_model_parallel_initialized,
|
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||||
init_distributed_environment)
|
init_distributed_environment)
|
||||||
from vllm.distributed.ec_transfer import ensure_ec_transfer_initialized
|
from vllm.distributed.ec_transfer import ensure_ec_transfer_initialized
|
||||||
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
|
from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized,
|
||||||
|
get_kv_transfer_group,
|
||||||
|
has_kv_transfer_group)
|
||||||
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
|
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
|
||||||
from vllm.logger import logger
|
from vllm.logger import logger
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
@@ -374,8 +376,18 @@ class NPUWorker(WorkerBase):
|
|||||||
return self.model_runner.get_model()
|
return self.model_runner.get_model()
|
||||||
|
|
||||||
def get_kv_connector_handshake_metadata(self) -> Optional[dict]:
|
def get_kv_connector_handshake_metadata(self) -> Optional[dict]:
|
||||||
|
"""Get KV connector metadata from this worker if available."""
|
||||||
|
if not has_kv_transfer_group():
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
connector = get_kv_transfer_group()
|
||||||
|
|
||||||
|
# Return None for connectors that don't need to exchange handshake
|
||||||
|
# metadata across workers.
|
||||||
|
if (metadata := connector.get_handshake_metadata()) is None:
|
||||||
|
return None
|
||||||
|
return {self.rank: metadata}
|
||||||
|
|
||||||
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
|
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
|
||||||
return self.model_runner.get_kv_cache_spec()
|
return self.model_runner.get_kv_cache_spec()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user