diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py index a6f460846..284334396 100644 --- a/python/sglang/srt/models/llama.py +++ b/python/sglang/srt/models/llama.py @@ -380,6 +380,12 @@ class LlamaForCausalLM(nn.Module): ] params_dict = dict(self.named_parameters()) + load_tie_word_embeddings = ( + hasattr(self.config, "tie_word_embeddings") + and self.config.tie_word_embeddings + and "lm_head.weight" in params_dict + ) + for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name or "projector" in name: continue @@ -412,15 +418,14 @@ class LlamaForCausalLM(nn.Module): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) - if ( - hasattr(self.config, "tie_word_embeddings") - and self.config.tie_word_embeddings - and "lm_head.weight" in params_dict - ): + if load_tie_word_embeddings and name == "model.embed_tokens.weight": + embed_tokens_weight = loaded_weight + + if load_tie_word_embeddings: # Tie output embedding layer to input embedding layer, to solve issues where lm_head.weight is missing param = self.lm_head.weight weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, self.model.embed_tokens.weight) + weight_loader(param, embed_tokens_weight) apply_torchao_config_(self, params_dict, set(["proj.weight"]))