fix: resolve qwen2 moe weight loader (#1252)

This commit is contained in:
Yineng Zhang
2024-08-29 04:25:46 +10:00
committed by GitHub
parent 0a97d7962d
commit 492143bf32

View File

@@ -401,24 +401,12 @@ class Qwen2MoeForCausalLM(nn.Module):
("gate_up_proj", "up_proj", 1),
]
expert_params_mapping = [
# These are the weights for the experts
# (param_name, weight_name, expert_id, shard_id)
(
(
"experts.w13_weight"
if weight_name in ["gate_proj", "up_proj"]
else "experts.w2_weight"
),
f"experts.{expert_id}.{weight_name}.weight",
expert_id,
shard_id,
)
for expert_id in range(self.config.num_experts)
for shard_id, weight_name in enumerate(
["gate_proj", "down_proj", "up_proj"]
)
]
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.num_experts,
)
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
@@ -458,7 +446,7 @@ class Qwen2MoeForCausalLM(nn.Module):
weight_loader(
param,
loaded_weight,
weight_name,
name,
shard_id=shard_id,
expert_id=expert_id,
)