Better unit tests for adding a new model (#1488)

This commit is contained in:
Lianmin Zheng
2024-09-22 01:50:37 -07:00
committed by GitHub
parent 441c22db8c
commit 167591e864
8 changed files with 157 additions and 126 deletions

View File

@@ -12,7 +12,7 @@ from sglang.test.test_utils import (
)
class TestReplaceWeights(unittest.TestCase):
class TestUpdateWeights(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
@@ -33,13 +33,7 @@ class TestReplaceWeights(unittest.TestCase):
"sampling_params": {
"temperature": 0,
"max_new_tokens": 32,
"n": 1,
},
"stream": False,
"return_logprob": False,
"top_logprobs_num": 0,
"return_text_in_logprobs": False,
"logprob_start_len": 0,
},
)
print(json.dumps(response.json()))
@@ -64,7 +58,7 @@ class TestReplaceWeights(unittest.TestCase):
print(json.dumps(response.json()))
return ret
def test_replace_weights(self):
def test_update_weights(self):
origin_model_path = self.get_model_info()
print(f"origin_model_path: {origin_model_path}")
origin_response = self.run_decode()
@@ -92,7 +86,7 @@ class TestReplaceWeights(unittest.TestCase):
updated_response = self.run_decode()
assert origin_response[:32] == updated_response[:32]
def test_replace_weights_unexist_model(self):
def test_update_weights_unexist_model(self):
origin_model_path = self.get_model_info()
print(f"origin_model_path: {origin_model_path}")
origin_response = self.run_decode()