[HiCache] Configurable and Dynamic Prefetch Timeout (#10512)

Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
This commit is contained in:
ykwd
2025-10-01 21:44:10 +08:00
committed by GitHub
parent 86cb4db058
commit bfa274380b
2 changed files with 75 additions and 21 deletions

View File

@@ -250,7 +250,7 @@ class HiCacheController:
storage_backend: Optional[str] = None, storage_backend: Optional[str] = None,
prefetch_threshold: int = 256, prefetch_threshold: int = 256,
model_name: Optional[str] = None, model_name: Optional[str] = None,
storage_backend_extra_config: Optional[str] = None, storage_backend_extra_config: Optional[dict] = None,
): ):
self.mem_pool_device_allocator = token_to_kv_pool_allocator self.mem_pool_device_allocator = token_to_kv_pool_allocator
self.mem_pool_device = token_to_kv_pool_allocator.get_kvcache() self.mem_pool_device = token_to_kv_pool_allocator.get_kvcache()
@@ -361,7 +361,7 @@ class HiCacheController:
def _generate_storage_config( def _generate_storage_config(
self, self,
model_name: Optional[str] = None, model_name: Optional[str] = None,
storage_backend_extra_config: Optional[str] = None, storage_backend_extra_config: Optional[dict] = None,
): ):
if is_dp_attention_enabled(): if is_dp_attention_enabled():
@@ -376,23 +376,13 @@ class HiCacheController:
# Currently, AscendMLAPagedTokenToKVPool is the subclass of MLATokenToKVPool. # Currently, AscendMLAPagedTokenToKVPool is the subclass of MLATokenToKVPool.
is_mla_backend = isinstance(self.mem_pool_device, MLATokenToKVPool) is_mla_backend = isinstance(self.mem_pool_device, MLATokenToKVPool)
# Parse extra config JSON if provided
extra_config = None
if storage_backend_extra_config:
try:
import json
extra_config = json.loads(storage_backend_extra_config)
except Exception as e:
logger.error(f"Invalid backend extra config JSON: {e}")
return HiCacheStorageConfig( return HiCacheStorageConfig(
tp_rank=self.tp_rank, tp_rank=self.tp_rank,
tp_size=self.tp_size, tp_size=self.tp_size,
is_mla_model=is_mla_backend, is_mla_model=is_mla_backend,
is_page_first_layout=self.mem_pool_host.layout == "page_first", is_page_first_layout=self.mem_pool_host.layout == "page_first",
model_name=model_name, model_name=model_name,
extra_config=extra_config, extra_config=storage_backend_extra_config,
) )
def reset(self): def reset(self):

View File

