[HICache] introduce evict policy (#10190)
Signed-off-by: Xuchun Shang <xuchun.shang@linux.alibaba.com> Co-authored-by: Teng Ma <sima.mt@alibaba-inc.com>
This commit is contained in:
@@ -667,6 +667,7 @@ class Scheduler(
|
|||||||
else self.tp_cpu_group
|
else self.tp_cpu_group
|
||||||
),
|
),
|
||||||
page_size=self.page_size,
|
page_size=self.page_size,
|
||||||
|
eviction_policy=server_args.radix_eviction_policy,
|
||||||
hicache_ratio=server_args.hicache_ratio,
|
hicache_ratio=server_args.hicache_ratio,
|
||||||
hicache_size=server_args.hicache_size,
|
hicache_size=server_args.hicache_size,
|
||||||
hicache_write_policy=server_args.hicache_write_policy,
|
hicache_write_policy=server_args.hicache_write_policy,
|
||||||
@@ -719,6 +720,7 @@ class Scheduler(
|
|||||||
tp_size=self.tp_size,
|
tp_size=self.tp_size,
|
||||||
rank=self.tp_rank,
|
rank=self.tp_rank,
|
||||||
tp_group=self.tp_group,
|
tp_group=self.tp_group,
|
||||||
|
eviction_policy=server_args.radix_eviction_policy,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.tree_cache = RadixCache(
|
self.tree_cache = RadixCache(
|
||||||
@@ -727,6 +729,7 @@ class Scheduler(
|
|||||||
page_size=self.page_size,
|
page_size=self.page_size,
|
||||||
disable=server_args.disable_radix_cache,
|
disable=server_args.disable_radix_cache,
|
||||||
enable_kv_cache_events=self.enable_kv_cache_events,
|
enable_kv_cache_events=self.enable_kv_cache_events,
|
||||||
|
eviction_policy=server_args.radix_eviction_policy,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.decode_mem_cache_buf_multiplier = (
|
self.decode_mem_cache_buf_multiplier = (
|
||||||
|
|||||||
23
python/sglang/srt/mem_cache/evict_policy.py
Normal file
23
python/sglang/srt/mem_cache/evict_policy.py
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import TYPE_CHECKING, List, Tuple, Union
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from sglang.srt.mem_cache.radix_cache import TreeNode
|
||||||
|
|
||||||
|
|
||||||
|
class EvictionStrategy(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def get_priority(self, node: "TreeNode") -> Union[float, Tuple]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class LRUStrategy(EvictionStrategy):
|
||||||
|
def get_priority(self, node: "TreeNode") -> float:
|
||||||
|
return node.last_access_time
|
||||||
|
|
||||||
|
|
||||||
|
class LFUStrategy(EvictionStrategy):
|
||||||
|
def get_priority(self, node: "TreeNode") -> Tuple[int, float]:
|
||||||
|
return (node.hit_count, node.last_access_time)
|
||||||
@@ -39,6 +39,7 @@ class HiRadixCache(RadixCache):
|
|||||||
hicache_io_backend: str,
|
hicache_io_backend: str,
|
||||||
hicache_mem_layout: str,
|
hicache_mem_layout: str,
|
||||||
enable_metrics: bool,
|
enable_metrics: bool,
|
||||||
|
eviction_policy: str = "lru",
|
||||||
hicache_storage_backend: Optional[str] = None,
|
hicache_storage_backend: Optional[str] = None,
|
||||||
hicache_storage_prefetch_policy: Optional[str] = "best_effort",
|
hicache_storage_prefetch_policy: Optional[str] = "best_effort",
|
||||||
model_name: Optional[str] = None,
|
model_name: Optional[str] = None,
|
||||||
@@ -117,8 +118,13 @@ class HiRadixCache(RadixCache):
|
|||||||
1 if hicache_write_policy == "write_through" else 2
|
1 if hicache_write_policy == "write_through" else 2
|
||||||
)
|
)
|
||||||
self.load_back_threshold = 10
|
self.load_back_threshold = 10
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
req_to_token_pool, token_to_kv_pool_allocator, page_size, disable=False
|
req_to_token_pool,
|
||||||
|
token_to_kv_pool_allocator,
|
||||||
|
page_size,
|
||||||
|
disable=False,
|
||||||
|
eviction_policy=eviction_policy,
|
||||||
)
|
)
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
@@ -258,12 +264,15 @@ class HiRadixCache(RadixCache):
|
|||||||
|
|
||||||
def evict(self, num_tokens: int):
|
def evict(self, num_tokens: int):
|
||||||
leaves = self._collect_leaves_device()
|
leaves = self._collect_leaves_device()
|
||||||
heapq.heapify(leaves)
|
eviction_heap = [
|
||||||
|
(self.eviction_strategy.get_priority(node), node) for node in leaves
|
||||||
|
]
|
||||||
|
heapq.heapify(eviction_heap)
|
||||||
|
|
||||||
num_evicted = 0
|
num_evicted = 0
|
||||||
write_back_nodes = []
|
write_back_nodes = []
|
||||||
while num_evicted < num_tokens and len(leaves):
|
while num_evicted < num_tokens and len(eviction_heap):
|
||||||
x = heapq.heappop(leaves)
|
_priority, x = heapq.heappop(eviction_heap)
|
||||||
|
|
||||||
if x.lock_ref > 0:
|
if x.lock_ref > 0:
|
||||||
continue
|
continue
|
||||||
@@ -285,7 +294,8 @@ class HiRadixCache(RadixCache):
|
|||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
# all children are evicted or no children
|
# all children are evicted or no children
|
||||||
heapq.heappush(leaves, x.parent)
|
new_priority = self.eviction_strategy.get_priority(x.parent)
|
||||||
|
heapq.heappush(eviction_heap, (new_priority, x.parent))
|
||||||
|
|
||||||
if self.cache_controller.write_policy == "write_back":
|
if self.cache_controller.write_policy == "write_back":
|
||||||
self.writing_check(write_back=True)
|
self.writing_check(write_back=True)
|
||||||
@@ -310,11 +320,14 @@ class HiRadixCache(RadixCache):
|
|||||||
|
|
||||||
def evict_host(self, num_tokens: int):
|
def evict_host(self, num_tokens: int):
|
||||||
leaves = self._collect_leaves()
|
leaves = self._collect_leaves()
|
||||||
heapq.heapify(leaves)
|
eviction_heap = [
|
||||||
|
(self.eviction_strategy.get_priority(node), node) for node in leaves
|
||||||
|
]
|
||||||
|
heapq.heapify(eviction_heap)
|
||||||
|
|
||||||
num_evicted = 0
|
num_evicted = 0
|
||||||
while num_evicted < num_tokens and len(leaves):
|
while num_evicted < num_tokens and len(eviction_heap):
|
||||||
x = heapq.heappop(leaves)
|
_priority, x = heapq.heappop(eviction_heap)
|
||||||
if x == self.root_node:
|
if x == self.root_node:
|
||||||
break
|
break
|
||||||
# only evict the host value of evicted nodes
|
# only evict the host value of evicted nodes
|
||||||
@@ -333,7 +346,8 @@ class HiRadixCache(RadixCache):
|
|||||||
del x.parent.children[k]
|
del x.parent.children[k]
|
||||||
|
|
||||||
if len(x.parent.children) == 0 and x.parent.evicted:
|
if len(x.parent.children) == 0 and x.parent.evicted:
|
||||||
heapq.heappush(leaves, x.parent)
|
new_priority = self.eviction_strategy.get_priority(x.parent)
|
||||||
|
heapq.heappush(eviction_heap, (new_priority, x.parent))
|
||||||
|
|
||||||
def load_back(
|
def load_back(
|
||||||
self, node: TreeNode, mem_quota: Optional[int] = None
|
self, node: TreeNode, mem_quota: Optional[int] = None
|
||||||
|
|||||||
@@ -34,6 +34,7 @@ from sglang.srt.disaggregation.kv_events import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
||||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
|
||||||
|
from sglang.srt.mem_cache.evict_policy import EvictionStrategy, LFUStrategy, LRUStrategy
|
||||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -122,6 +123,7 @@ class RadixCache(BasePrefixCache):
|
|||||||
page_size: int,
|
page_size: int,
|
||||||
disable: bool = False,
|
disable: bool = False,
|
||||||
enable_kv_cache_events: bool = False,
|
enable_kv_cache_events: bool = False,
|
||||||
|
eviction_policy: str = "lru",
|
||||||
):
|
):
|
||||||
self.req_to_token_pool = req_to_token_pool
|
self.req_to_token_pool = req_to_token_pool
|
||||||
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
||||||
@@ -141,6 +143,15 @@ class RadixCache(BasePrefixCache):
|
|||||||
else:
|
else:
|
||||||
self.key_match_fn = partial(_key_match_paged, page_size=page_size)
|
self.key_match_fn = partial(_key_match_paged, page_size=page_size)
|
||||||
self.get_child_key_fn = lambda key: tuple(key[:page_size])
|
self.get_child_key_fn = lambda key: tuple(key[:page_size])
|
||||||
|
|
||||||
|
if eviction_policy.lower() == "lru":
|
||||||
|
self.eviction_strategy: EvictionStrategy = LRUStrategy()
|
||||||
|
elif eviction_policy.lower() == "lfu":
|
||||||
|
self.eviction_strategy: EvictionStrategy = LFUStrategy()
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown eviction policy: {eviction_policy}. Supported policies: 'lru', 'lfu'."
|
||||||
|
)
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
##### Public API #####
|
##### Public API #####
|
||||||
@@ -296,11 +307,14 @@ class RadixCache(BasePrefixCache):
|
|||||||
return
|
return
|
||||||
|
|
||||||
leaves = self._collect_leaves()
|
leaves = self._collect_leaves()
|
||||||
heapq.heapify(leaves)
|
eviction_heap = [
|
||||||
|
(self.eviction_strategy.get_priority(node), node) for node in leaves
|
||||||
|
]
|
||||||
|
heapq.heapify(eviction_heap)
|
||||||
|
|
||||||
num_evicted = 0
|
num_evicted = 0
|
||||||
while num_evicted < num_tokens and len(leaves):
|
while num_evicted < num_tokens and len(eviction_heap):
|
||||||
x = heapq.heappop(leaves)
|
_priority, x = heapq.heappop(eviction_heap)
|
||||||
|
|
||||||
if x == self.root_node:
|
if x == self.root_node:
|
||||||
break
|
break
|
||||||
@@ -312,7 +326,8 @@ class RadixCache(BasePrefixCache):
|
|||||||
self._delete_leaf(x)
|
self._delete_leaf(x)
|
||||||
|
|
||||||
if len(x.parent.children) == 0:
|
if len(x.parent.children) == 0:
|
||||||
heapq.heappush(leaves, x.parent)
|
new_priority = self.eviction_strategy.get_priority(x.parent)
|
||||||
|
heapq.heappush(eviction_heap, (new_priority, x.parent))
|
||||||
|
|
||||||
self._record_remove_event(x)
|
self._record_remove_event(x)
|
||||||
|
|
||||||
|
|||||||
@@ -78,6 +78,7 @@ class LMCRadixCache(RadixCache):
|
|||||||
tp_size: int = 1,
|
tp_size: int = 1,
|
||||||
rank: int = 0,
|
rank: int = 0,
|
||||||
tp_group: Optional[torch.distributed.ProcessGroup] = None,
|
tp_group: Optional[torch.distributed.ProcessGroup] = None,
|
||||||
|
eviction_policy: str = "lru",
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
req_to_token_pool=req_to_token_pool,
|
req_to_token_pool=req_to_token_pool,
|
||||||
@@ -85,6 +86,7 @@ class LMCRadixCache(RadixCache):
|
|||||||
page_size=page_size,
|
page_size=page_size,
|
||||||
disable=disable,
|
disable=disable,
|
||||||
enable_kv_cache_events=enable_kv_cache_events,
|
enable_kv_cache_events=enable_kv_cache_events,
|
||||||
|
eviction_policy=eviction_policy,
|
||||||
)
|
)
|
||||||
|
|
||||||
kvcache = self.token_to_kv_pool_allocator.get_kvcache()
|
kvcache = self.token_to_kv_pool_allocator.get_kvcache()
|
||||||
|
|||||||
@@ -185,6 +185,7 @@ class ServerArgs:
|
|||||||
hybrid_kvcache_ratio: Optional[float] = None
|
hybrid_kvcache_ratio: Optional[float] = None
|
||||||
swa_full_tokens_ratio: float = 0.8
|
swa_full_tokens_ratio: float = 0.8
|
||||||
disable_hybrid_swa_memory: bool = False
|
disable_hybrid_swa_memory: bool = False
|
||||||
|
radix_eviction_policy: str = "lru"
|
||||||
|
|
||||||
# Runtime options
|
# Runtime options
|
||||||
device: Optional[str] = None
|
device: Optional[str] = None
|
||||||
@@ -1907,6 +1908,13 @@ class ServerArgs:
|
|||||||
default=ServerArgs.hicache_write_policy,
|
default=ServerArgs.hicache_write_policy,
|
||||||
help="The write policy of hierarchical cache.",
|
help="The write policy of hierarchical cache.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--radix-eviction-policy",
|
||||||
|
type=str,
|
||||||
|
choices=["lru", "lfu"],
|
||||||
|
default=ServerArgs.radix_eviction_policy,
|
||||||
|
help="The eviction policy of radix trees. 'lru' stands for Least Recently Used, 'lfu' stands for Least Frequently Used.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--hicache-io-backend",
|
"--hicache-io-backend",
|
||||||
type=str,
|
type=str,
|
||||||
|
|||||||
Reference in New Issue
Block a user