### What this PR does / why we need it?
Mooncake Layerwise Connector supports hybrid attention manager with
multiple kvcache groups.
### Does this PR introduce _any_ user-facing change?
Yes.
### How was this patch tested?
By CI.
- vLLM version: v0.16.0
- vLLM main:
15d76f74e2
---------
Signed-off-by: nwpu-zxr <zhouxuerong2@huawei.com>
299 lines
13 KiB
Python
299 lines
13 KiB
Python
import math
|
||
import os
|
||
from collections import defaultdict
|
||
from dataclasses import dataclass
|
||
from typing import Any
|
||
|
||
import torch
|
||
import torch.distributed as dist
|
||
from vllm.logger import logger
|
||
|
||
from vllm_ascend.distributed.parallel_state import get_p_tp_group
|
||
|
||
|
||
def kv_alltoall_and_rearrange(pd_tp_ratio: int, key: torch.Tensor, value: torch.TensorType):
|
||
if pd_tp_ratio <= 1:
|
||
return None, None
|
||
elif key is None or value is None:
|
||
raise ValueError("key or value is None")
|
||
k_output = alltoall_and_rearrange(pd_tp_ratio, key)
|
||
v_output = alltoall_and_rearrange(pd_tp_ratio, value)
|
||
return k_output, v_output
|
||
|
||
|
||
def alltoall_and_rearrange(tp_ratio: int, input_tensor: torch.Tensor):
|
||
num_kv_heads = input_tensor.size(1)
|
||
output_tensor = torch.zeros_like(input_tensor)
|
||
dist.all_to_all_single(output_tensor, input_tensor, group=get_p_tp_group().device_group)
|
||
input_tensor = 0
|
||
result = rearrange_output(output_tensor, tp_ratio, num_kv_heads)
|
||
output_tensor = 0
|
||
return result
|
||
|
||
|
||
def rearrange_output(base_output: torch.Tensor, cut_num: int, num_kv_heads: int):
|
||
size_0 = base_output.size(0)
|
||
if size_0 % cut_num != 0:
|
||
raise ValueError(f"The size of dim 0 [{size_0}] must be divisible by the cut_num [{cut_num}]")
|
||
chunk_size = size_0 // cut_num
|
||
reshaped = base_output.view(cut_num, chunk_size, -1)
|
||
transposed = reshaped.transpose(0, 1)
|
||
return transposed.contiguous().view(size_0, num_kv_heads, -1)
|
||
|
||
|
||
def align_memory(tensor: torch.Tensor, alignment: int) -> torch.Tensor:
|
||
data_ptr = tensor.data_ptr()
|
||
aligned_addr = (data_ptr + alignment - 1) // alignment * alignment
|
||
offset = (aligned_addr - data_ptr) // tensor.element_size()
|
||
return tensor[int(offset) :]
|
||
|
||
|
||
def get_transfer_timeout_value():
|
||
ascend_transfer_timeout = os.getenv("ASCEND_TRANSFER_TIMEOUT", "")
|
||
if len(ascend_transfer_timeout) > 0:
|
||
return int(ascend_transfer_timeout)
|
||
hccl_rdma_timeout = int(os.getenv("HCCL_RDMA_TIMEOUT", "20")) # type: ignore
|
||
hccl_rdma_retry_cnt = int(os.getenv("HCCL_RDMA_RETRY_CNT", "7")) # type: ignore
|
||
return int((4.096 * (2**hccl_rdma_timeout)) * hccl_rdma_retry_cnt // 1000 + 3000)
|
||
|
||
|
||
@dataclass
|
||
class parallel_info:
|
||
tp_size: int
|
||
pcp_size: int
|
||
dcp_size: int
|
||
use_mla: bool
|
||
pd_head_ratio: int
|
||
|
||
|
||
def get_cp_group(tp: int, heads: int, dcp: int):
|
||
# Partition the second dimension of [pcp][head_group][dcp] to obtain a complete head group
|
||
# head_group is all blocks for request in the same head
|
||
# tp8 dcp2 heads4 return[[0,1,2,3]]
|
||
# tp8 dcp1 heads4 return[[0,2,4,6],[1,3,5,7]]
|
||
step = tp // heads
|
||
if step == 0:
|
||
return [[i for i in range(tp // dcp)]]
|
||
else:
|
||
return [
|
||
set([k // dcp for h in range(heads) for k in range(h * step + i * dcp, h * step + (i + 1) * dcp)])
|
||
for i in range(step // dcp)
|
||
]
|
||
|
||
|
||
def context_parallel_parameters_check(
|
||
remote_pcp_size: int,
|
||
remote_dcp_size: int,
|
||
p_parallel_info: parallel_info,
|
||
d_parallel_info: parallel_info,
|
||
total_num_kv_heads: int,
|
||
):
|
||
# Check whether the pcp–dcp ratio is supported
|
||
assert (p_parallel_info.pcp_size * p_parallel_info.dcp_size) % (remote_pcp_size * remote_dcp_size) == 0
|
||
if not p_parallel_info.use_mla:
|
||
p_node_heads_per_rank = math.ceil(total_num_kv_heads / p_parallel_info.tp_size)
|
||
d_node_heads_per_rank = math.ceil(total_num_kv_heads / d_parallel_info.dcp_size)
|
||
assert d_node_heads_per_rank % p_node_heads_per_rank == 0
|
||
|
||
|
||
def get_tp_rank_head_mapping(num_key_value_heads: int, tp_size: int):
|
||
# Get the head_idx corresponding to the tp_rank, {tp_rank:[head_indx]}
|
||
mapping = {}
|
||
if tp_size <= num_key_value_heads:
|
||
if num_key_value_heads % tp_size != 0:
|
||
raise ValueError(f"Number of heads ({num_key_value_heads}) cannot be evenly divided by TP ({tp_size}).")
|
||
|
||
heads_per_rank = num_key_value_heads // tp_size
|
||
|
||
for rank in range(tp_size):
|
||
start_idx = rank * heads_per_rank
|
||
end_idx = start_idx + heads_per_rank
|
||
mapping[rank] = list(range(start_idx, end_idx))
|
||
else:
|
||
if tp_size % num_key_value_heads != 0:
|
||
raise ValueError(f"Number of heads ({num_key_value_heads}) cannot be evenly divided by TP ({tp_size}).")
|
||
ranks_per_head = tp_size // num_key_value_heads
|
||
for rank in range(tp_size):
|
||
head_idx = rank // ranks_per_head
|
||
mapping[rank] = [head_idx]
|
||
return mapping
|
||
|
||
|
||
def get_head_group_mapping(num_key_value_heads: int, tp_size: int, num_groups: int, select_cp_group: list[int]):
|
||
# Get the mapping dictionary, where the key is head_group_rank and the value is head_idx
|
||
if tp_size % num_groups != 0:
|
||
raise ValueError(
|
||
f"Total number of devices ({tp_size}) cannot be divided by the number of groups ({num_groups})."
|
||
)
|
||
ranks_per_group = tp_size // num_groups
|
||
tp_mapping = get_tp_rank_head_mapping(num_key_value_heads, tp_size)
|
||
group_mapping = {}
|
||
for group_rank in range(num_groups):
|
||
if group_rank in select_cp_group:
|
||
start_rank = group_rank * ranks_per_group
|
||
end_rank = start_rank + ranks_per_group
|
||
heads_set = set()
|
||
|
||
for rank in range(start_rank, end_rank):
|
||
heads_set.update(tp_mapping[rank])
|
||
group_mapping[group_rank] = sorted(list(heads_set))
|
||
return group_mapping
|
||
|
||
|
||
def get_local_remote_block_port_mappings(
|
||
to_trans_idx: int,
|
||
p_parallel_info: parallel_info,
|
||
d_parallel_info: parallel_info,
|
||
d_hosts: list[str],
|
||
d_port: int,
|
||
selected_p_cp_group: list[int],
|
||
selected_d_cp_group: list[int],
|
||
prompt_len: int,
|
||
block_size: int,
|
||
req_meta,
|
||
total_num_kv_heads: int,
|
||
req_id: str,
|
||
):
|
||
p_head_group_size = p_parallel_info.tp_size // p_parallel_info.dcp_size
|
||
d_head_group_size = d_parallel_info.tp_size // d_parallel_info.dcp_size
|
||
world_size = d_parallel_info.pcp_size * d_head_group_size * d_parallel_info.dcp_size
|
||
# Compute which logic_block_idx corresponds to each tp_rank
|
||
p_rank_block_mapping: list[list[list[list[int]]]] = [
|
||
[[[] for _ in range(p_parallel_info.dcp_size)] for _ in range(p_head_group_size)]
|
||
for _ in range(p_parallel_info.pcp_size)
|
||
]
|
||
for logic_block_idx in range(to_trans_idx):
|
||
pcp_rank = (logic_block_idx // p_parallel_info.dcp_size) % p_parallel_info.pcp_size
|
||
dcp_rank = logic_block_idx % p_parallel_info.dcp_size
|
||
for p_head_group_rank in range(p_head_group_size):
|
||
if p_head_group_rank in selected_p_cp_group:
|
||
p_rank_block_mapping[pcp_rank][p_head_group_rank][dcp_rank].append(logic_block_idx)
|
||
|
||
# Find the remote device that holds the logic_block_idx
|
||
d_block_rank_mapping: dict[int, dict[int, dict[str, Any]]] = defaultdict(lambda: defaultdict(dict))
|
||
for logic_block_idx in range(to_trans_idx):
|
||
pcp_rank = (logic_block_idx // d_parallel_info.dcp_size) % d_parallel_info.pcp_size
|
||
for d_head_group_rank in range(d_head_group_size):
|
||
if d_head_group_rank in selected_d_cp_group:
|
||
dcp_rank = logic_block_idx % d_parallel_info.dcp_size
|
||
world_rank = (
|
||
pcp_rank * d_head_group_size * d_parallel_info.dcp_size
|
||
+ d_head_group_rank * d_parallel_info.dcp_size
|
||
+ dcp_rank
|
||
)
|
||
world_size = d_parallel_info.pcp_size * d_head_group_size * d_parallel_info.dcp_size
|
||
host = d_hosts[(len(d_hosts) * world_rank) // world_size]
|
||
port = d_port + world_rank
|
||
block_idx = (logic_block_idx - (pcp_rank * d_parallel_info.pcp_size + dcp_rank)) // (
|
||
d_parallel_info.pcp_size * d_parallel_info.dcp_size
|
||
)
|
||
d_block_rank_mapping[logic_block_idx][d_head_group_rank] = {
|
||
"pcp_rank": pcp_rank,
|
||
"dcp_rank": dcp_rank,
|
||
"host": host,
|
||
"port": port,
|
||
"block_idx": block_idx,
|
||
}
|
||
# Get how many times each device should receive done_single for this request
|
||
d_trans_count_mapping = {}
|
||
trans_block_size = math.ceil(prompt_len / block_size) # Total number of blocks
|
||
transed_block_size = math.ceil(req_meta.remote_cache_tokens / block_size) # Number of prefix cache hit blocks
|
||
d_cp_size = d_parallel_info.pcp_size * d_parallel_info.dcp_size
|
||
for d_pcp_rank in range(d_parallel_info.pcp_size):
|
||
for d_head_group_rank in range(d_head_group_size):
|
||
for d_dcp_rank in range(d_parallel_info.dcp_size):
|
||
if trans_block_size >= (p_parallel_info.pcp_size * p_parallel_info.dcp_size):
|
||
trans_count = (p_parallel_info.pcp_size * p_parallel_info.dcp_size) // d_cp_size
|
||
else:
|
||
current_rank_idx = d_pcp_rank * d_parallel_info.dcp_size + d_dcp_rank
|
||
total_global_blocks = transed_block_size + trans_block_size
|
||
|
||
target_total_count = total_global_blocks // d_cp_size
|
||
if current_rank_idx < (total_global_blocks % d_cp_size):
|
||
target_total_count += 1
|
||
|
||
prev_processed_count = transed_block_size // d_cp_size
|
||
if current_rank_idx < (transed_block_size % d_cp_size):
|
||
prev_processed_count += 1
|
||
|
||
trans_count = target_total_count - prev_processed_count
|
||
world_rank = (
|
||
d_pcp_rank * d_head_group_size * d_parallel_info.dcp_size
|
||
+ d_head_group_rank * d_parallel_info.dcp_size
|
||
+ d_dcp_rank
|
||
)
|
||
host = d_hosts[(len(d_hosts) * world_rank) // world_size]
|
||
port = d_port + world_rank
|
||
d_trans_count_mapping[(host, port)] = trans_count * p_parallel_info.pd_head_ratio
|
||
|
||
# Compute the mapping between local and remote head_group_rank
|
||
p_tp_rank_head_mapping = get_head_group_mapping(
|
||
total_num_kv_heads, p_parallel_info.tp_size, p_head_group_size, selected_p_cp_group
|
||
)
|
||
d_tp_rank_head_mapping = get_head_group_mapping(
|
||
total_num_kv_heads, d_parallel_info.tp_size, d_head_group_size, selected_d_cp_group
|
||
)
|
||
head_to_d_groups = defaultdict(set)
|
||
for d_rank, heads in d_tp_rank_head_mapping.items():
|
||
for head in heads:
|
||
head_to_d_groups[head].add(d_rank)
|
||
pd_head_mapping = {}
|
||
for p_rank, p_heads in p_tp_rank_head_mapping.items():
|
||
target_d_ranks = set()
|
||
for head in p_heads:
|
||
if head in head_to_d_groups:
|
||
target_d_ranks.update(head_to_d_groups[head])
|
||
else:
|
||
logger.info(f"Warning: Head {head} exists in P but not in D mapping.")
|
||
pd_head_mapping[p_rank] = sorted(list(target_d_ranks))
|
||
logger.debug(
|
||
f"MooncakeLayerwiseConnector _get_kv_split_metadata {req_id=} "
|
||
f"P-side logic_block to rank mapping: {p_rank_block_mapping}, "
|
||
f"D-side logic_block to rank mapping: {d_block_rank_mapping}, "
|
||
f"P&D head_group_rank mapping: {pd_head_mapping}"
|
||
)
|
||
return p_rank_block_mapping, d_block_rank_mapping, pd_head_mapping, d_trans_count_mapping
|
||
|
||
|
||
def get_transfer_mappings(
|
||
p_rank_block_mapping: list[list[list[list[int]]]],
|
||
d_block_rank_mapping: dict[int, dict[int, dict[str, Any]]],
|
||
pd_head_mapping: dict[int, set],
|
||
d_trans_count_mapping: dict[tuple[str, int], int],
|
||
req_meta,
|
||
block_group_idx: int,
|
||
p_parallel_info: parallel_info,
|
||
req_id: str,
|
||
transed_idx: int,
|
||
to_trans_idx: int,
|
||
tp_rank: int,
|
||
pcp_rank: int,
|
||
dcp_rank: int,
|
||
):
|
||
transfer_mappings: dict[tuple[str, int], dict[str, Any]] = {}
|
||
p_head_group_rank = (tp_rank - dcp_rank) // p_parallel_info.dcp_size
|
||
p_block_idxs: list[int] = p_rank_block_mapping[pcp_rank][p_head_group_rank][dcp_rank]
|
||
p_block_ids = req_meta.local_block_ids[block_group_idx]
|
||
d_block_ids = req_meta.remote_block_ids[block_group_idx]
|
||
for p_block_idx, logic_block_idx in enumerate(p_block_idxs):
|
||
if logic_block_idx < transed_idx or logic_block_idx >= to_trans_idx:
|
||
continue
|
||
for d_head_group_rank in pd_head_mapping[p_head_group_rank]:
|
||
p_block_id = p_block_ids[p_block_idx]
|
||
remote_host = d_block_rank_mapping[logic_block_idx][d_head_group_rank]["host"]
|
||
remote_port = d_block_rank_mapping[logic_block_idx][d_head_group_rank]["port"]
|
||
d_block_idx = d_block_rank_mapping[logic_block_idx][d_head_group_rank]["block_idx"]
|
||
d_block_id = d_block_ids[d_block_idx]
|
||
if (remote_host, remote_port) not in transfer_mappings:
|
||
transfer_mappings[(remote_host, remote_port)] = {
|
||
"local_block_ids": [],
|
||
"remote_block_ids": [],
|
||
"trans_count": 0,
|
||
}
|
||
transfer_mappings[(remote_host, remote_port)]["local_block_ids"].append(p_block_id)
|
||
transfer_mappings[(remote_host, remote_port)]["remote_block_ids"].append(d_block_id)
|
||
for (host, port), block_dict in transfer_mappings.items():
|
||
block_dict["trans_count"] = d_trans_count_mapping[(host, port)]
|
||
logger.debug(f"MooncakeLayerwiseConnector Request {req_id} transfer tasks: {transfer_mappings}")
|
||
return transfer_mappings
|