Hicache Storage Layer Prototype (#7704)

This commit is contained in:
Zhiqiang Xie
2025-07-18 00:20:19 -07:00
committed by GitHub
parent 7891bac16b
commit 9d33fcfb8e
9 changed files with 714 additions and 4 deletions

View File

@@ -35,6 +35,7 @@ class HiRadixCache(RadixCache):
hicache_size: int,
hicache_write_policy: str,
hicache_io_backend: str,
hicache_storage_backend: Optional[str] = None,
):
self.kv_cache = token_to_kv_pool_allocator.get_kvcache()
if isinstance(self.kv_cache, MHATokenToKVPool):
@@ -49,6 +50,9 @@ class HiRadixCache(RadixCache):
raise ValueError(f"HiRadixCache only supports MHA and MLA yet")
self.tp_group = tp_cache_group
self.enable_storage = hicache_storage_backend is not None
# todo: customizable storage prefetch threshold
self.prefetch_threshold = 256
self.load_cache_event = threading.Event()
self.cache_controller = HiCacheController(
@@ -58,16 +62,22 @@ class HiRadixCache(RadixCache):
load_cache_event=self.load_cache_event,
write_policy=hicache_write_policy,
io_backend=hicache_io_backend,
storage_backend=hicache_storage_backend,
prefetch_threshold=self.prefetch_threshold,
)
# record the nodes with ongoing write through
self.ongoing_write_through = {}
# record the node segments with ongoing load back
self.ongoing_load_back = {}
# record the ongoing prefetch requests
self.ongoing_prefetch = {}
self.ongoing_backup = {}
# todo: dynamically adjust the threshold
self.write_through_threshold = (
1 if hicache_write_policy == "write_through" else 3
)
self.write_through_threshold_storage = 3
self.load_back_threshold = 10
super().__init__(
req_to_token_pool, token_to_kv_pool_allocator, page_size, disable=False
@@ -108,13 +118,30 @@ class HiRadixCache(RadixCache):
return len(host_indices)
def write_backup_storage(self, node: TreeNode):
operation_id = self.cache_controller.write_storage(
node.host_value, node.key, node.parent.get_last_hash_value()
)
self.ongoing_backup[operation_id] = node
node.protect_host()
def inc_hit_count(self, node: TreeNode):
if node.backuped or self.cache_controller.write_policy == "write_back":
if self.cache_controller.write_policy == "write_back":
return
node.hit_count += 1
if node.hit_count >= self.write_through_threshold:
self.write_backup(node)
node.hit_count = 0
if not node.backuped:
if node.hit_count >= self.write_through_threshold:
# write to host if the node is not backuped
self.write_backup(node)
else:
if (
self.enable_storage
and (not node.backuped_storage)
and node.hit_count >= self.write_through_threshold_storage
):
# if the node is backuped on host memory but not on storage
self.write_backup_storage(node)
def writing_check(self, write_back=False):
if write_back:
@@ -221,6 +248,10 @@ class HiRadixCache(RadixCache):
if not x.evicted:
continue
# node is protected from eviction as it has ongoing prefetch or backup to storage
if x.host_ref_counter > 0:
continue
num_evicted += self.cache_controller.evict_host(x.host_value)
for k, v in x.parent.children.items():
@@ -314,6 +345,85 @@ class HiRadixCache(RadixCache):
def check_hicache_events(self):
self.writing_check()
self.loading_check()
if self.enable_storage:
self.check_revoked_prefetch()
self.check_backup_progress()
def check_revoked_prefetch(self):
queue_size = torch.tensor(
self.cache_controller.prefetch_revoke_queue.qsize(), dtype=torch.int
)
if torch.distributed.get_world_size(group=self.tp_group) > 1:
# synchrnoize TP workers to make the same update to hiradix cache
torch.distributed.all_reduce(
queue_size,
op=torch.distributed.ReduceOp.MIN,
group=self.tp_group,
)
for _ in range(queue_size.item()):
req_id = self.cache_controller.prefetch_revoke_queue.get()
if req_id in self.ongoing_prefetch:
last_host_node, _, host_indices, _ = self.ongoing_prefetch[req_id]
last_host_node.release_host()
self.cache_controller.mem_pool_host.free(host_indices)
del self.ongoing_prefetch[req_id]
def check_backup_progress(self):
queue_size = torch.tensor(
self.cache_controller.ack_backup_queue.qsize(), dtype=torch.int
)
if torch.distributed.get_world_size(group=self.tp_group) > 1:
# synchrnoize TP workers to make the same update to hiradix cache
torch.distributed.all_reduce(
queue_size,
op=torch.distributed.ReduceOp.MIN,
group=self.tp_group,
)
for _ in range(queue_size.item()):
ack_id, hash_value = self.cache_controller.ack_backup_queue.get()
self.ongoing_backup[ack_id].hash_value = hash_value
self.ongoing_backup[ack_id].release_host()
del self.ongoing_backup[ack_id]
def check_prefetch_progress(self, req_id: str):
if req_id not in self.ongoing_prefetch:
# there is no ongoing prefetch for this request or it has been revoked
return
# todo: more policies for prefetch progress such as timeout
# the current policy is to prefetch with best effort and terminate when queuing is over
last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch[
req_id
]
completed_tokens, hash_value = self.cache_controller.terminate_prefetch(
operation
)
logger.debug(f"Prefetch {req_id} completed with {completed_tokens} tokens")
min_completed_tokens = torch.tensor(completed_tokens, dtype=torch.int)
if torch.distributed.get_world_size(group=self.tp_group) > 1:
# synchrnoize TP workers to make the same update to hiradix cache
torch.distributed.all_reduce(
min_completed_tokens,
op=torch.distributed.ReduceOp.MIN,
group=self.tp_group,
)
min_completed_tokens = min_completed_tokens.item()
fetched_token_ids = token_ids[:min_completed_tokens]
written_indices = host_indices[:min_completed_tokens]
matched_length = self._insert_helper_host(
last_host_node,
fetched_token_ids,
written_indices,
hash_value[:min_completed_tokens],
)
self.cache_controller.mem_pool_host.free(host_indices[:matched_length])
self.cache_controller.mem_pool_host.free(
host_indices[min_completed_tokens:completed_tokens]
)
last_host_node.release_host()
del self.ongoing_prefetch[req_id]
def match_prefix(self, key: List[int], **kwargs):
empty_value = torch.empty((0,), dtype=torch.int64, device=self.device)
@@ -348,6 +458,71 @@ class HiRadixCache(RadixCache):
host_hit_length=host_hit_length,
)
def prefetch_from_storage(
self,
req_id: str,
last_host_node: TreeNode,
new_input_tokens: List[int],
last_hash: Optional[str] = None,
):
if not self.enable_storage or len(new_input_tokens) < self.prefetch_threshold:
return
last_host_node.protect_host()
host_indices = self.cache_controller.mem_pool_host.alloc(len(new_input_tokens))
if host_indices is None:
self.evict_host(len(new_input_tokens))
host_indices = self.cache_controller.mem_pool_host.alloc(
len(new_input_tokens)
)
if host_indices is None:
last_host_node.release_host()
# no sufficient host memory to prefetch
return
operation = self.cache_controller.prefetch(
req_id, host_indices, new_input_tokens, last_hash
)
self.ongoing_prefetch[req_id] = (
last_host_node,
new_input_tokens,
host_indices,
operation,
)
def _insert_helper_host(self, node: TreeNode, key: List, host_value, hash_value):
node.last_access_time = time.monotonic()
if len(key) == 0:
return 0
child_key = self.get_child_key_fn(key)
matched_length = 0
while len(key) > 0 and child_key in node.children.keys():
node = node.children[child_key]
node.last_access_time = time.monotonic()
prefix_len = self.key_match_fn(node.key, key)
key = key[prefix_len:]
host_value = host_value[prefix_len:]
hash_value = hash_value[prefix_len:]
matched_length += prefix_len
if prefix_len < len(node.key):
new_node = self._split_node(node.key, node, prefix_len)
node = new_node
if len(key):
child_key = self.get_child_key_fn(key)
if len(key):
new_node = TreeNode()
new_node.parent = node
new_node.key = key
new_node.value = None
new_node.host_value = host_value
new_node.hash_value = hash_value
node.children[child_key] = new_node
return matched_length
def _match_prefix_helper(self, node: TreeNode, key: List):
node.last_access_time = time.monotonic()
child_key = self.get_child_key_fn(key)