Fix layernorm input shape (#1066)
Co-authored-by: Yineng Zhang <me@zhyncs.com>
This commit is contained in:
@@ -445,11 +445,12 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
q_nope_out = q_input[..., : self.kv_lora_rank]
|
||||
torch.bmm(q_nope.transpose(0, 1), self.w_kc, out=q_nope_out.transpose(0, 1))
|
||||
|
||||
k_input = self.kv_a_proj_with_mqa(hidden_states)[0].unsqueeze(1)
|
||||
k_pe = k_input[..., self.kv_lora_rank :]
|
||||
v_input = k_input[..., : self.kv_lora_rank]
|
||||
v_input = self.kv_a_layernorm(v_input.contiguous())
|
||||
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
|
||||
v_input = latent_cache[..., : self.kv_lora_rank]
|
||||
v_input = self.kv_a_layernorm(v_input.contiguous()).unsqueeze(1)
|
||||
k_input = latent_cache.unsqueeze(1)
|
||||
k_input[..., : self.kv_lora_rank] = v_input
|
||||
k_pe = k_input[..., self.kv_lora_rank :]
|
||||
|
||||
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
|
||||
q_input[..., self.kv_lora_rank :] = q_pe
|
||||
|
||||
Reference in New Issue
Block a user