From 869f1c02c4a7140c674ea92127a45eac0211bf74 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sun, 13 Oct 2024 20:32:37 -0700 Subject: [PATCH] Add a test case to test retract (#1662) --- python/sglang/srt/managers/schedule_batch.py | 3 ++ python/sglang/srt/managers/scheduler.py | 6 ++- test/srt/run_suite.py | 1 + test/srt/test_retract_decode.py | 41 ++++++++++++++++++++ 4 files changed, 50 insertions(+), 1 deletion(-) create mode 100644 test/srt/test_retract_decode.py diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index b4248d5ec..9f02acbe1 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -590,9 +590,11 @@ class ScheduleBatch: retracted_reqs = [] seq_lens_cpu = self.seq_lens.cpu().numpy() + first_iter = True while ( self.token_to_kv_pool.available_size() < len(sorted_indices) * global_config.retract_decode_steps + or first_iter ): if len(sorted_indices) == 1: # Corner case: only one request left @@ -601,6 +603,7 @@ class ScheduleBatch: ), "No space left for only one request" break + first_iter = False idx = sorted_indices.pop() req = self.reqs[idx] retracted_reqs.append(req) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 42c2a2841..bc47915f2 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -77,6 +77,9 @@ logger = logging.getLogger(__name__) # Crash on warning if we are running CI tests crash_on_warning = os.getenv("SGLANG_IS_IN_CI", "false") == "true" +# Test retract decode +test_retract = os.getenv("SGLANG_TEST_RETRACT", "false") == "true" + class Scheduler: """A scheduler that manages a tensor parallel GPU worker.""" @@ -611,10 +614,11 @@ class Scheduler: return new_batch def update_running_batch(self): + global test_retract batch = self.running_batch # Check if decode out of memory - if not batch.check_decode_mem(): + if not batch.check_decode_mem() or (test_retract and batch.batch_size() > 10): old_ratio = self.new_token_ratio retracted_reqs, new_token_ratio = batch.retract_decode() diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index ffdaf0fe4..b9e561ff5 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -17,6 +17,7 @@ suites = { "test_large_max_new_tokens.py", "test_openai_server.py", "test_pytorch_sampling_backend.py", + "test_retract_decode.py", "test_server_args.py", "test_skip_tokenizer_init.py", "test_srt_engine.py", diff --git a/test/srt/test_retract_decode.py b/test/srt/test_retract_decode.py new file mode 100644 index 000000000..b16fd5163 --- /dev/null +++ b/test/srt/test_retract_decode.py @@ -0,0 +1,41 @@ +import unittest +from types import SimpleNamespace + +from sglang.srt.utils import kill_child_process +from sglang.test.run_eval import run_eval +from sglang.test.test_utils import ( + DEFAULT_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + popen_launch_server, +) + + +class TestRetractDecode(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH + ) + + @classmethod + def tearDownClass(cls): + kill_child_process(cls.process.pid) + + def test_mmlu(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mmlu", + num_examples=64, + num_threads=32, + ) + + metrics = run_eval(args) + assert metrics["score"] >= 0.65 + + +if __name__ == "__main__": + unittest.main()