From 17dd9ae42c9f1905af5873577f8344f2c3442b33 Mon Sep 17 00:00:00 2001 From: fems14 <74094523+fems14@users.noreply.github.com> Date: Fri, 24 Oct 2025 18:22:45 +0800 Subject: [PATCH] [0.11.0][bugfix]look up multi_tp key (#3699) (#3723) ### What this PR does / why we need it? In multi-Tensor Parallel (TP) scenarios, the KV pool only queries the first GPU card. When keys on other cards are released, the query result still returns as successful, introducing accuracy issues. This PR modifies the KV pool's query logic to check all cards, resolving this problem. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 Signed-off-by: fems14 <1804143737@qq.com> --- .../distributed/mooncake/mooncake_engine.py | 112 ++++++++++++++---- .../distributed/mooncake/mooncake_store.py | 2 +- .../mooncake/mooncake_store_connector_v1.py | 3 +- 3 files changed, 95 insertions(+), 22 deletions(-) diff --git a/vllm_ascend/distributed/mooncake/mooncake_engine.py b/vllm_ascend/distributed/mooncake/mooncake_engine.py index 3992618..bff3f3e 100644 --- a/vllm_ascend/distributed/mooncake/mooncake_engine.py +++ b/vllm_ascend/distributed/mooncake/mooncake_engine.py @@ -510,20 +510,18 @@ class MooncakeEngine: ) -> int: """ Checks the existence of KV cache of the tokens from the cache engine. - :param tokens: the input tokens, with shape [seq_len] - :return: An int indicating how many prefix tokens are cached. """ end = 0 - - for start, end, key in self.token_database.process_tokens(tokens): - try: - if use_layerwise: - keys = [] + keys = [] + try: + if use_layerwise: + for start, end, key in self.token_database.process_tokens( + tokens): keys_multi_layer = key.split_layers(self.num_layers) - for key in keys_multi_layer: - keys.append(key.to_string()) + for item in keys_multi_layer: + keys.append(item.to_string()) # batch is_exists ress = self.m_store.batch_exists(keys) res = 1 @@ -531,19 +529,93 @@ class MooncakeEngine: if value != 1: res = 0 break - else: - res = self.m_store.exists(key) - if res == 1: - continue - else: - return start - except Exception as e: - logger.warning(f"Remote connection failed in contains: {e}") - return start - - # all tokens where found, return the maximal end + if res == 1: + continue + else: + return start + else: + starts = [] + for start, end, key in self.token_database.process_tokens( + tokens): + keys.append(key.to_string()) + starts.append(start) + res = self.m_store.batch_exists( + keys) # type: ignore[assignment] + for index, value in enumerate(res): # type: ignore[arg-type] + if value != 1: + return starts[index] + # all tokens where found, return the maximal end + except Exception as e: + logger.error(f"Remote connection failed in contains: {e}") + return start return end + def lookup_scheduler( + self, + tokens: Union[torch.Tensor, List[int]], + use_layerwise: bool, + ) -> int: + """ + Checks the existence of KV cache of the tokens from the cache engine. + :param tokens: the input tokens, with shape [seq_len] + :return: An int indicating how many prefix tokens are cached. + """ + end = 0 + keys = [] + try: + if use_layerwise: + for start, end, key in self.token_database.process_tokens( + tokens): + keys_multi_layer = key.split_layers(self.num_layers) + for item in keys_multi_layer: + keys.append(item.to_string()) + # batch is_exists + ress = self.m_store.batch_exists(keys) + res = 1 + for value in ress: + if value != 1: + res = 0 + break + if res == 1: + continue + else: + return start + else: + starts = [] + for start, end, key in self.token_database.process_tokens( + tokens): + keys.append(key.to_string()) + starts.append(start) + multi_tp_keys = keys[:] + for i in range(1, self.tp_size): + for item in keys: + new_str = item.replace( # type: ignore[attr-defined] + "@0", f"@{i}", 1) + multi_tp_keys.append(new_str) + res = self.m_store.batch_exists( + multi_tp_keys) # type: ignore[assignment] + num_block = len(keys) + multi_tp_values = [ + res[i * num_block:(i + 1) * + num_block] # type: ignore[index] + for i in range(self.tp_size) + ] + index = self.find_min_first_non_one_index(multi_tp_values) + if index != -1: + return starts[index] + # all tokens where found, return the maximal end + except Exception as e: + logger.error(f"Remote connection failed in contains: {e}") + return start + return end + + def find_min_first_non_one_index(self, arr): + try: + return min(idx for row in arr for idx, val in enumerate(row) + if val != 1) + except ValueError: + return -1 + def close(self) -> None: """Close the cache engine and free all the resources""" self.m_store.close() diff --git a/vllm_ascend/distributed/mooncake/mooncake_store.py b/vllm_ascend/distributed/mooncake/mooncake_store.py index ed2f4bc..cee07c6 100644 --- a/vllm_ascend/distributed/mooncake/mooncake_store.py +++ b/vllm_ascend/distributed/mooncake/mooncake_store.py @@ -68,7 +68,7 @@ class Mooncakestore(): def exists(self, key: MooncakeEngineKey) -> bool: return self.store.is_exist(key.to_string()) == 1 - def batch_exists(self, keys: list[str]) -> list[bool]: + def batch_exists(self, keys: list[str]) -> list[int]: return self.store.batch_is_exist(keys) def register_buffer(self, ptr, length): diff --git a/vllm_ascend/distributed/mooncake/mooncake_store_connector_v1.py b/vllm_ascend/distributed/mooncake/mooncake_store_connector_v1.py index 954b78a..f55dd03 100644 --- a/vllm_ascend/distributed/mooncake/mooncake_store_connector_v1.py +++ b/vllm_ascend/distributed/mooncake/mooncake_store_connector_v1.py @@ -479,7 +479,8 @@ class MooncakeLookupServer: while self.running: frames = self.socket.recv_multipart(copy=False) token_ids = self.decoder.decode(frames) - result = self.mooncake_engine.lookup(token_ids, use_layerwise) + result = self.mooncake_engine.lookup_scheduler( + token_ids, use_layerwise) response = result.to_bytes(4, "big") self.socket.send(response)