Integrate trtllm ragged attention for prefill self-attention (#9801)
This commit is contained in:
@@ -1050,7 +1050,6 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
attention_backend == "flashinfer"
|
||||
or attention_backend == "fa3"
|
||||
or attention_backend == "flashmla"
|
||||
or attention_backend == "trtllm_mla"
|
||||
or attention_backend == "cutlass_mla"
|
||||
):
|
||||
# Use MHA with chunked KV cache when prefilling on long sequences.
|
||||
@@ -1079,6 +1078,15 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
return AttnForwardMethod.MHA_CHUNKED_KV
|
||||
else:
|
||||
return _dispatch_mla_subtype()
|
||||
elif attention_backend == "trtllm_mla":
|
||||
if (
|
||||
forward_batch.forward_mode.is_extend()
|
||||
and not forward_batch.forward_mode.is_target_verify()
|
||||
and not forward_batch.forward_mode.is_draft_extend()
|
||||
):
|
||||
return AttnForwardMethod.MHA_CHUNKED_KV
|
||||
else:
|
||||
return _dispatch_mla_subtype()
|
||||
elif attention_backend == "aiter":
|
||||
if (
|
||||
forward_batch.forward_mode.is_extend()
|
||||
|
||||
Reference in New Issue
Block a user