[HICache] introduce evict policy (#10190)

Signed-off-by: Xuchun Shang <xuchun.shang@linux.alibaba.com>
Co-authored-by: Teng Ma <sima.mt@alibaba-inc.com>
This commit is contained in:
Xuchun Shang
2025-09-18 11:10:20 +08:00
committed by GitHub
parent c32fb7a24d
commit 1ccd59c715
6 changed files with 78 additions and 13 deletions

View File

@@ -0,0 +1,23 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, List, Tuple, Union
if TYPE_CHECKING:
from sglang.srt.mem_cache.radix_cache import TreeNode
class EvictionStrategy(ABC):
@abstractmethod
def get_priority(self, node: "TreeNode") -> Union[float, Tuple]:
pass
class LRUStrategy(EvictionStrategy):
def get_priority(self, node: "TreeNode") -> float:
return node.last_access_time
class LFUStrategy(EvictionStrategy):
def get_priority(self, node: "TreeNode") -> Tuple[int, float]:
return (node.hit_count, node.last_access_time)

View File

@@ -39,6 +39,7 @@ class HiRadixCache(RadixCache):
hicache_io_backend: str,
hicache_mem_layout: str,
enable_metrics: bool,
eviction_policy: str = "lru",
hicache_storage_backend: Optional[str] = None,
hicache_storage_prefetch_policy: Optional[str] = "best_effort",
model_name: Optional[str] = None,
@@ -117,8 +118,13 @@ class HiRadixCache(RadixCache):
1 if hicache_write_policy == "write_through" else 2
)
self.load_back_threshold = 10
super().__init__(
req_to_token_pool, token_to_kv_pool_allocator, page_size, disable=False
req_to_token_pool,
token_to_kv_pool_allocator,
page_size,
disable=False,
eviction_policy=eviction_policy,
)
def reset(self):
@@ -258,12 +264,15 @@ class HiRadixCache(RadixCache):
def evict(self, num_tokens: int):
leaves = self._collect_leaves_device()
heapq.heapify(leaves)
eviction_heap = [
(self.eviction_strategy.get_priority(node), node) for node in leaves
]
heapq.heapify(eviction_heap)
num_evicted = 0
write_back_nodes = []
while num_evicted < num_tokens and len(leaves):
x = heapq.heappop(leaves)
while num_evicted < num_tokens and len(eviction_heap):
_priority, x = heapq.heappop(eviction_heap)
if x.lock_ref > 0:
continue
@@ -285,7 +294,8 @@ class HiRadixCache(RadixCache):
break
else:
# all children are evicted or no children
heapq.heappush(leaves, x.parent)
new_priority = self.eviction_strategy.get_priority(x.parent)
heapq.heappush(eviction_heap, (new_priority, x.parent))
if self.cache_controller.write_policy == "write_back":
self.writing_check(write_back=True)
@@ -310,11 +320,14 @@ class HiRadixCache(RadixCache):
def evict_host(self, num_tokens: int):
leaves = self._collect_leaves()
heapq.heapify(leaves)
eviction_heap = [
(self.eviction_strategy.get_priority(node), node) for node in leaves
]
heapq.heapify(eviction_heap)
num_evicted = 0
while num_evicted < num_tokens and len(leaves):
x = heapq.heappop(leaves)
while num_evicted < num_tokens and len(eviction_heap):
_priority, x = heapq.heappop(eviction_heap)
if x == self.root_node:
break
# only evict the host value of evicted nodes
@@ -333,7 +346,8 @@ class HiRadixCache(RadixCache):
del x.parent.children[k]
if len(x.parent.children) == 0 and x.parent.evicted:
heapq.heappush(leaves, x.parent)
new_priority = self.eviction_strategy.get_priority(x.parent)
heapq.heappush(eviction_heap, (new_priority, x.parent))
def load_back(
self, node: TreeNode, mem_quota: Optional[int] = None

View File

@@ -34,6 +34,7 @@ from sglang.srt.disaggregation.kv_events import (
)
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
from sglang.srt.mem_cache.evict_policy import EvictionStrategy, LFUStrategy, LRUStrategy
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
if TYPE_CHECKING:
@@ -122,6 +123,7 @@ class RadixCache(BasePrefixCache):
page_size: int,
disable: bool = False,
enable_kv_cache_events: bool = False,
eviction_policy: str = "lru",
):
self.req_to_token_pool = req_to_token_pool
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
@@ -141,6 +143,15 @@ class RadixCache(BasePrefixCache):
else:
self.key_match_fn = partial(_key_match_paged, page_size=page_size)
self.get_child_key_fn = lambda key: tuple(key[:page_size])
if eviction_policy.lower() == "lru":
self.eviction_strategy: EvictionStrategy = LRUStrategy()
elif eviction_policy.lower() == "lfu":
self.eviction_strategy: EvictionStrategy = LFUStrategy()
else:
raise ValueError(
f"Unknown eviction policy: {eviction_policy}. Supported policies: 'lru', 'lfu'."
)
self.reset()
##### Public API #####
@@ -296,11 +307,14 @@ class RadixCache(BasePrefixCache):
return
leaves = self._collect_leaves()
heapq.heapify(leaves)
eviction_heap = [
(self.eviction_strategy.get_priority(node), node) for node in leaves
]
heapq.heapify(eviction_heap)
num_evicted = 0
while num_evicted < num_tokens and len(leaves):
x = heapq.heappop(leaves)
while num_evicted < num_tokens and len(eviction_heap):
_priority, x = heapq.heappop(eviction_heap)
if x == self.root_node:
break
@@ -312,7 +326,8 @@ class RadixCache(BasePrefixCache):
self._delete_leaf(x)
if len(x.parent.children) == 0:
heapq.heappush(leaves, x.parent)
new_priority = self.eviction_strategy.get_priority(x.parent)
heapq.heappush(eviction_heap, (new_priority, x.parent))
self._record_remove_event(x)

View File

@@ -78,6 +78,7 @@ class LMCRadixCache(RadixCache):
tp_size: int = 1,
rank: int = 0,
tp_group: Optional[torch.distributed.ProcessGroup] = None,
eviction_policy: str = "lru",
):
super().__init__(
req_to_token_pool=req_to_token_pool,
@@ -85,6 +86,7 @@ class LMCRadixCache(RadixCache):
page_size=page_size,
disable=disable,
enable_kv_cache_events=enable_kv_cache_events,
eviction_policy=eviction_policy,
)
kvcache = self.token_to_kv_pool_allocator.get_kvcache()