Fix layernorm input shape (#1066)

Co-authored-by: Yineng Zhang <me@zhyncs.com>
This commit is contained in:
Ke Bao
2024-08-13 13:48:07 +08:00
committed by GitHub
parent 65e89baea9
commit 162f3ccb01

View File

@@ -445,11 +445,12 @@ class DeepseekV2AttentionMLA(nn.Module):
q_nope_out = q_input[..., : self.kv_lora_rank]
torch.bmm(q_nope.transpose(0, 1), self.w_kc, out=q_nope_out.transpose(0, 1))
k_input = self.kv_a_proj_with_mqa(hidden_states)[0].unsqueeze(1)
k_pe = k_input[..., self.kv_lora_rank :]
v_input = k_input[..., : self.kv_lora_rank]
v_input = self.kv_a_layernorm(v_input.contiguous())
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
v_input = latent_cache[..., : self.kv_lora_rank]
v_input = self.kv_a_layernorm(v_input.contiguous()).unsqueeze(1)
k_input = latent_cache.unsqueeze(1)
k_input[..., : self.kv_lora_rank] = v_input
k_pe = k_input[..., self.kv_lora_rank :]
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
q_input[..., self.kv_lora_rank :] = q_pe