diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 83e6b45cb..a246534cb 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -667,6 +667,7 @@ class Scheduler( else self.tp_cpu_group ), page_size=self.page_size, + eviction_policy=server_args.radix_eviction_policy, hicache_ratio=server_args.hicache_ratio, hicache_size=server_args.hicache_size, hicache_write_policy=server_args.hicache_write_policy, @@ -719,6 +720,7 @@ class Scheduler( tp_size=self.tp_size, rank=self.tp_rank, tp_group=self.tp_group, + eviction_policy=server_args.radix_eviction_policy, ) else: self.tree_cache = RadixCache( @@ -727,6 +729,7 @@ class Scheduler( page_size=self.page_size, disable=server_args.disable_radix_cache, enable_kv_cache_events=self.enable_kv_cache_events, + eviction_policy=server_args.radix_eviction_policy, ) self.decode_mem_cache_buf_multiplier = ( diff --git a/python/sglang/srt/mem_cache/evict_policy.py b/python/sglang/srt/mem_cache/evict_policy.py new file mode 100644 index 000000000..ddd2ab6c3 --- /dev/null +++ b/python/sglang/srt/mem_cache/evict_policy.py @@ -0,0 +1,23 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, List, Tuple, Union + +if TYPE_CHECKING: + from sglang.srt.mem_cache.radix_cache import TreeNode + + +class EvictionStrategy(ABC): + @abstractmethod + def get_priority(self, node: "TreeNode") -> Union[float, Tuple]: + pass + + +class LRUStrategy(EvictionStrategy): + def get_priority(self, node: "TreeNode") -> float: + return node.last_access_time + + +class LFUStrategy(EvictionStrategy): + def get_priority(self, node: "TreeNode") -> Tuple[int, float]: + return (node.hit_count, node.last_access_time) diff --git a/python/sglang/srt/mem_cache/hiradix_cache.py b/python/sglang/srt/mem_cache/hiradix_cache.py index 3b00e4619..538c2a450 100644 --- a/python/sglang/srt/mem_cache/hiradix_cache.py +++ b/python/sglang/srt/mem_cache/hiradix_cache.py @@ -39,6 +39,7 @@ class HiRadixCache(RadixCache): hicache_io_backend: str, hicache_mem_layout: str, enable_metrics: bool, + eviction_policy: str = "lru", hicache_storage_backend: Optional[str] = None, hicache_storage_prefetch_policy: Optional[str] = "best_effort", model_name: Optional[str] = None, @@ -117,8 +118,13 @@ class HiRadixCache(RadixCache): 1 if hicache_write_policy == "write_through" else 2 ) self.load_back_threshold = 10 + super().__init__( - req_to_token_pool, token_to_kv_pool_allocator, page_size, disable=False + req_to_token_pool, + token_to_kv_pool_allocator, + page_size, + disable=False, + eviction_policy=eviction_policy, ) def reset(self): @@ -258,12 +264,15 @@ class HiRadixCache(RadixCache): def evict(self, num_tokens: int): leaves = self._collect_leaves_device() - heapq.heapify(leaves) + eviction_heap = [ + (self.eviction_strategy.get_priority(node), node) for node in leaves + ] + heapq.heapify(eviction_heap) num_evicted = 0 write_back_nodes = [] - while num_evicted < num_tokens and len(leaves): - x = heapq.heappop(leaves) + while num_evicted < num_tokens and len(eviction_heap): + _priority, x = heapq.heappop(eviction_heap) if x.lock_ref > 0: continue @@ -285,7 +294,8 @@ class HiRadixCache(RadixCache): break else: # all children are evicted or no children - heapq.heappush(leaves, x.parent) + new_priority = self.eviction_strategy.get_priority(x.parent) + heapq.heappush(eviction_heap, (new_priority, x.parent)) if self.cache_controller.write_policy == "write_back": self.writing_check(write_back=True) @@ -310,11 +320,14 @@ class HiRadixCache(RadixCache): def evict_host(self, num_tokens: int): leaves = self._collect_leaves() - heapq.heapify(leaves) + eviction_heap = [ + (self.eviction_strategy.get_priority(node), node) for node in leaves + ] + heapq.heapify(eviction_heap) num_evicted = 0 - while num_evicted < num_tokens and len(leaves): - x = heapq.heappop(leaves) + while num_evicted < num_tokens and len(eviction_heap): + _priority, x = heapq.heappop(eviction_heap) if x == self.root_node: break # only evict the host value of evicted nodes @@ -333,7 +346,8 @@ class HiRadixCache(RadixCache): del x.parent.children[k] if len(x.parent.children) == 0 and x.parent.evicted: - heapq.heappush(leaves, x.parent) + new_priority = self.eviction_strategy.get_priority(x.parent) + heapq.heappush(eviction_heap, (new_priority, x.parent)) def load_back( self, node: TreeNode, mem_quota: Optional[int] = None diff --git a/python/sglang/srt/mem_cache/radix_cache.py b/python/sglang/srt/mem_cache/radix_cache.py index d8208e143..4745811de 100644 --- a/python/sglang/srt/mem_cache/radix_cache.py +++ b/python/sglang/srt/mem_cache/radix_cache.py @@ -34,6 +34,7 @@ from sglang.srt.disaggregation.kv_events import ( ) from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult +from sglang.srt.mem_cache.evict_policy import EvictionStrategy, LFUStrategy, LRUStrategy from sglang.srt.mem_cache.memory_pool import ReqToTokenPool if TYPE_CHECKING: @@ -122,6 +123,7 @@ class RadixCache(BasePrefixCache): page_size: int, disable: bool = False, enable_kv_cache_events: bool = False, + eviction_policy: str = "lru", ): self.req_to_token_pool = req_to_token_pool self.token_to_kv_pool_allocator = token_to_kv_pool_allocator @@ -141,6 +143,15 @@ class RadixCache(BasePrefixCache): else: self.key_match_fn = partial(_key_match_paged, page_size=page_size) self.get_child_key_fn = lambda key: tuple(key[:page_size]) + + if eviction_policy.lower() == "lru": + self.eviction_strategy: EvictionStrategy = LRUStrategy() + elif eviction_policy.lower() == "lfu": + self.eviction_strategy: EvictionStrategy = LFUStrategy() + else: + raise ValueError( + f"Unknown eviction policy: {eviction_policy}. Supported policies: 'lru', 'lfu'." + ) self.reset() ##### Public API ##### @@ -296,11 +307,14 @@ class RadixCache(BasePrefixCache): return leaves = self._collect_leaves() - heapq.heapify(leaves) + eviction_heap = [ + (self.eviction_strategy.get_priority(node), node) for node in leaves + ] + heapq.heapify(eviction_heap) num_evicted = 0 - while num_evicted < num_tokens and len(leaves): - x = heapq.heappop(leaves) + while num_evicted < num_tokens and len(eviction_heap): + _priority, x = heapq.heappop(eviction_heap) if x == self.root_node: break @@ -312,7 +326,8 @@ class RadixCache(BasePrefixCache): self._delete_leaf(x) if len(x.parent.children) == 0: - heapq.heappush(leaves, x.parent) + new_priority = self.eviction_strategy.get_priority(x.parent) + heapq.heappush(eviction_heap, (new_priority, x.parent)) self._record_remove_event(x) diff --git a/python/sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py b/python/sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py index f8690aec4..99537135e 100644 --- a/python/sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +++ b/python/sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py @@ -78,6 +78,7 @@ class LMCRadixCache(RadixCache): tp_size: int = 1, rank: int = 0, tp_group: Optional[torch.distributed.ProcessGroup] = None, + eviction_policy: str = "lru", ): super().__init__( req_to_token_pool=req_to_token_pool, @@ -85,6 +86,7 @@ class LMCRadixCache(RadixCache): page_size=page_size, disable=disable, enable_kv_cache_events=enable_kv_cache_events, + eviction_policy=eviction_policy, ) kvcache = self.token_to_kv_pool_allocator.get_kvcache() diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 34dae30fb..3459d67ee 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -185,6 +185,7 @@ class ServerArgs: hybrid_kvcache_ratio: Optional[float] = None swa_full_tokens_ratio: float = 0.8 disable_hybrid_swa_memory: bool = False + radix_eviction_policy: str = "lru" # Runtime options device: Optional[str] = None @@ -1907,6 +1908,13 @@ class ServerArgs: default=ServerArgs.hicache_write_policy, help="The write policy of hierarchical cache.", ) + parser.add_argument( + "--radix-eviction-policy", + type=str, + choices=["lru", "lfu"], + default=ServerArgs.radix_eviction_policy, + help="The eviction policy of radix trees. 'lru' stands for Least Recently Used, 'lfu' stands for Least Frequently Used.", + ) parser.add_argument( "--hicache-io-backend", type=str,