### What this PR does / why we need it? In multi-Tensor Parallel (TP) scenarios, the KV pool only queries the first GPU card. When keys on other cards are released, the query result still returns as successful, introducing accuracy issues. This PR modifies the KV pool's query logic to check all cards, resolving this problem. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 Signed-off-by: fems14 <1804143737@qq.com>
This commit is contained in:
@@ -510,20 +510,18 @@ class MooncakeEngine:
|
||||
) -> int:
|
||||
"""
|
||||
Checks the existence of KV cache of the tokens from the cache engine.
|
||||
|
||||
:param tokens: the input tokens, with shape [seq_len]
|
||||
|
||||
:return: An int indicating how many prefix tokens are cached.
|
||||
"""
|
||||
end = 0
|
||||
|
||||
for start, end, key in self.token_database.process_tokens(tokens):
|
||||
try:
|
||||
if use_layerwise:
|
||||
keys = []
|
||||
keys = []
|
||||
try:
|
||||
if use_layerwise:
|
||||
for start, end, key in self.token_database.process_tokens(
|
||||
tokens):
|
||||
keys_multi_layer = key.split_layers(self.num_layers)
|
||||
for key in keys_multi_layer:
|
||||
keys.append(key.to_string())
|
||||
for item in keys_multi_layer:
|
||||
keys.append(item.to_string())
|
||||
# batch is_exists
|
||||
ress = self.m_store.batch_exists(keys)
|
||||
res = 1
|
||||
@@ -531,19 +529,93 @@ class MooncakeEngine:
|
||||
if value != 1:
|
||||
res = 0
|
||||
break
|
||||
else:
|
||||
res = self.m_store.exists(key)
|
||||
if res == 1:
|
||||
continue
|
||||
else:
|
||||
return start
|
||||
except Exception as e:
|
||||
logger.warning(f"Remote connection failed in contains: {e}")
|
||||
return start
|
||||
|
||||
# all tokens where found, return the maximal end
|
||||
if res == 1:
|
||||
continue
|
||||
else:
|
||||
return start
|
||||
else:
|
||||
starts = []
|
||||
for start, end, key in self.token_database.process_tokens(
|
||||
tokens):
|
||||
keys.append(key.to_string())
|
||||
starts.append(start)
|
||||
res = self.m_store.batch_exists(
|
||||
keys) # type: ignore[assignment]
|
||||
for index, value in enumerate(res): # type: ignore[arg-type]
|
||||
if value != 1:
|
||||
return starts[index]
|
||||
# all tokens where found, return the maximal end
|
||||
except Exception as e:
|
||||
logger.error(f"Remote connection failed in contains: {e}")
|
||||
return start
|
||||
return end
|
||||
|
||||
def lookup_scheduler(
|
||||
self,
|
||||
tokens: Union[torch.Tensor, List[int]],
|
||||
use_layerwise: bool,
|
||||
) -> int:
|
||||
"""
|
||||
Checks the existence of KV cache of the tokens from the cache engine.
|
||||
:param tokens: the input tokens, with shape [seq_len]
|
||||
:return: An int indicating how many prefix tokens are cached.
|
||||
"""
|
||||
end = 0
|
||||
keys = []
|
||||
try:
|
||||
if use_layerwise:
|
||||
for start, end, key in self.token_database.process_tokens(
|
||||
tokens):
|
||||
keys_multi_layer = key.split_layers(self.num_layers)
|
||||
for item in keys_multi_layer:
|
||||
keys.append(item.to_string())
|
||||
# batch is_exists
|
||||
ress = self.m_store.batch_exists(keys)
|
||||
res = 1
|
||||
for value in ress:
|
||||
if value != 1:
|
||||
res = 0
|
||||
break
|
||||
if res == 1:
|
||||
continue
|
||||
else:
|
||||
return start
|
||||
else:
|
||||
starts = []
|
||||
for start, end, key in self.token_database.process_tokens(
|
||||
tokens):
|
||||
keys.append(key.to_string())
|
||||
starts.append(start)
|
||||
multi_tp_keys = keys[:]
|
||||
for i in range(1, self.tp_size):
|
||||
for item in keys:
|
||||
new_str = item.replace( # type: ignore[attr-defined]
|
||||
"@0", f"@{i}", 1)
|
||||
multi_tp_keys.append(new_str)
|
||||
res = self.m_store.batch_exists(
|
||||
multi_tp_keys) # type: ignore[assignment]
|
||||
num_block = len(keys)
|
||||
multi_tp_values = [
|
||||
res[i * num_block:(i + 1) *
|
||||
num_block] # type: ignore[index]
|
||||
for i in range(self.tp_size)
|
||||
]
|
||||
index = self.find_min_first_non_one_index(multi_tp_values)
|
||||
if index != -1:
|
||||
return starts[index]
|
||||
# all tokens where found, return the maximal end
|
||||
except Exception as e:
|
||||
logger.error(f"Remote connection failed in contains: {e}")
|
||||
return start
|
||||
return end
|
||||
|
||||
def find_min_first_non_one_index(self, arr):
|
||||
try:
|
||||
return min(idx for row in arr for idx, val in enumerate(row)
|
||||
if val != 1)
|
||||
except ValueError:
|
||||
return -1
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the cache engine and free all the resources"""
|
||||
self.m_store.close()
|
||||
|
||||
@@ -68,7 +68,7 @@ class Mooncakestore():
|
||||
def exists(self, key: MooncakeEngineKey) -> bool:
|
||||
return self.store.is_exist(key.to_string()) == 1
|
||||
|
||||
def batch_exists(self, keys: list[str]) -> list[bool]:
|
||||
def batch_exists(self, keys: list[str]) -> list[int]:
|
||||
return self.store.batch_is_exist(keys)
|
||||
|
||||
def register_buffer(self, ptr, length):
|
||||
|
||||
@@ -479,7 +479,8 @@ class MooncakeLookupServer:
|
||||
while self.running:
|
||||
frames = self.socket.recv_multipart(copy=False)
|
||||
token_ids = self.decoder.decode(frames)
|
||||
result = self.mooncake_engine.lookup(token_ids, use_layerwise)
|
||||
result = self.mooncake_engine.lookup_scheduler(
|
||||
token_ids, use_layerwise)
|
||||
response = result.to_bytes(4, "big")
|
||||
self.socket.send(response)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user