Fix update_weights deadlock for DP (#1825)
This commit is contained in:
@@ -554,18 +554,43 @@ class TokenizerManager:
|
||||
obj.load_format = self.server_args.load_format
|
||||
|
||||
if not self.model_update_lock.locked():
|
||||
async with self.model_update_lock:
|
||||
# wait for the previous generation requests to finish
|
||||
while len(self.rid_to_state) > 0:
|
||||
await asyncio.sleep(0.001)
|
||||
self.send_to_scheduler.send_pyobj(obj)
|
||||
self.model_update_result = asyncio.Future()
|
||||
result = await self.model_update_result
|
||||
if result.success:
|
||||
self.server_args.model_path = obj.model_path
|
||||
self.server_args.load_format = obj.load_format
|
||||
self.model_path = obj.model_path
|
||||
return result.success, result.message
|
||||
|
||||
if self.server_args.dp_size == 1:
|
||||
async with self.model_update_lock:
|
||||
# wait for the previous generation requests to finish
|
||||
while len(self.rid_to_state) > 0:
|
||||
await asyncio.sleep(0.001)
|
||||
self.send_to_scheduler.send_pyobj(obj)
|
||||
self.model_update_result = asyncio.Future()
|
||||
result = await self.model_update_result
|
||||
if result.success:
|
||||
self.server_args.model_path = obj.model_path
|
||||
self.server_args.load_format = obj.load_format
|
||||
self.model_path = obj.model_path
|
||||
return result.success, result.message
|
||||
|
||||
else: # self.server_args.dp_size > 1
|
||||
|
||||
# There will be dp_size number of response from the detokenizer
|
||||
async with self.model_update_lock:
|
||||
# wait for the previous generation requests to finish
|
||||
while len(self.rid_to_state) > 0:
|
||||
await asyncio.sleep(0.001)
|
||||
self.send_to_scheduler.send_pyobj(obj)
|
||||
self.model_update_result = asyncio.Future()
|
||||
self.model_update_tmp = []
|
||||
result = await self.model_update_result
|
||||
|
||||
all_success = all([r.success for r in result])
|
||||
if all_success is True:
|
||||
self.server_args.model_path = obj.model_path
|
||||
self.server_args.load_format = obj.load_format
|
||||
self.model_path = obj.model_path
|
||||
all_message = [r.message for r in result]
|
||||
all_message = " | ".join(all_message)
|
||||
|
||||
return all_success, all_message
|
||||
|
||||
else:
|
||||
return False, "Another update is in progress. Please try again later."
|
||||
|
||||
@@ -600,7 +625,13 @@ class TokenizerManager:
|
||||
] = await self.recv_from_detokenizer.recv_pyobj()
|
||||
|
||||
if isinstance(recv_obj, UpdateWeightReqOutput):
|
||||
self.model_update_result.set_result(recv_obj)
|
||||
if self.server_args.dp_size == 1:
|
||||
self.model_update_result.set_result(recv_obj)
|
||||
else: # self.server_args.dp_size > 1
|
||||
self.model_update_tmp.append(recv_obj)
|
||||
# set future if the all results are recevied
|
||||
if len(self.model_update_tmp) == self.server_args.dp_size:
|
||||
self.model_update_result.set_result(self.model_update_tmp)
|
||||
continue
|
||||
elif isinstance(recv_obj, GetMemPoolSizeReqOutput):
|
||||
self.mem_pool_size.set_result(recv_obj)
|
||||
|
||||
Reference in New Issue
Block a user