From 70b3c6eeb1e178c04149a5af27101eb28417b6cb Mon Sep 17 00:00:00 2001 From: Jhin <47354855+jhinpan@users.noreply.github.com> Date: Wed, 5 Mar 2025 14:25:18 -0600 Subject: [PATCH] Add update_weights_from_disk endpoint to Engine (#4102) Co-authored-by: zhaochenyang20 --- python/sglang/srt/entrypoints/engine.py | 22 ++ test/srt/test_update_weights_from_disk.py | 244 +++++++++++++++--- .../test_update_weights_from_distributed.py | 4 +- 3 files changed, 239 insertions(+), 31 deletions(-) diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 65c1f1b85..074691a4f 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -44,6 +44,7 @@ from sglang.srt.managers.io_struct import ( InitWeightsUpdateGroupReqInput, ReleaseMemoryOccupationReqInput, ResumeMemoryOccupationReqInput, + UpdateWeightFromDiskReqInput, UpdateWeightsFromDistributedReqInput, UpdateWeightsFromTensorReqInput, ) @@ -302,6 +303,27 @@ class Engine: self.tokenizer_manager.update_weights_from_tensor(obj, None) ) + def update_weights_from_disk( + self, + model_path: str, + load_format: Optional[str] = None, + ): + """Update the weights from disk inplace without re-launching the engine. + + This method allows updating the model weights from disk without restarting + the engine. It can be used to load a different model or update weights with + new training. + """ + obj = UpdateWeightFromDiskReqInput( + model_path=model_path, + load_format=load_format, + ) + + loop = asyncio.get_event_loop() + return loop.run_until_complete( + self.tokenizer_manager.update_weights_from_disk(obj, None) + ) + def get_weights_by_name(self, name: str, truncate_size: int = 100): """Get weights by parameter name.""" obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size) diff --git a/test/srt/test_update_weights_from_disk.py b/test/srt/test_update_weights_from_disk.py index 3b2dc0f6f..248525048 100644 --- a/test/srt/test_update_weights_from_disk.py +++ b/test/srt/test_update_weights_from_disk.py @@ -1,18 +1,76 @@ import json +import random import unittest import requests +import sglang as sgl from sglang.srt.utils import kill_process_tree from sglang.test.test_utils import ( DEFAULT_SMALL_MODEL_NAME_FOR_TEST, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, + is_in_ci, popen_launch_server, ) -class TestUpdateWeights(unittest.TestCase): +############################################################################### +# Engine Mode Tests (Single-configuration) +############################################################################### +class TestEngineUpdateWeightsFromDisk(unittest.TestCase): + def setUp(self): + self.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST + # Initialize the engine in offline (direct) mode. + self.engine = sgl.Engine(model_path=self.model) + + def tearDown(self): + self.engine.shutdown() + + def run_decode(self): + prompts = ["The capital of France is"] + sampling_params = {"temperature": 0, "max_new_tokens": 32} + outputs = self.engine.generate(prompts, sampling_params) + print("=" * 100) + print( + f"[Engine Mode] Prompt: {prompts[0]}\nGenerated text: {outputs[0]['text']}" + ) + return outputs[0]["text"] + + def run_update_weights(self, model_path): + ret = self.engine.update_weights_from_disk(model_path) + print(json.dumps(ret)) + return ret + + def test_update_weights(self): + origin_response = self.run_decode() + # Update weights: use new model (remove "-Instruct") + new_model_path = self.model.replace("-Instruct", "") + ret = self.run_update_weights(new_model_path) + self.assertTrue(ret[0]) # ret is a tuple; index 0 holds the success flag + + updated_response = self.run_decode() + self.assertNotEqual(origin_response[:32], updated_response[:32]) + + # Revert back to original weights + ret = self.run_update_weights(self.model) + self.assertTrue(ret[0]) + reverted_response = self.run_decode() + self.assertEqual(origin_response[:32], reverted_response[:32]) + + def test_update_weights_unexist_model(self): + origin_response = self.run_decode() + new_model_path = self.model.replace("-Instruct", "wrong") + ret = self.run_update_weights(new_model_path) + self.assertFalse(ret[0]) + updated_response = self.run_decode() + self.assertEqual(origin_response[:32], updated_response[:32]) + + +############################################################################### +# HTTP Server Mode Tests (Single-configuration) +############################################################################### +class TestServerUpdateWeightsFromDisk(unittest.TestCase): @classmethod def setUpClass(cls): cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST @@ -30,16 +88,12 @@ class TestUpdateWeights(unittest.TestCase): self.base_url + "/generate", json={ "text": "The capital of France is", - "sampling_params": { - "temperature": 0, - "max_new_tokens": 32, - }, + "sampling_params": {"temperature": 0, "max_new_tokens": 32}, }, ) - print(json.dumps(response.json())) print("=" * 100) - text = response.json()["text"] - return text + print(f"[Server Mode] Generated text: {response.json()['text']}") + return response.json()["text"] def get_model_info(self): response = requests.get(self.base_url + "/get_model_info") @@ -50,58 +104,188 @@ class TestUpdateWeights(unittest.TestCase): def run_update_weights(self, model_path): response = requests.post( self.base_url + "/update_weights_from_disk", - json={ - "model_path": model_path, - }, + json={"model_path": model_path}, ) ret = response.json() - print(json.dumps(response.json())) + print(json.dumps(ret)) return ret def test_update_weights(self): origin_model_path = self.get_model_info() - print(f"origin_model_path: {origin_model_path}") + print(f"[Server Mode] origin_model_path: {origin_model_path}") origin_response = self.run_decode() - # update weights new_model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST.replace("-Instruct", "") ret = self.run_update_weights(new_model_path) - assert ret["success"] + self.assertTrue(ret["success"]) updated_model_path = self.get_model_info() - print(f"updated_model_path: {updated_model_path}") - assert updated_model_path == new_model_path - assert updated_model_path != origin_model_path + print(f"[Server Mode] updated_model_path: {updated_model_path}") + self.assertEqual(updated_model_path, new_model_path) + self.assertNotEqual(updated_model_path, origin_model_path) updated_response = self.run_decode() - assert origin_response[:32] != updated_response[:32] + self.assertNotEqual(origin_response[:32], updated_response[:32]) - # update weights back ret = self.run_update_weights(origin_model_path) - assert ret["success"] - + self.assertTrue(ret["success"]) updated_model_path = self.get_model_info() - assert updated_model_path == origin_model_path + self.assertEqual(updated_model_path, origin_model_path) updated_response = self.run_decode() - assert origin_response[:32] == updated_response[:32] + self.assertEqual(origin_response[:32], updated_response[:32]) def test_update_weights_unexist_model(self): origin_model_path = self.get_model_info() - print(f"origin_model_path: {origin_model_path}") + print(f"[Server Mode] origin_model_path: {origin_model_path}") origin_response = self.run_decode() - # update weights new_model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST.replace("-Instruct", "wrong") ret = self.run_update_weights(new_model_path) - assert not ret["success"] + self.assertFalse(ret["success"]) updated_model_path = self.get_model_info() - print(f"updated_model_path: {updated_model_path}") - assert updated_model_path == origin_model_path + print(f"[Server Mode] updated_model_path: {updated_model_path}") + self.assertEqual(updated_model_path, origin_model_path) updated_response = self.run_decode() - assert origin_response[:32] == updated_response[:32] + self.assertEqual(origin_response[:32], updated_response[:32]) + + +############################################################################### +# Parameterized Tests for update_weights_from_disk +# Test coverage is determined based on the value of is_in_ci: +# - In a CI environment: randomly select one mode (Engine or Server) and test only with tp=1, dp=1. +# - In a non-CI environment: test both Engine and Server modes, and enumerate all combinations +# with tp and dp ranging from 1 to 2. +############################################################################### +class TestUpdateWeightsFromDiskParameterized(unittest.TestCase): + def run_common_test(self, mode, tp, dp): + """ + Common test procedure for update_weights_from_disk. + For Engine mode, we instantiate the engine with tp_size=tp. + For Server mode, we launch the server with additional arguments for tp (dp is not used in server launch here). + """ + if mode == "Engine": + # Instantiate engine with additional parameter tp_size. + print(f"[Parameterized Engine] Testing with tp={tp}, dp={dp}") + engine = sgl.Engine( + model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST, + random_seed=42, + tp_size=tp, + # dp parameter is not explicitly used in this API. + ) + try: + origin_response = self._engine_update_weights_test(engine) + finally: + engine.shutdown() + elif mode == "Server": + print(f"[Parameterized Server] Testing with tp={tp}, dp={dp}") + # Pass additional arguments to launch the server. + base_args = ["--tp-size", str(tp)] + process = popen_launch_server( + DEFAULT_SMALL_MODEL_NAME_FOR_TEST, + DEFAULT_URL_FOR_TEST, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=base_args, + ) + try: + origin_response = self._server_update_weights_test(DEFAULT_URL_FOR_TEST) + finally: + kill_process_tree(process.pid) + else: + raise ValueError(f"Unknown mode: {mode}") + + def _engine_update_weights_test(self, engine): + # Run the update weights test on the given engine instance. + def run_decode(): + prompts = ["The capital of France is"] + sampling_params = {"temperature": 0, "max_new_tokens": 32} + outputs = engine.generate(prompts, sampling_params) + print("=" * 100) + print( + f"[Parameterized Engine] Prompt: {prompts[0]}\nGenerated text: {outputs[0]['text']}" + ) + return outputs[0]["text"] + + def run_update_weights(model_path): + ret = engine.update_weights_from_disk(model_path) + print(json.dumps(ret)) + return ret + + origin_response = run_decode() + new_model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST.replace("-Instruct", "") + ret = run_update_weights(new_model_path) + self.assertTrue(ret[0]) + updated_response = run_decode() + self.assertNotEqual(origin_response[:32], updated_response[:32]) + ret = run_update_weights(DEFAULT_SMALL_MODEL_NAME_FOR_TEST) + self.assertTrue(ret[0]) + reverted_response = run_decode() + self.assertEqual(origin_response[:32], reverted_response[:32]) + return origin_response + + def _server_update_weights_test(self, base_url): + def run_decode(): + response = requests.post( + base_url + "/generate", + json={ + "text": "The capital of France is", + "sampling_params": {"temperature": 0, "max_new_tokens": 32}, + }, + ) + print("=" * 100) + print(f"[Parameterized Server] Generated text: {response.json()['text']}") + return response.json()["text"] + + def get_model_info(): + response = requests.get(base_url + "/get_model_info") + model_path = response.json()["model_path"] + print(json.dumps(response.json())) + return model_path + + def run_update_weights(model_path): + response = requests.post( + base_url + "/update_weights_from_disk", + json={"model_path": model_path}, + ) + ret = response.json() + print(json.dumps(ret)) + return ret + + origin_model_path = get_model_info() + origin_response = run_decode() + new_model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST.replace("-Instruct", "") + ret = run_update_weights(new_model_path) + self.assertTrue(ret["success"]) + updated_model_path = get_model_info() + self.assertEqual(updated_model_path, new_model_path) + self.assertNotEqual(updated_model_path, origin_model_path) + updated_response = run_decode() + self.assertNotEqual(origin_response[:32], updated_response[:32]) + ret = run_update_weights(origin_model_path) + self.assertTrue(ret["success"]) + updated_model_path = get_model_info() + self.assertEqual(updated_model_path, origin_model_path) + reverted_response = run_decode() + self.assertEqual(origin_response[:32], reverted_response[:32]) + return origin_response + + def test_parameterized_update_weights(self): + if is_in_ci(): + # In CI, choose one random mode (Engine or Server) with tp=1, dp=1. + mode = random.choice(["Engine", "Server"]) + test_suits = [(1, 1, mode)] + else: + # Otherwise, test both modes and enumerate tp,dp combinations from 1 to 2. + test_suits = [] + for mode in ["Engine", "Server"]: + for tp in [1, 2]: + for dp in [1, 2]: + test_suits.append((tp, dp, mode)) + for tp, dp, mode in test_suits: + with self.subTest(mode=mode, tp=tp, dp=dp): + self.run_common_test(mode, tp, dp) if __name__ == "__main__": diff --git a/test/srt/test_update_weights_from_distributed.py b/test/srt/test_update_weights_from_distributed.py index 7acbe9fb3..fc15efcfe 100644 --- a/test/srt/test_update_weights_from_distributed.py +++ b/test/srt/test_update_weights_from_distributed.py @@ -15,6 +15,7 @@ distributed setup. import gc import os +import random import time import unittest @@ -529,8 +530,9 @@ class TestUpdateWeightsFromDistributed(unittest.TestCase): assert torch.cuda.device_count() >= 2, "At least 2 GPUs are required" # test_suits : tp, dp, model_name, backend if is_in_ci(): + mode = random.choice(["Engine", "Server"]) test_suits = [ - (1, 1, DEFAULT_SMALL_MODEL_NAME_FOR_TEST, "Engine"), + (1, 1, DEFAULT_SMALL_MODEL_NAME_FOR_TEST, mode), ] else: test_suits = [