diff --git a/vllm_ascend/distributed/kvpool/config_data.py b/vllm_ascend/distributed/kvpool/config_data.py index c52b982a..5de21350 100644 --- a/vllm_ascend/distributed/kvpool/config_data.py +++ b/vllm_ascend/distributed/kvpool/config_data.py @@ -1,6 +1,7 @@ from dataclasses import dataclass from typing import Iterable, List, Optional, Tuple, Union +import torch from vllm.distributed.kv_transfer.kv_connector.v1.base import \ KVConnectorMetadata from vllm.logger import logger @@ -284,6 +285,8 @@ class ReqMeta: is_last_chunk: Optional[bool] = None + current_event: Optional[torch.npu.Event] = None + @staticmethod def from_request_tracker( tracker: RequestTracker, @@ -375,3 +378,4 @@ class LasyerMultiBlockReqMeta: block_ids: list[int] layer_id: int is_last_chunk: Optional[bool] = True + current_event: Optional[torch.npu.Event] = None \ No newline at end of file diff --git a/vllm_ascend/distributed/kvpool/kv_transfer.py b/vllm_ascend/distributed/kvpool/kv_transfer.py index 3eb4c00a..02d79c88 100644 --- a/vllm_ascend/distributed/kvpool/kv_transfer.py +++ b/vllm_ascend/distributed/kvpool/kv_transfer.py @@ -114,6 +114,7 @@ class KVCacheStoreSendingThread(KVTransferThread): block_ids = req_meta.block_ids req_id = req_meta.req_id is_last_chunk = req_meta.is_last_chunk + current_event = req_meta.current_event starts = [] ends = [] keys = [] @@ -161,6 +162,14 @@ class KVCacheStoreSendingThread(KVTransferThread): addrs.append(addr) sizes.append(size) if keys: + """ + Note: Due to a bug in ADXL, calling current_event.synchronize() may occasionally hang. + This issue will be fixed in CANN version 8.5.rc1. + You can manually build the master branch of the project at https://gitcode.com/cann/hixl + to resolve this issue before the 8.5.RC1 release. + """ + if current_event is not None: + current_event.synchronize() self.m_store.put(keys, addrs, sizes) if is_last_chunk: @@ -235,6 +244,7 @@ class KVCacheStoreLayerSendingThread(KVTransferThread): ends = req_meta.ends keys = req_meta.keys layer_id = req_meta.layer_id + current_event = req_meta.current_event total_block = len(keys) is_last_chunk = req_meta.is_last_chunk if not self.dcp_size > 1: @@ -270,6 +280,8 @@ class KVCacheStoreLayerSendingThread(KVTransferThread): addr_list.append(addr) size_list.append(size) + if current_event is not None: + current_event.synchronize() self.m_store.put(key_list, addr_list, size_list) if layer_id == self.final_layer_id and is_last_chunk: diff --git a/vllm_ascend/distributed/kvpool/pool_worker.py b/vllm_ascend/distributed/kvpool/pool_worker.py index 24a39537..97bc4c5f 100644 --- a/vllm_ascend/distributed/kvpool/pool_worker.py +++ b/vllm_ascend/distributed/kvpool/pool_worker.py @@ -251,12 +251,20 @@ class KVPoolWorker: connector_metadata: AscendConnectorMetadata) -> None: if self.current_layer == 0: self.layerwise_storers = [] + current_event = None + for request in connector_metadata.requests: + can_save = request.can_save + if can_save is None or not can_save: + continue + current_event = torch.npu.Event() + current_event.record() + break for request in connector_metadata.requests: can_save = request.can_save if can_save is None or not can_save: continue - layerwise_storer = self.store_layer(request) + layerwise_storer = self.store_layer(request, current_event) self.layerwise_storers.append(layerwise_storer) for layerwise_storer in self.layerwise_storers: try: @@ -266,11 +274,21 @@ class KVPoolWorker: self.current_layer = self.current_layer + 1 def wait_for_save(self, connector_metadata: AscendConnectorMetadata): + current_event = None + for request in connector_metadata.requests: + can_save = request.can_save + if can_save is None or not can_save: + continue + current_event = torch.npu.Event() + current_event.record() + break + for request in connector_metadata.requests: can_save = request.can_save if can_save is None or not can_save: continue + request.current_event = current_event self.kv_send_thread.add_request( # type: ignore[union-attr] request, ) @@ -347,6 +365,7 @@ class KVPoolWorker: def store_layer( self, request: ReqMeta, + current_event: Optional[torch.npu.Event], ) -> Generator[None, None, None]: """ Store the KV cache in a layerwise manner. @@ -385,7 +404,8 @@ class KVPoolWorker: keys_multi_chunk, starts, ends, request.block_ids, layer_id, - request.is_last_chunk) + request.is_last_chunk, + current_event) self.kv_send_thread.add_request( # type: ignore[union-attr, call-arg] req_meta) # type: ignore[union-attr, call-arg, arg-type] yield