fix: resolve qwen2 moe weight loader (#1252)
This commit is contained in:
@@ -401,24 +401,12 @@ class Qwen2MoeForCausalLM(nn.Module):
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
|
||||
expert_params_mapping = [
|
||||
# These are the weights for the experts
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
(
|
||||
(
|
||||
"experts.w13_weight"
|
||||
if weight_name in ["gate_proj", "up_proj"]
|
||||
else "experts.w2_weight"
|
||||
),
|
||||
f"experts.{expert_id}.{weight_name}.weight",
|
||||
expert_id,
|
||||
shard_id,
|
||||
)
|
||||
for expert_id in range(self.config.num_experts)
|
||||
for shard_id, weight_name in enumerate(
|
||||
["gate_proj", "down_proj", "up_proj"]
|
||||
)
|
||||
]
|
||||
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||||
ckpt_gate_proj_name="gate_proj",
|
||||
ckpt_down_proj_name="down_proj",
|
||||
ckpt_up_proj_name="up_proj",
|
||||
num_experts=self.config.num_experts,
|
||||
)
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in weights:
|
||||
@@ -458,7 +446,7 @@ class Qwen2MoeForCausalLM(nn.Module):
|
||||
weight_loader(
|
||||
param,
|
||||
loaded_weight,
|
||||
weight_name,
|
||||
name,
|
||||
shard_id=shard_id,
|
||||
expert_id=expert_id,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user