KVCache Transfer via Layer-wise Strategy in Disaggregation (#2602)

### What this PR does / why we need it?
See RFC: https://github.com/vllm-project/vllm-ascend/issues/2470 This PR
add a new kv connector for layer-wised kv transfer

### Does this PR introduce _any_ user-facing change?
yes, a new kv connector is added. User can use layer wised feature now.
### How was this patch tested?


- vLLM version: v0.11.0rc3
- vLLM main:
https://github.com/vllm-project/vllm/commit/releases/v0.11.0

---------

Signed-off-by: leichao.lc <leichao139636@163.com>
Signed-off-by: CaveNightingale <2859066733@qq.com>
Signed-off-by: nwpu-zxr <zhouxuerong2@huawei.com>
Signed-off-by: wangxiaoteng <wangxiaoteng@huawei.com>
Signed-off-by: hanxinlong <50882499@qq.com>
Signed-off-by: liziyu <liziyu16@huawei.com>
Co-authored-by: CaveNightingale <2859066733@qq.com>
Co-authored-by: nwpu-zxr <zhouxuerong2@huawei.com>
Co-authored-by: wangxiaoteng <wangxiaoteng@huawei.com>
Co-authored-by: hanxinlong <50882499@qq.com>
This commit is contained in:
Chao Lei
2025-09-30 15:10:29 +08:00
committed by GitHub
parent f8c93d8d24
commit a486ff8c11
10 changed files with 3012 additions and 4 deletions

View File

@@ -94,6 +94,17 @@ class AscendConfig:
raise AssertionError(
"oproj_tensor_parallel_size is only supported in pd scenario and can only be used in D node."
)
self.pd_tp_ratio = 1
if vllm_config.kv_transfer_config is not None and not vllm_config.model_config.is_deepseek_mla:
prefill_tp_size = vllm_config.kv_transfer_config.get_from_extra_config(
"prefill", {"tp_size": 1})["tp_size"]
decode_tp_size = vllm_config.kv_transfer_config.get_from_extra_config(
"decode", {"tp_size": 1})["tp_size"]
pd_tp_ratio: int = prefill_tp_size // decode_tp_size
self.pd_tp_ratio = pd_tp_ratio
if self.pd_tp_ratio == 0:
raise AssertionError(
"Only support P node tp size lagger then D node tp size")
class TorchairGraphConfig:

View File

@@ -31,3 +31,8 @@ KVConnectorFactory.register_connector(
"MooncakeConnectorStoreV1",
"vllm_ascend.distributed.mooncake.mooncake_store_connector_v1",
"MooncakeConnectorV1")
KVConnectorFactory.register_connector(
"MooncakeLayerwiseConnector",
"vllm_ascend.distributed.mooncake_layerwise_connector",
"MooncakeLayerwiseConnector")

View File

@@ -1109,4 +1109,4 @@ def ensure_zmq_recv(
logger.error(f"Receive failed after all retries: {e}")
raise RuntimeError(
f"Failed to receive data after {max_retries} "
f"retries: {e}")
f"retries: {e}")

File diff suppressed because it is too large Load Diff

View File

@@ -13,6 +13,7 @@ _MC2: Optional[GroupCoordinator] = None
_MLP_TP: Optional[GroupCoordinator] = None
_OTP: Optional[GroupCoordinator] = None
_LMTP: Optional[GroupCoordinator] = None
_P_TP: Optional[GroupCoordinator] = None
def get_mc2_group() -> GroupCoordinator:
@@ -37,6 +38,12 @@ def get_mlp_tp_group() -> GroupCoordinator:
return _MLP_TP
def get_p_tp_group() -> GroupCoordinator:
assert _P_TP is not None, (
"distributed prefill tensor parallel group is not initialized")
return _P_TP
def model_parallel_initialized():
return (_MC2 is not None)
@@ -54,6 +61,22 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
all_ranks = torch.arange(world_size).reshape(
-1, parallel_config.data_parallel_size *
parallel_config.tensor_parallel_size)
pd_tp_ratio = get_ascend_config().pd_tp_ratio
global _P_TP
assert _P_TP is None, (
"distributed prefill tensor parallel group is already initialized")
prefill_tensor_model_parallel_size = pd_tp_ratio if \
pd_tp_ratio > 0 and pd_tp_ratio < parallel_config.tensor_parallel_size else parallel_config.tensor_parallel_size
group_ranks = all_ranks.view(-1,
prefill_tensor_model_parallel_size).unbind(0)
group_ranks = [x.tolist() for x in group_ranks]
num = get_world_group().local_rank // pd_tp_ratio
_P_TP = init_model_parallel_group(group_ranks,
get_world_group().local_rank,
backend,
group_name=f"p_tp_{num}")
global _MC2
group_ranks = all_ranks.unbind(0)
group_ranks = [x.tolist() for x in group_ranks]
@@ -142,3 +165,8 @@ def destroy_ascend_model_parallel():
if _OTP:
_OTP.destroy()
_OTP = None
global _P_TP
if _P_TP:
_P_TP.destroy()
_P_TP = None

View File

@@ -0,0 +1,47 @@
import torch
import torch.distributed as dist
from vllm_ascend.distributed.parallel_state import get_p_tp_group
def kv_alltoall_and_rearrange(pd_tp_ratio: int, key: torch.Tensor,
value: torch.TensorType):
if pd_tp_ratio <= 1:
return None, None
elif key is None or value is None:
raise ValueError("key or value is None")
k_output = alltoall_and_rearrange(pd_tp_ratio, key)
v_output = alltoall_and_rearrange(pd_tp_ratio, value)
return k_output, v_output
def alltoall_and_rearrange(tp_ratio: int, input_tensor: torch.Tensor):
num_kv_heads = input_tensor.size(1)
output_tensor = torch.zeros_like(input_tensor)
dist.all_to_all_single(output_tensor,
input_tensor,
group=get_p_tp_group().device_group)
input_tensor = 0
result = rearrange_output(output_tensor, tp_ratio, num_kv_heads)
output_tensor = 0
return result
def rearrange_output(base_output: torch.Tensor, cut_num: int,
num_kv_heads: int):
size_0 = base_output.size(0)
if size_0 % cut_num != 0:
raise ValueError(
f"The size of dim 0 [{size_0}] must be divisible by the cut_num [{cut_num}]"
)
chunk_size = size_0 // cut_num
reshaped = base_output.view(cut_num, chunk_size, -1)
transposed = reshaped.transpose(0, 1)
return transposed.contiguous().view(size_0, num_kv_heads, -1)
def align_memory(tensor: torch.Tensor, alignment: int) -> torch.Tensor:
data_ptr = tensor.data_ptr()
aligned_addr = (data_ptr + alignment - 1) // alignment * alignment
offset = (aligned_addr - data_ptr) // tensor.element_size()
return tensor[int(offset):]