Remove 200us slow concat kernel (part 2: srt) (#7020)
This commit is contained in:
@@ -1013,7 +1013,11 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
def forward_absorb_core(
|
||||
self, q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator
|
||||
):
|
||||
if self.attention_backend == "fa3" or self.attention_backend == "flashinfer":
|
||||
if (
|
||||
self.attention_backend == "fa3"
|
||||
or self.attention_backend == "flashinfer"
|
||||
or self.attention_backend == "cutlass_mla"
|
||||
):
|
||||
attn_output = self.attn_mqa(
|
||||
q_nope_out, k_nope, k_nope, forward_batch, q_rope=q_pe, k_rope=k_pe
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user