fix get_memory_pool_size deadlock for DP (#1830)
This commit is contained in:
@@ -539,9 +539,22 @@ class TokenizerManager:
|
|||||||
self.create_handle_loop()
|
self.create_handle_loop()
|
||||||
|
|
||||||
req = GetMemPoolSizeReq()
|
req = GetMemPoolSizeReq()
|
||||||
self.send_to_scheduler.send_pyobj(req)
|
ret = None
|
||||||
self.mem_pool_size = asyncio.Future()
|
|
||||||
return await self.mem_pool_size
|
if self.server_args.dp_size == 1:
|
||||||
|
self.send_to_scheduler.send_pyobj(req)
|
||||||
|
self.mem_pool_size = asyncio.Future()
|
||||||
|
res = await self.mem_pool_size
|
||||||
|
ret = res.size
|
||||||
|
|
||||||
|
else: # self.server_args.dp_size > 1
|
||||||
|
self.send_to_scheduler.send_pyobj(req)
|
||||||
|
self.mem_pool_size = asyncio.Future()
|
||||||
|
self.mem_pool_size_tmp = []
|
||||||
|
res = await self.mem_pool_size
|
||||||
|
ret = [r.size for r in res]
|
||||||
|
|
||||||
|
return ret
|
||||||
|
|
||||||
async def update_weights(
|
async def update_weights(
|
||||||
self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None
|
self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None
|
||||||
@@ -634,7 +647,13 @@ class TokenizerManager:
|
|||||||
self.model_update_result.set_result(self.model_update_tmp)
|
self.model_update_result.set_result(self.model_update_tmp)
|
||||||
continue
|
continue
|
||||||
elif isinstance(recv_obj, GetMemPoolSizeReqOutput):
|
elif isinstance(recv_obj, GetMemPoolSizeReqOutput):
|
||||||
self.mem_pool_size.set_result(recv_obj)
|
if self.server_args.dp_size == 1:
|
||||||
|
self.mem_pool_size.set_result(recv_obj)
|
||||||
|
else: # self.sever_args.dp_size > 1
|
||||||
|
self.mem_pool_size_tmp.append(recv_obj)
|
||||||
|
# set future if the all results are received
|
||||||
|
if len(self.mem_pool_size_tmp) == self.server_args.dp_size:
|
||||||
|
self.mem_pool_size.set_result(self.mem_pool_size_tmp)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
|
|||||||
@@ -177,7 +177,8 @@ async def get_memory_pool_size():
|
|||||||
"""Get the memory pool size in number of tokens"""
|
"""Get the memory pool size in number of tokens"""
|
||||||
try:
|
try:
|
||||||
ret = await tokenizer_manager.get_memory_pool_size()
|
ret = await tokenizer_manager.get_memory_pool_size()
|
||||||
return ret.size
|
|
||||||
|
return ret
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
||||||
|
|||||||
@@ -62,6 +62,15 @@ class TestDataParallelism(unittest.TestCase):
|
|||||||
# check if the response is 200
|
# check if the response is 200
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
def test_get_memory_pool_size(self):
|
||||||
|
response = requests.get(self.base_url + "/get_memory_pool_size")
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
time.sleep(5)
|
||||||
|
|
||||||
|
response = requests.get(self.base_url + "/get_memory_pool_size")
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user