Debug radixcache: refactor recursive helper methods (#3029)
Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
This commit is contained in:
@@ -112,14 +112,12 @@ class RadixCache(BasePrefixCache):
|
|||||||
if self.disable:
|
if self.disable:
|
||||||
return [], self.root_node
|
return [], self.root_node
|
||||||
|
|
||||||
value = []
|
value, last_node = self._match_prefix_helper(self.root_node, key)
|
||||||
last_node = [self.root_node]
|
|
||||||
self._match_prefix_helper(self.root_node, key, value, last_node)
|
|
||||||
if value:
|
if value:
|
||||||
value = torch.concat(value)
|
value = torch.concat(value)
|
||||||
else:
|
else:
|
||||||
value = torch.tensor([], dtype=torch.int32)
|
value = torch.tensor([], dtype=torch.int32)
|
||||||
return value, last_node[0]
|
return value, last_node
|
||||||
|
|
||||||
def insert(self, key: List, value=None):
|
def insert(self, key: List, value=None):
|
||||||
if self.disable:
|
if self.disable:
|
||||||
@@ -196,7 +194,7 @@ class RadixCache(BasePrefixCache):
|
|||||||
print(f"#tokens: {self.total_size()}")
|
print(f"#tokens: {self.total_size()}")
|
||||||
|
|
||||||
def total_size(self):
|
def total_size(self):
|
||||||
return self._total_size_helper(self.root_node)
|
return self._total_size_helper()
|
||||||
|
|
||||||
def evict(self, num_tokens: int, evict_callback: Callable):
|
def evict(self, num_tokens: int, evict_callback: Callable):
|
||||||
if self.disable:
|
if self.disable:
|
||||||
@@ -258,24 +256,23 @@ class RadixCache(BasePrefixCache):
|
|||||||
|
|
||||||
##### Internal Helper Functions #####
|
##### Internal Helper Functions #####
|
||||||
|
|
||||||
def _match_prefix_helper(
|
def _match_prefix_helper(self, node: TreeNode, key: List):
|
||||||
self, node: TreeNode, key: List, value, last_node: TreeNode
|
|
||||||
):
|
|
||||||
node.last_access_time = time.time()
|
node.last_access_time = time.time()
|
||||||
if len(key) == 0:
|
value = []
|
||||||
return
|
while len(key) > 0 and key[0] in node.children.keys():
|
||||||
|
|
||||||
if key[0] in node.children.keys():
|
|
||||||
child = node.children[key[0]]
|
child = node.children[key[0]]
|
||||||
|
child.last_access_time = time.time()
|
||||||
prefix_len = _key_match(child.key, key)
|
prefix_len = _key_match(child.key, key)
|
||||||
if prefix_len < len(child.key):
|
if prefix_len < len(child.key):
|
||||||
new_node = self._split_node(child.key, child, prefix_len)
|
new_node = self._split_node(child.key, child, prefix_len)
|
||||||
value.append(new_node.value)
|
value.append(new_node.value)
|
||||||
last_node[0] = new_node
|
node = new_node
|
||||||
|
break
|
||||||
else:
|
else:
|
||||||
value.append(child.value)
|
value.append(child.value)
|
||||||
last_node[0] = child
|
node = child
|
||||||
self._match_prefix_helper(child, key[prefix_len:], value, last_node)
|
key = key[prefix_len:]
|
||||||
|
return value, node
|
||||||
|
|
||||||
def _split_node(self, key, child: TreeNode, split_len: int):
|
def _split_node(self, key, child: TreeNode, split_len: int):
|
||||||
# new_node -> child
|
# new_node -> child
|
||||||
@@ -296,22 +293,18 @@ class RadixCache(BasePrefixCache):
|
|||||||
if len(key) == 0:
|
if len(key) == 0:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
if key[0] in node.children.keys():
|
total_prefix_length = 0
|
||||||
child = node.children[key[0]]
|
while len(key) > 0 and key[0] in node.children.keys():
|
||||||
prefix_len = _key_match(child.key, key)
|
node = node.children[key[0]]
|
||||||
|
node.last_access_time = time.time()
|
||||||
|
prefix_len = _key_match(node.key, key)
|
||||||
|
total_prefix_length += prefix_len
|
||||||
|
key = key[prefix_len:]
|
||||||
|
value = value[prefix_len:]
|
||||||
|
|
||||||
if prefix_len == len(child.key):
|
if prefix_len < len(node.key):
|
||||||
if prefix_len == len(key):
|
new_node = self._split_node(node.key, node, prefix_len)
|
||||||
return prefix_len
|
node = new_node
|
||||||
else:
|
|
||||||
key = key[prefix_len:]
|
|
||||||
value = value[prefix_len:]
|
|
||||||
return prefix_len + self._insert_helper(child, key, value)
|
|
||||||
|
|
||||||
new_node = self._split_node(child.key, child, prefix_len)
|
|
||||||
return prefix_len + self._insert_helper(
|
|
||||||
new_node, key[prefix_len:], value[prefix_len:]
|
|
||||||
)
|
|
||||||
|
|
||||||
if len(key):
|
if len(key):
|
||||||
new_node = TreeNode()
|
new_node = TreeNode()
|
||||||
@@ -320,12 +313,21 @@ class RadixCache(BasePrefixCache):
|
|||||||
new_node.value = value
|
new_node.value = value
|
||||||
node.children[key[0]] = new_node
|
node.children[key[0]] = new_node
|
||||||
self.evictable_size_ += len(value)
|
self.evictable_size_ += len(value)
|
||||||
return 0
|
return total_prefix_length
|
||||||
|
|
||||||
def _print_helper(self, node: TreeNode, indent: int):
|
def _print_helper(self, node: TreeNode, indent: int):
|
||||||
for _, child in node.children.items():
|
"""Prints the radix tree in a human-readable format."""
|
||||||
print(" " * indent, len(child.key), child.key[:10], f"r={child.lock_ref}")
|
stack = [(node, indent)]
|
||||||
self._print_helper(child, indent=indent + 2)
|
while stack:
|
||||||
|
current_node, current_indent = stack.pop()
|
||||||
|
print(
|
||||||
|
" " * current_indent,
|
||||||
|
len(current_node.key),
|
||||||
|
current_node.key[:10],
|
||||||
|
f"r={current_node.lock_ref}",
|
||||||
|
)
|
||||||
|
for _, child in current_node.children.items():
|
||||||
|
stack.append((child, current_indent + 2))
|
||||||
|
|
||||||
def _delete_leaf(self, node):
|
def _delete_leaf(self, node):
|
||||||
for k, v in node.parent.children.items():
|
for k, v in node.parent.children.items():
|
||||||
@@ -334,13 +336,17 @@ class RadixCache(BasePrefixCache):
|
|||||||
del node.parent.children[k]
|
del node.parent.children[k]
|
||||||
self.evictable_size_ -= len(node.key)
|
self.evictable_size_ -= len(node.key)
|
||||||
|
|
||||||
def _total_size_helper(self, node: TreeNode):
|
def _total_size_helper(self):
|
||||||
if node.evicted:
|
total_size = 0
|
||||||
return 0
|
stack = [self.root_node]
|
||||||
x = len(node.value)
|
while stack:
|
||||||
for child in node.children.values():
|
current_node = stack.pop()
|
||||||
x += self._total_size_helper(child)
|
total_size += len(current_node.value)
|
||||||
return x
|
for child in current_node.children.values():
|
||||||
|
if child.evicted:
|
||||||
|
continue
|
||||||
|
stack.append(child)
|
||||||
|
return total_size
|
||||||
|
|
||||||
def _collect_leaves(self):
|
def _collect_leaves(self):
|
||||||
ret_list = []
|
ret_list = []
|
||||||
|
|||||||
Reference in New Issue
Block a user