2024-01-08 04:37:50 +00:00
|
|
|
import heapq
|
|
|
|
|
import time
|
|
|
|
|
from collections import defaultdict
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TreeNode:
|
|
|
|
|
def __init__(self):
|
|
|
|
|
self.children = defaultdict(TreeNode)
|
|
|
|
|
self.parent = None
|
2024-04-18 00:47:37 +08:00
|
|
|
self.key = None
|
2024-01-08 04:37:50 +00:00
|
|
|
self.value = None
|
|
|
|
|
self.ref_counter = 0
|
|
|
|
|
self.last_access_time = time.time()
|
|
|
|
|
|
2024-04-26 01:01:36 +08:00
|
|
|
def __lt__(self, other: "TreeNode"):
|
2024-01-08 04:37:50 +00:00
|
|
|
return self.last_access_time < other.last_access_time
|
|
|
|
|
|
|
|
|
|
|
2024-04-26 01:01:36 +08:00
|
|
|
def _key_match(key0, key1):
|
2024-01-08 04:37:50 +00:00
|
|
|
i = 0
|
2024-04-26 01:01:36 +08:00
|
|
|
for k0, k1 in zip(key0, key1):
|
|
|
|
|
if k0 != k1:
|
2024-01-08 04:37:50 +00:00
|
|
|
break
|
|
|
|
|
i += 1
|
|
|
|
|
return i
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RadixCache:
|
2024-04-26 01:01:36 +08:00
|
|
|
def __init__(self, disable: bool = False):
|
2024-01-26 13:32:59 +08:00
|
|
|
self.disable = disable
|
2024-04-26 01:01:36 +08:00
|
|
|
self.reset()
|
2024-01-26 13:32:59 +08:00
|
|
|
|
|
|
|
|
##### Public API #####
|
|
|
|
|
|
|
|
|
|
def reset(self):
|
2024-01-08 04:37:50 +00:00
|
|
|
self.root_node = TreeNode()
|
2024-04-18 00:47:37 +08:00
|
|
|
self.root_node.key = []
|
2024-01-08 04:37:50 +00:00
|
|
|
self.root_node.value = []
|
|
|
|
|
self.root_node.ref_counter = 1
|
|
|
|
|
self.evictable_size_ = 0
|
|
|
|
|
|
|
|
|
|
def match_prefix(self, key):
|
|
|
|
|
if self.disable:
|
|
|
|
|
return [], self.root_node
|
|
|
|
|
|
|
|
|
|
value = []
|
|
|
|
|
last_node = [self.root_node]
|
|
|
|
|
self._match_prefix_helper(self.root_node, key, value, last_node)
|
|
|
|
|
if value:
|
|
|
|
|
value = torch.concat(value)
|
|
|
|
|
return value, last_node[0]
|
|
|
|
|
|
|
|
|
|
def insert(self, key, value=None):
|
|
|
|
|
if self.disable:
|
|
|
|
|
return len(key)
|
|
|
|
|
|
|
|
|
|
if value is None:
|
|
|
|
|
value = [x for x in key]
|
|
|
|
|
return self._insert_helper(self.root_node, key, value)
|
|
|
|
|
|
|
|
|
|
def pretty_print(self):
|
|
|
|
|
self._print_helper(self.root_node, 0)
|
|
|
|
|
print(f"#tokens: {self.total_size()}")
|
|
|
|
|
|
|
|
|
|
def total_size(self):
|
|
|
|
|
return self._total_size_helper(self.root_node)
|
|
|
|
|
|
|
|
|
|
def evict(self, num_tokens, evict_callback):
|
|
|
|
|
if self.disable:
|
2024-04-26 01:01:36 +08:00
|
|
|
return
|
2024-01-08 04:37:50 +00:00
|
|
|
|
|
|
|
|
leaves = self._collect_leaves()
|
|
|
|
|
heapq.heapify(leaves)
|
|
|
|
|
|
|
|
|
|
num_evicted = 0
|
|
|
|
|
while num_evicted < num_tokens and len(leaves):
|
|
|
|
|
x = heapq.heappop(leaves)
|
|
|
|
|
|
|
|
|
|
if x == self.root_node:
|
|
|
|
|
break
|
|
|
|
|
if x.ref_counter > 0:
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
num_evicted += evict_callback(x.value)
|
|
|
|
|
self._delete_leaf(x)
|
|
|
|
|
|
|
|
|
|
if len(x.parent.children) == 0:
|
|
|
|
|
heapq.heappush(leaves, x.parent)
|
|
|
|
|
|
|
|
|
|
def inc_ref_counter(self, node):
|
|
|
|
|
delta = 0
|
|
|
|
|
while node != self.root_node:
|
|
|
|
|
if node.ref_counter == 0:
|
|
|
|
|
self.evictable_size_ -= len(node.value)
|
|
|
|
|
delta -= len(node.value)
|
|
|
|
|
node.ref_counter += 1
|
|
|
|
|
node = node.parent
|
|
|
|
|
return delta
|
|
|
|
|
|
|
|
|
|
def dec_ref_counter(self, node):
|
|
|
|
|
delta = 0
|
|
|
|
|
while node != self.root_node:
|
|
|
|
|
if node.ref_counter == 1:
|
|
|
|
|
self.evictable_size_ += len(node.value)
|
|
|
|
|
delta += len(node.value)
|
|
|
|
|
node.ref_counter -= 1
|
|
|
|
|
node = node.parent
|
|
|
|
|
return delta
|
|
|
|
|
|
|
|
|
|
def evictable_size(self):
|
|
|
|
|
return self.evictable_size_
|
|
|
|
|
|
|
|
|
|
##### Internal Helper Functions #####
|
2024-04-26 01:01:36 +08:00
|
|
|
|
2024-01-08 04:37:50 +00:00
|
|
|
def _match_prefix_helper(self, node, key, value, last_node):
|
|
|
|
|
node.last_access_time = time.time()
|
2024-04-18 00:47:37 +08:00
|
|
|
if len(key) == 0:
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
if key[0] in node.children.keys():
|
|
|
|
|
child = node.children[key[0]]
|
2024-04-26 01:01:36 +08:00
|
|
|
prefix_len = _key_match(child.key, key)
|
2024-04-18 00:47:37 +08:00
|
|
|
if prefix_len < len(child.key):
|
|
|
|
|
new_node = self._split_node(child.key, child, prefix_len)
|
|
|
|
|
value.append(new_node.value)
|
|
|
|
|
last_node[0] = new_node
|
|
|
|
|
else:
|
|
|
|
|
value.append(child.value)
|
|
|
|
|
last_node[0] = child
|
|
|
|
|
self._match_prefix_helper(child, key[prefix_len:], value, last_node)
|
2024-01-08 04:37:50 +00:00
|
|
|
|
|
|
|
|
def _split_node(self, key, child, split_len):
|
|
|
|
|
# new_node -> child
|
|
|
|
|
new_node = TreeNode()
|
2024-04-18 00:47:37 +08:00
|
|
|
new_node.children = {key[split_len:][0]: child}
|
2024-01-08 04:37:50 +00:00
|
|
|
new_node.parent = child.parent
|
|
|
|
|
new_node.ref_counter = child.ref_counter
|
2024-04-18 00:47:37 +08:00
|
|
|
new_node.key = child.key[:split_len]
|
2024-01-08 04:37:50 +00:00
|
|
|
new_node.value = child.value[:split_len]
|
|
|
|
|
child.parent = new_node
|
2024-04-18 00:47:37 +08:00
|
|
|
child.key = child.key[split_len:]
|
2024-01-08 04:37:50 +00:00
|
|
|
child.value = child.value[split_len:]
|
2024-04-18 00:47:37 +08:00
|
|
|
new_node.parent.children[key[:split_len][0]] = new_node
|
2024-01-08 04:37:50 +00:00
|
|
|
return new_node
|
|
|
|
|
|
|
|
|
|
def _insert_helper(self, node, key, value):
|
|
|
|
|
node.last_access_time = time.time()
|
2024-04-18 00:47:37 +08:00
|
|
|
if len(key) == 0:
|
|
|
|
|
return 0
|
2024-01-08 04:37:50 +00:00
|
|
|
|
2024-04-18 00:47:37 +08:00
|
|
|
if key[0] in node.children.keys():
|
|
|
|
|
child = node.children[key[0]]
|
2024-04-26 01:01:36 +08:00
|
|
|
prefix_len = _key_match(child.key, key)
|
2024-01-08 04:37:50 +00:00
|
|
|
|
2024-04-18 00:47:37 +08:00
|
|
|
if prefix_len == len(child.key):
|
2024-01-08 04:37:50 +00:00
|
|
|
if prefix_len == len(key):
|
|
|
|
|
return prefix_len
|
|
|
|
|
else:
|
|
|
|
|
key = key[prefix_len:]
|
|
|
|
|
value = value[prefix_len:]
|
|
|
|
|
return prefix_len + self._insert_helper(child, key, value)
|
|
|
|
|
|
2024-04-18 00:47:37 +08:00
|
|
|
new_node = self._split_node(child.key, child, prefix_len)
|
|
|
|
|
return prefix_len + self._insert_helper(
|
|
|
|
|
new_node, key[prefix_len:], value[prefix_len:]
|
|
|
|
|
)
|
2024-01-08 04:37:50 +00:00
|
|
|
|
|
|
|
|
if len(key):
|
|
|
|
|
new_node = TreeNode()
|
|
|
|
|
new_node.parent = node
|
2024-04-18 00:47:37 +08:00
|
|
|
new_node.key = key
|
2024-01-08 04:37:50 +00:00
|
|
|
new_node.value = value
|
2024-04-18 00:47:37 +08:00
|
|
|
node.children[key[0]] = new_node
|
2024-01-08 04:37:50 +00:00
|
|
|
self.evictable_size_ += len(value)
|
|
|
|
|
return 0
|
|
|
|
|
|
|
|
|
|
def _print_helper(self, node, indent):
|
2024-04-18 00:47:37 +08:00
|
|
|
for _, child in node.children.items():
|
2024-04-21 17:25:14 +08:00
|
|
|
print(
|
|
|
|
|
" " * indent, len(child.key), child.key[:10], f"r={child.ref_counter}"
|
|
|
|
|
)
|
2024-01-08 04:37:50 +00:00
|
|
|
self._print_helper(child, indent=indent + 2)
|
|
|
|
|
|
|
|
|
|
def _delete_leaf(self, node):
|
|
|
|
|
for k, v in node.parent.children.items():
|
|
|
|
|
if v == node:
|
|
|
|
|
break
|
|
|
|
|
del node.parent.children[k]
|
2024-04-18 00:47:37 +08:00
|
|
|
self.evictable_size_ -= len(node.key)
|
2024-01-08 04:37:50 +00:00
|
|
|
|
|
|
|
|
def _total_size_helper(self, node):
|
|
|
|
|
x = len(node.value)
|
|
|
|
|
for child in node.children.values():
|
|
|
|
|
x += self._total_size_helper(child)
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
def _collect_leaves(self):
|
|
|
|
|
ret_list = []
|
|
|
|
|
|
|
|
|
|
def dfs_(cur_node):
|
|
|
|
|
if len(cur_node.children) == 0:
|
|
|
|
|
ret_list.append(cur_node)
|
|
|
|
|
|
|
|
|
|
for x in cur_node.children.values():
|
|
|
|
|
dfs_(x)
|
|
|
|
|
|
|
|
|
|
dfs_(self.root_node)
|
|
|
|
|
return ret_list
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
2024-04-26 01:01:36 +08:00
|
|
|
tree = RadixCache()
|
2024-01-08 04:37:50 +00:00
|
|
|
|
|
|
|
|
tree.insert("Hello")
|
|
|
|
|
tree.insert("Hello")
|
|
|
|
|
tree.insert("Hello_L.A.!")
|
|
|
|
|
# tree.insert("Hello_world! Happy")
|
|
|
|
|
# tree.insert("I love you!")
|
|
|
|
|
tree.pretty_print()
|
|
|
|
|
|
|
|
|
|
# print(tree.match_prefix("I love you! aha"))
|
|
|
|
|
|
|
|
|
|
# def evict_callback(x):
|
|
|
|
|
# print("evict", x)
|
|
|
|
|
# return len(x)
|
|
|
|
|
|
|
|
|
|
# tree.evict(5, evict_callback)
|
|
|
|
|
# tree.evict(10, evict_callback)
|
|
|
|
|
# tree.pretty_print()
|