From b9fd178f1b7bab721b384c017dcec30a3ba0f323 Mon Sep 17 00:00:00 2001 From: Liangsheng Yin Date: Thu, 31 Oct 2024 18:27:42 -0700 Subject: [PATCH] Fix retraction + overlap (#1860) Co-authored-by: Lianmin Zheng --- .github/workflows/pr-test.yml | 4 ++-- python/sglang/srt/managers/schedule_batch.py | 15 ++++++++----- python/sglang/srt/managers/scheduler.py | 22 ++++++++++++++++---- test/srt/test_radix_attention.py | 21 +++++++++++++++++++ 4 files changed, 51 insertions(+), 11 deletions(-) diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index 1d66a5c7b..8f374b897 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -50,7 +50,7 @@ jobs: timeout-minutes: 20 run: | cd test/srt - python3 run_suite.py --suite minimal --range-begin 0 --range-end 4 + python3 run_suite.py --suite minimal --range-begin 0 --range-end 6 unit-test-backend-part-2: if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' @@ -67,7 +67,7 @@ jobs: timeout-minutes: 20 run: | cd test/srt - python3 run_suite.py --suite minimal --range-begin 4 --range-end 14 + python3 run_suite.py --suite minimal --range-begin 6 --range-end 14 unit-test-backend-part-3: if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 24e82993b..131d2982a 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -211,9 +211,6 @@ class Req: # this does not include the jump forward tokens. self.completion_tokens_wo_jump_forward = 0 - # The number of cached tokens, that were already cached in the KV store - self.cached_tokens = 0 - # For vision inputs self.image_inputs: Optional[ImageInputs] = None @@ -223,6 +220,9 @@ class Req: self.last_node = None self.is_being_chunked = 0 + # For retraction + self.is_retracted = False + # Logprobs (arguments) self.return_logprob = False self.logprob_start_len = 0 @@ -242,12 +242,15 @@ class Req: # The relative logprob_start_len in an extend batch self.extend_logprob_start_len = 0 - # Embedding + # Embedding (return values) self.embedding = None # Constrained decoding self.grammar: Optional[Grammar] = None + # The number of cached tokens, that were already cached in the KV cache + self.cached_tokens = 0 + # For Qwen2-VL self.mrope_position_delta = [] # use mutable object @@ -561,7 +564,7 @@ class ScheduleBatch: seq_lens[i] -= encoder_len if len(req.prefix_indices) < encoder_len: - # NOTE: the encoder part should considered as a whole + # NOTE: the encoder part should be considered as a whole assert len(req.prefix_indices) == 0 input_ids[i] = input_ids[i][encoder_len:] encoder_out_cache_loc.append(self.out_cache_loc[pt : pt + encoder_len]) @@ -648,6 +651,7 @@ class ScheduleBatch: req.extend_logprob_start_len = extend_logprob_start_len pt += req.extend_input_len + req.is_retracted = False # Set fields self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to( @@ -780,6 +784,7 @@ class ScheduleBatch: req.prefix_indices = [] req.last_node = None req.extend_input_len = 0 + req.is_retracted = True # For incremental logprobs req.last_update_decode_tokens = 0 diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 9682b7260..74cd1f3ea 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -79,6 +79,7 @@ from sglang.utils import get_exception_traceback logger = logging.getLogger(__name__) + # Crash on warning if we are running CI tests crash_on_warning = os.getenv("SGLANG_IS_IN_CI", "false") == "true" @@ -831,9 +832,10 @@ class Scheduler: # Check finish conditions logprob_pt = 0 for i, req in enumerate(batch.reqs): - if req.is_being_chunked > 0: - req.is_being_chunked -= 1 - else: + if req.is_retracted: + continue + + if req.is_being_chunked <= 0: # Inflight reqs' prefill is not finished req.completion_tokens_wo_jump_forward += 1 req.output_ids.append(next_token_ids[i]) @@ -851,12 +853,18 @@ class Scheduler: logprob_pt += self.add_logprob_return_values( i, req, logprob_pt, next_token_ids, logits_output ) + else: + req.is_being_chunked -= 1 + else: # embedding or reward model embeddings, bid = result embeddings = embeddings.tolist() # Check finish conditions for i, req in enumerate(batch.reqs): + if req.is_retracted: + continue + req.embedding = embeddings[i] if req.is_being_chunked > 0: req.is_being_chunked -= 1 @@ -893,7 +901,12 @@ class Scheduler: # Check finish condition for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)): - if self.server_args.enable_overlap_schedule and req.finished(): + if req.is_retracted: + continue + + if self.server_args.enable_overlap_schedule and ( + req.finished() + ): self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1]) continue @@ -1015,6 +1028,7 @@ class Scheduler: is_stream_iter = self.forward_ct_decode % self.stream_interval == 0 for req in reqs: + # TODO(lianmin): revisit this for overlap + retract + stream if req.finished() or ( req.stream and (is_stream_iter or len(req.output_ids) == 1) ): diff --git a/test/srt/test_radix_attention.py b/test/srt/test_radix_attention.py index 1c20d95eb..e858ba9ee 100644 --- a/test/srt/test_radix_attention.py +++ b/test/srt/test_radix_attention.py @@ -107,6 +107,27 @@ class TestRadixCacheLPM(TestRadixCacheFCFS): ) +class TestRadixCacheOverlapLPM(TestRadixCacheFCFS): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--enable-overlap-schedule", + "--chunked-prefill-size", + "128", + "--max-total-tokens", + "20000", + "--schedule-policy", + "lpm", + ], + ) + + if __name__ == "__main__": os.environ["SGLANG_TEST_RETRACT"] = "true" unittest.main()