[router] add py binding unit tests to coverage 80% (#10043)
This commit is contained in:
11
.github/workflows/pr-test-rust.yml
vendored
11
.github/workflows/pr-test-rust.yml
vendored
@@ -39,7 +39,7 @@ jobs:
|
||||
cd sgl-router/
|
||||
cargo fmt -- --check
|
||||
|
||||
- name: Run test
|
||||
- name: Run Rust tests
|
||||
timeout-minutes: 20
|
||||
run: |
|
||||
source "$HOME/.cargo/env"
|
||||
@@ -83,6 +83,15 @@ jobs:
|
||||
pip install setuptools-rust wheel build
|
||||
python3 -m build
|
||||
pip install --force-reinstall dist/*.whl
|
||||
|
||||
|
||||
- name: Run Python unit tests
|
||||
run: |
|
||||
cd sgl-router
|
||||
source "$HOME/.cargo/env"
|
||||
pip install pytest pytest-cov pytest-xdist
|
||||
pytest -q py_test/unit
|
||||
|
||||
- name: Run e2e test
|
||||
run: |
|
||||
bash scripts/killall_sglang.sh "nuk_gpus"
|
||||
|
||||
9
sgl-router/.coveragerc
Normal file
9
sgl-router/.coveragerc
Normal file
@@ -0,0 +1,9 @@
|
||||
[run]
|
||||
source = py_src/sglang_router
|
||||
omit =
|
||||
py_src/sglang_router/mini_lb.py
|
||||
|
||||
[report]
|
||||
fail_under = 80
|
||||
omit =
|
||||
py_src/sglang_router/mini_lb.py
|
||||
8
sgl-router/py_test/conftest.py
Normal file
8
sgl-router/py_test/conftest.py
Normal file
@@ -0,0 +1,8 @@
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Ensure local sources in py_src are importable ahead of any installed package
|
||||
_ROOT = Path(__file__).resolve().parents[1]
|
||||
_SRC = _ROOT / "py_src"
|
||||
if str(_SRC) not in sys.path:
|
||||
sys.path.insert(0, str(_SRC))
|
||||
7
sgl-router/py_test/unit/__init__.py
Normal file
7
sgl-router/py_test/unit/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""
|
||||
Unit tests for sglang_router.
|
||||
|
||||
This package contains fast, isolated unit tests for Python components
|
||||
of the SGLang router. These tests focus on testing individual functions
|
||||
and classes in isolation without starting actual router instances.
|
||||
"""
|
||||
628
sgl-router/py_test/unit/test_arg_parser.py
Normal file
628
sgl-router/py_test/unit/test_arg_parser.py
Normal file
@@ -0,0 +1,628 @@
|
||||
"""
|
||||
Unit tests for argument parsing functionality in sglang_router.
|
||||
|
||||
These tests focus on testing the argument parsing logic in isolation,
|
||||
without starting actual router instances.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from sglang_router.launch_router import RouterArgs, parse_router_args
|
||||
from sglang_router.router import policy_from_str
|
||||
|
||||
|
||||
class TestRouterArgs:
|
||||
"""Test RouterArgs dataclass and its methods."""
|
||||
|
||||
def test_default_values(self):
|
||||
"""Test that RouterArgs has correct default values."""
|
||||
args = RouterArgs()
|
||||
|
||||
# Test basic defaults
|
||||
assert args.host == "127.0.0.1"
|
||||
assert args.port == 30000
|
||||
assert args.policy == "cache_aware"
|
||||
assert args.worker_urls == []
|
||||
assert args.pd_disaggregation is False
|
||||
assert args.prefill_urls == []
|
||||
assert args.decode_urls == []
|
||||
|
||||
# Test PD-specific defaults
|
||||
assert args.prefill_policy is None
|
||||
assert args.decode_policy is None
|
||||
|
||||
# Test service discovery defaults
|
||||
assert args.service_discovery is False
|
||||
assert args.selector == {}
|
||||
assert args.service_discovery_port == 80
|
||||
assert args.service_discovery_namespace is None
|
||||
|
||||
# Test retry and circuit breaker defaults
|
||||
assert args.retry_max_retries == 5
|
||||
assert args.cb_failure_threshold == 10
|
||||
assert args.disable_retries is False
|
||||
assert args.disable_circuit_breaker is False
|
||||
|
||||
def test_parse_selector_valid(self):
|
||||
"""Test parsing valid selector arguments."""
|
||||
# Test single key-value pair
|
||||
result = RouterArgs._parse_selector(["app=worker"])
|
||||
assert result == {"app": "worker"}
|
||||
|
||||
# Test multiple key-value pairs
|
||||
result = RouterArgs._parse_selector(["app=worker", "env=prod", "version=v1"])
|
||||
assert result == {"app": "worker", "env": "prod", "version": "v1"}
|
||||
|
||||
# Test empty list
|
||||
result = RouterArgs._parse_selector([])
|
||||
assert result == {}
|
||||
|
||||
# Test None
|
||||
result = RouterArgs._parse_selector(None)
|
||||
assert result == {}
|
||||
|
||||
def test_parse_selector_invalid(self):
|
||||
"""Test parsing invalid selector arguments."""
|
||||
# Test malformed selector (no equals sign)
|
||||
result = RouterArgs._parse_selector(["app"])
|
||||
assert result == {}
|
||||
|
||||
# Test multiple equals signs (should use first one)
|
||||
result = RouterArgs._parse_selector(["app=worker=extra"])
|
||||
assert result == {"app": "worker=extra"}
|
||||
|
||||
def test_parse_prefill_urls_valid(self):
|
||||
"""Test parsing valid prefill URL arguments."""
|
||||
# Test with bootstrap port
|
||||
result = RouterArgs._parse_prefill_urls([["http://prefill1:8000", "9000"]])
|
||||
assert result == [("http://prefill1:8000", 9000)]
|
||||
|
||||
# Test with 'none' bootstrap port
|
||||
result = RouterArgs._parse_prefill_urls([["http://prefill1:8000", "none"]])
|
||||
assert result == [("http://prefill1:8000", None)]
|
||||
|
||||
# Test without bootstrap port
|
||||
result = RouterArgs._parse_prefill_urls([["http://prefill1:8000"]])
|
||||
assert result == [("http://prefill1:8000", None)]
|
||||
|
||||
# Test multiple prefill URLs
|
||||
result = RouterArgs._parse_prefill_urls(
|
||||
[
|
||||
["http://prefill1:8000", "9000"],
|
||||
["http://prefill2:8000", "none"],
|
||||
["http://prefill3:8000"],
|
||||
]
|
||||
)
|
||||
expected = [
|
||||
("http://prefill1:8000", 9000),
|
||||
("http://prefill2:8000", None),
|
||||
("http://prefill3:8000", None),
|
||||
]
|
||||
assert result == expected
|
||||
|
||||
# Test empty list
|
||||
result = RouterArgs._parse_prefill_urls([])
|
||||
assert result == []
|
||||
|
||||
# Test None
|
||||
result = RouterArgs._parse_prefill_urls(None)
|
||||
assert result == []
|
||||
|
||||
def test_parse_prefill_urls_invalid(self):
|
||||
"""Test parsing invalid prefill URL arguments."""
|
||||
# Test invalid bootstrap port
|
||||
with pytest.raises(ValueError, match="Invalid bootstrap port"):
|
||||
RouterArgs._parse_prefill_urls([["http://prefill1:8000", "invalid"]])
|
||||
|
||||
def test_parse_decode_urls_valid(self):
|
||||
"""Test parsing valid decode URL arguments."""
|
||||
# Test single decode URL
|
||||
result = RouterArgs._parse_decode_urls([["http://decode1:8001"]])
|
||||
assert result == ["http://decode1:8001"]
|
||||
|
||||
# Test multiple decode URLs
|
||||
result = RouterArgs._parse_decode_urls(
|
||||
[["http://decode1:8001"], ["http://decode2:8001"]]
|
||||
)
|
||||
assert result == ["http://decode1:8001", "http://decode2:8001"]
|
||||
|
||||
# Test empty list
|
||||
result = RouterArgs._parse_decode_urls([])
|
||||
assert result == []
|
||||
|
||||
# Test None
|
||||
result = RouterArgs._parse_decode_urls(None)
|
||||
assert result == []
|
||||
|
||||
def test_from_cli_args_basic(self):
|
||||
"""Test creating RouterArgs from basic CLI arguments."""
|
||||
args = SimpleNamespace(
|
||||
host="0.0.0.0",
|
||||
port=30001,
|
||||
worker_urls=["http://worker1:8000", "http://worker2:8000"],
|
||||
policy="round_robin",
|
||||
prefill=None,
|
||||
decode=None,
|
||||
router_policy="round_robin",
|
||||
router_pd_disaggregation=False,
|
||||
router_prefill_policy=None,
|
||||
router_decode_policy=None,
|
||||
router_worker_startup_timeout_secs=300,
|
||||
router_worker_startup_check_interval=15,
|
||||
router_cache_threshold=0.7,
|
||||
router_balance_abs_threshold=128,
|
||||
router_balance_rel_threshold=2.0,
|
||||
router_eviction_interval=180,
|
||||
router_max_tree_size=2**28,
|
||||
router_max_payload_size=1024 * 1024 * 1024, # 1GB
|
||||
router_dp_aware=True,
|
||||
router_api_key="test-key",
|
||||
router_log_dir="/tmp/logs",
|
||||
router_log_level="debug",
|
||||
router_service_discovery=True,
|
||||
router_selector=["app=worker", "env=test"],
|
||||
router_service_discovery_port=8080,
|
||||
router_service_discovery_namespace="default",
|
||||
router_prefill_selector=["app=prefill"],
|
||||
router_decode_selector=["app=decode"],
|
||||
router_prometheus_port=29000,
|
||||
router_prometheus_host="0.0.0.0",
|
||||
router_request_id_headers=["x-request-id", "x-trace-id"],
|
||||
router_request_timeout_secs=1200,
|
||||
router_max_concurrent_requests=512,
|
||||
router_queue_size=200,
|
||||
router_queue_timeout_secs=120,
|
||||
router_rate_limit_tokens_per_second=100,
|
||||
router_cors_allowed_origins=["http://localhost:3000"],
|
||||
router_retry_max_retries=3,
|
||||
router_retry_initial_backoff_ms=100,
|
||||
router_retry_max_backoff_ms=10000,
|
||||
router_retry_backoff_multiplier=2.0,
|
||||
router_retry_jitter_factor=0.1,
|
||||
router_cb_failure_threshold=5,
|
||||
router_cb_success_threshold=2,
|
||||
router_cb_timeout_duration_secs=30,
|
||||
router_cb_window_duration_secs=60,
|
||||
router_disable_retries=False,
|
||||
router_disable_circuit_breaker=False,
|
||||
router_health_failure_threshold=2,
|
||||
router_health_success_threshold=1,
|
||||
router_health_check_timeout_secs=3,
|
||||
router_health_check_interval_secs=30,
|
||||
router_health_check_endpoint="/healthz",
|
||||
)
|
||||
|
||||
router_args = RouterArgs.from_cli_args(args, use_router_prefix=True)
|
||||
|
||||
# Test basic configuration
|
||||
assert router_args.host == "0.0.0.0"
|
||||
assert router_args.port == 30001
|
||||
assert router_args.worker_urls == ["http://worker1:8000", "http://worker2:8000"]
|
||||
assert router_args.policy == "round_robin"
|
||||
|
||||
# Test PD configuration
|
||||
assert router_args.pd_disaggregation is False
|
||||
assert router_args.prefill_urls == []
|
||||
assert router_args.decode_urls == []
|
||||
|
||||
# Test service discovery
|
||||
assert router_args.service_discovery is True
|
||||
assert router_args.selector == {"app": "worker", "env": "test"}
|
||||
assert router_args.service_discovery_port == 8080
|
||||
assert router_args.service_discovery_namespace == "default"
|
||||
assert router_args.prefill_selector == {"app": "prefill"}
|
||||
assert router_args.decode_selector == {"app": "decode"}
|
||||
|
||||
# Test other configurations
|
||||
assert router_args.dp_aware is True
|
||||
assert router_args.api_key == "test-key"
|
||||
assert router_args.log_dir == "/tmp/logs"
|
||||
assert router_args.log_level == "debug"
|
||||
assert router_args.prometheus_port == 29000
|
||||
assert router_args.prometheus_host == "0.0.0.0"
|
||||
assert router_args.request_id_headers == ["x-request-id", "x-trace-id"]
|
||||
assert router_args.request_timeout_secs == 1200
|
||||
assert router_args.max_concurrent_requests == 512
|
||||
assert router_args.queue_size == 200
|
||||
assert router_args.queue_timeout_secs == 120
|
||||
assert router_args.rate_limit_tokens_per_second == 100
|
||||
assert router_args.cors_allowed_origins == ["http://localhost:3000"]
|
||||
|
||||
# Test retry configuration
|
||||
assert router_args.retry_max_retries == 3
|
||||
assert router_args.retry_initial_backoff_ms == 100
|
||||
assert router_args.retry_max_backoff_ms == 10000
|
||||
assert router_args.retry_backoff_multiplier == 2.0
|
||||
assert router_args.retry_jitter_factor == 0.1
|
||||
|
||||
# Test circuit breaker configuration
|
||||
assert router_args.cb_failure_threshold == 5
|
||||
assert router_args.cb_success_threshold == 2
|
||||
assert router_args.cb_timeout_duration_secs == 30
|
||||
assert router_args.cb_window_duration_secs == 60
|
||||
assert router_args.disable_retries is False
|
||||
assert router_args.disable_circuit_breaker is False
|
||||
|
||||
# Test health check configuration
|
||||
assert router_args.health_failure_threshold == 2
|
||||
assert router_args.health_success_threshold == 1
|
||||
assert router_args.health_check_timeout_secs == 3
|
||||
assert router_args.health_check_interval_secs == 30
|
||||
assert router_args.health_check_endpoint == "/healthz"
|
||||
|
||||
# Note: model_path and tokenizer_path are not available in current RouterArgs
|
||||
|
||||
def test_from_cli_args_pd_mode(self):
|
||||
"""Test creating RouterArgs from CLI arguments in PD mode."""
|
||||
args = SimpleNamespace(
|
||||
host="127.0.0.1",
|
||||
port=30000,
|
||||
worker_urls=[],
|
||||
policy="cache_aware",
|
||||
prefill=[
|
||||
["http://prefill1:8000", "9000"],
|
||||
["http://prefill2:8000", "none"],
|
||||
],
|
||||
decode=[["http://decode1:8001"], ["http://decode2:8001"]],
|
||||
router_prefill=[
|
||||
["http://prefill1:8000", "9000"],
|
||||
["http://prefill2:8000", "none"],
|
||||
],
|
||||
router_decode=[["http://decode1:8001"], ["http://decode2:8001"]],
|
||||
router_policy="cache_aware",
|
||||
router_pd_disaggregation=True,
|
||||
router_prefill_policy="power_of_two",
|
||||
router_decode_policy="round_robin",
|
||||
# Include all required fields with defaults
|
||||
router_worker_startup_timeout_secs=600,
|
||||
router_worker_startup_check_interval=30,
|
||||
router_cache_threshold=0.3,
|
||||
router_balance_abs_threshold=64,
|
||||
router_balance_rel_threshold=1.5,
|
||||
router_eviction_interval=120,
|
||||
router_max_tree_size=2**26,
|
||||
router_max_payload_size=512 * 1024 * 1024,
|
||||
router_dp_aware=False,
|
||||
router_api_key=None,
|
||||
router_log_dir=None,
|
||||
router_log_level=None,
|
||||
router_service_discovery=False,
|
||||
router_selector=None,
|
||||
router_service_discovery_port=80,
|
||||
router_service_discovery_namespace=None,
|
||||
router_prefill_selector=None,
|
||||
router_decode_selector=None,
|
||||
router_prometheus_port=None,
|
||||
router_prometheus_host=None,
|
||||
router_request_id_headers=None,
|
||||
router_request_timeout_secs=1800,
|
||||
router_max_concurrent_requests=256,
|
||||
router_queue_size=100,
|
||||
router_queue_timeout_secs=60,
|
||||
router_rate_limit_tokens_per_second=None,
|
||||
router_cors_allowed_origins=[],
|
||||
router_retry_max_retries=5,
|
||||
router_retry_initial_backoff_ms=50,
|
||||
router_retry_max_backoff_ms=30000,
|
||||
router_retry_backoff_multiplier=1.5,
|
||||
router_retry_jitter_factor=0.2,
|
||||
router_cb_failure_threshold=10,
|
||||
router_cb_success_threshold=3,
|
||||
router_cb_timeout_duration_secs=60,
|
||||
router_cb_window_duration_secs=120,
|
||||
router_disable_retries=False,
|
||||
router_disable_circuit_breaker=False,
|
||||
router_health_failure_threshold=3,
|
||||
router_health_success_threshold=2,
|
||||
router_health_check_timeout_secs=5,
|
||||
router_health_check_interval_secs=60,
|
||||
router_health_check_endpoint="/health",
|
||||
)
|
||||
|
||||
router_args = RouterArgs.from_cli_args(args, use_router_prefix=True)
|
||||
|
||||
# Test PD configuration
|
||||
assert router_args.pd_disaggregation is True
|
||||
assert router_args.prefill_urls == [
|
||||
("http://prefill1:8000", 9000),
|
||||
("http://prefill2:8000", None),
|
||||
]
|
||||
assert router_args.decode_urls == ["http://decode1:8001", "http://decode2:8001"]
|
||||
assert router_args.prefill_policy == "power_of_two"
|
||||
assert router_args.decode_policy == "round_robin"
|
||||
assert router_args.policy == "cache_aware" # Main policy still set
|
||||
|
||||
def test_from_cli_args_without_prefix(self):
|
||||
"""Test creating RouterArgs from CLI arguments without router prefix."""
|
||||
args = SimpleNamespace(
|
||||
host="127.0.0.1",
|
||||
port=30000,
|
||||
worker_urls=["http://worker1:8000"],
|
||||
policy="random",
|
||||
prefill=None,
|
||||
decode=None,
|
||||
pd_disaggregation=False,
|
||||
prefill_policy=None,
|
||||
decode_policy=None,
|
||||
worker_startup_timeout_secs=600,
|
||||
worker_startup_check_interval=30,
|
||||
cache_threshold=0.3,
|
||||
balance_abs_threshold=64,
|
||||
balance_rel_threshold=1.5,
|
||||
eviction_interval=120,
|
||||
max_tree_size=2**26,
|
||||
max_payload_size=512 * 1024 * 1024,
|
||||
dp_aware=False,
|
||||
api_key=None,
|
||||
log_dir=None,
|
||||
log_level=None,
|
||||
service_discovery=False,
|
||||
selector=None,
|
||||
service_discovery_port=80,
|
||||
service_discovery_namespace=None,
|
||||
prefill_selector=None,
|
||||
decode_selector=None,
|
||||
prometheus_port=None,
|
||||
prometheus_host=None,
|
||||
request_id_headers=None,
|
||||
request_timeout_secs=1800,
|
||||
max_concurrent_requests=256,
|
||||
queue_size=100,
|
||||
queue_timeout_secs=60,
|
||||
rate_limit_tokens_per_second=None,
|
||||
cors_allowed_origins=[],
|
||||
retry_max_retries=5,
|
||||
retry_initial_backoff_ms=50,
|
||||
retry_max_backoff_ms=30000,
|
||||
retry_backoff_multiplier=1.5,
|
||||
retry_jitter_factor=0.2,
|
||||
cb_failure_threshold=10,
|
||||
cb_success_threshold=3,
|
||||
cb_timeout_duration_secs=60,
|
||||
cb_window_duration_secs=120,
|
||||
disable_retries=False,
|
||||
disable_circuit_breaker=False,
|
||||
health_failure_threshold=3,
|
||||
health_success_threshold=2,
|
||||
health_check_timeout_secs=5,
|
||||
health_check_interval_secs=60,
|
||||
health_check_endpoint="/health",
|
||||
model_path=None,
|
||||
tokenizer_path=None,
|
||||
)
|
||||
|
||||
router_args = RouterArgs.from_cli_args(args, use_router_prefix=False)
|
||||
|
||||
assert router_args.host == "127.0.0.1"
|
||||
assert router_args.port == 30000
|
||||
assert router_args.worker_urls == ["http://worker1:8000"]
|
||||
assert router_args.policy == "random"
|
||||
assert router_args.pd_disaggregation is False
|
||||
|
||||
|
||||
class TestPolicyFromStr:
|
||||
"""Test policy string to enum conversion."""
|
||||
|
||||
def test_valid_policies(self):
|
||||
"""Test conversion of valid policy strings."""
|
||||
from sglang_router_rs import PolicyType
|
||||
|
||||
assert policy_from_str("random") == PolicyType.Random
|
||||
assert policy_from_str("round_robin") == PolicyType.RoundRobin
|
||||
assert policy_from_str("cache_aware") == PolicyType.CacheAware
|
||||
assert policy_from_str("power_of_two") == PolicyType.PowerOfTwo
|
||||
|
||||
def test_invalid_policy(self):
|
||||
"""Test conversion of invalid policy string."""
|
||||
with pytest.raises(KeyError):
|
||||
policy_from_str("invalid_policy")
|
||||
|
||||
|
||||
class TestParseRouterArgs:
|
||||
"""Test the parse_router_args function."""
|
||||
|
||||
def test_parse_basic_args(self):
|
||||
"""Test parsing basic router arguments."""
|
||||
args = [
|
||||
"--host",
|
||||
"0.0.0.0",
|
||||
"--port",
|
||||
"30001",
|
||||
"--worker-urls",
|
||||
"http://worker1:8000",
|
||||
"http://worker2:8000",
|
||||
"--policy",
|
||||
"round_robin",
|
||||
]
|
||||
|
||||
router_args = parse_router_args(args)
|
||||
|
||||
assert router_args.host == "0.0.0.0"
|
||||
assert router_args.port == 30001
|
||||
assert router_args.worker_urls == ["http://worker1:8000", "http://worker2:8000"]
|
||||
assert router_args.policy == "round_robin"
|
||||
|
||||
def test_parse_pd_args(self):
|
||||
"""Test parsing PD disaggregated mode arguments."""
|
||||
args = [
|
||||
"--pd-disaggregation",
|
||||
"--prefill",
|
||||
"http://prefill1:8000",
|
||||
"9000",
|
||||
"--prefill",
|
||||
"http://prefill2:8000",
|
||||
"none",
|
||||
"--decode",
|
||||
"http://decode1:8001",
|
||||
"--decode",
|
||||
"http://decode2:8001",
|
||||
"--prefill-policy",
|
||||
"power_of_two",
|
||||
"--decode-policy",
|
||||
"round_robin",
|
||||
]
|
||||
|
||||
router_args = parse_router_args(args)
|
||||
|
||||
assert router_args.pd_disaggregation is True
|
||||
assert router_args.prefill_urls == [
|
||||
("http://prefill1:8000", 9000),
|
||||
("http://prefill2:8000", None),
|
||||
]
|
||||
assert router_args.decode_urls == ["http://decode1:8001", "http://decode2:8001"]
|
||||
assert router_args.prefill_policy == "power_of_two"
|
||||
assert router_args.decode_policy == "round_robin"
|
||||
|
||||
def test_parse_service_discovery_args(self):
|
||||
"""Test parsing service discovery arguments."""
|
||||
args = [
|
||||
"--service-discovery",
|
||||
"--selector",
|
||||
"app=worker",
|
||||
"env=prod",
|
||||
"--service-discovery-port",
|
||||
"8080",
|
||||
"--service-discovery-namespace",
|
||||
"default",
|
||||
]
|
||||
|
||||
router_args = parse_router_args(args)
|
||||
|
||||
assert router_args.service_discovery is True
|
||||
assert router_args.selector == {"app": "worker", "env": "prod"}
|
||||
assert router_args.service_discovery_port == 8080
|
||||
assert router_args.service_discovery_namespace == "default"
|
||||
|
||||
def test_parse_retry_and_circuit_breaker_args(self):
|
||||
"""Test parsing retry and circuit breaker arguments."""
|
||||
args = [
|
||||
"--retry-max-retries",
|
||||
"3",
|
||||
"--retry-initial-backoff-ms",
|
||||
"100",
|
||||
"--retry-max-backoff-ms",
|
||||
"10000",
|
||||
"--retry-backoff-multiplier",
|
||||
"2.0",
|
||||
"--retry-jitter-factor",
|
||||
"0.1",
|
||||
"--disable-retries",
|
||||
"--cb-failure-threshold",
|
||||
"5",
|
||||
"--cb-success-threshold",
|
||||
"2",
|
||||
"--cb-timeout-duration-secs",
|
||||
"30",
|
||||
"--cb-window-duration-secs",
|
||||
"60",
|
||||
"--disable-circuit-breaker",
|
||||
]
|
||||
|
||||
router_args = parse_router_args(args)
|
||||
|
||||
# Test retry configuration
|
||||
assert router_args.retry_max_retries == 3
|
||||
assert router_args.retry_initial_backoff_ms == 100
|
||||
assert router_args.retry_max_backoff_ms == 10000
|
||||
assert router_args.retry_backoff_multiplier == 2.0
|
||||
assert router_args.retry_jitter_factor == 0.1
|
||||
assert router_args.disable_retries is True
|
||||
|
||||
# Test circuit breaker configuration
|
||||
assert router_args.cb_failure_threshold == 5
|
||||
assert router_args.cb_success_threshold == 2
|
||||
assert router_args.cb_timeout_duration_secs == 30
|
||||
assert router_args.cb_window_duration_secs == 60
|
||||
assert router_args.disable_circuit_breaker is True
|
||||
|
||||
def test_parse_rate_limiting_args(self):
|
||||
"""Test parsing rate limiting arguments."""
|
||||
args = [
|
||||
"--max-concurrent-requests",
|
||||
"512",
|
||||
"--queue-size",
|
||||
"200",
|
||||
"--queue-timeout-secs",
|
||||
"120",
|
||||
"--rate-limit-tokens-per-second",
|
||||
"100",
|
||||
]
|
||||
|
||||
router_args = parse_router_args(args)
|
||||
|
||||
assert router_args.max_concurrent_requests == 512
|
||||
assert router_args.queue_size == 200
|
||||
assert router_args.queue_timeout_secs == 120
|
||||
assert router_args.rate_limit_tokens_per_second == 100
|
||||
|
||||
def test_parse_health_check_args(self):
|
||||
"""Test parsing health check arguments."""
|
||||
args = [
|
||||
"--health-failure-threshold",
|
||||
"2",
|
||||
"--health-success-threshold",
|
||||
"1",
|
||||
"--health-check-timeout-secs",
|
||||
"3",
|
||||
"--health-check-interval-secs",
|
||||
"30",
|
||||
"--health-check-endpoint",
|
||||
"/healthz",
|
||||
]
|
||||
|
||||
router_args = parse_router_args(args)
|
||||
|
||||
assert router_args.health_failure_threshold == 2
|
||||
assert router_args.health_success_threshold == 1
|
||||
assert router_args.health_check_timeout_secs == 3
|
||||
assert router_args.health_check_interval_secs == 30
|
||||
assert router_args.health_check_endpoint == "/healthz"
|
||||
|
||||
def test_parse_cors_args(self):
|
||||
"""Test parsing CORS arguments."""
|
||||
args = [
|
||||
"--cors-allowed-origins",
|
||||
"http://localhost:3000",
|
||||
"https://example.com",
|
||||
]
|
||||
|
||||
router_args = parse_router_args(args)
|
||||
|
||||
assert router_args.cors_allowed_origins == [
|
||||
"http://localhost:3000",
|
||||
"https://example.com",
|
||||
]
|
||||
|
||||
def test_parse_tokenizer_args(self):
|
||||
"""Test parsing tokenizer arguments."""
|
||||
# Note: model-path and tokenizer-path arguments are not available in current implementation
|
||||
# This test is skipped until those arguments are added
|
||||
pytest.skip("Tokenizer arguments not available in current implementation")
|
||||
|
||||
def test_parse_invalid_args(self):
|
||||
"""Test parsing invalid arguments."""
|
||||
# Test invalid policy
|
||||
with pytest.raises(SystemExit):
|
||||
parse_router_args(["--policy", "invalid_policy"])
|
||||
|
||||
# Test invalid bootstrap port
|
||||
with pytest.raises(ValueError, match="Invalid bootstrap port"):
|
||||
parse_router_args(
|
||||
[
|
||||
"--pd-disaggregation",
|
||||
"--prefill",
|
||||
"http://prefill1:8000",
|
||||
"invalid_port",
|
||||
]
|
||||
)
|
||||
|
||||
def test_help_output(self):
|
||||
"""Test that help output is generated correctly."""
|
||||
with pytest.raises(SystemExit) as exc_info:
|
||||
parse_router_args(["--help"])
|
||||
|
||||
# SystemExit with code 0 indicates help was displayed
|
||||
assert exc_info.value.code == 0
|
||||
421
sgl-router/py_test/unit/test_router_config.py
Normal file
421
sgl-router/py_test/unit/test_router_config.py
Normal file
@@ -0,0 +1,421 @@
|
||||
"""
|
||||
Unit tests for router configuration validation and setup.
|
||||
|
||||
These tests focus on testing the router configuration logic in isolation,
|
||||
including validation of configuration parameters and their interactions.
|
||||
"""
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from sglang_router.launch_router import RouterArgs, launch_router
|
||||
from sglang_router.router import policy_from_str
|
||||
from sglang_router_rs import PolicyType
|
||||
|
||||
|
||||
class TestRouterConfigValidation:
|
||||
"""Test router configuration validation logic."""
|
||||
|
||||
def test_valid_basic_config(self):
|
||||
"""Test that a valid basic configuration passes validation."""
|
||||
args = RouterArgs(
|
||||
host="127.0.0.1",
|
||||
port=30000,
|
||||
worker_urls=["http://worker1:8000", "http://worker2:8000"],
|
||||
policy="cache_aware",
|
||||
)
|
||||
|
||||
# Should not raise any exceptions
|
||||
assert args.host == "127.0.0.1"
|
||||
assert args.port == 30000
|
||||
assert args.worker_urls == ["http://worker1:8000", "http://worker2:8000"]
|
||||
assert args.policy == "cache_aware"
|
||||
|
||||
def test_valid_pd_config(self):
|
||||
"""Test that a valid PD configuration passes validation."""
|
||||
args = RouterArgs(
|
||||
host="127.0.0.1",
|
||||
port=30000,
|
||||
pd_disaggregation=True,
|
||||
prefill_urls=[
|
||||
("http://prefill1:8000", 9000),
|
||||
("http://prefill2:8000", None),
|
||||
],
|
||||
decode_urls=["http://decode1:8001", "http://decode2:8001"],
|
||||
policy="cache_aware",
|
||||
)
|
||||
|
||||
assert args.pd_disaggregation is True
|
||||
assert args.prefill_urls == [
|
||||
("http://prefill1:8000", 9000),
|
||||
("http://prefill2:8000", None),
|
||||
]
|
||||
assert args.decode_urls == ["http://decode1:8001", "http://decode2:8001"]
|
||||
assert args.policy == "cache_aware"
|
||||
|
||||
def test_pd_config_without_urls_raises_error(self):
|
||||
"""Test that PD mode without URLs raises validation error."""
|
||||
args = RouterArgs(
|
||||
pd_disaggregation=True,
|
||||
prefill_urls=[],
|
||||
decode_urls=[],
|
||||
service_discovery=False,
|
||||
)
|
||||
|
||||
# This should raise an error when trying to launch
|
||||
with pytest.raises(
|
||||
ValueError, match="PD disaggregation mode requires --prefill"
|
||||
):
|
||||
launch_router(args)
|
||||
|
||||
def test_pd_config_with_service_discovery_allows_empty_urls(self):
|
||||
"""Test that PD mode with service discovery allows empty URLs."""
|
||||
args = RouterArgs(
|
||||
pd_disaggregation=True,
|
||||
prefill_urls=[],
|
||||
decode_urls=[],
|
||||
service_discovery=True,
|
||||
)
|
||||
|
||||
# Should not raise validation error when service discovery is enabled
|
||||
with patch("sglang_router.launch_router.Router") as router_mod:
|
||||
mock_router_instance = MagicMock()
|
||||
router_mod.from_args = MagicMock(return_value=mock_router_instance)
|
||||
|
||||
launch_router(args)
|
||||
|
||||
# Should create router instance via from_args
|
||||
router_mod.from_args.assert_called_once()
|
||||
|
||||
def test_regular_mode_without_workers_allows_empty_urls(self):
|
||||
"""Test that regular mode allows empty worker URLs."""
|
||||
args = RouterArgs(worker_urls=[], service_discovery=False)
|
||||
|
||||
# Should not raise validation error
|
||||
with patch("sglang_router.launch_router.Router") as router_mod:
|
||||
mock_router_instance = MagicMock()
|
||||
router_mod.from_args = MagicMock(return_value=mock_router_instance)
|
||||
|
||||
launch_router(args)
|
||||
|
||||
# Should create router instance via from_args
|
||||
router_mod.from_args.assert_called_once()
|
||||
|
||||
def test_cache_threshold_validation(self):
|
||||
"""Test cache threshold validation."""
|
||||
# Valid cache threshold
|
||||
args = RouterArgs(cache_threshold=0.5)
|
||||
assert args.cache_threshold == 0.5
|
||||
|
||||
# Edge cases
|
||||
args = RouterArgs(cache_threshold=0.0)
|
||||
assert args.cache_threshold == 0.0
|
||||
|
||||
args = RouterArgs(cache_threshold=1.0)
|
||||
assert args.cache_threshold == 1.0
|
||||
|
||||
def test_balance_threshold_validation(self):
|
||||
"""Test load balancing threshold validation."""
|
||||
# Valid thresholds
|
||||
args = RouterArgs(balance_abs_threshold=64, balance_rel_threshold=1.5)
|
||||
assert args.balance_abs_threshold == 64
|
||||
assert args.balance_rel_threshold == 1.5
|
||||
|
||||
# Edge cases
|
||||
args = RouterArgs(balance_abs_threshold=0, balance_rel_threshold=1.0)
|
||||
assert args.balance_abs_threshold == 0
|
||||
assert args.balance_rel_threshold == 1.0
|
||||
|
||||
def test_timeout_validation(self):
|
||||
"""Test timeout parameter validation."""
|
||||
# Valid timeouts
|
||||
args = RouterArgs(
|
||||
worker_startup_timeout_secs=600,
|
||||
worker_startup_check_interval=30,
|
||||
request_timeout_secs=1800,
|
||||
queue_timeout_secs=60,
|
||||
)
|
||||
assert args.worker_startup_timeout_secs == 600
|
||||
assert args.worker_startup_check_interval == 30
|
||||
assert args.request_timeout_secs == 1800
|
||||
assert args.queue_timeout_secs == 60
|
||||
|
||||
def test_retry_config_validation(self):
|
||||
"""Test retry configuration validation."""
|
||||
# Valid retry config
|
||||
args = RouterArgs(
|
||||
retry_max_retries=5,
|
||||
retry_initial_backoff_ms=50,
|
||||
retry_max_backoff_ms=30000,
|
||||
retry_backoff_multiplier=1.5,
|
||||
retry_jitter_factor=0.2,
|
||||
disable_retries=False,
|
||||
)
|
||||
assert args.retry_max_retries == 5
|
||||
assert args.retry_initial_backoff_ms == 50
|
||||
assert args.retry_max_backoff_ms == 30000
|
||||
assert args.retry_backoff_multiplier == 1.5
|
||||
assert args.retry_jitter_factor == 0.2
|
||||
assert args.disable_retries is False
|
||||
|
||||
def test_circuit_breaker_config_validation(self):
|
||||
"""Test circuit breaker configuration validation."""
|
||||
# Valid circuit breaker config
|
||||
args = RouterArgs(
|
||||
cb_failure_threshold=10,
|
||||
cb_success_threshold=3,
|
||||
cb_timeout_duration_secs=60,
|
||||
cb_window_duration_secs=120,
|
||||
disable_circuit_breaker=False,
|
||||
)
|
||||
assert args.cb_failure_threshold == 10
|
||||
assert args.cb_success_threshold == 3
|
||||
assert args.cb_timeout_duration_secs == 60
|
||||
assert args.cb_window_duration_secs == 120
|
||||
assert args.disable_circuit_breaker is False
|
||||
|
||||
def test_health_check_config_validation(self):
|
||||
"""Test health check configuration validation."""
|
||||
# Valid health check config
|
||||
args = RouterArgs(
|
||||
health_failure_threshold=3,
|
||||
health_success_threshold=2,
|
||||
health_check_timeout_secs=5,
|
||||
health_check_interval_secs=60,
|
||||
health_check_endpoint="/health",
|
||||
)
|
||||
assert args.health_failure_threshold == 3
|
||||
assert args.health_success_threshold == 2
|
||||
assert args.health_check_timeout_secs == 5
|
||||
assert args.health_check_interval_secs == 60
|
||||
assert args.health_check_endpoint == "/health"
|
||||
|
||||
def test_rate_limiting_config_validation(self):
|
||||
"""Test rate limiting configuration validation."""
|
||||
# Valid rate limiting config
|
||||
args = RouterArgs(
|
||||
max_concurrent_requests=256,
|
||||
queue_size=100,
|
||||
queue_timeout_secs=60,
|
||||
rate_limit_tokens_per_second=100,
|
||||
)
|
||||
assert args.max_concurrent_requests == 256
|
||||
assert args.queue_size == 100
|
||||
assert args.queue_timeout_secs == 60
|
||||
assert args.rate_limit_tokens_per_second == 100
|
||||
|
||||
def test_service_discovery_config_validation(self):
|
||||
"""Test service discovery configuration validation."""
|
||||
# Valid service discovery config
|
||||
args = RouterArgs(
|
||||
service_discovery=True,
|
||||
selector={"app": "worker", "env": "prod"},
|
||||
service_discovery_port=8080,
|
||||
service_discovery_namespace="default",
|
||||
)
|
||||
assert args.service_discovery is True
|
||||
assert args.selector == {"app": "worker", "env": "prod"}
|
||||
assert args.service_discovery_port == 8080
|
||||
assert args.service_discovery_namespace == "default"
|
||||
|
||||
def test_pd_service_discovery_config_validation(self):
|
||||
"""Test PD service discovery configuration validation."""
|
||||
# Valid PD service discovery config
|
||||
args = RouterArgs(
|
||||
pd_disaggregation=True,
|
||||
service_discovery=True,
|
||||
prefill_selector={"app": "prefill"},
|
||||
decode_selector={"app": "decode"},
|
||||
bootstrap_port_annotation="sglang.ai/bootstrap-port",
|
||||
)
|
||||
assert args.pd_disaggregation is True
|
||||
assert args.service_discovery is True
|
||||
assert args.prefill_selector == {"app": "prefill"}
|
||||
assert args.decode_selector == {"app": "decode"}
|
||||
assert args.bootstrap_port_annotation == "sglang.ai/bootstrap-port"
|
||||
|
||||
def test_prometheus_config_validation(self):
|
||||
"""Test Prometheus configuration validation."""
|
||||
# Valid Prometheus config
|
||||
args = RouterArgs(prometheus_port=29000, prometheus_host="127.0.0.1")
|
||||
assert args.prometheus_port == 29000
|
||||
assert args.prometheus_host == "127.0.0.1"
|
||||
|
||||
def test_cors_config_validation(self):
|
||||
"""Test CORS configuration validation."""
|
||||
# Valid CORS config
|
||||
args = RouterArgs(
|
||||
cors_allowed_origins=["http://localhost:3000", "https://example.com"]
|
||||
)
|
||||
assert args.cors_allowed_origins == [
|
||||
"http://localhost:3000",
|
||||
"https://example.com",
|
||||
]
|
||||
|
||||
def test_tokenizer_config_validation(self):
|
||||
"""Test tokenizer configuration validation."""
|
||||
# Note: model_path and tokenizer_path are not available in current RouterArgs
|
||||
pytest.skip("Tokenizer configuration not available in current implementation")
|
||||
|
||||
def test_dp_aware_config_validation(self):
|
||||
"""Test data parallelism aware configuration validation."""
|
||||
# Valid DP aware config
|
||||
args = RouterArgs(dp_aware=True, api_key="test-api-key")
|
||||
assert args.dp_aware is True
|
||||
assert args.api_key == "test-api-key"
|
||||
|
||||
def test_request_id_headers_validation(self):
|
||||
"""Test request ID headers configuration validation."""
|
||||
# Valid request ID headers config
|
||||
args = RouterArgs(
|
||||
request_id_headers=["x-request-id", "x-trace-id", "x-correlation-id"]
|
||||
)
|
||||
assert args.request_id_headers == [
|
||||
"x-request-id",
|
||||
"x-trace-id",
|
||||
"x-correlation-id",
|
||||
]
|
||||
|
||||
def test_policy_consistency_validation(self):
|
||||
"""Test policy consistency validation in PD mode."""
|
||||
# Test with both prefill and decode policies specified
|
||||
args = RouterArgs(
|
||||
pd_disaggregation=True,
|
||||
prefill_urls=[("http://prefill1:8000", None)],
|
||||
decode_urls=["http://decode1:8001"],
|
||||
policy="cache_aware",
|
||||
prefill_policy="power_of_two",
|
||||
decode_policy="round_robin",
|
||||
)
|
||||
|
||||
# Should not raise validation error
|
||||
with patch("sglang_router.launch_router.Router") as router_mod:
|
||||
mock_router_instance = MagicMock()
|
||||
router_mod.from_args = MagicMock(return_value=mock_router_instance)
|
||||
|
||||
launch_router(args)
|
||||
|
||||
# Should create router instance via from_args
|
||||
router_mod.from_args.assert_called_once()
|
||||
|
||||
def test_policy_fallback_validation(self):
|
||||
"""Test policy fallback validation in PD mode."""
|
||||
# Test with only prefill policy specified
|
||||
args = RouterArgs(
|
||||
pd_disaggregation=True,
|
||||
prefill_urls=[("http://prefill1:8000", None)],
|
||||
decode_urls=["http://decode1:8001"],
|
||||
policy="cache_aware",
|
||||
prefill_policy="power_of_two",
|
||||
decode_policy=None,
|
||||
)
|
||||
|
||||
# Should not raise validation error
|
||||
with patch("sglang_router.launch_router.Router") as router_mod:
|
||||
mock_router_instance = MagicMock()
|
||||
router_mod.from_args = MagicMock(return_value=mock_router_instance)
|
||||
|
||||
launch_router(args)
|
||||
|
||||
# Should create router instance via from_args
|
||||
router_mod.from_args.assert_called_once()
|
||||
|
||||
def test_policy_enum_conversion(self):
|
||||
"""Test policy string to enum conversion."""
|
||||
# Test all valid policy conversions
|
||||
assert policy_from_str("random") == PolicyType.Random
|
||||
assert policy_from_str("round_robin") == PolicyType.RoundRobin
|
||||
assert policy_from_str("cache_aware") == PolicyType.CacheAware
|
||||
assert policy_from_str("power_of_two") == PolicyType.PowerOfTwo
|
||||
|
||||
def test_invalid_policy_enum_conversion(self):
|
||||
"""Test invalid policy string to enum conversion."""
|
||||
with pytest.raises(KeyError):
|
||||
policy_from_str("invalid_policy")
|
||||
|
||||
def test_config_immutability(self):
|
||||
"""Test that configuration objects are properly immutable."""
|
||||
args = RouterArgs(
|
||||
host="127.0.0.1", port=30000, worker_urls=["http://worker1:8000"]
|
||||
)
|
||||
|
||||
# Test that we can't modify the configuration after creation
|
||||
# (This is more of a design test - dataclasses are mutable by default)
|
||||
original_host = args.host
|
||||
args.host = "0.0.0.0"
|
||||
assert args.host == "0.0.0.0" # Dataclasses are mutable
|
||||
assert args.host != original_host
|
||||
|
||||
def test_config_defaults_consistency(self):
|
||||
"""Test that configuration defaults are consistent."""
|
||||
args1 = RouterArgs()
|
||||
args2 = RouterArgs()
|
||||
|
||||
# Both instances should have the same defaults
|
||||
assert args1.host == args2.host
|
||||
assert args1.port == args2.port
|
||||
assert args1.policy == args2.policy
|
||||
assert args1.worker_urls == args2.worker_urls
|
||||
assert args1.pd_disaggregation == args2.pd_disaggregation
|
||||
|
||||
def test_config_serialization(self):
|
||||
"""Test that configuration can be serialized/deserialized."""
|
||||
args = RouterArgs(
|
||||
host="127.0.0.1",
|
||||
port=30000,
|
||||
worker_urls=["http://worker1:8000"],
|
||||
policy="cache_aware",
|
||||
cache_threshold=0.5,
|
||||
)
|
||||
|
||||
# Test that we can access all attributes
|
||||
assert hasattr(args, "host")
|
||||
assert hasattr(args, "port")
|
||||
assert hasattr(args, "worker_urls")
|
||||
assert hasattr(args, "policy")
|
||||
assert hasattr(args, "cache_threshold")
|
||||
|
||||
def test_config_with_none_values(self):
|
||||
"""Test configuration with None values."""
|
||||
args = RouterArgs(
|
||||
api_key=None,
|
||||
log_dir=None,
|
||||
log_level=None,
|
||||
prometheus_port=None,
|
||||
prometheus_host=None,
|
||||
request_id_headers=None,
|
||||
rate_limit_tokens_per_second=None,
|
||||
service_discovery_namespace=None,
|
||||
)
|
||||
|
||||
# All None values should be preserved
|
||||
assert args.api_key is None
|
||||
assert args.log_dir is None
|
||||
assert args.log_level is None
|
||||
assert args.prometheus_port is None
|
||||
assert args.prometheus_host is None
|
||||
assert args.request_id_headers is None
|
||||
assert args.rate_limit_tokens_per_second is None
|
||||
assert args.service_discovery_namespace is None
|
||||
|
||||
def test_config_with_empty_lists(self):
|
||||
"""Test configuration with empty lists."""
|
||||
args = RouterArgs(
|
||||
worker_urls=[], prefill_urls=[], decode_urls=[], cors_allowed_origins=[]
|
||||
)
|
||||
|
||||
# All empty lists should be preserved
|
||||
assert args.worker_urls == []
|
||||
assert args.prefill_urls == []
|
||||
assert args.decode_urls == []
|
||||
assert args.cors_allowed_origins == []
|
||||
|
||||
def test_config_with_empty_dicts(self):
|
||||
"""Test configuration with empty dictionaries."""
|
||||
args = RouterArgs(selector={}, prefill_selector={}, decode_selector={})
|
||||
|
||||
# All empty dictionaries should be preserved
|
||||
assert args.selector == {}
|
||||
assert args.prefill_selector == {}
|
||||
assert args.decode_selector == {}
|
||||
1053
sgl-router/py_test/unit/test_startup_sequence.py
Normal file
1053
sgl-router/py_test/unit/test_startup_sequence.py
Normal file
File diff suppressed because it is too large
Load Diff
506
sgl-router/py_test/unit/test_validation.py
Normal file
506
sgl-router/py_test/unit/test_validation.py
Normal file
@@ -0,0 +1,506 @@
|
||||
"""
|
||||
Unit tests for validation logic in sglang_router.
|
||||
|
||||
These tests focus on testing the validation logic in isolation,
|
||||
including parameter validation, URL validation, and configuration validation.
|
||||
"""
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from sglang_router.launch_router import RouterArgs, launch_router
|
||||
|
||||
|
||||
class TestURLValidation:
|
||||
"""Test URL validation logic."""
|
||||
|
||||
def test_valid_worker_urls(self):
|
||||
"""Test validation of valid worker URLs."""
|
||||
valid_urls = [
|
||||
"http://worker1:8000",
|
||||
"https://worker2:8000",
|
||||
"http://localhost:8000",
|
||||
"http://127.0.0.1:8000",
|
||||
"http://192.168.1.100:8000",
|
||||
"http://worker.example.com:8000",
|
||||
]
|
||||
|
||||
for url in valid_urls:
|
||||
args = RouterArgs(worker_urls=[url])
|
||||
# Should not raise any validation errors
|
||||
assert url in args.worker_urls
|
||||
|
||||
def test_valid_prefill_urls(self):
|
||||
"""Test validation of valid prefill URLs."""
|
||||
valid_prefill_urls = [
|
||||
("http://prefill1:8000", 9000),
|
||||
("https://prefill2:8000", None),
|
||||
("http://localhost:8000", 9000),
|
||||
("http://127.0.0.1:8000", None),
|
||||
]
|
||||
|
||||
for url, bootstrap_port in valid_prefill_urls:
|
||||
args = RouterArgs(prefill_urls=[(url, bootstrap_port)])
|
||||
# Should not raise any validation errors
|
||||
assert (url, bootstrap_port) in args.prefill_urls
|
||||
|
||||
def test_valid_decode_urls(self):
|
||||
"""Test validation of valid decode URLs."""
|
||||
valid_decode_urls = [
|
||||
"http://decode1:8001",
|
||||
"https://decode2:8001",
|
||||
"http://localhost:8001",
|
||||
"http://127.0.0.1:8001",
|
||||
]
|
||||
|
||||
for url in valid_decode_urls:
|
||||
args = RouterArgs(decode_urls=[url])
|
||||
# Should not raise any validation errors
|
||||
assert url in args.decode_urls
|
||||
|
||||
def test_malformed_urls(self):
|
||||
"""Test handling of malformed URLs."""
|
||||
# Note: The current implementation doesn't validate URL format
|
||||
# This test documents the current behavior
|
||||
malformed_urls = [
|
||||
"not-a-url",
|
||||
"ftp://worker1:8000", # Wrong protocol
|
||||
"http://", # Missing host
|
||||
":8000", # Missing protocol and host
|
||||
"http://worker1", # Missing port
|
||||
]
|
||||
|
||||
for url in malformed_urls:
|
||||
args = RouterArgs(worker_urls=[url])
|
||||
# Currently, malformed URLs are accepted
|
||||
# This might be something to improve in the future
|
||||
assert url in args.worker_urls
|
||||
|
||||
|
||||
class TestPortValidation:
|
||||
"""Test port validation logic."""
|
||||
|
||||
def test_valid_ports(self):
|
||||
"""Test validation of valid port numbers."""
|
||||
valid_ports = [1, 80, 8000, 30000, 65535]
|
||||
|
||||
for port in valid_ports:
|
||||
args = RouterArgs(port=port)
|
||||
assert args.port == port
|
||||
|
||||
def test_invalid_ports(self):
|
||||
"""Test handling of invalid port numbers."""
|
||||
# Note: The current implementation doesn't validate port ranges
|
||||
# This test documents the current behavior
|
||||
invalid_ports = [0, -1, 65536, 70000]
|
||||
|
||||
for port in invalid_ports:
|
||||
args = RouterArgs(port=port)
|
||||
# Currently, invalid ports are accepted
|
||||
# This might be something to improve in the future
|
||||
assert args.port == port
|
||||
|
||||
def test_bootstrap_port_validation(self):
|
||||
"""Test validation of bootstrap ports in PD mode."""
|
||||
valid_bootstrap_ports = [1, 80, 9000, 30000, 65535, None]
|
||||
|
||||
for bootstrap_port in valid_bootstrap_ports:
|
||||
args = RouterArgs(prefill_urls=[("http://prefill1:8000", bootstrap_port)])
|
||||
assert args.prefill_urls[0][1] == bootstrap_port
|
||||
|
||||
|
||||
class TestParameterValidation:
|
||||
"""Test parameter validation logic."""
|
||||
|
||||
def test_cache_threshold_validation(self):
|
||||
"""Test cache threshold parameter validation."""
|
||||
# Valid cache thresholds
|
||||
valid_thresholds = [0.0, 0.1, 0.5, 0.9, 1.0]
|
||||
|
||||
for threshold in valid_thresholds:
|
||||
args = RouterArgs(cache_threshold=threshold)
|
||||
assert args.cache_threshold == threshold
|
||||
|
||||
def test_balance_threshold_validation(self):
|
||||
"""Test load balancing threshold parameter validation."""
|
||||
# Valid absolute thresholds
|
||||
valid_abs_thresholds = [0, 1, 32, 64, 128, 1000]
|
||||
for threshold in valid_abs_thresholds:
|
||||
args = RouterArgs(balance_abs_threshold=threshold)
|
||||
assert args.balance_abs_threshold == threshold
|
||||
|
||||
# Valid relative thresholds
|
||||
valid_rel_thresholds = [1.0, 1.1, 1.5, 2.0, 10.0]
|
||||
for threshold in valid_rel_thresholds:
|
||||
args = RouterArgs(balance_rel_threshold=threshold)
|
||||
assert args.balance_rel_threshold == threshold
|
||||
|
||||
def test_timeout_validation(self):
|
||||
"""Test timeout parameter validation."""
|
||||
# Valid timeouts
|
||||
valid_timeouts = [1, 30, 60, 300, 600, 1800, 3600]
|
||||
|
||||
for timeout in valid_timeouts:
|
||||
args = RouterArgs(
|
||||
worker_startup_timeout_secs=timeout,
|
||||
worker_startup_check_interval=timeout,
|
||||
request_timeout_secs=timeout,
|
||||
queue_timeout_secs=timeout,
|
||||
)
|
||||
assert args.worker_startup_timeout_secs == timeout
|
||||
assert args.worker_startup_check_interval == timeout
|
||||
assert args.request_timeout_secs == timeout
|
||||
assert args.queue_timeout_secs == timeout
|
||||
|
||||
def test_retry_parameter_validation(self):
|
||||
"""Test retry parameter validation."""
|
||||
# Valid retry parameters
|
||||
valid_retry_counts = [0, 1, 3, 5, 10]
|
||||
for count in valid_retry_counts:
|
||||
args = RouterArgs(retry_max_retries=count)
|
||||
assert args.retry_max_retries == count
|
||||
|
||||
# Valid backoff parameters
|
||||
valid_backoff_ms = [1, 50, 100, 1000, 30000]
|
||||
for backoff in valid_backoff_ms:
|
||||
args = RouterArgs(
|
||||
retry_initial_backoff_ms=backoff, retry_max_backoff_ms=backoff
|
||||
)
|
||||
assert args.retry_initial_backoff_ms == backoff
|
||||
assert args.retry_max_backoff_ms == backoff
|
||||
|
||||
# Valid multiplier parameters
|
||||
valid_multipliers = [1.0, 1.5, 2.0, 3.0]
|
||||
for multiplier in valid_multipliers:
|
||||
args = RouterArgs(retry_backoff_multiplier=multiplier)
|
||||
assert args.retry_backoff_multiplier == multiplier
|
||||
|
||||
# Valid jitter parameters
|
||||
valid_jitter = [0.0, 0.1, 0.2, 0.5]
|
||||
for jitter in valid_jitter:
|
||||
args = RouterArgs(retry_jitter_factor=jitter)
|
||||
assert args.retry_jitter_factor == jitter
|
||||
|
||||
def test_circuit_breaker_parameter_validation(self):
|
||||
"""Test circuit breaker parameter validation."""
|
||||
# Valid failure thresholds
|
||||
valid_failure_thresholds = [1, 3, 5, 10, 20]
|
||||
for threshold in valid_failure_thresholds:
|
||||
args = RouterArgs(cb_failure_threshold=threshold)
|
||||
assert args.cb_failure_threshold == threshold
|
||||
|
||||
# Valid success thresholds
|
||||
valid_success_thresholds = [1, 2, 3, 5]
|
||||
for threshold in valid_success_thresholds:
|
||||
args = RouterArgs(cb_success_threshold=threshold)
|
||||
assert args.cb_success_threshold == threshold
|
||||
|
||||
# Valid timeout durations
|
||||
valid_timeouts = [10, 30, 60, 120, 300]
|
||||
for timeout in valid_timeouts:
|
||||
args = RouterArgs(
|
||||
cb_timeout_duration_secs=timeout, cb_window_duration_secs=timeout
|
||||
)
|
||||
assert args.cb_timeout_duration_secs == timeout
|
||||
assert args.cb_window_duration_secs == timeout
|
||||
|
||||
def test_health_check_parameter_validation(self):
|
||||
"""Test health check parameter validation."""
|
||||
# Valid failure thresholds
|
||||
valid_failure_thresholds = [1, 2, 3, 5, 10]
|
||||
for threshold in valid_failure_thresholds:
|
||||
args = RouterArgs(health_failure_threshold=threshold)
|
||||
assert args.health_failure_threshold == threshold
|
||||
|
||||
# Valid success thresholds
|
||||
valid_success_thresholds = [1, 2, 3, 5]
|
||||
for threshold in valid_success_thresholds:
|
||||
args = RouterArgs(health_success_threshold=threshold)
|
||||
assert args.health_success_threshold == threshold
|
||||
|
||||
# Valid timeouts and intervals
|
||||
valid_times = [1, 5, 10, 30, 60, 120]
|
||||
for time_val in valid_times:
|
||||
args = RouterArgs(
|
||||
health_check_timeout_secs=time_val, health_check_interval_secs=time_val
|
||||
)
|
||||
assert args.health_check_timeout_secs == time_val
|
||||
assert args.health_check_interval_secs == time_val
|
||||
|
||||
def test_rate_limiting_parameter_validation(self):
|
||||
"""Test rate limiting parameter validation."""
|
||||
# Valid concurrent request limits
|
||||
valid_limits = [1, 10, 64, 256, 512, 1000]
|
||||
for limit in valid_limits:
|
||||
args = RouterArgs(max_concurrent_requests=limit)
|
||||
assert args.max_concurrent_requests == limit
|
||||
|
||||
# Valid queue sizes
|
||||
valid_queue_sizes = [0, 10, 50, 100, 500, 1000]
|
||||
for size in valid_queue_sizes:
|
||||
args = RouterArgs(queue_size=size)
|
||||
assert args.queue_size == size
|
||||
|
||||
# Valid token rates
|
||||
valid_rates = [1, 10, 50, 100, 500, 1000]
|
||||
for rate in valid_rates:
|
||||
args = RouterArgs(rate_limit_tokens_per_second=rate)
|
||||
assert args.rate_limit_tokens_per_second == rate
|
||||
|
||||
def test_tree_size_validation(self):
|
||||
"""Test tree size parameter validation."""
|
||||
# Valid tree sizes (powers of 2)
|
||||
valid_sizes = [2**10, 2**20, 2**24, 2**26, 2**28, 2**30]
|
||||
|
||||
for size in valid_sizes:
|
||||
args = RouterArgs(max_tree_size=size)
|
||||
assert args.max_tree_size == size
|
||||
|
||||
def test_payload_size_validation(self):
|
||||
"""Test payload size parameter validation."""
|
||||
# Valid payload sizes
|
||||
valid_sizes = [
|
||||
1024, # 1KB
|
||||
1024 * 1024, # 1MB
|
||||
10 * 1024 * 1024, # 10MB
|
||||
100 * 1024 * 1024, # 100MB
|
||||
512 * 1024 * 1024, # 512MB
|
||||
1024 * 1024 * 1024, # 1GB
|
||||
]
|
||||
|
||||
for size in valid_sizes:
|
||||
args = RouterArgs(max_payload_size=size)
|
||||
assert args.max_payload_size == size
|
||||
|
||||
|
||||
class TestConfigurationValidation:
|
||||
"""Test configuration validation logic."""
|
||||
|
||||
def test_pd_mode_validation(self):
|
||||
"""Test PD mode configuration validation."""
|
||||
# Valid PD configuration
|
||||
args = RouterArgs(
|
||||
pd_disaggregation=True,
|
||||
prefill_urls=[("http://prefill1:8000", 9000)],
|
||||
decode_urls=["http://decode1:8001"],
|
||||
)
|
||||
|
||||
assert args.pd_disaggregation is True
|
||||
assert len(args.prefill_urls) > 0
|
||||
assert len(args.decode_urls) > 0
|
||||
|
||||
def test_service_discovery_validation(self):
|
||||
"""Test service discovery configuration validation."""
|
||||
# Valid service discovery configuration
|
||||
args = RouterArgs(
|
||||
service_discovery=True,
|
||||
selector={"app": "worker", "env": "prod"},
|
||||
service_discovery_port=8080,
|
||||
service_discovery_namespace="default",
|
||||
)
|
||||
|
||||
assert args.service_discovery is True
|
||||
assert args.selector == {"app": "worker", "env": "prod"}
|
||||
assert args.service_discovery_port == 8080
|
||||
assert args.service_discovery_namespace == "default"
|
||||
|
||||
def test_pd_service_discovery_validation(self):
|
||||
"""Test PD service discovery configuration validation."""
|
||||
# Valid PD service discovery configuration
|
||||
args = RouterArgs(
|
||||
pd_disaggregation=True,
|
||||
service_discovery=True,
|
||||
prefill_selector={"app": "prefill"},
|
||||
decode_selector={"app": "decode"},
|
||||
)
|
||||
|
||||
assert args.pd_disaggregation is True
|
||||
assert args.service_discovery is True
|
||||
assert args.prefill_selector == {"app": "prefill"}
|
||||
assert args.decode_selector == {"app": "decode"}
|
||||
|
||||
def test_policy_validation(self):
|
||||
"""Test policy configuration validation."""
|
||||
# Valid policies
|
||||
valid_policies = ["random", "round_robin", "cache_aware", "power_of_two"]
|
||||
|
||||
for policy in valid_policies:
|
||||
args = RouterArgs(policy=policy)
|
||||
assert args.policy == policy
|
||||
|
||||
def test_pd_policy_validation(self):
|
||||
"""Test PD policy configuration validation."""
|
||||
# Valid PD policies
|
||||
valid_policies = ["random", "round_robin", "cache_aware", "power_of_two"]
|
||||
|
||||
for prefill_policy in valid_policies:
|
||||
for decode_policy in valid_policies:
|
||||
args = RouterArgs(
|
||||
pd_disaggregation=True,
|
||||
prefill_urls=[("http://prefill1:8000", None)],
|
||||
decode_urls=["http://decode1:8001"],
|
||||
prefill_policy=prefill_policy,
|
||||
decode_policy=decode_policy,
|
||||
)
|
||||
assert args.prefill_policy == prefill_policy
|
||||
assert args.decode_policy == decode_policy
|
||||
|
||||
def test_cors_validation(self):
|
||||
"""Test CORS configuration validation."""
|
||||
# Valid CORS origins
|
||||
valid_origins = [
|
||||
[],
|
||||
["http://localhost:3000"],
|
||||
["https://example.com"],
|
||||
["http://localhost:3000", "https://example.com"],
|
||||
["*"], # Wildcard (if supported)
|
||||
]
|
||||
|
||||
for origins in valid_origins:
|
||||
args = RouterArgs(cors_allowed_origins=origins)
|
||||
assert args.cors_allowed_origins == origins
|
||||
|
||||
def test_logging_validation(self):
|
||||
"""Test logging configuration validation."""
|
||||
# Valid log levels
|
||||
valid_log_levels = ["debug", "info", "warning", "error", "critical"]
|
||||
|
||||
for level in valid_log_levels:
|
||||
args = RouterArgs(log_level=level)
|
||||
assert args.log_level == level
|
||||
|
||||
def test_prometheus_validation(self):
|
||||
"""Test Prometheus configuration validation."""
|
||||
# Valid Prometheus configuration
|
||||
args = RouterArgs(prometheus_port=29000, prometheus_host="127.0.0.1")
|
||||
|
||||
assert args.prometheus_port == 29000
|
||||
assert args.prometheus_host == "127.0.0.1"
|
||||
|
||||
def test_tokenizer_validation(self):
|
||||
"""Test tokenizer configuration validation."""
|
||||
# Note: model_path and tokenizer_path are not available in current RouterArgs
|
||||
pytest.skip("Tokenizer configuration not available in current implementation")
|
||||
|
||||
def test_request_id_headers_validation(self):
|
||||
"""Test request ID headers configuration validation."""
|
||||
# Valid request ID headers
|
||||
valid_headers = [
|
||||
["x-request-id"],
|
||||
["x-request-id", "x-trace-id"],
|
||||
["x-request-id", "x-trace-id", "x-correlation-id"],
|
||||
["custom-header"],
|
||||
]
|
||||
|
||||
for headers in valid_headers:
|
||||
args = RouterArgs(request_id_headers=headers)
|
||||
assert args.request_id_headers == headers
|
||||
|
||||
|
||||
class TestLaunchValidation:
|
||||
"""Test launch-time validation logic."""
|
||||
|
||||
def test_pd_mode_requires_urls(self):
|
||||
"""Test that PD mode requires prefill and decode URLs."""
|
||||
# PD mode without URLs should fail
|
||||
args = RouterArgs(
|
||||
pd_disaggregation=True,
|
||||
prefill_urls=[],
|
||||
decode_urls=[],
|
||||
service_discovery=False,
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="PD disaggregation mode requires --prefill"
|
||||
):
|
||||
launch_router(args)
|
||||
|
||||
def test_pd_mode_with_service_discovery_allows_empty_urls(self):
|
||||
"""Test that PD mode with service discovery allows empty URLs."""
|
||||
args = RouterArgs(
|
||||
pd_disaggregation=True,
|
||||
prefill_urls=[],
|
||||
decode_urls=[],
|
||||
service_discovery=True,
|
||||
)
|
||||
|
||||
# Should not raise validation error
|
||||
with patch("sglang_router.launch_router.Router") as router_mod:
|
||||
mock_router_instance = MagicMock()
|
||||
router_mod.from_args = MagicMock(return_value=mock_router_instance)
|
||||
|
||||
launch_router(args)
|
||||
|
||||
# Should create router instance via from_args
|
||||
router_mod.from_args.assert_called_once()
|
||||
|
||||
def test_regular_mode_allows_empty_worker_urls(self):
|
||||
"""Test that regular mode allows empty worker URLs."""
|
||||
args = RouterArgs(worker_urls=[], service_discovery=False)
|
||||
|
||||
# Should not raise validation error
|
||||
with patch("sglang_router.launch_router.Router") as router_mod:
|
||||
mock_router_instance = MagicMock()
|
||||
router_mod.from_args = MagicMock(return_value=mock_router_instance)
|
||||
|
||||
launch_router(args)
|
||||
|
||||
# Should create router instance via from_args
|
||||
router_mod.from_args.assert_called_once()
|
||||
|
||||
def test_launch_with_valid_config(self):
|
||||
"""Test launching with valid configuration."""
|
||||
args = RouterArgs(
|
||||
host="127.0.0.1",
|
||||
port=30000,
|
||||
worker_urls=["http://worker1:8000"],
|
||||
policy="cache_aware",
|
||||
)
|
||||
|
||||
# Should not raise validation error
|
||||
with patch("sglang_router.launch_router.Router") as router_mod:
|
||||
mock_router_instance = MagicMock()
|
||||
router_mod.from_args = MagicMock(return_value=mock_router_instance)
|
||||
|
||||
launch_router(args)
|
||||
|
||||
# Should create router instance via from_args
|
||||
router_mod.from_args.assert_called_once()
|
||||
|
||||
def test_launch_with_pd_config(self):
|
||||
"""Test launching with valid PD configuration."""
|
||||
args = RouterArgs(
|
||||
pd_disaggregation=True,
|
||||
prefill_urls=[("http://prefill1:8000", 9000)],
|
||||
decode_urls=["http://decode1:8001"],
|
||||
policy="cache_aware",
|
||||
)
|
||||
|
||||
# Should not raise validation error
|
||||
with patch("sglang_router.launch_router.Router") as router_mod:
|
||||
mock_router_instance = MagicMock()
|
||||
router_mod.from_args = MagicMock(return_value=mock_router_instance)
|
||||
|
||||
launch_router(args)
|
||||
|
||||
# Should create router instance via from_args
|
||||
router_mod.from_args.assert_called_once()
|
||||
|
||||
def test_launch_with_service_discovery_config(self):
|
||||
"""Test launching with valid service discovery configuration."""
|
||||
args = RouterArgs(
|
||||
service_discovery=True,
|
||||
selector={"app": "worker"},
|
||||
service_discovery_port=8080,
|
||||
)
|
||||
|
||||
# Should not raise validation error
|
||||
with patch("sglang_router.launch_router.Router") as router_mod:
|
||||
mock_router_instance = MagicMock()
|
||||
router_mod.from_args = MagicMock(return_value=mock_router_instance)
|
||||
|
||||
launch_router(args)
|
||||
|
||||
# Should create router instance via from_args
|
||||
router_mod.from_args.assert_called_once()
|
||||
@@ -21,6 +21,7 @@ dev = [
|
||||
"requests>=2.25.0",
|
||||
]
|
||||
|
||||
|
||||
# https://github.com/PyO3/setuptools-rust?tab=readme-ov-file
|
||||
[tool.setuptools.packages]
|
||||
find = { where = ["py_src"] }
|
||||
|
||||
6
sgl-router/pytest.ini
Normal file
6
sgl-router/pytest.ini
Normal file
@@ -0,0 +1,6 @@
|
||||
[pytest]
|
||||
testpaths = py_test
|
||||
python_files = test_*.py
|
||||
python_classes = Test*
|
||||
python_functions = test_*
|
||||
addopts = --cov=sglang_router --cov-report=term-missing
|
||||
Reference in New Issue
Block a user