feat(hicache): Support passing prefix keys for l3 store. (#9045)

Co-authored-by: pansicheng <sicheng.pan.chn@gmail.com>
Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
This commit is contained in:
hzh0425
2025-10-10 15:22:05 +08:00
committed by GitHub
parent d8467db727
commit ee3bd8a1c8
11 changed files with 107 additions and 24 deletions

View File

@@ -84,12 +84,14 @@ class HiRadixCache(RadixCache):
prefetch_threshold,
prefetch_timeout_base,
prefetch_timeout_per_ki_token,
hicache_storage_pass_prefix_keys,
) = self._parse_storage_backend_extra_config(storage_backend_extra_config)
self.prefetch_threshold = prefetch_threshold
self.prefetch_timeout_base = prefetch_timeout_base
self.prefetch_timeout_per_page = (
page_size / 1024 * prefetch_timeout_per_ki_token
)
self.hicache_storage_pass_prefix_keys = hicache_storage_pass_prefix_keys
# TODO: support more timeout check functions
self.is_prefetch_timeout = self._prefetch_timeout_check_linear_func
self.prefetch_stop_policy = hicache_storage_prefetch_policy
@@ -149,7 +151,7 @@ class HiRadixCache(RadixCache):
storage_backend_extra_config: JSON string containing extra configuration
Returns:
tuple: (extra_config_dict, prefetch_threshold, prefetch_timeout_base, prefetch_timeout_per_ki_token)
tuple: (extra_config_dict, prefetch_threshold, prefetch_timeout_base, prefetch_timeout_per_ki_token, hicache_storage_pass_prefix_keys)
"""
# Parse extra config JSON if provided
extra_config = {}
@@ -165,6 +167,9 @@ class HiRadixCache(RadixCache):
prefetch_timeout_per_ki_token = extra_config.pop(
"prefetch_timeout_per_ki_token", 0.25
) # seconds per 1024 tokens
hicache_storage_pass_prefix_keys = extra_config.pop(
"hicache_storage_pass_prefix_keys", False
)
if not isinstance(prefetch_threshold, int):
raise ValueError(
@@ -184,6 +189,7 @@ class HiRadixCache(RadixCache):
prefetch_threshold,
float(prefetch_timeout_base),
float(prefetch_timeout_per_ki_token),
hicache_storage_pass_prefix_keys,
)
def reset(self):
@@ -245,8 +251,14 @@ class HiRadixCache(RadixCache):
return len(host_indices)
def write_backup_storage(self, node: TreeNode):
prefix_keys = (
node.get_prefix_hash_values(node.parent)
if self.hicache_storage_pass_prefix_keys
else None
)
operation_id = self.cache_controller.write_storage(
node.host_value, node.key, node.hash_value
node.host_value, node.key, node.hash_value, prefix_keys
)
self.ongoing_backup[operation_id] = node
node.protect_host()
@@ -700,6 +712,7 @@ class HiRadixCache(RadixCache):
last_host_node: TreeNode,
new_input_tokens: List[int],
last_hash: Optional[str] = None,
prefix_keys: Optional[List[str]] = None,
):
# align the number of fetching tokens to the page size
prefetch_length = len(new_input_tokens) - (
@@ -723,7 +736,7 @@ class HiRadixCache(RadixCache):
# no sufficient host memory for prefetch
return
operation = self.cache_controller.prefetch(
req_id, host_indices, new_input_tokens, last_hash
req_id, host_indices, new_input_tokens, last_hash, prefix_keys
)
self.ongoing_prefetch[req_id] = (
last_host_node,