[HiCache] resolve conflict between chunked-prefill and hicache hit count (#9776)

This commit is contained in:
Zhiqiang Xie
2025-08-29 10:30:54 -07:00
committed by GitHub
parent e5b29bf14e
commit 54e872d343
8 changed files with 20 additions and 17 deletions

View File

@@ -102,7 +102,7 @@ class HiRadixCache(RadixCache):
self.ongoing_backup = {}
# todo: dynamically adjust the threshold
self.write_through_threshold = (
1 if hicache_write_policy == "write_through" else 3
1 if hicache_write_policy == "write_through" else 2
)
self.write_through_threshold_storage = (
1 if hicache_write_policy == "write_through" else 3
@@ -155,8 +155,9 @@ class HiRadixCache(RadixCache):
self.ongoing_backup[operation_id] = node
node.protect_host()
def inc_hit_count(self, node: TreeNode):
if self.cache_controller.write_policy == "write_back":
def _inc_hit_count(self, node: TreeNode, chunked=False):
# skip the hit count update for chunked requests
if self.cache_controller.write_policy == "write_back" or chunked:
return
node.hit_count += 1
@@ -672,11 +673,11 @@ class HiRadixCache(RadixCache):
new_node.parent.children[self.get_child_key_fn(key)] = new_node
return new_node
def _insert_helper(self, node: TreeNode, key: List, value):
node.last_access_time = time.monotonic()
def insert(self, key: List, value, chunked=False):
if len(key) == 0:
return 0
node = self.root_node
child_key = self.get_child_key_fn(key)
total_prefix_length = 0
@@ -693,7 +694,7 @@ class HiRadixCache(RadixCache):
self.token_to_kv_pool_host.update_synced(node.host_value)
self.evictable_size_ += len(node.value)
else:
self.inc_hit_count(node)
self._inc_hit_count(node, chunked)
total_prefix_length += prefix_len
else:
# partial match, split the node
@@ -703,7 +704,7 @@ class HiRadixCache(RadixCache):
self.token_to_kv_pool_host.update_synced(new_node.host_value)
self.evictable_size_ += len(new_node.value)
else:
self.inc_hit_count(new_node)
self._inc_hit_count(new_node, chunked)
total_prefix_length += prefix_len
node = new_node
@@ -737,7 +738,7 @@ class HiRadixCache(RadixCache):
last_hash = new_node.hash_value[-1]
if self.cache_controller.write_policy != "write_back":
self.inc_hit_count(new_node)
self._inc_hit_count(new_node, chunked)
return total_prefix_length
def _collect_leaves_device(self):