diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index a1929cbe0..ff7c21e2a 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -50,7 +50,7 @@ jobs: timeout-minutes: 25 run: | cd test/srt - python3 run_suite.py --suite minimal --range-begin 0 --range-end 6 + python3 run_suite.py --suite minimal --range-begin 0 --range-end 7 unit-test-backend-part-2: if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' @@ -67,7 +67,7 @@ jobs: timeout-minutes: 25 run: | cd test/srt - python3 run_suite.py --suite minimal --range-begin 6 --range-end 14 + python3 run_suite.py --suite minimal --range-begin 7 --range-end 14 unit-test-backend-part-3: if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 8bfb8e8f7..97dec49c2 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -729,10 +729,13 @@ class ScheduleBatch: self.input_ids = input_ids self.out_cache_loc = out_cache_loc + # For overlap scheduler, the output_ids has one step delay + delta = 0 if self.enable_overlap else -1 + # NOTE: prefix_indices is what has been cached, but we don't cache each decode step self.prefix_lens.extend( [ - len(r.origin_input_ids) + len(r.output_ids) - 1 + len(r.origin_input_ids) + len(r.output_ids) + delta for r in running_batch.reqs ] ) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 165c7f66f..796164297 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -848,7 +848,12 @@ class Scheduler: new_batch.prepare_for_extend() # Mixed-style chunked prefill - if self.is_mixed_chunk and self.running_batch is not None: + if ( + self.is_mixed_chunk + and self.running_batch is not None + and not (new_batch.return_logprob or self.running_batch.return_logprob) + ): + # TODO (lianmin): support return_logprob + mixed chunked prefill self.running_batch.filter_batch() if not self.running_batch.is_empty(): self.running_batch.prepare_for_decode() @@ -979,7 +984,10 @@ class Scheduler: continue if self.is_mixed_chunk and self.enable_overlap and req.finished(): - raise ValueError("Unhandled error!") + # Free the one delayed token for the mixed decode batch + j = len(batch.out_cache_loc) - len(batch.reqs) + i + self.token_to_kv_pool.free(batch.out_cache_loc[j : j + 1]) + continue if req.is_being_chunked <= 0: req.completion_tokens_wo_jump_forward += 1 @@ -992,7 +1000,6 @@ class Scheduler: self.tree_cache.cache_unfinished_req(req) if req.return_logprob: - # TODO (lianmin): need to think the case w/ mixed chunked prefill logprob_pt += self.add_logprob_return_values( i, req, logprob_pt, next_token_ids, logits_output ) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 5a5cca918..b545e00c0 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -199,12 +199,6 @@ class ServerArgs: "Overlap schedule is disabled." ) - if self.enable_mixed_chunk: - logger.info( - "Overlap schedule is disabled because mixed-style chunked prefill is enabled." - ) - self.disable_overlap_schedule = True - @staticmethod def add_cli_args(parser: argparse.ArgumentParser): # Model and port args