[HiCache] Storage Refactoring (#9797)
Co-authored-by: pansicheng <27603155+pansicheng@users.noreply.github.com>
This commit is contained in:
@@ -104,9 +104,6 @@ class HiRadixCache(RadixCache):
|
||||
self.write_through_threshold = (
|
||||
1 if hicache_write_policy == "write_through" else 2
|
||||
)
|
||||
self.write_through_threshold_storage = (
|
||||
1 if hicache_write_policy == "write_through" else 3
|
||||
)
|
||||
self.load_back_threshold = 10
|
||||
super().__init__(
|
||||
req_to_token_pool, token_to_kv_pool_allocator, page_size, disable=False
|
||||
@@ -174,14 +171,6 @@ class HiRadixCache(RadixCache):
|
||||
if node.hit_count >= self.write_through_threshold:
|
||||
# write to host if the node is not backuped
|
||||
self.write_backup(node)
|
||||
else:
|
||||
if (
|
||||
self.enable_storage
|
||||
and (not node.backuped_storage)
|
||||
and node.hit_count >= self.write_through_threshold_storage
|
||||
):
|
||||
# if the node is backuped on host memory but not on storage
|
||||
self.write_backup_storage(node)
|
||||
|
||||
def writing_check(self, write_back=False):
|
||||
if write_back:
|
||||
@@ -202,8 +191,11 @@ class HiRadixCache(RadixCache):
|
||||
)
|
||||
for _ in range(queue_size.item()):
|
||||
ack_id = self.cache_controller.ack_write_queue.get()
|
||||
self.dec_lock_ref(self.ongoing_write_through[ack_id])
|
||||
backuped_node = self.ongoing_write_through[ack_id]
|
||||
self.dec_lock_ref(backuped_node)
|
||||
del self.ongoing_write_through[ack_id]
|
||||
if self.enable_storage:
|
||||
self.write_backup_storage(backuped_node)
|
||||
|
||||
def loading_check(self):
|
||||
while not self.cache_controller.ack_load_queue.empty():
|
||||
@@ -386,57 +378,54 @@ class HiRadixCache(RadixCache):
|
||||
self.writing_check()
|
||||
self.loading_check()
|
||||
if self.enable_storage:
|
||||
self.check_revoked_prefetch()
|
||||
self.check_backup_progress()
|
||||
self.drain_storage_control_queues()
|
||||
|
||||
def check_revoked_prefetch(self):
|
||||
queue_size = torch.tensor(
|
||||
self.cache_controller.prefetch_revoke_queue.qsize(), dtype=torch.int
|
||||
def drain_storage_control_queues(self):
|
||||
"""
|
||||
Combine prefetch revoke, backup ack, and host mem release checks
|
||||
to minimize TP synchronization and Python overhead.
|
||||
"""
|
||||
cc = self.cache_controller
|
||||
|
||||
qsizes = torch.tensor(
|
||||
[
|
||||
cc.prefetch_revoke_queue.qsize(),
|
||||
cc.ack_backup_queue.qsize(),
|
||||
cc.host_mem_release_queue.qsize(),
|
||||
],
|
||||
dtype=torch.int,
|
||||
)
|
||||
if self.tp_world_size > 1:
|
||||
# synchrnoize TP workers to make the same update to hiradix cache
|
||||
torch.distributed.all_reduce(
|
||||
queue_size,
|
||||
op=torch.distributed.ReduceOp.MIN,
|
||||
group=self.tp_group,
|
||||
qsizes, op=torch.distributed.ReduceOp.MIN, group=self.tp_group
|
||||
)
|
||||
for _ in range(queue_size.item()):
|
||||
req_id = self.cache_controller.prefetch_revoke_queue.get()
|
||||
if req_id in self.ongoing_prefetch:
|
||||
last_host_node, token_ids, _, _ = self.ongoing_prefetch[req_id]
|
||||
|
||||
n_revoke, n_backup, n_release = map(int, qsizes.tolist())
|
||||
|
||||
# process prefetch revokes
|
||||
for _ in range(n_revoke):
|
||||
req_id = cc.prefetch_revoke_queue.get()
|
||||
info = self.ongoing_prefetch.pop(req_id, None)
|
||||
if info is not None:
|
||||
last_host_node, token_ids, _, _ = info
|
||||
last_host_node.release_host()
|
||||
del self.ongoing_prefetch[req_id]
|
||||
self.cache_controller.prefetch_tokens_occupied -= len(token_ids)
|
||||
else:
|
||||
# the revoked operation already got terminated
|
||||
pass
|
||||
cc.prefetch_tokens_occupied -= len(token_ids)
|
||||
# else: the revoked operation already got terminated, nothing to do
|
||||
|
||||
def check_backup_progress(self):
|
||||
queue_size = torch.tensor(
|
||||
self.cache_controller.ack_backup_queue.qsize(), dtype=torch.int
|
||||
)
|
||||
if self.tp_world_size > 1:
|
||||
# synchrnoize TP workers to make the same update to hiradix cache
|
||||
torch.distributed.all_reduce(
|
||||
queue_size,
|
||||
op=torch.distributed.ReduceOp.MIN,
|
||||
group=self.tp_group,
|
||||
)
|
||||
for _ in range(queue_size.item()):
|
||||
ack_id, completed_tokens = self.cache_controller.ack_backup_queue.get()
|
||||
host_node = self.ongoing_backup[ack_id]
|
||||
# process backup acks
|
||||
for _ in range(n_backup):
|
||||
ack_id = cc.ack_backup_queue.get()
|
||||
entry = self.ongoing_backup.pop(ack_id, None)
|
||||
if entry is not None:
|
||||
entry.release_host()
|
||||
|
||||
if completed_tokens > 0:
|
||||
if completed_tokens < len(host_node.key):
|
||||
# backup is only partially successful, split the node
|
||||
new_node = self._split_node(
|
||||
host_node.key, host_node, completed_tokens
|
||||
)
|
||||
new_node.backuped_storage = True
|
||||
else:
|
||||
host_node.backuped_storage = True
|
||||
host_node.release_host()
|
||||
del self.ongoing_backup[ack_id]
|
||||
# release host memory
|
||||
host_indices_list = []
|
||||
for _ in range(n_release):
|
||||
host_indices_list.append(cc.host_mem_release_queue.get())
|
||||
if host_indices_list:
|
||||
host_indices = torch.cat(host_indices_list, dim=0)
|
||||
cc.mem_pool_host.free(host_indices)
|
||||
|
||||
def can_terminate_prefetch(self, operation: PrefetchOperation):
|
||||
can_terminate = True
|
||||
@@ -519,7 +508,7 @@ class HiRadixCache(RadixCache):
|
||||
self.cache_controller.mem_pool_host.update_prefetch(written_indices)
|
||||
|
||||
self.cache_controller.mem_pool_host.free(host_indices[:matched_length])
|
||||
self.cache_controller.mem_pool_host.free(
|
||||
self.cache_controller.append_host_mem_release(
|
||||
host_indices[min_completed_tokens:completed_tokens]
|
||||
)
|
||||
last_host_node.release_host()
|
||||
@@ -575,7 +564,11 @@ class HiRadixCache(RadixCache):
|
||||
len(new_input_tokens) % self.page_size
|
||||
)
|
||||
new_input_tokens = new_input_tokens[:prefetch_length]
|
||||
if not self.enable_storage or prefetch_length < self.prefetch_threshold:
|
||||
if (
|
||||
not self.enable_storage
|
||||
or prefetch_length < self.prefetch_threshold
|
||||
or self.cache_controller.prefetch_rate_limited()
|
||||
):
|
||||
return
|
||||
|
||||
last_host_node.protect_host()
|
||||
@@ -583,6 +576,10 @@ class HiRadixCache(RadixCache):
|
||||
if host_indices is None:
|
||||
self.evict_host(prefetch_length)
|
||||
host_indices = self.cache_controller.mem_pool_host.alloc(prefetch_length)
|
||||
if host_indices is None:
|
||||
last_host_node.release_host()
|
||||
# no sufficient host memory for prefetch
|
||||
return
|
||||
operation = self.cache_controller.prefetch(
|
||||
req_id, host_indices, new_input_tokens, last_hash
|
||||
)
|
||||
|
||||
@@ -62,7 +62,6 @@ class TreeNode:
|
||||
self.host_value: Optional[torch.Tensor] = None
|
||||
# store hash values of each pages
|
||||
self.hash_value: Optional[List[str]] = None
|
||||
self.backuped_storage = False
|
||||
|
||||
self.id = TreeNode.counter if id is None else id
|
||||
TreeNode.counter += 1
|
||||
|
||||
Reference in New Issue
Block a user