Clean up model loader (#1440)
This commit is contained in:
@@ -44,7 +44,6 @@ class TestReplaceWeights(unittest.TestCase):
|
||||
)
|
||||
print(json.dumps(response.json()))
|
||||
print("=" * 100)
|
||||
# return the "text" in response
|
||||
text = response.json()["text"]
|
||||
return text
|
||||
|
||||
@@ -61,7 +60,9 @@ class TestReplaceWeights(unittest.TestCase):
|
||||
"model_path": model_path,
|
||||
},
|
||||
)
|
||||
ret = response.json()
|
||||
print(json.dumps(response.json()))
|
||||
return ret
|
||||
|
||||
def test_replace_weights(self):
|
||||
origin_model_path = self.get_model_info()
|
||||
@@ -70,7 +71,8 @@ class TestReplaceWeights(unittest.TestCase):
|
||||
|
||||
# update weights
|
||||
new_model_path = "meta-llama/Meta-Llama-3.1-8B"
|
||||
self.run_update_weights(new_model_path)
|
||||
ret = self.run_update_weights(new_model_path)
|
||||
assert ret["success"]
|
||||
|
||||
updated_model_path = self.get_model_info()
|
||||
print(f"updated_model_path: {updated_model_path}")
|
||||
@@ -81,7 +83,9 @@ class TestReplaceWeights(unittest.TestCase):
|
||||
assert origin_response[:32] != updated_response[:32]
|
||||
|
||||
# update weights back
|
||||
self.run_update_weights(origin_model_path)
|
||||
ret = self.run_update_weights(origin_model_path)
|
||||
assert ret["success"]
|
||||
|
||||
updated_model_path = self.get_model_info()
|
||||
assert updated_model_path == origin_model_path
|
||||
|
||||
@@ -95,7 +99,8 @@ class TestReplaceWeights(unittest.TestCase):
|
||||
|
||||
# update weights
|
||||
new_model_path = "meta-llama/Meta-Llama-3.1-8B-1"
|
||||
self.run_update_weights(new_model_path)
|
||||
ret = self.run_update_weights(new_model_path)
|
||||
assert not ret["success"]
|
||||
|
||||
updated_model_path = self.get_model_info()
|
||||
print(f"updated_model_path: {updated_model_path}")
|
||||
|
||||
Reference in New Issue
Block a user