### What this PR does / why we need it? In multi-Tensor Parallel (TP) scenarios, the KV pool only queries the first GPU card. When keys on other cards are released, the query result still returns as successful, introducing accuracy issues. This PR modifies the KV pool's query logic to check all cards, resolving this problem. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 Signed-off-by: fems14 <1804143737@qq.com>
622 lines
26 KiB
Python
622 lines
26 KiB
Python
# Standard
|
|
import math
|
|
import threading
|
|
import time
|
|
from typing import Generator, List, Optional, Union
|
|
|
|
# Third Party
|
|
import torch
|
|
from vllm.config import VllmConfig
|
|
from vllm.utils import get_kv_cache_torch_dtype, logger
|
|
|
|
from vllm_ascend.distributed.mooncake.config_data import (
|
|
ChunkedTokenDatabase, LasyerMultiBlockReqMeta, MooncakeConnectorMetadata,
|
|
MooncakeEngineMetadata)
|
|
from vllm_ascend.distributed.mooncake.kv_transfer import (
|
|
KVCacheStoreLayerRecvingThread, KVCacheStoreLayerSendingThread,
|
|
KVCacheStoreRecvingThread, KVCacheStoreSendingThread, KVTransferThread)
|
|
from vllm_ascend.distributed.mooncake.mooncake_store import Mooncakestore
|
|
|
|
|
|
class MooncakeEngine:
|
|
#The main class for the cache engine.
|
|
|
|
def __init__(
|
|
self,
|
|
vllm_config: VllmConfig,
|
|
use_layerwize: bool,
|
|
):
|
|
model_config = vllm_config.model_config
|
|
parallel_config = vllm_config.parallel_config
|
|
self.use_mla = False
|
|
if (hasattr(model_config, "use_mla")
|
|
and isinstance(model_config.use_mla, bool)
|
|
and model_config.use_mla):
|
|
self.use_mla = True
|
|
self.use_layerwise = use_layerwize
|
|
self.tp_rank = parallel_config.rank
|
|
self.tp_size = parallel_config.tensor_parallel_size
|
|
self.kv_role = vllm_config.kv_transfer_config.kv_role
|
|
self.load_async = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
|
"load_async", False)
|
|
self.register_buffer = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
|
"register_buffer", False)
|
|
self.block_size = vllm_config.cache_config.block_size
|
|
self.current_layer = 0
|
|
# self.use_mla = first_kv_cache_tuple[0].size(
|
|
# -1) != first_kv_cache_tuple[1].size(-1)
|
|
self.num_layers = model_config.get_num_layers(parallel_config)
|
|
self.block_size = vllm_config.cache_config.block_size
|
|
num_kv_head = model_config.get_num_kv_heads(parallel_config)
|
|
head_size = model_config.get_head_size()
|
|
kv_dtype = get_kv_cache_torch_dtype(
|
|
vllm_config.cache_config.cache_dtype, model_config.dtype)
|
|
self.hidden_dim_size = num_kv_head * head_size
|
|
if self.use_mla:
|
|
kv_shape = (self.num_layers, 1, self.block_size, 1, head_size)
|
|
else:
|
|
kv_shape = (self.num_layers, 2, self.block_size, num_kv_head,
|
|
head_size)
|
|
self.metadata = MooncakeEngineMetadata(
|
|
model_config.model,
|
|
parallel_config.world_size,
|
|
parallel_config.rank,
|
|
kv_dtype,
|
|
kv_shape,
|
|
self.block_size,
|
|
self.use_mla,
|
|
)
|
|
|
|
self.token_database = ChunkedTokenDatabase(self.metadata)
|
|
|
|
self.m_store = Mooncakestore(parallel_config)
|
|
|
|
self.kv_send_thread: Optional[KVTransferThread] = None
|
|
self.kv_recv_thread: Optional[KVTransferThread] = None
|
|
|
|
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
|
_, first_kv_cache_tuple = next(iter(kv_caches.items()))
|
|
first_kv_cache = first_kv_cache_tuple[0]
|
|
|
|
# TODO(tms): Find a more robust way to detect and handle MLA
|
|
if self.use_mla:
|
|
# MLA case.[num_block, block_size, 1, hidden_dim]
|
|
self.num_blocks = first_kv_cache.shape[0]
|
|
block_rank = 3 # [block_size, latent_dim]
|
|
block_shape_norm = first_kv_cache_tuple[0].shape[-block_rank:]
|
|
block_shape_pe = first_kv_cache_tuple[1].shape[-block_rank:]
|
|
self.block_len = [
|
|
first_kv_cache[0].element_size() * math.prod(block_shape_norm),
|
|
first_kv_cache[1].element_size() * math.prod(block_shape_pe)
|
|
]
|
|
logger.info(
|
|
"num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s",
|
|
self.num_blocks, block_shape_norm, block_shape_pe)
|
|
else:
|
|
# [num_block, block_size, num_head, hidden_dim]
|
|
self.num_blocks = first_kv_cache.shape[0]
|
|
kv_elem_size = first_kv_cache.element_size()
|
|
block_rank = 3 # [block_size, kv_heads, head_dim]
|
|
block_shape = first_kv_cache.shape[-block_rank:]
|
|
self.block_len = [kv_elem_size * math.prod(block_shape)]
|
|
logger.info("num_blocks: %s, block_shape: %s", self.num_blocks,
|
|
block_shape)
|
|
|
|
logger.info("Registering KV_Caches. use_mla: %s, shape %s",
|
|
self.use_mla, first_kv_cache.shape)
|
|
|
|
self.kv_caches = kv_caches
|
|
self.kv_caches_base_addr = []
|
|
for cache_or_caches in kv_caches.values():
|
|
# Normalize to always be a list of caches
|
|
if self.use_mla:
|
|
for i, cache in enumerate(cache_or_caches, 0):
|
|
base_addr = cache.data_ptr()
|
|
self.kv_caches_base_addr.append(base_addr)
|
|
if self.register_buffer:
|
|
region_len = self.num_blocks * self.block_len[i % 2]
|
|
self._register(base_addr, region_len)
|
|
else:
|
|
cache_list = [cache_or_caches
|
|
] if self.use_mla else cache_or_caches
|
|
for cache in cache_list:
|
|
base_addr = cache.data_ptr()
|
|
self.kv_caches_base_addr.append(base_addr)
|
|
if self.register_buffer:
|
|
region_len = self.num_blocks * self.block_len[0]
|
|
self._register(base_addr, region_len)
|
|
|
|
def _register(self, ptr, length):
|
|
logger.debug(
|
|
"Registering KV cache: ptr=0x%x, length=%d, num_blocks=%d, "
|
|
"block_lens=%s", ptr, length, self.num_blocks, self.block_len)
|
|
try:
|
|
self.m_store.register_buffer(ptr, length)
|
|
except Exception as e:
|
|
raise RuntimeError(
|
|
f"Mooncake memory registration failed. Error is: {e}")
|
|
|
|
if self.use_layerwise:
|
|
self.get_event = threading.Event()
|
|
if self.kv_role in ['kv_producer', 'kv_both']:
|
|
ready_event_sending = threading.Event()
|
|
self.kv_send_thread = KVCacheStoreLayerSendingThread(
|
|
self.tp_rank, self.tp_size, self.m_store,
|
|
self.kv_caches_base_addr, self.token_database,
|
|
self.block_len, self.block_size, ready_event_sending,
|
|
self.num_layers)
|
|
self.kv_send_thread.start()
|
|
ready_event = threading.Event()
|
|
self.kv_recv_thread = KVCacheStoreLayerRecvingThread(
|
|
self.tp_rank, self.tp_size, self.m_store,
|
|
self.kv_caches_base_addr, self.token_database, self.block_len,
|
|
self.block_size, ready_event, self.get_event)
|
|
self.kv_recv_thread.start()
|
|
ready_event.wait()
|
|
else:
|
|
if self.kv_role in ['kv_producer', 'kv_both']:
|
|
ready_event_sending = threading.Event()
|
|
self.kv_send_thread = KVCacheStoreSendingThread(
|
|
self.tp_rank, self.tp_size, self.m_store,
|
|
self.kv_caches_base_addr, self.token_database,
|
|
self.block_len, self.block_size, ready_event_sending)
|
|
self.kv_send_thread.start()
|
|
if self.load_async:
|
|
ready_event = threading.Event()
|
|
self.kv_recv_thread = KVCacheStoreRecvingThread(
|
|
self.tp_rank, self.tp_size, self.m_store,
|
|
self.kv_caches_base_addr, self.token_database,
|
|
self.block_len, self.block_size, ready_event)
|
|
self.kv_recv_thread.start()
|
|
ready_event.wait()
|
|
|
|
def start_load_kv(self, metadata: MooncakeConnectorMetadata):
|
|
self.current_layer = 0
|
|
self.layerwise_retrievers = []
|
|
for request in metadata.requests:
|
|
load_spec = request.load_spec
|
|
if load_spec is None or not load_spec.can_load: #load =0
|
|
continue
|
|
tokens = request.token_ids
|
|
req_id = request.req_id
|
|
if (load_spec.mooncake_cached_tokens % self.block_size
|
|
!= 0) and (load_spec.mooncake_cached_tokens
|
|
== tokens.shape[0] - 1):
|
|
tokens = tokens[:request.load_spec.mooncake_cached_tokens + 1]
|
|
else:
|
|
tokens = tokens[:request.load_spec.mooncake_cached_tokens]
|
|
masked_token_count = (request.load_spec.vllm_cached_tokens //
|
|
self.block_size * self.block_size)
|
|
token_mask = torch.ones_like(tokens, dtype=torch.bool)
|
|
token_mask[:masked_token_count] = False
|
|
if self.use_layerwise:
|
|
layerwise_retriever = self.retrieve_layer(
|
|
req_id,
|
|
tokens,
|
|
request.block_ids,
|
|
token_mask,
|
|
)
|
|
next(layerwise_retriever) # first layer load
|
|
self.layerwise_retrievers.append(layerwise_retriever)
|
|
else:
|
|
if self.load_async:
|
|
self.kv_recv_thread.add_request( # type: ignore[union-attr]
|
|
req_id,
|
|
tokens,
|
|
request.block_ids,
|
|
token_mask,
|
|
)
|
|
else:
|
|
if self.m_store.config.use_ascend_direct:
|
|
addr_list = []
|
|
size_list = []
|
|
key_list = []
|
|
blockIds = []
|
|
for start, end, key in self.token_database.process_tokens(
|
|
tokens, token_mask):
|
|
addr, size, block_id = self.prepare_value(
|
|
start, end, request.block_ids)
|
|
key_list.append(key.to_string())
|
|
addr_list.append(addr)
|
|
size_list.append(size)
|
|
blockIds.append(block_id)
|
|
self.m_store.get_batch(key_list, addr_list, size_list,
|
|
blockIds)
|
|
else:
|
|
for start, end, key in self.token_database.process_tokens(
|
|
tokens, token_mask):
|
|
addr, size, _ = self.prepare_value(
|
|
start, end, request.block_ids)
|
|
self.m_store.get(key, addr, size)
|
|
|
|
def prepare_value(self, start: int, end: int, block_ids: list[int]):
|
|
addr_list = []
|
|
size_list = []
|
|
block_id = block_ids[start // self.block_size]
|
|
for index, base_addr in enumerate(self.kv_caches_base_addr):
|
|
block_len = (self.block_len[index % 2]
|
|
if self.use_mla else self.block_len[0])
|
|
|
|
addr = base_addr + block_id * block_len
|
|
length = int(block_len / self.block_size * (end - start))
|
|
addr_list.append(addr)
|
|
size_list.append(length)
|
|
return addr_list, size_list, block_id
|
|
|
|
def wait_for_layer_load(self) -> None:
|
|
"""MooncakeConnector does not do layerwise saving."""
|
|
for layerwise_retriever in self.layerwise_retrievers:
|
|
ret_token_mask = next(layerwise_retriever)
|
|
if self.current_layer == self.num_layers - 1:
|
|
assert ret_token_mask is not None
|
|
num_retrieved_tokens = ret_token_mask.sum().item()
|
|
logger.info(f"Retrieved {num_retrieved_tokens} tokens")
|
|
|
|
def save_kv_layer(self,
|
|
connector_metadata: MooncakeConnectorMetadata) -> None:
|
|
"""MooncakeConnector does not save explicitly."""
|
|
if self.current_layer == 0:
|
|
self.layerwise_storers = []
|
|
for request in connector_metadata.requests:
|
|
save_spec = request.save_spec
|
|
if save_spec is None or not save_spec.can_save:
|
|
continue
|
|
|
|
token_ids = request.token_ids
|
|
req_id = request.req_id
|
|
assert isinstance(token_ids, torch.Tensor)
|
|
assert token_ids.is_cpu
|
|
|
|
# TODO: whether need to remov saveThread
|
|
# no lookup, skipmask
|
|
skip_leading_tokens = max(
|
|
self.lookup(token_ids, self.use_layerwise),
|
|
save_spec.skip_leading_tokens,
|
|
)
|
|
if skip_leading_tokens == len(token_ids):
|
|
if request.is_last_chunk:
|
|
self.kv_send_thread.set_finished_request( # type: ignore[union-attr]
|
|
req_id)
|
|
continue # skip this request
|
|
|
|
skip_leading_tokens = (skip_leading_tokens // self.block_size *
|
|
self.block_size)
|
|
|
|
store_mask = torch.ones_like(token_ids, dtype=torch.bool)
|
|
store_mask[:skip_leading_tokens] = False
|
|
logger.info(
|
|
"Storing KV cache for %d out of %d tokens "
|
|
"(skip_leading_tokens=%d) for request %s",
|
|
len(token_ids) - skip_leading_tokens,
|
|
len(token_ids),
|
|
skip_leading_tokens,
|
|
request.req_id,
|
|
)
|
|
|
|
layerwise_storer = self.store_layer(
|
|
req_id,
|
|
token_ids,
|
|
mask=store_mask,
|
|
block_ids=request.block_ids,
|
|
)
|
|
self.layerwise_storers.append(layerwise_storer)
|
|
for layerwise_storer in self.layerwise_storers:
|
|
try:
|
|
next(layerwise_storer)
|
|
except Exception:
|
|
raise
|
|
self.current_layer = self.current_layer + 1
|
|
|
|
def wait_for_save(self, connector_metadata: MooncakeConnectorMetadata):
|
|
"""MooncakeConnector does not save explicitly."""
|
|
for request in connector_metadata.requests:
|
|
save_spec = request.save_spec
|
|
if save_spec is None or not save_spec.can_save:
|
|
continue
|
|
|
|
token_ids = request.token_ids
|
|
req_id = request.req_id
|
|
assert isinstance(token_ids, torch.Tensor)
|
|
assert token_ids.is_cpu
|
|
|
|
skip_leading_tokens = max(
|
|
self.lookup(token_ids, self.use_layerwise),
|
|
save_spec.skip_leading_tokens,
|
|
)
|
|
if skip_leading_tokens == len(token_ids):
|
|
if request.is_last_chunk:
|
|
self.kv_send_thread.set_finished_request( # type: ignore[union-attr]
|
|
req_id)
|
|
continue # skip this request
|
|
|
|
skip_leading_tokens = (skip_leading_tokens // self.block_size *
|
|
self.block_size)
|
|
|
|
store_mask = torch.ones_like(token_ids, dtype=torch.bool)
|
|
store_mask[:skip_leading_tokens] = False
|
|
|
|
logger.info(
|
|
"Storing KV cache for %d out of %d tokens "
|
|
"(skip_leading_tokens=%d) for request %s",
|
|
len(token_ids) - skip_leading_tokens,
|
|
len(token_ids),
|
|
skip_leading_tokens,
|
|
request.req_id,
|
|
)
|
|
|
|
self.kv_send_thread.add_request( # type: ignore[union-attr]
|
|
req_id,
|
|
token_ids,
|
|
request.block_ids,
|
|
store_mask,
|
|
request.is_last_chunk,
|
|
)
|
|
|
|
def retrieve_layer(
|
|
self,
|
|
req_id: str,
|
|
tokens: torch.Tensor,
|
|
block_ids: list[int],
|
|
mask: Optional[torch.Tensor] = None,
|
|
) -> Generator[Optional[torch.Tensor], None, None]:
|
|
"""
|
|
Retrieve the KV cache in a layerwise manner.
|
|
|
|
:param torch.Tensor tokens: The tokens of the corresponding KV caches.
|
|
|
|
:param Optional[torch.Tensor] mask: The mask for the tokens. Should
|
|
have the same length as tokens. And the mask should ALWAYS be like
|
|
FFFFFTTTTTTT, where True means the tokens needs to be matched.
|
|
|
|
:param **kwargs: The additional arguments for the KV transfer which
|
|
will be passed into the npu_transfer.
|
|
|
|
return: A generator that yields Optional[torch.Tensor]. The tensor will
|
|
be the boolean mask indicating which tokens are retrieved and will
|
|
only be returned in the last iteration.
|
|
"""
|
|
|
|
if mask is not None:
|
|
num_required_tokens = torch.sum(mask).item()
|
|
else:
|
|
num_required_tokens = len(tokens)
|
|
|
|
ret_mask = torch.zeros_like(tokens, dtype=torch.bool, device="cpu")
|
|
|
|
starts = []
|
|
ends = []
|
|
keys = []
|
|
first_flag = True
|
|
for start, end, key in self.token_database.process_tokens(
|
|
tokens, mask):
|
|
keys_multi_layer = key.split_layers(self.num_layers)
|
|
starts.append(start)
|
|
ends.append(end)
|
|
keys.append(keys_multi_layer)
|
|
ret_mask[start:end] = True
|
|
|
|
if keys:
|
|
# Transpose the keys into layer major format
|
|
keys = [list(row) for row in zip(*keys)] # [num_layer,block_num]
|
|
for layer_id, keys_multi_chunk in enumerate(keys):
|
|
if not first_flag:
|
|
is_finish = self.get_event.wait(timeout=3) #try---cache
|
|
if not is_finish:
|
|
logger.info("Layerwise get failed")
|
|
self.get_event.clear()
|
|
req_meta = LasyerMultiBlockReqMeta(req_id, keys_multi_chunk,
|
|
starts, ends, block_ids,
|
|
layer_id)
|
|
self.kv_recv_thread.add_request( # type: ignore[union-attr, call-arg]
|
|
req_meta) # type: ignore[union-attr, call-arg, arg-type]
|
|
first_flag = False
|
|
yield None
|
|
else:
|
|
# If no cache are found, we still need to yield to avoid
|
|
# `StopIteration`
|
|
for layer_id in range(self.num_layers):
|
|
yield None
|
|
|
|
retrieved_tokens = torch.sum(ret_mask)
|
|
logger.debug(f"Retrieved {retrieved_tokens} "
|
|
f"out of {num_required_tokens} "
|
|
f"out of total {len(tokens)} tokens")
|
|
|
|
yield ret_mask
|
|
|
|
def store_layer(
|
|
self,
|
|
req_id: str,
|
|
tokens: torch.Tensor,
|
|
block_ids: list[int],
|
|
mask: Optional[torch.Tensor] = None,
|
|
) -> Generator[None, None, None]:
|
|
"""
|
|
Store the KV cache in a layerwise manner.
|
|
|
|
:param torch.Tensor tokens: The tokens of the corresponding KV caches.
|
|
|
|
:param Optional[torch.Tensor] mask: The mask for the tokens. Should
|
|
have the same length as tokens. And the mask should ALWAYS be like
|
|
FFFFFTTTTTTT, where True means the tokens needs to be matched.
|
|
|
|
:param **kwargs: The additional arguments for the storage backend which
|
|
will be passed into the gpu_connector.
|
|
|
|
return: A generator that yields None. In the first iteration, the
|
|
generator allocates the memory objects for all layers and moves
|
|
the KV cache of the first layer from GPU to CPU. In the next
|
|
iterations, it moves the KV cache of layer i from GPU to the memory
|
|
objects (on CPU) and puts the memory objects of layer i-1 to the
|
|
storage backends. In the last iteration, it puts the memory objects
|
|
of the last layer to the storage backends.
|
|
"""
|
|
|
|
if mask is not None:
|
|
num_stored_tokens = torch.sum(mask).item()
|
|
else:
|
|
num_stored_tokens = len(tokens)
|
|
|
|
starts = []
|
|
ends = []
|
|
keys = []
|
|
for start, end, key in self.token_database.process_tokens(
|
|
tokens, mask):
|
|
keys_multi_layer = key.split_layers(self.num_layers)
|
|
starts.append(start)
|
|
ends.append(end)
|
|
keys.append(keys_multi_layer) #[block_num,layer_num]
|
|
|
|
if keys:
|
|
keys = [list(row) for row in zip(*keys)] #[layer_num,block_num]
|
|
for layer_id, keys_multi_chunk in enumerate(keys):
|
|
req_meta = LasyerMultiBlockReqMeta(req_id, keys_multi_chunk,
|
|
starts, ends, block_ids,
|
|
layer_id)
|
|
self.kv_send_thread.add_request( # type: ignore[union-attr, call-arg]
|
|
req_meta) # type: ignore[union-attr, call-arg, arg-type]
|
|
yield
|
|
else:
|
|
for layer_id in range(self.num_layers):
|
|
yield
|
|
logger.debug(
|
|
f"Stored {num_stored_tokens} out of total {len(tokens)} tokens")
|
|
|
|
def get_finished(self) -> tuple[set[str], set[str]]:
|
|
done_sending = (
|
|
self.kv_send_thread.
|
|
get_and_clear_finished_requests( # type: ignore[union-attr]
|
|
) if self.kv_role in ['kv_producer', 'kv_both'] else set())
|
|
|
|
done_recving = (
|
|
self.kv_recv_thread.
|
|
get_and_clear_finished_requests( # type: ignore[union-attr]
|
|
) if self.load_async else set())
|
|
|
|
logger.debug(
|
|
"Number of completed KV cache send requests: %d, receive "
|
|
"requests: %d, tp_rank:%d", len(done_sending), len(done_recving),
|
|
self.tp_rank)
|
|
return done_sending, done_recving
|
|
|
|
def wait_layer_transfer_finish(self):
|
|
time.sleep(10)
|
|
pass
|
|
|
|
def lookup(
|
|
self,
|
|
tokens: Union[torch.Tensor, List[int]],
|
|
use_layerwise: bool,
|
|
) -> int:
|
|
"""
|
|
Checks the existence of KV cache of the tokens from the cache engine.
|
|
:param tokens: the input tokens, with shape [seq_len]
|
|
:return: An int indicating how many prefix tokens are cached.
|
|
"""
|
|
end = 0
|
|
keys = []
|
|
try:
|
|
if use_layerwise:
|
|
for start, end, key in self.token_database.process_tokens(
|
|
tokens):
|
|
keys_multi_layer = key.split_layers(self.num_layers)
|
|
for item in keys_multi_layer:
|
|
keys.append(item.to_string())
|
|
# batch is_exists
|
|
ress = self.m_store.batch_exists(keys)
|
|
res = 1
|
|
for value in ress:
|
|
if value != 1:
|
|
res = 0
|
|
break
|
|
if res == 1:
|
|
continue
|
|
else:
|
|
return start
|
|
else:
|
|
starts = []
|
|
for start, end, key in self.token_database.process_tokens(
|
|
tokens):
|
|
keys.append(key.to_string())
|
|
starts.append(start)
|
|
res = self.m_store.batch_exists(
|
|
keys) # type: ignore[assignment]
|
|
for index, value in enumerate(res): # type: ignore[arg-type]
|
|
if value != 1:
|
|
return starts[index]
|
|
# all tokens where found, return the maximal end
|
|
except Exception as e:
|
|
logger.error(f"Remote connection failed in contains: {e}")
|
|
return start
|
|
return end
|
|
|
|
def lookup_scheduler(
|
|
self,
|
|
tokens: Union[torch.Tensor, List[int]],
|
|
use_layerwise: bool,
|
|
) -> int:
|
|
"""
|
|
Checks the existence of KV cache of the tokens from the cache engine.
|
|
:param tokens: the input tokens, with shape [seq_len]
|
|
:return: An int indicating how many prefix tokens are cached.
|
|
"""
|
|
end = 0
|
|
keys = []
|
|
try:
|
|
if use_layerwise:
|
|
for start, end, key in self.token_database.process_tokens(
|
|
tokens):
|
|
keys_multi_layer = key.split_layers(self.num_layers)
|
|
for item in keys_multi_layer:
|
|
keys.append(item.to_string())
|
|
# batch is_exists
|
|
ress = self.m_store.batch_exists(keys)
|
|
res = 1
|
|
for value in ress:
|
|
if value != 1:
|
|
res = 0
|
|
break
|
|
if res == 1:
|
|
continue
|
|
else:
|
|
return start
|
|
else:
|
|
starts = []
|
|
for start, end, key in self.token_database.process_tokens(
|
|
tokens):
|
|
keys.append(key.to_string())
|
|
starts.append(start)
|
|
multi_tp_keys = keys[:]
|
|
for i in range(1, self.tp_size):
|
|
for item in keys:
|
|
new_str = item.replace( # type: ignore[attr-defined]
|
|
"@0", f"@{i}", 1)
|
|
multi_tp_keys.append(new_str)
|
|
res = self.m_store.batch_exists(
|
|
multi_tp_keys) # type: ignore[assignment]
|
|
num_block = len(keys)
|
|
multi_tp_values = [
|
|
res[i * num_block:(i + 1) *
|
|
num_block] # type: ignore[index]
|
|
for i in range(self.tp_size)
|
|
]
|
|
index = self.find_min_first_non_one_index(multi_tp_values)
|
|
if index != -1:
|
|
return starts[index]
|
|
# all tokens where found, return the maximal end
|
|
except Exception as e:
|
|
logger.error(f"Remote connection failed in contains: {e}")
|
|
return start
|
|
return end
|
|
|
|
def find_min_first_non_one_index(self, arr):
|
|
try:
|
|
return min(idx for row in arr for idx, val in enumerate(row)
|
|
if val != 1)
|
|
except ValueError:
|
|
return -1
|
|
|
|
def close(self) -> None:
|
|
"""Close the cache engine and free all the resources"""
|
|
self.m_store.close()
|