[P/D] layerwise connector supports DeepSeek-V3.2 sparse attention && Distribute transfer tasks to redundant kv_head cards (#5722)

### What this PR does / why we need it?
Add new function to mooncake layerwise connector, including:
1. supports sparse attention, for DeepSeek-V3.2
2. Distribute transfer tasks to redundant kv_head cards

This PR is related to [[RFC]: CDCP Scheduling for Disaggregated
Prefilling with KV Cache Layerwise Push
Support](https://github.com/vllm-project/vllm-ascend/issues/4842)

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
By CI.

- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef

---------

Signed-off-by: nwpu-zxr <zhouxuerong2@huawei.com>
Signed-off-by: liziyu <liziyu16@huawei.com>
Co-authored-by: liziyu <liziyu16@huawei.com>
This commit is contained in:
zxr2333
2026-01-10 23:04:16 +08:00
committed by GitHub
parent c316679e65
commit 78b554dda9
3 changed files with 142 additions and 111 deletions

View File

@@ -108,6 +108,7 @@ class AscendSFAMetadata:
# chunked prefill by default if no attn_states passed
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
sfa_cp_context: Optional[SfaCpContext] = None
reshape_cache_event: torch.npu.Event = None
M = TypeVar("M", bound=AscendSFAMetadata)
@@ -369,6 +370,7 @@ class AscendSFAImpl(MLAAttentionImpl):
self.enable_sfa_cp = enable_dsa_cp()
self.local_num_heads = self.num_heads
self.vllm_config = get_current_vllm_config()
self.is_kv_producer = self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_producer
if self.enable_sfa_cp:
self.local_num_heads = self.num_heads * self.tp_size
self.layer_sharding_kwargs = []
@@ -897,11 +899,15 @@ class AscendSFAImpl(MLAAttentionImpl):
k = get_tp_group().all_gather(k, 0)
if kv_cache is not None:
if self.is_kv_producer:
attn_metadata.reshape_cache_event = torch.npu.Event()
torch_npu.npu_scatter_nd_update_(kv_cache[2].view(-1, k.shape[-1]),
attn_metadata.slot_mapping.view(
-1, 1),
k.view(-1,
k.shape[-1])) # b, s, n, d
if self.is_kv_producer:
attn_metadata.reshape_cache_event.record()
weights, _ = self.weights_proj(x)
weights = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(