Bug: fix lm head weights in Qwen models (#3777)
This commit is contained in:
@@ -379,8 +379,6 @@ class Qwen2ForCausalLM(nn.Module):
|
|||||||
continue
|
continue
|
||||||
if name.startswith("model.vision_tower") and name not in params_dict:
|
if name.startswith("model.vision_tower") and name not in params_dict:
|
||||||
continue
|
continue
|
||||||
if name.startswith("lm_head"):
|
|
||||||
continue
|
|
||||||
|
|
||||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||||
if weight_name not in name:
|
if weight_name not in name:
|
||||||
|
|||||||
@@ -62,7 +62,11 @@ class Qwen2ForRewardModel(nn.Module):
|
|||||||
return EmbeddingPoolerOutput(pooled_logits)
|
return EmbeddingPoolerOutput(pooled_logits)
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
return Qwen2ForCausalLM.load_weights(self, weights)
|
# Filter out lm_head weights of Qwen2ForCausalLM
|
||||||
|
filtered_weights = [
|
||||||
|
(name, w) for name, w in weights if not name.startswith("lm_head")
|
||||||
|
]
|
||||||
|
return Qwen2ForCausalLM.load_weights(self, filtered_weights)
|
||||||
|
|
||||||
|
|
||||||
EntryClass = [
|
EntryClass = [
|
||||||
|
|||||||
Reference in New Issue
Block a user