Fuse MLA set kv cache kernel (#5748)

This commit is contained in:
Ke Bao
2025-04-27 09:42:22 +08:00
committed by GitHub
parent 02723e1b0d
commit 799c4bb502
4 changed files with 100 additions and 9 deletions

View File

@@ -757,14 +757,13 @@ class DeepseekV2AttentionMLA(nn.Module):
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
k = torch.cat([k_nope, k_pe], dim=-1)
if self.attention_backend == "fa3":
attn_output = self.attn_mqa(
q_nope_out, k, k_nope, forward_batch, q_rope=q_pe
q_nope_out, k_nope, k_nope, forward_batch, q_rope=q_pe, k_rope=k_pe
)
else:
q = torch.cat([q_nope_out, q_pe], dim=-1)
k = torch.cat([k_nope, k_pe], dim=-1)
attn_output = self.attn_mqa(q, k, k_nope, forward_batch)
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)