Integrate trtllm ragged attention for prefill self-attention (#9801)

This commit is contained in:
Elfie Guo
2025-09-05 02:18:00 -07:00
committed by GitHub
parent f98366604b
commit bebd0576e5
4 changed files with 300 additions and 44 deletions

View File

@@ -1050,7 +1050,6 @@ class DeepseekV2AttentionMLA(nn.Module):
attention_backend == "flashinfer"
or attention_backend == "fa3"
or attention_backend == "flashmla"
or attention_backend == "trtllm_mla"
or attention_backend == "cutlass_mla"
):
# Use MHA with chunked KV cache when prefilling on long sequences.
@@ -1079,6 +1078,15 @@ class DeepseekV2AttentionMLA(nn.Module):
return AttnForwardMethod.MHA_CHUNKED_KV
else:
return _dispatch_mla_subtype()
elif attention_backend == "trtllm_mla":
if (
forward_batch.forward_mode.is_extend()
and not forward_batch.forward_mode.is_target_verify()
and not forward_batch.forward_mode.is_draft_extend()
):
return AttnForwardMethod.MHA_CHUNKED_KV
else:
return _dispatch_mla_subtype()
elif attention_backend == "aiter":
if (
forward_batch.forward_mode.is_extend()