diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py index 4ae31ecc8..d0d399363 100644 --- a/python/sglang/srt/managers/detokenizer_manager.py +++ b/python/sglang/srt/managers/detokenizer_manager.py @@ -27,6 +27,7 @@ from sglang.srt.managers.io_struct import ( BatchEmbeddingOut, BatchStrOut, BatchTokenIDOut, + GetMemPoolSizeReqOutput, UpdateWeightReqOutput, ) from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR, FINISH_MATCHED_TOKEN @@ -111,6 +112,9 @@ class DetokenizerManager: # If it is a weight update request, no detokenization is needed. self.send_to_tokenizer.send_pyobj(recv_obj) continue + elif isinstance(recv_obj, GetMemPoolSizeReqOutput): + self.send_to_tokenizer.send_pyobj(recv_obj) + continue elif self.tokenizer is None: # If the tokenizer is skipped, no detokenization is needed self.send_to_tokenizer.send_pyobj(recv_obj) diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 9625ff44e..2cdc3f478 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -353,3 +353,13 @@ class AbortReq: class ProfileReq(Enum): START_PROFILE = 1 STOP_PROFILE = 2 + + +@dataclass +class GetMemPoolSizeReq: + pass + + +@dataclass +class GetMemPoolSizeReqOutput: + size: int diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 16f4196bd..60531ce25 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -38,6 +38,8 @@ from sglang.srt.managers.io_struct import ( BatchEmbeddingOut, BatchTokenIDOut, FlushCacheReq, + GetMemPoolSizeReq, + GetMemPoolSizeReqOutput, ProfileReq, TokenizedEmbeddingReqInput, TokenizedGenerateReqInput, @@ -363,6 +365,10 @@ class Scheduler: self.start_profile() else: self.stop_profile() + elif isinstance(recv_req, GetMemPoolSizeReq): + self.send_to_detokenizer.send_pyobj( + GetMemPoolSizeReqOutput(self.max_total_num_tokens) + ) else: raise ValueError(f"Invalid request: {recv_req}") diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index fc9e23519..875239a94 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -46,6 +46,8 @@ from sglang.srt.managers.io_struct import ( EmbeddingReqInput, FlushCacheReq, GenerateReqInput, + GetMemPoolSizeReq, + GetMemPoolSizeReqOutput, ProfileReq, RewardReqInput, TokenizedEmbeddingReqInput, @@ -531,6 +533,15 @@ class TokenizerManager: req = ProfileReq.STOP_PROFILE self.send_to_scheduler.send_pyobj(req) + async def get_memory_pool_size(self): + if self.to_create_loop: + self.create_handle_loop() + + req = GetMemPoolSizeReq() + self.send_to_scheduler.send_pyobj(req) + self.mem_pool_size = asyncio.Future() + return await self.mem_pool_size + async def update_weights( self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None ): @@ -590,6 +601,9 @@ class TokenizerManager: if isinstance(recv_obj, UpdateWeightReqOutput): self.model_update_result.set_result(recv_obj) continue + elif isinstance(recv_obj, GetMemPoolSizeReqOutput): + self.mem_pool_size.set_result(recv_obj) + continue assert isinstance( recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut) diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index ceb2d55c2..8912c5583 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -172,6 +172,18 @@ async def stop_profile(): ) +@app.api_route("/get_memory_pool_size", methods=["GET", "POST"]) +async def get_memory_pool_size(): + """Get the memory pool size in number of tokens""" + try: + ret = await tokenizer_manager.get_memory_pool_size() + return ret.size + except Exception as e: + return JSONResponse( + {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST + ) + + @app.post("/update_weights") async def update_weights(obj: UpdateWeightReqInput, request: Request): """Update the weights inplace without re-launching the server.""" diff --git a/test/srt/test_srt_endpoint.py b/test/srt/test_srt_endpoint.py index 9a0a37c60..c4c8e844d 100644 --- a/test/srt/test_srt_endpoint.py +++ b/test/srt/test_srt_endpoint.py @@ -119,6 +119,10 @@ class TestSRTEndpoint(unittest.TestCase): [x[-1] for x in res["meta_info"]["output_token_logprobs"]] ) + def test_get_memory_pool_size(self): + response = requests.post(self.base_url + "/get_memory_pool_size") + assert isinstance(response.json(), int) + if __name__ == "__main__": unittest.main()