sglangv0.5.2 & support Qwen3-Next-80B-A3B-Instruct

This commit is contained in:
maxiao1
2025-09-13 17:00:20 +08:00
commit 118f1fc726
2037 changed files with 515371 additions and 0 deletions

View File

@@ -0,0 +1 @@
"""Test package root for router Python tests."""

View File

@@ -0,0 +1,8 @@
import sys
from pathlib import Path
# Ensure local sources in py_src are importable ahead of any installed package
_ROOT = Path(__file__).resolve().parents[1]
_SRC = _ROOT / "py_src"
if str(_SRC) not in sys.path:
sys.path.insert(0, str(_SRC))

View File

@@ -0,0 +1,512 @@
import json
import logging
import os
import shutil
import signal
import socket
import subprocess
import time
from pathlib import Path
from types import SimpleNamespace
from typing import Callable, Optional
from urllib.parse import urlparse
import pytest
import requests
from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
)
logger = logging.getLogger(__name__)
def _find_available_port() -> int:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("127.0.0.1", 0))
return s.getsockname()[1]
def _parse_url(base_url: str) -> tuple[str, str]:
"""Parse a base URL and return (host, port) as strings.
This is more robust than simple string splitting and supports different schemes
and URL shapes like trailing paths.
"""
parsed = urlparse(base_url)
return parsed.hostname or "127.0.0.1", (
str(parsed.port) if parsed.port is not None else ""
)
def _wait_router_health(base_url: str, timeout: float) -> None:
start = time.perf_counter()
with requests.Session() as session:
while time.perf_counter() - start < timeout:
try:
r = session.get(f"{base_url}/health", timeout=5)
if r.status_code == 200:
return
except requests.RequestException:
pass
time.sleep(2)
raise TimeoutError("Router failed to become healthy in time")
def _popen_launch_router(
model: str,
base_url: str,
dp_size: int,
timeout: float,
policy: str = "cache_aware",
) -> subprocess.Popen:
host, port = _parse_url(base_url)
prom_port = _find_available_port()
cmd = [
"python3",
"-m",
"sglang_router.launch_server",
"--model-path",
model,
"--host",
host,
"--port",
port,
"--dp",
str(dp_size),
"--router-policy",
policy,
"--allow-auto-truncate",
"--router-prometheus-port",
str(prom_port),
"--router-prometheus-host",
"127.0.0.1",
]
proc = subprocess.Popen(cmd)
_wait_router_health(base_url, timeout)
return proc
def _popen_launch_worker(
model: str,
base_url: str,
*,
dp_size: int | None = None,
api_key: str | None = None,
base_gpu_id: int | None = 0,
) -> subprocess.Popen:
host, port = _parse_url(base_url)
cmd = [
"python3",
"-m",
"sglang.launch_server",
"--model-path",
model,
"--host",
host,
"--port",
port,
"--base-gpu-id",
str(base_gpu_id or 0),
]
if dp_size is not None:
cmd += ["--dp-size", str(dp_size)]
if api_key is not None:
cmd += ["--api-key", api_key]
return subprocess.Popen(cmd)
def _popen_launch_router_only(
base_url: str,
policy: str = "round_robin",
timeout: float = 120.0,
*,
dp_aware: bool = False,
api_key: str | None = None,
) -> subprocess.Popen:
host, port = _parse_url(base_url)
prom_port = _find_available_port()
cmd = [
"python3",
"-m",
"sglang_router.launch_router",
"--host",
host,
"--port",
port,
"--policy",
policy,
]
if dp_aware:
cmd += ["--dp-aware"]
if api_key is not None:
cmd += ["--api-key", api_key]
cmd += [
"--prometheus-port",
str(prom_port),
"--prometheus-host",
"127.0.0.1",
]
proc = subprocess.Popen(cmd)
_wait_router_health(base_url, timeout)
return proc
def _terminate(proc: subprocess.Popen, timeout: float = 120) -> None:
if proc is None:
return
proc.terminate()
start = time.perf_counter()
while proc.poll() is None:
if time.perf_counter() - start > timeout:
proc.kill()
break
time.sleep(1)
def _which(cmd: str) -> Optional[str]:
try:
return shutil.which(cmd)
except Exception as e:
logger.warning("shutil.which(%r) failed: %s", cmd, e)
return None
def _graceful_stop_popen(p: subprocess.Popen) -> None:
if p is None:
return
try:
if p.poll() is None:
p.terminate()
for _ in range(5):
if p.poll() is not None:
break
time.sleep(1)
if p.poll() is None:
p.kill()
except Exception as e:
logger.warning("Exception during graceful stop of popen: %s", e)
def _pid_alive(pid: int) -> bool:
try:
os.kill(pid, 0)
return True
except Exception:
return False
def _graceful_stop_pid(pid: int) -> None:
try:
if _pid_alive(pid):
try:
os.kill(pid, signal.SIGTERM)
except Exception:
pass
for _ in range(5):
if not _pid_alive(pid):
break
time.sleep(1)
if _pid_alive(pid):
try:
os.kill(pid, signal.SIGKILL)
except Exception:
pass
except Exception:
pass
def _graceful_stop_any(obj) -> None:
try:
if isinstance(obj, subprocess.Popen):
_graceful_stop_popen(obj)
return
if isinstance(obj, int):
_graceful_stop_pid(obj)
return
proc_obj = getattr(obj, "proc", None)
if isinstance(proc_obj, subprocess.Popen):
_graceful_stop_popen(proc_obj)
except Exception:
pass
@pytest.fixture(scope="session")
def genai_bench_runner() -> Callable[..., None]:
"""Provide a callable to run genai-bench and validate metrics.
Usage in tests:
def test(..., genai_bench_runner):
genai_bench_runner(router_url=..., model_path=..., experiment_folder=...)
"""
def _run(
*,
router_url: str,
model_path: str,
experiment_folder: str,
timeout_sec: int | None = None,
thresholds: dict | None = None,
extra_env: dict | None = None,
num_concurrency: int = 32,
traffic_scenario: str = "D(4000,100)",
max_requests_per_run: int | None = None,
clean_experiment: bool = True,
kill_procs: list | None = None,
drain_delay_sec: int = 6,
) -> None:
cli = _which("genai-bench")
if not cli:
pytest.fail(
"genai-bench CLI not found; please install it to run benchmarks"
)
# Clean previous experiment folder under current working directory
if clean_experiment:
exp_dir = Path.cwd() / experiment_folder
if exp_dir.exists():
shutil.rmtree(exp_dir, ignore_errors=True)
# Default requests per run if not provided
mrr = (
max_requests_per_run
if max_requests_per_run is not None
else num_concurrency * 3
)
cmd = [
cli,
"benchmark",
"--api-backend",
"openai",
"--api-base",
router_url,
"--api-key",
"dummy-token",
"--api-model-name",
model_path,
"--model-tokenizer",
model_path,
"--task",
"text-to-text",
"--num-concurrency",
str(num_concurrency),
"--traffic-scenario",
traffic_scenario,
"--max-requests-per-run",
str(mrr),
"--max-time-per-run",
"2",
"--experiment-folder-name",
experiment_folder,
"--experiment-base-dir",
str(Path.cwd()),
]
env = os.environ.copy()
if extra_env:
env.update(extra_env)
to = timeout_sec or int(os.environ.get("GENAI_BENCH_TEST_TIMEOUT", "120"))
proc = subprocess.Popen(
cmd, env=env, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True
)
stdout = stderr = ""
rc = None
try:
try:
stdout, stderr = proc.communicate(timeout=to)
except subprocess.TimeoutExpired:
# Simple: kill the CLI process if it doesn't exit in time
try:
proc.kill()
except Exception:
pass
stdout, stderr = proc.communicate()
rc = proc.returncode
# Prefer exact path under cwd; fallback to rglob search
base = Path.cwd()
direct = base / experiment_folder
candidates = [direct] if direct.is_dir() else []
if not candidates:
for p in base.rglob(experiment_folder):
if p.is_dir() and p.name == experiment_folder:
candidates = [p]
break
if not candidates:
raise AssertionError(
"Benchmark failed: experiment folder not found: "
f"{experiment_folder}\nExit code: {rc}\nSTDOUT (tail):\n{stdout[-1000:]}\nSTDERR (tail):\n{stderr[-1000:]}"
)
actual_folder = candidates[0]
json_files = [
p
for p in actual_folder.rglob("*.json")
if "experiment_metadata" not in p.name
]
if not json_files:
raise AssertionError(
"Benchmark failed: no JSON results found\n"
f"Exit code: {rc}\nSTDOUT (tail):\n{stdout[-1000:]}\nSTDERR (tail):\n{stderr[-1000:]}"
)
th = thresholds # None means "log only", no validation
for jf in json_files:
with jf.open("r") as f:
data = json.load(f)
stats = data.get("aggregated_metrics", {}).get("stats", {})
ttft_mean = float(stats.get("ttft", {}).get("mean", float("inf")))
e2e_latency_mean = float(
stats.get("e2e_latency", {}).get("mean", float("inf"))
)
input_tp_mean = float(stats.get("input_throughput", {}).get("mean", 0.0))
output_tp_mean = float(stats.get("output_throughput", {}).get("mean", 0.0))
logger.info(
"genai-bench[%s] %s ttft_mean=%.3fs e2e_latency_mean=%.3fs input_tp_mean=%.1f tok/s output_tp_mean=%.1f tok/s",
experiment_folder,
jf.name,
ttft_mean,
e2e_latency_mean,
input_tp_mean,
output_tp_mean,
)
if th is not None:
assert (
ttft_mean <= th["ttft_mean_max"]
), f"TTFT validation failed: {ttft_mean} > {th['ttft_mean_max']} (file={jf.name})"
assert (
e2e_latency_mean <= th["e2e_latency_mean_max"]
), f"E2E latency validation failed: {e2e_latency_mean} > {th['e2e_latency_mean_max']} (file={jf.name})"
assert (
input_tp_mean >= th["input_throughput_mean_min"]
), f"Input throughput validation failed: {input_tp_mean} < {th['input_throughput_mean_min']} (file={jf.name})"
assert (
output_tp_mean >= th["output_throughput_mean_min"]
), f"Output throughput validation failed: {output_tp_mean} < {th['output_throughput_mean_min']} (file={jf.name})"
finally:
# Always attempt to stop workers to avoid resource leakage
if kill_procs:
# Give router/workers a small grace period to finish any last drains
if drain_delay_sec > 0:
try:
time.sleep(drain_delay_sec)
except Exception:
pass
for p in kill_procs:
_graceful_stop_any(p)
try:
time.sleep(2)
except Exception:
pass
return _run
def pytest_configure(config):
config.addinivalue_line("markers", "e2e: mark as end-to-end test")
@pytest.fixture(scope="session")
def e2e_model() -> str:
# Always use the default test model
return DEFAULT_MODEL_NAME_FOR_TEST
@pytest.fixture
def e2e_router(e2e_model: str):
# Keep this available but tests below use router-only to avoid GPU contention
base_url = DEFAULT_URL_FOR_TEST
proc = _popen_launch_router(
e2e_model, base_url, dp_size=2, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
)
try:
yield SimpleNamespace(proc=proc, url=base_url)
finally:
_terminate(proc)
@pytest.fixture
def e2e_router_only_rr():
port = _find_available_port()
base_url = f"http://127.0.0.1:{port}"
proc = _popen_launch_router_only(base_url, policy="round_robin")
try:
yield SimpleNamespace(proc=proc, url=base_url)
finally:
_terminate(proc)
@pytest.fixture(scope="session")
def e2e_primary_worker(e2e_model: str):
port = _find_available_port()
base_url = f"http://127.0.0.1:{port}"
proc = _popen_launch_worker(e2e_model, base_url)
# Router health gate will handle worker readiness
try:
yield SimpleNamespace(proc=proc, url=base_url)
finally:
_terminate(proc)
@pytest.fixture
def e2e_router_only_rr_dp_aware_api():
"""Router-only with dp-aware enabled and an API key."""
port = _find_available_port()
base_url = f"http://127.0.0.1:{port}"
api_key = "secret"
proc = _popen_launch_router_only(
base_url, policy="round_robin", timeout=180.0, dp_aware=True, api_key=api_key
)
try:
yield SimpleNamespace(proc=proc, url=base_url, api_key=api_key)
finally:
_terminate(proc)
@pytest.fixture
def e2e_worker_dp2_api(e2e_model: str, e2e_router_only_rr_dp_aware_api):
"""Worker with dp-size=2 and the same API key as the dp-aware router."""
port = _find_available_port()
base_url = f"http://127.0.0.1:{port}"
api_key = e2e_router_only_rr_dp_aware_api.api_key
proc = _popen_launch_worker(e2e_model, base_url, dp_size=2, api_key=api_key)
try:
yield SimpleNamespace(proc=proc, url=base_url)
finally:
_terminate(proc)
@pytest.fixture(scope="session")
def e2e_two_workers_dp2(e2e_model: str):
"""Launch two workers, each with dp_size=2, mapped to GPUs [0,1] and [2,3]."""
workers = []
try:
# Worker A on GPUs 0-1
port_a = _find_available_port()
url_a = f"http://127.0.0.1:{port_a}"
proc_a = _popen_launch_worker(e2e_model, url_a, dp_size=2, base_gpu_id=0)
workers.append(SimpleNamespace(proc=proc_a, url=url_a))
# Worker B on GPUs 2-3
port_b = _find_available_port()
url_b = f"http://127.0.0.1:{port_b}"
proc_b = _popen_launch_worker(e2e_model, url_b, dp_size=2, base_gpu_id=2)
workers.append(SimpleNamespace(proc=proc_b, url=url_b))
yield workers
finally:
for w in workers:
_terminate(w.proc)

View File

