upgrade flashinfer v0.2.0.post2 (#3288)

Co-authored-by: pankajroark <pankajroark@users.noreply.github.com>
This commit is contained in:
Yineng Zhang
2025-02-04 21:41:40 +08:00
committed by GitHub
parent 70817a7eae
commit d39899e85c
8 changed files with 42 additions and 51 deletions

View File

@@ -269,6 +269,7 @@ class EAGLEWorker(TpModelWorker):
def forward_draft_extend_after_decode(self, batch: ScheduleBatch):
seq_lens_backup = batch.seq_lens
req_pool_indices_backup = batch.req_pool_indices
self._set_mem_pool(batch, self.model_runner)
batch.forward_mode = ForwardMode.DRAFT_EXTEND
@@ -284,6 +285,7 @@ class EAGLEWorker(TpModelWorker):
# This is because `seq_lens` can be modified in `prepare_extend_after_decode`
batch.forward_mode = ForwardMode.DECODE
batch.seq_lens = seq_lens_backup
batch.req_pool_indices = req_pool_indices_backup
def capture_for_decode(
self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch