### What this PR does / why we need it?
mooncake layerwise support pcp function
PCP (Prefill Context Parallelism) Support: Introduced explicit support
for Prefill Context Parallelism (PCP) and Decode Context Parallelism
(DCP) in the Mooncake layerwise KV cache transfer mechanism, allowing
for more granular control and awareness of parallel configurations
during data transfer.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
By ci
- vLLM version: v0.15.0
- vLLM main:
d7e17aaacd
---------
Signed-off-by: wangxiaoteng <wangxiaoteng@huawei.com>
Signed-off-by: liziyu <liziyu16@huawei.com>
Co-authored-by: liziyu <liziyu16@huawei.com>
296 lines
13 KiB
Python
296 lines
13 KiB
Python
import math
|
||
import os
|
||
from collections import defaultdict
|
||
from dataclasses import dataclass
|
||
from typing import Any
|
||
|
||
import torch
|
||
import torch.distributed as dist
|
||
from vllm.logger import logger
|
||
|
||
from vllm_ascend.distributed.parallel_state import get_p_tp_group
|
||
|
||
|
||
def kv_alltoall_and_rearrange(pd_tp_ratio: int, key: torch.Tensor, value: torch.TensorType):
|
||
if pd_tp_ratio <= 1:
|
||
return None, None
|
||
elif key is None or value is None:
|
||
raise ValueError("key or value is None")
|
||
k_output = alltoall_and_rearrange(pd_tp_ratio, key)
|
||
v_output = alltoall_and_rearrange(pd_tp_ratio, value)
|
||
return k_output, v_output
|
||
|
||
|
||
def alltoall_and_rearrange(tp_ratio: int, input_tensor: torch.Tensor):
|
||
num_kv_heads = input_tensor.size(1)
|
||
output_tensor = torch.zeros_like(input_tensor)
|
||
dist.all_to_all_single(output_tensor, input_tensor, group=get_p_tp_group().device_group)
|
||
input_tensor = 0
|
||
result = rearrange_output(output_tensor, tp_ratio, num_kv_heads)
|
||
output_tensor = 0
|
||
return result
|
||
|
||
|
||
def rearrange_output(base_output: torch.Tensor, cut_num: int, num_kv_heads: int):
|
||
size_0 = base_output.size(0)
|
||
if size_0 % cut_num != 0:
|
||
raise ValueError(f"The size of dim 0 [{size_0}] must be divisible by the cut_num [{cut_num}]")
|
||
chunk_size = size_0 // cut_num
|
||
reshaped = base_output.view(cut_num, chunk_size, -1)
|
||
transposed = reshaped.transpose(0, 1)
|
||
return transposed.contiguous().view(size_0, num_kv_heads, -1)
|
||
|
||
|
||
def align_memory(tensor: torch.Tensor, alignment: int) -> torch.Tensor:
|
||
data_ptr = tensor.data_ptr()
|
||
aligned_addr = (data_ptr + alignment - 1) // alignment * alignment
|
||
offset = (aligned_addr - data_ptr) // tensor.element_size()
|
||
return tensor[int(offset) :]
|
||
|
||
|
||
def get_transfer_timeout_value():
|
||
ascend_transfer_timeout = os.getenv("ASCEND_TRANSFER_TIMEOUT", "")
|
||
if len(ascend_transfer_timeout) > 0:
|
||
return int(ascend_transfer_timeout)
|
||
hccl_rdma_timeout = int(os.getenv("HCCL_RDMA_TIMEOUT", "20")) # type: ignore
|
||
hccl_rdma_retry_cnt = int(os.getenv("HCCL_RDMA_RETRY_CNT", "7")) # type: ignore
|
||
return int((4.096 * (2**hccl_rdma_timeout)) * hccl_rdma_retry_cnt // 1000 + 3000)
|
||
|
||
|
||
@dataclass
|
||
class parallel_info:
|
||
tp_size: int
|
||
pcp_size: int
|
||
dcp_size: int
|
||
use_mla: bool
|
||
pd_head_ratio: int
|
||
|
||
|
||
def get_cp_group(tp: int, heads: int, dcp: int):
|
||
# Partition the second dimension of [pcp][head_group][dcp] to obtain a complete head group
|
||
# head_group is all blocks for request in the same head
|
||
# tp8 dcp2 heads4 return[[0,1,2,3]]
|
||
# tp8 dcp1 heads4 return[[0,2,4,6],[1,3,5,7]]
|
||
step = tp // heads
|
||
if step == 0:
|
||
return [[i for i in range(tp // dcp)]]
|
||
else:
|
||
return [
|
||
set([k // dcp for h in range(heads) for k in range(h * step + i * dcp, h * step + (i + 1) * dcp)])
|
||
for i in range(step // dcp)
|
||
]
|
||
|
||
|
||
def context_parallel_parameters_check(
|
||
remote_pcp_size: int,
|
||
remote_dcp_size: int,
|
||
p_parallel_info: parallel_info,
|
||
d_parallel_info: parallel_info,
|
||
total_num_kv_heads: int,
|
||
):
|
||
# Check whether the pcp–dcp ratio is supported
|
||
assert (p_parallel_info.pcp_size * p_parallel_info.dcp_size) % (remote_pcp_size * remote_dcp_size) == 0
|
||
if not p_parallel_info.use_mla:
|
||
p_node_heads_per_rank = math.ceil(total_num_kv_heads / p_parallel_info.tp_size)
|
||
d_node_heads_per_rank = math.ceil(total_num_kv_heads / d_parallel_info.dcp_size)
|
||
assert d_node_heads_per_rank % p_node_heads_per_rank == 0
|
||
|
||
|
||
def get_tp_rank_head_mapping(num_key_value_heads: int, tp_size: int):
|
||
# Get the head_idx corresponding to the tp_rank, {tp_rank:[head_indx]}
|
||
mapping = {}
|
||
if tp_size <= num_key_value_heads:
|
||
if num_key_value_heads % tp_size != 0:
|
||
raise ValueError(f"Number of heads ({num_key_value_heads}) cannot be evenly divided by TP ({tp_size}).")
|
||
|
||
heads_per_rank = num_key_value_heads // tp_size
|
||
|
||
for rank in range(tp_size):
|
||
start_idx = rank * heads_per_rank
|
||
end_idx = start_idx + heads_per_rank
|
||
mapping[rank] = list(range(start_idx, end_idx))
|
||
else:
|
||
if tp_size % num_key_value_heads != 0:
|
||
raise ValueError(f"Number of heads ({num_key_value_heads}) cannot be evenly divided by TP ({tp_size}).")
|
||
ranks_per_head = tp_size // num_key_value_heads
|
||
for rank in range(tp_size):
|
||
head_idx = rank // ranks_per_head
|
||
mapping[rank] = [head_idx]
|
||
return mapping
|
||
|
||
|
||
def get_head_group_mapping(num_key_value_heads: int, tp_size: int, num_groups: int, select_cp_group: list[int]):
|
||
# Get the mapping dictionary, where the key is head_group_rank and the value is head_idx
|
||
if tp_size % num_groups != 0:
|
||
raise ValueError(
|
||
f"Total number of devices ({tp_size}) cannot be divided by the number of groups ({num_groups})."
|
||
)
|
||
ranks_per_group = tp_size // num_groups
|
||
tp_mapping = get_tp_rank_head_mapping(num_key_value_heads, tp_size)
|
||
group_mapping = {}
|
||
for group_rank in range(num_groups):
|
||
if group_rank in select_cp_group:
|
||
start_rank = group_rank * ranks_per_group
|
||
end_rank = start_rank + ranks_per_group
|
||
heads_set = set()
|
||
|
||
for rank in range(start_rank, end_rank):
|
||
heads_set.update(tp_mapping[rank])
|
||
group_mapping[group_rank] = sorted(list(heads_set))
|
||
return group_mapping
|
||
|
||
|
||
def get_local_remote_block_port_mappings(
|
||
to_trans_idx: int,
|
||
p_parallel_info: parallel_info,
|
||
d_parallel_info: parallel_info,
|
||
d_hosts: list[str],
|
||
d_port: int,
|
||
selected_p_cp_group: list[int],
|
||
selected_d_cp_group: list[int],
|
||
prompt_len: int,
|
||
block_size: int,
|
||
req_meta,
|
||
total_num_kv_heads: int,
|
||
req_id: str,
|
||
):
|
||
p_head_group_size = p_parallel_info.tp_size // p_parallel_info.dcp_size
|
||
d_head_group_size = d_parallel_info.tp_size // d_parallel_info.dcp_size
|
||
world_size = d_parallel_info.pcp_size * d_head_group_size * d_parallel_info.dcp_size
|
||
# Compute which logic_block_idx corresponds to each tp_rank
|
||
p_rank_block_mapping: list[list[list[list[int]]]] = [
|
||
[[[] for _ in range(p_parallel_info.dcp_size)] for _ in range(p_head_group_size)]
|
||
for _ in range(p_parallel_info.pcp_size)
|
||
]
|
||
for logic_block_idx in range(to_trans_idx):
|
||
pcp_rank = (logic_block_idx // p_parallel_info.dcp_size) % p_parallel_info.pcp_size
|
||
dcp_rank = logic_block_idx % p_parallel_info.dcp_size
|
||
for p_head_group_rank in range(p_head_group_size):
|
||
if p_head_group_rank in selected_p_cp_group:
|
||
p_rank_block_mapping[pcp_rank][p_head_group_rank][dcp_rank].append(logic_block_idx)
|
||
|
||
# Find the remote device that holds the logic_block_idx
|
||
d_block_rank_mapping: dict[int, dict[int, dict[str, Any]]] = defaultdict(lambda: defaultdict(dict))
|
||
for logic_block_idx in range(to_trans_idx):
|
||
pcp_rank = (logic_block_idx // d_parallel_info.dcp_size) % d_parallel_info.pcp_size
|
||
for d_head_group_rank in range(d_head_group_size):
|
||
if d_head_group_rank in selected_d_cp_group:
|
||
dcp_rank = logic_block_idx % d_parallel_info.dcp_size
|
||
world_rank = (
|
||
pcp_rank * d_head_group_size * d_parallel_info.dcp_size
|
||
+ d_head_group_rank * d_parallel_info.dcp_size
|
||
+ dcp_rank
|
||
)
|
||
world_size = d_parallel_info.pcp_size * d_head_group_size * d_parallel_info.dcp_size
|
||
host = d_hosts[(len(d_hosts) * world_rank) // world_size]
|
||
port = d_port + world_rank
|
||
block_idx = (logic_block_idx - (pcp_rank * d_parallel_info.pcp_size + dcp_rank)) // (
|
||
d_parallel_info.pcp_size * d_parallel_info.dcp_size
|
||
)
|
||
d_block_rank_mapping[logic_block_idx][d_head_group_rank] = {
|
||
"pcp_rank": pcp_rank,
|
||
"dcp_rank": dcp_rank,
|
||
"host": host,
|
||
"port": port,
|
||
"block_idx": block_idx,
|
||
}
|
||
# Get how many times each device should receive done_single for this request
|
||
d_trans_count_mapping = {}
|
||
trans_block_size = math.ceil(prompt_len / block_size) # Total number of blocks
|
||
transed_block_size = math.ceil(req_meta.remote_cache_tokens / block_size) # Number of prefix cache hit blocks
|
||
d_cp_size = d_parallel_info.pcp_size * d_parallel_info.dcp_size
|
||
for d_pcp_rank in range(d_parallel_info.pcp_size):
|
||
for d_head_group_rank in range(d_head_group_size):
|
||
for d_dcp_rank in range(d_parallel_info.dcp_size):
|
||
if trans_block_size >= (p_parallel_info.pcp_size * p_parallel_info.dcp_size):
|
||
trans_count = (p_parallel_info.pcp_size * p_parallel_info.dcp_size) // d_cp_size
|
||
else:
|
||
current_rank_idx = d_pcp_rank * d_parallel_info.dcp_size + d_dcp_rank
|
||
total_global_blocks = transed_block_size + trans_block_size
|
||
|
||
target_total_count = total_global_blocks // d_cp_size
|
||
if current_rank_idx < (total_global_blocks % d_cp_size):
|
||
target_total_count += 1
|
||
|
||
prev_processed_count = transed_block_size // d_cp_size
|
||
if current_rank_idx < (transed_block_size % d_cp_size):
|
||
prev_processed_count += 1
|
||
|
||
trans_count = target_total_count - prev_processed_count
|
||
world_rank = (
|
||
d_pcp_rank * d_head_group_size * d_parallel_info.dcp_size
|
||
+ d_head_group_rank * d_parallel_info.dcp_size
|
||
+ d_dcp_rank
|
||
)
|
||
host = d_hosts[(len(d_hosts) * world_rank) // world_size]
|
||
port = d_port + world_rank
|
||
d_trans_count_mapping[(host, port)] = trans_count * p_parallel_info.pd_head_ratio
|
||
|
||
# Compute the mapping between local and remote head_group_rank
|
||
p_tp_rank_head_mapping = get_head_group_mapping(
|
||
total_num_kv_heads, p_parallel_info.tp_size, p_head_group_size, selected_p_cp_group
|
||
)
|
||
d_tp_rank_head_mapping = get_head_group_mapping(
|
||
total_num_kv_heads, d_parallel_info.tp_size, d_head_group_size, selected_d_cp_group
|
||
)
|
||
head_to_d_groups = defaultdict(set)
|
||
for d_rank, heads in d_tp_rank_head_mapping.items():
|
||
for head in heads:
|
||
head_to_d_groups[head].add(d_rank)
|
||
pd_head_mapping = {}
|
||
for p_rank, p_heads in p_tp_rank_head_mapping.items():
|
||
target_d_ranks = set()
|
||
for head in p_heads:
|
||
if head in head_to_d_groups:
|
||
target_d_ranks.update(head_to_d_groups[head])
|
||
else:
|
||
logger.info(f"Warning: Head {head} exists in P but not in D mapping.")
|
||
pd_head_mapping[p_rank] = sorted(list(target_d_ranks))
|
||
logger.debug(
|
||
f"MooncakeLayerwiseConnector _get_kv_split_metadata {req_id=} "
|
||
f"P-side logic_block to rank mapping: {p_rank_block_mapping}, "
|
||
f"D-side logic_block to rank mapping: {d_block_rank_mapping}, "
|
||
f"P&D head_group_rank mapping: {pd_head_mapping}"
|
||
)
|
||
return p_rank_block_mapping, d_block_rank_mapping, pd_head_mapping, d_trans_count_mapping
|
||
|
||
|
||
def get_transfer_mappings(
|
||
p_rank_block_mapping: list[list[list[list[int]]]],
|
||
d_block_rank_mapping: dict[int, dict[int, dict[str, Any]]],
|
||
pd_head_mapping: dict[int, set],
|
||
d_trans_count_mapping: dict[tuple[str, int], int],
|
||
req_meta,
|
||
p_parallel_info: parallel_info,
|
||
req_id: str,
|
||
transed_idx: int,
|
||
to_trans_idx: int,
|
||
tp_rank: int,
|
||
pcp_rank: int,
|
||
dcp_rank: int,
|
||
):
|
||
transfer_mappings: dict[tuple[str, int], dict[str, Any]] = {}
|
||
p_head_group_rank = (tp_rank - dcp_rank) // p_parallel_info.dcp_size
|
||
p_block_idxs: list[int] = p_rank_block_mapping[pcp_rank][p_head_group_rank][dcp_rank]
|
||
for p_block_idx, logic_block_idx in enumerate(p_block_idxs):
|
||
if logic_block_idx < transed_idx or logic_block_idx >= to_trans_idx:
|
||
continue
|
||
for d_head_group_rank in pd_head_mapping[p_head_group_rank]:
|
||
p_block_id = req_meta.local_block_ids[p_block_idx]
|
||
remote_host = d_block_rank_mapping[logic_block_idx][d_head_group_rank]["host"]
|
||
remote_port = d_block_rank_mapping[logic_block_idx][d_head_group_rank]["port"]
|
||
d_block_idx = d_block_rank_mapping[logic_block_idx][d_head_group_rank]["block_idx"]
|
||
d_block_id = req_meta.remote_block_ids[d_block_idx]
|
||
if (remote_host, remote_port) not in transfer_mappings:
|
||
transfer_mappings[(remote_host, remote_port)] = {
|
||
"local_block_ids": [],
|
||
"remote_block_ids": [],
|
||
"trans_count": 0,
|
||
}
|
||
transfer_mappings[(remote_host, remote_port)]["local_block_ids"].append(p_block_id)
|
||
transfer_mappings[(remote_host, remote_port)]["remote_block_ids"].append(d_block_id)
|
||
for (host, port), block_dict in transfer_mappings.items():
|
||
block_dict["trans_count"] = d_trans_count_mapping[(host, port)]
|
||
logger.debug(f"MooncakeLayerwiseConnector Request {req_id} transfer tasks: {transfer_mappings}")
|
||
return transfer_mappings
|