@@ -0,0 +1,262 @@
import logging
import os
import socket
import subprocess
import time
from types import SimpleNamespace
from typing import Optional
import pytest
import requests
from sglang.test.run_eval import run_eval
logger = logging.getLogger(__name__)
def _find_available_port() -> int:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("127.0.0.1", 0))
return s.getsockname()[1]
def _wait_health(url: str, timeout: float = 180.0) -> None:
start = time.perf_counter()
with requests.Session() as session:
while time.perf_counter() - start < timeout:
try:
r = session.get(f"{url}/health", timeout=5)
if r.status_code == 200:
return
except requests.RequestException:
pass
time.sleep(1)
raise TimeoutError(f"Service at {url} failed to become healthy in time")
def _detect_ib_device() -> Optional[str]:
"""Return first active IB device name (e.g., mlx5_0) or None if unavailable."""
# Fast check that ibv_devinfo exists
try:
subprocess.run(
["ibv_devinfo", "-l"],
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
timeout=1,
)
except (FileNotFoundError, subprocess.TimeoutExpired):
return None
for i in range(12):
dev = f"mlx5_{i}"
try:
res = subprocess.run(
["ibv_devinfo", dev],
capture_output=True,
text=True,
timeout=2,
)
if res.returncode == 0 and ("state:" in res.stdout):
for line in res.stdout.splitlines():
if "state:" in line and "PORT_ACTIVE" in line:
return dev
except Exception:
pass
return None
def _popen_launch_prefill_worker(
model: str,
bootstrap_port: int,
ib_device: Optional[str] = None,
base_gpu_id: int = 0,
) -> SimpleNamespace:
port = _find_available_port()
url = f"http://127.0.0.1:{port}"
cmd = [
"python3",
"-m",
"sglang.launch_server",
"--model-path",
model,
"--disaggregation-mode",
"prefill",
"--host",
"127.0.0.1",
"--port",
str(port),
"--disaggregation-bootstrap-port",
str(bootstrap_port),
"--base-gpu-id",
str(base_gpu_id),
]
if ib_device:
cmd += ["--disaggregation-ib-device", ib_device]
proc = subprocess.Popen(cmd)
_wait_health(url, timeout=300.0)
return SimpleNamespace(proc=proc, url=url, bootstrap_port=bootstrap_port)
def _popen_launch_decode_worker(
model: str, ib_device: Optional[str] = None, base_gpu_id: int = 0
) -> SimpleNamespace:
port = _find_available_port()
url = f"http://127.0.0.1:{port}"
cmd = [
"python3",
"-m",
"sglang.launch_server",
"--model-path",
model,
"--disaggregation-mode",
"decode",
"--host",
"127.0.0.1",
"--port",
str(port),
"--base-gpu-id",
str(base_gpu_id),
]
if ib_device:
cmd += ["--disaggregation-ib-device", ib_device]
proc = subprocess.Popen(cmd)
_wait_health(url, timeout=300.0)
return SimpleNamespace(proc=proc, url=url)
def _terminate(proc: subprocess.Popen, timeout: float = 120) -> None:
if proc is None:
return
proc.terminate()
start = time.perf_counter()
while proc.poll() is None:
if time.perf_counter() - start > timeout:
proc.kill()
break
time.sleep(1)
@pytest.fixture(scope="module")
def pd_cluster(e2e_model: str):
"""Start 2 prefill + 2 decode workers and one PD router, once per module."""
# Environment capability checks: require sgl_kernel and GPU backend
try:
import sgl_kernel # noqa: F401
except Exception as e: # pragma: no cover - environment dependent
pytest.fail(f"PD e2e requires sgl_kernel but it is not available: {e}")
try:
import torch # noqa: F401
except Exception as e: # pragma: no cover - environment dependent
pytest.fail(
f"PD e2e requires torch but it is not available or misconfigured: {e}"
)
if not torch.cuda.is_available(): # pragma: no cover - environment dependent
pytest.fail("PD e2e requires CUDA backend, but CUDA is not available")
workers: list[SimpleNamespace] = []
router_proc = None
try:
ib_device = _detect_ib_device()
# Launch 4 workers across 4 GPUs: prefill on 0,1 and decode on 2,3
pf1 = _popen_launch_prefill_worker(
e2e_model,
bootstrap_port=_find_available_port(),
ib_device=ib_device,
base_gpu_id=0,
)
pf2 = _popen_launch_prefill_worker(
e2e_model,
bootstrap_port=_find_available_port(),
ib_device=ib_device,
base_gpu_id=1,
)
dc1 = _popen_launch_decode_worker(e2e_model, ib_device=ib_device, base_gpu_id=2)
dc2 = _popen_launch_decode_worker(e2e_model, ib_device=ib_device, base_gpu_id=3)
prefills = [pf1, pf2]
decodes = [dc1, dc2]
workers.extend(prefills + decodes)
# PD router with two prefill and two decode endpoints
rport = _find_available_port()
router_url = f"http://127.0.0.1:{rport}"
pport = _find_available_port()
prefill = [(pf.url, pf.bootstrap_port) for pf in prefills]
decode = [dc.url for dc in decodes]
cmd = [
"python3",
"-m",
"sglang_router.launch_router",
"--host",
"127.0.0.1",
"--port",
str(rport),
"--policy",
"round_robin",
"--pd-disaggregation",
]
for url, bport in prefill:
cmd += ["--prefill", url, str(bport)]
for url in decode:
cmd += ["--decode", url]
cmd += [
"--prometheus-port",
str(pport),
"--prometheus-host",
"127.0.0.1",
]
router_proc = subprocess.Popen(cmd)
_wait_health(router_url, timeout=180.0)
yield SimpleNamespace(
router_url=router_url, workers=workers, router_proc=router_proc
)
finally:
if router_proc is not None:
_terminate(router_proc)
for w in workers:
_terminate(w.proc)
@pytest.mark.e2e
def test_pd_mmlu(e2e_model: str, pd_cluster):
"""
Launch 4 workers, start a PD router (2 prefill + 2 decode), then run MMLU.
"""
args = SimpleNamespace(
base_url=pd_cluster.router_url,
model=e2e_model,
eval_name="mmlu",
num_examples=64,
num_threads=32,
temperature=0.1,
)
metrics = run_eval(args)
assert metrics["score"] >= 0.65
@pytest.mark.e2e
def test_pd_genai_bench(e2e_model: str, pd_cluster, genai_bench_runner):
"""
Launch 4 workers, start a PD router (2 prefill + 2 decode), then run a
short genai-bench benchmark and validate aggregate metrics.
"""
# Run genai-bench against the shared router
policy_label = "benchmark_round_robin_pd"
genai_bench_runner(
router_url=pd_cluster.router_url,
model_path=e2e_model,
experiment_folder=policy_label,
thresholds={
"ttft_mean_max": 12,
"e2e_latency_mean_max": 15,
"input_throughput_mean_min": 400,
"output_throughput_mean_min": 20,
},
kill_procs=pd_cluster.workers,
)

View File

@@ -0,0 +1,169 @@
import threading
import time
from types import SimpleNamespace
import pytest
import requests
from sglang.test.run_eval import run_eval
@pytest.mark.e2e
def test_mmlu(e2e_router_only_rr, e2e_two_workers_dp2, e2e_model):
# Attach two dp=2 workers (total 4 GPUs) to a fresh router-only instance
base = e2e_router_only_rr.url
for w in e2e_two_workers_dp2:
r = requests.post(f"{base}/add_worker", params={"url": w.url}, timeout=180)
r.raise_for_status()
args = SimpleNamespace(
base_url=base,
model=e2e_model,
eval_name="mmlu",
num_examples=64,
num_threads=32,
temperature=0.1,
)
metrics = run_eval(args)
assert metrics["score"] >= 0.65
@pytest.mark.e2e
def test_genai_bench(
e2e_router_only_rr, e2e_two_workers_dp2, e2e_model, genai_bench_runner
):
"""Attach a worker to the regular router and run a short genai-bench."""
base = e2e_router_only_rr.url
for w in e2e_two_workers_dp2:
r = requests.post(f"{base}/add_worker", params={"url": w.url}, timeout=180)
r.raise_for_status()
genai_bench_runner(
router_url=base,
model_path=e2e_model,
experiment_folder="benchmark_round_robin_regular",
thresholds={
"ttft_mean_max": 6,
"e2e_latency_mean_max": 14,
"input_throughput_mean_min": 1000,
"output_throughput_mean_min": 12,
},
kill_procs=e2e_two_workers_dp2,
)
@pytest.mark.e2e
def test_add_and_remove_worker_live(e2e_router_only_rr, e2e_primary_worker, e2e_model):
base = e2e_router_only_rr.url
worker_url = e2e_primary_worker.url
r = requests.post(f"{base}/add_worker", params={"url": worker_url}, timeout=180)
r.raise_for_status()
with requests.Session() as s:
for i in range(8):
r = s.post(
f"{base}/v1/completions",
json={
"model": e2e_model,
"prompt": f"x{i}",
"max_tokens": 1,
"stream": False,
},
timeout=120,
)
r.raise_for_status()
# Remove the worker
r = requests.post(f"{base}/remove_worker", params={"url": worker_url}, timeout=60)
r.raise_for_status()
@pytest.mark.e2e
def test_lazy_fault_tolerance_live(e2e_router_only_rr, e2e_primary_worker, e2e_model):
base = e2e_router_only_rr.url
worker = e2e_primary_worker
r = requests.post(f"{base}/add_worker", params={"url": worker.url}, timeout=180)
r.raise_for_status()
def killer():
time.sleep(10)
try:
worker.proc.terminate()
except Exception:
pass
t = threading.Thread(target=killer, daemon=True)
t.start()
args = SimpleNamespace(
base_url=base,
model=e2e_model,
eval_name="mmlu",
num_examples=32,
num_threads=16,
temperature=0.0,
)
metrics = run_eval(args)
assert 0.0 <= metrics["score"] <= 1.0
@pytest.mark.e2e
def test_dp_aware_worker_expansion_and_api_key(
e2e_model,
e2e_router_only_rr_dp_aware_api,
e2e_worker_dp2_api,
):
"""
Launch a router-only instance in dp_aware mode and a single worker with dp_size=2
and API key protection. Verify expansion, auth enforcement, and basic eval.
"""
import os
router_url = e2e_router_only_rr_dp_aware_api.url
worker_url = e2e_worker_dp2_api.url
api_key = e2e_router_only_rr_dp_aware_api.api_key
# Attach worker; router should expand to dp_size logical workers
r = requests.post(
f"{router_url}/add_worker", params={"url": worker_url}, timeout=180
)
r.raise_for_status()
r = requests.get(f"{router_url}/list_workers", timeout=30)
r.raise_for_status()
urls = r.json().get("urls", [])
assert len(urls) == 2
assert set(urls) == {f"{worker_url}@0", f"{worker_url}@1"}
# Verify API key enforcement path-through
# 1) Without Authorization -> 401 from backend
r = requests.post(
f"{router_url}/v1/completions",
json={"model": e2e_model, "prompt": "hi", "max_tokens": 1},
timeout=60,
)
assert r.status_code == 401
# 2) With correct Authorization -> 200
r = requests.post(
f"{router_url}/v1/completions",
json={"model": e2e_model, "prompt": "hi", "max_tokens": 1},
headers={"Authorization": f"Bearer {api_key}"},
timeout=60,
)
assert r.status_code == 200
# Finally, run MMLU eval through the router with auth
os.environ["OPENAI_API_KEY"] = api_key
args = SimpleNamespace(
base_url=router_url,
model=e2e_model,
eval_name="mmlu",
num_examples=64,
num_threads=32,
temperature=0.1,
)
metrics = run_eval(args)
assert metrics["score"] >= 0.65

View File

@@ -0,0 +1 @@
"""Shared fixtures for router integration tests."""

View File

