Fix weight loading for tied word embedding when TP > 1 (#2009)
This commit is contained in:
@@ -380,6 +380,12 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
]
|
]
|
||||||
params_dict = dict(self.named_parameters())
|
params_dict = dict(self.named_parameters())
|
||||||
|
|
||||||
|
load_tie_word_embeddings = (
|
||||||
|
hasattr(self.config, "tie_word_embeddings")
|
||||||
|
and self.config.tie_word_embeddings
|
||||||
|
and "lm_head.weight" in params_dict
|
||||||
|
)
|
||||||
|
|
||||||
for name, loaded_weight in weights:
|
for name, loaded_weight in weights:
|
||||||
if "rotary_emb.inv_freq" in name or "projector" in name:
|
if "rotary_emb.inv_freq" in name or "projector" in name:
|
||||||
continue
|
continue
|
||||||
@@ -412,15 +418,14 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
weight_loader(param, loaded_weight)
|
weight_loader(param, loaded_weight)
|
||||||
|
|
||||||
if (
|
if load_tie_word_embeddings and name == "model.embed_tokens.weight":
|
||||||
hasattr(self.config, "tie_word_embeddings")
|
embed_tokens_weight = loaded_weight
|
||||||
and self.config.tie_word_embeddings
|
|
||||||
and "lm_head.weight" in params_dict
|
if load_tie_word_embeddings:
|
||||||
):
|
|
||||||
# Tie output embedding layer to input embedding layer, to solve issues where lm_head.weight is missing
|
# Tie output embedding layer to input embedding layer, to solve issues where lm_head.weight is missing
|
||||||
param = self.lm_head.weight
|
param = self.lm_head.weight
|
||||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
weight_loader(param, self.model.embed_tokens.weight)
|
weight_loader(param, embed_tokens_weight)
|
||||||
|
|
||||||
apply_torchao_config_(self, params_dict, set(["proj.weight"]))
|
apply_torchao_config_(self, params_dict, set(["proj.weight"]))
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user