[Revision] Add fast decode plan for flashinfer mla (#4012)

This commit is contained in:
Baizhou Zhang
2025-03-05 11:20:41 -08:00
committed by GitHub
parent 71ab0dabe0
commit fc91d08a8f
9 changed files with 145 additions and 34 deletions

View File

@@ -353,6 +353,7 @@ class FlashInferAttnBackend(AttentionBackend):
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
seq_lens_cpu: Optional[torch.Tensor],
):
if forward_mode.is_decode_or_idle():
self.indices_updater_decode.update(
@@ -1058,6 +1059,7 @@ class FlashInferMultiStepDraftBackend:
encoder_lens=None,
forward_mode=ForwardMode.DECODE,
spec_info=forward_batch.spec_info,
seq_lens_cpu=None,
)
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)