@@ -0,0 +1,252 @@
"""
Lightweight mock worker HTTP server for router integration tests.
Implements minimal endpoints used by the router:
- GET /health, /health_generate
- POST /generate, /v1/completions, /v1/chat/completions
- POST /flush_cache
- GET /get_server_info, /get_model_info, /v1/models
Behavior knobs are controlled via CLI flags to simulate failures, latency, and load.
"""
import argparse
import asyncio
import json
import os
import random
import signal
import sys
import time
from contextlib import asynccontextmanager
from typing import Optional
import uvicorn
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse, PlainTextResponse, StreamingResponse
# Global state (per-process)
_inflight = 0
_failures_seen = 0
def _parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser()
p.add_argument("--host", default="127.0.0.1")
p.add_argument("--port", type=int, required=True)
p.add_argument("--worker-id", default=None)
p.add_argument("--latency-ms", type=int, default=0)
p.add_argument("--timeout", action="store_true")
p.add_argument("--status-code", type=int, default=200)
p.add_argument("--fail-first-n", type=int, default=0)
p.add_argument("--random-fail-rate", type=float, default=0.0)
p.add_argument("--require-api-key", action="store_true")
p.add_argument("--api-key", default=None)
p.add_argument("--max-payload-bytes", type=int, default=10 * 1024 * 1024)
p.add_argument("--stream", action="store_true")
p.add_argument("--dp-size", type=int, default=1)
p.add_argument("--crash-on-request", action="store_true")
p.add_argument("--health-fail-after-ms", type=int, default=0)
return p.parse_args()
def _extract_worker_id(args: argparse.Namespace) -> str:
if args.worker_id:
return str(args.worker_id)
# default to port (unique enough for tests)
return f"worker-{args.port}"
def create_app(args: argparse.Namespace) -> FastAPI:
app = FastAPI()
worker_id = _extract_worker_id(args)
start_ts = time.time()
crashed = {"done": False}
async def maybe_delay():
if args.latency_ms > 0:
await asyncio.sleep(args.latency_ms / 1000.0)
def should_fail() -> Optional[int]:
global _failures_seen
# Fail first N requests (500)
if args.fail_first_n > 0 and _failures_seen < args.fail_first_n:
_failures_seen += 1
return 500
# Random failure probability (500)
if args.random_fail_rate > 0.0 and random.random() < args.random_fail_rate:
return 500
# Forced status code override (non-200) for all responses
if args.status_code != 200:
return int(args.status_code)
return None
def check_api_key(request: Request):
if not args.require_api_key:
return
auth = request.headers.get("Authorization")
if not auth or not auth.startswith("Bearer "):
raise HTTPException(status_code=401, detail="Unauthorized")
key = auth.split(" ", 1)[1]
if args.api_key and key != args.api_key:
raise HTTPException(status_code=401, detail="Unauthorized")
@asynccontextmanager
async def track_inflight():
global _inflight
_inflight += 1
try:
yield
finally:
_inflight -= 1
@app.get("/health")
async def health():
if (
args.health_fail_after_ms
and (time.time() - start_ts) * 1000.0 >= args.health_fail_after_ms
):
return PlainTextResponse("bad", status_code=500)
return PlainTextResponse("ok", status_code=200)
@app.get("/health_generate")
async def health_generate():
return PlainTextResponse("ok", status_code=200)
@app.post("/flush_cache")
async def flush_cache():
return PlainTextResponse("ok", status_code=200)
@app.get("/get_model_info")
async def get_model_info():
return JSONResponse({"model": "mock", "vocab_size": 32000})
@app.get("/v1/models")
async def list_models():
return JSONResponse({"data": [{"id": "mock", "object": "model"}]})
@app.get("/get_server_info")
async def get_server_info(request: Request):
# Enforce API key on server info when required (used by dp_aware probing)
check_api_key(request)
return JSONResponse(
{
"worker_id": worker_id,
"load_in_flight": _inflight,
"cache": {"size": 0, "hit_rate": 0.0},
"dp_size": int(args.dp_size),
}
)
@app.get("/get_load")
async def get_load():
return JSONResponse({"load": _inflight})
def make_json_response(obj: dict, status_code: int = 200) -> JSONResponse:
resp = JSONResponse(obj, status_code=status_code)
resp.headers["X-Worker-Id"] = worker_id
return resp
async def handle_text_request(request: Request):
# Authorization
check_api_key(request)
# Payload limit
body = await request.body()
if len(body) > args.max_payload_bytes:
return make_json_response({"error": "payload too large"}, status_code=413)
# Simulate crash on first request
if args.crash_on_request and not crashed["done"]:
crashed["done"] = True
os._exit(1)
# Optional timeout (simulate hang)
if args.timeout:
await asyncio.sleep(3600)
# Optional latency
await maybe_delay()
# Optional failures
fail_code = should_fail()
if fail_code is not None and fail_code != 200:
return make_json_response(
{"error": f"mock failure {fail_code}"}, status_code=fail_code
)
# Build response echoing minimal shape
try:
data = await request.json()
except (json.JSONDecodeError, ValueError):
data = {}
now = time.time()
ret = {
"id": f"cmpl-{int(now*1000)}",
"object": "text_completion",
"created": int(now),
"model": "mock",
"choices": [
{
"text": "ok",
"index": 0,
"finish_reason": "stop",
}
],
"worker_id": worker_id,
"echo": data,
}
return make_json_response(ret, status_code=200)
async def handle_stream_request(request: Request):
check_api_key(request)
async def gen():
# minimal 2-chunk stream then [DONE]
for i in range(2):
await asyncio.sleep(0.01)
chunk = {
"choices": [{"delta": {"content": "x"}}],
"worker_id": worker_id,
}
yield f"data: {json.dumps(chunk)}\n\n"
yield "data: [DONE]\n\n"
headers = {"X-Worker-Id": worker_id}
return StreamingResponse(gen(), media_type="text/event-stream", headers=headers)
@app.post("/generate")
async def generate(request: Request):
async with track_inflight():
if args.stream:
return await handle_stream_request(request)
return await handle_text_request(request)
@app.post("/v1/completions")
async def completions(request: Request):
async with track_inflight():
if args.stream:
return await handle_stream_request(request)
return await handle_text_request(request)
@app.post("/v1/chat/completions")
async def chat_completions(request: Request):
async with track_inflight():
if args.stream:
return await handle_stream_request(request)
return await handle_text_request(request)
return app
def main() -> None:
args = _parse_args()
app = create_app(args)
# Handle SIGTERM gracefully for fast test teardown
signal.signal(signal.SIGTERM, lambda *_: sys.exit(0))
uvicorn.run(app, host=args.host, port=args.port, log_level="warning")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,8 @@
import socket
def find_free_port() -> int:
"""Return an available TCP port on localhost."""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("127.0.0.1", 0))
return s.getsockname()[1]

View File

@@ -0,0 +1,158 @@
import subprocess
import time
from dataclasses import dataclass
from typing import Dict, List, Optional
import requests
from .ports import find_free_port
@dataclass
class ProcHandle:
process: subprocess.Popen
url: str
class RouterManager:
"""Helper to spawn a router process and interact with admin endpoints."""
def __init__(self):
self._children: List[subprocess.Popen] = []
def start_router(
self,
worker_urls: Optional[List[str]] = None,
policy: str = "round_robin",
port: Optional[int] = None,
extra: Optional[Dict] = None,
# PD options
pd_disaggregation: bool = False,
prefill_urls: Optional[List[tuple]] = None,
decode_urls: Optional[List[str]] = None,
prefill_policy: Optional[str] = None,
decode_policy: Optional[str] = None,
) -> ProcHandle:
worker_urls = worker_urls or []
port = port or find_free_port()
cmd = [
"python3",
"-m",
"sglang_router.launch_router",
"--host",
"127.0.0.1",
"--port",
str(port),
"--policy",
policy,
]
# Avoid Prometheus port collisions by assigning a free port per router
prom_port = find_free_port()
cmd.extend(
["--prometheus-port", str(prom_port), "--prometheus-host", "127.0.0.1"]
)
if worker_urls:
cmd.extend(["--worker-urls", *worker_urls])
# PD routing configuration
if pd_disaggregation:
cmd.append("--pd-disaggregation")
if prefill_urls:
for url, bport in prefill_urls:
if bport is None:
cmd.extend(["--prefill", url, "none"])
else:
cmd.extend(["--prefill", url, str(bport)])
if decode_urls:
for url in decode_urls:
cmd.extend(["--decode", url])
if prefill_policy:
cmd.extend(["--prefill-policy", prefill_policy])
if decode_policy:
cmd.extend(["--decode-policy", decode_policy])
# Map supported extras to CLI flags (subset for integration)
if extra:
flag_map = {
"max_payload_size": "--max-payload-size",
"dp_aware": "--dp-aware",
"api_key": "--api-key",
# Health/monitoring
"worker_startup_check_interval": "--worker-startup-check-interval",
# Cache-aware tuning
"cache_threshold": "--cache-threshold",
"balance_abs_threshold": "--balance-abs-threshold",
"balance_rel_threshold": "--balance-rel-threshold",
# Retry
"retry_max_retries": "--retry-max-retries",
"retry_initial_backoff_ms": "--retry-initial-backoff-ms",
"retry_max_backoff_ms": "--retry-max-backoff-ms",
"retry_backoff_multiplier": "--retry-backoff-multiplier",
"retry_jitter_factor": "--retry-jitter-factor",
"disable_retries": "--disable-retries",
# Circuit breaker
"cb_failure_threshold": "--cb-failure-threshold",
"cb_success_threshold": "--cb-success-threshold",
"cb_timeout_duration_secs": "--cb-timeout-duration-secs",
"cb_window_duration_secs": "--cb-window-duration-secs",
"disable_circuit_breaker": "--disable-circuit-breaker",
# Rate limiting
"max_concurrent_requests": "--max-concurrent-requests",
"queue_size": "--queue-size",
"queue_timeout_secs": "--queue-timeout-secs",
"rate_limit_tokens_per_second": "--rate-limit-tokens-per-second",
}
for k, v in extra.items():
if v is None:
continue
flag = flag_map.get(k)
if not flag:
continue
if isinstance(v, bool):
if v:
cmd.append(flag)
else:
cmd.extend([flag, str(v)])
proc = subprocess.Popen(cmd)
self._children.append(proc)
url = f"http://127.0.0.1:{port}"
self._wait_health(url)
return ProcHandle(process=proc, url=url)
def _wait_health(self, base_url: str, timeout: float = 30.0):
start = time.time()
with requests.Session() as s:
while time.time() - start < timeout:
try:
r = s.get(f"{base_url}/health", timeout=2)
if r.status_code == 200:
return
except requests.RequestException:
pass
time.sleep(0.2)
raise TimeoutError(f"Router at {base_url} did not become healthy")
def add_worker(self, base_url: str, worker_url: str) -> None:
r = requests.post(f"{base_url}/add_worker", params={"url": worker_url})
assert r.status_code == 200, f"add_worker failed: {r.status_code} {r.text}"
def remove_worker(self, base_url: str, worker_url: str) -> None:
r = requests.post(f"{base_url}/remove_worker", params={"url": worker_url})
assert r.status_code == 200, f"remove_worker failed: {r.status_code} {r.text}"
def list_workers(self, base_url: str) -> list[str]:
r = requests.get(f"{base_url}/list_workers")
assert r.status_code == 200, f"list_workers failed: {r.status_code} {r.text}"
data = r.json()
return data.get("urls", [])
def stop_all(self):
for p in self._children:
if p.poll() is None:
p.terminate()
try:
p.wait(timeout=5)
except subprocess.TimeoutExpired:
p.kill()
self._children.clear()

View File

@@ -0,0 +1 @@
"""Integration test package for the router."""

View File

@@ -0,0 +1,109 @@
import os
import subprocess
import time
from pathlib import Path
from typing import Dict, Iterable, List, Tuple
import pytest
import requests
from ..fixtures.ports import find_free_port
from ..fixtures.router_manager import RouterManager
def pytest_configure(config):
config.addinivalue_line("markers", "integration: mark as router integration test")
@pytest.fixture
def router_manager() -> Iterable[RouterManager]:
mgr = RouterManager()
try:
yield mgr
finally:
mgr.stop_all()
def _spawn_mock_worker(args: List[str]) -> Tuple[subprocess.Popen, str, str]:
repo_root = Path(__file__).resolve().parents[2]
script = repo_root / "py_test" / "fixtures" / "mock_worker.py"
port = find_free_port()
worker_id = f"worker-{port}"
base_cmd = [
"python3",
str(script),
"--port",
str(port),
"--worker-id",
worker_id,
]
cmd = base_cmd + args
proc = subprocess.Popen(cmd)
url = f"http://127.0.0.1:{port}"
_wait_health(url)
return proc, url, worker_id
def _wait_health(url: str, timeout: float = 10.0):
start = time.time()
with requests.Session() as s:
while time.time() - start < timeout:
try:
r = s.get(f"{url}/health", timeout=1)
if r.status_code == 200:
return
except requests.RequestException:
pass
time.sleep(0.1)
raise TimeoutError(f"Mock worker at {url} did not become healthy")
@pytest.fixture
def mock_worker():
"""Start a single healthy mock worker; yields (process, url, worker_id)."""
proc, url, worker_id = _spawn_mock_worker([])
try:
yield proc, url, worker_id
finally:
if proc.poll() is None:
proc.terminate()
try:
proc.wait(timeout=3)
except subprocess.TimeoutExpired:
proc.kill()
@pytest.fixture
def mock_workers():
"""Factory to start N workers with custom args.
Usage:
procs, urls, ids = mock_workers(n=3, args=["--latency-ms", "5"]) # same args for all
...
"""
procs: List[subprocess.Popen] = []
def _start(n: int, args: List[str] | None = None):
args = args or []
new_procs: List[subprocess.Popen] = []
urls: List[str] = []
ids: List[str] = []
for _ in range(n):
p, url, wid = _spawn_mock_worker(args)
procs.append(p)
new_procs.append(p)
urls.append(url)
ids.append(wid)
return new_procs, urls, ids
try:
yield _start
finally:
for p in procs:
if p.poll() is None:
p.terminate()
try:
p.wait(timeout=3)
except subprocess.TimeoutExpired:
p.kill()

View File

@@ -0,0 +1 @@
"""Load balancing integration tests."""

View File

@@ -0,0 +1,73 @@
import collections
import concurrent.futures
import uuid
import pytest
import requests
@pytest.mark.integration
def test_cache_aware_affinity(mock_workers, router_manager):
# Two workers; same prompt should stick to one due to cache tree
_, urls, ids = mock_workers(n=2)
rh = router_manager.start_router(worker_urls=urls, policy="cache_aware")
counts = collections.Counter()
with requests.Session() as s:
for i in range(12):
r = s.post(
f"{rh.url}/v1/completions",
json={
"model": "test-model",
"prompt": "repeated prompt for cache",
"max_tokens": 1,
"stream": False,
},
)
assert r.status_code == 200
wid = r.headers.get("X-Worker-Id") or r.json().get("worker_id")
counts[wid] += 1
# Expect strong skew toward one worker (tree match); majority > 80%
top = max(counts.values())
assert top >= 10, counts
@pytest.mark.integration
def test_cache_aware_diverse_prompts_balances(mock_workers, router_manager):
# Add latency so concurrent requests overlap and influence load-based selection
_, urls, ids = mock_workers(n=3, args=["--latency-ms", "30"])
rh = router_manager.start_router(
worker_urls=urls,
policy="cache_aware",
extra={
"cache_threshold": 0.99,
"balance_abs_threshold": 0,
"balance_rel_threshold": 1.0,
},
)
counts = collections.Counter()
def call(i):
# Use diverse, unrelated prompts to avoid prefix matches entirely
prompt = str(uuid.uuid4())
r = requests.post(
f"{rh.url}/v1/completions",
json={
"model": "test-model",
"prompt": prompt,
"max_tokens": 1,
"stream": False,
},
timeout=5,
)
assert r.status_code == 200
return r.headers.get("X-Worker-Id") or r.json().get("worker_id")
with concurrent.futures.ThreadPoolExecutor(max_workers=16) as ex:
for wid in ex.map(call, range(40)):
counts[wid] += 1
# Expect participation of at least two workers
assert sum(1 for v in counts.values() if v > 0) >= 2, counts

View File

