upgrade flashinfer v0.2.0.post2 (#3288)

Co-authored-by: pankajroark <pankajroark@users.noreply.github.com>
This commit is contained in:
Yineng Zhang
2025-02-04 21:41:40 +08:00
committed by GitHub
parent 70817a7eae
commit d39899e85c
8 changed files with 42 additions and 51 deletions

View File

@@ -69,6 +69,7 @@ class EagleDraftInput:
accept_length_cpu = batch.spec_info.accept_length_cpu
batch.extend_lens = [x + 1 for x in accept_length_cpu]
batch.seq_lens = batch.spec_info.seq_lens_for_draft_extend
batch.req_pool_indices = batch.spec_info.req_pool_indices_for_draft_extend
seq_lens_cpu = batch.seq_lens.tolist()
pt = 0
@@ -353,8 +354,12 @@ class EagleVerifyInput:
]
if has_finished:
draft_input.seq_lens_for_draft_extend = batch.seq_lens[unfinished_index]
draft_input.req_pool_indices_for_draft_extend = batch.req_pool_indices[
unfinished_index
]
else:
draft_input.seq_lens_for_draft_extend = batch.seq_lens
draft_input.req_pool_indices_for_draft_extend = batch.req_pool_indices
logits_output.next_token_logits = logits_output.next_token_logits[accept_index]
return (

View File

@@ -269,6 +269,7 @@ class EAGLEWorker(TpModelWorker):
def forward_draft_extend_after_decode(self, batch: ScheduleBatch):
seq_lens_backup = batch.seq_lens
req_pool_indices_backup = batch.req_pool_indices
self._set_mem_pool(batch, self.model_runner)
batch.forward_mode = ForwardMode.DRAFT_EXTEND
@@ -284,6 +285,7 @@ class EAGLEWorker(TpModelWorker):
# This is because `seq_lens` can be modified in `prepare_extend_after_decode`
batch.forward_mode = ForwardMode.DECODE
batch.seq_lens = seq_lens_backup
batch.req_pool_indices = req_pool_indices_backup
def capture_for_decode(
self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch