From 14d90617b080f652b1eed1307ab83c40af984a4d Mon Sep 17 00:00:00 2001 From: Chayenne Date: Fri, 21 Feb 2025 16:49:31 -0800 Subject: [PATCH] Bug: fix lm head weights in Qwen models (#3777) --- python/sglang/srt/models/qwen2.py | 2 -- python/sglang/srt/models/qwen2_rm.py | 6 +++++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/models/qwen2.py b/python/sglang/srt/models/qwen2.py index 4afd9f2a3..46b62f837 100644 --- a/python/sglang/srt/models/qwen2.py +++ b/python/sglang/srt/models/qwen2.py @@ -379,8 +379,6 @@ class Qwen2ForCausalLM(nn.Module): continue if name.startswith("model.vision_tower") and name not in params_dict: continue - if name.startswith("lm_head"): - continue for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: diff --git a/python/sglang/srt/models/qwen2_rm.py b/python/sglang/srt/models/qwen2_rm.py index c7aaa7697..39ed15fa5 100644 --- a/python/sglang/srt/models/qwen2_rm.py +++ b/python/sglang/srt/models/qwen2_rm.py @@ -62,7 +62,11 @@ class Qwen2ForRewardModel(nn.Module): return EmbeddingPoolerOutput(pooled_logits) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - return Qwen2ForCausalLM.load_weights(self, weights) + # Filter out lm_head weights of Qwen2ForCausalLM + filtered_weights = [ + (name, w) for name, w in weights if not name.startswith("lm_head") + ] + return Qwen2ForCausalLM.load_weights(self, filtered_weights) EntryClass = [