[P/D] layerwise connector supports DeepSeek-V3.2 sparse attention && Distribute transfer tasks to redundant kv_head cards (#5722)
### What this PR does / why we need it?
Add new function to mooncake layerwise connector, including:
1. supports sparse attention, for DeepSeek-V3.2
2. Distribute transfer tasks to redundant kv_head cards
This PR is related to [[RFC]: CDCP Scheduling for Disaggregated
Prefilling with KV Cache Layerwise Push
Support](https://github.com/vllm-project/vllm-ascend/issues/4842)
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
By CI.
- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef
---------
Signed-off-by: nwpu-zxr <zhouxuerong2@huawei.com>
Signed-off-by: liziyu <liziyu16@huawei.com>
Co-authored-by: liziyu <liziyu16@huawei.com>
This commit is contained in:
@@ -108,6 +108,7 @@ class AscendSFAMetadata:
|
||||
# chunked prefill by default if no attn_states passed
|
||||
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
|
||||
sfa_cp_context: Optional[SfaCpContext] = None
|
||||
reshape_cache_event: torch.npu.Event = None
|
||||
|
||||
|
||||
M = TypeVar("M", bound=AscendSFAMetadata)
|
||||
@@ -369,6 +370,7 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
self.enable_sfa_cp = enable_dsa_cp()
|
||||
self.local_num_heads = self.num_heads
|
||||
self.vllm_config = get_current_vllm_config()
|
||||
self.is_kv_producer = self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_producer
|
||||
if self.enable_sfa_cp:
|
||||
self.local_num_heads = self.num_heads * self.tp_size
|
||||
self.layer_sharding_kwargs = []
|
||||
@@ -897,11 +899,15 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
k = get_tp_group().all_gather(k, 0)
|
||||
|
||||
if kv_cache is not None:
|
||||
if self.is_kv_producer:
|
||||
attn_metadata.reshape_cache_event = torch.npu.Event()
|
||||
torch_npu.npu_scatter_nd_update_(kv_cache[2].view(-1, k.shape[-1]),
|
||||
attn_metadata.slot_mapping.view(
|
||||
-1, 1),
|
||||
k.view(-1,
|
||||
k.shape[-1])) # b, s, n, d
|
||||
if self.is_kv_producer:
|
||||
attn_metadata.reshape_cache_event.record()
|
||||
|
||||
weights, _ = self.weights_proj(x)
|
||||
weights = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
|
||||
Reference in New Issue
Block a user