sglangv0.5.2 & support Qwen3-Next-80B-A3B-Instruct
This commit is contained in:
1
sgl-router/py_test/__init__.py
Normal file
1
sgl-router/py_test/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Test package root for router Python tests."""
|
||||
8
sgl-router/py_test/conftest.py
Normal file
8
sgl-router/py_test/conftest.py
Normal file
@@ -0,0 +1,8 @@
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Ensure local sources in py_src are importable ahead of any installed package
|
||||
_ROOT = Path(__file__).resolve().parents[1]
|
||||
_SRC = _ROOT / "py_src"
|
||||
if str(_SRC) not in sys.path:
|
||||
sys.path.insert(0, str(_SRC))
|
||||
512
sgl-router/py_test/e2e/conftest.py
Normal file
512
sgl-router/py_test/e2e/conftest.py
Normal file
@@ -0,0 +1,512 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import signal
|
||||
import socket
|
||||
import subprocess
|
||||
import time
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from typing import Callable, Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _find_available_port() -> int:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(("127.0.0.1", 0))
|
||||
return s.getsockname()[1]
|
||||
|
||||
|
||||
def _parse_url(base_url: str) -> tuple[str, str]:
|
||||
"""Parse a base URL and return (host, port) as strings.
|
||||
|
||||
This is more robust than simple string splitting and supports different schemes
|
||||
and URL shapes like trailing paths.
|
||||
"""
|
||||
parsed = urlparse(base_url)
|
||||
return parsed.hostname or "127.0.0.1", (
|
||||
str(parsed.port) if parsed.port is not None else ""
|
||||
)
|
||||
|
||||
|
||||
def _wait_router_health(base_url: str, timeout: float) -> None:
|
||||
start = time.perf_counter()
|
||||
with requests.Session() as session:
|
||||
while time.perf_counter() - start < timeout:
|
||||
try:
|
||||
r = session.get(f"{base_url}/health", timeout=5)
|
||||
if r.status_code == 200:
|
||||
return
|
||||
except requests.RequestException:
|
||||
pass
|
||||
time.sleep(2)
|
||||
raise TimeoutError("Router failed to become healthy in time")
|
||||
|
||||
|
||||
def _popen_launch_router(
|
||||
model: str,
|
||||
base_url: str,
|
||||
dp_size: int,
|
||||
timeout: float,
|
||||
policy: str = "cache_aware",
|
||||
) -> subprocess.Popen:
|
||||
host, port = _parse_url(base_url)
|
||||
|
||||
prom_port = _find_available_port()
|
||||
|
||||
cmd = [
|
||||
"python3",
|
||||
"-m",
|
||||
"sglang_router.launch_server",
|
||||
"--model-path",
|
||||
model,
|
||||
"--host",
|
||||
host,
|
||||
"--port",
|
||||
port,
|
||||
"--dp",
|
||||
str(dp_size),
|
||||
"--router-policy",
|
||||
policy,
|
||||
"--allow-auto-truncate",
|
||||
"--router-prometheus-port",
|
||||
str(prom_port),
|
||||
"--router-prometheus-host",
|
||||
"127.0.0.1",
|
||||
]
|
||||
|
||||
proc = subprocess.Popen(cmd)
|
||||
_wait_router_health(base_url, timeout)
|
||||
return proc
|
||||
|
||||
|
||||
def _popen_launch_worker(
|
||||
model: str,
|
||||
base_url: str,
|
||||
*,
|
||||
dp_size: int | None = None,
|
||||
api_key: str | None = None,
|
||||
base_gpu_id: int | None = 0,
|
||||
) -> subprocess.Popen:
|
||||
host, port = _parse_url(base_url)
|
||||
|
||||
cmd = [
|
||||
"python3",
|
||||
"-m",
|
||||
"sglang.launch_server",
|
||||
"--model-path",
|
||||
model,
|
||||
"--host",
|
||||
host,
|
||||
"--port",
|
||||
port,
|
||||
"--base-gpu-id",
|
||||
str(base_gpu_id or 0),
|
||||
]
|
||||
if dp_size is not None:
|
||||
cmd += ["--dp-size", str(dp_size)]
|
||||
if api_key is not None:
|
||||
cmd += ["--api-key", api_key]
|
||||
return subprocess.Popen(cmd)
|
||||
|
||||
|
||||
def _popen_launch_router_only(
|
||||
base_url: str,
|
||||
policy: str = "round_robin",
|
||||
timeout: float = 120.0,
|
||||
*,
|
||||
dp_aware: bool = False,
|
||||
api_key: str | None = None,
|
||||
) -> subprocess.Popen:
|
||||
host, port = _parse_url(base_url)
|
||||
|
||||
prom_port = _find_available_port()
|
||||
cmd = [
|
||||
"python3",
|
||||
"-m",
|
||||
"sglang_router.launch_router",
|
||||
"--host",
|
||||
host,
|
||||
"--port",
|
||||
port,
|
||||
"--policy",
|
||||
policy,
|
||||
]
|
||||
if dp_aware:
|
||||
cmd += ["--dp-aware"]
|
||||
if api_key is not None:
|
||||
cmd += ["--api-key", api_key]
|
||||
cmd += [
|
||||
"--prometheus-port",
|
||||
str(prom_port),
|
||||
"--prometheus-host",
|
||||
"127.0.0.1",
|
||||
]
|
||||
proc = subprocess.Popen(cmd)
|
||||
_wait_router_health(base_url, timeout)
|
||||
return proc
|
||||
|
||||
|
||||
def _terminate(proc: subprocess.Popen, timeout: float = 120) -> None:
|
||||
if proc is None:
|
||||
return
|
||||
proc.terminate()
|
||||
start = time.perf_counter()
|
||||
while proc.poll() is None:
|
||||
if time.perf_counter() - start > timeout:
|
||||
proc.kill()
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
def _which(cmd: str) -> Optional[str]:
|
||||
try:
|
||||
return shutil.which(cmd)
|
||||
except Exception as e:
|
||||
logger.warning("shutil.which(%r) failed: %s", cmd, e)
|
||||
return None
|
||||
|
||||
|
||||
def _graceful_stop_popen(p: subprocess.Popen) -> None:
|
||||
if p is None:
|
||||
return
|
||||
try:
|
||||
if p.poll() is None:
|
||||
p.terminate()
|
||||
for _ in range(5):
|
||||
if p.poll() is not None:
|
||||
break
|
||||
time.sleep(1)
|
||||
if p.poll() is None:
|
||||
p.kill()
|
||||
except Exception as e:
|
||||
logger.warning("Exception during graceful stop of popen: %s", e)
|
||||
|
||||
|
||||
def _pid_alive(pid: int) -> bool:
|
||||
try:
|
||||
os.kill(pid, 0)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _graceful_stop_pid(pid: int) -> None:
|
||||
try:
|
||||
if _pid_alive(pid):
|
||||
try:
|
||||
os.kill(pid, signal.SIGTERM)
|
||||
except Exception:
|
||||
pass
|
||||
for _ in range(5):
|
||||
if not _pid_alive(pid):
|
||||
break
|
||||
time.sleep(1)
|
||||
if _pid_alive(pid):
|
||||
try:
|
||||
os.kill(pid, signal.SIGKILL)
|
||||
except Exception:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _graceful_stop_any(obj) -> None:
|
||||
try:
|
||||
if isinstance(obj, subprocess.Popen):
|
||||
_graceful_stop_popen(obj)
|
||||
return
|
||||
if isinstance(obj, int):
|
||||
_graceful_stop_pid(obj)
|
||||
return
|
||||
proc_obj = getattr(obj, "proc", None)
|
||||
if isinstance(proc_obj, subprocess.Popen):
|
||||
_graceful_stop_popen(proc_obj)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def genai_bench_runner() -> Callable[..., None]:
|
||||
"""Provide a callable to run genai-bench and validate metrics.
|
||||
|
||||
Usage in tests:
|
||||
def test(..., genai_bench_runner):
|
||||
genai_bench_runner(router_url=..., model_path=..., experiment_folder=...)
|
||||
"""
|
||||
|
||||
def _run(
|
||||
*,
|
||||
router_url: str,
|
||||
model_path: str,
|
||||
experiment_folder: str,
|
||||
timeout_sec: int | None = None,
|
||||
thresholds: dict | None = None,
|
||||
extra_env: dict | None = None,
|
||||
num_concurrency: int = 32,
|
||||
traffic_scenario: str = "D(4000,100)",
|
||||
max_requests_per_run: int | None = None,
|
||||
clean_experiment: bool = True,
|
||||
kill_procs: list | None = None,
|
||||
drain_delay_sec: int = 6,
|
||||
) -> None:
|
||||
cli = _which("genai-bench")
|
||||
if not cli:
|
||||
pytest.fail(
|
||||
"genai-bench CLI not found; please install it to run benchmarks"
|
||||
)
|
||||
|
||||
# Clean previous experiment folder under current working directory
|
||||
if clean_experiment:
|
||||
exp_dir = Path.cwd() / experiment_folder
|
||||
if exp_dir.exists():
|
||||
shutil.rmtree(exp_dir, ignore_errors=True)
|
||||
|
||||
# Default requests per run if not provided
|
||||
mrr = (
|
||||
max_requests_per_run
|
||||
if max_requests_per_run is not None
|
||||
else num_concurrency * 3
|
||||
)
|
||||
|
||||
cmd = [
|
||||
cli,
|
||||
"benchmark",
|
||||
"--api-backend",
|
||||
"openai",
|
||||
"--api-base",
|
||||
router_url,
|
||||
"--api-key",
|
||||
"dummy-token",
|
||||
"--api-model-name",
|
||||
model_path,
|
||||
"--model-tokenizer",
|
||||
model_path,
|
||||
"--task",
|
||||
"text-to-text",
|
||||
"--num-concurrency",
|
||||
str(num_concurrency),
|
||||
"--traffic-scenario",
|
||||
traffic_scenario,
|
||||
"--max-requests-per-run",
|
||||
str(mrr),
|
||||
"--max-time-per-run",
|
||||
"2",
|
||||
"--experiment-folder-name",
|
||||
experiment_folder,
|
||||
"--experiment-base-dir",
|
||||
str(Path.cwd()),
|
||||
]
|
||||
|
||||
env = os.environ.copy()
|
||||
if extra_env:
|
||||
env.update(extra_env)
|
||||
|
||||
to = timeout_sec or int(os.environ.get("GENAI_BENCH_TEST_TIMEOUT", "120"))
|
||||
proc = subprocess.Popen(
|
||||
cmd, env=env, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True
|
||||
)
|
||||
stdout = stderr = ""
|
||||
rc = None
|
||||
try:
|
||||
try:
|
||||
stdout, stderr = proc.communicate(timeout=to)
|
||||
except subprocess.TimeoutExpired:
|
||||
# Simple: kill the CLI process if it doesn't exit in time
|
||||
try:
|
||||
proc.kill()
|
||||
except Exception:
|
||||
pass
|
||||
stdout, stderr = proc.communicate()
|
||||
rc = proc.returncode
|
||||
|
||||
# Prefer exact path under cwd; fallback to rglob search
|
||||
base = Path.cwd()
|
||||
direct = base / experiment_folder
|
||||
candidates = [direct] if direct.is_dir() else []
|
||||
if not candidates:
|
||||
for p in base.rglob(experiment_folder):
|
||||
if p.is_dir() and p.name == experiment_folder:
|
||||
candidates = [p]
|
||||
break
|
||||
if not candidates:
|
||||
raise AssertionError(
|
||||
"Benchmark failed: experiment folder not found: "
|
||||
f"{experiment_folder}\nExit code: {rc}\nSTDOUT (tail):\n{stdout[-1000:]}\nSTDERR (tail):\n{stderr[-1000:]}"
|
||||
)
|
||||
actual_folder = candidates[0]
|
||||
|
||||
json_files = [
|
||||
p
|
||||
for p in actual_folder.rglob("*.json")
|
||||
if "experiment_metadata" not in p.name
|
||||
]
|
||||
if not json_files:
|
||||
raise AssertionError(
|
||||
"Benchmark failed: no JSON results found\n"
|
||||
f"Exit code: {rc}\nSTDOUT (tail):\n{stdout[-1000:]}\nSTDERR (tail):\n{stderr[-1000:]}"
|
||||
)
|
||||
|
||||
th = thresholds # None means "log only", no validation
|
||||
|
||||
for jf in json_files:
|
||||
with jf.open("r") as f:
|
||||
data = json.load(f)
|
||||
stats = data.get("aggregated_metrics", {}).get("stats", {})
|
||||
ttft_mean = float(stats.get("ttft", {}).get("mean", float("inf")))
|
||||
e2e_latency_mean = float(
|
||||
stats.get("e2e_latency", {}).get("mean", float("inf"))
|
||||
)
|
||||
input_tp_mean = float(stats.get("input_throughput", {}).get("mean", 0.0))
|
||||
output_tp_mean = float(stats.get("output_throughput", {}).get("mean", 0.0))
|
||||
|
||||
logger.info(
|
||||
"genai-bench[%s] %s ttft_mean=%.3fs e2e_latency_mean=%.3fs input_tp_mean=%.1f tok/s output_tp_mean=%.1f tok/s",
|
||||
experiment_folder,
|
||||
jf.name,
|
||||
ttft_mean,
|
||||
e2e_latency_mean,
|
||||
input_tp_mean,
|
||||
output_tp_mean,
|
||||
)
|
||||
|
||||
if th is not None:
|
||||
assert (
|
||||
ttft_mean <= th["ttft_mean_max"]
|
||||
), f"TTFT validation failed: {ttft_mean} > {th['ttft_mean_max']} (file={jf.name})"
|
||||
assert (
|
||||
e2e_latency_mean <= th["e2e_latency_mean_max"]
|
||||
), f"E2E latency validation failed: {e2e_latency_mean} > {th['e2e_latency_mean_max']} (file={jf.name})"
|
||||
assert (
|
||||
input_tp_mean >= th["input_throughput_mean_min"]
|
||||
), f"Input throughput validation failed: {input_tp_mean} < {th['input_throughput_mean_min']} (file={jf.name})"
|
||||
assert (
|
||||
output_tp_mean >= th["output_throughput_mean_min"]
|
||||
), f"Output throughput validation failed: {output_tp_mean} < {th['output_throughput_mean_min']} (file={jf.name})"
|
||||
|
||||
finally:
|
||||
# Always attempt to stop workers to avoid resource leakage
|
||||
if kill_procs:
|
||||
# Give router/workers a small grace period to finish any last drains
|
||||
if drain_delay_sec > 0:
|
||||
try:
|
||||
time.sleep(drain_delay_sec)
|
||||
except Exception:
|
||||
pass
|
||||
for p in kill_procs:
|
||||
_graceful_stop_any(p)
|
||||
try:
|
||||
time.sleep(2)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return _run
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
config.addinivalue_line("markers", "e2e: mark as end-to-end test")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def e2e_model() -> str:
|
||||
# Always use the default test model
|
||||
return DEFAULT_MODEL_NAME_FOR_TEST
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def e2e_router(e2e_model: str):
|
||||
# Keep this available but tests below use router-only to avoid GPU contention
|
||||
base_url = DEFAULT_URL_FOR_TEST
|
||||
proc = _popen_launch_router(
|
||||
e2e_model, base_url, dp_size=2, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
|
||||
)
|
||||
try:
|
||||
yield SimpleNamespace(proc=proc, url=base_url)
|
||||
finally:
|
||||
_terminate(proc)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def e2e_router_only_rr():
|
||||
port = _find_available_port()
|
||||
base_url = f"http://127.0.0.1:{port}"
|
||||
proc = _popen_launch_router_only(base_url, policy="round_robin")
|
||||
try:
|
||||
yield SimpleNamespace(proc=proc, url=base_url)
|
||||
finally:
|
||||
_terminate(proc)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def e2e_primary_worker(e2e_model: str):
|
||||
port = _find_available_port()
|
||||
base_url = f"http://127.0.0.1:{port}"
|
||||
proc = _popen_launch_worker(e2e_model, base_url)
|
||||
# Router health gate will handle worker readiness
|
||||
try:
|
||||
yield SimpleNamespace(proc=proc, url=base_url)
|
||||
finally:
|
||||
_terminate(proc)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def e2e_router_only_rr_dp_aware_api():
|
||||
"""Router-only with dp-aware enabled and an API key."""
|
||||
port = _find_available_port()
|
||||
base_url = f"http://127.0.0.1:{port}"
|
||||
api_key = "secret"
|
||||
proc = _popen_launch_router_only(
|
||||
base_url, policy="round_robin", timeout=180.0, dp_aware=True, api_key=api_key
|
||||
)
|
||||
try:
|
||||
yield SimpleNamespace(proc=proc, url=base_url, api_key=api_key)
|
||||
finally:
|
||||
_terminate(proc)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def e2e_worker_dp2_api(e2e_model: str, e2e_router_only_rr_dp_aware_api):
|
||||
"""Worker with dp-size=2 and the same API key as the dp-aware router."""
|
||||
port = _find_available_port()
|
||||
base_url = f"http://127.0.0.1:{port}"
|
||||
api_key = e2e_router_only_rr_dp_aware_api.api_key
|
||||
proc = _popen_launch_worker(e2e_model, base_url, dp_size=2, api_key=api_key)
|
||||
try:
|
||||
yield SimpleNamespace(proc=proc, url=base_url)
|
||||
finally:
|
||||
_terminate(proc)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def e2e_two_workers_dp2(e2e_model: str):
|
||||
"""Launch two workers, each with dp_size=2, mapped to GPUs [0,1] and [2,3]."""
|
||||
workers = []
|
||||
try:
|
||||
# Worker A on GPUs 0-1
|
||||
port_a = _find_available_port()
|
||||
url_a = f"http://127.0.0.1:{port_a}"
|
||||
proc_a = _popen_launch_worker(e2e_model, url_a, dp_size=2, base_gpu_id=0)
|
||||
workers.append(SimpleNamespace(proc=proc_a, url=url_a))
|
||||
|
||||
# Worker B on GPUs 2-3
|
||||
port_b = _find_available_port()
|
||||
url_b = f"http://127.0.0.1:{port_b}"
|
||||
proc_b = _popen_launch_worker(e2e_model, url_b, dp_size=2, base_gpu_id=2)
|
||||
workers.append(SimpleNamespace(proc=proc_b, url=url_b))
|
||||
|
||||
yield workers
|
||||
finally:
|
||||
for w in workers:
|
||||
_terminate(w.proc)
|
||||
262
sgl-router/py_test/e2e/test_pd_router.py
Normal file
262
sgl-router/py_test/e2e/test_pd_router.py
Normal file
@@ -0,0 +1,262 @@
|
||||
import logging
|
||||
import os
|
||||
import socket
|
||||
import subprocess
|
||||
import time
|
||||
from types import SimpleNamespace
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from sglang.test.run_eval import run_eval
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _find_available_port() -> int:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(("127.0.0.1", 0))
|
||||
return s.getsockname()[1]
|
||||
|
||||
|
||||
def _wait_health(url: str, timeout: float = 180.0) -> None:
|
||||
start = time.perf_counter()
|
||||
with requests.Session() as session:
|
||||
while time.perf_counter() - start < timeout:
|
||||
try:
|
||||
r = session.get(f"{url}/health", timeout=5)
|
||||
if r.status_code == 200:
|
||||
return
|
||||
except requests.RequestException:
|
||||
pass
|
||||
time.sleep(1)
|
||||
raise TimeoutError(f"Service at {url} failed to become healthy in time")
|
||||
|
||||
|
||||
def _detect_ib_device() -> Optional[str]:
|
||||
"""Return first active IB device name (e.g., mlx5_0) or None if unavailable."""
|
||||
# Fast check that ibv_devinfo exists
|
||||
try:
|
||||
subprocess.run(
|
||||
["ibv_devinfo", "-l"],
|
||||
stdout=subprocess.DEVNULL,
|
||||
stderr=subprocess.DEVNULL,
|
||||
timeout=1,
|
||||
)
|
||||
except (FileNotFoundError, subprocess.TimeoutExpired):
|
||||
return None
|
||||
|
||||
for i in range(12):
|
||||
dev = f"mlx5_{i}"
|
||||
try:
|
||||
res = subprocess.run(
|
||||
["ibv_devinfo", dev],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=2,
|
||||
)
|
||||
if res.returncode == 0 and ("state:" in res.stdout):
|
||||
for line in res.stdout.splitlines():
|
||||
if "state:" in line and "PORT_ACTIVE" in line:
|
||||
return dev
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def _popen_launch_prefill_worker(
|
||||
model: str,
|
||||
bootstrap_port: int,
|
||||
ib_device: Optional[str] = None,
|
||||
base_gpu_id: int = 0,
|
||||
) -> SimpleNamespace:
|
||||
port = _find_available_port()
|
||||
url = f"http://127.0.0.1:{port}"
|
||||
cmd = [
|
||||
"python3",
|
||||
"-m",
|
||||
"sglang.launch_server",
|
||||
"--model-path",
|
||||
model,
|
||||
"--disaggregation-mode",
|
||||
"prefill",
|
||||
"--host",
|
||||
"127.0.0.1",
|
||||
"--port",
|
||||
str(port),
|
||||
"--disaggregation-bootstrap-port",
|
||||
str(bootstrap_port),
|
||||
"--base-gpu-id",
|
||||
str(base_gpu_id),
|
||||
]
|
||||
if ib_device:
|
||||
cmd += ["--disaggregation-ib-device", ib_device]
|
||||
proc = subprocess.Popen(cmd)
|
||||
_wait_health(url, timeout=300.0)
|
||||
return SimpleNamespace(proc=proc, url=url, bootstrap_port=bootstrap_port)
|
||||
|
||||
|
||||
def _popen_launch_decode_worker(
|
||||
model: str, ib_device: Optional[str] = None, base_gpu_id: int = 0
|
||||
) -> SimpleNamespace:
|
||||
port = _find_available_port()
|
||||
url = f"http://127.0.0.1:{port}"
|
||||
cmd = [
|
||||
"python3",
|
||||
"-m",
|
||||
"sglang.launch_server",
|
||||
"--model-path",
|
||||
model,
|
||||
"--disaggregation-mode",
|
||||
"decode",
|
||||
"--host",
|
||||
"127.0.0.1",
|
||||
"--port",
|
||||
str(port),
|
||||
"--base-gpu-id",
|
||||
str(base_gpu_id),
|
||||
]
|
||||
if ib_device:
|
||||
cmd += ["--disaggregation-ib-device", ib_device]
|
||||
proc = subprocess.Popen(cmd)
|
||||
_wait_health(url, timeout=300.0)
|
||||
return SimpleNamespace(proc=proc, url=url)
|
||||
|
||||
|
||||
def _terminate(proc: subprocess.Popen, timeout: float = 120) -> None:
|
||||
if proc is None:
|
||||
return
|
||||
proc.terminate()
|
||||
start = time.perf_counter()
|
||||
while proc.poll() is None:
|
||||
if time.perf_counter() - start > timeout:
|
||||
proc.kill()
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def pd_cluster(e2e_model: str):
|
||||
"""Start 2 prefill + 2 decode workers and one PD router, once per module."""
|
||||
# Environment capability checks: require sgl_kernel and GPU backend
|
||||
try:
|
||||
import sgl_kernel # noqa: F401
|
||||
except Exception as e: # pragma: no cover - environment dependent
|
||||
pytest.fail(f"PD e2e requires sgl_kernel but it is not available: {e}")
|
||||
|
||||
try:
|
||||
import torch # noqa: F401
|
||||
except Exception as e: # pragma: no cover - environment dependent
|
||||
pytest.fail(
|
||||
f"PD e2e requires torch but it is not available or misconfigured: {e}"
|
||||
)
|
||||
|
||||
if not torch.cuda.is_available(): # pragma: no cover - environment dependent
|
||||
pytest.fail("PD e2e requires CUDA backend, but CUDA is not available")
|
||||
|
||||
workers: list[SimpleNamespace] = []
|
||||
router_proc = None
|
||||
try:
|
||||
ib_device = _detect_ib_device()
|
||||
|
||||
# Launch 4 workers across 4 GPUs: prefill on 0,1 and decode on 2,3
|
||||
pf1 = _popen_launch_prefill_worker(
|
||||
e2e_model,
|
||||
bootstrap_port=_find_available_port(),
|
||||
ib_device=ib_device,
|
||||
base_gpu_id=0,
|
||||
)
|
||||
pf2 = _popen_launch_prefill_worker(
|
||||
e2e_model,
|
||||
bootstrap_port=_find_available_port(),
|
||||
ib_device=ib_device,
|
||||
base_gpu_id=1,
|
||||
)
|
||||
dc1 = _popen_launch_decode_worker(e2e_model, ib_device=ib_device, base_gpu_id=2)
|
||||
dc2 = _popen_launch_decode_worker(e2e_model, ib_device=ib_device, base_gpu_id=3)
|
||||
prefills = [pf1, pf2]
|
||||
decodes = [dc1, dc2]
|
||||
workers.extend(prefills + decodes)
|
||||
|
||||
# PD router with two prefill and two decode endpoints
|
||||
rport = _find_available_port()
|
||||
router_url = f"http://127.0.0.1:{rport}"
|
||||
pport = _find_available_port()
|
||||
|
||||
prefill = [(pf.url, pf.bootstrap_port) for pf in prefills]
|
||||
decode = [dc.url for dc in decodes]
|
||||
|
||||
cmd = [
|
||||
"python3",
|
||||
"-m",
|
||||
"sglang_router.launch_router",
|
||||
"--host",
|
||||
"127.0.0.1",
|
||||
"--port",
|
||||
str(rport),
|
||||
"--policy",
|
||||
"round_robin",
|
||||
"--pd-disaggregation",
|
||||
]
|
||||
for url, bport in prefill:
|
||||
cmd += ["--prefill", url, str(bport)]
|
||||
for url in decode:
|
||||
cmd += ["--decode", url]
|
||||
cmd += [
|
||||
"--prometheus-port",
|
||||
str(pport),
|
||||
"--prometheus-host",
|
||||
"127.0.0.1",
|
||||
]
|
||||
|
||||
router_proc = subprocess.Popen(cmd)
|
||||
_wait_health(router_url, timeout=180.0)
|
||||
|
||||
yield SimpleNamespace(
|
||||
router_url=router_url, workers=workers, router_proc=router_proc
|
||||
)
|
||||
finally:
|
||||
if router_proc is not None:
|
||||
_terminate(router_proc)
|
||||
for w in workers:
|
||||
_terminate(w.proc)
|
||||
|
||||
|
||||
@pytest.mark.e2e
|
||||
def test_pd_mmlu(e2e_model: str, pd_cluster):
|
||||
"""
|
||||
Launch 4 workers, start a PD router (2 prefill + 2 decode), then run MMLU.
|
||||
"""
|
||||
args = SimpleNamespace(
|
||||
base_url=pd_cluster.router_url,
|
||||
model=e2e_model,
|
||||
eval_name="mmlu",
|
||||
num_examples=64,
|
||||
num_threads=32,
|
||||
temperature=0.1,
|
||||
)
|
||||
metrics = run_eval(args)
|
||||
assert metrics["score"] >= 0.65
|
||||
|
||||
|
||||
@pytest.mark.e2e
|
||||
def test_pd_genai_bench(e2e_model: str, pd_cluster, genai_bench_runner):
|
||||
"""
|
||||
Launch 4 workers, start a PD router (2 prefill + 2 decode), then run a
|
||||
short genai-bench benchmark and validate aggregate metrics.
|
||||
"""
|
||||
# Run genai-bench against the shared router
|
||||
policy_label = "benchmark_round_robin_pd"
|
||||
genai_bench_runner(
|
||||
router_url=pd_cluster.router_url,
|
||||
model_path=e2e_model,
|
||||
experiment_folder=policy_label,
|
||||
thresholds={
|
||||
"ttft_mean_max": 12,
|
||||
"e2e_latency_mean_max": 15,
|
||||
"input_throughput_mean_min": 400,
|
||||
"output_throughput_mean_min": 20,
|
||||
},
|
||||
kill_procs=pd_cluster.workers,
|
||||
)
|
||||
169
sgl-router/py_test/e2e/test_regular_router.py
Normal file
169
sgl-router/py_test/e2e/test_regular_router.py
Normal file
@@ -0,0 +1,169 @@
|
||||
import threading
|
||||
import time
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from sglang.test.run_eval import run_eval
|
||||
|
||||
|
||||
@pytest.mark.e2e
|
||||
def test_mmlu(e2e_router_only_rr, e2e_two_workers_dp2, e2e_model):
|
||||
# Attach two dp=2 workers (total 4 GPUs) to a fresh router-only instance
|
||||
base = e2e_router_only_rr.url
|
||||
for w in e2e_two_workers_dp2:
|
||||
r = requests.post(f"{base}/add_worker", params={"url": w.url}, timeout=180)
|
||||
r.raise_for_status()
|
||||
|
||||
args = SimpleNamespace(
|
||||
base_url=base,
|
||||
model=e2e_model,
|
||||
eval_name="mmlu",
|
||||
num_examples=64,
|
||||
num_threads=32,
|
||||
temperature=0.1,
|
||||
)
|
||||
metrics = run_eval(args)
|
||||
assert metrics["score"] >= 0.65
|
||||
|
||||
|
||||
@pytest.mark.e2e
|
||||
def test_genai_bench(
|
||||
e2e_router_only_rr, e2e_two_workers_dp2, e2e_model, genai_bench_runner
|
||||
):
|
||||
"""Attach a worker to the regular router and run a short genai-bench."""
|
||||
base = e2e_router_only_rr.url
|
||||
for w in e2e_two_workers_dp2:
|
||||
r = requests.post(f"{base}/add_worker", params={"url": w.url}, timeout=180)
|
||||
r.raise_for_status()
|
||||
|
||||
genai_bench_runner(
|
||||
router_url=base,
|
||||
model_path=e2e_model,
|
||||
experiment_folder="benchmark_round_robin_regular",
|
||||
thresholds={
|
||||
"ttft_mean_max": 6,
|
||||
"e2e_latency_mean_max": 14,
|
||||
"input_throughput_mean_min": 1000,
|
||||
"output_throughput_mean_min": 12,
|
||||
},
|
||||
kill_procs=e2e_two_workers_dp2,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.e2e
|
||||
def test_add_and_remove_worker_live(e2e_router_only_rr, e2e_primary_worker, e2e_model):
|
||||
base = e2e_router_only_rr.url
|
||||
worker_url = e2e_primary_worker.url
|
||||
|
||||
r = requests.post(f"{base}/add_worker", params={"url": worker_url}, timeout=180)
|
||||
r.raise_for_status()
|
||||
|
||||
with requests.Session() as s:
|
||||
for i in range(8):
|
||||
r = s.post(
|
||||
f"{base}/v1/completions",
|
||||
json={
|
||||
"model": e2e_model,
|
||||
"prompt": f"x{i}",
|
||||
"max_tokens": 1,
|
||||
"stream": False,
|
||||
},
|
||||
timeout=120,
|
||||
)
|
||||
r.raise_for_status()
|
||||
|
||||
# Remove the worker
|
||||
r = requests.post(f"{base}/remove_worker", params={"url": worker_url}, timeout=60)
|
||||
r.raise_for_status()
|
||||
|
||||
|
||||
@pytest.mark.e2e
|
||||
def test_lazy_fault_tolerance_live(e2e_router_only_rr, e2e_primary_worker, e2e_model):
|
||||
base = e2e_router_only_rr.url
|
||||
worker = e2e_primary_worker
|
||||
|
||||
r = requests.post(f"{base}/add_worker", params={"url": worker.url}, timeout=180)
|
||||
r.raise_for_status()
|
||||
|
||||
def killer():
|
||||
time.sleep(10)
|
||||
try:
|
||||
worker.proc.terminate()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
t = threading.Thread(target=killer, daemon=True)
|
||||
t.start()
|
||||
|
||||
args = SimpleNamespace(
|
||||
base_url=base,
|
||||
model=e2e_model,
|
||||
eval_name="mmlu",
|
||||
num_examples=32,
|
||||
num_threads=16,
|
||||
temperature=0.0,
|
||||
)
|
||||
metrics = run_eval(args)
|
||||
assert 0.0 <= metrics["score"] <= 1.0
|
||||
|
||||
|
||||
@pytest.mark.e2e
|
||||
def test_dp_aware_worker_expansion_and_api_key(
|
||||
e2e_model,
|
||||
e2e_router_only_rr_dp_aware_api,
|
||||
e2e_worker_dp2_api,
|
||||
):
|
||||
"""
|
||||
Launch a router-only instance in dp_aware mode and a single worker with dp_size=2
|
||||
and API key protection. Verify expansion, auth enforcement, and basic eval.
|
||||
"""
|
||||
import os
|
||||
|
||||
router_url = e2e_router_only_rr_dp_aware_api.url
|
||||
worker_url = e2e_worker_dp2_api.url
|
||||
api_key = e2e_router_only_rr_dp_aware_api.api_key
|
||||
|
||||
# Attach worker; router should expand to dp_size logical workers
|
||||
r = requests.post(
|
||||
f"{router_url}/add_worker", params={"url": worker_url}, timeout=180
|
||||
)
|
||||
r.raise_for_status()
|
||||
|
||||
r = requests.get(f"{router_url}/list_workers", timeout=30)
|
||||
r.raise_for_status()
|
||||
urls = r.json().get("urls", [])
|
||||
assert len(urls) == 2
|
||||
assert set(urls) == {f"{worker_url}@0", f"{worker_url}@1"}
|
||||
|
||||
# Verify API key enforcement path-through
|
||||
# 1) Without Authorization -> 401 from backend
|
||||
r = requests.post(
|
||||
f"{router_url}/v1/completions",
|
||||
json={"model": e2e_model, "prompt": "hi", "max_tokens": 1},
|
||||
timeout=60,
|
||||
)
|
||||
assert r.status_code == 401
|
||||
|
||||
# 2) With correct Authorization -> 200
|
||||
r = requests.post(
|
||||
f"{router_url}/v1/completions",
|
||||
json={"model": e2e_model, "prompt": "hi", "max_tokens": 1},
|
||||
headers={"Authorization": f"Bearer {api_key}"},
|
||||
timeout=60,
|
||||
)
|
||||
assert r.status_code == 200
|
||||
|
||||
# Finally, run MMLU eval through the router with auth
|
||||
os.environ["OPENAI_API_KEY"] = api_key
|
||||
args = SimpleNamespace(
|
||||
base_url=router_url,
|
||||
model=e2e_model,
|
||||
eval_name="mmlu",
|
||||
num_examples=64,
|
||||
num_threads=32,
|
||||
temperature=0.1,
|
||||
)
|
||||
metrics = run_eval(args)
|
||||
assert metrics["score"] >= 0.65
|
||||
1
sgl-router/py_test/fixtures/__init__.py
Normal file
1
sgl-router/py_test/fixtures/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Shared fixtures for router integration tests."""
|
||||
252
sgl-router/py_test/fixtures/mock_worker.py
Normal file
252
sgl-router/py_test/fixtures/mock_worker.py
Normal file
@@ -0,0 +1,252 @@
|
||||
"""
|
||||
Lightweight mock worker HTTP server for router integration tests.
|
||||
|
||||
Implements minimal endpoints used by the router:
|
||||
- GET /health, /health_generate
|
||||
- POST /generate, /v1/completions, /v1/chat/completions
|
||||
- POST /flush_cache
|
||||
- GET /get_server_info, /get_model_info, /v1/models
|
||||
|
||||
Behavior knobs are controlled via CLI flags to simulate failures, latency, and load.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import signal
|
||||
import sys
|
||||
import time
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Optional
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi.responses import JSONResponse, PlainTextResponse, StreamingResponse
|
||||
|
||||
# Global state (per-process)
|
||||
_inflight = 0
|
||||
_failures_seen = 0
|
||||
|
||||
|
||||
def _parse_args() -> argparse.Namespace:
|
||||
p = argparse.ArgumentParser()
|
||||
p.add_argument("--host", default="127.0.0.1")
|
||||
p.add_argument("--port", type=int, required=True)
|
||||
p.add_argument("--worker-id", default=None)
|
||||
p.add_argument("--latency-ms", type=int, default=0)
|
||||
p.add_argument("--timeout", action="store_true")
|
||||
p.add_argument("--status-code", type=int, default=200)
|
||||
p.add_argument("--fail-first-n", type=int, default=0)
|
||||
p.add_argument("--random-fail-rate", type=float, default=0.0)
|
||||
p.add_argument("--require-api-key", action="store_true")
|
||||
p.add_argument("--api-key", default=None)
|
||||
p.add_argument("--max-payload-bytes", type=int, default=10 * 1024 * 1024)
|
||||
p.add_argument("--stream", action="store_true")
|
||||
p.add_argument("--dp-size", type=int, default=1)
|
||||
p.add_argument("--crash-on-request", action="store_true")
|
||||
p.add_argument("--health-fail-after-ms", type=int, default=0)
|
||||
return p.parse_args()
|
||||
|
||||
|
||||
def _extract_worker_id(args: argparse.Namespace) -> str:
|
||||
if args.worker_id:
|
||||
return str(args.worker_id)
|
||||
# default to port (unique enough for tests)
|
||||
return f"worker-{args.port}"
|
||||
|
||||
|
||||
def create_app(args: argparse.Namespace) -> FastAPI:
|
||||
app = FastAPI()
|
||||
worker_id = _extract_worker_id(args)
|
||||
start_ts = time.time()
|
||||
crashed = {"done": False}
|
||||
|
||||
async def maybe_delay():
|
||||
if args.latency_ms > 0:
|
||||
await asyncio.sleep(args.latency_ms / 1000.0)
|
||||
|
||||
def should_fail() -> Optional[int]:
|
||||
global _failures_seen
|
||||
# Fail first N requests (500)
|
||||
if args.fail_first_n > 0 and _failures_seen < args.fail_first_n:
|
||||
_failures_seen += 1
|
||||
return 500
|
||||
# Random failure probability (500)
|
||||
if args.random_fail_rate > 0.0 and random.random() < args.random_fail_rate:
|
||||
return 500
|
||||
# Forced status code override (non-200) for all responses
|
||||
if args.status_code != 200:
|
||||
return int(args.status_code)
|
||||
return None
|
||||
|
||||
def check_api_key(request: Request):
|
||||
if not args.require_api_key:
|
||||
return
|
||||
auth = request.headers.get("Authorization")
|
||||
if not auth or not auth.startswith("Bearer "):
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
key = auth.split(" ", 1)[1]
|
||||
if args.api_key and key != args.api_key:
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
|
||||
@asynccontextmanager
|
||||
async def track_inflight():
|
||||
global _inflight
|
||||
_inflight += 1
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_inflight -= 1
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
if (
|
||||
args.health_fail_after_ms
|
||||
and (time.time() - start_ts) * 1000.0 >= args.health_fail_after_ms
|
||||
):
|
||||
return PlainTextResponse("bad", status_code=500)
|
||||
return PlainTextResponse("ok", status_code=200)
|
||||
|
||||
@app.get("/health_generate")
|
||||
async def health_generate():
|
||||
return PlainTextResponse("ok", status_code=200)
|
||||
|
||||
@app.post("/flush_cache")
|
||||
async def flush_cache():
|
||||
return PlainTextResponse("ok", status_code=200)
|
||||
|
||||
@app.get("/get_model_info")
|
||||
async def get_model_info():
|
||||
return JSONResponse({"model": "mock", "vocab_size": 32000})
|
||||
|
||||
@app.get("/v1/models")
|
||||
async def list_models():
|
||||
return JSONResponse({"data": [{"id": "mock", "object": "model"}]})
|
||||
|
||||
@app.get("/get_server_info")
|
||||
async def get_server_info(request: Request):
|
||||
# Enforce API key on server info when required (used by dp_aware probing)
|
||||
check_api_key(request)
|
||||
return JSONResponse(
|
||||
{
|
||||
"worker_id": worker_id,
|
||||
"load_in_flight": _inflight,
|
||||
"cache": {"size": 0, "hit_rate": 0.0},
|
||||
"dp_size": int(args.dp_size),
|
||||
}
|
||||
)
|
||||
|
||||
@app.get("/get_load")
|
||||
async def get_load():
|
||||
return JSONResponse({"load": _inflight})
|
||||
|
||||
def make_json_response(obj: dict, status_code: int = 200) -> JSONResponse:
|
||||
resp = JSONResponse(obj, status_code=status_code)
|
||||
resp.headers["X-Worker-Id"] = worker_id
|
||||
return resp
|
||||
|
||||
async def handle_text_request(request: Request):
|
||||
# Authorization
|
||||
check_api_key(request)
|
||||
|
||||
# Payload limit
|
||||
body = await request.body()
|
||||
if len(body) > args.max_payload_bytes:
|
||||
return make_json_response({"error": "payload too large"}, status_code=413)
|
||||
|
||||
# Simulate crash on first request
|
||||
if args.crash_on_request and not crashed["done"]:
|
||||
crashed["done"] = True
|
||||
os._exit(1)
|
||||
|
||||
# Optional timeout (simulate hang)
|
||||
if args.timeout:
|
||||
await asyncio.sleep(3600)
|
||||
|
||||
# Optional latency
|
||||
await maybe_delay()
|
||||
|
||||
# Optional failures
|
||||
fail_code = should_fail()
|
||||
if fail_code is not None and fail_code != 200:
|
||||
return make_json_response(
|
||||
{"error": f"mock failure {fail_code}"}, status_code=fail_code
|
||||
)
|
||||
|
||||
# Build response echoing minimal shape
|
||||
try:
|
||||
data = await request.json()
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
data = {}
|
||||
|
||||
now = time.time()
|
||||
ret = {
|
||||
"id": f"cmpl-{int(now*1000)}",
|
||||
"object": "text_completion",
|
||||
"created": int(now),
|
||||
"model": "mock",
|
||||
"choices": [
|
||||
{
|
||||
"text": "ok",
|
||||
"index": 0,
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
"worker_id": worker_id,
|
||||
"echo": data,
|
||||
}
|
||||
return make_json_response(ret, status_code=200)
|
||||
|
||||
async def handle_stream_request(request: Request):
|
||||
check_api_key(request)
|
||||
|
||||
async def gen():
|
||||
# minimal 2-chunk stream then [DONE]
|
||||
for i in range(2):
|
||||
await asyncio.sleep(0.01)
|
||||
chunk = {
|
||||
"choices": [{"delta": {"content": "x"}}],
|
||||
"worker_id": worker_id,
|
||||
}
|
||||
yield f"data: {json.dumps(chunk)}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
headers = {"X-Worker-Id": worker_id}
|
||||
return StreamingResponse(gen(), media_type="text/event-stream", headers=headers)
|
||||
|
||||
@app.post("/generate")
|
||||
async def generate(request: Request):
|
||||
async with track_inflight():
|
||||
if args.stream:
|
||||
return await handle_stream_request(request)
|
||||
return await handle_text_request(request)
|
||||
|
||||
@app.post("/v1/completions")
|
||||
async def completions(request: Request):
|
||||
async with track_inflight():
|
||||
if args.stream:
|
||||
return await handle_stream_request(request)
|
||||
return await handle_text_request(request)
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
async def chat_completions(request: Request):
|
||||
async with track_inflight():
|
||||
if args.stream:
|
||||
return await handle_stream_request(request)
|
||||
return await handle_text_request(request)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = _parse_args()
|
||||
app = create_app(args)
|
||||
# Handle SIGTERM gracefully for fast test teardown
|
||||
signal.signal(signal.SIGTERM, lambda *_: sys.exit(0))
|
||||
uvicorn.run(app, host=args.host, port=args.port, log_level="warning")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
8
sgl-router/py_test/fixtures/ports.py
Normal file
8
sgl-router/py_test/fixtures/ports.py
Normal file
@@ -0,0 +1,8 @@
|
||||
import socket
|
||||
|
||||
|
||||
def find_free_port() -> int:
|
||||
"""Return an available TCP port on localhost."""
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(("127.0.0.1", 0))
|
||||
return s.getsockname()[1]
|
||||
158
sgl-router/py_test/fixtures/router_manager.py
Normal file
158
sgl-router/py_test/fixtures/router_manager.py
Normal file
@@ -0,0 +1,158 @@
|
||||
import subprocess
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import requests
|
||||
|
||||
from .ports import find_free_port
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProcHandle:
|
||||
process: subprocess.Popen
|
||||
url: str
|
||||
|
||||
|
||||
class RouterManager:
|
||||
"""Helper to spawn a router process and interact with admin endpoints."""
|
||||
|
||||
def __init__(self):
|
||||
self._children: List[subprocess.Popen] = []
|
||||
|
||||
def start_router(
|
||||
self,
|
||||
worker_urls: Optional[List[str]] = None,
|
||||
policy: str = "round_robin",
|
||||
port: Optional[int] = None,
|
||||
extra: Optional[Dict] = None,
|
||||
# PD options
|
||||
pd_disaggregation: bool = False,
|
||||
prefill_urls: Optional[List[tuple]] = None,
|
||||
decode_urls: Optional[List[str]] = None,
|
||||
prefill_policy: Optional[str] = None,
|
||||
decode_policy: Optional[str] = None,
|
||||
) -> ProcHandle:
|
||||
worker_urls = worker_urls or []
|
||||
port = port or find_free_port()
|
||||
cmd = [
|
||||
"python3",
|
||||
"-m",
|
||||
"sglang_router.launch_router",
|
||||
"--host",
|
||||
"127.0.0.1",
|
||||
"--port",
|
||||
str(port),
|
||||
"--policy",
|
||||
policy,
|
||||
]
|
||||
# Avoid Prometheus port collisions by assigning a free port per router
|
||||
prom_port = find_free_port()
|
||||
cmd.extend(
|
||||
["--prometheus-port", str(prom_port), "--prometheus-host", "127.0.0.1"]
|
||||
)
|
||||
if worker_urls:
|
||||
cmd.extend(["--worker-urls", *worker_urls])
|
||||
|
||||
# PD routing configuration
|
||||
if pd_disaggregation:
|
||||
cmd.append("--pd-disaggregation")
|
||||
if prefill_urls:
|
||||
for url, bport in prefill_urls:
|
||||
if bport is None:
|
||||
cmd.extend(["--prefill", url, "none"])
|
||||
else:
|
||||
cmd.extend(["--prefill", url, str(bport)])
|
||||
if decode_urls:
|
||||
for url in decode_urls:
|
||||
cmd.extend(["--decode", url])
|
||||
if prefill_policy:
|
||||
cmd.extend(["--prefill-policy", prefill_policy])
|
||||
if decode_policy:
|
||||
cmd.extend(["--decode-policy", decode_policy])
|
||||
|
||||
# Map supported extras to CLI flags (subset for integration)
|
||||
if extra:
|
||||
flag_map = {
|
||||
"max_payload_size": "--max-payload-size",
|
||||
"dp_aware": "--dp-aware",
|
||||
"api_key": "--api-key",
|
||||
# Health/monitoring
|
||||
"worker_startup_check_interval": "--worker-startup-check-interval",
|
||||
# Cache-aware tuning
|
||||
"cache_threshold": "--cache-threshold",
|
||||
"balance_abs_threshold": "--balance-abs-threshold",
|
||||
"balance_rel_threshold": "--balance-rel-threshold",
|
||||
# Retry
|
||||
"retry_max_retries": "--retry-max-retries",
|
||||
"retry_initial_backoff_ms": "--retry-initial-backoff-ms",
|
||||
"retry_max_backoff_ms": "--retry-max-backoff-ms",
|
||||
"retry_backoff_multiplier": "--retry-backoff-multiplier",
|
||||
"retry_jitter_factor": "--retry-jitter-factor",
|
||||
"disable_retries": "--disable-retries",
|
||||
# Circuit breaker
|
||||
"cb_failure_threshold": "--cb-failure-threshold",
|
||||
"cb_success_threshold": "--cb-success-threshold",
|
||||
"cb_timeout_duration_secs": "--cb-timeout-duration-secs",
|
||||
"cb_window_duration_secs": "--cb-window-duration-secs",
|
||||
"disable_circuit_breaker": "--disable-circuit-breaker",
|
||||
# Rate limiting
|
||||
"max_concurrent_requests": "--max-concurrent-requests",
|
||||
"queue_size": "--queue-size",
|
||||
"queue_timeout_secs": "--queue-timeout-secs",
|
||||
"rate_limit_tokens_per_second": "--rate-limit-tokens-per-second",
|
||||
}
|
||||
for k, v in extra.items():
|
||||
if v is None:
|
||||
continue
|
||||
flag = flag_map.get(k)
|
||||
if not flag:
|
||||
continue
|
||||
if isinstance(v, bool):
|
||||
if v:
|
||||
cmd.append(flag)
|
||||
else:
|
||||
cmd.extend([flag, str(v)])
|
||||
|
||||
proc = subprocess.Popen(cmd)
|
||||
self._children.append(proc)
|
||||
url = f"http://127.0.0.1:{port}"
|
||||
self._wait_health(url)
|
||||
return ProcHandle(process=proc, url=url)
|
||||
|
||||
def _wait_health(self, base_url: str, timeout: float = 30.0):
|
||||
start = time.time()
|
||||
with requests.Session() as s:
|
||||
while time.time() - start < timeout:
|
||||
try:
|
||||
r = s.get(f"{base_url}/health", timeout=2)
|
||||
if r.status_code == 200:
|
||||
return
|
||||
except requests.RequestException:
|
||||
pass
|
||||
time.sleep(0.2)
|
||||
raise TimeoutError(f"Router at {base_url} did not become healthy")
|
||||
|
||||
def add_worker(self, base_url: str, worker_url: str) -> None:
|
||||
r = requests.post(f"{base_url}/add_worker", params={"url": worker_url})
|
||||
assert r.status_code == 200, f"add_worker failed: {r.status_code} {r.text}"
|
||||
|
||||
def remove_worker(self, base_url: str, worker_url: str) -> None:
|
||||
r = requests.post(f"{base_url}/remove_worker", params={"url": worker_url})
|
||||
assert r.status_code == 200, f"remove_worker failed: {r.status_code} {r.text}"
|
||||
|
||||
def list_workers(self, base_url: str) -> list[str]:
|
||||
r = requests.get(f"{base_url}/list_workers")
|
||||
assert r.status_code == 200, f"list_workers failed: {r.status_code} {r.text}"
|
||||
data = r.json()
|
||||
return data.get("urls", [])
|
||||
|
||||
def stop_all(self):
|
||||
for p in self._children:
|
||||
if p.poll() is None:
|
||||
p.terminate()
|
||||
try:
|
||||
p.wait(timeout=5)
|
||||
except subprocess.TimeoutExpired:
|
||||
p.kill()
|
||||
self._children.clear()
|
||||
1
sgl-router/py_test/integration/__init__.py
Normal file
1
sgl-router/py_test/integration/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Integration test package for the router."""
|
||||
109
sgl-router/py_test/integration/conftest.py
Normal file
109
sgl-router/py_test/integration/conftest.py
Normal file
@@ -0,0 +1,109 @@
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Dict, Iterable, List, Tuple
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from ..fixtures.ports import find_free_port
|
||||
from ..fixtures.router_manager import RouterManager
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
config.addinivalue_line("markers", "integration: mark as router integration test")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def router_manager() -> Iterable[RouterManager]:
|
||||
mgr = RouterManager()
|
||||
try:
|
||||
yield mgr
|
||||
finally:
|
||||
mgr.stop_all()
|
||||
|
||||
|
||||
def _spawn_mock_worker(args: List[str]) -> Tuple[subprocess.Popen, str, str]:
|
||||
repo_root = Path(__file__).resolve().parents[2]
|
||||
script = repo_root / "py_test" / "fixtures" / "mock_worker.py"
|
||||
port = find_free_port()
|
||||
worker_id = f"worker-{port}"
|
||||
base_cmd = [
|
||||
"python3",
|
||||
str(script),
|
||||
"--port",
|
||||
str(port),
|
||||
"--worker-id",
|
||||
worker_id,
|
||||
]
|
||||
cmd = base_cmd + args
|
||||
proc = subprocess.Popen(cmd)
|
||||
url = f"http://127.0.0.1:{port}"
|
||||
_wait_health(url)
|
||||
return proc, url, worker_id
|
||||
|
||||
|
||||
def _wait_health(url: str, timeout: float = 10.0):
|
||||
start = time.time()
|
||||
with requests.Session() as s:
|
||||
while time.time() - start < timeout:
|
||||
try:
|
||||
r = s.get(f"{url}/health", timeout=1)
|
||||
if r.status_code == 200:
|
||||
return
|
||||
except requests.RequestException:
|
||||
pass
|
||||
time.sleep(0.1)
|
||||
raise TimeoutError(f"Mock worker at {url} did not become healthy")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_worker():
|
||||
"""Start a single healthy mock worker; yields (process, url, worker_id)."""
|
||||
proc, url, worker_id = _spawn_mock_worker([])
|
||||
try:
|
||||
yield proc, url, worker_id
|
||||
finally:
|
||||
if proc.poll() is None:
|
||||
proc.terminate()
|
||||
try:
|
||||
proc.wait(timeout=3)
|
||||
except subprocess.TimeoutExpired:
|
||||
proc.kill()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_workers():
|
||||
"""Factory to start N workers with custom args.
|
||||
|
||||
Usage:
|
||||
procs, urls, ids = mock_workers(n=3, args=["--latency-ms", "5"]) # same args for all
|
||||
...
|
||||
"""
|
||||
|
||||
procs: List[subprocess.Popen] = []
|
||||
|
||||
def _start(n: int, args: List[str] | None = None):
|
||||
args = args or []
|
||||
new_procs: List[subprocess.Popen] = []
|
||||
urls: List[str] = []
|
||||
ids: List[str] = []
|
||||
for _ in range(n):
|
||||
p, url, wid = _spawn_mock_worker(args)
|
||||
procs.append(p)
|
||||
new_procs.append(p)
|
||||
urls.append(url)
|
||||
ids.append(wid)
|
||||
return new_procs, urls, ids
|
||||
|
||||
try:
|
||||
yield _start
|
||||
finally:
|
||||
for p in procs:
|
||||
if p.poll() is None:
|
||||
p.terminate()
|
||||
try:
|
||||
p.wait(timeout=3)
|
||||
except subprocess.TimeoutExpired:
|
||||
p.kill()
|
||||
@@ -0,0 +1 @@
|
||||
"""Load balancing integration tests."""
|
||||
@@ -0,0 +1,73 @@
|
||||
import collections
|
||||
import concurrent.futures
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_cache_aware_affinity(mock_workers, router_manager):
|
||||
# Two workers; same prompt should stick to one due to cache tree
|
||||
_, urls, ids = mock_workers(n=2)
|
||||
rh = router_manager.start_router(worker_urls=urls, policy="cache_aware")
|
||||
|
||||
counts = collections.Counter()
|
||||
with requests.Session() as s:
|
||||
for i in range(12):
|
||||
r = s.post(
|
||||
f"{rh.url}/v1/completions",
|
||||
json={
|
||||
"model": "test-model",
|
||||
"prompt": "repeated prompt for cache",
|
||||
"max_tokens": 1,
|
||||
"stream": False,
|
||||
},
|
||||
)
|
||||
assert r.status_code == 200
|
||||
wid = r.headers.get("X-Worker-Id") or r.json().get("worker_id")
|
||||
counts[wid] += 1
|
||||
|
||||
# Expect strong skew toward one worker (tree match); majority > 80%
|
||||
top = max(counts.values())
|
||||
assert top >= 10, counts
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_cache_aware_diverse_prompts_balances(mock_workers, router_manager):
|
||||
# Add latency so concurrent requests overlap and influence load-based selection
|
||||
_, urls, ids = mock_workers(n=3, args=["--latency-ms", "30"])
|
||||
rh = router_manager.start_router(
|
||||
worker_urls=urls,
|
||||
policy="cache_aware",
|
||||
extra={
|
||||
"cache_threshold": 0.99,
|
||||
"balance_abs_threshold": 0,
|
||||
"balance_rel_threshold": 1.0,
|
||||
},
|
||||
)
|
||||
|
||||
counts = collections.Counter()
|
||||
|
||||
def call(i):
|
||||
# Use diverse, unrelated prompts to avoid prefix matches entirely
|
||||
prompt = str(uuid.uuid4())
|
||||
r = requests.post(
|
||||
f"{rh.url}/v1/completions",
|
||||
json={
|
||||
"model": "test-model",
|
||||
"prompt": prompt,
|
||||
"max_tokens": 1,
|
||||
"stream": False,
|
||||
},
|
||||
timeout=5,
|
||||
)
|
||||
assert r.status_code == 200
|
||||
return r.headers.get("X-Worker-Id") or r.json().get("worker_id")
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=16) as ex:
|
||||
for wid in ex.map(call, range(40)):
|
||||
counts[wid] += 1
|
||||
|
||||
# Expect participation of at least two workers
|
||||
assert sum(1 for v in counts.values() if v > 0) >= 2, counts
|
||||
@@ -0,0 +1,89 @@
|
||||
import collections
|
||||
import concurrent.futures
|
||||
import time
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_power_of_two_prefers_less_loaded(mock_workers, router_manager):
|
||||
# Start two workers: one slow (higher inflight), one fast
|
||||
# Router monitors /get_load and Power-of-Two uses cached loads to choose
|
||||
# Start one slow and one fast worker using the fixture factory
|
||||
procs_slow, urls_slow, ids_slow = mock_workers(n=1, args=["--latency-ms", "200"])
|
||||
procs_fast, urls_fast, ids_fast = mock_workers(n=1, args=["--latency-ms", "0"])
|
||||
procs = procs_slow + procs_fast
|
||||
urls = urls_slow + urls_fast
|
||||
ids = ids_slow + ids_fast
|
||||
slow_id = ids_slow[0]
|
||||
|
||||
rh = router_manager.start_router(
|
||||
worker_urls=urls,
|
||||
policy="power_of_two",
|
||||
extra={"worker_startup_check_interval": 1},
|
||||
)
|
||||
|
||||
# Prime: fire a burst to create measurable load on slow worker, then wait for monitor tick
|
||||
|
||||
def _prime_call(i):
|
||||
try:
|
||||
requests.post(
|
||||
f"{rh.url}/v1/completions",
|
||||
json={
|
||||
"model": "test-model",
|
||||
"prompt": f"warm-{i}",
|
||||
"max_tokens": 1,
|
||||
"stream": False,
|
||||
},
|
||||
timeout=5,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=32) as ex:
|
||||
list(ex.map(_prime_call, range(128)))
|
||||
time.sleep(2)
|
||||
|
||||
# Apply direct background load on the slow worker to amplify load diff
|
||||
def _direct_load(i):
|
||||
try:
|
||||
requests.post(
|
||||
f"{slow_url}/v1/completions",
|
||||
json={
|
||||
"model": "test-model",
|
||||
"prompt": f"bg-{i}",
|
||||
"max_tokens": 1,
|
||||
"stream": False,
|
||||
},
|
||||
timeout=5,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=32) as ex:
|
||||
list(ex.map(_direct_load, range(128)))
|
||||
time.sleep(1)
|
||||
|
||||
def call(i):
|
||||
r = requests.post(
|
||||
f"{rh.url}/v1/completions",
|
||||
json={
|
||||
"model": "test-model",
|
||||
"prompt": f"p{i}",
|
||||
"max_tokens": 1,
|
||||
"stream": False,
|
||||
},
|
||||
timeout=5,
|
||||
)
|
||||
assert r.status_code == 200
|
||||
return r.headers.get("X-Worker-Id") or r.json().get("worker_id")
|
||||
|
||||
counts = collections.Counter()
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=32) as ex:
|
||||
for wid in ex.map(call, range(200)):
|
||||
counts[wid] += 1
|
||||
|
||||
# Expect the slow worker (higher latency/inflight) to receive fewer requests
|
||||
fast_worker_id = [i for i in ids if i != slow_id][0]
|
||||
assert counts[slow_id] < counts[fast_worker_id], counts
|
||||
33
sgl-router/py_test/integration/load_balancing/test_random.py
Normal file
33
sgl-router/py_test/integration/load_balancing/test_random.py
Normal file
@@ -0,0 +1,33 @@
|
||||
import collections
|
||||
import math
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_random_distribution(mock_workers, router_manager):
|
||||
procs, urls, ids = mock_workers(n=4)
|
||||
rh = router_manager.start_router(worker_urls=urls, policy="random")
|
||||
|
||||
counts = collections.Counter()
|
||||
N = 200
|
||||
with requests.Session() as s:
|
||||
for i in range(N):
|
||||
r = s.post(
|
||||
f"{rh.url}/v1/completions",
|
||||
json={
|
||||
"model": "test-model",
|
||||
"prompt": f"p{i}",
|
||||
"max_tokens": 1,
|
||||
"stream": False,
|
||||
},
|
||||
)
|
||||
assert r.status_code == 200
|
||||
wid = r.headers.get("X-Worker-Id") or r.json().get("worker_id")
|
||||
counts[wid] += 1
|
||||
|
||||
# simple statistical tolerance: each worker should be within ±50% of mean
|
||||
mean = N / len(ids)
|
||||
for wid in ids:
|
||||
assert 0.5 * mean <= counts[wid] <= 1.5 * mean, counts
|
||||
@@ -0,0 +1,34 @@
|
||||
import collections
|
||||
import time
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_round_robin_distribution(mock_workers, router_manager):
|
||||
procs, urls, ids = mock_workers(n=3)
|
||||
|
||||
rh = router_manager.start_router(worker_urls=urls, policy="round_robin")
|
||||
|
||||
counts = collections.Counter()
|
||||
with requests.Session() as s:
|
||||
for i in range(30):
|
||||
r = s.post(
|
||||
f"{rh.url}/v1/completions",
|
||||
json={
|
||||
"model": "test-model",
|
||||
"prompt": f"hello {i}",
|
||||
"max_tokens": 1,
|
||||
"stream": False,
|
||||
},
|
||||
)
|
||||
assert r.status_code == 200
|
||||
wid = r.headers.get("X-Worker-Id") or r.json().get("worker_id")
|
||||
assert wid in ids
|
||||
counts[wid] += 1
|
||||
|
||||
# Expect near-even distribution across 3 workers
|
||||
# 30 requests -> ideally 10 each; allow small tolerance ±3
|
||||
for wid in ids:
|
||||
assert 7 <= counts[wid] <= 13, counts
|
||||
38
sgl-router/py_test/integration/test_api_auth.py
Normal file
38
sgl-router/py_test/integration/test_api_auth.py
Normal file
@@ -0,0 +1,38 @@
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_router_api_key_enforcement(router_manager, mock_workers):
|
||||
# Start backend requiring API key; router should forward Authorization header transparently
|
||||
_, urls, _ = mock_workers(
|
||||
n=1, args=["--require-api-key", "--api-key", "correct_api_key"]
|
||||
)
|
||||
rh = router_manager.start_router(
|
||||
worker_urls=urls,
|
||||
policy="round_robin",
|
||||
extra={},
|
||||
)
|
||||
|
||||
# No auth -> 401
|
||||
r = requests.post(
|
||||
f"{rh.url}/v1/completions",
|
||||
json={"model": "test-model", "prompt": "x", "max_tokens": 1, "stream": False},
|
||||
)
|
||||
assert r.status_code == 401
|
||||
|
||||
# Invalid auth -> 401
|
||||
r = requests.post(
|
||||
f"{rh.url}/v1/completions",
|
||||
json={"model": "test-model", "prompt": "x", "max_tokens": 1, "stream": False},
|
||||
headers={"Authorization": "Bearer wrong"},
|
||||
)
|
||||
assert r.status_code == 401
|
||||
|
||||
# Correct auth -> 200
|
||||
r = requests.post(
|
||||
f"{rh.url}/v1/completions",
|
||||
json={"model": "test-model", "prompt": "x", "max_tokens": 1, "stream": False},
|
||||
headers={"Authorization": "Bearer correct_api_key"},
|
||||
)
|
||||
assert r.status_code == 200
|
||||
191
sgl-router/py_test/integration/test_circuit_breaker.py
Normal file
191
sgl-router/py_test/integration/test_circuit_breaker.py
Normal file
@@ -0,0 +1,191 @@
|
||||
import time
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_circuit_breaker_opens_and_recovers(router_manager, mock_workers):
|
||||
# A single worker that fails first 3 requests, then succeeds
|
||||
_, [wurl], _ = mock_workers(n=1, args=["--fail-first-n", "3"]) # fails first 3
|
||||
rh = router_manager.start_router(
|
||||
worker_urls=[wurl],
|
||||
policy="round_robin",
|
||||
extra={
|
||||
"cb_failure_threshold": 3,
|
||||
"cb_success_threshold": 2,
|
||||
"cb_timeout_duration_secs": 3,
|
||||
"cb_window_duration_secs": 10,
|
||||
"disable_retries": True,
|
||||
},
|
||||
)
|
||||
|
||||
def post_once():
|
||||
return requests.post(
|
||||
f"{rh.url}/v1/completions",
|
||||
json={
|
||||
"model": "test-model",
|
||||
"prompt": "trigger",
|
||||
"max_tokens": 1,
|
||||
"stream": False,
|
||||
},
|
||||
timeout=3,
|
||||
)
|
||||
|
||||
saw_503 = False
|
||||
for _ in range(8):
|
||||
r = post_once()
|
||||
if r.status_code == 503:
|
||||
saw_503 = True
|
||||
break
|
||||
assert saw_503, "circuit breaker did not open to return 503"
|
||||
|
||||
time.sleep(4)
|
||||
r1 = post_once()
|
||||
r2 = post_once()
|
||||
assert r1.status_code == 200 and r2.status_code == 200
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_circuit_breaker_half_open_failure_reopens(router_manager, mock_workers):
|
||||
_, [wurl], _ = mock_workers(n=1, args=["--status-code", "500"]) # always fail
|
||||
rh = router_manager.start_router(
|
||||
worker_urls=[wurl],
|
||||
policy="round_robin",
|
||||
extra={
|
||||
"cb_failure_threshold": 2,
|
||||
"cb_success_threshold": 2,
|
||||
"cb_timeout_duration_secs": 2,
|
||||
"cb_window_duration_secs": 5,
|
||||
"disable_retries": True,
|
||||
},
|
||||
)
|
||||
|
||||
def post_once():
|
||||
return requests.post(
|
||||
f"{rh.url}/v1/completions",
|
||||
json={
|
||||
"model": "test-model",
|
||||
"prompt": "x",
|
||||
"max_tokens": 1,
|
||||
"stream": False,
|
||||
},
|
||||
timeout=3,
|
||||
)
|
||||
|
||||
opened = False
|
||||
for _ in range(8):
|
||||
r = post_once()
|
||||
if r.status_code == 503:
|
||||
opened = True
|
||||
break
|
||||
assert opened, "circuit breaker did not open"
|
||||
|
||||
time.sleep(3)
|
||||
r = post_once()
|
||||
assert r.status_code == 500
|
||||
r2 = post_once()
|
||||
assert r2.status_code == 503
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_circuit_breaker_disable_flag(router_manager, mock_workers):
|
||||
_, [wurl], _ = mock_workers(n=1, args=["--status-code", "500"]) # always fail
|
||||
rh = router_manager.start_router(
|
||||
worker_urls=[wurl],
|
||||
policy="round_robin",
|
||||
extra={
|
||||
"disable_circuit_breaker": True,
|
||||
"disable_retries": True,
|
||||
},
|
||||
)
|
||||
r = requests.post(
|
||||
f"{rh.url}/v1/completions",
|
||||
json={
|
||||
"model": "test-model",
|
||||
"prompt": "x",
|
||||
"max_tokens": 1,
|
||||
"stream": False,
|
||||
},
|
||||
timeout=3,
|
||||
)
|
||||
assert r.status_code == 500
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_circuit_breaker_per_worker_isolation(router_manager, mock_workers):
|
||||
_, [fail_url], _ = mock_workers(n=1, args=["--status-code", "500"]) # always fail
|
||||
_, [ok_url], _ = mock_workers(n=1)
|
||||
rh = router_manager.start_router(
|
||||
worker_urls=[fail_url, ok_url],
|
||||
policy="round_robin",
|
||||
extra={
|
||||
"cb_failure_threshold": 2,
|
||||
"cb_success_threshold": 1,
|
||||
"cb_timeout_duration_secs": 2,
|
||||
"cb_window_duration_secs": 10,
|
||||
"disable_retries": True,
|
||||
},
|
||||
)
|
||||
|
||||
def post_once():
|
||||
return requests.post(
|
||||
f"{rh.url}/v1/completions",
|
||||
json={
|
||||
"model": "test-model",
|
||||
"prompt": "y",
|
||||
"max_tokens": 1,
|
||||
"stream": False,
|
||||
},
|
||||
timeout=3,
|
||||
)
|
||||
|
||||
failures = 0
|
||||
successes_after_open = 0
|
||||
opened = False
|
||||
for _ in range(30):
|
||||
r = post_once()
|
||||
if not opened:
|
||||
if r.status_code == 500:
|
||||
failures += 1
|
||||
if failures >= 2:
|
||||
_ = post_once()
|
||||
_ = post_once()
|
||||
opened = True
|
||||
else:
|
||||
if r.status_code == 200:
|
||||
successes_after_open += 1
|
||||
else:
|
||||
assert False, f"Unexpected non-200 after CB open: {r.status_code}"
|
||||
assert opened and successes_after_open >= 5
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_circuit_breaker_with_retries(router_manager, mock_workers):
|
||||
_, [fail_url], _ = mock_workers(n=1, args=["--status-code", "500"]) # always fail
|
||||
_, [ok_url], _ = mock_workers(n=1)
|
||||
rh = router_manager.start_router(
|
||||
worker_urls=[fail_url, ok_url],
|
||||
policy="round_robin",
|
||||
extra={
|
||||
"retry_max_retries": 3,
|
||||
"retry_initial_backoff_ms": 10,
|
||||
"retry_max_backoff_ms": 50,
|
||||
"cb_failure_threshold": 2,
|
||||
"cb_success_threshold": 1,
|
||||
"cb_timeout_duration_secs": 2,
|
||||
"cb_window_duration_secs": 10,
|
||||
},
|
||||
)
|
||||
|
||||
r = requests.post(
|
||||
f"{rh.url}/v1/completions",
|
||||
json={
|
||||
"model": "test-model",
|
||||
"prompt": "z",
|
||||
"max_tokens": 1,
|
||||
"stream": False,
|
||||
},
|
||||
timeout=5,
|
||||
)
|
||||
assert r.status_code == 200
|
||||
36
sgl-router/py_test/integration/test_fault_tolerance.py
Normal file
36
sgl-router/py_test/integration/test_fault_tolerance.py
Normal file
@@ -0,0 +1,36 @@
|
||||
import concurrent.futures
|
||||
import subprocess
|
||||
import time
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_worker_crash_reroute_with_retries(router_manager, mock_workers):
|
||||
# Start one healthy and one that will crash on first request
|
||||
_, [ok_url], _ = mock_workers(n=1)
|
||||
_, [crash_url], _ = mock_workers(n=1, args=["--crash-on-request"])
|
||||
rh = router_manager.start_router(
|
||||
worker_urls=[crash_url, ok_url],
|
||||
policy="round_robin",
|
||||
extra={
|
||||
"retry_max_retries": 3,
|
||||
"retry_initial_backoff_ms": 10,
|
||||
"retry_max_backoff_ms": 50,
|
||||
},
|
||||
)
|
||||
|
||||
# A single request should succeed via retry to the healthy worker
|
||||
r = requests.post(
|
||||
f"{rh.url}/v1/completions",
|
||||
json={
|
||||
"model": "test-model",
|
||||
"prompt": "crash",
|
||||
"max_tokens": 1,
|
||||
"stream": False,
|
||||
},
|
||||
timeout=5,
|
||||
)
|
||||
assert r.status_code == 200
|
||||
# mock_workers fixture handles cleanup
|
||||
33
sgl-router/py_test/integration/test_payload_size.py
Normal file
33
sgl-router/py_test/integration/test_payload_size.py
Normal file
@@ -0,0 +1,33 @@
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_payload_size_limit(router_manager, mock_workers):
|
||||
# Start one backend and a router with a 1MB payload limit
|
||||
_, urls, _ = mock_workers(n=1)
|
||||
rh = router_manager.start_router(
|
||||
worker_urls=urls,
|
||||
policy="round_robin",
|
||||
extra={"max_payload_size": 1 * 1024 * 1024}, # 1MB
|
||||
)
|
||||
|
||||
# Payload just under 1MB should succeed
|
||||
payload_small = {
|
||||
"model": "test-model",
|
||||
"prompt": "x" * int(0.5 * 1024 * 1024), # ~0.5MB
|
||||
"max_tokens": 1,
|
||||
"stream": False,
|
||||
}
|
||||
r = requests.post(f"{rh.url}/v1/completions", json=payload_small)
|
||||
assert r.status_code == 200
|
||||
|
||||
# Payload over 1MB should fail with 413
|
||||
payload_large = {
|
||||
"model": "test-model",
|
||||
"prompt": "x" * int(1.2 * 1024 * 1024), # ~1.2MB
|
||||
"max_tokens": 1,
|
||||
"stream": False,
|
||||
}
|
||||
r = requests.post(f"{rh.url}/v1/completions", json=payload_large)
|
||||
assert r.status_code == 413
|
||||
127
sgl-router/py_test/integration/test_pd_routing.py
Normal file
127
sgl-router/py_test/integration/test_pd_routing.py
Normal file
@@ -0,0 +1,127 @@
|
||||
import collections
|
||||
import concurrent.futures
|
||||
import subprocess
|
||||
import time
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_pd_power_of_two_decode_attribution(router_manager, mock_workers):
|
||||
# Start two prefill and three decode mock workers via fixture
|
||||
_, prefill_urls_raw, prefill_ids = mock_workers(n=2)
|
||||
_, decode_urls_raw, decode_ids_list = mock_workers(n=3)
|
||||
prefill_urls = [(u, None) for u in prefill_urls_raw]
|
||||
decode_urls = list(decode_urls_raw)
|
||||
decode_ids = set(decode_ids_list)
|
||||
|
||||
rh = router_manager.start_router(
|
||||
policy="power_of_two",
|
||||
pd_disaggregation=True,
|
||||
prefill_urls=prefill_urls,
|
||||
decode_urls=decode_urls,
|
||||
extra={"worker_startup_check_interval": 1},
|
||||
)
|
||||
|
||||
counts = collections.Counter()
|
||||
with requests.Session() as s:
|
||||
for i in range(30):
|
||||
r = s.post(
|
||||
f"{rh.url}/v1/completions",
|
||||
json={
|
||||
"model": "test-model",
|
||||
"prompt": f"p{i}",
|
||||
"max_tokens": 1,
|
||||
"stream": False,
|
||||
},
|
||||
)
|
||||
assert r.status_code == 200
|
||||
wid = r.headers.get("X-Worker-Id") or r.json().get("worker_id")
|
||||
assert wid in decode_ids
|
||||
counts[wid] += 1
|
||||
|
||||
assert sum(1 for v in counts.values() if v > 0) >= 2
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_pd_power_of_two_skews_to_faster_decode(router_manager, mock_workers):
|
||||
# Start two prefill workers (fast)
|
||||
_, prefill_urls_raw, _ = mock_workers(n=2)
|
||||
|
||||
# Start two decode workers: one slow, one fast
|
||||
_, [decode_slow_url], [slow_id] = mock_workers(
|
||||
n=1, args=["--latency-ms", "300"]
|
||||
) # slower decode
|
||||
_, [decode_fast_url], [fast_id] = mock_workers(n=1)
|
||||
decode_urls_raw = [decode_slow_url, decode_fast_url]
|
||||
|
||||
prefill_urls = [(u, None) for u in prefill_urls_raw]
|
||||
decode_urls = list(decode_urls_raw)
|
||||
|
||||
rh = router_manager.start_router(
|
||||
policy="power_of_two",
|
||||
pd_disaggregation=True,
|
||||
prefill_urls=prefill_urls,
|
||||
decode_urls=decode_urls,
|
||||
extra={"worker_startup_check_interval": 1},
|
||||
)
|
||||
|
||||
def _prime_call(i):
|
||||
try:
|
||||
requests.post(
|
||||
f"{rh.url}/v1/completions",
|
||||
json={
|
||||
"model": "test-model",
|
||||
"prompt": f"warm-{i}",
|
||||
"max_tokens": 1,
|
||||
"stream": False,
|
||||
},
|
||||
timeout=8,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=32) as ex:
|
||||
list(ex.map(_prime_call, range(128)))
|
||||
time.sleep(2)
|
||||
|
||||
def _direct_decode_load(i):
|
||||
try:
|
||||
requests.post(
|
||||
f"{decode_slow_url}/v1/completions",
|
||||
json={
|
||||
"model": "test-model",
|
||||
"prompt": f"bg-{i}",
|
||||
"max_tokens": 1,
|
||||
"stream": False,
|
||||
},
|
||||
timeout=8,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=32) as ex:
|
||||
list(ex.map(_direct_decode_load, range(128)))
|
||||
time.sleep(1)
|
||||
|
||||
def call(i):
|
||||
r = requests.post(
|
||||
f"{rh.url}/v1/completions",
|
||||
json={
|
||||
"model": "test-model",
|
||||
"prompt": f"p{i}",
|
||||
"max_tokens": 1,
|
||||
"stream": False,
|
||||
},
|
||||
timeout=8,
|
||||
)
|
||||
assert r.status_code == 200
|
||||
return r.headers.get("X-Worker-Id") or r.json().get("worker_id")
|
||||
|
||||
counts = collections.Counter()
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=32) as ex:
|
||||
for wid in ex.map(call, range(200)):
|
||||
counts[wid] += 1
|
||||
|
||||
assert counts[slow_id] < counts[fast_id], counts
|
||||
91
sgl-router/py_test/integration/test_rate_limiting.py
Normal file
91
sgl-router/py_test/integration/test_rate_limiting.py
Normal file
@@ -0,0 +1,91 @@
|
||||
import concurrent.futures
|
||||
import time
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_rate_limit_and_queue(router_manager, mock_workers):
|
||||
# One fast backend
|
||||
_, urls, _ = mock_workers(n=1)
|
||||
rh = router_manager.start_router(
|
||||
worker_urls=urls,
|
||||
policy="round_robin",
|
||||
extra={
|
||||
"max_concurrent_requests": 2,
|
||||
"queue_size": 0, # no queue -> immediate 429 when limit exceeded
|
||||
},
|
||||
)
|
||||
|
||||
def call_once(i):
|
||||
try:
|
||||
r = requests.post(
|
||||
f"{rh.url}/v1/completions",
|
||||
json={
|
||||
"model": "test-model",
|
||||
"prompt": f"p{i}",
|
||||
"max_tokens": 1,
|
||||
"stream": False,
|
||||
},
|
||||
timeout=3,
|
||||
)
|
||||
return r.status_code
|
||||
except Exception:
|
||||
return 599
|
||||
|
||||
# Fire a burst of concurrent requests
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=16) as ex:
|
||||
results = list(ex.map(call_once, range(16)))
|
||||
|
||||
# Expect some to succeed and some to be rate limited (429)
|
||||
assert any(code == 200 for code in results)
|
||||
assert any(code == 429 for code in results)
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_rate_limit_queue_and_timeout(router_manager, mock_workers):
|
||||
# Slow backend: ~2s per request ensures queue wait > timeout
|
||||
_, urls, _ = mock_workers(n=1, args=["--latency-ms", "2000"]) # 2.0s per request
|
||||
|
||||
# Allow 1 concurrent, queue up to 1, with 1s queue timeout
|
||||
rh = router_manager.start_router(
|
||||
worker_urls=urls,
|
||||
policy="round_robin",
|
||||
extra={
|
||||
"max_concurrent_requests": 1,
|
||||
"queue_size": 1,
|
||||
"queue_timeout_secs": 1,
|
||||
},
|
||||
)
|
||||
|
||||
def call_once(i):
|
||||
try:
|
||||
r = requests.post(
|
||||
f"{rh.url}/v1/completions",
|
||||
json={
|
||||
"model": "test-model",
|
||||
"prompt": f"q{i}",
|
||||
"max_tokens": 1,
|
||||
"stream": False,
|
||||
},
|
||||
timeout=5,
|
||||
)
|
||||
return r.status_code
|
||||
except Exception:
|
||||
return 599
|
||||
|
||||
# Fire 4 concurrent requests: 1 runs (~2s), 1 queued (times out at 1s -> 408), 2 overflow -> 429
|
||||
import concurrent.futures
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=8) as ex:
|
||||
results = list(ex.map(call_once, range(4)))
|
||||
|
||||
# We expect:
|
||||
# - Some 200s (processed)
|
||||
# - At least one 408 (queued too long and timed out)
|
||||
# - Remaining non-200s are either 429 (queue overflow) or additional 408s depending on scheduling
|
||||
assert any(code == 200 for code in results)
|
||||
assert any(code == 408 for code in results), results
|
||||
non200 = [c for c in results if c != 200]
|
||||
assert len(non200) >= 2 and all(c in (408, 429) for c in non200), results
|
||||
65
sgl-router/py_test/integration/test_retries.py
Normal file
65
sgl-router/py_test/integration/test_retries.py
Normal file
@@ -0,0 +1,65 @@
|
||||
import concurrent.futures
|
||||
import subprocess
|
||||
import time
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_retry_reroutes_to_healthy_worker(router_manager, mock_workers):
|
||||
# Worker A always 500; Worker B healthy
|
||||
# Worker A always 500; Worker B/C healthy
|
||||
_, [url_a], [id_a] = mock_workers(n=1, args=["--status-code", "500"]) # fail
|
||||
_, [url_b], [id_b] = mock_workers(n=1)
|
||||
_, [url_c], [id_c] = mock_workers(n=1)
|
||||
rh = router_manager.start_router(
|
||||
worker_urls=[url_a, url_b, url_c],
|
||||
policy="round_robin",
|
||||
extra={
|
||||
"retry_max_retries": 3,
|
||||
"retry_initial_backoff_ms": 10,
|
||||
"retry_max_backoff_ms": 50,
|
||||
},
|
||||
)
|
||||
|
||||
r = requests.post(
|
||||
f"{rh.url}/v1/completions",
|
||||
json={
|
||||
"model": "test-model",
|
||||
"prompt": "x",
|
||||
"max_tokens": 1,
|
||||
"stream": False,
|
||||
},
|
||||
timeout=5,
|
||||
)
|
||||
assert r.status_code == 200
|
||||
wid = r.headers.get("X-Worker-Id") or r.json().get("worker_id")
|
||||
assert wid == id_b # should have retried onto healthy worker
|
||||
# mock_workers fixture handles cleanup
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_disable_retries_surfaces_failure(router_manager, mock_workers):
|
||||
# Single failing worker, retries disabled -> should return 500
|
||||
_, [url], [wid] = mock_workers(n=1, args=["--status-code", "500"]) # always fail
|
||||
rh = router_manager.start_router(
|
||||
worker_urls=[url],
|
||||
policy="round_robin",
|
||||
extra={
|
||||
"disable_retries": True,
|
||||
},
|
||||
)
|
||||
|
||||
r = requests.post(
|
||||
f"{rh.url}/v1/completions",
|
||||
json={
|
||||
"model": "test-model",
|
||||
"prompt": "x",
|
||||
"max_tokens": 1,
|
||||
"stream": False,
|
||||
},
|
||||
timeout=5,
|
||||
)
|
||||
assert r.status_code == 500
|
||||
# mock_workers fixture handles cleanup
|
||||
@@ -0,0 +1,36 @@
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_discovery_shim_add_remove(router_manager, mock_workers):
|
||||
# Start router without workers
|
||||
rh = router_manager.start_router(worker_urls=[], policy="round_robin")
|
||||
|
||||
# Initially empty
|
||||
urls = router_manager.list_workers(rh.url)
|
||||
assert urls == []
|
||||
|
||||
# Add a worker (simulate discovery event)
|
||||
_, [wurl], [wid] = mock_workers(n=1)
|
||||
router_manager.add_worker(rh.url, wurl)
|
||||
urls = router_manager.list_workers(rh.url)
|
||||
assert wurl in urls
|
||||
|
||||
# Can serve a request
|
||||
r = requests.post(
|
||||
f"{rh.url}/v1/completions",
|
||||
json={
|
||||
"model": "test-model",
|
||||
"prompt": "hi",
|
||||
"max_tokens": 1,
|
||||
"stream": False,
|
||||
},
|
||||
)
|
||||
assert r.status_code == 200
|
||||
|
||||
# Remove worker (simulate pod deletion)
|
||||
router_manager.remove_worker(rh.url, wurl)
|
||||
urls = router_manager.list_workers(rh.url)
|
||||
assert wurl not in urls
|
||||
# mock_workers fixture handles cleanup
|
||||
61
sgl-router/py_test/integration/test_worker_management.py
Normal file
61
sgl-router/py_test/integration/test_worker_management.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import collections
|
||||
import subprocess
|
||||
import time
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_add_and_remove_worker(mock_worker, router_manager, mock_workers):
|
||||
# Start with a single worker
|
||||
proc1, url1, id1 = mock_worker
|
||||
rh = router_manager.start_router(worker_urls=[url1], policy="round_robin")
|
||||
|
||||
# Add a second worker
|
||||
|
||||
procs2, urls2, ids2 = mock_workers(n=1)
|
||||
url2 = urls2[0]
|
||||
id2 = ids2[0]
|
||||
router_manager.add_worker(rh.url, url2)
|
||||
|
||||
# Send some requests and ensure both workers are seen
|
||||
seen = set()
|
||||
with requests.Session() as s:
|
||||
for i in range(20):
|
||||
r = s.post(
|
||||
f"{rh.url}/v1/completions",
|
||||
json={
|
||||
"model": "test-model",
|
||||
"prompt": f"x{i}",
|
||||
"max_tokens": 1,
|
||||
"stream": False,
|
||||
},
|
||||
)
|
||||
assert r.status_code == 200
|
||||
wid = r.headers.get("X-Worker-Id") or r.json().get("worker_id")
|
||||
seen.add(wid)
|
||||
if len(seen) == 2:
|
||||
break
|
||||
|
||||
assert id1 in seen and id2 in seen
|
||||
|
||||
# Now remove the second worker
|
||||
router_manager.remove_worker(rh.url, url2)
|
||||
|
||||
# After removal, subsequent requests should only come from first worker
|
||||
with requests.Session() as s:
|
||||
for i in range(10):
|
||||
r = s.post(
|
||||
f"{rh.url}/v1/completions",
|
||||
json={
|
||||
"model": "test-model",
|
||||
"prompt": f"y{i}",
|
||||
"max_tokens": 1,
|
||||
"stream": False,
|
||||
},
|
||||
)
|
||||
assert r.status_code == 200
|
||||
wid = r.headers.get("X-Worker-Id") or r.json().get("worker_id")
|
||||
assert wid == id1
|
||||
# mock_workers fixture handles cleanup
|
||||
7
sgl-router/py_test/unit/__init__.py
Normal file
7
sgl-router/py_test/unit/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""
|
||||
Unit tests for sglang_router.
|
||||
|
||||
This package contains fast, isolated unit tests for Python components
|
||||
of the SGLang router. These tests focus on testing individual functions
|
||||
and classes in isolation without starting actual router instances.
|
||||
"""
|
||||
628
sgl-router/py_test/unit/test_arg_parser.py
Normal file
628
sgl-router/py_test/unit/test_arg_parser.py
Normal file
@@ -0,0 +1,628 @@
|
||||
"""
|
||||
Unit tests for argument parsing functionality in sglang_router.
|
||||
|
||||
These tests focus on testing the argument parsing logic in isolation,
|
||||
without starting actual router instances.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from sglang_router.launch_router import RouterArgs, parse_router_args
|
||||
from sglang_router.router import policy_from_str
|
||||
|
||||
|
||||
class TestRouterArgs:
|
||||
"""Test RouterArgs dataclass and its methods."""
|
||||
|
||||
def test_default_values(self):
|
||||
"""Test that RouterArgs has correct default values."""
|
||||
args = RouterArgs()
|
||||
|
||||
# Test basic defaults
|
||||
assert args.host == "127.0.0.1"
|
||||
assert args.port == 30000
|
||||
assert args.policy == "cache_aware"
|
||||
assert args.worker_urls == []
|
||||
assert args.pd_disaggregation is False
|
||||
assert args.prefill_urls == []
|
||||
assert args.decode_urls == []
|
||||
|
||||
# Test PD-specific defaults
|
||||
assert args.prefill_policy is None
|
||||
assert args.decode_policy is None
|
||||
|
||||
# Test service discovery defaults
|
||||
assert args.service_discovery is False
|
||||
assert args.selector == {}
|
||||
assert args.service_discovery_port == 80
|
||||
assert args.service_discovery_namespace is None
|
||||
|
||||
# Test retry and circuit breaker defaults
|
||||
assert args.retry_max_retries == 5
|
||||
assert args.cb_failure_threshold == 10
|
||||
assert args.disable_retries is False
|
||||
assert args.disable_circuit_breaker is False
|
||||
|
||||
def test_parse_selector_valid(self):
|
||||
"""Test parsing valid selector arguments."""
|
||||
# Test single key-value pair
|
||||
result = RouterArgs._parse_selector(["app=worker"])
|
||||
assert result == {"app": "worker"}
|
||||
|
||||
# Test multiple key-value pairs
|
||||
result = RouterArgs._parse_selector(["app=worker", "env=prod", "version=v1"])
|
||||
assert result == {"app": "worker", "env": "prod", "version": "v1"}
|
||||
|
||||
# Test empty list
|
||||
result = RouterArgs._parse_selector([])
|
||||
assert result == {}
|
||||
|
||||
# Test None
|
||||
result = RouterArgs._parse_selector(None)
|
||||
assert result == {}
|
||||
|
||||
def test_parse_selector_invalid(self):
|
||||
"""Test parsing invalid selector arguments."""
|
||||
# Test malformed selector (no equals sign)
|
||||
result = RouterArgs._parse_selector(["app"])
|
||||
assert result == {}
|
||||
|
||||
# Test multiple equals signs (should use first one)
|
||||
result = RouterArgs._parse_selector(["app=worker=extra"])
|
||||
assert result == {"app": "worker=extra"}
|
||||
|
||||
def test_parse_prefill_urls_valid(self):
|
||||
"""Test parsing valid prefill URL arguments."""
|
||||
# Test with bootstrap port
|
||||
result = RouterArgs._parse_prefill_urls([["http://prefill1:8000", "9000"]])
|
||||
assert result == [("http://prefill1:8000", 9000)]
|
||||
|
||||
# Test with 'none' bootstrap port
|
||||
result = RouterArgs._parse_prefill_urls([["http://prefill1:8000", "none"]])
|
||||
assert result == [("http://prefill1:8000", None)]
|
||||
|
||||
# Test without bootstrap port
|
||||
result = RouterArgs._parse_prefill_urls([["http://prefill1:8000"]])
|
||||
assert result == [("http://prefill1:8000", None)]
|
||||
|
||||
# Test multiple prefill URLs
|
||||
result = RouterArgs._parse_prefill_urls(
|
||||
[
|
||||
["http://prefill1:8000", "9000"],
|
||||
["http://prefill2:8000", "none"],
|
||||
["http://prefill3:8000"],
|
||||
]
|
||||
)
|
||||
expected = [
|
||||
("http://prefill1:8000", 9000),
|
||||
("http://prefill2:8000", None),
|
||||
("http://prefill3:8000", None),
|
||||
]
|
||||
assert result == expected
|
||||
|
||||
# Test empty list
|
||||
result = RouterArgs._parse_prefill_urls([])
|
||||
assert result == []
|
||||
|
||||
# Test None
|
||||
result = RouterArgs._parse_prefill_urls(None)
|
||||
assert result == []
|
||||
|
||||
def test_parse_prefill_urls_invalid(self):
|
||||
"""Test parsing invalid prefill URL arguments."""
|
||||
# Test invalid bootstrap port
|
||||
with pytest.raises(ValueError, match="Invalid bootstrap port"):
|
||||
RouterArgs._parse_prefill_urls([["http://prefill1:8000", "invalid"]])
|
||||
|
||||
def test_parse_decode_urls_valid(self):
|
||||
"""Test parsing valid decode URL arguments."""
|
||||
# Test single decode URL
|
||||
result = RouterArgs._parse_decode_urls([["http://decode1:8001"]])
|
||||
assert result == ["http://decode1:8001"]
|
||||
|
||||
# Test multiple decode URLs
|
||||
result = RouterArgs._parse_decode_urls(
|
||||
[["http://decode1:8001"], ["http://decode2:8001"]]
|
||||
)
|
||||
assert result == ["http://decode1:8001", "http://decode2:8001"]
|
||||
|
||||
# Test empty list
|
||||
result = RouterArgs._parse_decode_urls([])
|
||||
assert result == []
|
||||
|
||||
# Test None
|
||||
result = RouterArgs._parse_decode_urls(None)
|
||||
assert result == []
|
||||
|
||||
def test_from_cli_args_basic(self):
|
||||
"""Test creating RouterArgs from basic CLI arguments."""
|
||||
args = SimpleNamespace(
|
||||
host="0.0.0.0",
|
||||
port=30001,
|
||||
worker_urls=["http://worker1:8000", "http://worker2:8000"],
|
||||
policy="round_robin",
|
||||
prefill=None,
|
||||
decode=None,
|
||||
router_policy="round_robin",
|
||||
router_pd_disaggregation=False,
|
||||
router_prefill_policy=None,
|
||||
router_decode_policy=None,
|
||||
router_worker_startup_timeout_secs=300,
|
||||
router_worker_startup_check_interval=15,
|
||||
router_cache_threshold=0.7,
|
||||
router_balance_abs_threshold=128,
|
||||
router_balance_rel_threshold=2.0,
|
||||
router_eviction_interval=180,
|
||||
router_max_tree_size=2**28,
|
||||
router_max_payload_size=1024 * 1024 * 1024, # 1GB
|
||||
router_dp_aware=True,
|
||||
router_api_key="test-key",
|
||||
router_log_dir="/tmp/logs",
|
||||
router_log_level="debug",
|
||||
router_service_discovery=True,
|
||||
router_selector=["app=worker", "env=test"],
|
||||
router_service_discovery_port=8080,
|
||||
router_service_discovery_namespace="default",
|
||||
router_prefill_selector=["app=prefill"],
|
||||
router_decode_selector=["app=decode"],
|
||||
router_prometheus_port=29000,
|
||||
router_prometheus_host="0.0.0.0",
|
||||
router_request_id_headers=["x-request-id", "x-trace-id"],
|
||||
router_request_timeout_secs=1200,
|
||||
router_max_concurrent_requests=512,
|
||||
router_queue_size=200,
|
||||
router_queue_timeout_secs=120,
|
||||
router_rate_limit_tokens_per_second=100,
|
||||
router_cors_allowed_origins=["http://localhost:3000"],
|
||||
router_retry_max_retries=3,
|
||||
router_retry_initial_backoff_ms=100,
|
||||
router_retry_max_backoff_ms=10000,
|
||||
router_retry_backoff_multiplier=2.0,
|
||||
router_retry_jitter_factor=0.1,
|
||||
router_cb_failure_threshold=5,
|
||||
router_cb_success_threshold=2,
|
||||
router_cb_timeout_duration_secs=30,
|
||||
router_cb_window_duration_secs=60,
|
||||
router_disable_retries=False,
|
||||
router_disable_circuit_breaker=False,
|
||||
router_health_failure_threshold=2,
|
||||
router_health_success_threshold=1,
|
||||
router_health_check_timeout_secs=3,
|
||||
router_health_check_interval_secs=30,
|
||||
router_health_check_endpoint="/healthz",
|
||||
)
|
||||
|
||||
router_args = RouterArgs.from_cli_args(args, use_router_prefix=True)
|
||||
|
||||
# Test basic configuration
|
||||
assert router_args.host == "0.0.0.0"
|
||||
assert router_args.port == 30001
|
||||
assert router_args.worker_urls == ["http://worker1:8000", "http://worker2:8000"]
|
||||
assert router_args.policy == "round_robin"
|
||||
|
||||
# Test PD configuration
|
||||
assert router_args.pd_disaggregation is False
|
||||
assert router_args.prefill_urls == []
|
||||
assert router_args.decode_urls == []
|
||||
|
||||
# Test service discovery
|
||||
assert router_args.service_discovery is True
|
||||
assert router_args.selector == {"app": "worker", "env": "test"}
|
||||
assert router_args.service_discovery_port == 8080
|
||||
assert router_args.service_discovery_namespace == "default"
|
||||
assert router_args.prefill_selector == {"app": "prefill"}
|
||||
assert router_args.decode_selector == {"app": "decode"}
|
||||
|
||||
# Test other configurations
|
||||
assert router_args.dp_aware is True
|
||||
assert router_args.api_key == "test-key"
|
||||
assert router_args.log_dir == "/tmp/logs"
|
||||
assert router_args.log_level == "debug"
|
||||
assert router_args.prometheus_port == 29000
|
||||
assert router_args.prometheus_host == "0.0.0.0"
|
||||
assert router_args.request_id_headers == ["x-request-id", "x-trace-id"]
|
||||
assert router_args.request_timeout_secs == 1200
|
||||
assert router_args.max_concurrent_requests == 512
|
||||
assert router_args.queue_size == 200
|
||||
assert router_args.queue_timeout_secs == 120
|
||||
assert router_args.rate_limit_tokens_per_second == 100
|
||||
assert router_args.cors_allowed_origins == ["http://localhost:3000"]
|
||||
|
||||
# Test retry configuration
|
||||
assert router_args.retry_max_retries == 3
|
||||
assert router_args.retry_initial_backoff_ms == 100
|
||||
assert router_args.retry_max_backoff_ms == 10000
|
||||
assert router_args.retry_backoff_multiplier == 2.0
|
||||
assert router_args.retry_jitter_factor == 0.1
|
||||
|
||||
# Test circuit breaker configuration
|
||||
assert router_args.cb_failure_threshold == 5
|
||||
assert router_args.cb_success_threshold == 2
|
||||
assert router_args.cb_timeout_duration_secs == 30
|
||||
assert router_args.cb_window_duration_secs == 60
|
||||
assert router_args.disable_retries is False
|
||||
assert router_args.disable_circuit_breaker is False
|
||||
|
||||
# Test health check configuration
|
||||
assert router_args.health_failure_threshold == 2
|
||||
assert router_args.health_success_threshold == 1
|
||||
assert router_args.health_check_timeout_secs == 3
|
||||
assert router_args.health_check_interval_secs == 30
|
||||
assert router_args.health_check_endpoint == "/healthz"
|
||||
|
||||
# Note: model_path and tokenizer_path are not available in current RouterArgs
|
||||
|
||||
def test_from_cli_args_pd_mode(self):
|
||||
"""Test creating RouterArgs from CLI arguments in PD mode."""
|
||||
args = SimpleNamespace(
|
||||
host="127.0.0.1",
|
||||
port=30000,
|
||||
worker_urls=[],
|
||||
policy="cache_aware",
|
||||
prefill=[
|
||||
["http://prefill1:8000", "9000"],
|
||||
["http://prefill2:8000", "none"],
|
||||
],
|
||||
decode=[["http://decode1:8001"], ["http://decode2:8001"]],
|
||||
router_prefill=[
|
||||
["http://prefill1:8000", "9000"],
|
||||
["http://prefill2:8000", "none"],
|
||||
],
|
||||
router_decode=[["http://decode1:8001"], ["http://decode2:8001"]],
|
||||
router_policy="cache_aware",
|
||||
router_pd_disaggregation=True,
|
||||
router_prefill_policy="power_of_two",
|
||||
router_decode_policy="round_robin",
|
||||
# Include all required fields with defaults
|
||||
router_worker_startup_timeout_secs=600,
|
||||
router_worker_startup_check_interval=30,
|
||||
router_cache_threshold=0.3,
|
||||
router_balance_abs_threshold=64,
|
||||
router_balance_rel_threshold=1.5,
|
||||
router_eviction_interval=120,
|
||||
router_max_tree_size=2**26,
|
||||
router_max_payload_size=512 * 1024 * 1024,
|
||||
router_dp_aware=False,
|
||||
router_api_key=None,
|
||||
router_log_dir=None,
|
||||
router_log_level=None,
|
||||
router_service_discovery=False,
|
||||
router_selector=None,
|
||||
router_service_discovery_port=80,
|
||||
router_service_discovery_namespace=None,
|
||||
router_prefill_selector=None,
|
||||
router_decode_selector=None,
|
||||
router_prometheus_port=None,
|
||||
router_prometheus_host=None,
|
||||
router_request_id_headers=None,
|
||||
router_request_timeout_secs=1800,
|
||||
router_max_concurrent_requests=256,
|
||||
router_queue_size=100,
|
||||
router_queue_timeout_secs=60,
|
||||
router_rate_limit_tokens_per_second=None,
|
||||
router_cors_allowed_origins=[],
|
||||
router_retry_max_retries=5,
|
||||
router_retry_initial_backoff_ms=50,
|
||||
router_retry_max_backoff_ms=30000,
|
||||
router_retry_backoff_multiplier=1.5,
|
||||
router_retry_jitter_factor=0.2,
|
||||
router_cb_failure_threshold=10,
|
||||
router_cb_success_threshold=3,
|
||||
router_cb_timeout_duration_secs=60,
|
||||
router_cb_window_duration_secs=120,
|
||||
router_disable_retries=False,
|
||||
router_disable_circuit_breaker=False,
|
||||
router_health_failure_threshold=3,
|
||||
router_health_success_threshold=2,
|
||||
router_health_check_timeout_secs=5,
|
||||
router_health_check_interval_secs=60,
|
||||
router_health_check_endpoint="/health",
|
||||
)
|
||||
|
||||
router_args = RouterArgs.from_cli_args(args, use_router_prefix=True)
|
||||
|
||||
# Test PD configuration
|
||||
assert router_args.pd_disaggregation is True
|
||||
assert router_args.prefill_urls == [
|
||||
("http://prefill1:8000", 9000),
|
||||
("http://prefill2:8000", None),
|
||||
]
|
||||
assert router_args.decode_urls == ["http://decode1:8001", "http://decode2:8001"]
|
||||
assert router_args.prefill_policy == "power_of_two"
|
||||
assert router_args.decode_policy == "round_robin"
|
||||
assert router_args.policy == "cache_aware" # Main policy still set
|
||||
|
||||
def test_from_cli_args_without_prefix(self):
|
||||
"""Test creating RouterArgs from CLI arguments without router prefix."""
|
||||
args = SimpleNamespace(
|
||||
host="127.0.0.1",
|
||||
port=30000,
|
||||
worker_urls=["http://worker1:8000"],
|
||||
policy="random",
|
||||
prefill=None,
|
||||
decode=None,
|
||||
pd_disaggregation=False,
|
||||
prefill_policy=None,
|
||||
decode_policy=None,
|
||||
worker_startup_timeout_secs=600,
|
||||
worker_startup_check_interval=30,
|
||||
cache_threshold=0.3,
|
||||
balance_abs_threshold=64,
|
||||
balance_rel_threshold=1.5,
|
||||
eviction_interval=120,
|
||||
max_tree_size=2**26,
|
||||
max_payload_size=512 * 1024 * 1024,
|
||||
dp_aware=False,
|
||||
api_key=None,
|
||||
log_dir=None,
|
||||
log_level=None,
|
||||
service_discovery=False,
|
||||
selector=None,
|
||||
service_discovery_port=80,
|
||||
service_discovery_namespace=None,
|
||||
prefill_selector=None,
|
||||
decode_selector=None,
|
||||
prometheus_port=None,
|
||||
prometheus_host=None,
|
||||
request_id_headers=None,
|
||||
request_timeout_secs=1800,
|
||||
max_concurrent_requests=256,
|
||||
queue_size=100,
|
||||
queue_timeout_secs=60,
|
||||
rate_limit_tokens_per_second=None,
|
||||
cors_allowed_origins=[],
|
||||
retry_max_retries=5,
|
||||
retry_initial_backoff_ms=50,
|
||||
retry_max_backoff_ms=30000,
|
||||
retry_backoff_multiplier=1.5,
|
||||
retry_jitter_factor=0.2,
|
||||
cb_failure_threshold=10,
|
||||
cb_success_threshold=3,
|
||||
cb_timeout_duration_secs=60,
|
||||
cb_window_duration_secs=120,
|
||||
disable_retries=False,
|
||||
disable_circuit_breaker=False,
|
||||
health_failure_threshold=3,
|
||||
health_success_threshold=2,
|
||||
health_check_timeout_secs=5,
|
||||
health_check_interval_secs=60,
|
||||
health_check_endpoint="/health",
|
||||
model_path=None,
|
||||
tokenizer_path=None,
|
||||
)
|
||||
|
||||
router_args = RouterArgs.from_cli_args(args, use_router_prefix=False)
|
||||
|
||||
assert router_args.host == "127.0.0.1"
|
||||
assert router_args.port == 30000
|
||||
assert router_args.worker_urls == ["http://worker1:8000"]
|
||||
assert router_args.policy == "random"
|
||||
assert router_args.pd_disaggregation is False
|
||||
|
||||
|
||||
class TestPolicyFromStr:
|
||||
"""Test policy string to enum conversion."""
|
||||
|
||||
def test_valid_policies(self):
|
||||
"""Test conversion of valid policy strings."""
|
||||
from sglang_router_rs import PolicyType
|
||||
|
||||
assert policy_from_str("random") == PolicyType.Random
|
||||
assert policy_from_str("round_robin") == PolicyType.RoundRobin
|
||||
assert policy_from_str("cache_aware") == PolicyType.CacheAware
|
||||
assert policy_from_str("power_of_two") == PolicyType.PowerOfTwo
|
||||
|
||||
def test_invalid_policy(self):
|
||||
"""Test conversion of invalid policy string."""
|
||||
with pytest.raises(KeyError):
|
||||
policy_from_str("invalid_policy")
|
||||
|
||||
|
||||
class TestParseRouterArgs:
|
||||
"""Test the parse_router_args function."""
|
||||
|
||||
def test_parse_basic_args(self):
|
||||
"""Test parsing basic router arguments."""
|
||||
args = [
|
||||
"--host",
|
||||
"0.0.0.0",
|
||||
"--port",
|
||||
"30001",
|
||||
"--worker-urls",
|
||||
"http://worker1:8000",
|
||||
"http://worker2:8000",
|
||||
"--policy",
|
||||
"round_robin",
|
||||
]
|
||||
|
||||
router_args = parse_router_args(args)
|
||||
|
||||
assert router_args.host == "0.0.0.0"
|
||||
assert router_args.port == 30001
|
||||
assert router_args.worker_urls == ["http://worker1:8000", "http://worker2:8000"]
|
||||
assert router_args.policy == "round_robin"
|
||||
|
||||
def test_parse_pd_args(self):
|
||||
"""Test parsing PD disaggregated mode arguments."""
|
||||
args = [
|
||||
"--pd-disaggregation",
|
||||
"--prefill",
|
||||
"http://prefill1:8000",
|
||||
"9000",
|
||||
"--prefill",
|
||||
"http://prefill2:8000",
|
||||
"none",
|
||||
"--decode",
|
||||
"http://decode1:8001",
|
||||
"--decode",
|
||||
"http://decode2:8001",
|
||||
"--prefill-policy",
|
||||
"power_of_two",
|
||||
"--decode-policy",
|
||||
"round_robin",
|
||||
]
|
||||
|
||||
router_args = parse_router_args(args)
|
||||
|
||||
assert router_args.pd_disaggregation is True
|
||||
assert router_args.prefill_urls == [
|
||||
("http://prefill1:8000", 9000),
|
||||
("http://prefill2:8000", None),
|
||||
]
|
||||
assert router_args.decode_urls == ["http://decode1:8001", "http://decode2:8001"]
|
||||
assert router_args.prefill_policy == "power_of_two"
|
||||
assert router_args.decode_policy == "round_robin"
|
||||
|
||||
def test_parse_service_discovery_args(self):
|
||||
"""Test parsing service discovery arguments."""
|
||||
args = [
|
||||
"--service-discovery",
|
||||
"--selector",
|
||||
"app=worker",
|
||||
"env=prod",
|
||||
"--service-discovery-port",
|
||||
"8080",
|
||||
"--service-discovery-namespace",
|
||||
"default",
|
||||
]
|
||||
|
||||
router_args = parse_router_args(args)
|
||||
|
||||
assert router_args.service_discovery is True
|
||||
assert router_args.selector == {"app": "worker", "env": "prod"}
|
||||
assert router_args.service_discovery_port == 8080
|
||||
assert router_args.service_discovery_namespace == "default"
|
||||
|
||||
def test_parse_retry_and_circuit_breaker_args(self):
|
||||
"""Test parsing retry and circuit breaker arguments."""
|
||||
args = [
|
||||
"--retry-max-retries",
|
||||
"3",
|
||||
"--retry-initial-backoff-ms",
|
||||
"100",
|
||||
"--retry-max-backoff-ms",
|
||||
"10000",
|
||||
"--retry-backoff-multiplier",
|
||||
"2.0",
|
||||
"--retry-jitter-factor",
|
||||
"0.1",
|
||||
"--disable-retries",
|
||||
"--cb-failure-threshold",
|
||||
"5",
|
||||
"--cb-success-threshold",
|
||||
"2",
|
||||
"--cb-timeout-duration-secs",
|
||||
"30",
|
||||
"--cb-window-duration-secs",
|
||||
"60",
|
||||
"--disable-circuit-breaker",
|
||||
]
|
||||
|
||||
router_args = parse_router_args(args)
|
||||
|
||||
# Test retry configuration
|
||||
assert router_args.retry_max_retries == 3
|
||||
assert router_args.retry_initial_backoff_ms == 100
|
||||
assert router_args.retry_max_backoff_ms == 10000
|
||||
assert router_args.retry_backoff_multiplier == 2.0
|
||||
assert router_args.retry_jitter_factor == 0.1
|
||||
assert router_args.disable_retries is True
|
||||
|
||||
# Test circuit breaker configuration
|
||||
assert router_args.cb_failure_threshold == 5
|
||||
assert router_args.cb_success_threshold == 2
|
||||
assert router_args.cb_timeout_duration_secs == 30
|
||||
assert router_args.cb_window_duration_secs == 60
|
||||
assert router_args.disable_circuit_breaker is True
|
||||
|
||||
def test_parse_rate_limiting_args(self):
|
||||
"""Test parsing rate limiting arguments."""
|
||||
args = [
|
||||
"--max-concurrent-requests",
|
||||
"512",
|
||||
"--queue-size",
|
||||
"200",
|
||||
"--queue-timeout-secs",
|
||||
"120",
|
||||
"--rate-limit-tokens-per-second",
|
||||
"100",
|
||||
]
|
||||
|
||||
router_args = parse_router_args(args)
|
||||
|
||||
assert router_args.max_concurrent_requests == 512
|
||||
assert router_args.queue_size == 200
|
||||
assert router_args.queue_timeout_secs == 120
|
||||
assert router_args.rate_limit_tokens_per_second == 100
|
||||
|
||||
def test_parse_health_check_args(self):
|
||||
"""Test parsing health check arguments."""
|
||||
args = [
|
||||
"--health-failure-threshold",
|
||||
"2",
|
||||
"--health-success-threshold",
|
||||
"1",
|
||||
"--health-check-timeout-secs",
|
||||
"3",
|
||||
"--health-check-interval-secs",
|
||||
"30",
|
||||
"--health-check-endpoint",
|
||||
"/healthz",
|
||||
]
|
||||
|
||||
router_args = parse_router_args(args)
|
||||
|
||||
assert router_args.health_failure_threshold == 2
|
||||
assert router_args.health_success_threshold == 1
|
||||
assert router_args.health_check_timeout_secs == 3
|
||||
assert router_args.health_check_interval_secs == 30
|
||||
assert router_args.health_check_endpoint == "/healthz"
|
||||
|
||||
def test_parse_cors_args(self):
|
||||
"""Test parsing CORS arguments."""
|
||||
args = [
|
||||
"--cors-allowed-origins",
|
||||
"http://localhost:3000",
|
||||
"https://example.com",
|
||||
]
|
||||
|
||||
router_args = parse_router_args(args)
|
||||
|
||||
assert router_args.cors_allowed_origins == [
|
||||
"http://localhost:3000",
|
||||
"https://example.com",
|
||||
]
|
||||
|
||||
def test_parse_tokenizer_args(self):
|
||||
"""Test parsing tokenizer arguments."""
|
||||
# Note: model-path and tokenizer-path arguments are not available in current implementation
|
||||
# This test is skipped until those arguments are added
|
||||
pytest.skip("Tokenizer arguments not available in current implementation")
|
||||
|
||||
def test_parse_invalid_args(self):
|
||||
"""Test parsing invalid arguments."""
|
||||
# Test invalid policy
|
||||
with pytest.raises(SystemExit):
|
||||
parse_router_args(["--policy", "invalid_policy"])
|
||||
|
||||
# Test invalid bootstrap port
|
||||
with pytest.raises(ValueError, match="Invalid bootstrap port"):
|
||||
parse_router_args(
|
||||
[
|
||||
"--pd-disaggregation",
|
||||
"--prefill",
|
||||
"http://prefill1:8000",
|
||||
"invalid_port",
|
||||
]
|
||||
)
|
||||
|
||||
def test_help_output(self):
|
||||
"""Test that help output is generated correctly."""
|
||||
with pytest.raises(SystemExit) as exc_info:
|
||||
parse_router_args(["--help"])
|
||||
|
||||
# SystemExit with code 0 indicates help was displayed
|
||||
assert exc_info.value.code == 0
|
||||
421
sgl-router/py_test/unit/test_router_config.py
Normal file
421
sgl-router/py_test/unit/test_router_config.py
Normal file
@@ -0,0 +1,421 @@
|
||||
"""
|
||||
Unit tests for router configuration validation and setup.
|
||||
|
||||
These tests focus on testing the router configuration logic in isolation,
|
||||
including validation of configuration parameters and their interactions.
|
||||
"""
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from sglang_router.launch_router import RouterArgs, launch_router
|
||||
from sglang_router.router import policy_from_str
|
||||
from sglang_router_rs import PolicyType
|
||||
|
||||
|
||||
class TestRouterConfigValidation:
|
||||
"""Test router configuration validation logic."""
|
||||
|
||||
def test_valid_basic_config(self):
|
||||
"""Test that a valid basic configuration passes validation."""
|
||||
args = RouterArgs(
|
||||
host="127.0.0.1",
|
||||
port=30000,
|
||||
worker_urls=["http://worker1:8000", "http://worker2:8000"],
|
||||
policy="cache_aware",
|
||||
)
|
||||
|
||||
# Should not raise any exceptions
|
||||
assert args.host == "127.0.0.1"
|
||||
assert args.port == 30000
|
||||
assert args.worker_urls == ["http://worker1:8000", "http://worker2:8000"]
|
||||
assert args.policy == "cache_aware"
|
||||
|
||||
def test_valid_pd_config(self):
|
||||
"""Test that a valid PD configuration passes validation."""
|
||||
args = RouterArgs(
|
||||
host="127.0.0.1",
|
||||
port=30000,
|
||||
pd_disaggregation=True,
|
||||
prefill_urls=[
|
||||
("http://prefill1:8000", 9000),
|
||||
("http://prefill2:8000", None),
|
||||
],
|
||||
decode_urls=["http://decode1:8001", "http://decode2:8001"],
|
||||
policy="cache_aware",
|
||||
)
|
||||
|
||||
assert args.pd_disaggregation is True
|
||||
assert args.prefill_urls == [
|
||||
("http://prefill1:8000", 9000),
|
||||
("http://prefill2:8000", None),
|
||||
]
|
||||
assert args.decode_urls == ["http://decode1:8001", "http://decode2:8001"]
|
||||
assert args.policy == "cache_aware"
|
||||
|
||||
def test_pd_config_without_urls_raises_error(self):
|
||||
"""Test that PD mode without URLs raises validation error."""
|
||||
args = RouterArgs(
|
||||
pd_disaggregation=True,
|
||||
prefill_urls=[],
|
||||
decode_urls=[],
|
||||
service_discovery=False,
|
||||
)
|
||||
|
||||
# This should raise an error when trying to launch
|
||||
with pytest.raises(
|
||||
ValueError, match="PD disaggregation mode requires --prefill"
|
||||
):
|
||||
launch_router(args)
|
||||
|
||||
def test_pd_config_with_service_discovery_allows_empty_urls(self):
|
||||
"""Test that PD mode with service discovery allows empty URLs."""
|
||||
args = RouterArgs(
|
||||
pd_disaggregation=True,
|
||||
prefill_urls=[],
|
||||
decode_urls=[],
|
||||
service_discovery=True,
|
||||
)
|
||||
|
||||
# Should not raise validation error when service discovery is enabled
|
||||
with patch("sglang_router.launch_router.Router") as router_mod:
|
||||
mock_router_instance = MagicMock()
|
||||
router_mod.from_args = MagicMock(return_value=mock_router_instance)
|
||||
|
||||
launch_router(args)
|
||||
|
||||
# Should create router instance via from_args
|
||||
router_mod.from_args.assert_called_once()
|
||||
|
||||
def test_regular_mode_without_workers_allows_empty_urls(self):
|
||||
"""Test that regular mode allows empty worker URLs."""
|
||||
args = RouterArgs(worker_urls=[], service_discovery=False)
|
||||
|
||||
# Should not raise validation error
|
||||
with patch("sglang_router.launch_router.Router") as router_mod:
|
||||
mock_router_instance = MagicMock()
|
||||
router_mod.from_args = MagicMock(return_value=mock_router_instance)
|
||||
|
||||
launch_router(args)
|
||||
|
||||
# Should create router instance via from_args
|
||||
router_mod.from_args.assert_called_once()
|
||||
|
||||
def test_cache_threshold_validation(self):
|
||||
"""Test cache threshold validation."""
|
||||
# Valid cache threshold
|
||||
args = RouterArgs(cache_threshold=0.5)
|
||||
assert args.cache_threshold == 0.5
|
||||
|
||||
# Edge cases
|
||||
args = RouterArgs(cache_threshold=0.0)
|
||||
assert args.cache_threshold == 0.0
|
||||
|
||||
args = RouterArgs(cache_threshold=1.0)
|
||||
assert args.cache_threshold == 1.0
|
||||
|
||||
def test_balance_threshold_validation(self):
|
||||
"""Test load balancing threshold validation."""
|
||||
# Valid thresholds
|
||||
args = RouterArgs(balance_abs_threshold=64, balance_rel_threshold=1.5)
|
||||
assert args.balance_abs_threshold == 64
|
||||
assert args.balance_rel_threshold == 1.5
|
||||
|
||||
# Edge cases
|
||||
args = RouterArgs(balance_abs_threshold=0, balance_rel_threshold=1.0)
|
||||
assert args.balance_abs_threshold == 0
|
||||
assert args.balance_rel_threshold == 1.0
|
||||
|
||||
def test_timeout_validation(self):
|
||||
"""Test timeout parameter validation."""
|
||||
# Valid timeouts
|
||||
args = RouterArgs(
|
||||
worker_startup_timeout_secs=600,
|
||||
worker_startup_check_interval=30,
|
||||
request_timeout_secs=1800,
|
||||
queue_timeout_secs=60,
|
||||
)
|
||||
assert args.worker_startup_timeout_secs == 600
|
||||
assert args.worker_startup_check_interval == 30
|
||||
assert args.request_timeout_secs == 1800
|
||||
assert args.queue_timeout_secs == 60
|
||||
|
||||
def test_retry_config_validation(self):
|
||||
"""Test retry configuration validation."""
|
||||
# Valid retry config
|
||||
args = RouterArgs(
|
||||
retry_max_retries=5,
|
||||
retry_initial_backoff_ms=50,
|
||||
retry_max_backoff_ms=30000,
|
||||
retry_backoff_multiplier=1.5,
|
||||
retry_jitter_factor=0.2,
|
||||
disable_retries=False,
|
||||
)
|
||||
assert args.retry_max_retries == 5
|
||||
assert args.retry_initial_backoff_ms == 50
|
||||
assert args.retry_max_backoff_ms == 30000
|
||||
assert args.retry_backoff_multiplier == 1.5
|
||||
assert args.retry_jitter_factor == 0.2
|
||||
assert args.disable_retries is False
|
||||
|
||||
def test_circuit_breaker_config_validation(self):
|
||||
"""Test circuit breaker configuration validation."""
|
||||
# Valid circuit breaker config
|
||||
args = RouterArgs(
|
||||
cb_failure_threshold=10,
|
||||
cb_success_threshold=3,
|
||||
cb_timeout_duration_secs=60,
|
||||
cb_window_duration_secs=120,
|
||||
disable_circuit_breaker=False,
|
||||
)
|
||||
assert args.cb_failure_threshold == 10
|
||||
assert args.cb_success_threshold == 3
|
||||
assert args.cb_timeout_duration_secs == 60
|
||||
assert args.cb_window_duration_secs == 120
|
||||
assert args.disable_circuit_breaker is False
|
||||
|
||||
def test_health_check_config_validation(self):
|
||||
"""Test health check configuration validation."""
|
||||
# Valid health check config
|
||||
args = RouterArgs(
|
||||
health_failure_threshold=3,
|
||||
health_success_threshold=2,
|
||||
health_check_timeout_secs=5,
|
||||
health_check_interval_secs=60,
|
||||
health_check_endpoint="/health",
|
||||
)
|
||||
assert args.health_failure_threshold == 3
|
||||
assert args.health_success_threshold == 2
|
||||
assert args.health_check_timeout_secs == 5
|
||||
assert args.health_check_interval_secs == 60
|
||||
assert args.health_check_endpoint == "/health"
|
||||
|
||||
def test_rate_limiting_config_validation(self):
|
||||
"""Test rate limiting configuration validation."""
|
||||
# Valid rate limiting config
|
||||
args = RouterArgs(
|
||||
max_concurrent_requests=256,
|
||||
queue_size=100,
|
||||
queue_timeout_secs=60,
|
||||
rate_limit_tokens_per_second=100,
|
||||
)
|
||||
assert args.max_concurrent_requests == 256
|
||||
assert args.queue_size == 100
|
||||
assert args.queue_timeout_secs == 60
|
||||
assert args.rate_limit_tokens_per_second == 100
|
||||
|
||||
def test_service_discovery_config_validation(self):
|
||||
"""Test service discovery configuration validation."""
|
||||
# Valid service discovery config
|
||||
args = RouterArgs(
|
||||
service_discovery=True,
|
||||
selector={"app": "worker", "env": "prod"},
|
||||
service_discovery_port=8080,
|
||||
service_discovery_namespace="default",
|
||||
)
|
||||
assert args.service_discovery is True
|
||||
assert args.selector == {"app": "worker", "env": "prod"}
|
||||
assert args.service_discovery_port == 8080
|
||||
assert args.service_discovery_namespace == "default"
|
||||
|
||||
def test_pd_service_discovery_config_validation(self):
|
||||
"""Test PD service discovery configuration validation."""
|
||||
# Valid PD service discovery config
|
||||
args = RouterArgs(
|
||||
pd_disaggregation=True,
|
||||
service_discovery=True,
|
||||
prefill_selector={"app": "prefill"},
|
||||
decode_selector={"app": "decode"},
|
||||
bootstrap_port_annotation="sglang.ai/bootstrap-port",
|
||||
)
|
||||
assert args.pd_disaggregation is True
|
||||
assert args.service_discovery is True
|
||||
assert args.prefill_selector == {"app": "prefill"}
|
||||
assert args.decode_selector == {"app": "decode"}
|
||||
assert args.bootstrap_port_annotation == "sglang.ai/bootstrap-port"
|
||||
|
||||
def test_prometheus_config_validation(self):
|
||||
"""Test Prometheus configuration validation."""
|
||||
# Valid Prometheus config
|
||||
args = RouterArgs(prometheus_port=29000, prometheus_host="127.0.0.1")
|
||||
assert args.prometheus_port == 29000
|
||||
assert args.prometheus_host == "127.0.0.1"
|
||||
|
||||
def test_cors_config_validation(self):
|
||||
"""Test CORS configuration validation."""
|
||||
# Valid CORS config
|
||||
args = RouterArgs(
|
||||
cors_allowed_origins=["http://localhost:3000", "https://example.com"]
|
||||
)
|
||||
assert args.cors_allowed_origins == [
|
||||
"http://localhost:3000",
|
||||
"https://example.com",
|
||||
]
|
||||
|
||||
def test_tokenizer_config_validation(self):
|
||||
"""Test tokenizer configuration validation."""
|
||||
# Note: model_path and tokenizer_path are not available in current RouterArgs
|
||||
pytest.skip("Tokenizer configuration not available in current implementation")
|
||||
|
||||
def test_dp_aware_config_validation(self):
|
||||
"""Test data parallelism aware configuration validation."""
|
||||
# Valid DP aware config
|
||||
args = RouterArgs(dp_aware=True, api_key="test-api-key")
|
||||
assert args.dp_aware is True
|
||||
assert args.api_key == "test-api-key"
|
||||
|
||||
def test_request_id_headers_validation(self):
|
||||
"""Test request ID headers configuration validation."""
|
||||
# Valid request ID headers config
|
||||
args = RouterArgs(
|
||||
request_id_headers=["x-request-id", "x-trace-id", "x-correlation-id"]
|
||||
)
|
||||
assert args.request_id_headers == [
|
||||
"x-request-id",
|
||||
"x-trace-id",
|
||||
"x-correlation-id",
|
||||
]
|
||||
|
||||
def test_policy_consistency_validation(self):
|
||||
"""Test policy consistency validation in PD mode."""
|
||||
# Test with both prefill and decode policies specified
|
||||
args = RouterArgs(
|
||||
pd_disaggregation=True,
|
||||
prefill_urls=[("http://prefill1:8000", None)],
|
||||
decode_urls=["http://decode1:8001"],
|
||||
policy="cache_aware",
|
||||
prefill_policy="power_of_two",
|
||||
decode_policy="round_robin",
|
||||
)
|
||||
|
||||
# Should not raise validation error
|
||||
with patch("sglang_router.launch_router.Router") as router_mod:
|
||||
mock_router_instance = MagicMock()
|
||||
router_mod.from_args = MagicMock(return_value=mock_router_instance)
|
||||
|
||||
launch_router(args)
|
||||
|
||||
# Should create router instance via from_args
|
||||
router_mod.from_args.assert_called_once()
|
||||
|
||||
def test_policy_fallback_validation(self):
|
||||
"""Test policy fallback validation in PD mode."""
|
||||
# Test with only prefill policy specified
|
||||
args = RouterArgs(
|
||||
pd_disaggregation=True,
|
||||
prefill_urls=[("http://prefill1:8000", None)],
|
||||
decode_urls=["http://decode1:8001"],
|
||||
policy="cache_aware",
|
||||
prefill_policy="power_of_two",
|
||||
decode_policy=None,
|
||||
)
|
||||
|
||||
# Should not raise validation error
|
||||
with patch("sglang_router.launch_router.Router") as router_mod:
|
||||
mock_router_instance = MagicMock()
|
||||
router_mod.from_args = MagicMock(return_value=mock_router_instance)
|
||||
|
||||
launch_router(args)
|
||||
|
||||
# Should create router instance via from_args
|
||||
router_mod.from_args.assert_called_once()
|
||||
|
||||
def test_policy_enum_conversion(self):
|
||||
"""Test policy string to enum conversion."""
|
||||
# Test all valid policy conversions
|
||||
assert policy_from_str("random") == PolicyType.Random
|
||||
assert policy_from_str("round_robin") == PolicyType.RoundRobin
|
||||
assert policy_from_str("cache_aware") == PolicyType.CacheAware
|
||||
assert policy_from_str("power_of_two") == PolicyType.PowerOfTwo
|
||||
|
||||
def test_invalid_policy_enum_conversion(self):
|
||||
"""Test invalid policy string to enum conversion."""
|
||||
with pytest.raises(KeyError):
|
||||
policy_from_str("invalid_policy")
|
||||
|
||||
def test_config_immutability(self):
|
||||
"""Test that configuration objects are properly immutable."""
|
||||
args = RouterArgs(
|
||||
host="127.0.0.1", port=30000, worker_urls=["http://worker1:8000"]
|
||||
)
|
||||
|
||||
# Test that we can't modify the configuration after creation
|
||||
# (This is more of a design test - dataclasses are mutable by default)
|
||||
original_host = args.host
|
||||
args.host = "0.0.0.0"
|
||||
assert args.host == "0.0.0.0" # Dataclasses are mutable
|
||||
assert args.host != original_host
|
||||
|
||||
def test_config_defaults_consistency(self):
|
||||
"""Test that configuration defaults are consistent."""
|
||||
args1 = RouterArgs()
|
||||
args2 = RouterArgs()
|
||||
|
||||
# Both instances should have the same defaults
|
||||
assert args1.host == args2.host
|
||||
assert args1.port == args2.port
|
||||
assert args1.policy == args2.policy
|
||||
assert args1.worker_urls == args2.worker_urls
|
||||
assert args1.pd_disaggregation == args2.pd_disaggregation
|
||||
|
||||
def test_config_serialization(self):
|
||||
"""Test that configuration can be serialized/deserialized."""
|
||||
args = RouterArgs(
|
||||
host="127.0.0.1",
|
||||
port=30000,
|
||||
worker_urls=["http://worker1:8000"],
|
||||
policy="cache_aware",
|
||||
cache_threshold=0.5,
|
||||
)
|
||||
|
||||
# Test that we can access all attributes
|
||||
assert hasattr(args, "host")
|
||||
assert hasattr(args, "port")
|
||||
assert hasattr(args, "worker_urls")
|
||||
assert hasattr(args, "policy")
|
||||
assert hasattr(args, "cache_threshold")
|
||||
|
||||
def test_config_with_none_values(self):
|
||||
"""Test configuration with None values."""
|
||||
args = RouterArgs(
|
||||
api_key=None,
|
||||
log_dir=None,
|
||||
log_level=None,
|
||||
prometheus_port=None,
|
||||
prometheus_host=None,
|
||||
request_id_headers=None,
|
||||
rate_limit_tokens_per_second=None,
|
||||
service_discovery_namespace=None,
|
||||
)
|
||||
|
||||
# All None values should be preserved
|
||||
assert args.api_key is None
|
||||
assert args.log_dir is None
|
||||
assert args.log_level is None
|
||||
assert args.prometheus_port is None
|
||||
assert args.prometheus_host is None
|
||||
assert args.request_id_headers is None
|
||||
assert args.rate_limit_tokens_per_second is None
|
||||
assert args.service_discovery_namespace is None
|
||||
|
||||
def test_config_with_empty_lists(self):
|
||||
"""Test configuration with empty lists."""
|
||||
args = RouterArgs(
|
||||
worker_urls=[], prefill_urls=[], decode_urls=[], cors_allowed_origins=[]
|
||||
)
|
||||
|
||||
# All empty lists should be preserved
|
||||
assert args.worker_urls == []
|
||||
assert args.prefill_urls == []
|
||||
assert args.decode_urls == []
|
||||
assert args.cors_allowed_origins == []
|
||||
|
||||
def test_config_with_empty_dicts(self):
|
||||
"""Test configuration with empty dictionaries."""
|
||||
args = RouterArgs(selector={}, prefill_selector={}, decode_selector={})
|
||||
|
||||
# All empty dictionaries should be preserved
|
||||
assert args.selector == {}
|
||||
assert args.prefill_selector == {}
|
||||
assert args.decode_selector == {}
|
||||
1053
sgl-router/py_test/unit/test_startup_sequence.py
Normal file
1053
sgl-router/py_test/unit/test_startup_sequence.py
Normal file
File diff suppressed because it is too large
Load Diff
506
sgl-router/py_test/unit/test_validation.py
Normal file
506
sgl-router/py_test/unit/test_validation.py
Normal file
@@ -0,0 +1,506 @@
|
||||
"""
|
||||
Unit tests for validation logic in sglang_router.
|
||||
|
||||
These tests focus on testing the validation logic in isolation,
|
||||
including parameter validation, URL validation, and configuration validation.
|
||||
"""
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from sglang_router.launch_router import RouterArgs, launch_router
|
||||
|
||||
|
||||
class TestURLValidation:
|
||||
"""Test URL validation logic."""
|
||||
|
||||
def test_valid_worker_urls(self):
|
||||
"""Test validation of valid worker URLs."""
|
||||
valid_urls = [
|
||||
"http://worker1:8000",
|
||||
"https://worker2:8000",
|
||||
"http://localhost:8000",
|
||||
"http://127.0.0.1:8000",
|
||||
"http://192.168.1.100:8000",
|
||||
"http://worker.example.com:8000",
|
||||
]
|
||||
|
||||
for url in valid_urls:
|
||||
args = RouterArgs(worker_urls=[url])
|
||||
# Should not raise any validation errors
|
||||
assert url in args.worker_urls
|
||||
|
||||
def test_valid_prefill_urls(self):
|
||||
"""Test validation of valid prefill URLs."""
|
||||
valid_prefill_urls = [
|
||||
("http://prefill1:8000", 9000),
|
||||
("https://prefill2:8000", None),
|
||||
("http://localhost:8000", 9000),
|
||||
("http://127.0.0.1:8000", None),
|
||||
]
|
||||
|
||||
for url, bootstrap_port in valid_prefill_urls:
|
||||
args = RouterArgs(prefill_urls=[(url, bootstrap_port)])
|
||||
# Should not raise any validation errors
|
||||
assert (url, bootstrap_port) in args.prefill_urls
|
||||
|
||||
def test_valid_decode_urls(self):
|
||||
"""Test validation of valid decode URLs."""
|
||||
valid_decode_urls = [
|
||||
"http://decode1:8001",
|
||||
"https://decode2:8001",
|
||||
"http://localhost:8001",
|
||||
"http://127.0.0.1:8001",
|
||||
]
|
||||
|
||||
for url in valid_decode_urls:
|
||||
args = RouterArgs(decode_urls=[url])
|
||||
# Should not raise any validation errors
|
||||
assert url in args.decode_urls
|
||||
|
||||
def test_malformed_urls(self):
|
||||
"""Test handling of malformed URLs."""
|
||||
# Note: The current implementation doesn't validate URL format
|
||||
# This test documents the current behavior
|
||||
malformed_urls = [
|
||||
"not-a-url",
|
||||
"ftp://worker1:8000", # Wrong protocol
|
||||
"http://", # Missing host
|
||||
":8000", # Missing protocol and host
|
||||
"http://worker1", # Missing port
|
||||
]
|
||||
|
||||
for url in malformed_urls:
|
||||
args = RouterArgs(worker_urls=[url])
|
||||
# Currently, malformed URLs are accepted
|
||||
# This might be something to improve in the future
|
||||
assert url in args.worker_urls
|
||||
|
||||
|
||||
class TestPortValidation:
|
||||
"""Test port validation logic."""
|
||||
|
||||
def test_valid_ports(self):
|
||||
"""Test validation of valid port numbers."""
|
||||
valid_ports = [1, 80, 8000, 30000, 65535]
|
||||
|
||||
for port in valid_ports:
|
||||
args = RouterArgs(port=port)
|
||||
assert args.port == port
|
||||
|
||||
def test_invalid_ports(self):
|
||||
"""Test handling of invalid port numbers."""
|
||||
# Note: The current implementation doesn't validate port ranges
|
||||
# This test documents the current behavior
|
||||
invalid_ports = [0, -1, 65536, 70000]
|
||||
|
||||
for port in invalid_ports:
|
||||
args = RouterArgs(port=port)
|
||||
# Currently, invalid ports are accepted
|
||||
# This might be something to improve in the future
|
||||
assert args.port == port
|
||||
|
||||
def test_bootstrap_port_validation(self):
|
||||
"""Test validation of bootstrap ports in PD mode."""
|
||||
valid_bootstrap_ports = [1, 80, 9000, 30000, 65535, None]
|
||||
|
||||
for bootstrap_port in valid_bootstrap_ports:
|
||||
args = RouterArgs(prefill_urls=[("http://prefill1:8000", bootstrap_port)])
|
||||
assert args.prefill_urls[0][1] == bootstrap_port
|
||||
|
||||
|
||||
class TestParameterValidation:
|
||||
"""Test parameter validation logic."""
|
||||
|
||||
def test_cache_threshold_validation(self):
|
||||
"""Test cache threshold parameter validation."""
|
||||
# Valid cache thresholds
|
||||
valid_thresholds = [0.0, 0.1, 0.5, 0.9, 1.0]
|
||||
|
||||
for threshold in valid_thresholds:
|
||||
args = RouterArgs(cache_threshold=threshold)
|
||||
assert args.cache_threshold == threshold
|
||||
|
||||
def test_balance_threshold_validation(self):
|
||||
"""Test load balancing threshold parameter validation."""
|
||||
# Valid absolute thresholds
|
||||
valid_abs_thresholds = [0, 1, 32, 64, 128, 1000]
|
||||
for threshold in valid_abs_thresholds:
|
||||
args = RouterArgs(balance_abs_threshold=threshold)
|
||||
assert args.balance_abs_threshold == threshold
|
||||
|
||||
# Valid relative thresholds
|
||||
valid_rel_thresholds = [1.0, 1.1, 1.5, 2.0, 10.0]
|
||||
for threshold in valid_rel_thresholds:
|
||||
args = RouterArgs(balance_rel_threshold=threshold)
|
||||
assert args.balance_rel_threshold == threshold
|
||||
|
||||
def test_timeout_validation(self):
|
||||
"""Test timeout parameter validation."""
|
||||
# Valid timeouts
|
||||
valid_timeouts = [1, 30, 60, 300, 600, 1800, 3600]
|
||||
|
||||
for timeout in valid_timeouts:
|
||||
args = RouterArgs(
|
||||
worker_startup_timeout_secs=timeout,
|
||||
worker_startup_check_interval=timeout,
|
||||
request_timeout_secs=timeout,
|
||||
queue_timeout_secs=timeout,
|
||||
)
|
||||
assert args.worker_startup_timeout_secs == timeout
|
||||
assert args.worker_startup_check_interval == timeout
|
||||
assert args.request_timeout_secs == timeout
|
||||
assert args.queue_timeout_secs == timeout
|
||||
|
||||
def test_retry_parameter_validation(self):
|
||||
"""Test retry parameter validation."""
|
||||
# Valid retry parameters
|
||||
valid_retry_counts = [0, 1, 3, 5, 10]
|
||||
for count in valid_retry_counts:
|
||||
args = RouterArgs(retry_max_retries=count)
|
||||
assert args.retry_max_retries == count
|
||||
|
||||
# Valid backoff parameters
|
||||
valid_backoff_ms = [1, 50, 100, 1000, 30000]
|
||||
for backoff in valid_backoff_ms:
|
||||
args = RouterArgs(
|
||||
retry_initial_backoff_ms=backoff, retry_max_backoff_ms=backoff
|
||||
)
|
||||
assert args.retry_initial_backoff_ms == backoff
|
||||
assert args.retry_max_backoff_ms == backoff
|
||||
|
||||
# Valid multiplier parameters
|
||||
valid_multipliers = [1.0, 1.5, 2.0, 3.0]
|
||||
for multiplier in valid_multipliers:
|
||||
args = RouterArgs(retry_backoff_multiplier=multiplier)
|
||||
assert args.retry_backoff_multiplier == multiplier
|
||||
|
||||
# Valid jitter parameters
|
||||
valid_jitter = [0.0, 0.1, 0.2, 0.5]
|
||||
for jitter in valid_jitter:
|
||||
args = RouterArgs(retry_jitter_factor=jitter)
|
||||
assert args.retry_jitter_factor == jitter
|
||||
|
||||
def test_circuit_breaker_parameter_validation(self):
|
||||
"""Test circuit breaker parameter validation."""
|
||||
# Valid failure thresholds
|
||||
valid_failure_thresholds = [1, 3, 5, 10, 20]
|
||||
for threshold in valid_failure_thresholds:
|
||||
args = RouterArgs(cb_failure_threshold=threshold)
|
||||
assert args.cb_failure_threshold == threshold
|
||||
|
||||
# Valid success thresholds
|
||||
valid_success_thresholds = [1, 2, 3, 5]
|
||||
for threshold in valid_success_thresholds:
|
||||
args = RouterArgs(cb_success_threshold=threshold)
|
||||
assert args.cb_success_threshold == threshold
|
||||
|
||||
# Valid timeout durations
|
||||
valid_timeouts = [10, 30, 60, 120, 300]
|
||||
for timeout in valid_timeouts:
|
||||
args = RouterArgs(
|
||||
cb_timeout_duration_secs=timeout, cb_window_duration_secs=timeout
|
||||
)
|
||||
assert args.cb_timeout_duration_secs == timeout
|
||||
assert args.cb_window_duration_secs == timeout
|
||||
|
||||
def test_health_check_parameter_validation(self):
|
||||
"""Test health check parameter validation."""
|
||||
# Valid failure thresholds
|
||||
valid_failure_thresholds = [1, 2, 3, 5, 10]
|
||||
for threshold in valid_failure_thresholds:
|
||||
args = RouterArgs(health_failure_threshold=threshold)
|
||||
assert args.health_failure_threshold == threshold
|
||||
|
||||
# Valid success thresholds
|
||||
valid_success_thresholds = [1, 2, 3, 5]
|
||||
for threshold in valid_success_thresholds:
|
||||
args = RouterArgs(health_success_threshold=threshold)
|
||||
assert args.health_success_threshold == threshold
|
||||
|
||||
# Valid timeouts and intervals
|
||||
valid_times = [1, 5, 10, 30, 60, 120]
|
||||
for time_val in valid_times:
|
||||
args = RouterArgs(
|
||||
health_check_timeout_secs=time_val, health_check_interval_secs=time_val
|
||||
)
|
||||
assert args.health_check_timeout_secs == time_val
|
||||
assert args.health_check_interval_secs == time_val
|
||||
|
||||
def test_rate_limiting_parameter_validation(self):
|
||||
"""Test rate limiting parameter validation."""
|
||||
# Valid concurrent request limits
|
||||
valid_limits = [1, 10, 64, 256, 512, 1000]
|
||||
for limit in valid_limits:
|
||||
args = RouterArgs(max_concurrent_requests=limit)
|
||||
assert args.max_concurrent_requests == limit
|
||||
|
||||
# Valid queue sizes
|
||||
valid_queue_sizes = [0, 10, 50, 100, 500, 1000]
|
||||
for size in valid_queue_sizes:
|
||||
args = RouterArgs(queue_size=size)
|
||||
assert args.queue_size == size
|
||||
|
||||
# Valid token rates
|
||||
valid_rates = [1, 10, 50, 100, 500, 1000]
|
||||
for rate in valid_rates:
|
||||
args = RouterArgs(rate_limit_tokens_per_second=rate)
|
||||
assert args.rate_limit_tokens_per_second == rate
|
||||
|
||||
def test_tree_size_validation(self):
|
||||
"""Test tree size parameter validation."""
|
||||
# Valid tree sizes (powers of 2)
|
||||
valid_sizes = [2**10, 2**20, 2**24, 2**26, 2**28, 2**30]
|
||||
|
||||
for size in valid_sizes:
|
||||
args = RouterArgs(max_tree_size=size)
|
||||
assert args.max_tree_size == size
|
||||
|
||||
def test_payload_size_validation(self):
|
||||
"""Test payload size parameter validation."""
|
||||
# Valid payload sizes
|
||||
valid_sizes = [
|
||||
1024, # 1KB
|
||||
1024 * 1024, # 1MB
|
||||
10 * 1024 * 1024, # 10MB
|
||||
100 * 1024 * 1024, # 100MB
|
||||
512 * 1024 * 1024, # 512MB
|
||||
1024 * 1024 * 1024, # 1GB
|
||||
]
|
||||
|
||||
for size in valid_sizes:
|
||||
args = RouterArgs(max_payload_size=size)
|
||||
assert args.max_payload_size == size
|
||||
|
||||
|
||||
class TestConfigurationValidation:
|
||||
"""Test configuration validation logic."""
|
||||
|
||||
def test_pd_mode_validation(self):
|
||||
"""Test PD mode configuration validation."""
|
||||
# Valid PD configuration
|
||||
args = RouterArgs(
|
||||
pd_disaggregation=True,
|
||||
prefill_urls=[("http://prefill1:8000", 9000)],
|
||||
decode_urls=["http://decode1:8001"],
|
||||
)
|
||||
|
||||
assert args.pd_disaggregation is True
|
||||
assert len(args.prefill_urls) > 0
|
||||
assert len(args.decode_urls) > 0
|
||||
|
||||
def test_service_discovery_validation(self):
|
||||
"""Test service discovery configuration validation."""
|
||||
# Valid service discovery configuration
|
||||
args = RouterArgs(
|
||||
service_discovery=True,
|
||||
selector={"app": "worker", "env": "prod"},
|
||||
service_discovery_port=8080,
|
||||
service_discovery_namespace="default",
|
||||
)
|
||||
|
||||
assert args.service_discovery is True
|
||||
assert args.selector == {"app": "worker", "env": "prod"}
|
||||
assert args.service_discovery_port == 8080
|
||||
assert args.service_discovery_namespace == "default"
|
||||
|
||||
def test_pd_service_discovery_validation(self):
|
||||
"""Test PD service discovery configuration validation."""
|
||||
# Valid PD service discovery configuration
|
||||
args = RouterArgs(
|
||||
pd_disaggregation=True,
|
||||
service_discovery=True,
|
||||
prefill_selector={"app": "prefill"},
|
||||
decode_selector={"app": "decode"},
|
||||
)
|
||||
|
||||
assert args.pd_disaggregation is True
|
||||
assert args.service_discovery is True
|
||||
assert args.prefill_selector == {"app": "prefill"}
|
||||
assert args.decode_selector == {"app": "decode"}
|
||||
|
||||
def test_policy_validation(self):
|
||||
"""Test policy configuration validation."""
|
||||
# Valid policies
|
||||
valid_policies = ["random", "round_robin", "cache_aware", "power_of_two"]
|
||||
|
||||
for policy in valid_policies:
|
||||
args = RouterArgs(policy=policy)
|
||||
assert args.policy == policy
|
||||
|
||||
def test_pd_policy_validation(self):
|
||||
"""Test PD policy configuration validation."""
|
||||
# Valid PD policies
|
||||
valid_policies = ["random", "round_robin", "cache_aware", "power_of_two"]
|
||||
|
||||
for prefill_policy in valid_policies:
|
||||
for decode_policy in valid_policies:
|
||||
args = RouterArgs(
|
||||
pd_disaggregation=True,
|
||||
prefill_urls=[("http://prefill1:8000", None)],
|
||||
decode_urls=["http://decode1:8001"],
|
||||
prefill_policy=prefill_policy,
|
||||
decode_policy=decode_policy,
|
||||
)
|
||||
assert args.prefill_policy == prefill_policy
|
||||
assert args.decode_policy == decode_policy
|
||||
|
||||
def test_cors_validation(self):
|
||||
"""Test CORS configuration validation."""
|
||||
# Valid CORS origins
|
||||
valid_origins = [
|
||||
[],
|
||||
["http://localhost:3000"],
|
||||
["https://example.com"],
|
||||
["http://localhost:3000", "https://example.com"],
|
||||
["*"], # Wildcard (if supported)
|
||||
]
|
||||
|
||||
for origins in valid_origins:
|
||||
args = RouterArgs(cors_allowed_origins=origins)
|
||||
assert args.cors_allowed_origins == origins
|
||||
|
||||
def test_logging_validation(self):
|
||||
"""Test logging configuration validation."""
|
||||
# Valid log levels
|
||||
valid_log_levels = ["debug", "info", "warning", "error", "critical"]
|
||||
|
||||
for level in valid_log_levels:
|
||||
args = RouterArgs(log_level=level)
|
||||
assert args.log_level == level
|
||||
|
||||
def test_prometheus_validation(self):
|
||||
"""Test Prometheus configuration validation."""
|
||||
# Valid Prometheus configuration
|
||||
args = RouterArgs(prometheus_port=29000, prometheus_host="127.0.0.1")
|
||||
|
||||
assert args.prometheus_port == 29000
|
||||
assert args.prometheus_host == "127.0.0.1"
|
||||
|
||||
def test_tokenizer_validation(self):
|
||||
"""Test tokenizer configuration validation."""
|
||||
# Note: model_path and tokenizer_path are not available in current RouterArgs
|
||||
pytest.skip("Tokenizer configuration not available in current implementation")
|
||||
|
||||
def test_request_id_headers_validation(self):
|
||||
"""Test request ID headers configuration validation."""
|
||||
# Valid request ID headers
|
||||
valid_headers = [
|
||||
["x-request-id"],
|
||||
["x-request-id", "x-trace-id"],
|
||||
["x-request-id", "x-trace-id", "x-correlation-id"],
|
||||
["custom-header"],
|
||||
]
|
||||
|
||||
for headers in valid_headers:
|
||||
args = RouterArgs(request_id_headers=headers)
|
||||
assert args.request_id_headers == headers
|
||||
|
||||
|
||||
class TestLaunchValidation:
|
||||
"""Test launch-time validation logic."""
|
||||
|
||||
def test_pd_mode_requires_urls(self):
|
||||
"""Test that PD mode requires prefill and decode URLs."""
|
||||
# PD mode without URLs should fail
|
||||
args = RouterArgs(
|
||||
pd_disaggregation=True,
|
||||
prefill_urls=[],
|
||||
decode_urls=[],
|
||||
service_discovery=False,
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="PD disaggregation mode requires --prefill"
|
||||
):
|
||||
launch_router(args)
|
||||
|
||||
def test_pd_mode_with_service_discovery_allows_empty_urls(self):
|
||||
"""Test that PD mode with service discovery allows empty URLs."""
|
||||
args = RouterArgs(
|
||||
pd_disaggregation=True,
|
||||
prefill_urls=[],
|
||||
decode_urls=[],
|
||||
service_discovery=True,
|
||||
)
|
||||
|
||||
# Should not raise validation error
|
||||
with patch("sglang_router.launch_router.Router") as router_mod:
|
||||
mock_router_instance = MagicMock()
|
||||
router_mod.from_args = MagicMock(return_value=mock_router_instance)
|
||||
|
||||
launch_router(args)
|
||||
|
||||
# Should create router instance via from_args
|
||||
router_mod.from_args.assert_called_once()
|
||||
|
||||
def test_regular_mode_allows_empty_worker_urls(self):
|
||||
"""Test that regular mode allows empty worker URLs."""
|
||||
args = RouterArgs(worker_urls=[], service_discovery=False)
|
||||
|
||||
# Should not raise validation error
|
||||
with patch("sglang_router.launch_router.Router") as router_mod:
|
||||
mock_router_instance = MagicMock()
|
||||
router_mod.from_args = MagicMock(return_value=mock_router_instance)
|
||||
|
||||
launch_router(args)
|
||||
|
||||
# Should create router instance via from_args
|
||||
router_mod.from_args.assert_called_once()
|
||||
|
||||
def test_launch_with_valid_config(self):
|
||||
"""Test launching with valid configuration."""
|
||||
args = RouterArgs(
|
||||
host="127.0.0.1",
|
||||
port=30000,
|
||||
worker_urls=["http://worker1:8000"],
|
||||
policy="cache_aware",
|
||||
)
|
||||
|
||||
# Should not raise validation error
|
||||
with patch("sglang_router.launch_router.Router") as router_mod:
|
||||
mock_router_instance = MagicMock()
|
||||
router_mod.from_args = MagicMock(return_value=mock_router_instance)
|
||||
|
||||
launch_router(args)
|
||||
|
||||
# Should create router instance via from_args
|
||||
router_mod.from_args.assert_called_once()
|
||||
|
||||
def test_launch_with_pd_config(self):
|
||||
"""Test launching with valid PD configuration."""
|
||||
args = RouterArgs(
|
||||
pd_disaggregation=True,
|
||||
prefill_urls=[("http://prefill1:8000", 9000)],
|
||||
decode_urls=["http://decode1:8001"],
|
||||
policy="cache_aware",
|
||||
)
|
||||
|
||||
# Should not raise validation error
|
||||
with patch("sglang_router.launch_router.Router") as router_mod:
|
||||
mock_router_instance = MagicMock()
|
||||
router_mod.from_args = MagicMock(return_value=mock_router_instance)
|
||||
|
||||
launch_router(args)
|
||||
|
||||
# Should create router instance via from_args
|
||||
router_mod.from_args.assert_called_once()
|
||||
|
||||
def test_launch_with_service_discovery_config(self):
|
||||
"""Test launching with valid service discovery configuration."""
|
||||
args = RouterArgs(
|
||||
service_discovery=True,
|
||||
selector={"app": "worker"},
|
||||
service_discovery_port=8080,
|
||||
)
|
||||
|
||||
# Should not raise validation error
|
||||
with patch("sglang_router.launch_router.Router") as router_mod:
|
||||
mock_router_instance = MagicMock()
|
||||
router_mod.from_args = MagicMock(return_value=mock_router_instance)
|
||||
|
||||
launch_router(args)
|
||||
|
||||
# Should create router instance via from_args
|
||||
router_mod.from_args.assert_called_once()
|
||||
Reference in New Issue
Block a user