Fix weight loading bug for Deepseek v3+nextn (#5684)
This commit is contained in:
@@ -242,6 +242,12 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
|
|||||||
num_experts=self.config.n_routed_experts + self.n_share_experts_fusion,
|
num_experts=self.config.n_routed_experts + self.n_share_experts_fusion,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
|
||||||
|
fuse_qkv_a_proj = hasattr(self.config, "q_lora_rank") and (
|
||||||
|
self.config.q_lora_rank is not None
|
||||||
|
)
|
||||||
|
cached_a_proj = {} if fuse_qkv_a_proj else None
|
||||||
|
|
||||||
nextn_layer_prefix = "model.layers.0"
|
nextn_layer_prefix = "model.layers.0"
|
||||||
nextn_spec_weight_names = [
|
nextn_spec_weight_names = [
|
||||||
"shared_head.norm",
|
"shared_head.norm",
|
||||||
@@ -313,11 +319,51 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
|
|||||||
if name.endswith(".bias") and name not in params_dict:
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
param = params_dict[name]
|
# Handle fused_qkv_a_proj
|
||||||
weight_loader = getattr(
|
if fuse_qkv_a_proj and (
|
||||||
param, "weight_loader", default_weight_loader
|
"q_a_proj" in name or "kv_a_proj_with_mqa" in name
|
||||||
)
|
):
|
||||||
weight_loader(param, loaded_weight)
|
cached_a_proj[name] = loaded_weight
|
||||||
|
q_a_proj_name = (
|
||||||
|
name
|
||||||
|
if "q_a_proj" in name
|
||||||
|
else name.replace("kv_a_proj_with_mqa", "q_a_proj")
|
||||||
|
)
|
||||||
|
kv_a_proj_name = (
|
||||||
|
name
|
||||||
|
if "kv_a_proj_with_mqa" in name
|
||||||
|
else name.replace("q_a_proj", "kv_a_proj_with_mqa")
|
||||||
|
)
|
||||||
|
|
||||||
|
# When both q_a_proj and kv_a_proj_with_mqa has been cached, load the fused weight to parameter
|
||||||
|
if (
|
||||||
|
q_a_proj_name in cached_a_proj
|
||||||
|
and kv_a_proj_name in cached_a_proj
|
||||||
|
):
|
||||||
|
|
||||||
|
q_a_proj_weight = cached_a_proj[q_a_proj_name]
|
||||||
|
kv_a_proj_weight = cached_a_proj[kv_a_proj_name]
|
||||||
|
fused_weight = torch.cat(
|
||||||
|
[q_a_proj_weight, kv_a_proj_weight], dim=0
|
||||||
|
)
|
||||||
|
|
||||||
|
param_name = name.replace(
|
||||||
|
"q_a_proj", "fused_qkv_a_proj_with_mqa"
|
||||||
|
)
|
||||||
|
param = params_dict[param_name]
|
||||||
|
|
||||||
|
weight_loader = getattr(
|
||||||
|
param, "weight_loader", default_weight_loader
|
||||||
|
)
|
||||||
|
weight_loader(param, fused_weight)
|
||||||
|
cached_a_proj.pop(q_a_proj_name)
|
||||||
|
cached_a_proj.pop(kv_a_proj_name)
|
||||||
|
else:
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(
|
||||||
|
param, "weight_loader", default_weight_loader
|
||||||
|
)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
|
||||||
self_attn = self.model.decoder.self_attn
|
self_attn = self.model.decoder.self_attn
|
||||||
if hasattr(self_attn.kv_b_proj, "qweight"):
|
if hasattr(self_attn.kv_b_proj, "qweight"):
|
||||||
|
|||||||
Reference in New Issue
Block a user