diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 347e7ad1d..428bf10d7 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -554,18 +554,43 @@ class TokenizerManager: obj.load_format = self.server_args.load_format if not self.model_update_lock.locked(): - async with self.model_update_lock: - # wait for the previous generation requests to finish - while len(self.rid_to_state) > 0: - await asyncio.sleep(0.001) - self.send_to_scheduler.send_pyobj(obj) - self.model_update_result = asyncio.Future() - result = await self.model_update_result - if result.success: - self.server_args.model_path = obj.model_path - self.server_args.load_format = obj.load_format - self.model_path = obj.model_path - return result.success, result.message + + if self.server_args.dp_size == 1: + async with self.model_update_lock: + # wait for the previous generation requests to finish + while len(self.rid_to_state) > 0: + await asyncio.sleep(0.001) + self.send_to_scheduler.send_pyobj(obj) + self.model_update_result = asyncio.Future() + result = await self.model_update_result + if result.success: + self.server_args.model_path = obj.model_path + self.server_args.load_format = obj.load_format + self.model_path = obj.model_path + return result.success, result.message + + else: # self.server_args.dp_size > 1 + + # There will be dp_size number of response from the detokenizer + async with self.model_update_lock: + # wait for the previous generation requests to finish + while len(self.rid_to_state) > 0: + await asyncio.sleep(0.001) + self.send_to_scheduler.send_pyobj(obj) + self.model_update_result = asyncio.Future() + self.model_update_tmp = [] + result = await self.model_update_result + + all_success = all([r.success for r in result]) + if all_success is True: + self.server_args.model_path = obj.model_path + self.server_args.load_format = obj.load_format + self.model_path = obj.model_path + all_message = [r.message for r in result] + all_message = " | ".join(all_message) + + return all_success, all_message + else: return False, "Another update is in progress. Please try again later." @@ -600,7 +625,13 @@ class TokenizerManager: ] = await self.recv_from_detokenizer.recv_pyobj() if isinstance(recv_obj, UpdateWeightReqOutput): - self.model_update_result.set_result(recv_obj) + if self.server_args.dp_size == 1: + self.model_update_result.set_result(recv_obj) + else: # self.server_args.dp_size > 1 + self.model_update_tmp.append(recv_obj) + # set future if the all results are recevied + if len(self.model_update_tmp) == self.server_args.dp_size: + self.model_update_result.set_result(self.model_update_tmp) continue elif isinstance(recv_obj, GetMemPoolSizeReqOutput): self.mem_pool_size.set_result(recv_obj) diff --git a/test/srt/test_data_parallelism.py b/test/srt/test_data_parallelism.py index 5f17994a2..00bae0a88 100644 --- a/test/srt/test_data_parallelism.py +++ b/test/srt/test_data_parallelism.py @@ -1,6 +1,9 @@ +import time import unittest from types import SimpleNamespace +import requests + from sglang.srt.utils import kill_child_process from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( @@ -39,6 +42,26 @@ class TestDataParallelism(unittest.TestCase): metrics = run_eval(args) assert metrics["score"] >= 0.65 + def test_update_weight(self): + response = requests.post( + self.base_url + "/update_weights", + json={"model_path": DEFAULT_MODEL_NAME_FOR_TEST}, + ) + + # check if the response is 200 + assert response.status_code == 200 + + # pause a few seconds then send again + time.sleep(5) + + response = requests.post( + self.base_url + "/update_weights", + json={"model_path": DEFAULT_MODEL_NAME_FOR_TEST}, + ) + + # check if the response is 200 + assert response.status_code == 200 + if __name__ == "__main__": unittest.main()