[API] add get memory pool size (#1760)

Co-authored-by: Byron Hsu <byronhsu1230@gmail.com>
This commit is contained in:
Ying Sheng
2024-10-23 00:02:29 -07:00
committed by GitHub
parent ad4125d1a9
commit 2fce449b1c
6 changed files with 50 additions and 0 deletions

View File

@@ -46,6 +46,8 @@ from sglang.srt.managers.io_struct import (
EmbeddingReqInput,
FlushCacheReq,
GenerateReqInput,
GetMemPoolSizeReq,
GetMemPoolSizeReqOutput,
ProfileReq,
RewardReqInput,
TokenizedEmbeddingReqInput,
@@ -531,6 +533,15 @@ class TokenizerManager:
req = ProfileReq.STOP_PROFILE
self.send_to_scheduler.send_pyobj(req)
async def get_memory_pool_size(self):
if self.to_create_loop:
self.create_handle_loop()
req = GetMemPoolSizeReq()
self.send_to_scheduler.send_pyobj(req)
self.mem_pool_size = asyncio.Future()
return await self.mem_pool_size
async def update_weights(
self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None
):
@@ -590,6 +601,9 @@ class TokenizerManager:
if isinstance(recv_obj, UpdateWeightReqOutput):
self.model_update_result.set_result(recv_obj)
continue
elif isinstance(recv_obj, GetMemPoolSizeReqOutput):
self.mem_pool_size.set_result(recv_obj)
continue
assert isinstance(
recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)