From 8b6966d0205abeaca143693c6f273dcacbfa779d Mon Sep 17 00:00:00 2001 From: Zhiqiang Xie Date: Sun, 31 Aug 2025 07:58:21 -0700 Subject: [PATCH] [HiCache] Storage Refactoring (#9797) Co-authored-by: pansicheng <27603155+pansicheng@users.noreply.github.com> --- .../sglang/srt/managers/cache_controller.py | 156 +++++++----------- python/sglang/srt/mem_cache/hiradix_cache.py | 111 ++++++------- python/sglang/srt/mem_cache/radix_cache.py | 1 - 3 files changed, 114 insertions(+), 154 deletions(-) diff --git a/python/sglang/srt/managers/cache_controller.py b/python/sglang/srt/managers/cache_controller.py index 89fb00da4..8a8237c65 100644 --- a/python/sglang/srt/managers/cache_controller.py +++ b/python/sglang/srt/managers/cache_controller.py @@ -250,26 +250,21 @@ class HiCacheController: self.write_policy = write_policy self.page_size = page_size self.io_backend = io_backend - self.enable_storage = False - # todo: move backend initialization to storage backend module if storage_backend is not None: self.storage_backend_type = storage_backend from sglang.srt.mem_cache.hicache_storage import get_hash_str self.get_hash_str = get_hash_str - self.storage_config = self._generate_storage_config( model_name, storage_backend_extra_config ) - # In MLA backend, only one rank needs to backup the KV cache + # for MLA models, only one rank needs to backup the KV cache self.backup_skip = ( self.storage_config.is_mla_model - # todo: for load balancing, decide which rank to backup the KV cache by hash value + # todo: load balancing and self.storage_config.tp_rank != 0 - # todo: support other storage backends - and self.storage_backend_type in ["file", "mooncake"] ) if storage_backend == "file": @@ -309,12 +304,15 @@ class HiCacheController: raise NotImplementedError( f"Unsupported storage backend: {storage_backend}" ) + self.enable_storage = True # todo: threshold policy for prefetching self.prefetch_threshold = max(prefetch_threshold, self.page_size) self.prefetch_capacity_limit = int( 0.8 * (self.mem_pool_host.size - self.mem_pool_device.size) ) + # granularity of batch storage IO operations, in number of pages + self.storage_batch_size = 128 # tracking the number of tokens locked in prefetching, updated by the main scheduler thread self.prefetch_tokens_occupied = 0 @@ -325,12 +323,6 @@ class HiCacheController: self.prefetch_tp_group = torch.distributed.new_group( group_ranks, backend="gloo" ) - self.prefetch_io_tp_group = torch.distributed.new_group( - group_ranks, backend="gloo" - ) - self.backup_tp_group = torch.distributed.new_group( - group_ranks, backend="gloo" - ) self.load_cache_event = load_cache_event self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num) @@ -380,6 +372,7 @@ class HiCacheController: self.prefetch_revoke_queue = Queue() self.ack_backup_queue = Queue() + self.host_mem_release_queue = Queue() self.prefetch_thread.start() self.backup_thread.start() @@ -618,7 +611,11 @@ class HiCacheController: operation.mark_done() return operation.completed_tokens, operation.hash_value - # zero copy + def append_host_mem_release(self, host_indices: torch.Tensor): + chunks = host_indices.split(self.mem_pool_host.page_size) + for chunk in chunks: + self.host_mem_release_queue.put(chunk) + def _3fs_zero_copy_page_get(self, operation, hash_values, host_indices): hashes, dsts = self.mem_pool_host.get_buffer_with_hash( hash_values, host_indices @@ -631,7 +628,6 @@ class HiCacheController: f"Prefetch operation {operation.request_id} failed to retrieve page {hashes}." ) - # zero copy def _mooncake_page_get(self, operation, hash_values, host_indices): key_strs, buffer_ptrs, buffer_sizes = self.mem_pool_host.get_buffer_meta( hash_values, @@ -650,9 +646,7 @@ class HiCacheController: if get_result != 0: operation.increment(get_result * self.page_size) - # non-zero copy def _generic_page_get(self, operation, hash_values, host_indices): - # todo: zero copy dummy_page_dst = [self.mem_pool_host.get_dummy_flat_data_page()] * len( hash_values ) @@ -675,22 +669,19 @@ class HiCacheController: def _page_transfer(self, operation): # Select the get function and batch size - if self.is_mooncake_backend(): + if self.storage_backend_type == "mooncake": get_func = self._mooncake_page_get - batch_size = 128 - elif self.storage_backend_type == "hf3fs": - if self.mem_pool_host.layout == "page_first": - get_func = self._3fs_zero_copy_page_get - elif self.mem_pool_host.layout == "layer_first": - get_func = self._generic_page_get - batch_size = 128 + elif ( + self.storage_backend_type == "hf3fs" + and self.mem_pool_host.layout == "page_first" + ): + get_func = self._3fs_zero_copy_page_get else: get_func = self._generic_page_get - batch_size = 8 # Transfer batch by batch - for i in range(0, len(operation.hash_value), batch_size): - batch_hashes = operation.hash_value[i : i + batch_size] + for i in range(0, len(operation.hash_value), self.storage_batch_size): + batch_hashes = operation.hash_value[i : i + self.storage_batch_size] batch_host_indices = operation.host_indices[ i * self.page_size : (i + len(batch_hashes)) * self.page_size ] @@ -704,10 +695,9 @@ class HiCacheController: ): break # Some operations fail or operation terminated by controller # release pre-allocated memory - self.mem_pool_host.free(operation.host_indices[operation.completed_tokens :]) - - def is_mooncake_backend(self): - return self.storage_backend_type == "mooncake" + self.append_host_mem_release( + operation.host_indices[operation.completed_tokens :] + ) def prefetch_io_aux_func(self): """ @@ -717,47 +707,49 @@ class HiCacheController: try: operation = self.prefetch_buffer.get(block=True, timeout=1) self._page_transfer(operation) - - if self.tp_world_size > 1: - # to ensure all TP workers release the host memory at the same time - torch.distributed.barrier(group=self.prefetch_io_tp_group) # operation terminated by controller, release pre-allocated memory - self.mem_pool_host.free( + self.append_host_mem_release( operation.host_indices[operation.completed_tokens :] ) except Empty: continue - def prefetch_rate_limit_check(self) -> bool: + def prefetch_rate_limited(self) -> bool: """ Rate limit the prefetching operations to avoid overwhelming the storage backend. """ # cancel prefetch if too much memory is occupied if self.prefetch_tokens_occupied >= self.prefetch_capacity_limit: - return False + return True # todo: more sophisticated rate limiting based on storage backend performance - return True + return False - def _generic_storage_hit_query(self, operation) -> tuple[list[str], int]: + def _storage_hit_query(self, operation) -> tuple[list[str], int]: last_hash = operation.last_hash tokens_to_fetch = operation.token_ids storage_query_count = 0 - remaining_tokens = len(tokens_to_fetch) hash_value = [] - while remaining_tokens >= self.page_size: - last_hash = self.get_hash_str( - tokens_to_fetch[ - storage_query_count : storage_query_count + self.page_size - ], - last_hash, + + for start in range( + 0, len(tokens_to_fetch), self.page_size * self.storage_batch_size + ): + end = min( + start + self.page_size * self.storage_batch_size, len(tokens_to_fetch) ) - hash_value.append(last_hash) - storage_query_count += self.page_size - remaining_tokens -= self.page_size - # deferring to batch exists - hit_page_num = self.storage_backend.batch_exists(hash_value) - return hash_value[:hit_page_num], hit_page_num * self.page_size + batch_tokens = tokens_to_fetch[start:end] + batch_hashes = [] + for i in range(0, len(batch_tokens), self.page_size): + last_hash = self.get_hash_str( + batch_tokens[i : i + self.page_size], last_hash + ) + batch_hashes.append(last_hash) + hit_page_num = self.storage_backend.batch_exists(batch_hashes) + hash_value.extend(batch_hashes[:hit_page_num]) + storage_query_count += hit_page_num * self.page_size + if hit_page_num < len(batch_hashes): + break + return hash_value, storage_query_count def prefetch_thread_func(self): """ @@ -772,13 +764,7 @@ class HiCacheController: if operation is None: continue - if ( - operation.host_indices is not None - ) and self.prefetch_rate_limit_check(): - hash_value, storage_hit_count = self._generic_storage_hit_query( - operation - ) - + hash_value, storage_hit_count = self._storage_hit_query(operation) if self.tp_world_size > 1: storage_hit_count_tensor = torch.tensor( storage_hit_count, dtype=torch.int @@ -793,8 +779,7 @@ class HiCacheController: if storage_hit_count < self.prefetch_threshold: # not to prefetch if not enough benefits self.prefetch_revoke_queue.put(operation.request_id) - if operation.host_indices is not None: - self.mem_pool_host.free(operation.host_indices) + self.append_host_mem_release(operation.host_indices) logger.debug( f"Revoking prefetch for request {operation.request_id} due to insufficient hits ({storage_hit_count})." ) @@ -803,7 +788,9 @@ class HiCacheController: : (storage_hit_count // self.page_size) ] # free the pre-allocated memory for pages that are not hit - self.mem_pool_host.free(operation.host_indices[storage_hit_count:]) + self.append_host_mem_release( + operation.host_indices[storage_hit_count:] + ) operation.host_indices = operation.host_indices[:storage_hit_count] logger.debug( f"Prefetching {len(operation.hash_value)} pages for request {operation.request_id}." @@ -858,21 +845,18 @@ class HiCacheController: # Backup batch by batch def _page_backup(self, operation): # Select the set function and batch size - if self.is_mooncake_backend(): + if self.storage_backend_type == "mooncake": backup_set_func = self._mooncake_page_set - batch_size = 128 - elif self.storage_backend_type == "hf3fs": - if self.mem_pool_host.layout == "page_first": - backup_set_func = self._3fs_zero_copy_page_set - elif self.mem_pool_host.layout == "layer_first": - backup_set_func = self._generic_page_set - batch_size = 128 + elif ( + self.storage_backend_type == "hf3fs" + and self.mem_pool_host.layout == "page_first" + ): + backup_set_func = self._3fs_zero_copy_page_set else: backup_set_func = self._generic_page_set - batch_size = 8 # Backup batch by batch - for i in range(0, len(operation.hash_value), batch_size): - batch_hashes = operation.hash_value[i : i + batch_size] + for i in range(0, len(operation.hash_value), self.storage_batch_size): + batch_hashes = operation.hash_value[i : i + self.storage_batch_size] batch_host_indices = operation.host_indices[ i * self.page_size : (i + len(batch_hashes)) * self.page_size ] @@ -898,27 +882,7 @@ class HiCacheController: if not self.backup_skip: self._page_backup(operation) - min_completed_tokens = operation.completed_tokens - else: - min_completed_tokens = len(operation.token_ids) - - if self.tp_world_size > 1: - completed_tokens_tensor = torch.tensor( - min_completed_tokens, dtype=torch.int - ) - torch.distributed.all_reduce( - completed_tokens_tensor, - op=torch.distributed.ReduceOp.MIN, - group=self.backup_tp_group, - ) - min_completed_tokens = completed_tokens_tensor.item() - - self.ack_backup_queue.put( - ( - operation.id, - min_completed_tokens, - ) - ) + self.ack_backup_queue.put(operation.id) except Empty: continue diff --git a/python/sglang/srt/mem_cache/hiradix_cache.py b/python/sglang/srt/mem_cache/hiradix_cache.py index dbbdcc890..2bd231ae6 100644 --- a/python/sglang/srt/mem_cache/hiradix_cache.py +++ b/python/sglang/srt/mem_cache/hiradix_cache.py @@ -104,9 +104,6 @@ class HiRadixCache(RadixCache): self.write_through_threshold = ( 1 if hicache_write_policy == "write_through" else 2 ) - self.write_through_threshold_storage = ( - 1 if hicache_write_policy == "write_through" else 3 - ) self.load_back_threshold = 10 super().__init__( req_to_token_pool, token_to_kv_pool_allocator, page_size, disable=False @@ -174,14 +171,6 @@ class HiRadixCache(RadixCache): if node.hit_count >= self.write_through_threshold: # write to host if the node is not backuped self.write_backup(node) - else: - if ( - self.enable_storage - and (not node.backuped_storage) - and node.hit_count >= self.write_through_threshold_storage - ): - # if the node is backuped on host memory but not on storage - self.write_backup_storage(node) def writing_check(self, write_back=False): if write_back: @@ -202,8 +191,11 @@ class HiRadixCache(RadixCache): ) for _ in range(queue_size.item()): ack_id = self.cache_controller.ack_write_queue.get() - self.dec_lock_ref(self.ongoing_write_through[ack_id]) + backuped_node = self.ongoing_write_through[ack_id] + self.dec_lock_ref(backuped_node) del self.ongoing_write_through[ack_id] + if self.enable_storage: + self.write_backup_storage(backuped_node) def loading_check(self): while not self.cache_controller.ack_load_queue.empty(): @@ -386,57 +378,54 @@ class HiRadixCache(RadixCache): self.writing_check() self.loading_check() if self.enable_storage: - self.check_revoked_prefetch() - self.check_backup_progress() + self.drain_storage_control_queues() - def check_revoked_prefetch(self): - queue_size = torch.tensor( - self.cache_controller.prefetch_revoke_queue.qsize(), dtype=torch.int + def drain_storage_control_queues(self): + """ + Combine prefetch revoke, backup ack, and host mem release checks + to minimize TP synchronization and Python overhead. + """ + cc = self.cache_controller + + qsizes = torch.tensor( + [ + cc.prefetch_revoke_queue.qsize(), + cc.ack_backup_queue.qsize(), + cc.host_mem_release_queue.qsize(), + ], + dtype=torch.int, ) if self.tp_world_size > 1: - # synchrnoize TP workers to make the same update to hiradix cache torch.distributed.all_reduce( - queue_size, - op=torch.distributed.ReduceOp.MIN, - group=self.tp_group, + qsizes, op=torch.distributed.ReduceOp.MIN, group=self.tp_group ) - for _ in range(queue_size.item()): - req_id = self.cache_controller.prefetch_revoke_queue.get() - if req_id in self.ongoing_prefetch: - last_host_node, token_ids, _, _ = self.ongoing_prefetch[req_id] + + n_revoke, n_backup, n_release = map(int, qsizes.tolist()) + + # process prefetch revokes + for _ in range(n_revoke): + req_id = cc.prefetch_revoke_queue.get() + info = self.ongoing_prefetch.pop(req_id, None) + if info is not None: + last_host_node, token_ids, _, _ = info last_host_node.release_host() - del self.ongoing_prefetch[req_id] - self.cache_controller.prefetch_tokens_occupied -= len(token_ids) - else: - # the revoked operation already got terminated - pass + cc.prefetch_tokens_occupied -= len(token_ids) + # else: the revoked operation already got terminated, nothing to do - def check_backup_progress(self): - queue_size = torch.tensor( - self.cache_controller.ack_backup_queue.qsize(), dtype=torch.int - ) - if self.tp_world_size > 1: - # synchrnoize TP workers to make the same update to hiradix cache - torch.distributed.all_reduce( - queue_size, - op=torch.distributed.ReduceOp.MIN, - group=self.tp_group, - ) - for _ in range(queue_size.item()): - ack_id, completed_tokens = self.cache_controller.ack_backup_queue.get() - host_node = self.ongoing_backup[ack_id] + # process backup acks + for _ in range(n_backup): + ack_id = cc.ack_backup_queue.get() + entry = self.ongoing_backup.pop(ack_id, None) + if entry is not None: + entry.release_host() - if completed_tokens > 0: - if completed_tokens < len(host_node.key): - # backup is only partially successful, split the node - new_node = self._split_node( - host_node.key, host_node, completed_tokens - ) - new_node.backuped_storage = True - else: - host_node.backuped_storage = True - host_node.release_host() - del self.ongoing_backup[ack_id] + # release host memory + host_indices_list = [] + for _ in range(n_release): + host_indices_list.append(cc.host_mem_release_queue.get()) + if host_indices_list: + host_indices = torch.cat(host_indices_list, dim=0) + cc.mem_pool_host.free(host_indices) def can_terminate_prefetch(self, operation: PrefetchOperation): can_terminate = True @@ -519,7 +508,7 @@ class HiRadixCache(RadixCache): self.cache_controller.mem_pool_host.update_prefetch(written_indices) self.cache_controller.mem_pool_host.free(host_indices[:matched_length]) - self.cache_controller.mem_pool_host.free( + self.cache_controller.append_host_mem_release( host_indices[min_completed_tokens:completed_tokens] ) last_host_node.release_host() @@ -575,7 +564,11 @@ class HiRadixCache(RadixCache): len(new_input_tokens) % self.page_size ) new_input_tokens = new_input_tokens[:prefetch_length] - if not self.enable_storage or prefetch_length < self.prefetch_threshold: + if ( + not self.enable_storage + or prefetch_length < self.prefetch_threshold + or self.cache_controller.prefetch_rate_limited() + ): return last_host_node.protect_host() @@ -583,6 +576,10 @@ class HiRadixCache(RadixCache): if host_indices is None: self.evict_host(prefetch_length) host_indices = self.cache_controller.mem_pool_host.alloc(prefetch_length) + if host_indices is None: + last_host_node.release_host() + # no sufficient host memory for prefetch + return operation = self.cache_controller.prefetch( req_id, host_indices, new_input_tokens, last_hash ) diff --git a/python/sglang/srt/mem_cache/radix_cache.py b/python/sglang/srt/mem_cache/radix_cache.py index a586b8696..b0cf0bb9c 100644 --- a/python/sglang/srt/mem_cache/radix_cache.py +++ b/python/sglang/srt/mem_cache/radix_cache.py @@ -62,7 +62,6 @@ class TreeNode: self.host_value: Optional[torch.Tensor] = None # store hash values of each pages self.hash_value: Optional[List[str]] = None - self.backuped_storage = False self.id = TreeNode.counter if id is None else id TreeNode.counter += 1