@@ -1,8 +1,8 @@
import heapq import heapq
import json
import logging import logging
import threading import threading
import time import time
from queue import Queue
from typing import List, Optional from typing import List, Optional
import torch import torch
@@ -78,9 +78,19 @@ class HiRadixCache(RadixCache):
self.enable_storage = hicache_storage_backend is not None self.enable_storage = hicache_storage_backend is not None
self.enable_storage_metrics = self.enable_storage and enable_metrics self.enable_storage_metrics = self.enable_storage and enable_metrics
# todo: customizable storage prefetch threshold and timeout (
self.prefetch_threshold = 256 extra_config,
self.prefetch_timeout = 3 # seconds prefetch_threshold,
prefetch_timeout_base,
prefetch_timeout_per_ki_token,
) = self._parse_storage_backend_extra_config(storage_backend_extra_config)
self.prefetch_threshold = prefetch_threshold
self.prefetch_timeout_base = prefetch_timeout_base
self.prefetch_timeout_per_page = (
page_size / 1024 * prefetch_timeout_per_ki_token
)
# TODO: support more timeout check functions
self.is_prefetch_timeout = self._prefetch_timeout_check_linear_func
self.prefetch_stop_policy = hicache_storage_prefetch_policy self.prefetch_stop_policy = hicache_storage_prefetch_policy
self.load_cache_event = threading.Event() self.load_cache_event = threading.Event()
@@ -95,7 +105,7 @@ class HiRadixCache(RadixCache):
storage_backend=hicache_storage_backend, storage_backend=hicache_storage_backend,
prefetch_threshold=self.prefetch_threshold, prefetch_threshold=self.prefetch_threshold,
model_name=model_name, model_name=model_name,
storage_backend_extra_config=storage_backend_extra_config, storage_backend_extra_config=extra_config,
) )
if self.enable_storage_metrics: if self.enable_storage_metrics:
# TODO: support pp # TODO: support pp
@@ -127,6 +137,53 @@ class HiRadixCache(RadixCache):
eviction_policy=eviction_policy, eviction_policy=eviction_policy,
) )
def _parse_storage_backend_extra_config(
self, storage_backend_extra_config: Optional[str]
):
"""
Parse storage backend extra config JSON and extract specific parameters.
Args:
storage_backend_extra_config: JSON string containing extra configuration
Returns:
tuple: (extra_config_dict, prefetch_threshold, prefetch_timeout_base, prefetch_timeout_per_ki_token)
"""
# Parse extra config JSON if provided
extra_config = {}
if storage_backend_extra_config:
try:
extra_config = json.loads(storage_backend_extra_config)
except Exception as e:
logger.error(f"Invalid backend extra config JSON: {e}")
raise e
prefetch_threshold = extra_config.pop("prefetch_threshold", 256) # tokens
prefetch_timeout_base = extra_config.pop("prefetch_timeout_base", 1) # seconds
prefetch_timeout_per_ki_token = extra_config.pop(
"prefetch_timeout_per_ki_token", 0.25
) # seconds per 1024 tokens
if not isinstance(prefetch_threshold, int):
raise ValueError(
f"prefetch_threshold must be int, got {type(prefetch_threshold).__name__}"
)
if not isinstance(prefetch_timeout_base, (int, float)):
raise ValueError(
f"prefetch_timeout_base must be number, got {type(prefetch_timeout_base).__name__}"
)
if not isinstance(prefetch_timeout_per_ki_token, (int, float)):
raise ValueError(
f"prefetch_timeout_per_ki_token must be number, got {type(prefetch_timeout_per_ki_token).__name__}"
)
return (
extra_config,
prefetch_threshold,
float(prefetch_timeout_base),
float(prefetch_timeout_per_ki_token),
)
def reset(self): def reset(self):
TreeNode.counter = 0 TreeNode.counter = 0
self.cache_controller.reset() self.cache_controller.reset()
@@ -490,6 +547,15 @@ class HiRadixCache(RadixCache):
host_indices = torch.cat(host_indices_list, dim=0) host_indices = torch.cat(host_indices_list, dim=0)
cc.mem_pool_host.free(host_indices) cc.mem_pool_host.free(host_indices)
# Timeout is linearly increasing with the number of pages
def _prefetch_timeout_check_linear_func(self, operation: PrefetchOperation):
# If hash_value has not been computed in timeout_base seconds, terminate it.
return (
time.monotonic() - operation.start_time
> self.prefetch_timeout_base
+ len(operation.hash_value) * self.prefetch_timeout_per_page
)
def can_terminate_prefetch(self, operation: PrefetchOperation): def can_terminate_prefetch(self, operation: PrefetchOperation):
can_terminate = True can_terminate = True
@@ -506,9 +572,7 @@ class HiRadixCache(RadixCache):
if self.prefetch_stop_policy == "wait_complete": if self.prefetch_stop_policy == "wait_complete":
can_terminate = completed can_terminate = completed
elif self.prefetch_stop_policy == "timeout": elif self.prefetch_stop_policy == "timeout":
can_terminate = completed or ( can_terminate = completed or self.is_prefetch_timeout(operation)
time.monotonic() - operation.start_time > self.prefetch_timeout
)
else: else:
# unknown prefetch stop policy, just return True # unknown prefetch stop policy, just return True
return True return True