[Feat] Support update weights without restart server (#1157)

This commit is contained in:
Shan Yu
2024-08-20 13:48:24 -07:00
committed by GitHub
parent 350a81609b
commit cd10654e7e
7 changed files with 303 additions and 13 deletions

View File

@@ -15,6 +15,7 @@ limitations under the License.
"""ModelRunner runs the forward passes of the models."""
import gc
import importlib
import importlib.resources
import logging
@@ -157,9 +158,9 @@ class ModelRunner:
self.server_args.dtype = "float16"
monkey_patch_vllm_dummy_weight_loader()
device_config = DeviceConfig()
load_config = LoadConfig(load_format=self.server_args.load_format)
vllm_model_config = VllmModelConfig(
self.device_config = DeviceConfig()
self.load_config = LoadConfig(load_format=self.server_args.load_format)
self.vllm_model_config = VllmModelConfig(
model=self.server_args.model_path,
quantization=self.server_args.quantization,
tokenizer=None,
@@ -173,17 +174,19 @@ class ModelRunner:
if is_llama3_405b_fp8_head_16(self.model_config) and self.tp_size <= 8:
# A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints
self.model_config.hf_config.num_key_value_heads = 8
vllm_model_config.hf_config.num_key_value_heads = 8
self.vllm_model_config.hf_config.num_key_value_heads = 8
monkey_patch_vllm_qvk_linear_loader()
self.dtype = vllm_model_config.dtype
self.dtype = self.vllm_model_config.dtype
if self.model_config.model_overide_args is not None:
vllm_model_config.hf_config.update(self.model_config.model_overide_args)
self.vllm_model_config.hf_config.update(
self.model_config.model_overide_args
)
self.model = get_model(
model_config=vllm_model_config,
device_config=device_config,
load_config=load_config,
model_config=self.vllm_model_config,
device_config=self.device_config,
load_config=self.load_config,
lora_config=None,
multimodal_config=None,
parallel_config=None,
@@ -206,6 +209,91 @@ class ModelRunner:
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
)
def update_weights(self, model_path, load_format):
from vllm.model_executor.model_loader.loader import (
DefaultModelLoader,
device_loading_context,
get_model_loader,
)
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
logger.info(
f"[gpu={self.gpu_id}] Update weights begin. "
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
)
target_device = torch.device(self.device_config.device)
try:
vllm_model_config = VllmModelConfig(
model=model_path,
quantization=self.server_args.quantization,
tokenizer=None,
tokenizer_mode=None,
trust_remote_code=self.server_args.trust_remote_code,
dtype=self.server_args.dtype,
seed=42,
skip_tokenizer_init=True,
)
except Exception as e:
logger.error(f"Failed to load model config: {e}")
return False, "Failed to update model weights"
load_config = LoadConfig(load_format=load_format)
# Only support vllm DefaultModelLoader for now
loader = get_model_loader(load_config)
if not isinstance(loader, DefaultModelLoader):
logger.error("Failed to get weights iterator: Unsupported loader")
return False, "Failed to update model weights"
def get_weight_iter(config):
iter = loader._get_weights_iterator(
config.model,
config.revision,
fall_back_to_pt=getattr(
self.model, "fall_back_to_pt_during_load", True
),
)
return iter
def model_load_weights(model, iter):
model.load_weights(iter)
for _, module in self.model.named_modules():
quant_method = getattr(module, "quant_method", None)
if quant_method is not None:
with device_loading_context(module, target_device):
quant_method.process_weights_after_loading(module)
return model
with set_default_torch_dtype(vllm_model_config.dtype):
try:
iter = get_weight_iter(vllm_model_config)
except Exception as e:
message = f"Failed to get weights iterator: {e}"
logger.error(message)
return False, message
try:
model = model_load_weights(self.model, iter)
except Exception as e:
message = f"Failed to update weights: {e}. \n Rolling back to original weights"
logger.error(message)
del iter
gc.collect()
iter = get_weight_iter(self.vllm_model_config)
self.model = model_load_weights(self.model, iter)
return False, message
self.model = model
self.server_args.model_path = model_path
self.server_args.load_format = load_format
self.vllm_model_config = vllm_model_config
self.load_config = load_config
self.model_config.path = model_path
logger.info(f"[gpu={self.gpu_id}] Update weights end.")
return True, "Succeeded to update model weights"
def profile_max_num_token(self, total_gpu_memory):
available_gpu_memory = get_available_gpu_memory(
self.gpu_id, distributed=self.tp_size > 1