Fix update_weights deadlock for DP (#1825)
This commit is contained in:
@@ -554,18 +554,43 @@ class TokenizerManager:
|
|||||||
obj.load_format = self.server_args.load_format
|
obj.load_format = self.server_args.load_format
|
||||||
|
|
||||||
if not self.model_update_lock.locked():
|
if not self.model_update_lock.locked():
|
||||||
async with self.model_update_lock:
|
|
||||||
# wait for the previous generation requests to finish
|
if self.server_args.dp_size == 1:
|
||||||
while len(self.rid_to_state) > 0:
|
async with self.model_update_lock:
|
||||||
await asyncio.sleep(0.001)
|
# wait for the previous generation requests to finish
|
||||||
self.send_to_scheduler.send_pyobj(obj)
|
while len(self.rid_to_state) > 0:
|
||||||
self.model_update_result = asyncio.Future()
|
await asyncio.sleep(0.001)
|
||||||
result = await self.model_update_result
|
self.send_to_scheduler.send_pyobj(obj)
|
||||||
if result.success:
|
self.model_update_result = asyncio.Future()
|
||||||
self.server_args.model_path = obj.model_path
|
result = await self.model_update_result
|
||||||
self.server_args.load_format = obj.load_format
|
if result.success:
|
||||||
self.model_path = obj.model_path
|
self.server_args.model_path = obj.model_path
|
||||||
return result.success, result.message
|
self.server_args.load_format = obj.load_format
|
||||||
|
self.model_path = obj.model_path
|
||||||
|
return result.success, result.message
|
||||||
|
|
||||||
|
else: # self.server_args.dp_size > 1
|
||||||
|
|
||||||
|
# There will be dp_size number of response from the detokenizer
|
||||||
|
async with self.model_update_lock:
|
||||||
|
# wait for the previous generation requests to finish
|
||||||
|
while len(self.rid_to_state) > 0:
|
||||||
|
await asyncio.sleep(0.001)
|
||||||
|
self.send_to_scheduler.send_pyobj(obj)
|
||||||
|
self.model_update_result = asyncio.Future()
|
||||||
|
self.model_update_tmp = []
|
||||||
|
result = await self.model_update_result
|
||||||
|
|
||||||
|
all_success = all([r.success for r in result])
|
||||||
|
if all_success is True:
|
||||||
|
self.server_args.model_path = obj.model_path
|
||||||
|
self.server_args.load_format = obj.load_format
|
||||||
|
self.model_path = obj.model_path
|
||||||
|
all_message = [r.message for r in result]
|
||||||
|
all_message = " | ".join(all_message)
|
||||||
|
|
||||||
|
return all_success, all_message
|
||||||
|
|
||||||
else:
|
else:
|
||||||
return False, "Another update is in progress. Please try again later."
|
return False, "Another update is in progress. Please try again later."
|
||||||
|
|
||||||
@@ -600,7 +625,13 @@ class TokenizerManager:
|
|||||||
] = await self.recv_from_detokenizer.recv_pyobj()
|
] = await self.recv_from_detokenizer.recv_pyobj()
|
||||||
|
|
||||||
if isinstance(recv_obj, UpdateWeightReqOutput):
|
if isinstance(recv_obj, UpdateWeightReqOutput):
|
||||||
self.model_update_result.set_result(recv_obj)
|
if self.server_args.dp_size == 1:
|
||||||
|
self.model_update_result.set_result(recv_obj)
|
||||||
|
else: # self.server_args.dp_size > 1
|
||||||
|
self.model_update_tmp.append(recv_obj)
|
||||||
|
# set future if the all results are recevied
|
||||||
|
if len(self.model_update_tmp) == self.server_args.dp_size:
|
||||||
|
self.model_update_result.set_result(self.model_update_tmp)
|
||||||
continue
|
continue
|
||||||
elif isinstance(recv_obj, GetMemPoolSizeReqOutput):
|
elif isinstance(recv_obj, GetMemPoolSizeReqOutput):
|
||||||
self.mem_pool_size.set_result(recv_obj)
|
self.mem_pool_size.set_result(recv_obj)
|
||||||
|
|||||||
@@ -1,6 +1,9 @@
|
|||||||
|
import time
|
||||||
import unittest
|
import unittest
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
from sglang.srt.utils import kill_child_process
|
from sglang.srt.utils import kill_child_process
|
||||||
from sglang.test.run_eval import run_eval
|
from sglang.test.run_eval import run_eval
|
||||||
from sglang.test.test_utils import (
|
from sglang.test.test_utils import (
|
||||||
@@ -39,6 +42,26 @@ class TestDataParallelism(unittest.TestCase):
|
|||||||
metrics = run_eval(args)
|
metrics = run_eval(args)
|
||||||
assert metrics["score"] >= 0.65
|
assert metrics["score"] >= 0.65
|
||||||
|
|
||||||
|
def test_update_weight(self):
|
||||||
|
response = requests.post(
|
||||||
|
self.base_url + "/update_weights",
|
||||||
|
json={"model_path": DEFAULT_MODEL_NAME_FOR_TEST},
|
||||||
|
)
|
||||||
|
|
||||||
|
# check if the response is 200
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
# pause a few seconds then send again
|
||||||
|
time.sleep(5)
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
self.base_url + "/update_weights",
|
||||||
|
json={"model_path": DEFAULT_MODEL_NAME_FOR_TEST},
|
||||||
|
)
|
||||||
|
|
||||||
|
# check if the response is 200
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user