[HiCache] Configurable and Dynamic Prefetch Timeout (#10512)
Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
This commit is contained in:
@@ -250,7 +250,7 @@ class HiCacheController:
|
|||||||
storage_backend: Optional[str] = None,
|
storage_backend: Optional[str] = None,
|
||||||
prefetch_threshold: int = 256,
|
prefetch_threshold: int = 256,
|
||||||
model_name: Optional[str] = None,
|
model_name: Optional[str] = None,
|
||||||
storage_backend_extra_config: Optional[str] = None,
|
storage_backend_extra_config: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
self.mem_pool_device_allocator = token_to_kv_pool_allocator
|
self.mem_pool_device_allocator = token_to_kv_pool_allocator
|
||||||
self.mem_pool_device = token_to_kv_pool_allocator.get_kvcache()
|
self.mem_pool_device = token_to_kv_pool_allocator.get_kvcache()
|
||||||
@@ -361,7 +361,7 @@ class HiCacheController:
|
|||||||
def _generate_storage_config(
|
def _generate_storage_config(
|
||||||
self,
|
self,
|
||||||
model_name: Optional[str] = None,
|
model_name: Optional[str] = None,
|
||||||
storage_backend_extra_config: Optional[str] = None,
|
storage_backend_extra_config: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
|
|
||||||
if is_dp_attention_enabled():
|
if is_dp_attention_enabled():
|
||||||
@@ -376,23 +376,13 @@ class HiCacheController:
|
|||||||
# Currently, AscendMLAPagedTokenToKVPool is the subclass of MLATokenToKVPool.
|
# Currently, AscendMLAPagedTokenToKVPool is the subclass of MLATokenToKVPool.
|
||||||
is_mla_backend = isinstance(self.mem_pool_device, MLATokenToKVPool)
|
is_mla_backend = isinstance(self.mem_pool_device, MLATokenToKVPool)
|
||||||
|
|
||||||
# Parse extra config JSON if provided
|
|
||||||
extra_config = None
|
|
||||||
if storage_backend_extra_config:
|
|
||||||
try:
|
|
||||||
import json
|
|
||||||
|
|
||||||
extra_config = json.loads(storage_backend_extra_config)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Invalid backend extra config JSON: {e}")
|
|
||||||
|
|
||||||
return HiCacheStorageConfig(
|
return HiCacheStorageConfig(
|
||||||
tp_rank=self.tp_rank,
|
tp_rank=self.tp_rank,
|
||||||
tp_size=self.tp_size,
|
tp_size=self.tp_size,
|
||||||
is_mla_model=is_mla_backend,
|
is_mla_model=is_mla_backend,
|
||||||
is_page_first_layout=self.mem_pool_host.layout == "page_first",
|
is_page_first_layout=self.mem_pool_host.layout == "page_first",
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
extra_config=extra_config,
|
extra_config=storage_backend_extra_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
import heapq
|
import heapq
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from queue import Queue
|
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -78,9 +78,19 @@ class HiRadixCache(RadixCache):
|
|||||||
self.enable_storage = hicache_storage_backend is not None
|
self.enable_storage = hicache_storage_backend is not None
|
||||||
self.enable_storage_metrics = self.enable_storage and enable_metrics
|
self.enable_storage_metrics = self.enable_storage and enable_metrics
|
||||||
|
|
||||||
# todo: customizable storage prefetch threshold and timeout
|
(
|
||||||
self.prefetch_threshold = 256
|
extra_config,
|
||||||
self.prefetch_timeout = 3 # seconds
|
prefetch_threshold,
|
||||||
|
prefetch_timeout_base,
|
||||||
|
prefetch_timeout_per_ki_token,
|
||||||
|
) = self._parse_storage_backend_extra_config(storage_backend_extra_config)
|
||||||
|
self.prefetch_threshold = prefetch_threshold
|
||||||
|
self.prefetch_timeout_base = prefetch_timeout_base
|
||||||
|
self.prefetch_timeout_per_page = (
|
||||||
|
page_size / 1024 * prefetch_timeout_per_ki_token
|
||||||
|
)
|
||||||
|
# TODO: support more timeout check functions
|
||||||
|
self.is_prefetch_timeout = self._prefetch_timeout_check_linear_func
|
||||||
self.prefetch_stop_policy = hicache_storage_prefetch_policy
|
self.prefetch_stop_policy = hicache_storage_prefetch_policy
|
||||||
|
|
||||||
self.load_cache_event = threading.Event()
|
self.load_cache_event = threading.Event()
|
||||||
@@ -95,7 +105,7 @@ class HiRadixCache(RadixCache):
|
|||||||
storage_backend=hicache_storage_backend,
|
storage_backend=hicache_storage_backend,
|
||||||
prefetch_threshold=self.prefetch_threshold,
|
prefetch_threshold=self.prefetch_threshold,
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
storage_backend_extra_config=storage_backend_extra_config,
|
storage_backend_extra_config=extra_config,
|
||||||
)
|
)
|
||||||
if self.enable_storage_metrics:
|
if self.enable_storage_metrics:
|
||||||
# TODO: support pp
|
# TODO: support pp
|
||||||
@@ -127,6 +137,53 @@ class HiRadixCache(RadixCache):
|
|||||||
eviction_policy=eviction_policy,
|
eviction_policy=eviction_policy,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _parse_storage_backend_extra_config(
|
||||||
|
self, storage_backend_extra_config: Optional[str]
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Parse storage backend extra config JSON and extract specific parameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
storage_backend_extra_config: JSON string containing extra configuration
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (extra_config_dict, prefetch_threshold, prefetch_timeout_base, prefetch_timeout_per_ki_token)
|
||||||
|
"""
|
||||||
|
# Parse extra config JSON if provided
|
||||||
|
extra_config = {}
|
||||||
|
if storage_backend_extra_config:
|
||||||
|
try:
|
||||||
|
extra_config = json.loads(storage_backend_extra_config)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Invalid backend extra config JSON: {e}")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
prefetch_threshold = extra_config.pop("prefetch_threshold", 256) # tokens
|
||||||
|
prefetch_timeout_base = extra_config.pop("prefetch_timeout_base", 1) # seconds
|
||||||
|
prefetch_timeout_per_ki_token = extra_config.pop(
|
||||||
|
"prefetch_timeout_per_ki_token", 0.25
|
||||||
|
) # seconds per 1024 tokens
|
||||||
|
|
||||||
|
if not isinstance(prefetch_threshold, int):
|
||||||
|
raise ValueError(
|
||||||
|
f"prefetch_threshold must be int, got {type(prefetch_threshold).__name__}"
|
||||||
|
)
|
||||||
|
if not isinstance(prefetch_timeout_base, (int, float)):
|
||||||
|
raise ValueError(
|
||||||
|
f"prefetch_timeout_base must be number, got {type(prefetch_timeout_base).__name__}"
|
||||||
|
)
|
||||||
|
if not isinstance(prefetch_timeout_per_ki_token, (int, float)):
|
||||||
|
raise ValueError(
|
||||||
|
f"prefetch_timeout_per_ki_token must be number, got {type(prefetch_timeout_per_ki_token).__name__}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return (
|
||||||
|
extra_config,
|
||||||
|
prefetch_threshold,
|
||||||
|
float(prefetch_timeout_base),
|
||||||
|
float(prefetch_timeout_per_ki_token),
|
||||||
|
)
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
TreeNode.counter = 0
|
TreeNode.counter = 0
|
||||||
self.cache_controller.reset()
|
self.cache_controller.reset()
|
||||||
@@ -490,6 +547,15 @@ class HiRadixCache(RadixCache):
|
|||||||
host_indices = torch.cat(host_indices_list, dim=0)
|
host_indices = torch.cat(host_indices_list, dim=0)
|
||||||
cc.mem_pool_host.free(host_indices)
|
cc.mem_pool_host.free(host_indices)
|
||||||
|
|
||||||
|
# Timeout is linearly increasing with the number of pages
|
||||||
|
def _prefetch_timeout_check_linear_func(self, operation: PrefetchOperation):
|
||||||
|
# If hash_value has not been computed in timeout_base seconds, terminate it.
|
||||||
|
return (
|
||||||
|
time.monotonic() - operation.start_time
|
||||||
|
> self.prefetch_timeout_base
|
||||||
|
+ len(operation.hash_value) * self.prefetch_timeout_per_page
|
||||||
|
)
|
||||||
|
|
||||||
def can_terminate_prefetch(self, operation: PrefetchOperation):
|
def can_terminate_prefetch(self, operation: PrefetchOperation):
|
||||||
can_terminate = True
|
can_terminate = True
|
||||||
|
|
||||||
@@ -506,9 +572,7 @@ class HiRadixCache(RadixCache):
|
|||||||
if self.prefetch_stop_policy == "wait_complete":
|
if self.prefetch_stop_policy == "wait_complete":
|
||||||
can_terminate = completed
|
can_terminate = completed
|
||||||
elif self.prefetch_stop_policy == "timeout":
|
elif self.prefetch_stop_policy == "timeout":
|
||||||
can_terminate = completed or (
|
can_terminate = completed or self.is_prefetch_timeout(operation)
|
||||||
time.monotonic() - operation.start_time > self.prefetch_timeout
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
# unknown prefetch stop policy, just return True
|
# unknown prefetch stop policy, just return True
|
||||||
return True
|
return True
|
||||||
|
|||||||
Reference in New Issue
Block a user