Clean up model loader (#1440)
This commit is contained in:
@@ -54,11 +54,9 @@ from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import (
|
||||
get_available_gpu_memory,
|
||||
is_generation_model,
|
||||
is_llama3_405b_fp8_head_16,
|
||||
is_multimodal_model,
|
||||
monkey_patch_vllm_dummy_weight_loader,
|
||||
monkey_patch_vllm_p2p_access_check,
|
||||
monkey_patch_vllm_qvk_linear_loader,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -166,10 +164,13 @@ class ModelRunner:
|
||||
return min_per_gpu_memory
|
||||
|
||||
def load_model(self):
|
||||
torch.set_num_threads(1)
|
||||
logger.info(
|
||||
f"Load weight begin. avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
|
||||
)
|
||||
|
||||
# This can reduce thread conflicts and speed up weight loading.
|
||||
torch.set_num_threads(1)
|
||||
|
||||
if torch.cuda.get_device_capability()[0] < 8:
|
||||
logger.info(
|
||||
"Compute capability below sm80. Use float16 due to lack of bfloat16 support."
|
||||
@@ -178,6 +179,7 @@ class ModelRunner:
|
||||
if torch.cuda.get_device_capability()[1] < 5:
|
||||
raise RuntimeError("SGLang only supports sm75 and above.")
|
||||
|
||||
# Prepare the vllm model config
|
||||
monkey_patch_vllm_dummy_weight_loader()
|
||||
self.device_config = DeviceConfig()
|
||||
self.load_config = LoadConfig(load_format=self.server_args.load_format)
|
||||
@@ -188,23 +190,16 @@ class ModelRunner:
|
||||
tokenizer_mode=None,
|
||||
trust_remote_code=self.server_args.trust_remote_code,
|
||||
dtype=self.server_args.dtype,
|
||||
seed=42,
|
||||
seed=self.server_args.random_seed,
|
||||
skip_tokenizer_init=True,
|
||||
)
|
||||
|
||||
# A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints
|
||||
# Drop this after Sept, 2024.
|
||||
if is_llama3_405b_fp8_head_16(self.model_config) and self.tp_size <= 8:
|
||||
self.model_config.hf_config.num_key_value_heads = 8
|
||||
self.vllm_model_config.hf_config.num_key_value_heads = 8
|
||||
monkey_patch_vllm_qvk_linear_loader()
|
||||
|
||||
self.dtype = self.vllm_model_config.dtype
|
||||
if self.model_config.model_override_args is not None:
|
||||
self.vllm_model_config.hf_config.update(
|
||||
self.model_config.model_override_args
|
||||
)
|
||||
self.dtype = self.vllm_model_config.dtype
|
||||
|
||||
# Load the model
|
||||
self.model = get_model(
|
||||
model_config=self.vllm_model_config,
|
||||
load_config=self.load_config,
|
||||
@@ -255,20 +250,20 @@ class ModelRunner:
|
||||
tokenizer_mode=None,
|
||||
trust_remote_code=self.server_args.trust_remote_code,
|
||||
dtype=self.server_args.dtype,
|
||||
seed=42,
|
||||
seed=self.server_args.random_seed,
|
||||
skip_tokenizer_init=True,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load model config: {e}")
|
||||
return False, "Failed to update model weights"
|
||||
message = f"Failed to load model config: {e}."
|
||||
return False, message
|
||||
|
||||
load_config = LoadConfig(load_format=load_format)
|
||||
|
||||
# Only support vllm DefaultModelLoader for now
|
||||
loader = get_model_loader(load_config)
|
||||
if not isinstance(loader, DefaultModelLoader):
|
||||
logger.error("Failed to get weights iterator: Unsupported loader")
|
||||
return False, "Failed to update model weights"
|
||||
message = f"Failed to get model loader: {loader}."
|
||||
return False, message
|
||||
|
||||
def get_weight_iter(config):
|
||||
iter = loader._get_weights_iterator(
|
||||
@@ -293,14 +288,14 @@ class ModelRunner:
|
||||
try:
|
||||
iter = get_weight_iter(vllm_model_config)
|
||||
except Exception as e:
|
||||
message = f"Failed to get weights iterator: {e}"
|
||||
logger.error(message)
|
||||
message = f"Failed to get weights iterator: {e}."
|
||||
return False, message
|
||||
try:
|
||||
model = model_load_weights(self.model, iter)
|
||||
except Exception as e:
|
||||
message = f"Failed to update weights: {e}. \n Rolling back to original weights"
|
||||
logger.error(message)
|
||||
message = (
|
||||
f"Failed to update weights: {e}.\nRolling back to original weights."
|
||||
)
|
||||
del iter
|
||||
gc.collect()
|
||||
iter = get_weight_iter(self.vllm_model_config)
|
||||
@@ -315,7 +310,7 @@ class ModelRunner:
|
||||
self.model_config.path = model_path
|
||||
|
||||
logger.info("Update weights end.")
|
||||
return True, "Succeeded to update model weights"
|
||||
return True, "Succeeded to update model weights."
|
||||
|
||||
def init_lora_manager(self):
|
||||
self.lora_manager = LoRAManager(
|
||||
|
||||
Reference in New Issue
Block a user