Tiny refactor weight loading logic (#5232)
This commit is contained in:
@@ -557,12 +557,7 @@ class ModelRunner:
|
|||||||
return iter
|
return iter
|
||||||
|
|
||||||
def model_load_weights(model, iter):
|
def model_load_weights(model, iter):
|
||||||
model.load_weights(iter)
|
DefaultModelLoader.load_weights_and_postprocess(model, iter, target_device)
|
||||||
for _, module in self.model.named_modules():
|
|
||||||
quant_method = getattr(module, "quant_method", None)
|
|
||||||
if quant_method is not None:
|
|
||||||
with device_loading_context(module, target_device):
|
|
||||||
quant_method.process_weights_after_loading(module)
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
with set_default_torch_dtype(self.model_config.dtype):
|
with set_default_torch_dtype(self.model_config.dtype):
|
||||||
|
|||||||
@@ -374,20 +374,27 @@ class DefaultModelLoader(BaseModelLoader):
|
|||||||
self.load_config,
|
self.load_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
model.load_weights(self._get_all_weights(model_config, model))
|
self.load_weights_and_postprocess(
|
||||||
|
model, self._get_all_weights(model_config, model), target_device
|
||||||
|
)
|
||||||
|
|
||||||
for _, module in model.named_modules():
|
|
||||||
quant_method = getattr(module, "quant_method", None)
|
|
||||||
if quant_method is not None:
|
|
||||||
# When quant methods need to process weights after loading
|
|
||||||
# (for repacking, quantizing, etc), they expect parameters
|
|
||||||
# to be on the global target device. This scope is for the
|
|
||||||
# case where cpu offloading is used, where we will move the
|
|
||||||
# parameters onto device for processing and back off after.
|
|
||||||
with device_loading_context(module, target_device):
|
|
||||||
quant_method.process_weights_after_loading(module)
|
|
||||||
return model.eval()
|
return model.eval()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load_weights_and_postprocess(model, weights, target_device):
|
||||||
|
model.load_weights(weights)
|
||||||
|
|
||||||
|
for _, module in model.named_modules():
|
||||||
|
quant_method = getattr(module, "quant_method", None)
|
||||||
|
if quant_method is not None:
|
||||||
|
# When quant methods need to process weights after loading
|
||||||
|
# (for repacking, quantizing, etc), they expect parameters
|
||||||
|
# to be on the global target device. This scope is for the
|
||||||
|
# case where cpu offloading is used, where we will move the
|
||||||
|
# parameters onto device for processing and back off after.
|
||||||
|
with device_loading_context(module, target_device):
|
||||||
|
quant_method.process_weights_after_loading(module)
|
||||||
|
|
||||||
|
|
||||||
class LayeredModelLoader(DefaultModelLoader):
|
class LayeredModelLoader(DefaultModelLoader):
|
||||||
"""Model loader that loads weights layer by layer so that one can quantize a
|
"""Model loader that loads weights layer by layer so that one can quantize a
|
||||||
|
|||||||
Reference in New Issue
Block a user