Remove 200us slow concat kernel (part 2: srt) (#7020)

This commit is contained in:
fzyzcjy
2025-06-14 06:19:31 +08:00
committed by GitHub
parent 0f1dfa1efe
commit c49c1d9226
2 changed files with 39 additions and 11 deletions

View File

@@ -1013,7 +1013,11 @@ class DeepseekV2AttentionMLA(nn.Module):
def forward_absorb_core(
self, q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator
):
if self.attention_backend == "fa3" or self.attention_backend == "flashinfer":
if (
self.attention_backend == "fa3"
or self.attention_backend == "flashinfer"
or self.attention_backend == "cutlass_mla"
):
attn_output = self.attn_mqa(
q_nope_out, k_nope, k_nope, forward_batch, q_rope=q_pe, k_rope=k_pe
)