From 045ab92dc0b7a5de8c3f37411230774ffc01ee65 Mon Sep 17 00:00:00 2001 From: Keyang Ru Date: Fri, 5 Sep 2025 08:40:21 -0700 Subject: [PATCH] [router] add py binding unit tests to coverage 80% (#10043) --- .github/workflows/pr-test-rust.yml | 11 +- sgl-router/.coveragerc | 9 + sgl-router/py_test/conftest.py | 8 + sgl-router/py_test/unit/__init__.py | 7 + sgl-router/py_test/unit/test_arg_parser.py | 628 ++++++++++ sgl-router/py_test/unit/test_router_config.py | 421 +++++++ .../py_test/unit/test_startup_sequence.py | 1053 +++++++++++++++++ sgl-router/py_test/unit/test_validation.py | 506 ++++++++ sgl-router/pyproject.toml | 1 + sgl-router/pytest.ini | 6 + 10 files changed, 2649 insertions(+), 1 deletion(-) create mode 100644 sgl-router/.coveragerc create mode 100644 sgl-router/py_test/conftest.py create mode 100644 sgl-router/py_test/unit/__init__.py create mode 100644 sgl-router/py_test/unit/test_arg_parser.py create mode 100644 sgl-router/py_test/unit/test_router_config.py create mode 100644 sgl-router/py_test/unit/test_startup_sequence.py create mode 100644 sgl-router/py_test/unit/test_validation.py create mode 100644 sgl-router/pytest.ini diff --git a/.github/workflows/pr-test-rust.yml b/.github/workflows/pr-test-rust.yml index 85107ed30..319cbce70 100644 --- a/.github/workflows/pr-test-rust.yml +++ b/.github/workflows/pr-test-rust.yml @@ -39,7 +39,7 @@ jobs: cd sgl-router/ cargo fmt -- --check - - name: Run test + - name: Run Rust tests timeout-minutes: 20 run: | source "$HOME/.cargo/env" @@ -83,6 +83,15 @@ jobs: pip install setuptools-rust wheel build python3 -m build pip install --force-reinstall dist/*.whl + + + - name: Run Python unit tests + run: | + cd sgl-router + source "$HOME/.cargo/env" + pip install pytest pytest-cov pytest-xdist + pytest -q py_test/unit + - name: Run e2e test run: | bash scripts/killall_sglang.sh "nuk_gpus" diff --git a/sgl-router/.coveragerc b/sgl-router/.coveragerc new file mode 100644 index 000000000..5bab1e8d2 --- /dev/null +++ b/sgl-router/.coveragerc @@ -0,0 +1,9 @@ +[run] +source = py_src/sglang_router +omit = + py_src/sglang_router/mini_lb.py + +[report] +fail_under = 80 +omit = + py_src/sglang_router/mini_lb.py diff --git a/sgl-router/py_test/conftest.py b/sgl-router/py_test/conftest.py new file mode 100644 index 000000000..894e12bf5 --- /dev/null +++ b/sgl-router/py_test/conftest.py @@ -0,0 +1,8 @@ +import sys +from pathlib import Path + +# Ensure local sources in py_src are importable ahead of any installed package +_ROOT = Path(__file__).resolve().parents[1] +_SRC = _ROOT / "py_src" +if str(_SRC) not in sys.path: + sys.path.insert(0, str(_SRC)) diff --git a/sgl-router/py_test/unit/__init__.py b/sgl-router/py_test/unit/__init__.py new file mode 100644 index 000000000..42cbd8bee --- /dev/null +++ b/sgl-router/py_test/unit/__init__.py @@ -0,0 +1,7 @@ +""" +Unit tests for sglang_router. + +This package contains fast, isolated unit tests for Python components +of the SGLang router. These tests focus on testing individual functions +and classes in isolation without starting actual router instances. +""" diff --git a/sgl-router/py_test/unit/test_arg_parser.py b/sgl-router/py_test/unit/test_arg_parser.py new file mode 100644 index 000000000..04d8a112d --- /dev/null +++ b/sgl-router/py_test/unit/test_arg_parser.py @@ -0,0 +1,628 @@ +""" +Unit tests for argument parsing functionality in sglang_router. + +These tests focus on testing the argument parsing logic in isolation, +without starting actual router instances. +""" + +import argparse +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from sglang_router.launch_router import RouterArgs, parse_router_args +from sglang_router.router import policy_from_str + + +class TestRouterArgs: + """Test RouterArgs dataclass and its methods.""" + + def test_default_values(self): + """Test that RouterArgs has correct default values.""" + args = RouterArgs() + + # Test basic defaults + assert args.host == "127.0.0.1" + assert args.port == 30000 + assert args.policy == "cache_aware" + assert args.worker_urls == [] + assert args.pd_disaggregation is False + assert args.prefill_urls == [] + assert args.decode_urls == [] + + # Test PD-specific defaults + assert args.prefill_policy is None + assert args.decode_policy is None + + # Test service discovery defaults + assert args.service_discovery is False + assert args.selector == {} + assert args.service_discovery_port == 80 + assert args.service_discovery_namespace is None + + # Test retry and circuit breaker defaults + assert args.retry_max_retries == 5 + assert args.cb_failure_threshold == 10 + assert args.disable_retries is False + assert args.disable_circuit_breaker is False + + def test_parse_selector_valid(self): + """Test parsing valid selector arguments.""" + # Test single key-value pair + result = RouterArgs._parse_selector(["app=worker"]) + assert result == {"app": "worker"} + + # Test multiple key-value pairs + result = RouterArgs._parse_selector(["app=worker", "env=prod", "version=v1"]) + assert result == {"app": "worker", "env": "prod", "version": "v1"} + + # Test empty list + result = RouterArgs._parse_selector([]) + assert result == {} + + # Test None + result = RouterArgs._parse_selector(None) + assert result == {} + + def test_parse_selector_invalid(self): + """Test parsing invalid selector arguments.""" + # Test malformed selector (no equals sign) + result = RouterArgs._parse_selector(["app"]) + assert result == {} + + # Test multiple equals signs (should use first one) + result = RouterArgs._parse_selector(["app=worker=extra"]) + assert result == {"app": "worker=extra"} + + def test_parse_prefill_urls_valid(self): + """Test parsing valid prefill URL arguments.""" + # Test with bootstrap port + result = RouterArgs._parse_prefill_urls([["http://prefill1:8000", "9000"]]) + assert result == [("http://prefill1:8000", 9000)] + + # Test with 'none' bootstrap port + result = RouterArgs._parse_prefill_urls([["http://prefill1:8000", "none"]]) + assert result == [("http://prefill1:8000", None)] + + # Test without bootstrap port + result = RouterArgs._parse_prefill_urls([["http://prefill1:8000"]]) + assert result == [("http://prefill1:8000", None)] + + # Test multiple prefill URLs + result = RouterArgs._parse_prefill_urls( + [ + ["http://prefill1:8000", "9000"], + ["http://prefill2:8000", "none"], + ["http://prefill3:8000"], + ] + ) + expected = [ + ("http://prefill1:8000", 9000), + ("http://prefill2:8000", None), + ("http://prefill3:8000", None), + ] + assert result == expected + + # Test empty list + result = RouterArgs._parse_prefill_urls([]) + assert result == [] + + # Test None + result = RouterArgs._parse_prefill_urls(None) + assert result == [] + + def test_parse_prefill_urls_invalid(self): + """Test parsing invalid prefill URL arguments.""" + # Test invalid bootstrap port + with pytest.raises(ValueError, match="Invalid bootstrap port"): + RouterArgs._parse_prefill_urls([["http://prefill1:8000", "invalid"]]) + + def test_parse_decode_urls_valid(self): + """Test parsing valid decode URL arguments.""" + # Test single decode URL + result = RouterArgs._parse_decode_urls([["http://decode1:8001"]]) + assert result == ["http://decode1:8001"] + + # Test multiple decode URLs + result = RouterArgs._parse_decode_urls( + [["http://decode1:8001"], ["http://decode2:8001"]] + ) + assert result == ["http://decode1:8001", "http://decode2:8001"] + + # Test empty list + result = RouterArgs._parse_decode_urls([]) + assert result == [] + + # Test None + result = RouterArgs._parse_decode_urls(None) + assert result == [] + + def test_from_cli_args_basic(self): + """Test creating RouterArgs from basic CLI arguments.""" + args = SimpleNamespace( + host="0.0.0.0", + port=30001, + worker_urls=["http://worker1:8000", "http://worker2:8000"], + policy="round_robin", + prefill=None, + decode=None, + router_policy="round_robin", + router_pd_disaggregation=False, + router_prefill_policy=None, + router_decode_policy=None, + router_worker_startup_timeout_secs=300, + router_worker_startup_check_interval=15, + router_cache_threshold=0.7, + router_balance_abs_threshold=128, + router_balance_rel_threshold=2.0, + router_eviction_interval=180, + router_max_tree_size=2**28, + router_max_payload_size=1024 * 1024 * 1024, # 1GB + router_dp_aware=True, + router_api_key="test-key", + router_log_dir="/tmp/logs", + router_log_level="debug", + router_service_discovery=True, + router_selector=["app=worker", "env=test"], + router_service_discovery_port=8080, + router_service_discovery_namespace="default", + router_prefill_selector=["app=prefill"], + router_decode_selector=["app=decode"], + router_prometheus_port=29000, + router_prometheus_host="0.0.0.0", + router_request_id_headers=["x-request-id", "x-trace-id"], + router_request_timeout_secs=1200, + router_max_concurrent_requests=512, + router_queue_size=200, + router_queue_timeout_secs=120, + router_rate_limit_tokens_per_second=100, + router_cors_allowed_origins=["http://localhost:3000"], + router_retry_max_retries=3, + router_retry_initial_backoff_ms=100, + router_retry_max_backoff_ms=10000, + router_retry_backoff_multiplier=2.0, + router_retry_jitter_factor=0.1, + router_cb_failure_threshold=5, + router_cb_success_threshold=2, + router_cb_timeout_duration_secs=30, + router_cb_window_duration_secs=60, + router_disable_retries=False, + router_disable_circuit_breaker=False, + router_health_failure_threshold=2, + router_health_success_threshold=1, + router_health_check_timeout_secs=3, + router_health_check_interval_secs=30, + router_health_check_endpoint="/healthz", + ) + + router_args = RouterArgs.from_cli_args(args, use_router_prefix=True) + + # Test basic configuration + assert router_args.host == "0.0.0.0" + assert router_args.port == 30001 + assert router_args.worker_urls == ["http://worker1:8000", "http://worker2:8000"] + assert router_args.policy == "round_robin" + + # Test PD configuration + assert router_args.pd_disaggregation is False + assert router_args.prefill_urls == [] + assert router_args.decode_urls == [] + + # Test service discovery + assert router_args.service_discovery is True + assert router_args.selector == {"app": "worker", "env": "test"} + assert router_args.service_discovery_port == 8080 + assert router_args.service_discovery_namespace == "default" + assert router_args.prefill_selector == {"app": "prefill"} + assert router_args.decode_selector == {"app": "decode"} + + # Test other configurations + assert router_args.dp_aware is True + assert router_args.api_key == "test-key" + assert router_args.log_dir == "/tmp/logs" + assert router_args.log_level == "debug" + assert router_args.prometheus_port == 29000 + assert router_args.prometheus_host == "0.0.0.0" + assert router_args.request_id_headers == ["x-request-id", "x-trace-id"] + assert router_args.request_timeout_secs == 1200 + assert router_args.max_concurrent_requests == 512 + assert router_args.queue_size == 200 + assert router_args.queue_timeout_secs == 120 + assert router_args.rate_limit_tokens_per_second == 100 + assert router_args.cors_allowed_origins == ["http://localhost:3000"] + + # Test retry configuration + assert router_args.retry_max_retries == 3 + assert router_args.retry_initial_backoff_ms == 100 + assert router_args.retry_max_backoff_ms == 10000 + assert router_args.retry_backoff_multiplier == 2.0 + assert router_args.retry_jitter_factor == 0.1 + + # Test circuit breaker configuration + assert router_args.cb_failure_threshold == 5 + assert router_args.cb_success_threshold == 2 + assert router_args.cb_timeout_duration_secs == 30 + assert router_args.cb_window_duration_secs == 60 + assert router_args.disable_retries is False + assert router_args.disable_circuit_breaker is False + + # Test health check configuration + assert router_args.health_failure_threshold == 2 + assert router_args.health_success_threshold == 1 + assert router_args.health_check_timeout_secs == 3 + assert router_args.health_check_interval_secs == 30 + assert router_args.health_check_endpoint == "/healthz" + + # Note: model_path and tokenizer_path are not available in current RouterArgs + + def test_from_cli_args_pd_mode(self): + """Test creating RouterArgs from CLI arguments in PD mode.""" + args = SimpleNamespace( + host="127.0.0.1", + port=30000, + worker_urls=[], + policy="cache_aware", + prefill=[ + ["http://prefill1:8000", "9000"], + ["http://prefill2:8000", "none"], + ], + decode=[["http://decode1:8001"], ["http://decode2:8001"]], + router_prefill=[ + ["http://prefill1:8000", "9000"], + ["http://prefill2:8000", "none"], + ], + router_decode=[["http://decode1:8001"], ["http://decode2:8001"]], + router_policy="cache_aware", + router_pd_disaggregation=True, + router_prefill_policy="power_of_two", + router_decode_policy="round_robin", + # Include all required fields with defaults + router_worker_startup_timeout_secs=600, + router_worker_startup_check_interval=30, + router_cache_threshold=0.3, + router_balance_abs_threshold=64, + router_balance_rel_threshold=1.5, + router_eviction_interval=120, + router_max_tree_size=2**26, + router_max_payload_size=512 * 1024 * 1024, + router_dp_aware=False, + router_api_key=None, + router_log_dir=None, + router_log_level=None, + router_service_discovery=False, + router_selector=None, + router_service_discovery_port=80, + router_service_discovery_namespace=None, + router_prefill_selector=None, + router_decode_selector=None, + router_prometheus_port=None, + router_prometheus_host=None, + router_request_id_headers=None, + router_request_timeout_secs=1800, + router_max_concurrent_requests=256, + router_queue_size=100, + router_queue_timeout_secs=60, + router_rate_limit_tokens_per_second=None, + router_cors_allowed_origins=[], + router_retry_max_retries=5, + router_retry_initial_backoff_ms=50, + router_retry_max_backoff_ms=30000, + router_retry_backoff_multiplier=1.5, + router_retry_jitter_factor=0.2, + router_cb_failure_threshold=10, + router_cb_success_threshold=3, + router_cb_timeout_duration_secs=60, + router_cb_window_duration_secs=120, + router_disable_retries=False, + router_disable_circuit_breaker=False, + router_health_failure_threshold=3, + router_health_success_threshold=2, + router_health_check_timeout_secs=5, + router_health_check_interval_secs=60, + router_health_check_endpoint="/health", + ) + + router_args = RouterArgs.from_cli_args(args, use_router_prefix=True) + + # Test PD configuration + assert router_args.pd_disaggregation is True + assert router_args.prefill_urls == [ + ("http://prefill1:8000", 9000), + ("http://prefill2:8000", None), + ] + assert router_args.decode_urls == ["http://decode1:8001", "http://decode2:8001"] + assert router_args.prefill_policy == "power_of_two" + assert router_args.decode_policy == "round_robin" + assert router_args.policy == "cache_aware" # Main policy still set + + def test_from_cli_args_without_prefix(self): + """Test creating RouterArgs from CLI arguments without router prefix.""" + args = SimpleNamespace( + host="127.0.0.1", + port=30000, + worker_urls=["http://worker1:8000"], + policy="random", + prefill=None, + decode=None, + pd_disaggregation=False, + prefill_policy=None, + decode_policy=None, + worker_startup_timeout_secs=600, + worker_startup_check_interval=30, + cache_threshold=0.3, + balance_abs_threshold=64, + balance_rel_threshold=1.5, + eviction_interval=120, + max_tree_size=2**26, + max_payload_size=512 * 1024 * 1024, + dp_aware=False, + api_key=None, + log_dir=None, + log_level=None, + service_discovery=False, + selector=None, + service_discovery_port=80, + service_discovery_namespace=None, + prefill_selector=None, + decode_selector=None, + prometheus_port=None, + prometheus_host=None, + request_id_headers=None, + request_timeout_secs=1800, + max_concurrent_requests=256, + queue_size=100, + queue_timeout_secs=60, + rate_limit_tokens_per_second=None, + cors_allowed_origins=[], + retry_max_retries=5, + retry_initial_backoff_ms=50, + retry_max_backoff_ms=30000, + retry_backoff_multiplier=1.5, + retry_jitter_factor=0.2, + cb_failure_threshold=10, + cb_success_threshold=3, + cb_timeout_duration_secs=60, + cb_window_duration_secs=120, + disable_retries=False, + disable_circuit_breaker=False, + health_failure_threshold=3, + health_success_threshold=2, + health_check_timeout_secs=5, + health_check_interval_secs=60, + health_check_endpoint="/health", + model_path=None, + tokenizer_path=None, + ) + + router_args = RouterArgs.from_cli_args(args, use_router_prefix=False) + + assert router_args.host == "127.0.0.1" + assert router_args.port == 30000 + assert router_args.worker_urls == ["http://worker1:8000"] + assert router_args.policy == "random" + assert router_args.pd_disaggregation is False + + +class TestPolicyFromStr: + """Test policy string to enum conversion.""" + + def test_valid_policies(self): + """Test conversion of valid policy strings.""" + from sglang_router_rs import PolicyType + + assert policy_from_str("random") == PolicyType.Random + assert policy_from_str("round_robin") == PolicyType.RoundRobin + assert policy_from_str("cache_aware") == PolicyType.CacheAware + assert policy_from_str("power_of_two") == PolicyType.PowerOfTwo + + def test_invalid_policy(self): + """Test conversion of invalid policy string.""" + with pytest.raises(KeyError): + policy_from_str("invalid_policy") + + +class TestParseRouterArgs: + """Test the parse_router_args function.""" + + def test_parse_basic_args(self): + """Test parsing basic router arguments.""" + args = [ + "--host", + "0.0.0.0", + "--port", + "30001", + "--worker-urls", + "http://worker1:8000", + "http://worker2:8000", + "--policy", + "round_robin", + ] + + router_args = parse_router_args(args) + + assert router_args.host == "0.0.0.0" + assert router_args.port == 30001 + assert router_args.worker_urls == ["http://worker1:8000", "http://worker2:8000"] + assert router_args.policy == "round_robin" + + def test_parse_pd_args(self): + """Test parsing PD disaggregated mode arguments.""" + args = [ + "--pd-disaggregation", + "--prefill", + "http://prefill1:8000", + "9000", + "--prefill", + "http://prefill2:8000", + "none", + "--decode", + "http://decode1:8001", + "--decode", + "http://decode2:8001", + "--prefill-policy", + "power_of_two", + "--decode-policy", + "round_robin", + ] + + router_args = parse_router_args(args) + + assert router_args.pd_disaggregation is True + assert router_args.prefill_urls == [ + ("http://prefill1:8000", 9000), + ("http://prefill2:8000", None), + ] + assert router_args.decode_urls == ["http://decode1:8001", "http://decode2:8001"] + assert router_args.prefill_policy == "power_of_two" + assert router_args.decode_policy == "round_robin" + + def test_parse_service_discovery_args(self): + """Test parsing service discovery arguments.""" + args = [ + "--service-discovery", + "--selector", + "app=worker", + "env=prod", + "--service-discovery-port", + "8080", + "--service-discovery-namespace", + "default", + ] + + router_args = parse_router_args(args) + + assert router_args.service_discovery is True + assert router_args.selector == {"app": "worker", "env": "prod"} + assert router_args.service_discovery_port == 8080 + assert router_args.service_discovery_namespace == "default" + + def test_parse_retry_and_circuit_breaker_args(self): + """Test parsing retry and circuit breaker arguments.""" + args = [ + "--retry-max-retries", + "3", + "--retry-initial-backoff-ms", + "100", + "--retry-max-backoff-ms", + "10000", + "--retry-backoff-multiplier", + "2.0", + "--retry-jitter-factor", + "0.1", + "--disable-retries", + "--cb-failure-threshold", + "5", + "--cb-success-threshold", + "2", + "--cb-timeout-duration-secs", + "30", + "--cb-window-duration-secs", + "60", + "--disable-circuit-breaker", + ] + + router_args = parse_router_args(args) + + # Test retry configuration + assert router_args.retry_max_retries == 3 + assert router_args.retry_initial_backoff_ms == 100 + assert router_args.retry_max_backoff_ms == 10000 + assert router_args.retry_backoff_multiplier == 2.0 + assert router_args.retry_jitter_factor == 0.1 + assert router_args.disable_retries is True + + # Test circuit breaker configuration + assert router_args.cb_failure_threshold == 5 + assert router_args.cb_success_threshold == 2 + assert router_args.cb_timeout_duration_secs == 30 + assert router_args.cb_window_duration_secs == 60 + assert router_args.disable_circuit_breaker is True + + def test_parse_rate_limiting_args(self): + """Test parsing rate limiting arguments.""" + args = [ + "--max-concurrent-requests", + "512", + "--queue-size", + "200", + "--queue-timeout-secs", + "120", + "--rate-limit-tokens-per-second", + "100", + ] + + router_args = parse_router_args(args) + + assert router_args.max_concurrent_requests == 512 + assert router_args.queue_size == 200 + assert router_args.queue_timeout_secs == 120 + assert router_args.rate_limit_tokens_per_second == 100 + + def test_parse_health_check_args(self): + """Test parsing health check arguments.""" + args = [ + "--health-failure-threshold", + "2", + "--health-success-threshold", + "1", + "--health-check-timeout-secs", + "3", + "--health-check-interval-secs", + "30", + "--health-check-endpoint", + "/healthz", + ] + + router_args = parse_router_args(args) + + assert router_args.health_failure_threshold == 2 + assert router_args.health_success_threshold == 1 + assert router_args.health_check_timeout_secs == 3 + assert router_args.health_check_interval_secs == 30 + assert router_args.health_check_endpoint == "/healthz" + + def test_parse_cors_args(self): + """Test parsing CORS arguments.""" + args = [ + "--cors-allowed-origins", + "http://localhost:3000", + "https://example.com", + ] + + router_args = parse_router_args(args) + + assert router_args.cors_allowed_origins == [ + "http://localhost:3000", + "https://example.com", + ] + + def test_parse_tokenizer_args(self): + """Test parsing tokenizer arguments.""" + # Note: model-path and tokenizer-path arguments are not available in current implementation + # This test is skipped until those arguments are added + pytest.skip("Tokenizer arguments not available in current implementation") + + def test_parse_invalid_args(self): + """Test parsing invalid arguments.""" + # Test invalid policy + with pytest.raises(SystemExit): + parse_router_args(["--policy", "invalid_policy"]) + + # Test invalid bootstrap port + with pytest.raises(ValueError, match="Invalid bootstrap port"): + parse_router_args( + [ + "--pd-disaggregation", + "--prefill", + "http://prefill1:8000", + "invalid_port", + ] + ) + + def test_help_output(self): + """Test that help output is generated correctly.""" + with pytest.raises(SystemExit) as exc_info: + parse_router_args(["--help"]) + + # SystemExit with code 0 indicates help was displayed + assert exc_info.value.code == 0 diff --git a/sgl-router/py_test/unit/test_router_config.py b/sgl-router/py_test/unit/test_router_config.py new file mode 100644 index 000000000..ed0d9db4b --- /dev/null +++ b/sgl-router/py_test/unit/test_router_config.py @@ -0,0 +1,421 @@ +""" +Unit tests for router configuration validation and setup. + +These tests focus on testing the router configuration logic in isolation, +including validation of configuration parameters and their interactions. +""" + +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from sglang_router.launch_router import RouterArgs, launch_router +from sglang_router.router import policy_from_str +from sglang_router_rs import PolicyType + + +class TestRouterConfigValidation: + """Test router configuration validation logic.""" + + def test_valid_basic_config(self): + """Test that a valid basic configuration passes validation.""" + args = RouterArgs( + host="127.0.0.1", + port=30000, + worker_urls=["http://worker1:8000", "http://worker2:8000"], + policy="cache_aware", + ) + + # Should not raise any exceptions + assert args.host == "127.0.0.1" + assert args.port == 30000 + assert args.worker_urls == ["http://worker1:8000", "http://worker2:8000"] + assert args.policy == "cache_aware" + + def test_valid_pd_config(self): + """Test that a valid PD configuration passes validation.""" + args = RouterArgs( + host="127.0.0.1", + port=30000, + pd_disaggregation=True, + prefill_urls=[ + ("http://prefill1:8000", 9000), + ("http://prefill2:8000", None), + ], + decode_urls=["http://decode1:8001", "http://decode2:8001"], + policy="cache_aware", + ) + + assert args.pd_disaggregation is True + assert args.prefill_urls == [ + ("http://prefill1:8000", 9000), + ("http://prefill2:8000", None), + ] + assert args.decode_urls == ["http://decode1:8001", "http://decode2:8001"] + assert args.policy == "cache_aware" + + def test_pd_config_without_urls_raises_error(self): + """Test that PD mode without URLs raises validation error.""" + args = RouterArgs( + pd_disaggregation=True, + prefill_urls=[], + decode_urls=[], + service_discovery=False, + ) + + # This should raise an error when trying to launch + with pytest.raises( + ValueError, match="PD disaggregation mode requires --prefill" + ): + launch_router(args) + + def test_pd_config_with_service_discovery_allows_empty_urls(self): + """Test that PD mode with service discovery allows empty URLs.""" + args = RouterArgs( + pd_disaggregation=True, + prefill_urls=[], + decode_urls=[], + service_discovery=True, + ) + + # Should not raise validation error when service discovery is enabled + with patch("sglang_router.launch_router.Router") as router_mod: + mock_router_instance = MagicMock() + router_mod.from_args = MagicMock(return_value=mock_router_instance) + + launch_router(args) + + # Should create router instance via from_args + router_mod.from_args.assert_called_once() + + def test_regular_mode_without_workers_allows_empty_urls(self): + """Test that regular mode allows empty worker URLs.""" + args = RouterArgs(worker_urls=[], service_discovery=False) + + # Should not raise validation error + with patch("sglang_router.launch_router.Router") as router_mod: + mock_router_instance = MagicMock() + router_mod.from_args = MagicMock(return_value=mock_router_instance) + + launch_router(args) + + # Should create router instance via from_args + router_mod.from_args.assert_called_once() + + def test_cache_threshold_validation(self): + """Test cache threshold validation.""" + # Valid cache threshold + args = RouterArgs(cache_threshold=0.5) + assert args.cache_threshold == 0.5 + + # Edge cases + args = RouterArgs(cache_threshold=0.0) + assert args.cache_threshold == 0.0 + + args = RouterArgs(cache_threshold=1.0) + assert args.cache_threshold == 1.0 + + def test_balance_threshold_validation(self): + """Test load balancing threshold validation.""" + # Valid thresholds + args = RouterArgs(balance_abs_threshold=64, balance_rel_threshold=1.5) + assert args.balance_abs_threshold == 64 + assert args.balance_rel_threshold == 1.5 + + # Edge cases + args = RouterArgs(balance_abs_threshold=0, balance_rel_threshold=1.0) + assert args.balance_abs_threshold == 0 + assert args.balance_rel_threshold == 1.0 + + def test_timeout_validation(self): + """Test timeout parameter validation.""" + # Valid timeouts + args = RouterArgs( + worker_startup_timeout_secs=600, + worker_startup_check_interval=30, + request_timeout_secs=1800, + queue_timeout_secs=60, + ) + assert args.worker_startup_timeout_secs == 600 + assert args.worker_startup_check_interval == 30 + assert args.request_timeout_secs == 1800 + assert args.queue_timeout_secs == 60 + + def test_retry_config_validation(self): + """Test retry configuration validation.""" + # Valid retry config + args = RouterArgs( + retry_max_retries=5, + retry_initial_backoff_ms=50, + retry_max_backoff_ms=30000, + retry_backoff_multiplier=1.5, + retry_jitter_factor=0.2, + disable_retries=False, + ) + assert args.retry_max_retries == 5 + assert args.retry_initial_backoff_ms == 50 + assert args.retry_max_backoff_ms == 30000 + assert args.retry_backoff_multiplier == 1.5 + assert args.retry_jitter_factor == 0.2 + assert args.disable_retries is False + + def test_circuit_breaker_config_validation(self): + """Test circuit breaker configuration validation.""" + # Valid circuit breaker config + args = RouterArgs( + cb_failure_threshold=10, + cb_success_threshold=3, + cb_timeout_duration_secs=60, + cb_window_duration_secs=120, + disable_circuit_breaker=False, + ) + assert args.cb_failure_threshold == 10 + assert args.cb_success_threshold == 3 + assert args.cb_timeout_duration_secs == 60 + assert args.cb_window_duration_secs == 120 + assert args.disable_circuit_breaker is False + + def test_health_check_config_validation(self): + """Test health check configuration validation.""" + # Valid health check config + args = RouterArgs( + health_failure_threshold=3, + health_success_threshold=2, + health_check_timeout_secs=5, + health_check_interval_secs=60, + health_check_endpoint="/health", + ) + assert args.health_failure_threshold == 3 + assert args.health_success_threshold == 2 + assert args.health_check_timeout_secs == 5 + assert args.health_check_interval_secs == 60 + assert args.health_check_endpoint == "/health" + + def test_rate_limiting_config_validation(self): + """Test rate limiting configuration validation.""" + # Valid rate limiting config + args = RouterArgs( + max_concurrent_requests=256, + queue_size=100, + queue_timeout_secs=60, + rate_limit_tokens_per_second=100, + ) + assert args.max_concurrent_requests == 256 + assert args.queue_size == 100 + assert args.queue_timeout_secs == 60 + assert args.rate_limit_tokens_per_second == 100 + + def test_service_discovery_config_validation(self): + """Test service discovery configuration validation.""" + # Valid service discovery config + args = RouterArgs( + service_discovery=True, + selector={"app": "worker", "env": "prod"}, + service_discovery_port=8080, + service_discovery_namespace="default", + ) + assert args.service_discovery is True + assert args.selector == {"app": "worker", "env": "prod"} + assert args.service_discovery_port == 8080 + assert args.service_discovery_namespace == "default" + + def test_pd_service_discovery_config_validation(self): + """Test PD service discovery configuration validation.""" + # Valid PD service discovery config + args = RouterArgs( + pd_disaggregation=True, + service_discovery=True, + prefill_selector={"app": "prefill"}, + decode_selector={"app": "decode"}, + bootstrap_port_annotation="sglang.ai/bootstrap-port", + ) + assert args.pd_disaggregation is True + assert args.service_discovery is True + assert args.prefill_selector == {"app": "prefill"} + assert args.decode_selector == {"app": "decode"} + assert args.bootstrap_port_annotation == "sglang.ai/bootstrap-port" + + def test_prometheus_config_validation(self): + """Test Prometheus configuration validation.""" + # Valid Prometheus config + args = RouterArgs(prometheus_port=29000, prometheus_host="127.0.0.1") + assert args.prometheus_port == 29000 + assert args.prometheus_host == "127.0.0.1" + + def test_cors_config_validation(self): + """Test CORS configuration validation.""" + # Valid CORS config + args = RouterArgs( + cors_allowed_origins=["http://localhost:3000", "https://example.com"] + ) + assert args.cors_allowed_origins == [ + "http://localhost:3000", + "https://example.com", + ] + + def test_tokenizer_config_validation(self): + """Test tokenizer configuration validation.""" + # Note: model_path and tokenizer_path are not available in current RouterArgs + pytest.skip("Tokenizer configuration not available in current implementation") + + def test_dp_aware_config_validation(self): + """Test data parallelism aware configuration validation.""" + # Valid DP aware config + args = RouterArgs(dp_aware=True, api_key="test-api-key") + assert args.dp_aware is True + assert args.api_key == "test-api-key" + + def test_request_id_headers_validation(self): + """Test request ID headers configuration validation.""" + # Valid request ID headers config + args = RouterArgs( + request_id_headers=["x-request-id", "x-trace-id", "x-correlation-id"] + ) + assert args.request_id_headers == [ + "x-request-id", + "x-trace-id", + "x-correlation-id", + ] + + def test_policy_consistency_validation(self): + """Test policy consistency validation in PD mode.""" + # Test with both prefill and decode policies specified + args = RouterArgs( + pd_disaggregation=True, + prefill_urls=[("http://prefill1:8000", None)], + decode_urls=["http://decode1:8001"], + policy="cache_aware", + prefill_policy="power_of_two", + decode_policy="round_robin", + ) + + # Should not raise validation error + with patch("sglang_router.launch_router.Router") as router_mod: + mock_router_instance = MagicMock() + router_mod.from_args = MagicMock(return_value=mock_router_instance) + + launch_router(args) + + # Should create router instance via from_args + router_mod.from_args.assert_called_once() + + def test_policy_fallback_validation(self): + """Test policy fallback validation in PD mode.""" + # Test with only prefill policy specified + args = RouterArgs( + pd_disaggregation=True, + prefill_urls=[("http://prefill1:8000", None)], + decode_urls=["http://decode1:8001"], + policy="cache_aware", + prefill_policy="power_of_two", + decode_policy=None, + ) + + # Should not raise validation error + with patch("sglang_router.launch_router.Router") as router_mod: + mock_router_instance = MagicMock() + router_mod.from_args = MagicMock(return_value=mock_router_instance) + + launch_router(args) + + # Should create router instance via from_args + router_mod.from_args.assert_called_once() + + def test_policy_enum_conversion(self): + """Test policy string to enum conversion.""" + # Test all valid policy conversions + assert policy_from_str("random") == PolicyType.Random + assert policy_from_str("round_robin") == PolicyType.RoundRobin + assert policy_from_str("cache_aware") == PolicyType.CacheAware + assert policy_from_str("power_of_two") == PolicyType.PowerOfTwo + + def test_invalid_policy_enum_conversion(self): + """Test invalid policy string to enum conversion.""" + with pytest.raises(KeyError): + policy_from_str("invalid_policy") + + def test_config_immutability(self): + """Test that configuration objects are properly immutable.""" + args = RouterArgs( + host="127.0.0.1", port=30000, worker_urls=["http://worker1:8000"] + ) + + # Test that we can't modify the configuration after creation + # (This is more of a design test - dataclasses are mutable by default) + original_host = args.host + args.host = "0.0.0.0" + assert args.host == "0.0.0.0" # Dataclasses are mutable + assert args.host != original_host + + def test_config_defaults_consistency(self): + """Test that configuration defaults are consistent.""" + args1 = RouterArgs() + args2 = RouterArgs() + + # Both instances should have the same defaults + assert args1.host == args2.host + assert args1.port == args2.port + assert args1.policy == args2.policy + assert args1.worker_urls == args2.worker_urls + assert args1.pd_disaggregation == args2.pd_disaggregation + + def test_config_serialization(self): + """Test that configuration can be serialized/deserialized.""" + args = RouterArgs( + host="127.0.0.1", + port=30000, + worker_urls=["http://worker1:8000"], + policy="cache_aware", + cache_threshold=0.5, + ) + + # Test that we can access all attributes + assert hasattr(args, "host") + assert hasattr(args, "port") + assert hasattr(args, "worker_urls") + assert hasattr(args, "policy") + assert hasattr(args, "cache_threshold") + + def test_config_with_none_values(self): + """Test configuration with None values.""" + args = RouterArgs( + api_key=None, + log_dir=None, + log_level=None, + prometheus_port=None, + prometheus_host=None, + request_id_headers=None, + rate_limit_tokens_per_second=None, + service_discovery_namespace=None, + ) + + # All None values should be preserved + assert args.api_key is None + assert args.log_dir is None + assert args.log_level is None + assert args.prometheus_port is None + assert args.prometheus_host is None + assert args.request_id_headers is None + assert args.rate_limit_tokens_per_second is None + assert args.service_discovery_namespace is None + + def test_config_with_empty_lists(self): + """Test configuration with empty lists.""" + args = RouterArgs( + worker_urls=[], prefill_urls=[], decode_urls=[], cors_allowed_origins=[] + ) + + # All empty lists should be preserved + assert args.worker_urls == [] + assert args.prefill_urls == [] + assert args.decode_urls == [] + assert args.cors_allowed_origins == [] + + def test_config_with_empty_dicts(self): + """Test configuration with empty dictionaries.""" + args = RouterArgs(selector={}, prefill_selector={}, decode_selector={}) + + # All empty dictionaries should be preserved + assert args.selector == {} + assert args.prefill_selector == {} + assert args.decode_selector == {} diff --git a/sgl-router/py_test/unit/test_startup_sequence.py b/sgl-router/py_test/unit/test_startup_sequence.py new file mode 100644 index 000000000..133c7eb16 --- /dev/null +++ b/sgl-router/py_test/unit/test_startup_sequence.py @@ -0,0 +1,1053 @@ +""" +Unit tests for startup sequence logic in sglang_router. + +These tests focus on testing the startup sequence logic in isolation, +including router initialization, configuration validation, and startup flow. +""" + +import logging +from types import SimpleNamespace +from unittest.mock import MagicMock, call, patch + +import pytest +from sglang_router.launch_router import RouterArgs, launch_router +from sglang_router.router import policy_from_str + + +# Local helper mirroring the router logger setup used in production +def setup_logger(): + logger = logging.getLogger("router") + logger.setLevel(logging.INFO) + if not logger.handlers: + formatter = logging.Formatter( + "[Router (Python)] %(asctime)s - %(levelname)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + handler = logging.StreamHandler() + handler.setFormatter(formatter) + logger.addHandler(handler) + return logger + + +from sglang_router_rs import PolicyType + + +class TestSetupLogger: + """Test logger setup functionality.""" + + def test_setup_logger_returns_logger(self): + """Test that setup_logger returns a logger instance.""" + logger = setup_logger() + + assert isinstance(logger, logging.Logger) + assert logger.name == "router" + assert logger.level == logging.INFO + + def test_setup_logger_has_handler(self): + """Test that setup_logger configures a handler.""" + logger = setup_logger() + + assert len(logger.handlers) > 0 + handler = logger.handlers[0] + assert isinstance(handler, logging.StreamHandler) + + def test_setup_logger_has_formatter(self): + """Test that setup_logger configures a formatter.""" + logger = setup_logger() + + handler = logger.handlers[0] + formatter = handler.formatter + + assert formatter is not None + assert "[Router (Python)]" in formatter._fmt + + def test_setup_logger_multiple_calls(self): + """Test that multiple calls to setup_logger work correctly.""" + logger1 = setup_logger() + logger2 = setup_logger() + + # Should return the same logger instance + assert logger1 is logger2 + + +class TestPolicyFromStr: + """Test policy string to enum conversion in startup context.""" + + def test_policy_conversion_in_startup(self): + """Test policy conversion during startup sequence.""" + # Test all valid policies + policies = ["random", "round_robin", "cache_aware", "power_of_two"] + expected_enums = [ + PolicyType.Random, + PolicyType.RoundRobin, + PolicyType.CacheAware, + PolicyType.PowerOfTwo, + ] + + for policy_str, expected_enum in zip(policies, expected_enums): + result = policy_from_str(policy_str) + assert result == expected_enum + + def test_invalid_policy_in_startup(self): + """Test handling of invalid policy during startup.""" + with pytest.raises(KeyError): + policy_from_str("invalid_policy") + + +class TestRouterInitialization: + """Test router initialization logic.""" + + def test_router_initialization_basic(self): + """Test basic router initialization.""" + args = RouterArgs( + host="127.0.0.1", + port=30000, + worker_urls=["http://worker1:8000"], + policy="cache_aware", + ) + + with patch("sglang_router.launch_router.Router") as router_mod: + captured_args = {} + + mock_router_instance = MagicMock() + + def fake_from_args(router_args): + # capture needed fields from RouterArgs + captured_args.update( + dict( + host=router_args.host, + port=router_args.port, + worker_urls=router_args.worker_urls, + policy=policy_from_str(router_args.policy), + ) + ) + return mock_router_instance + + router_mod.from_args = MagicMock(side_effect=fake_from_args) + + result = launch_router(args) + + # Verify Router.from_args was called and captured fields match + router_mod.from_args.assert_called_once() + assert captured_args["host"] == "127.0.0.1" + assert captured_args["port"] == 30000 + assert captured_args["worker_urls"] == ["http://worker1:8000"] + assert captured_args["policy"] == PolicyType.CacheAware + + # Verify router.start() was called + mock_router_instance.start.assert_called_once() + + # Function returns None; ensure start was invoked + + def test_router_initialization_pd_mode(self): + """Test router initialization in PD mode.""" + args = RouterArgs( + pd_disaggregation=True, + prefill_urls=[("http://prefill1:8000", 9000)], + decode_urls=["http://decode1:8001"], + policy="power_of_two", + ) + + with patch("sglang_router.launch_router.Router") as router_mod: + captured_args = {} + mock_router_instance = MagicMock() + + def fake_from_args(router_args): + captured_args.update( + dict( + pd_disaggregation=router_args.pd_disaggregation, + prefill_urls=router_args.prefill_urls, + decode_urls=router_args.decode_urls, + policy=policy_from_str(router_args.policy), + ) + ) + return mock_router_instance + + router_mod.from_args = MagicMock(side_effect=fake_from_args) + + result = launch_router(args) + + # Verify Router.from_args was called with PD parameters + router_mod.from_args.assert_called_once() + assert captured_args["pd_disaggregation"] is True + assert captured_args["prefill_urls"] == [("http://prefill1:8000", 9000)] + assert captured_args["decode_urls"] == ["http://decode1:8001"] + assert captured_args["policy"] == PolicyType.PowerOfTwo + + # Verify router.start() was called + mock_router_instance.start.assert_called_once() + + # Function returns None; ensure start was invoked + + def test_router_initialization_with_service_discovery(self): + """Test router initialization with service discovery.""" + args = RouterArgs( + service_discovery=True, + selector={"app": "worker", "env": "prod"}, + service_discovery_port=8080, + service_discovery_namespace="default", + ) + + with patch("sglang_router.launch_router.Router") as router_mod: + captured_args = {} + mock_router_instance = MagicMock() + + def fake_from_args(router_args): + captured_args.update( + dict( + service_discovery=router_args.service_discovery, + selector=router_args.selector, + service_discovery_port=router_args.service_discovery_port, + service_discovery_namespace=router_args.service_discovery_namespace, + ) + ) + return mock_router_instance + + router_mod.from_args = MagicMock(side_effect=fake_from_args) + + result = launch_router(args) + + # Verify Router.from_args was called with service discovery parameters + router_mod.from_args.assert_called_once() + assert captured_args["service_discovery"] is True + assert captured_args["selector"] == {"app": "worker", "env": "prod"} + assert captured_args["service_discovery_port"] == 8080 + assert captured_args["service_discovery_namespace"] == "default" + + # Verify router.start() was called + mock_router_instance.start.assert_called_once() + + # Function returns None; ensure start was invoked + + def test_router_initialization_with_retry_config(self): + """Test router initialization with retry configuration.""" + args = RouterArgs( + retry_max_retries=3, + retry_initial_backoff_ms=100, + retry_max_backoff_ms=10000, + retry_backoff_multiplier=2.0, + retry_jitter_factor=0.1, + disable_retries=False, + ) + + with patch("sglang_router.launch_router.Router") as router_mod: + captured_args = {} + mock_router_instance = MagicMock() + + def fake_from_args(router_args): + captured_args.update( + dict( + retry_max_retries=router_args.retry_max_retries, + retry_initial_backoff_ms=router_args.retry_initial_backoff_ms, + retry_max_backoff_ms=router_args.retry_max_backoff_ms, + retry_backoff_multiplier=router_args.retry_backoff_multiplier, + retry_jitter_factor=router_args.retry_jitter_factor, + disable_retries=router_args.disable_retries, + ) + ) + return mock_router_instance + + router_mod.from_args = MagicMock(side_effect=fake_from_args) + + result = launch_router(args) + + # Verify router was created with retry parameters + router_mod.from_args.assert_called_once() + assert captured_args["retry_max_retries"] == 3 + assert captured_args["retry_initial_backoff_ms"] == 100 + assert captured_args["retry_max_backoff_ms"] == 10000 + assert captured_args["retry_backoff_multiplier"] == 2.0 + assert captured_args["retry_jitter_factor"] == 0.1 + assert captured_args["disable_retries"] is False + + # Verify router.start() was called + mock_router_instance.start.assert_called_once() + + # Function returns None; ensure start was invoked + + def test_router_initialization_with_circuit_breaker_config(self): + """Test router initialization with circuit breaker configuration.""" + args = RouterArgs( + cb_failure_threshold=5, + cb_success_threshold=2, + cb_timeout_duration_secs=30, + cb_window_duration_secs=60, + disable_circuit_breaker=False, + ) + + with patch("sglang_router.launch_router.Router") as router_mod: + captured_args = {} + mock_router_instance = MagicMock() + + def fake_from_args(router_args): + captured_args.update( + dict( + cb_failure_threshold=router_args.cb_failure_threshold, + cb_success_threshold=router_args.cb_success_threshold, + cb_timeout_duration_secs=router_args.cb_timeout_duration_secs, + cb_window_duration_secs=router_args.cb_window_duration_secs, + disable_circuit_breaker=router_args.disable_circuit_breaker, + ) + ) + return mock_router_instance + + router_mod.from_args = MagicMock(side_effect=fake_from_args) + + result = launch_router(args) + + # Verify router was created with circuit breaker parameters + router_mod.from_args.assert_called_once() + assert captured_args["cb_failure_threshold"] == 5 + assert captured_args["cb_success_threshold"] == 2 + assert captured_args["cb_timeout_duration_secs"] == 30 + assert captured_args["cb_window_duration_secs"] == 60 + assert captured_args["disable_circuit_breaker"] is False + + # Verify router.start() was called + mock_router_instance.start.assert_called_once() + + # Function returns None; ensure start was invoked + + def test_router_initialization_with_rate_limiting_config(self): + """Test router initialization with rate limiting configuration.""" + args = RouterArgs( + max_concurrent_requests=512, + queue_size=200, + queue_timeout_secs=120, + rate_limit_tokens_per_second=100, + ) + + with patch("sglang_router.launch_router.Router") as router_mod: + captured_args = {} + mock_router_instance = MagicMock() + + def fake_from_args(router_args): + captured_args.update( + dict( + max_concurrent_requests=router_args.max_concurrent_requests, + queue_size=router_args.queue_size, + queue_timeout_secs=router_args.queue_timeout_secs, + rate_limit_tokens_per_second=router_args.rate_limit_tokens_per_second, + ) + ) + return mock_router_instance + + router_mod.from_args = MagicMock(side_effect=fake_from_args) + + result = launch_router(args) + + # Verify router was created with rate limiting parameters + router_mod.from_args.assert_called_once() + assert captured_args["max_concurrent_requests"] == 512 + assert captured_args["queue_size"] == 200 + assert captured_args["queue_timeout_secs"] == 120 + assert captured_args["rate_limit_tokens_per_second"] == 100 + + # Verify router.start() was called + mock_router_instance.start.assert_called_once() + + # Function returns None; ensure start was invoked + + def test_router_initialization_with_health_check_config(self): + """Test router initialization with health check configuration.""" + args = RouterArgs( + health_failure_threshold=2, + health_success_threshold=1, + health_check_timeout_secs=3, + health_check_interval_secs=30, + health_check_endpoint="/healthz", + ) + + with patch("sglang_router.launch_router.Router") as router_mod: + captured_args = {} + mock_router_instance = MagicMock() + + def fake_from_args(router_args): + captured_args.update( + dict( + health_failure_threshold=router_args.health_failure_threshold, + health_success_threshold=router_args.health_success_threshold, + health_check_timeout_secs=router_args.health_check_timeout_secs, + health_check_interval_secs=router_args.health_check_interval_secs, + health_check_endpoint=router_args.health_check_endpoint, + ) + ) + return mock_router_instance + + router_mod.from_args = MagicMock(side_effect=fake_from_args) + + result = launch_router(args) + + # Verify router was created with health check parameters + router_mod.from_args.assert_called_once() + assert captured_args["health_failure_threshold"] == 2 + assert captured_args["health_success_threshold"] == 1 + assert captured_args["health_check_timeout_secs"] == 3 + assert captured_args["health_check_interval_secs"] == 30 + assert captured_args["health_check_endpoint"] == "/healthz" + + # Verify router.start() was called + mock_router_instance.start.assert_called_once() + + # Function returns None; ensure start was invoked + + def test_router_initialization_with_prometheus_config(self): + """Test router initialization with Prometheus configuration.""" + args = RouterArgs(prometheus_port=29000, prometheus_host="127.0.0.1") + + with patch("sglang_router.launch_router.Router") as router_mod: + captured_args = {} + mock_router_instance = MagicMock() + + def fake_from_args(router_args): + captured_args.update( + dict( + prometheus_port=router_args.prometheus_port, + prometheus_host=router_args.prometheus_host, + ) + ) + return mock_router_instance + + router_mod.from_args = MagicMock(side_effect=fake_from_args) + + result = launch_router(args) + + # Verify router was created with Prometheus parameters + router_mod.from_args.assert_called_once() + assert captured_args["prometheus_port"] == 29000 + assert captured_args["prometheus_host"] == "127.0.0.1" + + # Verify router.start() was called + mock_router_instance.start.assert_called_once() + + # Function returns None; ensure start was invoked + + def test_router_initialization_with_cors_config(self): + """Test router initialization with CORS configuration.""" + args = RouterArgs( + cors_allowed_origins=["http://localhost:3000", "https://example.com"] + ) + + with patch("sglang_router.launch_router.Router") as router_mod: + captured_args = {} + mock_router_instance = MagicMock() + + def fake_from_args(router_args): + captured_args.update( + dict(cors_allowed_origins=router_args.cors_allowed_origins) + ) + return mock_router_instance + + router_mod.from_args = MagicMock(side_effect=fake_from_args) + + result = launch_router(args) + + # Verify router was created with CORS parameters + router_mod.from_args.assert_called_once() + assert captured_args["cors_allowed_origins"] == [ + "http://localhost:3000", + "https://example.com", + ] + + # Verify router.start() was called + mock_router_instance.start.assert_called_once() + + # Function returns None; ensure start was invoked + + def test_router_initialization_with_tokenizer_config(self): + """Test router initialization with tokenizer configuration.""" + # Note: model_path and tokenizer_path are not available in current RouterArgs + pytest.skip("Tokenizer configuration not available in current implementation") + + +class TestStartupValidation: + """Test startup validation logic.""" + + def test_pd_mode_validation_during_startup(self): + """Test PD mode validation during startup.""" + # PD mode without URLs should fail + args = RouterArgs( + pd_disaggregation=True, + prefill_urls=[], + decode_urls=[], + service_discovery=False, + ) + + with pytest.raises( + ValueError, match="PD disaggregation mode requires --prefill" + ): + launch_router(args) + + def test_pd_mode_with_service_discovery_validation(self): + """Test PD mode with service discovery validation during startup.""" + args = RouterArgs( + pd_disaggregation=True, + prefill_urls=[], + decode_urls=[], + service_discovery=True, + ) + + # Should not raise validation error + with patch("sglang_router.launch_router.Router") as router_mod: + mock_router_instance = MagicMock() + + router_mod.from_args = MagicMock(return_value=mock_router_instance) + + result = launch_router(args) + + # Should create router instance + router_mod.from_args.assert_called_once() + + def test_policy_warning_during_startup(self): + """Test policy warning during startup in PD mode.""" + args = RouterArgs( + pd_disaggregation=True, + prefill_urls=[("http://prefill1:8000", None)], + decode_urls=["http://decode1:8001"], + policy="cache_aware", + prefill_policy="power_of_two", + decode_policy="round_robin", + ) + + with patch("sglang_router.launch_router.Router") as router_mod: + mock_router_instance = MagicMock() + router_mod.from_args = MagicMock(return_value=mock_router_instance) + + # The policy messages are emitted by router_args logger + with patch("sglang_router.router_args.logger") as mock_logger: + result = launch_router(args) + + # Should log warning about policy usage + mock_logger.warning.assert_called_once() + warning_call = mock_logger.warning.call_args[0][0] + assert ( + "Both --prefill-policy and --decode-policy are specified" + in warning_call + ) + + # Should create router instance + router_mod.from_args.assert_called_once() + + def test_policy_info_during_startup(self): + """Test policy info logging during startup in PD mode.""" + # Test with only prefill policy specified + args = RouterArgs( + pd_disaggregation=True, + prefill_urls=[("http://prefill1:8000", None)], + decode_urls=["http://decode1:8001"], + policy="cache_aware", + prefill_policy="power_of_two", + decode_policy=None, + ) + + with patch("sglang_router.launch_router.Router") as router_mod: + mock_router_instance = MagicMock() + router_mod.from_args = MagicMock(return_value=mock_router_instance) + + # The policy messages are emitted by router_args logger + with patch("sglang_router.router_args.logger") as mock_logger: + result = launch_router(args) + + # Should log info about policy usage + mock_logger.info.assert_called_once() + info_call = mock_logger.info.call_args[0][0] + assert "Using --prefill-policy 'power_of_two'" in info_call + assert "and --policy 'cache_aware'" in info_call + + # Should create router instance + router_mod.from_args.assert_called_once() + + def test_policy_info_decode_only_during_startup(self): + """Test policy info logging during startup with only decode policy specified.""" + args = RouterArgs( + pd_disaggregation=True, + prefill_urls=[("http://prefill1:8000", None)], + decode_urls=["http://decode1:8001"], + policy="cache_aware", + prefill_policy=None, + decode_policy="round_robin", + ) + + with patch("sglang_router.launch_router.Router") as router_mod: + mock_router_instance = MagicMock() + router_mod.from_args = MagicMock(return_value=mock_router_instance) + + # The policy messages are emitted by router_args logger + with patch("sglang_router.router_args.logger") as mock_logger: + result = launch_router(args) + + # Should log info about policy usage + mock_logger.info.assert_called_once() + info_call = mock_logger.info.call_args[0][0] + assert "Using --policy 'cache_aware'" in info_call + assert "and --decode-policy 'round_robin'" in info_call + + # Should create router instance + router_mod.from_args.assert_called_once() + + +class TestStartupErrorHandling: + """Test startup error handling logic.""" + + def test_router_creation_error_handling(self): + """Test error handling when router creation fails.""" + args = RouterArgs( + host="127.0.0.1", port=30000, worker_urls=["http://worker1:8000"] + ) + + with patch("sglang_router.launch_router.Router") as router_mod: + # Simulate router creation failure in from_args + router_mod.from_args = MagicMock( + side_effect=Exception("Router creation failed") + ) + + with patch("sglang_router.launch_router.logger") as mock_logger: + with pytest.raises(Exception, match="Router creation failed"): + launch_router(args) + + # Should log error + mock_logger.error.assert_called_once() + error_call = mock_logger.error.call_args[0][0] + assert "Error starting router: Router creation failed" in error_call + + def test_router_start_error_handling(self): + """Test error handling when router start fails.""" + args = RouterArgs( + host="127.0.0.1", port=30000, worker_urls=["http://worker1:8000"] + ) + + with patch("sglang_router.launch_router.Router") as router_mod: + mock_router_instance = MagicMock() + router_mod.from_args = MagicMock(return_value=mock_router_instance) + + # Simulate router start failure + mock_router_instance.start.side_effect = Exception("Router start failed") + + with patch("sglang_router.launch_router.logger") as mock_logger: + with pytest.raises(Exception, match="Router start failed"): + launch_router(args) + + # Should log error + mock_logger.error.assert_called_once() + error_call = mock_logger.error.call_args[0][0] + assert "Error starting router: Router start failed" in error_call + + +# --- Added unit tests for Router wrapper and launch_server helpers --- + + +def _install_sglang_stubs(monkeypatch): + """Install lightweight stubs for sglang.srt to avoid heavy deps during unit tests.""" + import sys + import types + + sglang_mod = types.ModuleType("sglang") + srt_mod = types.ModuleType("sglang.srt") + entry_mod = types.ModuleType("sglang.srt.entrypoints") + http_server_mod = types.ModuleType("sglang.srt.entrypoints.http_server") + server_args_mod = types.ModuleType("sglang.srt.server_args") + utils_mod = types.ModuleType("sglang.srt.utils") + + def launch_server(_args): + return None + + class ServerArgs: + # Minimal fields used by launch_server_process + def __init__(self): + self.port = 0 + self.base_gpu_id = 0 + self.dp_size = 1 + self.tp_size = 1 + + @staticmethod + def add_cli_args(_parser): + return None + + @staticmethod + def from_cli_args(_args): + sa = ServerArgs() + if hasattr(_args, "dp_size"): + sa.dp_size = _args.dp_size + if hasattr(_args, "tp_size"): + sa.tp_size = _args.tp_size + if hasattr(_args, "host"): + sa.host = _args.host + else: + sa.host = "127.0.0.1" + return sa + + def is_port_available(_port: int) -> bool: + return True + + http_server_mod.launch_server = launch_server + server_args_mod.ServerArgs = ServerArgs + utils_mod.is_port_available = is_port_available + + # Also stub external deps imported at module top-level + def _dummy_get(*_a, **_k): + raise NotImplementedError + + requests_stub = types.SimpleNamespace( + exceptions=types.SimpleNamespace(RequestException=Exception), get=_dummy_get + ) + setproctitle_stub = types.SimpleNamespace(setproctitle=lambda *_a, **_k: None) + + monkeypatch.setitem(sys.modules, "requests", requests_stub) + monkeypatch.setitem(sys.modules, "setproctitle", setproctitle_stub) + + monkeypatch.setitem(sys.modules, "sglang", sglang_mod) + monkeypatch.setitem(sys.modules, "sglang.srt", srt_mod) + monkeypatch.setitem(sys.modules, "sglang.srt.entrypoints", entry_mod) + monkeypatch.setitem( + sys.modules, "sglang.srt.entrypoints.http_server", http_server_mod + ) + monkeypatch.setitem(sys.modules, "sglang.srt.server_args", server_args_mod) + monkeypatch.setitem(sys.modules, "sglang.srt.utils", utils_mod) + + +def test_router_defaults_and_start(monkeypatch): + """Router wrapper: defaults normalization and start() call. + + Mocks the Rust-backed _Router to avoid native deps. + """ + from sglang_router import router as router_mod + + captured = {} + + class FakeRouter: + def __init__(self, **kwargs): + captured.update(kwargs) + + def start(self): + captured["started"] = True + + monkeypatch.setattr(router_mod, "_Router", FakeRouter, raising=True) + + from sglang_router.router_args import RouterArgs as _RouterArgs + + Router = router_mod.Router + args = _RouterArgs( + worker_urls=["http://w1:8000"], + policy="round_robin", + selector=None, + prefill_selector=None, + decode_selector=None, + cors_allowed_origins=None, + ) + + r = Router.from_args(args) + + # Defaults preserved/normalized by Router.from_args + assert captured["selector"] is None + assert captured["prefill_selector"] is None + assert captured["decode_selector"] is None + assert captured["cors_allowed_origins"] is None + assert captured["worker_urls"] == ["http://w1:8000"] + from sglang_router_rs import PolicyType + + assert captured["policy"] == PolicyType.RoundRobin + + r.start() + assert captured.get("started") is True + + +def test_find_available_ports_and_wait_health(monkeypatch): + """launch_server helpers: port finding and health waiting with transient error.""" + _install_sglang_stubs(monkeypatch) + import importlib + + ls = importlib.import_module("sglang_router.launch_server") + + # Deterministic increments + monkeypatch.setattr(ls.random, "randint", lambda a, b: 100) + ports = ls.find_available_ports(30000, 3) + assert ports == [30000, 30100, 30200] + + calls = {"n": 0} + + class Ok: + status_code = 200 + + def fake_get(_url, timeout=5): + calls["n"] += 1 + if calls["n"] == 1: + raise ls.requests.exceptions.RequestException("boom") + return Ok() + + monkeypatch.setattr(ls.requests, "get", fake_get) + monkeypatch.setattr(ls.time, "sleep", lambda _s: None) + base = {"t": 0.0} + monkeypatch.setattr( + ls.time, + "perf_counter", + lambda: (base.__setitem__("t", base["t"] + 0.1) or base["t"]), + ) + + assert ls.wait_for_server_health("127.0.0.1", 12345, timeout=1) + + +def test_launch_server_process_and_cleanup(monkeypatch): + """launch_server: process creation args and cleanup SIGTERM/SIGKILL logic.""" + _install_sglang_stubs(monkeypatch) + import importlib + + ls = importlib.import_module("sglang_router.launch_server") + + created = {} + + class FakeProcess: + def __init__(self, target, args): + created["target"] = target + created["args"] = args + self.pid = 4242 + self._alive = True + + def start(self): + created["started"] = True + + def join(self, timeout=None): + return None + + def is_alive(self): + return self._alive + + monkeypatch.setattr(ls.mp, "Process", FakeProcess) + + import sys as _sys + + SA = _sys.modules["sglang.srt.server_args"].ServerArgs + sa = SA() + sa.tp_size = 2 + + proc = ls.launch_server_process(sa, worker_port=31001, dp_id=3) + assert created.get("started") is True + targ, targ_args = created["target"], created["args"] + assert targ is ls.run_server + passed_sa = targ_args[0] + assert passed_sa.port == 31001 + assert passed_sa.base_gpu_id == 3 * 2 + assert passed_sa.dp_size == 1 + + # cleanup_processes + p1 = FakeProcess(target=None, args=()) + p1._alive = False + p2 = FakeProcess(target=None, args=()) + p2._alive = True + + calls = [] + + def fake_killpg(pid, sig): + calls.append((pid, sig)) + + monkeypatch.setattr(ls.os, "killpg", fake_killpg) + + ls.cleanup_processes([p1, p2]) + + import signal as _sig + + assert (p1.pid, _sig.SIGTERM) in calls and (p2.pid, _sig.SIGTERM) in calls + assert (p2.pid, _sig.SIGKILL) in calls + + def test_validation_error_handling(self): + """Test error handling when validation fails.""" + args = RouterArgs( + pd_disaggregation=True, + prefill_urls=[], + decode_urls=[], + service_discovery=False, + ) + + with patch("sglang_router.launch_router.logger") as mock_logger: + + with pytest.raises( + ValueError, match="PD disaggregation mode requires --prefill" + ): + launch_router(args) + + # Should log error for validation failures + mock_logger.error.assert_called_once() + + +class TestStartupFlow: + """Test complete startup flow.""" + + def test_complete_startup_flow_basic(self): + """Test complete startup flow for basic configuration.""" + args = RouterArgs( + host="127.0.0.1", + port=30000, + worker_urls=["http://worker1:8000", "http://worker2:8000"], + policy="cache_aware", + cache_threshold=0.5, + balance_abs_threshold=32, + balance_rel_threshold=1.5, + ) + + with patch("sglang_router.launch_router.Router") as router_mod: + mock_router_instance = MagicMock() + router_mod.from_args = MagicMock(return_value=mock_router_instance) + + result = launch_router(args) + + # Verify complete flow + router_mod.from_args.assert_called_once() + mock_router_instance.start.assert_called_once() + + def test_complete_startup_flow_pd_mode(self): + """Test complete startup flow for PD mode configuration.""" + args = RouterArgs( + pd_disaggregation=True, + prefill_urls=[ + ("http://prefill1:8000", 9000), + ("http://prefill2:8000", None), + ], + decode_urls=["http://decode1:8001", "http://decode2:8001"], + policy="power_of_two", + prefill_policy="cache_aware", + decode_policy="round_robin", + ) + + with patch("sglang_router.launch_router.Router") as router_mod: + mock_router_instance = MagicMock() + router_mod.from_args = MagicMock(return_value=mock_router_instance) + + with patch("sglang_router.router_args.logger") as mock_logger: + result = launch_router(args) + + # Verify complete flow + router_mod.from_args.assert_called_once() + mock_router_instance.start.assert_called_once() + + # Verify policy warning was logged + mock_logger.warning.assert_called_once() + + def test_complete_startup_flow_with_all_features(self): + """Test complete startup flow with all features enabled.""" + args = RouterArgs( + host="0.0.0.0", + port=30001, + worker_urls=["http://worker1:8000"], + policy="round_robin", + service_discovery=True, + selector={"app": "worker"}, + service_discovery_port=8080, + service_discovery_namespace="default", + dp_aware=True, + api_key="test-key", + log_dir="/tmp/logs", + log_level="debug", + prometheus_port=29000, + prometheus_host="0.0.0.0", + request_id_headers=["x-request-id", "x-trace-id"], + request_timeout_secs=1200, + max_concurrent_requests=512, + queue_size=200, + queue_timeout_secs=120, + rate_limit_tokens_per_second=100, + cors_allowed_origins=["http://localhost:3000"], + retry_max_retries=3, + retry_initial_backoff_ms=100, + retry_max_backoff_ms=10000, + retry_backoff_multiplier=2.0, + retry_jitter_factor=0.1, + cb_failure_threshold=5, + cb_success_threshold=2, + cb_timeout_duration_secs=30, + cb_window_duration_secs=60, + health_failure_threshold=2, + health_success_threshold=1, + health_check_timeout_secs=3, + health_check_interval_secs=30, + health_check_endpoint="/healthz", + ) + + with patch("sglang_router.launch_router.Router") as router_mod: + captured_args = {} + mock_router_instance = MagicMock() + + def fake_from_args(router_args): + captured_args.update( + dict( + host=router_args.host, + port=router_args.port, + worker_urls=router_args.worker_urls, + policy=policy_from_str(router_args.policy), + service_discovery=router_args.service_discovery, + selector=router_args.selector, + service_discovery_port=router_args.service_discovery_port, + service_discovery_namespace=router_args.service_discovery_namespace, + dp_aware=router_args.dp_aware, + api_key=router_args.api_key, + log_dir=router_args.log_dir, + log_level=router_args.log_level, + prometheus_port=router_args.prometheus_port, + prometheus_host=router_args.prometheus_host, + request_id_headers=router_args.request_id_headers, + request_timeout_secs=router_args.request_timeout_secs, + max_concurrent_requests=router_args.max_concurrent_requests, + queue_size=router_args.queue_size, + queue_timeout_secs=router_args.queue_timeout_secs, + rate_limit_tokens_per_second=router_args.rate_limit_tokens_per_second, + cors_allowed_origins=router_args.cors_allowed_origins, + retry_max_retries=router_args.retry_max_retries, + retry_initial_backoff_ms=router_args.retry_initial_backoff_ms, + retry_max_backoff_ms=router_args.retry_max_backoff_ms, + retry_backoff_multiplier=router_args.retry_backoff_multiplier, + retry_jitter_factor=router_args.retry_jitter_factor, + cb_failure_threshold=router_args.cb_failure_threshold, + cb_success_threshold=router_args.cb_success_threshold, + cb_timeout_duration_secs=router_args.cb_timeout_duration_secs, + cb_window_duration_secs=router_args.cb_window_duration_secs, + health_failure_threshold=router_args.health_failure_threshold, + health_success_threshold=router_args.health_success_threshold, + health_check_timeout_secs=router_args.health_check_timeout_secs, + health_check_interval_secs=router_args.health_check_interval_secs, + health_check_endpoint=router_args.health_check_endpoint, + ) + ) + return mock_router_instance + + router_mod.from_args = MagicMock(side_effect=fake_from_args) + + result = launch_router(args) + + # Verify complete flow + router_mod.from_args.assert_called_once() + mock_router_instance.start.assert_called_once() + + # Verify key parameters were propagated into RouterArgs + assert captured_args["host"] == "0.0.0.0" + assert captured_args["port"] == 30001 + assert captured_args["worker_urls"] == ["http://worker1:8000"] + assert captured_args["policy"] == PolicyType.RoundRobin + assert captured_args["service_discovery"] is True + assert captured_args["selector"] == {"app": "worker"} + assert captured_args["service_discovery_port"] == 8080 + assert captured_args["service_discovery_namespace"] == "default" + assert captured_args["dp_aware"] is True + assert captured_args["api_key"] == "test-key" + assert captured_args["log_dir"] == "/tmp/logs" + assert captured_args["log_level"] == "debug" + assert captured_args["prometheus_port"] == 29000 + assert captured_args["prometheus_host"] == "0.0.0.0" + assert captured_args["request_id_headers"] == ["x-request-id", "x-trace-id"] + assert captured_args["request_timeout_secs"] == 1200 + assert captured_args["max_concurrent_requests"] == 512 + assert captured_args["queue_size"] == 200 + assert captured_args["queue_timeout_secs"] == 120 + assert captured_args["rate_limit_tokens_per_second"] == 100 + assert captured_args["cors_allowed_origins"] == ["http://localhost:3000"] + assert captured_args["retry_max_retries"] == 3 + assert captured_args["retry_initial_backoff_ms"] == 100 + assert captured_args["retry_max_backoff_ms"] == 10000 + assert captured_args["retry_backoff_multiplier"] == 2.0 + assert captured_args["retry_jitter_factor"] == 0.1 + assert captured_args["cb_failure_threshold"] == 5 + assert captured_args["cb_success_threshold"] == 2 + assert captured_args["cb_timeout_duration_secs"] == 30 + assert captured_args["cb_window_duration_secs"] == 60 + assert captured_args["health_failure_threshold"] == 2 + assert captured_args["health_success_threshold"] == 1 + assert captured_args["health_check_timeout_secs"] == 3 + assert captured_args["health_check_interval_secs"] == 30 + assert captured_args["health_check_endpoint"] == "/healthz" diff --git a/sgl-router/py_test/unit/test_validation.py b/sgl-router/py_test/unit/test_validation.py new file mode 100644 index 000000000..1a3e54612 --- /dev/null +++ b/sgl-router/py_test/unit/test_validation.py @@ -0,0 +1,506 @@ +""" +Unit tests for validation logic in sglang_router. + +These tests focus on testing the validation logic in isolation, +including parameter validation, URL validation, and configuration validation. +""" + +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from sglang_router.launch_router import RouterArgs, launch_router + + +class TestURLValidation: + """Test URL validation logic.""" + + def test_valid_worker_urls(self): + """Test validation of valid worker URLs.""" + valid_urls = [ + "http://worker1:8000", + "https://worker2:8000", + "http://localhost:8000", + "http://127.0.0.1:8000", + "http://192.168.1.100:8000", + "http://worker.example.com:8000", + ] + + for url in valid_urls: + args = RouterArgs(worker_urls=[url]) + # Should not raise any validation errors + assert url in args.worker_urls + + def test_valid_prefill_urls(self): + """Test validation of valid prefill URLs.""" + valid_prefill_urls = [ + ("http://prefill1:8000", 9000), + ("https://prefill2:8000", None), + ("http://localhost:8000", 9000), + ("http://127.0.0.1:8000", None), + ] + + for url, bootstrap_port in valid_prefill_urls: + args = RouterArgs(prefill_urls=[(url, bootstrap_port)]) + # Should not raise any validation errors + assert (url, bootstrap_port) in args.prefill_urls + + def test_valid_decode_urls(self): + """Test validation of valid decode URLs.""" + valid_decode_urls = [ + "http://decode1:8001", + "https://decode2:8001", + "http://localhost:8001", + "http://127.0.0.1:8001", + ] + + for url in valid_decode_urls: + args = RouterArgs(decode_urls=[url]) + # Should not raise any validation errors + assert url in args.decode_urls + + def test_malformed_urls(self): + """Test handling of malformed URLs.""" + # Note: The current implementation doesn't validate URL format + # This test documents the current behavior + malformed_urls = [ + "not-a-url", + "ftp://worker1:8000", # Wrong protocol + "http://", # Missing host + ":8000", # Missing protocol and host + "http://worker1", # Missing port + ] + + for url in malformed_urls: + args = RouterArgs(worker_urls=[url]) + # Currently, malformed URLs are accepted + # This might be something to improve in the future + assert url in args.worker_urls + + +class TestPortValidation: + """Test port validation logic.""" + + def test_valid_ports(self): + """Test validation of valid port numbers.""" + valid_ports = [1, 80, 8000, 30000, 65535] + + for port in valid_ports: + args = RouterArgs(port=port) + assert args.port == port + + def test_invalid_ports(self): + """Test handling of invalid port numbers.""" + # Note: The current implementation doesn't validate port ranges + # This test documents the current behavior + invalid_ports = [0, -1, 65536, 70000] + + for port in invalid_ports: + args = RouterArgs(port=port) + # Currently, invalid ports are accepted + # This might be something to improve in the future + assert args.port == port + + def test_bootstrap_port_validation(self): + """Test validation of bootstrap ports in PD mode.""" + valid_bootstrap_ports = [1, 80, 9000, 30000, 65535, None] + + for bootstrap_port in valid_bootstrap_ports: + args = RouterArgs(prefill_urls=[("http://prefill1:8000", bootstrap_port)]) + assert args.prefill_urls[0][1] == bootstrap_port + + +class TestParameterValidation: + """Test parameter validation logic.""" + + def test_cache_threshold_validation(self): + """Test cache threshold parameter validation.""" + # Valid cache thresholds + valid_thresholds = [0.0, 0.1, 0.5, 0.9, 1.0] + + for threshold in valid_thresholds: + args = RouterArgs(cache_threshold=threshold) + assert args.cache_threshold == threshold + + def test_balance_threshold_validation(self): + """Test load balancing threshold parameter validation.""" + # Valid absolute thresholds + valid_abs_thresholds = [0, 1, 32, 64, 128, 1000] + for threshold in valid_abs_thresholds: + args = RouterArgs(balance_abs_threshold=threshold) + assert args.balance_abs_threshold == threshold + + # Valid relative thresholds + valid_rel_thresholds = [1.0, 1.1, 1.5, 2.0, 10.0] + for threshold in valid_rel_thresholds: + args = RouterArgs(balance_rel_threshold=threshold) + assert args.balance_rel_threshold == threshold + + def test_timeout_validation(self): + """Test timeout parameter validation.""" + # Valid timeouts + valid_timeouts = [1, 30, 60, 300, 600, 1800, 3600] + + for timeout in valid_timeouts: + args = RouterArgs( + worker_startup_timeout_secs=timeout, + worker_startup_check_interval=timeout, + request_timeout_secs=timeout, + queue_timeout_secs=timeout, + ) + assert args.worker_startup_timeout_secs == timeout + assert args.worker_startup_check_interval == timeout + assert args.request_timeout_secs == timeout + assert args.queue_timeout_secs == timeout + + def test_retry_parameter_validation(self): + """Test retry parameter validation.""" + # Valid retry parameters + valid_retry_counts = [0, 1, 3, 5, 10] + for count in valid_retry_counts: + args = RouterArgs(retry_max_retries=count) + assert args.retry_max_retries == count + + # Valid backoff parameters + valid_backoff_ms = [1, 50, 100, 1000, 30000] + for backoff in valid_backoff_ms: + args = RouterArgs( + retry_initial_backoff_ms=backoff, retry_max_backoff_ms=backoff + ) + assert args.retry_initial_backoff_ms == backoff + assert args.retry_max_backoff_ms == backoff + + # Valid multiplier parameters + valid_multipliers = [1.0, 1.5, 2.0, 3.0] + for multiplier in valid_multipliers: + args = RouterArgs(retry_backoff_multiplier=multiplier) + assert args.retry_backoff_multiplier == multiplier + + # Valid jitter parameters + valid_jitter = [0.0, 0.1, 0.2, 0.5] + for jitter in valid_jitter: + args = RouterArgs(retry_jitter_factor=jitter) + assert args.retry_jitter_factor == jitter + + def test_circuit_breaker_parameter_validation(self): + """Test circuit breaker parameter validation.""" + # Valid failure thresholds + valid_failure_thresholds = [1, 3, 5, 10, 20] + for threshold in valid_failure_thresholds: + args = RouterArgs(cb_failure_threshold=threshold) + assert args.cb_failure_threshold == threshold + + # Valid success thresholds + valid_success_thresholds = [1, 2, 3, 5] + for threshold in valid_success_thresholds: + args = RouterArgs(cb_success_threshold=threshold) + assert args.cb_success_threshold == threshold + + # Valid timeout durations + valid_timeouts = [10, 30, 60, 120, 300] + for timeout in valid_timeouts: + args = RouterArgs( + cb_timeout_duration_secs=timeout, cb_window_duration_secs=timeout + ) + assert args.cb_timeout_duration_secs == timeout + assert args.cb_window_duration_secs == timeout + + def test_health_check_parameter_validation(self): + """Test health check parameter validation.""" + # Valid failure thresholds + valid_failure_thresholds = [1, 2, 3, 5, 10] + for threshold in valid_failure_thresholds: + args = RouterArgs(health_failure_threshold=threshold) + assert args.health_failure_threshold == threshold + + # Valid success thresholds + valid_success_thresholds = [1, 2, 3, 5] + for threshold in valid_success_thresholds: + args = RouterArgs(health_success_threshold=threshold) + assert args.health_success_threshold == threshold + + # Valid timeouts and intervals + valid_times = [1, 5, 10, 30, 60, 120] + for time_val in valid_times: + args = RouterArgs( + health_check_timeout_secs=time_val, health_check_interval_secs=time_val + ) + assert args.health_check_timeout_secs == time_val + assert args.health_check_interval_secs == time_val + + def test_rate_limiting_parameter_validation(self): + """Test rate limiting parameter validation.""" + # Valid concurrent request limits + valid_limits = [1, 10, 64, 256, 512, 1000] + for limit in valid_limits: + args = RouterArgs(max_concurrent_requests=limit) + assert args.max_concurrent_requests == limit + + # Valid queue sizes + valid_queue_sizes = [0, 10, 50, 100, 500, 1000] + for size in valid_queue_sizes: + args = RouterArgs(queue_size=size) + assert args.queue_size == size + + # Valid token rates + valid_rates = [1, 10, 50, 100, 500, 1000] + for rate in valid_rates: + args = RouterArgs(rate_limit_tokens_per_second=rate) + assert args.rate_limit_tokens_per_second == rate + + def test_tree_size_validation(self): + """Test tree size parameter validation.""" + # Valid tree sizes (powers of 2) + valid_sizes = [2**10, 2**20, 2**24, 2**26, 2**28, 2**30] + + for size in valid_sizes: + args = RouterArgs(max_tree_size=size) + assert args.max_tree_size == size + + def test_payload_size_validation(self): + """Test payload size parameter validation.""" + # Valid payload sizes + valid_sizes = [ + 1024, # 1KB + 1024 * 1024, # 1MB + 10 * 1024 * 1024, # 10MB + 100 * 1024 * 1024, # 100MB + 512 * 1024 * 1024, # 512MB + 1024 * 1024 * 1024, # 1GB + ] + + for size in valid_sizes: + args = RouterArgs(max_payload_size=size) + assert args.max_payload_size == size + + +class TestConfigurationValidation: + """Test configuration validation logic.""" + + def test_pd_mode_validation(self): + """Test PD mode configuration validation.""" + # Valid PD configuration + args = RouterArgs( + pd_disaggregation=True, + prefill_urls=[("http://prefill1:8000", 9000)], + decode_urls=["http://decode1:8001"], + ) + + assert args.pd_disaggregation is True + assert len(args.prefill_urls) > 0 + assert len(args.decode_urls) > 0 + + def test_service_discovery_validation(self): + """Test service discovery configuration validation.""" + # Valid service discovery configuration + args = RouterArgs( + service_discovery=True, + selector={"app": "worker", "env": "prod"}, + service_discovery_port=8080, + service_discovery_namespace="default", + ) + + assert args.service_discovery is True + assert args.selector == {"app": "worker", "env": "prod"} + assert args.service_discovery_port == 8080 + assert args.service_discovery_namespace == "default" + + def test_pd_service_discovery_validation(self): + """Test PD service discovery configuration validation.""" + # Valid PD service discovery configuration + args = RouterArgs( + pd_disaggregation=True, + service_discovery=True, + prefill_selector={"app": "prefill"}, + decode_selector={"app": "decode"}, + ) + + assert args.pd_disaggregation is True + assert args.service_discovery is True + assert args.prefill_selector == {"app": "prefill"} + assert args.decode_selector == {"app": "decode"} + + def test_policy_validation(self): + """Test policy configuration validation.""" + # Valid policies + valid_policies = ["random", "round_robin", "cache_aware", "power_of_two"] + + for policy in valid_policies: + args = RouterArgs(policy=policy) + assert args.policy == policy + + def test_pd_policy_validation(self): + """Test PD policy configuration validation.""" + # Valid PD policies + valid_policies = ["random", "round_robin", "cache_aware", "power_of_two"] + + for prefill_policy in valid_policies: + for decode_policy in valid_policies: + args = RouterArgs( + pd_disaggregation=True, + prefill_urls=[("http://prefill1:8000", None)], + decode_urls=["http://decode1:8001"], + prefill_policy=prefill_policy, + decode_policy=decode_policy, + ) + assert args.prefill_policy == prefill_policy + assert args.decode_policy == decode_policy + + def test_cors_validation(self): + """Test CORS configuration validation.""" + # Valid CORS origins + valid_origins = [ + [], + ["http://localhost:3000"], + ["https://example.com"], + ["http://localhost:3000", "https://example.com"], + ["*"], # Wildcard (if supported) + ] + + for origins in valid_origins: + args = RouterArgs(cors_allowed_origins=origins) + assert args.cors_allowed_origins == origins + + def test_logging_validation(self): + """Test logging configuration validation.""" + # Valid log levels + valid_log_levels = ["debug", "info", "warning", "error", "critical"] + + for level in valid_log_levels: + args = RouterArgs(log_level=level) + assert args.log_level == level + + def test_prometheus_validation(self): + """Test Prometheus configuration validation.""" + # Valid Prometheus configuration + args = RouterArgs(prometheus_port=29000, prometheus_host="127.0.0.1") + + assert args.prometheus_port == 29000 + assert args.prometheus_host == "127.0.0.1" + + def test_tokenizer_validation(self): + """Test tokenizer configuration validation.""" + # Note: model_path and tokenizer_path are not available in current RouterArgs + pytest.skip("Tokenizer configuration not available in current implementation") + + def test_request_id_headers_validation(self): + """Test request ID headers configuration validation.""" + # Valid request ID headers + valid_headers = [ + ["x-request-id"], + ["x-request-id", "x-trace-id"], + ["x-request-id", "x-trace-id", "x-correlation-id"], + ["custom-header"], + ] + + for headers in valid_headers: + args = RouterArgs(request_id_headers=headers) + assert args.request_id_headers == headers + + +class TestLaunchValidation: + """Test launch-time validation logic.""" + + def test_pd_mode_requires_urls(self): + """Test that PD mode requires prefill and decode URLs.""" + # PD mode without URLs should fail + args = RouterArgs( + pd_disaggregation=True, + prefill_urls=[], + decode_urls=[], + service_discovery=False, + ) + + with pytest.raises( + ValueError, match="PD disaggregation mode requires --prefill" + ): + launch_router(args) + + def test_pd_mode_with_service_discovery_allows_empty_urls(self): + """Test that PD mode with service discovery allows empty URLs.""" + args = RouterArgs( + pd_disaggregation=True, + prefill_urls=[], + decode_urls=[], + service_discovery=True, + ) + + # Should not raise validation error + with patch("sglang_router.launch_router.Router") as router_mod: + mock_router_instance = MagicMock() + router_mod.from_args = MagicMock(return_value=mock_router_instance) + + launch_router(args) + + # Should create router instance via from_args + router_mod.from_args.assert_called_once() + + def test_regular_mode_allows_empty_worker_urls(self): + """Test that regular mode allows empty worker URLs.""" + args = RouterArgs(worker_urls=[], service_discovery=False) + + # Should not raise validation error + with patch("sglang_router.launch_router.Router") as router_mod: + mock_router_instance = MagicMock() + router_mod.from_args = MagicMock(return_value=mock_router_instance) + + launch_router(args) + + # Should create router instance via from_args + router_mod.from_args.assert_called_once() + + def test_launch_with_valid_config(self): + """Test launching with valid configuration.""" + args = RouterArgs( + host="127.0.0.1", + port=30000, + worker_urls=["http://worker1:8000"], + policy="cache_aware", + ) + + # Should not raise validation error + with patch("sglang_router.launch_router.Router") as router_mod: + mock_router_instance = MagicMock() + router_mod.from_args = MagicMock(return_value=mock_router_instance) + + launch_router(args) + + # Should create router instance via from_args + router_mod.from_args.assert_called_once() + + def test_launch_with_pd_config(self): + """Test launching with valid PD configuration.""" + args = RouterArgs( + pd_disaggregation=True, + prefill_urls=[("http://prefill1:8000", 9000)], + decode_urls=["http://decode1:8001"], + policy="cache_aware", + ) + + # Should not raise validation error + with patch("sglang_router.launch_router.Router") as router_mod: + mock_router_instance = MagicMock() + router_mod.from_args = MagicMock(return_value=mock_router_instance) + + launch_router(args) + + # Should create router instance via from_args + router_mod.from_args.assert_called_once() + + def test_launch_with_service_discovery_config(self): + """Test launching with valid service discovery configuration.""" + args = RouterArgs( + service_discovery=True, + selector={"app": "worker"}, + service_discovery_port=8080, + ) + + # Should not raise validation error + with patch("sglang_router.launch_router.Router") as router_mod: + mock_router_instance = MagicMock() + router_mod.from_args = MagicMock(return_value=mock_router_instance) + + launch_router(args) + + # Should create router instance via from_args + router_mod.from_args.assert_called_once() diff --git a/sgl-router/pyproject.toml b/sgl-router/pyproject.toml index bd0314aec..9a7606f6a 100644 --- a/sgl-router/pyproject.toml +++ b/sgl-router/pyproject.toml @@ -21,6 +21,7 @@ dev = [ "requests>=2.25.0", ] + # https://github.com/PyO3/setuptools-rust?tab=readme-ov-file [tool.setuptools.packages] find = { where = ["py_src"] } diff --git a/sgl-router/pytest.ini b/sgl-router/pytest.ini new file mode 100644 index 000000000..d28b847e6 --- /dev/null +++ b/sgl-router/pytest.ini @@ -0,0 +1,6 @@ +[pytest] +testpaths = py_test +python_files = test_*.py +python_classes = Test* +python_functions = test_* +addopts = --cov=sglang_router --cov-report=term-missing