From 6e0b6468325c9be4f05d9baa7d9578a65d375bb9 Mon Sep 17 00:00:00 2001 From: Zhiqiang Xie Date: Sat, 9 Aug 2025 01:16:51 -0700 Subject: [PATCH] HiCache Storage tp fix (#8878) --- .../sglang/srt/managers/cache_controller.py | 19 +++++++++++++------ python/sglang/srt/mem_cache/hiradix_cache.py | 8 ++++---- 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/python/sglang/srt/managers/cache_controller.py b/python/sglang/srt/managers/cache_controller.py index b518f42c5..35874bb18 100644 --- a/python/sglang/srt/managers/cache_controller.py +++ b/python/sglang/srt/managers/cache_controller.py @@ -570,10 +570,6 @@ class HiCacheController: ) completed_tokens += self.page_size else: - # operation terminated by controller, release pre-allocated memory - self.mem_pool_host.free( - operation.host_indices[operation.completed_tokens :] - ) break def mooncake_page_transfer(self, operation): @@ -599,6 +595,14 @@ class HiCacheController: self.generic_page_transfer(operation, batch_size=128) else: self.generic_page_transfer(operation) + + if self.tp_world_size > 1: + # to ensure all TP workers release the host memory at the same time + torch.distributed.barrier(group=self.prefetch_tp_group) + # operation terminated by controller, release pre-allocated memory + self.mem_pool_host.free( + operation.host_indices[operation.completed_tokens :] + ) except Empty: continue @@ -626,7 +630,9 @@ class HiCacheController: continue storage_hit_count = 0 - if self.prefetch_rate_limit_check(): + if ( + operation.host_indices is not None + ) and self.prefetch_rate_limit_check(): last_hash = operation.last_hash tokens_to_fetch = operation.token_ids @@ -670,7 +676,8 @@ class HiCacheController: if storage_hit_count < self.prefetch_threshold: # not to prefetch if not enough benefits self.prefetch_revoke_queue.put(operation.request_id) - self.mem_pool_host.free(operation.host_indices) + if operation.host_indices is not None: + self.mem_pool_host.free(operation.host_indices) logger.debug( f"Revoking prefetch for request {operation.request_id} due to insufficient hits ({storage_hit_count})." ) diff --git a/python/sglang/srt/mem_cache/hiradix_cache.py b/python/sglang/srt/mem_cache/hiradix_cache.py index 0f51712eb..e11b9e64d 100644 --- a/python/sglang/srt/mem_cache/hiradix_cache.py +++ b/python/sglang/srt/mem_cache/hiradix_cache.py @@ -471,6 +471,10 @@ class HiRadixCache(RadixCache): req_id ] + if operation.host_indices is None: + # prefetch has not been issued due to insufficient host memory + return True + if not self.can_terminate_prefetch(operation): return False @@ -565,10 +569,6 @@ class HiRadixCache(RadixCache): if host_indices is None: self.evict_host(prefetch_length) host_indices = self.cache_controller.mem_pool_host.alloc(prefetch_length) - if host_indices is None: - last_host_node.release_host() - # no sufficient host memory to prefetch - return operation = self.cache_controller.prefetch( req_id, host_indices, new_input_tokens, last_hash )