From 1195182040c492637e9a9200143b9f46359c6993 Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Mon, 21 Apr 2025 09:15:03 +0800 Subject: [PATCH] Tiny add Engine.flush_cache API (#5241) --- python/sglang/srt/entrypoints/engine.py | 4 ++++ python/sglang/srt/entrypoints/http_server.py | 4 ++-- python/sglang/srt/managers/io_struct.py | 7 ++++++- python/sglang/srt/managers/scheduler.py | 10 ++++++---- python/sglang/srt/managers/tokenizer_manager.py | 15 +++++++++++---- 5 files changed, 29 insertions(+), 11 deletions(-) diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 9738e466c..32407fe24 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -279,6 +279,10 @@ class Engine(EngineBase): self.shutdown() return False + def flush_cache(self): + loop = asyncio.get_event_loop() + return loop.run_until_complete(self.tokenizer_manager.flush_cache()) + def start_profile(self): loop = asyncio.get_event_loop() loop.run_until_complete(self.tokenizer_manager.start_profile()) diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index 83ee9d403..750409ee9 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -315,11 +315,11 @@ async def classify_request(obj: EmbeddingReqInput, request: Request): @app.api_route("/flush_cache", methods=["GET", "POST"]) async def flush_cache(): """Flush the radix cache.""" - _global_state.tokenizer_manager.flush_cache() + ret = await _global_state.tokenizer_manager.flush_cache() return Response( content="Cache flushed.\nPlease check backend logs for more details. " "(When there are running or waiting requests, the operation will not be performed.)\n", - status_code=200, + status_code=200 if ret.success else HTTPStatus.BAD_REQUEST, ) diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index e8590c950..e6ddb03f7 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -671,10 +671,15 @@ class BatchEmbeddingOut: @dataclass -class FlushCacheReq: +class FlushCacheReqInput: pass +@dataclass +class FlushCacheReqOutput: + success: bool + + @dataclass class UpdateWeightFromDiskReqInput: # The model path with the new weights diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 218b6743c..d2a601f91 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -60,7 +60,8 @@ from sglang.srt.managers.io_struct import ( CloseSessionReqInput, ExpertDistributionReq, ExpertDistributionReqOutput, - FlushCacheReq, + FlushCacheReqInput, + FlushCacheReqOutput, GetInternalStateReq, GetInternalStateReqOutput, GetWeightsByNameReqInput, @@ -402,7 +403,7 @@ class Scheduler( [ (TokenizedGenerateReqInput, self.handle_generate_request), (TokenizedEmbeddingReqInput, self.handle_embedding_request), - (FlushCacheReq, self.flush_cache_wrapped), + (FlushCacheReqInput, self.flush_cache_wrapped), (AbortReq, self.abort_request), (OpenSessionReqInput, self.open_session), (CloseSessionReqInput, self.close_session), @@ -1596,8 +1597,9 @@ class Scheduler( time.sleep(5) self.parent_process.send_signal(signal.SIGQUIT) - def flush_cache_wrapped(self, recv_req: FlushCacheReq): - self.flush_cache() + def flush_cache_wrapped(self, recv_req: FlushCacheReqInput): + success = self.flush_cache() + return FlushCacheReqOutput(success=success) def flush_cache(self): """Flush the memory pool and cache.""" diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 92a6bbafc..e144781dd 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -66,7 +66,8 @@ from sglang.srt.managers.io_struct import ( EmbeddingReqInput, ExpertDistributionReq, ExpertDistributionReqOutput, - FlushCacheReq, + FlushCacheReqInput, + FlushCacheReqOutput, GenerateReqInput, GetInternalStateReq, GetInternalStateReqOutput, @@ -264,6 +265,9 @@ class TokenizerManager: self.resume_memory_occupation_communicator = _Communicator( self.send_to_scheduler, server_args.dp_size ) + self.flush_cache_communicator = _Communicator( + self.send_to_scheduler, server_args.dp_size + ) self.start_profile_communicator = _Communicator( self.send_to_scheduler, server_args.dp_size ) @@ -314,6 +318,10 @@ class TokenizerManager: ResumeMemoryOccupationReqOutput, self.resume_memory_occupation_communicator.handle_recv, ), + ( + FlushCacheReqOutput, + self.flush_cache_communicator.handle_recv, + ), ( ProfileReqOutput, self.start_profile_communicator.handle_recv, @@ -707,9 +715,8 @@ class TokenizerManager: except StopAsyncIteration: pass - def flush_cache(self): - req = FlushCacheReq() - self.send_to_scheduler.send_pyobj(req) + async def flush_cache(self) -> FlushCacheReqOutput: + return await self.flush_cache_communicator(FlushCacheReqInput()) def abort_request(self, rid: str): if rid not in self.rid_to_state: