Bug: Fix weight loader error when LM head weights are tied (#3766)
This commit is contained in:
@@ -458,6 +458,8 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
continue
|
continue
|
||||||
if name.startswith("model.vision_tower") and name not in params_dict:
|
if name.startswith("model.vision_tower") and name not in params_dict:
|
||||||
continue
|
continue
|
||||||
|
if self.config.tie_word_embeddings and "lm_head.weight" in name:
|
||||||
|
continue
|
||||||
# Handle FP8 kv-scale remapping
|
# Handle FP8 kv-scale remapping
|
||||||
if "scale" in name:
|
if "scale" in name:
|
||||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||||
|
|||||||
@@ -339,6 +339,8 @@ class MiniCPMForCausalLM(nn.Module):
|
|||||||
# Models trained using ColossalAI may include these tensors in
|
# Models trained using ColossalAI may include these tensors in
|
||||||
# the checkpoint. Skip them.
|
# the checkpoint. Skip them.
|
||||||
continue
|
continue
|
||||||
|
if self.config.tie_word_embeddings and "lm_head.weight" in name:
|
||||||
|
continue
|
||||||
|
|
||||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||||
if weight_name not in name:
|
if weight_name not in name:
|
||||||
|
|||||||
@@ -603,6 +603,8 @@ class MiniCPM3ForCausalLM(nn.Module):
|
|||||||
# Models trained using ColossalAI may include these tensors in
|
# Models trained using ColossalAI may include these tensors in
|
||||||
# the checkpoint. Skip them.
|
# the checkpoint. Skip them.
|
||||||
continue
|
continue
|
||||||
|
if self.config.tie_word_embeddings and "lm_head.weight" in name:
|
||||||
|
continue
|
||||||
|
|
||||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||||
if weight_name not in name:
|
if weight_name not in name:
|
||||||
|
|||||||
@@ -325,6 +325,8 @@ class OlmoForCausalLM(nn.Module):
|
|||||||
# Models trained using ColossalAI may include these tensors in
|
# Models trained using ColossalAI may include these tensors in
|
||||||
# the checkpoint. Skip them.
|
# the checkpoint. Skip them.
|
||||||
continue
|
continue
|
||||||
|
if self.config.tie_word_embeddings and "lm_head.weight" in name:
|
||||||
|
continue
|
||||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||||
if weight_name not in name:
|
if weight_name not in name:
|
||||||
continue
|
continue
|
||||||
|
|||||||
@@ -433,6 +433,8 @@ class Phi3SmallForCausalLM(nn.Module):
|
|||||||
continue
|
continue
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
continue
|
continue
|
||||||
|
if self.config.tie_word_embeddings and "lm_head.weight" in name:
|
||||||
|
continue
|
||||||
|
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
|
|||||||
@@ -377,6 +377,8 @@ class Qwen2ForCausalLM(nn.Module):
|
|||||||
# Models trained using ColossalAI may include these tensors in
|
# Models trained using ColossalAI may include these tensors in
|
||||||
# the checkpoint. Skip them.
|
# the checkpoint. Skip them.
|
||||||
continue
|
continue
|
||||||
|
if self.config.tie_word_embeddings and "lm_head.weight" in name:
|
||||||
|
continue
|
||||||
if name.startswith("model.vision_tower") and name not in params_dict:
|
if name.startswith("model.vision_tower") and name not in params_dict:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|||||||
@@ -586,6 +586,8 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|||||||
for name, loaded_weight in weights:
|
for name, loaded_weight in weights:
|
||||||
if "rotary_emb.inv_freq" in name:
|
if "rotary_emb.inv_freq" in name:
|
||||||
continue
|
continue
|
||||||
|
if self.config.tie_word_embeddings and "lm_head.weight" in name:
|
||||||
|
continue
|
||||||
|
|
||||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||||
if weight_name not in name:
|
if weight_name not in name:
|
||||||
|
|||||||
@@ -486,6 +486,8 @@ class TorchNativeLlamaForCausalLM(nn.Module):
|
|||||||
continue
|
continue
|
||||||
if name.startswith("model.vision_tower") and name not in params_dict:
|
if name.startswith("model.vision_tower") and name not in params_dict:
|
||||||
continue
|
continue
|
||||||
|
if self.config.tie_word_embeddings and "lm_head.weight" in name:
|
||||||
|
continue
|
||||||
|
|
||||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||||
if weight_name not in name:
|
if weight_name not in name:
|
||||||
|
|||||||
Reference in New Issue
Block a user