diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index f1ab8c3e7..bf867407d 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -96,6 +96,7 @@ from sglang.srt.utils import ( bind_or_assign, cpu_has_amx_support, get_bool_env_var, + get_device_sm, get_int_env_var, is_cpu, is_cuda, @@ -112,7 +113,7 @@ _is_cpu_amx_available = cpu_has_amx_support() _is_cpu = is_cpu() if _is_cuda: - from sgl_kernel import awq_dequantize, bmm_fp8, merge_state_v2 + from sgl_kernel import awq_dequantize, bmm_fp8, dsv3_fused_a_gemm, merge_state_v2 elif _is_cpu and _is_cpu_amx_available: pass else: @@ -875,6 +876,15 @@ class DeepseekV2AttentionMLA(nn.Module): weight_names=["w_kc", "w_vc"], transpose_dims=[[1, 2], [1, 2]] ) + self.use_min_latency_fused_a_gemm = ( + hasattr(self, "fused_qkv_a_proj_with_mqa") + and self.fused_qkv_a_proj_with_mqa.weight.dtype == torch.bfloat16 + and self.fused_qkv_a_proj_with_mqa.weight.shape[0] == 2112 + and self.fused_qkv_a_proj_with_mqa.weight.shape[1] == 7168 + and is_cuda + and get_device_sm() >= 90 + ) + self.qkv_proj_with_rope_is_int8 = ( hasattr(self, "fused_qkv_a_proj_with_mqa") and self.fused_qkv_a_proj_with_mqa.weight.dtype == torch.int8 @@ -1114,7 +1124,13 @@ class DeepseekV2AttentionMLA(nn.Module): from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode if self.q_lora_rank is not None: - q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split( + if hidden_states.shape[0] <= 16 and self.use_min_latency_fused_a_gemm: + fused_qkv_a_proj_out = dsv3_fused_a_gemm( + hidden_states, self.fused_qkv_a_proj_with_mqa.weight.T + ) + else: + fused_qkv_a_proj_out = self.fused_qkv_a_proj_with_mqa(hidden_states)[0] + q, latent_cache = fused_qkv_a_proj_out.split( [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1 ) k_nope = latent_cache[..., : self.kv_lora_rank]