Reland tiny refactor DefaultModelLoader.Source (#6041)
This commit is contained in:
@@ -553,13 +553,7 @@ class ModelRunner:
|
|||||||
|
|
||||||
def get_weight_iter(config):
|
def get_weight_iter(config):
|
||||||
iter = loader._get_weights_iterator(
|
iter = loader._get_weights_iterator(
|
||||||
DefaultModelLoader.Source(
|
DefaultModelLoader.Source.init_new(config, self.model)
|
||||||
config.model_path,
|
|
||||||
revision=config.revision,
|
|
||||||
fall_back_to_pt=getattr(
|
|
||||||
self.model, "fall_back_to_pt_during_load", True
|
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
return iter
|
return iter
|
||||||
|
|
||||||
|
|||||||
@@ -197,6 +197,15 @@ class DefaultModelLoader(BaseModelLoader):
|
|||||||
fall_back_to_pt: bool = True
|
fall_back_to_pt: bool = True
|
||||||
"""Whether .pt weights can be used."""
|
"""Whether .pt weights can be used."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def init_new(cls, model_config: ModelConfig, model):
|
||||||
|
return cls(
|
||||||
|
model_config.model_path,
|
||||||
|
model_config.revision,
|
||||||
|
prefix="",
|
||||||
|
fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load", True),
|
||||||
|
)
|
||||||
|
|
||||||
def __init__(self, load_config: LoadConfig):
|
def __init__(self, load_config: LoadConfig):
|
||||||
super().__init__(load_config)
|
super().__init__(load_config)
|
||||||
if load_config.model_loader_extra_config:
|
if load_config.model_loader_extra_config:
|
||||||
@@ -341,12 +350,7 @@ class DefaultModelLoader(BaseModelLoader):
|
|||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||||
|
|
||||||
primary_weights = DefaultModelLoader.Source(
|
primary_weights = DefaultModelLoader.Source.init_new(model_config, model)
|
||||||
model_config.model_path,
|
|
||||||
model_config.revision,
|
|
||||||
prefix="",
|
|
||||||
fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load", True),
|
|
||||||
)
|
|
||||||
yield from self._get_weights_iterator(primary_weights)
|
yield from self._get_weights_iterator(primary_weights)
|
||||||
|
|
||||||
secondary_weights = cast(
|
secondary_weights = cast(
|
||||||
|
|||||||
Reference in New Issue
Block a user