fix: resolve qwen2 moe weight loader (#1252)
This commit is contained in:
@@ -401,24 +401,12 @@ class Qwen2MoeForCausalLM(nn.Module):
|
|||||||
("gate_up_proj", "up_proj", 1),
|
("gate_up_proj", "up_proj", 1),
|
||||||
]
|
]
|
||||||
|
|
||||||
expert_params_mapping = [
|
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||||||
# These are the weights for the experts
|
ckpt_gate_proj_name="gate_proj",
|
||||||
# (param_name, weight_name, expert_id, shard_id)
|
ckpt_down_proj_name="down_proj",
|
||||||
(
|
ckpt_up_proj_name="up_proj",
|
||||||
(
|
num_experts=self.config.num_experts,
|
||||||
"experts.w13_weight"
|
|
||||||
if weight_name in ["gate_proj", "up_proj"]
|
|
||||||
else "experts.w2_weight"
|
|
||||||
),
|
|
||||||
f"experts.{expert_id}.{weight_name}.weight",
|
|
||||||
expert_id,
|
|
||||||
shard_id,
|
|
||||||
)
|
)
|
||||||
for expert_id in range(self.config.num_experts)
|
|
||||||
for shard_id, weight_name in enumerate(
|
|
||||||
["gate_proj", "down_proj", "up_proj"]
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
params_dict = dict(self.named_parameters())
|
params_dict = dict(self.named_parameters())
|
||||||
for name, loaded_weight in weights:
|
for name, loaded_weight in weights:
|
||||||
@@ -458,7 +446,7 @@ class Qwen2MoeForCausalLM(nn.Module):
|
|||||||
weight_loader(
|
weight_loader(
|
||||||
param,
|
param,
|
||||||
loaded_weight,
|
loaded_weight,
|
||||||
weight_name,
|
name,
|
||||||
shard_id=shard_id,
|
shard_id=shard_id,
|
||||||
expert_id=expert_id,
|
expert_id=expert_id,
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user