From 3c93187cafd675ad8c05dcf4095513ce4ec0bae3 Mon Sep 17 00:00:00 2001 From: TianyiQ <34389237+TianyiQ@users.noreply.github.com> Date: Tue, 24 Sep 2024 21:50:20 -0700 Subject: [PATCH] Add support for tie_word_embeddings when loading weights + support for SmolLM (#1508) --- README.md | 1 + python/sglang/srt/models/llama.py | 8 ++++++++ test/srt/models/test_generation_models.py | 1 + 3 files changed, 10 insertions(+) diff --git a/README.md b/README.md index f94279835..c4da865f2 100644 --- a/README.md +++ b/README.md @@ -263,6 +263,7 @@ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct - BaiChuan2 - MiniCPM / MiniCPM 3 - XVERSE / XVERSE MoE +- SmolLM **Embedding Models** diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py index 447b548aa..b63aaf16f 100644 --- a/python/sglang/srt/models/llama.py +++ b/python/sglang/srt/models/llama.py @@ -403,6 +403,14 @@ class LlamaForCausalLM(nn.Module): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + if ( + hasattr(self.config, "tie_word_embeddings") + and self.config.tie_word_embeddings + ): + # Tie output embedding layer to input embedding layer, to solve issues where lm_head.weight is missing + param = self.lm_head.weight + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, self.model.embed_tokens.weight) apply_torchao_config_(self, params_dict, set(["proj.weight"])) diff --git a/test/srt/models/test_generation_models.py b/test/srt/models/test_generation_models.py index 67ef363d0..732b3d800 100644 --- a/test/srt/models/test_generation_models.py +++ b/test/srt/models/test_generation_models.py @@ -51,6 +51,7 @@ CI_MODELS = [ # All other models ALL_OTHER_MODELS = [ ModelCase("Qwen/Qwen2-1.5B"), + ModelCase("HuggingFaceTB/SmolLM-135M-Instruct"), ] TORCH_DTYPES = [torch.float16]