Simplify batch update (#2154)
This commit is contained in:
@@ -467,6 +467,7 @@ class ScheduleBatch:
|
||||
extend_lens: List[int] = None
|
||||
extend_num_tokens: int = None
|
||||
decoding_reqs: List[Req] = None
|
||||
extend_logprob_start_lens: List[int] = None
|
||||
|
||||
# For encoder-decoder
|
||||
encoder_cached: Optional[List[bool]] = None
|
||||
@@ -722,7 +723,6 @@ class ScheduleBatch:
|
||||
self.merge_batch(running_batch)
|
||||
self.input_ids = input_ids
|
||||
self.out_cache_loc = out_cache_loc
|
||||
self.extend_num_tokens += running_bs
|
||||
|
||||
# NOTE: prefix_indices is what has been cached, but we don't cache each decode step
|
||||
self.prefix_lens.extend(
|
||||
@@ -732,6 +732,8 @@ class ScheduleBatch:
|
||||
]
|
||||
)
|
||||
self.extend_lens.extend([1] * running_bs)
|
||||
self.extend_num_tokens += running_bs
|
||||
# TODO (lianmin): Revisit this. It should be seq_len - 1
|
||||
self.extend_logprob_start_lens.extend([0] * running_bs)
|
||||
|
||||
def check_decode_mem(self):
|
||||
|
||||
Reference in New Issue
Block a user