From 0c1e87964b87f201f1cc9d3bd6d54ae3280a9b31 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Mon, 14 Oct 2024 01:15:34 -0700 Subject: [PATCH] Move filter_batch out of stream_output (#1663) --- python/sglang/srt/managers/schedule_batch.py | 34 ++++++++----- python/sglang/srt/managers/scheduler.py | 52 +++++++++++--------- test/srt/test_json_constrained.py | 4 ++ 3 files changed, 54 insertions(+), 36 deletions(-) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 9f02acbe1..0b2172b14 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -659,7 +659,7 @@ class ScheduleBatch: def check_for_jump_forward(self, pad_input_ids_func): jump_forward_reqs = [] - filter_indices = [i for i in range(len(self.reqs))] + keep_indices = set(i for i in range(len(self.reqs))) for i, req in enumerate(self.reqs): if req.jump_forward_map is not None: @@ -719,9 +719,9 @@ class ScheduleBatch: ) jump_forward_reqs.append(req) - filter_indices.remove(i) + keep_indices.remove(i) - self.filter_batch(filter_indices) + self.filter_batch(keep_indices=list(keep_indices)) return jump_forward_reqs @@ -740,19 +740,31 @@ class ScheduleBatch: self.req_pool_indices, self.seq_lens - 1 ] = self.out_cache_loc - def filter_batch(self, unfinished_indices: List[int]): - if unfinished_indices is None or len(unfinished_indices) == 0: + def filter_batch( + self, + current_inflight_req: Optional[Req] = None, + keep_indices: Optional[List[int]] = None, + ): + if keep_indices is None: + keep_indices = [ + i + for i in range(len(self.reqs)) + if not self.reqs[i].finished() + and self.reqs[i] is not current_inflight_req + ] + + if keep_indices is None or len(keep_indices) == 0: # Filter out all requests self.reqs = [] return - if len(unfinished_indices) == len(self.reqs): + if len(keep_indices) == len(self.reqs): # No need to filter return - self.reqs = [self.reqs[i] for i in unfinished_indices] + self.reqs = [self.reqs[i] for i in keep_indices] new_indices = torch.tensor( - unfinished_indices, dtype=torch.int32, device=self.seq_lens.device + keep_indices, dtype=torch.int32, device=self.seq_lens.device ) self.req_pool_indices = self.req_pool_indices[new_indices] self.seq_lens = self.seq_lens[new_indices] @@ -760,16 +772,14 @@ class ScheduleBatch: self.output_ids = self.output_ids[new_indices] self.return_logprob = any(req.return_logprob for req in self.reqs) if self.return_logprob: - self.top_logprobs_nums = [ - self.top_logprobs_nums[i] for i in unfinished_indices - ] + self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in keep_indices] else: self.top_logprobs_nums = None self.has_stream = any(req.stream for req in self.reqs) self.has_regex = any(req.regex_fsm for req in self.reqs) - self.sampling_info.filter_batch(unfinished_indices, new_indices) + self.sampling_info.filter_batch(keep_indices, new_indices) def merge_batch(self, other: "ScheduleBatch"): # Penalizer orchestrator must be merged before Batch.reqs is merged. This is because diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index bc47915f2..0c5049844 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -446,31 +446,41 @@ class Scheduler: exit(1) if crash_on_warning else None def get_next_batch_to_run(self): - # Merge prefill to the running batch + # Merge the prefill batch into the running batch if ( self.last_batch and not self.last_batch.forward_mode.is_decode() and not self.last_batch.is_empty() ): - if self.running_batch is None: - self.running_batch = self.last_batch - else: - self.running_batch.merge_batch(self.last_batch) + if self.current_inflight_req: + self.last_batch.filter_batch(self.current_inflight_req) + self.batch_is_full = False + if not self.last_batch.is_empty(): + if self.running_batch is None: + self.running_batch = self.last_batch + else: + self.running_batch.merge_batch(self.last_batch) # Prefill first new_batch = self.get_new_batch_prefill() if new_batch is not None: return new_batch - # Run decode - if self.running_batch is not None: - self.update_running_batch() - if not self.running_batch: - return None - return self.running_batch - else: + # Check memory + if self.running_batch is None: self.check_memory() self.new_token_ratio = global_config.init_new_token_ratio + return + + # Run decode + before_bs = self.running_batch.batch_size() + self.update_running_batch() + if not self.running_batch: + self.batch_is_full = False + return None + if before_bs != self.running_batch.batch_size(): + self.batch_is_full = False + return self.running_batch def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: # Handle the cases where prefill is not allowed @@ -617,6 +627,11 @@ class Scheduler: global test_retract batch = self.running_batch + batch.filter_batch() + if batch.is_empty(): + self.running_batch = None + return + # Check if decode out of memory if not batch.check_decode_mem() or (test_retract and batch.batch_size() > 10): old_ratio = self.new_token_ratio @@ -640,8 +655,6 @@ class Scheduler: if not self.disable_regex_jump_forward: jump_forward_reqs = batch.check_for_jump_forward(self.pad_input_ids_func) self.waiting_queue.extend(jump_forward_reqs) - if jump_forward_reqs: - self.batch_is_full = False if batch.is_empty(): self.running_batch = None return @@ -892,14 +905,8 @@ class Scheduler: output_no_stop_trim = [] else: # embedding or reward model output_embeddings = [] - unfinished_indices = [] - - for i, req in enumerate(batch.reqs): - if not req.finished() and req is not self.current_inflight_req: - unfinished_indices.append(i) - else: - self.batch_is_full = False + for req in batch.reqs: if req.finished() or ( req.stream and ( @@ -955,9 +962,6 @@ class Scheduler: } output_meta_info.append(meta_info) - # Remove finished reqs: update batch tensors - batch.filter_batch(unfinished_indices) - # Send to detokenizer if output_rids: if self.is_generation: diff --git a/test/srt/test_json_constrained.py b/test/srt/test_json_constrained.py index 12cd51676..c054d7234 100644 --- a/test/srt/test_json_constrained.py +++ b/test/srt/test_json_constrained.py @@ -1,3 +1,7 @@ +""" +python3 -m unittest test_json_constrained.TestJSONConstrained.test_json_generate +""" + import json import unittest from concurrent.futures import ThreadPoolExecutor