support custom weight loader for model runner (#7122)
Co-authored-by: kavioyu <kavioyu@tencent.com>
This commit is contained in:
@@ -93,6 +93,7 @@ from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
||||
from sglang.srt.utils import (
|
||||
MultiprocessingSerializer,
|
||||
cpu_has_amx_support,
|
||||
dynamic_import,
|
||||
enable_show_time_cost,
|
||||
get_available_gpu_memory,
|
||||
get_bool_env_var,
|
||||
@@ -761,6 +762,9 @@ class ModelRunner:
|
||||
]
|
||||
if load_format == "direct":
|
||||
_model_load_weights_direct(self.model, named_tensors)
|
||||
elif load_format in self.server_args.custom_weight_loader:
|
||||
custom_loader = dynamic_import(load_format)
|
||||
custom_loader(self.model, named_tensors)
|
||||
elif load_format is None:
|
||||
self.model.load_weights(named_tensors)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user