[feature] support pcp + mtp (in pd co-locate scenario) (#4098)

1. support pcp + mtp in pd co-locate scenario
2. llmdatadist connector pcp related bugfix and cleancode

- vLLM version: v0.11.0
- vLLM main:
83f478bb19

Signed-off-by: zhangsicheng5 <zhangsicheng5@huawei.com>
This commit is contained in:
zhangsicheng5
2025-11-12 17:22:21 +08:00
committed by GitHub
parent 1b4ce63ec9
commit a123f355e9
6 changed files with 246 additions and 97 deletions

View File

@@ -79,7 +79,7 @@ class ReqMeta:
remote_port: str
engine_id: str
remote_tp_size: str
remote_cp_size: str
remote_pcp_size: str
remote_dcp_size: str
@@ -97,7 +97,7 @@ class LLMDataDistCMgrConnectorMetadata(KVConnectorMetadata):
remote_host=kv_transfer_params["remote_host"],
remote_port=kv_transfer_params["remote_port"],
remote_tp_size=kv_transfer_params["remote_tp_size"],
remote_cp_size=kv_transfer_params["remote_cp_size"],
remote_pcp_size=kv_transfer_params["remote_pcp_size"],
remote_dcp_size=kv_transfer_params["remote_dcp_size"],
)
@@ -318,7 +318,7 @@ class LLMDataDistCMgrConnectorScheduler():
remote_port=self.port,
remote_tp_size=str(
self.vllm_config.parallel_config.tensor_parallel_size),
remote_cp_size=str(self.pcp_size),
remote_pcp_size=str(self.pcp_size),
remote_dcp_size=str(self.dcp_size),
)
@@ -677,7 +677,7 @@ class LLMDataDistCMgrConnectorWorker():
remote_engine_id=meta.engine_id,
request_id=req_id,
remote_tp_size=meta.remote_tp_size,
remote_cp_size=meta.remote_cp_size,
remote_pcp_size=meta.remote_pcp_size,
remote_dcp_size=meta.remote_dcp_size,
)
futures.append(future)
@@ -876,39 +876,40 @@ class LLMDataDistCMgrConnectorWorker():
remote_block_ids: list[int],
remote_port: int,
remote_tp_size: int,
remote_cp_size: int,
remote_pcp_size: int,
remote_dcp_size: int,
) -> tuple[int, list[int], list[int]]:
"""
In cp/dcp scenario, kv_cache may be split, so we need to pull multiple blocks from multiple remote P node.
Use this function to calculate remote port and remote block number of each remote P node that we need to pull.
"""
if self.pcp_size == remote_cp_size and self.dcp_size == remote_dcp_size:
if self.pcp_size == remote_pcp_size and self.dcp_size == remote_dcp_size:
# remote & local cp/dcp are equal, do kv transfer point-to-point
remote_kv_num = 1
remote_ports = [remote_port + self.pcp_rank * self.tp_size + tp_offset \
for tp_offset in range(self.tp_rank, int(remote_tp_size), self.tp_size)]
remote_block_nums = [len(remote_block_ids)]
elif (self.use_mla and self.pcp_size == 1 and self.dcp_size == 1) \
or (not self.use_mla and self.pcp_size == 1 and remote_tp_size == self.tp_size and remote_dcp_size == self.dcp_size):
or (not self.use_mla and self.pcp_size == 1 and self.dcp_size == 1 and remote_tp_size == self.tp_size):
# remote & local cp/dcp are not equal, each D node needs to pull from cp(*dcp) P nodes
# 1. for mla, support D cp_size = dcp_size = 1
# 2. for gqa, support D tp_size = P tp_size, D dcp_size = P dcp_size
remote_dcp_size = remote_dcp_size // self.dcp_size
remote_kv_num = remote_cp_size * remote_dcp_size
remote_kv_num = remote_pcp_size * remote_dcp_size
cp_dcp_offsets = []
for cp_idx in range(remote_cp_size):
for cp_idx in range(remote_pcp_size):
cp_offset = cp_idx * remote_tp_size
cp_dcp_offsets += list(
range(cp_offset, cp_offset + remote_dcp_size))
remote_ports = [remote_port + cp_dcp_offset + (self.tp_rank if not self.use_mla else 0) \
tp_offset = 0 if self.use_mla else self.tp_rank // remote_dcp_size * remote_dcp_size
remote_ports = [remote_port + cp_dcp_offset + tp_offset \
for cp_dcp_offset in cp_dcp_offsets]
# recompute cp/dcp block assign here, maybe we can also pass it from P node meta
local_block_num = len(local_block_ids)
remote_block_nums = [
local_block_num // (remote_cp_size * remote_dcp_size)
] * remote_cp_size * remote_dcp_size
num_remain_blocks = local_block_num % (remote_cp_size *
local_block_num // (remote_pcp_size * remote_dcp_size)
] * remote_pcp_size * remote_dcp_size
num_remain_blocks = local_block_num % (remote_pcp_size *
remote_dcp_size)
for i in range(num_remain_blocks):
remote_block_nums[i] += 1
@@ -921,7 +922,7 @@ class LLMDataDistCMgrConnectorWorker():
# Other cases are not supported now, maybe need to reshard kv_cache.
raise NotImplementedError(
f'Current case is not supported now: use_mla={self.use_mla}, '
f'P tp={remote_tp_size}, pcp={remote_cp_size}, dcp={remote_dcp_size}, '
f'P tp={remote_tp_size}, pcp={remote_pcp_size}, dcp={remote_dcp_size}, '
f'D tp={self.tp_size}, pcp={self.pcp_size}, dcp={self.dcp_size}'
)
return remote_kv_num, remote_ports, remote_block_nums
@@ -935,7 +936,7 @@ class LLMDataDistCMgrConnectorWorker():
remote_engine_id: str,
request_id: str,
remote_tp_size: str,
remote_cp_size: str,
remote_pcp_size: str,
remote_dcp_size: str,
):
remote_kv_num, remote_ports, remote_block_nums = self._get_kv_split_metadata(
@@ -943,7 +944,7 @@ class LLMDataDistCMgrConnectorWorker():
remote_block_ids=remote_block_ids,
remote_port=remote_port,
remote_tp_size=int(remote_tp_size),
remote_cp_size=int(remote_cp_size),
remote_pcp_size=int(remote_pcp_size),
remote_dcp_size=int(remote_dcp_size),
)
logger.debug(