Remove q concat in FA3 backend for DeepSeek decode (#5638)

This commit is contained in:
Ke Bao
2025-04-23 02:43:12 +08:00
committed by GitHub
parent 917324862e
commit 6b6e748775
4 changed files with 40 additions and 9 deletions

View File

@@ -751,10 +751,15 @@ class DeepseekV2AttentionMLA(nn.Module):
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
q = torch.cat([q_nope_out, q_pe], dim=-1)
k = torch.cat([k_nope, k_pe], dim=-1)
attn_output = self.attn_mqa(q, k, k_nope, forward_batch)
if self.attention_backend == "fa3":
attn_output = self.attn_mqa(
q_nope_out, k, k_nope, forward_batch, q_rope=q_pe
)
else:
q = torch.cat([q_nope_out, q_pe], dim=-1)
attn_output = self.attn_mqa(q, k, k_nope, forward_batch)
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
if self.use_deep_gemm_bmm: