opt flashinfer mla cat (#5822)

Co-authored-by: xuyongfei.xyf <xuyongfei.xyf@antgroup.com>
This commit is contained in:
xu-yfei
2025-05-09 14:17:14 +08:00
committed by GitHub
parent 0ab3f437ab
commit e30c273bc9
2 changed files with 60 additions and 14 deletions

View File

@@ -777,7 +777,7 @@ class DeepseekV2AttentionMLA(nn.Module):
q_nope_out = q_nope_out.transpose(0, 1)
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
if self.attention_backend == "fa3":
if self.attention_backend == "fa3" or self.attention_backend == "flashinfer":
attn_output = self.attn_mqa(
q_nope_out, k_nope, k_nope, forward_batch, q_rope=q_pe, k_rope=k_pe
)