diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 9b990f11c..785b18165 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -549,22 +549,18 @@ class TokenizerManager: self.create_handle_loop() req = GetMemPoolSizeReq() - ret = None + + self.send_to_scheduler.send_pyobj(req) + self.mem_pool_size = asyncio.Future() if self.server_args.dp_size == 1: - self.send_to_scheduler.send_pyobj(req) - self.mem_pool_size = asyncio.Future() res = await self.mem_pool_size - ret = res.size - + return res.size else: # self.server_args.dp_size > 1 - self.send_to_scheduler.send_pyobj(req) - self.mem_pool_size = asyncio.Future() self.mem_pool_size_tmp = [] res = await self.mem_pool_size ret = [r.size for r in res] - - return ret + return ret async def update_weights( self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None @@ -578,29 +574,21 @@ class TokenizerManager: if not self.model_update_lock.locked(): - if self.server_args.dp_size == 1: - async with self.model_update_lock: - # wait for the previous generation requests to finish - while len(self.rid_to_state) > 0: - await asyncio.sleep(0.001) - self.send_to_scheduler.send_pyobj(obj) - self.model_update_result = asyncio.Future() + async with self.model_update_lock: + # wait for the previous generation requests to finish + while len(self.rid_to_state) > 0: + await asyncio.sleep(0.001) + self.send_to_scheduler.send_pyobj(obj) + self.model_update_result = asyncio.Future() + + if self.server_args.dp_size == 1: result = await self.model_update_result if result.success: self.server_args.model_path = obj.model_path self.server_args.load_format = obj.load_format self.model_path = obj.model_path - return result.success, result.message - - else: # self.server_args.dp_size > 1 - - # There will be dp_size number of response from the detokenizer - async with self.model_update_lock: - # wait for the previous generation requests to finish - while len(self.rid_to_state) > 0: - await asyncio.sleep(0.001) - self.send_to_scheduler.send_pyobj(obj) - self.model_update_result = asyncio.Future() + return result.success, result.message + else: # self.server_args.dp_size > 1 self.model_update_tmp = [] result = await self.model_update_result @@ -611,8 +599,7 @@ class TokenizerManager: self.model_path = obj.model_path all_message = [r.message for r in result] all_message = " | ".join(all_message) - - return all_success, all_message + return all_success, all_message else: return False, "Another update is in progress. Please try again later."