import math import threading from typing import Dict, Generator, Optional, Type import torch from vllm.config import VllmConfig from vllm.distributed import (get_decode_context_model_parallel_rank, get_decode_context_model_parallel_world_size, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.logger import logger from vllm.v1.core.kv_cache_utils import BlockHash from vllm_ascend.distributed.kvpool.backend.backend import Backend from vllm_ascend.distributed.kvpool.backend.memcache_backend import \ MemcacheBackend from vllm_ascend.distributed.kvpool.backend.mooncake_backend import \ MooncakeBackend from vllm_ascend.distributed.kvpool.config_data import ( AscendConnectorMetadata, ChunkedTokenDatabase, KeyMetadata, LasyerMultiBlockReqMeta) from vllm_ascend.distributed.kvpool.kv_transfer import ( KVCacheStoreLayerRecvingThread, KVCacheStoreLayerSendingThread, KVCacheStoreRecvingThread, KVCacheStoreSendingThread, KVTransferThread) from vllm_ascend.utils import prefill_context_parallel_enable if prefill_context_parallel_enable(): # isort: off from vllm.distributed import (get_prefill_context_model_parallel_rank, get_prefill_context_model_parallel_world_size ) # isort: on backend_map: Dict[str, Type[Backend]] = { "mooncake": MooncakeBackend, "memcache": MemcacheBackend, } class KVPoolWorker: #The main class for the cache engine. def __init__( self, vllm_config: VllmConfig, use_layerwize: bool, ): model_config = vllm_config.model_config parallel_config = vllm_config.parallel_config self.dp_rank = parallel_config.data_parallel_rank self.use_mla = False if (hasattr(model_config, "use_mla") and isinstance(model_config.use_mla, bool) and model_config.use_mla): self.use_mla = True self.use_layerwise = use_layerwize self.tp_rank = get_tensor_model_parallel_rank() self.tp_size = get_tensor_model_parallel_world_size() self.pcp_size = get_prefill_context_model_parallel_world_size( ) if prefill_context_parallel_enable() else 1 self.pcp_rank = get_prefill_context_model_parallel_rank( ) if self.pcp_size > 1 else 0 self.dcp_size = get_decode_context_model_parallel_world_size() self.dcp_rank = get_decode_context_model_parallel_rank( ) if self.dcp_size > 1 else 0 self.kv_role = vllm_config.kv_transfer_config.kv_role self.load_async = vllm_config.kv_transfer_config.kv_connector_extra_config.get( "load_async", False) self.backend = vllm_config.kv_transfer_config.kv_connector_extra_config.get( "backend", "mooncake") self.block_size = vllm_config.cache_config.block_size if self.pcp_size > 1: self.block_size *= self.pcp_size if self.dcp_size > 1: self.block_size *= self.dcp_size self.current_layer = 0 self.num_layers = model_config.get_num_layers(parallel_config) if self.use_mla: self.num_kv_head = 1 else: self.num_kv_head = model_config.get_total_num_kv_heads() if self.num_kv_head < self.tp_size: self.put_step = self.tp_size // self.num_kv_head self.head_or_tp_rank = self.tp_rank // self.put_step else: self.head_or_tp_rank = self.tp_rank self.put_step = 1 self.metadata = KeyMetadata( model_config.model.split('/')[-1], self.head_or_tp_rank, self.pcp_rank, self.dcp_rank, ) self.token_database = ChunkedTokenDatabase(self.metadata, self.block_size, self.use_mla) real_backend = backend_map.get(self.backend.lower()) self.m_store = real_backend( # type: ignore[misc] parallel_config) self.kv_send_thread: Optional[KVTransferThread] = None self.kv_recv_thread: Optional[KVTransferThread] = None def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): _, first_kv_cache_tuple = next(iter(kv_caches.items())) first_kv_cache = first_kv_cache_tuple[0] # TODO(tms): Find a more robust way to detect and handle MLA if self.use_mla: # MLA case.[num_block, block_size, 1, hidden_dim] self.num_blocks = first_kv_cache.shape[0] block_rank = 3 # [block_size, latent_dim] block_shape_norm = first_kv_cache_tuple[0].shape[-block_rank:] block_shape_pe = first_kv_cache_tuple[1].shape[-block_rank:] self.block_len = [ first_kv_cache[0].element_size() * math.prod(block_shape_norm), first_kv_cache[1].element_size() * math.prod(block_shape_pe) ] logger.info( "num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s", self.num_blocks, block_shape_norm, block_shape_pe) else: # [num_block, block_size, num_head, hidden_dim] self.num_blocks = first_kv_cache.shape[0] kv_elem_size = first_kv_cache.element_size() block_rank = 3 # [block_size, kv_heads, head_dim] block_shape = first_kv_cache.shape[-block_rank:] self.block_len = [kv_elem_size * math.prod(block_shape)] logger.info("num_blocks: %s, block_shape: %s", self.num_blocks, block_shape) logger.info("Registering KV_Caches. use_mla: %s, shape %s", self.use_mla, first_kv_cache.shape) self.kv_caches = kv_caches self.kv_caches_base_addr = [] ptrs = [] lengths = [] for cache_or_caches in kv_caches.values(): # Normalize to always be a list of caches if self.use_mla: for i, cache in enumerate(cache_or_caches, 0): base_addr = cache.data_ptr() self.kv_caches_base_addr.append(base_addr) region_len = self.num_blocks * self.block_len[i % 2] ptrs.append(base_addr) lengths.append(region_len) else: cache_list = [cache_or_caches ] if self.use_mla else cache_or_caches for cache in cache_list: base_addr = cache.data_ptr() self.kv_caches_base_addr.append(base_addr) region_len = self.num_blocks * self.block_len[0] ptrs.append(base_addr) lengths.append(region_len) self.m_store.register_buffer(ptrs, lengths) self.token_database.set_kv_caches_base_addr(self.kv_caches_base_addr) self.token_database.set_block_len(self.block_len) if self.use_layerwise: self.get_event = threading.Event() if self.kv_role in ['kv_producer', 'kv_both']: ready_event_sending = threading.Event() self.kv_send_thread = KVCacheStoreLayerSendingThread( self.m_store, self.token_database, self.tp_rank, self.dcp_size, self.put_step, ready_event_sending, self.num_layers) self.kv_send_thread.start() ready_event = threading.Event() self.kv_recv_thread = KVCacheStoreLayerRecvingThread( self.m_store, self.token_database, self.tp_rank, self.dcp_size, ready_event, self.get_event) self.kv_recv_thread.start() ready_event.wait() else: if self.kv_role in ['kv_producer', 'kv_both']: ready_event_sending = threading.Event() self.kv_send_thread = KVCacheStoreSendingThread( self.m_store, self.token_database, self.tp_rank, self.dcp_size, self.put_step, ready_event_sending) self.kv_send_thread.start() if self.load_async: ready_event = threading.Event() self.kv_recv_thread = KVCacheStoreRecvingThread( self.m_store, self.token_database, self.tp_rank, self.dcp_size, ready_event) self.kv_recv_thread.start() ready_event.wait() def start_load_kv(self, metadata: AscendConnectorMetadata): self.current_layer = 0 self.layerwise_retrievers = [] for request in metadata.requests: load_spec = request.load_spec if load_spec is None or not load_spec.can_load: #load =0 continue token_len = request.token_len_chunk req_id = request.req_id if (load_spec.kvpool_cached_tokens % self.block_size != 0) and (load_spec.kvpool_cached_tokens == token_len - 1): token_len = request.load_spec.kvpool_cached_tokens + 1 else: token_len = request.load_spec.kvpool_cached_tokens mask_num = (request.load_spec.vllm_cached_tokens // self.block_size * self.block_size) if self.use_layerwise: layerwise_retriever = self.retrieve_layer( req_id, token_len, request.block_ids, request.block_hashes, mask_num, ) next(layerwise_retriever) # first layer load self.layerwise_retrievers.append(layerwise_retriever) else: if self.load_async: self.kv_recv_thread.add_request( # type: ignore[union-attr] req_id, token_len, request.block_ids, request.block_hashes, mask_num, ) else: addr_list = [] size_list = [] key_list = [] for start, end, key in self.token_database.process_tokens( token_len, request.block_hashes, mask_num): addr, size, _ = self.token_database.prepare_value( start, end, request.block_ids) key_list.append(key.to_string()) addr_list.append(addr) size_list.append(size) key_list_c = key_list[self.tp_rank % len( key_list):] + key_list[:self.tp_rank % len(key_list)] addr_list_c = addr_list[self.tp_rank % len(addr_list ):] + addr_list[:self.tp_rank % len(addr_list)] size_list_c = size_list[self.tp_rank % len(size_list ):] + size_list[:self.tp_rank % len(size_list)] self.m_store.get(key_list_c, addr_list_c, size_list_c) def wait_for_layer_load(self) -> None: for layerwise_retriever in self.layerwise_retrievers: ret_token_mask = next(layerwise_retriever) if self.current_layer == self.num_layers - 1: assert ret_token_mask is not None num_retrieved_tokens = ret_token_mask.sum().item() logger.debug(f"Retrieved {num_retrieved_tokens} tokens") def save_kv_layer(self, connector_metadata: AscendConnectorMetadata) -> None: if self.current_layer == 0: self.layerwise_storers = [] for request in connector_metadata.requests: can_save = request.can_save if can_save is None or not can_save: continue token_len = request.token_len_chunk req_id = request.req_id # TODO: whether need to remov saveThread # no lookup, skipmask skip_leading_tokens = self.lookup(token_len, request.block_hashes, self.use_layerwise) if skip_leading_tokens == token_len: if request.is_last_chunk: self.kv_send_thread.set_finished_request( # type: ignore[union-attr] req_id) continue # skip this request mask_num = (skip_leading_tokens // self.block_size * self.block_size) logger.info( "Storing KV cache for %d out of %d tokens " "(skip_leading_tokens=%d) for request %s", token_len - skip_leading_tokens, token_len, skip_leading_tokens, request.req_id, ) layerwise_storer = self.store_layer( req_id, token_len, block_hashes=request.block_hashes, mask_num=mask_num, block_ids=request.block_ids, is_last_chunk=request.is_last_chunk, ) self.layerwise_storers.append(layerwise_storer) for layerwise_storer in self.layerwise_storers: try: next(layerwise_storer) except Exception: raise self.current_layer = self.current_layer + 1 def wait_for_save(self, connector_metadata: AscendConnectorMetadata): for request in connector_metadata.requests: can_save = request.can_save if can_save is None or not can_save: continue token_len = request.token_len_chunk req_id = request.req_id skip_leading_tokens = self.lookup(token_len, request.block_hashes, self.use_layerwise) if skip_leading_tokens == token_len: if request.is_last_chunk: self.kv_send_thread.set_finished_request( # type: ignore[union-attr] req_id) continue # skip this request mask_num = (skip_leading_tokens // self.block_size * self.block_size) logger.info( "Storing KV cache for %d out of %d tokens " "(skip_leading_tokens=%d) for request %s", token_len - skip_leading_tokens, token_len, skip_leading_tokens, request.req_id, ) self.kv_send_thread.add_request( # type: ignore[union-attr] req_id, token_len, request.block_ids, request.block_hashes, mask_num, request.is_last_chunk, ) def retrieve_layer( self, req_id: str, token_len: int, block_ids: list[int], block_hashes: list[BlockHash], mask_num: int = 0, ) -> Generator[Optional[torch.Tensor], None, None]: """ Retrieve the KV cache in a layerwise manner. :param torch.Tensor tokens: The tokens of the corresponding KV caches. :param Optional[torch.Tensor] mask: The mask for the tokens. Should have the same length as tokens. And the mask should ALWAYS be like FFFFFTTTTTTT, where True means the tokens needs to be matched. :param **kwargs: The additional arguments for the KV transfer which will be passed into the npu_transfer. return: A generator that yields Optional[torch.Tensor]. The tensor will be the boolean mask indicating which tokens are retrieved and will only be returned in the last iteration. """ num_required_tokens = token_len - mask_num ret_mask = torch.zeros(token_len, dtype=torch.bool, device="cpu") starts = [] ends = [] keys = [] first_flag = True for start, end, key in self.token_database.process_tokens( token_len, block_hashes, mask_num): keys_multi_layer = key.split_layers(self.num_layers) starts.append(start) ends.append(end) keys.append(keys_multi_layer) ret_mask[start:end] = True if keys: # Transpose the keys into layer major format keys = [list(row) for row in zip(*keys)] # [num_layer,block_num] for layer_id, keys_multi_chunk in enumerate(keys): if not first_flag: is_finish = self.get_event.wait(timeout=3) #try---cache if not is_finish: logger.info("Layerwise get failed") self.get_event.clear() req_meta = LasyerMultiBlockReqMeta(req_id, keys_multi_chunk, starts, ends, block_ids, layer_id) self.kv_recv_thread.add_request( # type: ignore[union-attr, call-arg] req_meta) # type: ignore[union-attr, call-arg, arg-type] first_flag = False yield None else: # If no cache are found, we still need to yield to avoid # `StopIteration` for layer_id in range(self.num_layers): yield None retrieved_tokens = torch.sum(ret_mask) logger.debug(f"Retrieved {retrieved_tokens} " f"out of {num_required_tokens} " f"out of total {token_len} tokens") yield ret_mask def store_layer( self, req_id: str, token_len: int, block_ids: list[int], block_hashes: list[BlockHash], is_last_chunk: bool, mask_num: int = 0, ) -> Generator[None, None, None]: """ Store the KV cache in a layerwise manner. :param torch.Tensor tokens: The tokens of the corresponding KV caches. :param Optional[torch.Tensor] mask: The mask for the tokens. Should have the same length as tokens. And the mask should ALWAYS be like FFFFFTTTTTTT, where True means the tokens needs to be matched. :param **kwargs: The additional arguments for the storage backend which will be passed into the gpu_connector. return: A generator that yields None. In the first iteration, the generator allocates the memory objects for all layers and moves the KV cache of the first layer from GPU to CPU. In the next iterations, it moves the KV cache of layer i from GPU to the memory objects (on CPU) and puts the memory objects of layer i-1 to the storage backends. In the last iteration, it puts the memory objects of the last layer to the storage backends. """ num_stored_tokens = token_len - mask_num starts = [] ends = [] keys = [] for start, end, key in self.token_database.process_tokens( token_len, block_hashes, mask_num): keys_multi_layer = key.split_layers(self.num_layers) starts.append(start) ends.append(end) keys.append(keys_multi_layer) #[block_num,layer_num] if keys: keys = [list(row) for row in zip(*keys)] #[layer_num,block_num] for layer_id, keys_multi_chunk in enumerate(keys): req_meta = LasyerMultiBlockReqMeta(req_id, keys_multi_chunk, starts, ends, block_ids, layer_id, is_last_chunk) self.kv_send_thread.add_request( # type: ignore[union-attr, call-arg] req_meta) # type: ignore[union-attr, call-arg, arg-type] yield else: for layer_id in range(self.num_layers): yield logger.debug( f"Stored {num_stored_tokens} out of total {token_len} tokens") def get_finished(self) -> tuple[set[str], set[str]]: done_sending = ( self.kv_send_thread. get_and_clear_finished_requests( # type: ignore[union-attr] ) if self.kv_role in ['kv_producer', 'kv_both'] else set()) done_recving = ( self.kv_recv_thread. get_and_clear_finished_requests( # type: ignore[union-attr] ) if self.load_async else set()) logger.debug( "Number of completed KV cache send requests: %d, receive " "requests: %d, tp_rank:%d", len(done_sending), len(done_recving), self.tp_rank) return done_sending, done_recving def lookup( self, token_len: int, block_hashes: list[BlockHash], use_layerwise: bool, ) -> int: """ Checks the existence of KV cache of the tokens from the cache engine. :param tokens: the input tokens, with shape [seq_len] :return: An int indicating how many prefix tokens are cached. """ end = 0 keys = [] try: starts = [] for start, end, key in self.token_database.process_tokens( token_len, block_hashes): if use_layerwise: keys_multi_layer = key.split_layers(self.num_layers) for item in keys_multi_layer: keys.append(item.to_string()) else: keys.append(key.to_string()) starts.append(start) res = self.m_store.exists(keys) # type: ignore[assignment] if use_layerwise: res = self.check_all_layers_exists(res, self.num_layers) for index, value in enumerate(res): # type: ignore[arg-type] if value != 1: return starts[index] # all tokens where found, return the maximal end except Exception as e: logger.error(f"Remote connection failed in contains: {e}") return start return end def lookup_scheduler( self, token_len: int, block_hashes: list[BlockHash], use_layerwise: bool, ) -> int: """ Checks the existence of KV cache of the tokens from the cache engine. :param tokens: the input tokens, with shape [seq_len] :return: An int indicating how many prefix tokens are cached. """ end = 0 keys = [] try: starts = [] for start, end, key in self.token_database.process_tokens( token_len, block_hashes): if use_layerwise: keys_multi_layer = key.split_layers(self.num_layers) for item in keys_multi_layer: keys.append(item.to_string()) else: keys.append(key.to_string()) starts.append(start) multi_tp_keys = keys[:] for i in range(1, min(self.tp_size, self.num_kv_head)): for item in keys: new_str = item.replace( # type: ignore[attr-defined] "@head_or_tp_rank:0", f"@head_or_tp_rank:{i}", 1) multi_tp_keys.append(new_str) res = self.m_store.exists( multi_tp_keys) # type: ignore[assignment] num_block = len(keys) if use_layerwise: res = self.check_all_layers_exists(res, self.num_layers) num_block = len(keys) // self.num_layers multi_tp_values = [ res[i * num_block:(i + 1) * num_block] # type: ignore[index] for i in range(min(self.tp_size, self.num_kv_head)) ] index = self.find_min_first_non_one_index(multi_tp_values) if index != -1: return starts[index] # all tokens where found, return the maximal end except Exception as e: logger.error(f"Remote connection failed in contains: {e}") return start return end def check_all_layers_exists(self, res: list[int], num_layers: int) -> list[int]: total_chunks = len(res) // num_layers result = [] for chunk_idx in range(total_chunks): start = chunk_idx * num_layers end = start + num_layers chunk = res[start:end] result.append(1 if all(x == 1 for x in chunk) else 0) return result def find_min_first_non_one_index(self, arr): try: return min(idx for row in arr for idx, val in enumerate(row) if val != 1) except ValueError: return -1