Fix weight loading bug for Deepseek v3+nextn (#5684)
This commit is contained in:
@@ -242,6 +242,12 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
|
||||
num_experts=self.config.n_routed_experts + self.n_share_experts_fusion,
|
||||
)
|
||||
|
||||
# Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
|
||||
fuse_qkv_a_proj = hasattr(self.config, "q_lora_rank") and (
|
||||
self.config.q_lora_rank is not None
|
||||
)
|
||||
cached_a_proj = {} if fuse_qkv_a_proj else None
|
||||
|
||||
nextn_layer_prefix = "model.layers.0"
|
||||
nextn_spec_weight_names = [
|
||||
"shared_head.norm",
|
||||
@@ -313,11 +319,51 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(
|
||||
param, "weight_loader", default_weight_loader
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
# Handle fused_qkv_a_proj
|
||||
if fuse_qkv_a_proj and (
|
||||
"q_a_proj" in name or "kv_a_proj_with_mqa" in name
|
||||
):
|
||||
cached_a_proj[name] = loaded_weight
|
||||
q_a_proj_name = (
|
||||
name
|
||||
if "q_a_proj" in name
|
||||
else name.replace("kv_a_proj_with_mqa", "q_a_proj")
|
||||
)
|
||||
kv_a_proj_name = (
|
||||
name
|
||||
if "kv_a_proj_with_mqa" in name
|
||||
else name.replace("q_a_proj", "kv_a_proj_with_mqa")
|
||||
)
|
||||
|
||||
# When both q_a_proj and kv_a_proj_with_mqa has been cached, load the fused weight to parameter
|
||||
if (
|
||||
q_a_proj_name in cached_a_proj
|
||||
and kv_a_proj_name in cached_a_proj
|
||||
):
|
||||
|
||||
q_a_proj_weight = cached_a_proj[q_a_proj_name]
|
||||
kv_a_proj_weight = cached_a_proj[kv_a_proj_name]
|
||||
fused_weight = torch.cat(
|
||||
[q_a_proj_weight, kv_a_proj_weight], dim=0
|
||||
)
|
||||
|
||||
param_name = name.replace(
|
||||
"q_a_proj", "fused_qkv_a_proj_with_mqa"
|
||||
)
|
||||
param = params_dict[param_name]
|
||||
|
||||
weight_loader = getattr(
|
||||
param, "weight_loader", default_weight_loader
|
||||
)
|
||||
weight_loader(param, fused_weight)
|
||||
cached_a_proj.pop(q_a_proj_name)
|
||||
cached_a_proj.pop(kv_a_proj_name)
|
||||
else:
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(
|
||||
param, "weight_loader", default_weight_loader
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
self_attn = self.model.decoder.self_attn
|
||||
if hasattr(self_attn.kv_b_proj, "qweight"):
|
||||
|
||||
Reference in New Issue
Block a user