diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 1847af151..6ef23af24 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -93,6 +93,7 @@ from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter from sglang.srt.utils import ( MultiprocessingSerializer, cpu_has_amx_support, + dynamic_import, enable_show_time_cost, get_available_gpu_memory, get_bool_env_var, @@ -761,6 +762,9 @@ class ModelRunner: ] if load_format == "direct": _model_load_weights_direct(self.model, named_tensors) + elif load_format in self.server_args.custom_weight_loader: + custom_loader = dynamic_import(load_format) + custom_loader(self.model, named_tensors) elif load_format is None: self.model.load_weights(named_tensors) else: diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 04b6f96cb..54e92d0bf 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -234,6 +234,9 @@ class ServerArgs: num_reserved_decode_tokens: int = 512 # used for decode kv cache offload in PD pdlb_url: Optional[str] = None + # For model weight update + custom_weight_loader: Optional[List[str]] = None + def __post_init__(self): # Expert parallelism if self.enable_ep_moe: @@ -538,6 +541,9 @@ class ServerArgs: "1" if self.disable_outlines_disk_cache else "0" ) + if self.custom_weight_loader is None: + self.custom_weight_loader = [] + def validate_disagg_tp_size(self, prefill_tp: int, decode_tp: int): larger_tp = max(decode_tp, prefill_tp) smaller_tp = min(decode_tp, prefill_tp) @@ -1576,6 +1582,13 @@ class ServerArgs: default=None, help="The URL of the PD disaggregation load balancer. If set, the prefill/decode server will register with the load balancer.", ) + parser.add_argument( + "--custom-weight-loader", + type=str, + nargs="*", + default=None, + help="The custom dataloader which used to update the model. Should be set with a valid import path, such as my_package.weight_load_func", + ) @classmethod def from_cli_args(cls, args: argparse.Namespace): diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index fe8dcbf8e..2184a4a94 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -2340,3 +2340,16 @@ class LazyValue: self._value = self._creator() self._creator = None return self._value + + +def dynamic_import(func_path: str): + parts = func_path.split(".") + if len(parts) < 2: + raise ValueError( + "func_path should contain both module name and func name (such as 'module.func')" + ) + module_path = ".".join(parts[:-1]) + func_name = parts[-1] + module = importlib.import_module(module_path) + func = getattr(module, func_name) + return func diff --git a/test/srt/test_update_weights_from_tensor.py b/test/srt/test_update_weights_from_tensor.py index 38187652b..a1ca7f4b0 100644 --- a/test/srt/test_update_weights_from_tensor.py +++ b/test/srt/test_update_weights_from_tensor.py @@ -78,6 +78,40 @@ class TestUpdateWeightsFromTensor(CustomTestCase): engine.shutdown() + def test_update_weights_from_tensor_load_format_custom(self): + custom_loader_name = ( + "sglang.srt.model_executor.model_runner._model_load_weights_direct" + ) + engine = sgl.Engine( + model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST, + custom_weight_loader=[custom_loader_name], + ) + + write_param_names = [ + f"model.layers.{i}.self_attn.qkv_proj.weight" for i in range(6, 16) + ] + read_param_names = [ + f"model.layers.{i}.self_attn.k_proj.weight" for i in range(6, 16) + ] + + _check_param( + engine, read_param_names[0], [-0.0198, 0.0227, 0.0168, 0.0232, -0.0178] + ) + + new_tensor = torch.full((3072, 2048), 1.5) + engine.update_weights_from_tensor( + [ + (write_param_name, new_tensor.clone()) + for write_param_name in write_param_names + ], + load_format=custom_loader_name, + ) + + for read_param_name in read_param_names[:3]: + _check_param(engine, read_param_name, [1.5] * 5) + + engine.shutdown() + def _check_param(engine, param_name, expect_values): actual_values = torch.tensor(engine.get_weights_by_name(param_name))[0, :5]