Fix AWQ Dequant and Weight Loading of deepseek v2 (#6842)
This commit is contained in:
@@ -2137,8 +2137,14 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
):
|
||||
q_a_proj_weight = cached_a_proj[q_a_proj_name]
|
||||
kv_a_proj_weight = cached_a_proj[kv_a_proj_name]
|
||||
cat_dim = 0
|
||||
if (
|
||||
self.quant_config.get_name() == "awq"
|
||||
or self.quant_config.get_name() == "moe_wna16"
|
||||
):
|
||||
cat_dim = 1
|
||||
fused_weight = torch.cat(
|
||||
[q_a_proj_weight, kv_a_proj_weight], dim=0
|
||||
[q_a_proj_weight, kv_a_proj_weight], dim=cat_dim
|
||||
)
|
||||
param_name = (
|
||||
name.replace("q_a_proj", "fused_qkv_a_proj_with_mqa")
|
||||
|
||||
Reference in New Issue
Block a user