Refactor tokenizer manager (#1846)
This commit is contained in:
@@ -549,22 +549,18 @@ class TokenizerManager:
|
|||||||
self.create_handle_loop()
|
self.create_handle_loop()
|
||||||
|
|
||||||
req = GetMemPoolSizeReq()
|
req = GetMemPoolSizeReq()
|
||||||
ret = None
|
|
||||||
|
self.send_to_scheduler.send_pyobj(req)
|
||||||
|
self.mem_pool_size = asyncio.Future()
|
||||||
|
|
||||||
if self.server_args.dp_size == 1:
|
if self.server_args.dp_size == 1:
|
||||||
self.send_to_scheduler.send_pyobj(req)
|
|
||||||
self.mem_pool_size = asyncio.Future()
|
|
||||||
res = await self.mem_pool_size
|
res = await self.mem_pool_size
|
||||||
ret = res.size
|
return res.size
|
||||||
|
|
||||||
else: # self.server_args.dp_size > 1
|
else: # self.server_args.dp_size > 1
|
||||||
self.send_to_scheduler.send_pyobj(req)
|
|
||||||
self.mem_pool_size = asyncio.Future()
|
|
||||||
self.mem_pool_size_tmp = []
|
self.mem_pool_size_tmp = []
|
||||||
res = await self.mem_pool_size
|
res = await self.mem_pool_size
|
||||||
ret = [r.size for r in res]
|
ret = [r.size for r in res]
|
||||||
|
return ret
|
||||||
return ret
|
|
||||||
|
|
||||||
async def update_weights(
|
async def update_weights(
|
||||||
self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None
|
self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None
|
||||||
@@ -578,29 +574,21 @@ class TokenizerManager:
|
|||||||
|
|
||||||
if not self.model_update_lock.locked():
|
if not self.model_update_lock.locked():
|
||||||
|
|
||||||
if self.server_args.dp_size == 1:
|
async with self.model_update_lock:
|
||||||
async with self.model_update_lock:
|
# wait for the previous generation requests to finish
|
||||||
# wait for the previous generation requests to finish
|
while len(self.rid_to_state) > 0:
|
||||||
while len(self.rid_to_state) > 0:
|
await asyncio.sleep(0.001)
|
||||||
await asyncio.sleep(0.001)
|
self.send_to_scheduler.send_pyobj(obj)
|
||||||
self.send_to_scheduler.send_pyobj(obj)
|
self.model_update_result = asyncio.Future()
|
||||||
self.model_update_result = asyncio.Future()
|
|
||||||
|
if self.server_args.dp_size == 1:
|
||||||
result = await self.model_update_result
|
result = await self.model_update_result
|
||||||
if result.success:
|
if result.success:
|
||||||
self.server_args.model_path = obj.model_path
|
self.server_args.model_path = obj.model_path
|
||||||
self.server_args.load_format = obj.load_format
|
self.server_args.load_format = obj.load_format
|
||||||
self.model_path = obj.model_path
|
self.model_path = obj.model_path
|
||||||
return result.success, result.message
|
return result.success, result.message
|
||||||
|
else: # self.server_args.dp_size > 1
|
||||||
else: # self.server_args.dp_size > 1
|
|
||||||
|
|
||||||
# There will be dp_size number of response from the detokenizer
|
|
||||||
async with self.model_update_lock:
|
|
||||||
# wait for the previous generation requests to finish
|
|
||||||
while len(self.rid_to_state) > 0:
|
|
||||||
await asyncio.sleep(0.001)
|
|
||||||
self.send_to_scheduler.send_pyobj(obj)
|
|
||||||
self.model_update_result = asyncio.Future()
|
|
||||||
self.model_update_tmp = []
|
self.model_update_tmp = []
|
||||||
result = await self.model_update_result
|
result = await self.model_update_result
|
||||||
|
|
||||||
@@ -611,8 +599,7 @@ class TokenizerManager:
|
|||||||
self.model_path = obj.model_path
|
self.model_path = obj.model_path
|
||||||
all_message = [r.message for r in result]
|
all_message = [r.message for r in result]
|
||||||
all_message = " | ".join(all_message)
|
all_message = " | ".join(all_message)
|
||||||
|
return all_success, all_message
|
||||||
return all_success, all_message
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
return False, "Another update is in progress. Please try again later."
|
return False, "Another update is in progress. Please try again later."
|
||||||
|
|||||||
Reference in New Issue
Block a user