diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index f910d3a65..2997ce958 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -252,6 +252,24 @@ jobs: cd test/srt python3 test_moe_eval_accuracy_large.py + unit-test-backend-pd: + if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && + github.event.pull_request.draft == false + runs-on: 8-gpu-runner + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Install dependencies + run: | + bash scripts/ci_install_dependency_8_gpu.sh + + - name: Run test + timeout-minutes: 10 + run: | + cd test/srt + python3 -m unittest test_disaggregation.TestDisaggregationMooncake.test_gsm8k + large-scale-test-8-gpu: if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && github.event.pull_request.draft == false diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index 2c0bf8759..1e78d6dc1 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -478,6 +478,81 @@ def popen_launch_server( raise TimeoutError("Server failed to start within the timeout period.") +def popen_launch_pd_server( + model: str, + base_url: str, + timeout: float, + api_key: Optional[str] = None, + other_args: list[str] = (), + env: Optional[dict] = None, + return_stdout_stderr: Optional[tuple] = None, +): + _, host, port = base_url.split(":") + host = host[2:] + + command = "sglang.launch_server" + + command = [ + "python3", + "-m", + command, + "--model-path", + model, + *[str(x) for x in other_args], + ] + + command.extend( + [ + "--host", + host, + "--port", + port, + ] + ) + + if api_key: + command += ["--api-key", api_key] + + print(f"command={' '.join(command)}") + + if return_stdout_stderr: + process = subprocess.Popen( + command, + stdout=return_stdout_stderr[0], + stderr=return_stdout_stderr[1], + env=env, + text=True, + ) + else: + process = subprocess.Popen(command, stdout=None, stderr=None, env=env) + + start_time = time.time() + with requests.Session() as session: + while time.time() - start_time < timeout: + try: + headers = { + "Content-Type": "application/json; charset=utf-8", + "Authorization": f"Bearer {api_key}", + } + response = session.get( + f"{base_url}/health", + headers=headers, + ) + if response.status_code == 200: + return process + except requests.RequestException: + pass + + return_code = process.poll() + if return_code is not None: + raise Exception(f"Server unexpectedly exits ({return_code=}).") + + time.sleep(10) + + kill_process_tree(process.pid) + raise TimeoutError("Server failed to start within the timeout period.") + + def run_with_timeout( func: Callable, args: tuple = (), diff --git a/scripts/ci_install_dependency_8_gpu.sh b/scripts/ci_install_dependency_8_gpu.sh index 6bb07677a..5fe1bb419 100755 --- a/scripts/ci_install_dependency_8_gpu.sh +++ b/scripts/ci_install_dependency_8_gpu.sh @@ -53,6 +53,9 @@ pip install -e lmms-eval/ # Install FlashMLA for attention backend tests pip install git+https://github.com/deepseek-ai/FlashMLA.git +# Install mooncake-transfer-engine +pip install mooncake-transfer-engine + # Install system dependencies # apt-get update && apt-get install -y libibverbs-dev infiniband-diags libmlx5-1 rdma-core openssh-server perftest ibverbs-providers libibumad3 libibverbs1 libnl-3-200 libnl-route-3-200 librdmacm1 rdma-core-dev infiniband-diags-dev libibverbs-dev libibverbs-utils librdmacm-dev librdmacm-utils ibverbs-utils rdma-core-utils apt install curl wget git sudo libibverbs-dev -y diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 59018e343..0feb44778 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -85,6 +85,9 @@ suites = { TestFile("test_w8a8_quantization.py", 46), TestFile("models/lora/test_lora_cuda_graph.py", 250), ], + "per-commit-pd": [ + TestFile("test_disaggregation.py", 90), + ], "per-commit-2-gpu": [ TestFile("models/lora/test_lora_tp.py", 116), TestFile("test_data_parallelism.py", 73), diff --git a/test/srt/test_disaggregation.py b/test/srt/test_disaggregation.py new file mode 100644 index 000000000..ee8bca0b3 --- /dev/null +++ b/test/srt/test_disaggregation.py @@ -0,0 +1,142 @@ +import subprocess +import threading +import time +import unittest +from types import SimpleNamespace + +import requests +import torch + +from sglang.srt.utils import kill_process_tree +from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.test_utils import ( + DEFAULT_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + popen_launch_pd_server, + run_with_timeout, +) + + +class TestDisaggregationMooncake(CustomTestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MODEL_NAME_FOR_TEST + cls.base_host = "127.0.0.1" + cls.base_port = int(DEFAULT_URL_FOR_TEST.split(":")[-1]) + cls.lb_url = DEFAULT_URL_FOR_TEST + cls.prefill_url = f"http://{cls.base_host}:{cls.base_port + 100}" + cls.decode_url = f"http://{cls.base_host}:{cls.base_port + 200}" + + run_with_timeout(cls.start_prefill, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH) + run_with_timeout(cls.start_decode, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH) + + cls.wait_server_ready(cls.prefill_url + "/health") + cls.wait_server_ready(cls.decode_url + "/health") + + lb_command = [ + "python3", + "-m", + "sglang.srt.disaggregation.mini_lb", + "--prefill", + cls.prefill_url, + "--decode", + cls.decode_url, + "--host", + cls.base_host, + "--port", + str(cls.base_port), + ] + + print("Starting load balancer:", " ".join(lb_command)) + cls.process_lb = subprocess.Popen( + lb_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) + cls.wait_server_ready(cls.lb_url + "/health") + + @classmethod + def start_prefill(cls): + prefill_args = [ + "--trust-remote-code", + "--disaggregation-mode", + "prefill", + "--host", + cls.base_host, + "--port", + str(cls.base_port + 100), + "--tp", + "4", + ] + cls.process_prefill = popen_launch_pd_server( + cls.model, + cls.prefill_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=prefill_args, + ) + + @classmethod + def start_decode(cls): + decode_args = [ + "--trust-remote-code", + "--disaggregation-mode", + "decode", + "--host", + cls.base_host, + "--port", + str(cls.base_port + 200), + "--tp", + "4", + "--base-gpu-id", + "4", + ] + cls.process_decode = popen_launch_pd_server( + cls.model, + cls.decode_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=decode_args, + ) + + @classmethod + def wait_server_ready(cls, url, timeout=60): + start_time = time.time() + while True: + try: + response = requests.get(url) + if response.status_code == 200: + print(f"Server {url} is ready") + return + except Exception: + pass + + if time.time() - start_time > timeout: + raise RuntimeError(f"Server {url} failed to start in {timeout}s") + time.sleep(1) + + @classmethod + def tearDownClass(cls): + for process in [cls.process_lb, cls.process_decode, cls.process_prefill]: + if process: + try: + kill_process_tree(process.pid) + except Exception as e: + print(f"Error killing process {process.pid}: {e}") + + def test_gsm8k(self): + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=200, + max_new_tokens=512, + parallel=128, + host="http://127.0.0.1", + port=int(self.lb_url.split(":")[-1]), + ) + metrics = run_eval_few_shot_gsm8k(args) + print(f"Evaluation metrics: {metrics}") + + self.assertGreater(metrics["accuracy"], 0.62) + + +if __name__ == "__main__": + unittest.main()