From 680cad20233be46da97e92db0ba29d2b8fa41c03 Mon Sep 17 00:00:00 2001 From: Byron Hsu Date: Mon, 28 Oct 2024 23:07:14 -0700 Subject: [PATCH] fix get_memory_pool_size deadlock for DP (#1830) --- .../sglang/srt/managers/tokenizer_manager.py | 27 ++++++++++++++++--- python/sglang/srt/server.py | 3 ++- test/srt/test_data_parallelism.py | 9 +++++++ 3 files changed, 34 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 428bf10d7..9a3e90969 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -539,9 +539,22 @@ class TokenizerManager: self.create_handle_loop() req = GetMemPoolSizeReq() - self.send_to_scheduler.send_pyobj(req) - self.mem_pool_size = asyncio.Future() - return await self.mem_pool_size + ret = None + + if self.server_args.dp_size == 1: + self.send_to_scheduler.send_pyobj(req) + self.mem_pool_size = asyncio.Future() + res = await self.mem_pool_size + ret = res.size + + else: # self.server_args.dp_size > 1 + self.send_to_scheduler.send_pyobj(req) + self.mem_pool_size = asyncio.Future() + self.mem_pool_size_tmp = [] + res = await self.mem_pool_size + ret = [r.size for r in res] + + return ret async def update_weights( self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None @@ -634,7 +647,13 @@ class TokenizerManager: self.model_update_result.set_result(self.model_update_tmp) continue elif isinstance(recv_obj, GetMemPoolSizeReqOutput): - self.mem_pool_size.set_result(recv_obj) + if self.server_args.dp_size == 1: + self.mem_pool_size.set_result(recv_obj) + else: # self.sever_args.dp_size > 1 + self.mem_pool_size_tmp.append(recv_obj) + # set future if the all results are received + if len(self.mem_pool_size_tmp) == self.server_args.dp_size: + self.mem_pool_size.set_result(self.mem_pool_size_tmp) continue assert isinstance( diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 64f6c6f55..c9d9c7ee5 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -177,7 +177,8 @@ async def get_memory_pool_size(): """Get the memory pool size in number of tokens""" try: ret = await tokenizer_manager.get_memory_pool_size() - return ret.size + + return ret except Exception as e: return JSONResponse( {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST diff --git a/test/srt/test_data_parallelism.py b/test/srt/test_data_parallelism.py index 00bae0a88..0ac8b784c 100644 --- a/test/srt/test_data_parallelism.py +++ b/test/srt/test_data_parallelism.py @@ -62,6 +62,15 @@ class TestDataParallelism(unittest.TestCase): # check if the response is 200 assert response.status_code == 200 + def test_get_memory_pool_size(self): + response = requests.get(self.base_url + "/get_memory_pool_size") + assert response.status_code == 200 + + time.sleep(5) + + response = requests.get(self.base_url + "/get_memory_pool_size") + assert response.status_code == 200 + if __name__ == "__main__": unittest.main()