Files
sglang/python/sglang/srt/mem_cache/radix_cache.py

315 lines
9.5 KiB
Python
Raw Normal View History

2024-07-28 23:07:12 +10:00
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
2024-06-08 02:06:52 -07:00
"""
The radix tree data structure for managing the KV cache.
"""
import heapq
import time
from collections import defaultdict
2024-08-07 15:52:24 -07:00
from typing import TYPE_CHECKING
import torch
2024-08-07 15:52:24 -07:00
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import Req
class TreeNode:
def __init__(self):
self.children = defaultdict(TreeNode)
self.parent = None
2024-04-18 00:47:37 +08:00
self.key = None
self.value = None
2024-05-13 12:47:13 +08:00
self.lock_ref = 0
self.last_access_time = time.time()
def __lt__(self, other: "TreeNode"):
return self.last_access_time < other.last_access_time
def _key_match(key0, key1):
i = 0
for k0, k1 in zip(key0, key1):
if k0 != k1:
break
i += 1
return i
class RadixCache(BasePrefixCache):
2024-05-13 12:47:13 +08:00
def __init__(self, req_to_token_pool, token_to_kv_pool, disable: bool = False):
self.req_to_token_pool = req_to_token_pool
self.token_to_kv_pool = token_to_kv_pool
2024-01-26 13:32:59 +08:00
self.disable = disable
self.reset()
2024-01-26 13:32:59 +08:00
##### Public API #####
def reset(self):
self.root_node = TreeNode()
2024-04-18 00:47:37 +08:00
self.root_node.key = []
self.root_node.value = []
2024-05-13 12:47:13 +08:00
self.root_node.lock_ref = 1
self.evictable_size_ = 0
def match_prefix(self, key, **kwargs):
if self.disable:
return [], self.root_node
value = []
last_node = [self.root_node]
self._match_prefix_helper(self.root_node, key, value, last_node)
if value:
value = torch.concat(value)
2024-05-13 12:47:13 +08:00
else:
value = torch.tensor([], dtype=torch.int32)
return value, last_node[0]
def insert(self, key, value=None):
if self.disable:
return 0
if value is None:
value = [x for x in key]
return self._insert_helper(self.root_node, key, value)
2024-08-07 15:52:24 -07:00
def cache_finished_req(self, req: "Req", token_ids=None):
"""Cache request when it finishes."""
if token_ids is None:
token_ids = (req.origin_input_ids + req.output_ids)[:-1]
2024-08-07 15:52:24 -07:00
kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, : len(token_ids)
]
2024-05-13 12:47:13 +08:00
if self.disable:
2024-08-07 15:52:24 -07:00
self.token_to_kv_pool.free(kv_indices)
self.req_to_token_pool.free(req.req_pool_idx)
return
2024-05-13 12:47:13 +08:00
# Radix Cache takes one ref in memory pool
2024-08-07 15:52:24 -07:00
new_prefix_len = self.insert(token_ids, kv_indices.clone())
self.token_to_kv_pool.free(kv_indices[len(req.prefix_indices) : new_prefix_len])
2024-05-13 12:47:13 +08:00
2024-08-07 15:52:24 -07:00
# Remove req slot release the cache lock
self.req_to_token_pool.free(req.req_pool_idx)
self.dec_lock_ref(req.last_node)
def cache_unfinished_req(self, req: "Req", token_ids=None):
"""Cache request when it is unfinished."""
if self.disable:
return
if token_ids is None:
token_ids = req.fill_ids
2024-08-07 15:52:24 -07:00
kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, : len(token_ids)
]
# Radix Cache takes one ref in memory pool
new_prefix_len = self.insert(token_ids, kv_indices.clone())
self.token_to_kv_pool.free(kv_indices[len(req.prefix_indices) : new_prefix_len])
# The prefix indices could be updated, reuse it
new_indices, new_last_node = self.match_prefix(token_ids)
assert len(new_indices) == len(token_ids)
self.req_to_token_pool.req_to_token[
req.req_pool_idx, len(req.prefix_indices) : len(new_indices)
] = new_indices[len(req.prefix_indices) :]
self.dec_lock_ref(req.last_node)
self.inc_lock_ref(new_last_node)
req.prefix_indices = new_indices
req.last_node = new_last_node
2024-05-13 12:47:13 +08:00
def pretty_print(self):
self._print_helper(self.root_node, 0)
print(f"#tokens: {self.total_size()}")
def total_size(self):
return self._total_size_helper(self.root_node)
def evict(self, num_tokens, evict_callback):
if self.disable:
return
leaves = self._collect_leaves()
heapq.heapify(leaves)
num_evicted = 0
while num_evicted < num_tokens and len(leaves):
x = heapq.heappop(leaves)
if x == self.root_node:
break
2024-05-13 12:47:13 +08:00
if x.lock_ref > 0:
continue
2024-07-13 23:39:37 -07:00
evict_callback(x.value)
num_evicted += len(x.value)
self._delete_leaf(x)
if len(x.parent.children) == 0:
heapq.heappush(leaves, x.parent)
2024-05-13 12:47:13 +08:00
def inc_lock_ref(self, node: TreeNode):
if self.disable:
return 0
delta = 0
while node != self.root_node:
2024-05-13 12:47:13 +08:00
if node.lock_ref == 0:
self.evictable_size_ -= len(node.value)
delta -= len(node.value)
2024-05-13 12:47:13 +08:00
node.lock_ref += 1
node = node.parent
return delta
2024-05-13 12:47:13 +08:00
def dec_lock_ref(self, node: TreeNode):
if self.disable:
return 0
delta = 0
while node != self.root_node:
2024-05-13 12:47:13 +08:00
if node.lock_ref == 1:
self.evictable_size_ += len(node.value)
delta += len(node.value)
2024-05-13 12:47:13 +08:00
node.lock_ref -= 1
node = node.parent
return delta
def evictable_size(self):
return self.evictable_size_
##### Internal Helper Functions #####
def _match_prefix_helper(self, node, key, value, last_node):
node.last_access_time = time.time()
2024-04-18 00:47:37 +08:00
if len(key) == 0:
return
if key[0] in node.children.keys():
child = node.children[key[0]]
prefix_len = _key_match(child.key, key)
2024-04-18 00:47:37 +08:00
if prefix_len < len(child.key):
new_node = self._split_node(child.key, child, prefix_len)
value.append(new_node.value)
last_node[0] = new_node
else:
value.append(child.value)
last_node[0] = child
self._match_prefix_helper(child, key[prefix_len:], value, last_node)
2024-05-13 12:47:13 +08:00
def _split_node(self, key, child: TreeNode, split_len):
# new_node -> child
new_node = TreeNode()
2024-04-18 00:47:37 +08:00
new_node.children = {key[split_len:][0]: child}
new_node.parent = child.parent
2024-05-13 12:47:13 +08:00
new_node.lock_ref = child.lock_ref
2024-04-18 00:47:37 +08:00
new_node.key = child.key[:split_len]
new_node.value = child.value[:split_len]
child.parent = new_node
2024-04-18 00:47:37 +08:00
child.key = child.key[split_len:]
child.value = child.value[split_len:]
2024-04-18 00:47:37 +08:00
new_node.parent.children[key[:split_len][0]] = new_node
return new_node
def _insert_helper(self, node, key, value):
node.last_access_time = time.time()
2024-04-18 00:47:37 +08:00
if len(key) == 0:
return 0
2024-04-18 00:47:37 +08:00
if key[0] in node.children.keys():
child = node.children[key[0]]
prefix_len = _key_match(child.key, key)
2024-04-18 00:47:37 +08:00
if prefix_len == len(child.key):
if prefix_len == len(key):
return prefix_len
else:
key = key[prefix_len:]
value = value[prefix_len:]
return prefix_len + self._insert_helper(child, key, value)
2024-04-18 00:47:37 +08:00
new_node = self._split_node(child.key, child, prefix_len)
return prefix_len + self._insert_helper(
new_node, key[prefix_len:], value[prefix_len:]
)
if len(key):
new_node = TreeNode()
new_node.parent = node
2024-04-18 00:47:37 +08:00
new_node.key = key
new_node.value = value
2024-04-18 00:47:37 +08:00
node.children[key[0]] = new_node
self.evictable_size_ += len(value)
return 0
2024-05-13 12:47:13 +08:00
def _print_helper(self, node: TreeNode, indent):
2024-04-18 00:47:37 +08:00
for _, child in node.children.items():
2024-05-13 12:47:13 +08:00
print(" " * indent, len(child.key), child.key[:10], f"r={child.lock_ref}")
self._print_helper(child, indent=indent + 2)
def _delete_leaf(self, node):
for k, v in node.parent.children.items():
if v == node:
break
del node.parent.children[k]
2024-04-18 00:47:37 +08:00
self.evictable_size_ -= len(node.key)
def _total_size_helper(self, node):
x = len(node.value)
for child in node.children.values():
x += self._total_size_helper(child)
return x
def _collect_leaves(self):
ret_list = []
def dfs_(cur_node):
if len(cur_node.children) == 0:
ret_list.append(cur_node)
for x in cur_node.children.values():
dfs_(x)
dfs_(self.root_node)
return ret_list
if __name__ == "__main__":
2024-05-13 12:47:13 +08:00
tree = RadixCache(None, None, False)
tree.insert("Hello")
tree.insert("Hello")
tree.insert("Hello_L.A.!")
# tree.insert("Hello_world! Happy")
# tree.insert("I love you!")
tree.pretty_print()
# print(tree.match_prefix("I love you! aha"))
# def evict_callback(x):
# print("evict", x)
# return len(x)
# tree.evict(5, evict_callback)
# tree.evict(10, evict_callback)
# tree.pretty_print()