Bug: fix lm head weights in Qwen models (#3777)

This commit is contained in:
Chayenne
2025-02-21 16:49:31 -08:00
committed by GitHub
parent d37f95511d
commit 14d90617b0
2 changed files with 5 additions and 3 deletions

View File

@@ -379,8 +379,6 @@ class Qwen2ForCausalLM(nn.Module):
continue
if name.startswith("model.vision_tower") and name not in params_dict:
continue
if name.startswith("lm_head"):
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:

View File

@@ -62,7 +62,11 @@ class Qwen2ForRewardModel(nn.Module):
return EmbeddingPoolerOutput(pooled_logits)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
return Qwen2ForCausalLM.load_weights(self, weights)
# Filter out lm_head weights of Qwen2ForCausalLM
filtered_weights = [
(name, w) for name, w in weights if not name.startswith("lm_head")
]
return Qwen2ForCausalLM.load_weights(self, filtered_weights)
EntryClass = [