diff --git a/vllm_ascend/distributed/mooncake/mooncake_engine.py b/vllm_ascend/distributed/mooncake/mooncake_engine.py index 3992618..bff3f3e 100644 --- a/vllm_ascend/distributed/mooncake/mooncake_engine.py +++ b/vllm_ascend/distributed/mooncake/mooncake_engine.py @@ -510,20 +510,18 @@ class MooncakeEngine: ) -> int: """ Checks the existence of KV cache of the tokens from the cache engine. - :param tokens: the input tokens, with shape [seq_len] - :return: An int indicating how many prefix tokens are cached. """ end = 0 - - for start, end, key in self.token_database.process_tokens(tokens): - try: - if use_layerwise: - keys = [] + keys = [] + try: + if use_layerwise: + for start, end, key in self.token_database.process_tokens( + tokens): keys_multi_layer = key.split_layers(self.num_layers) - for key in keys_multi_layer: - keys.append(key.to_string()) + for item in keys_multi_layer: + keys.append(item.to_string()) # batch is_exists ress = self.m_store.batch_exists(keys) res = 1 @@ -531,19 +529,93 @@ class MooncakeEngine: if value != 1: res = 0 break - else: - res = self.m_store.exists(key) - if res == 1: - continue - else: - return start - except Exception as e: - logger.warning(f"Remote connection failed in contains: {e}") - return start - - # all tokens where found, return the maximal end + if res == 1: + continue + else: + return start + else: + starts = [] + for start, end, key in self.token_database.process_tokens( + tokens): + keys.append(key.to_string()) + starts.append(start) + res = self.m_store.batch_exists( + keys) # type: ignore[assignment] + for index, value in enumerate(res): # type: ignore[arg-type] + if value != 1: + return starts[index] + # all tokens where found, return the maximal end + except Exception as e: + logger.error(f"Remote connection failed in contains: {e}") + return start return end + def lookup_scheduler( + self, + tokens: Union[torch.Tensor, List[int]], + use_layerwise: bool, + ) -> int: + """ + Checks the existence of KV cache of the tokens from the cache engine. + :param tokens: the input tokens, with shape [seq_len] + :return: An int indicating how many prefix tokens are cached. + """ + end = 0 + keys = [] + try: + if use_layerwise: + for start, end, key in self.token_database.process_tokens( + tokens): + keys_multi_layer = key.split_layers(self.num_layers) + for item in keys_multi_layer: + keys.append(item.to_string()) + # batch is_exists + ress = self.m_store.batch_exists(keys) + res = 1 + for value in ress: + if value != 1: + res = 0 + break + if res == 1: + continue + else: + return start + else: + starts = [] + for start, end, key in self.token_database.process_tokens( + tokens): + keys.append(key.to_string()) + starts.append(start) + multi_tp_keys = keys[:] + for i in range(1, self.tp_size): + for item in keys: + new_str = item.replace( # type: ignore[attr-defined] + "@0", f"@{i}", 1) + multi_tp_keys.append(new_str) + res = self.m_store.batch_exists( + multi_tp_keys) # type: ignore[assignment] + num_block = len(keys) + multi_tp_values = [ + res[i * num_block:(i + 1) * + num_block] # type: ignore[index] + for i in range(self.tp_size) + ] + index = self.find_min_first_non_one_index(multi_tp_values) + if index != -1: + return starts[index] + # all tokens where found, return the maximal end + except Exception as e: + logger.error(f"Remote connection failed in contains: {e}") + return start + return end + + def find_min_first_non_one_index(self, arr): + try: + return min(idx for row in arr for idx, val in enumerate(row) + if val != 1) + except ValueError: + return -1 + def close(self) -> None: """Close the cache engine and free all the resources""" self.m_store.close() diff --git a/vllm_ascend/distributed/mooncake/mooncake_store.py b/vllm_ascend/distributed/mooncake/mooncake_store.py index ed2f4bc..cee07c6 100644 --- a/vllm_ascend/distributed/mooncake/mooncake_store.py +++ b/vllm_ascend/distributed/mooncake/mooncake_store.py @@ -68,7 +68,7 @@ class Mooncakestore(): def exists(self, key: MooncakeEngineKey) -> bool: return self.store.is_exist(key.to_string()) == 1 - def batch_exists(self, keys: list[str]) -> list[bool]: + def batch_exists(self, keys: list[str]) -> list[int]: return self.store.batch_is_exist(keys) def register_buffer(self, ptr, length): diff --git a/vllm_ascend/distributed/mooncake/mooncake_store_connector_v1.py b/vllm_ascend/distributed/mooncake/mooncake_store_connector_v1.py index 954b78a..f55dd03 100644 --- a/vllm_ascend/distributed/mooncake/mooncake_store_connector_v1.py +++ b/vllm_ascend/distributed/mooncake/mooncake_store_connector_v1.py @@ -479,7 +479,8 @@ class MooncakeLookupServer: while self.running: frames = self.socket.recv_multipart(copy=False) token_ids = self.decoder.decode(frames) - result = self.mooncake_engine.lookup(token_ids, use_layerwise) + result = self.mooncake_engine.lookup_scheduler( + token_ids, use_layerwise) response = result.to_bytes(4, "big") self.socket.send(response)