[feature] support pcp + mtp (in pd co-locate scenario) (#4098)
1. support pcp + mtp in pd co-locate scenario
2. llmdatadist connector pcp related bugfix and cleancode
- vLLM version: v0.11.0
- vLLM main:
83f478bb19
Signed-off-by: zhangsicheng5 <zhangsicheng5@huawei.com>
This commit is contained in:
@@ -79,7 +79,7 @@ class ReqMeta:
|
||||
remote_port: str
|
||||
engine_id: str
|
||||
remote_tp_size: str
|
||||
remote_cp_size: str
|
||||
remote_pcp_size: str
|
||||
remote_dcp_size: str
|
||||
|
||||
|
||||
@@ -97,7 +97,7 @@ class LLMDataDistCMgrConnectorMetadata(KVConnectorMetadata):
|
||||
remote_host=kv_transfer_params["remote_host"],
|
||||
remote_port=kv_transfer_params["remote_port"],
|
||||
remote_tp_size=kv_transfer_params["remote_tp_size"],
|
||||
remote_cp_size=kv_transfer_params["remote_cp_size"],
|
||||
remote_pcp_size=kv_transfer_params["remote_pcp_size"],
|
||||
remote_dcp_size=kv_transfer_params["remote_dcp_size"],
|
||||
)
|
||||
|
||||
@@ -318,7 +318,7 @@ class LLMDataDistCMgrConnectorScheduler():
|
||||
remote_port=self.port,
|
||||
remote_tp_size=str(
|
||||
self.vllm_config.parallel_config.tensor_parallel_size),
|
||||
remote_cp_size=str(self.pcp_size),
|
||||
remote_pcp_size=str(self.pcp_size),
|
||||
remote_dcp_size=str(self.dcp_size),
|
||||
)
|
||||
|
||||
@@ -677,7 +677,7 @@ class LLMDataDistCMgrConnectorWorker():
|
||||
remote_engine_id=meta.engine_id,
|
||||
request_id=req_id,
|
||||
remote_tp_size=meta.remote_tp_size,
|
||||
remote_cp_size=meta.remote_cp_size,
|
||||
remote_pcp_size=meta.remote_pcp_size,
|
||||
remote_dcp_size=meta.remote_dcp_size,
|
||||
)
|
||||
futures.append(future)
|
||||
@@ -876,39 +876,40 @@ class LLMDataDistCMgrConnectorWorker():
|
||||
remote_block_ids: list[int],
|
||||
remote_port: int,
|
||||
remote_tp_size: int,
|
||||
remote_cp_size: int,
|
||||
remote_pcp_size: int,
|
||||
remote_dcp_size: int,
|
||||
) -> tuple[int, list[int], list[int]]:
|
||||
"""
|
||||
In cp/dcp scenario, kv_cache may be split, so we need to pull multiple blocks from multiple remote P node.
|
||||
Use this function to calculate remote port and remote block number of each remote P node that we need to pull.
|
||||
"""
|
||||
if self.pcp_size == remote_cp_size and self.dcp_size == remote_dcp_size:
|
||||
if self.pcp_size == remote_pcp_size and self.dcp_size == remote_dcp_size:
|
||||
# remote & local cp/dcp are equal, do kv transfer point-to-point
|
||||
remote_kv_num = 1
|
||||
remote_ports = [remote_port + self.pcp_rank * self.tp_size + tp_offset \
|
||||
for tp_offset in range(self.tp_rank, int(remote_tp_size), self.tp_size)]
|
||||
remote_block_nums = [len(remote_block_ids)]
|
||||
elif (self.use_mla and self.pcp_size == 1 and self.dcp_size == 1) \
|
||||
or (not self.use_mla and self.pcp_size == 1 and remote_tp_size == self.tp_size and remote_dcp_size == self.dcp_size):
|
||||
or (not self.use_mla and self.pcp_size == 1 and self.dcp_size == 1 and remote_tp_size == self.tp_size):
|
||||
# remote & local cp/dcp are not equal, each D node needs to pull from cp(*dcp) P nodes
|
||||
# 1. for mla, support D cp_size = dcp_size = 1
|
||||
# 2. for gqa, support D tp_size = P tp_size, D dcp_size = P dcp_size
|
||||
remote_dcp_size = remote_dcp_size // self.dcp_size
|
||||
remote_kv_num = remote_cp_size * remote_dcp_size
|
||||
remote_kv_num = remote_pcp_size * remote_dcp_size
|
||||
cp_dcp_offsets = []
|
||||
for cp_idx in range(remote_cp_size):
|
||||
for cp_idx in range(remote_pcp_size):
|
||||
cp_offset = cp_idx * remote_tp_size
|
||||
cp_dcp_offsets += list(
|
||||
range(cp_offset, cp_offset + remote_dcp_size))
|
||||
remote_ports = [remote_port + cp_dcp_offset + (self.tp_rank if not self.use_mla else 0) \
|
||||
tp_offset = 0 if self.use_mla else self.tp_rank // remote_dcp_size * remote_dcp_size
|
||||
remote_ports = [remote_port + cp_dcp_offset + tp_offset \
|
||||
for cp_dcp_offset in cp_dcp_offsets]
|
||||
# recompute cp/dcp block assign here, maybe we can also pass it from P node meta
|
||||
local_block_num = len(local_block_ids)
|
||||
remote_block_nums = [
|
||||
local_block_num // (remote_cp_size * remote_dcp_size)
|
||||
] * remote_cp_size * remote_dcp_size
|
||||
num_remain_blocks = local_block_num % (remote_cp_size *
|
||||
local_block_num // (remote_pcp_size * remote_dcp_size)
|
||||
] * remote_pcp_size * remote_dcp_size
|
||||
num_remain_blocks = local_block_num % (remote_pcp_size *
|
||||
remote_dcp_size)
|
||||
for i in range(num_remain_blocks):
|
||||
remote_block_nums[i] += 1
|
||||
@@ -921,7 +922,7 @@ class LLMDataDistCMgrConnectorWorker():
|
||||
# Other cases are not supported now, maybe need to reshard kv_cache.
|
||||
raise NotImplementedError(
|
||||
f'Current case is not supported now: use_mla={self.use_mla}, '
|
||||
f'P tp={remote_tp_size}, pcp={remote_cp_size}, dcp={remote_dcp_size}, '
|
||||
f'P tp={remote_tp_size}, pcp={remote_pcp_size}, dcp={remote_dcp_size}, '
|
||||
f'D tp={self.tp_size}, pcp={self.pcp_size}, dcp={self.dcp_size}'
|
||||
)
|
||||
return remote_kv_num, remote_ports, remote_block_nums
|
||||
@@ -935,7 +936,7 @@ class LLMDataDistCMgrConnectorWorker():
|
||||
remote_engine_id: str,
|
||||
request_id: str,
|
||||
remote_tp_size: str,
|
||||
remote_cp_size: str,
|
||||
remote_pcp_size: str,
|
||||
remote_dcp_size: str,
|
||||
):
|
||||
remote_kv_num, remote_ports, remote_block_nums = self._get_kv_split_metadata(
|
||||
@@ -943,7 +944,7 @@ class LLMDataDistCMgrConnectorWorker():
|
||||
remote_block_ids=remote_block_ids,
|
||||
remote_port=remote_port,
|
||||
remote_tp_size=int(remote_tp_size),
|
||||
remote_cp_size=int(remote_cp_size),
|
||||
remote_pcp_size=int(remote_pcp_size),
|
||||
remote_dcp_size=int(remote_dcp_size),
|
||||
)
|
||||
logger.debug(
|
||||
|
||||
Reference in New Issue
Block a user