Files
xc-llm-ascend/vllm_ascend/distributed/kv_transfer/utils/utils.py
wangxiaoteng888 b881fab416 [P/D][PCP] mooncake layerwise support pcp function (#6627)
### What this PR does / why we need it?
mooncake layerwise support pcp function
PCP (Prefill Context Parallelism) Support: Introduced explicit support
for Prefill Context Parallelism (PCP) and Decode Context Parallelism
(DCP) in the Mooncake layerwise KV cache transfer mechanism, allowing
for more granular control and awareness of parallel configurations
during data transfer.

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
By ci

- vLLM version: v0.15.0
- vLLM main:
d7e17aaacd

---------

Signed-off-by: wangxiaoteng <wangxiaoteng@huawei.com>
Signed-off-by: liziyu <liziyu16@huawei.com>
Co-authored-by: liziyu <liziyu16@huawei.com>
2026-02-12 11:02:25 +08:00

296 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import math
import os
from collections import defaultdict
from dataclasses import dataclass
from typing import Any
import torch
import torch.distributed as dist
from vllm.logger import logger
from vllm_ascend.distributed.parallel_state import get_p_tp_group
def kv_alltoall_and_rearrange(pd_tp_ratio: int, key: torch.Tensor, value: torch.TensorType):
if pd_tp_ratio <= 1:
return None, None
elif key is None or value is None:
raise ValueError("key or value is None")
k_output = alltoall_and_rearrange(pd_tp_ratio, key)
v_output = alltoall_and_rearrange(pd_tp_ratio, value)
return k_output, v_output
def alltoall_and_rearrange(tp_ratio: int, input_tensor: torch.Tensor):
num_kv_heads = input_tensor.size(1)
output_tensor = torch.zeros_like(input_tensor)
dist.all_to_all_single(output_tensor, input_tensor, group=get_p_tp_group().device_group)
input_tensor = 0
result = rearrange_output(output_tensor, tp_ratio, num_kv_heads)
output_tensor = 0
return result
def rearrange_output(base_output: torch.Tensor, cut_num: int, num_kv_heads: int):
size_0 = base_output.size(0)
if size_0 % cut_num != 0:
raise ValueError(f"The size of dim 0 [{size_0}] must be divisible by the cut_num [{cut_num}]")
chunk_size = size_0 // cut_num
reshaped = base_output.view(cut_num, chunk_size, -1)
transposed = reshaped.transpose(0, 1)
return transposed.contiguous().view(size_0, num_kv_heads, -1)
def align_memory(tensor: torch.Tensor, alignment: int) -> torch.Tensor:
data_ptr = tensor.data_ptr()
aligned_addr = (data_ptr + alignment - 1) // alignment * alignment
offset = (aligned_addr - data_ptr) // tensor.element_size()
return tensor[int(offset) :]
def get_transfer_timeout_value():
ascend_transfer_timeout = os.getenv("ASCEND_TRANSFER_TIMEOUT", "")
if len(ascend_transfer_timeout) > 0:
return int(ascend_transfer_timeout)
hccl_rdma_timeout = int(os.getenv("HCCL_RDMA_TIMEOUT", "20")) # type: ignore
hccl_rdma_retry_cnt = int(os.getenv("HCCL_RDMA_RETRY_CNT", "7")) # type: ignore
return int((4.096 * (2**hccl_rdma_timeout)) * hccl_rdma_retry_cnt // 1000 + 3000)
@dataclass
class parallel_info:
tp_size: int
pcp_size: int
dcp_size: int
use_mla: bool
pd_head_ratio: int
def get_cp_group(tp: int, heads: int, dcp: int):
# Partition the second dimension of [pcp][head_group][dcp] to obtain a complete head group
# head_group is all blocks for request in the same head
# tp8 dcp2 heads4 return[[0,1,2,3]]
# tp8 dcp1 heads4 return[[0,2,4,6],[1,3,5,7]]
step = tp // heads
if step == 0:
return [[i for i in range(tp // dcp)]]
else:
return [
set([k // dcp for h in range(heads) for k in range(h * step + i * dcp, h * step + (i + 1) * dcp)])
for i in range(step // dcp)
]
def context_parallel_parameters_check(
remote_pcp_size: int,
remote_dcp_size: int,
p_parallel_info: parallel_info,
d_parallel_info: parallel_info,
total_num_kv_heads: int,
):
# Check whether the pcpdcp ratio is supported
assert (p_parallel_info.pcp_size * p_parallel_info.dcp_size) % (remote_pcp_size * remote_dcp_size) == 0
if not p_parallel_info.use_mla:
p_node_heads_per_rank = math.ceil(total_num_kv_heads / p_parallel_info.tp_size)
d_node_heads_per_rank = math.ceil(total_num_kv_heads / d_parallel_info.dcp_size)
assert d_node_heads_per_rank % p_node_heads_per_rank == 0
def get_tp_rank_head_mapping(num_key_value_heads: int, tp_size: int):
# Get the head_idx corresponding to the tp_rank, {tp_rank:[head_indx]}
mapping = {}
if tp_size <= num_key_value_heads:
if num_key_value_heads % tp_size != 0:
raise ValueError(f"Number of heads ({num_key_value_heads}) cannot be evenly divided by TP ({tp_size}).")
heads_per_rank = num_key_value_heads // tp_size
for rank in range(tp_size):
start_idx = rank * heads_per_rank
end_idx = start_idx + heads_per_rank
mapping[rank] = list(range(start_idx, end_idx))
else:
if tp_size % num_key_value_heads != 0:
raise ValueError(f"Number of heads ({num_key_value_heads}) cannot be evenly divided by TP ({tp_size}).")
ranks_per_head = tp_size // num_key_value_heads
for rank in range(tp_size):
head_idx = rank // ranks_per_head
mapping[rank] = [head_idx]
return mapping
def get_head_group_mapping(num_key_value_heads: int, tp_size: int, num_groups: int, select_cp_group: list[int]):
# Get the mapping dictionary, where the key is head_group_rank and the value is head_idx
if tp_size % num_groups != 0:
raise ValueError(
f"Total number of devices ({tp_size}) cannot be divided by the number of groups ({num_groups})."
)
ranks_per_group = tp_size // num_groups
tp_mapping = get_tp_rank_head_mapping(num_key_value_heads, tp_size)
group_mapping = {}
for group_rank in range(num_groups):
if group_rank in select_cp_group:
start_rank = group_rank * ranks_per_group
end_rank = start_rank + ranks_per_group
heads_set = set()
for rank in range(start_rank, end_rank):
heads_set.update(tp_mapping[rank])
group_mapping[group_rank] = sorted(list(heads_set))
return group_mapping
def get_local_remote_block_port_mappings(
to_trans_idx: int,
p_parallel_info: parallel_info,
d_parallel_info: parallel_info,
d_hosts: list[str],
d_port: int,
selected_p_cp_group: list[int],
selected_d_cp_group: list[int],
prompt_len: int,
block_size: int,
req_meta,
total_num_kv_heads: int,
req_id: str,
):
p_head_group_size = p_parallel_info.tp_size // p_parallel_info.dcp_size
d_head_group_size = d_parallel_info.tp_size // d_parallel_info.dcp_size
world_size = d_parallel_info.pcp_size * d_head_group_size * d_parallel_info.dcp_size
# Compute which logic_block_idx corresponds to each tp_rank
p_rank_block_mapping: list[list[list[list[int]]]] = [
[[[] for _ in range(p_parallel_info.dcp_size)] for _ in range(p_head_group_size)]
for _ in range(p_parallel_info.pcp_size)
]
for logic_block_idx in range(to_trans_idx):
pcp_rank = (logic_block_idx // p_parallel_info.dcp_size) % p_parallel_info.pcp_size
dcp_rank = logic_block_idx % p_parallel_info.dcp_size
for p_head_group_rank in range(p_head_group_size):
if p_head_group_rank in selected_p_cp_group:
p_rank_block_mapping[pcp_rank][p_head_group_rank][dcp_rank].append(logic_block_idx)
# Find the remote device that holds the logic_block_idx
d_block_rank_mapping: dict[int, dict[int, dict[str, Any]]] = defaultdict(lambda: defaultdict(dict))
for logic_block_idx in range(to_trans_idx):
pcp_rank = (logic_block_idx // d_parallel_info.dcp_size) % d_parallel_info.pcp_size
for d_head_group_rank in range(d_head_group_size):
if d_head_group_rank in selected_d_cp_group:
dcp_rank = logic_block_idx % d_parallel_info.dcp_size
world_rank = (
pcp_rank * d_head_group_size * d_parallel_info.dcp_size
+ d_head_group_rank * d_parallel_info.dcp_size
+ dcp_rank
)
world_size = d_parallel_info.pcp_size * d_head_group_size * d_parallel_info.dcp_size
host = d_hosts[(len(d_hosts) * world_rank) // world_size]
port = d_port + world_rank
block_idx = (logic_block_idx - (pcp_rank * d_parallel_info.pcp_size + dcp_rank)) // (
d_parallel_info.pcp_size * d_parallel_info.dcp_size
)
d_block_rank_mapping[logic_block_idx][d_head_group_rank] = {
"pcp_rank": pcp_rank,
"dcp_rank": dcp_rank,
"host": host,
"port": port,
"block_idx": block_idx,
}
# Get how many times each device should receive done_single for this request
d_trans_count_mapping = {}
trans_block_size = math.ceil(prompt_len / block_size) # Total number of blocks
transed_block_size = math.ceil(req_meta.remote_cache_tokens / block_size) # Number of prefix cache hit blocks
d_cp_size = d_parallel_info.pcp_size * d_parallel_info.dcp_size
for d_pcp_rank in range(d_parallel_info.pcp_size):
for d_head_group_rank in range(d_head_group_size):
for d_dcp_rank in range(d_parallel_info.dcp_size):
if trans_block_size >= (p_parallel_info.pcp_size * p_parallel_info.dcp_size):
trans_count = (p_parallel_info.pcp_size * p_parallel_info.dcp_size) // d_cp_size
else:
current_rank_idx = d_pcp_rank * d_parallel_info.dcp_size + d_dcp_rank
total_global_blocks = transed_block_size + trans_block_size
target_total_count = total_global_blocks // d_cp_size
if current_rank_idx < (total_global_blocks % d_cp_size):
target_total_count += 1
prev_processed_count = transed_block_size // d_cp_size
if current_rank_idx < (transed_block_size % d_cp_size):
prev_processed_count += 1
trans_count = target_total_count - prev_processed_count
world_rank = (
d_pcp_rank * d_head_group_size * d_parallel_info.dcp_size
+ d_head_group_rank * d_parallel_info.dcp_size
+ d_dcp_rank
)
host = d_hosts[(len(d_hosts) * world_rank) // world_size]
port = d_port + world_rank
d_trans_count_mapping[(host, port)] = trans_count * p_parallel_info.pd_head_ratio
# Compute the mapping between local and remote head_group_rank
p_tp_rank_head_mapping = get_head_group_mapping(
total_num_kv_heads, p_parallel_info.tp_size, p_head_group_size, selected_p_cp_group
)
d_tp_rank_head_mapping = get_head_group_mapping(
total_num_kv_heads, d_parallel_info.tp_size, d_head_group_size, selected_d_cp_group
)
head_to_d_groups = defaultdict(set)
for d_rank, heads in d_tp_rank_head_mapping.items():
for head in heads:
head_to_d_groups[head].add(d_rank)
pd_head_mapping = {}
for p_rank, p_heads in p_tp_rank_head_mapping.items():
target_d_ranks = set()
for head in p_heads:
if head in head_to_d_groups:
target_d_ranks.update(head_to_d_groups[head])
else:
logger.info(f"Warning: Head {head} exists in P but not in D mapping.")
pd_head_mapping[p_rank] = sorted(list(target_d_ranks))
logger.debug(
f"MooncakeLayerwiseConnector _get_kv_split_metadata {req_id=} "
f"P-side logic_block to rank mapping: {p_rank_block_mapping}, "
f"D-side logic_block to rank mapping: {d_block_rank_mapping}, "
f"P&D head_group_rank mapping: {pd_head_mapping}"
)
return p_rank_block_mapping, d_block_rank_mapping, pd_head_mapping, d_trans_count_mapping
def get_transfer_mappings(
p_rank_block_mapping: list[list[list[list[int]]]],
d_block_rank_mapping: dict[int, dict[int, dict[str, Any]]],
pd_head_mapping: dict[int, set],
d_trans_count_mapping: dict[tuple[str, int], int],
req_meta,
p_parallel_info: parallel_info,
req_id: str,
transed_idx: int,
to_trans_idx: int,
tp_rank: int,
pcp_rank: int,
dcp_rank: int,
):
transfer_mappings: dict[tuple[str, int], dict[str, Any]] = {}
p_head_group_rank = (tp_rank - dcp_rank) // p_parallel_info.dcp_size
p_block_idxs: list[int] = p_rank_block_mapping[pcp_rank][p_head_group_rank][dcp_rank]
for p_block_idx, logic_block_idx in enumerate(p_block_idxs):
if logic_block_idx < transed_idx or logic_block_idx >= to_trans_idx:
continue
for d_head_group_rank in pd_head_mapping[p_head_group_rank]:
p_block_id = req_meta.local_block_ids[p_block_idx]
remote_host = d_block_rank_mapping[logic_block_idx][d_head_group_rank]["host"]
remote_port = d_block_rank_mapping[logic_block_idx][d_head_group_rank]["port"]
d_block_idx = d_block_rank_mapping[logic_block_idx][d_head_group_rank]["block_idx"]
d_block_id = req_meta.remote_block_ids[d_block_idx]
if (remote_host, remote_port) not in transfer_mappings:
transfer_mappings[(remote_host, remote_port)] = {
"local_block_ids": [],
"remote_block_ids": [],
"trans_count": 0,
}
transfer_mappings[(remote_host, remote_port)]["local_block_ids"].append(p_block_id)
transfer_mappings[(remote_host, remote_port)]["remote_block_ids"].append(d_block_id)
for (host, port), block_dict in transfer_mappings.items():
block_dict["trans_count"] = d_trans_count_mapping[(host, port)]
logger.debug(f"MooncakeLayerwiseConnector Request {req_id} transfer tasks: {transfer_mappings}")
return transfer_mappings