From 62b362b1f134e449985208567adeecef363e1126 Mon Sep 17 00:00:00 2001 From: luzengxiangcn <60803814+luzengxiangcn@users.noreply.github.com> Date: Thu, 6 Mar 2025 08:11:42 +0800 Subject: [PATCH] Debug radixcache: refactor recursive helper methods (#3029) Co-authored-by: Zhiqiang Xie --- python/sglang/srt/mem_cache/radix_cache.py | 88 ++++++++++++---------- 1 file changed, 47 insertions(+), 41 deletions(-) diff --git a/python/sglang/srt/mem_cache/radix_cache.py b/python/sglang/srt/mem_cache/radix_cache.py index c99a47516..d46ec4277 100644 --- a/python/sglang/srt/mem_cache/radix_cache.py +++ b/python/sglang/srt/mem_cache/radix_cache.py @@ -112,14 +112,12 @@ class RadixCache(BasePrefixCache): if self.disable: return [], self.root_node - value = [] - last_node = [self.root_node] - self._match_prefix_helper(self.root_node, key, value, last_node) + value, last_node = self._match_prefix_helper(self.root_node, key) if value: value = torch.concat(value) else: value = torch.tensor([], dtype=torch.int32) - return value, last_node[0] + return value, last_node def insert(self, key: List, value=None): if self.disable: @@ -196,7 +194,7 @@ class RadixCache(BasePrefixCache): print(f"#tokens: {self.total_size()}") def total_size(self): - return self._total_size_helper(self.root_node) + return self._total_size_helper() def evict(self, num_tokens: int, evict_callback: Callable): if self.disable: @@ -258,24 +256,23 @@ class RadixCache(BasePrefixCache): ##### Internal Helper Functions ##### - def _match_prefix_helper( - self, node: TreeNode, key: List, value, last_node: TreeNode - ): + def _match_prefix_helper(self, node: TreeNode, key: List): node.last_access_time = time.time() - if len(key) == 0: - return - - if key[0] in node.children.keys(): + value = [] + while len(key) > 0 and key[0] in node.children.keys(): child = node.children[key[0]] + child.last_access_time = time.time() prefix_len = _key_match(child.key, key) if prefix_len < len(child.key): new_node = self._split_node(child.key, child, prefix_len) value.append(new_node.value) - last_node[0] = new_node + node = new_node + break else: value.append(child.value) - last_node[0] = child - self._match_prefix_helper(child, key[prefix_len:], value, last_node) + node = child + key = key[prefix_len:] + return value, node def _split_node(self, key, child: TreeNode, split_len: int): # new_node -> child @@ -296,22 +293,18 @@ class RadixCache(BasePrefixCache): if len(key) == 0: return 0 - if key[0] in node.children.keys(): - child = node.children[key[0]] - prefix_len = _key_match(child.key, key) + total_prefix_length = 0 + while len(key) > 0 and key[0] in node.children.keys(): + node = node.children[key[0]] + node.last_access_time = time.time() + prefix_len = _key_match(node.key, key) + total_prefix_length += prefix_len + key = key[prefix_len:] + value = value[prefix_len:] - if prefix_len == len(child.key): - if prefix_len == len(key): - return prefix_len - else: - key = key[prefix_len:] - value = value[prefix_len:] - return prefix_len + self._insert_helper(child, key, value) - - new_node = self._split_node(child.key, child, prefix_len) - return prefix_len + self._insert_helper( - new_node, key[prefix_len:], value[prefix_len:] - ) + if prefix_len < len(node.key): + new_node = self._split_node(node.key, node, prefix_len) + node = new_node if len(key): new_node = TreeNode() @@ -320,12 +313,21 @@ class RadixCache(BasePrefixCache): new_node.value = value node.children[key[0]] = new_node self.evictable_size_ += len(value) - return 0 + return total_prefix_length def _print_helper(self, node: TreeNode, indent: int): - for _, child in node.children.items(): - print(" " * indent, len(child.key), child.key[:10], f"r={child.lock_ref}") - self._print_helper(child, indent=indent + 2) + """Prints the radix tree in a human-readable format.""" + stack = [(node, indent)] + while stack: + current_node, current_indent = stack.pop() + print( + " " * current_indent, + len(current_node.key), + current_node.key[:10], + f"r={current_node.lock_ref}", + ) + for _, child in current_node.children.items(): + stack.append((child, current_indent + 2)) def _delete_leaf(self, node): for k, v in node.parent.children.items(): @@ -334,13 +336,17 @@ class RadixCache(BasePrefixCache): del node.parent.children[k] self.evictable_size_ -= len(node.key) - def _total_size_helper(self, node: TreeNode): - if node.evicted: - return 0 - x = len(node.value) - for child in node.children.values(): - x += self._total_size_helper(child) - return x + def _total_size_helper(self): + total_size = 0 + stack = [self.root_node] + while stack: + current_node = stack.pop() + total_size += len(current_node.value) + for child in current_node.children.values(): + if child.evicted: + continue + stack.append(child) + return total_size def _collect_leaves(self): ret_list = []