Bug: fix lm head weights in Qwen models (#3777)
This commit is contained in:
@@ -379,8 +379,6 @@ class Qwen2ForCausalLM(nn.Module):
|
||||
continue
|
||||
if name.startswith("model.vision_tower") and name not in params_dict:
|
||||
continue
|
||||
if name.startswith("lm_head"):
|
||||
continue
|
||||
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
|
||||
@@ -62,7 +62,11 @@ class Qwen2ForRewardModel(nn.Module):
|
||||
return EmbeddingPoolerOutput(pooled_logits)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
return Qwen2ForCausalLM.load_weights(self, weights)
|
||||
# Filter out lm_head weights of Qwen2ForCausalLM
|
||||
filtered_weights = [
|
||||
(name, w) for name, w in weights if not name.startswith("lm_head")
|
||||
]
|
||||
return Qwen2ForCausalLM.load_weights(self, filtered_weights)
|
||||
|
||||
|
||||
EntryClass = [
|
||||
|
||||
Reference in New Issue
Block a user