From 00974e4f6ebf3489f33909020a9fb922159407a9 Mon Sep 17 00:00:00 2001 From: Shangming Cai Date: Sat, 6 Sep 2025 22:14:46 +0800 Subject: [PATCH] [CI] Refactor disaggregation tests (#10068) Signed-off-by: Shangming Cai --- .../sglang/test/test_disaggregation_utils.py | 66 ++++++ test/srt/run_suite.py | 1 + test/srt/test_disaggregation.py | 213 ++---------------- test/srt/test_disaggregation_different_tp.py | 117 +--------- test/srt/test_disaggregation_pp.py | 56 +---- 5 files changed, 100 insertions(+), 353 deletions(-) create mode 100644 python/sglang/test/test_disaggregation_utils.py diff --git a/python/sglang/test/test_disaggregation_utils.py b/python/sglang/test/test_disaggregation_utils.py new file mode 100644 index 000000000..f61b71a9d --- /dev/null +++ b/python/sglang/test/test_disaggregation_utils.py @@ -0,0 +1,66 @@ +import time + +import requests + +from sglang.srt.utils import kill_process_tree +from sglang.test.test_utils import ( + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + CustomTestCase, + popen_with_error_check, +) + + +class TestDisaggregationBase(CustomTestCase): + @classmethod + def setUpClass(cls): + cls.process_lb, cls.process_decode, cls.process_prefill = None, None, None + pass + + @classmethod + def launch_lb(cls): + lb_command = [ + "python3", + "-m", + "sglang_router.launch_router", + "--pd-disaggregation", + "--mini-lb", # FIXME: remove this + "--prefill", + cls.prefill_url, + "--decode", + cls.decode_url, + "--host", + cls.base_host, + "--port", + cls.lb_port, + ] + print("Starting load balancer:", " ".join(lb_command)) + cls.process_lb = popen_with_error_check(lb_command) + cls.wait_server_ready(cls.lb_url + "/health") + + @classmethod + def wait_server_ready(cls, url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH): + start_time = time.perf_counter() + while True: + try: + response = requests.get(url) + if response.status_code == 200: + print(f"Server {url} is ready") + return + except Exception: + pass + + if time.perf_counter() - start_time > timeout: + raise RuntimeError(f"Server {url} failed to start in {timeout}s") + time.sleep(1) + + @classmethod + def tearDownClass(cls): + for process in [cls.process_lb, cls.process_decode, cls.process_prefill]: + if process: + try: + kill_process_tree(process.pid) + except Exception as e: + print(f"Error killing process {process.pid}: {e}") + + # wait for 5 seconds + time.sleep(5) diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index bfe867f17..f4e5871de 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -139,6 +139,7 @@ suites = { TestFile("lora/test_lora_llama4.py", 600), TestFile("test_disaggregation.py", 499), TestFile("test_disaggregation_different_tp.py", 155), + TestFile("test_disaggregation_pp.py", 60), TestFile("test_full_deepseek_v3.py", 333), ], "per-commit-8-gpu-b200": [ diff --git a/test/srt/test_disaggregation.py b/test/srt/test_disaggregation.py index 1a7cb99ed..827bfc3b8 100644 --- a/test/srt/test_disaggregation.py +++ b/test/srt/test_disaggregation.py @@ -7,21 +7,19 @@ from urllib.parse import urlparse import requests -from sglang.srt.utils import kill_process_tree from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.test_disaggregation_utils import TestDisaggregationBase from sglang.test.test_utils import ( DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST, DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST, DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, - CustomTestCase, popen_launch_pd_server, - popen_with_error_check, ) -class TestDisaggregationAccuracy(CustomTestCase): +class TestDisaggregationAccuracy(TestDisaggregationBase): @classmethod def setUpClass(cls): cls.model = DEFAULT_MODEL_NAME_FOR_TEST @@ -44,25 +42,7 @@ class TestDisaggregationAccuracy(CustomTestCase): cls.wait_server_ready(cls.prefill_url + "/health") cls.wait_server_ready(cls.decode_url + "/health") - lb_command = [ - "python3", - "-m", - "sglang_router.launch_router", - "--pd-disaggregation", - "--mini-lb", # FIXME: remove this - "--prefill", - cls.prefill_url, - "--decode", - cls.decode_url, - "--host", - cls.base_host, - "--port", - cls.lb_port, - ] - - print("Starting load balancer:", " ".join(lb_command)) - cls.process_lb = popen_with_error_check(lb_command) - cls.wait_server_ready(cls.lb_url + "/health") + cls.launch_lb() @classmethod def start_prefill(cls): @@ -102,34 +82,6 @@ class TestDisaggregationAccuracy(CustomTestCase): other_args=decode_args, ) - @classmethod - def wait_server_ready(cls, url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH): - start_time = time.perf_counter() - while True: - try: - response = requests.get(url) - if response.status_code == 200: - print(f"Server {url} is ready") - return - except Exception: - pass - - if time.perf_counter() - start_time > timeout: - raise RuntimeError(f"Server {url} failed to start in {timeout}s") - time.sleep(1) - - @classmethod - def tearDownClass(cls): - for process in [cls.process_lb, cls.process_decode, cls.process_prefill]: - if process: - try: - kill_process_tree(process.pid) - except Exception as e: - print(f"Error killing process {process.pid}: {e}") - - # wait for 5 seconds - time.sleep(5) - def test_gsm8k(self): args = SimpleNamespace( num_shots=5, @@ -199,7 +151,7 @@ class TestDisaggregationAccuracy(CustomTestCase): json.loads(output) -class TestDisaggregationMooncakeFailure(CustomTestCase): +class TestDisaggregationMooncakeFailure(TestDisaggregationBase): @classmethod def setUpClass(cls): # set DISAGGREGATION_TEST_FAILURE_PROB to simulate failure @@ -225,25 +177,12 @@ class TestDisaggregationMooncakeFailure(CustomTestCase): cls.wait_server_ready(cls.prefill_url + "/health") cls.wait_server_ready(cls.decode_url + "/health") - lb_command = [ - "python3", - "-m", - "sglang_router.launch_router", - "--pd-disaggregation", - "--mini-lb", # FIXME: remove this - "--prefill", - cls.prefill_url, - "--decode", - cls.decode_url, - "--host", - cls.base_host, - "--port", - cls.lb_port, - ] + cls.launch_lb() - print("Starting load balancer:", " ".join(lb_command)) - cls.process_lb = popen_with_error_check(lb_command) - cls.wait_server_ready(cls.lb_url + "/health") + @classmethod + def tearDownClass(cls): + os.environ.pop("DISAGGREGATION_TEST_FAILURE_PROB") + super().tearDownClass() @classmethod def start_prefill(cls): @@ -283,36 +222,6 @@ class TestDisaggregationMooncakeFailure(CustomTestCase): other_args=decode_args, ) - @classmethod - def wait_server_ready(cls, url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH): - start_time = time.perf_counter() - while True: - try: - response = requests.get(url) - if response.status_code == 200: - print(f"Server {url} is ready") - return - except Exception: - pass - - if time.perf_counter() - start_time > timeout: - raise RuntimeError(f"Server {url} failed to start in {timeout}s") - time.sleep(1) - - @classmethod - def tearDownClass(cls): - # unset DISAGGREGATION_TEST_FAILURE_PROB - os.environ.pop("DISAGGREGATION_TEST_FAILURE_PROB") - for process in [cls.process_lb, cls.process_decode, cls.process_prefill]: - if process: - try: - kill_process_tree(process.pid) - except Exception as e: - print(f"Error killing process {process.pid}: {e}") - - # wait for 5 seconds - time.sleep(5) - def test_gsm8k(self): args = SimpleNamespace( num_shots=5, @@ -341,7 +250,7 @@ class TestDisaggregationMooncakeFailure(CustomTestCase): raise e from health_check_error -class TestDisaggregationMooncakeSpec(CustomTestCase): +class TestDisaggregationMooncakeSpec(TestDisaggregationBase): @classmethod def setUpClass(cls): @@ -380,41 +289,7 @@ class TestDisaggregationMooncakeSpec(CustomTestCase): cls.wait_server_ready(cls.prefill_url + "/health") cls.wait_server_ready(cls.decode_url + "/health") - lb_command = [ - "python3", - "-m", - "sglang_router.launch_router", - "--pd-disaggregation", - "--mini-lb", # FIXME: remove this - "--prefill", - cls.prefill_url, - "--decode", - cls.decode_url, - "--host", - cls.base_host, - "--port", - cls.lb_port, - ] - - print("Starting load balancer:", " ".join(lb_command)) - cls.process_lb = popen_with_error_check(lb_command) - cls.wait_server_ready(cls.lb_url + "/health") - - @classmethod - def wait_server_ready(cls, url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH): - start_time = time.perf_counter() - while True: - try: - response = requests.get(url) - if response.status_code == 200: - print(f"Server {url} is ready") - return - except Exception: - pass - - if time.perf_counter() - start_time > timeout: - raise RuntimeError(f"Server {url} failed to start in {timeout}s") - time.sleep(1) + cls.launch_lb() @classmethod def start_prefill(cls): @@ -454,18 +329,6 @@ class TestDisaggregationMooncakeSpec(CustomTestCase): other_args=decode_args, ) - @classmethod - def tearDownClass(cls): - for process in [cls.process_lb, cls.process_decode, cls.process_prefill]: - if process: - try: - kill_process_tree(process.pid) - except Exception as e: - print(f"Error killing process {process.pid}: {e}") - - # wait for 5 seconds - time.sleep(5) - def test_gsm8k(self): args = SimpleNamespace( num_shots=5, @@ -482,7 +345,7 @@ class TestDisaggregationMooncakeSpec(CustomTestCase): self.assertGreater(metrics["accuracy"], 0.20) -class TestDisaggregationSimulatedRetract(CustomTestCase): +class TestDisaggregationSimulatedRetract(TestDisaggregationBase): @classmethod def setUpClass(cls): os.environ["SGLANG_TEST_RETRACT"] = "true" @@ -506,25 +369,12 @@ class TestDisaggregationSimulatedRetract(CustomTestCase): cls.wait_server_ready(cls.prefill_url + "/health") cls.wait_server_ready(cls.decode_url + "/health") - lb_command = [ - "python3", - "-m", - "sglang_router.launch_router", - "--pd-disaggregation", - "--mini-lb", # FIXME: remove this - "--prefill", - cls.prefill_url, - "--decode", - cls.decode_url, - "--host", - cls.base_host, - "--port", - cls.lb_port, - ] + cls.launch_lb() - print("Starting load balancer:", " ".join(lb_command)) - cls.process_lb = popen_with_error_check(lb_command) - cls.wait_server_ready(cls.lb_url + "/health") + @classmethod + def tearDownClass(cls): + os.environ.pop("SGLANG_TEST_RETRACT") + super().tearDownClass() @classmethod def start_prefill(cls): @@ -564,35 +414,6 @@ class TestDisaggregationSimulatedRetract(CustomTestCase): other_args=decode_args, ) - @classmethod - def wait_server_ready(cls, url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH): - start_time = time.perf_counter() - while True: - try: - response = requests.get(url) - if response.status_code == 200: - print(f"Server {url} is ready") - return - except Exception: - pass - - if time.perf_counter() - start_time > timeout: - raise RuntimeError(f"Server {url} failed to start in {timeout}s") - time.sleep(1) - - @classmethod - def tearDownClass(cls): - os.environ.pop("SGLANG_TEST_RETRACT") - for process in [cls.process_lb, cls.process_decode, cls.process_prefill]: - if process: - try: - kill_process_tree(process.pid) - except Exception as e: - print(f"Error killing process {process.pid}: {e}") - - # wait for 5 seconds - time.sleep(5) - def test_gsm8k(self): args = SimpleNamespace( num_shots=5, diff --git a/test/srt/test_disaggregation_different_tp.py b/test/srt/test_disaggregation_different_tp.py index 911afbe9b..67a3afcbe 100644 --- a/test/srt/test_disaggregation_different_tp.py +++ b/test/srt/test_disaggregation_different_tp.py @@ -1,25 +1,20 @@ import os -import subprocess import time import unittest from types import SimpleNamespace from urllib.parse import urlparse -import requests - -from sglang.srt.utils import kill_process_tree from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.test_disaggregation_utils import TestDisaggregationBase from sglang.test.test_utils import ( DEFAULT_MODEL_NAME_FOR_TEST_MLA, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, - CustomTestCase, popen_launch_pd_server, - popen_with_error_check, ) -class TestDisaggregationMooncakePrefillLargerTP(CustomTestCase): +class TestDisaggregationMooncakePrefillLargerTP(TestDisaggregationBase): @classmethod def setUpClass(cls): # Temporarily disable JIT DeepGEMM @@ -46,25 +41,7 @@ class TestDisaggregationMooncakePrefillLargerTP(CustomTestCase): cls.wait_server_ready(cls.prefill_url + "/health") cls.wait_server_ready(cls.decode_url + "/health") - lb_command = [ - "python3", - "-m", - "sglang_router.launch_router", - "--pd-disaggregation", - "--mini-lb", # FIXME: remove this - "--prefill", - cls.prefill_url, - "--decode", - cls.decode_url, - "--host", - cls.base_host, - "--port", - cls.lb_port, - ] - - print("Starting load balancer:", " ".join(lb_command)) - cls.process_lb = popen_with_error_check(lb_command) - cls.wait_server_ready(cls.lb_url + "/health") + cls.launch_lb() @classmethod def start_prefill(cls): @@ -104,39 +81,6 @@ class TestDisaggregationMooncakePrefillLargerTP(CustomTestCase): other_args=decode_args, ) - @classmethod - def wait_server_ready(cls, url, timeout=60): - start_time = time.perf_counter() - while True: - try: - response = requests.get(url) - if response.status_code == 200: - print(f"Server {url} is ready") - return - except Exception: - pass - - if time.perf_counter() - start_time > timeout: - raise RuntimeError(f"Server {url} failed to start in {timeout}s") - time.sleep(1) - - @classmethod - def tearDownClass(cls): - # Restore JIT DeepGEMM environment variable - if cls.original_jit_deepgemm is not None: - os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = cls.original_jit_deepgemm - else: - os.environ.pop("SGL_ENABLE_JIT_DEEPGEMM", None) - - for process in [cls.process_lb, cls.process_decode, cls.process_prefill]: - if process: - try: - kill_process_tree(process.pid) - except Exception as e: - print(f"Error killing process {process.pid}: {e}") - # wait for 5 seconds - time.sleep(5) - def test_gsm8k(self): args = SimpleNamespace( num_shots=5, @@ -153,7 +97,7 @@ class TestDisaggregationMooncakePrefillLargerTP(CustomTestCase): self.assertGreater(metrics["accuracy"], 0.60) -class TestDisaggregationMooncakeDecodeLargerTP(CustomTestCase): +class TestDisaggregationMooncakeDecodeLargerTP(TestDisaggregationBase): @classmethod def setUpClass(cls): # Temporarily disable JIT DeepGEMM @@ -180,25 +124,7 @@ class TestDisaggregationMooncakeDecodeLargerTP(CustomTestCase): cls.wait_server_ready(cls.prefill_url + "/health") cls.wait_server_ready(cls.decode_url + "/health") - lb_command = [ - "python3", - "-m", - "sglang_router.launch_router", - "--pd-disaggregation", - "--mini-lb", # FIXME: remove this - "--prefill", - cls.prefill_url, - "--decode", - cls.decode_url, - "--host", - cls.base_host, - "--port", - cls.lb_port, - ] - - print("Starting load balancer:", " ".join(lb_command)) - cls.process_lb = popen_with_error_check(lb_command) - cls.wait_server_ready(cls.lb_url + "/health") + cls.launch_lb() @classmethod def start_prefill(cls): @@ -238,39 +164,6 @@ class TestDisaggregationMooncakeDecodeLargerTP(CustomTestCase): other_args=decode_args, ) - @classmethod - def wait_server_ready(cls, url, timeout=60): - start_time = time.perf_counter() - while True: - try: - response = requests.get(url) - if response.status_code == 200: - print(f"Server {url} is ready") - return - except Exception: - pass - - if time.perf_counter() - start_time > timeout: - raise RuntimeError(f"Server {url} failed to start in {timeout}s") - time.sleep(1) - - @classmethod - def tearDownClass(cls): - # Restore JIT DeepGEMM environment variable - if cls.original_jit_deepgemm is not None: - os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = cls.original_jit_deepgemm - else: - os.environ.pop("SGL_ENABLE_JIT_DEEPGEMM", None) - - for process in [cls.process_lb, cls.process_decode, cls.process_prefill]: - if process: - try: - kill_process_tree(process.pid) - except Exception as e: - print(f"Error killing process {process.pid}: {e}") - # wait for 5 seconds - time.sleep(5) - def test_gsm8k(self): args = SimpleNamespace( num_shots=5, diff --git a/test/srt/test_disaggregation_pp.py b/test/srt/test_disaggregation_pp.py index ece959a7d..a8bab8f81 100644 --- a/test/srt/test_disaggregation_pp.py +++ b/test/srt/test_disaggregation_pp.py @@ -1,29 +1,19 @@ -import json -import os -import random import time import unittest -from concurrent.futures import ThreadPoolExecutor from types import SimpleNamespace -from typing import List, Optional +from urllib.parse import urlparse -import requests - -from sglang.srt.utils import kill_process_tree from sglang.test.few_shot_gsm8k import run_eval -from sglang.test.runners import DEFAULT_PROMPTS +from sglang.test.test_disaggregation_utils import TestDisaggregationBase from sglang.test.test_utils import ( - DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST, - DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST, DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, - CustomTestCase, - popen_launch_server, + popen_launch_pd_server, ) -class TestPDPPAccuracy(unittest.TestCase): +class TestDisaggregationPPAccuracy(TestDisaggregationBase): @classmethod def setUpClass(cls): cls.model = DEFAULT_MODEL_NAME_FOR_TEST @@ -46,27 +36,7 @@ class TestPDPPAccuracy(unittest.TestCase): cls.wait_server_ready(cls.prefill_url + "/health") cls.wait_server_ready(cls.decode_url + "/health") - lb_command = [ - "python3", - "-m", - "sglang_router.launch_router", - "--pd-disaggregation", - "--mini-lb", # FIXME: remove this - "--prefill", - cls.prefill_url, - "--decode", - cls.decode_url, - "--host", - cls.base_host, - "--port", - cls.lb_port, - ] - - print("Starting load balancer:", " ".join(lb_command)) - cls.process_lb = subprocess.Popen( - lb_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE - ) - cls.wait_server_ready(cls.lb_url + "/health") + cls.launch_lb() @classmethod def start_prefill(cls): @@ -75,11 +45,11 @@ class TestPDPPAccuracy(unittest.TestCase): "--disaggregation-mode", "prefill", "--tp-size", - "2", + "1", "--pp-size", "2", "--disaggregation-ib-device", - "mlx5_roce0", + "mlx5_roce0,mlx5_roce1", "--disable-overlap-schedule", ] cls.process_prefill = popen_launch_pd_server( @@ -98,9 +68,9 @@ class TestPDPPAccuracy(unittest.TestCase): "--tp", "1", "--base-gpu-id", - "1", + "2", "--disaggregation-ib-device", - "mlx5_roce1", + "mlx5_roce2", ] cls.process_decode = popen_launch_pd_server( cls.model, @@ -109,10 +79,6 @@ class TestPDPPAccuracy(unittest.TestCase): other_args=decode_args, ) - @classmethod - def tearDownClass(cls): - kill_process_tree(cls.process.pid) - def test_gsm8k(self): args = SimpleNamespace( num_shots=5, @@ -120,8 +86,8 @@ class TestPDPPAccuracy(unittest.TestCase): num_questions=200, max_new_tokens=512, parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), + host=f"http://{self.base_host}", + port=int(self.lb_port), ) metrics = run_eval(args) print(f"{metrics=}")