diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index c2ceac39a..7c535a6ef 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -154,6 +154,7 @@ if _is_cuda: from sgl_kernel import ( awq_dequantize, bmm_fp8, + concat_mla_k, dsv3_fused_a_gemm, dsv3_router_gemm, merge_state_v2, @@ -1295,8 +1296,18 @@ class DeepseekV2AttentionMLA(nn.Module): q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) q[..., self.qk_nope_head_dim :] = q_pe k = torch.empty_like(q) - k[..., : self.qk_nope_head_dim] = k_nope - k[..., self.qk_nope_head_dim :] = k_pe + + # Temporary for DeepSeek V3/R1 only, but can generalize if needed + if ( + _is_cuda + and (self.num_local_heads == 128) + and (self.qk_nope_head_dim == 128) + and (self.qk_rope_head_dim == 64) + ): + concat_mla_k(k=k, k_nope=k_nope, k_rope=k_pe) + else: + k[..., : self.qk_nope_head_dim] = k_nope + k[..., self.qk_nope_head_dim :] = k_pe if not _is_npu: latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1)