diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 796bca849..c124a0d5d 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -456,7 +456,12 @@ class Scheduler: and not self.last_batch.is_empty() ): if self.current_inflight_req: - self.last_batch.filter_batch(self.current_inflight_req) + self.last_batch.filter_batch( + current_inflight_req=self.current_inflight_req + ) + self.tree_cache.cache_unfinished_req(self.current_inflight_req) + # Inflight request keeps its rid but will get a new req_pool_idx. + self.req_to_token_pool.free(self.current_inflight_req.req_pool_idx) self.batch_is_full = False if not self.last_batch.is_empty(): if self.running_batch is None: @@ -728,27 +733,24 @@ class Scheduler: # Check finish conditions logprob_pt = 0 for i, req in enumerate(batch.reqs): - if req is not self.current_inflight_req: + if req.is_inflight_req > 0: + req.is_inflight_req -= 1 + else: # Inflight reqs' prefill is not finished req.completion_tokens_wo_jump_forward += 1 req.output_ids.append(next_token_ids[i]) req.check_finished() + if req.finished(): + self.tree_cache.cache_finished_req(req) + elif not batch.decoding_reqs or req not in batch.decoding_reqs: + self.tree_cache.cache_unfinished_req(req) + if req.regex_fsm is not None: req.regex_fsm_state = req.regex_fsm.get_next_state( req.regex_fsm_state, next_token_ids[i] ) - if req.finished(): - self.tree_cache.cache_finished_req(req) - elif not batch.decoding_reqs or req not in batch.decoding_reqs: - self.tree_cache.cache_unfinished_req(req) - - if req.is_inflight_req > 0: - # Inflight request would get a new req idx - req.is_inflight_req -= 1 - self.req_to_token_pool.free(req.req_pool_idx) - if req.return_logprob: logprob_pt += self.add_logprob_return_values( i, req, logprob_pt, next_token_ids, logits_output @@ -760,7 +762,9 @@ class Scheduler: # Check finish conditions for i, req in enumerate(batch.reqs): req.embedding = embeddings[i] - if req is not self.current_inflight_req: + if req.is_inflight_req > 0: + req.is_inflight_req -= 1 + else: # Inflight reqs' prefill is not finished # dummy output token for embedding models req.output_ids.append(0) @@ -771,11 +775,6 @@ class Scheduler: else: self.tree_cache.cache_unfinished_req(req) - if req.is_inflight_req > 0: - # Inflight request would get a new req idx - req.is_inflight_req -= 1 - self.req_to_token_pool.free(req.req_pool_idx) - self.stream_output(batch) def process_batch_result_decode(self, batch: ScheduleBatch, result):