Move filter_batch out of stream_output (#1663)
This commit is contained in:
@@ -446,31 +446,41 @@ class Scheduler:
|
||||
exit(1) if crash_on_warning else None
|
||||
|
||||
def get_next_batch_to_run(self):
|
||||
# Merge prefill to the running batch
|
||||
# Merge the prefill batch into the running batch
|
||||
if (
|
||||
self.last_batch
|
||||
and not self.last_batch.forward_mode.is_decode()
|
||||
and not self.last_batch.is_empty()
|
||||
):
|
||||
if self.running_batch is None:
|
||||
self.running_batch = self.last_batch
|
||||
else:
|
||||
self.running_batch.merge_batch(self.last_batch)
|
||||
if self.current_inflight_req:
|
||||
self.last_batch.filter_batch(self.current_inflight_req)
|
||||
self.batch_is_full = False
|
||||
if not self.last_batch.is_empty():
|
||||
if self.running_batch is None:
|
||||
self.running_batch = self.last_batch
|
||||
else:
|
||||
self.running_batch.merge_batch(self.last_batch)
|
||||
|
||||
# Prefill first
|
||||
new_batch = self.get_new_batch_prefill()
|
||||
if new_batch is not None:
|
||||
return new_batch
|
||||
|
||||
# Run decode
|
||||
if self.running_batch is not None:
|
||||
self.update_running_batch()
|
||||
if not self.running_batch:
|
||||
return None
|
||||
return self.running_batch
|
||||
else:
|
||||
# Check memory
|
||||
if self.running_batch is None:
|
||||
self.check_memory()
|
||||
self.new_token_ratio = global_config.init_new_token_ratio
|
||||
return
|
||||
|
||||
# Run decode
|
||||
before_bs = self.running_batch.batch_size()
|
||||
self.update_running_batch()
|
||||
if not self.running_batch:
|
||||
self.batch_is_full = False
|
||||
return None
|
||||
if before_bs != self.running_batch.batch_size():
|
||||
self.batch_is_full = False
|
||||
return self.running_batch
|
||||
|
||||
def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
|
||||
# Handle the cases where prefill is not allowed
|
||||
@@ -617,6 +627,11 @@ class Scheduler:
|
||||
global test_retract
|
||||
batch = self.running_batch
|
||||
|
||||
batch.filter_batch()
|
||||
if batch.is_empty():
|
||||
self.running_batch = None
|
||||
return
|
||||
|
||||
# Check if decode out of memory
|
||||
if not batch.check_decode_mem() or (test_retract and batch.batch_size() > 10):
|
||||
old_ratio = self.new_token_ratio
|
||||
@@ -640,8 +655,6 @@ class Scheduler:
|
||||
if not self.disable_regex_jump_forward:
|
||||
jump_forward_reqs = batch.check_for_jump_forward(self.pad_input_ids_func)
|
||||
self.waiting_queue.extend(jump_forward_reqs)
|
||||
if jump_forward_reqs:
|
||||
self.batch_is_full = False
|
||||
if batch.is_empty():
|
||||
self.running_batch = None
|
||||
return
|
||||
@@ -892,14 +905,8 @@ class Scheduler:
|
||||
output_no_stop_trim = []
|
||||
else: # embedding or reward model
|
||||
output_embeddings = []
|
||||
unfinished_indices = []
|
||||
|
||||
for i, req in enumerate(batch.reqs):
|
||||
if not req.finished() and req is not self.current_inflight_req:
|
||||
unfinished_indices.append(i)
|
||||
else:
|
||||
self.batch_is_full = False
|
||||
|
||||
for req in batch.reqs:
|
||||
if req.finished() or (
|
||||
req.stream
|
||||
and (
|
||||
@@ -955,9 +962,6 @@ class Scheduler:
|
||||
}
|
||||
output_meta_info.append(meta_info)
|
||||
|
||||
# Remove finished reqs: update batch tensors
|
||||
batch.filter_batch(unfinished_indices)
|
||||
|
||||
# Send to detokenizer
|
||||
if output_rids:
|
||||
if self.is_generation:
|
||||
|
||||
Reference in New Issue
Block a user