From 530ae1bdc80f8740975977d4a347b62760fd381d Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Mon, 11 Nov 2024 17:52:42 -0800 Subject: [PATCH] Fix weight loading for tied word embedding when TP > 1 (#2009) --- python/sglang/srt/models/llama.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py index a6f460846..284334396 100644 --- a/python/sglang/srt/models/llama.py +++ b/python/sglang/srt/models/llama.py @@ -380,6 +380,12 @@ class LlamaForCausalLM(nn.Module): ] params_dict = dict(self.named_parameters()) + load_tie_word_embeddings = ( + hasattr(self.config, "tie_word_embeddings") + and self.config.tie_word_embeddings + and "lm_head.weight" in params_dict + ) + for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name or "projector" in name: continue @@ -412,15 +418,14 @@ class LlamaForCausalLM(nn.Module): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) - if ( - hasattr(self.config, "tie_word_embeddings") - and self.config.tie_word_embeddings - and "lm_head.weight" in params_dict - ): + if load_tie_word_embeddings and name == "model.embed_tokens.weight": + embed_tokens_weight = loaded_weight + + if load_tie_word_embeddings: # Tie output embedding layer to input embedding layer, to solve issues where lm_head.weight is missing param = self.lm_head.weight weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, self.model.embed_tokens.weight) + weight_loader(param, embed_tokens_weight) apply_torchao_config_(self, params_dict, set(["proj.weight"]))