opt flashinfer mla cat (#5822)
Co-authored-by: xuyongfei.xyf <xuyongfei.xyf@antgroup.com>
This commit is contained in:
@@ -777,7 +777,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
q_nope_out = q_nope_out.transpose(0, 1)
|
||||
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
|
||||
|
||||
if self.attention_backend == "fa3":
|
||||
if self.attention_backend == "fa3" or self.attention_backend == "flashinfer":
|
||||
attn_output = self.attn_mqa(
|
||||
q_nope_out, k_nope, k_nope, forward_batch, q_rope=q_pe, k_rope=k_pe
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user