upgrade flashinfer v0.2.0.post2 (#3288)
Co-authored-by: pankajroark <pankajroark@users.noreply.github.com>
This commit is contained in:
@@ -69,6 +69,7 @@ class EagleDraftInput:
|
||||
accept_length_cpu = batch.spec_info.accept_length_cpu
|
||||
batch.extend_lens = [x + 1 for x in accept_length_cpu]
|
||||
batch.seq_lens = batch.spec_info.seq_lens_for_draft_extend
|
||||
batch.req_pool_indices = batch.spec_info.req_pool_indices_for_draft_extend
|
||||
seq_lens_cpu = batch.seq_lens.tolist()
|
||||
|
||||
pt = 0
|
||||
@@ -353,8 +354,12 @@ class EagleVerifyInput:
|
||||
]
|
||||
if has_finished:
|
||||
draft_input.seq_lens_for_draft_extend = batch.seq_lens[unfinished_index]
|
||||
draft_input.req_pool_indices_for_draft_extend = batch.req_pool_indices[
|
||||
unfinished_index
|
||||
]
|
||||
else:
|
||||
draft_input.seq_lens_for_draft_extend = batch.seq_lens
|
||||
draft_input.req_pool_indices_for_draft_extend = batch.req_pool_indices
|
||||
|
||||
logits_output.next_token_logits = logits_output.next_token_logits[accept_index]
|
||||
return (
|
||||
|
||||
@@ -269,6 +269,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
|
||||
def forward_draft_extend_after_decode(self, batch: ScheduleBatch):
|
||||
seq_lens_backup = batch.seq_lens
|
||||
req_pool_indices_backup = batch.req_pool_indices
|
||||
|
||||
self._set_mem_pool(batch, self.model_runner)
|
||||
batch.forward_mode = ForwardMode.DRAFT_EXTEND
|
||||
@@ -284,6 +285,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
# This is because `seq_lens` can be modified in `prepare_extend_after_decode`
|
||||
batch.forward_mode = ForwardMode.DECODE
|
||||
batch.seq_lens = seq_lens_backup
|
||||
batch.req_pool_indices = req_pool_indices_backup
|
||||
|
||||
def capture_for_decode(
|
||||
self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
|
||||
|
||||
Reference in New Issue
Block a user