feat: Add flexible validation for partial weight updates (#9663)
Co-authored-by: RichardW <rich-junwang@users.noreply.github.com> Co-authored-by: Zhuorany <yzr1914001753@gmail.com> Co-authored-by: Stefan He <hebiaobuaa@gmail.com> Co-authored-by: Yineng Zhang <me@zhyncs.com> Co-authored-by: Night <32424487+PrinsYin@users.noreply.github.com> Co-authored-by:zhaochenyang20 <zhaochen20@outlook.com> Co-authored-by: Liangsheng Yin <hnyls2002@gmail.com>
This commit is contained in:
@@ -1029,10 +1029,6 @@ class GptOssForCausalLM(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
params_dict = dict(self.named_parameters())
|
params_dict = dict(self.named_parameters())
|
||||||
params_checker = {k: False for k, v in params_dict.items()}
|
|
||||||
|
|
||||||
for other_loaded_param_name in other_loaded_param_names:
|
|
||||||
params_checker[other_loaded_param_name] = True
|
|
||||||
|
|
||||||
for name, loaded_weight in weights:
|
for name, loaded_weight in weights:
|
||||||
loaded_weight = _WeightCreator.maybe_materialize(loaded_weight)
|
loaded_weight = _WeightCreator.maybe_materialize(loaded_weight)
|
||||||
@@ -1069,7 +1065,6 @@ class GptOssForCausalLM(nn.Module):
|
|||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = param.weight_loader
|
weight_loader = param.weight_loader
|
||||||
weight_loader(param, loaded_weight, shard_id)
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
params_checker[name] = True
|
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
for mapping in expert_params_mapping:
|
for mapping in expert_params_mapping:
|
||||||
@@ -1092,7 +1087,6 @@ class GptOssForCausalLM(nn.Module):
|
|||||||
name,
|
name,
|
||||||
shard_id=shard_id,
|
shard_id=shard_id,
|
||||||
)
|
)
|
||||||
params_checker[name] = True
|
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
@@ -1111,17 +1105,9 @@ class GptOssForCausalLM(nn.Module):
|
|||||||
param, "weight_loader", default_weight_loader
|
param, "weight_loader", default_weight_loader
|
||||||
)
|
)
|
||||||
weight_loader(param, loaded_weight)
|
weight_loader(param, loaded_weight)
|
||||||
params_checker[name] = True
|
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Parameter {name} not found in params_dict")
|
logger.warning(f"Parameter {name} not found in params_dict")
|
||||||
|
|
||||||
not_loaded_params = [k for k, v in params_checker.items() if not v]
|
|
||||||
if tp_rank == 0:
|
|
||||||
if len(not_loaded_params) > 0:
|
|
||||||
raise Exception(f"Not all parameters loaded: {not_loaded_params}")
|
|
||||||
else:
|
|
||||||
logging.info("All parameters loaded successfully.")
|
|
||||||
|
|
||||||
def get_embed_and_head(self):
|
def get_embed_and_head(self):
|
||||||
return self.model.embed_tokens.weight, self.lm_head.weight
|
return self.model.embed_tokens.weight, self.lm_head.weight
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user