diff --git a/.github/workflows/pr-test-rust.yml b/.github/workflows/pr-test-rust.yml index e43c57dab..33d98f176 100644 --- a/.github/workflows/pr-test-rust.yml +++ b/.github/workflows/pr-test-rust.yml @@ -105,11 +105,11 @@ jobs: pip install fastapi uvicorn orjson pytest -q -m integration - - name: Run e2e test + - name: Run Python E2E tests run: | bash scripts/killall_sglang.sh "nuk_gpus" - cd sgl-router/py_test - python3 run_suite.py + cd sgl-router + pytest -m e2e -s -vv -o log_cli=true --log-cli-level=INFO finish: needs: [unit-test-rust, e2e-python] diff --git a/sgl-router/py_test/e2e/conftest.py b/sgl-router/py_test/e2e/conftest.py new file mode 100644 index 000000000..02eea55d4 --- /dev/null +++ b/sgl-router/py_test/e2e/conftest.py @@ -0,0 +1,235 @@ +import socket +import subprocess +import time +from types import SimpleNamespace +from urllib.parse import urlparse + +import pytest +import requests + +from sglang.test.test_utils import ( + DEFAULT_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, +) + + +def _find_available_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("127.0.0.1", 0)) + return s.getsockname()[1] + + +def _parse_url(base_url: str) -> tuple[str, str]: + """Parse a base URL and return (host, port) as strings. + + This is more robust than simple string splitting and supports different schemes + and URL shapes like trailing paths. + """ + parsed = urlparse(base_url) + return parsed.hostname or "127.0.0.1", ( + str(parsed.port) if parsed.port is not None else "" + ) + + +def _wait_router_health(base_url: str, timeout: float) -> None: + start = time.perf_counter() + with requests.Session() as session: + while time.perf_counter() - start < timeout: + try: + r = session.get(f"{base_url}/health", timeout=5) + if r.status_code == 200: + return + except requests.RequestException: + pass + time.sleep(2) + raise TimeoutError("Router failed to become healthy in time") + + +def _popen_launch_router( + model: str, + base_url: str, + dp_size: int, + timeout: float, + policy: str = "cache_aware", +) -> subprocess.Popen: + host, port = _parse_url(base_url) + + prom_port = _find_available_port() + + cmd = [ + "python3", + "-m", + "sglang_router.launch_server", + "--model-path", + model, + "--host", + host, + "--port", + port, + "--dp", + str(dp_size), + "--router-policy", + policy, + "--allow-auto-truncate", + "--router-prometheus-port", + str(prom_port), + "--router-prometheus-host", + "127.0.0.1", + ] + + proc = subprocess.Popen(cmd) + _wait_router_health(base_url, timeout) + return proc + + +def _popen_launch_worker( + model: str, + base_url: str, + *, + dp_size: int | None = None, + api_key: str | None = None, +) -> subprocess.Popen: + host, port = _parse_url(base_url) + + cmd = [ + "python3", + "-m", + "sglang.launch_server", + "--model-path", + model, + "--host", + host, + "--port", + port, + "--base-gpu-id", + "0", + ] + if dp_size is not None: + cmd += ["--dp-size", str(dp_size)] + if api_key is not None: + cmd += ["--api-key", api_key] + return subprocess.Popen(cmd) + + +def _popen_launch_router_only( + base_url: str, + policy: str = "round_robin", + timeout: float = 120.0, + *, + dp_aware: bool = False, + api_key: str | None = None, +) -> subprocess.Popen: + host, port = _parse_url(base_url) + + prom_port = _find_available_port() + cmd = [ + "python3", + "-m", + "sglang_router.launch_router", + "--host", + host, + "--port", + port, + "--policy", + policy, + ] + if dp_aware: + cmd += ["--dp-aware"] + if api_key is not None: + cmd += ["--api-key", api_key] + cmd += [ + "--prometheus-port", + str(prom_port), + "--prometheus-host", + "127.0.0.1", + ] + proc = subprocess.Popen(cmd) + _wait_router_health(base_url, timeout) + return proc + + +def _terminate(proc: subprocess.Popen, timeout: float = 120) -> None: + if proc is None: + return + proc.terminate() + start = time.perf_counter() + while proc.poll() is None: + if time.perf_counter() - start > timeout: + proc.kill() + break + time.sleep(1) + + +def pytest_configure(config): + config.addinivalue_line("markers", "e2e: mark as end-to-end test") + + +@pytest.fixture(scope="session") +def e2e_model() -> str: + # Always use the default test model + return DEFAULT_MODEL_NAME_FOR_TEST + + +@pytest.fixture +def e2e_router(e2e_model: str): + # Keep this available but tests below use router-only to avoid GPU contention + base_url = DEFAULT_URL_FOR_TEST + proc = _popen_launch_router( + e2e_model, base_url, dp_size=2, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH + ) + try: + yield SimpleNamespace(proc=proc, url=base_url) + finally: + _terminate(proc) + + +@pytest.fixture +def e2e_router_only_rr(): + port = _find_available_port() + base_url = f"http://127.0.0.1:{port}" + proc = _popen_launch_router_only(base_url, policy="round_robin") + try: + yield SimpleNamespace(proc=proc, url=base_url) + finally: + _terminate(proc) + + +@pytest.fixture(scope="session") +def e2e_primary_worker(e2e_model: str): + port = _find_available_port() + base_url = f"http://127.0.0.1:{port}" + proc = _popen_launch_worker(e2e_model, base_url) + # Router health gate will handle worker readiness + try: + yield SimpleNamespace(proc=proc, url=base_url) + finally: + _terminate(proc) + + +@pytest.fixture +def e2e_router_only_rr_dp_aware_api(): + """Router-only with dp-aware enabled and an API key.""" + port = _find_available_port() + base_url = f"http://127.0.0.1:{port}" + api_key = "secret" + proc = _popen_launch_router_only( + base_url, policy="round_robin", timeout=180.0, dp_aware=True, api_key=api_key + ) + try: + yield SimpleNamespace(proc=proc, url=base_url, api_key=api_key) + finally: + _terminate(proc) + + +@pytest.fixture +def e2e_worker_dp2_api(e2e_model: str, e2e_router_only_rr_dp_aware_api): + """Worker with dp-size=2 and the same API key as the dp-aware router.""" + port = _find_available_port() + base_url = f"http://127.0.0.1:{port}" + api_key = e2e_router_only_rr_dp_aware_api.api_key + proc = _popen_launch_worker(e2e_model, base_url, dp_size=2, api_key=api_key) + try: + yield SimpleNamespace(proc=proc, url=base_url) + finally: + _terminate(proc) diff --git a/sgl-router/py_test/e2e/test_e2e_router.py b/sgl-router/py_test/e2e/test_e2e_router.py new file mode 100644 index 000000000..b40c43408 --- /dev/null +++ b/sgl-router/py_test/e2e/test_e2e_router.py @@ -0,0 +1,146 @@ +import threading +import time +from types import SimpleNamespace + +import pytest +import requests + +from sglang.test.run_eval import run_eval + + +@pytest.mark.e2e +def test_mmlu(e2e_router_only_rr, e2e_primary_worker, e2e_model): + # Attach the primary worker to a fresh router-only instance (single model) + base = e2e_router_only_rr.url + r = requests.post( + f"{base}/add_worker", params={"url": e2e_primary_worker.url}, timeout=180 + ) + r.raise_for_status() + + args = SimpleNamespace( + base_url=base, + model=e2e_model, + eval_name="mmlu", + num_examples=64, + num_threads=32, + temperature=0.1, + ) + metrics = run_eval(args) + assert metrics["score"] >= 0.65 + + +@pytest.mark.e2e +def test_add_and_remove_worker_live(e2e_router_only_rr, e2e_primary_worker, e2e_model): + base = e2e_router_only_rr.url + worker_url = e2e_primary_worker.url + + r = requests.post(f"{base}/add_worker", params={"url": worker_url}, timeout=180) + r.raise_for_status() + + with requests.Session() as s: + for i in range(8): + r = s.post( + f"{base}/v1/completions", + json={ + "model": e2e_model, + "prompt": f"x{i}", + "max_tokens": 1, + "stream": False, + }, + timeout=120, + ) + r.raise_for_status() + + # Remove the worker + r = requests.post(f"{base}/remove_worker", params={"url": worker_url}, timeout=60) + r.raise_for_status() + + +@pytest.mark.e2e +def test_lazy_fault_tolerance_live(e2e_router_only_rr, e2e_primary_worker, e2e_model): + base = e2e_router_only_rr.url + worker = e2e_primary_worker + + r = requests.post(f"{base}/add_worker", params={"url": worker.url}, timeout=180) + r.raise_for_status() + + def killer(): + time.sleep(10) + try: + worker.proc.terminate() + except Exception: + pass + + t = threading.Thread(target=killer, daemon=True) + t.start() + + args = SimpleNamespace( + base_url=base, + model=e2e_model, + eval_name="mmlu", + num_examples=32, + num_threads=16, + temperature=0.0, + ) + metrics = run_eval(args) + assert 0.0 <= metrics["score"] <= 1.0 + + +@pytest.mark.e2e +def test_dp_aware_worker_expansion_and_api_key( + e2e_model, + e2e_router_only_rr_dp_aware_api, + e2e_worker_dp2_api, +): + """ + Launch a router-only instance in dp_aware mode and a single worker with dp_size=2 + and API key protection. Verify expansion, auth enforcement, and basic eval. + """ + import os + + router_url = e2e_router_only_rr_dp_aware_api.url + worker_url = e2e_worker_dp2_api.url + api_key = e2e_router_only_rr_dp_aware_api.api_key + + # Attach worker; router should expand to dp_size logical workers + r = requests.post( + f"{router_url}/add_worker", params={"url": worker_url}, timeout=180 + ) + r.raise_for_status() + + r = requests.get(f"{router_url}/list_workers", timeout=30) + r.raise_for_status() + urls = r.json().get("urls", []) + assert len(urls) == 2 + assert set(urls) == {f"{worker_url}@0", f"{worker_url}@1"} + + # Verify API key enforcement path-through + # 1) Without Authorization -> 401 from backend + r = requests.post( + f"{router_url}/v1/completions", + json={"model": e2e_model, "prompt": "hi", "max_tokens": 1}, + timeout=60, + ) + assert r.status_code == 401 + + # 2) With correct Authorization -> 200 + r = requests.post( + f"{router_url}/v1/completions", + json={"model": e2e_model, "prompt": "hi", "max_tokens": 1}, + headers={"Authorization": f"Bearer {api_key}"}, + timeout=60, + ) + assert r.status_code == 200 + + # Finally, run MMLU eval through the router with auth + os.environ["OPENAI_API_KEY"] = api_key + args = SimpleNamespace( + base_url=router_url, + model=e2e_model, + eval_name="mmlu", + num_examples=64, + num_threads=32, + temperature=0.1, + ) + metrics = run_eval(args) + assert metrics["score"] >= 0.65 diff --git a/sgl-router/py_test/fixtures/mock_worker.py b/sgl-router/py_test/fixtures/mock_worker.py index 92d1e9a73..13b3c71a6 100644 --- a/sgl-router/py_test/fixtures/mock_worker.py +++ b/sgl-router/py_test/fixtures/mock_worker.py @@ -44,6 +44,7 @@ def _parse_args() -> argparse.Namespace: p.add_argument("--api-key", default=None) p.add_argument("--max-payload-bytes", type=int, default=10 * 1024 * 1024) p.add_argument("--stream", action="store_true") + p.add_argument("--dp-size", type=int, default=1) p.add_argument("--crash-on-request", action="store_true") p.add_argument("--health-fail-after-ms", type=int, default=0) return p.parse_args() @@ -125,12 +126,15 @@ def create_app(args: argparse.Namespace) -> FastAPI: return JSONResponse({"data": [{"id": "mock", "object": "model"}]}) @app.get("/get_server_info") - async def get_server_info(): + async def get_server_info(request: Request): + # Enforce API key on server info when required (used by dp_aware probing) + check_api_key(request) return JSONResponse( { "worker_id": worker_id, "load_in_flight": _inflight, "cache": {"size": 0, "hit_rate": 0.0}, + "dp_size": int(args.dp_size), } ) diff --git a/sgl-router/py_test/integration/test_payload_size.py b/sgl-router/py_test/integration/test_payload_size.py new file mode 100644 index 000000000..b3583ab28 --- /dev/null +++ b/sgl-router/py_test/integration/test_payload_size.py @@ -0,0 +1,33 @@ +import pytest +import requests + + +@pytest.mark.integration +def test_payload_size_limit(router_manager, mock_workers): + # Start one backend and a router with a 1MB payload limit + _, urls, _ = mock_workers(n=1) + rh = router_manager.start_router( + worker_urls=urls, + policy="round_robin", + extra={"max_payload_size": 1 * 1024 * 1024}, # 1MB + ) + + # Payload just under 1MB should succeed + payload_small = { + "model": "test-model", + "prompt": "x" * int(0.5 * 1024 * 1024), # ~0.5MB + "max_tokens": 1, + "stream": False, + } + r = requests.post(f"{rh.url}/v1/completions", json=payload_small) + assert r.status_code == 200 + + # Payload over 1MB should fail with 413 + payload_large = { + "model": "test-model", + "prompt": "x" * int(1.2 * 1024 * 1024), # ~1.2MB + "max_tokens": 1, + "stream": False, + } + r = requests.post(f"{rh.url}/v1/completions", json=payload_large) + assert r.status_code == 413 diff --git a/sgl-router/py_test/run_suite.py b/sgl-router/py_test/run_suite.py deleted file mode 100644 index 195c2b36e..000000000 --- a/sgl-router/py_test/run_suite.py +++ /dev/null @@ -1,27 +0,0 @@ -import argparse -import glob - -from sglang.test.test_utils import TestFile, run_unittest_files - -if __name__ == "__main__": - arg_parser = argparse.ArgumentParser() - arg_parser.add_argument( - "--timeout-per-file", - type=int, - default=2000, - help="The time limit for running one file in seconds.", - ) - args = arg_parser.parse_args() - - files = glob.glob("**/test_*.py", recursive=True) - # Exclude integration tests from the e2e suite; those are run separately via pytest -m integration - files = [ - f - for f in files - if "/integration/" not in f and not f.startswith("integration/") - ] - files.sort() - - test_files = [TestFile(name=file) for file in files] - exit_code = run_unittest_files(test_files, args.timeout_per_file) - exit(exit_code) diff --git a/sgl-router/py_test/test_launch_router.py b/sgl-router/py_test/test_launch_router.py deleted file mode 100644 index 031ad5d08..000000000 --- a/sgl-router/py_test/test_launch_router.py +++ /dev/null @@ -1,354 +0,0 @@ -import multiprocessing -import time -import unittest -from types import SimpleNamespace - - -def terminate_process(process: multiprocessing.Process, timeout: float = 1.0) -> None: - """Terminate a process gracefully, with forced kill as fallback. - - Args: - process: The process to terminate - timeout: Seconds to wait for graceful termination before forcing kill - """ - if not process.is_alive(): - return - - process.terminate() - process.join(timeout=timeout) - if process.is_alive(): - process.kill() # Force kill if terminate didn't work - process.join() - - -class TestLaunchRouter(unittest.TestCase): - def setUp(self): - """Set up default arguments for router tests.""" - self.default_args = SimpleNamespace( - host="127.0.0.1", - port=30000, - policy="cache_aware", - worker_startup_timeout_secs=600, - worker_startup_check_interval=10, - cache_threshold=0.5, - balance_abs_threshold=32, - balance_rel_threshold=1.0001, - eviction_interval_secs=60, - max_tree_size=2**24, - max_payload_size=256 * 1024 * 1024, # 256MB - verbose=False, - log_dir=None, - log_level=None, - service_discovery=False, - selector=None, - service_discovery_port=80, - service_discovery_namespace=None, - dp_aware=False, - prometheus_port=None, - prometheus_host=None, - request_timeout_secs=60, - max_concurrent_requests=64, - cors_allowed_origins=[], - pd_disaggregation=False, - prefill=None, - decode=None, - worker_urls=[], - retry_max_retries=3, - retry_initial_backoff_ms=100, - retry_max_backoff_ms=10_000, - retry_backoff_multiplier=2.0, - retry_jitter_factor=0.1, - cb_failure_threshold=5, - cb_success_threshold=2, - cb_timeout_duration_secs=30, - cb_window_duration_secs=60, - disable_retries=False, - disable_circuit_breaker=False, - model_path=None, - tokenizer_path=None, - ) - - def create_router_args(self, **kwargs): - """Create router arguments by updating default args with provided kwargs.""" - args_dict = vars(self.default_args).copy() - args_dict.update(kwargs) - return SimpleNamespace(**args_dict) - - def run_router_process(self, args): - """Run router in a separate process and verify it starts successfully.""" - - def run_router(): - try: - from sglang_router.launch_router import launch_router - - router = launch_router(args) - if router is None: - return 1 - return 0 - except Exception as e: - print(e) - return 1 - - process = multiprocessing.Process(target=run_router) - try: - process.start() - # Wait 3 seconds - time.sleep(3) - # Process is still running means router started successfully - self.assertTrue(process.is_alive()) - finally: - terminate_process(process) - - def test_launch_router_common(self): - args = self.create_router_args(worker_urls=["http://localhost:8000"]) - self.run_router_process(args) - - def test_launch_router_with_empty_worker_urls(self): - args = self.create_router_args(worker_urls=[]) - self.run_router_process( - args - ) # Should start successfully with empty worker list - - def test_launch_router_with_service_discovery(self): - # Test router startup with service discovery enabled but no selectors - args = self.create_router_args( - worker_urls=[], service_discovery=True, selector=["app=test-worker"] - ) - self.run_router_process(args) - - def test_launch_router_with_service_discovery_namespace(self): - # Test router startup with service discovery enabled and namespace specified - args = self.create_router_args( - worker_urls=[], - service_discovery=True, - selector=["app=test-worker"], - service_discovery_namespace="test-namespace", - ) - self.run_router_process(args) - - def test_launch_router_common_with_dp_aware(self): - args = self.create_router_args( - worker_urls=["http://localhost:8000"], - dp_aware=True, - ) - self.run_router_process(args) - - def test_launch_router_with_empty_worker_urls_with_dp_aware(self): - args = self.create_router_args( - worker_urls=[], - dp_aware=True, - ) - self.run_router_process(args) - - def test_launch_router_common_with_dp_aware_service_discovery(self): - # Test launch router with bot srevice_discovery and dp_aware enabled - # Should fail since service_discovery and dp_aware is conflict - args = self.create_router_args( - worker_urls=["http://localhost:8000"], - dp_aware=True, - service_discovery=True, - selector=["app=test-worker"], - ) - - def run_router(): - try: - from sglang_router.launch_router import launch_router - - router = launch_router(args) - if router is None: - return 1 - return 0 - except Exception as e: - print(e) - return 1 - - process = multiprocessing.Process(target=run_router) - try: - process.start() - # Wait 3 seconds - time.sleep(3) - # Should fail since service_discovery and dp_aware is conflict - self.assertFalse(process.is_alive()) - finally: - terminate_process(process) - - def test_launch_router_pd_mode_basic(self): - """Test basic PD router functionality without actually starting servers.""" - # This test just verifies the PD router can be created and configured - # without actually starting it (which would require real prefill/decode servers) - from sglang_router.launch_router import RouterArgs - from sglang_router.router import PolicyType, Router - - # Test RouterArgs parsing for PD mode - # Simulate the parsed args structure from argparse with action="append" - args = self.create_router_args( - pd_disaggregation=True, - policy="power_of_two", # PowerOfTwo is only valid in PD mode - prefill=[ - ["http://prefill1:8080", "9000"], - ["http://prefill2:8080", "none"], - ], - decode=[ - ["http://decode1:8081"], - ["http://decode2:8081"], - ], - worker_urls=[], # Empty for PD mode - ) - - router_args = RouterArgs.from_cli_args(args) - self.assertTrue(router_args.pd_disaggregation) - self.assertEqual(router_args.policy, "power_of_two") - self.assertEqual(len(router_args.prefill_urls), 2) - self.assertEqual(len(router_args.decode_urls), 2) - - # Verify the parsed URLs and bootstrap ports - self.assertEqual(router_args.prefill_urls[0], ("http://prefill1:8080", 9000)) - self.assertEqual(router_args.prefill_urls[1], ("http://prefill2:8080", None)) - self.assertEqual(router_args.decode_urls[0], "http://decode1:8081") - self.assertEqual(router_args.decode_urls[1], "http://decode2:8081") - - # Test Router creation in PD mode - router = Router.from_args(router_args) - self.assertIsNotNone(router) - - def test_policy_validation(self): - """Test that policy validation works correctly for PD and regular modes.""" - from sglang_router.launch_router import RouterArgs, launch_router - - # Test 1: PowerOfTwo requires at least 2 workers - args = self.create_router_args( - pd_disaggregation=False, - policy="power_of_two", - worker_urls=["http://localhost:8000"], # Only 1 worker - ) - - # Should raise error - with self.assertRaises(ValueError) as cm: - launch_router(args) - self.assertIn( - "Power-of-two policy requires at least 2 workers", - str(cm.exception), - ) - - # Test 2: PowerOfTwo with sufficient workers should succeed - args = self.create_router_args( - pd_disaggregation=False, - policy="power_of_two", - worker_urls=["http://localhost:8000", "http://localhost:8001"], # 2 workers - ) - # This should not raise an error (validation passes) - - # Test 3: All policies now work in both modes - # Regular mode with RoundRobin - args = self.create_router_args( - pd_disaggregation=False, - policy="round_robin", - worker_urls=["http://localhost:8000"], - ) - # This should not raise validation error - - # PD mode with RoundRobin (now supported!) - args = self.create_router_args( - pd_disaggregation=True, - policy="round_robin", - prefill=[["http://prefill1:8080", "9000"]], - decode=[["http://decode1:8081"]], - worker_urls=[], - ) - # This should not raise validation error - - def test_pd_service_discovery_args_parsing(self): - """Test PD service discovery CLI argument parsing.""" - import argparse - - from sglang_router.launch_router import RouterArgs - - parser = argparse.ArgumentParser() - RouterArgs.add_cli_args(parser) - - args = parser.parse_args( - [ - "--pd-disaggregation", - "--service-discovery", - "--prefill-selector", - "app=sglang", - "component=prefill", - "--decode-selector", - "app=sglang", - "component=decode", - "--service-discovery-port", - "8000", - "--service-discovery-namespace", - "production", - "--policy", - "cache_aware", - ] - ) - - router_args = RouterArgs.from_cli_args(args) - - self.assertTrue(router_args.pd_disaggregation) - self.assertTrue(router_args.service_discovery) - self.assertEqual( - router_args.prefill_selector, {"app": "sglang", "component": "prefill"} - ) - self.assertEqual( - router_args.decode_selector, {"app": "sglang", "component": "decode"} - ) - self.assertEqual(router_args.service_discovery_port, 8000) - self.assertEqual(router_args.service_discovery_namespace, "production") - - def test_regular_service_discovery_args_parsing(self): - """Test regular mode service discovery CLI argument parsing.""" - import argparse - - from sglang_router.launch_router import RouterArgs - - parser = argparse.ArgumentParser() - RouterArgs.add_cli_args(parser) - - args = parser.parse_args( - [ - "--service-discovery", - "--selector", - "app=sglang-worker", - "environment=staging", - "--service-discovery-port", - "8000", - "--policy", - "round_robin", - ] - ) - - router_args = RouterArgs.from_cli_args(args) - - self.assertFalse(router_args.pd_disaggregation) - self.assertTrue(router_args.service_discovery) - self.assertEqual( - router_args.selector, {"app": "sglang-worker", "environment": "staging"} - ) - self.assertEqual(router_args.prefill_selector, {}) - self.assertEqual(router_args.decode_selector, {}) - - def test_empty_worker_urls_args_parsing(self): - """Test that router accepts no worker URLs and defaults to empty list.""" - import argparse - - from sglang_router.launch_router import RouterArgs - - parser = argparse.ArgumentParser() - RouterArgs.add_cli_args(parser) - - # Test with no --worker-urls argument at all - args = parser.parse_args(["--policy", "random", "--port", "30000"]) - router_args = RouterArgs.from_cli_args(args) - self.assertEqual(router_args.worker_urls, []) - - # Test with explicit empty --worker-urls - args = parser.parse_args(["--worker-urls", "--policy", "random"]) - router_args = RouterArgs.from_cli_args(args) - self.assertEqual(router_args.worker_urls, []) - - -if __name__ == "__main__": - unittest.main() diff --git a/sgl-router/py_test/test_launch_server.py b/sgl-router/py_test/test_launch_server.py deleted file mode 100644 index cdad0b9a1..000000000 --- a/sgl-router/py_test/test_launch_server.py +++ /dev/null @@ -1,735 +0,0 @@ -import socket -import subprocess -import time -import unittest -from types import SimpleNamespace - -import requests - -from sglang.srt.utils import kill_process_tree -from sglang.test.run_eval import run_eval -from sglang.test.test_utils import ( - DEFAULT_MODEL_NAME_FOR_TEST, - DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - DEFAULT_URL_FOR_TEST, -) - - -def popen_launch_router( - model: str, - base_url: str, - dp_size: int, - timeout: float, - policy: str = "cache_aware", - max_payload_size: int = None, - api_key: str = None, - log_dir: str = None, - service_discovery: bool = False, - selector: list = None, - service_discovery_port: int = 80, - service_discovery_namespace: str = None, - prometheus_port: int = None, - prometheus_host: str = None, - dp_aware: bool = False, - # Router retry/CB tuning (optional) - router_retry_max_retries: int = None, - router_retry_initial_backoff_ms: int = None, - router_retry_max_backoff_ms: int = None, - router_retry_backoff_multiplier: float = None, - router_retry_jitter_factor: float = None, - router_cb_failure_threshold: int = None, - router_cb_success_threshold: int = None, - router_cb_timeout_duration_secs: int = None, - router_cb_window_duration_secs: int = None, -): - """ - Launch the router server process. - - Args: - model: Model path/name - base_url: Server base URL - dp_size: Data parallel size - timeout: Server launch timeout - policy: Router policy, one of "cache_aware", "round_robin", "random" - max_payload_size: Maximum payload size in bytes - api_key: API key for the router - log_dir: Directory to store log files. If None, logs are only output to console. - service_discovery: Enable Kubernetes service discovery - selector: List of label selectors in format ["key1=value1", "key2=value2"] - service_discovery_port: Port to use for service discovery - service_discovery_namespace: Kubernetes namespace to watch for pods. If None, watches all namespaces. - prometheus_port: Port to expose Prometheus metrics. If None, Prometheus metrics are disabled. - prometheus_host: Host address to bind the Prometheus metrics server. - dp_aware: Enable data parallelism aware routing strategy. - """ - _, host, port = base_url.split(":") - host = host[2:] - - command = [ - "python3", - "-m", - "sglang_router.launch_server", - "--model-path", - model, - "--host", - host, - "--port", - port, - "--dp", - str(dp_size), - "--router-eviction-interval-secs", - "5", - "--router-policy", - policy, - "--allow-auto-truncate", - ] - - if api_key is not None: - command.extend(["--api-key", api_key]) - command.extend(["--router-api-key", api_key]) - - if max_payload_size is not None: - command.extend(["--router-max-payload-size", str(max_payload_size)]) - - if service_discovery: - command.append("--router-service-discovery") - - if selector: - command.extend(["--router-selector"] + selector) - - if service_discovery_port != 80: - command.extend(["--router-service-discovery-port", str(service_discovery_port)]) - - if service_discovery_namespace: - command.extend( - ["--router-service-discovery-namespace", service_discovery_namespace] - ) - - if prometheus_port is not None: - command.extend(["--router-prometheus-port", str(prometheus_port)]) - - if prometheus_host is not None: - command.extend(["--router-prometheus-host", prometheus_host]) - - if log_dir is not None: - command.extend(["--log-dir", log_dir]) - - if dp_aware: - command.append("--router-dp-aware") - - # Append router retry/CB tuning flags if provided - def _add(flag: str, val): - if val is not None: - command.extend([flag, str(val)]) - - _add("--router-retry-max-retries", router_retry_max_retries) - _add("--router-retry-initial-backoff-ms", router_retry_initial_backoff_ms) - _add("--router-retry-max-backoff-ms", router_retry_max_backoff_ms) - _add("--router-retry-backoff-multiplier", router_retry_backoff_multiplier) - _add("--router-retry-jitter-factor", router_retry_jitter_factor) - _add("--router-cb-failure-threshold", router_cb_failure_threshold) - _add("--router-cb-success-threshold", router_cb_success_threshold) - _add("--router-cb-timeout-duration-secs", router_cb_timeout_duration_secs) - _add("--router-cb-window-duration-secs", router_cb_window_duration_secs) - - process = subprocess.Popen(command, stdout=None, stderr=None) - - start_time = time.perf_counter() - with requests.Session() as session: - while time.perf_counter() - start_time < timeout: - try: - response = session.get(f"{base_url}/health") - if response.status_code == 200: - print(f"Router {base_url} is healthy") - return process - except requests.RequestException: - pass - time.sleep(10) - - raise TimeoutError("Router failed to start within the timeout period.") - - -def find_available_port(): - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] - - -def popen_launch_server( - model: str, - base_url: str, - timeout: float, - api_key: str = None, -): - _, host, port = base_url.split(":") - host = host[2:] - - command = [ - "python3", - "-m", - "sglang.launch_server", - "--model-path", - model, - "--host", - host, - "--port", - port, - "--base-gpu-id", - "1", - ] - - if api_key is not None: - command.extend(["--api-key", api_key]) - - process = subprocess.Popen(command, stdout=None, stderr=None) - - # intentionally don't wait and defer the job to the router health check - return process - - -def terminate_and_wait(process, timeout=300): - """Terminate a process and wait until it is terminated. - - Args: - process: subprocess.Popen object - timeout: maximum time to wait in seconds - - Raises: - TimeoutError: if process does not terminate within timeout - """ - if process is None: - return - - process.terminate() - start_time = time.perf_counter() - - while process.poll() is None: - print(f"Terminating process {process.pid}") - if time.perf_counter() - start_time > timeout: - raise TimeoutError( - f"Process {process.pid} failed to terminate within {timeout}s" - ) - time.sleep(1) - - print(f"Process {process.pid} is successfully terminated") - - -class TestLaunchServer(unittest.TestCase): - def setUp(self): - self.model = DEFAULT_MODEL_NAME_FOR_TEST - self.base_url = DEFAULT_URL_FOR_TEST - self.process = None - self.other_process = [] - - def tearDown(self): - print("Running tearDown...") - if self.process: - terminate_and_wait(self.process) - for process in self.other_process: - terminate_and_wait(process) - print("tearDown done") - - def test_1_mmlu(self): - print("Running test_1_mmlu...") - # DP size = 2 - self.process = popen_launch_router( - self.model, - self.base_url, - dp_size=2, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - policy="cache_aware", - ) - - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mmlu", - num_examples=64, - num_threads=32, - temperature=0.1, - ) - - metrics = run_eval(args) - score = metrics["score"] - THRESHOLD = 0.635 - passed = score >= THRESHOLD - msg = f"MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})" - self.assertGreaterEqual(score, THRESHOLD, msg) - - def test_2_add_and_remove_worker(self): - print("Running test_2_add_and_remove_worker...") - # DP size = 1 - self.process = popen_launch_router( - self.model, - self.base_url, - dp_size=1, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - policy="round_robin", # use round robin to make sure every worker processes requests - ) - # 1. start a worker - port = find_available_port() - worker_url = f"http://127.0.0.1:{port}" - worker_process = popen_launch_server( - self.model, worker_url, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH - ) - self.other_process.append(worker_process) - - # 2. use /add_worker api to add it to the router. It will be used by the router after it is healthy - with requests.Session() as session: - response = session.post(f"{self.base_url}/add_worker?url={worker_url}") - print(f"status code: {response.status_code}, response: {response.text}") - self.assertEqual(response.status_code, 200) - - # 3. run mmlu - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mmlu", - num_examples=64, - num_threads=32, - temperature=0.1, - ) - metrics = run_eval(args) - score = metrics["score"] - THRESHOLD = 0.635 - passed = score >= THRESHOLD - msg = f"MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})" - self.assertGreaterEqual(score, THRESHOLD, msg) - - # 4. use /remove_worker api to remove it from the router - with requests.Session() as session: - response = session.post(f"{self.base_url}/remove_worker?url={worker_url}") - print(f"status code: {response.status_code}, response: {response.text}") - self.assertEqual(response.status_code, 200) - - # 5. run mmlu again - metrics = run_eval(args) - score = metrics["score"] - THRESHOLD = 0.635 - passed = score >= THRESHOLD - msg = f"MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})" - self.assertGreaterEqual(score, THRESHOLD, msg) - - def test_3_lazy_fault_tolerance(self): - print("Running test_3_lazy_fault_tolerance...") - # DP size = 1 - self.process = popen_launch_router( - self.model, - self.base_url, - dp_size=1, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - policy="round_robin", - ) - - # 1. start a worker - port = find_available_port() - worker_url = f"http://127.0.0.1:{port}" - worker_process = popen_launch_server( - self.model, worker_url, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH - ) - self.other_process.append(worker_process) - - # 2. use /add_worker api to add it to the router. It will be used by the router after it is healthy - with requests.Session() as session: - response = session.post(f"{self.base_url}/add_worker?url={worker_url}") - print(f"status code: {response.status_code}, response: {response.text}") - self.assertEqual(response.status_code, 200) - - # Start a thread to kill the worker after 10 seconds to mimic abrupt worker failure - def kill_worker(): - time.sleep(10) - kill_process_tree(worker_process.pid) - print("Worker process killed") - - import threading - - kill_thread = threading.Thread(target=kill_worker) - kill_thread.daemon = True - kill_thread.start() - - # 3. run mmlu - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mmlu", - num_examples=256, - num_threads=32, - temperature=0.1, - ) - metrics = run_eval(args) - score = metrics["score"] - THRESHOLD = 0.635 - passed = score >= THRESHOLD - msg = f"MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})" - self.assertGreaterEqual(score, THRESHOLD, msg) - - def test_4_payload_size(self): - print("Running test_4_payload_size...") - # Start router with 1MB limit - self.process = popen_launch_router( - self.model, - self.base_url, - dp_size=1, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - policy="round_robin", - max_payload_size=1 * 1024 * 1024, # 1MB limit - ) - - # Test case 1: Payload just under 1MB should succeed - payload_0_5_mb = { - "text": "x" * int(0.5 * 1024 * 1024), # 0.5MB of text - "temperature": 0.0, - } - - with requests.Session() as session: - response = session.post( - f"{self.base_url}/generate", - json=payload_0_5_mb, - headers={"Content-Type": "application/json"}, - ) - self.assertEqual( - response.status_code, - 200, - f"0.5MB payload should succeed but got status {response.status_code}", - ) - - # Test case 2: Payload over 1MB should fail - payload_1_plus_mb = { - "text": "x" * int((1.2 * 1024 * 1024)), # 1.2MB of text - "temperature": 0.0, - } - - with requests.Session() as session: - response = session.post( - f"{self.base_url}/generate", - json=payload_1_plus_mb, - headers={"Content-Type": "application/json"}, - ) - self.assertEqual( - response.status_code, - 413, # Payload Too Large - f"1.2MB payload should fail with 413 but got status {response.status_code}", - ) - - def test_5_api_key(self): - print("Running test_5_api_key...") - - self.process = popen_launch_router( - self.model, - self.base_url, - dp_size=1, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - policy="round_robin", - api_key="correct_api_key", - ) - - # Test case 1: request without api key should fail - with requests.Session() as session: - response = session.post( - f"{self.base_url}/generate", - json={"text": "Kanye west is, ", "temperature": 0}, - ) - print(f"status code: {response.status_code}, response: {response.text}") - self.assertEqual( - response.status_code, - 401, - "Request without api key should fail with 401", - ) - - # Test case 2: request with invalid api key should fail - with requests.Session() as session: - response = requests.post( - f"{self.base_url}/generate", - json={"text": "Kanye west is, ", "temperature": 0}, - headers={"Authorization": "Bearer 123"}, - ) - print(f"status code: {response.status_code}, response: {response.text}") - self.assertEqual( - response.status_code, - 401, - "Request with invalid api key should fail with 401", - ) - - # Test case 3: request with correct api key should succeed - with requests.Session() as session: - response = session.post( - f"{self.base_url}/generate", - json={"text": "Kanye west is ", "temperature": 0}, - headers={"Authorization": "Bearer correct_api_key"}, - ) - print(f"status code: {response.status_code}, response: {response.text}") - self.assertEqual( - response.status_code, 200, "Request with correct api key should succeed" - ) - - def test_6_mmlu_with_dp_aware(self): - print("Running test_6_mmlu_with_dp_aware...") - # DP size = 2 - self.process = popen_launch_router( - self.model, - self.base_url, - dp_size=2, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - policy="cache_aware", - dp_aware=True, - ) - - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mmlu", - num_examples=64, - num_threads=32, - temperature=0.1, - ) - - metrics = run_eval(args) - score = metrics["score"] - THRESHOLD = 0.635 - passed = score >= THRESHOLD - msg = f"dp aware MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})" - self.assertGreaterEqual(score, THRESHOLD, msg) - - def test_7_add_and_remove_worker_with_dp_aware(self): - print("Running test_7_add_and_remove_worker_with_dp_aware...") - - # Set dp_size = 1 - self.process = popen_launch_router( - self.model, - self.base_url, - dp_size=1, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - policy="round_robin", # make sure every worker processes requests - dp_aware=True, # dp aware strategy should work well with RR - ) - - # 1. Start a worker - port = find_available_port() - worker_url = f"http://127.0.0.1:{port}" - worker_process = popen_launch_server( - self.model, worker_url, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH - ) - self.other_process.append(worker_process) - - # 2. Use the /add_worker API to add it to the router - # It will be used by router after it is healthy - with requests.Session() as session: - response = session.post(f"{self.base_url}/add_worker?url={worker_url}") - print(f"status code: {response.status_code}, response: {response.text}") - self.assertEqual(response.status_code, 200) - - # 3. Run mmlu - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mmlu", - num_examples=64, - num_threads=32, - temperature=0.1, - ) - metrics = run_eval(args) - score = metrics["score"] - THRESHOLD = 0.635 - passed = score >= THRESHOLD - msg = f"MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})" - self.assertGreaterEqual(score, THRESHOLD, msg) - - # 4. Use the /remove_worker API to remove it from the router - with requests.Session() as session: - response = session.post(f"{self.base_url}/remove_worker?url={worker_url}") - print(f"status code: {response.status_code}, response: {response.text}") - self.assertEqual(response.status_code, 200) - - # 5. Run mmlu again - metrics = run_eval(args) - score = metrics["score"] - THRESHOLD = 0.635 - passed = score >= THRESHOLD - msg = f"MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})" - self.assertGreaterEqual(score, THRESHOLD, msg) - - # 6. Start another worker with api_key set - terminate_and_wait(worker_process) # terminate the old worker process - port = find_available_port() - worker_url = f"http://127.0.0.1:{port}" - worker_process = popen_launch_server( - self.model, - worker_url, - DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - api_key="correct_api_key", - ) - self.other_process.append(worker_process) - - # 7. Use the /add_worker API to add it to the router - # Should fail since the router would contact the worker's - # /get_server_info endpoint for the dp_size info, but it - # has no knowledge of the api key - with requests.Session() as session: - response = session.post(f"{self.base_url}/add_worker?url={worker_url}") - print(f"status code: {response.status_code}, response: {response.text}") - self.assertNotEqual(response.status_code, 200) - - def test_8_lazy_fault_tolerance_with_dp_aware(self): - print("Running test_8_lazy_fault_tolerance_with_dp_aware...") - - # Set dp_size = 1 - self.process = popen_launch_router( - self.model, - self.base_url, - dp_size=1, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - policy="round_robin", - dp_aware=True, - ) - - # 1. Start a worker - port = find_available_port() - worker_url = f"http://127.0.0.1:{port}" - worker_process = popen_launch_server( - self.model, worker_url, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH - ) - self.other_process.append(worker_process) - - # 2. Use the /add_worker API to add it to the router - # It will be used by router after it is healthy - with requests.Session() as session: - response = session.post(f"{self.base_url}/add_worker?url={worker_url}") - print(f"status code: {response.status_code}, response: {response.text}") - self.assertEqual(response.status_code, 200) - - # Start a thread to kill the worker after 10 seconds to mimic - # abrupt worker failure - def kill_worker(): - time.sleep(10) - kill_process_tree(worker_process.pid) - print("Worker process killed") - - import threading - - kill_thread = threading.Thread(target=kill_worker) - kill_thread.daemon = True - kill_thread.start() - - # 3. Run mmlu - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mmlu", - num_examples=256, - num_threads=32, - temperature=0.1, - ) - metrics = run_eval(args) - score = metrics["score"] - THRESHOLD = 0.635 - passed = score >= THRESHOLD - msg = f"MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})" - self.assertGreaterEqual(score, THRESHOLD, msg) - - def test_9_payload_size_with_dp_aware(self): - print("Running test_9_payload_size_with_dp_aware...") - - # Start the router with 1MB limit - self.process = popen_launch_router( - self.model, - self.base_url, - dp_size=1, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - policy="round_robin", - max_payload_size=1 * 1024 * 1024, # 1MB limit - dp_aware=True, - ) - - # Test case 1: Payload just under 1MB should succeed - payload_0_5_mb = { - "text": "x" * int(0.5 * 1024 * 1024), # 0.5MB of text - "temperature": 0.0, - } - - with requests.Session() as session: - response = session.post( - f"{self.base_url}/generate", - json=payload_0_5_mb, - headers={"Content-Type": "application/json"}, - ) - self.assertEqual( - response.status_code, - 200, - f"0.5MB payload should succeed but got status {response.status_code}", - ) - - # Test case 2: Payload over 1MB should fail - payload_1_plus_mb = { - "text": "x" * int((1.2 * 1024 * 1024)), # 1.2MB of text - "temperature": 0.0, - } - - with requests.Session() as session: - response = session.post( - f"{self.base_url}/generate", - json=payload_1_plus_mb, - headers={"Content-Type": "application/json"}, - ) - self.assertEqual( - response.status_code, - 413, # Payload Too Large - f"1.2MB payload should fail with 413 but got status {response.status_code}", - ) - - def test_10_api_key_with_dp_aware(self): - print("Running test_10_api_key_with_dp_aware...") - - self.process = popen_launch_router( - self.model, - self.base_url, - dp_size=1, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - policy="round_robin", - api_key="correct_api_key", - dp_aware=True, - ) - - # Test case 1: request without api key should fail - with requests.Session() as session: - response = session.post( - f"{self.base_url}/generate", - json={"text": "Kanye west is, ", "temperature": 0}, - ) - print(f"status code: {response.status_code}, response: {response.text}") - self.assertEqual( - response.status_code, - 401, - f"Request without api key should fail with 401 but got status {response.status_code}", - ) - - # Test case 2: request with invalid api key should fail - with requests.Session() as session: - response = requests.post( - f"{self.base_url}/generate", - json={"text": "Kanye west is, ", "temperature": 0}, - headers={"Authorization": "Bearer 123"}, - ) - print(f"status code: {response.status_code}, response: {response.text}") - self.assertEqual( - response.status_code, - 401, - f"Request without api key should fail with 401 but got status {response.status_code}", - ) - - # Test case 3: request with correct api key should succeed - with requests.Session() as session: - response = session.post( - f"{self.base_url}/generate", - json={"text": "Kanye west is ", "temperature": 0}, - headers={"Authorization": "Bearer correct_api_key"}, - ) - print(f"status code: {response.status_code}, response: {response.text}") - self.assertEqual( - response.status_code, - 200, - f"Request with correct api key should succeed but got status {response.status_code}", - ) - - -if __name__ == "__main__": - unittest.main()