Minor: style improvement of radix_cache and memory_pool (#395)
This commit is contained in:
@@ -236,9 +236,8 @@ class Batch:
|
||||
extend_num_tokens = seq_lens.sum() - prefix_lens.sum()
|
||||
out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
|
||||
if out_cache_loc is None:
|
||||
if not self.tree_cache.disable:
|
||||
self.tree_cache.evict(extend_num_tokens, self.token_to_kv_pool.free)
|
||||
out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
|
||||
self.tree_cache.evict(extend_num_tokens, self.token_to_kv_pool.dec_refs)
|
||||
out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
|
||||
|
||||
if out_cache_loc is None:
|
||||
print("Prefill out of memory. This should never happen.")
|
||||
@@ -307,8 +306,8 @@ class Batch:
|
||||
if self.token_to_kv_pool.available_size() >= bs:
|
||||
return True
|
||||
|
||||
if not self.tree_cache.disable:
|
||||
self.tree_cache.evict(bs, self.token_to_kv_pool.free)
|
||||
self.tree_cache.evict(bs, self.token_to_kv_pool.dec_refs)
|
||||
|
||||
if self.token_to_kv_pool.available_size() >= bs:
|
||||
return True
|
||||
|
||||
@@ -341,7 +340,7 @@ class Batch:
|
||||
token_indices = self.req_to_token_pool.req_to_token[
|
||||
req_pool_indices_np[idx]
|
||||
][: seq_lens_np[idx]]
|
||||
self.token_to_kv_pool.free(token_indices)
|
||||
self.token_to_kv_pool.dec_refs(token_indices)
|
||||
|
||||
self.filter_batch(sorted_indices)
|
||||
|
||||
@@ -372,7 +371,7 @@ class Batch:
|
||||
prefix_len = self.tree_cache.insert(
|
||||
token_ids_in_memory, indices.clone()
|
||||
)
|
||||
self.token_to_kv_pool.free(indices[:prefix_len])
|
||||
self.token_to_kv_pool.dec_refs(indices[:prefix_len])
|
||||
self.req_to_token_pool.free(req_pool_idx)
|
||||
self.tree_cache.dec_ref_counter(req.last_node)
|
||||
|
||||
|
||||
@@ -113,7 +113,7 @@ class ModelRpcServer:
|
||||
logger.info(server_args.get_optional_modes_logging())
|
||||
|
||||
# Init cache
|
||||
self.tree_cache = RadixCache(server_args.disable_radix_cache)
|
||||
self.tree_cache = RadixCache(disable=server_args.disable_radix_cache)
|
||||
self.tree_cache_metrics = {"total": 0, "hit": 0}
|
||||
self.scheduler = Scheduler(
|
||||
self.schedule_heuristic,
|
||||
@@ -628,7 +628,7 @@ class ModelRpcServer:
|
||||
token_ids[:seq_len], indices.clone()
|
||||
)
|
||||
|
||||
self.token_to_kv_pool.free(indices[:prefix_len])
|
||||
self.token_to_kv_pool.dec_refs(indices[:prefix_len])
|
||||
self.req_to_token_pool.free(req_pool_idx)
|
||||
self.tree_cache.dec_ref_counter(req.last_node)
|
||||
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
import heapq
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@@ -16,23 +14,23 @@ class TreeNode:
|
||||
self.ref_counter = 0
|
||||
self.last_access_time = time.time()
|
||||
|
||||
def __lt__(self, other):
|
||||
def __lt__(self, other: "TreeNode"):
|
||||
return self.last_access_time < other.last_access_time
|
||||
|
||||
|
||||
def match(key, seq):
|
||||
def _key_match(key0, key1):
|
||||
i = 0
|
||||
for k, w in zip(key, seq):
|
||||
if k != w:
|
||||
for k0, k1 in zip(key0, key1):
|
||||
if k0 != k1:
|
||||
break
|
||||
i += 1
|
||||
return i
|
||||
|
||||
|
||||
class RadixCache:
|
||||
def __init__(self, disable=False):
|
||||
self.reset()
|
||||
def __init__(self, disable: bool = False):
|
||||
self.disable = disable
|
||||
self.reset()
|
||||
|
||||
##### Public API #####
|
||||
|
||||
@@ -71,7 +69,7 @@ class RadixCache:
|
||||
|
||||
def evict(self, num_tokens, evict_callback):
|
||||
if self.disable:
|
||||
raise RuntimeError()
|
||||
return
|
||||
|
||||
leaves = self._collect_leaves()
|
||||
heapq.heapify(leaves)
|
||||
@@ -115,6 +113,7 @@ class RadixCache:
|
||||
return self.evictable_size_
|
||||
|
||||
##### Internal Helper Functions #####
|
||||
|
||||
def _match_prefix_helper(self, node, key, value, last_node):
|
||||
node.last_access_time = time.time()
|
||||
if len(key) == 0:
|
||||
@@ -122,7 +121,7 @@ class RadixCache:
|
||||
|
||||
if key[0] in node.children.keys():
|
||||
child = node.children[key[0]]
|
||||
prefix_len = match(child.key, key)
|
||||
prefix_len = _key_match(child.key, key)
|
||||
if prefix_len < len(child.key):
|
||||
new_node = self._split_node(child.key, child, prefix_len)
|
||||
value.append(new_node.value)
|
||||
@@ -153,7 +152,7 @@ class RadixCache:
|
||||
|
||||
if key[0] in node.children.keys():
|
||||
child = node.children[key[0]]
|
||||
prefix_len = match(child.key, key)
|
||||
prefix_len = _key_match(child.key, key)
|
||||
|
||||
if prefix_len == len(child.key):
|
||||
if prefix_len == len(key):
|
||||
@@ -212,7 +211,7 @@ class RadixCache:
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tree = RadixCache(disable=False)
|
||||
tree = RadixCache()
|
||||
|
||||
tree.insert("Hello")
|
||||
tree.insert("Hello")
|
||||
|
||||
@@ -31,9 +31,6 @@ class ReqToTokenPool:
|
||||
self.can_use_mem_size += free_index.shape[0]
|
||||
self.mem_state[free_index] = 1
|
||||
|
||||
# if self.can_use_mem_size == len(self.mem_state):
|
||||
# print(f"ReqToTokenPool: freed all. size = {self.can_use_mem_size}.")
|
||||
|
||||
def clear(self):
|
||||
self.mem_state.fill_(1)
|
||||
self.can_use_mem_size = len(self.mem_state)
|
||||
@@ -42,7 +39,7 @@ class ReqToTokenPool:
|
||||
class TokenToKVPool:
|
||||
def __init__(self, size, dtype, head_num, head_dim, layer_num):
|
||||
self.mem_state = torch.zeros((size,), dtype=torch.int16, device="cuda")
|
||||
self.alloc_ct = 0
|
||||
self.total_ref_ct = 0
|
||||
|
||||
# [size, key/value, head_num, head_dim] for each layer
|
||||
self.kv_data = [
|
||||
@@ -83,9 +80,6 @@ class TokenToKVPool:
|
||||
self.add_refs(select_index)
|
||||
return select_index.to(torch.int32), start_loc, start_loc + need_size
|
||||
|
||||
def free(self, free_index):
|
||||
return self.decrease_refs(free_index)
|
||||
|
||||
def used_size(self):
|
||||
return len(torch.nonzero(self.mem_state).squeeze(1))
|
||||
|
||||
@@ -93,20 +87,17 @@ class TokenToKVPool:
|
||||
return torch.sum(self.mem_state == 0).item()
|
||||
|
||||
def add_refs(self, token_index: torch.Tensor):
|
||||
self.alloc_ct += len(token_index)
|
||||
self.total_ref_ct += len(token_index)
|
||||
self.mem_state[token_index] += 1
|
||||
|
||||
def decrease_refs(self, token_index: torch.Tensor):
|
||||
self.alloc_ct -= len(token_index)
|
||||
def dec_refs(self, token_index: torch.Tensor):
|
||||
self.total_ref_ct -= len(token_index)
|
||||
self.mem_state[token_index] -= 1
|
||||
|
||||
num_freed = torch.sum(self.mem_state[token_index] == 0)
|
||||
|
||||
# if self.alloc_ct == 0:
|
||||
# print(f"TokenToKVPool: freed all. size = {len(self.mem_state)}.")
|
||||
|
||||
return num_freed
|
||||
|
||||
def clear(self):
|
||||
self.mem_state.fill_(0)
|
||||
self.alloc_ct = 0
|
||||
self.total_ref_ct = 0
|
||||
|
||||
@@ -500,7 +500,7 @@ async def v1_chat_completions(raw_request: Request):
|
||||
return response
|
||||
|
||||
|
||||
def launch_server(server_args, pipe_finish_writer):
|
||||
def launch_server(server_args: ServerArgs, pipe_finish_writer):
|
||||
global tokenizer_manager
|
||||
global chat_template_name
|
||||
|
||||
|
||||
Reference in New Issue
Block a user