Large page size aligned hierarchical caching (#4581)
This commit is contained in:
@@ -16,7 +16,6 @@ from sglang.srt.mem_cache.memory_pool import (
|
||||
TokenToKVPoolAllocator,
|
||||
)
|
||||
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
|
||||
from sglang.srt.mem_cache.radix_cache import _key_match_page_size1 as _key_match
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -31,29 +30,25 @@ class HiRadixCache(RadixCache):
|
||||
page_size: int,
|
||||
hicache_ratio: float,
|
||||
):
|
||||
if page_size != 1:
|
||||
raise ValueError(
|
||||
"Page size larger than 1 is not yet supported in HiRadixCache."
|
||||
)
|
||||
self.kv_cache = token_to_kv_pool_allocator.get_kvcache()
|
||||
if isinstance(self.kv_cache, MHATokenToKVPool):
|
||||
self.token_to_kv_pool_host = MHATokenToKVPoolHost(
|
||||
self.kv_cache, hicache_ratio
|
||||
self.kv_cache, hicache_ratio, page_size
|
||||
)
|
||||
elif isinstance(self.kv_cache, MLATokenToKVPool):
|
||||
self.token_to_kv_pool_host = MLATokenToKVPoolHost(
|
||||
self.kv_cache, hicache_ratio
|
||||
self.kv_cache, hicache_ratio, page_size
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Only MHA and MLA supports swap kv_cache to host.")
|
||||
raise ValueError(f"HiRadixCache only supports MHA and MLA yet")
|
||||
|
||||
self.tp_group = tp_cache_group
|
||||
self.page_size = page_size
|
||||
|
||||
self.load_cache_event = threading.Event()
|
||||
self.cache_controller = HiCacheController(
|
||||
token_to_kv_pool_allocator,
|
||||
self.token_to_kv_pool_host,
|
||||
page_size,
|
||||
load_cache_event=self.load_cache_event,
|
||||
)
|
||||
|
||||
@@ -65,7 +60,7 @@ class HiRadixCache(RadixCache):
|
||||
self.write_through_threshold = 1
|
||||
self.load_back_threshold = 10
|
||||
super().__init__(
|
||||
req_to_token_pool, token_to_kv_pool_allocator, self.page_size, disable=False
|
||||
req_to_token_pool, token_to_kv_pool_allocator, page_size, disable=False
|
||||
)
|
||||
|
||||
def reset(self):
|
||||
@@ -299,18 +294,26 @@ class HiRadixCache(RadixCache):
|
||||
|
||||
return last_node, prefix_indices
|
||||
|
||||
def read_to_load_cache(self):
|
||||
def ready_to_load_cache(self):
|
||||
self.load_cache_event.set()
|
||||
|
||||
def match_prefix(self, key: List[int], include_evicted=False, **kwargs):
|
||||
if self.disable:
|
||||
return [], self.root_node
|
||||
empty_value = torch.empty((0,), dtype=torch.int64, device=self.device)
|
||||
if self.disable or len(key) == 0:
|
||||
if include_evicted:
|
||||
return empty_value, self.root_node, self.root_node
|
||||
else:
|
||||
return empty_value, self.root_node
|
||||
|
||||
if self.page_size != 1:
|
||||
page_aligned_len = len(key) // self.page_size * self.page_size
|
||||
key = key[:page_aligned_len]
|
||||
|
||||
value, last_node = self._match_prefix_helper(self.root_node, key)
|
||||
if value:
|
||||
value = torch.cat(value)
|
||||
else:
|
||||
value = torch.tensor([], dtype=torch.int64)
|
||||
value = empty_value
|
||||
|
||||
last_node_global = last_node
|
||||
while last_node.evicted:
|
||||
@@ -323,11 +326,13 @@ class HiRadixCache(RadixCache):
|
||||
|
||||
def _match_prefix_helper(self, node: TreeNode, key: List):
|
||||
node.last_access_time = time.time()
|
||||
child_key = self.get_child_key_fn(key)
|
||||
value = []
|
||||
while len(key) > 0 and key[0] in node.children.keys():
|
||||
child = node.children[key[0]]
|
||||
|
||||
while len(key) > 0 and child_key in node.children.keys():
|
||||
child = node.children[child_key]
|
||||
child.last_access_time = time.time()
|
||||
prefix_len = _key_match(child.key, key)
|
||||
prefix_len = self.key_match_fn(child.key, key)
|
||||
if prefix_len < len(child.key):
|
||||
new_node = self._split_node(child.key, child, prefix_len)
|
||||
if not new_node.evicted:
|
||||
@@ -339,12 +344,16 @@ class HiRadixCache(RadixCache):
|
||||
value.append(child.value)
|
||||
node = child
|
||||
key = key[prefix_len:]
|
||||
|
||||
if len(key):
|
||||
child_key = self.get_child_key_fn(key)
|
||||
|
||||
return value, node
|
||||
|
||||
def _split_node(self, key, child: TreeNode, split_len: int):
|
||||
# child node split into new_node -> child
|
||||
new_node = TreeNode()
|
||||
new_node.children = {key[split_len]: child}
|
||||
new_node.children = {self.get_child_key_fn(key[split_len:]): child}
|
||||
new_node.parent = child.parent
|
||||
new_node.lock_ref = child.lock_ref
|
||||
new_node.key = child.key[:split_len]
|
||||
@@ -361,7 +370,7 @@ class HiRadixCache(RadixCache):
|
||||
child.host_value = child.host_value[split_len:]
|
||||
child.parent = new_node
|
||||
child.key = child.key[split_len:]
|
||||
new_node.parent.children[key[0]] = new_node
|
||||
new_node.parent.children[self.get_child_key_fn(key)] = new_node
|
||||
return new_node
|
||||
|
||||
def _insert_helper(self, node: TreeNode, key: List, value):
|
||||
@@ -369,52 +378,53 @@ class HiRadixCache(RadixCache):
|
||||
if len(key) == 0:
|
||||
return 0
|
||||
|
||||
if key[0] in node.children.keys():
|
||||
child = node.children[key[0]]
|
||||
prefix_len = _key_match(child.key, key)
|
||||
child_key = self.get_child_key_fn(key)
|
||||
total_prefix_length = 0
|
||||
|
||||
if prefix_len == len(child.key):
|
||||
if child.evicted:
|
||||
while len(key) > 0 and child_key in node.children.keys():
|
||||
node = node.children[child_key]
|
||||
node.last_access_time = time.time()
|
||||
prefix_len = self.key_match_fn(node.key, key)
|
||||
|
||||
if prefix_len == len(node.key):
|
||||
if node.evicted:
|
||||
# change the reference if the node is evicted
|
||||
# this often happens in the case of KV cache recomputation
|
||||
child.value = value[:prefix_len]
|
||||
self.token_to_kv_pool_host.update_synced(child.host_value)
|
||||
self.evictable_size_ += len(value[:prefix_len])
|
||||
return self._insert_helper(
|
||||
child, key[prefix_len:], value[prefix_len:]
|
||||
)
|
||||
node.value = value[:prefix_len]
|
||||
self.token_to_kv_pool_host.update_synced(node.host_value)
|
||||
self.evictable_size_ += len(node.value)
|
||||
else:
|
||||
self.inc_hit_count(child)
|
||||
return prefix_len + self._insert_helper(
|
||||
child, key[prefix_len:], value[prefix_len:]
|
||||
)
|
||||
|
||||
# partial match, split the node
|
||||
new_node = self._split_node(child.key, child, prefix_len)
|
||||
if new_node.evicted:
|
||||
new_node.value = value[:prefix_len]
|
||||
self.token_to_kv_pool_host.update_synced(new_node.host_value)
|
||||
self.evictable_size_ += len(new_node.value)
|
||||
return self._insert_helper(
|
||||
new_node, key[prefix_len:], value[prefix_len:]
|
||||
)
|
||||
self.inc_hit_count(node)
|
||||
total_prefix_length += prefix_len
|
||||
else:
|
||||
self.inc_hit_count(new_node)
|
||||
return prefix_len + self._insert_helper(
|
||||
new_node, key[prefix_len:], value[prefix_len:]
|
||||
)
|
||||
# partial match, split the node
|
||||
new_node = self._split_node(node.key, node, prefix_len)
|
||||
if new_node.evicted:
|
||||
new_node.value = value[:prefix_len]
|
||||
self.token_to_kv_pool_host.update_synced(new_node.host_value)
|
||||
self.evictable_size_ += len(new_node.value)
|
||||
else:
|
||||
self.inc_hit_count(new_node)
|
||||
total_prefix_length += prefix_len
|
||||
node = new_node
|
||||
|
||||
key = key[prefix_len:]
|
||||
value = value[prefix_len:]
|
||||
|
||||
if len(key):
|
||||
child_key = self.get_child_key_fn(key)
|
||||
|
||||
if len(key):
|
||||
new_node = TreeNode()
|
||||
new_node.parent = node
|
||||
new_node.key = key
|
||||
new_node.value = value
|
||||
node.children[key[0]] = new_node
|
||||
node.children[child_key] = new_node
|
||||
self.evictable_size_ += len(value)
|
||||
|
||||
if self.cache_controller.write_policy == "write_through":
|
||||
self.write_backup(new_node)
|
||||
return 0
|
||||
return total_prefix_length
|
||||
|
||||
def _collect_leaves_device(self):
|
||||
def is_leaf(node):
|
||||
|
||||
Reference in New Issue
Block a user