[HICache] introduce evict policy (#10190)
Signed-off-by: Xuchun Shang <xuchun.shang@linux.alibaba.com> Co-authored-by: Teng Ma <sima.mt@alibaba-inc.com>
This commit is contained in:
@@ -39,6 +39,7 @@ class HiRadixCache(RadixCache):
|
||||
hicache_io_backend: str,
|
||||
hicache_mem_layout: str,
|
||||
enable_metrics: bool,
|
||||
eviction_policy: str = "lru",
|
||||
hicache_storage_backend: Optional[str] = None,
|
||||
hicache_storage_prefetch_policy: Optional[str] = "best_effort",
|
||||
model_name: Optional[str] = None,
|
||||
@@ -117,8 +118,13 @@ class HiRadixCache(RadixCache):
|
||||
1 if hicache_write_policy == "write_through" else 2
|
||||
)
|
||||
self.load_back_threshold = 10
|
||||
|
||||
super().__init__(
|
||||
req_to_token_pool, token_to_kv_pool_allocator, page_size, disable=False
|
||||
req_to_token_pool,
|
||||
token_to_kv_pool_allocator,
|
||||
page_size,
|
||||
disable=False,
|
||||
eviction_policy=eviction_policy,
|
||||
)
|
||||
|
||||
def reset(self):
|
||||
@@ -258,12 +264,15 @@ class HiRadixCache(RadixCache):
|
||||
|
||||
def evict(self, num_tokens: int):
|
||||
leaves = self._collect_leaves_device()
|
||||
heapq.heapify(leaves)
|
||||
eviction_heap = [
|
||||
(self.eviction_strategy.get_priority(node), node) for node in leaves
|
||||
]
|
||||
heapq.heapify(eviction_heap)
|
||||
|
||||
num_evicted = 0
|
||||
write_back_nodes = []
|
||||
while num_evicted < num_tokens and len(leaves):
|
||||
x = heapq.heappop(leaves)
|
||||
while num_evicted < num_tokens and len(eviction_heap):
|
||||
_priority, x = heapq.heappop(eviction_heap)
|
||||
|
||||
if x.lock_ref > 0:
|
||||
continue
|
||||
@@ -285,7 +294,8 @@ class HiRadixCache(RadixCache):
|
||||
break
|
||||
else:
|
||||
# all children are evicted or no children
|
||||
heapq.heappush(leaves, x.parent)
|
||||
new_priority = self.eviction_strategy.get_priority(x.parent)
|
||||
heapq.heappush(eviction_heap, (new_priority, x.parent))
|
||||
|
||||
if self.cache_controller.write_policy == "write_back":
|
||||
self.writing_check(write_back=True)
|
||||
@@ -310,11 +320,14 @@ class HiRadixCache(RadixCache):
|
||||
|
||||
def evict_host(self, num_tokens: int):
|
||||
leaves = self._collect_leaves()
|
||||
heapq.heapify(leaves)
|
||||
eviction_heap = [
|
||||
(self.eviction_strategy.get_priority(node), node) for node in leaves
|
||||
]
|
||||
heapq.heapify(eviction_heap)
|
||||
|
||||
num_evicted = 0
|
||||
while num_evicted < num_tokens and len(leaves):
|
||||
x = heapq.heappop(leaves)
|
||||
while num_evicted < num_tokens and len(eviction_heap):
|
||||
_priority, x = heapq.heappop(eviction_heap)
|
||||
if x == self.root_node:
|
||||
break
|
||||
# only evict the host value of evicted nodes
|
||||
@@ -333,7 +346,8 @@ class HiRadixCache(RadixCache):
|
||||
del x.parent.children[k]
|
||||
|
||||
if len(x.parent.children) == 0 and x.parent.evicted:
|
||||
heapq.heappush(leaves, x.parent)
|
||||
new_priority = self.eviction_strategy.get_priority(x.parent)
|
||||
heapq.heappush(eviction_heap, (new_priority, x.parent))
|
||||
|
||||
def load_back(
|
||||
self, node: TreeNode, mem_quota: Optional[int] = None
|
||||
|
||||
Reference in New Issue
Block a user