[P/D] layerwise connector supports DeepSeek-V3.2 sparse attention && Distribute transfer tasks to redundant kv_head cards (#5722)

### What this PR does / why we need it?
Add new function to mooncake layerwise connector, including:
1. supports sparse attention, for DeepSeek-V3.2
2. Distribute transfer tasks to redundant kv_head cards

This PR is related to [[RFC]: CDCP Scheduling for Disaggregated
Prefilling with KV Cache Layerwise Push
Support](https://github.com/vllm-project/vllm-ascend/issues/4842)

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
By CI.

- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef

---------

Signed-off-by: nwpu-zxr <zhouxuerong2@huawei.com>
Signed-off-by: liziyu <liziyu16@huawei.com>
Co-authored-by: liziyu <liziyu16@huawei.com>
This commit is contained in:
zxr2333
2026-01-10 23:04:16 +08:00
committed by GitHub
parent c316679e65
commit 78b554dda9
3 changed files with 142 additions and 111 deletions

View File

@@ -170,36 +170,6 @@ class TestKVCacheSendingLayerThread(unittest.TestCase):
self.thread._transfer_kv_cache(send_task) self.thread._transfer_kv_cache(send_task)
self.engine.batch_transfer_sync_write.assert_not_called() self.engine.batch_transfer_sync_write.assert_not_called()
def test_transfer_skips_when_tp_not_sender(self):
thread = KVCacheSendingLayerThread(
engine=self.engine,
total_layers=2,
ready_event=self.ready_event,
tp_rank=1,
pd_head_ratio=1,
num_head_replica=2,
kv_cache_base_addr=[1000, 2000, 3000, 4000],
use_mla=False,
block_len=[1024],
decode_tp_size=1,
first_kv_cache=self.first_kv_cache,
k_buffer=MagicMock(),
v_buffer=MagicMock(),
resharding_stream=MagicMock(),
callback_func=MagicMock())
req_meta = self.req_meta_base
send_task = SendTask(
send_request={"req3": req_meta},
wait_event=MagicMock(),
k_cache=self.key,
v_cache=self.value,
layer_idx=1,
rearrange_block_ids=[],
)
thread._transfer_kv_cache(send_task)
self.engine.batch_transfer_sync_write.assert_not_called()
@patch( @patch(
"vllm_ascend.distributed.mooncake_layerwise_connector.group_concurrent_contiguous", "vllm_ascend.distributed.mooncake_layerwise_connector.group_concurrent_contiguous",
side_effect=group_concurrent_contiguous) side_effect=group_concurrent_contiguous)
@@ -425,6 +395,7 @@ class MockVllmConfig:
self.parallel_config.data_parallel_size = 1 self.parallel_config.data_parallel_size = 1
self.parallel_config.data_parallel_rank = 0 self.parallel_config.data_parallel_rank = 0
self.cache_config.block_size = 16 self.cache_config.block_size = 16
self.model_config.hf_config.num_key_value_heads = 1
self.kv_transfer_config.engine_id = "test_engine" self.kv_transfer_config.engine_id = "test_engine"
self.kv_transfer_config.kv_port = 5000 self.kv_transfer_config.kv_port = 5000

View File

@@ -108,6 +108,7 @@ class AscendSFAMetadata:
# chunked prefill by default if no attn_states passed # chunked prefill by default if no attn_states passed
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
sfa_cp_context: Optional[SfaCpContext] = None sfa_cp_context: Optional[SfaCpContext] = None
reshape_cache_event: torch.npu.Event = None
M = TypeVar("M", bound=AscendSFAMetadata) M = TypeVar("M", bound=AscendSFAMetadata)
@@ -369,6 +370,7 @@ class AscendSFAImpl(MLAAttentionImpl):
self.enable_sfa_cp = enable_dsa_cp() self.enable_sfa_cp = enable_dsa_cp()
self.local_num_heads = self.num_heads self.local_num_heads = self.num_heads
self.vllm_config = get_current_vllm_config() self.vllm_config = get_current_vllm_config()
self.is_kv_producer = self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_producer
if self.enable_sfa_cp: if self.enable_sfa_cp:
self.local_num_heads = self.num_heads * self.tp_size self.local_num_heads = self.num_heads * self.tp_size
self.layer_sharding_kwargs = [] self.layer_sharding_kwargs = []
@@ -897,11 +899,15 @@ class AscendSFAImpl(MLAAttentionImpl):
k = get_tp_group().all_gather(k, 0) k = get_tp_group().all_gather(k, 0)
if kv_cache is not None: if kv_cache is not None:
if self.is_kv_producer:
attn_metadata.reshape_cache_event = torch.npu.Event()
torch_npu.npu_scatter_nd_update_(kv_cache[2].view(-1, k.shape[-1]), torch_npu.npu_scatter_nd_update_(kv_cache[2].view(-1, k.shape[-1]),
attn_metadata.slot_mapping.view( attn_metadata.slot_mapping.view(
-1, 1), -1, 1),
k.view(-1, k.view(-1,
k.shape[-1])) # b, s, n, d k.shape[-1])) # b, s, n, d
if self.is_kv_producer:
attn_metadata.reshape_cache_event.record()
weights, _ = self.weights_proj(x) weights, _ = self.weights_proj(x)
weights = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( weights = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(

View File

@@ -162,6 +162,7 @@ class KVCacheSendingLayerThread(threading.Thread):
self.kv_caches_base_addr = kv_cache_base_addr self.kv_caches_base_addr = kv_cache_base_addr
self.total_layers = total_layers self.total_layers = total_layers
self.use_mla = use_mla self.use_mla = use_mla
self.use_sparse = len(block_len) == 3
self.block_len = block_len self.block_len = block_len
self._decode_tp_size = decode_tp_size self._decode_tp_size = decode_tp_size
self.resharding_stream = resharding_stream self.resharding_stream = resharding_stream
@@ -195,17 +196,6 @@ class KVCacheSendingLayerThread(threading.Thread):
src_list: list[str] = [] src_list: list[str] = []
dst_list: list[str] = [] dst_list: list[str] = []
length_list: list[int] = [] length_list: list[int] = []
# not need to send kv cache
if self.tp_rank % self.num_head_replica != 0:
logger.debug(
f"Cancelling KV cache transfer for request {req_id}. Reason: TP rank excluded from head replication (TP Rank: {self.tp_rank}, Replicas: {self.num_head_replica})."
)
return (src_list, dst_list, length_list)
if self.use_mla and self.tp_rank >= self._decode_tp_size:
logger.debug(
f"Cancelling KV cache transfer for request {req_id}. Reason: MLA mode active and TP rank outside decoding group (TP Rank: {self.tp_rank}, Decode TP Size: {self._decode_tp_size})."
)
return (src_list, dst_list, length_list)
layer_idx = send_task.layer_idx layer_idx = send_task.layer_idx
remote_block_ids = req_meta.remote_block_ids remote_block_ids = req_meta.remote_block_ids
@@ -214,21 +204,36 @@ class KVCacheSendingLayerThread(threading.Thread):
local_block_ids = req_meta.local_block_ids local_block_ids = req_meta.local_block_ids
if self.pd_head_ratio == 1: if self.pd_head_ratio == 1:
layer_local_kv_base_addr = [ if self.use_sparse:
local_kv_base_addr[i] layer_local_kv_base_addr = [
for i in [2 * layer_idx, 2 * layer_idx + 1] local_kv_base_addr[i] for i in
] [3 * layer_idx, 3 * layer_idx + 1, 3 * layer_idx + 2]
layer_remote_kv_base_addr = [ ]
remote_kv_base_addrs[i] # type:ignore layer_remote_kv_base_addr = [
for i in [2 * layer_idx, 2 * layer_idx + 1] remote_kv_base_addrs[i] # type:ignore
] for i in
[3 * layer_idx, 3 * layer_idx + 1, 3 * layer_idx + 2]
]
else:
layer_local_kv_base_addr = [
local_kv_base_addr[i]
for i in [2 * layer_idx, 2 * layer_idx + 1]
]
layer_remote_kv_base_addr = [
remote_kv_base_addrs[i] # type:ignore
for i in [2 * layer_idx, 2 * layer_idx + 1]
]
grouped_remote_block_ids, grouped_local_block_ids = \ grouped_remote_block_ids, grouped_local_block_ids = \
group_concurrent_contiguous(remote_block_ids, local_block_ids) group_concurrent_contiguous(remote_block_ids, local_block_ids)
for k, (src_layer_base_addr, dst_layer_base_addr) in enumerate( for k, (src_layer_base_addr, dst_layer_base_addr) in enumerate(
zip(layer_local_kv_base_addr, layer_remote_kv_base_addr)): zip(layer_local_kv_base_addr, layer_remote_kv_base_addr)):
block_len = self.block_len[ if self.use_mla:
k % 2] if self.use_mla else self.block_len[0] block_len = (self.block_len[k % 2])
elif self.use_sparse:
block_len = (self.block_len[k % 3])
else:
block_len = (self.block_len[0])
for group_remote_block_id, group_local_block_id in zip( for group_remote_block_id, group_local_block_id in zip(
grouped_remote_block_ids, grouped_local_block_ids): grouped_remote_block_ids, grouped_local_block_ids):
src = src_layer_base_addr + group_local_block_id[ src = src_layer_base_addr + group_local_block_id[
@@ -931,7 +936,9 @@ class MooncakeLayerwiseConnectorWorker:
# TODO(tms): Find a more robust way to detect and handle MLA # TODO(tms): Find a more robust way to detect and handle MLA
self.use_mla = first_kv_cache_tuple[0].size( self.use_mla = first_kv_cache_tuple[0].size(
-1) != first_kv_cache_tuple[1].size(-1) -1) != first_kv_cache_tuple[1].size(-1) and len(
first_kv_cache_tuple) == 2
self.use_sparse = len(first_kv_cache_tuple) == 3
if self.use_mla: if self.use_mla:
# MLA case.[num_block, block_size, 1, hidden_dim] # MLA case.[num_block, block_size, 1, hidden_dim]
self.num_blocks = first_kv_cache.shape[0] self.num_blocks = first_kv_cache.shape[0]
@@ -945,6 +952,21 @@ class MooncakeLayerwiseConnectorWorker:
logger.info( logger.info(
"num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s", "num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s",
self.num_blocks, block_shape_norm, block_shape_pe) self.num_blocks, block_shape_norm, block_shape_pe)
elif self.use_sparse:
self.num_blocks = first_kv_cache.shape[0]
block_rank = 3 # [block_size, latent_dim]
block_shape_norm = first_kv_cache_tuple[0].shape[-block_rank:]
block_shape_pe = first_kv_cache_tuple[1].shape[-block_rank:]
block_shape_k = first_kv_cache_tuple[2].shape[-block_rank:]
self.block_len = [
first_kv_cache[0].element_size() * math.prod(block_shape_norm),
first_kv_cache[1].element_size() * math.prod(block_shape_pe),
first_kv_cache[2].element_size() * math.prod(block_shape_k)
]
logger.info(
"num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s, block_shape_k: %s",
self.num_blocks, block_shape_norm, block_shape_pe,
block_shape_k)
else: else:
# [num_block, block_size, num_head, hidden_dim] # [num_block, block_size, num_head, hidden_dim]
self.num_blocks = first_kv_cache.shape[0] self.num_blocks = first_kv_cache.shape[0]
@@ -955,8 +977,9 @@ class MooncakeLayerwiseConnectorWorker:
logger.info("num_blocks: %s, block_shape: %s", self.num_blocks, logger.info("num_blocks: %s, block_shape: %s", self.num_blocks,
block_shape) block_shape)
logger.info("Registering KV_Caches. use_mla: %s, shape %s", logger.info(
self.use_mla, first_kv_cache.shape) "Registering KV_Caches. use_mla: %s, use_sparse: %s, shape %s",
self.use_mla, self.use_sparse, first_kv_cache.shape)
self.kv_caches = kv_caches self.kv_caches = kv_caches
kv_caches_base_addr = [] kv_caches_base_addr = []
@@ -971,9 +994,17 @@ class MooncakeLayerwiseConnectorWorker:
kv_caches_base_addr.append(base_addr) kv_caches_base_addr.append(base_addr)
ptrs.append(base_addr) ptrs.append(base_addr)
lengths.append(region_len) lengths.append(region_len)
elif self.use_sparse:
for i, cache in enumerate(cache_or_caches, 0):
base_addr = cache.data_ptr()
region_len = self.num_blocks * self.block_len[i % 3]
kv_caches_base_addr.append(base_addr)
ptrs.append(base_addr)
lengths.append(region_len)
else: else:
cache_list = [cache_or_caches cache_list = [
] if self.use_mla else cache_or_caches cache_or_caches
] if self.use_mla or self.use_sparse else cache_or_caches
for cache in cache_list: for cache in cache_list:
base_addr = cache.data_ptr() base_addr = cache.data_ptr()
region_len = self.num_blocks * self.block_len[0] region_len = self.num_blocks * self.block_len[0]
@@ -1046,56 +1077,72 @@ class MooncakeLayerwiseConnectorWorker:
if self.vllm_config.kv_transfer_config.is_kv_producer and connector_metadata.requests.keys( if self.vllm_config.kv_transfer_config.is_kv_producer and connector_metadata.requests.keys(
): ):
# enable decode prefix cache # enable decode prefix cache
if self.use_mla: if self.use_mla or self.use_sparse:
reshape_cache_event = attn_metadata[ num_kv_head = self._decode_tp_size
layer_name].reshape_cache_event
else: else:
reshape_cache_event = attn_metadata.reshape_cache_event num_kv_head = self.vllm_config.model_config.hf_config.num_key_value_heads
num_replica_groups = self.tp_size // num_kv_head if self.tp_size >= num_kv_head else 1
replica_group_idx = self.tp_rank % num_replica_groups
req_ids = sorted(list(connector_metadata.requests.keys()))
selected_req_ids = [
req_id for i, req_id in enumerate(req_ids)
if i % num_replica_groups == replica_group_idx
]
if selected_req_ids:
if self.use_mla or self.use_sparse:
reshape_cache_event = attn_metadata[
layer_name].reshape_cache_event
else:
reshape_cache_event = attn_metadata.reshape_cache_event
if self.pd_head_ratio != 1: if self.pd_head_ratio != 1:
assert self.resharding_stream is not None assert self.resharding_stream is not None
with npu_stream_switch(self.resharding_stream): with npu_stream_switch(self.resharding_stream):
reshape_cache_event.wait() reshape_cache_event.wait()
rearrange_block_ids = sorted({ rearrange_block_ids = sorted({
block_id block_id
for request in connector_metadata.requests.values() for req_id in selected_req_ids
for block_id in request.local_block_ids for block_id in
}) connector_metadata.requests[req_id].local_block_ids
})
keys = kv_layer[0][rearrange_block_ids].clone() keys = kv_layer[0][rearrange_block_ids].clone()
values = kv_layer[1][rearrange_block_ids].clone() values = kv_layer[1][rearrange_block_ids].clone()
# sort kv caches for each block # sort kv caches for each block
keys = keys.view(keys.size(0), self.pd_head_ratio, -1, keys = keys.view(keys.size(0), self.pd_head_ratio, -1,
*keys.shape[2:]).transpose( *keys.shape[2:]).transpose(
0, 1).reshape_as(keys) 0, 1).reshape_as(keys)
values = values.view(values.size(0), self.pd_head_ratio, values = values.view(values.size(0),
-1, *values.shape[2:]).transpose( self.pd_head_ratio, -1,
0, 1).reshape_as(values) *values.shape[2:]).transpose(
# reshard kv cache 0, 1).reshape_as(values)
keys = keys.reshape(-1, *kv_layer[0].shape[2:]) # reshard kv cache
values = values.reshape(-1, *kv_layer[1].shape[2:]) keys = keys.reshape(-1, *kv_layer[0].shape[2:])
(keys, values) = kv_alltoall_and_rearrange( values = values.reshape(-1, *kv_layer[1].shape[2:])
self.pd_head_ratio, keys, values) (keys, values) = kv_alltoall_and_rearrange(
else: self.pd_head_ratio, keys, values)
keys = None else:
values = None keys = None
rearrange_block_ids = None values = None
rearrange_block_ids = None
assert self.kv_send_layer_thread is not None assert self.kv_send_layer_thread is not None
assert reshape_cache_event is not None assert reshape_cache_event is not None
send_task = SendTask(wait_event=reshape_cache_event, send_task = SendTask(wait_event=reshape_cache_event,
k_cache=keys, k_cache=keys,
v_cache=values, v_cache=values,
layer_idx=self.current_layer, layer_idx=self.current_layer,
rearrange_block_ids=rearrange_block_ids) rearrange_block_ids=rearrange_block_ids)
for req_id, req_meta in connector_metadata.requests.items(): for req_id, req_meta in connector_metadata.requests.items():
req_meta_update = self.update_decoder_info(req_id, req_meta) if req_id in selected_req_ids:
logger.debug( req_meta_update = self.update_decoder_info(
f"Add request {req_id} to kv send layer thread. {req_meta_update=}" req_id, req_meta)
) logger.debug(
send_task.send_request[req_id] = req_meta_update f"Add request {req_id} to kv send layer thread. {req_meta_update=}"
)
send_task.send_request[req_id] = req_meta_update
self.kv_send_layer_thread.send_queue.put(send_task) self.kv_send_layer_thread.send_queue.put(send_task)
self.current_layer += 1 self.current_layer += 1
def _get_remote_socket( def _get_remote_socket(
@@ -1121,8 +1168,13 @@ class MooncakeLayerwiseConnectorWorker:
def update_decoder_info(self, req_id, req_meta): def update_decoder_info(self, req_id, req_meta):
req_meta_update = copy.deepcopy(req_meta) req_meta_update = copy.deepcopy(req_meta)
req_meta_update.remote_port = req_meta_update.remote_port + ( if self.use_mla or self.use_sparse:
self.tp_rank // self.pd_tp_ratio) % self._decode_tp_size pd_tp_ratio = self.tp_size // self._decode_tp_size
req_meta_update.remote_port = req_meta_update.remote_port + (
self.tp_rank // pd_tp_ratio) % self._decode_tp_size
else:
req_meta_update.remote_port = req_meta_update.remote_port + (
self.tp_rank // self.pd_tp_ratio) % self._decode_tp_size
if req_meta_update.remote_engine_id not in self.remote_kv_caches_base_addr or \ if req_meta_update.remote_engine_id not in self.remote_kv_caches_base_addr or \
req_meta_update.remote_port not in self.remote_kv_caches_base_addr[req_meta_update.remote_engine_id]: req_meta_update.remote_port not in self.remote_kv_caches_base_addr[req_meta_update.remote_engine_id]:
try: try:
@@ -1146,14 +1198,16 @@ class MooncakeLayerwiseConnectorWorker:
logger.info( logger.info(
f"Query to port and kv base addr for request {req_id} from {req_meta_update.remote_host}:{req_meta_update.remote_port} success {agent_meta.kv_caches_base_addr=} {agent_meta.te_rpc_port=}" f"Query to port and kv base addr for request {req_id} from {req_meta_update.remote_host}:{req_meta_update.remote_port} success {agent_meta.kv_caches_base_addr=} {agent_meta.te_rpc_port=}"
) )
session_id = f"{req_meta_update.remote_host}:{agent_meta.te_rpc_port}" if self.pd_head_ratio > 1:
ret = self.engine.batch_transfer_sync_write( # for tp inequal, pre-create link to prevent alltoall out of memory
session_id, [self.kv_caches_base_addr[0]], session_id = f"{req_meta_update.remote_host}:{agent_meta.te_rpc_port}"
[agent_meta.kv_caches_base_addr[0]], [128]) ret = self.engine.batch_transfer_sync_write(
if ret < 0: session_id, [self.kv_caches_base_addr[0]],
logger.error( [agent_meta.kv_caches_base_addr[0]], [128])
f"Mooncake transfer failed to create link to device {session_id}" if ret < 0:
) logger.error(
f"Mooncake transfer failed to create link to device {session_id}"
)
req_meta_update.remote_te_rpc_port = self.remote_te_port[ req_meta_update.remote_te_rpc_port = self.remote_te_port[
req_meta_update.remote_engine_id][req_meta_update.remote_port] req_meta_update.remote_engine_id][req_meta_update.remote_port]
req_meta_update.remote_kv_caches_base_addr = self.remote_kv_caches_base_addr[ req_meta_update.remote_kv_caches_base_addr = self.remote_kv_caches_base_addr[