Refactor tokenizer manager (#1846)
This commit is contained in:
@@ -549,22 +549,18 @@ class TokenizerManager:
|
||||
self.create_handle_loop()
|
||||
|
||||
req = GetMemPoolSizeReq()
|
||||
ret = None
|
||||
|
||||
self.send_to_scheduler.send_pyobj(req)
|
||||
self.mem_pool_size = asyncio.Future()
|
||||
|
||||
if self.server_args.dp_size == 1:
|
||||
self.send_to_scheduler.send_pyobj(req)
|
||||
self.mem_pool_size = asyncio.Future()
|
||||
res = await self.mem_pool_size
|
||||
ret = res.size
|
||||
|
||||
return res.size
|
||||
else: # self.server_args.dp_size > 1
|
||||
self.send_to_scheduler.send_pyobj(req)
|
||||
self.mem_pool_size = asyncio.Future()
|
||||
self.mem_pool_size_tmp = []
|
||||
res = await self.mem_pool_size
|
||||
ret = [r.size for r in res]
|
||||
|
||||
return ret
|
||||
return ret
|
||||
|
||||
async def update_weights(
|
||||
self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None
|
||||
@@ -578,29 +574,21 @@ class TokenizerManager:
|
||||
|
||||
if not self.model_update_lock.locked():
|
||||
|
||||
if self.server_args.dp_size == 1:
|
||||
async with self.model_update_lock:
|
||||
# wait for the previous generation requests to finish
|
||||
while len(self.rid_to_state) > 0:
|
||||
await asyncio.sleep(0.001)
|
||||
self.send_to_scheduler.send_pyobj(obj)
|
||||
self.model_update_result = asyncio.Future()
|
||||
async with self.model_update_lock:
|
||||
# wait for the previous generation requests to finish
|
||||
while len(self.rid_to_state) > 0:
|
||||
await asyncio.sleep(0.001)
|
||||
self.send_to_scheduler.send_pyobj(obj)
|
||||
self.model_update_result = asyncio.Future()
|
||||
|
||||
if self.server_args.dp_size == 1:
|
||||
result = await self.model_update_result
|
||||
if result.success:
|
||||
self.server_args.model_path = obj.model_path
|
||||
self.server_args.load_format = obj.load_format
|
||||
self.model_path = obj.model_path
|
||||
return result.success, result.message
|
||||
|
||||
else: # self.server_args.dp_size > 1
|
||||
|
||||
# There will be dp_size number of response from the detokenizer
|
||||
async with self.model_update_lock:
|
||||
# wait for the previous generation requests to finish
|
||||
while len(self.rid_to_state) > 0:
|
||||
await asyncio.sleep(0.001)
|
||||
self.send_to_scheduler.send_pyobj(obj)
|
||||
self.model_update_result = asyncio.Future()
|
||||
return result.success, result.message
|
||||
else: # self.server_args.dp_size > 1
|
||||
self.model_update_tmp = []
|
||||
result = await self.model_update_result
|
||||
|
||||
@@ -611,8 +599,7 @@ class TokenizerManager:
|
||||
self.model_path = obj.model_path
|
||||
all_message = [r.message for r in result]
|
||||
all_message = " | ".join(all_message)
|
||||
|
||||
return all_success, all_message
|
||||
return all_success, all_message
|
||||
|
||||
else:
|
||||
return False, "Another update is in progress. Please try again later."
|
||||
|
||||
Reference in New Issue
Block a user