Tiny add Engine.flush_cache API (#5241)
This commit is contained in:
@@ -279,6 +279,10 @@ class Engine(EngineBase):
|
|||||||
self.shutdown()
|
self.shutdown()
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def flush_cache(self):
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
return loop.run_until_complete(self.tokenizer_manager.flush_cache())
|
||||||
|
|
||||||
def start_profile(self):
|
def start_profile(self):
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
loop.run_until_complete(self.tokenizer_manager.start_profile())
|
loop.run_until_complete(self.tokenizer_manager.start_profile())
|
||||||
|
|||||||
@@ -315,11 +315,11 @@ async def classify_request(obj: EmbeddingReqInput, request: Request):
|
|||||||
@app.api_route("/flush_cache", methods=["GET", "POST"])
|
@app.api_route("/flush_cache", methods=["GET", "POST"])
|
||||||
async def flush_cache():
|
async def flush_cache():
|
||||||
"""Flush the radix cache."""
|
"""Flush the radix cache."""
|
||||||
_global_state.tokenizer_manager.flush_cache()
|
ret = await _global_state.tokenizer_manager.flush_cache()
|
||||||
return Response(
|
return Response(
|
||||||
content="Cache flushed.\nPlease check backend logs for more details. "
|
content="Cache flushed.\nPlease check backend logs for more details. "
|
||||||
"(When there are running or waiting requests, the operation will not be performed.)\n",
|
"(When there are running or waiting requests, the operation will not be performed.)\n",
|
||||||
status_code=200,
|
status_code=200 if ret.success else HTTPStatus.BAD_REQUEST,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -671,10 +671,15 @@ class BatchEmbeddingOut:
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FlushCacheReq:
|
class FlushCacheReqInput:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FlushCacheReqOutput:
|
||||||
|
success: bool
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class UpdateWeightFromDiskReqInput:
|
class UpdateWeightFromDiskReqInput:
|
||||||
# The model path with the new weights
|
# The model path with the new weights
|
||||||
|
|||||||
@@ -60,7 +60,8 @@ from sglang.srt.managers.io_struct import (
|
|||||||
CloseSessionReqInput,
|
CloseSessionReqInput,
|
||||||
ExpertDistributionReq,
|
ExpertDistributionReq,
|
||||||
ExpertDistributionReqOutput,
|
ExpertDistributionReqOutput,
|
||||||
FlushCacheReq,
|
FlushCacheReqInput,
|
||||||
|
FlushCacheReqOutput,
|
||||||
GetInternalStateReq,
|
GetInternalStateReq,
|
||||||
GetInternalStateReqOutput,
|
GetInternalStateReqOutput,
|
||||||
GetWeightsByNameReqInput,
|
GetWeightsByNameReqInput,
|
||||||
@@ -402,7 +403,7 @@ class Scheduler(
|
|||||||
[
|
[
|
||||||
(TokenizedGenerateReqInput, self.handle_generate_request),
|
(TokenizedGenerateReqInput, self.handle_generate_request),
|
||||||
(TokenizedEmbeddingReqInput, self.handle_embedding_request),
|
(TokenizedEmbeddingReqInput, self.handle_embedding_request),
|
||||||
(FlushCacheReq, self.flush_cache_wrapped),
|
(FlushCacheReqInput, self.flush_cache_wrapped),
|
||||||
(AbortReq, self.abort_request),
|
(AbortReq, self.abort_request),
|
||||||
(OpenSessionReqInput, self.open_session),
|
(OpenSessionReqInput, self.open_session),
|
||||||
(CloseSessionReqInput, self.close_session),
|
(CloseSessionReqInput, self.close_session),
|
||||||
@@ -1596,8 +1597,9 @@ class Scheduler(
|
|||||||
time.sleep(5)
|
time.sleep(5)
|
||||||
self.parent_process.send_signal(signal.SIGQUIT)
|
self.parent_process.send_signal(signal.SIGQUIT)
|
||||||
|
|
||||||
def flush_cache_wrapped(self, recv_req: FlushCacheReq):
|
def flush_cache_wrapped(self, recv_req: FlushCacheReqInput):
|
||||||
self.flush_cache()
|
success = self.flush_cache()
|
||||||
|
return FlushCacheReqOutput(success=success)
|
||||||
|
|
||||||
def flush_cache(self):
|
def flush_cache(self):
|
||||||
"""Flush the memory pool and cache."""
|
"""Flush the memory pool and cache."""
|
||||||
|
|||||||
@@ -66,7 +66,8 @@ from sglang.srt.managers.io_struct import (
|
|||||||
EmbeddingReqInput,
|
EmbeddingReqInput,
|
||||||
ExpertDistributionReq,
|
ExpertDistributionReq,
|
||||||
ExpertDistributionReqOutput,
|
ExpertDistributionReqOutput,
|
||||||
FlushCacheReq,
|
FlushCacheReqInput,
|
||||||
|
FlushCacheReqOutput,
|
||||||
GenerateReqInput,
|
GenerateReqInput,
|
||||||
GetInternalStateReq,
|
GetInternalStateReq,
|
||||||
GetInternalStateReqOutput,
|
GetInternalStateReqOutput,
|
||||||
@@ -264,6 +265,9 @@ class TokenizerManager:
|
|||||||
self.resume_memory_occupation_communicator = _Communicator(
|
self.resume_memory_occupation_communicator = _Communicator(
|
||||||
self.send_to_scheduler, server_args.dp_size
|
self.send_to_scheduler, server_args.dp_size
|
||||||
)
|
)
|
||||||
|
self.flush_cache_communicator = _Communicator(
|
||||||
|
self.send_to_scheduler, server_args.dp_size
|
||||||
|
)
|
||||||
self.start_profile_communicator = _Communicator(
|
self.start_profile_communicator = _Communicator(
|
||||||
self.send_to_scheduler, server_args.dp_size
|
self.send_to_scheduler, server_args.dp_size
|
||||||
)
|
)
|
||||||
@@ -314,6 +318,10 @@ class TokenizerManager:
|
|||||||
ResumeMemoryOccupationReqOutput,
|
ResumeMemoryOccupationReqOutput,
|
||||||
self.resume_memory_occupation_communicator.handle_recv,
|
self.resume_memory_occupation_communicator.handle_recv,
|
||||||
),
|
),
|
||||||
|
(
|
||||||
|
FlushCacheReqOutput,
|
||||||
|
self.flush_cache_communicator.handle_recv,
|
||||||
|
),
|
||||||
(
|
(
|
||||||
ProfileReqOutput,
|
ProfileReqOutput,
|
||||||
self.start_profile_communicator.handle_recv,
|
self.start_profile_communicator.handle_recv,
|
||||||
@@ -707,9 +715,8 @@ class TokenizerManager:
|
|||||||
except StopAsyncIteration:
|
except StopAsyncIteration:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def flush_cache(self):
|
async def flush_cache(self) -> FlushCacheReqOutput:
|
||||||
req = FlushCacheReq()
|
return await self.flush_cache_communicator(FlushCacheReqInput())
|
||||||
self.send_to_scheduler.send_pyobj(req)
|
|
||||||
|
|
||||||
def abort_request(self, rid: str):
|
def abort_request(self, rid: str):
|
||||||
if rid not in self.rid_to_state:
|
if rid not in self.rid_to_state:
|
||||||
|
|||||||
Reference in New Issue
Block a user