fix dsv3 fused proj check (#7738)
This commit is contained in:
@@ -336,10 +336,12 @@ class DeepseekV2MoE(nn.Module):
|
||||
else {}
|
||||
),
|
||||
)
|
||||
is_packed_weight = (
|
||||
self.shared_experts.gate_up_proj.quant_method.quant_config.get_name()
|
||||
in ["awq", "moe_wna16"]
|
||||
)
|
||||
is_packed_weight = hasattr(
|
||||
self.shared_experts.gate_up_proj.quant_method, "quant_config"
|
||||
) and self.shared_experts.gate_up_proj.quant_method.quant_config.get_name() in {
|
||||
"awq",
|
||||
"moe_wna16",
|
||||
}
|
||||
self.shared_experts_is_int8 = (
|
||||
not is_packed_weight
|
||||
and self.shared_experts.gate_up_proj.weight.dtype == torch.int8
|
||||
@@ -891,21 +893,20 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
# If we have self.fused_qkv_a_proj_with_mqa and we're running on CPU, we will choose the torch.ops.sgl_kernel.qkv_proj_with_rope_fused_weight kernel
|
||||
# which requires self.w_kc and self.w_vc to be packed.
|
||||
# If not, we will use torch.bmm and weight shouldn't be packed in this case
|
||||
if (
|
||||
hasattr(self, "fused_qkv_a_proj_with_mqa")
|
||||
and _is_cpu
|
||||
and _is_cpu_amx_available
|
||||
):
|
||||
has_fused_proj = hasattr(self, "fused_qkv_a_proj_with_mqa")
|
||||
if has_fused_proj and _is_cpu and _is_cpu_amx_available:
|
||||
self.quant_method = PackWeightMethod(
|
||||
weight_names=["w_kc", "w_vc"], transpose_dims=[[1, 2], [1, 2]]
|
||||
)
|
||||
|
||||
is_packed_weight = (
|
||||
self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.get_name()
|
||||
in ["awq", "moe_wna16"]
|
||||
has_fused_proj
|
||||
and hasattr(self.fused_qkv_a_proj_with_mqa.quant_method, "quant_config")
|
||||
and self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.get_name()
|
||||
in {"awq", "moe_wna16"}
|
||||
)
|
||||
self.use_min_latency_fused_a_gemm = (
|
||||
hasattr(self, "fused_qkv_a_proj_with_mqa")
|
||||
has_fused_proj
|
||||
and not is_packed_weight
|
||||
and self.fused_qkv_a_proj_with_mqa.weight.dtype == torch.bfloat16
|
||||
and self.fused_qkv_a_proj_with_mqa.weight.shape[0] == 2112
|
||||
@@ -915,12 +916,12 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
)
|
||||
|
||||
self.qkv_proj_with_rope_is_int8 = (
|
||||
hasattr(self, "fused_qkv_a_proj_with_mqa")
|
||||
has_fused_proj
|
||||
and not is_packed_weight
|
||||
and self.fused_qkv_a_proj_with_mqa.weight.dtype == torch.int8
|
||||
)
|
||||
self.qkv_proj_with_rope_is_fp8 = (
|
||||
hasattr(self, "fused_qkv_a_proj_with_mqa")
|
||||
has_fused_proj
|
||||
and not is_packed_weight
|
||||
and self.fused_qkv_a_proj_with_mqa.weight.dtype == torch.float8_e4m3fn
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user