diff --git a/python/sglang/test/test_disaggregation_utils.py b/python/sglang/test/test_disaggregation_utils.py index f61b71a9d..5c4601601 100644 --- a/python/sglang/test/test_disaggregation_utils.py +++ b/python/sglang/test/test_disaggregation_utils.py @@ -1,10 +1,12 @@ import time +from urllib.parse import urlparse import requests from sglang.srt.utils import kill_process_tree from sglang.test.test_utils import ( DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, CustomTestCase, popen_with_error_check, ) @@ -13,8 +15,17 @@ from sglang.test.test_utils import ( class TestDisaggregationBase(CustomTestCase): @classmethod def setUpClass(cls): + parsed_url = urlparse(DEFAULT_URL_FOR_TEST) + cls.base_host = parsed_url.hostname + base_port = str(parsed_url.port) + cls.lb_port = base_port + cls.prefill_port = f"{int(base_port) + 100}" + cls.decode_port = f"{int(base_port) + 200}" + cls.prefill_url = f"http://{cls.base_host}:{cls.prefill_port}" + cls.decode_url = f"http://{cls.base_host}:{cls.decode_port}" + cls.lb_url = f"http://{cls.base_host}:{cls.lb_port}" + print(f"{cls.base_host=} {cls.lb_port=} {cls.prefill_port=} {cls.decode_port=}") cls.process_lb, cls.process_decode, cls.process_prefill = None, None, None - pass @classmethod def launch_lb(cls): diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 08900cae4..b0cfd44bf 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -146,12 +146,12 @@ suites = { ], "per-commit-8-gpu": [ TestFile("hicache/test_hicache_storage_mooncake_backend.py", 400), - TestFile("lora/test_lora_llama4.py", 600), - TestFile("test_disaggregation.py", 499), + TestFile("lora/test_lora_llama4.py", 400), + TestFile("test_disaggregation.py", 600), TestFile("test_disaggregation_dp_attention.py", 155), - TestFile("test_disaggregation_different_tp.py", 155), - TestFile("test_disaggregation_pp.py", 60), - TestFile("test_full_deepseek_v3.py", 333), + TestFile("test_disaggregation_different_tp.py", 600), + TestFile("test_disaggregation_pp.py", 140), + TestFile("test_full_deepseek_v3.py", 550), ], "per-commit-4-gpu-b200": [ # TestFile("test_gpt_oss_4gpu.py", 600), diff --git a/test/srt/test_disaggregation.py b/test/srt/test_disaggregation.py index 827bfc3b8..9fecf5c59 100644 --- a/test/srt/test_disaggregation.py +++ b/test/srt/test_disaggregation.py @@ -3,7 +3,6 @@ import os import time import unittest from types import SimpleNamespace -from urllib.parse import urlparse import requests @@ -14,7 +13,6 @@ from sglang.test.test_utils import ( DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST, DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - DEFAULT_URL_FOR_TEST, popen_launch_pd_server, ) @@ -22,17 +20,8 @@ from sglang.test.test_utils import ( class TestDisaggregationAccuracy(TestDisaggregationBase): @classmethod def setUpClass(cls): + super().setUpClass() cls.model = DEFAULT_MODEL_NAME_FOR_TEST - parsed_url = urlparse(DEFAULT_URL_FOR_TEST) - cls.base_host = parsed_url.hostname - base_port = str(parsed_url.port) - cls.lb_port = base_port - cls.prefill_port = f"{int(base_port) + 100}" - cls.decode_port = f"{int(base_port) + 200}" - cls.prefill_url = f"http://{cls.base_host}:{cls.prefill_port}" - cls.decode_url = f"http://{cls.base_host}:{cls.decode_port}" - cls.lb_url = f"http://{cls.base_host}:{cls.lb_port}" - print(f"{cls.base_host=} {cls.lb_port=} {cls.prefill_port=} {cls.decode_port=}") # Non blocking start servers cls.start_prefill() @@ -51,9 +40,9 @@ class TestDisaggregationAccuracy(TestDisaggregationBase): "--disaggregation-mode", "prefill", "--tp", - "1", + "2", "--disaggregation-ib-device", - "mlx5_roce0", + "mlx5_roce0,mlx5_roce1", ] cls.process_prefill = popen_launch_pd_server( cls.model, @@ -69,11 +58,11 @@ class TestDisaggregationAccuracy(TestDisaggregationBase): "--disaggregation-mode", "decode", "--tp", - "1", + "2", "--base-gpu-id", - "1", + "2", "--disaggregation-ib-device", - "mlx5_roce1", + "mlx5_roce2,mlx5_roce3", ] cls.process_decode = popen_launch_pd_server( cls.model, @@ -154,20 +143,11 @@ class TestDisaggregationAccuracy(TestDisaggregationBase): class TestDisaggregationMooncakeFailure(TestDisaggregationBase): @classmethod def setUpClass(cls): + super().setUpClass() # set DISAGGREGATION_TEST_FAILURE_PROB to simulate failure os.environ["DISAGGREGATION_TEST_FAILURE_PROB"] = "0.05" cls.model = DEFAULT_MODEL_NAME_FOR_TEST - parsed_url = urlparse(DEFAULT_URL_FOR_TEST) - cls.base_host = parsed_url.hostname - base_port = str(parsed_url.port) - cls.lb_port = base_port - cls.prefill_port = f"{int(base_port) + 100}" - cls.decode_port = f"{int(base_port) + 200}" - cls.prefill_url = f"http://{cls.base_host}:{cls.prefill_port}" - cls.decode_url = f"http://{cls.base_host}:{cls.decode_port}" - cls.lb_url = f"http://{cls.base_host}:{cls.lb_port}" - print(f"{cls.base_host=} {cls.lb_port=} {cls.prefill_port=} {cls.decode_port=}") # Non blocking start servers cls.start_prefill() @@ -191,9 +171,9 @@ class TestDisaggregationMooncakeFailure(TestDisaggregationBase): "--disaggregation-mode", "prefill", "--tp", - "1", + "2", "--disaggregation-ib-device", - "mlx5_roce0", + "mlx5_roce0,mlx5_roce1", ] cls.process_prefill = popen_launch_pd_server( cls.model, @@ -209,11 +189,11 @@ class TestDisaggregationMooncakeFailure(TestDisaggregationBase): "--disaggregation-mode", "decode", "--tp", - "1", + "2", "--base-gpu-id", - "1", + "2", "--disaggregation-ib-device", - "mlx5_roce1", + "mlx5_roce2,mlx5_roce3", ] cls.process_decode = popen_launch_pd_server( cls.model, @@ -254,17 +234,9 @@ class TestDisaggregationMooncakeSpec(TestDisaggregationBase): @classmethod def setUpClass(cls): + super().setUpClass() cls.model = DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST cls.draft_model = DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST - parsed_url = urlparse(DEFAULT_URL_FOR_TEST) - cls.base_host = parsed_url.hostname - base_port = str(parsed_url.port) - cls.lb_port = base_port - cls.prefill_port = f"{int(base_port) + 100}" - cls.decode_port = f"{int(base_port) + 200}" - cls.prefill_url = f"http://{cls.base_host}:{cls.prefill_port}" - cls.decode_url = f"http://{cls.base_host}:{cls.decode_port}" - cls.lb_url = f"http://{cls.base_host}:{cls.lb_port}" cls.spec_args = [ "--speculative-algorithm", "EAGLE", @@ -348,18 +320,9 @@ class TestDisaggregationMooncakeSpec(TestDisaggregationBase): class TestDisaggregationSimulatedRetract(TestDisaggregationBase): @classmethod def setUpClass(cls): + super().setUpClass() os.environ["SGLANG_TEST_RETRACT"] = "true" cls.model = DEFAULT_MODEL_NAME_FOR_TEST - parsed_url = urlparse(DEFAULT_URL_FOR_TEST) - cls.base_host = parsed_url.hostname - base_port = str(parsed_url.port) - cls.lb_port = base_port - cls.prefill_port = f"{int(base_port) + 100}" - cls.decode_port = f"{int(base_port) + 200}" - cls.prefill_url = f"http://{cls.base_host}:{cls.prefill_port}" - cls.decode_url = f"http://{cls.base_host}:{cls.decode_port}" - cls.lb_url = f"http://{cls.base_host}:{cls.lb_port}" - print(f"{cls.base_host=} {cls.lb_port=} {cls.prefill_port=} {cls.decode_port=}") # Non blocking start servers cls.start_prefill() @@ -383,9 +346,9 @@ class TestDisaggregationSimulatedRetract(TestDisaggregationBase): "--disaggregation-mode", "prefill", "--tp", - "1", + "2", "--disaggregation-ib-device", - "mlx5_roce0", + "mlx5_roce0,mlx5_roce1", ] cls.process_prefill = popen_launch_pd_server( cls.model, @@ -401,11 +364,11 @@ class TestDisaggregationSimulatedRetract(TestDisaggregationBase): "--disaggregation-mode", "decode", "--tp", - "1", + "2", "--base-gpu-id", - "1", + "2", "--disaggregation-ib-device", - "mlx5_roce1", + "mlx5_roce2,mlx5_roce3", ] cls.process_decode = popen_launch_pd_server( cls.model, diff --git a/test/srt/test_disaggregation_different_tp.py b/test/srt/test_disaggregation_different_tp.py index 67a3afcbe..3fd00c217 100644 --- a/test/srt/test_disaggregation_different_tp.py +++ b/test/srt/test_disaggregation_different_tp.py @@ -2,14 +2,13 @@ import os import time import unittest from types import SimpleNamespace -from urllib.parse import urlparse from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k from sglang.test.test_disaggregation_utils import TestDisaggregationBase from sglang.test.test_utils import ( + DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_MODEL_NAME_FOR_TEST_MLA, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - DEFAULT_URL_FOR_TEST, popen_launch_pd_server, ) @@ -17,21 +16,86 @@ from sglang.test.test_utils import ( class TestDisaggregationMooncakePrefillLargerTP(TestDisaggregationBase): @classmethod def setUpClass(cls): + super().setUpClass() + # Temporarily disable JIT DeepGEMM + cls.original_jit_deepgemm = os.environ.get("SGL_ENABLE_JIT_DEEPGEMM") + os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "false" + + cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA + + # Non blocking start servers + cls.start_prefill() + cls.start_decode() + + # Block until both + cls.wait_server_ready(cls.prefill_url + "/health") + cls.wait_server_ready(cls.decode_url + "/health") + + cls.launch_lb() + + @classmethod + def start_prefill(cls): + prefill_args = [ + "--trust-remote-code", + "--disaggregation-mode", + "prefill", + "--tp", + "4", + "--disaggregation-ib-device", + "mlx5_roce0,mlx5_roce1", + ] + cls.process_prefill = popen_launch_pd_server( + cls.model, + cls.prefill_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=prefill_args, + ) + + @classmethod + def start_decode(cls): + decode_args = [ + "--trust-remote-code", + "--disaggregation-mode", + "decode", + "--tp", + "2", + "--base-gpu-id", + "4", + "--disaggregation-ib-device", + "mlx5_roce4,mlx5_roce5", + ] + cls.process_decode = popen_launch_pd_server( + cls.model, + cls.decode_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=decode_args, + ) + + def test_gsm8k(self): + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=200, + max_new_tokens=512, + parallel=128, + host=f"http://{self.base_host}", + port=int(self.lb_port), + ) + metrics = run_eval_few_shot_gsm8k(args) + print(f"Evaluation metrics: {metrics}") + + self.assertGreater(metrics["accuracy"], 0.60) + + +class TestDisaggregationMooncakeDecodeLargerTP(TestDisaggregationBase): + @classmethod + def setUpClass(cls): + super().setUpClass() # Temporarily disable JIT DeepGEMM cls.original_jit_deepgemm = os.environ.get("SGL_ENABLE_JIT_DEEPGEMM") os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "false" cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA - parsed_url = urlparse(DEFAULT_URL_FOR_TEST) - cls.base_host = parsed_url.hostname - base_port = str(parsed_url.port) - cls.lb_port = base_port - cls.prefill_port = f"{int(base_port) + 100}" - cls.decode_port = f"{int(base_port) + 200}" - cls.prefill_url = f"http://{cls.base_host}:{cls.prefill_port}" - cls.decode_url = f"http://{cls.base_host}:{cls.decode_port}" - cls.lb_url = f"http://{cls.base_host}:{cls.lb_port}" - print(f"{cls.base_host=} {cls.lb_port=} {cls.prefill_port=} {cls.decode_port=}") # Non blocking start servers cls.start_prefill() @@ -68,11 +132,11 @@ class TestDisaggregationMooncakePrefillLargerTP(TestDisaggregationBase): "--disaggregation-mode", "decode", "--tp", - "1", + "4", "--base-gpu-id", - "2", + "4", "--disaggregation-ib-device", - "mlx5_roce2", + "mlx5_roce4,mlx5_roce5", ] cls.process_decode = popen_launch_pd_server( cls.model, @@ -97,24 +161,15 @@ class TestDisaggregationMooncakePrefillLargerTP(TestDisaggregationBase): self.assertGreater(metrics["accuracy"], 0.60) -class TestDisaggregationMooncakeDecodeLargerTP(TestDisaggregationBase): +class TestDisaggregationMooncakeMHAPrefillLargerTP(TestDisaggregationBase): @classmethod def setUpClass(cls): + super().setUpClass() # Temporarily disable JIT DeepGEMM cls.original_jit_deepgemm = os.environ.get("SGL_ENABLE_JIT_DEEPGEMM") os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "false" - cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA - parsed_url = urlparse(DEFAULT_URL_FOR_TEST) - cls.base_host = parsed_url.hostname - base_port = str(parsed_url.port) - cls.lb_port = base_port - cls.prefill_port = f"{int(base_port) + 100}" - cls.decode_port = f"{int(base_port) + 200}" - cls.prefill_url = f"http://{cls.base_host}:{cls.prefill_port}" - cls.decode_url = f"http://{cls.base_host}:{cls.decode_port}" - cls.lb_url = f"http://{cls.base_host}:{cls.lb_port}" - print(f"{cls.base_host=} {cls.lb_port=} {cls.prefill_port=} {cls.decode_port=}") + cls.model = DEFAULT_MODEL_NAME_FOR_TEST # Non blocking start servers cls.start_prefill() @@ -133,9 +188,9 @@ class TestDisaggregationMooncakeDecodeLargerTP(TestDisaggregationBase): "--disaggregation-mode", "prefill", "--tp", - "1", + "4", "--disaggregation-ib-device", - "mlx5_roce0", + "mlx5_roce0,mlx5_roce1", ] cls.process_prefill = popen_launch_pd_server( cls.model, @@ -153,9 +208,83 @@ class TestDisaggregationMooncakeDecodeLargerTP(TestDisaggregationBase): "--tp", "2", "--base-gpu-id", - "1", + "4", "--disaggregation-ib-device", - "mlx5_roce1,mlx5_roce2", + "mlx5_roce4,mlx5_roce5", + ] + cls.process_decode = popen_launch_pd_server( + cls.model, + cls.decode_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=decode_args, + ) + + def test_gsm8k(self): + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=200, + max_new_tokens=512, + parallel=128, + host=f"http://{self.base_host}", + port=int(self.lb_port), + ) + metrics = run_eval_few_shot_gsm8k(args) + print(f"Evaluation metrics: {metrics}") + + self.assertGreater(metrics["accuracy"], 0.60) + + +class TestDisaggregationMooncakeMHADecodeLargerTP(TestDisaggregationBase): + @classmethod + def setUpClass(cls): + super().setUpClass() + # Temporarily disable JIT DeepGEMM + cls.original_jit_deepgemm = os.environ.get("SGL_ENABLE_JIT_DEEPGEMM") + os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "false" + + cls.model = DEFAULT_MODEL_NAME_FOR_TEST + + # Non blocking start servers + cls.start_prefill() + cls.start_decode() + + # Block until both + cls.wait_server_ready(cls.prefill_url + "/health") + cls.wait_server_ready(cls.decode_url + "/health") + + cls.launch_lb() + + @classmethod + def start_prefill(cls): + prefill_args = [ + "--trust-remote-code", + "--disaggregation-mode", + "prefill", + "--tp", + "2", + "--disaggregation-ib-device", + "mlx5_roce0,mlx5_roce1", + ] + cls.process_prefill = popen_launch_pd_server( + cls.model, + cls.prefill_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=prefill_args, + ) + + @classmethod + def start_decode(cls): + decode_args = [ + "--trust-remote-code", + "--disaggregation-mode", + "decode", + "--tp", + "4", + "--base-gpu-id", + "4", + "--disaggregation-ib-device", + "mlx5_roce4,mlx5_roce5", ] cls.process_decode = popen_launch_pd_server( cls.model, diff --git a/test/srt/test_disaggregation_dp_attention.py b/test/srt/test_disaggregation_dp_attention.py index 4f3c6ae86..bf934a913 100644 --- a/test/srt/test_disaggregation_dp_attention.py +++ b/test/srt/test_disaggregation_dp_attention.py @@ -17,21 +17,12 @@ from sglang.test.test_utils import ( class TestDisaggregationDPAttention(TestDisaggregationBase): @classmethod def setUpClass(cls): + super().setUpClass() # Temporarily disable JIT DeepGEMM cls.original_jit_deepgemm = os.environ.get("SGL_ENABLE_JIT_DEEPGEMM") os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "false" cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA - parsed_url = urlparse(DEFAULT_URL_FOR_TEST) - cls.base_host = parsed_url.hostname - base_port = str(parsed_url.port) - cls.lb_port = base_port - cls.prefill_port = f"{int(base_port) + 100}" - cls.decode_port = f"{int(base_port) + 200}" - cls.prefill_url = f"http://{cls.base_host}:{cls.prefill_port}" - cls.decode_url = f"http://{cls.base_host}:{cls.decode_port}" - cls.lb_url = f"http://{cls.base_host}:{cls.lb_port}" - print(f"{cls.base_host=} {cls.lb_port=} {cls.prefill_port=} {cls.decode_port=}") # Non blocking start servers cls.start_prefill() diff --git a/test/srt/test_disaggregation_pp.py b/test/srt/test_disaggregation_pp.py index a8bab8f81..7367e95a0 100644 --- a/test/srt/test_disaggregation_pp.py +++ b/test/srt/test_disaggregation_pp.py @@ -1,14 +1,12 @@ import time import unittest from types import SimpleNamespace -from urllib.parse import urlparse from sglang.test.few_shot_gsm8k import run_eval from sglang.test.test_disaggregation_utils import TestDisaggregationBase from sglang.test.test_utils import ( DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - DEFAULT_URL_FOR_TEST, popen_launch_pd_server, ) @@ -16,17 +14,8 @@ from sglang.test.test_utils import ( class TestDisaggregationPPAccuracy(TestDisaggregationBase): @classmethod def setUpClass(cls): + super().setUpClass() cls.model = DEFAULT_MODEL_NAME_FOR_TEST - parsed_url = urlparse(DEFAULT_URL_FOR_TEST) - cls.base_host = parsed_url.hostname - base_port = str(parsed_url.port) - cls.lb_port = base_port - cls.prefill_port = f"{int(base_port) + 100}" - cls.decode_port = f"{int(base_port) + 200}" - cls.prefill_url = f"http://{cls.base_host}:{cls.prefill_port}" - cls.decode_url = f"http://{cls.base_host}:{cls.decode_port}" - cls.lb_url = f"http://{cls.base_host}:{cls.lb_port}" - print(f"{cls.base_host=} {cls.lb_port=} {cls.prefill_port=} {cls.decode_port=}") # Non blocking start servers cls.start_prefill() @@ -45,7 +34,7 @@ class TestDisaggregationPPAccuracy(TestDisaggregationBase): "--disaggregation-mode", "prefill", "--tp-size", - "1", + "2", "--pp-size", "2", "--disaggregation-ib-device", @@ -66,11 +55,11 @@ class TestDisaggregationPPAccuracy(TestDisaggregationBase): "--disaggregation-mode", "decode", "--tp", - "1", - "--base-gpu-id", "2", + "--base-gpu-id", + "4", "--disaggregation-ib-device", - "mlx5_roce2", + "mlx5_roce4,mlx5_roce5", ] cls.process_decode = popen_launch_pd_server( cls.model,