[HiCache] Configurable and Dynamic Prefetch Timeout (#10512)
Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
This commit is contained in:
@@ -250,7 +250,7 @@ class HiCacheController:
|
||||
storage_backend: Optional[str] = None,
|
||||
prefetch_threshold: int = 256,
|
||||
model_name: Optional[str] = None,
|
||||
storage_backend_extra_config: Optional[str] = None,
|
||||
storage_backend_extra_config: Optional[dict] = None,
|
||||
):
|
||||
self.mem_pool_device_allocator = token_to_kv_pool_allocator
|
||||
self.mem_pool_device = token_to_kv_pool_allocator.get_kvcache()
|
||||
@@ -361,7 +361,7 @@ class HiCacheController:
|
||||
def _generate_storage_config(
|
||||
self,
|
||||
model_name: Optional[str] = None,
|
||||
storage_backend_extra_config: Optional[str] = None,
|
||||
storage_backend_extra_config: Optional[dict] = None,
|
||||
):
|
||||
|
||||
if is_dp_attention_enabled():
|
||||
@@ -376,23 +376,13 @@ class HiCacheController:
|
||||
# Currently, AscendMLAPagedTokenToKVPool is the subclass of MLATokenToKVPool.
|
||||
is_mla_backend = isinstance(self.mem_pool_device, MLATokenToKVPool)
|
||||
|
||||
# Parse extra config JSON if provided
|
||||
extra_config = None
|
||||
if storage_backend_extra_config:
|
||||
try:
|
||||
import json
|
||||
|
||||
extra_config = json.loads(storage_backend_extra_config)
|
||||
except Exception as e:
|
||||
logger.error(f"Invalid backend extra config JSON: {e}")
|
||||
|
||||
return HiCacheStorageConfig(
|
||||
tp_rank=self.tp_rank,
|
||||
tp_size=self.tp_size,
|
||||
is_mla_model=is_mla_backend,
|
||||
is_page_first_layout=self.mem_pool_host.layout == "page_first",
|
||||
model_name=model_name,
|
||||
extra_config=extra_config,
|
||||
extra_config=storage_backend_extra_config,
|
||||
)
|
||||
|
||||
def reset(self):
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import heapq
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from queue import Queue
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
@@ -78,9 +78,19 @@ class HiRadixCache(RadixCache):
|
||||
self.enable_storage = hicache_storage_backend is not None
|
||||
self.enable_storage_metrics = self.enable_storage and enable_metrics
|
||||
|
||||
# todo: customizable storage prefetch threshold and timeout
|
||||
self.prefetch_threshold = 256
|
||||
self.prefetch_timeout = 3 # seconds
|
||||
(
|
||||
extra_config,
|
||||
prefetch_threshold,
|
||||
prefetch_timeout_base,
|
||||
prefetch_timeout_per_ki_token,
|
||||
) = self._parse_storage_backend_extra_config(storage_backend_extra_config)
|
||||
self.prefetch_threshold = prefetch_threshold
|
||||
self.prefetch_timeout_base = prefetch_timeout_base
|
||||
self.prefetch_timeout_per_page = (
|
||||
page_size / 1024 * prefetch_timeout_per_ki_token
|
||||
)
|
||||
# TODO: support more timeout check functions
|
||||
self.is_prefetch_timeout = self._prefetch_timeout_check_linear_func
|
||||
self.prefetch_stop_policy = hicache_storage_prefetch_policy
|
||||
|
||||
self.load_cache_event = threading.Event()
|
||||
@@ -95,7 +105,7 @@ class HiRadixCache(RadixCache):
|
||||
storage_backend=hicache_storage_backend,
|
||||
prefetch_threshold=self.prefetch_threshold,
|
||||
model_name=model_name,
|
||||
storage_backend_extra_config=storage_backend_extra_config,
|
||||
storage_backend_extra_config=extra_config,
|
||||
)
|
||||
if self.enable_storage_metrics:
|
||||
# TODO: support pp
|
||||
@@ -127,6 +137,53 @@ class HiRadixCache(RadixCache):
|
||||
eviction_policy=eviction_policy,
|
||||
)
|
||||
|
||||
def _parse_storage_backend_extra_config(
|
||||
self, storage_backend_extra_config: Optional[str]
|
||||
):
|
||||
"""
|
||||
Parse storage backend extra config JSON and extract specific parameters.
|
||||
|
||||
Args:
|
||||
storage_backend_extra_config: JSON string containing extra configuration
|
||||
|
||||
Returns:
|
||||
tuple: (extra_config_dict, prefetch_threshold, prefetch_timeout_base, prefetch_timeout_per_ki_token)
|
||||
"""
|
||||
# Parse extra config JSON if provided
|
||||
extra_config = {}
|
||||
if storage_backend_extra_config:
|
||||
try:
|
||||
extra_config = json.loads(storage_backend_extra_config)
|
||||
except Exception as e:
|
||||
logger.error(f"Invalid backend extra config JSON: {e}")
|
||||
raise e
|
||||
|
||||
prefetch_threshold = extra_config.pop("prefetch_threshold", 256) # tokens
|
||||
prefetch_timeout_base = extra_config.pop("prefetch_timeout_base", 1) # seconds
|
||||
prefetch_timeout_per_ki_token = extra_config.pop(
|
||||
"prefetch_timeout_per_ki_token", 0.25
|
||||
) # seconds per 1024 tokens
|
||||
|
||||
if not isinstance(prefetch_threshold, int):
|
||||
raise ValueError(
|
||||
f"prefetch_threshold must be int, got {type(prefetch_threshold).__name__}"
|
||||
)
|
||||
if not isinstance(prefetch_timeout_base, (int, float)):
|
||||
raise ValueError(
|
||||
f"prefetch_timeout_base must be number, got {type(prefetch_timeout_base).__name__}"
|
||||
)
|
||||
if not isinstance(prefetch_timeout_per_ki_token, (int, float)):
|
||||
raise ValueError(
|
||||
f"prefetch_timeout_per_ki_token must be number, got {type(prefetch_timeout_per_ki_token).__name__}"
|
||||
)
|
||||
|
||||
return (
|
||||
extra_config,
|
||||
prefetch_threshold,
|
||||
float(prefetch_timeout_base),
|
||||
float(prefetch_timeout_per_ki_token),
|
||||
)
|
||||
|
||||
def reset(self):
|
||||
TreeNode.counter = 0
|
||||
self.cache_controller.reset()
|
||||
@@ -490,6 +547,15 @@ class HiRadixCache(RadixCache):
|
||||
host_indices = torch.cat(host_indices_list, dim=0)
|
||||
cc.mem_pool_host.free(host_indices)
|
||||
|
||||
# Timeout is linearly increasing with the number of pages
|
||||
def _prefetch_timeout_check_linear_func(self, operation: PrefetchOperation):
|
||||
# If hash_value has not been computed in timeout_base seconds, terminate it.
|
||||
return (
|
||||
time.monotonic() - operation.start_time
|
||||
> self.prefetch_timeout_base
|
||||
+ len(operation.hash_value) * self.prefetch_timeout_per_page
|
||||
)
|
||||
|
||||
def can_terminate_prefetch(self, operation: PrefetchOperation):
|
||||
can_terminate = True
|
||||
|
||||
@@ -506,9 +572,7 @@ class HiRadixCache(RadixCache):
|
||||
if self.prefetch_stop_policy == "wait_complete":
|
||||
can_terminate = completed
|
||||
elif self.prefetch_stop_policy == "timeout":
|
||||
can_terminate = completed or (
|
||||
time.monotonic() - operation.start_time > self.prefetch_timeout
|
||||
)
|
||||
can_terminate = completed or self.is_prefetch_timeout(operation)
|
||||
else:
|
||||
# unknown prefetch stop policy, just return True
|
||||
return True
|
||||
|
||||
Reference in New Issue
Block a user