[P/D] layerwise connector supports DeepSeek-V3.2 sparse attention && Distribute transfer tasks to redundant kv_head cards (#5722)
### What this PR does / why we need it?
Add new function to mooncake layerwise connector, including:
1. supports sparse attention, for DeepSeek-V3.2
2. Distribute transfer tasks to redundant kv_head cards
This PR is related to [[RFC]: CDCP Scheduling for Disaggregated
Prefilling with KV Cache Layerwise Push
Support](https://github.com/vllm-project/vllm-ascend/issues/4842)
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
By CI.
- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef
---------
Signed-off-by: nwpu-zxr <zhouxuerong2@huawei.com>
Signed-off-by: liziyu <liziyu16@huawei.com>
Co-authored-by: liziyu <liziyu16@huawei.com>
This commit is contained in:
@@ -170,36 +170,6 @@ class TestKVCacheSendingLayerThread(unittest.TestCase):
|
||||
self.thread._transfer_kv_cache(send_task)
|
||||
self.engine.batch_transfer_sync_write.assert_not_called()
|
||||
|
||||
def test_transfer_skips_when_tp_not_sender(self):
|
||||
|
||||
thread = KVCacheSendingLayerThread(
|
||||
engine=self.engine,
|
||||
total_layers=2,
|
||||
ready_event=self.ready_event,
|
||||
tp_rank=1,
|
||||
pd_head_ratio=1,
|
||||
num_head_replica=2,
|
||||
kv_cache_base_addr=[1000, 2000, 3000, 4000],
|
||||
use_mla=False,
|
||||
block_len=[1024],
|
||||
decode_tp_size=1,
|
||||
first_kv_cache=self.first_kv_cache,
|
||||
k_buffer=MagicMock(),
|
||||
v_buffer=MagicMock(),
|
||||
resharding_stream=MagicMock(),
|
||||
callback_func=MagicMock())
|
||||
req_meta = self.req_meta_base
|
||||
send_task = SendTask(
|
||||
send_request={"req3": req_meta},
|
||||
wait_event=MagicMock(),
|
||||
k_cache=self.key,
|
||||
v_cache=self.value,
|
||||
layer_idx=1,
|
||||
rearrange_block_ids=[],
|
||||
)
|
||||
thread._transfer_kv_cache(send_task)
|
||||
self.engine.batch_transfer_sync_write.assert_not_called()
|
||||
|
||||
@patch(
|
||||
"vllm_ascend.distributed.mooncake_layerwise_connector.group_concurrent_contiguous",
|
||||
side_effect=group_concurrent_contiguous)
|
||||
@@ -425,6 +395,7 @@ class MockVllmConfig:
|
||||
self.parallel_config.data_parallel_size = 1
|
||||
self.parallel_config.data_parallel_rank = 0
|
||||
self.cache_config.block_size = 16
|
||||
self.model_config.hf_config.num_key_value_heads = 1
|
||||
|
||||
self.kv_transfer_config.engine_id = "test_engine"
|
||||
self.kv_transfer_config.kv_port = 5000
|
||||
|
||||
@@ -108,6 +108,7 @@ class AscendSFAMetadata:
|
||||
# chunked prefill by default if no attn_states passed
|
||||
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
|
||||
sfa_cp_context: Optional[SfaCpContext] = None
|
||||
reshape_cache_event: torch.npu.Event = None
|
||||
|
||||
|
||||
M = TypeVar("M", bound=AscendSFAMetadata)
|
||||
@@ -369,6 +370,7 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
self.enable_sfa_cp = enable_dsa_cp()
|
||||
self.local_num_heads = self.num_heads
|
||||
self.vllm_config = get_current_vllm_config()
|
||||
self.is_kv_producer = self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_producer
|
||||
if self.enable_sfa_cp:
|
||||
self.local_num_heads = self.num_heads * self.tp_size
|
||||
self.layer_sharding_kwargs = []
|
||||
@@ -897,11 +899,15 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
k = get_tp_group().all_gather(k, 0)
|
||||
|
||||
if kv_cache is not None:
|
||||
if self.is_kv_producer:
|
||||
attn_metadata.reshape_cache_event = torch.npu.Event()
|
||||
torch_npu.npu_scatter_nd_update_(kv_cache[2].view(-1, k.shape[-1]),
|
||||
attn_metadata.slot_mapping.view(
|
||||
-1, 1),
|
||||
k.view(-1,
|
||||
k.shape[-1])) # b, s, n, d
|
||||
if self.is_kv_producer:
|
||||
attn_metadata.reshape_cache_event.record()
|
||||
|
||||
weights, _ = self.weights_proj(x)
|
||||
weights = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
|
||||
@@ -162,6 +162,7 @@ class KVCacheSendingLayerThread(threading.Thread):
|
||||
self.kv_caches_base_addr = kv_cache_base_addr
|
||||
self.total_layers = total_layers
|
||||
self.use_mla = use_mla
|
||||
self.use_sparse = len(block_len) == 3
|
||||
self.block_len = block_len
|
||||
self._decode_tp_size = decode_tp_size
|
||||
self.resharding_stream = resharding_stream
|
||||
@@ -195,17 +196,6 @@ class KVCacheSendingLayerThread(threading.Thread):
|
||||
src_list: list[str] = []
|
||||
dst_list: list[str] = []
|
||||
length_list: list[int] = []
|
||||
# not need to send kv cache
|
||||
if self.tp_rank % self.num_head_replica != 0:
|
||||
logger.debug(
|
||||
f"Cancelling KV cache transfer for request {req_id}. Reason: TP rank excluded from head replication (TP Rank: {self.tp_rank}, Replicas: {self.num_head_replica})."
|
||||
)
|
||||
return (src_list, dst_list, length_list)
|
||||
if self.use_mla and self.tp_rank >= self._decode_tp_size:
|
||||
logger.debug(
|
||||
f"Cancelling KV cache transfer for request {req_id}. Reason: MLA mode active and TP rank outside decoding group (TP Rank: {self.tp_rank}, Decode TP Size: {self._decode_tp_size})."
|
||||
)
|
||||
return (src_list, dst_list, length_list)
|
||||
|
||||
layer_idx = send_task.layer_idx
|
||||
remote_block_ids = req_meta.remote_block_ids
|
||||
@@ -214,21 +204,36 @@ class KVCacheSendingLayerThread(threading.Thread):
|
||||
local_block_ids = req_meta.local_block_ids
|
||||
|
||||
if self.pd_head_ratio == 1:
|
||||
layer_local_kv_base_addr = [
|
||||
local_kv_base_addr[i]
|
||||
for i in [2 * layer_idx, 2 * layer_idx + 1]
|
||||
]
|
||||
layer_remote_kv_base_addr = [
|
||||
remote_kv_base_addrs[i] # type:ignore
|
||||
for i in [2 * layer_idx, 2 * layer_idx + 1]
|
||||
]
|
||||
if self.use_sparse:
|
||||
layer_local_kv_base_addr = [
|
||||
local_kv_base_addr[i] for i in
|
||||
[3 * layer_idx, 3 * layer_idx + 1, 3 * layer_idx + 2]
|
||||
]
|
||||
layer_remote_kv_base_addr = [
|
||||
remote_kv_base_addrs[i] # type:ignore
|
||||
for i in
|
||||
[3 * layer_idx, 3 * layer_idx + 1, 3 * layer_idx + 2]
|
||||
]
|
||||
else:
|
||||
layer_local_kv_base_addr = [
|
||||
local_kv_base_addr[i]
|
||||
for i in [2 * layer_idx, 2 * layer_idx + 1]
|
||||
]
|
||||
layer_remote_kv_base_addr = [
|
||||
remote_kv_base_addrs[i] # type:ignore
|
||||
for i in [2 * layer_idx, 2 * layer_idx + 1]
|
||||
]
|
||||
grouped_remote_block_ids, grouped_local_block_ids = \
|
||||
group_concurrent_contiguous(remote_block_ids, local_block_ids)
|
||||
|
||||
for k, (src_layer_base_addr, dst_layer_base_addr) in enumerate(
|
||||
zip(layer_local_kv_base_addr, layer_remote_kv_base_addr)):
|
||||
block_len = self.block_len[
|
||||
k % 2] if self.use_mla else self.block_len[0]
|
||||
if self.use_mla:
|
||||
block_len = (self.block_len[k % 2])
|
||||
elif self.use_sparse:
|
||||
block_len = (self.block_len[k % 3])
|
||||
else:
|
||||
block_len = (self.block_len[0])
|
||||
for group_remote_block_id, group_local_block_id in zip(
|
||||
grouped_remote_block_ids, grouped_local_block_ids):
|
||||
src = src_layer_base_addr + group_local_block_id[
|
||||
@@ -931,7 +936,9 @@ class MooncakeLayerwiseConnectorWorker:
|
||||
|
||||
# TODO(tms): Find a more robust way to detect and handle MLA
|
||||
self.use_mla = first_kv_cache_tuple[0].size(
|
||||
-1) != first_kv_cache_tuple[1].size(-1)
|
||||
-1) != first_kv_cache_tuple[1].size(-1) and len(
|
||||
first_kv_cache_tuple) == 2
|
||||
self.use_sparse = len(first_kv_cache_tuple) == 3
|
||||
if self.use_mla:
|
||||
# MLA case.[num_block, block_size, 1, hidden_dim]
|
||||
self.num_blocks = first_kv_cache.shape[0]
|
||||
@@ -945,6 +952,21 @@ class MooncakeLayerwiseConnectorWorker:
|
||||
logger.info(
|
||||
"num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s",
|
||||
self.num_blocks, block_shape_norm, block_shape_pe)
|
||||
elif self.use_sparse:
|
||||
self.num_blocks = first_kv_cache.shape[0]
|
||||
block_rank = 3 # [block_size, latent_dim]
|
||||
block_shape_norm = first_kv_cache_tuple[0].shape[-block_rank:]
|
||||
block_shape_pe = first_kv_cache_tuple[1].shape[-block_rank:]
|
||||
block_shape_k = first_kv_cache_tuple[2].shape[-block_rank:]
|
||||
self.block_len = [
|
||||
first_kv_cache[0].element_size() * math.prod(block_shape_norm),
|
||||
first_kv_cache[1].element_size() * math.prod(block_shape_pe),
|
||||
first_kv_cache[2].element_size() * math.prod(block_shape_k)
|
||||
]
|
||||
logger.info(
|
||||
"num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s, block_shape_k: %s",
|
||||
self.num_blocks, block_shape_norm, block_shape_pe,
|
||||
block_shape_k)
|
||||
else:
|
||||
# [num_block, block_size, num_head, hidden_dim]
|
||||
self.num_blocks = first_kv_cache.shape[0]
|
||||
@@ -955,8 +977,9 @@ class MooncakeLayerwiseConnectorWorker:
|
||||
logger.info("num_blocks: %s, block_shape: %s", self.num_blocks,
|
||||
block_shape)
|
||||
|
||||
logger.info("Registering KV_Caches. use_mla: %s, shape %s",
|
||||
self.use_mla, first_kv_cache.shape)
|
||||
logger.info(
|
||||
"Registering KV_Caches. use_mla: %s, use_sparse: %s, shape %s",
|
||||
self.use_mla, self.use_sparse, first_kv_cache.shape)
|
||||
|
||||
self.kv_caches = kv_caches
|
||||
kv_caches_base_addr = []
|
||||
@@ -971,9 +994,17 @@ class MooncakeLayerwiseConnectorWorker:
|
||||
kv_caches_base_addr.append(base_addr)
|
||||
ptrs.append(base_addr)
|
||||
lengths.append(region_len)
|
||||
elif self.use_sparse:
|
||||
for i, cache in enumerate(cache_or_caches, 0):
|
||||
base_addr = cache.data_ptr()
|
||||
region_len = self.num_blocks * self.block_len[i % 3]
|
||||
kv_caches_base_addr.append(base_addr)
|
||||
ptrs.append(base_addr)
|
||||
lengths.append(region_len)
|
||||
else:
|
||||
cache_list = [cache_or_caches
|
||||
] if self.use_mla else cache_or_caches
|
||||
cache_list = [
|
||||
cache_or_caches
|
||||
] if self.use_mla or self.use_sparse else cache_or_caches
|
||||
for cache in cache_list:
|
||||
base_addr = cache.data_ptr()
|
||||
region_len = self.num_blocks * self.block_len[0]
|
||||
@@ -1046,56 +1077,72 @@ class MooncakeLayerwiseConnectorWorker:
|
||||
if self.vllm_config.kv_transfer_config.is_kv_producer and connector_metadata.requests.keys(
|
||||
):
|
||||
# enable decode prefix cache
|
||||
if self.use_mla:
|
||||
reshape_cache_event = attn_metadata[
|
||||
layer_name].reshape_cache_event
|
||||
if self.use_mla or self.use_sparse:
|
||||
num_kv_head = self._decode_tp_size
|
||||
else:
|
||||
reshape_cache_event = attn_metadata.reshape_cache_event
|
||||
num_kv_head = self.vllm_config.model_config.hf_config.num_key_value_heads
|
||||
num_replica_groups = self.tp_size // num_kv_head if self.tp_size >= num_kv_head else 1
|
||||
replica_group_idx = self.tp_rank % num_replica_groups
|
||||
req_ids = sorted(list(connector_metadata.requests.keys()))
|
||||
selected_req_ids = [
|
||||
req_id for i, req_id in enumerate(req_ids)
|
||||
if i % num_replica_groups == replica_group_idx
|
||||
]
|
||||
if selected_req_ids:
|
||||
if self.use_mla or self.use_sparse:
|
||||
reshape_cache_event = attn_metadata[
|
||||
layer_name].reshape_cache_event
|
||||
else:
|
||||
reshape_cache_event = attn_metadata.reshape_cache_event
|
||||
|
||||
if self.pd_head_ratio != 1:
|
||||
assert self.resharding_stream is not None
|
||||
with npu_stream_switch(self.resharding_stream):
|
||||
reshape_cache_event.wait()
|
||||
rearrange_block_ids = sorted({
|
||||
block_id
|
||||
for request in connector_metadata.requests.values()
|
||||
for block_id in request.local_block_ids
|
||||
})
|
||||
if self.pd_head_ratio != 1:
|
||||
assert self.resharding_stream is not None
|
||||
with npu_stream_switch(self.resharding_stream):
|
||||
reshape_cache_event.wait()
|
||||
rearrange_block_ids = sorted({
|
||||
block_id
|
||||
for req_id in selected_req_ids
|
||||
for block_id in
|
||||
connector_metadata.requests[req_id].local_block_ids
|
||||
})
|
||||
|
||||
keys = kv_layer[0][rearrange_block_ids].clone()
|
||||
values = kv_layer[1][rearrange_block_ids].clone()
|
||||
# sort kv caches for each block
|
||||
keys = keys.view(keys.size(0), self.pd_head_ratio, -1,
|
||||
*keys.shape[2:]).transpose(
|
||||
0, 1).reshape_as(keys)
|
||||
values = values.view(values.size(0), self.pd_head_ratio,
|
||||
-1, *values.shape[2:]).transpose(
|
||||
0, 1).reshape_as(values)
|
||||
# reshard kv cache
|
||||
keys = keys.reshape(-1, *kv_layer[0].shape[2:])
|
||||
values = values.reshape(-1, *kv_layer[1].shape[2:])
|
||||
(keys, values) = kv_alltoall_and_rearrange(
|
||||
self.pd_head_ratio, keys, values)
|
||||
else:
|
||||
keys = None
|
||||
values = None
|
||||
rearrange_block_ids = None
|
||||
keys = kv_layer[0][rearrange_block_ids].clone()
|
||||
values = kv_layer[1][rearrange_block_ids].clone()
|
||||
# sort kv caches for each block
|
||||
keys = keys.view(keys.size(0), self.pd_head_ratio, -1,
|
||||
*keys.shape[2:]).transpose(
|
||||
0, 1).reshape_as(keys)
|
||||
values = values.view(values.size(0),
|
||||
self.pd_head_ratio, -1,
|
||||
*values.shape[2:]).transpose(
|
||||
0, 1).reshape_as(values)
|
||||
# reshard kv cache
|
||||
keys = keys.reshape(-1, *kv_layer[0].shape[2:])
|
||||
values = values.reshape(-1, *kv_layer[1].shape[2:])
|
||||
(keys, values) = kv_alltoall_and_rearrange(
|
||||
self.pd_head_ratio, keys, values)
|
||||
else:
|
||||
keys = None
|
||||
values = None
|
||||
rearrange_block_ids = None
|
||||
|
||||
assert self.kv_send_layer_thread is not None
|
||||
assert reshape_cache_event is not None
|
||||
send_task = SendTask(wait_event=reshape_cache_event,
|
||||
k_cache=keys,
|
||||
v_cache=values,
|
||||
layer_idx=self.current_layer,
|
||||
rearrange_block_ids=rearrange_block_ids)
|
||||
for req_id, req_meta in connector_metadata.requests.items():
|
||||
req_meta_update = self.update_decoder_info(req_id, req_meta)
|
||||
logger.debug(
|
||||
f"Add request {req_id} to kv send layer thread. {req_meta_update=}"
|
||||
)
|
||||
send_task.send_request[req_id] = req_meta_update
|
||||
assert self.kv_send_layer_thread is not None
|
||||
assert reshape_cache_event is not None
|
||||
send_task = SendTask(wait_event=reshape_cache_event,
|
||||
k_cache=keys,
|
||||
v_cache=values,
|
||||
layer_idx=self.current_layer,
|
||||
rearrange_block_ids=rearrange_block_ids)
|
||||
for req_id, req_meta in connector_metadata.requests.items():
|
||||
if req_id in selected_req_ids:
|
||||
req_meta_update = self.update_decoder_info(
|
||||
req_id, req_meta)
|
||||
logger.debug(
|
||||
f"Add request {req_id} to kv send layer thread. {req_meta_update=}"
|
||||
)
|
||||
send_task.send_request[req_id] = req_meta_update
|
||||
|
||||
self.kv_send_layer_thread.send_queue.put(send_task)
|
||||
self.kv_send_layer_thread.send_queue.put(send_task)
|
||||
self.current_layer += 1
|
||||
|
||||
def _get_remote_socket(
|
||||
@@ -1121,8 +1168,13 @@ class MooncakeLayerwiseConnectorWorker:
|
||||
|
||||
def update_decoder_info(self, req_id, req_meta):
|
||||
req_meta_update = copy.deepcopy(req_meta)
|
||||
req_meta_update.remote_port = req_meta_update.remote_port + (
|
||||
self.tp_rank // self.pd_tp_ratio) % self._decode_tp_size
|
||||
if self.use_mla or self.use_sparse:
|
||||
pd_tp_ratio = self.tp_size // self._decode_tp_size
|
||||
req_meta_update.remote_port = req_meta_update.remote_port + (
|
||||
self.tp_rank // pd_tp_ratio) % self._decode_tp_size
|
||||
else:
|
||||
req_meta_update.remote_port = req_meta_update.remote_port + (
|
||||
self.tp_rank // self.pd_tp_ratio) % self._decode_tp_size
|
||||
if req_meta_update.remote_engine_id not in self.remote_kv_caches_base_addr or \
|
||||
req_meta_update.remote_port not in self.remote_kv_caches_base_addr[req_meta_update.remote_engine_id]:
|
||||
try:
|
||||
@@ -1146,14 +1198,16 @@ class MooncakeLayerwiseConnectorWorker:
|
||||
logger.info(
|
||||
f"Query to port and kv base addr for request {req_id} from {req_meta_update.remote_host}:{req_meta_update.remote_port} success {agent_meta.kv_caches_base_addr=} {agent_meta.te_rpc_port=}"
|
||||
)
|
||||
session_id = f"{req_meta_update.remote_host}:{agent_meta.te_rpc_port}"
|
||||
ret = self.engine.batch_transfer_sync_write(
|
||||
session_id, [self.kv_caches_base_addr[0]],
|
||||
[agent_meta.kv_caches_base_addr[0]], [128])
|
||||
if ret < 0:
|
||||
logger.error(
|
||||
f"Mooncake transfer failed to create link to device {session_id}"
|
||||
)
|
||||
if self.pd_head_ratio > 1:
|
||||
# for tp inequal, pre-create link to prevent alltoall out of memory
|
||||
session_id = f"{req_meta_update.remote_host}:{agent_meta.te_rpc_port}"
|
||||
ret = self.engine.batch_transfer_sync_write(
|
||||
session_id, [self.kv_caches_base_addr[0]],
|
||||
[agent_meta.kv_caches_base_addr[0]], [128])
|
||||
if ret < 0:
|
||||
logger.error(
|
||||
f"Mooncake transfer failed to create link to device {session_id}"
|
||||
)
|
||||
req_meta_update.remote_te_rpc_port = self.remote_te_port[
|
||||
req_meta_update.remote_engine_id][req_meta_update.remote_port]
|
||||
req_meta_update.remote_kv_caches_base_addr = self.remote_kv_caches_base_addr[
|
||||
|
||||
Reference in New Issue
Block a user