diff --git a/python/sglang/srt/managers/router/radix_cache.py b/python/sglang/srt/managers/router/radix_cache.py index 6ee670309..7bb8a4b2a 100644 --- a/python/sglang/srt/managers/router/radix_cache.py +++ b/python/sglang/srt/managers/router/radix_cache.py @@ -11,6 +11,7 @@ class TreeNode: def __init__(self): self.children = defaultdict(TreeNode) self.parent = None + self.key = None self.value = None self.ref_counter = 0 self.last_access_time = time.time() @@ -37,6 +38,7 @@ class RadixCache: def reset(self): self.root_node = TreeNode() + self.root_node.key = [] self.root_node.value = [] self.root_node.ref_counter = 1 self.evictable_size_ = 0 @@ -115,40 +117,45 @@ class RadixCache: ##### Internal Helper Functions ##### def _match_prefix_helper(self, node, key, value, last_node): node.last_access_time = time.time() + if len(key) == 0: + return - for c_key, child in node.children.items(): - prefix_len = match(c_key, key) - if prefix_len != 0: - if prefix_len < len(c_key): - new_node = self._split_node(c_key, child, prefix_len) - value.append(new_node.value) - last_node[0] = new_node - else: - value.append(child.value) - last_node[0] = child - self._match_prefix_helper(child, key[prefix_len:], value, last_node) - break + if key[0] in node.children.keys(): + child = node.children[key[0]] + prefix_len = match(child.key, key) + if prefix_len < len(child.key): + new_node = self._split_node(child.key, child, prefix_len) + value.append(new_node.value) + last_node[0] = new_node + else: + value.append(child.value) + last_node[0] = child + self._match_prefix_helper(child, key[prefix_len:], value, last_node) def _split_node(self, key, child, split_len): # new_node -> child new_node = TreeNode() - new_node.children = {key[split_len:]: child} + new_node.children = {key[split_len:][0]: child} new_node.parent = child.parent new_node.ref_counter = child.ref_counter + new_node.key = child.key[:split_len] new_node.value = child.value[:split_len] child.parent = new_node + child.key = child.key[split_len:] child.value = child.value[split_len:] - new_node.parent.children[key[:split_len]] = new_node - del new_node.parent.children[key] + new_node.parent.children[key[:split_len][0]] = new_node return new_node def _insert_helper(self, node, key, value): node.last_access_time = time.time() + if len(key) == 0: + return 0 - for c_key, child in node.children.items(): - prefix_len = match(c_key, key) + if key[0] in node.children.keys(): + child = node.children[key[0]] + prefix_len = match(child.key, key) - if prefix_len == len(c_key): + if prefix_len == len(child.key): if prefix_len == len(key): return prefix_len else: @@ -156,23 +163,23 @@ class RadixCache: value = value[prefix_len:] return prefix_len + self._insert_helper(child, key, value) - if prefix_len: - new_node = self._split_node(c_key, child, prefix_len) - return prefix_len + self._insert_helper( - new_node, key[prefix_len:], value[prefix_len:] - ) + new_node = self._split_node(child.key, child, prefix_len) + return prefix_len + self._insert_helper( + new_node, key[prefix_len:], value[prefix_len:] + ) if len(key): new_node = TreeNode() new_node.parent = node + new_node.key = key new_node.value = value - node.children[key] = new_node + node.children[key[0]] = new_node self.evictable_size_ += len(value) return 0 def _print_helper(self, node, indent): - for key, child in node.children.items(): - print(" " * indent, len(key), key[:10], f"r={child.ref_counter}") + for _, child in node.children.items(): + print(" " * indent, len(child.key), child.key[:10], f"r={child.ref_counter}") self._print_helper(child, indent=indent + 2) def _delete_leaf(self, node): @@ -180,7 +187,7 @@ class RadixCache: if v == node: break del node.parent.children[k] - self.evictable_size_ -= len(k) + self.evictable_size_ -= len(node.key) def _total_size_helper(self, node): x = len(node.value)