Tiny refactor weight loading logic (#5232)

This commit is contained in:
fzyzcjy
2025-05-08 16:02:56 +08:00
committed by GitHub
parent b6cf3532b5
commit 6450c1228c
2 changed files with 19 additions and 17 deletions

View File

@@ -557,12 +557,7 @@ class ModelRunner:
return iter return iter
def model_load_weights(model, iter): def model_load_weights(model, iter):
model.load_weights(iter) DefaultModelLoader.load_weights_and_postprocess(model, iter, target_device)
for _, module in self.model.named_modules():
quant_method = getattr(module, "quant_method", None)
if quant_method is not None:
with device_loading_context(module, target_device):
quant_method.process_weights_after_loading(module)
return model return model
with set_default_torch_dtype(self.model_config.dtype): with set_default_torch_dtype(self.model_config.dtype):

View File

@@ -374,7 +374,15 @@ class DefaultModelLoader(BaseModelLoader):
self.load_config, self.load_config,
) )
model.load_weights(self._get_all_weights(model_config, model)) self.load_weights_and_postprocess(
model, self._get_all_weights(model_config, model), target_device
)
return model.eval()
@staticmethod
def load_weights_and_postprocess(model, weights, target_device):
model.load_weights(weights)
for _, module in model.named_modules(): for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None) quant_method = getattr(module, "quant_method", None)
@@ -386,7 +394,6 @@ class DefaultModelLoader(BaseModelLoader):
# parameters onto device for processing and back off after. # parameters onto device for processing and back off after.
with device_loading_context(module, target_device): with device_loading_context(module, target_device):
quant_method.process_weights_after_loading(module) quant_method.process_weights_after_loading(module)
return model.eval()
class LayeredModelLoader(DefaultModelLoader): class LayeredModelLoader(DefaultModelLoader):