Add support for tie_word_embeddings when loading weights + support for SmolLM (#1508)

This commit is contained in:
TianyiQ
2024-09-24 21:50:20 -07:00
committed by GitHub
parent fb2d0680e0
commit 3c93187caf
3 changed files with 10 additions and 0 deletions

View File

@@ -403,6 +403,14 @@ class LlamaForCausalLM(nn.Module):
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
if (
hasattr(self.config, "tie_word_embeddings")
and self.config.tie_word_embeddings
):
# Tie output embedding layer to input embedding layer, to solve issues where lm_head.weight is missing
param = self.lm_head.weight
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, self.model.embed_tokens.weight)
apply_torchao_config_(self, params_dict, set(["proj.weight"]))