diff --git a/python/sglang/srt/models/qwen2.py b/python/sglang/srt/models/qwen2.py index 4afd9f2a3..46b62f837 100644 --- a/python/sglang/srt/models/qwen2.py +++ b/python/sglang/srt/models/qwen2.py @@ -379,8 +379,6 @@ class Qwen2ForCausalLM(nn.Module): continue if name.startswith("model.vision_tower") and name not in params_dict: continue - if name.startswith("lm_head"): - continue for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: diff --git a/python/sglang/srt/models/qwen2_rm.py b/python/sglang/srt/models/qwen2_rm.py index c7aaa7697..39ed15fa5 100644 --- a/python/sglang/srt/models/qwen2_rm.py +++ b/python/sglang/srt/models/qwen2_rm.py @@ -62,7 +62,11 @@ class Qwen2ForRewardModel(nn.Module): return EmbeddingPoolerOutput(pooled_logits) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - return Qwen2ForCausalLM.load_weights(self, weights) + # Filter out lm_head weights of Qwen2ForCausalLM + filtered_weights = [ + (name, w) for name, w in weights if not name.startswith("lm_head") + ] + return Qwen2ForCausalLM.load_weights(self, filtered_weights) EntryClass = [