diff --git a/python/sglang/srt/managers/cache_controller.py b/python/sglang/srt/managers/cache_controller.py index 49ee74671..f36d61ee0 100644 --- a/python/sglang/srt/managers/cache_controller.py +++ b/python/sglang/srt/managers/cache_controller.py @@ -22,7 +22,10 @@ from typing import TYPE_CHECKING, List, NamedTuple, Optional, Set, Tuple import torch -from sglang.srt.mem_cache.hicache_storage import HiCacheStorageConfig +from sglang.srt.mem_cache.hicache_storage import ( + HiCacheStorageConfig, + HiCacheStorageExtraInfo, +) if TYPE_CHECKING: from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator @@ -191,12 +194,14 @@ class StorageOperation: token_ids: List[int], last_hash: Optional[str] = None, hash_value: Optional[List[str]] = None, + prefix_keys: Optional[List[str]] = None, ): self.host_indices = host_indices self.token_ids = token_ids self.last_hash = last_hash self.completed_tokens = 0 self.hash_value = hash_value if hash_value is not None else [] + self.prefix_keys = prefix_keys self.id = StorageOperation.counter StorageOperation.counter += 1 @@ -212,6 +217,7 @@ class PrefetchOperation(StorageOperation): host_indices: torch.Tensor, token_ids: List[int], last_hash: Optional[str] = None, + prefix_keys: Optional[List[str]] = None, ): self.request_id = request_id @@ -219,7 +225,7 @@ class PrefetchOperation(StorageOperation): self._terminated_flag = False self.start_time = time.monotonic() - super().__init__(host_indices, token_ids, last_hash) + super().__init__(host_indices, token_ids, last_hash, prefix_keys=prefix_keys) def increment(self, num_tokens: int): with self._lock: @@ -550,12 +556,13 @@ class HiCacheController: host_indices: torch.Tensor, new_input_tokens: List[int], last_hash: Optional[str] = None, + prefix_keys: Optional[List[str]] = None, ) -> PrefetchOperation: """ Prefetch KV caches from storage backend to host memory. """ operation = PrefetchOperation( - request_id, host_indices, new_input_tokens, last_hash + request_id, host_indices, new_input_tokens, last_hash, prefix_keys ) self.prefetch_queue.put(operation) return operation @@ -571,8 +578,12 @@ class HiCacheController: for page in pages: self.host_mem_release_queue.put(page) - def _page_get_zero_copy(self, operation, hash_values, host_indices): - results = self.storage_backend.batch_get_v1(hash_values, host_indices) + def _page_get_zero_copy( + self, operation, hash_values, host_indices, extra_info=None + ): + results = self.storage_backend.batch_get_v1( + hash_values, host_indices, extra_info + ) inc = 0 for i in range(len(hash_values)): if not results[i]: @@ -584,7 +595,7 @@ class HiCacheController: operation.increment(inc) # todo: deprecate - def _generic_page_get(self, operation, hash_values, host_indices): + def _generic_page_get(self, operation, hash_values, host_indices, extra_info=None): dummy_page_dst = [ self.mem_pool_host.get_dummy_flat_data_page() for _ in hash_values ] @@ -608,6 +619,7 @@ class HiCacheController: def _page_transfer(self, operation): # Transfer batch by batch + prefix_keys = operation.prefix_keys for i in range(0, len(operation.hash_value), self.storage_batch_size): batch_hashes = operation.hash_value[i : i + self.storage_batch_size] batch_host_indices = operation.host_indices[ @@ -615,7 +627,8 @@ class HiCacheController: ] prev_completed_tokens = operation.completed_tokens # Get one batch token, and update the completed_tokens if succeed - self.page_get_func(operation, batch_hashes, batch_host_indices) + extra_info = HiCacheStorageExtraInfo(prefix_keys=prefix_keys) + self.page_get_func(operation, batch_hashes, batch_host_indices, extra_info) # Check termination if ( operation.completed_tokens @@ -623,6 +636,10 @@ class HiCacheController: ): operation.mark_terminate() break # Some operations fail or operation terminated by controller + + if prefix_keys and len(prefix_keys) > 0: + prefix_keys += batch_hashes + # release pre-allocated memory self.append_host_mem_release( operation.host_indices[operation.completed_tokens :] @@ -656,6 +673,7 @@ class HiCacheController: def _storage_hit_query(self, operation) -> tuple[list[str], int]: last_hash = operation.last_hash tokens_to_fetch = operation.token_ids + prefix_keys = operation.prefix_keys.copy() if operation.prefix_keys else None storage_query_count = 0 hash_value = [] @@ -673,11 +691,15 @@ class HiCacheController: batch_tokens[i : i + self.page_size], last_hash ) batch_hashes.append(last_hash) - hit_page_num = self.storage_backend.batch_exists(batch_hashes) + extra_info = HiCacheStorageExtraInfo(prefix_keys=prefix_keys) + hit_page_num = self.storage_backend.batch_exists(batch_hashes, extra_info) hash_value.extend(batch_hashes[:hit_page_num]) storage_query_count += hit_page_num * self.page_size if hit_page_num < len(batch_hashes): break + if prefix_keys and len(prefix_keys) > 0: + prefix_keys += batch_hashes + return hash_value, storage_query_count def prefetch_thread_func(self): @@ -734,28 +756,34 @@ class HiCacheController: host_indices: torch.Tensor, token_ids: List[int], hash_value: Optional[List[str]] = None, + prefix_keys: Optional[List[str]] = None, ) -> int: """ Write KV caches from host memory to storage backend. """ - operation = StorageOperation(host_indices, token_ids, hash_value=hash_value) + operation = StorageOperation( + host_indices, token_ids, hash_value=hash_value, prefix_keys=prefix_keys + ) self.backup_queue.put(operation) return operation.id # todo: deprecate - def _generic_page_set(self, hash_values, host_indices) -> bool: + def _generic_page_set(self, hash_values, host_indices, extra_info=None) -> bool: data = [ self.mem_pool_host.get_data_page(host_indices[i * self.page_size]) for i in range(len(hash_values)) ] return self.storage_backend.batch_set(hash_values, data) - def _page_set_zero_copy(self, hash_values, host_indices) -> bool: - return all(self.storage_backend.batch_set_v1(hash_values, host_indices)) + def _page_set_zero_copy(self, hash_values, host_indices, extra_info=None) -> bool: + return all( + self.storage_backend.batch_set_v1(hash_values, host_indices, extra_info) + ) # Backup batch by batch def _page_backup(self, operation): # Backup batch by batch + prefix_keys = operation.prefix_keys for i in range(0, len(operation.hash_value), self.storage_batch_size): batch_hashes = operation.hash_value[i : i + self.storage_batch_size] batch_host_indices = operation.host_indices[ @@ -763,12 +791,16 @@ class HiCacheController: ] # Set one batch token, and record if success. # todo: allow partial success - success = self.page_set_func(batch_hashes, batch_host_indices) + extra_info = HiCacheStorageExtraInfo(prefix_keys=prefix_keys) + success = self.page_set_func(batch_hashes, batch_host_indices, extra_info) if not success: logger.warning( f"Write page to storage: {len(batch_hashes)} pages failed." ) break + + if prefix_keys and len(prefix_keys) > 0: + prefix_keys += batch_hashes operation.completed_tokens += self.page_size * len(batch_hashes) def backup_thread_func(self): diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index d2bc3c056..d4c8d5902 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -1491,8 +1491,18 @@ class Scheduler( last_hash = req.last_host_node.get_last_hash_value() matched_len = len(req.prefix_indices) + req.host_hit_length new_input_tokens = req.fill_ids[matched_len:] + + prefix_keys = ( + req.last_node.get_prefix_hash_values(req.last_node.parent) + if self.tree_cache.hicache_storage_pass_prefix_keys + else None + ) self.tree_cache.prefetch_from_storage( - req.rid, req.last_host_node, new_input_tokens, last_hash + req.rid, + req.last_host_node, + new_input_tokens, + last_hash, + prefix_keys, ) def _add_request_to_queue(self, req: Req, is_retracted: bool = False): diff --git a/python/sglang/srt/mem_cache/hicache_storage.py b/python/sglang/srt/mem_cache/hicache_storage.py index 8b21446b9..ac9cb2917 100644 --- a/python/sglang/srt/mem_cache/hicache_storage.py +++ b/python/sglang/srt/mem_cache/hicache_storage.py @@ -36,6 +36,7 @@ class HiCacheStorageConfig: @dataclass class HiCacheStorageExtraInfo: + prefix_keys: Optional[List[str]] = (None,) extra_info: Optional[dict] = None @@ -139,7 +140,9 @@ class HiCacheStorage(ABC): pass # TODO: Use a finer-grained return type (e.g., List[bool]) - def batch_exists(self, keys: List[str]) -> int: + def batch_exists( + self, keys: List[str], extra_info: Optional[HiCacheStorageExtraInfo] = None + ) -> int: """ Check if the keys exist in the storage. return the number of consecutive existing keys from the start. diff --git a/python/sglang/srt/mem_cache/hiradix_cache.py b/python/sglang/srt/mem_cache/hiradix_cache.py index 1a8b6accc..6ea4e1ba9 100644 --- a/python/sglang/srt/mem_cache/hiradix_cache.py +++ b/python/sglang/srt/mem_cache/hiradix_cache.py @@ -84,12 +84,14 @@ class HiRadixCache(RadixCache): prefetch_threshold, prefetch_timeout_base, prefetch_timeout_per_ki_token, + hicache_storage_pass_prefix_keys, ) = self._parse_storage_backend_extra_config(storage_backend_extra_config) self.prefetch_threshold = prefetch_threshold self.prefetch_timeout_base = prefetch_timeout_base self.prefetch_timeout_per_page = ( page_size / 1024 * prefetch_timeout_per_ki_token ) + self.hicache_storage_pass_prefix_keys = hicache_storage_pass_prefix_keys # TODO: support more timeout check functions self.is_prefetch_timeout = self._prefetch_timeout_check_linear_func self.prefetch_stop_policy = hicache_storage_prefetch_policy @@ -149,7 +151,7 @@ class HiRadixCache(RadixCache): storage_backend_extra_config: JSON string containing extra configuration Returns: - tuple: (extra_config_dict, prefetch_threshold, prefetch_timeout_base, prefetch_timeout_per_ki_token) + tuple: (extra_config_dict, prefetch_threshold, prefetch_timeout_base, prefetch_timeout_per_ki_token, hicache_storage_pass_prefix_keys) """ # Parse extra config JSON if provided extra_config = {} @@ -165,6 +167,9 @@ class HiRadixCache(RadixCache): prefetch_timeout_per_ki_token = extra_config.pop( "prefetch_timeout_per_ki_token", 0.25 ) # seconds per 1024 tokens + hicache_storage_pass_prefix_keys = extra_config.pop( + "hicache_storage_pass_prefix_keys", False + ) if not isinstance(prefetch_threshold, int): raise ValueError( @@ -184,6 +189,7 @@ class HiRadixCache(RadixCache): prefetch_threshold, float(prefetch_timeout_base), float(prefetch_timeout_per_ki_token), + hicache_storage_pass_prefix_keys, ) def reset(self): @@ -245,8 +251,14 @@ class HiRadixCache(RadixCache): return len(host_indices) def write_backup_storage(self, node: TreeNode): + prefix_keys = ( + node.get_prefix_hash_values(node.parent) + if self.hicache_storage_pass_prefix_keys + else None + ) + operation_id = self.cache_controller.write_storage( - node.host_value, node.key, node.hash_value + node.host_value, node.key, node.hash_value, prefix_keys ) self.ongoing_backup[operation_id] = node node.protect_host() @@ -700,6 +712,7 @@ class HiRadixCache(RadixCache): last_host_node: TreeNode, new_input_tokens: List[int], last_hash: Optional[str] = None, + prefix_keys: Optional[List[str]] = None, ): # align the number of fetching tokens to the page size prefetch_length = len(new_input_tokens) - ( @@ -723,7 +736,7 @@ class HiRadixCache(RadixCache): # no sufficient host memory for prefetch return operation = self.cache_controller.prefetch( - req_id, host_indices, new_input_tokens, last_hash + req_id, host_indices, new_input_tokens, last_hash, prefix_keys ) self.ongoing_prefetch[req_id] = ( last_host_node, diff --git a/python/sglang/srt/mem_cache/radix_cache.py b/python/sglang/srt/mem_cache/radix_cache.py index dac120016..bed7923f6 100644 --- a/python/sglang/srt/mem_cache/radix_cache.py +++ b/python/sglang/srt/mem_cache/radix_cache.py @@ -22,7 +22,7 @@ The radix tree data structure for managing the KV cache. import heapq import time from collections import defaultdict -from functools import partial +from functools import lru_cache, partial from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Tuple, Union import torch @@ -114,6 +114,13 @@ class TreeNode: return None return self.hash_value[-1] + @lru_cache(maxsize=1) + def get_prefix_hash_values(self, node: TreeNode) -> List[str]: + if node is None or node.hash_value is None: + return [] + + return node.get_prefix_hash_values(node.parent) + node.hash_value + def __lt__(self, other: "TreeNode"): return self.last_access_time < other.last_access_time diff --git a/python/sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py b/python/sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py index 59aacc11d..bcc827109 100644 --- a/python/sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +++ b/python/sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py @@ -13,7 +13,11 @@ from aibrix_kvcache import ( ) from aibrix_kvcache.common.absl_logging import log_every_n_seconds -from sglang.srt.mem_cache.hicache_storage import HiCacheStorage, HiCacheStorageConfig +from sglang.srt.mem_cache.hicache_storage import ( + HiCacheStorage, + HiCacheStorageConfig, + HiCacheStorageExtraInfo, +) from sglang.srt.mem_cache.memory_pool_host import HostKVCache logger = logging.getLogger(__name__) @@ -140,7 +144,9 @@ class AibrixKVCacheStorage(HiCacheStorage): ) -> bool: return self.batch_set([key], [value], [target_location], [target_size]) - def batch_exists(self, keys: List[str]) -> int: + def batch_exists( + self, keys: List[str], extra_info: Optional[HiCacheStorageExtraInfo] = None + ) -> int: block_hash = BlockHashes(keys, self.page_size) status = self.kv_cache_manager.exists(None, block_hash) if status.is_ok(): diff --git a/python/sglang/srt/mem_cache/storage/eic/eic_storage.py b/python/sglang/srt/mem_cache/storage/eic/eic_storage.py index 4c4ea89eb..0acd5b65f 100644 --- a/python/sglang/srt/mem_cache/storage/eic/eic_storage.py +++ b/python/sglang/srt/mem_cache/storage/eic/eic_storage.py @@ -408,7 +408,9 @@ class EICStorage(HiCacheStorage): exist_num = self.batch_exists([key]) return exist_num == 1 - def batch_exists(self, keys) -> int: + def batch_exists( + self, keys, extra_info: Optional[HiCacheStorageExtraInfo] = None + ) -> int: if len(keys) == 0: return 0 if self.use_zero_copy and not self.is_mla_model: diff --git a/python/sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py b/python/sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py index 70b2203e0..1f8c58dbd 100644 --- a/python/sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +++ b/python/sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py @@ -454,7 +454,9 @@ class HiCacheHF3FS(HiCacheStorage): result = self.metadata_client.exists(self.rank, [key]) return result[0] if result else False - def batch_exists(self, keys: List[str]) -> int: + def batch_exists( + self, keys: List[str], extra_info: Optional[HiCacheStorageExtraInfo] = None + ) -> int: factor = 1 if self.is_zero_copy and not self.is_mla_model: keys = self._get_mha_zero_copy_keys(keys) diff --git a/python/sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py b/python/sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py index 0b9db07f7..764f256ef 100644 --- a/python/sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +++ b/python/sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py @@ -399,7 +399,9 @@ class MooncakeStore(HiCacheStorage): exist_result = self._batch_exist([key]) return exist_result[0] == 1 - def batch_exists(self, keys) -> int: + def batch_exists( + self, keys, extra_info: Optional[HiCacheStorageExtraInfo] = None + ) -> int: if self.is_mla_backend: query_keys = [f"{key}_k" for key in keys] key_multiplier = 1 diff --git a/test/srt/hicache/test_hicache_storage_3fs_backend.py b/test/srt/hicache/test_hicache_storage_3fs_backend.py index 493ec05da..362da4b73 100644 --- a/test/srt/hicache/test_hicache_storage_3fs_backend.py +++ b/test/srt/hicache/test_hicache_storage_3fs_backend.py @@ -29,6 +29,7 @@ class HiCacheStorage3FSBackendBaseMixin(HiCacheStorageBaseMixin): "numjobs": 2, "entries": 8, "use_mock_hf3fs_client": True, + "hicache_storage_pass_prefix_keys": True, } # Write config to temporary file diff --git a/test/srt/hicache/test_hicache_storage_file_backend.py b/test/srt/hicache/test_hicache_storage_file_backend.py index 8708fe4a0..382db07b3 100644 --- a/test/srt/hicache/test_hicache_storage_file_backend.py +++ b/test/srt/hicache/test_hicache_storage_file_backend.py @@ -4,6 +4,7 @@ Usage: python3 -m pytest test/srt/hicache/test_hicache_storage_e2e.py -v """ +import json import os import random import tempfile @@ -70,6 +71,9 @@ class HiCacheStorageBaseMixin: @classmethod def _get_base_server_args(cls): """Get base server arguments - can be extended in subclasses""" + extra_config = { + "hicache_storage_pass_prefix_keys": True, + } return { "--enable-hierarchical-cache": True, "--mem-fraction-static": 0.6, @@ -78,6 +82,7 @@ class HiCacheStorageBaseMixin: "--enable-cache-report": True, "--hicache-storage-prefetch-policy": "wait_complete", "--hicache-storage-backend": "file", + "--hicache-storage-backend-extra-config": json.dumps(extra_config), } @classmethod