fix issue of loading weight

This commit is contained in:
2026-06-30 09:55:13 +08:00
parent f89bc60d59
commit 1902c81fdd

View File

@@ -1322,6 +1322,35 @@ class Qwen3_5MoeForCausalLM(Qwen3_5ForCausalLM):
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
continue continue
# --- Individual expert weights (FT checkpoint: experts.{i}.{proj}.weight) ---
# Standard transformers fine-tuning saves each expert separately instead of
# the pre-merged (num_experts, ...) tensors in the original checkpoint.
if ".mlp.experts." in name:
parts = name.split(".mlp.experts.", 1)
expert_rest = parts[1] # e.g. "0.gate_proj.weight"
dot_pos = expert_rest.find(".")
if dot_pos > 0 and expert_rest[:dot_pos].isdigit():
eid = int(expert_rest[:dot_pos])
proj_raw = expert_rest[dot_pos + 1:]
proj = proj_raw[:-7] if proj_raw.endswith(".weight") else proj_raw
prefix = parts[0] # e.g. "model.layers.0"
if proj == "gate_proj":
w13_name = f"{prefix}.mlp.experts.w13_weight"
if w13_name in params_dict:
param = params_dict[w13_name]
param.weight_loader(param, loaded_weight, "w1_weight", "w1", eid)
elif proj == "up_proj":
w13_name = f"{prefix}.mlp.experts.w13_weight"
if w13_name in params_dict:
param = params_dict[w13_name]
param.weight_loader(param, loaded_weight, "w3_weight", "w3", eid)
elif proj == "down_proj":
w2_name = f"{prefix}.mlp.experts.w2_weight"
if w2_name in params_dict:
param = params_dict[w2_name]
param.weight_loader(param, loaded_weight, "w2_weight", "w2", eid)
continue
# --- Stacked / standard weights --- # --- Stacked / standard weights ---
for param_name, weight_name, shard_id in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name: if weight_name not in name: