[P/D] layerwise connector supports DeepSeek-V3.2 sparse attention && Distribute transfer tasks to redundant kv_head cards (#5722)
### What this PR does / why we need it?
Add new function to mooncake layerwise connector, including:
1. supports sparse attention, for DeepSeek-V3.2
2. Distribute transfer tasks to redundant kv_head cards
This PR is related to [[RFC]: CDCP Scheduling for Disaggregated
Prefilling with KV Cache Layerwise Push
Support](https://github.com/vllm-project/vllm-ascend/issues/4842)
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
By CI.
- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef
---------
Signed-off-by: nwpu-zxr <zhouxuerong2@huawei.com>
Signed-off-by: liziyu <liziyu16@huawei.com>
Co-authored-by: liziyu <liziyu16@huawei.com>
This commit is contained in:
@@ -162,6 +162,7 @@ class KVCacheSendingLayerThread(threading.Thread):
|
||||
self.kv_caches_base_addr = kv_cache_base_addr
|
||||
self.total_layers = total_layers
|
||||
self.use_mla = use_mla
|
||||
self.use_sparse = len(block_len) == 3
|
||||
self.block_len = block_len
|
||||
self._decode_tp_size = decode_tp_size
|
||||
self.resharding_stream = resharding_stream
|
||||
@@ -195,17 +196,6 @@ class KVCacheSendingLayerThread(threading.Thread):
|
||||
src_list: list[str] = []
|
||||
dst_list: list[str] = []
|
||||
length_list: list[int] = []
|
||||
# not need to send kv cache
|
||||
if self.tp_rank % self.num_head_replica != 0:
|
||||
logger.debug(
|
||||
f"Cancelling KV cache transfer for request {req_id}. Reason: TP rank excluded from head replication (TP Rank: {self.tp_rank}, Replicas: {self.num_head_replica})."
|
||||
)
|
||||
return (src_list, dst_list, length_list)
|
||||
if self.use_mla and self.tp_rank >= self._decode_tp_size:
|
||||
logger.debug(
|
||||
f"Cancelling KV cache transfer for request {req_id}. Reason: MLA mode active and TP rank outside decoding group (TP Rank: {self.tp_rank}, Decode TP Size: {self._decode_tp_size})."
|
||||
)
|
||||
return (src_list, dst_list, length_list)
|
||||
|
||||
layer_idx = send_task.layer_idx
|
||||
remote_block_ids = req_meta.remote_block_ids
|
||||
@@ -214,21 +204,36 @@ class KVCacheSendingLayerThread(threading.Thread):
|
||||
local_block_ids = req_meta.local_block_ids
|
||||
|
||||
if self.pd_head_ratio == 1:
|
||||
layer_local_kv_base_addr = [
|
||||
local_kv_base_addr[i]
|
||||
for i in [2 * layer_idx, 2 * layer_idx + 1]
|
||||
]
|
||||
layer_remote_kv_base_addr = [
|
||||
remote_kv_base_addrs[i] # type:ignore
|
||||
for i in [2 * layer_idx, 2 * layer_idx + 1]
|
||||
]
|
||||
if self.use_sparse:
|
||||
layer_local_kv_base_addr = [
|
||||
local_kv_base_addr[i] for i in
|
||||
[3 * layer_idx, 3 * layer_idx + 1, 3 * layer_idx + 2]
|
||||
]
|
||||
layer_remote_kv_base_addr = [
|
||||
remote_kv_base_addrs[i] # type:ignore
|
||||
for i in
|
||||
[3 * layer_idx, 3 * layer_idx + 1, 3 * layer_idx + 2]
|
||||
]
|
||||
else:
|
||||
layer_local_kv_base_addr = [
|
||||
local_kv_base_addr[i]
|
||||
for i in [2 * layer_idx, 2 * layer_idx + 1]
|
||||
]
|
||||
layer_remote_kv_base_addr = [
|
||||
remote_kv_base_addrs[i] # type:ignore
|
||||
for i in [2 * layer_idx, 2 * layer_idx + 1]
|
||||
]
|
||||
grouped_remote_block_ids, grouped_local_block_ids = \
|
||||
group_concurrent_contiguous(remote_block_ids, local_block_ids)
|
||||
|
||||
for k, (src_layer_base_addr, dst_layer_base_addr) in enumerate(
|
||||
zip(layer_local_kv_base_addr, layer_remote_kv_base_addr)):
|
||||
block_len = self.block_len[
|
||||
k % 2] if self.use_mla else self.block_len[0]
|
||||
if self.use_mla:
|
||||
block_len = (self.block_len[k % 2])
|
||||
elif self.use_sparse:
|
||||
block_len = (self.block_len[k % 3])
|
||||
else:
|
||||
block_len = (self.block_len[0])
|
||||
for group_remote_block_id, group_local_block_id in zip(
|
||||
grouped_remote_block_ids, grouped_local_block_ids):
|
||||
src = src_layer_base_addr + group_local_block_id[
|
||||
@@ -931,7 +936,9 @@ class MooncakeLayerwiseConnectorWorker:
|
||||
|
||||
# TODO(tms): Find a more robust way to detect and handle MLA
|
||||
self.use_mla = first_kv_cache_tuple[0].size(
|
||||
-1) != first_kv_cache_tuple[1].size(-1)
|
||||
-1) != first_kv_cache_tuple[1].size(-1) and len(
|
||||
first_kv_cache_tuple) == 2
|
||||
self.use_sparse = len(first_kv_cache_tuple) == 3
|
||||
if self.use_mla:
|
||||
# MLA case.[num_block, block_size, 1, hidden_dim]
|
||||
self.num_blocks = first_kv_cache.shape[0]
|
||||
@@ -945,6 +952,21 @@ class MooncakeLayerwiseConnectorWorker:
|
||||
logger.info(
|
||||
"num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s",
|
||||
self.num_blocks, block_shape_norm, block_shape_pe)
|
||||
elif self.use_sparse:
|
||||
self.num_blocks = first_kv_cache.shape[0]
|
||||
block_rank = 3 # [block_size, latent_dim]
|
||||
block_shape_norm = first_kv_cache_tuple[0].shape[-block_rank:]
|
||||
block_shape_pe = first_kv_cache_tuple[1].shape[-block_rank:]
|
||||
block_shape_k = first_kv_cache_tuple[2].shape[-block_rank:]
|
||||
self.block_len = [
|
||||
first_kv_cache[0].element_size() * math.prod(block_shape_norm),
|
||||
first_kv_cache[1].element_size() * math.prod(block_shape_pe),
|
||||
first_kv_cache[2].element_size() * math.prod(block_shape_k)
|
||||
]
|
||||
logger.info(
|
||||
"num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s, block_shape_k: %s",
|
||||
self.num_blocks, block_shape_norm, block_shape_pe,
|
||||
block_shape_k)
|
||||
else:
|
||||
# [num_block, block_size, num_head, hidden_dim]
|
||||
self.num_blocks = first_kv_cache.shape[0]
|
||||
@@ -955,8 +977,9 @@ class MooncakeLayerwiseConnectorWorker:
|
||||
logger.info("num_blocks: %s, block_shape: %s", self.num_blocks,
|
||||
block_shape)
|
||||
|
||||
logger.info("Registering KV_Caches. use_mla: %s, shape %s",
|
||||
self.use_mla, first_kv_cache.shape)
|
||||
logger.info(
|
||||
"Registering KV_Caches. use_mla: %s, use_sparse: %s, shape %s",
|
||||
self.use_mla, self.use_sparse, first_kv_cache.shape)
|
||||
|
||||
self.kv_caches = kv_caches
|
||||
kv_caches_base_addr = []
|
||||
@@ -971,9 +994,17 @@ class MooncakeLayerwiseConnectorWorker:
|
||||
kv_caches_base_addr.append(base_addr)
|
||||
ptrs.append(base_addr)
|
||||
lengths.append(region_len)
|
||||
elif self.use_sparse:
|
||||
for i, cache in enumerate(cache_or_caches, 0):
|
||||
base_addr = cache.data_ptr()
|
||||
region_len = self.num_blocks * self.block_len[i % 3]
|
||||
kv_caches_base_addr.append(base_addr)
|
||||
ptrs.append(base_addr)
|
||||
lengths.append(region_len)
|
||||
else:
|
||||
cache_list = [cache_or_caches
|
||||
] if self.use_mla else cache_or_caches
|
||||
cache_list = [
|
||||
cache_or_caches
|
||||
] if self.use_mla or self.use_sparse else cache_or_caches
|
||||
for cache in cache_list:
|
||||
base_addr = cache.data_ptr()
|
||||
region_len = self.num_blocks * self.block_len[0]
|
||||
@@ -1046,56 +1077,72 @@ class MooncakeLayerwiseConnectorWorker:
|
||||
if self.vllm_config.kv_transfer_config.is_kv_producer and connector_metadata.requests.keys(
|
||||
):
|
||||
# enable decode prefix cache
|
||||
if self.use_mla:
|
||||
reshape_cache_event = attn_metadata[
|
||||
layer_name].reshape_cache_event
|
||||
if self.use_mla or self.use_sparse:
|
||||
num_kv_head = self._decode_tp_size
|
||||
else:
|
||||
reshape_cache_event = attn_metadata.reshape_cache_event
|
||||
num_kv_head = self.vllm_config.model_config.hf_config.num_key_value_heads
|
||||
num_replica_groups = self.tp_size // num_kv_head if self.tp_size >= num_kv_head else 1
|
||||
replica_group_idx = self.tp_rank % num_replica_groups
|
||||
req_ids = sorted(list(connector_metadata.requests.keys()))
|
||||
selected_req_ids = [
|
||||
req_id for i, req_id in enumerate(req_ids)
|
||||
if i % num_replica_groups == replica_group_idx
|
||||
]
|
||||
if selected_req_ids:
|
||||
if self.use_mla or self.use_sparse:
|
||||
reshape_cache_event = attn_metadata[
|
||||
layer_name].reshape_cache_event
|
||||
else:
|
||||
reshape_cache_event = attn_metadata.reshape_cache_event
|
||||
|
||||
if self.pd_head_ratio != 1:
|
||||
assert self.resharding_stream is not None
|
||||
with npu_stream_switch(self.resharding_stream):
|
||||
reshape_cache_event.wait()
|
||||
rearrange_block_ids = sorted({
|
||||
block_id
|
||||
for request in connector_metadata.requests.values()
|
||||
for block_id in request.local_block_ids
|
||||
})
|
||||
if self.pd_head_ratio != 1:
|
||||
assert self.resharding_stream is not None
|
||||
with npu_stream_switch(self.resharding_stream):
|
||||
reshape_cache_event.wait()
|
||||
rearrange_block_ids = sorted({
|
||||
block_id
|
||||
for req_id in selected_req_ids
|
||||
for block_id in
|
||||
connector_metadata.requests[req_id].local_block_ids
|
||||
})
|
||||
|
||||
keys = kv_layer[0][rearrange_block_ids].clone()
|
||||
values = kv_layer[1][rearrange_block_ids].clone()
|
||||
# sort kv caches for each block
|
||||
keys = keys.view(keys.size(0), self.pd_head_ratio, -1,
|
||||
*keys.shape[2:]).transpose(
|
||||
0, 1).reshape_as(keys)
|
||||
values = values.view(values.size(0), self.pd_head_ratio,
|
||||
-1, *values.shape[2:]).transpose(
|
||||
0, 1).reshape_as(values)
|
||||
# reshard kv cache
|
||||
keys = keys.reshape(-1, *kv_layer[0].shape[2:])
|
||||
values = values.reshape(-1, *kv_layer[1].shape[2:])
|
||||
(keys, values) = kv_alltoall_and_rearrange(
|
||||
self.pd_head_ratio, keys, values)
|
||||
else:
|
||||
keys = None
|
||||
values = None
|
||||
rearrange_block_ids = None
|
||||
keys = kv_layer[0][rearrange_block_ids].clone()
|
||||
values = kv_layer[1][rearrange_block_ids].clone()
|
||||
# sort kv caches for each block
|
||||
keys = keys.view(keys.size(0), self.pd_head_ratio, -1,
|
||||
*keys.shape[2:]).transpose(
|
||||
0, 1).reshape_as(keys)
|
||||
values = values.view(values.size(0),
|
||||
self.pd_head_ratio, -1,
|
||||
*values.shape[2:]).transpose(
|
||||
0, 1).reshape_as(values)
|
||||
# reshard kv cache
|
||||
keys = keys.reshape(-1, *kv_layer[0].shape[2:])
|
||||
values = values.reshape(-1, *kv_layer[1].shape[2:])
|
||||
(keys, values) = kv_alltoall_and_rearrange(
|
||||
self.pd_head_ratio, keys, values)
|
||||
else:
|
||||
keys = None
|
||||
values = None
|
||||
rearrange_block_ids = None
|
||||
|
||||
assert self.kv_send_layer_thread is not None
|
||||
assert reshape_cache_event is not None
|
||||
send_task = SendTask(wait_event=reshape_cache_event,
|
||||
k_cache=keys,
|
||||
v_cache=values,
|
||||
layer_idx=self.current_layer,
|
||||
rearrange_block_ids=rearrange_block_ids)
|
||||
for req_id, req_meta in connector_metadata.requests.items():
|
||||
req_meta_update = self.update_decoder_info(req_id, req_meta)
|
||||
logger.debug(
|
||||
f"Add request {req_id} to kv send layer thread. {req_meta_update=}"
|
||||
)
|
||||
send_task.send_request[req_id] = req_meta_update
|
||||
assert self.kv_send_layer_thread is not None
|
||||
assert reshape_cache_event is not None
|
||||
send_task = SendTask(wait_event=reshape_cache_event,
|
||||
k_cache=keys,
|
||||
v_cache=values,
|
||||
layer_idx=self.current_layer,
|
||||
rearrange_block_ids=rearrange_block_ids)
|
||||
for req_id, req_meta in connector_metadata.requests.items():
|
||||
if req_id in selected_req_ids:
|
||||
req_meta_update = self.update_decoder_info(
|
||||
req_id, req_meta)
|
||||
logger.debug(
|
||||
f"Add request {req_id} to kv send layer thread. {req_meta_update=}"
|
||||
)
|
||||
send_task.send_request[req_id] = req_meta_update
|
||||
|
||||
self.kv_send_layer_thread.send_queue.put(send_task)
|
||||
self.kv_send_layer_thread.send_queue.put(send_task)
|
||||
self.current_layer += 1
|
||||
|
||||
def _get_remote_socket(
|
||||
@@ -1121,8 +1168,13 @@ class MooncakeLayerwiseConnectorWorker:
|
||||
|
||||
def update_decoder_info(self, req_id, req_meta):
|
||||
req_meta_update = copy.deepcopy(req_meta)
|
||||
req_meta_update.remote_port = req_meta_update.remote_port + (
|
||||
self.tp_rank // self.pd_tp_ratio) % self._decode_tp_size
|
||||
if self.use_mla or self.use_sparse:
|
||||
pd_tp_ratio = self.tp_size // self._decode_tp_size
|
||||
req_meta_update.remote_port = req_meta_update.remote_port + (
|
||||
self.tp_rank // pd_tp_ratio) % self._decode_tp_size
|
||||
else:
|
||||
req_meta_update.remote_port = req_meta_update.remote_port + (
|
||||
self.tp_rank // self.pd_tp_ratio) % self._decode_tp_size
|
||||
if req_meta_update.remote_engine_id not in self.remote_kv_caches_base_addr or \
|
||||
req_meta_update.remote_port not in self.remote_kv_caches_base_addr[req_meta_update.remote_engine_id]:
|
||||
try:
|
||||
@@ -1146,14 +1198,16 @@ class MooncakeLayerwiseConnectorWorker:
|
||||
logger.info(
|
||||
f"Query to port and kv base addr for request {req_id} from {req_meta_update.remote_host}:{req_meta_update.remote_port} success {agent_meta.kv_caches_base_addr=} {agent_meta.te_rpc_port=}"
|
||||
)
|
||||
session_id = f"{req_meta_update.remote_host}:{agent_meta.te_rpc_port}"
|
||||
ret = self.engine.batch_transfer_sync_write(
|
||||
session_id, [self.kv_caches_base_addr[0]],
|
||||
[agent_meta.kv_caches_base_addr[0]], [128])
|
||||
if ret < 0:
|
||||
logger.error(
|
||||
f"Mooncake transfer failed to create link to device {session_id}"
|
||||
)
|
||||
if self.pd_head_ratio > 1:
|
||||
# for tp inequal, pre-create link to prevent alltoall out of memory
|
||||
session_id = f"{req_meta_update.remote_host}:{agent_meta.te_rpc_port}"
|
||||
ret = self.engine.batch_transfer_sync_write(
|
||||
session_id, [self.kv_caches_base_addr[0]],
|
||||
[agent_meta.kv_caches_base_addr[0]], [128])
|
||||
if ret < 0:
|
||||
logger.error(
|
||||
f"Mooncake transfer failed to create link to device {session_id}"
|
||||
)
|
||||
req_meta_update.remote_te_rpc_port = self.remote_te_port[
|
||||
req_meta_update.remote_engine_id][req_meta_update.remote_port]
|
||||
req_meta_update.remote_kv_caches_base_addr = self.remote_kv_caches_base_addr[
|
||||
|
||||
Reference in New Issue
Block a user