From a3339d8cac8ab0172a55c1bcc231f53323c5040f Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Sat, 22 Feb 2025 09:53:12 +0800 Subject: [PATCH] Bug: Fix weight loader error when LM head weights are tied (#3766) --- python/sglang/srt/models/llama.py | 2 ++ python/sglang/srt/models/minicpm.py | 2 ++ python/sglang/srt/models/minicpm3.py | 2 ++ python/sglang/srt/models/olmo.py | 2 ++ python/sglang/srt/models/phi3_small.py | 2 ++ python/sglang/srt/models/qwen2.py | 2 ++ python/sglang/srt/models/qwen2_vl.py | 2 ++ python/sglang/srt/models/torch_native_llama.py | 2 ++ 8 files changed, 16 insertions(+) diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py index 9cb13e944..27b4277cf 100644 --- a/python/sglang/srt/models/llama.py +++ b/python/sglang/srt/models/llama.py @@ -458,6 +458,8 @@ class LlamaForCausalLM(nn.Module): continue if name.startswith("model.vision_tower") and name not in params_dict: continue + if self.config.tie_word_embeddings and "lm_head.weight" in name: + continue # Handle FP8 kv-scale remapping if "scale" in name: name = maybe_remap_kv_scale_name(name, params_dict) diff --git a/python/sglang/srt/models/minicpm.py b/python/sglang/srt/models/minicpm.py index f5e69411a..6f8b500a4 100644 --- a/python/sglang/srt/models/minicpm.py +++ b/python/sglang/srt/models/minicpm.py @@ -339,6 +339,8 @@ class MiniCPMForCausalLM(nn.Module): # Models trained using ColossalAI may include these tensors in # the checkpoint. Skip them. continue + if self.config.tie_word_embeddings and "lm_head.weight" in name: + continue for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: diff --git a/python/sglang/srt/models/minicpm3.py b/python/sglang/srt/models/minicpm3.py index 31ea7cd9f..f7b331bab 100644 --- a/python/sglang/srt/models/minicpm3.py +++ b/python/sglang/srt/models/minicpm3.py @@ -603,6 +603,8 @@ class MiniCPM3ForCausalLM(nn.Module): # Models trained using ColossalAI may include these tensors in # the checkpoint. Skip them. continue + if self.config.tie_word_embeddings and "lm_head.weight" in name: + continue for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: diff --git a/python/sglang/srt/models/olmo.py b/python/sglang/srt/models/olmo.py index 4d8a79900..9f118ea6b 100644 --- a/python/sglang/srt/models/olmo.py +++ b/python/sglang/srt/models/olmo.py @@ -325,6 +325,8 @@ class OlmoForCausalLM(nn.Module): # Models trained using ColossalAI may include these tensors in # the checkpoint. Skip them. continue + if self.config.tie_word_embeddings and "lm_head.weight" in name: + continue for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue diff --git a/python/sglang/srt/models/phi3_small.py b/python/sglang/srt/models/phi3_small.py index b7195dbaa..fa365b98c 100644 --- a/python/sglang/srt/models/phi3_small.py +++ b/python/sglang/srt/models/phi3_small.py @@ -433,6 +433,8 @@ class Phi3SmallForCausalLM(nn.Module): continue if name.endswith(".bias") and name not in params_dict: continue + if self.config.tie_word_embeddings and "lm_head.weight" in name: + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/python/sglang/srt/models/qwen2.py b/python/sglang/srt/models/qwen2.py index 46b62f837..d53d9561f 100644 --- a/python/sglang/srt/models/qwen2.py +++ b/python/sglang/srt/models/qwen2.py @@ -377,6 +377,8 @@ class Qwen2ForCausalLM(nn.Module): # Models trained using ColossalAI may include these tensors in # the checkpoint. Skip them. continue + if self.config.tie_word_embeddings and "lm_head.weight" in name: + continue if name.startswith("model.vision_tower") and name not in params_dict: continue diff --git a/python/sglang/srt/models/qwen2_vl.py b/python/sglang/srt/models/qwen2_vl.py index add9019bd..d8e190deb 100644 --- a/python/sglang/srt/models/qwen2_vl.py +++ b/python/sglang/srt/models/qwen2_vl.py @@ -586,6 +586,8 @@ class Qwen2VLForConditionalGeneration(nn.Module): for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue + if self.config.tie_word_embeddings and "lm_head.weight" in name: + continue for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: diff --git a/python/sglang/srt/models/torch_native_llama.py b/python/sglang/srt/models/torch_native_llama.py index 7b3e5bc5d..0612e3e7d 100644 --- a/python/sglang/srt/models/torch_native_llama.py +++ b/python/sglang/srt/models/torch_native_llama.py @@ -486,6 +486,8 @@ class TorchNativeLlamaForCausalLM(nn.Module): continue if name.startswith("model.vision_tower") and name not in params_dict: continue + if self.config.tie_word_embeddings and "lm_head.weight" in name: + continue for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: