Fix update_weights deadlock for DP (#1825)
This commit is contained in:
@@ -554,18 +554,43 @@ class TokenizerManager:
|
||||
obj.load_format = self.server_args.load_format
|
||||
|
||||
if not self.model_update_lock.locked():
|
||||
async with self.model_update_lock:
|
||||
# wait for the previous generation requests to finish
|
||||
while len(self.rid_to_state) > 0:
|
||||
await asyncio.sleep(0.001)
|
||||
self.send_to_scheduler.send_pyobj(obj)
|
||||
self.model_update_result = asyncio.Future()
|
||||
result = await self.model_update_result
|
||||
if result.success:
|
||||
self.server_args.model_path = obj.model_path
|
||||
self.server_args.load_format = obj.load_format
|
||||
self.model_path = obj.model_path
|
||||
return result.success, result.message
|
||||
|
||||
if self.server_args.dp_size == 1:
|
||||
async with self.model_update_lock:
|
||||
# wait for the previous generation requests to finish
|
||||
while len(self.rid_to_state) > 0:
|
||||
await asyncio.sleep(0.001)
|
||||
self.send_to_scheduler.send_pyobj(obj)
|
||||
self.model_update_result = asyncio.Future()
|
||||
result = await self.model_update_result
|
||||
if result.success:
|
||||
self.server_args.model_path = obj.model_path
|
||||
self.server_args.load_format = obj.load_format
|
||||
self.model_path = obj.model_path
|
||||
return result.success, result.message
|
||||
|
||||
else: # self.server_args.dp_size > 1
|
||||
|
||||
# There will be dp_size number of response from the detokenizer
|
||||
async with self.model_update_lock:
|
||||
# wait for the previous generation requests to finish
|
||||
while len(self.rid_to_state) > 0:
|
||||
await asyncio.sleep(0.001)
|
||||
self.send_to_scheduler.send_pyobj(obj)
|
||||
self.model_update_result = asyncio.Future()
|
||||
self.model_update_tmp = []
|
||||
result = await self.model_update_result
|
||||
|
||||
all_success = all([r.success for r in result])
|
||||
if all_success is True:
|
||||
self.server_args.model_path = obj.model_path
|
||||
self.server_args.load_format = obj.load_format
|
||||
self.model_path = obj.model_path
|
||||
all_message = [r.message for r in result]
|
||||
all_message = " | ".join(all_message)
|
||||
|
||||
return all_success, all_message
|
||||
|
||||
else:
|
||||
return False, "Another update is in progress. Please try again later."
|
||||
|
||||
@@ -600,7 +625,13 @@ class TokenizerManager:
|
||||
] = await self.recv_from_detokenizer.recv_pyobj()
|
||||
|
||||
if isinstance(recv_obj, UpdateWeightReqOutput):
|
||||
self.model_update_result.set_result(recv_obj)
|
||||
if self.server_args.dp_size == 1:
|
||||
self.model_update_result.set_result(recv_obj)
|
||||
else: # self.server_args.dp_size > 1
|
||||
self.model_update_tmp.append(recv_obj)
|
||||
# set future if the all results are recevied
|
||||
if len(self.model_update_tmp) == self.server_args.dp_size:
|
||||
self.model_update_result.set_result(self.model_update_tmp)
|
||||
continue
|
||||
elif isinstance(recv_obj, GetMemPoolSizeReqOutput):
|
||||
self.mem_pool_size.set_result(recv_obj)
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
import time
|
||||
import unittest
|
||||
from types import SimpleNamespace
|
||||
|
||||
import requests
|
||||
|
||||
from sglang.srt.utils import kill_child_process
|
||||
from sglang.test.run_eval import run_eval
|
||||
from sglang.test.test_utils import (
|
||||
@@ -39,6 +42,26 @@ class TestDataParallelism(unittest.TestCase):
|
||||
metrics = run_eval(args)
|
||||
assert metrics["score"] >= 0.65
|
||||
|
||||
def test_update_weight(self):
|
||||
response = requests.post(
|
||||
self.base_url + "/update_weights",
|
||||
json={"model_path": DEFAULT_MODEL_NAME_FOR_TEST},
|
||||
)
|
||||
|
||||
# check if the response is 200
|
||||
assert response.status_code == 200
|
||||
|
||||
# pause a few seconds then send again
|
||||
time.sleep(5)
|
||||
|
||||
response = requests.post(
|
||||
self.base_url + "/update_weights",
|
||||
json={"model_path": DEFAULT_MODEL_NAME_FOR_TEST},
|
||||
)
|
||||
|
||||
# check if the response is 200
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user