Support updating weights at once by stopping all requests (#6698)

Signed-off-by: Tianyu Zhou <albert.zty@antgroup.com>
Co-authored-by: Zilin Zhu <zhuzilinallen@gmail.com>
This commit is contained in:
Albert
2025-07-03 13:26:06 +08:00
committed by GitHub
parent b044400dd3
commit d3c275b117
7 changed files with 190 additions and 13 deletions

View File

@@ -1,6 +1,8 @@
import json
import random
import time
import unittest
from concurrent.futures import ThreadPoolExecutor, as_completed
import requests
@@ -153,6 +155,82 @@ class TestServerUpdateWeightsFromDisk(CustomTestCase):
self.assertEqual(origin_response[:32], updated_response[:32])
class TestServerUpdateWeightsFromDiskAbortAllRequests(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=["--max-running-requests", 8],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def run_decode(self, max_new_tokens=32):
response = requests.post(
self.base_url + "/generate",
json={
"text": "The capital of France is",
"sampling_params": {
"temperature": 0,
"max_new_tokens": max_new_tokens,
"ignore_eos": True,
},
},
)
return response.json()
def get_model_info(self):
response = requests.get(self.base_url + "/get_model_info")
model_path = response.json()["model_path"]
print(json.dumps(response.json()))
return model_path
def run_update_weights(self, model_path, abort_all_requests=False):
response = requests.post(
self.base_url + "/update_weights_from_disk",
json={
"model_path": model_path,
"abort_all_requests": abort_all_requests,
},
)
ret = response.json()
print(json.dumps(ret))
return ret
def test_update_weights_abort_all_requests(self):
origin_model_path = self.get_model_info()
print(f"[Server Mode] origin_model_path: {origin_model_path}")
num_requests = 32
with ThreadPoolExecutor(num_requests) as executor:
futures = [
executor.submit(self.run_decode, 16000) for _ in range(num_requests)
]
# ensure the decode has been started
time.sleep(2)
new_model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST.replace("-Instruct", "")
ret = self.run_update_weights(new_model_path, abort_all_requests=True)
self.assertTrue(ret["success"])
for future in as_completed(futures):
self.assertEqual(
future.result()["meta_info"]["finish_reason"]["type"], "abort"
)
updated_model_path = self.get_model_info()
print(f"[Server Mode] updated_model_path: {updated_model_path}")
self.assertEqual(updated_model_path, new_model_path)
self.assertNotEqual(updated_model_path, origin_model_path)
###############################################################################
# Parameterized Tests for update_weights_from_disk
# Test coverage is determined based on the value of is_in_ci: