Hicache Storage Layer Prototype (#7704)
This commit is contained in:
@@ -35,6 +35,7 @@ class HiRadixCache(RadixCache):
|
||||
hicache_size: int,
|
||||
hicache_write_policy: str,
|
||||
hicache_io_backend: str,
|
||||
hicache_storage_backend: Optional[str] = None,
|
||||
):
|
||||
self.kv_cache = token_to_kv_pool_allocator.get_kvcache()
|
||||
if isinstance(self.kv_cache, MHATokenToKVPool):
|
||||
@@ -49,6 +50,9 @@ class HiRadixCache(RadixCache):
|
||||
raise ValueError(f"HiRadixCache only supports MHA and MLA yet")
|
||||
|
||||
self.tp_group = tp_cache_group
|
||||
self.enable_storage = hicache_storage_backend is not None
|
||||
# todo: customizable storage prefetch threshold
|
||||
self.prefetch_threshold = 256
|
||||
|
||||
self.load_cache_event = threading.Event()
|
||||
self.cache_controller = HiCacheController(
|
||||
@@ -58,16 +62,22 @@ class HiRadixCache(RadixCache):
|
||||
load_cache_event=self.load_cache_event,
|
||||
write_policy=hicache_write_policy,
|
||||
io_backend=hicache_io_backend,
|
||||
storage_backend=hicache_storage_backend,
|
||||
prefetch_threshold=self.prefetch_threshold,
|
||||
)
|
||||
|
||||
# record the nodes with ongoing write through
|
||||
self.ongoing_write_through = {}
|
||||
# record the node segments with ongoing load back
|
||||
self.ongoing_load_back = {}
|
||||
# record the ongoing prefetch requests
|
||||
self.ongoing_prefetch = {}
|
||||
self.ongoing_backup = {}
|
||||
# todo: dynamically adjust the threshold
|
||||
self.write_through_threshold = (
|
||||
1 if hicache_write_policy == "write_through" else 3
|
||||
)
|
||||
self.write_through_threshold_storage = 3
|
||||
self.load_back_threshold = 10
|
||||
super().__init__(
|
||||
req_to_token_pool, token_to_kv_pool_allocator, page_size, disable=False
|
||||
@@ -108,13 +118,30 @@ class HiRadixCache(RadixCache):
|
||||
|
||||
return len(host_indices)
|
||||
|
||||
def write_backup_storage(self, node: TreeNode):
|
||||
operation_id = self.cache_controller.write_storage(
|
||||
node.host_value, node.key, node.parent.get_last_hash_value()
|
||||
)
|
||||
self.ongoing_backup[operation_id] = node
|
||||
node.protect_host()
|
||||
|
||||
def inc_hit_count(self, node: TreeNode):
|
||||
if node.backuped or self.cache_controller.write_policy == "write_back":
|
||||
if self.cache_controller.write_policy == "write_back":
|
||||
return
|
||||
node.hit_count += 1
|
||||
if node.hit_count >= self.write_through_threshold:
|
||||
self.write_backup(node)
|
||||
node.hit_count = 0
|
||||
|
||||
if not node.backuped:
|
||||
if node.hit_count >= self.write_through_threshold:
|
||||
# write to host if the node is not backuped
|
||||
self.write_backup(node)
|
||||
else:
|
||||
if (
|
||||
self.enable_storage
|
||||
and (not node.backuped_storage)
|
||||
and node.hit_count >= self.write_through_threshold_storage
|
||||
):
|
||||
# if the node is backuped on host memory but not on storage
|
||||
self.write_backup_storage(node)
|
||||
|
||||
def writing_check(self, write_back=False):
|
||||
if write_back:
|
||||
@@ -221,6 +248,10 @@ class HiRadixCache(RadixCache):
|
||||
if not x.evicted:
|
||||
continue
|
||||
|
||||
# node is protected from eviction as it has ongoing prefetch or backup to storage
|
||||
if x.host_ref_counter > 0:
|
||||
continue
|
||||
|
||||
num_evicted += self.cache_controller.evict_host(x.host_value)
|
||||
|
||||
for k, v in x.parent.children.items():
|
||||
@@ -314,6 +345,85 @@ class HiRadixCache(RadixCache):
|
||||
def check_hicache_events(self):
|
||||
self.writing_check()
|
||||
self.loading_check()
|
||||
if self.enable_storage:
|
||||
self.check_revoked_prefetch()
|
||||
self.check_backup_progress()
|
||||
|
||||
def check_revoked_prefetch(self):
|
||||
queue_size = torch.tensor(
|
||||
self.cache_controller.prefetch_revoke_queue.qsize(), dtype=torch.int
|
||||
)
|
||||
if torch.distributed.get_world_size(group=self.tp_group) > 1:
|
||||
# synchrnoize TP workers to make the same update to hiradix cache
|
||||
torch.distributed.all_reduce(
|
||||
queue_size,
|
||||
op=torch.distributed.ReduceOp.MIN,
|
||||
group=self.tp_group,
|
||||
)
|
||||
for _ in range(queue_size.item()):
|
||||
req_id = self.cache_controller.prefetch_revoke_queue.get()
|
||||
if req_id in self.ongoing_prefetch:
|
||||
last_host_node, _, host_indices, _ = self.ongoing_prefetch[req_id]
|
||||
last_host_node.release_host()
|
||||
self.cache_controller.mem_pool_host.free(host_indices)
|
||||
del self.ongoing_prefetch[req_id]
|
||||
|
||||
def check_backup_progress(self):
|
||||
queue_size = torch.tensor(
|
||||
self.cache_controller.ack_backup_queue.qsize(), dtype=torch.int
|
||||
)
|
||||
if torch.distributed.get_world_size(group=self.tp_group) > 1:
|
||||
# synchrnoize TP workers to make the same update to hiradix cache
|
||||
torch.distributed.all_reduce(
|
||||
queue_size,
|
||||
op=torch.distributed.ReduceOp.MIN,
|
||||
group=self.tp_group,
|
||||
)
|
||||
for _ in range(queue_size.item()):
|
||||
ack_id, hash_value = self.cache_controller.ack_backup_queue.get()
|
||||
self.ongoing_backup[ack_id].hash_value = hash_value
|
||||
self.ongoing_backup[ack_id].release_host()
|
||||
del self.ongoing_backup[ack_id]
|
||||
|
||||
def check_prefetch_progress(self, req_id: str):
|
||||
if req_id not in self.ongoing_prefetch:
|
||||
# there is no ongoing prefetch for this request or it has been revoked
|
||||
return
|
||||
|
||||
# todo: more policies for prefetch progress such as timeout
|
||||
# the current policy is to prefetch with best effort and terminate when queuing is over
|
||||
last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch[
|
||||
req_id
|
||||
]
|
||||
completed_tokens, hash_value = self.cache_controller.terminate_prefetch(
|
||||
operation
|
||||
)
|
||||
logger.debug(f"Prefetch {req_id} completed with {completed_tokens} tokens")
|
||||
|
||||
min_completed_tokens = torch.tensor(completed_tokens, dtype=torch.int)
|
||||
if torch.distributed.get_world_size(group=self.tp_group) > 1:
|
||||
# synchrnoize TP workers to make the same update to hiradix cache
|
||||
torch.distributed.all_reduce(
|
||||
min_completed_tokens,
|
||||
op=torch.distributed.ReduceOp.MIN,
|
||||
group=self.tp_group,
|
||||
)
|
||||
min_completed_tokens = min_completed_tokens.item()
|
||||
fetched_token_ids = token_ids[:min_completed_tokens]
|
||||
written_indices = host_indices[:min_completed_tokens]
|
||||
matched_length = self._insert_helper_host(
|
||||
last_host_node,
|
||||
fetched_token_ids,
|
||||
written_indices,
|
||||
hash_value[:min_completed_tokens],
|
||||
)
|
||||
|
||||
self.cache_controller.mem_pool_host.free(host_indices[:matched_length])
|
||||
self.cache_controller.mem_pool_host.free(
|
||||
host_indices[min_completed_tokens:completed_tokens]
|
||||
)
|
||||
last_host_node.release_host()
|
||||
del self.ongoing_prefetch[req_id]
|
||||
|
||||
def match_prefix(self, key: List[int], **kwargs):
|
||||
empty_value = torch.empty((0,), dtype=torch.int64, device=self.device)
|
||||
@@ -348,6 +458,71 @@ class HiRadixCache(RadixCache):
|
||||
host_hit_length=host_hit_length,
|
||||
)
|
||||
|
||||
def prefetch_from_storage(
|
||||
self,
|
||||
req_id: str,
|
||||
last_host_node: TreeNode,
|
||||
new_input_tokens: List[int],
|
||||
last_hash: Optional[str] = None,
|
||||
):
|
||||
if not self.enable_storage or len(new_input_tokens) < self.prefetch_threshold:
|
||||
return
|
||||
|
||||
last_host_node.protect_host()
|
||||
host_indices = self.cache_controller.mem_pool_host.alloc(len(new_input_tokens))
|
||||
if host_indices is None:
|
||||
self.evict_host(len(new_input_tokens))
|
||||
host_indices = self.cache_controller.mem_pool_host.alloc(
|
||||
len(new_input_tokens)
|
||||
)
|
||||
if host_indices is None:
|
||||
last_host_node.release_host()
|
||||
# no sufficient host memory to prefetch
|
||||
return
|
||||
operation = self.cache_controller.prefetch(
|
||||
req_id, host_indices, new_input_tokens, last_hash
|
||||
)
|
||||
self.ongoing_prefetch[req_id] = (
|
||||
last_host_node,
|
||||
new_input_tokens,
|
||||
host_indices,
|
||||
operation,
|
||||
)
|
||||
|
||||
def _insert_helper_host(self, node: TreeNode, key: List, host_value, hash_value):
|
||||
node.last_access_time = time.monotonic()
|
||||
if len(key) == 0:
|
||||
return 0
|
||||
|
||||
child_key = self.get_child_key_fn(key)
|
||||
|
||||
matched_length = 0
|
||||
while len(key) > 0 and child_key in node.children.keys():
|
||||
node = node.children[child_key]
|
||||
node.last_access_time = time.monotonic()
|
||||
prefix_len = self.key_match_fn(node.key, key)
|
||||
key = key[prefix_len:]
|
||||
host_value = host_value[prefix_len:]
|
||||
hash_value = hash_value[prefix_len:]
|
||||
matched_length += prefix_len
|
||||
|
||||
if prefix_len < len(node.key):
|
||||
new_node = self._split_node(node.key, node, prefix_len)
|
||||
node = new_node
|
||||
|
||||
if len(key):
|
||||
child_key = self.get_child_key_fn(key)
|
||||
|
||||
if len(key):
|
||||
new_node = TreeNode()
|
||||
new_node.parent = node
|
||||
new_node.key = key
|
||||
new_node.value = None
|
||||
new_node.host_value = host_value
|
||||
new_node.hash_value = hash_value
|
||||
node.children[child_key] = new_node
|
||||
return matched_length
|
||||
|
||||
def _match_prefix_helper(self, node: TreeNode, key: List):
|
||||
node.last_access_time = time.monotonic()
|
||||
child_key = self.get_child_key_fn(key)
|
||||
|
||||
Reference in New Issue
Block a user