From 693723d1f77725078a5175c469ff951fa95a8e36 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Mon, 28 Apr 2025 01:18:57 -0700 Subject: [PATCH] Revert "Tiny refactor DefaultModelLoader.Source" (#5825) --- python/sglang/srt/model_executor/model_runner.py | 8 +++++++- python/sglang/srt/model_loader/loader.py | 16 ++++++---------- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index fc3372624..3052e924c 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -515,7 +515,13 @@ class ModelRunner: def get_weight_iter(config): iter = loader._get_weights_iterator( - DefaultModelLoader.Source.init_new(config, model) + DefaultModelLoader.Source( + config.model_path, + revision=config.revision, + fall_back_to_pt=getattr( + self.model, "fall_back_to_pt_during_load", True + ), + ) ) return iter diff --git a/python/sglang/srt/model_loader/loader.py b/python/sglang/srt/model_loader/loader.py index cc69d3d78..94b02c6f5 100644 --- a/python/sglang/srt/model_loader/loader.py +++ b/python/sglang/srt/model_loader/loader.py @@ -197,15 +197,6 @@ class DefaultModelLoader(BaseModelLoader): fall_back_to_pt: bool = True """Whether .pt weights can be used.""" - @classmethod - def init_new(cls, model_config: ModelConfig, model): - return cls( - model_config.model_path, - model_config.revision, - prefix="", - fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load", True), - ) - def __init__(self, load_config: LoadConfig): super().__init__(load_config) if load_config.model_loader_extra_config: @@ -350,7 +341,12 @@ class DefaultModelLoader(BaseModelLoader): model: nn.Module, ) -> Generator[Tuple[str, torch.Tensor], None, None]: - primary_weights = DefaultModelLoader.Source.init_new(model_config, model) + primary_weights = DefaultModelLoader.Source( + model_config.model_path, + model_config.revision, + prefix="", + fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load", True), + ) yield from self._get_weights_iterator(primary_weights) secondary_weights = cast(