diff --git a/test/srt/test_disaggregation_different_tp.py b/test/srt/test_disaggregation_different_tp.py index 9b045dd10..fdc332040 100644 --- a/test/srt/test_disaggregation_different_tp.py +++ b/test/srt/test_disaggregation_different_tp.py @@ -3,6 +3,7 @@ import subprocess import time import unittest from types import SimpleNamespace +from urllib.parse import urlparse import requests @@ -18,7 +19,7 @@ from sglang.test.test_utils import ( ) -class TestDisaggregationMooncakeDifferentTP(CustomTestCase): +class TestDisaggregationMooncakePrefillLargerTP(CustomTestCase): @classmethod def setUpClass(cls): # Temporarily disable JIT DeepGEMM @@ -26,15 +27,22 @@ class TestDisaggregationMooncakeDifferentTP(CustomTestCase): os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "false" cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA - cls.base_host = "127.0.0.1" - cls.base_port = int(DEFAULT_URL_FOR_TEST.split(":")[-1]) - cls.lb_url = DEFAULT_URL_FOR_TEST - cls.prefill_url = f"http://{cls.base_host}:{cls.base_port + 100}" - cls.decode_url = f"http://{cls.base_host}:{cls.base_port + 200}" + parsed_url = urlparse(DEFAULT_URL_FOR_TEST) + cls.base_host = parsed_url.hostname + base_port = str(parsed_url.port) + cls.lb_port = base_port + cls.prefill_port = f"{int(base_port) + 100}" + cls.decode_port = f"{int(base_port) + 200}" + cls.prefill_url = f"http://{cls.base_host}:{cls.prefill_port}" + cls.decode_url = f"http://{cls.base_host}:{cls.decode_port}" + cls.lb_url = f"http://{cls.base_host}:{cls.lb_port}" + print(f"{cls.base_host=} {cls.lb_port=} {cls.prefill_port=} {cls.decode_port=}") - run_with_timeout(cls.start_prefill, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH) - run_with_timeout(cls.start_decode, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH) + # Non blocking start servers + cls.start_prefill() + cls.start_decode() + # Block until both cls.wait_server_ready(cls.prefill_url + "/health") cls.wait_server_ready(cls.decode_url + "/health") @@ -49,7 +57,7 @@ class TestDisaggregationMooncakeDifferentTP(CustomTestCase): "--host", cls.base_host, "--port", - str(cls.base_port), + cls.lb_port, ] print("Starting load balancer:", " ".join(lb_command)) @@ -64,12 +72,10 @@ class TestDisaggregationMooncakeDifferentTP(CustomTestCase): "--trust-remote-code", "--disaggregation-mode", "prefill", - "--host", - cls.base_host, - "--port", - str(cls.base_port + 100), "--tp", - "4", + "2", + "--disaggregation-ib-device", + "mlx5_roce0,mlx5_roce1", ] cls.process_prefill = popen_launch_pd_server( cls.model, @@ -84,14 +90,12 @@ class TestDisaggregationMooncakeDifferentTP(CustomTestCase): "--trust-remote-code", "--disaggregation-mode", "decode", - "--host", - cls.base_host, - "--port", - str(cls.base_port + 200), "--tp", - "2", + "1", "--base-gpu-id", - "4", + "2", + "--disaggregation-ib-device", + "mlx5_roce2", ] cls.process_decode = popen_launch_pd_server( cls.model, @@ -130,6 +134,8 @@ class TestDisaggregationMooncakeDifferentTP(CustomTestCase): kill_process_tree(process.pid) except Exception as e: print(f"Error killing process {process.pid}: {e}") + # wait for 5 seconds + time.sleep(5) def test_gsm8k(self): args = SimpleNamespace( @@ -138,8 +144,142 @@ class TestDisaggregationMooncakeDifferentTP(CustomTestCase): num_questions=200, max_new_tokens=512, parallel=128, - host="http://127.0.0.1", - port=int(self.lb_url.split(":")[-1]), + host=f"http://{self.base_host}", + port=int(self.lb_port), + ) + metrics = run_eval_few_shot_gsm8k(args) + print(f"Evaluation metrics: {metrics}") + + self.assertGreater(metrics["accuracy"], 0.60) + + +class TestDisaggregationMooncakeDecodeLargerTP(CustomTestCase): + @classmethod + def setUpClass(cls): + # Temporarily disable JIT DeepGEMM + cls.original_jit_deepgemm = os.environ.get("SGL_ENABLE_JIT_DEEPGEMM") + os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "false" + + cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA + parsed_url = urlparse(DEFAULT_URL_FOR_TEST) + cls.base_host = parsed_url.hostname + base_port = str(parsed_url.port) + cls.lb_port = base_port + cls.prefill_port = f"{int(base_port) + 100}" + cls.decode_port = f"{int(base_port) + 200}" + cls.prefill_url = f"http://{cls.base_host}:{cls.prefill_port}" + cls.decode_url = f"http://{cls.base_host}:{cls.decode_port}" + cls.lb_url = f"http://{cls.base_host}:{cls.lb_port}" + print(f"{cls.base_host=} {cls.lb_port=} {cls.prefill_port=} {cls.decode_port=}") + + # Non blocking start servers + cls.start_prefill() + cls.start_decode() + + # Block until both + cls.wait_server_ready(cls.prefill_url + "/health") + cls.wait_server_ready(cls.decode_url + "/health") + + lb_command = [ + "python3", + "-m", + "sglang.srt.disaggregation.mini_lb", + "--prefill", + cls.prefill_url, + "--decode", + cls.decode_url, + "--host", + cls.base_host, + "--port", + cls.lb_port, + ] + + print("Starting load balancer:", " ".join(lb_command)) + cls.process_lb = subprocess.Popen( + lb_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) + cls.wait_server_ready(cls.lb_url + "/health") + + @classmethod + def start_prefill(cls): + prefill_args = [ + "--trust-remote-code", + "--disaggregation-mode", + "prefill", + "--tp", + "1", + "--disaggregation-ib-device", + "mlx5_roce0", + ] + cls.process_prefill = popen_launch_pd_server( + cls.model, + cls.prefill_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=prefill_args, + ) + + @classmethod + def start_decode(cls): + decode_args = [ + "--trust-remote-code", + "--disaggregation-mode", + "decode", + "--tp", + "2", + "--base-gpu-id", + "1", + "--disaggregation-ib-device", + "mlx5_roce1,mlx5_roce2", + ] + cls.process_decode = popen_launch_pd_server( + cls.model, + cls.decode_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=decode_args, + ) + + @classmethod + def wait_server_ready(cls, url, timeout=60): + start_time = time.perf_counter() + while True: + try: + response = requests.get(url) + if response.status_code == 200: + print(f"Server {url} is ready") + return + except Exception: + pass + + if time.perf_counter() - start_time > timeout: + raise RuntimeError(f"Server {url} failed to start in {timeout}s") + time.sleep(1) + + @classmethod + def tearDownClass(cls): + # Restore JIT DeepGEMM environment variable + if cls.original_jit_deepgemm is not None: + os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = cls.original_jit_deepgemm + else: + os.environ.pop("SGL_ENABLE_JIT_DEEPGEMM", None) + + for process in [cls.process_lb, cls.process_decode, cls.process_prefill]: + if process: + try: + kill_process_tree(process.pid) + except Exception as e: + print(f"Error killing process {process.pid}: {e}") + # wait for 5 seconds + time.sleep(5) + + def test_gsm8k(self): + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=200, + max_new_tokens=512, + parallel=128, + host=f"http://{self.base_host}", + port=int(self.lb_port), ) metrics = run_eval_few_shot_gsm8k(args) print(f"Evaluation metrics: {metrics}")