From 9d33fcfb8e93c4a01fb39c6609c71f7104cb3371 Mon Sep 17 00:00:00 2001 From: Zhiqiang Xie Date: Fri, 18 Jul 2025 00:20:19 -0700 Subject: [PATCH] Hicache Storage Layer Prototype (#7704) --- .../sglang/srt/managers/cache_controller.py | 241 ++++++++++++++++++ python/sglang/srt/managers/scheduler.py | 14 + .../sglang/srt/mem_cache/hicache_storage.py | 152 +++++++++++ python/sglang/srt/mem_cache/hiradix_cache.py | 183 ++++++++++++- .../sglang/srt/mem_cache/memory_pool_host.py | 38 +++ python/sglang/srt/mem_cache/radix_cache.py | 26 ++ python/sglang/srt/server_args.py | 8 + test/srt/run_suite.py | 1 + test/srt/test_hicache_storage.py | 55 ++++ 9 files changed, 714 insertions(+), 4 deletions(-) create mode 100644 python/sglang/srt/mem_cache/hicache_storage.py create mode 100644 test/srt/test_hicache_storage.py diff --git a/python/sglang/srt/managers/cache_controller.py b/python/sglang/srt/managers/cache_controller.py index cad1d74b7..5f43a5e9a 100644 --- a/python/sglang/srt/managers/cache_controller.py +++ b/python/sglang/srt/managers/cache_controller.py @@ -25,6 +25,8 @@ if TYPE_CHECKING: from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator from sglang.srt.mem_cache.memory_pool_host import HostKVCache +from sglang.srt.mem_cache.hicache_storage import HiCacheFile, get_hash_str + logger = logging.getLogger(__name__) @@ -159,6 +161,57 @@ class TransferBuffer: self.buffers.queue.clear() +class StorageOperation: + counter = 0 + + def __init__( + self, + host_indices: torch.Tensor, + token_ids: List[int], + last_hash: Optional[str] = None, + ): + self.host_indices = host_indices + self.token_ids = token_ids + self.last_hash = last_hash + self.completed_tokens = 0 + self.hash_value = [] + + self.id = StorageOperation.counter + StorageOperation.counter += 1 + + def __lt__(self, other: "StorageOperation"): + return self.id < other.id + + +class PrefetchOperation(StorageOperation): + def __init__( + self, + request_id: str, + host_indices: torch.Tensor, + token_ids: List[int], + last_hash: Optional[str] = None, + ): + self.request_id = request_id + + self._done_flag = False + self._lock = threading.Lock() + + super().__init__(host_indices, token_ids, last_hash) + + def increment(self, num_tokens: int): + with self._lock: + if self._done_flag: + return + self.completed_tokens += num_tokens + + def mark_done(self): + with self._lock: + self._done_flag = True + + def is_done(self) -> bool: + return self._done_flag + + class HiCacheController: def __init__( @@ -169,6 +222,8 @@ class HiCacheController: load_cache_event: threading.Event = None, write_policy: str = "write_through_selective", io_backend: str = "", + storage_backend: Optional[str] = None, + prefetch_threshold: int = 256, ): self.mem_pool_device_allocator = token_to_kv_pool_allocator self.mem_pool_device = token_to_kv_pool_allocator.get_kvcache() @@ -186,6 +241,19 @@ class HiCacheController: else: self.io_backend = io_backend + self.enable_storage = False + # todo: move backend initialization to storage backend module + if storage_backend is not None: + if storage_backend == "file": + self.storage_backend = HiCacheFile() + self.enable_storage = True + # todo: threshold policy for prefetching + self.prefetch_threshold = prefetch_threshold + else: + raise NotImplementedError( + f"Unsupported storage backend: {storage_backend}" + ) + self.load_cache_event = load_cache_event self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num) self.mem_pool_device.register_layer_transfer_counter(self.layer_done_counter) @@ -218,9 +286,26 @@ class HiCacheController: self.load_thread = threading.Thread( target=self.load_thread_func_layer_by_layer, daemon=True ) + self.write_thread.start() self.load_thread.start() + if self.enable_storage: + self.prefetch_thread = threading.Thread( + target=self.prefetch_thread_func, daemon=True + ) + self.backup_thread = threading.Thread( + target=self.backup_thread_func, daemon=True + ) + self.prefetch_queue = Queue() + self.backup_queue = Queue() + + self.prefetch_revoke_queue = Queue() + self.ack_backup_queue = Queue() + + self.prefetch_thread.start() + self.backup_thread.start() + def reset(self): self.stop_event.set() self.write_thread.join() @@ -232,6 +317,13 @@ class HiCacheController: self.load_buffer.clear() self.ack_write_queue.queue.clear() self.ack_load_queue.queue.clear() + if self.enable_storage: + self.prefetch_thread.join() + self.backup_thread.join() + self.prefetch_queue.queue.clear() + self.backup_queue.queue.clear() + self.prefetch_revoke_queue.queue.clear() + self.ack_backup_queue.queue.clear() self.write_thread = threading.Thread( target=self.write_thread_func_direct, daemon=True @@ -243,6 +335,16 @@ class HiCacheController: self.write_thread.start() self.load_thread.start() + if self.enable_storage: + self.prefetch_thread = threading.Thread( + target=self.prefetch_thread_func, daemon=True + ) + self.backup_thread = threading.Thread( + target=self.backup_thread_func, daemon=True + ) + self.prefetch_thread.start() + self.backup_thread.start() + def write( self, device_indices: torch.Tensor, @@ -383,3 +485,142 @@ class HiCacheController: raise ValueError( f"Inconsistent states: {self.mem_pool_host.get_state(host_indices)}" ) + + def prefetch( + self, + request_id: str, + host_indices: torch.Tensor, + new_input_tokens: List[int], + last_hash: Optional[str] = None, + ) -> int: + """ + Prefetch KV caches from storage backend to host memory. + """ + operation = PrefetchOperation( + request_id, host_indices, new_input_tokens, last_hash + ) + self.prefetch_queue.put(operation) + return operation + + def terminate_prefetch(self, operation): + operation.mark_done() + return operation.completed_tokens, operation.hash_value + + def prefetch_io_aux_func(self): + """ + Auxiliary function conducting IO operations for prefetching. + """ + while not self.stop_event.is_set(): + try: + operation = self.prefetch_buffer.get(block=True, timeout=1) + for h in operation.hash_value: + page_data = self.storage_backend.get(h) + if page_data is None: + logger.warning( + f"Prefetch operation {operation.request_id} failed to retrieve page {h}." + ) + break + self.mem_pool_host.set_from_flat_data_page( + operation.host_indices[operation.completed_tokens], + page_data, + ) + operation.increment(self.page_size) + if operation.is_done(): + # operation terminated by controller, release pre-allocated memory + self.mem_pool_host.free( + operation.host_indices[operation.completed_tokens :] + ) + break + except Empty: + continue + + def prefetch_thread_func(self): + """ + Manage prefetching operations from storage backend to host memory. + """ + self.prefetch_buffer = Queue() + aux_thread = threading.Thread(target=self.prefetch_io_aux_func, daemon=True) + aux_thread.start() + while (not self.stop_event.is_set()) or not self.prefetch_queue.empty(): + try: + operation = self.prefetch_queue.get(block=True, timeout=1) + if operation is None: + continue + + last_hash = operation.last_hash + tokens_to_fetch = operation.token_ids + + storage_hit_count = 0 + remaining_tokens = len(tokens_to_fetch) + hash_value = [] + while remaining_tokens >= self.page_size: + last_hash = get_hash_str( + tokens_to_fetch[ + storage_hit_count : storage_hit_count + self.page_size + ], + last_hash, + ) + if self.storage_backend.exists(last_hash): + storage_hit_count += self.page_size + hash_value.append(last_hash) + remaining_tokens -= self.page_size + else: + break + + if storage_hit_count < self.prefetch_threshold: + # not to prefetch if not enough benefits + self.prefetch_revoke_queue.put(operation.request_id) + else: + operation.hash_value = hash_value + logger.debug( + f"Prefetching {len(hash_value)} pages for request {operation.request_id}." + ) + self.prefetch_buffer.put(operation) + + except Empty: + continue + + def write_storage( + self, + host_indices: torch.Tensor, + token_ids: List[int], + last_hash: Optional[str] = None, + ) -> int: + """ + Write KV caches from host memory to storage backend. + """ + operation = StorageOperation(host_indices, token_ids, last_hash) + self.backup_queue.put(operation) + return operation.id + + def backup_thread_func(self): + """ + Manage backup operations from host memory to storage backend. + """ + while not self.stop_event.is_set(): + try: + operation = self.backup_queue.get(block=True, timeout=1) + if operation is None: + continue + + last_hash = operation.last_hash + tokens_to_backup = operation.token_ids + + for i in range(0, len(tokens_to_backup), self.page_size): + last_hash = get_hash_str( + tokens_to_backup[i : i + self.page_size], last_hash + ) + # todo, handle failures in storage backend + self.storage_backend.set( + last_hash, + self.mem_pool_host.get_flat_data_page( + operation.host_indices[i] + ), + ) + operation.completed_tokens += self.page_size + operation.hash_value.append(last_hash) + + self.ack_backup_queue.put((operation.id, operation.hash_value)) + + except Empty: + continue diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 874ed60f0..c79e296f6 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -262,6 +262,7 @@ class Scheduler( ) self.gpu_id = gpu_id self.enable_hierarchical_cache = server_args.enable_hierarchical_cache + self.enable_hicache_storage = server_args.hicache_storage_backend is not None self.page_size = server_args.page_size self.dp_size = server_args.dp_size self.attn_tp_rank, self.attn_tp_size, self.attn_dp_rank = ( @@ -614,6 +615,7 @@ class Scheduler( == "fa3" # hot fix for incompatibility else server_args.hicache_io_backend ), + hicache_storage_backend=server_args.hicache_storage_backend, ) self.tp_worker.register_hicache_layer_transfer_counter( self.tree_cache.cache_controller.layer_done_counter @@ -1258,6 +1260,15 @@ class Scheduler( elif self.disaggregation_mode == DisaggregationMode.DECODE: self.disagg_decode_prealloc_queue.add(req) else: + if self.enable_hicache_storage: + req.init_next_round_input(self.tree_cache) + last_hash = req.last_host_node.get_last_hash_value() + matched_len = len(req.prefix_indices) + req.host_hit_length + if (matched_len > 0 and last_hash is not None) or matched_len == 0: + new_input_tokens = req.fill_ids[matched_len:] + self.tree_cache.prefetch_from_storage( + req.rid, req.last_host_node, new_input_tokens, last_hash + ) self.waiting_queue.append(req) def _extend_requests_to_queue(self, reqs: List[Req], is_retracted: bool = False): @@ -1731,6 +1742,9 @@ class Scheduler( self.running_batch.batch_is_full = True break + if self.enable_hicache_storage: + self.tree_cache.check_prefetch_progress(req.rid) + req.init_next_round_input(self.tree_cache) res = adder.add_one_req(req, has_chunked_req=(self.chunked_req is not None)) diff --git a/python/sglang/srt/mem_cache/hicache_storage.py b/python/sglang/srt/mem_cache/hicache_storage.py new file mode 100644 index 000000000..1dfe661ab --- /dev/null +++ b/python/sglang/srt/mem_cache/hicache_storage.py @@ -0,0 +1,152 @@ +import hashlib +import logging +import os +from abc import ABC, abstractmethod +from typing import List, Optional + +import torch + +logger = logging.getLogger(__name__) + + +def get_hash_str(token_ids: List[int], prior_hash: Optional[str] = None) -> str: + hasher = hashlib.sha256() + + if prior_hash: + hasher.update(bytes.fromhex(prior_hash)) + + for t in token_ids: + hasher.update(t.to_bytes(4, byteorder="little", signed=False)) + + return hasher.hexdigest() + + +class HiCacheStorage(ABC): + """ + HiCacheStorage is a class that provides a generic key-value interface for storing and retrieving KV cache. + It abstracts the underlying storage mechanism, allowing different implementations to be used. + """ + + # todo, translate tensor object access for different TP ranks + # potentially pass model and TP configs into storage backend + # todo, the page size of storage backend does not have to be the same as the same as host memory pool + + @abstractmethod + def get( + self, key: str, target_location: Optional[torch.Tensor] = None + ) -> torch.Tensor | None: + """ + Retrieve the value associated with the given key. + Returns None if the key does not exist. + """ + pass + + @abstractmethod + def batch_get( + self, keys: List[str], target_locations: Optional[List[torch.Tensor]] = None + ) -> List[torch.Tensor | None]: + """ + Retrieve values for multiple keys. + Returns a list of tensors or None for each key. + """ + pass + + @abstractmethod + def set(self, key, value) -> bool: + """ + Store the value associated with the given key. + Returns True if the operation was successful, False otherwise. + """ + pass + + @abstractmethod + def batch_set(self, keys: List[str], values: List[torch.Tensor]) -> bool: + """ + Store multiple key-value pairs. + Returns True if all operations were successful, False otherwise. + """ + pass + + @abstractmethod + def exists(self, key: str) -> bool: + """ + Check if the key exists in the storage. + Returns True if the key exists, False otherwise. + """ + pass + + +class HiCacheFile(HiCacheStorage): + + def __init__(self, file_path: str = "/tmp/hicache"): + self.file_path = file_path + if not os.path.exists(self.file_path): + os.makedirs(self.file_path) + logger.info(f"Created HiCacheFile storage directory at {self.file_path}") + + def get( + self, key: str, target_location: Optional[torch.Tensor] = None + ) -> torch.Tensor | None: + tensor_path = os.path.join(self.file_path, f"{key}.bin") + try: + # todo: fixing the target_location logic to enable in-place loading + loaded_tensor = torch.load(tensor_path) + if isinstance(loaded_tensor, torch.Tensor): + return loaded_tensor + else: + logger.error(f"Loaded data for key {key} is not a tensor.") + return None + except FileNotFoundError: + return None + + def batch_get( + self, + keys: List[str], + target_locations: Optional[List[torch.Tensor]] = None, + ) -> List[torch.Tensor | None]: + return [ + self.get(key, target_location) + for key, target_location in zip( + keys, target_locations or [None] * len(keys) + ) + ] + + def set(self, key: str, value: torch.Tensor) -> bool: + tensor_path = os.path.join(self.file_path, f"{key}.bin") + if self.exists(key): + logger.debug(f"Key {key} already exists. Skipped.") + return True + try: + torch.save(value, tensor_path) + return True + except Exception as e: + logger.error(f"Failed to save tensor {key}: {e}") + return False + + def batch_set(self, keys: List[str], values: List[torch.Tensor]) -> bool: + for key, value in zip(keys, values): + if not self.set(key, value): + return False + return True + + def exists(self, key: str) -> bool: + tensor_path = os.path.join(self.file_path, f"{key}.bin") + return os.path.exists(tensor_path) + + def delete(self, key: str) -> None: + tensor_path = os.path.join(self.file_path, f"{key}.bin") + try: + os.remove(tensor_path) + except FileNotFoundError: + logger.warning(f"Key {key} does not exist. Cannot delete.") + return + + def clear(self) -> None: + try: + for filename in os.listdir(self.file_path): + file_path = os.path.join(self.file_path, filename) + if os.path.isfile(file_path): + os.remove(file_path) + logger.info("Cleared all entries in HiCacheFile storage.") + except Exception as e: + logger.error(f"Failed to clear HiCacheFile storage: {e}") diff --git a/python/sglang/srt/mem_cache/hiradix_cache.py b/python/sglang/srt/mem_cache/hiradix_cache.py index cb7d95558..796f0553c 100644 --- a/python/sglang/srt/mem_cache/hiradix_cache.py +++ b/python/sglang/srt/mem_cache/hiradix_cache.py @@ -35,6 +35,7 @@ class HiRadixCache(RadixCache): hicache_size: int, hicache_write_policy: str, hicache_io_backend: str, + hicache_storage_backend: Optional[str] = None, ): self.kv_cache = token_to_kv_pool_allocator.get_kvcache() if isinstance(self.kv_cache, MHATokenToKVPool): @@ -49,6 +50,9 @@ class HiRadixCache(RadixCache): raise ValueError(f"HiRadixCache only supports MHA and MLA yet") self.tp_group = tp_cache_group + self.enable_storage = hicache_storage_backend is not None + # todo: customizable storage prefetch threshold + self.prefetch_threshold = 256 self.load_cache_event = threading.Event() self.cache_controller = HiCacheController( @@ -58,16 +62,22 @@ class HiRadixCache(RadixCache): load_cache_event=self.load_cache_event, write_policy=hicache_write_policy, io_backend=hicache_io_backend, + storage_backend=hicache_storage_backend, + prefetch_threshold=self.prefetch_threshold, ) # record the nodes with ongoing write through self.ongoing_write_through = {} # record the node segments with ongoing load back self.ongoing_load_back = {} + # record the ongoing prefetch requests + self.ongoing_prefetch = {} + self.ongoing_backup = {} # todo: dynamically adjust the threshold self.write_through_threshold = ( 1 if hicache_write_policy == "write_through" else 3 ) + self.write_through_threshold_storage = 3 self.load_back_threshold = 10 super().__init__( req_to_token_pool, token_to_kv_pool_allocator, page_size, disable=False @@ -108,13 +118,30 @@ class HiRadixCache(RadixCache): return len(host_indices) + def write_backup_storage(self, node: TreeNode): + operation_id = self.cache_controller.write_storage( + node.host_value, node.key, node.parent.get_last_hash_value() + ) + self.ongoing_backup[operation_id] = node + node.protect_host() + def inc_hit_count(self, node: TreeNode): - if node.backuped or self.cache_controller.write_policy == "write_back": + if self.cache_controller.write_policy == "write_back": return node.hit_count += 1 - if node.hit_count >= self.write_through_threshold: - self.write_backup(node) - node.hit_count = 0 + + if not node.backuped: + if node.hit_count >= self.write_through_threshold: + # write to host if the node is not backuped + self.write_backup(node) + else: + if ( + self.enable_storage + and (not node.backuped_storage) + and node.hit_count >= self.write_through_threshold_storage + ): + # if the node is backuped on host memory but not on storage + self.write_backup_storage(node) def writing_check(self, write_back=False): if write_back: @@ -221,6 +248,10 @@ class HiRadixCache(RadixCache): if not x.evicted: continue + # node is protected from eviction as it has ongoing prefetch or backup to storage + if x.host_ref_counter > 0: + continue + num_evicted += self.cache_controller.evict_host(x.host_value) for k, v in x.parent.children.items(): @@ -314,6 +345,85 @@ class HiRadixCache(RadixCache): def check_hicache_events(self): self.writing_check() self.loading_check() + if self.enable_storage: + self.check_revoked_prefetch() + self.check_backup_progress() + + def check_revoked_prefetch(self): + queue_size = torch.tensor( + self.cache_controller.prefetch_revoke_queue.qsize(), dtype=torch.int + ) + if torch.distributed.get_world_size(group=self.tp_group) > 1: + # synchrnoize TP workers to make the same update to hiradix cache + torch.distributed.all_reduce( + queue_size, + op=torch.distributed.ReduceOp.MIN, + group=self.tp_group, + ) + for _ in range(queue_size.item()): + req_id = self.cache_controller.prefetch_revoke_queue.get() + if req_id in self.ongoing_prefetch: + last_host_node, _, host_indices, _ = self.ongoing_prefetch[req_id] + last_host_node.release_host() + self.cache_controller.mem_pool_host.free(host_indices) + del self.ongoing_prefetch[req_id] + + def check_backup_progress(self): + queue_size = torch.tensor( + self.cache_controller.ack_backup_queue.qsize(), dtype=torch.int + ) + if torch.distributed.get_world_size(group=self.tp_group) > 1: + # synchrnoize TP workers to make the same update to hiradix cache + torch.distributed.all_reduce( + queue_size, + op=torch.distributed.ReduceOp.MIN, + group=self.tp_group, + ) + for _ in range(queue_size.item()): + ack_id, hash_value = self.cache_controller.ack_backup_queue.get() + self.ongoing_backup[ack_id].hash_value = hash_value + self.ongoing_backup[ack_id].release_host() + del self.ongoing_backup[ack_id] + + def check_prefetch_progress(self, req_id: str): + if req_id not in self.ongoing_prefetch: + # there is no ongoing prefetch for this request or it has been revoked + return + + # todo: more policies for prefetch progress such as timeout + # the current policy is to prefetch with best effort and terminate when queuing is over + last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch[ + req_id + ] + completed_tokens, hash_value = self.cache_controller.terminate_prefetch( + operation + ) + logger.debug(f"Prefetch {req_id} completed with {completed_tokens} tokens") + + min_completed_tokens = torch.tensor(completed_tokens, dtype=torch.int) + if torch.distributed.get_world_size(group=self.tp_group) > 1: + # synchrnoize TP workers to make the same update to hiradix cache + torch.distributed.all_reduce( + min_completed_tokens, + op=torch.distributed.ReduceOp.MIN, + group=self.tp_group, + ) + min_completed_tokens = min_completed_tokens.item() + fetched_token_ids = token_ids[:min_completed_tokens] + written_indices = host_indices[:min_completed_tokens] + matched_length = self._insert_helper_host( + last_host_node, + fetched_token_ids, + written_indices, + hash_value[:min_completed_tokens], + ) + + self.cache_controller.mem_pool_host.free(host_indices[:matched_length]) + self.cache_controller.mem_pool_host.free( + host_indices[min_completed_tokens:completed_tokens] + ) + last_host_node.release_host() + del self.ongoing_prefetch[req_id] def match_prefix(self, key: List[int], **kwargs): empty_value = torch.empty((0,), dtype=torch.int64, device=self.device) @@ -348,6 +458,71 @@ class HiRadixCache(RadixCache): host_hit_length=host_hit_length, ) + def prefetch_from_storage( + self, + req_id: str, + last_host_node: TreeNode, + new_input_tokens: List[int], + last_hash: Optional[str] = None, + ): + if not self.enable_storage or len(new_input_tokens) < self.prefetch_threshold: + return + + last_host_node.protect_host() + host_indices = self.cache_controller.mem_pool_host.alloc(len(new_input_tokens)) + if host_indices is None: + self.evict_host(len(new_input_tokens)) + host_indices = self.cache_controller.mem_pool_host.alloc( + len(new_input_tokens) + ) + if host_indices is None: + last_host_node.release_host() + # no sufficient host memory to prefetch + return + operation = self.cache_controller.prefetch( + req_id, host_indices, new_input_tokens, last_hash + ) + self.ongoing_prefetch[req_id] = ( + last_host_node, + new_input_tokens, + host_indices, + operation, + ) + + def _insert_helper_host(self, node: TreeNode, key: List, host_value, hash_value): + node.last_access_time = time.monotonic() + if len(key) == 0: + return 0 + + child_key = self.get_child_key_fn(key) + + matched_length = 0 + while len(key) > 0 and child_key in node.children.keys(): + node = node.children[child_key] + node.last_access_time = time.monotonic() + prefix_len = self.key_match_fn(node.key, key) + key = key[prefix_len:] + host_value = host_value[prefix_len:] + hash_value = hash_value[prefix_len:] + matched_length += prefix_len + + if prefix_len < len(node.key): + new_node = self._split_node(node.key, node, prefix_len) + node = new_node + + if len(key): + child_key = self.get_child_key_fn(key) + + if len(key): + new_node = TreeNode() + new_node.parent = node + new_node.key = key + new_node.value = None + new_node.host_value = host_value + new_node.hash_value = hash_value + node.children[child_key] = new_node + return matched_length + def _match_prefix_helper(self, node: TreeNode, key: List): node.last_access_time = time.monotonic() child_key = self.get_child_key_fn(key) diff --git a/python/sglang/srt/mem_cache/memory_pool_host.py b/python/sglang/srt/mem_cache/memory_pool_host.py index 1bc2ddf7e..f50347962 100644 --- a/python/sglang/srt/mem_cache/memory_pool_host.py +++ b/python/sglang/srt/mem_cache/memory_pool_host.py @@ -99,6 +99,20 @@ class HostKVCache(abc.ABC): def init_kv_buffer(self): raise NotImplementedError() + @abc.abstractmethod + def get_flat_data_page(self, index) -> torch.Tensor: + """ + Get a flat data page from the host memory pool. + """ + raise NotImplementedError() + + @abc.abstractmethod + def set_from_flat_data_page(self, index: int, data_page: torch.Tensor) -> None: + """ + Set a flat data page to the host memory pool. + """ + raise NotImplementedError() + @synchronized() def clear(self): # Initialize memory states and tracking structures. @@ -227,6 +241,19 @@ class MHATokenToKVPoolHost(HostKVCache): pin_memory=self.pin_memory, ) + # todo, page first memory layout + def get_flat_data_page(self, index) -> torch.Tensor: + return self.kv_buffer[:, :, index : index + self.page_size, :, :].flatten() + + def set_from_flat_data_page(self, index: int, data_page: torch.Tensor) -> None: + self.kv_buffer[:, :, index : index + self.page_size, :, :] = data_page.reshape( + 2, + self.layer_num, + self.page_size, + self.head_num, + self.head_dim, + ) + @property def k_buffer(self): return self.kv_buffer[0] @@ -276,3 +303,14 @@ class MLATokenToKVPoolHost(HostKVCache): device=self.device, pin_memory=self.pin_memory, ) + + def get_flat_data_page(self, index) -> torch.Tensor: + return self.kv_buffer[:, index : index + self.page_size, :, :].flatten() + + def set_from_flat_data_page(self, index: int, data_page: torch.Tensor) -> None: + self.kv_buffer[:, index : index + self.page_size, :, :] = data_page.reshape( + self.layer_num, + self.page_size, + 1, + self.kv_lora_rank + self.qk_rope_head_dim, + ) diff --git a/python/sglang/srt/mem_cache/radix_cache.py b/python/sglang/srt/mem_cache/radix_cache.py index 706432209..0826990c2 100644 --- a/python/sglang/srt/mem_cache/radix_cache.py +++ b/python/sglang/srt/mem_cache/radix_cache.py @@ -55,8 +55,13 @@ class TreeNode: self.hit_count = 0 # indicating the node is loading KV cache from host self.loading = False + # indicating the node is locked to protect from eviction + # incremented when the node is referenced by a storage operation + self.host_ref_counter = 0 # store the host indices of KV cache self.host_value: Optional[torch.Tensor] = None + # store hash values of each pages + self.hash_value: Optional[List[str]] = None self.id = TreeNode.counter if id is None else id TreeNode.counter += 1 @@ -69,6 +74,27 @@ class TreeNode: def backuped(self): return self.host_value is not None + @property + def backuped_storage(self): + return self.hash_value is not None and len(self.hash_value) > 0 + + def protect_host(self): + """Protect the host value from eviction.""" + self.host_ref_counter += 1 + + def release_host(self): + """Release the host value, allowing it to be evicted.""" + if self.host_ref_counter > 0: + self.host_ref_counter -= 1 + else: + raise RuntimeError("Host reference counter is already zero.") + + def get_last_hash_value(self) -> Optional[str]: + """Returns the hash value of the last page in this node.""" + if self.hash_value is None or len(self.hash_value) == 0: + return None + return self.hash_value[-1] + def __lt__(self, other: "TreeNode"): return self.last_access_time < other.last_access_time diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index e475039d7..cb8038d33 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -222,6 +222,7 @@ class ServerArgs: hicache_size: int = 0 hicache_write_policy: str = "write_through_selective" hicache_io_backend: str = "" + hicache_storage_backend: Optional[str] = None flashinfer_mla_disable_ragged: bool = False disable_shared_experts_fusion: bool = False disable_chunked_prefix_cache: bool = False @@ -1604,6 +1605,13 @@ class ServerArgs: default=ServerArgs.hicache_io_backend, help="The IO backend for KV cache transfer between CPU and GPU", ) + parser.add_argument( + "--hicache-storage-backend", + type=str, + choices=["file"], # todo, mooncacke + default=ServerArgs.hicache_storage_backend, + help="The storage backend for hierarchical KV cache.", + ) parser.add_argument( "--flashinfer-mla-disable-ragged", action="store_true", diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 059955f33..41564869e 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -64,6 +64,7 @@ suites = { TestFile("test_fused_moe.py", 30), TestFile("test_hicache.py", 116), TestFile("test_hicache_mla.py", 127), + TestFile("test_hicache_storage.py", 127), TestFile("test_hidden_states.py", 55), TestFile("test_int8_kernel.py", 8), TestFile("test_input_embeddings.py", 38), diff --git a/test/srt/test_hicache_storage.py b/test/srt/test_hicache_storage.py new file mode 100644 index 000000000..aadc9529d --- /dev/null +++ b/test/srt/test_hicache_storage.py @@ -0,0 +1,55 @@ +import unittest +from types import SimpleNamespace + +from sglang.srt.utils import kill_process_tree +from sglang.test.run_eval import run_eval +from sglang.test.test_utils import ( + DEFAULT_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + popen_launch_server, +) + + +class TestHiCache(CustomTestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--enable-hierarchical-cache", + "--mem-fraction-static", + 0.7, + "--hicache-size", + 100, + "--page-size", + "64", + "--hicache-storage-backend", + "file", + ], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_mmlu(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mmlu", + num_examples=64, + num_threads=32, + ) + + metrics = run_eval(args) + self.assertGreaterEqual(metrics["score"], 0.65) + + +if __name__ == "__main__": + unittest.main()