@@ -0,0 +1,89 @@
import collections
import concurrent.futures
import time
import pytest
import requests
@pytest.mark.integration
def test_power_of_two_prefers_less_loaded(mock_workers, router_manager):
# Start two workers: one slow (higher inflight), one fast
# Router monitors /get_load and Power-of-Two uses cached loads to choose
# Start one slow and one fast worker using the fixture factory
procs_slow, urls_slow, ids_slow = mock_workers(n=1, args=["--latency-ms", "200"])
procs_fast, urls_fast, ids_fast = mock_workers(n=1, args=["--latency-ms", "0"])
procs = procs_slow + procs_fast
urls = urls_slow + urls_fast
ids = ids_slow + ids_fast
slow_id = ids_slow[0]
rh = router_manager.start_router(
worker_urls=urls,
policy="power_of_two",
extra={"worker_startup_check_interval": 1},
)
# Prime: fire a burst to create measurable load on slow worker, then wait for monitor tick
def _prime_call(i):
try:
requests.post(
f"{rh.url}/v1/completions",
json={
"model": "test-model",
"prompt": f"warm-{i}",
"max_tokens": 1,
"stream": False,
},
timeout=5,
)
except Exception:
pass
with concurrent.futures.ThreadPoolExecutor(max_workers=32) as ex:
list(ex.map(_prime_call, range(128)))
time.sleep(2)
# Apply direct background load on the slow worker to amplify load diff
def _direct_load(i):
try:
requests.post(
f"{slow_url}/v1/completions",
json={
"model": "test-model",
"prompt": f"bg-{i}",
"max_tokens": 1,
"stream": False,
},
timeout=5,
)
except Exception:
pass
with concurrent.futures.ThreadPoolExecutor(max_workers=32) as ex:
list(ex.map(_direct_load, range(128)))
time.sleep(1)
def call(i):
r = requests.post(
f"{rh.url}/v1/completions",
json={
"model": "test-model",
"prompt": f"p{i}",
"max_tokens": 1,
"stream": False,
},
timeout=5,
)
assert r.status_code == 200
return r.headers.get("X-Worker-Id") or r.json().get("worker_id")
counts = collections.Counter()
with concurrent.futures.ThreadPoolExecutor(max_workers=32) as ex:
for wid in ex.map(call, range(200)):
counts[wid] += 1
# Expect the slow worker (higher latency/inflight) to receive fewer requests
fast_worker_id = [i for i in ids if i != slow_id][0]
assert counts[slow_id] < counts[fast_worker_id], counts

View File

@@ -0,0 +1,33 @@
import collections
import math
import pytest
import requests
@pytest.mark.integration
def test_random_distribution(mock_workers, router_manager):
procs, urls, ids = mock_workers(n=4)
rh = router_manager.start_router(worker_urls=urls, policy="random")
counts = collections.Counter()
N = 200
with requests.Session() as s:
for i in range(N):
r = s.post(
f"{rh.url}/v1/completions",
json={
"model": "test-model",
"prompt": f"p{i}",
"max_tokens": 1,
"stream": False,
},
)
assert r.status_code == 200
wid = r.headers.get("X-Worker-Id") or r.json().get("worker_id")
counts[wid] += 1
# simple statistical tolerance: each worker should be within ±50% of mean
mean = N / len(ids)
for wid in ids:
assert 0.5 * mean <= counts[wid] <= 1.5 * mean, counts

View File

@@ -0,0 +1,34 @@
import collections
import time
import pytest
import requests
@pytest.mark.integration
def test_round_robin_distribution(mock_workers, router_manager):
procs, urls, ids = mock_workers(n=3)
rh = router_manager.start_router(worker_urls=urls, policy="round_robin")
counts = collections.Counter()
with requests.Session() as s:
for i in range(30):
r = s.post(
f"{rh.url}/v1/completions",
json={
"model": "test-model",
"prompt": f"hello {i}",
"max_tokens": 1,
"stream": False,
},
)
assert r.status_code == 200
wid = r.headers.get("X-Worker-Id") or r.json().get("worker_id")
assert wid in ids
counts[wid] += 1
# Expect near-even distribution across 3 workers
# 30 requests -> ideally 10 each; allow small tolerance ±3
for wid in ids:
assert 7 <= counts[wid] <= 13, counts

View File

@@ -0,0 +1,38 @@
import pytest
import requests
@pytest.mark.integration
def test_router_api_key_enforcement(router_manager, mock_workers):
# Start backend requiring API key; router should forward Authorization header transparently
_, urls, _ = mock_workers(
n=1, args=["--require-api-key", "--api-key", "correct_api_key"]
)
rh = router_manager.start_router(
worker_urls=urls,
policy="round_robin",
extra={},
)
# No auth -> 401
r = requests.post(
f"{rh.url}/v1/completions",
json={"model": "test-model", "prompt": "x", "max_tokens": 1, "stream": False},
)
assert r.status_code == 401
# Invalid auth -> 401
r = requests.post(
f"{rh.url}/v1/completions",
json={"model": "test-model", "prompt": "x", "max_tokens": 1, "stream": False},
headers={"Authorization": "Bearer wrong"},
)
assert r.status_code == 401
# Correct auth -> 200
r = requests.post(
f"{rh.url}/v1/completions",
json={"model": "test-model", "prompt": "x", "max_tokens": 1, "stream": False},
headers={"Authorization": "Bearer correct_api_key"},
)
assert r.status_code == 200

View File

@@ -0,0 +1,191 @@
import time
import pytest
import requests
@pytest.mark.integration
def test_circuit_breaker_opens_and_recovers(router_manager, mock_workers):
# A single worker that fails first 3 requests, then succeeds
_, [wurl], _ = mock_workers(n=1, args=["--fail-first-n", "3"]) # fails first 3
rh = router_manager.start_router(
worker_urls=[wurl],
policy="round_robin",
extra={
"cb_failure_threshold": 3,
"cb_success_threshold": 2,
"cb_timeout_duration_secs": 3,
"cb_window_duration_secs": 10,
"disable_retries": True,
},
)
def post_once():
return requests.post(
f"{rh.url}/v1/completions",
json={
"model": "test-model",
"prompt": "trigger",
"max_tokens": 1,
"stream": False,
},
timeout=3,
)
saw_503 = False
for _ in range(8):
r = post_once()
if r.status_code == 503:
saw_503 = True
break
assert saw_503, "circuit breaker did not open to return 503"
time.sleep(4)
r1 = post_once()
r2 = post_once()
assert r1.status_code == 200 and r2.status_code == 200
@pytest.mark.integration
def test_circuit_breaker_half_open_failure_reopens(router_manager, mock_workers):
_, [wurl], _ = mock_workers(n=1, args=["--status-code", "500"]) # always fail
rh = router_manager.start_router(
worker_urls=[wurl],
policy="round_robin",
extra={
"cb_failure_threshold": 2,
"cb_success_threshold": 2,
"cb_timeout_duration_secs": 2,
"cb_window_duration_secs": 5,
"disable_retries": True,
},
)
def post_once():
return requests.post(
f"{rh.url}/v1/completions",
json={
"model": "test-model",
"prompt": "x",
"max_tokens": 1,
"stream": False,
},
timeout=3,
)
opened = False
for _ in range(8):
r = post_once()
if r.status_code == 503:
opened = True
break
assert opened, "circuit breaker did not open"
time.sleep(3)
r = post_once()
assert r.status_code == 500
r2 = post_once()
assert r2.status_code == 503
@pytest.mark.integration
def test_circuit_breaker_disable_flag(router_manager, mock_workers):
_, [wurl], _ = mock_workers(n=1, args=["--status-code", "500"]) # always fail
rh = router_manager.start_router(
worker_urls=[wurl],
policy="round_robin",
extra={
"disable_circuit_breaker": True,
"disable_retries": True,
},
)
r = requests.post(
f"{rh.url}/v1/completions",
json={
"model": "test-model",
"prompt": "x",
"max_tokens": 1,
"stream": False,
},
timeout=3,
)
assert r.status_code == 500
@pytest.mark.integration
def test_circuit_breaker_per_worker_isolation(router_manager, mock_workers):
_, [fail_url], _ = mock_workers(n=1, args=["--status-code", "500"]) # always fail
_, [ok_url], _ = mock_workers(n=1)
rh = router_manager.start_router(
worker_urls=[fail_url, ok_url],
policy="round_robin",
extra={
"cb_failure_threshold": 2,
"cb_success_threshold": 1,
"cb_timeout_duration_secs": 2,
"cb_window_duration_secs": 10,
"disable_retries": True,
},
)
def post_once():
return requests.post(
f"{rh.url}/v1/completions",
json={
"model": "test-model",
"prompt": "y",
"max_tokens": 1,
"stream": False,
},
timeout=3,
)
failures = 0
successes_after_open = 0
opened = False
for _ in range(30):
r = post_once()
if not opened:
if r.status_code == 500:
failures += 1
if failures >= 2:
_ = post_once()
_ = post_once()
opened = True
else:
if r.status_code == 200:
successes_after_open += 1
else:
assert False, f"Unexpected non-200 after CB open: {r.status_code}"
assert opened and successes_after_open >= 5
@pytest.mark.integration
def test_circuit_breaker_with_retries(router_manager, mock_workers):
_, [fail_url], _ = mock_workers(n=1, args=["--status-code", "500"]) # always fail
_, [ok_url], _ = mock_workers(n=1)
rh = router_manager.start_router(
worker_urls=[fail_url, ok_url],
policy="round_robin",
extra={
"retry_max_retries": 3,
"retry_initial_backoff_ms": 10,
"retry_max_backoff_ms": 50,
"cb_failure_threshold": 2,
"cb_success_threshold": 1,
"cb_timeout_duration_secs": 2,
"cb_window_duration_secs": 10,
},
)
r = requests.post(
f"{rh.url}/v1/completions",
json={
"model": "test-model",
"prompt": "z",
"max_tokens": 1,
"stream": False,
},
timeout=5,
)
assert r.status_code == 200

View File

@@ -0,0 +1,36 @@
import concurrent.futures
import subprocess
import time
import pytest
import requests
@pytest.mark.integration
def test_worker_crash_reroute_with_retries(router_manager, mock_workers):
# Start one healthy and one that will crash on first request
_, [ok_url], _ = mock_workers(n=1)
_, [crash_url], _ = mock_workers(n=1, args=["--crash-on-request"])
rh = router_manager.start_router(
worker_urls=[crash_url, ok_url],
policy="round_robin",
extra={
"retry_max_retries": 3,
"retry_initial_backoff_ms": 10,
"retry_max_backoff_ms": 50,
},
)
# A single request should succeed via retry to the healthy worker
r = requests.post(
f"{rh.url}/v1/completions",
json={
"model": "test-model",
"prompt": "crash",
"max_tokens": 1,
"stream": False,
},
timeout=5,
)
assert r.status_code == 200
# mock_workers fixture handles cleanup

View File

@@ -0,0 +1,33 @@
import pytest
import requests
@pytest.mark.integration
def test_payload_size_limit(router_manager, mock_workers):
# Start one backend and a router with a 1MB payload limit
_, urls, _ = mock_workers(n=1)
rh = router_manager.start_router(
worker_urls=urls,
policy="round_robin",
extra={"max_payload_size": 1 * 1024 * 1024}, # 1MB
)
# Payload just under 1MB should succeed
payload_small = {
"model": "test-model",
"prompt": "x" * int(0.5 * 1024 * 1024), # ~0.5MB
"max_tokens": 1,
"stream": False,
}
r = requests.post(f"{rh.url}/v1/completions", json=payload_small)
assert r.status_code == 200
# Payload over 1MB should fail with 413
payload_large = {
"model": "test-model",
"prompt": "x" * int(1.2 * 1024 * 1024), # ~1.2MB
"max_tokens": 1,
"stream": False,
}
r = requests.post(f"{rh.url}/v1/completions", json=payload_large)
assert r.status_code == 413

View File

@@ -0,0 +1,127 @@
import collections
import concurrent.futures
import subprocess
import time
import pytest
import requests
@pytest.mark.integration
def test_pd_power_of_two_decode_attribution(router_manager, mock_workers):
# Start two prefill and three decode mock workers via fixture
_, prefill_urls_raw, prefill_ids = mock_workers(n=2)
_, decode_urls_raw, decode_ids_list = mock_workers(n=3)
prefill_urls = [(u, None) for u in prefill_urls_raw]
decode_urls = list(decode_urls_raw)
decode_ids = set(decode_ids_list)
rh = router_manager.start_router(
policy="power_of_two",
pd_disaggregation=True,
prefill_urls=prefill_urls,
decode_urls=decode_urls,
extra={"worker_startup_check_interval": 1},
)
counts = collections.Counter()
with requests.Session() as s:
for i in range(30):
r = s.post(
f"{rh.url}/v1/completions",
json={
"model": "test-model",
"prompt": f"p{i}",
"max_tokens": 1,
"stream": False,
},
)
assert r.status_code == 200
wid = r.headers.get("X-Worker-Id") or r.json().get("worker_id")
assert wid in decode_ids
counts[wid] += 1
assert sum(1 for v in counts.values() if v > 0) >= 2
@pytest.mark.integration
def test_pd_power_of_two_skews_to_faster_decode(router_manager, mock_workers):
# Start two prefill workers (fast)
_, prefill_urls_raw, _ = mock_workers(n=2)
# Start two decode workers: one slow, one fast
_, [decode_slow_url], [slow_id] = mock_workers(
n=1, args=["--latency-ms", "300"]
) # slower decode
_, [decode_fast_url], [fast_id] = mock_workers(n=1)
decode_urls_raw = [decode_slow_url, decode_fast_url]
prefill_urls = [(u, None) for u in prefill_urls_raw]
decode_urls = list(decode_urls_raw)
rh = router_manager.start_router(
policy="power_of_two",
pd_disaggregation=True,
prefill_urls=prefill_urls,
decode_urls=decode_urls,
extra={"worker_startup_check_interval": 1},
)
def _prime_call(i):
try:
requests.post(
f"{rh.url}/v1/completions",
json={
"model": "test-model",
"prompt": f"warm-{i}",
"max_tokens": 1,
"stream": False,
},
timeout=8,
)
except Exception:
pass
with concurrent.futures.ThreadPoolExecutor(max_workers=32) as ex:
list(ex.map(_prime_call, range(128)))
time.sleep(2)
def _direct_decode_load(i):
try:
requests.post(
f"{decode_slow_url}/v1/completions",
json={
"model": "test-model",
"prompt": f"bg-{i}",
"max_tokens": 1,
"stream": False,
},
timeout=8,
)
except Exception:
pass
with concurrent.futures.ThreadPoolExecutor(max_workers=32) as ex:
list(ex.map(_direct_decode_load, range(128)))
time.sleep(1)
def call(i):
r = requests.post(
f"{rh.url}/v1/completions",
json={
"model": "test-model",
"prompt": f"p{i}",
"max_tokens": 1,
"stream": False,
},
timeout=8,
)
assert r.status_code == 200
return r.headers.get("X-Worker-Id") or r.json().get("worker_id")
counts = collections.Counter()
with concurrent.futures.ThreadPoolExecutor(max_workers=32) as ex:
for wid in ex.map(call, range(200)):
counts[wid] += 1
assert counts[slow_id] < counts[fast_id], counts

