Files
sglang/python/sglang/srt/mem_cache/radix_cache.py
2024-07-29 23:04:48 -07:00

288 lines
8.6 KiB
Python

"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
"""
The radix tree data structure for managing the KV cache.
"""
import heapq
import time
from collections import defaultdict
import torch
class TreeNode:
def __init__(self):
self.children = defaultdict(TreeNode)
self.parent = None
self.key = None
self.value = None
self.lock_ref = 0
self.last_access_time = time.time()
def __lt__(self, other: "TreeNode"):
return self.last_access_time < other.last_access_time
def _key_match(key0, key1):
i = 0
for k0, k1 in zip(key0, key1):
if k0 != k1:
break
i += 1
return i
class RadixCache:
def __init__(self, req_to_token_pool, token_to_kv_pool, disable: bool = False):
self.req_to_token_pool = req_to_token_pool
self.token_to_kv_pool = token_to_kv_pool
self.disable = disable
self.reset()
##### Public API #####
def reset(self):
self.root_node = TreeNode()
self.root_node.key = []
self.root_node.value = []
self.root_node.lock_ref = 1
self.evictable_size_ = 0
def match_prefix(self, key):
if self.disable:
return [], self.root_node
value = []
last_node = [self.root_node]
self._match_prefix_helper(self.root_node, key, value, last_node)
if value:
value = torch.concat(value)
else:
value = torch.tensor([], dtype=torch.int64)
return value, last_node[0]
def insert(self, key, value=None):
if self.disable:
return 0
if value is None:
value = [x for x in key]
return self._insert_helper(self.root_node, key, value)
def cache_req(
self,
token_ids,
last_uncached_pos,
req_pool_idx,
del_in_memory_pool=True,
old_last_node=None,
):
# Insert the request into radix cache
indices = self.req_to_token_pool.req_to_token[req_pool_idx, : len(token_ids)]
new_prefix_len = self.insert(token_ids, indices.clone())
if self.disable:
if del_in_memory_pool:
self.token_to_kv_pool.free(indices)
else:
return torch.tensor([], dtype=torch.int64), self.root_node
# Radix Cache takes one ref in memory pool
self.token_to_kv_pool.free(indices[last_uncached_pos:new_prefix_len])
if del_in_memory_pool:
self.req_to_token_pool.free(req_pool_idx)
else:
cached_indices, new_last_node = self.match_prefix(token_ids)
assert len(cached_indices) == len(token_ids)
self.req_to_token_pool.req_to_token[
req_pool_idx, last_uncached_pos : len(cached_indices)
] = cached_indices[last_uncached_pos:]
self.dec_lock_ref(old_last_node)
self.inc_lock_ref(new_last_node)
return cached_indices, new_last_node
def pretty_print(self):
self._print_helper(self.root_node, 0)
print(f"#tokens: {self.total_size()}")
def total_size(self):
return self._total_size_helper(self.root_node)
def evict(self, num_tokens, evict_callback):
if self.disable:
return
leaves = self._collect_leaves()
heapq.heapify(leaves)
num_evicted = 0
while num_evicted < num_tokens and len(leaves):
x = heapq.heappop(leaves)
if x == self.root_node:
break
if x.lock_ref > 0:
continue
evict_callback(x.value)
num_evicted += len(x.value)
self._delete_leaf(x)
if len(x.parent.children) == 0:
heapq.heappush(leaves, x.parent)
def inc_lock_ref(self, node: TreeNode):
delta = 0
while node != self.root_node:
if node.lock_ref == 0:
self.evictable_size_ -= len(node.value)
delta -= len(node.value)
node.lock_ref += 1
node = node.parent
return delta
def dec_lock_ref(self, node: TreeNode):
delta = 0
while node != self.root_node:
if node.lock_ref == 1:
self.evictable_size_ += len(node.value)
delta += len(node.value)
node.lock_ref -= 1
node = node.parent
return delta
def evictable_size(self):
return self.evictable_size_
##### Internal Helper Functions #####
def _match_prefix_helper(self, node, key, value, last_node):
node.last_access_time = time.time()
if len(key) == 0:
return
if key[0] in node.children.keys():
child = node.children[key[0]]
prefix_len = _key_match(child.key, key)
if prefix_len < len(child.key):
new_node = self._split_node(child.key, child, prefix_len)
value.append(new_node.value)
last_node[0] = new_node
else:
value.append(child.value)
last_node[0] = child
self._match_prefix_helper(child, key[prefix_len:], value, last_node)
def _split_node(self, key, child: TreeNode, split_len):
# new_node -> child
new_node = TreeNode()
new_node.children = {key[split_len:][0]: child}
new_node.parent = child.parent
new_node.lock_ref = child.lock_ref
new_node.key = child.key[:split_len]
new_node.value = child.value[:split_len]
child.parent = new_node
child.key = child.key[split_len:]
child.value = child.value[split_len:]
new_node.parent.children[key[:split_len][0]] = new_node
return new_node
def _insert_helper(self, node, key, value):
node.last_access_time = time.time()
if len(key) == 0:
return 0
if key[0] in node.children.keys():
child = node.children[key[0]]
prefix_len = _key_match(child.key, key)
if prefix_len == len(child.key):
if prefix_len == len(key):
return prefix_len
else:
key = key[prefix_len:]
value = value[prefix_len:]
return prefix_len + self._insert_helper(child, key, value)
new_node = self._split_node(child.key, child, prefix_len)
return prefix_len + self._insert_helper(
new_node, key[prefix_len:], value[prefix_len:]
)
if len(key):
new_node = TreeNode()
new_node.parent = node
new_node.key = key
new_node.value = value
node.children[key[0]] = new_node
self.evictable_size_ += len(value)
return 0
def _print_helper(self, node: TreeNode, indent):
for _, child in node.children.items():
print(" " * indent, len(child.key), child.key[:10], f"r={child.lock_ref}")
self._print_helper(child, indent=indent + 2)
def _delete_leaf(self, node):
for k, v in node.parent.children.items():
if v == node:
break
del node.parent.children[k]
self.evictable_size_ -= len(node.key)
def _total_size_helper(self, node):
x = len(node.value)
for child in node.children.values():
x += self._total_size_helper(child)
return x
def _collect_leaves(self):
ret_list = []
def dfs_(cur_node):
if len(cur_node.children) == 0:
ret_list.append(cur_node)
for x in cur_node.children.values():
dfs_(x)
dfs_(self.root_node)
return ret_list
if __name__ == "__main__":
tree = RadixCache(None, None, False)
tree.insert("Hello")
tree.insert("Hello")
tree.insert("Hello_L.A.!")
# tree.insert("Hello_world! Happy")
# tree.insert("I love you!")
tree.pretty_print()
# print(tree.match_prefix("I love you! aha"))
# def evict_callback(x):
# print("evict", x)
# return len(x)
# tree.evict(5, evict_callback)
# tree.evict(10, evict_callback)
# tree.pretty_print()