Add fast decode plan for flashinfer mla (#3987)

This commit is contained in:
Baizhou Zhang
2025-03-02 19:16:37 -08:00
committed by GitHub
parent 7fbab730bd
commit fa56106731
9 changed files with 156 additions and 52 deletions

View File

@@ -269,9 +269,10 @@ class FlashInferAttnBackend(AttentionBackend):
num_tokens: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[SpecInfo],
**kwargs,
):
if forward_mode.is_decode_or_idle():
decode_wrappers = []
@@ -339,9 +340,10 @@ class FlashInferAttnBackend(AttentionBackend):
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[SpecInfo],
**kwargs,
):
if forward_mode.is_decode_or_idle():
self.indices_updater_decode.update(