upgrade flashinfer v0.2.0.post2 (#3288)
Co-authored-by: pankajroark <pankajroark@users.noreply.github.com>
This commit is contained in:
@@ -269,6 +269,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
|
||||
def forward_draft_extend_after_decode(self, batch: ScheduleBatch):
|
||||
seq_lens_backup = batch.seq_lens
|
||||
req_pool_indices_backup = batch.req_pool_indices
|
||||
|
||||
self._set_mem_pool(batch, self.model_runner)
|
||||
batch.forward_mode = ForwardMode.DRAFT_EXTEND
|
||||
@@ -284,6 +285,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
# This is because `seq_lens` can be modified in `prepare_extend_after_decode`
|
||||
batch.forward_mode = ForwardMode.DECODE
|
||||
batch.seq_lens = seq_lens_backup
|
||||
batch.req_pool_indices = req_pool_indices_backup
|
||||
|
||||
def capture_for_decode(
|
||||
self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
|
||||
|
||||
Reference in New Issue
Block a user