diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 3052e924c..fc3372624 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -515,13 +515,7 @@ class ModelRunner: def get_weight_iter(config): iter = loader._get_weights_iterator( - DefaultModelLoader.Source( - config.model_path, - revision=config.revision, - fall_back_to_pt=getattr( - self.model, "fall_back_to_pt_during_load", True - ), - ) + DefaultModelLoader.Source.init_new(config, model) ) return iter diff --git a/python/sglang/srt/model_loader/loader.py b/python/sglang/srt/model_loader/loader.py index 94b02c6f5..cc69d3d78 100644 --- a/python/sglang/srt/model_loader/loader.py +++ b/python/sglang/srt/model_loader/loader.py @@ -197,6 +197,15 @@ class DefaultModelLoader(BaseModelLoader): fall_back_to_pt: bool = True """Whether .pt weights can be used.""" + @classmethod + def init_new(cls, model_config: ModelConfig, model): + return cls( + model_config.model_path, + model_config.revision, + prefix="", + fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load", True), + ) + def __init__(self, load_config: LoadConfig): super().__init__(load_config) if load_config.model_loader_extra_config: @@ -341,12 +350,7 @@ class DefaultModelLoader(BaseModelLoader): model: nn.Module, ) -> Generator[Tuple[str, torch.Tensor], None, None]: - primary_weights = DefaultModelLoader.Source( - model_config.model_path, - model_config.revision, - prefix="", - fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load", True), - ) + primary_weights = DefaultModelLoader.Source.init_new(model_config, model) yield from self._get_weights_iterator(primary_weights) secondary_weights = cast(