Simplify batch update (#2154)

This commit is contained in:
Lianmin Zheng
2024-11-24 04:47:10 -08:00
committed by GitHub
parent d90c3d6b8b
commit c211e7b669
7 changed files with 47 additions and 46 deletions

View File

@@ -467,6 +467,7 @@ class ScheduleBatch:
extend_lens: List[int] = None
extend_num_tokens: int = None
decoding_reqs: List[Req] = None
extend_logprob_start_lens: List[int] = None
# For encoder-decoder
encoder_cached: Optional[List[bool]] = None
@@ -722,7 +723,6 @@ class ScheduleBatch:
self.merge_batch(running_batch)
self.input_ids = input_ids
self.out_cache_loc = out_cache_loc
self.extend_num_tokens += running_bs
# NOTE: prefix_indices is what has been cached, but we don't cache each decode step
self.prefix_lens.extend(
@@ -732,6 +732,8 @@ class ScheduleBatch:
]
)
self.extend_lens.extend([1] * running_bs)
self.extend_num_tokens += running_bs
# TODO (lianmin): Revisit this. It should be seq_len - 1
self.extend_logprob_start_lens.extend([0] * running_bs)
def check_decode_mem(self):