From 528bd1ed856e4a9225eef3a4e9eeddff41c8a940 Mon Sep 17 00:00:00 2001 From: Zhiqiang Xie Date: Sat, 26 Jul 2025 23:13:16 -0700 Subject: [PATCH] HiCache, check before terminate prefetching (#8372) --- python/sglang/srt/managers/cache_controller.py | 16 +++++++++------- python/sglang/srt/mem_cache/hiradix_cache.py | 7 +++++-- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/python/sglang/srt/managers/cache_controller.py b/python/sglang/srt/managers/cache_controller.py index 9ef860f63..fb7ad794f 100644 --- a/python/sglang/srt/managers/cache_controller.py +++ b/python/sglang/srt/managers/cache_controller.py @@ -201,8 +201,9 @@ class PrefetchOperation(StorageOperation): def increment(self, num_tokens: int): with self._lock: if self._done_flag: - return + return False self.completed_tokens += num_tokens + return True def mark_done(self): with self._lock: @@ -528,12 +529,12 @@ class HiCacheController: f"Prefetch operation {operation.request_id} failed to retrieve page {h}." ) break - self.mem_pool_host.set_from_flat_data_page( - operation.host_indices[operation.completed_tokens], - page_data, - ) - operation.increment(self.page_size) - if operation.is_done(): + if operation.increment(self.page_size): + self.mem_pool_host.set_from_flat_data_page( + operation.host_indices[operation.completed_tokens], + page_data, + ) + else: # operation terminated by controller, release pre-allocated memory self.mem_pool_host.free( operation.host_indices[operation.completed_tokens :] @@ -589,6 +590,7 @@ class HiCacheController: if storage_hit_count < self.prefetch_threshold: # not to prefetch if not enough benefits self.prefetch_revoke_queue.put(operation.request_id) + self.mem_pool_host.free(operation.host_indices) logger.debug( f"Revoking prefetch for request {operation.request_id} due to insufficient hits ({storage_hit_count})." ) diff --git a/python/sglang/srt/mem_cache/hiradix_cache.py b/python/sglang/srt/mem_cache/hiradix_cache.py index 05248a1de..e6acbe9cc 100644 --- a/python/sglang/srt/mem_cache/hiradix_cache.py +++ b/python/sglang/srt/mem_cache/hiradix_cache.py @@ -365,10 +365,12 @@ class HiRadixCache(RadixCache): for _ in range(queue_size.item()): req_id = self.cache_controller.prefetch_revoke_queue.get() if req_id in self.ongoing_prefetch: - last_host_node, _, host_indices, _ = self.ongoing_prefetch[req_id] + last_host_node, _, _, _ = self.ongoing_prefetch[req_id] last_host_node.release_host() - self.cache_controller.mem_pool_host.free(host_indices) del self.ongoing_prefetch[req_id] + else: + # the revoked operation already got terminated + pass def check_backup_progress(self): queue_size = torch.tensor( @@ -403,6 +405,7 @@ class HiRadixCache(RadixCache): last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch[ req_id ] + completed_tokens, hash_value = self.cache_controller.terminate_prefetch( operation )