fix get_memory_pool_size deadlock for DP (#1830)

This commit is contained in:
Byron Hsu
2024-10-28 23:07:14 -07:00
committed by GitHub
parent 0a24eb850a
commit 680cad2023
3 changed files with 34 additions and 5 deletions

View File

@@ -539,9 +539,22 @@ class TokenizerManager:
self.create_handle_loop()
req = GetMemPoolSizeReq()
self.send_to_scheduler.send_pyobj(req)
self.mem_pool_size = asyncio.Future()
return await self.mem_pool_size
ret = None
if self.server_args.dp_size == 1:
self.send_to_scheduler.send_pyobj(req)
self.mem_pool_size = asyncio.Future()
res = await self.mem_pool_size
ret = res.size
else: # self.server_args.dp_size > 1
self.send_to_scheduler.send_pyobj(req)
self.mem_pool_size = asyncio.Future()
self.mem_pool_size_tmp = []
res = await self.mem_pool_size
ret = [r.size for r in res]
return ret
async def update_weights(
self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None
@@ -634,7 +647,13 @@ class TokenizerManager:
self.model_update_result.set_result(self.model_update_tmp)
continue
elif isinstance(recv_obj, GetMemPoolSizeReqOutput):
self.mem_pool_size.set_result(recv_obj)
if self.server_args.dp_size == 1:
self.mem_pool_size.set_result(recv_obj)
else: # self.sever_args.dp_size > 1
self.mem_pool_size_tmp.append(recv_obj)
# set future if the all results are received
if len(self.mem_pool_size_tmp) == self.server_args.dp_size:
self.mem_pool_size.set_result(self.mem_pool_size_tmp)
continue
assert isinstance(

View File

@@ -177,7 +177,8 @@ async def get_memory_pool_size():
"""Get the memory pool size in number of tokens"""
try:
ret = await tokenizer_manager.get_memory_pool_size()
return ret.size
return ret
except Exception as e:
return JSONResponse(
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST