[HiCache] Storage Refactoring (#9797)

Co-authored-by: pansicheng <27603155+pansicheng@users.noreply.github.com>
This commit is contained in:
Zhiqiang Xie
2025-08-31 07:58:21 -07:00
committed by GitHub
parent a391f73adc
commit 8b6966d020
3 changed files with 114 additions and 154 deletions

View File

@@ -104,9 +104,6 @@ class HiRadixCache(RadixCache):
self.write_through_threshold = (
1 if hicache_write_policy == "write_through" else 2
)
self.write_through_threshold_storage = (
1 if hicache_write_policy == "write_through" else 3
)
self.load_back_threshold = 10
super().__init__(
req_to_token_pool, token_to_kv_pool_allocator, page_size, disable=False
@@ -174,14 +171,6 @@ class HiRadixCache(RadixCache):
if node.hit_count >= self.write_through_threshold:
# write to host if the node is not backuped
self.write_backup(node)
else:
if (
self.enable_storage
and (not node.backuped_storage)
and node.hit_count >= self.write_through_threshold_storage
):
# if the node is backuped on host memory but not on storage
self.write_backup_storage(node)
def writing_check(self, write_back=False):
if write_back:
@@ -202,8 +191,11 @@ class HiRadixCache(RadixCache):
)
for _ in range(queue_size.item()):
ack_id = self.cache_controller.ack_write_queue.get()
self.dec_lock_ref(self.ongoing_write_through[ack_id])
backuped_node = self.ongoing_write_through[ack_id]
self.dec_lock_ref(backuped_node)
del self.ongoing_write_through[ack_id]
if self.enable_storage:
self.write_backup_storage(backuped_node)
def loading_check(self):
while not self.cache_controller.ack_load_queue.empty():
@@ -386,57 +378,54 @@ class HiRadixCache(RadixCache):
self.writing_check()
self.loading_check()
if self.enable_storage:
self.check_revoked_prefetch()
self.check_backup_progress()
self.drain_storage_control_queues()
def check_revoked_prefetch(self):
queue_size = torch.tensor(
self.cache_controller.prefetch_revoke_queue.qsize(), dtype=torch.int
def drain_storage_control_queues(self):
"""
Combine prefetch revoke, backup ack, and host mem release checks
to minimize TP synchronization and Python overhead.
"""
cc = self.cache_controller
qsizes = torch.tensor(
[
cc.prefetch_revoke_queue.qsize(),
cc.ack_backup_queue.qsize(),
cc.host_mem_release_queue.qsize(),
],
dtype=torch.int,
)
if self.tp_world_size > 1:
# synchrnoize TP workers to make the same update to hiradix cache
torch.distributed.all_reduce(
queue_size,
op=torch.distributed.ReduceOp.MIN,
group=self.tp_group,
qsizes, op=torch.distributed.ReduceOp.MIN, group=self.tp_group
)
for _ in range(queue_size.item()):
req_id = self.cache_controller.prefetch_revoke_queue.get()
if req_id in self.ongoing_prefetch:
last_host_node, token_ids, _, _ = self.ongoing_prefetch[req_id]
n_revoke, n_backup, n_release = map(int, qsizes.tolist())
# process prefetch revokes
for _ in range(n_revoke):
req_id = cc.prefetch_revoke_queue.get()
info = self.ongoing_prefetch.pop(req_id, None)
if info is not None:
last_host_node, token_ids, _, _ = info
last_host_node.release_host()
del self.ongoing_prefetch[req_id]
self.cache_controller.prefetch_tokens_occupied -= len(token_ids)
else:
# the revoked operation already got terminated
pass
cc.prefetch_tokens_occupied -= len(token_ids)
# else: the revoked operation already got terminated, nothing to do
def check_backup_progress(self):
queue_size = torch.tensor(
self.cache_controller.ack_backup_queue.qsize(), dtype=torch.int
)
if self.tp_world_size > 1:
# synchrnoize TP workers to make the same update to hiradix cache
torch.distributed.all_reduce(
queue_size,
op=torch.distributed.ReduceOp.MIN,
group=self.tp_group,
)
for _ in range(queue_size.item()):
ack_id, completed_tokens = self.cache_controller.ack_backup_queue.get()
host_node = self.ongoing_backup[ack_id]
# process backup acks
for _ in range(n_backup):
ack_id = cc.ack_backup_queue.get()
entry = self.ongoing_backup.pop(ack_id, None)
if entry is not None:
entry.release_host()
if completed_tokens > 0:
if completed_tokens < len(host_node.key):
# backup is only partially successful, split the node
new_node = self._split_node(
host_node.key, host_node, completed_tokens
)
new_node.backuped_storage = True
else:
host_node.backuped_storage = True
host_node.release_host()
del self.ongoing_backup[ack_id]
# release host memory
host_indices_list = []
for _ in range(n_release):
host_indices_list.append(cc.host_mem_release_queue.get())
if host_indices_list:
host_indices = torch.cat(host_indices_list, dim=0)
cc.mem_pool_host.free(host_indices)
def can_terminate_prefetch(self, operation: PrefetchOperation):
can_terminate = True
@@ -519,7 +508,7 @@ class HiRadixCache(RadixCache):
self.cache_controller.mem_pool_host.update_prefetch(written_indices)
self.cache_controller.mem_pool_host.free(host_indices[:matched_length])
self.cache_controller.mem_pool_host.free(
self.cache_controller.append_host_mem_release(
host_indices[min_completed_tokens:completed_tokens]
)
last_host_node.release_host()
@@ -575,7 +564,11 @@ class HiRadixCache(RadixCache):
len(new_input_tokens) % self.page_size
)
new_input_tokens = new_input_tokens[:prefetch_length]
if not self.enable_storage or prefetch_length < self.prefetch_threshold:
if (
not self.enable_storage
or prefetch_length < self.prefetch_threshold
or self.cache_controller.prefetch_rate_limited()
):
return
last_host_node.protect_host()
@@ -583,6 +576,10 @@ class HiRadixCache(RadixCache):
if host_indices is None:
self.evict_host(prefetch_length)
host_indices = self.cache_controller.mem_pool_host.alloc(prefetch_length)
if host_indices is None:
last_host_node.release_host()
# no sufficient host memory for prefetch
return
operation = self.cache_controller.prefetch(
req_id, host_indices, new_input_tokens, last_hash
)

View File

@@ -62,7 +62,6 @@ class TreeNode:
self.host_value: Optional[torch.Tensor] = None
# store hash values of each pages
self.hash_value: Optional[List[str]] = None
self.backuped_storage = False
self.id = TreeNode.counter if id is None else id
TreeNode.counter += 1