[0.11.0][bugfix]look up multi_tp key (#3699) (#3723)

### What this PR does / why we need it?
In multi-Tensor Parallel (TP) scenarios, the KV pool only queries the
first GPU card. When keys on other cards are released, the query result
still returns as successful, introducing accuracy issues. This PR
modifies the KV pool's query logic to check all cards, resolving this
problem.
### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

Signed-off-by: fems14 <1804143737@qq.com>
This commit is contained in:
fems14
2025-10-24 18:22:45 +08:00
committed by GitHub
parent f0eb3e1d97
commit 17dd9ae42c
3 changed files with 95 additions and 22 deletions

View File

@@ -510,20 +510,18 @@ class MooncakeEngine:
) -> int:
"""
Checks the existence of KV cache of the tokens from the cache engine.
:param tokens: the input tokens, with shape [seq_len]
:return: An int indicating how many prefix tokens are cached.
"""
end = 0
for start, end, key in self.token_database.process_tokens(tokens):
try:
if use_layerwise:
keys = []
keys = []
try:
if use_layerwise:
for start, end, key in self.token_database.process_tokens(
tokens):
keys_multi_layer = key.split_layers(self.num_layers)
for key in keys_multi_layer:
keys.append(key.to_string())
for item in keys_multi_layer:
keys.append(item.to_string())
# batch is_exists
ress = self.m_store.batch_exists(keys)
res = 1
@@ -531,19 +529,93 @@ class MooncakeEngine:
if value != 1:
res = 0
break
else:
res = self.m_store.exists(key)
if res == 1:
continue
else:
return start
except Exception as e:
logger.warning(f"Remote connection failed in contains: {e}")
return start
# all tokens where found, return the maximal end
if res == 1:
continue
else:
return start
else:
starts = []
for start, end, key in self.token_database.process_tokens(
tokens):
keys.append(key.to_string())
starts.append(start)
res = self.m_store.batch_exists(
keys) # type: ignore[assignment]
for index, value in enumerate(res): # type: ignore[arg-type]
if value != 1:
return starts[index]
# all tokens where found, return the maximal end
except Exception as e:
logger.error(f"Remote connection failed in contains: {e}")
return start
return end
def lookup_scheduler(
self,
tokens: Union[torch.Tensor, List[int]],
use_layerwise: bool,
) -> int:
"""
Checks the existence of KV cache of the tokens from the cache engine.
:param tokens: the input tokens, with shape [seq_len]
:return: An int indicating how many prefix tokens are cached.
"""
end = 0
keys = []
try:
if use_layerwise:
for start, end, key in self.token_database.process_tokens(
tokens):
keys_multi_layer = key.split_layers(self.num_layers)
for item in keys_multi_layer:
keys.append(item.to_string())
# batch is_exists
ress = self.m_store.batch_exists(keys)
res = 1
for value in ress:
if value != 1:
res = 0
break
if res == 1:
continue
else:
return start
else:
starts = []
for start, end, key in self.token_database.process_tokens(
tokens):
keys.append(key.to_string())
starts.append(start)
multi_tp_keys = keys[:]
for i in range(1, self.tp_size):
for item in keys:
new_str = item.replace( # type: ignore[attr-defined]
"@0", f"@{i}", 1)
multi_tp_keys.append(new_str)
res = self.m_store.batch_exists(
multi_tp_keys) # type: ignore[assignment]
num_block = len(keys)
multi_tp_values = [
res[i * num_block:(i + 1) *
num_block] # type: ignore[index]
for i in range(self.tp_size)
]
index = self.find_min_first_non_one_index(multi_tp_values)
if index != -1:
return starts[index]
# all tokens where found, return the maximal end
except Exception as e:
logger.error(f"Remote connection failed in contains: {e}")
return start
return end
def find_min_first_non_one_index(self, arr):
try:
return min(idx for row in arr for idx, val in enumerate(row)
if val != 1)
except ValueError:
return -1
def close(self) -> None:
"""Close the cache engine and free all the resources"""
self.m_store.close()

View File

@@ -68,7 +68,7 @@ class Mooncakestore():
def exists(self, key: MooncakeEngineKey) -> bool:
return self.store.is_exist(key.to_string()) == 1
def batch_exists(self, keys: list[str]) -> list[bool]:
def batch_exists(self, keys: list[str]) -> list[int]:
return self.store.batch_is_exist(keys)
def register_buffer(self, ptr, length):

View File

@@ -479,7 +479,8 @@ class MooncakeLookupServer:
while self.running:
frames = self.socket.recv_multipart(copy=False)
token_ids = self.decoder.decode(frames)
result = self.mooncake_engine.lookup(token_ids, use_layerwise)
result = self.mooncake_engine.lookup_scheduler(
token_ids, use_layerwise)
response = result.to_bytes(4, "big")
self.socket.send(response)