feat(hicache): Support passing prefix keys for l3 store. (#9045)
Co-authored-by: pansicheng <sicheng.pan.chn@gmail.com> Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
This commit is contained in:
@@ -84,12 +84,14 @@ class HiRadixCache(RadixCache):
|
||||
prefetch_threshold,
|
||||
prefetch_timeout_base,
|
||||
prefetch_timeout_per_ki_token,
|
||||
hicache_storage_pass_prefix_keys,
|
||||
) = self._parse_storage_backend_extra_config(storage_backend_extra_config)
|
||||
self.prefetch_threshold = prefetch_threshold
|
||||
self.prefetch_timeout_base = prefetch_timeout_base
|
||||
self.prefetch_timeout_per_page = (
|
||||
page_size / 1024 * prefetch_timeout_per_ki_token
|
||||
)
|
||||
self.hicache_storage_pass_prefix_keys = hicache_storage_pass_prefix_keys
|
||||
# TODO: support more timeout check functions
|
||||
self.is_prefetch_timeout = self._prefetch_timeout_check_linear_func
|
||||
self.prefetch_stop_policy = hicache_storage_prefetch_policy
|
||||
@@ -149,7 +151,7 @@ class HiRadixCache(RadixCache):
|
||||
storage_backend_extra_config: JSON string containing extra configuration
|
||||
|
||||
Returns:
|
||||
tuple: (extra_config_dict, prefetch_threshold, prefetch_timeout_base, prefetch_timeout_per_ki_token)
|
||||
tuple: (extra_config_dict, prefetch_threshold, prefetch_timeout_base, prefetch_timeout_per_ki_token, hicache_storage_pass_prefix_keys)
|
||||
"""
|
||||
# Parse extra config JSON if provided
|
||||
extra_config = {}
|
||||
@@ -165,6 +167,9 @@ class HiRadixCache(RadixCache):
|
||||
prefetch_timeout_per_ki_token = extra_config.pop(
|
||||
"prefetch_timeout_per_ki_token", 0.25
|
||||
) # seconds per 1024 tokens
|
||||
hicache_storage_pass_prefix_keys = extra_config.pop(
|
||||
"hicache_storage_pass_prefix_keys", False
|
||||
)
|
||||
|
||||
if not isinstance(prefetch_threshold, int):
|
||||
raise ValueError(
|
||||
@@ -184,6 +189,7 @@ class HiRadixCache(RadixCache):
|
||||
prefetch_threshold,
|
||||
float(prefetch_timeout_base),
|
||||
float(prefetch_timeout_per_ki_token),
|
||||
hicache_storage_pass_prefix_keys,
|
||||
)
|
||||
|
||||
def reset(self):
|
||||
@@ -245,8 +251,14 @@ class HiRadixCache(RadixCache):
|
||||
return len(host_indices)
|
||||
|
||||
def write_backup_storage(self, node: TreeNode):
|
||||
prefix_keys = (
|
||||
node.get_prefix_hash_values(node.parent)
|
||||
if self.hicache_storage_pass_prefix_keys
|
||||
else None
|
||||
)
|
||||
|
||||
operation_id = self.cache_controller.write_storage(
|
||||
node.host_value, node.key, node.hash_value
|
||||
node.host_value, node.key, node.hash_value, prefix_keys
|
||||
)
|
||||
self.ongoing_backup[operation_id] = node
|
||||
node.protect_host()
|
||||
@@ -700,6 +712,7 @@ class HiRadixCache(RadixCache):
|
||||
last_host_node: TreeNode,
|
||||
new_input_tokens: List[int],
|
||||
last_hash: Optional[str] = None,
|
||||
prefix_keys: Optional[List[str]] = None,
|
||||
):
|
||||
# align the number of fetching tokens to the page size
|
||||
prefetch_length = len(new_input_tokens) - (
|
||||
@@ -723,7 +736,7 @@ class HiRadixCache(RadixCache):
|
||||
# no sufficient host memory for prefetch
|
||||
return
|
||||
operation = self.cache_controller.prefetch(
|
||||
req_id, host_indices, new_input_tokens, last_hash
|
||||
req_id, host_indices, new_input_tokens, last_hash, prefix_keys
|
||||
)
|
||||
self.ongoing_prefetch[req_id] = (
|
||||
last_host_node,
|
||||
|
||||
Reference in New Issue
Block a user