Tiny refactor DefaultModelLoader.Source (#5482)
This commit is contained in:
@@ -515,13 +515,7 @@ class ModelRunner:
|
||||
|
||||
def get_weight_iter(config):
|
||||
iter = loader._get_weights_iterator(
|
||||
DefaultModelLoader.Source(
|
||||
config.model_path,
|
||||
revision=config.revision,
|
||||
fall_back_to_pt=getattr(
|
||||
self.model, "fall_back_to_pt_during_load", True
|
||||
),
|
||||
)
|
||||
DefaultModelLoader.Source.init_new(config, model)
|
||||
)
|
||||
return iter
|
||||
|
||||
|
||||
@@ -197,6 +197,15 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
fall_back_to_pt: bool = True
|
||||
"""Whether .pt weights can be used."""
|
||||
|
||||
@classmethod
|
||||
def init_new(cls, model_config: ModelConfig, model):
|
||||
return cls(
|
||||
model_config.model_path,
|
||||
model_config.revision,
|
||||
prefix="",
|
||||
fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load", True),
|
||||
)
|
||||
|
||||
def __init__(self, load_config: LoadConfig):
|
||||
super().__init__(load_config)
|
||||
if load_config.model_loader_extra_config:
|
||||
@@ -341,12 +350,7 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
model: nn.Module,
|
||||
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||
|
||||
primary_weights = DefaultModelLoader.Source(
|
||||
model_config.model_path,
|
||||
model_config.revision,
|
||||
prefix="",
|
||||
fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load", True),
|
||||
)
|
||||
primary_weights = DefaultModelLoader.Source.init_new(model_config, model)
|
||||
yield from self._get_weights_iterator(primary_weights)
|
||||
|
||||
secondary_weights = cast(
|
||||
|
||||
Reference in New Issue
Block a user