support cp&dcp (#3260)

### What this PR does / why we need it?
This PR adds the Prefill Context Parallelism (PCP) feature, which
corresponds to DCP. For specific implementation details, please refer to
the RFC https://github.com/vllm-project/vllm/issues/25749.
TL;DR: PCP enhances long-sequence inference capabilities by partitioning
the sequence dimension during the prefill stage.
### Does this PR introduce _any_ user-facing change?
The current implementation primarily includes the following changes:

Modified ModelRunner.py for CP partitioning logic for tokens;
Modified attention_v1.py and mla_v1.py to adapt the GQA/MLA backend to
PCP.
Modified block_tables.py to extend the KV cache storage based on
DCP&PCP;
Added necessary command-line arguments to control parallelism for PCP;
### How was this patch tested?


- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

---------

Signed-off-by: LookAround <lixushi@huawei.com>
Signed-off-by: chenjie <chenjie137@huawei.com>
Signed-off-by: Delphine-Nic <tanwenqin@huawei.com>
Signed-off-by: zhangsicheng5 <zhangsicheng5@huawei.com>
Signed-off-by: Feng Liu <liufeng248@huawei.com>
Signed-off-by: gaojc <1055866782@qq.com>
Signed-off-by: weiguihua2 <weiguihua2@huawei.com>
Signed-off-by: z50049692 <zhangmingwei11@huawei.com>
Co-authored-by: chenjie <chenjie137@huawei.com>
Co-authored-by: Delphine-Nic <tanwenqin@huawei.com>
Co-authored-by: zhangsicheng5 <zhangsicheng5@huawei.com>
Co-authored-by: Feng Liu <liufeng248@huawei.com>
Co-authored-by: gaojc <1055866782@qq.com>
Co-authored-by: weiguihua2 <weiguihua2@huawei.com>
Co-authored-by: z50049692 <zhangmingwei11@huawei.com>
Co-authored-by: w00896881 <wangzixuan40@huawei.com>
This commit is contained in:
LookAround0301
2025-10-24 10:32:01 +08:00
committed by GitHub
parent 2bcadcb9d5
commit b54d44e664
18 changed files with 1729 additions and 211 deletions

View File

@@ -22,7 +22,8 @@ from vllm import envs
from vllm.config import KVTransferConfig, VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
from vllm.distributed.parallel_state import get_tp_group, get_world_group
from vllm.distributed.parallel_state import (get_dcp_group, get_tp_group,
get_world_group)
from vllm.forward_context import ForwardContext
from vllm.utils import get_ip, logger
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
@@ -30,7 +31,12 @@ from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.request import Request, RequestStatus
import vllm_ascend.envs as envs_ascend
from vllm_ascend.utils import AscendSocVersion, get_ascend_soc_version
from vllm_ascend.utils import (AscendSocVersion, get_ascend_soc_version,
prefill_context_parallel_enable)
if prefill_context_parallel_enable():
from vllm.distributed.parallel_state import \
get_prefill_context_model_parallel_rank
TORCH_DTYPE_TO_NPU_DTYPE = {
torch.half: llm_datadist.DataType.DT_FLOAT16,
@@ -66,6 +72,8 @@ class ReqMeta:
remote_port: str
engine_id: str
remote_tp_size: str
remote_cp_size: str
remote_dcp_size: str
class LLMDataDistCMgrConnectorMetadata(KVConnectorMetadata):
@@ -82,6 +90,8 @@ class LLMDataDistCMgrConnectorMetadata(KVConnectorMetadata):
remote_host=kv_transfer_params["remote_host"],
remote_port=kv_transfer_params["remote_port"],
remote_tp_size=kv_transfer_params["remote_tp_size"],
remote_cp_size=kv_transfer_params["remote_cp_size"],
remote_dcp_size=kv_transfer_params["remote_dcp_size"],
)
@@ -185,8 +195,11 @@ class LLMDataDistCMgrConnectorScheduler():
else:
dp_rank_local = vllm_config.parallel_config.data_parallel_rank_local
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
self.pcp_size = self.vllm_config.parallel_config.prefill_context_parallel_size if prefill_context_parallel_enable(
) else 1
self.dcp_size = vllm_config.parallel_config.decode_context_parallel_size
self.port = dp_rank_local * tp_size + envs_ascend.VLLM_ASCEND_LLMDD_RPC_PORT if dp_rank_local is not None else tp_size + envs_ascend.VLLM_ASCEND_LLMDD_RPC_PORT
self.port = dp_rank_local * self.pcp_size * tp_size + envs_ascend.VLLM_ASCEND_LLMDD_RPC_PORT if dp_rank_local is not None else tp_size + envs_ascend.VLLM_ASCEND_LLMDD_RPC_PORT
self._reqs_need_recv: dict[str, tuple[Request, list[int]]] = {}
self._reqs_need_send: dict[str, float] = {}
@@ -298,6 +311,8 @@ class LLMDataDistCMgrConnectorScheduler():
remote_port=self.port,
remote_tp_size=str(
self.vllm_config.parallel_config.tensor_parallel_size),
remote_cp_size=str(self.pcp_size),
remote_dcp_size=str(self.dcp_size),
)
@@ -322,6 +337,11 @@ class LLMDataDistCMgrConnectorWorker():
self.tp_size = vllm_config.parallel_config.tensor_parallel_size
self.tp_rank = get_tp_group().rank_in_group
self.rank = get_world_group().rank
self.pcp_size = vllm_config.parallel_config.prefill_context_parallel_size if prefill_context_parallel_enable(
) else 1
self.pcp_rank = get_prefill_context_model_parallel_rank(
) if prefill_context_parallel_enable() else 0
self.dcp_size = get_dcp_group().world_size
self.local_ip = get_ip()
self.kv_transfer_config: KVTransferConfig = vllm_config.kv_transfer_config
self.local_agent_metadata: Optional[
@@ -362,7 +382,8 @@ class LLMDataDistCMgrConnectorWorker():
def listen_for_agent_metadata_req(self, event: threading.Event):
assert self.local_agent_metadata is not None
port = envs_ascend.VLLM_ASCEND_LLMDD_RPC_PORT + self.local_dp_rank * self.tp_size + self.tp_rank if self.local_dp_rank is not None else envs_ascend.VLLM_ASCEND_LLMDD_RPC_PORT + self.tp_size + self.tp_rank
port = envs_ascend.VLLM_ASCEND_LLMDD_RPC_PORT + self.local_dp_rank * self.pcp_size * self.tp_size + self.pcp_rank * self.tp_size + self.tp_rank \
if self.local_dp_rank is not None else envs_ascend.VLLM_ASCEND_LLMDD_RPC_PORT + self.tp_size + self.tp_rank
url = f"tcp://{envs_ascend.VLLM_ASCEND_LLMDD_RPC_IP}:{port}"
msg_encoder = msgspec.msgpack.Encoder()
msg_decoder = msgspec.msgpack.Decoder()
@@ -472,9 +493,10 @@ class LLMDataDistCMgrConnectorWorker():
d for d in device_list if d.get("server_id") == self.local_ip
and device_filter(d.get("device_id", ""))
]
if len(device_list) <= self.tp_rank:
if len(device_list) <= self.pcp_rank * self.tp_size + self.tp_rank:
continue
device_info = device_list[self.tp_rank]
device_info = device_list[self.pcp_rank * self.tp_size +
self.tp_rank]
super_pod_id_ = device_info.get("super_pod_id", None)
server_id_ = device_info["server_id"]
device_id_ = device_info["device_id"]
@@ -648,6 +670,8 @@ class LLMDataDistCMgrConnectorWorker():
remote_engine_id=meta.engine_id,
request_id=req_id,
remote_tp_size=meta.remote_tp_size,
remote_cp_size=meta.remote_cp_size,
remote_dcp_size=meta.remote_dcp_size,
)
futures.append(future)
@@ -839,6 +863,62 @@ class LLMDataDistCMgrConnectorWorker():
f"Failed to send reqest_id {request_id} to prefill: {e}"
)
def _get_kv_split_metadata(
self,
local_block_ids: list[int],
remote_block_ids: list[int],
remote_port: int,
remote_tp_size: int,
remote_cp_size: int,
remote_dcp_size: int,
) -> tuple[int, list[int], list[int]]:
"""
In cp/dcp scenario, kv_cache may be split, so we need to pull multiple blocks from multiple remote P node.
Use this function to calculate remote port and remote block number of each remote P node that we need to pull.
"""
if self.pcp_size == remote_cp_size and self.dcp_size == remote_dcp_size:
# remote & local cp/dcp are equal, do kv transfer point-to-point
remote_kv_num = 1
remote_ports = [remote_port + self.pcp_rank * self.tp_size + tp_offset \
for tp_offset in range(self.tp_rank, int(remote_tp_size), self.tp_size)]
remote_block_nums = [len(remote_block_ids)]
elif (self.use_mla and self.pcp_size == 1 and self.dcp_size == 1) \
or (not self.use_mla and self.pcp_size == 1 and remote_tp_size == self.tp_size and remote_dcp_size == self.dcp_size):
# remote & local cp/dcp are not equal, each D node needs to pull from cp(*dcp) P nodes
# 1. for mla, support D cp_size = dcp_size = 1
# 2. for gqa, support D tp_size = P tp_size, D dcp_size = P dcp_size
remote_dcp_size = remote_dcp_size // self.dcp_size
remote_kv_num = remote_cp_size * remote_dcp_size
cp_dcp_offsets = []
for cp_idx in range(remote_cp_size):
cp_offset = cp_idx * remote_tp_size
cp_dcp_offsets += list(
range(cp_offset, cp_offset + remote_dcp_size))
remote_ports = [remote_port + cp_dcp_offset + (self.tp_rank if not self.use_mla else 0) \
for cp_dcp_offset in cp_dcp_offsets]
# recompute cp/dcp block assign here, maybe we can also pass it from P node meta
local_block_num = len(local_block_ids)
remote_block_nums = [
local_block_num // (remote_cp_size * remote_dcp_size)
] * remote_cp_size * remote_dcp_size
num_remain_blocks = local_block_num % (remote_cp_size *
remote_dcp_size)
for i in range(num_remain_blocks):
remote_block_nums[i] += 1
# make sure the last block (which may be unfull) of P nodes is put to the last block of D node
remote_ports = remote_ports[
num_remain_blocks:] + remote_ports[:num_remain_blocks]
remote_block_nums = remote_block_nums[
num_remain_blocks:] + remote_block_nums[:num_remain_blocks]
else:
# Other cases are not supported now, maybe need to reshard kv_cache.
raise NotImplementedError(
f'Current case is not supported now: use_mla={self.use_mla}, '
f'P tp={remote_tp_size}, pcp={remote_cp_size}, dcp={remote_dcp_size}, '
f'D tp={self.tp_size}, pcp={self.pcp_size}, dcp={self.dcp_size}'
)
return remote_kv_num, remote_ports, remote_block_nums
def _read_blocks(
self,
local_block_ids: list[int],
@@ -848,97 +928,119 @@ class LLMDataDistCMgrConnectorWorker():
remote_engine_id: str,
request_id: str,
remote_tp_size: str,
remote_cp_size: str,
remote_dcp_size: str,
):
# if remote_ip not in self.linked_cluster:
tp_offset = self.tp_rank % int(remote_tp_size)
remote_cluster_id = self.connect_to_remote_agent(
remote_ip, remote_port + tp_offset)
num_local_blocks = len(local_block_ids)
if num_local_blocks == 0:
return
num_remote_blocks = len(remote_block_ids)
assert num_local_blocks <= num_remote_blocks
if num_local_blocks < num_remote_blocks:
remote_block_ids = remote_block_ids[-num_local_blocks:]
remote_kv_num, remote_ports, remote_block_nums = self._get_kv_split_metadata(
local_block_ids=local_block_ids,
remote_block_ids=remote_block_ids,
remote_port=remote_port,
remote_tp_size=int(remote_tp_size),
remote_cp_size=int(remote_cp_size),
remote_dcp_size=int(remote_dcp_size),
)
logger.debug(
f'Pull blocks from remote: remote_kv_num={remote_kv_num}, remote_ports={remote_ports}, '
f'remote_block_nums={remote_block_nums}, local_block_ids={local_block_ids}'
)
logger.info(f"remote cluster id is: {remote_cluster_id}")
if self.use_mla:
remote_cache_key_k_normed = BlocksCacheKey(
cluster_id=remote_cluster_id, model_id=0)
remote_cache_key_k_pe = BlocksCacheKey(
cluster_id=remote_cluster_id, model_id=1)
logger.info("Try pull blocks from remote server")
try:
self.cache_manager.pull_blocks(
remote_cache_key_k_normed,
self.cache[0], # type: ignore[has-type]
remote_block_ids,
local_block_ids)
self.cache_manager.pull_blocks(
remote_cache_key_k_pe,
self.cache[1], # type: ignore[has-type]
remote_block_ids,
local_block_ids)
except (TypeError, ValueError):
raise RuntimeError(
f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to pull_blocks remote_cache_key: {remote_cache_key_k_normed} {remote_cache_key_k_pe}, cache: {self.cache}, local_block_ids: {local_block_ids}, remote_block_ids: {remote_block_ids}" # type: ignore[has-type]
)
except LLMException:
raise RuntimeError(
"LLMDataDistCMgrConnectorWorker: Timeout during pull_blocks, you can try to increase the sync_kv_timeout config or checking your connect status"
)
elif self.use_sparse:
remote_cache_key_k_normed = BlocksCacheKey(
cluster_id=remote_cluster_id, model_id=0)
remote_cache_key_k_pe = BlocksCacheKey(
cluster_id=remote_cluster_id, model_id=1)
remote_cache_key_k_idx = BlocksCacheKey(
cluster_id=remote_cluster_id, model_id=2)
logger.info("Try pull blocks from remote server")
try:
self.cache_manager.pull_blocks(
remote_cache_key_k_normed,
self.cache[0], # type: ignore[has-type]
remote_block_ids,
local_block_ids)
self.cache_manager.pull_blocks(
remote_cache_key_k_pe,
self.cache[1], # type: ignore[has-type]
remote_block_ids,
local_block_ids)
self.cache_manager.pull_blocks(
remote_cache_key_k_idx,
self.cache[2], # type: ignore[has-type]
remote_block_ids,
local_block_ids)
except (TypeError, ValueError):
raise RuntimeError(
f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to pull_blocks remote_cache_key: {remote_cache_key_k_normed} {remote_cache_key_k_pe} {remote_cache_key_k_idx}, cache: {self.cache}, local_block_ids: {local_block_ids}, remote_block_ids: {remote_block_ids}" # type: ignore[has-type]
)
except LLMException:
raise RuntimeError(
"LLMDataDistCMgrConnectorWorker: Timeout during pull_blocks, you can try to increase the sync_kv_timeout config or checking your connect status"
)
else:
remote_cache_key = BlocksCacheKey(cluster_id=remote_cluster_id)
logger.info("Try pull blocks from remote server")
try:
self.cache_manager.pull_blocks(
remote_cache_key,
self.cache, # type: ignore[has-type]
remote_block_ids,
local_block_ids)
except (TypeError, ValueError):
raise RuntimeError(
f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to pull_blocks remote_cache_key: {remote_cache_key}, cache: {self.cache}, local_block_ids: {local_block_ids}, remote_block_ids: {remote_block_ids}" # type: ignore[has-type]
)
except LLMException:
raise RuntimeError(
"LLMDataDistCMgrConnectorWorker: Timeout during pull_blocks, you can try to increase the sync_kv_timeout config or checking your connect status"
)
remote_ports = list(
range(remote_port + self.tp_rank,
remote_port + int(remote_tp_size), self.tp_size))
local_block_offset = 0
remote_block_ids_full = remote_block_ids
local_block_ids_full = local_block_ids
for remote_kv_id in range(remote_kv_num):
remote_port = remote_ports[remote_kv_id]
num_blocks_to_pull = remote_block_nums[remote_kv_id]
if num_blocks_to_pull == 0:
continue
remote_block_ids = remote_block_ids_full[:num_blocks_to_pull]
local_block_ids = local_block_ids_full[
local_block_offset:local_block_offset + num_blocks_to_pull]
local_block_offset += num_blocks_to_pull
remote_cluster_id = self.connect_to_remote_agent(
remote_ip, remote_port)
num_local_blocks = len(local_block_ids)
if num_local_blocks == 0:
return
num_remote_blocks = len(remote_block_ids)
assert num_local_blocks <= num_remote_blocks
if num_local_blocks < num_remote_blocks:
remote_block_ids = remote_block_ids[-num_local_blocks:]
logger.info(f"remote cluster id is: {remote_cluster_id}")
if self.use_mla:
remote_cache_key_k_normed = BlocksCacheKey(
cluster_id=remote_cluster_id, model_id=0)
remote_cache_key_k_pe = BlocksCacheKey(
cluster_id=remote_cluster_id, model_id=1)
logger.info("Try pull blocks from remote server")
try:
self.cache_manager.pull_blocks(
remote_cache_key_k_normed,
self.cache[0], # type: ignore[has-type]
remote_block_ids,
local_block_ids)
self.cache_manager.pull_blocks(
remote_cache_key_k_pe,
self.cache[1], # type: ignore[has-type]
remote_block_ids,
local_block_ids)
except (TypeError, ValueError):
raise RuntimeError(
f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to pull_blocks remote_cache_key: {remote_cache_key_k_normed} {remote_cache_key_k_pe}, cache: {self.cache}, local_block_ids: {local_block_ids}, remote_block_ids: {remote_block_ids}" # type: ignore[has-type]
)
except LLMException:
raise RuntimeError(
"LLMDataDistCMgrConnectorWorker: Timeout during pull_blocks, you can try to increase the sync_kv_timeout config or checking your connect status"
)
elif self.use_sparse:
remote_cache_key_k_normed = BlocksCacheKey(
cluster_id=remote_cluster_id, model_id=0)
remote_cache_key_k_pe = BlocksCacheKey(
cluster_id=remote_cluster_id, model_id=1)
remote_cache_key_k_idx = BlocksCacheKey(
cluster_id=remote_cluster_id, model_id=2)
logger.info("Try pull blocks from remote server")
try:
self.cache_manager.pull_blocks(
remote_cache_key_k_normed,
self.cache[0], # type: ignore[has-type]
remote_block_ids,
local_block_ids)
self.cache_manager.pull_blocks(
remote_cache_key_k_pe,
self.cache[1], # type: ignore[has-type]
remote_block_ids,
local_block_ids)
self.cache_manager.pull_blocks(
remote_cache_key_k_idx,
self.cache[2], # type: ignore[has-type]
remote_block_ids,
local_block_ids)
except (TypeError, ValueError):
raise RuntimeError(
f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to pull_blocks remote_cache_key: {remote_cache_key_k_normed} {remote_cache_key_k_pe} {remote_cache_key_k_idx}, cache: {self.cache}, local_block_ids: {local_block_ids}, remote_block_ids: {remote_block_ids}" # type: ignore[has-type]
)
except LLMException:
raise RuntimeError(
"LLMDataDistCMgrConnectorWorker: Timeout during pull_blocks, you can try to increase the sync_kv_timeout config or checking your connect status"
)
else:
remote_cache_key = BlocksCacheKey(cluster_id=remote_cluster_id)
logger.info("Try pull blocks from remote server")
try:
self.cache_manager.pull_blocks(
remote_cache_key,
self.cache, # type: ignore[has-type]
remote_block_ids,
local_block_ids)
except (TypeError, ValueError):
raise RuntimeError(
f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to pull_blocks remote_cache_key: {remote_cache_key}, cache: {self.cache}, local_block_ids: {local_block_ids}, remote_block_ids: {remote_block_ids}" # type: ignore[has-type]
)
except LLMException:
raise RuntimeError(
"LLMDataDistCMgrConnectorWorker: Timeout during pull_blocks, you can try to increase the sync_kv_timeout config or checking your connect status"
)
self.send_finish_to_remote(remote_ip, remote_ports, request_id)
with self.thread_lock:
self.finished_reqs.add(request_id)
@@ -990,4 +1092,4 @@ def zmq_ctx(socket_type: Any,
yield socket
finally:
if ctx is not None:
ctx.destroy(linger=0)
ctx.destroy(linger=0)

View File

@@ -7,6 +7,7 @@ from vllm.distributed.parallel_state import (GroupCoordinator, get_world_group,
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.utils import prefill_context_parallel_enable
# Currently, mc2 op need their own group coordinator.
_MC2: Optional[GroupCoordinator] = None
@@ -58,9 +59,15 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
# The layout of all ranks: ExternalDP * EP
# ExternalDP is the data parallel group that is not part of the model,
# every dp rank can generate independently (in verl integration).
all_ranks = torch.arange(world_size).reshape(
-1, parallel_config.data_parallel_size *
parallel_config.tensor_parallel_size)
if prefill_context_parallel_enable():
all_ranks = torch.arange(world_size).reshape(
-1, parallel_config.data_parallel_size *
parallel_config.prefill_context_parallel_size *
parallel_config.tensor_parallel_size)
else:
all_ranks = torch.arange(world_size).reshape(
-1, parallel_config.data_parallel_size *
parallel_config.tensor_parallel_size)
pd_tp_ratio = get_ascend_config().pd_tp_ratio
pd_head_ratio = get_ascend_config().pd_head_ratio