View File

@@ -0,0 +1,91 @@
import concurrent.futures
import time
import pytest
import requests
@pytest.mark.integration
def test_rate_limit_and_queue(router_manager, mock_workers):
# One fast backend
_, urls, _ = mock_workers(n=1)
rh = router_manager.start_router(
worker_urls=urls,
policy="round_robin",
extra={
"max_concurrent_requests": 2,
"queue_size": 0, # no queue -> immediate 429 when limit exceeded
},
)
def call_once(i):
try:
r = requests.post(
f"{rh.url}/v1/completions",
json={
"model": "test-model",
"prompt": f"p{i}",
"max_tokens": 1,
"stream": False,
},
timeout=3,
)
return r.status_code
except Exception:
return 599
# Fire a burst of concurrent requests
with concurrent.futures.ThreadPoolExecutor(max_workers=16) as ex:
results = list(ex.map(call_once, range(16)))
# Expect some to succeed and some to be rate limited (429)
assert any(code == 200 for code in results)
assert any(code == 429 for code in results)
@pytest.mark.integration
def test_rate_limit_queue_and_timeout(router_manager, mock_workers):
# Slow backend: ~2s per request ensures queue wait > timeout
_, urls, _ = mock_workers(n=1, args=["--latency-ms", "2000"]) # 2.0s per request
# Allow 1 concurrent, queue up to 1, with 1s queue timeout
rh = router_manager.start_router(
worker_urls=urls,
policy="round_robin",
extra={
"max_concurrent_requests": 1,
"queue_size": 1,
"queue_timeout_secs": 1,
},
)
def call_once(i):
try:
r = requests.post(
f"{rh.url}/v1/completions",
json={
"model": "test-model",
"prompt": f"q{i}",
"max_tokens": 1,
"stream": False,
},
timeout=5,
)
return r.status_code
except Exception:
return 599
# Fire 4 concurrent requests: 1 runs (~2s), 1 queued (times out at 1s -> 408), 2 overflow -> 429
import concurrent.futures
with concurrent.futures.ThreadPoolExecutor(max_workers=8) as ex:
results = list(ex.map(call_once, range(4)))
# We expect:
# - Some 200s (processed)
# - At least one 408 (queued too long and timed out)
# - Remaining non-200s are either 429 (queue overflow) or additional 408s depending on scheduling
assert any(code == 200 for code in results)
assert any(code == 408 for code in results), results
non200 = [c for c in results if c != 200]
assert len(non200) >= 2 and all(c in (408, 429) for c in non200), results

View File

@@ -0,0 +1,65 @@
import concurrent.futures
import subprocess
import time
import pytest
import requests
@pytest.mark.integration
def test_retry_reroutes_to_healthy_worker(router_manager, mock_workers):
# Worker A always 500; Worker B healthy
# Worker A always 500; Worker B/C healthy
_, [url_a], [id_a] = mock_workers(n=1, args=["--status-code", "500"]) # fail
_, [url_b], [id_b] = mock_workers(n=1)
_, [url_c], [id_c] = mock_workers(n=1)
rh = router_manager.start_router(
worker_urls=[url_a, url_b, url_c],
policy="round_robin",
extra={
"retry_max_retries": 3,
"retry_initial_backoff_ms": 10,
"retry_max_backoff_ms": 50,
},
)
r = requests.post(
f"{rh.url}/v1/completions",
json={
"model": "test-model",
"prompt": "x",
"max_tokens": 1,
"stream": False,
},
timeout=5,
)
assert r.status_code == 200
wid = r.headers.get("X-Worker-Id") or r.json().get("worker_id")
assert wid == id_b # should have retried onto healthy worker
# mock_workers fixture handles cleanup
@pytest.mark.integration
def test_disable_retries_surfaces_failure(router_manager, mock_workers):
# Single failing worker, retries disabled -> should return 500
_, [url], [wid] = mock_workers(n=1, args=["--status-code", "500"]) # always fail
rh = router_manager.start_router(
worker_urls=[url],
policy="round_robin",
extra={
"disable_retries": True,
},
)
r = requests.post(
f"{rh.url}/v1/completions",
json={
"model": "test-model",
"prompt": "x",
"max_tokens": 1,
"stream": False,
},
timeout=5,
)
assert r.status_code == 500
# mock_workers fixture handles cleanup

View File

@@ -0,0 +1,36 @@
import pytest
import requests
@pytest.mark.integration
def test_discovery_shim_add_remove(router_manager, mock_workers):
# Start router without workers
rh = router_manager.start_router(worker_urls=[], policy="round_robin")
# Initially empty
urls = router_manager.list_workers(rh.url)
assert urls == []
# Add a worker (simulate discovery event)
_, [wurl], [wid] = mock_workers(n=1)
router_manager.add_worker(rh.url, wurl)
urls = router_manager.list_workers(rh.url)
assert wurl in urls
# Can serve a request
r = requests.post(
f"{rh.url}/v1/completions",
json={
"model": "test-model",
"prompt": "hi",
"max_tokens": 1,
"stream": False,
},
)
assert r.status_code == 200
# Remove worker (simulate pod deletion)
router_manager.remove_worker(rh.url, wurl)
urls = router_manager.list_workers(rh.url)
assert wurl not in urls
# mock_workers fixture handles cleanup

View File

@@ -0,0 +1,61 @@
import collections
import subprocess
import time
import pytest
import requests
@pytest.mark.integration
def test_add_and_remove_worker(mock_worker, router_manager, mock_workers):
# Start with a single worker
proc1, url1, id1 = mock_worker
rh = router_manager.start_router(worker_urls=[url1], policy="round_robin")
# Add a second worker
procs2, urls2, ids2 = mock_workers(n=1)
url2 = urls2[0]
id2 = ids2[0]
router_manager.add_worker(rh.url, url2)
# Send some requests and ensure both workers are seen
seen = set()
with requests.Session() as s:
for i in range(20):
r = s.post(
f"{rh.url}/v1/completions",
json={
"model": "test-model",
"prompt": f"x{i}",
"max_tokens": 1,
"stream": False,
},
)
assert r.status_code == 200
wid = r.headers.get("X-Worker-Id") or r.json().get("worker_id")
seen.add(wid)
if len(seen) == 2:
break
assert id1 in seen and id2 in seen
# Now remove the second worker
router_manager.remove_worker(rh.url, url2)
# After removal, subsequent requests should only come from first worker
with requests.Session() as s:
for i in range(10):
r = s.post(
f"{rh.url}/v1/completions",
json={
"model": "test-model",
"prompt": f"y{i}",
"max_tokens": 1,
"stream": False,
},
)
assert r.status_code == 200
wid = r.headers.get("X-Worker-Id") or r.json().get("worker_id")
assert wid == id1
# mock_workers fixture handles cleanup

View File

@@ -0,0 +1,7 @@
"""
Unit tests for sglang_router.
This package contains fast, isolated unit tests for Python components
of the SGLang router. These tests focus on testing individual functions
and classes in isolation without starting actual router instances.
"""

View File

