Code structure refactor (#807)

This commit is contained in:
Liangsheng Yin
2024-07-29 23:04:48 -07:00
committed by GitHub
parent 21e22b9e96
commit cdcbde5fc3
41 changed files with 106 additions and 105 deletions

View File

@@ -0,0 +1,33 @@
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
"""
Flush the KV cache.
Usage:
python3 -m sglang.srt.mem_cache.flush_cache --url http://localhost:30000
"""
import argparse
import requests
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--url", type=str, default="http://localhost:30000")
args = parser.parse_args()
response = requests.get(args.url + "/flush_cache")
assert response.status_code == 200

View File

@@ -0,0 +1,141 @@
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
"""Memory pool."""
import logging
import torch
logger = logging.getLogger(__name__)
class ReqToTokenPool:
"""A memory pool that maps a request to its token locations."""
def __init__(self, size: int, max_context_len: int):
self.size = size
self.mem_state = torch.ones((size,), dtype=torch.bool, device="cuda")
self.req_to_token = torch.empty(
(size, max_context_len), dtype=torch.int32, device="cuda"
)
self.can_use_mem_size = size
def alloc(self, need_size: int):
if need_size > self.can_use_mem_size:
return None
select_index = (
torch.nonzero(self.mem_state).squeeze(1)[:need_size].to(torch.int32)
)
self.mem_state[select_index] = False
self.can_use_mem_size -= need_size
return select_index
def free(self, free_index):
self.mem_state[free_index] = True
if isinstance(free_index, (int,)):
self.can_use_mem_size += 1
else:
self.can_use_mem_size += free_index.shape[0]
def clear(self):
self.mem_state.fill_(True)
self.can_use_mem_size = len(self.mem_state)
class TokenToKVPool:
"""A memory pool that maps a token to its kv cache locations"""
def __init__(
self,
size: int,
dtype: torch.dtype,
head_num: int,
head_dim: int,
layer_num: int,
):
self.size = size
# We also add one slot. This slot is used for writing dummy output from padded tokens.
self.mem_state = torch.ones((self.size + 1,), dtype=torch.bool, device="cuda")
# [size, head_num, head_dim] for each layer
self.k_buffer = [
torch.empty((size + 1, head_num, head_dim), dtype=dtype, device="cuda")
for _ in range(layer_num)
]
self.v_buffer = [
torch.empty((size + 1, head_num, head_dim), dtype=dtype, device="cuda")
for _ in range(layer_num)
]
# Prefetch buffer
self.prefetch_buffer = torch.empty(0, device="cuda", dtype=torch.int32)
self.prefetch_chunk_size = 512
self.can_use_mem_size = self.size
self.clear()
def get_key_buffer(self, layer_id: int):
return self.k_buffer[layer_id]
def get_value_buffer(self, layer_id: int):
return self.v_buffer[layer_id]
def get_kv_buffer(self, layer_id: int):
return self.k_buffer[layer_id], self.v_buffer[layer_id]
def available_size(self):
return self.can_use_mem_size + len(self.prefetch_buffer)
def alloc(self, need_size: int):
buffer_len = len(self.prefetch_buffer)
if need_size <= buffer_len:
select_index = self.prefetch_buffer[:need_size]
self.prefetch_buffer = self.prefetch_buffer[need_size:]
return select_index
addition_size = need_size - buffer_len
alloc_size = max(addition_size, self.prefetch_chunk_size)
select_index = (
torch.nonzero(self.mem_state).squeeze(1)[:alloc_size].to(torch.int32)
)
if select_index.shape[0] < addition_size:
return None
self.mem_state[select_index] = False
self.can_use_mem_size -= len(select_index)
self.prefetch_buffer = torch.cat((self.prefetch_buffer, select_index))
ret_index = self.prefetch_buffer[:need_size]
self.prefetch_buffer = self.prefetch_buffer[need_size:]
return ret_index
def free(self, free_index: torch.Tensor):
self.mem_state[free_index] = True
self.can_use_mem_size += len(free_index)
def clear(self):
self.prefetch_buffer = torch.empty(0, device="cuda", dtype=torch.int32)
self.mem_state.fill_(True)
self.can_use_mem_size = self.size
# We also add one slot. This slot is used for writing dummy output from padded tokens.
self.mem_state[0] = False

View File

@@ -0,0 +1,287 @@
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
"""
The radix tree data structure for managing the KV cache.
"""
import heapq
import time
from collections import defaultdict
import torch
class TreeNode:
def __init__(self):
self.children = defaultdict(TreeNode)
self.parent = None
self.key = None
self.value = None
self.lock_ref = 0
self.last_access_time = time.time()
def __lt__(self, other: "TreeNode"):
return self.last_access_time < other.last_access_time
def _key_match(key0, key1):
i = 0
for k0, k1 in zip(key0, key1):
if k0 != k1:
break
i += 1
return i
class RadixCache:
def __init__(self, req_to_token_pool, token_to_kv_pool, disable: bool = False):
self.req_to_token_pool = req_to_token_pool
self.token_to_kv_pool = token_to_kv_pool
self.disable = disable
self.reset()
##### Public API #####
def reset(self):
self.root_node = TreeNode()
self.root_node.key = []
self.root_node.value = []
self.root_node.lock_ref = 1
self.evictable_size_ = 0
def match_prefix(self, key):
if self.disable:
return [], self.root_node
value = []
last_node = [self.root_node]
self._match_prefix_helper(self.root_node, key, value, last_node)
if value:
value = torch.concat(value)
else:
value = torch.tensor([], dtype=torch.int64)
return value, last_node[0]
def insert(self, key, value=None):
if self.disable:
return 0
if value is None:
value = [x for x in key]
return self._insert_helper(self.root_node, key, value)
def cache_req(
self,
token_ids,
last_uncached_pos,
req_pool_idx,
del_in_memory_pool=True,
old_last_node=None,
):
# Insert the request into radix cache
indices = self.req_to_token_pool.req_to_token[req_pool_idx, : len(token_ids)]
new_prefix_len = self.insert(token_ids, indices.clone())
if self.disable:
if del_in_memory_pool:
self.token_to_kv_pool.free(indices)
else:
return torch.tensor([], dtype=torch.int64), self.root_node
# Radix Cache takes one ref in memory pool
self.token_to_kv_pool.free(indices[last_uncached_pos:new_prefix_len])
if del_in_memory_pool:
self.req_to_token_pool.free(req_pool_idx)
else:
cached_indices, new_last_node = self.match_prefix(token_ids)
assert len(cached_indices) == len(token_ids)
self.req_to_token_pool.req_to_token[
req_pool_idx, last_uncached_pos : len(cached_indices)
] = cached_indices[last_uncached_pos:]
self.dec_lock_ref(old_last_node)
self.inc_lock_ref(new_last_node)
return cached_indices, new_last_node
def pretty_print(self):
self._print_helper(self.root_node, 0)
print(f"#tokens: {self.total_size()}")
def total_size(self):
return self._total_size_helper(self.root_node)
def evict(self, num_tokens, evict_callback):
if self.disable:
return
leaves = self._collect_leaves()
heapq.heapify(leaves)
num_evicted = 0
while num_evicted < num_tokens and len(leaves):
x = heapq.heappop(leaves)
if x == self.root_node:
break
if x.lock_ref > 0:
continue
evict_callback(x.value)
num_evicted += len(x.value)
self._delete_leaf(x)
if len(x.parent.children) == 0:
heapq.heappush(leaves, x.parent)
def inc_lock_ref(self, node: TreeNode):
delta = 0
while node != self.root_node:
if node.lock_ref == 0:
self.evictable_size_ -= len(node.value)
delta -= len(node.value)
node.lock_ref += 1
node = node.parent
return delta
def dec_lock_ref(self, node: TreeNode):
delta = 0
while node != self.root_node:
if node.lock_ref == 1:
self.evictable_size_ += len(node.value)
delta += len(node.value)
node.lock_ref -= 1
node = node.parent
return delta
def evictable_size(self):
return self.evictable_size_
##### Internal Helper Functions #####
def _match_prefix_helper(self, node, key, value, last_node):
node.last_access_time = time.time()
if len(key) == 0:
return
if key[0] in node.children.keys():
child = node.children[key[0]]
prefix_len = _key_match(child.key, key)
if prefix_len < len(child.key):
new_node = self._split_node(child.key, child, prefix_len)
value.append(new_node.value)
last_node[0] = new_node
else:
value.append(child.value)
last_node[0] = child
self._match_prefix_helper(child, key[prefix_len:], value, last_node)
def _split_node(self, key, child: TreeNode, split_len):
# new_node -> child
new_node = TreeNode()
new_node.children = {key[split_len:][0]: child}
new_node.parent = child.parent
new_node.lock_ref = child.lock_ref
new_node.key = child.key[:split_len]
new_node.value = child.value[:split_len]
child.parent = new_node
child.key = child.key[split_len:]
child.value = child.value[split_len:]
new_node.parent.children[key[:split_len][0]] = new_node
return new_node
def _insert_helper(self, node, key, value):
node.last_access_time = time.time()
if len(key) == 0:
return 0
if key[0] in node.children.keys():
child = node.children[key[0]]
prefix_len = _key_match(child.key, key)
if prefix_len == len(child.key):
if prefix_len == len(key):
return prefix_len
else:
key = key[prefix_len:]
value = value[prefix_len:]
return prefix_len + self._insert_helper(child, key, value)
new_node = self._split_node(child.key, child, prefix_len)
return prefix_len + self._insert_helper(
new_node, key[prefix_len:], value[prefix_len:]
)
if len(key):
new_node = TreeNode()
new_node.parent = node
new_node.key = key
new_node.value = value
node.children[key[0]] = new_node
self.evictable_size_ += len(value)
return 0
def _print_helper(self, node: TreeNode, indent):
for _, child in node.children.items():
print(" " * indent, len(child.key), child.key[:10], f"r={child.lock_ref}")
self._print_helper(child, indent=indent + 2)
def _delete_leaf(self, node):
for k, v in node.parent.children.items():
if v == node:
break
del node.parent.children[k]
self.evictable_size_ -= len(node.key)
def _total_size_helper(self, node):
x = len(node.value)
for child in node.children.values():
x += self._total_size_helper(child)
return x
def _collect_leaves(self):
ret_list = []
def dfs_(cur_node):
if len(cur_node.children) == 0:
ret_list.append(cur_node)
for x in cur_node.children.values():
dfs_(x)
dfs_(self.root_node)
return ret_list
if __name__ == "__main__":
tree = RadixCache(None, None, False)
tree.insert("Hello")
tree.insert("Hello")
tree.insert("Hello_L.A.!")
# tree.insert("Hello_world! Happy")
# tree.insert("I love you!")
tree.pretty_print()
# print(tree.match_prefix("I love you! aha"))
# def evict_callback(x):
# print("evict", x)
# return len(x)
# tree.evict(5, evict_callback)
# tree.evict(10, evict_callback)
# tree.pretty_print()