fix awq and dsv3 fused gemm compatible (#7735)
This commit is contained in:
@@ -336,11 +336,17 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
else {}
|
else {}
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
is_packed_weight = (
|
||||||
|
self.shared_experts.gate_up_proj.quant_method.quant_config.get_name()
|
||||||
|
in ["awq", "moe_wna16"]
|
||||||
|
)
|
||||||
self.shared_experts_is_int8 = (
|
self.shared_experts_is_int8 = (
|
||||||
self.shared_experts.gate_up_proj.weight.dtype == torch.int8
|
not is_packed_weight
|
||||||
|
and self.shared_experts.gate_up_proj.weight.dtype == torch.int8
|
||||||
)
|
)
|
||||||
self.shared_experts_is_fp8 = (
|
self.shared_experts_is_fp8 = (
|
||||||
self.shared_experts.gate_up_proj.weight.dtype == torch.float8_e4m3fn
|
not is_packed_weight
|
||||||
|
and self.shared_experts.gate_up_proj.weight.dtype == torch.float8_e4m3fn
|
||||||
)
|
)
|
||||||
if self.shared_experts_is_fp8:
|
if self.shared_experts_is_fp8:
|
||||||
assert (
|
assert (
|
||||||
@@ -894,8 +900,13 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
weight_names=["w_kc", "w_vc"], transpose_dims=[[1, 2], [1, 2]]
|
weight_names=["w_kc", "w_vc"], transpose_dims=[[1, 2], [1, 2]]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
is_packed_weight = (
|
||||||
|
self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.get_name()
|
||||||
|
in ["awq", "moe_wna16"]
|
||||||
|
)
|
||||||
self.use_min_latency_fused_a_gemm = (
|
self.use_min_latency_fused_a_gemm = (
|
||||||
hasattr(self, "fused_qkv_a_proj_with_mqa")
|
hasattr(self, "fused_qkv_a_proj_with_mqa")
|
||||||
|
and not is_packed_weight
|
||||||
and self.fused_qkv_a_proj_with_mqa.weight.dtype == torch.bfloat16
|
and self.fused_qkv_a_proj_with_mqa.weight.dtype == torch.bfloat16
|
||||||
and self.fused_qkv_a_proj_with_mqa.weight.shape[0] == 2112
|
and self.fused_qkv_a_proj_with_mqa.weight.shape[0] == 2112
|
||||||
and self.fused_qkv_a_proj_with_mqa.weight.shape[1] == 7168
|
and self.fused_qkv_a_proj_with_mqa.weight.shape[1] == 7168
|
||||||
@@ -905,10 +916,12 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
|
|
||||||
self.qkv_proj_with_rope_is_int8 = (
|
self.qkv_proj_with_rope_is_int8 = (
|
||||||
hasattr(self, "fused_qkv_a_proj_with_mqa")
|
hasattr(self, "fused_qkv_a_proj_with_mqa")
|
||||||
|
and not is_packed_weight
|
||||||
and self.fused_qkv_a_proj_with_mqa.weight.dtype == torch.int8
|
and self.fused_qkv_a_proj_with_mqa.weight.dtype == torch.int8
|
||||||
)
|
)
|
||||||
self.qkv_proj_with_rope_is_fp8 = (
|
self.qkv_proj_with_rope_is_fp8 = (
|
||||||
hasattr(self, "fused_qkv_a_proj_with_mqa")
|
hasattr(self, "fused_qkv_a_proj_with_mqa")
|
||||||
|
and not is_packed_weight
|
||||||
and self.fused_qkv_a_proj_with_mqa.weight.dtype == torch.float8_e4m3fn
|
and self.fused_qkv_a_proj_with_mqa.weight.dtype == torch.float8_e4m3fn
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user