[Feat] Support update weights without restart server (#1157)
This commit is contained in:
@@ -28,6 +28,7 @@ from sglang.srt.managers.io_struct import (
|
|||||||
BatchEmbeddingOut,
|
BatchEmbeddingOut,
|
||||||
BatchStrOut,
|
BatchStrOut,
|
||||||
BatchTokenIDOut,
|
BatchTokenIDOut,
|
||||||
|
UpdateWeightReqOutput,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR
|
from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR
|
||||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||||
@@ -84,6 +85,10 @@ class DetokenizerManager:
|
|||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if isinstance(recv_obj, UpdateWeightReqOutput):
|
||||||
|
self.send_to_tokenizer.send_pyobj(recv_obj)
|
||||||
|
continue
|
||||||
|
|
||||||
assert isinstance(recv_obj, BatchTokenIDOut)
|
assert isinstance(recv_obj, BatchTokenIDOut)
|
||||||
bs = len(recv_obj.rids)
|
bs = len(recv_obj.rids)
|
||||||
|
|
||||||
|
|||||||
@@ -278,6 +278,20 @@ class FlushCacheReq:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class UpdateWeightReqInput:
|
||||||
|
# The model path with the new weights
|
||||||
|
model_path: str
|
||||||
|
# The format to load the weights
|
||||||
|
load_format: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class UpdateWeightReqOutput:
|
||||||
|
success: bool
|
||||||
|
message: str
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AbortReq:
|
class AbortReq:
|
||||||
# The request id
|
# The request id
|
||||||
|
|||||||
@@ -46,6 +46,8 @@ from sglang.srt.managers.io_struct import (
|
|||||||
GenerateReqInput,
|
GenerateReqInput,
|
||||||
TokenizedEmbeddingReqInput,
|
TokenizedEmbeddingReqInput,
|
||||||
TokenizedGenerateReqInput,
|
TokenizedGenerateReqInput,
|
||||||
|
UpdateWeightReqInput,
|
||||||
|
UpdateWeightReqOutput,
|
||||||
)
|
)
|
||||||
from sglang.srt.mm_utils import expand2square, process_anyres_image
|
from sglang.srt.mm_utils import expand2square, process_anyres_image
|
||||||
from sglang.srt.sampling_params import SamplingParams
|
from sglang.srt.sampling_params import SamplingParams
|
||||||
@@ -121,6 +123,10 @@ class TokenizerManager:
|
|||||||
self.to_create_loop = True
|
self.to_create_loop = True
|
||||||
self.rid_to_state: Dict[str, ReqState] = {}
|
self.rid_to_state: Dict[str, ReqState] = {}
|
||||||
|
|
||||||
|
# for update model weights
|
||||||
|
self.model_update_lock = asyncio.Lock()
|
||||||
|
self.model_update_result = None
|
||||||
|
|
||||||
async def get_pixel_values(self, image_data):
|
async def get_pixel_values(self, image_data):
|
||||||
aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None)
|
aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None)
|
||||||
grid_pinpoints = (
|
grid_pinpoints = (
|
||||||
@@ -146,6 +152,9 @@ class TokenizerManager:
|
|||||||
if self.to_create_loop:
|
if self.to_create_loop:
|
||||||
self.create_handle_loop()
|
self.create_handle_loop()
|
||||||
|
|
||||||
|
while self.model_update_lock.locked():
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
obj.post_init()
|
obj.post_init()
|
||||||
is_single = obj.is_single
|
is_single = obj.is_single
|
||||||
|
|
||||||
@@ -513,6 +522,30 @@ class TokenizerManager:
|
|||||||
req = FlushCacheReq()
|
req = FlushCacheReq()
|
||||||
self.send_to_router.send_pyobj(req)
|
self.send_to_router.send_pyobj(req)
|
||||||
|
|
||||||
|
async def update_weights(self, obj: UpdateWeightReqInput, request):
|
||||||
|
if self.to_create_loop:
|
||||||
|
self.create_handle_loop()
|
||||||
|
|
||||||
|
# default the load format to the server_args
|
||||||
|
if obj.load_format is None:
|
||||||
|
obj.load_format = self.server_args.load_format
|
||||||
|
|
||||||
|
if not self.model_update_lock.locked():
|
||||||
|
async with self.model_update_lock:
|
||||||
|
# wait for the previous generation requests to finish
|
||||||
|
while len(self.rid_to_state) > 0:
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
self.send_to_router.send_pyobj(obj)
|
||||||
|
self.model_update_result = asyncio.Future()
|
||||||
|
result = await self.model_update_result
|
||||||
|
if result.success:
|
||||||
|
self.server_args.model_path = obj.model_path
|
||||||
|
self.server_args.load_format = obj.load_format
|
||||||
|
self.model_path = obj.model_path
|
||||||
|
return result.success, result.message
|
||||||
|
else:
|
||||||
|
return False, "Another update is in progress. Please try again later."
|
||||||
|
|
||||||
def abort_request(self, rid: str):
|
def abort_request(self, rid: str):
|
||||||
if rid not in self.rid_to_state:
|
if rid not in self.rid_to_state:
|
||||||
return
|
return
|
||||||
@@ -541,12 +574,18 @@ class TokenizerManager:
|
|||||||
|
|
||||||
async def handle_loop(self):
|
async def handle_loop(self):
|
||||||
while True:
|
while True:
|
||||||
recv_obj: Union[BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut] = (
|
recv_obj: Union[
|
||||||
await self.recv_from_detokenizer.recv_pyobj()
|
BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut, UpdateWeightReqOutput
|
||||||
)
|
] = await self.recv_from_detokenizer.recv_pyobj()
|
||||||
|
|
||||||
|
if isinstance(recv_obj, UpdateWeightReqOutput):
|
||||||
|
self.model_update_result.set_result(recv_obj)
|
||||||
|
continue
|
||||||
|
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)
|
recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)
|
||||||
), f"Unexpected obj received: {type(recv_obj)}"
|
), f"Unexpected obj received: {type(recv_obj)}"
|
||||||
|
|
||||||
for i, rid in enumerate(recv_obj.rids):
|
for i, rid in enumerate(recv_obj.rids):
|
||||||
state = self.rid_to_state.get(rid, None)
|
state = self.rid_to_state.get(rid, None)
|
||||||
if state is None:
|
if state is None:
|
||||||
|
|||||||
@@ -39,6 +39,8 @@ from sglang.srt.managers.io_struct import (
|
|||||||
FlushCacheReq,
|
FlushCacheReq,
|
||||||
TokenizedEmbeddingReqInput,
|
TokenizedEmbeddingReqInput,
|
||||||
TokenizedGenerateReqInput,
|
TokenizedGenerateReqInput,
|
||||||
|
UpdateWeightReqInput,
|
||||||
|
UpdateWeightReqOutput,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.policy_scheduler import PolicyScheduler, PrefillAdder
|
from sglang.srt.managers.policy_scheduler import PolicyScheduler, PrefillAdder
|
||||||
from sglang.srt.managers.schedule_batch import (
|
from sglang.srt.managers.schedule_batch import (
|
||||||
@@ -214,6 +216,9 @@ class ModelTpServer:
|
|||||||
self.flush_cache()
|
self.flush_cache()
|
||||||
elif isinstance(recv_req, AbortReq):
|
elif isinstance(recv_req, AbortReq):
|
||||||
self.abort_request(recv_req)
|
self.abort_request(recv_req)
|
||||||
|
elif isinstance(recv_req, UpdateWeightReqInput):
|
||||||
|
success, message = self.update_weights(recv_req)
|
||||||
|
self.out_pyobjs.append(UpdateWeightReqOutput(success, message))
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid request: {recv_req}")
|
raise ValueError(f"Invalid request: {recv_req}")
|
||||||
|
|
||||||
@@ -773,12 +778,15 @@ class ModelTpServer:
|
|||||||
self.token_to_kv_pool.clear()
|
self.token_to_kv_pool.clear()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
logger.info("Cache flushed successfully!")
|
logger.info("Cache flushed successfully!")
|
||||||
|
if_success = True
|
||||||
else:
|
else:
|
||||||
logging.warning(
|
logging.warning(
|
||||||
f"Cache not flushed because there are pending requests. "
|
f"Cache not flushed because there are pending requests. "
|
||||||
f"#queue-req: {len(self.waiting_queue)}, "
|
f"#queue-req: {len(self.waiting_queue)}, "
|
||||||
f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
|
f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
|
||||||
)
|
)
|
||||||
|
if_success = False
|
||||||
|
return if_success
|
||||||
|
|
||||||
def abort_request(self, recv_req):
|
def abort_request(self, recv_req):
|
||||||
# Delete requests in the waiting queue
|
# Delete requests in the waiting queue
|
||||||
@@ -798,6 +806,15 @@ class ModelTpServer:
|
|||||||
req.finished_reason = FINISH_ABORT()
|
req.finished_reason = FINISH_ABORT()
|
||||||
break
|
break
|
||||||
|
|
||||||
|
def update_weights(self, recv_req):
|
||||||
|
success, message = self.model_runner.update_weights(
|
||||||
|
recv_req.model_path, recv_req.load_format
|
||||||
|
)
|
||||||
|
if success:
|
||||||
|
flash_cache_success = self.flush_cache()
|
||||||
|
assert flash_cache_success, "Cache flush failed after updating weights"
|
||||||
|
return success, message
|
||||||
|
|
||||||
|
|
||||||
def run_tp_server(
|
def run_tp_server(
|
||||||
gpu_id: int,
|
gpu_id: int,
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ limitations under the License.
|
|||||||
|
|
||||||
"""ModelRunner runs the forward passes of the models."""
|
"""ModelRunner runs the forward passes of the models."""
|
||||||
|
|
||||||
|
import gc
|
||||||
import importlib
|
import importlib
|
||||||
import importlib.resources
|
import importlib.resources
|
||||||
import logging
|
import logging
|
||||||
@@ -157,9 +158,9 @@ class ModelRunner:
|
|||||||
self.server_args.dtype = "float16"
|
self.server_args.dtype = "float16"
|
||||||
|
|
||||||
monkey_patch_vllm_dummy_weight_loader()
|
monkey_patch_vllm_dummy_weight_loader()
|
||||||
device_config = DeviceConfig()
|
self.device_config = DeviceConfig()
|
||||||
load_config = LoadConfig(load_format=self.server_args.load_format)
|
self.load_config = LoadConfig(load_format=self.server_args.load_format)
|
||||||
vllm_model_config = VllmModelConfig(
|
self.vllm_model_config = VllmModelConfig(
|
||||||
model=self.server_args.model_path,
|
model=self.server_args.model_path,
|
||||||
quantization=self.server_args.quantization,
|
quantization=self.server_args.quantization,
|
||||||
tokenizer=None,
|
tokenizer=None,
|
||||||
@@ -173,17 +174,19 @@ class ModelRunner:
|
|||||||
if is_llama3_405b_fp8_head_16(self.model_config) and self.tp_size <= 8:
|
if is_llama3_405b_fp8_head_16(self.model_config) and self.tp_size <= 8:
|
||||||
# A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints
|
# A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints
|
||||||
self.model_config.hf_config.num_key_value_heads = 8
|
self.model_config.hf_config.num_key_value_heads = 8
|
||||||
vllm_model_config.hf_config.num_key_value_heads = 8
|
self.vllm_model_config.hf_config.num_key_value_heads = 8
|
||||||
monkey_patch_vllm_qvk_linear_loader()
|
monkey_patch_vllm_qvk_linear_loader()
|
||||||
|
|
||||||
self.dtype = vllm_model_config.dtype
|
self.dtype = self.vllm_model_config.dtype
|
||||||
if self.model_config.model_overide_args is not None:
|
if self.model_config.model_overide_args is not None:
|
||||||
vllm_model_config.hf_config.update(self.model_config.model_overide_args)
|
self.vllm_model_config.hf_config.update(
|
||||||
|
self.model_config.model_overide_args
|
||||||
|
)
|
||||||
|
|
||||||
self.model = get_model(
|
self.model = get_model(
|
||||||
model_config=vllm_model_config,
|
model_config=self.vllm_model_config,
|
||||||
device_config=device_config,
|
device_config=self.device_config,
|
||||||
load_config=load_config,
|
load_config=self.load_config,
|
||||||
lora_config=None,
|
lora_config=None,
|
||||||
multimodal_config=None,
|
multimodal_config=None,
|
||||||
parallel_config=None,
|
parallel_config=None,
|
||||||
@@ -206,6 +209,91 @@ class ModelRunner:
|
|||||||
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
|
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def update_weights(self, model_path, load_format):
|
||||||
|
from vllm.model_executor.model_loader.loader import (
|
||||||
|
DefaultModelLoader,
|
||||||
|
device_loading_context,
|
||||||
|
get_model_loader,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[gpu={self.gpu_id}] Update weights begin. "
|
||||||
|
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
|
||||||
|
)
|
||||||
|
|
||||||
|
target_device = torch.device(self.device_config.device)
|
||||||
|
|
||||||
|
try:
|
||||||
|
vllm_model_config = VllmModelConfig(
|
||||||
|
model=model_path,
|
||||||
|
quantization=self.server_args.quantization,
|
||||||
|
tokenizer=None,
|
||||||
|
tokenizer_mode=None,
|
||||||
|
trust_remote_code=self.server_args.trust_remote_code,
|
||||||
|
dtype=self.server_args.dtype,
|
||||||
|
seed=42,
|
||||||
|
skip_tokenizer_init=True,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to load model config: {e}")
|
||||||
|
return False, "Failed to update model weights"
|
||||||
|
|
||||||
|
load_config = LoadConfig(load_format=load_format)
|
||||||
|
|
||||||
|
# Only support vllm DefaultModelLoader for now
|
||||||
|
loader = get_model_loader(load_config)
|
||||||
|
if not isinstance(loader, DefaultModelLoader):
|
||||||
|
logger.error("Failed to get weights iterator: Unsupported loader")
|
||||||
|
return False, "Failed to update model weights"
|
||||||
|
|
||||||
|
def get_weight_iter(config):
|
||||||
|
iter = loader._get_weights_iterator(
|
||||||
|
config.model,
|
||||||
|
config.revision,
|
||||||
|
fall_back_to_pt=getattr(
|
||||||
|
self.model, "fall_back_to_pt_during_load", True
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return iter
|
||||||
|
|
||||||
|
def model_load_weights(model, iter):
|
||||||
|
model.load_weights(iter)
|
||||||
|
for _, module in self.model.named_modules():
|
||||||
|
quant_method = getattr(module, "quant_method", None)
|
||||||
|
if quant_method is not None:
|
||||||
|
with device_loading_context(module, target_device):
|
||||||
|
quant_method.process_weights_after_loading(module)
|
||||||
|
return model
|
||||||
|
|
||||||
|
with set_default_torch_dtype(vllm_model_config.dtype):
|
||||||
|
try:
|
||||||
|
iter = get_weight_iter(vllm_model_config)
|
||||||
|
except Exception as e:
|
||||||
|
message = f"Failed to get weights iterator: {e}"
|
||||||
|
logger.error(message)
|
||||||
|
return False, message
|
||||||
|
try:
|
||||||
|
model = model_load_weights(self.model, iter)
|
||||||
|
except Exception as e:
|
||||||
|
message = f"Failed to update weights: {e}. \n Rolling back to original weights"
|
||||||
|
logger.error(message)
|
||||||
|
del iter
|
||||||
|
gc.collect()
|
||||||
|
iter = get_weight_iter(self.vllm_model_config)
|
||||||
|
self.model = model_load_weights(self.model, iter)
|
||||||
|
return False, message
|
||||||
|
|
||||||
|
self.model = model
|
||||||
|
self.server_args.model_path = model_path
|
||||||
|
self.server_args.load_format = load_format
|
||||||
|
self.vllm_model_config = vllm_model_config
|
||||||
|
self.load_config = load_config
|
||||||
|
self.model_config.path = model_path
|
||||||
|
|
||||||
|
logger.info(f"[gpu={self.gpu_id}] Update weights end.")
|
||||||
|
return True, "Succeeded to update model weights"
|
||||||
|
|
||||||
def profile_max_num_token(self, total_gpu_memory):
|
def profile_max_num_token(self, total_gpu_memory):
|
||||||
available_gpu_memory = get_available_gpu_memory(
|
available_gpu_memory = get_available_gpu_memory(
|
||||||
self.gpu_id, distributed=self.tp_size > 1
|
self.gpu_id, distributed=self.tp_size > 1
|
||||||
|
|||||||
@@ -51,7 +51,11 @@ from sglang.srt.managers.controller_single import (
|
|||||||
start_controller_process as start_controller_process_single,
|
start_controller_process as start_controller_process_single,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
|
from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
|
||||||
from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput
|
from sglang.srt.managers.io_struct import (
|
||||||
|
EmbeddingReqInput,
|
||||||
|
GenerateReqInput,
|
||||||
|
UpdateWeightReqInput,
|
||||||
|
)
|
||||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||||
from sglang.srt.openai_api.adapter import (
|
from sglang.srt.openai_api.adapter import (
|
||||||
load_chat_template_for_openai_api,
|
load_chat_template_for_openai_api,
|
||||||
@@ -136,6 +140,23 @@ async def flush_cache():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/update_weights")
|
||||||
|
async def update_weights(obj: UpdateWeightReqInput, request: Request):
|
||||||
|
|
||||||
|
success, message = await tokenizer_manager.update_weights(obj, request)
|
||||||
|
content = {"message": message, "success": str(success)}
|
||||||
|
if success:
|
||||||
|
return JSONResponse(
|
||||||
|
content,
|
||||||
|
status_code=HTTPStatus.OK,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return JSONResponse(
|
||||||
|
content,
|
||||||
|
status_code=HTTPStatus.BAD_REQUEST,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def generate_request(obj: GenerateReqInput, request: Request):
|
async def generate_request(obj: GenerateReqInput, request: Request):
|
||||||
"""Handle a generate request."""
|
"""Handle a generate request."""
|
||||||
if obj.stream:
|
if obj.stream:
|
||||||
|
|||||||
106
test/srt/test_update_weights.py
Normal file
106
test/srt/test_update_weights.py
Normal file
@@ -0,0 +1,106 @@
|
|||||||
|
import json
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from sglang.srt.utils import kill_child_process
|
||||||
|
from sglang.test.test_utils import (
|
||||||
|
DEFAULT_MODEL_NAME_FOR_TEST,
|
||||||
|
DEFAULT_URL_FOR_UNIT_TEST,
|
||||||
|
popen_launch_server,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestReplaceWeights(unittest.TestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
|
||||||
|
cls.base_url = DEFAULT_URL_FOR_UNIT_TEST
|
||||||
|
cls.process = popen_launch_server(cls.model, cls.base_url, timeout=300)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
kill_child_process(cls.process.pid)
|
||||||
|
|
||||||
|
def run_decode(self):
|
||||||
|
response = requests.post(
|
||||||
|
self.base_url + "/generate",
|
||||||
|
json={
|
||||||
|
"text": "The capital of France is",
|
||||||
|
"sampling_params": {
|
||||||
|
"temperature": 0,
|
||||||
|
"max_new_tokens": 32,
|
||||||
|
"n": 1,
|
||||||
|
},
|
||||||
|
"stream": False,
|
||||||
|
"return_logprob": False,
|
||||||
|
"top_logprobs_num": 0,
|
||||||
|
"return_text_in_logprobs": False,
|
||||||
|
"logprob_start_len": 0,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
print(json.dumps(response.json()))
|
||||||
|
print("=" * 100)
|
||||||
|
# return the "text" in response
|
||||||
|
text = response.json()["text"]
|
||||||
|
return text
|
||||||
|
|
||||||
|
def get_model_info(self):
|
||||||
|
response = requests.get(self.base_url + "/get_model_info")
|
||||||
|
model_path = response.json()["model_path"]
|
||||||
|
print(json.dumps(response.json()))
|
||||||
|
return model_path
|
||||||
|
|
||||||
|
def run_update_weights(self, model_path):
|
||||||
|
response = requests.post(
|
||||||
|
self.base_url + "/update_weights",
|
||||||
|
json={
|
||||||
|
"model_path": model_path,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
print(json.dumps(response.json()))
|
||||||
|
|
||||||
|
def test_replace_weights(self):
|
||||||
|
origin_model_path = self.get_model_info()
|
||||||
|
print(f"origin_model_path: {origin_model_path}")
|
||||||
|
origin_response = self.run_decode()
|
||||||
|
|
||||||
|
# update weights
|
||||||
|
new_model_path = "meta-llama/Meta-Llama-3.1-8B"
|
||||||
|
self.run_update_weights(new_model_path)
|
||||||
|
|
||||||
|
updated_model_path = self.get_model_info()
|
||||||
|
print(f"updated_model_path: {updated_model_path}")
|
||||||
|
assert updated_model_path == new_model_path
|
||||||
|
assert updated_model_path != origin_model_path
|
||||||
|
|
||||||
|
updated_response = self.run_decode()
|
||||||
|
assert origin_response[:32] != updated_response[:32]
|
||||||
|
|
||||||
|
# update weights back
|
||||||
|
self.run_update_weights(origin_model_path)
|
||||||
|
updated_model_path = self.get_model_info()
|
||||||
|
assert updated_model_path == origin_model_path
|
||||||
|
|
||||||
|
updated_response = self.run_decode()
|
||||||
|
assert origin_response[:32] == updated_response[:32]
|
||||||
|
|
||||||
|
def test_replace_weights_unexist_model(self):
|
||||||
|
origin_model_path = self.get_model_info()
|
||||||
|
print(f"origin_model_path: {origin_model_path}")
|
||||||
|
origin_response = self.run_decode()
|
||||||
|
|
||||||
|
# update weights
|
||||||
|
new_model_path = "meta-llama/Meta-Llama-3.1-8B-1"
|
||||||
|
self.run_update_weights(new_model_path)
|
||||||
|
|
||||||
|
updated_model_path = self.get_model_info()
|
||||||
|
print(f"updated_model_path: {updated_model_path}")
|
||||||
|
assert updated_model_path == origin_model_path
|
||||||
|
|
||||||
|
updated_response = self.run_decode()
|
||||||
|
assert origin_response[:32] == updated_response[:32]
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user