Large page size aligned hierarchical caching (#4581)

This commit is contained in:
Zhiqiang Xie
2025-04-01 22:38:15 -07:00
committed by GitHub
parent 9eb49e878b
commit e119f04215
8 changed files with 242 additions and 71 deletions

View File

@@ -16,7 +16,6 @@ from sglang.srt.mem_cache.memory_pool import (
TokenToKVPoolAllocator,
)
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
from sglang.srt.mem_cache.radix_cache import _key_match_page_size1 as _key_match
logger = logging.getLogger(__name__)
@@ -31,29 +30,25 @@ class HiRadixCache(RadixCache):
page_size: int,
hicache_ratio: float,
):
if page_size != 1:
raise ValueError(
"Page size larger than 1 is not yet supported in HiRadixCache."
)
self.kv_cache = token_to_kv_pool_allocator.get_kvcache()
if isinstance(self.kv_cache, MHATokenToKVPool):
self.token_to_kv_pool_host = MHATokenToKVPoolHost(
self.kv_cache, hicache_ratio
self.kv_cache, hicache_ratio, page_size
)
elif isinstance(self.kv_cache, MLATokenToKVPool):
self.token_to_kv_pool_host = MLATokenToKVPoolHost(
self.kv_cache, hicache_ratio
self.kv_cache, hicache_ratio, page_size
)
else:
raise ValueError(f"Only MHA and MLA supports swap kv_cache to host.")
raise ValueError(f"HiRadixCache only supports MHA and MLA yet")
self.tp_group = tp_cache_group
self.page_size = page_size
self.load_cache_event = threading.Event()
self.cache_controller = HiCacheController(
token_to_kv_pool_allocator,
self.token_to_kv_pool_host,
page_size,
load_cache_event=self.load_cache_event,
)
@@ -65,7 +60,7 @@ class HiRadixCache(RadixCache):
self.write_through_threshold = 1
self.load_back_threshold = 10
super().__init__(
req_to_token_pool, token_to_kv_pool_allocator, self.page_size, disable=False
req_to_token_pool, token_to_kv_pool_allocator, page_size, disable=False
)
def reset(self):
@@ -299,18 +294,26 @@ class HiRadixCache(RadixCache):
return last_node, prefix_indices
def read_to_load_cache(self):
def ready_to_load_cache(self):
self.load_cache_event.set()
def match_prefix(self, key: List[int], include_evicted=False, **kwargs):
if self.disable:
return [], self.root_node
empty_value = torch.empty((0,), dtype=torch.int64, device=self.device)
if self.disable or len(key) == 0:
if include_evicted:
return empty_value, self.root_node, self.root_node
else:
return empty_value, self.root_node
if self.page_size != 1:
page_aligned_len = len(key) // self.page_size * self.page_size
key = key[:page_aligned_len]
value, last_node = self._match_prefix_helper(self.root_node, key)
if value:
value = torch.cat(value)
else:
value = torch.tensor([], dtype=torch.int64)
value = empty_value
last_node_global = last_node
while last_node.evicted:
@@ -323,11 +326,13 @@ class HiRadixCache(RadixCache):
def _match_prefix_helper(self, node: TreeNode, key: List):
node.last_access_time = time.time()
child_key = self.get_child_key_fn(key)
value = []
while len(key) > 0 and key[0] in node.children.keys():
child = node.children[key[0]]
while len(key) > 0 and child_key in node.children.keys():
child = node.children[child_key]
child.last_access_time = time.time()
prefix_len = _key_match(child.key, key)
prefix_len = self.key_match_fn(child.key, key)
if prefix_len < len(child.key):
new_node = self._split_node(child.key, child, prefix_len)
if not new_node.evicted:
@@ -339,12 +344,16 @@ class HiRadixCache(RadixCache):
value.append(child.value)
node = child
key = key[prefix_len:]
if len(key):
child_key = self.get_child_key_fn(key)
return value, node
def _split_node(self, key, child: TreeNode, split_len: int):
# child node split into new_node -> child
new_node = TreeNode()
new_node.children = {key[split_len]: child}
new_node.children = {self.get_child_key_fn(key[split_len:]): child}
new_node.parent = child.parent
new_node.lock_ref = child.lock_ref
new_node.key = child.key[:split_len]
@@ -361,7 +370,7 @@ class HiRadixCache(RadixCache):
child.host_value = child.host_value[split_len:]
child.parent = new_node
child.key = child.key[split_len:]
new_node.parent.children[key[0]] = new_node
new_node.parent.children[self.get_child_key_fn(key)] = new_node
return new_node
def _insert_helper(self, node: TreeNode, key: List, value):
@@ -369,52 +378,53 @@ class HiRadixCache(RadixCache):
if len(key) == 0:
return 0
if key[0] in node.children.keys():
child = node.children[key[0]]
prefix_len = _key_match(child.key, key)
child_key = self.get_child_key_fn(key)
total_prefix_length = 0
if prefix_len == len(child.key):
if child.evicted:
while len(key) > 0 and child_key in node.children.keys():
node = node.children[child_key]
node.last_access_time = time.time()
prefix_len = self.key_match_fn(node.key, key)
if prefix_len == len(node.key):
if node.evicted:
# change the reference if the node is evicted
# this often happens in the case of KV cache recomputation
child.value = value[:prefix_len]
self.token_to_kv_pool_host.update_synced(child.host_value)
self.evictable_size_ += len(value[:prefix_len])
return self._insert_helper(
child, key[prefix_len:], value[prefix_len:]
)
node.value = value[:prefix_len]
self.token_to_kv_pool_host.update_synced(node.host_value)
self.evictable_size_ += len(node.value)
else:
self.inc_hit_count(child)
return prefix_len + self._insert_helper(
child, key[prefix_len:], value[prefix_len:]
)
# partial match, split the node
new_node = self._split_node(child.key, child, prefix_len)
if new_node.evicted:
new_node.value = value[:prefix_len]
self.token_to_kv_pool_host.update_synced(new_node.host_value)
self.evictable_size_ += len(new_node.value)
return self._insert_helper(
new_node, key[prefix_len:], value[prefix_len:]
)
self.inc_hit_count(node)
total_prefix_length += prefix_len
else:
self.inc_hit_count(new_node)
return prefix_len + self._insert_helper(
new_node, key[prefix_len:], value[prefix_len:]
)
# partial match, split the node
new_node = self._split_node(node.key, node, prefix_len)
if new_node.evicted:
new_node.value = value[:prefix_len]
self.token_to_kv_pool_host.update_synced(new_node.host_value)
self.evictable_size_ += len(new_node.value)
else:
self.inc_hit_count(new_node)
total_prefix_length += prefix_len
node = new_node
key = key[prefix_len:]
value = value[prefix_len:]
if len(key):
child_key = self.get_child_key_fn(key)
if len(key):
new_node = TreeNode()
new_node.parent = node
new_node.key = key
new_node.value = value
node.children[key[0]] = new_node
node.children[child_key] = new_node
self.evictable_size_ += len(value)
if self.cache_controller.write_policy == "write_through":
self.write_backup(new_node)
return 0
return total_prefix_length
def _collect_leaves_device(self):
def is_leaf(node):