[router] Improve the router e2e tests (#10102)
This commit is contained in:
6
.github/workflows/pr-test-rust.yml
vendored
6
.github/workflows/pr-test-rust.yml
vendored
@@ -105,11 +105,11 @@ jobs:
|
||||
pip install fastapi uvicorn orjson
|
||||
pytest -q -m integration
|
||||
|
||||
- name: Run e2e test
|
||||
- name: Run Python E2E tests
|
||||
run: |
|
||||
bash scripts/killall_sglang.sh "nuk_gpus"
|
||||
cd sgl-router/py_test
|
||||
python3 run_suite.py
|
||||
cd sgl-router
|
||||
pytest -m e2e -s -vv -o log_cli=true --log-cli-level=INFO
|
||||
|
||||
finish:
|
||||
needs: [unit-test-rust, e2e-python]
|
||||
|
||||
235
sgl-router/py_test/e2e/conftest.py
Normal file
235
sgl-router/py_test/e2e/conftest.py
Normal file
@@ -0,0 +1,235 @@
|
||||
import socket
|
||||
import subprocess
|
||||
import time
|
||||
from types import SimpleNamespace
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
)
|
||||
|
||||
|
||||
def _find_available_port() -> int:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(("127.0.0.1", 0))
|
||||
return s.getsockname()[1]
|
||||
|
||||
|
||||
def _parse_url(base_url: str) -> tuple[str, str]:
|
||||
"""Parse a base URL and return (host, port) as strings.
|
||||
|
||||
This is more robust than simple string splitting and supports different schemes
|
||||
and URL shapes like trailing paths.
|
||||
"""
|
||||
parsed = urlparse(base_url)
|
||||
return parsed.hostname or "127.0.0.1", (
|
||||
str(parsed.port) if parsed.port is not None else ""
|
||||
)
|
||||
|
||||
|
||||
def _wait_router_health(base_url: str, timeout: float) -> None:
|
||||
start = time.perf_counter()
|
||||
with requests.Session() as session:
|
||||
while time.perf_counter() - start < timeout:
|
||||
try:
|
||||
r = session.get(f"{base_url}/health", timeout=5)
|
||||
if r.status_code == 200:
|
||||
return
|
||||
except requests.RequestException:
|
||||
pass
|
||||
time.sleep(2)
|
||||
raise TimeoutError("Router failed to become healthy in time")
|
||||
|
||||
|
||||
def _popen_launch_router(
|
||||
model: str,
|
||||
base_url: str,
|
||||
dp_size: int,
|
||||
timeout: float,
|
||||
policy: str = "cache_aware",
|
||||
) -> subprocess.Popen:
|
||||
host, port = _parse_url(base_url)
|
||||
|
||||
prom_port = _find_available_port()
|
||||
|
||||
cmd = [
|
||||
"python3",
|
||||
"-m",
|
||||
"sglang_router.launch_server",
|
||||
"--model-path",
|
||||
model,
|
||||
"--host",
|
||||
host,
|
||||
"--port",
|
||||
port,
|
||||
"--dp",
|
||||
str(dp_size),
|
||||
"--router-policy",
|
||||
policy,
|
||||
"--allow-auto-truncate",
|
||||
"--router-prometheus-port",
|
||||
str(prom_port),
|
||||
"--router-prometheus-host",
|
||||
"127.0.0.1",
|
||||
]
|
||||
|
||||
proc = subprocess.Popen(cmd)
|
||||
_wait_router_health(base_url, timeout)
|
||||
return proc
|
||||
|
||||
|
||||
def _popen_launch_worker(
|
||||
model: str,
|
||||
base_url: str,
|
||||
*,
|
||||
dp_size: int | None = None,
|
||||
api_key: str | None = None,
|
||||
) -> subprocess.Popen:
|
||||
host, port = _parse_url(base_url)
|
||||
|
||||
cmd = [
|
||||
"python3",
|
||||
"-m",
|
||||
"sglang.launch_server",
|
||||
"--model-path",
|
||||
model,
|
||||
"--host",
|
||||
host,
|
||||
"--port",
|
||||
port,
|
||||
"--base-gpu-id",
|
||||
"0",
|
||||
]
|
||||
if dp_size is not None:
|
||||
cmd += ["--dp-size", str(dp_size)]
|
||||
if api_key is not None:
|
||||
cmd += ["--api-key", api_key]
|
||||
return subprocess.Popen(cmd)
|
||||
|
||||
|
||||
def _popen_launch_router_only(
|
||||
base_url: str,
|
||||
policy: str = "round_robin",
|
||||
timeout: float = 120.0,
|
||||
*,
|
||||
dp_aware: bool = False,
|
||||
api_key: str | None = None,
|
||||
) -> subprocess.Popen:
|
||||
host, port = _parse_url(base_url)
|
||||
|
||||
prom_port = _find_available_port()
|
||||
cmd = [
|
||||
"python3",
|
||||
"-m",
|
||||
"sglang_router.launch_router",
|
||||
"--host",
|
||||
host,
|
||||
"--port",
|
||||
port,
|
||||
"--policy",
|
||||
policy,
|
||||
]
|
||||
if dp_aware:
|
||||
cmd += ["--dp-aware"]
|
||||
if api_key is not None:
|
||||
cmd += ["--api-key", api_key]
|
||||
cmd += [
|
||||
"--prometheus-port",
|
||||
str(prom_port),
|
||||
"--prometheus-host",
|
||||
"127.0.0.1",
|
||||
]
|
||||
proc = subprocess.Popen(cmd)
|
||||
_wait_router_health(base_url, timeout)
|
||||
return proc
|
||||
|
||||
|
||||
def _terminate(proc: subprocess.Popen, timeout: float = 120) -> None:
|
||||
if proc is None:
|
||||
return
|
||||
proc.terminate()
|
||||
start = time.perf_counter()
|
||||
while proc.poll() is None:
|
||||
if time.perf_counter() - start > timeout:
|
||||
proc.kill()
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
config.addinivalue_line("markers", "e2e: mark as end-to-end test")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def e2e_model() -> str:
|
||||
# Always use the default test model
|
||||
return DEFAULT_MODEL_NAME_FOR_TEST
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def e2e_router(e2e_model: str):
|
||||
# Keep this available but tests below use router-only to avoid GPU contention
|
||||
base_url = DEFAULT_URL_FOR_TEST
|
||||
proc = _popen_launch_router(
|
||||
e2e_model, base_url, dp_size=2, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
|
||||
)
|
||||
try:
|
||||
yield SimpleNamespace(proc=proc, url=base_url)
|
||||
finally:
|
||||
_terminate(proc)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def e2e_router_only_rr():
|
||||
port = _find_available_port()
|
||||
base_url = f"http://127.0.0.1:{port}"
|
||||
proc = _popen_launch_router_only(base_url, policy="round_robin")
|
||||
try:
|
||||
yield SimpleNamespace(proc=proc, url=base_url)
|
||||
finally:
|
||||
_terminate(proc)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def e2e_primary_worker(e2e_model: str):
|
||||
port = _find_available_port()
|
||||
base_url = f"http://127.0.0.1:{port}"
|
||||
proc = _popen_launch_worker(e2e_model, base_url)
|
||||
# Router health gate will handle worker readiness
|
||||
try:
|
||||
yield SimpleNamespace(proc=proc, url=base_url)
|
||||
finally:
|
||||
_terminate(proc)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def e2e_router_only_rr_dp_aware_api():
|
||||
"""Router-only with dp-aware enabled and an API key."""
|
||||
port = _find_available_port()
|
||||
base_url = f"http://127.0.0.1:{port}"
|
||||
api_key = "secret"
|
||||
proc = _popen_launch_router_only(
|
||||
base_url, policy="round_robin", timeout=180.0, dp_aware=True, api_key=api_key
|
||||
)
|
||||
try:
|
||||
yield SimpleNamespace(proc=proc, url=base_url, api_key=api_key)
|
||||
finally:
|
||||
_terminate(proc)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def e2e_worker_dp2_api(e2e_model: str, e2e_router_only_rr_dp_aware_api):
|
||||
"""Worker with dp-size=2 and the same API key as the dp-aware router."""
|
||||
port = _find_available_port()
|
||||
base_url = f"http://127.0.0.1:{port}"
|
||||
api_key = e2e_router_only_rr_dp_aware_api.api_key
|
||||
proc = _popen_launch_worker(e2e_model, base_url, dp_size=2, api_key=api_key)
|
||||
try:
|
||||
yield SimpleNamespace(proc=proc, url=base_url)
|
||||
finally:
|
||||
_terminate(proc)
|
||||
146
sgl-router/py_test/e2e/test_e2e_router.py
Normal file
146
sgl-router/py_test/e2e/test_e2e_router.py
Normal file
@@ -0,0 +1,146 @@
|
||||
import threading
|
||||
import time
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from sglang.test.run_eval import run_eval
|
||||
|
||||
|
||||
@pytest.mark.e2e
|
||||
def test_mmlu(e2e_router_only_rr, e2e_primary_worker, e2e_model):
|
||||
# Attach the primary worker to a fresh router-only instance (single model)
|
||||
base = e2e_router_only_rr.url
|
||||
r = requests.post(
|
||||
f"{base}/add_worker", params={"url": e2e_primary_worker.url}, timeout=180
|
||||
)
|
||||
r.raise_for_status()
|
||||
|
||||
args = SimpleNamespace(
|
||||
base_url=base,
|
||||
model=e2e_model,
|
||||
eval_name="mmlu",
|
||||
num_examples=64,
|
||||
num_threads=32,
|
||||
temperature=0.1,
|
||||
)
|
||||
metrics = run_eval(args)
|
||||
assert metrics["score"] >= 0.65
|
||||
|
||||
|
||||
@pytest.mark.e2e
|
||||
def test_add_and_remove_worker_live(e2e_router_only_rr, e2e_primary_worker, e2e_model):
|
||||
base = e2e_router_only_rr.url
|
||||
worker_url = e2e_primary_worker.url
|
||||
|
||||
r = requests.post(f"{base}/add_worker", params={"url": worker_url}, timeout=180)
|
||||
r.raise_for_status()
|
||||
|
||||
with requests.Session() as s:
|
||||
for i in range(8):
|
||||
r = s.post(
|
||||
f"{base}/v1/completions",
|
||||
json={
|
||||
"model": e2e_model,
|
||||
"prompt": f"x{i}",
|
||||
"max_tokens": 1,
|
||||
"stream": False,
|
||||
},
|
||||
timeout=120,
|
||||
)
|
||||
r.raise_for_status()
|
||||
|
||||
# Remove the worker
|
||||
r = requests.post(f"{base}/remove_worker", params={"url": worker_url}, timeout=60)
|
||||
r.raise_for_status()
|
||||
|
||||
|
||||
@pytest.mark.e2e
|
||||
def test_lazy_fault_tolerance_live(e2e_router_only_rr, e2e_primary_worker, e2e_model):
|
||||
base = e2e_router_only_rr.url
|
||||
worker = e2e_primary_worker
|
||||
|
||||
r = requests.post(f"{base}/add_worker", params={"url": worker.url}, timeout=180)
|
||||
r.raise_for_status()
|
||||
|
||||
def killer():
|
||||
time.sleep(10)
|
||||
try:
|
||||
worker.proc.terminate()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
t = threading.Thread(target=killer, daemon=True)
|
||||
t.start()
|
||||
|
||||
args = SimpleNamespace(
|
||||
base_url=base,
|
||||
model=e2e_model,
|
||||
eval_name="mmlu",
|
||||
num_examples=32,
|
||||
num_threads=16,
|
||||
temperature=0.0,
|
||||
)
|
||||
metrics = run_eval(args)
|
||||
assert 0.0 <= metrics["score"] <= 1.0
|
||||
|
||||
|
||||
@pytest.mark.e2e
|
||||
def test_dp_aware_worker_expansion_and_api_key(
|
||||
e2e_model,
|
||||
e2e_router_only_rr_dp_aware_api,
|
||||
e2e_worker_dp2_api,
|
||||
):
|
||||
"""
|
||||
Launch a router-only instance in dp_aware mode and a single worker with dp_size=2
|
||||
and API key protection. Verify expansion, auth enforcement, and basic eval.
|
||||
"""
|
||||
import os
|
||||
|
||||
router_url = e2e_router_only_rr_dp_aware_api.url
|
||||
worker_url = e2e_worker_dp2_api.url
|
||||
api_key = e2e_router_only_rr_dp_aware_api.api_key
|
||||
|
||||
# Attach worker; router should expand to dp_size logical workers
|
||||
r = requests.post(
|
||||
f"{router_url}/add_worker", params={"url": worker_url}, timeout=180
|
||||
)
|
||||
r.raise_for_status()
|
||||
|
||||
r = requests.get(f"{router_url}/list_workers", timeout=30)
|
||||
r.raise_for_status()
|
||||
urls = r.json().get("urls", [])
|
||||
assert len(urls) == 2
|
||||
assert set(urls) == {f"{worker_url}@0", f"{worker_url}@1"}
|
||||
|
||||
# Verify API key enforcement path-through
|
||||
# 1) Without Authorization -> 401 from backend
|
||||
r = requests.post(
|
||||
f"{router_url}/v1/completions",
|
||||
json={"model": e2e_model, "prompt": "hi", "max_tokens": 1},
|
||||
timeout=60,
|
||||
)
|
||||
assert r.status_code == 401
|
||||
|
||||
# 2) With correct Authorization -> 200
|
||||
r = requests.post(
|
||||
f"{router_url}/v1/completions",
|
||||
json={"model": e2e_model, "prompt": "hi", "max_tokens": 1},
|
||||
headers={"Authorization": f"Bearer {api_key}"},
|
||||
timeout=60,
|
||||
)
|
||||
assert r.status_code == 200
|
||||
|
||||
# Finally, run MMLU eval through the router with auth
|
||||
os.environ["OPENAI_API_KEY"] = api_key
|
||||
args = SimpleNamespace(
|
||||
base_url=router_url,
|
||||
model=e2e_model,
|
||||
eval_name="mmlu",
|
||||
num_examples=64,
|
||||
num_threads=32,
|
||||
temperature=0.1,
|
||||
)
|
||||
metrics = run_eval(args)
|
||||
assert metrics["score"] >= 0.65
|
||||
@@ -44,6 +44,7 @@ def _parse_args() -> argparse.Namespace:
|
||||
p.add_argument("--api-key", default=None)
|
||||
p.add_argument("--max-payload-bytes", type=int, default=10 * 1024 * 1024)
|
||||
p.add_argument("--stream", action="store_true")
|
||||
p.add_argument("--dp-size", type=int, default=1)
|
||||
p.add_argument("--crash-on-request", action="store_true")
|
||||
p.add_argument("--health-fail-after-ms", type=int, default=0)
|
||||
return p.parse_args()
|
||||
@@ -125,12 +126,15 @@ def create_app(args: argparse.Namespace) -> FastAPI:
|
||||
return JSONResponse({"data": [{"id": "mock", "object": "model"}]})
|
||||
|
||||
@app.get("/get_server_info")
|
||||
async def get_server_info():
|
||||
async def get_server_info(request: Request):
|
||||
# Enforce API key on server info when required (used by dp_aware probing)
|
||||
check_api_key(request)
|
||||
return JSONResponse(
|
||||
{
|
||||
"worker_id": worker_id,
|
||||
"load_in_flight": _inflight,
|
||||
"cache": {"size": 0, "hit_rate": 0.0},
|
||||
"dp_size": int(args.dp_size),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
33
sgl-router/py_test/integration/test_payload_size.py
Normal file
33
sgl-router/py_test/integration/test_payload_size.py
Normal file
@@ -0,0 +1,33 @@
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_payload_size_limit(router_manager, mock_workers):
|
||||
# Start one backend and a router with a 1MB payload limit
|
||||
_, urls, _ = mock_workers(n=1)
|
||||
rh = router_manager.start_router(
|
||||
worker_urls=urls,
|
||||
policy="round_robin",
|
||||
extra={"max_payload_size": 1 * 1024 * 1024}, # 1MB
|
||||
)
|
||||
|
||||
# Payload just under 1MB should succeed
|
||||
payload_small = {
|
||||
"model": "test-model",
|
||||
"prompt": "x" * int(0.5 * 1024 * 1024), # ~0.5MB
|
||||
"max_tokens": 1,
|
||||
"stream": False,
|
||||
}
|
||||
r = requests.post(f"{rh.url}/v1/completions", json=payload_small)
|
||||
assert r.status_code == 200
|
||||
|
||||
# Payload over 1MB should fail with 413
|
||||
payload_large = {
|
||||
"model": "test-model",
|
||||
"prompt": "x" * int(1.2 * 1024 * 1024), # ~1.2MB
|
||||
"max_tokens": 1,
|
||||
"stream": False,
|
||||
}
|
||||
r = requests.post(f"{rh.url}/v1/completions", json=payload_large)
|
||||
assert r.status_code == 413
|
||||
@@ -1,27 +0,0 @@
|
||||
import argparse
|
||||
import glob
|
||||
|
||||
from sglang.test.test_utils import TestFile, run_unittest_files
|
||||
|
||||
if __name__ == "__main__":
|
||||
arg_parser = argparse.ArgumentParser()
|
||||
arg_parser.add_argument(
|
||||
"--timeout-per-file",
|
||||
type=int,
|
||||
default=2000,
|
||||
help="The time limit for running one file in seconds.",
|
||||
)
|
||||
args = arg_parser.parse_args()
|
||||
|
||||
files = glob.glob("**/test_*.py", recursive=True)
|
||||
# Exclude integration tests from the e2e suite; those are run separately via pytest -m integration
|
||||
files = [
|
||||
f
|
||||
for f in files
|
||||
if "/integration/" not in f and not f.startswith("integration/")
|
||||
]
|
||||
files.sort()
|
||||
|
||||
test_files = [TestFile(name=file) for file in files]
|
||||
exit_code = run_unittest_files(test_files, args.timeout_per_file)
|
||||
exit(exit_code)
|
||||
@@ -1,354 +0,0 @@
|
||||
import multiprocessing
|
||||
import time
|
||||
import unittest
|
||||
from types import SimpleNamespace
|
||||
|
||||
|
||||
def terminate_process(process: multiprocessing.Process, timeout: float = 1.0) -> None:
|
||||
"""Terminate a process gracefully, with forced kill as fallback.
|
||||
|
||||
Args:
|
||||
process: The process to terminate
|
||||
timeout: Seconds to wait for graceful termination before forcing kill
|
||||
"""
|
||||
if not process.is_alive():
|
||||
return
|
||||
|
||||
process.terminate()
|
||||
process.join(timeout=timeout)
|
||||
if process.is_alive():
|
||||
process.kill() # Force kill if terminate didn't work
|
||||
process.join()
|
||||
|
||||
|
||||
class TestLaunchRouter(unittest.TestCase):
|
||||
def setUp(self):
|
||||
"""Set up default arguments for router tests."""
|
||||
self.default_args = SimpleNamespace(
|
||||
host="127.0.0.1",
|
||||
port=30000,
|
||||
policy="cache_aware",
|
||||
worker_startup_timeout_secs=600,
|
||||
worker_startup_check_interval=10,
|
||||
cache_threshold=0.5,
|
||||
balance_abs_threshold=32,
|
||||
balance_rel_threshold=1.0001,
|
||||
eviction_interval_secs=60,
|
||||
max_tree_size=2**24,
|
||||
max_payload_size=256 * 1024 * 1024, # 256MB
|
||||
verbose=False,
|
||||
log_dir=None,
|
||||
log_level=None,
|
||||
service_discovery=False,
|
||||
selector=None,
|
||||
service_discovery_port=80,
|
||||
service_discovery_namespace=None,
|
||||
dp_aware=False,
|
||||
prometheus_port=None,
|
||||
prometheus_host=None,
|
||||
request_timeout_secs=60,
|
||||
max_concurrent_requests=64,
|
||||
cors_allowed_origins=[],
|
||||
pd_disaggregation=False,
|
||||
prefill=None,
|
||||
decode=None,
|
||||
worker_urls=[],
|
||||
retry_max_retries=3,
|
||||
retry_initial_backoff_ms=100,
|
||||
retry_max_backoff_ms=10_000,
|
||||
retry_backoff_multiplier=2.0,
|
||||
retry_jitter_factor=0.1,
|
||||
cb_failure_threshold=5,
|
||||
cb_success_threshold=2,
|
||||
cb_timeout_duration_secs=30,
|
||||
cb_window_duration_secs=60,
|
||||
disable_retries=False,
|
||||
disable_circuit_breaker=False,
|
||||
model_path=None,
|
||||
tokenizer_path=None,
|
||||
)
|
||||
|
||||
def create_router_args(self, **kwargs):
|
||||
"""Create router arguments by updating default args with provided kwargs."""
|
||||
args_dict = vars(self.default_args).copy()
|
||||
args_dict.update(kwargs)
|
||||
return SimpleNamespace(**args_dict)
|
||||
|
||||
def run_router_process(self, args):
|
||||
"""Run router in a separate process and verify it starts successfully."""
|
||||
|
||||
def run_router():
|
||||
try:
|
||||
from sglang_router.launch_router import launch_router
|
||||
|
||||
router = launch_router(args)
|
||||
if router is None:
|
||||
return 1
|
||||
return 0
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return 1
|
||||
|
||||
process = multiprocessing.Process(target=run_router)
|
||||
try:
|
||||
process.start()
|
||||
# Wait 3 seconds
|
||||
time.sleep(3)
|
||||
# Process is still running means router started successfully
|
||||
self.assertTrue(process.is_alive())
|
||||
finally:
|
||||
terminate_process(process)
|
||||
|
||||
def test_launch_router_common(self):
|
||||
args = self.create_router_args(worker_urls=["http://localhost:8000"])
|
||||
self.run_router_process(args)
|
||||
|
||||
def test_launch_router_with_empty_worker_urls(self):
|
||||
args = self.create_router_args(worker_urls=[])
|
||||
self.run_router_process(
|
||||
args
|
||||
) # Should start successfully with empty worker list
|
||||
|
||||
def test_launch_router_with_service_discovery(self):
|
||||
# Test router startup with service discovery enabled but no selectors
|
||||
args = self.create_router_args(
|
||||
worker_urls=[], service_discovery=True, selector=["app=test-worker"]
|
||||
)
|
||||
self.run_router_process(args)
|
||||
|
||||
def test_launch_router_with_service_discovery_namespace(self):
|
||||
# Test router startup with service discovery enabled and namespace specified
|
||||
args = self.create_router_args(
|
||||
worker_urls=[],
|
||||
service_discovery=True,
|
||||
selector=["app=test-worker"],
|
||||
service_discovery_namespace="test-namespace",
|
||||
)
|
||||
self.run_router_process(args)
|
||||
|
||||
def test_launch_router_common_with_dp_aware(self):
|
||||
args = self.create_router_args(
|
||||
worker_urls=["http://localhost:8000"],
|
||||
dp_aware=True,
|
||||
)
|
||||
self.run_router_process(args)
|
||||
|
||||
def test_launch_router_with_empty_worker_urls_with_dp_aware(self):
|
||||
args = self.create_router_args(
|
||||
worker_urls=[],
|
||||
dp_aware=True,
|
||||
)
|
||||
self.run_router_process(args)
|
||||
|
||||
def test_launch_router_common_with_dp_aware_service_discovery(self):
|
||||
# Test launch router with bot srevice_discovery and dp_aware enabled
|
||||
# Should fail since service_discovery and dp_aware is conflict
|
||||
args = self.create_router_args(
|
||||
worker_urls=["http://localhost:8000"],
|
||||
dp_aware=True,
|
||||
service_discovery=True,
|
||||
selector=["app=test-worker"],
|
||||
)
|
||||
|
||||
def run_router():
|
||||
try:
|
||||
from sglang_router.launch_router import launch_router
|
||||
|
||||
router = launch_router(args)
|
||||
if router is None:
|
||||
return 1
|
||||
return 0
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return 1
|
||||
|
||||
process = multiprocessing.Process(target=run_router)
|
||||
try:
|
||||
process.start()
|
||||
# Wait 3 seconds
|
||||
time.sleep(3)
|
||||
# Should fail since service_discovery and dp_aware is conflict
|
||||
self.assertFalse(process.is_alive())
|
||||
finally:
|
||||
terminate_process(process)
|
||||
|
||||
def test_launch_router_pd_mode_basic(self):
|
||||
"""Test basic PD router functionality without actually starting servers."""
|
||||
# This test just verifies the PD router can be created and configured
|
||||
# without actually starting it (which would require real prefill/decode servers)
|
||||
from sglang_router.launch_router import RouterArgs
|
||||
from sglang_router.router import PolicyType, Router
|
||||
|
||||
# Test RouterArgs parsing for PD mode
|
||||
# Simulate the parsed args structure from argparse with action="append"
|
||||
args = self.create_router_args(
|
||||
pd_disaggregation=True,
|
||||
policy="power_of_two", # PowerOfTwo is only valid in PD mode
|
||||
prefill=[
|
||||
["http://prefill1:8080", "9000"],
|
||||
["http://prefill2:8080", "none"],
|
||||
],
|
||||
decode=[
|
||||
["http://decode1:8081"],
|
||||
["http://decode2:8081"],
|
||||
],
|
||||
worker_urls=[], # Empty for PD mode
|
||||
)
|
||||
|
||||
router_args = RouterArgs.from_cli_args(args)
|
||||
self.assertTrue(router_args.pd_disaggregation)
|
||||
self.assertEqual(router_args.policy, "power_of_two")
|
||||
self.assertEqual(len(router_args.prefill_urls), 2)
|
||||
self.assertEqual(len(router_args.decode_urls), 2)
|
||||
|
||||
# Verify the parsed URLs and bootstrap ports
|
||||
self.assertEqual(router_args.prefill_urls[0], ("http://prefill1:8080", 9000))
|
||||
self.assertEqual(router_args.prefill_urls[1], ("http://prefill2:8080", None))
|
||||
self.assertEqual(router_args.decode_urls[0], "http://decode1:8081")
|
||||
self.assertEqual(router_args.decode_urls[1], "http://decode2:8081")
|
||||
|
||||
# Test Router creation in PD mode
|
||||
router = Router.from_args(router_args)
|
||||
self.assertIsNotNone(router)
|
||||
|
||||
def test_policy_validation(self):
|
||||
"""Test that policy validation works correctly for PD and regular modes."""
|
||||
from sglang_router.launch_router import RouterArgs, launch_router
|
||||
|
||||
# Test 1: PowerOfTwo requires at least 2 workers
|
||||
args = self.create_router_args(
|
||||
pd_disaggregation=False,
|
||||
policy="power_of_two",
|
||||
worker_urls=["http://localhost:8000"], # Only 1 worker
|
||||
)
|
||||
|
||||
# Should raise error
|
||||
with self.assertRaises(ValueError) as cm:
|
||||
launch_router(args)
|
||||
self.assertIn(
|
||||
"Power-of-two policy requires at least 2 workers",
|
||||
str(cm.exception),
|
||||
)
|
||||
|
||||
# Test 2: PowerOfTwo with sufficient workers should succeed
|
||||
args = self.create_router_args(
|
||||
pd_disaggregation=False,
|
||||
policy="power_of_two",
|
||||
worker_urls=["http://localhost:8000", "http://localhost:8001"], # 2 workers
|
||||
)
|
||||
# This should not raise an error (validation passes)
|
||||
|
||||
# Test 3: All policies now work in both modes
|
||||
# Regular mode with RoundRobin
|
||||
args = self.create_router_args(
|
||||
pd_disaggregation=False,
|
||||
policy="round_robin",
|
||||
worker_urls=["http://localhost:8000"],
|
||||
)
|
||||
# This should not raise validation error
|
||||
|
||||
# PD mode with RoundRobin (now supported!)
|
||||
args = self.create_router_args(
|
||||
pd_disaggregation=True,
|
||||
policy="round_robin",
|
||||
prefill=[["http://prefill1:8080", "9000"]],
|
||||
decode=[["http://decode1:8081"]],
|
||||
worker_urls=[],
|
||||
)
|
||||
# This should not raise validation error
|
||||
|
||||
def test_pd_service_discovery_args_parsing(self):
|
||||
"""Test PD service discovery CLI argument parsing."""
|
||||
import argparse
|
||||
|
||||
from sglang_router.launch_router import RouterArgs
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
RouterArgs.add_cli_args(parser)
|
||||
|
||||
args = parser.parse_args(
|
||||
[
|
||||
"--pd-disaggregation",
|
||||
"--service-discovery",
|
||||
"--prefill-selector",
|
||||
"app=sglang",
|
||||
"component=prefill",
|
||||
"--decode-selector",
|
||||
"app=sglang",
|
||||
"component=decode",
|
||||
"--service-discovery-port",
|
||||
"8000",
|
||||
"--service-discovery-namespace",
|
||||
"production",
|
||||
"--policy",
|
||||
"cache_aware",
|
||||
]
|
||||
)
|
||||
|
||||
router_args = RouterArgs.from_cli_args(args)
|
||||
|
||||
self.assertTrue(router_args.pd_disaggregation)
|
||||
self.assertTrue(router_args.service_discovery)
|
||||
self.assertEqual(
|
||||
router_args.prefill_selector, {"app": "sglang", "component": "prefill"}
|
||||
)
|
||||
self.assertEqual(
|
||||
router_args.decode_selector, {"app": "sglang", "component": "decode"}
|
||||
)
|
||||
self.assertEqual(router_args.service_discovery_port, 8000)
|
||||
self.assertEqual(router_args.service_discovery_namespace, "production")
|
||||
|
||||
def test_regular_service_discovery_args_parsing(self):
|
||||
"""Test regular mode service discovery CLI argument parsing."""
|
||||
import argparse
|
||||
|
||||
from sglang_router.launch_router import RouterArgs
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
RouterArgs.add_cli_args(parser)
|
||||
|
||||
args = parser.parse_args(
|
||||
[
|
||||
"--service-discovery",
|
||||
"--selector",
|
||||
"app=sglang-worker",
|
||||
"environment=staging",
|
||||
"--service-discovery-port",
|
||||
"8000",
|
||||
"--policy",
|
||||
"round_robin",
|
||||
]
|
||||
)
|
||||
|
||||
router_args = RouterArgs.from_cli_args(args)
|
||||
|
||||
self.assertFalse(router_args.pd_disaggregation)
|
||||
self.assertTrue(router_args.service_discovery)
|
||||
self.assertEqual(
|
||||
router_args.selector, {"app": "sglang-worker", "environment": "staging"}
|
||||
)
|
||||
self.assertEqual(router_args.prefill_selector, {})
|
||||
self.assertEqual(router_args.decode_selector, {})
|
||||
|
||||
def test_empty_worker_urls_args_parsing(self):
|
||||
"""Test that router accepts no worker URLs and defaults to empty list."""
|
||||
import argparse
|
||||
|
||||
from sglang_router.launch_router import RouterArgs
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
RouterArgs.add_cli_args(parser)
|
||||
|
||||
# Test with no --worker-urls argument at all
|
||||
args = parser.parse_args(["--policy", "random", "--port", "30000"])
|
||||
router_args = RouterArgs.from_cli_args(args)
|
||||
self.assertEqual(router_args.worker_urls, [])
|
||||
|
||||
# Test with explicit empty --worker-urls
|
||||
args = parser.parse_args(["--worker-urls", "--policy", "random"])
|
||||
router_args = RouterArgs.from_cli_args(args)
|
||||
self.assertEqual(router_args.worker_urls, [])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -1,735 +0,0 @@
|
||||
import socket
|
||||
import subprocess
|
||||
import time
|
||||
import unittest
|
||||
from types import SimpleNamespace
|
||||
|
||||
import requests
|
||||
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.run_eval import run_eval
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
)
|
||||
|
||||
|
||||
def popen_launch_router(
|
||||
model: str,
|
||||
base_url: str,
|
||||
dp_size: int,
|
||||
timeout: float,
|
||||
policy: str = "cache_aware",
|
||||
max_payload_size: int = None,
|
||||
api_key: str = None,
|
||||
log_dir: str = None,
|
||||
service_discovery: bool = False,
|
||||
selector: list = None,
|
||||
service_discovery_port: int = 80,
|
||||
service_discovery_namespace: str = None,
|
||||
prometheus_port: int = None,
|
||||
prometheus_host: str = None,
|
||||
dp_aware: bool = False,
|
||||
# Router retry/CB tuning (optional)
|
||||
router_retry_max_retries: int = None,
|
||||
router_retry_initial_backoff_ms: int = None,
|
||||
router_retry_max_backoff_ms: int = None,
|
||||
router_retry_backoff_multiplier: float = None,
|
||||
router_retry_jitter_factor: float = None,
|
||||
router_cb_failure_threshold: int = None,
|
||||
router_cb_success_threshold: int = None,
|
||||
router_cb_timeout_duration_secs: int = None,
|
||||
router_cb_window_duration_secs: int = None,
|
||||
):
|
||||
"""
|
||||
Launch the router server process.
|
||||
|
||||
Args:
|
||||
model: Model path/name
|
||||
base_url: Server base URL
|
||||
dp_size: Data parallel size
|
||||
timeout: Server launch timeout
|
||||
policy: Router policy, one of "cache_aware", "round_robin", "random"
|
||||
max_payload_size: Maximum payload size in bytes
|
||||
api_key: API key for the router
|
||||
log_dir: Directory to store log files. If None, logs are only output to console.
|
||||
service_discovery: Enable Kubernetes service discovery
|
||||
selector: List of label selectors in format ["key1=value1", "key2=value2"]
|
||||
service_discovery_port: Port to use for service discovery
|
||||
service_discovery_namespace: Kubernetes namespace to watch for pods. If None, watches all namespaces.
|
||||
prometheus_port: Port to expose Prometheus metrics. If None, Prometheus metrics are disabled.
|
||||
prometheus_host: Host address to bind the Prometheus metrics server.
|
||||
dp_aware: Enable data parallelism aware routing strategy.
|
||||
"""
|
||||
_, host, port = base_url.split(":")
|
||||
host = host[2:]
|
||||
|
||||
command = [
|
||||
"python3",
|
||||
"-m",
|
||||
"sglang_router.launch_server",
|
||||
"--model-path",
|
||||
model,
|
||||
"--host",
|
||||
host,
|
||||
"--port",
|
||||
port,
|
||||
"--dp",
|
||||
str(dp_size),
|
||||
"--router-eviction-interval-secs",
|
||||
"5",
|
||||
"--router-policy",
|
||||
policy,
|
||||
"--allow-auto-truncate",
|
||||
]
|
||||
|
||||
if api_key is not None:
|
||||
command.extend(["--api-key", api_key])
|
||||
command.extend(["--router-api-key", api_key])
|
||||
|
||||
if max_payload_size is not None:
|
||||
command.extend(["--router-max-payload-size", str(max_payload_size)])
|
||||
|
||||
if service_discovery:
|
||||
command.append("--router-service-discovery")
|
||||
|
||||
if selector:
|
||||
command.extend(["--router-selector"] + selector)
|
||||
|
||||
if service_discovery_port != 80:
|
||||
command.extend(["--router-service-discovery-port", str(service_discovery_port)])
|
||||
|
||||
if service_discovery_namespace:
|
||||
command.extend(
|
||||
["--router-service-discovery-namespace", service_discovery_namespace]
|
||||
)
|
||||
|
||||
if prometheus_port is not None:
|
||||
command.extend(["--router-prometheus-port", str(prometheus_port)])
|
||||
|
||||
if prometheus_host is not None:
|
||||
command.extend(["--router-prometheus-host", prometheus_host])
|
||||
|
||||
if log_dir is not None:
|
||||
command.extend(["--log-dir", log_dir])
|
||||
|
||||
if dp_aware:
|
||||
command.append("--router-dp-aware")
|
||||
|
||||
# Append router retry/CB tuning flags if provided
|
||||
def _add(flag: str, val):
|
||||
if val is not None:
|
||||
command.extend([flag, str(val)])
|
||||
|
||||
_add("--router-retry-max-retries", router_retry_max_retries)
|
||||
_add("--router-retry-initial-backoff-ms", router_retry_initial_backoff_ms)
|
||||
_add("--router-retry-max-backoff-ms", router_retry_max_backoff_ms)
|
||||
_add("--router-retry-backoff-multiplier", router_retry_backoff_multiplier)
|
||||
_add("--router-retry-jitter-factor", router_retry_jitter_factor)
|
||||
_add("--router-cb-failure-threshold", router_cb_failure_threshold)
|
||||
_add("--router-cb-success-threshold", router_cb_success_threshold)
|
||||
_add("--router-cb-timeout-duration-secs", router_cb_timeout_duration_secs)
|
||||
_add("--router-cb-window-duration-secs", router_cb_window_duration_secs)
|
||||
|
||||
process = subprocess.Popen(command, stdout=None, stderr=None)
|
||||
|
||||
start_time = time.perf_counter()
|
||||
with requests.Session() as session:
|
||||
while time.perf_counter() - start_time < timeout:
|
||||
try:
|
||||
response = session.get(f"{base_url}/health")
|
||||
if response.status_code == 200:
|
||||
print(f"Router {base_url} is healthy")
|
||||
return process
|
||||
except requests.RequestException:
|
||||
pass
|
||||
time.sleep(10)
|
||||
|
||||
raise TimeoutError("Router failed to start within the timeout period.")
|
||||
|
||||
|
||||
def find_available_port():
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(("127.0.0.1", 0))
|
||||
return s.getsockname()[1]
|
||||
|
||||
|
||||
def popen_launch_server(
|
||||
model: str,
|
||||
base_url: str,
|
||||
timeout: float,
|
||||
api_key: str = None,
|
||||
):
|
||||
_, host, port = base_url.split(":")
|
||||
host = host[2:]
|
||||
|
||||
command = [
|
||||
"python3",
|
||||
"-m",
|
||||
"sglang.launch_server",
|
||||
"--model-path",
|
||||
model,
|
||||
"--host",
|
||||
host,
|
||||
"--port",
|
||||
port,
|
||||
"--base-gpu-id",
|
||||
"1",
|
||||
]
|
||||
|
||||
if api_key is not None:
|
||||
command.extend(["--api-key", api_key])
|
||||
|
||||
process = subprocess.Popen(command, stdout=None, stderr=None)
|
||||
|
||||
# intentionally don't wait and defer the job to the router health check
|
||||
return process
|
||||
|
||||
|
||||
def terminate_and_wait(process, timeout=300):
|
||||
"""Terminate a process and wait until it is terminated.
|
||||
|
||||
Args:
|
||||
process: subprocess.Popen object
|
||||
timeout: maximum time to wait in seconds
|
||||
|
||||
Raises:
|
||||
TimeoutError: if process does not terminate within timeout
|
||||
"""
|
||||
if process is None:
|
||||
return
|
||||
|
||||
process.terminate()
|
||||
start_time = time.perf_counter()
|
||||
|
||||
while process.poll() is None:
|
||||
print(f"Terminating process {process.pid}")
|
||||
if time.perf_counter() - start_time > timeout:
|
||||
raise TimeoutError(
|
||||
f"Process {process.pid} failed to terminate within {timeout}s"
|
||||
)
|
||||
time.sleep(1)
|
||||
|
||||
print(f"Process {process.pid} is successfully terminated")
|
||||
|
||||
|
||||
class TestLaunchServer(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.model = DEFAULT_MODEL_NAME_FOR_TEST
|
||||
self.base_url = DEFAULT_URL_FOR_TEST
|
||||
self.process = None
|
||||
self.other_process = []
|
||||
|
||||
def tearDown(self):
|
||||
print("Running tearDown...")
|
||||
if self.process:
|
||||
terminate_and_wait(self.process)
|
||||
for process in self.other_process:
|
||||
terminate_and_wait(process)
|
||||
print("tearDown done")
|
||||
|
||||
def test_1_mmlu(self):
|
||||
print("Running test_1_mmlu...")
|
||||
# DP size = 2
|
||||
self.process = popen_launch_router(
|
||||
self.model,
|
||||
self.base_url,
|
||||
dp_size=2,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
policy="cache_aware",
|
||||
)
|
||||
|
||||
args = SimpleNamespace(
|
||||
base_url=self.base_url,
|
||||
model=self.model,
|
||||
eval_name="mmlu",
|
||||
num_examples=64,
|
||||
num_threads=32,
|
||||
temperature=0.1,
|
||||
)
|
||||
|
||||
metrics = run_eval(args)
|
||||
score = metrics["score"]
|
||||
THRESHOLD = 0.635
|
||||
passed = score >= THRESHOLD
|
||||
msg = f"MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})"
|
||||
self.assertGreaterEqual(score, THRESHOLD, msg)
|
||||
|
||||
def test_2_add_and_remove_worker(self):
|
||||
print("Running test_2_add_and_remove_worker...")
|
||||
# DP size = 1
|
||||
self.process = popen_launch_router(
|
||||
self.model,
|
||||
self.base_url,
|
||||
dp_size=1,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
policy="round_robin", # use round robin to make sure every worker processes requests
|
||||
)
|
||||
# 1. start a worker
|
||||
port = find_available_port()
|
||||
worker_url = f"http://127.0.0.1:{port}"
|
||||
worker_process = popen_launch_server(
|
||||
self.model, worker_url, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
|
||||
)
|
||||
self.other_process.append(worker_process)
|
||||
|
||||
# 2. use /add_worker api to add it to the router. It will be used by the router after it is healthy
|
||||
with requests.Session() as session:
|
||||
response = session.post(f"{self.base_url}/add_worker?url={worker_url}")
|
||||
print(f"status code: {response.status_code}, response: {response.text}")
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
# 3. run mmlu
|
||||
args = SimpleNamespace(
|
||||
base_url=self.base_url,
|
||||
model=self.model,
|
||||
eval_name="mmlu",
|
||||
num_examples=64,
|
||||
num_threads=32,
|
||||
temperature=0.1,
|
||||
)
|
||||
metrics = run_eval(args)
|
||||
score = metrics["score"]
|
||||
THRESHOLD = 0.635
|
||||
passed = score >= THRESHOLD
|
||||
msg = f"MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})"
|
||||
self.assertGreaterEqual(score, THRESHOLD, msg)
|
||||
|
||||
# 4. use /remove_worker api to remove it from the router
|
||||
with requests.Session() as session:
|
||||
response = session.post(f"{self.base_url}/remove_worker?url={worker_url}")
|
||||
print(f"status code: {response.status_code}, response: {response.text}")
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
# 5. run mmlu again
|
||||
metrics = run_eval(args)
|
||||
score = metrics["score"]
|
||||
THRESHOLD = 0.635
|
||||
passed = score >= THRESHOLD
|
||||
msg = f"MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})"
|
||||
self.assertGreaterEqual(score, THRESHOLD, msg)
|
||||
|
||||
def test_3_lazy_fault_tolerance(self):
|
||||
print("Running test_3_lazy_fault_tolerance...")
|
||||
# DP size = 1
|
||||
self.process = popen_launch_router(
|
||||
self.model,
|
||||
self.base_url,
|
||||
dp_size=1,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
policy="round_robin",
|
||||
)
|
||||
|
||||
# 1. start a worker
|
||||
port = find_available_port()
|
||||
worker_url = f"http://127.0.0.1:{port}"
|
||||
worker_process = popen_launch_server(
|
||||
self.model, worker_url, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
|
||||
)
|
||||
self.other_process.append(worker_process)
|
||||
|
||||
# 2. use /add_worker api to add it to the router. It will be used by the router after it is healthy
|
||||
with requests.Session() as session:
|
||||
response = session.post(f"{self.base_url}/add_worker?url={worker_url}")
|
||||
print(f"status code: {response.status_code}, response: {response.text}")
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
# Start a thread to kill the worker after 10 seconds to mimic abrupt worker failure
|
||||
def kill_worker():
|
||||
time.sleep(10)
|
||||
kill_process_tree(worker_process.pid)
|
||||
print("Worker process killed")
|
||||
|
||||
import threading
|
||||
|
||||
kill_thread = threading.Thread(target=kill_worker)
|
||||
kill_thread.daemon = True
|
||||
kill_thread.start()
|
||||
|
||||
# 3. run mmlu
|
||||
args = SimpleNamespace(
|
||||
base_url=self.base_url,
|
||||
model=self.model,
|
||||
eval_name="mmlu",
|
||||
num_examples=256,
|
||||
num_threads=32,
|
||||
temperature=0.1,
|
||||
)
|
||||
metrics = run_eval(args)
|
||||
score = metrics["score"]
|
||||
THRESHOLD = 0.635
|
||||
passed = score >= THRESHOLD
|
||||
msg = f"MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})"
|
||||
self.assertGreaterEqual(score, THRESHOLD, msg)
|
||||
|
||||
def test_4_payload_size(self):
|
||||
print("Running test_4_payload_size...")
|
||||
# Start router with 1MB limit
|
||||
self.process = popen_launch_router(
|
||||
self.model,
|
||||
self.base_url,
|
||||
dp_size=1,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
policy="round_robin",
|
||||
max_payload_size=1 * 1024 * 1024, # 1MB limit
|
||||
)
|
||||
|
||||
# Test case 1: Payload just under 1MB should succeed
|
||||
payload_0_5_mb = {
|
||||
"text": "x" * int(0.5 * 1024 * 1024), # 0.5MB of text
|
||||
"temperature": 0.0,
|
||||
}
|
||||
|
||||
with requests.Session() as session:
|
||||
response = session.post(
|
||||
f"{self.base_url}/generate",
|
||||
json=payload_0_5_mb,
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
self.assertEqual(
|
||||
response.status_code,
|
||||
200,
|
||||
f"0.5MB payload should succeed but got status {response.status_code}",
|
||||
)
|
||||
|
||||
# Test case 2: Payload over 1MB should fail
|
||||
payload_1_plus_mb = {
|
||||
"text": "x" * int((1.2 * 1024 * 1024)), # 1.2MB of text
|
||||
"temperature": 0.0,
|
||||
}
|
||||
|
||||
with requests.Session() as session:
|
||||
response = session.post(
|
||||
f"{self.base_url}/generate",
|
||||
json=payload_1_plus_mb,
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
self.assertEqual(
|
||||
response.status_code,
|
||||
413, # Payload Too Large
|
||||
f"1.2MB payload should fail with 413 but got status {response.status_code}",
|
||||
)
|
||||
|
||||
def test_5_api_key(self):
|
||||
print("Running test_5_api_key...")
|
||||
|
||||
self.process = popen_launch_router(
|
||||
self.model,
|
||||
self.base_url,
|
||||
dp_size=1,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
policy="round_robin",
|
||||
api_key="correct_api_key",
|
||||
)
|
||||
|
||||
# Test case 1: request without api key should fail
|
||||
with requests.Session() as session:
|
||||
response = session.post(
|
||||
f"{self.base_url}/generate",
|
||||
json={"text": "Kanye west is, ", "temperature": 0},
|
||||
)
|
||||
print(f"status code: {response.status_code}, response: {response.text}")
|
||||
self.assertEqual(
|
||||
response.status_code,
|
||||
401,
|
||||
"Request without api key should fail with 401",
|
||||
)
|
||||
|
||||
# Test case 2: request with invalid api key should fail
|
||||
with requests.Session() as session:
|
||||
response = requests.post(
|
||||
f"{self.base_url}/generate",
|
||||
json={"text": "Kanye west is, ", "temperature": 0},
|
||||
headers={"Authorization": "Bearer 123"},
|
||||
)
|
||||
print(f"status code: {response.status_code}, response: {response.text}")
|
||||
self.assertEqual(
|
||||
response.status_code,
|
||||
401,
|
||||
"Request with invalid api key should fail with 401",
|
||||
)
|
||||
|
||||
# Test case 3: request with correct api key should succeed
|
||||
with requests.Session() as session:
|
||||
response = session.post(
|
||||
f"{self.base_url}/generate",
|
||||
json={"text": "Kanye west is ", "temperature": 0},
|
||||
headers={"Authorization": "Bearer correct_api_key"},
|
||||
)
|
||||
print(f"status code: {response.status_code}, response: {response.text}")
|
||||
self.assertEqual(
|
||||
response.status_code, 200, "Request with correct api key should succeed"
|
||||
)
|
||||
|
||||
def test_6_mmlu_with_dp_aware(self):
|
||||
print("Running test_6_mmlu_with_dp_aware...")
|
||||
# DP size = 2
|
||||
self.process = popen_launch_router(
|
||||
self.model,
|
||||
self.base_url,
|
||||
dp_size=2,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
policy="cache_aware",
|
||||
dp_aware=True,
|
||||
)
|
||||
|
||||
args = SimpleNamespace(
|
||||
base_url=self.base_url,
|
||||
model=self.model,
|
||||
eval_name="mmlu",
|
||||
num_examples=64,
|
||||
num_threads=32,
|
||||
temperature=0.1,
|
||||
)
|
||||
|
||||
metrics = run_eval(args)
|
||||
score = metrics["score"]
|
||||
THRESHOLD = 0.635
|
||||
passed = score >= THRESHOLD
|
||||
msg = f"dp aware MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})"
|
||||
self.assertGreaterEqual(score, THRESHOLD, msg)
|
||||
|
||||
def test_7_add_and_remove_worker_with_dp_aware(self):
|
||||
print("Running test_7_add_and_remove_worker_with_dp_aware...")
|
||||
|
||||
# Set dp_size = 1
|
||||
self.process = popen_launch_router(
|
||||
self.model,
|
||||
self.base_url,
|
||||
dp_size=1,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
policy="round_robin", # make sure every worker processes requests
|
||||
dp_aware=True, # dp aware strategy should work well with RR
|
||||
)
|
||||
|
||||
# 1. Start a worker
|
||||
port = find_available_port()
|
||||
worker_url = f"http://127.0.0.1:{port}"
|
||||
worker_process = popen_launch_server(
|
||||
self.model, worker_url, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
|
||||
)
|
||||
self.other_process.append(worker_process)
|
||||
|
||||
# 2. Use the /add_worker API to add it to the router
|
||||
# It will be used by router after it is healthy
|
||||
with requests.Session() as session:
|
||||
response = session.post(f"{self.base_url}/add_worker?url={worker_url}")
|
||||
print(f"status code: {response.status_code}, response: {response.text}")
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
# 3. Run mmlu
|
||||
args = SimpleNamespace(
|
||||
base_url=self.base_url,
|
||||
model=self.model,
|
||||
eval_name="mmlu",
|
||||
num_examples=64,
|
||||
num_threads=32,
|
||||
temperature=0.1,
|
||||
)
|
||||
metrics = run_eval(args)
|
||||
score = metrics["score"]
|
||||
THRESHOLD = 0.635
|
||||
passed = score >= THRESHOLD
|
||||
msg = f"MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})"
|
||||
self.assertGreaterEqual(score, THRESHOLD, msg)
|
||||
|
||||
# 4. Use the /remove_worker API to remove it from the router
|
||||
with requests.Session() as session:
|
||||
response = session.post(f"{self.base_url}/remove_worker?url={worker_url}")
|
||||
print(f"status code: {response.status_code}, response: {response.text}")
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
# 5. Run mmlu again
|
||||
metrics = run_eval(args)
|
||||
score = metrics["score"]
|
||||
THRESHOLD = 0.635
|
||||
passed = score >= THRESHOLD
|
||||
msg = f"MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})"
|
||||
self.assertGreaterEqual(score, THRESHOLD, msg)
|
||||
|
||||
# 6. Start another worker with api_key set
|
||||
terminate_and_wait(worker_process) # terminate the old worker process
|
||||
port = find_available_port()
|
||||
worker_url = f"http://127.0.0.1:{port}"
|
||||
worker_process = popen_launch_server(
|
||||
self.model,
|
||||
worker_url,
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
api_key="correct_api_key",
|
||||
)
|
||||
self.other_process.append(worker_process)
|
||||
|
||||
# 7. Use the /add_worker API to add it to the router
|
||||
# Should fail since the router would contact the worker's
|
||||
# /get_server_info endpoint for the dp_size info, but it
|
||||
# has no knowledge of the api key
|
||||
with requests.Session() as session:
|
||||
response = session.post(f"{self.base_url}/add_worker?url={worker_url}")
|
||||
print(f"status code: {response.status_code}, response: {response.text}")
|
||||
self.assertNotEqual(response.status_code, 200)
|
||||
|
||||
def test_8_lazy_fault_tolerance_with_dp_aware(self):
|
||||
print("Running test_8_lazy_fault_tolerance_with_dp_aware...")
|
||||
|
||||
# Set dp_size = 1
|
||||
self.process = popen_launch_router(
|
||||
self.model,
|
||||
self.base_url,
|
||||
dp_size=1,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
policy="round_robin",
|
||||
dp_aware=True,
|
||||
)
|
||||
|
||||
# 1. Start a worker
|
||||
port = find_available_port()
|
||||
worker_url = f"http://127.0.0.1:{port}"
|
||||
worker_process = popen_launch_server(
|
||||
self.model, worker_url, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
|
||||
)
|
||||
self.other_process.append(worker_process)
|
||||
|
||||
# 2. Use the /add_worker API to add it to the router
|
||||
# It will be used by router after it is healthy
|
||||
with requests.Session() as session:
|
||||
response = session.post(f"{self.base_url}/add_worker?url={worker_url}")
|
||||
print(f"status code: {response.status_code}, response: {response.text}")
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
# Start a thread to kill the worker after 10 seconds to mimic
|
||||
# abrupt worker failure
|
||||
def kill_worker():
|
||||
time.sleep(10)
|
||||
kill_process_tree(worker_process.pid)
|
||||
print("Worker process killed")
|
||||
|
||||
import threading
|
||||
|
||||
kill_thread = threading.Thread(target=kill_worker)
|
||||
kill_thread.daemon = True
|
||||
kill_thread.start()
|
||||
|
||||
# 3. Run mmlu
|
||||
args = SimpleNamespace(
|
||||
base_url=self.base_url,
|
||||
model=self.model,
|
||||
eval_name="mmlu",
|
||||
num_examples=256,
|
||||
num_threads=32,
|
||||
temperature=0.1,
|
||||
)
|
||||
metrics = run_eval(args)
|
||||
score = metrics["score"]
|
||||
THRESHOLD = 0.635
|
||||
passed = score >= THRESHOLD
|
||||
msg = f"MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})"
|
||||
self.assertGreaterEqual(score, THRESHOLD, msg)
|
||||
|
||||
def test_9_payload_size_with_dp_aware(self):
|
||||
print("Running test_9_payload_size_with_dp_aware...")
|
||||
|
||||
# Start the router with 1MB limit
|
||||
self.process = popen_launch_router(
|
||||
self.model,
|
||||
self.base_url,
|
||||
dp_size=1,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
policy="round_robin",
|
||||
max_payload_size=1 * 1024 * 1024, # 1MB limit
|
||||
dp_aware=True,
|
||||
)
|
||||
|
||||
# Test case 1: Payload just under 1MB should succeed
|
||||
payload_0_5_mb = {
|
||||
"text": "x" * int(0.5 * 1024 * 1024), # 0.5MB of text
|
||||
"temperature": 0.0,
|
||||
}
|
||||
|
||||
with requests.Session() as session:
|
||||
response = session.post(
|
||||
f"{self.base_url}/generate",
|
||||
json=payload_0_5_mb,
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
self.assertEqual(
|
||||
response.status_code,
|
||||
200,
|
||||
f"0.5MB payload should succeed but got status {response.status_code}",
|
||||
)
|
||||
|
||||
# Test case 2: Payload over 1MB should fail
|
||||
payload_1_plus_mb = {
|
||||
"text": "x" * int((1.2 * 1024 * 1024)), # 1.2MB of text
|
||||
"temperature": 0.0,
|
||||
}
|
||||
|
||||
with requests.Session() as session:
|
||||
response = session.post(
|
||||
f"{self.base_url}/generate",
|
||||
json=payload_1_plus_mb,
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
self.assertEqual(
|
||||
response.status_code,
|
||||
413, # Payload Too Large
|
||||
f"1.2MB payload should fail with 413 but got status {response.status_code}",
|
||||
)
|
||||
|
||||
def test_10_api_key_with_dp_aware(self):
|
||||
print("Running test_10_api_key_with_dp_aware...")
|
||||
|
||||
self.process = popen_launch_router(
|
||||
self.model,
|
||||
self.base_url,
|
||||
dp_size=1,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
policy="round_robin",
|
||||
api_key="correct_api_key",
|
||||
dp_aware=True,
|
||||
)
|
||||
|
||||
# Test case 1: request without api key should fail
|
||||
with requests.Session() as session:
|
||||
response = session.post(
|
||||
f"{self.base_url}/generate",
|
||||
json={"text": "Kanye west is, ", "temperature": 0},
|
||||
)
|
||||
print(f"status code: {response.status_code}, response: {response.text}")
|
||||
self.assertEqual(
|
||||
response.status_code,
|
||||
401,
|
||||
f"Request without api key should fail with 401 but got status {response.status_code}",
|
||||
)
|
||||
|
||||
# Test case 2: request with invalid api key should fail
|
||||
with requests.Session() as session:
|
||||
response = requests.post(
|
||||
f"{self.base_url}/generate",
|
||||
json={"text": "Kanye west is, ", "temperature": 0},
|
||||
headers={"Authorization": "Bearer 123"},
|
||||
)
|
||||
print(f"status code: {response.status_code}, response: {response.text}")
|
||||
self.assertEqual(
|
||||
response.status_code,
|
||||
401,
|
||||
f"Request without api key should fail with 401 but got status {response.status_code}",
|
||||
)
|
||||
|
||||
# Test case 3: request with correct api key should succeed
|
||||
with requests.Session() as session:
|
||||
response = session.post(
|
||||
f"{self.base_url}/generate",
|
||||
json={"text": "Kanye west is ", "temperature": 0},
|
||||
headers={"Authorization": "Bearer correct_api_key"},
|
||||
)
|
||||
print(f"status code: {response.status_code}, response: {response.text}")
|
||||
self.assertEqual(
|
||||
response.status_code,
|
||||
200,
|
||||
f"Request with correct api key should succeed but got status {response.status_code}",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user