[router] add py binding unit tests to coverage 80% (#10043)
This commit is contained in:
506
sgl-router/py_test/unit/test_validation.py
Normal file
506
sgl-router/py_test/unit/test_validation.py
Normal file
@@ -0,0 +1,506 @@
|
||||
"""
|
||||
Unit tests for validation logic in sglang_router.
|
||||
|
||||
These tests focus on testing the validation logic in isolation,
|
||||
including parameter validation, URL validation, and configuration validation.
|
||||
"""
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from sglang_router.launch_router import RouterArgs, launch_router
|
||||
|
||||
|
||||
class TestURLValidation:
|
||||
"""Test URL validation logic."""
|
||||
|
||||
def test_valid_worker_urls(self):
|
||||
"""Test validation of valid worker URLs."""
|
||||
valid_urls = [
|
||||
"http://worker1:8000",
|
||||
"https://worker2:8000",
|
||||
"http://localhost:8000",
|
||||
"http://127.0.0.1:8000",
|
||||
"http://192.168.1.100:8000",
|
||||
"http://worker.example.com:8000",
|
||||
]
|
||||
|
||||
for url in valid_urls:
|
||||
args = RouterArgs(worker_urls=[url])
|
||||
# Should not raise any validation errors
|
||||
assert url in args.worker_urls
|
||||
|
||||
def test_valid_prefill_urls(self):
|
||||
"""Test validation of valid prefill URLs."""
|
||||
valid_prefill_urls = [
|
||||
("http://prefill1:8000", 9000),
|
||||
("https://prefill2:8000", None),
|
||||
("http://localhost:8000", 9000),
|
||||
("http://127.0.0.1:8000", None),
|
||||
]
|
||||
|
||||
for url, bootstrap_port in valid_prefill_urls:
|
||||
args = RouterArgs(prefill_urls=[(url, bootstrap_port)])
|
||||
# Should not raise any validation errors
|
||||
assert (url, bootstrap_port) in args.prefill_urls
|
||||
|
||||
def test_valid_decode_urls(self):
|
||||
"""Test validation of valid decode URLs."""
|
||||
valid_decode_urls = [
|
||||
"http://decode1:8001",
|
||||
"https://decode2:8001",
|
||||
"http://localhost:8001",
|
||||
"http://127.0.0.1:8001",
|
||||
]
|
||||
|
||||
for url in valid_decode_urls:
|
||||
args = RouterArgs(decode_urls=[url])
|
||||
# Should not raise any validation errors
|
||||
assert url in args.decode_urls
|
||||
|
||||
def test_malformed_urls(self):
|
||||
"""Test handling of malformed URLs."""
|
||||
# Note: The current implementation doesn't validate URL format
|
||||
# This test documents the current behavior
|
||||
malformed_urls = [
|
||||
"not-a-url",
|
||||
"ftp://worker1:8000", # Wrong protocol
|
||||
"http://", # Missing host
|
||||
":8000", # Missing protocol and host
|
||||
"http://worker1", # Missing port
|
||||
]
|
||||
|
||||
for url in malformed_urls:
|
||||
args = RouterArgs(worker_urls=[url])
|
||||
# Currently, malformed URLs are accepted
|
||||
# This might be something to improve in the future
|
||||
assert url in args.worker_urls
|
||||
|
||||
|
||||
class TestPortValidation:
|
||||
"""Test port validation logic."""
|
||||
|
||||
def test_valid_ports(self):
|
||||
"""Test validation of valid port numbers."""
|
||||
valid_ports = [1, 80, 8000, 30000, 65535]
|
||||
|
||||
for port in valid_ports:
|
||||
args = RouterArgs(port=port)
|
||||
assert args.port == port
|
||||
|
||||
def test_invalid_ports(self):
|
||||
"""Test handling of invalid port numbers."""
|
||||
# Note: The current implementation doesn't validate port ranges
|
||||
# This test documents the current behavior
|
||||
invalid_ports = [0, -1, 65536, 70000]
|
||||
|
||||
for port in invalid_ports:
|
||||
args = RouterArgs(port=port)
|
||||
# Currently, invalid ports are accepted
|
||||
# This might be something to improve in the future
|
||||
assert args.port == port
|
||||
|
||||
def test_bootstrap_port_validation(self):
|
||||
"""Test validation of bootstrap ports in PD mode."""
|
||||
valid_bootstrap_ports = [1, 80, 9000, 30000, 65535, None]
|
||||
|
||||
for bootstrap_port in valid_bootstrap_ports:
|
||||
args = RouterArgs(prefill_urls=[("http://prefill1:8000", bootstrap_port)])
|
||||
assert args.prefill_urls[0][1] == bootstrap_port
|
||||
|
||||
|
||||
class TestParameterValidation:
|
||||
"""Test parameter validation logic."""
|
||||
|
||||
def test_cache_threshold_validation(self):
|
||||
"""Test cache threshold parameter validation."""
|
||||
# Valid cache thresholds
|
||||
valid_thresholds = [0.0, 0.1, 0.5, 0.9, 1.0]
|
||||
|
||||
for threshold in valid_thresholds:
|
||||
args = RouterArgs(cache_threshold=threshold)
|
||||
assert args.cache_threshold == threshold
|
||||
|
||||
def test_balance_threshold_validation(self):
|
||||
"""Test load balancing threshold parameter validation."""
|
||||
# Valid absolute thresholds
|
||||
valid_abs_thresholds = [0, 1, 32, 64, 128, 1000]
|
||||
for threshold in valid_abs_thresholds:
|
||||
args = RouterArgs(balance_abs_threshold=threshold)
|
||||
assert args.balance_abs_threshold == threshold
|
||||
|
||||
# Valid relative thresholds
|
||||
valid_rel_thresholds = [1.0, 1.1, 1.5, 2.0, 10.0]
|
||||
for threshold in valid_rel_thresholds:
|
||||
args = RouterArgs(balance_rel_threshold=threshold)
|
||||
assert args.balance_rel_threshold == threshold
|
||||
|
||||
def test_timeout_validation(self):
|
||||
"""Test timeout parameter validation."""
|
||||
# Valid timeouts
|
||||
valid_timeouts = [1, 30, 60, 300, 600, 1800, 3600]
|
||||
|
||||
for timeout in valid_timeouts:
|
||||
args = RouterArgs(
|
||||
worker_startup_timeout_secs=timeout,
|
||||
worker_startup_check_interval=timeout,
|
||||
request_timeout_secs=timeout,
|
||||
queue_timeout_secs=timeout,
|
||||
)
|
||||
assert args.worker_startup_timeout_secs == timeout
|
||||
assert args.worker_startup_check_interval == timeout
|
||||
assert args.request_timeout_secs == timeout
|
||||
assert args.queue_timeout_secs == timeout
|
||||
|
||||
def test_retry_parameter_validation(self):
|
||||
"""Test retry parameter validation."""
|
||||
# Valid retry parameters
|
||||
valid_retry_counts = [0, 1, 3, 5, 10]
|
||||
for count in valid_retry_counts:
|
||||
args = RouterArgs(retry_max_retries=count)
|
||||
assert args.retry_max_retries == count
|
||||
|
||||
# Valid backoff parameters
|
||||
valid_backoff_ms = [1, 50, 100, 1000, 30000]
|
||||
for backoff in valid_backoff_ms:
|
||||
args = RouterArgs(
|
||||
retry_initial_backoff_ms=backoff, retry_max_backoff_ms=backoff
|
||||
)
|
||||
assert args.retry_initial_backoff_ms == backoff
|
||||
assert args.retry_max_backoff_ms == backoff
|
||||
|
||||
# Valid multiplier parameters
|
||||
valid_multipliers = [1.0, 1.5, 2.0, 3.0]
|
||||
for multiplier in valid_multipliers:
|
||||
args = RouterArgs(retry_backoff_multiplier=multiplier)
|
||||
assert args.retry_backoff_multiplier == multiplier
|
||||
|
||||
# Valid jitter parameters
|
||||
valid_jitter = [0.0, 0.1, 0.2, 0.5]
|
||||
for jitter in valid_jitter:
|
||||
args = RouterArgs(retry_jitter_factor=jitter)
|
||||
assert args.retry_jitter_factor == jitter
|
||||
|
||||
def test_circuit_breaker_parameter_validation(self):
|
||||
"""Test circuit breaker parameter validation."""
|
||||
# Valid failure thresholds
|
||||
valid_failure_thresholds = [1, 3, 5, 10, 20]
|
||||
for threshold in valid_failure_thresholds:
|
||||
args = RouterArgs(cb_failure_threshold=threshold)
|
||||
assert args.cb_failure_threshold == threshold
|
||||
|
||||
# Valid success thresholds
|
||||
valid_success_thresholds = [1, 2, 3, 5]
|
||||
for threshold in valid_success_thresholds:
|
||||
args = RouterArgs(cb_success_threshold=threshold)
|
||||
assert args.cb_success_threshold == threshold
|
||||
|
||||
# Valid timeout durations
|
||||
valid_timeouts = [10, 30, 60, 120, 300]
|
||||
for timeout in valid_timeouts:
|
||||
args = RouterArgs(
|
||||
cb_timeout_duration_secs=timeout, cb_window_duration_secs=timeout
|
||||
)
|
||||
assert args.cb_timeout_duration_secs == timeout
|
||||
assert args.cb_window_duration_secs == timeout
|
||||
|
||||
def test_health_check_parameter_validation(self):
|
||||
"""Test health check parameter validation."""
|
||||
# Valid failure thresholds
|
||||
valid_failure_thresholds = [1, 2, 3, 5, 10]
|
||||
for threshold in valid_failure_thresholds:
|
||||
args = RouterArgs(health_failure_threshold=threshold)
|
||||
assert args.health_failure_threshold == threshold
|
||||
|
||||
# Valid success thresholds
|
||||
valid_success_thresholds = [1, 2, 3, 5]
|
||||
for threshold in valid_success_thresholds:
|
||||
args = RouterArgs(health_success_threshold=threshold)
|
||||
assert args.health_success_threshold == threshold
|
||||
|
||||
# Valid timeouts and intervals
|
||||
valid_times = [1, 5, 10, 30, 60, 120]
|
||||
for time_val in valid_times:
|
||||
args = RouterArgs(
|
||||
health_check_timeout_secs=time_val, health_check_interval_secs=time_val
|
||||
)
|
||||
assert args.health_check_timeout_secs == time_val
|
||||
assert args.health_check_interval_secs == time_val
|
||||
|
||||
def test_rate_limiting_parameter_validation(self):
|
||||
"""Test rate limiting parameter validation."""
|
||||
# Valid concurrent request limits
|
||||
valid_limits = [1, 10, 64, 256, 512, 1000]
|
||||
for limit in valid_limits:
|
||||
args = RouterArgs(max_concurrent_requests=limit)
|
||||
assert args.max_concurrent_requests == limit
|
||||
|
||||
# Valid queue sizes
|
||||
valid_queue_sizes = [0, 10, 50, 100, 500, 1000]
|
||||
for size in valid_queue_sizes:
|
||||
args = RouterArgs(queue_size=size)
|
||||
assert args.queue_size == size
|
||||
|
||||
# Valid token rates
|
||||
valid_rates = [1, 10, 50, 100, 500, 1000]
|
||||
for rate in valid_rates:
|
||||
args = RouterArgs(rate_limit_tokens_per_second=rate)
|
||||
assert args.rate_limit_tokens_per_second == rate
|
||||
|
||||
def test_tree_size_validation(self):
|
||||
"""Test tree size parameter validation."""
|
||||
# Valid tree sizes (powers of 2)
|
||||
valid_sizes = [2**10, 2**20, 2**24, 2**26, 2**28, 2**30]
|
||||
|
||||
for size in valid_sizes:
|
||||
args = RouterArgs(max_tree_size=size)
|
||||
assert args.max_tree_size == size
|
||||
|
||||
def test_payload_size_validation(self):
|
||||
"""Test payload size parameter validation."""
|
||||
# Valid payload sizes
|
||||
valid_sizes = [
|
||||
1024, # 1KB
|
||||
1024 * 1024, # 1MB
|
||||
10 * 1024 * 1024, # 10MB
|
||||
100 * 1024 * 1024, # 100MB
|
||||
512 * 1024 * 1024, # 512MB
|
||||
1024 * 1024 * 1024, # 1GB
|
||||
]
|
||||
|
||||
for size in valid_sizes:
|
||||
args = RouterArgs(max_payload_size=size)
|
||||
assert args.max_payload_size == size
|
||||
|
||||
|
||||
class TestConfigurationValidation:
|
||||
"""Test configuration validation logic."""
|
||||
|
||||
def test_pd_mode_validation(self):
|
||||
"""Test PD mode configuration validation."""
|
||||
# Valid PD configuration
|
||||
args = RouterArgs(
|
||||
pd_disaggregation=True,
|
||||
prefill_urls=[("http://prefill1:8000", 9000)],
|
||||
decode_urls=["http://decode1:8001"],
|
||||
)
|
||||
|
||||
assert args.pd_disaggregation is True
|
||||
assert len(args.prefill_urls) > 0
|
||||
assert len(args.decode_urls) > 0
|
||||
|
||||
def test_service_discovery_validation(self):
|
||||
"""Test service discovery configuration validation."""
|
||||
# Valid service discovery configuration
|
||||
args = RouterArgs(
|
||||
service_discovery=True,
|
||||
selector={"app": "worker", "env": "prod"},
|
||||
service_discovery_port=8080,
|
||||
service_discovery_namespace="default",
|
||||
)
|
||||
|
||||
assert args.service_discovery is True
|
||||
assert args.selector == {"app": "worker", "env": "prod"}
|
||||
assert args.service_discovery_port == 8080
|
||||
assert args.service_discovery_namespace == "default"
|
||||
|
||||
def test_pd_service_discovery_validation(self):
|
||||
"""Test PD service discovery configuration validation."""
|
||||
# Valid PD service discovery configuration
|
||||
args = RouterArgs(
|
||||
pd_disaggregation=True,
|
||||
service_discovery=True,
|
||||
prefill_selector={"app": "prefill"},
|
||||
decode_selector={"app": "decode"},
|
||||
)
|
||||
|
||||
assert args.pd_disaggregation is True
|
||||
assert args.service_discovery is True
|
||||
assert args.prefill_selector == {"app": "prefill"}
|
||||
assert args.decode_selector == {"app": "decode"}
|
||||
|
||||
def test_policy_validation(self):
|
||||
"""Test policy configuration validation."""
|
||||
# Valid policies
|
||||
valid_policies = ["random", "round_robin", "cache_aware", "power_of_two"]
|
||||
|
||||
for policy in valid_policies:
|
||||
args = RouterArgs(policy=policy)
|
||||
assert args.policy == policy
|
||||
|
||||
def test_pd_policy_validation(self):
|
||||
"""Test PD policy configuration validation."""
|
||||
# Valid PD policies
|
||||
valid_policies = ["random", "round_robin", "cache_aware", "power_of_two"]
|
||||
|
||||
for prefill_policy in valid_policies:
|
||||
for decode_policy in valid_policies:
|
||||
args = RouterArgs(
|
||||
pd_disaggregation=True,
|
||||
prefill_urls=[("http://prefill1:8000", None)],
|
||||
decode_urls=["http://decode1:8001"],
|
||||
prefill_policy=prefill_policy,
|
||||
decode_policy=decode_policy,
|
||||
)
|
||||
assert args.prefill_policy == prefill_policy
|
||||
assert args.decode_policy == decode_policy
|
||||
|
||||
def test_cors_validation(self):
|
||||
"""Test CORS configuration validation."""
|
||||
# Valid CORS origins
|
||||
valid_origins = [
|
||||
[],
|
||||
["http://localhost:3000"],
|
||||
["https://example.com"],
|
||||
["http://localhost:3000", "https://example.com"],
|
||||
["*"], # Wildcard (if supported)
|
||||
]
|
||||
|
||||
for origins in valid_origins:
|
||||
args = RouterArgs(cors_allowed_origins=origins)
|
||||
assert args.cors_allowed_origins == origins
|
||||
|
||||
def test_logging_validation(self):
|
||||
"""Test logging configuration validation."""
|
||||
# Valid log levels
|
||||
valid_log_levels = ["debug", "info", "warning", "error", "critical"]
|
||||
|
||||
for level in valid_log_levels:
|
||||
args = RouterArgs(log_level=level)
|
||||
assert args.log_level == level
|
||||
|
||||
def test_prometheus_validation(self):
|
||||
"""Test Prometheus configuration validation."""
|
||||
# Valid Prometheus configuration
|
||||
args = RouterArgs(prometheus_port=29000, prometheus_host="127.0.0.1")
|
||||
|
||||
assert args.prometheus_port == 29000
|
||||
assert args.prometheus_host == "127.0.0.1"
|
||||
|
||||
def test_tokenizer_validation(self):
|
||||
"""Test tokenizer configuration validation."""
|
||||
# Note: model_path and tokenizer_path are not available in current RouterArgs
|
||||
pytest.skip("Tokenizer configuration not available in current implementation")
|
||||
|
||||
def test_request_id_headers_validation(self):
|
||||
"""Test request ID headers configuration validation."""
|
||||
# Valid request ID headers
|
||||
valid_headers = [
|
||||
["x-request-id"],
|
||||
["x-request-id", "x-trace-id"],
|
||||
["x-request-id", "x-trace-id", "x-correlation-id"],
|
||||
["custom-header"],
|
||||
]
|
||||
|
||||
for headers in valid_headers:
|
||||
args = RouterArgs(request_id_headers=headers)
|
||||
assert args.request_id_headers == headers
|
||||
|
||||
|
||||
class TestLaunchValidation:
|
||||
"""Test launch-time validation logic."""
|
||||
|
||||
def test_pd_mode_requires_urls(self):
|
||||
"""Test that PD mode requires prefill and decode URLs."""
|
||||
# PD mode without URLs should fail
|
||||
args = RouterArgs(
|
||||
pd_disaggregation=True,
|
||||
prefill_urls=[],
|
||||
decode_urls=[],
|
||||
service_discovery=False,
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="PD disaggregation mode requires --prefill"
|
||||
):
|
||||
launch_router(args)
|
||||
|
||||
def test_pd_mode_with_service_discovery_allows_empty_urls(self):
|
||||
"""Test that PD mode with service discovery allows empty URLs."""
|
||||
args = RouterArgs(
|
||||
pd_disaggregation=True,
|
||||
prefill_urls=[],
|
||||
decode_urls=[],
|
||||
service_discovery=True,
|
||||
)
|
||||
|
||||
# Should not raise validation error
|
||||
with patch("sglang_router.launch_router.Router") as router_mod:
|
||||
mock_router_instance = MagicMock()
|
||||
router_mod.from_args = MagicMock(return_value=mock_router_instance)
|
||||
|
||||
launch_router(args)
|
||||
|
||||
# Should create router instance via from_args
|
||||
router_mod.from_args.assert_called_once()
|
||||
|
||||
def test_regular_mode_allows_empty_worker_urls(self):
|
||||
"""Test that regular mode allows empty worker URLs."""
|
||||
args = RouterArgs(worker_urls=[], service_discovery=False)
|
||||
|
||||
# Should not raise validation error
|
||||
with patch("sglang_router.launch_router.Router") as router_mod:
|
||||
mock_router_instance = MagicMock()
|
||||
router_mod.from_args = MagicMock(return_value=mock_router_instance)
|
||||
|
||||
launch_router(args)
|
||||
|
||||
# Should create router instance via from_args
|
||||
router_mod.from_args.assert_called_once()
|
||||
|
||||
def test_launch_with_valid_config(self):
|
||||
"""Test launching with valid configuration."""
|
||||
args = RouterArgs(
|
||||
host="127.0.0.1",
|
||||
port=30000,
|
||||
worker_urls=["http://worker1:8000"],
|
||||
policy="cache_aware",
|
||||
)
|
||||
|
||||
# Should not raise validation error
|
||||
with patch("sglang_router.launch_router.Router") as router_mod:
|
||||
mock_router_instance = MagicMock()
|
||||
router_mod.from_args = MagicMock(return_value=mock_router_instance)
|
||||
|
||||
launch_router(args)
|
||||
|
||||
# Should create router instance via from_args
|
||||
router_mod.from_args.assert_called_once()
|
||||
|
||||
def test_launch_with_pd_config(self):
|
||||
"""Test launching with valid PD configuration."""
|
||||
args = RouterArgs(
|
||||
pd_disaggregation=True,
|
||||
prefill_urls=[("http://prefill1:8000", 9000)],
|
||||
decode_urls=["http://decode1:8001"],
|
||||
policy="cache_aware",
|
||||
)
|
||||
|
||||
# Should not raise validation error
|
||||
with patch("sglang_router.launch_router.Router") as router_mod:
|
||||
mock_router_instance = MagicMock()
|
||||
router_mod.from_args = MagicMock(return_value=mock_router_instance)
|
||||
|
||||
launch_router(args)
|
||||
|
||||
# Should create router instance via from_args
|
||||
router_mod.from_args.assert_called_once()
|
||||
|
||||
def test_launch_with_service_discovery_config(self):
|
||||
"""Test launching with valid service discovery configuration."""
|
||||
args = RouterArgs(
|
||||
service_discovery=True,
|
||||
selector={"app": "worker"},
|
||||
service_discovery_port=8080,
|
||||
)
|
||||
|
||||
# Should not raise validation error
|
||||
with patch("sglang_router.launch_router.Router") as router_mod:
|
||||
mock_router_instance = MagicMock()
|
||||
router_mod.from_args = MagicMock(return_value=mock_router_instance)
|
||||
|
||||
launch_router(args)
|
||||
|
||||
# Should create router instance via from_args
|
||||
router_mod.from_args.assert_called_once()
|
||||
Reference in New Issue
Block a user