diff --git a/python/sglang/srt/managers/cache_controller.py b/python/sglang/srt/managers/cache_controller.py index c7d0218e5..49ee74671 100644 --- a/python/sglang/srt/managers/cache_controller.py +++ b/python/sglang/srt/managers/cache_controller.py @@ -250,7 +250,7 @@ class HiCacheController: storage_backend: Optional[str] = None, prefetch_threshold: int = 256, model_name: Optional[str] = None, - storage_backend_extra_config: Optional[str] = None, + storage_backend_extra_config: Optional[dict] = None, ): self.mem_pool_device_allocator = token_to_kv_pool_allocator self.mem_pool_device = token_to_kv_pool_allocator.get_kvcache() @@ -361,7 +361,7 @@ class HiCacheController: def _generate_storage_config( self, model_name: Optional[str] = None, - storage_backend_extra_config: Optional[str] = None, + storage_backend_extra_config: Optional[dict] = None, ): if is_dp_attention_enabled(): @@ -376,23 +376,13 @@ class HiCacheController: # Currently, AscendMLAPagedTokenToKVPool is the subclass of MLATokenToKVPool. is_mla_backend = isinstance(self.mem_pool_device, MLATokenToKVPool) - # Parse extra config JSON if provided - extra_config = None - if storage_backend_extra_config: - try: - import json - - extra_config = json.loads(storage_backend_extra_config) - except Exception as e: - logger.error(f"Invalid backend extra config JSON: {e}") - return HiCacheStorageConfig( tp_rank=self.tp_rank, tp_size=self.tp_size, is_mla_model=is_mla_backend, is_page_first_layout=self.mem_pool_host.layout == "page_first", model_name=model_name, - extra_config=extra_config, + extra_config=storage_backend_extra_config, ) def reset(self): diff --git a/python/sglang/srt/mem_cache/hiradix_cache.py b/python/sglang/srt/mem_cache/hiradix_cache.py index 75ff08fd6..c3d6342d9 100644 --- a/python/sglang/srt/mem_cache/hiradix_cache.py +++ b/python/sglang/srt/mem_cache/hiradix_cache.py @@ -1,8 +1,8 @@ import heapq +import json import logging import threading import time -from queue import Queue from typing import List, Optional import torch @@ -78,9 +78,19 @@ class HiRadixCache(RadixCache): self.enable_storage = hicache_storage_backend is not None self.enable_storage_metrics = self.enable_storage and enable_metrics - # todo: customizable storage prefetch threshold and timeout - self.prefetch_threshold = 256 - self.prefetch_timeout = 3 # seconds + ( + extra_config, + prefetch_threshold, + prefetch_timeout_base, + prefetch_timeout_per_ki_token, + ) = self._parse_storage_backend_extra_config(storage_backend_extra_config) + self.prefetch_threshold = prefetch_threshold + self.prefetch_timeout_base = prefetch_timeout_base + self.prefetch_timeout_per_page = ( + page_size / 1024 * prefetch_timeout_per_ki_token + ) + # TODO: support more timeout check functions + self.is_prefetch_timeout = self._prefetch_timeout_check_linear_func self.prefetch_stop_policy = hicache_storage_prefetch_policy self.load_cache_event = threading.Event() @@ -95,7 +105,7 @@ class HiRadixCache(RadixCache): storage_backend=hicache_storage_backend, prefetch_threshold=self.prefetch_threshold, model_name=model_name, - storage_backend_extra_config=storage_backend_extra_config, + storage_backend_extra_config=extra_config, ) if self.enable_storage_metrics: # TODO: support pp @@ -127,6 +137,53 @@ class HiRadixCache(RadixCache): eviction_policy=eviction_policy, ) + def _parse_storage_backend_extra_config( + self, storage_backend_extra_config: Optional[str] + ): + """ + Parse storage backend extra config JSON and extract specific parameters. + + Args: + storage_backend_extra_config: JSON string containing extra configuration + + Returns: + tuple: (extra_config_dict, prefetch_threshold, prefetch_timeout_base, prefetch_timeout_per_ki_token) + """ + # Parse extra config JSON if provided + extra_config = {} + if storage_backend_extra_config: + try: + extra_config = json.loads(storage_backend_extra_config) + except Exception as e: + logger.error(f"Invalid backend extra config JSON: {e}") + raise e + + prefetch_threshold = extra_config.pop("prefetch_threshold", 256) # tokens + prefetch_timeout_base = extra_config.pop("prefetch_timeout_base", 1) # seconds + prefetch_timeout_per_ki_token = extra_config.pop( + "prefetch_timeout_per_ki_token", 0.25 + ) # seconds per 1024 tokens + + if not isinstance(prefetch_threshold, int): + raise ValueError( + f"prefetch_threshold must be int, got {type(prefetch_threshold).__name__}" + ) + if not isinstance(prefetch_timeout_base, (int, float)): + raise ValueError( + f"prefetch_timeout_base must be number, got {type(prefetch_timeout_base).__name__}" + ) + if not isinstance(prefetch_timeout_per_ki_token, (int, float)): + raise ValueError( + f"prefetch_timeout_per_ki_token must be number, got {type(prefetch_timeout_per_ki_token).__name__}" + ) + + return ( + extra_config, + prefetch_threshold, + float(prefetch_timeout_base), + float(prefetch_timeout_per_ki_token), + ) + def reset(self): TreeNode.counter = 0 self.cache_controller.reset() @@ -490,6 +547,15 @@ class HiRadixCache(RadixCache): host_indices = torch.cat(host_indices_list, dim=0) cc.mem_pool_host.free(host_indices) + # Timeout is linearly increasing with the number of pages + def _prefetch_timeout_check_linear_func(self, operation: PrefetchOperation): + # If hash_value has not been computed in timeout_base seconds, terminate it. + return ( + time.monotonic() - operation.start_time + > self.prefetch_timeout_base + + len(operation.hash_value) * self.prefetch_timeout_per_page + ) + def can_terminate_prefetch(self, operation: PrefetchOperation): can_terminate = True @@ -506,9 +572,7 @@ class HiRadixCache(RadixCache): if self.prefetch_stop_policy == "wait_complete": can_terminate = completed elif self.prefetch_stop_policy == "timeout": - can_terminate = completed or ( - time.monotonic() - operation.start_time > self.prefetch_timeout - ) + can_terminate = completed or self.is_prefetch_timeout(operation) else: # unknown prefetch stop policy, just return True return True