support custom weight loader for model runner (#7122)
Co-authored-by: kavioyu <kavioyu@tencent.com>
This commit is contained in:
@@ -93,6 +93,7 @@ from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
||||
from sglang.srt.utils import (
|
||||
MultiprocessingSerializer,
|
||||
cpu_has_amx_support,
|
||||
dynamic_import,
|
||||
enable_show_time_cost,
|
||||
get_available_gpu_memory,
|
||||
get_bool_env_var,
|
||||
@@ -761,6 +762,9 @@ class ModelRunner:
|
||||
]
|
||||
if load_format == "direct":
|
||||
_model_load_weights_direct(self.model, named_tensors)
|
||||
elif load_format in self.server_args.custom_weight_loader:
|
||||
custom_loader = dynamic_import(load_format)
|
||||
custom_loader(self.model, named_tensors)
|
||||
elif load_format is None:
|
||||
self.model.load_weights(named_tensors)
|
||||
else:
|
||||
|
||||
@@ -234,6 +234,9 @@ class ServerArgs:
|
||||
num_reserved_decode_tokens: int = 512 # used for decode kv cache offload in PD
|
||||
pdlb_url: Optional[str] = None
|
||||
|
||||
# For model weight update
|
||||
custom_weight_loader: Optional[List[str]] = None
|
||||
|
||||
def __post_init__(self):
|
||||
# Expert parallelism
|
||||
if self.enable_ep_moe:
|
||||
@@ -538,6 +541,9 @@ class ServerArgs:
|
||||
"1" if self.disable_outlines_disk_cache else "0"
|
||||
)
|
||||
|
||||
if self.custom_weight_loader is None:
|
||||
self.custom_weight_loader = []
|
||||
|
||||
def validate_disagg_tp_size(self, prefill_tp: int, decode_tp: int):
|
||||
larger_tp = max(decode_tp, prefill_tp)
|
||||
smaller_tp = min(decode_tp, prefill_tp)
|
||||
@@ -1576,6 +1582,13 @@ class ServerArgs:
|
||||
default=None,
|
||||
help="The URL of the PD disaggregation load balancer. If set, the prefill/decode server will register with the load balancer.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--custom-weight-loader",
|
||||
type=str,
|
||||
nargs="*",
|
||||
default=None,
|
||||
help="The custom dataloader which used to update the model. Should be set with a valid import path, such as my_package.weight_load_func",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
|
||||
@@ -2340,3 +2340,16 @@ class LazyValue:
|
||||
self._value = self._creator()
|
||||
self._creator = None
|
||||
return self._value
|
||||
|
||||
|
||||
def dynamic_import(func_path: str):
|
||||
parts = func_path.split(".")
|
||||
if len(parts) < 2:
|
||||
raise ValueError(
|
||||
"func_path should contain both module name and func name (such as 'module.func')"
|
||||
)
|
||||
module_path = ".".join(parts[:-1])
|
||||
func_name = parts[-1]
|
||||
module = importlib.import_module(module_path)
|
||||
func = getattr(module, func_name)
|
||||
return func
|
||||
|
||||
Reference in New Issue
Block a user