diff --git a/test/srt/test_get_parameter_by_name.py b/test/srt/test_get_parameter_by_name.py index 73b0a3f74..8dce1ac2c 100644 --- a/test/srt/test_get_parameter_by_name.py +++ b/test/srt/test_get_parameter_by_name.py @@ -16,7 +16,7 @@ from sglang.test.test_utils import ( from sglang.utils import terminate_process -class TestUpdateWeights(unittest.TestCase): +class TestGetParameterByName(unittest.TestCase): @classmethod def setUpClass(cls): cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST @@ -64,7 +64,7 @@ class TestUpdateWeights(unittest.TestCase): if self.process: terminate_process(self.process) - def assert_update_weights_all_close(self, param_name, truncate_size): + def assert_weights_all_close(self, param_name, truncate_size): print( f"param_name: {param_name}, backend: {self.backend}, dp: {self.dp}, tp: {self.tp}" ) @@ -87,12 +87,12 @@ class TestUpdateWeights(unittest.TestCase): @staticmethod def _process_return(ret): if isinstance(ret, list) and len(ret) == 2: - print(f"running assert_allclose on data parallel") + print("running assert_allclose on data parallel") np.testing.assert_allclose(ret[0], ret[1]) return np.array(ret[0]) return np.array(ret) - def test_update_weights_unexist_model(self): + def test_get_parameters_by_name(self): test_suits = [("Engine", 1, 1), ("Runtime", 1, 1)] if torch.cuda.device_count() >= 2: @@ -120,7 +120,7 @@ class TestUpdateWeights(unittest.TestCase): for test_suit in test_suits: self.init_backend(*test_suit) for param_name in parameters: - self.assert_update_weights_all_close(param_name, 100) + self.assert_weights_all_close(param_name, 100) self.close_engine_and_server()