Move filter_batch out of stream_output (#1663)
This commit is contained in:
@@ -659,7 +659,7 @@ class ScheduleBatch:
|
|||||||
|
|
||||||
def check_for_jump_forward(self, pad_input_ids_func):
|
def check_for_jump_forward(self, pad_input_ids_func):
|
||||||
jump_forward_reqs = []
|
jump_forward_reqs = []
|
||||||
filter_indices = [i for i in range(len(self.reqs))]
|
keep_indices = set(i for i in range(len(self.reqs)))
|
||||||
|
|
||||||
for i, req in enumerate(self.reqs):
|
for i, req in enumerate(self.reqs):
|
||||||
if req.jump_forward_map is not None:
|
if req.jump_forward_map is not None:
|
||||||
@@ -719,9 +719,9 @@ class ScheduleBatch:
|
|||||||
)
|
)
|
||||||
|
|
||||||
jump_forward_reqs.append(req)
|
jump_forward_reqs.append(req)
|
||||||
filter_indices.remove(i)
|
keep_indices.remove(i)
|
||||||
|
|
||||||
self.filter_batch(filter_indices)
|
self.filter_batch(keep_indices=list(keep_indices))
|
||||||
|
|
||||||
return jump_forward_reqs
|
return jump_forward_reqs
|
||||||
|
|
||||||
@@ -740,19 +740,31 @@ class ScheduleBatch:
|
|||||||
self.req_pool_indices, self.seq_lens - 1
|
self.req_pool_indices, self.seq_lens - 1
|
||||||
] = self.out_cache_loc
|
] = self.out_cache_loc
|
||||||
|
|
||||||
def filter_batch(self, unfinished_indices: List[int]):
|
def filter_batch(
|
||||||
if unfinished_indices is None or len(unfinished_indices) == 0:
|
self,
|
||||||
|
current_inflight_req: Optional[Req] = None,
|
||||||
|
keep_indices: Optional[List[int]] = None,
|
||||||
|
):
|
||||||
|
if keep_indices is None:
|
||||||
|
keep_indices = [
|
||||||
|
i
|
||||||
|
for i in range(len(self.reqs))
|
||||||
|
if not self.reqs[i].finished()
|
||||||
|
and self.reqs[i] is not current_inflight_req
|
||||||
|
]
|
||||||
|
|
||||||
|
if keep_indices is None or len(keep_indices) == 0:
|
||||||
# Filter out all requests
|
# Filter out all requests
|
||||||
self.reqs = []
|
self.reqs = []
|
||||||
return
|
return
|
||||||
|
|
||||||
if len(unfinished_indices) == len(self.reqs):
|
if len(keep_indices) == len(self.reqs):
|
||||||
# No need to filter
|
# No need to filter
|
||||||
return
|
return
|
||||||
|
|
||||||
self.reqs = [self.reqs[i] for i in unfinished_indices]
|
self.reqs = [self.reqs[i] for i in keep_indices]
|
||||||
new_indices = torch.tensor(
|
new_indices = torch.tensor(
|
||||||
unfinished_indices, dtype=torch.int32, device=self.seq_lens.device
|
keep_indices, dtype=torch.int32, device=self.seq_lens.device
|
||||||
)
|
)
|
||||||
self.req_pool_indices = self.req_pool_indices[new_indices]
|
self.req_pool_indices = self.req_pool_indices[new_indices]
|
||||||
self.seq_lens = self.seq_lens[new_indices]
|
self.seq_lens = self.seq_lens[new_indices]
|
||||||
@@ -760,16 +772,14 @@ class ScheduleBatch:
|
|||||||
self.output_ids = self.output_ids[new_indices]
|
self.output_ids = self.output_ids[new_indices]
|
||||||
self.return_logprob = any(req.return_logprob for req in self.reqs)
|
self.return_logprob = any(req.return_logprob for req in self.reqs)
|
||||||
if self.return_logprob:
|
if self.return_logprob:
|
||||||
self.top_logprobs_nums = [
|
self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in keep_indices]
|
||||||
self.top_logprobs_nums[i] for i in unfinished_indices
|
|
||||||
]
|
|
||||||
else:
|
else:
|
||||||
self.top_logprobs_nums = None
|
self.top_logprobs_nums = None
|
||||||
|
|
||||||
self.has_stream = any(req.stream for req in self.reqs)
|
self.has_stream = any(req.stream for req in self.reqs)
|
||||||
self.has_regex = any(req.regex_fsm for req in self.reqs)
|
self.has_regex = any(req.regex_fsm for req in self.reqs)
|
||||||
|
|
||||||
self.sampling_info.filter_batch(unfinished_indices, new_indices)
|
self.sampling_info.filter_batch(keep_indices, new_indices)
|
||||||
|
|
||||||
def merge_batch(self, other: "ScheduleBatch"):
|
def merge_batch(self, other: "ScheduleBatch"):
|
||||||
# Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
|
# Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
|
||||||
|
|||||||
@@ -446,31 +446,41 @@ class Scheduler:
|
|||||||
exit(1) if crash_on_warning else None
|
exit(1) if crash_on_warning else None
|
||||||
|
|
||||||
def get_next_batch_to_run(self):
|
def get_next_batch_to_run(self):
|
||||||
# Merge prefill to the running batch
|
# Merge the prefill batch into the running batch
|
||||||
if (
|
if (
|
||||||
self.last_batch
|
self.last_batch
|
||||||
and not self.last_batch.forward_mode.is_decode()
|
and not self.last_batch.forward_mode.is_decode()
|
||||||
and not self.last_batch.is_empty()
|
and not self.last_batch.is_empty()
|
||||||
):
|
):
|
||||||
if self.running_batch is None:
|
if self.current_inflight_req:
|
||||||
self.running_batch = self.last_batch
|
self.last_batch.filter_batch(self.current_inflight_req)
|
||||||
else:
|
self.batch_is_full = False
|
||||||
self.running_batch.merge_batch(self.last_batch)
|
if not self.last_batch.is_empty():
|
||||||
|
if self.running_batch is None:
|
||||||
|
self.running_batch = self.last_batch
|
||||||
|
else:
|
||||||
|
self.running_batch.merge_batch(self.last_batch)
|
||||||
|
|
||||||
# Prefill first
|
# Prefill first
|
||||||
new_batch = self.get_new_batch_prefill()
|
new_batch = self.get_new_batch_prefill()
|
||||||
if new_batch is not None:
|
if new_batch is not None:
|
||||||
return new_batch
|
return new_batch
|
||||||
|
|
||||||
# Run decode
|
# Check memory
|
||||||
if self.running_batch is not None:
|
if self.running_batch is None:
|
||||||
self.update_running_batch()
|
|
||||||
if not self.running_batch:
|
|
||||||
return None
|
|
||||||
return self.running_batch
|
|
||||||
else:
|
|
||||||
self.check_memory()
|
self.check_memory()
|
||||||
self.new_token_ratio = global_config.init_new_token_ratio
|
self.new_token_ratio = global_config.init_new_token_ratio
|
||||||
|
return
|
||||||
|
|
||||||
|
# Run decode
|
||||||
|
before_bs = self.running_batch.batch_size()
|
||||||
|
self.update_running_batch()
|
||||||
|
if not self.running_batch:
|
||||||
|
self.batch_is_full = False
|
||||||
|
return None
|
||||||
|
if before_bs != self.running_batch.batch_size():
|
||||||
|
self.batch_is_full = False
|
||||||
|
return self.running_batch
|
||||||
|
|
||||||
def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
|
def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
|
||||||
# Handle the cases where prefill is not allowed
|
# Handle the cases where prefill is not allowed
|
||||||
@@ -617,6 +627,11 @@ class Scheduler:
|
|||||||
global test_retract
|
global test_retract
|
||||||
batch = self.running_batch
|
batch = self.running_batch
|
||||||
|
|
||||||
|
batch.filter_batch()
|
||||||
|
if batch.is_empty():
|
||||||
|
self.running_batch = None
|
||||||
|
return
|
||||||
|
|
||||||
# Check if decode out of memory
|
# Check if decode out of memory
|
||||||
if not batch.check_decode_mem() or (test_retract and batch.batch_size() > 10):
|
if not batch.check_decode_mem() or (test_retract and batch.batch_size() > 10):
|
||||||
old_ratio = self.new_token_ratio
|
old_ratio = self.new_token_ratio
|
||||||
@@ -640,8 +655,6 @@ class Scheduler:
|
|||||||
if not self.disable_regex_jump_forward:
|
if not self.disable_regex_jump_forward:
|
||||||
jump_forward_reqs = batch.check_for_jump_forward(self.pad_input_ids_func)
|
jump_forward_reqs = batch.check_for_jump_forward(self.pad_input_ids_func)
|
||||||
self.waiting_queue.extend(jump_forward_reqs)
|
self.waiting_queue.extend(jump_forward_reqs)
|
||||||
if jump_forward_reqs:
|
|
||||||
self.batch_is_full = False
|
|
||||||
if batch.is_empty():
|
if batch.is_empty():
|
||||||
self.running_batch = None
|
self.running_batch = None
|
||||||
return
|
return
|
||||||
@@ -892,14 +905,8 @@ class Scheduler:
|
|||||||
output_no_stop_trim = []
|
output_no_stop_trim = []
|
||||||
else: # embedding or reward model
|
else: # embedding or reward model
|
||||||
output_embeddings = []
|
output_embeddings = []
|
||||||
unfinished_indices = []
|
|
||||||
|
|
||||||
for i, req in enumerate(batch.reqs):
|
|
||||||
if not req.finished() and req is not self.current_inflight_req:
|
|
||||||
unfinished_indices.append(i)
|
|
||||||
else:
|
|
||||||
self.batch_is_full = False
|
|
||||||
|
|
||||||
|
for req in batch.reqs:
|
||||||
if req.finished() or (
|
if req.finished() or (
|
||||||
req.stream
|
req.stream
|
||||||
and (
|
and (
|
||||||
@@ -955,9 +962,6 @@ class Scheduler:
|
|||||||
}
|
}
|
||||||
output_meta_info.append(meta_info)
|
output_meta_info.append(meta_info)
|
||||||
|
|
||||||
# Remove finished reqs: update batch tensors
|
|
||||||
batch.filter_batch(unfinished_indices)
|
|
||||||
|
|
||||||
# Send to detokenizer
|
# Send to detokenizer
|
||||||
if output_rids:
|
if output_rids:
|
||||||
if self.is_generation:
|
if self.is_generation:
|
||||||
|
|||||||
@@ -1,3 +1,7 @@
|
|||||||
|
"""
|
||||||
|
python3 -m unittest test_json_constrained.TestJSONConstrained.test_json_generate
|
||||||
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import unittest
|
import unittest
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
|||||||
Reference in New Issue
Block a user