From 21b9a4b4353e3b148dc1c37a1cd137972ec3381f Mon Sep 17 00:00:00 2001 From: Keyang Ru Date: Fri, 5 Sep 2025 18:52:53 -0700 Subject: [PATCH] [router] Introduce router integration tests (#10086) --- .github/workflows/pr-test-rust.yml | 10 +- sgl-router/py_test/__init__.py | 1 + sgl-router/py_test/fixtures/__init__.py | 1 + sgl-router/py_test/fixtures/mock_worker.py | 248 ++++++++++++++++++ sgl-router/py_test/fixtures/ports.py | 8 + sgl-router/py_test/fixtures/router_manager.py | 158 +++++++++++ sgl-router/py_test/integration/__init__.py | 1 + sgl-router/py_test/integration/conftest.py | 109 ++++++++ .../integration/load_balancing/__init__.py | 1 + .../load_balancing/test_cache_aware.py | 73 ++++++ .../load_balancing/test_power_of_two.py | 89 +++++++ .../integration/load_balancing/test_random.py | 33 +++ .../load_balancing/test_round_robin.py | 34 +++ .../py_test/integration/test_api_auth.py | 38 +++ .../integration/test_circuit_breaker.py | 191 ++++++++++++++ .../integration/test_fault_tolerance.py | 36 +++ .../py_test/integration/test_pd_routing.py | 127 +++++++++ .../py_test/integration/test_rate_limiting.py | 91 +++++++ .../py_test/integration/test_retries.py | 65 +++++ .../test_service_discovery_shim.py | 36 +++ .../integration/test_worker_management.py | 61 +++++ sgl-router/py_test/run_suite.py | 7 + sgl-router/pytest.ini | 1 - 23 files changed, 1417 insertions(+), 2 deletions(-) create mode 100644 sgl-router/py_test/__init__.py create mode 100644 sgl-router/py_test/fixtures/__init__.py create mode 100644 sgl-router/py_test/fixtures/mock_worker.py create mode 100644 sgl-router/py_test/fixtures/ports.py create mode 100644 sgl-router/py_test/fixtures/router_manager.py create mode 100644 sgl-router/py_test/integration/__init__.py create mode 100644 sgl-router/py_test/integration/conftest.py create mode 100644 sgl-router/py_test/integration/load_balancing/__init__.py create mode 100644 sgl-router/py_test/integration/load_balancing/test_cache_aware.py create mode 100644 sgl-router/py_test/integration/load_balancing/test_power_of_two.py create mode 100644 sgl-router/py_test/integration/load_balancing/test_random.py create mode 100644 sgl-router/py_test/integration/load_balancing/test_round_robin.py create mode 100644 sgl-router/py_test/integration/test_api_auth.py create mode 100644 sgl-router/py_test/integration/test_circuit_breaker.py create mode 100644 sgl-router/py_test/integration/test_fault_tolerance.py create mode 100644 sgl-router/py_test/integration/test_pd_routing.py create mode 100644 sgl-router/py_test/integration/test_rate_limiting.py create mode 100644 sgl-router/py_test/integration/test_retries.py create mode 100644 sgl-router/py_test/integration/test_service_discovery_shim.py create mode 100644 sgl-router/py_test/integration/test_worker_management.py diff --git a/.github/workflows/pr-test-rust.yml b/.github/workflows/pr-test-rust.yml index 6c403b83b..ff54c5c32 100644 --- a/.github/workflows/pr-test-rust.yml +++ b/.github/workflows/pr-test-rust.yml @@ -95,7 +95,15 @@ jobs: cd sgl-router source "$HOME/.cargo/env" pip install pytest pytest-cov pytest-xdist - pytest -q py_test/unit + pytest -q py_test/unit --cov=sglang_router --cov-report=term-missing --cov-fail-under=80 + + - name: Run Python integration tests + run: | + cd sgl-router + source "$HOME/.cargo/env" + # Integration tests use FastAPI/uvicorn for mock workers + pip install fastapi uvicorn orjson + pytest -q -m integration - name: Run e2e test run: | diff --git a/sgl-router/py_test/__init__.py b/sgl-router/py_test/__init__.py new file mode 100644 index 000000000..893097780 --- /dev/null +++ b/sgl-router/py_test/__init__.py @@ -0,0 +1 @@ +"""Test package root for router Python tests.""" diff --git a/sgl-router/py_test/fixtures/__init__.py b/sgl-router/py_test/fixtures/__init__.py new file mode 100644 index 000000000..4ac754df8 --- /dev/null +++ b/sgl-router/py_test/fixtures/__init__.py @@ -0,0 +1 @@ +"""Shared fixtures for router integration tests.""" diff --git a/sgl-router/py_test/fixtures/mock_worker.py b/sgl-router/py_test/fixtures/mock_worker.py new file mode 100644 index 000000000..92d1e9a73 --- /dev/null +++ b/sgl-router/py_test/fixtures/mock_worker.py @@ -0,0 +1,248 @@ +""" +Lightweight mock worker HTTP server for router integration tests. + +Implements minimal endpoints used by the router: +- GET /health, /health_generate +- POST /generate, /v1/completions, /v1/chat/completions +- POST /flush_cache +- GET /get_server_info, /get_model_info, /v1/models + +Behavior knobs are controlled via CLI flags to simulate failures, latency, and load. +""" + +import argparse +import asyncio +import json +import os +import random +import signal +import sys +import time +from contextlib import asynccontextmanager +from typing import Optional + +import uvicorn +from fastapi import FastAPI, HTTPException, Request +from fastapi.responses import JSONResponse, PlainTextResponse, StreamingResponse + +# Global state (per-process) +_inflight = 0 +_failures_seen = 0 + + +def _parse_args() -> argparse.Namespace: + p = argparse.ArgumentParser() + p.add_argument("--host", default="127.0.0.1") + p.add_argument("--port", type=int, required=True) + p.add_argument("--worker-id", default=None) + p.add_argument("--latency-ms", type=int, default=0) + p.add_argument("--timeout", action="store_true") + p.add_argument("--status-code", type=int, default=200) + p.add_argument("--fail-first-n", type=int, default=0) + p.add_argument("--random-fail-rate", type=float, default=0.0) + p.add_argument("--require-api-key", action="store_true") + p.add_argument("--api-key", default=None) + p.add_argument("--max-payload-bytes", type=int, default=10 * 1024 * 1024) + p.add_argument("--stream", action="store_true") + p.add_argument("--crash-on-request", action="store_true") + p.add_argument("--health-fail-after-ms", type=int, default=0) + return p.parse_args() + + +def _extract_worker_id(args: argparse.Namespace) -> str: + if args.worker_id: + return str(args.worker_id) + # default to port (unique enough for tests) + return f"worker-{args.port}" + + +def create_app(args: argparse.Namespace) -> FastAPI: + app = FastAPI() + worker_id = _extract_worker_id(args) + start_ts = time.time() + crashed = {"done": False} + + async def maybe_delay(): + if args.latency_ms > 0: + await asyncio.sleep(args.latency_ms / 1000.0) + + def should_fail() -> Optional[int]: + global _failures_seen + # Fail first N requests (500) + if args.fail_first_n > 0 and _failures_seen < args.fail_first_n: + _failures_seen += 1 + return 500 + # Random failure probability (500) + if args.random_fail_rate > 0.0 and random.random() < args.random_fail_rate: + return 500 + # Forced status code override (non-200) for all responses + if args.status_code != 200: + return int(args.status_code) + return None + + def check_api_key(request: Request): + if not args.require_api_key: + return + auth = request.headers.get("Authorization") + if not auth or not auth.startswith("Bearer "): + raise HTTPException(status_code=401, detail="Unauthorized") + key = auth.split(" ", 1)[1] + if args.api_key and key != args.api_key: + raise HTTPException(status_code=401, detail="Unauthorized") + + @asynccontextmanager + async def track_inflight(): + global _inflight + _inflight += 1 + try: + yield + finally: + _inflight -= 1 + + @app.get("/health") + async def health(): + if ( + args.health_fail_after_ms + and (time.time() - start_ts) * 1000.0 >= args.health_fail_after_ms + ): + return PlainTextResponse("bad", status_code=500) + return PlainTextResponse("ok", status_code=200) + + @app.get("/health_generate") + async def health_generate(): + return PlainTextResponse("ok", status_code=200) + + @app.post("/flush_cache") + async def flush_cache(): + return PlainTextResponse("ok", status_code=200) + + @app.get("/get_model_info") + async def get_model_info(): + return JSONResponse({"model": "mock", "vocab_size": 32000}) + + @app.get("/v1/models") + async def list_models(): + return JSONResponse({"data": [{"id": "mock", "object": "model"}]}) + + @app.get("/get_server_info") + async def get_server_info(): + return JSONResponse( + { + "worker_id": worker_id, + "load_in_flight": _inflight, + "cache": {"size": 0, "hit_rate": 0.0}, + } + ) + + @app.get("/get_load") + async def get_load(): + return JSONResponse({"load": _inflight}) + + def make_json_response(obj: dict, status_code: int = 200) -> JSONResponse: + resp = JSONResponse(obj, status_code=status_code) + resp.headers["X-Worker-Id"] = worker_id + return resp + + async def handle_text_request(request: Request): + # Authorization + check_api_key(request) + + # Payload limit + body = await request.body() + if len(body) > args.max_payload_bytes: + return make_json_response({"error": "payload too large"}, status_code=413) + + # Simulate crash on first request + if args.crash_on_request and not crashed["done"]: + crashed["done"] = True + os._exit(1) + + # Optional timeout (simulate hang) + if args.timeout: + await asyncio.sleep(3600) + + # Optional latency + await maybe_delay() + + # Optional failures + fail_code = should_fail() + if fail_code is not None and fail_code != 200: + return make_json_response( + {"error": f"mock failure {fail_code}"}, status_code=fail_code + ) + + # Build response echoing minimal shape + try: + data = await request.json() + except (json.JSONDecodeError, ValueError): + data = {} + + now = time.time() + ret = { + "id": f"cmpl-{int(now*1000)}", + "object": "text_completion", + "created": int(now), + "model": "mock", + "choices": [ + { + "text": "ok", + "index": 0, + "finish_reason": "stop", + } + ], + "worker_id": worker_id, + "echo": data, + } + return make_json_response(ret, status_code=200) + + async def handle_stream_request(request: Request): + check_api_key(request) + + async def gen(): + # minimal 2-chunk stream then [DONE] + for i in range(2): + await asyncio.sleep(0.01) + chunk = { + "choices": [{"delta": {"content": "x"}}], + "worker_id": worker_id, + } + yield f"data: {json.dumps(chunk)}\n\n" + yield "data: [DONE]\n\n" + + headers = {"X-Worker-Id": worker_id} + return StreamingResponse(gen(), media_type="text/event-stream", headers=headers) + + @app.post("/generate") + async def generate(request: Request): + async with track_inflight(): + if args.stream: + return await handle_stream_request(request) + return await handle_text_request(request) + + @app.post("/v1/completions") + async def completions(request: Request): + async with track_inflight(): + if args.stream: + return await handle_stream_request(request) + return await handle_text_request(request) + + @app.post("/v1/chat/completions") + async def chat_completions(request: Request): + async with track_inflight(): + if args.stream: + return await handle_stream_request(request) + return await handle_text_request(request) + + return app + + +def main() -> None: + args = _parse_args() + app = create_app(args) + # Handle SIGTERM gracefully for fast test teardown + signal.signal(signal.SIGTERM, lambda *_: sys.exit(0)) + uvicorn.run(app, host=args.host, port=args.port, log_level="warning") + + +if __name__ == "__main__": + main() diff --git a/sgl-router/py_test/fixtures/ports.py b/sgl-router/py_test/fixtures/ports.py new file mode 100644 index 000000000..d616cffa1 --- /dev/null +++ b/sgl-router/py_test/fixtures/ports.py @@ -0,0 +1,8 @@ +import socket + + +def find_free_port() -> int: + """Return an available TCP port on localhost.""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("127.0.0.1", 0)) + return s.getsockname()[1] diff --git a/sgl-router/py_test/fixtures/router_manager.py b/sgl-router/py_test/fixtures/router_manager.py new file mode 100644 index 000000000..c536a0015 --- /dev/null +++ b/sgl-router/py_test/fixtures/router_manager.py @@ -0,0 +1,158 @@ +import subprocess +import time +from dataclasses import dataclass +from typing import Dict, List, Optional + +import requests + +from .ports import find_free_port + + +@dataclass +class ProcHandle: + process: subprocess.Popen + url: str + + +class RouterManager: + """Helper to spawn a router process and interact with admin endpoints.""" + + def __init__(self): + self._children: List[subprocess.Popen] = [] + + def start_router( + self, + worker_urls: Optional[List[str]] = None, + policy: str = "round_robin", + port: Optional[int] = None, + extra: Optional[Dict] = None, + # PD options + pd_disaggregation: bool = False, + prefill_urls: Optional[List[tuple]] = None, + decode_urls: Optional[List[str]] = None, + prefill_policy: Optional[str] = None, + decode_policy: Optional[str] = None, + ) -> ProcHandle: + worker_urls = worker_urls or [] + port = port or find_free_port() + cmd = [ + "python3", + "-m", + "sglang_router.launch_router", + "--host", + "127.0.0.1", + "--port", + str(port), + "--policy", + policy, + ] + # Avoid Prometheus port collisions by assigning a free port per router + prom_port = find_free_port() + cmd.extend( + ["--prometheus-port", str(prom_port), "--prometheus-host", "127.0.0.1"] + ) + if worker_urls: + cmd.extend(["--worker-urls", *worker_urls]) + + # PD routing configuration + if pd_disaggregation: + cmd.append("--pd-disaggregation") + if prefill_urls: + for url, bport in prefill_urls: + if bport is None: + cmd.extend(["--prefill", url, "none"]) + else: + cmd.extend(["--prefill", url, str(bport)]) + if decode_urls: + for url in decode_urls: + cmd.extend(["--decode", url]) + if prefill_policy: + cmd.extend(["--prefill-policy", prefill_policy]) + if decode_policy: + cmd.extend(["--decode-policy", decode_policy]) + + # Map supported extras to CLI flags (subset for integration) + if extra: + flag_map = { + "max_payload_size": "--max-payload-size", + "dp_aware": "--dp-aware", + "api_key": "--api-key", + # Health/monitoring + "worker_startup_check_interval": "--worker-startup-check-interval", + # Cache-aware tuning + "cache_threshold": "--cache-threshold", + "balance_abs_threshold": "--balance-abs-threshold", + "balance_rel_threshold": "--balance-rel-threshold", + # Retry + "retry_max_retries": "--retry-max-retries", + "retry_initial_backoff_ms": "--retry-initial-backoff-ms", + "retry_max_backoff_ms": "--retry-max-backoff-ms", + "retry_backoff_multiplier": "--retry-backoff-multiplier", + "retry_jitter_factor": "--retry-jitter-factor", + "disable_retries": "--disable-retries", + # Circuit breaker + "cb_failure_threshold": "--cb-failure-threshold", + "cb_success_threshold": "--cb-success-threshold", + "cb_timeout_duration_secs": "--cb-timeout-duration-secs", + "cb_window_duration_secs": "--cb-window-duration-secs", + "disable_circuit_breaker": "--disable-circuit-breaker", + # Rate limiting + "max_concurrent_requests": "--max-concurrent-requests", + "queue_size": "--queue-size", + "queue_timeout_secs": "--queue-timeout-secs", + "rate_limit_tokens_per_second": "--rate-limit-tokens-per-second", + } + for k, v in extra.items(): + if v is None: + continue + flag = flag_map.get(k) + if not flag: + continue + if isinstance(v, bool): + if v: + cmd.append(flag) + else: + cmd.extend([flag, str(v)]) + + proc = subprocess.Popen(cmd) + self._children.append(proc) + url = f"http://127.0.0.1:{port}" + self._wait_health(url) + return ProcHandle(process=proc, url=url) + + def _wait_health(self, base_url: str, timeout: float = 30.0): + start = time.time() + with requests.Session() as s: + while time.time() - start < timeout: + try: + r = s.get(f"{base_url}/health", timeout=2) + if r.status_code == 200: + return + except requests.RequestException: + pass + time.sleep(0.2) + raise TimeoutError(f"Router at {base_url} did not become healthy") + + def add_worker(self, base_url: str, worker_url: str) -> None: + r = requests.post(f"{base_url}/add_worker", params={"url": worker_url}) + assert r.status_code == 200, f"add_worker failed: {r.status_code} {r.text}" + + def remove_worker(self, base_url: str, worker_url: str) -> None: + r = requests.post(f"{base_url}/remove_worker", params={"url": worker_url}) + assert r.status_code == 200, f"remove_worker failed: {r.status_code} {r.text}" + + def list_workers(self, base_url: str) -> list[str]: + r = requests.get(f"{base_url}/list_workers") + assert r.status_code == 200, f"list_workers failed: {r.status_code} {r.text}" + data = r.json() + return data.get("urls", []) + + def stop_all(self): + for p in self._children: + if p.poll() is None: + p.terminate() + try: + p.wait(timeout=5) + except subprocess.TimeoutExpired: + p.kill() + self._children.clear() diff --git a/sgl-router/py_test/integration/__init__.py b/sgl-router/py_test/integration/__init__.py new file mode 100644 index 000000000..1e342eca0 --- /dev/null +++ b/sgl-router/py_test/integration/__init__.py @@ -0,0 +1 @@ +"""Integration test package for the router.""" diff --git a/sgl-router/py_test/integration/conftest.py b/sgl-router/py_test/integration/conftest.py new file mode 100644 index 000000000..21b9369d7 --- /dev/null +++ b/sgl-router/py_test/integration/conftest.py @@ -0,0 +1,109 @@ +import os +import subprocess +import time +from pathlib import Path +from typing import Dict, Iterable, List, Tuple + +import pytest +import requests + +from ..fixtures.ports import find_free_port +from ..fixtures.router_manager import RouterManager + + +def pytest_configure(config): + config.addinivalue_line("markers", "integration: mark as router integration test") + + +@pytest.fixture +def router_manager() -> Iterable[RouterManager]: + mgr = RouterManager() + try: + yield mgr + finally: + mgr.stop_all() + + +def _spawn_mock_worker(args: List[str]) -> Tuple[subprocess.Popen, str, str]: + repo_root = Path(__file__).resolve().parents[2] + script = repo_root / "py_test" / "fixtures" / "mock_worker.py" + port = find_free_port() + worker_id = f"worker-{port}" + base_cmd = [ + "python3", + str(script), + "--port", + str(port), + "--worker-id", + worker_id, + ] + cmd = base_cmd + args + proc = subprocess.Popen(cmd) + url = f"http://127.0.0.1:{port}" + _wait_health(url) + return proc, url, worker_id + + +def _wait_health(url: str, timeout: float = 10.0): + start = time.time() + with requests.Session() as s: + while time.time() - start < timeout: + try: + r = s.get(f"{url}/health", timeout=1) + if r.status_code == 200: + return + except requests.RequestException: + pass + time.sleep(0.1) + raise TimeoutError(f"Mock worker at {url} did not become healthy") + + +@pytest.fixture +def mock_worker(): + """Start a single healthy mock worker; yields (process, url, worker_id).""" + proc, url, worker_id = _spawn_mock_worker([]) + try: + yield proc, url, worker_id + finally: + if proc.poll() is None: + proc.terminate() + try: + proc.wait(timeout=3) + except subprocess.TimeoutExpired: + proc.kill() + + +@pytest.fixture +def mock_workers(): + """Factory to start N workers with custom args. + + Usage: + procs, urls, ids = mock_workers(n=3, args=["--latency-ms", "5"]) # same args for all + ... + """ + + procs: List[subprocess.Popen] = [] + + def _start(n: int, args: List[str] | None = None): + args = args or [] + new_procs: List[subprocess.Popen] = [] + urls: List[str] = [] + ids: List[str] = [] + for _ in range(n): + p, url, wid = _spawn_mock_worker(args) + procs.append(p) + new_procs.append(p) + urls.append(url) + ids.append(wid) + return new_procs, urls, ids + + try: + yield _start + finally: + for p in procs: + if p.poll() is None: + p.terminate() + try: + p.wait(timeout=3) + except subprocess.TimeoutExpired: + p.kill() diff --git a/sgl-router/py_test/integration/load_balancing/__init__.py b/sgl-router/py_test/integration/load_balancing/__init__.py new file mode 100644 index 000000000..77b8c2460 --- /dev/null +++ b/sgl-router/py_test/integration/load_balancing/__init__.py @@ -0,0 +1 @@ +"""Load balancing integration tests.""" diff --git a/sgl-router/py_test/integration/load_balancing/test_cache_aware.py b/sgl-router/py_test/integration/load_balancing/test_cache_aware.py new file mode 100644 index 000000000..acbbd3682 --- /dev/null +++ b/sgl-router/py_test/integration/load_balancing/test_cache_aware.py @@ -0,0 +1,73 @@ +import collections +import concurrent.futures +import uuid + +import pytest +import requests + + +@pytest.mark.integration +def test_cache_aware_affinity(mock_workers, router_manager): + # Two workers; same prompt should stick to one due to cache tree + _, urls, ids = mock_workers(n=2) + rh = router_manager.start_router(worker_urls=urls, policy="cache_aware") + + counts = collections.Counter() + with requests.Session() as s: + for i in range(12): + r = s.post( + f"{rh.url}/v1/completions", + json={ + "model": "test-model", + "prompt": "repeated prompt for cache", + "max_tokens": 1, + "stream": False, + }, + ) + assert r.status_code == 200 + wid = r.headers.get("X-Worker-Id") or r.json().get("worker_id") + counts[wid] += 1 + + # Expect strong skew toward one worker (tree match); majority > 80% + top = max(counts.values()) + assert top >= 10, counts + + +@pytest.mark.integration +def test_cache_aware_diverse_prompts_balances(mock_workers, router_manager): + # Add latency so concurrent requests overlap and influence load-based selection + _, urls, ids = mock_workers(n=3, args=["--latency-ms", "30"]) + rh = router_manager.start_router( + worker_urls=urls, + policy="cache_aware", + extra={ + "cache_threshold": 0.99, + "balance_abs_threshold": 0, + "balance_rel_threshold": 1.0, + }, + ) + + counts = collections.Counter() + + def call(i): + # Use diverse, unrelated prompts to avoid prefix matches entirely + prompt = str(uuid.uuid4()) + r = requests.post( + f"{rh.url}/v1/completions", + json={ + "model": "test-model", + "prompt": prompt, + "max_tokens": 1, + "stream": False, + }, + timeout=5, + ) + assert r.status_code == 200 + return r.headers.get("X-Worker-Id") or r.json().get("worker_id") + + with concurrent.futures.ThreadPoolExecutor(max_workers=16) as ex: + for wid in ex.map(call, range(40)): + counts[wid] += 1 + + # Expect participation of at least two workers + assert sum(1 for v in counts.values() if v > 0) >= 2, counts diff --git a/sgl-router/py_test/integration/load_balancing/test_power_of_two.py b/sgl-router/py_test/integration/load_balancing/test_power_of_two.py new file mode 100644 index 000000000..c56f4d38a --- /dev/null +++ b/sgl-router/py_test/integration/load_balancing/test_power_of_two.py @@ -0,0 +1,89 @@ +import collections +import concurrent.futures +import time + +import pytest +import requests + + +@pytest.mark.integration +def test_power_of_two_prefers_less_loaded(mock_workers, router_manager): + # Start two workers: one slow (higher inflight), one fast + # Router monitors /get_load and Power-of-Two uses cached loads to choose + # Start one slow and one fast worker using the fixture factory + procs_slow, urls_slow, ids_slow = mock_workers(n=1, args=["--latency-ms", "200"]) + procs_fast, urls_fast, ids_fast = mock_workers(n=1, args=["--latency-ms", "0"]) + procs = procs_slow + procs_fast + urls = urls_slow + urls_fast + ids = ids_slow + ids_fast + slow_id = ids_slow[0] + + rh = router_manager.start_router( + worker_urls=urls, + policy="power_of_two", + extra={"worker_startup_check_interval": 1}, + ) + + # Prime: fire a burst to create measurable load on slow worker, then wait for monitor tick + + def _prime_call(i): + try: + requests.post( + f"{rh.url}/v1/completions", + json={ + "model": "test-model", + "prompt": f"warm-{i}", + "max_tokens": 1, + "stream": False, + }, + timeout=5, + ) + except Exception: + pass + + with concurrent.futures.ThreadPoolExecutor(max_workers=32) as ex: + list(ex.map(_prime_call, range(128))) + time.sleep(2) + + # Apply direct background load on the slow worker to amplify load diff + def _direct_load(i): + try: + requests.post( + f"{slow_url}/v1/completions", + json={ + "model": "test-model", + "prompt": f"bg-{i}", + "max_tokens": 1, + "stream": False, + }, + timeout=5, + ) + except Exception: + pass + + with concurrent.futures.ThreadPoolExecutor(max_workers=32) as ex: + list(ex.map(_direct_load, range(128))) + time.sleep(1) + + def call(i): + r = requests.post( + f"{rh.url}/v1/completions", + json={ + "model": "test-model", + "prompt": f"p{i}", + "max_tokens": 1, + "stream": False, + }, + timeout=5, + ) + assert r.status_code == 200 + return r.headers.get("X-Worker-Id") or r.json().get("worker_id") + + counts = collections.Counter() + with concurrent.futures.ThreadPoolExecutor(max_workers=32) as ex: + for wid in ex.map(call, range(200)): + counts[wid] += 1 + + # Expect the slow worker (higher latency/inflight) to receive fewer requests + fast_worker_id = [i for i in ids if i != slow_id][0] + assert counts[slow_id] < counts[fast_worker_id], counts diff --git a/sgl-router/py_test/integration/load_balancing/test_random.py b/sgl-router/py_test/integration/load_balancing/test_random.py new file mode 100644 index 000000000..41a613e12 --- /dev/null +++ b/sgl-router/py_test/integration/load_balancing/test_random.py @@ -0,0 +1,33 @@ +import collections +import math + +import pytest +import requests + + +@pytest.mark.integration +def test_random_distribution(mock_workers, router_manager): + procs, urls, ids = mock_workers(n=4) + rh = router_manager.start_router(worker_urls=urls, policy="random") + + counts = collections.Counter() + N = 200 + with requests.Session() as s: + for i in range(N): + r = s.post( + f"{rh.url}/v1/completions", + json={ + "model": "test-model", + "prompt": f"p{i}", + "max_tokens": 1, + "stream": False, + }, + ) + assert r.status_code == 200 + wid = r.headers.get("X-Worker-Id") or r.json().get("worker_id") + counts[wid] += 1 + + # simple statistical tolerance: each worker should be within ±50% of mean + mean = N / len(ids) + for wid in ids: + assert 0.5 * mean <= counts[wid] <= 1.5 * mean, counts diff --git a/sgl-router/py_test/integration/load_balancing/test_round_robin.py b/sgl-router/py_test/integration/load_balancing/test_round_robin.py new file mode 100644 index 000000000..966f3747a --- /dev/null +++ b/sgl-router/py_test/integration/load_balancing/test_round_robin.py @@ -0,0 +1,34 @@ +import collections +import time + +import pytest +import requests + + +@pytest.mark.integration +def test_round_robin_distribution(mock_workers, router_manager): + procs, urls, ids = mock_workers(n=3) + + rh = router_manager.start_router(worker_urls=urls, policy="round_robin") + + counts = collections.Counter() + with requests.Session() as s: + for i in range(30): + r = s.post( + f"{rh.url}/v1/completions", + json={ + "model": "test-model", + "prompt": f"hello {i}", + "max_tokens": 1, + "stream": False, + }, + ) + assert r.status_code == 200 + wid = r.headers.get("X-Worker-Id") or r.json().get("worker_id") + assert wid in ids + counts[wid] += 1 + + # Expect near-even distribution across 3 workers + # 30 requests -> ideally 10 each; allow small tolerance ±3 + for wid in ids: + assert 7 <= counts[wid] <= 13, counts diff --git a/sgl-router/py_test/integration/test_api_auth.py b/sgl-router/py_test/integration/test_api_auth.py new file mode 100644 index 000000000..b8ba5c670 --- /dev/null +++ b/sgl-router/py_test/integration/test_api_auth.py @@ -0,0 +1,38 @@ +import pytest +import requests + + +@pytest.mark.integration +def test_router_api_key_enforcement(router_manager, mock_workers): + # Start backend requiring API key; router should forward Authorization header transparently + _, urls, _ = mock_workers( + n=1, args=["--require-api-key", "--api-key", "correct_api_key"] + ) + rh = router_manager.start_router( + worker_urls=urls, + policy="round_robin", + extra={}, + ) + + # No auth -> 401 + r = requests.post( + f"{rh.url}/v1/completions", + json={"model": "test-model", "prompt": "x", "max_tokens": 1, "stream": False}, + ) + assert r.status_code == 401 + + # Invalid auth -> 401 + r = requests.post( + f"{rh.url}/v1/completions", + json={"model": "test-model", "prompt": "x", "max_tokens": 1, "stream": False}, + headers={"Authorization": "Bearer wrong"}, + ) + assert r.status_code == 401 + + # Correct auth -> 200 + r = requests.post( + f"{rh.url}/v1/completions", + json={"model": "test-model", "prompt": "x", "max_tokens": 1, "stream": False}, + headers={"Authorization": "Bearer correct_api_key"}, + ) + assert r.status_code == 200 diff --git a/sgl-router/py_test/integration/test_circuit_breaker.py b/sgl-router/py_test/integration/test_circuit_breaker.py new file mode 100644 index 000000000..7e7ba409b --- /dev/null +++ b/sgl-router/py_test/integration/test_circuit_breaker.py @@ -0,0 +1,191 @@ +import time + +import pytest +import requests + + +@pytest.mark.integration +def test_circuit_breaker_opens_and_recovers(router_manager, mock_workers): + # A single worker that fails first 3 requests, then succeeds + _, [wurl], _ = mock_workers(n=1, args=["--fail-first-n", "3"]) # fails first 3 + rh = router_manager.start_router( + worker_urls=[wurl], + policy="round_robin", + extra={ + "cb_failure_threshold": 3, + "cb_success_threshold": 2, + "cb_timeout_duration_secs": 3, + "cb_window_duration_secs": 10, + "disable_retries": True, + }, + ) + + def post_once(): + return requests.post( + f"{rh.url}/v1/completions", + json={ + "model": "test-model", + "prompt": "trigger", + "max_tokens": 1, + "stream": False, + }, + timeout=3, + ) + + saw_503 = False + for _ in range(8): + r = post_once() + if r.status_code == 503: + saw_503 = True + break + assert saw_503, "circuit breaker did not open to return 503" + + time.sleep(4) + r1 = post_once() + r2 = post_once() + assert r1.status_code == 200 and r2.status_code == 200 + + +@pytest.mark.integration +def test_circuit_breaker_half_open_failure_reopens(router_manager, mock_workers): + _, [wurl], _ = mock_workers(n=1, args=["--status-code", "500"]) # always fail + rh = router_manager.start_router( + worker_urls=[wurl], + policy="round_robin", + extra={ + "cb_failure_threshold": 2, + "cb_success_threshold": 2, + "cb_timeout_duration_secs": 2, + "cb_window_duration_secs": 5, + "disable_retries": True, + }, + ) + + def post_once(): + return requests.post( + f"{rh.url}/v1/completions", + json={ + "model": "test-model", + "prompt": "x", + "max_tokens": 1, + "stream": False, + }, + timeout=3, + ) + + opened = False + for _ in range(8): + r = post_once() + if r.status_code == 503: + opened = True + break + assert opened, "circuit breaker did not open" + + time.sleep(3) + r = post_once() + assert r.status_code == 500 + r2 = post_once() + assert r2.status_code == 503 + + +@pytest.mark.integration +def test_circuit_breaker_disable_flag(router_manager, mock_workers): + _, [wurl], _ = mock_workers(n=1, args=["--status-code", "500"]) # always fail + rh = router_manager.start_router( + worker_urls=[wurl], + policy="round_robin", + extra={ + "disable_circuit_breaker": True, + "disable_retries": True, + }, + ) + r = requests.post( + f"{rh.url}/v1/completions", + json={ + "model": "test-model", + "prompt": "x", + "max_tokens": 1, + "stream": False, + }, + timeout=3, + ) + assert r.status_code == 500 + + +@pytest.mark.integration +def test_circuit_breaker_per_worker_isolation(router_manager, mock_workers): + _, [fail_url], _ = mock_workers(n=1, args=["--status-code", "500"]) # always fail + _, [ok_url], _ = mock_workers(n=1) + rh = router_manager.start_router( + worker_urls=[fail_url, ok_url], + policy="round_robin", + extra={ + "cb_failure_threshold": 2, + "cb_success_threshold": 1, + "cb_timeout_duration_secs": 2, + "cb_window_duration_secs": 10, + "disable_retries": True, + }, + ) + + def post_once(): + return requests.post( + f"{rh.url}/v1/completions", + json={ + "model": "test-model", + "prompt": "y", + "max_tokens": 1, + "stream": False, + }, + timeout=3, + ) + + failures = 0 + successes_after_open = 0 + opened = False + for _ in range(30): + r = post_once() + if not opened: + if r.status_code == 500: + failures += 1 + if failures >= 2: + _ = post_once() + _ = post_once() + opened = True + else: + if r.status_code == 200: + successes_after_open += 1 + else: + assert False, f"Unexpected non-200 after CB open: {r.status_code}" + assert opened and successes_after_open >= 5 + + +@pytest.mark.integration +def test_circuit_breaker_with_retries(router_manager, mock_workers): + _, [fail_url], _ = mock_workers(n=1, args=["--status-code", "500"]) # always fail + _, [ok_url], _ = mock_workers(n=1) + rh = router_manager.start_router( + worker_urls=[fail_url, ok_url], + policy="round_robin", + extra={ + "retry_max_retries": 3, + "retry_initial_backoff_ms": 10, + "retry_max_backoff_ms": 50, + "cb_failure_threshold": 2, + "cb_success_threshold": 1, + "cb_timeout_duration_secs": 2, + "cb_window_duration_secs": 10, + }, + ) + + r = requests.post( + f"{rh.url}/v1/completions", + json={ + "model": "test-model", + "prompt": "z", + "max_tokens": 1, + "stream": False, + }, + timeout=5, + ) + assert r.status_code == 200 diff --git a/sgl-router/py_test/integration/test_fault_tolerance.py b/sgl-router/py_test/integration/test_fault_tolerance.py new file mode 100644 index 000000000..78e5968ce --- /dev/null +++ b/sgl-router/py_test/integration/test_fault_tolerance.py @@ -0,0 +1,36 @@ +import concurrent.futures +import subprocess +import time + +import pytest +import requests + + +@pytest.mark.integration +def test_worker_crash_reroute_with_retries(router_manager, mock_workers): + # Start one healthy and one that will crash on first request + _, [ok_url], _ = mock_workers(n=1) + _, [crash_url], _ = mock_workers(n=1, args=["--crash-on-request"]) + rh = router_manager.start_router( + worker_urls=[crash_url, ok_url], + policy="round_robin", + extra={ + "retry_max_retries": 3, + "retry_initial_backoff_ms": 10, + "retry_max_backoff_ms": 50, + }, + ) + + # A single request should succeed via retry to the healthy worker + r = requests.post( + f"{rh.url}/v1/completions", + json={ + "model": "test-model", + "prompt": "crash", + "max_tokens": 1, + "stream": False, + }, + timeout=5, + ) + assert r.status_code == 200 + # mock_workers fixture handles cleanup diff --git a/sgl-router/py_test/integration/test_pd_routing.py b/sgl-router/py_test/integration/test_pd_routing.py new file mode 100644 index 000000000..d0ae7d552 --- /dev/null +++ b/sgl-router/py_test/integration/test_pd_routing.py @@ -0,0 +1,127 @@ +import collections +import concurrent.futures +import subprocess +import time + +import pytest +import requests + + +@pytest.mark.integration +def test_pd_power_of_two_decode_attribution(router_manager, mock_workers): + # Start two prefill and three decode mock workers via fixture + _, prefill_urls_raw, prefill_ids = mock_workers(n=2) + _, decode_urls_raw, decode_ids_list = mock_workers(n=3) + prefill_urls = [(u, None) for u in prefill_urls_raw] + decode_urls = list(decode_urls_raw) + decode_ids = set(decode_ids_list) + + rh = router_manager.start_router( + policy="power_of_two", + pd_disaggregation=True, + prefill_urls=prefill_urls, + decode_urls=decode_urls, + extra={"worker_startup_check_interval": 1}, + ) + + counts = collections.Counter() + with requests.Session() as s: + for i in range(30): + r = s.post( + f"{rh.url}/v1/completions", + json={ + "model": "test-model", + "prompt": f"p{i}", + "max_tokens": 1, + "stream": False, + }, + ) + assert r.status_code == 200 + wid = r.headers.get("X-Worker-Id") or r.json().get("worker_id") + assert wid in decode_ids + counts[wid] += 1 + + assert sum(1 for v in counts.values() if v > 0) >= 2 + + +@pytest.mark.integration +def test_pd_power_of_two_skews_to_faster_decode(router_manager, mock_workers): + # Start two prefill workers (fast) + _, prefill_urls_raw, _ = mock_workers(n=2) + + # Start two decode workers: one slow, one fast + _, [decode_slow_url], [slow_id] = mock_workers( + n=1, args=["--latency-ms", "300"] + ) # slower decode + _, [decode_fast_url], [fast_id] = mock_workers(n=1) + decode_urls_raw = [decode_slow_url, decode_fast_url] + + prefill_urls = [(u, None) for u in prefill_urls_raw] + decode_urls = list(decode_urls_raw) + + rh = router_manager.start_router( + policy="power_of_two", + pd_disaggregation=True, + prefill_urls=prefill_urls, + decode_urls=decode_urls, + extra={"worker_startup_check_interval": 1}, + ) + + def _prime_call(i): + try: + requests.post( + f"{rh.url}/v1/completions", + json={ + "model": "test-model", + "prompt": f"warm-{i}", + "max_tokens": 1, + "stream": False, + }, + timeout=8, + ) + except Exception: + pass + + with concurrent.futures.ThreadPoolExecutor(max_workers=32) as ex: + list(ex.map(_prime_call, range(128))) + time.sleep(2) + + def _direct_decode_load(i): + try: + requests.post( + f"{decode_slow_url}/v1/completions", + json={ + "model": "test-model", + "prompt": f"bg-{i}", + "max_tokens": 1, + "stream": False, + }, + timeout=8, + ) + except Exception: + pass + + with concurrent.futures.ThreadPoolExecutor(max_workers=32) as ex: + list(ex.map(_direct_decode_load, range(128))) + time.sleep(1) + + def call(i): + r = requests.post( + f"{rh.url}/v1/completions", + json={ + "model": "test-model", + "prompt": f"p{i}", + "max_tokens": 1, + "stream": False, + }, + timeout=8, + ) + assert r.status_code == 200 + return r.headers.get("X-Worker-Id") or r.json().get("worker_id") + + counts = collections.Counter() + with concurrent.futures.ThreadPoolExecutor(max_workers=32) as ex: + for wid in ex.map(call, range(200)): + counts[wid] += 1 + + assert counts[slow_id] < counts[fast_id], counts diff --git a/sgl-router/py_test/integration/test_rate_limiting.py b/sgl-router/py_test/integration/test_rate_limiting.py new file mode 100644 index 000000000..4297d77c9 --- /dev/null +++ b/sgl-router/py_test/integration/test_rate_limiting.py @@ -0,0 +1,91 @@ +import concurrent.futures +import time + +import pytest +import requests + + +@pytest.mark.integration +def test_rate_limit_and_queue(router_manager, mock_workers): + # One fast backend + _, urls, _ = mock_workers(n=1) + rh = router_manager.start_router( + worker_urls=urls, + policy="round_robin", + extra={ + "max_concurrent_requests": 2, + "queue_size": 0, # no queue -> immediate 429 when limit exceeded + }, + ) + + def call_once(i): + try: + r = requests.post( + f"{rh.url}/v1/completions", + json={ + "model": "test-model", + "prompt": f"p{i}", + "max_tokens": 1, + "stream": False, + }, + timeout=3, + ) + return r.status_code + except Exception: + return 599 + + # Fire a burst of concurrent requests + with concurrent.futures.ThreadPoolExecutor(max_workers=16) as ex: + results = list(ex.map(call_once, range(16))) + + # Expect some to succeed and some to be rate limited (429) + assert any(code == 200 for code in results) + assert any(code == 429 for code in results) + + +@pytest.mark.integration +def test_rate_limit_queue_and_timeout(router_manager, mock_workers): + # Slow backend: ~2s per request ensures queue wait > timeout + _, urls, _ = mock_workers(n=1, args=["--latency-ms", "2000"]) # 2.0s per request + + # Allow 1 concurrent, queue up to 1, with 1s queue timeout + rh = router_manager.start_router( + worker_urls=urls, + policy="round_robin", + extra={ + "max_concurrent_requests": 1, + "queue_size": 1, + "queue_timeout_secs": 1, + }, + ) + + def call_once(i): + try: + r = requests.post( + f"{rh.url}/v1/completions", + json={ + "model": "test-model", + "prompt": f"q{i}", + "max_tokens": 1, + "stream": False, + }, + timeout=5, + ) + return r.status_code + except Exception: + return 599 + + # Fire 4 concurrent requests: 1 runs (~2s), 1 queued (times out at 1s -> 408), 2 overflow -> 429 + import concurrent.futures + + with concurrent.futures.ThreadPoolExecutor(max_workers=8) as ex: + results = list(ex.map(call_once, range(4))) + + # We expect: + # - Some 200s (processed) + # - At least one 408 (queued too long and timed out) + # - Remaining non-200s are either 429 (queue overflow) or additional 408s depending on scheduling + assert any(code == 200 for code in results) + assert any(code == 408 for code in results), results + non200 = [c for c in results if c != 200] + assert len(non200) >= 2 and all(c in (408, 429) for c in non200), results diff --git a/sgl-router/py_test/integration/test_retries.py b/sgl-router/py_test/integration/test_retries.py new file mode 100644 index 000000000..5f3d4ffee --- /dev/null +++ b/sgl-router/py_test/integration/test_retries.py @@ -0,0 +1,65 @@ +import concurrent.futures +import subprocess +import time + +import pytest +import requests + + +@pytest.mark.integration +def test_retry_reroutes_to_healthy_worker(router_manager, mock_workers): + # Worker A always 500; Worker B healthy + # Worker A always 500; Worker B/C healthy + _, [url_a], [id_a] = mock_workers(n=1, args=["--status-code", "500"]) # fail + _, [url_b], [id_b] = mock_workers(n=1) + _, [url_c], [id_c] = mock_workers(n=1) + rh = router_manager.start_router( + worker_urls=[url_a, url_b, url_c], + policy="round_robin", + extra={ + "retry_max_retries": 3, + "retry_initial_backoff_ms": 10, + "retry_max_backoff_ms": 50, + }, + ) + + r = requests.post( + f"{rh.url}/v1/completions", + json={ + "model": "test-model", + "prompt": "x", + "max_tokens": 1, + "stream": False, + }, + timeout=5, + ) + assert r.status_code == 200 + wid = r.headers.get("X-Worker-Id") or r.json().get("worker_id") + assert wid == id_b # should have retried onto healthy worker + # mock_workers fixture handles cleanup + + +@pytest.mark.integration +def test_disable_retries_surfaces_failure(router_manager, mock_workers): + # Single failing worker, retries disabled -> should return 500 + _, [url], [wid] = mock_workers(n=1, args=["--status-code", "500"]) # always fail + rh = router_manager.start_router( + worker_urls=[url], + policy="round_robin", + extra={ + "disable_retries": True, + }, + ) + + r = requests.post( + f"{rh.url}/v1/completions", + json={ + "model": "test-model", + "prompt": "x", + "max_tokens": 1, + "stream": False, + }, + timeout=5, + ) + assert r.status_code == 500 + # mock_workers fixture handles cleanup diff --git a/sgl-router/py_test/integration/test_service_discovery_shim.py b/sgl-router/py_test/integration/test_service_discovery_shim.py new file mode 100644 index 000000000..5cc1d6734 --- /dev/null +++ b/sgl-router/py_test/integration/test_service_discovery_shim.py @@ -0,0 +1,36 @@ +import pytest +import requests + + +@pytest.mark.integration +def test_discovery_shim_add_remove(router_manager, mock_workers): + # Start router without workers + rh = router_manager.start_router(worker_urls=[], policy="round_robin") + + # Initially empty + urls = router_manager.list_workers(rh.url) + assert urls == [] + + # Add a worker (simulate discovery event) + _, [wurl], [wid] = mock_workers(n=1) + router_manager.add_worker(rh.url, wurl) + urls = router_manager.list_workers(rh.url) + assert wurl in urls + + # Can serve a request + r = requests.post( + f"{rh.url}/v1/completions", + json={ + "model": "test-model", + "prompt": "hi", + "max_tokens": 1, + "stream": False, + }, + ) + assert r.status_code == 200 + + # Remove worker (simulate pod deletion) + router_manager.remove_worker(rh.url, wurl) + urls = router_manager.list_workers(rh.url) + assert wurl not in urls + # mock_workers fixture handles cleanup diff --git a/sgl-router/py_test/integration/test_worker_management.py b/sgl-router/py_test/integration/test_worker_management.py new file mode 100644 index 000000000..8acb94114 --- /dev/null +++ b/sgl-router/py_test/integration/test_worker_management.py @@ -0,0 +1,61 @@ +import collections +import subprocess +import time + +import pytest +import requests + + +@pytest.mark.integration +def test_add_and_remove_worker(mock_worker, router_manager, mock_workers): + # Start with a single worker + proc1, url1, id1 = mock_worker + rh = router_manager.start_router(worker_urls=[url1], policy="round_robin") + + # Add a second worker + + procs2, urls2, ids2 = mock_workers(n=1) + url2 = urls2[0] + id2 = ids2[0] + router_manager.add_worker(rh.url, url2) + + # Send some requests and ensure both workers are seen + seen = set() + with requests.Session() as s: + for i in range(20): + r = s.post( + f"{rh.url}/v1/completions", + json={ + "model": "test-model", + "prompt": f"x{i}", + "max_tokens": 1, + "stream": False, + }, + ) + assert r.status_code == 200 + wid = r.headers.get("X-Worker-Id") or r.json().get("worker_id") + seen.add(wid) + if len(seen) == 2: + break + + assert id1 in seen and id2 in seen + + # Now remove the second worker + router_manager.remove_worker(rh.url, url2) + + # After removal, subsequent requests should only come from first worker + with requests.Session() as s: + for i in range(10): + r = s.post( + f"{rh.url}/v1/completions", + json={ + "model": "test-model", + "prompt": f"y{i}", + "max_tokens": 1, + "stream": False, + }, + ) + assert r.status_code == 200 + wid = r.headers.get("X-Worker-Id") or r.json().get("worker_id") + assert wid == id1 + # mock_workers fixture handles cleanup diff --git a/sgl-router/py_test/run_suite.py b/sgl-router/py_test/run_suite.py index ac7f9c140..195c2b36e 100644 --- a/sgl-router/py_test/run_suite.py +++ b/sgl-router/py_test/run_suite.py @@ -14,6 +14,13 @@ if __name__ == "__main__": args = arg_parser.parse_args() files = glob.glob("**/test_*.py", recursive=True) + # Exclude integration tests from the e2e suite; those are run separately via pytest -m integration + files = [ + f + for f in files + if "/integration/" not in f and not f.startswith("integration/") + ] + files.sort() test_files = [TestFile(name=file) for file in files] exit_code = run_unittest_files(test_files, args.timeout_per_file) diff --git a/sgl-router/pytest.ini b/sgl-router/pytest.ini index d28b847e6..c9f400753 100644 --- a/sgl-router/pytest.ini +++ b/sgl-router/pytest.ini @@ -3,4 +3,3 @@ testpaths = py_test python_files = test_*.py python_classes = Test* python_functions = test_* -addopts = --cov=sglang_router --cov-report=term-missing