From 84f2e4a0f847b0664d90dab3b6af58b7a667d37c Mon Sep 17 00:00:00 2001 From: AniZpZ Date: Thu, 3 Jul 2025 13:56:57 +0800 Subject: [PATCH] fix awq and dsv3 fused gemm compatible (#7735) --- python/sglang/srt/models/deepseek_v2.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 1646a2858..d4b2d72f7 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -336,11 +336,17 @@ class DeepseekV2MoE(nn.Module): else {} ), ) + is_packed_weight = ( + self.shared_experts.gate_up_proj.quant_method.quant_config.get_name() + in ["awq", "moe_wna16"] + ) self.shared_experts_is_int8 = ( - self.shared_experts.gate_up_proj.weight.dtype == torch.int8 + not is_packed_weight + and self.shared_experts.gate_up_proj.weight.dtype == torch.int8 ) self.shared_experts_is_fp8 = ( - self.shared_experts.gate_up_proj.weight.dtype == torch.float8_e4m3fn + not is_packed_weight + and self.shared_experts.gate_up_proj.weight.dtype == torch.float8_e4m3fn ) if self.shared_experts_is_fp8: assert ( @@ -894,8 +900,13 @@ class DeepseekV2AttentionMLA(nn.Module): weight_names=["w_kc", "w_vc"], transpose_dims=[[1, 2], [1, 2]] ) + is_packed_weight = ( + self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.get_name() + in ["awq", "moe_wna16"] + ) self.use_min_latency_fused_a_gemm = ( hasattr(self, "fused_qkv_a_proj_with_mqa") + and not is_packed_weight and self.fused_qkv_a_proj_with_mqa.weight.dtype == torch.bfloat16 and self.fused_qkv_a_proj_with_mqa.weight.shape[0] == 2112 and self.fused_qkv_a_proj_with_mqa.weight.shape[1] == 7168 @@ -905,10 +916,12 @@ class DeepseekV2AttentionMLA(nn.Module): self.qkv_proj_with_rope_is_int8 = ( hasattr(self, "fused_qkv_a_proj_with_mqa") + and not is_packed_weight and self.fused_qkv_a_proj_with_mqa.weight.dtype == torch.int8 ) self.qkv_proj_with_rope_is_fp8 = ( hasattr(self, "fused_qkv_a_proj_with_mqa") + and not is_packed_weight and self.fused_qkv_a_proj_with_mqa.weight.dtype == torch.float8_e4m3fn )