From f05c68733ec8827fe3008866d23fc96735e76fb2 Mon Sep 17 00:00:00 2001 From: Teng Ma Date: Sun, 31 Aug 2025 17:41:44 +0800 Subject: [PATCH] [HiCache] Clear kvcache in storage backend with fastAPI (#9750) Co-authored-by: hzh0425 --- python/sglang/srt/entrypoints/http_server.py | 10 ++++++++++ python/sglang/srt/managers/io_struct.py | 10 ++++++++++ python/sglang/srt/managers/scheduler.py | 13 +++++++++++++ .../sglang/srt/managers/tokenizer_manager.py | 16 ++++++++++++++++ python/sglang/srt/mem_cache/hicache_storage.py | 18 +++++++++++++++++- python/sglang/srt/mem_cache/hiradix_cache.py | 9 +++++++++ .../mem_cache/storage/hf3fs/storage_hf3fs.py | 10 ++++++++-- .../storage/mooncake_store/mooncake_store.py | 2 +- 8 files changed, 84 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index aa496b754..5d6e03ac3 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -480,6 +480,16 @@ async def flush_cache(): ) +@app.api_route("/clear_hicache_storage_backend", methods=["GET", "POST"]) +async def clear_hicache_storage_backend(): + """Clear the hierarchical cache storage backend.""" + ret = await _global_state.tokenizer_manager.clear_hicache_storage() + return Response( + content="Hierarchical cache storage backend cleared.\n", + status_code=200 if ret.success else HTTPStatus.BAD_REQUEST, + ) + + @app.api_route("/start_profile", methods=["GET", "POST"]) async def start_profile_async(obj: Optional[ProfileReqInput] = None): """Start profiling.""" diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 256868e4a..917d387fe 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -814,6 +814,16 @@ class BatchEmbeddingOut: cached_tokens: List[int] +@dataclass +class ClearHiCacheReqInput: + pass + + +@dataclass +class ClearHiCacheReqOutput: + success: bool + + @dataclass class FlushCacheReqInput: pass diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index f7de3275e..38ff0ef14 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -69,6 +69,8 @@ from sglang.srt.managers.io_struct import ( AbortReq, BatchTokenizedEmbeddingReqInput, BatchTokenizedGenerateReqInput, + ClearHiCacheReqInput, + ClearHiCacheReqOutput, CloseSessionReqInput, ExpertDistributionReq, ExpertDistributionReqOutput, @@ -515,6 +517,7 @@ class Scheduler( (BatchTokenizedGenerateReqInput, self.handle_batch_generate_request), (BatchTokenizedEmbeddingReqInput, self.handle_batch_embedding_request), (FlushCacheReqInput, self.flush_cache_wrapped), + (ClearHiCacheReqInput, self.clear_hicache_storage_wrapped), (AbortReq, self.abort_request), (OpenSessionReqInput, self.open_session), (CloseSessionReqInput, self.close_session), @@ -2207,6 +2210,16 @@ class Scheduler( success = self.flush_cache() return FlushCacheReqOutput(success=success) + def clear_hicache_storage_wrapped(self, recv_req: ClearHiCacheReqInput): + if self.enable_hierarchical_cache: + self.tree_cache.clear_storage_backend() + logger.info("Hierarchical cache cleared successfully!") + if_success = True + else: + logging.warning("Hierarchical cache is not enabled.") + if_success = False + return ClearHiCacheReqOutput(success=if_success) + def flush_cache(self): """Flush the memory pool and cache.""" if ( diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 7c09379cd..a209567c4 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -73,6 +73,8 @@ from sglang.srt.managers.io_struct import ( BatchTokenIDOut, BatchTokenizedEmbeddingReqInput, BatchTokenizedGenerateReqInput, + ClearHiCacheReqInput, + ClearHiCacheReqOutput, CloseSessionReqInput, ConfigureLoggingReq, EmbeddingReqInput, @@ -386,6 +388,9 @@ class TokenizerManager: self.flush_cache_communicator = _Communicator( self.send_to_scheduler, server_args.dp_size ) + self.clear_hicache_storage_communicator = _Communicator( + self.send_to_scheduler, server_args.dp_size + ) self.profile_communicator = _Communicator( self.send_to_scheduler, server_args.dp_size ) @@ -447,6 +452,10 @@ class TokenizerManager: SlowDownReqOutput, self.slow_down_communicator.handle_recv, ), + ( + ClearHiCacheReqOutput, + self.clear_hicache_storage_communicator.handle_recv, + ), ( FlushCacheReqOutput, self.flush_cache_communicator.handle_recv, @@ -988,6 +997,13 @@ class TokenizerManager: async def flush_cache(self) -> FlushCacheReqOutput: return (await self.flush_cache_communicator(FlushCacheReqInput()))[0] + async def clear_hicache_storage(self) -> ClearHiCacheReqOutput: + """Clear the hierarchical cache storage.""" + # Delegate to the scheduler to handle HiCacheStorage clearing + return (await self.clear_hicache_storage_communicator(ClearHiCacheReqInput()))[ + 0 + ] + def abort_request(self, rid: str = "", abort_all: bool = False): if not abort_all and rid not in self.rid_to_state: return diff --git a/python/sglang/srt/mem_cache/hicache_storage.py b/python/sglang/srt/mem_cache/hicache_storage.py index aaaee0262..159c70012 100644 --- a/python/sglang/srt/mem_cache/hicache_storage.py +++ b/python/sglang/srt/mem_cache/hicache_storage.py @@ -102,6 +102,20 @@ class HiCacheStorage(ABC): """ pass + @abstractmethod + def delete(self, key: str) -> bool: + """ + Delete the entry associated with the given key. + """ + pass + + @abstractmethod + def clear(self) -> bool: + """ + Clear all entries in the storage. + """ + pass + def batch_exists(self, keys: List[str]) -> int: """ Check if the keys exist in the storage. @@ -214,12 +228,14 @@ class HiCacheFile(HiCacheStorage): logger.warning(f"Key {key} does not exist. Cannot delete.") return - def clear(self) -> None: + def clear(self) -> bool: try: for filename in os.listdir(self.file_path): file_path = os.path.join(self.file_path, filename) if os.path.isfile(file_path): os.remove(file_path) logger.info("Cleared all entries in HiCacheFile storage.") + return True except Exception as e: logger.error(f"Failed to clear HiCacheFile storage: {e}") + return False diff --git a/python/sglang/srt/mem_cache/hiradix_cache.py b/python/sglang/srt/mem_cache/hiradix_cache.py index 611e94386..dbbdcc890 100644 --- a/python/sglang/srt/mem_cache/hiradix_cache.py +++ b/python/sglang/srt/mem_cache/hiradix_cache.py @@ -125,6 +125,15 @@ class HiRadixCache(RadixCache): height += 1 return height + def clear_storage_backend(self): + if self.enable_storage: + self.cache_controller.storage_backend.clear() + logger.info("Hierarchical cache storage backend cleared successfully!") + return True + else: + logger.warning("Hierarchical cache storage backend is not enabled.") + return False + def write_backup(self, node: TreeNode, write_back=False): host_indices = self.cache_controller.write( device_indices=node.value, diff --git a/python/sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py b/python/sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py index bf82dcd15..82e850d37 100644 --- a/python/sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +++ b/python/sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py @@ -393,8 +393,14 @@ class HiCacheHF3FS(HiCacheStorage): return len(keys) - def clear(self) -> None: - self.metadata_client.clear(self.rank) + def clear(self) -> bool: + try: + self.metadata_client.clear(self.rank) + logger.info(f"Cleared HiCacheHF3FS for rank {self.rank}") + return True + except Exception as e: + logger.error(f"Failed to clear HiCacheHF3FS: {e}") + return False def close(self) -> None: try: diff --git a/python/sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py b/python/sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py index bef26257b..ec9343f7e 100644 --- a/python/sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +++ b/python/sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py @@ -254,7 +254,7 @@ class MooncakeStore(HiCacheStorage): pass def clear(self) -> None: - raise (NotImplementedError) + self.store.remove_all() def _put_batch_zero_copy_impl( self, key_strs: List[str], buffer_ptrs: List[int], buffer_sizes: List[int]