From 162f3ccb01d9b31d21f1a1ae3d6cabbfe4079838 Mon Sep 17 00:00:00 2001 From: Ke Bao Date: Tue, 13 Aug 2024 13:48:07 +0800 Subject: [PATCH] Fix layernorm input shape (#1066) Co-authored-by: Yineng Zhang --- python/sglang/srt/models/deepseek_v2.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 2198428b8..13dd47739 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -445,11 +445,12 @@ class DeepseekV2AttentionMLA(nn.Module): q_nope_out = q_input[..., : self.kv_lora_rank] torch.bmm(q_nope.transpose(0, 1), self.w_kc, out=q_nope_out.transpose(0, 1)) - k_input = self.kv_a_proj_with_mqa(hidden_states)[0].unsqueeze(1) - k_pe = k_input[..., self.kv_lora_rank :] - v_input = k_input[..., : self.kv_lora_rank] - v_input = self.kv_a_layernorm(v_input.contiguous()) + latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0] + v_input = latent_cache[..., : self.kv_lora_rank] + v_input = self.kv_a_layernorm(v_input.contiguous()).unsqueeze(1) + k_input = latent_cache.unsqueeze(1) k_input[..., : self.kv_lora_rank] = v_input + k_pe = k_input[..., self.kv_lora_rank :] q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) q_input[..., self.kv_lora_rank :] = q_pe