Simplify flashinfer utilities (#1704)

This commit is contained in:
Lianmin Zheng
2024-10-17 22:54:14 -07:00
committed by GitHub
parent 9e0dac1ad7
commit 6d0fa73ece
8 changed files with 391 additions and 337 deletions

View File

@@ -744,7 +744,6 @@ class ScheduleBatch:
self.forward_mode = ForwardMode.DECODE
self.input_ids = self.output_ids
self.seq_lens.add_(1)
self.output_ids = None
if self.sampling_info.penalizer_orchestrator:
self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
@@ -755,9 +754,10 @@ class ScheduleBatch:
bs = len(self.reqs)
self.out_cache_loc = self.alloc_token_slots(bs)
self.req_to_token_pool.req_to_token[
self.req_pool_indices, self.seq_lens - 1
] = self.out_cache_loc
self.req_to_token_pool.req_to_token[self.req_pool_indices, self.seq_lens] = (
self.out_cache_loc
)
self.seq_lens.add_(1)
def filter_batch(
self,