Simplify flashinfer utilities (#1704)
This commit is contained in:
@@ -744,7 +744,6 @@ class ScheduleBatch:
|
||||
self.forward_mode = ForwardMode.DECODE
|
||||
|
||||
self.input_ids = self.output_ids
|
||||
self.seq_lens.add_(1)
|
||||
self.output_ids = None
|
||||
if self.sampling_info.penalizer_orchestrator:
|
||||
self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
||||
@@ -755,9 +754,10 @@ class ScheduleBatch:
|
||||
bs = len(self.reqs)
|
||||
self.out_cache_loc = self.alloc_token_slots(bs)
|
||||
|
||||
self.req_to_token_pool.req_to_token[
|
||||
self.req_pool_indices, self.seq_lens - 1
|
||||
] = self.out_cache_loc
|
||||
self.req_to_token_pool.req_to_token[self.req_pool_indices, self.seq_lens] = (
|
||||
self.out_cache_loc
|
||||
)
|
||||
self.seq_lens.add_(1)
|
||||
|
||||
def filter_batch(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user