Fix AWQ Dequant and Weight Loading of deepseek v2 (#6842)

This commit is contained in:
AniZpZ
2025-06-18 04:45:10 +08:00
committed by GitHub
parent e726131523
commit 3eb4a800e8
3 changed files with 18 additions and 11 deletions

View File

@@ -2137,8 +2137,14 @@ class DeepseekV2ForCausalLM(nn.Module):
):
q_a_proj_weight = cached_a_proj[q_a_proj_name]
kv_a_proj_weight = cached_a_proj[kv_a_proj_name]
cat_dim = 0
if (
self.quant_config.get_name() == "awq"
or self.quant_config.get_name() == "moe_wna16"
):
cat_dim = 1
fused_weight = torch.cat(
[q_a_proj_weight, kv_a_proj_weight], dim=0
[q_a_proj_weight, kv_a_proj_weight], dim=cat_dim
)
param_name = (
name.replace("q_a_proj", "fused_qkv_a_proj_with_mqa")