diff --git a/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/ascend_store_connector.py b/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/ascend_store_connector.py index f969bd92..717ff70a 100644 --- a/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/ascend_store_connector.py +++ b/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/ascend_store_connector.py @@ -1,9 +1,15 @@ import threading +from collections.abc import Iterable from typing import Any import torch import zmq from vllm.config import VllmConfig +from vllm.distributed.kv_events import ( + KVCacheEvent, + KVConnectorKVEvents, + KVEventAggregator, +) from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole from vllm.forward_context import ForwardContext from vllm.logger import logger @@ -12,6 +18,7 @@ from vllm.v1.attention.backend import AttentionMetadata # type: ignore from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import KVCacheConfig +from vllm.v1.outputs import KVConnectorOutput from vllm.v1.request import Request from vllm.v1.serial_utils import MsgpackDecoder @@ -22,6 +29,40 @@ from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.pool_scheduler imp from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.pool_worker import KVPoolWorker +class AscendStoreKVEvents(KVConnectorKVEvents): + def __init__(self, num_workers: int) -> None: + self._aggregator = KVEventAggregator(num_workers) + + def add_events(self, events: list[KVCacheEvent]) -> None: + self._aggregator.add_events(events) + + def aggregate(self) -> "AscendStoreKVEvents": + """ + Aggregate KV events and retain only common events. + """ + common_events = self._aggregator.get_common_events() + self._aggregator.clear_events() + self._aggregator.add_events(common_events) + self._aggregator.reset_workers() + return self + + def increment_workers(self, count: int = 1) -> None: + self._aggregator.increment_workers(count) + + def get_all_events(self) -> list[KVCacheEvent]: + return self._aggregator.get_all_events() + + def get_number_of_workers(self) -> int: + return self._aggregator.get_number_of_workers() + + def clear_events(self) -> None: + self._aggregator.clear_events() + self._aggregator.reset_workers() + + def __repr__(self) -> str: + return f"" + + class AscendStoreConnector(KVConnectorBase_V1): def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole, kv_cache_config: KVCacheConfig | None = None): super().__init__(vllm_config=vllm_config, role=role, kv_cache_config=kv_cache_config) @@ -40,6 +81,7 @@ class AscendStoreConnector(KVConnectorBase_V1): ) self.kv_caches: dict[str, torch.Tensor] = {} + self._kv_cache_events: AscendStoreKVEvents | None = None self.sended_but_unfinished_reqs: set[str] = set() @@ -82,6 +124,39 @@ class AscendStoreConnector(KVConnectorBase_V1): assert self.connector_scheduler is not None return self.connector_scheduler.request_finished(request, block_ids) + def update_connector_output(self, connector_output: KVConnectorOutput): + """ + Update KVConnector state from worker-side connectors output. + + Args: + connector_output (KVConnectorOutput): the worker-side connectors output. + """ + # Get the KV events + kv_cache_events = connector_output.kv_cache_events + if not kv_cache_events or not isinstance(kv_cache_events, AscendStoreKVEvents): + return + + if self._kv_cache_events is None: + self._kv_cache_events = kv_cache_events + else: + self._kv_cache_events.add_events(kv_cache_events.get_all_events()) + self._kv_cache_events.increment_workers(kv_cache_events.get_number_of_workers()) + return + + def take_events(self) -> Iterable["KVCacheEvent"]: + """ + Take the KV cache events from the connector. + + Yields: + New KV cache events since the last call. + """ + if self._kv_cache_events is not None: + self._kv_cache_events.aggregate() + kv_cache_events = self._kv_cache_events.get_all_events() + yield from kv_cache_events + self._kv_cache_events.clear_events() + self._kv_cache_events = None + ############################################################ # Worker Side Methods ############################################################ @@ -127,6 +202,18 @@ class AscendStoreConnector(KVConnectorBase_V1): ) return done_sending, done_recving + def get_kv_connector_kv_cache_events(self) -> AscendStoreKVEvents | None: + """ + Get the KV connector kv cache events collected during the last interval. + """ + events = self.connector_worker.get_kv_events() + if not events: + return None + + ascend_store_kv_events = AscendStoreKVEvents(num_workers=1) + ascend_store_kv_events.add_events(events) + return ascend_store_kv_events + class LookupKeyServer: def __init__( diff --git a/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/config_data.py b/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/config_data.py index 8cc1bad1..a7058cbb 100644 --- a/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/config_data.py +++ b/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/config_data.py @@ -221,7 +221,6 @@ class RequestTracker: # Request id req_id: str - # The token ids that has been scheduled so far token_len: int # The block ids that has been allocated so far @@ -233,6 +232,10 @@ class RequestTracker: # The number of tokens that has been savd num_saved_tokens: int = 0 + # The token ids that has been scheduled so far + # NOTE: This field will only be used when you enable kv-event + token_ids: list[int] | None = None + @staticmethod def from_new_request( new_request: "NewRequestData", @@ -256,6 +259,7 @@ class RequestTracker: return RequestTracker( req_id=new_request.req_id, + token_ids=new_request.prompt_token_ids[:num_tokens_to_compute].copy(), token_len=num_tokens_to_compute, allocated_block_ids=unfolded_block_ids, num_saved_tokens=0, @@ -268,7 +272,6 @@ class RequestTracker: """Update the request tracker when a running request is scheduled again """ - if len(new_block_ids) == 0: new_block_ids = [] elif isinstance(new_block_ids, tuple): @@ -284,7 +287,7 @@ class RequestTracker: class ReqMeta: # Request id req_id: str - # Request tokens + # Number of tokens in this chunk token_len_chunk: int block_ids: list[int] @@ -299,6 +302,11 @@ class ReqMeta: current_event: torch.npu.Event | None = None + # The following parameters are only used for kv event generation + # TODO: add lora_request which used for gen lora_id/lora_name in kv event + token_ids: list[int] | None = None + original_block_size: int | None = None + @staticmethod def from_request_tracker( tracker: RequestTracker, @@ -308,15 +316,18 @@ class ReqMeta: block_hashes: list[BlockHash] | None = None, is_last_chunk: bool | None = None, discard_partial_chunks: bool = True, + original_block_size: int | None = None, ) -> Optional["ReqMeta"]: """Create the request metadata from a request tracker. Args: tracker (RequestTracker): the request tracker. - block_size (int): the block size in vLLM. + block_size (int): the block size in vLLM scheduler and AscendConnector. + If context parallelism is enabled, block_size = block_size * pcp_size * dcp_size. load_spec (Optional[LoadSpec]): the load spec for KV cache loading. skip_save (bool): whether to skip the save operation. discard_partial_chunks (bool): whether to discard partial chunks. + original_block_size (int | None): the block size in vLLM worker. This is only used for kv event generation. Returns: the request metadata if we need to perform load/save @@ -342,6 +353,11 @@ class ReqMeta: if not skip_save: tracker.num_saved_tokens = num_tokens_to_save + # Get the token ids for kv event generation in kv_transfer + token_ids = None + if tracker.token_ids: + token_ids = tracker.token_ids + # # For load operation: check whether the request is scheduled to load if load_spec is not None and load_spec.can_load: logger.debug( @@ -361,6 +377,8 @@ class ReqMeta: load_spec=load_spec, block_hashes=block_hashes, is_last_chunk=is_last_chunk, + token_ids=token_ids, + original_block_size=original_block_size, ) diff --git a/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/kv_transfer.py b/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/kv_transfer.py index ec8c7041..a7e570fa 100644 --- a/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/kv_transfer.py +++ b/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/kv_transfer.py @@ -5,7 +5,9 @@ from concurrent.futures import ThreadPoolExecutor from typing import Any import torch +from vllm.distributed.kv_events import BlockStored from vllm.logger import logger +from vllm.v1.core.kv_cache_utils import maybe_convert_block_hash from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.backend import Backend @@ -41,6 +43,8 @@ class KVTransferThread(threading.Thread): # TODO(jianzs): make this configurable self.executor = ThreadPoolExecutor(max_workers=32) self.finished_requests: set[str] = set() + self.kv_event_lock = threading.Lock() + self.kv_events: list[BlockStored] = [] def add_request( self, @@ -101,6 +105,16 @@ class KVTransferThread(threading.Thread): return 0 return len(keys) + def update_kv_event(self, event: list[BlockStored]): + with self.kv_event_lock: + self.kv_events.extend(event) + + def get_kv_events(self) -> list[BlockStored]: + with self.kv_event_lock: + events = self.kv_events.copy() + self.kv_events.clear() + return events + class KVCacheStoreSendingThread(KVTransferThread): def __init__( @@ -113,6 +127,7 @@ class KVCacheStoreSendingThread(KVTransferThread): put_step: int, kv_role: str, ready_event: threading.Event, + enable_kv_event: bool = False, ): super().__init__( m_store, token_database, block_size, tp_rank, dcp_size, ready_event, name="KVCacheSendingThread" @@ -120,6 +135,7 @@ class KVCacheStoreSendingThread(KVTransferThread): self.put_step = put_step self.kv_role = kv_role self.stored_requests = defaultdict[str, int](int) + self.enable_kv_event = enable_kv_event def add_stored_request(self, req_id: str): with self.done_task_lock: @@ -188,11 +204,30 @@ class KVCacheStoreSendingThread(KVTransferThread): """ addrs = [] sizes = [] + stored_events: list[BlockStored] = [] + prev_key = None + new_block_hashes = [maybe_convert_block_hash(bh) for bh in req_meta.block_hashes[skip_block_num:]] for index, start in enumerate(starts): addr, size, _ = self.token_database.prepare_value(start, ends[index], block_ids) addrs.append(addr) sizes.append(size) + # Create KV event + if self.enable_kv_event: + token_ids = req_meta.token_ids[start : ends[index]] if req_meta.token_ids is not None else None + stored_event = BlockStored( + block_hashes=[new_block_hashes[index]], + parent_block_hash=prev_key, + token_ids=token_ids, + block_size=req_meta.original_block_size, + lora_id=None, + medium="cpu", + lora_name=None, + ) + stored_events.append(stored_event) + prev_key = new_block_hashes[index] + logger.debug(f"Added kv cache event '{stored_event}' to kv cache events queue") + if self.kv_role == "kv_consumer": keys, addrs, sizes = self.token_database.decode_adaptor_prefill_pp(keys, addrs, sizes) @@ -200,6 +235,10 @@ class KVCacheStoreSendingThread(KVTransferThread): current_event.synchronize() self.m_store.put(keys, addrs, sizes) + # TODO Query specific replica info to update the event + if self.enable_kv_event and stored_events is not None: + self.update_kv_event(stored_events) + self.dec_stored_request(req_id) self.request_queue.task_done() @@ -253,12 +292,14 @@ class KVCacheStoreLayerSendingThread(KVTransferThread): put_step: int, ready_event: threading.Event, num_layers: int, + enable_kv_event: bool = False, ): super().__init__( m_store, token_database, block_size, tp_rank, dcp_size, ready_event, name="KVCacheStoreLayerSendingThread" ) self.final_layer_id = num_layers - 1 self.put_step = put_step + self.enable_kv_event = enable_kv_event def add_request( # type: ignore[override] self, req_meta: ReqMeta diff --git a/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/pool_scheduler.py b/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/pool_scheduler.py index 8b802038..80468456 100644 --- a/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/pool_scheduler.py +++ b/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/pool_scheduler.py @@ -37,6 +37,7 @@ class KVPoolScheduler: self.pcp_size = getattr(vllm_config.parallel_config, "prefill_context_parallel_size", 1) self.dcp_size = getattr(vllm_config.parallel_config, "decode_context_parallel_size", 1) + self.original_block_size = vllm_config.cache_config.block_size self._block_size = vllm_config.cache_config.block_size if self.pcp_size > 1: self._block_size *= self.pcp_size @@ -183,6 +184,7 @@ class KVPoolScheduler: token_len=num_tokens_to_compute, allocated_block_ids=unfolded_block_ids, num_saved_tokens=0, + token_ids=request.prompt_token_ids[:num_tokens_to_compute].copy(), ) self._request_trackers[request.req_id] = request_tracker last_chunk_tokens_num = ( @@ -199,6 +201,7 @@ class KVPoolScheduler: block_hashes=request_real.block_hashes, is_last_chunk=request_tracker.token_len >= last_chunk_tokens_num, discard_partial_chunks=self._discard_partial_chunks, + original_block_size=self.original_block_size, ) if req_meta is not None: meta.add_request(req_meta) @@ -227,6 +230,7 @@ class KVPoolScheduler: token_len=num_tokens_to_compute, allocated_block_ids=new_block_ids, num_saved_tokens=0, + token_ids=request_real.prompt_token_ids[:num_tokens_to_compute].copy(), ) self._request_trackers[req_id] = request_tracker last_chunk_tokens_num = ( @@ -242,6 +246,7 @@ class KVPoolScheduler: block_hashes=request_real.block_hashes, is_last_chunk=request_tracker.token_len >= last_chunk_tokens_num, discard_partial_chunks=self._discard_partial_chunks, + original_block_size=self.original_block_size, ) # decode/chunked request @@ -276,6 +281,7 @@ class KVPoolScheduler: block_hashes=request.block_hashes, is_last_chunk=request_tracker.token_len >= last_chunk_tokens_num, discard_partial_chunks=self._discard_partial_chunks, + original_block_size=self.original_block_size, ) if req_meta is not None: meta.add_request(req_meta) @@ -299,7 +305,6 @@ class KVPoolScheduler: ) self._request_trackers[request_id] = request_tracker - req_meta = ReqMeta.from_request_tracker( request_tracker, self._block_size, diff --git a/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/pool_worker.py b/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/pool_worker.py index 832dbe3c..f82c8fcb 100644 --- a/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/pool_worker.py +++ b/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/pool_worker.py @@ -12,6 +12,7 @@ from vllm.distributed import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, ) +from vllm.distributed.kv_events import BlockStored from vllm.logger import logger from vllm.v1.core.kv_cache_utils import BlockHash @@ -74,6 +75,7 @@ class KVPoolWorker: "consumer_is_to_put", False ) self.backend = vllm_config.kv_transfer_config.kv_connector_extra_config.get("backend", "mooncake") + self.original_block_size = vllm_config.cache_config.block_size self.block_size = vllm_config.cache_config.block_size if self.pcp_size > 1: @@ -146,6 +148,10 @@ class KVPoolWorker: self.m_store = real_backend( # type: ignore[misc] parallel_config ) + kv_event_config = vllm_config.kv_events_config + self.enable_kv_events = False + if kv_event_config and kv_event_config.enable_kv_cache_events: + self.enable_kv_events = True self.kv_send_thread: KVTransferThread | None = None self.kv_recv_thread: KVTransferThread | None = None @@ -209,6 +215,7 @@ class KVPoolWorker: self.put_step, ready_event_sending, self.num_layers, + self.enable_kv_events, ) self.kv_send_thread.start() ready_event = threading.Event() @@ -235,6 +242,7 @@ class KVPoolWorker: self.put_step, self.kv_role, ready_event_sending, + self.enable_kv_events, ) self.kv_send_thread.start() if self.load_async: @@ -641,3 +649,10 @@ class KVPoolWorker: return min(idx for row in arr for idx, val in enumerate(row) if val != 1) except ValueError: return -1 + + def get_kv_events(self) -> list[BlockStored]: + if self.enable_kv_events and self.kv_send_thread is not None: + # collect store kv events form sending thread + events = self.kv_send_thread.get_kv_events() + return events + return []