@@ -0,0 +1,628 @@
"""
Unit tests for argument parsing functionality in sglang_router.
These tests focus on testing the argument parsing logic in isolation,
without starting actual router instances.
"""
import argparse
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from sglang_router.launch_router import RouterArgs, parse_router_args
from sglang_router.router import policy_from_str
class TestRouterArgs:
"""Test RouterArgs dataclass and its methods."""
def test_default_values(self):
"""Test that RouterArgs has correct default values."""
args = RouterArgs()
# Test basic defaults
assert args.host == "127.0.0.1"
assert args.port == 30000
assert args.policy == "cache_aware"
assert args.worker_urls == []
assert args.pd_disaggregation is False
assert args.prefill_urls == []
assert args.decode_urls == []
# Test PD-specific defaults
assert args.prefill_policy is None
assert args.decode_policy is None
# Test service discovery defaults
assert args.service_discovery is False
assert args.selector == {}
assert args.service_discovery_port == 80
assert args.service_discovery_namespace is None
# Test retry and circuit breaker defaults
assert args.retry_max_retries == 5
assert args.cb_failure_threshold == 10
assert args.disable_retries is False
assert args.disable_circuit_breaker is False
def test_parse_selector_valid(self):
"""Test parsing valid selector arguments."""
# Test single key-value pair
result = RouterArgs._parse_selector(["app=worker"])
assert result == {"app": "worker"}
# Test multiple key-value pairs
result = RouterArgs._parse_selector(["app=worker", "env=prod", "version=v1"])
assert result == {"app": "worker", "env": "prod", "version": "v1"}
# Test empty list
result = RouterArgs._parse_selector([])
assert result == {}
# Test None
result = RouterArgs._parse_selector(None)
assert result == {}
def test_parse_selector_invalid(self):
"""Test parsing invalid selector arguments."""
# Test malformed selector (no equals sign)
result = RouterArgs._parse_selector(["app"])
assert result == {}
# Test multiple equals signs (should use first one)
result = RouterArgs._parse_selector(["app=worker=extra"])
assert result == {"app": "worker=extra"}
def test_parse_prefill_urls_valid(self):
"""Test parsing valid prefill URL arguments."""
# Test with bootstrap port
result = RouterArgs._parse_prefill_urls([["http://prefill1:8000", "9000"]])
assert result == [("http://prefill1:8000", 9000)]
# Test with 'none' bootstrap port
result = RouterArgs._parse_prefill_urls([["http://prefill1:8000", "none"]])
assert result == [("http://prefill1:8000", None)]
# Test without bootstrap port
result = RouterArgs._parse_prefill_urls([["http://prefill1:8000"]])
assert result == [("http://prefill1:8000", None)]
# Test multiple prefill URLs
result = RouterArgs._parse_prefill_urls(
[
["http://prefill1:8000", "9000"],
["http://prefill2:8000", "none"],
["http://prefill3:8000"],
]
)
expected = [
("http://prefill1:8000", 9000),
("http://prefill2:8000", None),
("http://prefill3:8000", None),
]
assert result == expected
# Test empty list
result = RouterArgs._parse_prefill_urls([])
assert result == []
# Test None
result = RouterArgs._parse_prefill_urls(None)
assert result == []
def test_parse_prefill_urls_invalid(self):
"""Test parsing invalid prefill URL arguments."""
# Test invalid bootstrap port
with pytest.raises(ValueError, match="Invalid bootstrap port"):
RouterArgs._parse_prefill_urls([["http://prefill1:8000", "invalid"]])
def test_parse_decode_urls_valid(self):
"""Test parsing valid decode URL arguments."""
# Test single decode URL
result = RouterArgs._parse_decode_urls([["http://decode1:8001"]])
assert result == ["http://decode1:8001"]
# Test multiple decode URLs
result = RouterArgs._parse_decode_urls(
[["http://decode1:8001"], ["http://decode2:8001"]]
)
assert result == ["http://decode1:8001", "http://decode2:8001"]
# Test empty list
result = RouterArgs._parse_decode_urls([])
assert result == []
# Test None
result = RouterArgs._parse_decode_urls(None)
assert result == []
def test_from_cli_args_basic(self):
"""Test creating RouterArgs from basic CLI arguments."""
args = SimpleNamespace(
host="0.0.0.0",
port=30001,
worker_urls=["http://worker1:8000", "http://worker2:8000"],
policy="round_robin",
prefill=None,
decode=None,
router_policy="round_robin",
router_pd_disaggregation=False,
router_prefill_policy=None,
router_decode_policy=None,
router_worker_startup_timeout_secs=300,
router_worker_startup_check_interval=15,
router_cache_threshold=0.7,
router_balance_abs_threshold=128,
router_balance_rel_threshold=2.0,
router_eviction_interval=180,
router_max_tree_size=2**28,
router_max_payload_size=1024 * 1024 * 1024, # 1GB
router_dp_aware=True,
router_api_key="test-key",
router_log_dir="/tmp/logs",
router_log_level="debug",
router_service_discovery=True,
router_selector=["app=worker", "env=test"],
router_service_discovery_port=8080,
router_service_discovery_namespace="default",
router_prefill_selector=["app=prefill"],
router_decode_selector=["app=decode"],
router_prometheus_port=29000,
router_prometheus_host="0.0.0.0",
router_request_id_headers=["x-request-id", "x-trace-id"],
router_request_timeout_secs=1200,
router_max_concurrent_requests=512,
router_queue_size=200,
router_queue_timeout_secs=120,
router_rate_limit_tokens_per_second=100,
router_cors_allowed_origins=["http://localhost:3000"],
router_retry_max_retries=3,
router_retry_initial_backoff_ms=100,
router_retry_max_backoff_ms=10000,
router_retry_backoff_multiplier=2.0,
router_retry_jitter_factor=0.1,
router_cb_failure_threshold=5,
router_cb_success_threshold=2,
router_cb_timeout_duration_secs=30,
router_cb_window_duration_secs=60,
router_disable_retries=False,
router_disable_circuit_breaker=False,
router_health_failure_threshold=2,
router_health_success_threshold=1,
router_health_check_timeout_secs=3,
router_health_check_interval_secs=30,
router_health_check_endpoint="/healthz",
)
router_args = RouterArgs.from_cli_args(args, use_router_prefix=True)
# Test basic configuration
assert router_args.host == "0.0.0.0"
assert router_args.port == 30001
assert router_args.worker_urls == ["http://worker1:8000", "http://worker2:8000"]
assert router_args.policy == "round_robin"
# Test PD configuration
assert router_args.pd_disaggregation is False
assert router_args.prefill_urls == []
assert router_args.decode_urls == []
# Test service discovery
assert router_args.service_discovery is True
assert router_args.selector == {"app": "worker", "env": "test"}
assert router_args.service_discovery_port == 8080
assert router_args.service_discovery_namespace == "default"
assert router_args.prefill_selector == {"app": "prefill"}
assert router_args.decode_selector == {"app": "decode"}
# Test other configurations
assert router_args.dp_aware is True
assert router_args.api_key == "test-key"
assert router_args.log_dir == "/tmp/logs"
assert router_args.log_level == "debug"
assert router_args.prometheus_port == 29000
assert router_args.prometheus_host == "0.0.0.0"
assert router_args.request_id_headers == ["x-request-id", "x-trace-id"]
assert router_args.request_timeout_secs == 1200
assert router_args.max_concurrent_requests == 512
assert router_args.queue_size == 200
assert router_args.queue_timeout_secs == 120
assert router_args.rate_limit_tokens_per_second == 100
assert router_args.cors_allowed_origins == ["http://localhost:3000"]
# Test retry configuration
assert router_args.retry_max_retries == 3
assert router_args.retry_initial_backoff_ms == 100
assert router_args.retry_max_backoff_ms == 10000
assert router_args.retry_backoff_multiplier == 2.0
assert router_args.retry_jitter_factor == 0.1
# Test circuit breaker configuration
assert router_args.cb_failure_threshold == 5
assert router_args.cb_success_threshold == 2
assert router_args.cb_timeout_duration_secs == 30
assert router_args.cb_window_duration_secs == 60
assert router_args.disable_retries is False
assert router_args.disable_circuit_breaker is False
# Test health check configuration
assert router_args.health_failure_threshold == 2
assert router_args.health_success_threshold == 1
assert router_args.health_check_timeout_secs == 3
assert router_args.health_check_interval_secs == 30
assert router_args.health_check_endpoint == "/healthz"
# Note: model_path and tokenizer_path are not available in current RouterArgs
def test_from_cli_args_pd_mode(self):
"""Test creating RouterArgs from CLI arguments in PD mode."""
args = SimpleNamespace(
host="127.0.0.1",
port=30000,
worker_urls=[],
policy="cache_aware",
prefill=[
["http://prefill1:8000", "9000"],
["http://prefill2:8000", "none"],
],
decode=[["http://decode1:8001"], ["http://decode2:8001"]],
router_prefill=[
["http://prefill1:8000", "9000"],
["http://prefill2:8000", "none"],
],
router_decode=[["http://decode1:8001"], ["http://decode2:8001"]],
router_policy="cache_aware",
router_pd_disaggregation=True,
router_prefill_policy="power_of_two",
router_decode_policy="round_robin",
# Include all required fields with defaults
router_worker_startup_timeout_secs=600,
router_worker_startup_check_interval=30,
router_cache_threshold=0.3,
router_balance_abs_threshold=64,
router_balance_rel_threshold=1.5,
router_eviction_interval=120,
router_max_tree_size=2**26,
router_max_payload_size=512 * 1024 * 1024,
router_dp_aware=False,
router_api_key=None,
router_log_dir=None,
router_log_level=None,
router_service_discovery=False,
router_selector=None,
router_service_discovery_port=80,
router_service_discovery_namespace=None,
router_prefill_selector=None,
router_decode_selector=None,
router_prometheus_port=None,
router_prometheus_host=None,
router_request_id_headers=None,
router_request_timeout_secs=1800,
router_max_concurrent_requests=256,
router_queue_size=100,
router_queue_timeout_secs=60,
router_rate_limit_tokens_per_second=None,
router_cors_allowed_origins=[],
router_retry_max_retries=5,
router_retry_initial_backoff_ms=50,
router_retry_max_backoff_ms=30000,
router_retry_backoff_multiplier=1.5,
router_retry_jitter_factor=0.2,
router_cb_failure_threshold=10,
router_cb_success_threshold=3,
router_cb_timeout_duration_secs=60,
router_cb_window_duration_secs=120,
router_disable_retries=False,
router_disable_circuit_breaker=False,
router_health_failure_threshold=3,
router_health_success_threshold=2,
router_health_check_timeout_secs=5,
router_health_check_interval_secs=60,
router_health_check_endpoint="/health",
)
router_args = RouterArgs.from_cli_args(args, use_router_prefix=True)
# Test PD configuration
assert router_args.pd_disaggregation is True
assert router_args.prefill_urls == [
("http://prefill1:8000", 9000),
("http://prefill2:8000", None),
]
assert router_args.decode_urls == ["http://decode1:8001", "http://decode2:8001"]
assert router_args.prefill_policy == "power_of_two"
assert router_args.decode_policy == "round_robin"
assert router_args.policy == "cache_aware" # Main policy still set
def test_from_cli_args_without_prefix(self):
"""Test creating RouterArgs from CLI arguments without router prefix."""
args = SimpleNamespace(
host="127.0.0.1",
port=30000,
worker_urls=["http://worker1:8000"],
policy="random",
prefill=None,
decode=None,
pd_disaggregation=False,
prefill_policy=None,
decode_policy=None,
worker_startup_timeout_secs=600,
worker_startup_check_interval=30,
cache_threshold=0.3,
balance_abs_threshold=64,
balance_rel_threshold=1.5,
eviction_interval=120,
max_tree_size=2**26,
max_payload_size=512 * 1024 * 1024,
dp_aware=False,
api_key=None,
log_dir=None,
log_level=None,
service_discovery=False,
selector=None,
service_discovery_port=80,
service_discovery_namespace=None,
prefill_selector=None,
decode_selector=None,
prometheus_port=None,
prometheus_host=None,
request_id_headers=None,
request_timeout_secs=1800,
max_concurrent_requests=256,
queue_size=100,
queue_timeout_secs=60,
rate_limit_tokens_per_second=None,
cors_allowed_origins=[],
retry_max_retries=5,
retry_initial_backoff_ms=50,
retry_max_backoff_ms=30000,
retry_backoff_multiplier=1.5,
retry_jitter_factor=0.2,
cb_failure_threshold=10,
cb_success_threshold=3,
cb_timeout_duration_secs=60,
cb_window_duration_secs=120,
disable_retries=False,
disable_circuit_breaker=False,
health_failure_threshold=3,
health_success_threshold=2,
health_check_timeout_secs=5,
health_check_interval_secs=60,
health_check_endpoint="/health",
model_path=None,
tokenizer_path=None,
)
router_args = RouterArgs.from_cli_args(args, use_router_prefix=False)
assert router_args.host == "127.0.0.1"
assert router_args.port == 30000
assert router_args.worker_urls == ["http://worker1:8000"]
assert router_args.policy == "random"
assert router_args.pd_disaggregation is False
class TestPolicyFromStr:
"""Test policy string to enum conversion."""
def test_valid_policies(self):
"""Test conversion of valid policy strings."""
from sglang_router_rs import PolicyType
assert policy_from_str("random") == PolicyType.Random
assert policy_from_str("round_robin") == PolicyType.RoundRobin
assert policy_from_str("cache_aware") == PolicyType.CacheAware
assert policy_from_str("power_of_two") == PolicyType.PowerOfTwo
def test_invalid_policy(self):
"""Test conversion of invalid policy string."""
with pytest.raises(KeyError):
policy_from_str("invalid_policy")
class TestParseRouterArgs:
"""Test the parse_router_args function."""
def test_parse_basic_args(self):
"""Test parsing basic router arguments."""
args = [
"--host",
"0.0.0.0",
"--port",
"30001",
"--worker-urls",
"http://worker1:8000",
"http://worker2:8000",
"--policy",
"round_robin",
]
router_args = parse_router_args(args)
assert router_args.host == "0.0.0.0"
assert router_args.port == 30001
assert router_args.worker_urls == ["http://worker1:8000", "http://worker2:8000"]
assert router_args.policy == "round_robin"
def test_parse_pd_args(self):
"""Test parsing PD disaggregated mode arguments."""
args = [
"--pd-disaggregation",
"--prefill",
"http://prefill1:8000",
"9000",
"--prefill",
"http://prefill2:8000",
"none",
"--decode",
"http://decode1:8001",
"--decode",
"http://decode2:8001",
"--prefill-policy",
"power_of_two",
"--decode-policy",
"round_robin",
]
router_args = parse_router_args(args)
assert router_args.pd_disaggregation is True
assert router_args.prefill_urls == [
("http://prefill1:8000", 9000),
("http://prefill2:8000", None),
]
assert router_args.decode_urls == ["http://decode1:8001", "http://decode2:8001"]
assert router_args.prefill_policy == "power_of_two"
assert router_args.decode_policy == "round_robin"
def test_parse_service_discovery_args(self):
"""Test parsing service discovery arguments."""
args = [
"--service-discovery",
"--selector",
"app=worker",
"env=prod",
"--service-discovery-port",
"8080",
"--service-discovery-namespace",
"default",
]
router_args = parse_router_args(args)
assert router_args.service_discovery is True
assert router_args.selector == {"app": "worker", "env": "prod"}
assert router_args.service_discovery_port == 8080
assert router_args.service_discovery_namespace == "default"
def test_parse_retry_and_circuit_breaker_args(self):
"""Test parsing retry and circuit breaker arguments."""
args = [
"--retry-max-retries",
"3",
"--retry-initial-backoff-ms",
"100",
"--retry-max-backoff-ms",
"10000",
"--retry-backoff-multiplier",
"2.0",
"--retry-jitter-factor",
"0.1",
"--disable-retries",
"--cb-failure-threshold",
"5",
"--cb-success-threshold",
"2",
"--cb-timeout-duration-secs",
"30",
"--cb-window-duration-secs",
"60",
"--disable-circuit-breaker",
]
router_args = parse_router_args(args)
# Test retry configuration
assert router_args.retry_max_retries == 3
assert router_args.retry_initial_backoff_ms == 100
assert router_args.retry_max_backoff_ms == 10000
assert router_args.retry_backoff_multiplier == 2.0
assert router_args.retry_jitter_factor == 0.1
assert router_args.disable_retries is True
# Test circuit breaker configuration
assert router_args.cb_failure_threshold == 5
assert router_args.cb_success_threshold == 2
assert router_args.cb_timeout_duration_secs == 30
assert router_args.cb_window_duration_secs == 60
assert router_args.disable_circuit_breaker is True
def test_parse_rate_limiting_args(self):
"""Test parsing rate limiting arguments."""
args = [
"--max-concurrent-requests",
"512",
"--queue-size",
"200",
"--queue-timeout-secs",
"120",
"--rate-limit-tokens-per-second",
"100",
]
router_args = parse_router_args(args)
assert router_args.max_concurrent_requests == 512
assert router_args.queue_size == 200
assert router_args.queue_timeout_secs == 120
assert router_args.rate_limit_tokens_per_second == 100
def test_parse_health_check_args(self):
"""Test parsing health check arguments."""
args = [
"--health-failure-threshold",
"2",
"--health-success-threshold",
"1",
"--health-check-timeout-secs",
"3",
"--health-check-interval-secs",
"30",
"--health-check-endpoint",
"/healthz",
]
router_args = parse_router_args(args)
assert router_args.health_failure_threshold == 2
assert router_args.health_success_threshold == 1
assert router_args.health_check_timeout_secs == 3
assert router_args.health_check_interval_secs == 30
assert router_args.health_check_endpoint == "/healthz"
def test_parse_cors_args(self):
"""Test parsing CORS arguments."""
args = [
"--cors-allowed-origins",
"http://localhost:3000",
"https://example.com",
]
router_args = parse_router_args(args)
assert router_args.cors_allowed_origins == [
"http://localhost:3000",
"https://example.com",
]
def test_parse_tokenizer_args(self):
"""Test parsing tokenizer arguments."""
# Note: model-path and tokenizer-path arguments are not available in current implementation
# This test is skipped until those arguments are added
pytest.skip("Tokenizer arguments not available in current implementation")
def test_parse_invalid_args(self):
"""Test parsing invalid arguments."""
# Test invalid policy
with pytest.raises(SystemExit):
parse_router_args(["--policy", "invalid_policy"])
# Test invalid bootstrap port
with pytest.raises(ValueError, match="Invalid bootstrap port"):
parse_router_args(
[
"--pd-disaggregation",
"--prefill",
"http://prefill1:8000",
"invalid_port",
]
)
def test_help_output(self):
"""Test that help output is generated correctly."""
with pytest.raises(SystemExit) as exc_info:
parse_router_args(["--help"])
# SystemExit with code 0 indicates help was displayed
assert exc_info.value.code == 0

View File

