Minor: style improvement of radix_cache and memory_pool (#395)
This commit is contained in:
@@ -236,9 +236,8 @@ class Batch:
|
|||||||
extend_num_tokens = seq_lens.sum() - prefix_lens.sum()
|
extend_num_tokens = seq_lens.sum() - prefix_lens.sum()
|
||||||
out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
|
out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
|
||||||
if out_cache_loc is None:
|
if out_cache_loc is None:
|
||||||
if not self.tree_cache.disable:
|
self.tree_cache.evict(extend_num_tokens, self.token_to_kv_pool.dec_refs)
|
||||||
self.tree_cache.evict(extend_num_tokens, self.token_to_kv_pool.free)
|
out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
|
||||||
out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
|
|
||||||
|
|
||||||
if out_cache_loc is None:
|
if out_cache_loc is None:
|
||||||
print("Prefill out of memory. This should never happen.")
|
print("Prefill out of memory. This should never happen.")
|
||||||
@@ -307,8 +306,8 @@ class Batch:
|
|||||||
if self.token_to_kv_pool.available_size() >= bs:
|
if self.token_to_kv_pool.available_size() >= bs:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if not self.tree_cache.disable:
|
self.tree_cache.evict(bs, self.token_to_kv_pool.dec_refs)
|
||||||
self.tree_cache.evict(bs, self.token_to_kv_pool.free)
|
|
||||||
if self.token_to_kv_pool.available_size() >= bs:
|
if self.token_to_kv_pool.available_size() >= bs:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@@ -341,7 +340,7 @@ class Batch:
|
|||||||
token_indices = self.req_to_token_pool.req_to_token[
|
token_indices = self.req_to_token_pool.req_to_token[
|
||||||
req_pool_indices_np[idx]
|
req_pool_indices_np[idx]
|
||||||
][: seq_lens_np[idx]]
|
][: seq_lens_np[idx]]
|
||||||
self.token_to_kv_pool.free(token_indices)
|
self.token_to_kv_pool.dec_refs(token_indices)
|
||||||
|
|
||||||
self.filter_batch(sorted_indices)
|
self.filter_batch(sorted_indices)
|
||||||
|
|
||||||
@@ -372,7 +371,7 @@ class Batch:
|
|||||||
prefix_len = self.tree_cache.insert(
|
prefix_len = self.tree_cache.insert(
|
||||||
token_ids_in_memory, indices.clone()
|
token_ids_in_memory, indices.clone()
|
||||||
)
|
)
|
||||||
self.token_to_kv_pool.free(indices[:prefix_len])
|
self.token_to_kv_pool.dec_refs(indices[:prefix_len])
|
||||||
self.req_to_token_pool.free(req_pool_idx)
|
self.req_to_token_pool.free(req_pool_idx)
|
||||||
self.tree_cache.dec_ref_counter(req.last_node)
|
self.tree_cache.dec_ref_counter(req.last_node)
|
||||||
|
|
||||||
|
|||||||
@@ -113,7 +113,7 @@ class ModelRpcServer:
|
|||||||
logger.info(server_args.get_optional_modes_logging())
|
logger.info(server_args.get_optional_modes_logging())
|
||||||
|
|
||||||
# Init cache
|
# Init cache
|
||||||
self.tree_cache = RadixCache(server_args.disable_radix_cache)
|
self.tree_cache = RadixCache(disable=server_args.disable_radix_cache)
|
||||||
self.tree_cache_metrics = {"total": 0, "hit": 0}
|
self.tree_cache_metrics = {"total": 0, "hit": 0}
|
||||||
self.scheduler = Scheduler(
|
self.scheduler = Scheduler(
|
||||||
self.schedule_heuristic,
|
self.schedule_heuristic,
|
||||||
@@ -628,7 +628,7 @@ class ModelRpcServer:
|
|||||||
token_ids[:seq_len], indices.clone()
|
token_ids[:seq_len], indices.clone()
|
||||||
)
|
)
|
||||||
|
|
||||||
self.token_to_kv_pool.free(indices[:prefix_len])
|
self.token_to_kv_pool.dec_refs(indices[:prefix_len])
|
||||||
self.req_to_token_pool.free(req_pool_idx)
|
self.req_to_token_pool.free(req_pool_idx)
|
||||||
self.tree_cache.dec_ref_counter(req.last_node)
|
self.tree_cache.dec_ref_counter(req.last_node)
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,6 @@
|
|||||||
import heapq
|
import heapq
|
||||||
import time
|
import time
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Tuple
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -16,23 +14,23 @@ class TreeNode:
|
|||||||
self.ref_counter = 0
|
self.ref_counter = 0
|
||||||
self.last_access_time = time.time()
|
self.last_access_time = time.time()
|
||||||
|
|
||||||
def __lt__(self, other):
|
def __lt__(self, other: "TreeNode"):
|
||||||
return self.last_access_time < other.last_access_time
|
return self.last_access_time < other.last_access_time
|
||||||
|
|
||||||
|
|
||||||
def match(key, seq):
|
def _key_match(key0, key1):
|
||||||
i = 0
|
i = 0
|
||||||
for k, w in zip(key, seq):
|
for k0, k1 in zip(key0, key1):
|
||||||
if k != w:
|
if k0 != k1:
|
||||||
break
|
break
|
||||||
i += 1
|
i += 1
|
||||||
return i
|
return i
|
||||||
|
|
||||||
|
|
||||||
class RadixCache:
|
class RadixCache:
|
||||||
def __init__(self, disable=False):
|
def __init__(self, disable: bool = False):
|
||||||
self.reset()
|
|
||||||
self.disable = disable
|
self.disable = disable
|
||||||
|
self.reset()
|
||||||
|
|
||||||
##### Public API #####
|
##### Public API #####
|
||||||
|
|
||||||
@@ -71,7 +69,7 @@ class RadixCache:
|
|||||||
|
|
||||||
def evict(self, num_tokens, evict_callback):
|
def evict(self, num_tokens, evict_callback):
|
||||||
if self.disable:
|
if self.disable:
|
||||||
raise RuntimeError()
|
return
|
||||||
|
|
||||||
leaves = self._collect_leaves()
|
leaves = self._collect_leaves()
|
||||||
heapq.heapify(leaves)
|
heapq.heapify(leaves)
|
||||||
@@ -115,6 +113,7 @@ class RadixCache:
|
|||||||
return self.evictable_size_
|
return self.evictable_size_
|
||||||
|
|
||||||
##### Internal Helper Functions #####
|
##### Internal Helper Functions #####
|
||||||
|
|
||||||
def _match_prefix_helper(self, node, key, value, last_node):
|
def _match_prefix_helper(self, node, key, value, last_node):
|
||||||
node.last_access_time = time.time()
|
node.last_access_time = time.time()
|
||||||
if len(key) == 0:
|
if len(key) == 0:
|
||||||
@@ -122,7 +121,7 @@ class RadixCache:
|
|||||||
|
|
||||||
if key[0] in node.children.keys():
|
if key[0] in node.children.keys():
|
||||||
child = node.children[key[0]]
|
child = node.children[key[0]]
|
||||||
prefix_len = match(child.key, key)
|
prefix_len = _key_match(child.key, key)
|
||||||
if prefix_len < len(child.key):
|
if prefix_len < len(child.key):
|
||||||
new_node = self._split_node(child.key, child, prefix_len)
|
new_node = self._split_node(child.key, child, prefix_len)
|
||||||
value.append(new_node.value)
|
value.append(new_node.value)
|
||||||
@@ -153,7 +152,7 @@ class RadixCache:
|
|||||||
|
|
||||||
if key[0] in node.children.keys():
|
if key[0] in node.children.keys():
|
||||||
child = node.children[key[0]]
|
child = node.children[key[0]]
|
||||||
prefix_len = match(child.key, key)
|
prefix_len = _key_match(child.key, key)
|
||||||
|
|
||||||
if prefix_len == len(child.key):
|
if prefix_len == len(child.key):
|
||||||
if prefix_len == len(key):
|
if prefix_len == len(key):
|
||||||
@@ -212,7 +211,7 @@ class RadixCache:
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
tree = RadixCache(disable=False)
|
tree = RadixCache()
|
||||||
|
|
||||||
tree.insert("Hello")
|
tree.insert("Hello")
|
||||||
tree.insert("Hello")
|
tree.insert("Hello")
|
||||||
|
|||||||
@@ -31,9 +31,6 @@ class ReqToTokenPool:
|
|||||||
self.can_use_mem_size += free_index.shape[0]
|
self.can_use_mem_size += free_index.shape[0]
|
||||||
self.mem_state[free_index] = 1
|
self.mem_state[free_index] = 1
|
||||||
|
|
||||||
# if self.can_use_mem_size == len(self.mem_state):
|
|
||||||
# print(f"ReqToTokenPool: freed all. size = {self.can_use_mem_size}.")
|
|
||||||
|
|
||||||
def clear(self):
|
def clear(self):
|
||||||
self.mem_state.fill_(1)
|
self.mem_state.fill_(1)
|
||||||
self.can_use_mem_size = len(self.mem_state)
|
self.can_use_mem_size = len(self.mem_state)
|
||||||
@@ -42,7 +39,7 @@ class ReqToTokenPool:
|
|||||||
class TokenToKVPool:
|
class TokenToKVPool:
|
||||||
def __init__(self, size, dtype, head_num, head_dim, layer_num):
|
def __init__(self, size, dtype, head_num, head_dim, layer_num):
|
||||||
self.mem_state = torch.zeros((size,), dtype=torch.int16, device="cuda")
|
self.mem_state = torch.zeros((size,), dtype=torch.int16, device="cuda")
|
||||||
self.alloc_ct = 0
|
self.total_ref_ct = 0
|
||||||
|
|
||||||
# [size, key/value, head_num, head_dim] for each layer
|
# [size, key/value, head_num, head_dim] for each layer
|
||||||
self.kv_data = [
|
self.kv_data = [
|
||||||
@@ -83,9 +80,6 @@ class TokenToKVPool:
|
|||||||
self.add_refs(select_index)
|
self.add_refs(select_index)
|
||||||
return select_index.to(torch.int32), start_loc, start_loc + need_size
|
return select_index.to(torch.int32), start_loc, start_loc + need_size
|
||||||
|
|
||||||
def free(self, free_index):
|
|
||||||
return self.decrease_refs(free_index)
|
|
||||||
|
|
||||||
def used_size(self):
|
def used_size(self):
|
||||||
return len(torch.nonzero(self.mem_state).squeeze(1))
|
return len(torch.nonzero(self.mem_state).squeeze(1))
|
||||||
|
|
||||||
@@ -93,20 +87,17 @@ class TokenToKVPool:
|
|||||||
return torch.sum(self.mem_state == 0).item()
|
return torch.sum(self.mem_state == 0).item()
|
||||||
|
|
||||||
def add_refs(self, token_index: torch.Tensor):
|
def add_refs(self, token_index: torch.Tensor):
|
||||||
self.alloc_ct += len(token_index)
|
self.total_ref_ct += len(token_index)
|
||||||
self.mem_state[token_index] += 1
|
self.mem_state[token_index] += 1
|
||||||
|
|
||||||
def decrease_refs(self, token_index: torch.Tensor):
|
def dec_refs(self, token_index: torch.Tensor):
|
||||||
self.alloc_ct -= len(token_index)
|
self.total_ref_ct -= len(token_index)
|
||||||
self.mem_state[token_index] -= 1
|
self.mem_state[token_index] -= 1
|
||||||
|
|
||||||
num_freed = torch.sum(self.mem_state[token_index] == 0)
|
num_freed = torch.sum(self.mem_state[token_index] == 0)
|
||||||
|
|
||||||
# if self.alloc_ct == 0:
|
|
||||||
# print(f"TokenToKVPool: freed all. size = {len(self.mem_state)}.")
|
|
||||||
|
|
||||||
return num_freed
|
return num_freed
|
||||||
|
|
||||||
def clear(self):
|
def clear(self):
|
||||||
self.mem_state.fill_(0)
|
self.mem_state.fill_(0)
|
||||||
self.alloc_ct = 0
|
self.total_ref_ct = 0
|
||||||
|
|||||||
@@ -500,7 +500,7 @@ async def v1_chat_completions(raw_request: Request):
|
|||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
def launch_server(server_args, pipe_finish_writer):
|
def launch_server(server_args: ServerArgs, pipe_finish_writer):
|
||||||
global tokenizer_manager
|
global tokenizer_manager
|
||||||
global chat_template_name
|
global chat_template_name
|
||||||
|
|
||||||
|
|||||||
@@ -105,7 +105,7 @@ def test_generate_worker(
|
|||||||
|
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
req_idx = req_pool_indices[i].item()
|
req_idx = req_pool_indices[i].item()
|
||||||
model.token_to_kv_pool.free(
|
model.token_to_kv_pool.dec_refs(
|
||||||
model.req_to_token_pool.req_to_token[req_idx, : seq_lens[i]]
|
model.req_to_token_pool.req_to_token[req_idx, : seq_lens[i]]
|
||||||
)
|
)
|
||||||
model.req_to_token_pool.free(req_pool_indices)
|
model.req_to_token_pool.free(req_pool_indices)
|
||||||
|
|||||||
Reference in New Issue
Block a user