[Revision] Add fast decode plan for flashinfer mla (#4012)
This commit is contained in:
@@ -353,6 +353,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
seq_lens_cpu: Optional[torch.Tensor],
|
||||
):
|
||||
if forward_mode.is_decode_or_idle():
|
||||
self.indices_updater_decode.update(
|
||||
@@ -1058,6 +1059,7 @@ class FlashInferMultiStepDraftBackend:
|
||||
encoder_lens=None,
|
||||
forward_mode=ForwardMode.DECODE,
|
||||
spec_info=forward_batch.spec_info,
|
||||
seq_lens_cpu=None,
|
||||
)
|
||||
|
||||
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
|
||||
|
||||
Reference in New Issue
Block a user