@@ -0,0 +1,421 @@
"""
Unit tests for router configuration validation and setup.
These tests focus on testing the router configuration logic in isolation,
including validation of configuration parameters and their interactions.
"""
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from sglang_router.launch_router import RouterArgs, launch_router
from sglang_router.router import policy_from_str
from sglang_router_rs import PolicyType
class TestRouterConfigValidation:
"""Test router configuration validation logic."""
def test_valid_basic_config(self):
"""Test that a valid basic configuration passes validation."""
args = RouterArgs(
host="127.0.0.1",
port=30000,
worker_urls=["http://worker1:8000", "http://worker2:8000"],
policy="cache_aware",
)
# Should not raise any exceptions
assert args.host == "127.0.0.1"
assert args.port == 30000
assert args.worker_urls == ["http://worker1:8000", "http://worker2:8000"]
assert args.policy == "cache_aware"
def test_valid_pd_config(self):
"""Test that a valid PD configuration passes validation."""
args = RouterArgs(
host="127.0.0.1",
port=30000,
pd_disaggregation=True,
prefill_urls=[
("http://prefill1:8000", 9000),
("http://prefill2:8000", None),
],
decode_urls=["http://decode1:8001", "http://decode2:8001"],
policy="cache_aware",
)
assert args.pd_disaggregation is True
assert args.prefill_urls == [
("http://prefill1:8000", 9000),
("http://prefill2:8000", None),
]
assert args.decode_urls == ["http://decode1:8001", "http://decode2:8001"]
assert args.policy == "cache_aware"
def test_pd_config_without_urls_raises_error(self):
"""Test that PD mode without URLs raises validation error."""
args = RouterArgs(
pd_disaggregation=True,
prefill_urls=[],
decode_urls=[],
service_discovery=False,
)
# This should raise an error when trying to launch
with pytest.raises(
ValueError, match="PD disaggregation mode requires --prefill"
):
launch_router(args)
def test_pd_config_with_service_discovery_allows_empty_urls(self):
"""Test that PD mode with service discovery allows empty URLs."""
args = RouterArgs(
pd_disaggregation=True,
prefill_urls=[],
decode_urls=[],
service_discovery=True,
)
# Should not raise validation error when service discovery is enabled
with patch("sglang_router.launch_router.Router") as router_mod:
mock_router_instance = MagicMock()
router_mod.from_args = MagicMock(return_value=mock_router_instance)
launch_router(args)
# Should create router instance via from_args
router_mod.from_args.assert_called_once()
def test_regular_mode_without_workers_allows_empty_urls(self):
"""Test that regular mode allows empty worker URLs."""
args = RouterArgs(worker_urls=[], service_discovery=False)
# Should not raise validation error
with patch("sglang_router.launch_router.Router") as router_mod:
mock_router_instance = MagicMock()
router_mod.from_args = MagicMock(return_value=mock_router_instance)
launch_router(args)
# Should create router instance via from_args
router_mod.from_args.assert_called_once()
def test_cache_threshold_validation(self):
"""Test cache threshold validation."""
# Valid cache threshold
args = RouterArgs(cache_threshold=0.5)
assert args.cache_threshold == 0.5
# Edge cases
args = RouterArgs(cache_threshold=0.0)
assert args.cache_threshold == 0.0
args = RouterArgs(cache_threshold=1.0)
assert args.cache_threshold == 1.0
def test_balance_threshold_validation(self):
"""Test load balancing threshold validation."""
# Valid thresholds
args = RouterArgs(balance_abs_threshold=64, balance_rel_threshold=1.5)
assert args.balance_abs_threshold == 64
assert args.balance_rel_threshold == 1.5
# Edge cases
args = RouterArgs(balance_abs_threshold=0, balance_rel_threshold=1.0)
assert args.balance_abs_threshold == 0
assert args.balance_rel_threshold == 1.0
def test_timeout_validation(self):
"""Test timeout parameter validation."""
# Valid timeouts
args = RouterArgs(
worker_startup_timeout_secs=600,
worker_startup_check_interval=30,
request_timeout_secs=1800,
queue_timeout_secs=60,
)
assert args.worker_startup_timeout_secs == 600
assert args.worker_startup_check_interval == 30
assert args.request_timeout_secs == 1800
assert args.queue_timeout_secs == 60
def test_retry_config_validation(self):
"""Test retry configuration validation."""
# Valid retry config
args = RouterArgs(
retry_max_retries=5,
retry_initial_backoff_ms=50,
retry_max_backoff_ms=30000,
retry_backoff_multiplier=1.5,
retry_jitter_factor=0.2,
disable_retries=False,
)
assert args.retry_max_retries == 5
assert args.retry_initial_backoff_ms == 50
assert args.retry_max_backoff_ms == 30000
assert args.retry_backoff_multiplier == 1.5
assert args.retry_jitter_factor == 0.2
assert args.disable_retries is False
def test_circuit_breaker_config_validation(self):
"""Test circuit breaker configuration validation."""
# Valid circuit breaker config
args = RouterArgs(
cb_failure_threshold=10,
cb_success_threshold=3,
cb_timeout_duration_secs=60,
cb_window_duration_secs=120,
disable_circuit_breaker=False,
)
assert args.cb_failure_threshold == 10
assert args.cb_success_threshold == 3
assert args.cb_timeout_duration_secs == 60
assert args.cb_window_duration_secs == 120
assert args.disable_circuit_breaker is False
def test_health_check_config_validation(self):
"""Test health check configuration validation."""
# Valid health check config
args = RouterArgs(
health_failure_threshold=3,
health_success_threshold=2,
health_check_timeout_secs=5,
health_check_interval_secs=60,
health_check_endpoint="/health",
)
assert args.health_failure_threshold == 3
assert args.health_success_threshold == 2
assert args.health_check_timeout_secs == 5
assert args.health_check_interval_secs == 60
assert args.health_check_endpoint == "/health"
def test_rate_limiting_config_validation(self):
"""Test rate limiting configuration validation."""
# Valid rate limiting config
args = RouterArgs(
max_concurrent_requests=256,
queue_size=100,
queue_timeout_secs=60,
rate_limit_tokens_per_second=100,
)
assert args.max_concurrent_requests == 256
assert args.queue_size == 100
assert args.queue_timeout_secs == 60
assert args.rate_limit_tokens_per_second == 100
def test_service_discovery_config_validation(self):
"""Test service discovery configuration validation."""
# Valid service discovery config
args = RouterArgs(
service_discovery=True,
selector={"app": "worker", "env": "prod"},
service_discovery_port=8080,
service_discovery_namespace="default",
)
assert args.service_discovery is True
assert args.selector == {"app": "worker", "env": "prod"}
assert args.service_discovery_port == 8080
assert args.service_discovery_namespace == "default"
def test_pd_service_discovery_config_validation(self):
"""Test PD service discovery configuration validation."""
# Valid PD service discovery config
args = RouterArgs(
pd_disaggregation=True,
service_discovery=True,
prefill_selector={"app": "prefill"},
decode_selector={"app": "decode"},
bootstrap_port_annotation="sglang.ai/bootstrap-port",
)
assert args.pd_disaggregation is True
assert args.service_discovery is True
assert args.prefill_selector == {"app": "prefill"}
assert args.decode_selector == {"app": "decode"}
assert args.bootstrap_port_annotation == "sglang.ai/bootstrap-port"
def test_prometheus_config_validation(self):
"""Test Prometheus configuration validation."""
# Valid Prometheus config
args = RouterArgs(prometheus_port=29000, prometheus_host="127.0.0.1")
assert args.prometheus_port == 29000
assert args.prometheus_host == "127.0.0.1"
def test_cors_config_validation(self):
"""Test CORS configuration validation."""
# Valid CORS config
args = RouterArgs(
cors_allowed_origins=["http://localhost:3000", "https://example.com"]
)
assert args.cors_allowed_origins == [
"http://localhost:3000",
"https://example.com",
]
def test_tokenizer_config_validation(self):
"""Test tokenizer configuration validation."""
# Note: model_path and tokenizer_path are not available in current RouterArgs
pytest.skip("Tokenizer configuration not available in current implementation")
def test_dp_aware_config_validation(self):
"""Test data parallelism aware configuration validation."""
# Valid DP aware config
args = RouterArgs(dp_aware=True, api_key="test-api-key")
assert args.dp_aware is True
assert args.api_key == "test-api-key"
def test_request_id_headers_validation(self):
"""Test request ID headers configuration validation."""
# Valid request ID headers config
args = RouterArgs(
request_id_headers=["x-request-id", "x-trace-id", "x-correlation-id"]
)
assert args.request_id_headers == [
"x-request-id",
"x-trace-id",
"x-correlation-id",
]
def test_policy_consistency_validation(self):
"""Test policy consistency validation in PD mode."""
# Test with both prefill and decode policies specified
args = RouterArgs(
pd_disaggregation=True,
prefill_urls=[("http://prefill1:8000", None)],
decode_urls=["http://decode1:8001"],
policy="cache_aware",
prefill_policy="power_of_two",
decode_policy="round_robin",
)
# Should not raise validation error
with patch("sglang_router.launch_router.Router") as router_mod:
mock_router_instance = MagicMock()
router_mod.from_args = MagicMock(return_value=mock_router_instance)
launch_router(args)
# Should create router instance via from_args
router_mod.from_args.assert_called_once()
def test_policy_fallback_validation(self):
"""Test policy fallback validation in PD mode."""
# Test with only prefill policy specified
args = RouterArgs(
pd_disaggregation=True,
prefill_urls=[("http://prefill1:8000", None)],
decode_urls=["http://decode1:8001"],
policy="cache_aware",
prefill_policy="power_of_two",
decode_policy=None,
)
# Should not raise validation error
with patch("sglang_router.launch_router.Router") as router_mod:
mock_router_instance = MagicMock()
router_mod.from_args = MagicMock(return_value=mock_router_instance)
launch_router(args)
# Should create router instance via from_args
router_mod.from_args.assert_called_once()
def test_policy_enum_conversion(self):
"""Test policy string to enum conversion."""
# Test all valid policy conversions
assert policy_from_str("random") == PolicyType.Random
assert policy_from_str("round_robin") == PolicyType.RoundRobin
assert policy_from_str("cache_aware") == PolicyType.CacheAware
assert policy_from_str("power_of_two") == PolicyType.PowerOfTwo
def test_invalid_policy_enum_conversion(self):
"""Test invalid policy string to enum conversion."""
with pytest.raises(KeyError):
policy_from_str("invalid_policy")
def test_config_immutability(self):
"""Test that configuration objects are properly immutable."""
args = RouterArgs(
host="127.0.0.1", port=30000, worker_urls=["http://worker1:8000"]
)
# Test that we can't modify the configuration after creation
# (This is more of a design test - dataclasses are mutable by default)
original_host = args.host
args.host = "0.0.0.0"
assert args.host == "0.0.0.0" # Dataclasses are mutable
assert args.host != original_host
def test_config_defaults_consistency(self):
"""Test that configuration defaults are consistent."""
args1 = RouterArgs()
args2 = RouterArgs()
# Both instances should have the same defaults
assert args1.host == args2.host
assert args1.port == args2.port
assert args1.policy == args2.policy
assert args1.worker_urls == args2.worker_urls
assert args1.pd_disaggregation == args2.pd_disaggregation
def test_config_serialization(self):
"""Test that configuration can be serialized/deserialized."""
args = RouterArgs(
host="127.0.0.1",
port=30000,
worker_urls=["http://worker1:8000"],
policy="cache_aware",
cache_threshold=0.5,
)
# Test that we can access all attributes
assert hasattr(args, "host")
assert hasattr(args, "port")
assert hasattr(args, "worker_urls")
assert hasattr(args, "policy")
assert hasattr(args, "cache_threshold")
def test_config_with_none_values(self):
"""Test configuration with None values."""
args = RouterArgs(
api_key=None,
log_dir=None,
log_level=None,
prometheus_port=None,
prometheus_host=None,
request_id_headers=None,
rate_limit_tokens_per_second=None,
service_discovery_namespace=None,
)
# All None values should be preserved
assert args.api_key is None
assert args.log_dir is None
assert args.log_level is None
assert args.prometheus_port is None
assert args.prometheus_host is None
assert args.request_id_headers is None
assert args.rate_limit_tokens_per_second is None
assert args.service_discovery_namespace is None
def test_config_with_empty_lists(self):
"""Test configuration with empty lists."""
args = RouterArgs(
worker_urls=[], prefill_urls=[], decode_urls=[], cors_allowed_origins=[]
)
# All empty lists should be preserved
assert args.worker_urls == []
assert args.prefill_urls == []
assert args.decode_urls == []
assert args.cors_allowed_origins == []
def test_config_with_empty_dicts(self):
"""Test configuration with empty dictionaries."""
args = RouterArgs(selector={}, prefill_selector={}, decode_selector={})
# All empty dictionaries should be preserved
assert args.selector == {}
assert args.prefill_selector == {}
assert args.decode_selector == {}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,506 @@
"""
Unit tests for validation logic in sglang_router.
These tests focus on testing the validation logic in isolation,
including parameter validation, URL validation, and configuration validation.
"""
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from sglang_router.launch_router import RouterArgs, launch_router
class TestURLValidation:
"""Test URL validation logic."""
def test_valid_worker_urls(self):
"""Test validation of valid worker URLs."""
valid_urls = [
"http://worker1:8000",
"https://worker2:8000",
"http://localhost:8000",
"http://127.0.0.1:8000",
"http://192.168.1.100:8000",
"http://worker.example.com:8000",
]
for url in valid_urls:
args = RouterArgs(worker_urls=[url])
# Should not raise any validation errors
assert url in args.worker_urls
def test_valid_prefill_urls(self):
"""Test validation of valid prefill URLs."""
valid_prefill_urls = [
("http://prefill1:8000", 9000),
("https://prefill2:8000", None),
("http://localhost:8000", 9000),
("http://127.0.0.1:8000", None),
]
for url, bootstrap_port in valid_prefill_urls:
args = RouterArgs(prefill_urls=[(url, bootstrap_port)])
# Should not raise any validation errors
assert (url, bootstrap_port) in args.prefill_urls
def test_valid_decode_urls(self):
"""Test validation of valid decode URLs."""
valid_decode_urls = [
"http://decode1:8001",
"https://decode2:8001",
"http://localhost:8001",
"http://127.0.0.1:8001",
]
for url in valid_decode_urls:
args = RouterArgs(decode_urls=[url])
# Should not raise any validation errors
assert url in args.decode_urls
def test_malformed_urls(self):
"""Test handling of malformed URLs."""
# Note: The current implementation doesn't validate URL format
# This test documents the current behavior
malformed_urls = [
"not-a-url",
"ftp://worker1:8000", # Wrong protocol
"http://", # Missing host
":8000", # Missing protocol and host
"http://worker1", # Missing port
]
for url in malformed_urls:
args = RouterArgs(worker_urls=[url])
# Currently, malformed URLs are accepted
# This might be something to improve in the future
assert url in args.worker_urls
class TestPortValidation:
"""Test port validation logic."""
def test_valid_ports(self):
"""Test validation of valid port numbers."""
valid_ports = [1, 80, 8000, 30000, 65535]
for port in valid_ports:
args = RouterArgs(port=port)
assert args.port == port
def test_invalid_ports(self):
"""Test handling of invalid port numbers."""
# Note: The current implementation doesn't validate port ranges
# This test documents the current behavior
invalid_ports = [0, -1, 65536, 70000]
for port in invalid_ports:
args = RouterArgs(port=port)
# Currently, invalid ports are accepted
# This might be something to improve in the future
assert args.port == port
def test_bootstrap_port_validation(self):
"""Test validation of bootstrap ports in PD mode."""
valid_bootstrap_ports = [1, 80, 9000, 30000, 65535, None]
for bootstrap_port in valid_bootstrap_ports:
args = RouterArgs(prefill_urls=[("http://prefill1:8000", bootstrap_port)])
assert args.prefill_urls[0][1] == bootstrap_port
class TestParameterValidation:
"""Test parameter validation logic."""
def test_cache_threshold_validation(self):
"""Test cache threshold parameter validation."""
# Valid cache thresholds
valid_thresholds = [0.0, 0.1, 0.5, 0.9, 1.0]
for threshold in valid_thresholds:
args = RouterArgs(cache_threshold=threshold)
assert args.cache_threshold == threshold
def test_balance_threshold_validation(self):
"""Test load balancing threshold parameter validation."""
# Valid absolute thresholds
valid_abs_thresholds = [0, 1, 32, 64, 128, 1000]
for threshold in valid_abs_thresholds:
args = RouterArgs(balance_abs_threshold=threshold)
assert args.balance_abs_threshold == threshold
# Valid relative thresholds
valid_rel_thresholds = [1.0, 1.1, 1.5, 2.0, 10.0]
for threshold in valid_rel_thresholds:
args = RouterArgs(balance_rel_threshold=threshold)
assert args.balance_rel_threshold == threshold
def test_timeout_validation(self):
"""Test timeout parameter validation."""
# Valid timeouts
valid_timeouts = [1, 30, 60, 300, 600, 1800, 3600]
for timeout in valid_timeouts:
args = RouterArgs(
worker_startup_timeout_secs=timeout,
worker_startup_check_interval=timeout,
request_timeout_secs=timeout,
queue_timeout_secs=timeout,
)
assert args.worker_startup_timeout_secs == timeout
assert args.worker_startup_check_interval == timeout
assert args.request_timeout_secs == timeout
assert args.queue_timeout_secs == timeout
def test_retry_parameter_validation(self):
"""Test retry parameter validation."""
# Valid retry parameters
valid_retry_counts = [0, 1, 3, 5, 10]
for count in valid_retry_counts:
args = RouterArgs(retry_max_retries=count)
assert args.retry_max_retries == count
# Valid backoff parameters
valid_backoff_ms = [1, 50, 100, 1000, 30000]
for backoff in valid_backoff_ms:
args = RouterArgs(
retry_initial_backoff_ms=backoff, retry_max_backoff_ms=backoff
)
assert args.retry_initial_backoff_ms == backoff
assert args.retry_max_backoff_ms == backoff
# Valid multiplier parameters
valid_multipliers = [1.0, 1.5, 2.0, 3.0]
for multiplier in valid_multipliers:
args = RouterArgs(retry_backoff_multiplier=multiplier)
assert args.retry_backoff_multiplier == multiplier
# Valid jitter parameters
valid_jitter = [0.0, 0.1, 0.2, 0.5]
for jitter in valid_jitter:
args = RouterArgs(retry_jitter_factor=jitter)
assert args.retry_jitter_factor == jitter
def test_circuit_breaker_parameter_validation(self):
"""Test circuit breaker parameter validation."""
# Valid failure thresholds
valid_failure_thresholds = [1, 3, 5, 10, 20]
for threshold in valid_failure_thresholds:
args = RouterArgs(cb_failure_threshold=threshold)
assert args.cb_failure_threshold == threshold
# Valid success thresholds
valid_success_thresholds = [1, 2, 3, 5]
for threshold in valid_success_thresholds:
args = RouterArgs(cb_success_threshold=threshold)
assert args.cb_success_threshold == threshold
# Valid timeout durations
valid_timeouts = [10, 30, 60, 120, 300]
for timeout in valid_timeouts:
args = RouterArgs(
cb_timeout_duration_secs=timeout, cb_window_duration_secs=timeout
)
assert args.cb_timeout_duration_secs == timeout
assert args.cb_window_duration_secs == timeout
def test_health_check_parameter_validation(self):
"""Test health check parameter validation."""
# Valid failure thresholds
valid_failure_thresholds = [1, 2, 3, 5, 10]
for threshold in valid_failure_thresholds:
args = RouterArgs(health_failure_threshold=threshold)
assert args.health_failure_threshold == threshold
# Valid success thresholds
valid_success_thresholds = [1, 2, 3, 5]
for threshold in valid_success_thresholds:
args = RouterArgs(health_success_threshold=threshold)
assert args.health_success_threshold == threshold
# Valid timeouts and intervals
valid_times = [1, 5, 10, 30, 60, 120]
for time_val in valid_times:
args = RouterArgs(
health_check_timeout_secs=time_val, health_check_interval_secs=time_val
)
assert args.health_check_timeout_secs == time_val
assert args.health_check_interval_secs == time_val
def test_rate_limiting_parameter_validation(self):
"""Test rate limiting parameter validation."""
# Valid concurrent request limits
valid_limits = [1, 10, 64, 256, 512, 1000]
for limit in valid_limits:
args = RouterArgs(max_concurrent_requests=limit)
assert args.max_concurrent_requests == limit
# Valid queue sizes
valid_queue_sizes = [0, 10, 50, 100, 500, 1000]
for size in valid_queue_sizes:
args = RouterArgs(queue_size=size)
assert args.queue_size == size
# Valid token rates
valid_rates = [1, 10, 50, 100, 500, 1000]
for rate in valid_rates:
args = RouterArgs(rate_limit_tokens_per_second=rate)
assert args.rate_limit_tokens_per_second == rate
def test_tree_size_validation(self):
"""Test tree size parameter validation."""
# Valid tree sizes (powers of 2)
valid_sizes = [2**10, 2**20, 2**24, 2**26, 2**28, 2**30]
for size in valid_sizes:
args = RouterArgs(max_tree_size=size)
assert args.max_tree_size == size
def test_payload_size_validation(self):
"""Test payload size parameter validation."""
# Valid payload sizes
valid_sizes = [
1024, # 1KB
1024 * 1024, # 1MB
10 * 1024 * 1024, # 10MB
100 * 1024 * 1024, # 100MB
512 * 1024 * 1024, # 512MB
1024 * 1024 * 1024, # 1GB
]
for size in valid_sizes:
args = RouterArgs(max_payload_size=size)
assert args.max_payload_size == size
class TestConfigurationValidation:
"""Test configuration validation logic."""
def test_pd_mode_validation(self):
"""Test PD mode configuration validation."""
# Valid PD configuration
args = RouterArgs(
pd_disaggregation=True,
prefill_urls=[("http://prefill1:8000", 9000)],
decode_urls=["http://decode1:8001"],
)
assert args.pd_disaggregation is True
assert len(args.prefill_urls) > 0
assert len(args.decode_urls) > 0
def test_service_discovery_validation(self):
"""Test service discovery configuration validation."""
# Valid service discovery configuration
args = RouterArgs(
service_discovery=True,
selector={"app": "worker", "env": "prod"},
service_discovery_port=8080,
service_discovery_namespace="default",
)
assert args.service_discovery is True
assert args.selector == {"app": "worker", "env": "prod"}
assert args.service_discovery_port == 8080
assert args.service_discovery_namespace == "default"
def test_pd_service_discovery_validation(self):
"""Test PD service discovery configuration validation."""
# Valid PD service discovery configuration
args = RouterArgs(
pd_disaggregation=True,
service_discovery=True,
prefill_selector={"app": "prefill"},
decode_selector={"app": "decode"},
)
assert args.pd_disaggregation is True
assert args.service_discovery is True
assert args.prefill_selector == {"app": "prefill"}
assert args.decode_selector == {"app": "decode"}
def test_policy_validation(self):
"""Test policy configuration validation."""
# Valid policies
valid_policies = ["random", "round_robin", "cache_aware", "power_of_two"]
for policy in valid_policies:
args = RouterArgs(policy=policy)
assert args.policy == policy
def test_pd_policy_validation(self):
"""Test PD policy configuration validation."""
# Valid PD policies
valid_policies = ["random", "round_robin", "cache_aware", "power_of_two"]
for prefill_policy in valid_policies:
for decode_policy in valid_policies:
args = RouterArgs(
pd_disaggregation=True,
prefill_urls=[("http://prefill1:8000", None)],
decode_urls=["http://decode1:8001"],
prefill_policy=prefill_policy,
decode_policy=decode_policy,
)
assert args.prefill_policy == prefill_policy
assert args.decode_policy == decode_policy
def test_cors_validation(self):
"""Test CORS configuration validation."""
# Valid CORS origins
valid_origins = [
[],
["http://localhost:3000"],
["https://example.com"],
["http://localhost:3000", "https://example.com"],
["*"], # Wildcard (if supported)
]
for origins in valid_origins:
args = RouterArgs(cors_allowed_origins=origins)
assert args.cors_allowed_origins == origins
def test_logging_validation(self):
"""Test logging configuration validation."""
# Valid log levels
valid_log_levels = ["debug", "info", "warning", "error", "critical"]
for level in valid_log_levels:
args = RouterArgs(log_level=level)
assert args.log_level == level
def test_prometheus_validation(self):
"""Test Prometheus configuration validation."""
# Valid Prometheus configuration
args = RouterArgs(prometheus_port=29000, prometheus_host="127.0.0.1")
assert args.prometheus_port == 29000
assert args.prometheus_host == "127.0.0.1"
def test_tokenizer_validation(self):
"""Test tokenizer configuration validation."""
# Note: model_path and tokenizer_path are not available in current RouterArgs
pytest.skip("Tokenizer configuration not available in current implementation")
def test_request_id_headers_validation(self):
"""Test request ID headers configuration validation."""
# Valid request ID headers
valid_headers = [
["x-request-id"],
["x-request-id", "x-trace-id"],
["x-request-id", "x-trace-id", "x-correlation-id"],
["custom-header"],
]
for headers in valid_headers:
args = RouterArgs(request_id_headers=headers)
assert args.request_id_headers == headers
class TestLaunchValidation:
"""Test launch-time validation logic."""
def test_pd_mode_requires_urls(self):
"""Test that PD mode requires prefill and decode URLs."""
# PD mode without URLs should fail
args = RouterArgs(
pd_disaggregation=True,
prefill_urls=[],
decode_urls=[],
service_discovery=False,
)
with pytest.raises(
ValueError, match="PD disaggregation mode requires --prefill"
):
launch_router(args)
def test_pd_mode_with_service_discovery_allows_empty_urls(self):
"""Test that PD mode with service discovery allows empty URLs."""
args = RouterArgs(
pd_disaggregation=True,
prefill_urls=[],
decode_urls=[],
service_discovery=True,
)
# Should not raise validation error
with patch("sglang_router.launch_router.Router") as router_mod:
mock_router_instance = MagicMock()
router_mod.from_args = MagicMock(return_value=mock_router_instance)
launch_router(args)
# Should create router instance via from_args
router_mod.from_args.assert_called_once()
def test_regular_mode_allows_empty_worker_urls(self):
"""Test that regular mode allows empty worker URLs."""
args = RouterArgs(worker_urls=[], service_discovery=False)
# Should not raise validation error
with patch("sglang_router.launch_router.Router") as router_mod:
mock_router_instance = MagicMock()
router_mod.from_args = MagicMock(return_value=mock_router_instance)
launch_router(args)
# Should create router instance via from_args
router_mod.from_args.assert_called_once()
def test_launch_with_valid_config(self):
"""Test launching with valid configuration."""
args = RouterArgs(
host="127.0.0.1",
port=30000,
worker_urls=["http://worker1:8000"],
policy="cache_aware",
)
# Should not raise validation error
with patch("sglang_router.launch_router.Router") as router_mod:
mock_router_instance = MagicMock()
router_mod.from_args = MagicMock(return_value=mock_router_instance)
launch_router(args)
# Should create router instance via from_args
router_mod.from_args.assert_called_once()
def test_launch_with_pd_config(self):
"""Test launching with valid PD configuration."""
args = RouterArgs(
pd_disaggregation=True,
prefill_urls=[("http://prefill1:8000", 9000)],
decode_urls=["http://decode1:8001"],
policy="cache_aware",
)
# Should not raise validation error
with patch("sglang_router.launch_router.Router") as router_mod:
mock_router_instance = MagicMock()
router_mod.from_args = MagicMock(return_value=mock_router_instance)
launch_router(args)
# Should create router instance via from_args
router_mod.from_args.assert_called_once()
def test_launch_with_service_discovery_config(self):
"""Test launching with valid service discovery configuration."""
args = RouterArgs(
service_discovery=True,
selector={"app": "worker"},
service_discovery_port=8080,
)
# Should not raise validation error
with patch("sglang_router.launch_router.Router") as router_mod:
mock_router_instance = MagicMock()
router_mod.from_args = MagicMock(return_value=mock_router_instance)
launch_router(args)
# Should create router instance via from_args
router_mod.from_args.assert_called_once()