Files
sglang/sgl-router/py_test/unit/test_validation.py

506 lines
18 KiB
Python

"""
Unit tests for validation logic in sglang_router.
These tests focus on testing the validation logic in isolation,
including parameter validation, URL validation, and configuration validation.
"""
from unittest.mock import MagicMock, patch
import pytest
from sglang_router.launch_router import RouterArgs, launch_router
class TestURLValidation:
"""Test URL validation logic."""
def test_valid_worker_urls(self):
"""Test validation of valid worker URLs."""
valid_urls = [
"http://worker1:8000",
"https://worker2:8000",
"http://localhost:8000",
"http://127.0.0.1:8000",
"http://192.168.1.100:8000",
"http://worker.example.com:8000",
]
for url in valid_urls:
args = RouterArgs(worker_urls=[url])
# Should not raise any validation errors
assert url in args.worker_urls
def test_valid_prefill_urls(self):
"""Test validation of valid prefill URLs."""
valid_prefill_urls = [
("http://prefill1:8000", 9000),
("https://prefill2:8000", None),
("http://localhost:8000", 9000),
("http://127.0.0.1:8000", None),
]
for url, bootstrap_port in valid_prefill_urls:
args = RouterArgs(prefill_urls=[(url, bootstrap_port)])
# Should not raise any validation errors
assert (url, bootstrap_port) in args.prefill_urls
def test_valid_decode_urls(self):
"""Test validation of valid decode URLs."""
valid_decode_urls = [
"http://decode1:8001",
"https://decode2:8001",
"http://localhost:8001",
"http://127.0.0.1:8001",
]
for url in valid_decode_urls:
args = RouterArgs(decode_urls=[url])
# Should not raise any validation errors
assert url in args.decode_urls
def test_malformed_urls(self):
"""Test handling of malformed URLs."""
# Note: The current implementation doesn't validate URL format
# This test documents the current behavior
malformed_urls = [
"not-a-url",
"ftp://worker1:8000", # Wrong protocol
"http://", # Missing host
":8000", # Missing protocol and host
"http://worker1", # Missing port
]
for url in malformed_urls:
args = RouterArgs(worker_urls=[url])
# Currently, malformed URLs are accepted
# This might be something to improve in the future
assert url in args.worker_urls
class TestPortValidation:
"""Test port validation logic."""
def test_valid_ports(self):
"""Test validation of valid port numbers."""
valid_ports = [1, 80, 8000, 30000, 65535]
for port in valid_ports:
args = RouterArgs(port=port)
assert args.port == port
def test_invalid_ports(self):
"""Test handling of invalid port numbers."""
# Note: The current implementation doesn't validate port ranges
# This test documents the current behavior
invalid_ports = [0, -1, 65536, 70000]
for port in invalid_ports:
args = RouterArgs(port=port)
# Currently, invalid ports are accepted
# This might be something to improve in the future
assert args.port == port
def test_bootstrap_port_validation(self):
"""Test validation of bootstrap ports in PD mode."""
valid_bootstrap_ports = [1, 80, 9000, 30000, 65535, None]
for bootstrap_port in valid_bootstrap_ports:
args = RouterArgs(prefill_urls=[("http://prefill1:8000", bootstrap_port)])
assert args.prefill_urls[0][1] == bootstrap_port
class TestParameterValidation:
"""Test parameter validation logic."""
def test_cache_threshold_validation(self):
"""Test cache threshold parameter validation."""
# Valid cache thresholds
valid_thresholds = [0.0, 0.1, 0.5, 0.9, 1.0]
for threshold in valid_thresholds:
args = RouterArgs(cache_threshold=threshold)
assert args.cache_threshold == threshold
def test_balance_threshold_validation(self):
"""Test load balancing threshold parameter validation."""
# Valid absolute thresholds
valid_abs_thresholds = [0, 1, 32, 64, 128, 1000]
for threshold in valid_abs_thresholds:
args = RouterArgs(balance_abs_threshold=threshold)
assert args.balance_abs_threshold == threshold
# Valid relative thresholds
valid_rel_thresholds = [1.0, 1.1, 1.5, 2.0, 10.0]
for threshold in valid_rel_thresholds:
args = RouterArgs(balance_rel_threshold=threshold)
assert args.balance_rel_threshold == threshold
def test_timeout_validation(self):
"""Test timeout parameter validation."""
# Valid timeouts
valid_timeouts = [1, 30, 60, 300, 600, 1800, 3600]
for timeout in valid_timeouts:
args = RouterArgs(
worker_startup_timeout_secs=timeout,
worker_startup_check_interval=timeout,
request_timeout_secs=timeout,
queue_timeout_secs=timeout,
)
assert args.worker_startup_timeout_secs == timeout
assert args.worker_startup_check_interval == timeout
assert args.request_timeout_secs == timeout
assert args.queue_timeout_secs == timeout
def test_retry_parameter_validation(self):
"""Test retry parameter validation."""
# Valid retry parameters
valid_retry_counts = [0, 1, 3, 5, 10]
for count in valid_retry_counts:
args = RouterArgs(retry_max_retries=count)
assert args.retry_max_retries == count
# Valid backoff parameters
valid_backoff_ms = [1, 50, 100, 1000, 30000]
for backoff in valid_backoff_ms:
args = RouterArgs(
retry_initial_backoff_ms=backoff, retry_max_backoff_ms=backoff
)
assert args.retry_initial_backoff_ms == backoff
assert args.retry_max_backoff_ms == backoff
# Valid multiplier parameters
valid_multipliers = [1.0, 1.5, 2.0, 3.0]
for multiplier in valid_multipliers:
args = RouterArgs(retry_backoff_multiplier=multiplier)
assert args.retry_backoff_multiplier == multiplier
# Valid jitter parameters
valid_jitter = [0.0, 0.1, 0.2, 0.5]
for jitter in valid_jitter:
args = RouterArgs(retry_jitter_factor=jitter)
assert args.retry_jitter_factor == jitter
def test_circuit_breaker_parameter_validation(self):
"""Test circuit breaker parameter validation."""
# Valid failure thresholds
valid_failure_thresholds = [1, 3, 5, 10, 20]
for threshold in valid_failure_thresholds:
args = RouterArgs(cb_failure_threshold=threshold)
assert args.cb_failure_threshold == threshold
# Valid success thresholds
valid_success_thresholds = [1, 2, 3, 5]
for threshold in valid_success_thresholds:
args = RouterArgs(cb_success_threshold=threshold)
assert args.cb_success_threshold == threshold
# Valid timeout durations
valid_timeouts = [10, 30, 60, 120, 300]
for timeout in valid_timeouts:
args = RouterArgs(
cb_timeout_duration_secs=timeout, cb_window_duration_secs=timeout
)
assert args.cb_timeout_duration_secs == timeout
assert args.cb_window_duration_secs == timeout
def test_health_check_parameter_validation(self):
"""Test health check parameter validation."""
# Valid failure thresholds
valid_failure_thresholds = [1, 2, 3, 5, 10]
for threshold in valid_failure_thresholds:
args = RouterArgs(health_failure_threshold=threshold)
assert args.health_failure_threshold == threshold
# Valid success thresholds
valid_success_thresholds = [1, 2, 3, 5]
for threshold in valid_success_thresholds:
args = RouterArgs(health_success_threshold=threshold)
assert args.health_success_threshold == threshold
# Valid timeouts and intervals
valid_times = [1, 5, 10, 30, 60, 120]
for time_val in valid_times:
args = RouterArgs(
health_check_timeout_secs=time_val, health_check_interval_secs=time_val
)
assert args.health_check_timeout_secs == time_val
assert args.health_check_interval_secs == time_val
def test_rate_limiting_parameter_validation(self):
"""Test rate limiting parameter validation."""
# Valid concurrent request limits
valid_limits = [1, 10, 64, 256, 512, 1000]
for limit in valid_limits:
args = RouterArgs(max_concurrent_requests=limit)
assert args.max_concurrent_requests == limit
# Valid queue sizes
valid_queue_sizes = [0, 10, 50, 100, 500, 1000]
for size in valid_queue_sizes:
args = RouterArgs(queue_size=size)
assert args.queue_size == size
# Valid token rates
valid_rates = [1, 10, 50, 100, 500, 1000]
for rate in valid_rates:
args = RouterArgs(rate_limit_tokens_per_second=rate)
assert args.rate_limit_tokens_per_second == rate
def test_tree_size_validation(self):
"""Test tree size parameter validation."""
# Valid tree sizes (powers of 2)
valid_sizes = [2**10, 2**20, 2**24, 2**26, 2**28, 2**30]
for size in valid_sizes:
args = RouterArgs(max_tree_size=size)
assert args.max_tree_size == size
def test_payload_size_validation(self):
"""Test payload size parameter validation."""
# Valid payload sizes
valid_sizes = [
1024, # 1KB
1024 * 1024, # 1MB
10 * 1024 * 1024, # 10MB
100 * 1024 * 1024, # 100MB
512 * 1024 * 1024, # 512MB
1024 * 1024 * 1024, # 1GB
]
for size in valid_sizes:
args = RouterArgs(max_payload_size=size)
assert args.max_payload_size == size
class TestConfigurationValidation:
"""Test configuration validation logic."""
def test_pd_mode_validation(self):
"""Test PD mode configuration validation."""
# Valid PD configuration
args = RouterArgs(
pd_disaggregation=True,
prefill_urls=[("http://prefill1:8000", 9000)],
decode_urls=["http://decode1:8001"],
)
assert args.pd_disaggregation is True
assert len(args.prefill_urls) > 0
assert len(args.decode_urls) > 0
def test_service_discovery_validation(self):
"""Test service discovery configuration validation."""
# Valid service discovery configuration
args = RouterArgs(
service_discovery=True,
selector={"app": "worker", "env": "prod"},
service_discovery_port=8080,
service_discovery_namespace="default",
)
assert args.service_discovery is True
assert args.selector == {"app": "worker", "env": "prod"}
assert args.service_discovery_port == 8080
assert args.service_discovery_namespace == "default"
def test_pd_service_discovery_validation(self):
"""Test PD service discovery configuration validation."""
# Valid PD service discovery configuration
args = RouterArgs(
pd_disaggregation=True,
service_discovery=True,
prefill_selector={"app": "prefill"},
decode_selector={"app": "decode"},
)
assert args.pd_disaggregation is True
assert args.service_discovery is True
assert args.prefill_selector == {"app": "prefill"}
assert args.decode_selector == {"app": "decode"}
def test_policy_validation(self):
"""Test policy configuration validation."""
# Valid policies
valid_policies = ["random", "round_robin", "cache_aware", "power_of_two"]
for policy in valid_policies:
args = RouterArgs(policy=policy)
assert args.policy == policy
def test_pd_policy_validation(self):
"""Test PD policy configuration validation."""
# Valid PD policies
valid_policies = ["random", "round_robin", "cache_aware", "power_of_two"]
for prefill_policy in valid_policies:
for decode_policy in valid_policies:
args = RouterArgs(
pd_disaggregation=True,
prefill_urls=[("http://prefill1:8000", None)],
decode_urls=["http://decode1:8001"],
prefill_policy=prefill_policy,
decode_policy=decode_policy,
)
assert args.prefill_policy == prefill_policy
assert args.decode_policy == decode_policy
def test_cors_validation(self):
"""Test CORS configuration validation."""
# Valid CORS origins
valid_origins = [
[],
["http://localhost:3000"],
["https://example.com"],
["http://localhost:3000", "https://example.com"],
["*"], # Wildcard (if supported)
]
for origins in valid_origins:
args = RouterArgs(cors_allowed_origins=origins)
assert args.cors_allowed_origins == origins
def test_logging_validation(self):
"""Test logging configuration validation."""
# Valid log levels
valid_log_levels = ["debug", "info", "warning", "error", "critical"]
for level in valid_log_levels:
args = RouterArgs(log_level=level)
assert args.log_level == level
def test_prometheus_validation(self):
"""Test Prometheus configuration validation."""
# Valid Prometheus configuration
args = RouterArgs(prometheus_port=29000, prometheus_host="127.0.0.1")
assert args.prometheus_port == 29000
assert args.prometheus_host == "127.0.0.1"
def test_tokenizer_validation(self):
"""Test tokenizer configuration validation."""
# Note: model_path and tokenizer_path are not available in current RouterArgs
pytest.skip("Tokenizer configuration not available in current implementation")
def test_request_id_headers_validation(self):
"""Test request ID headers configuration validation."""
# Valid request ID headers
valid_headers = [
["x-request-id"],
["x-request-id", "x-trace-id"],
["x-request-id", "x-trace-id", "x-correlation-id"],
["custom-header"],
]
for headers in valid_headers:
args = RouterArgs(request_id_headers=headers)
assert args.request_id_headers == headers
class TestLaunchValidation:
"""Test launch-time validation logic."""
def test_pd_mode_requires_urls(self):
"""Test that PD mode requires prefill and decode URLs."""
# PD mode without URLs should fail
args = RouterArgs(
pd_disaggregation=True,
prefill_urls=[],
decode_urls=[],
service_discovery=False,
)
with pytest.raises(
ValueError, match="PD disaggregation mode requires --prefill"
):
launch_router(args)
def test_pd_mode_with_service_discovery_allows_empty_urls(self):
"""Test that PD mode with service discovery allows empty URLs."""
args = RouterArgs(
pd_disaggregation=True,
prefill_urls=[],
decode_urls=[],
service_discovery=True,
)
# Should not raise validation error
with patch("sglang_router.launch_router.Router") as router_mod:
mock_router_instance = MagicMock()
router_mod.from_args = MagicMock(return_value=mock_router_instance)
launch_router(args)
# Should create router instance via from_args
router_mod.from_args.assert_called_once()
def test_regular_mode_allows_empty_worker_urls(self):
"""Test that regular mode allows empty worker URLs."""
args = RouterArgs(worker_urls=[], service_discovery=False)
# Should not raise validation error
with patch("sglang_router.launch_router.Router") as router_mod:
mock_router_instance = MagicMock()
router_mod.from_args = MagicMock(return_value=mock_router_instance)
launch_router(args)
# Should create router instance via from_args
router_mod.from_args.assert_called_once()
def test_launch_with_valid_config(self):
"""Test launching with valid configuration."""
args = RouterArgs(
host="127.0.0.1",
port=30000,
worker_urls=["http://worker1:8000"],
policy="cache_aware",
)
# Should not raise validation error
with patch("sglang_router.launch_router.Router") as router_mod:
mock_router_instance = MagicMock()
router_mod.from_args = MagicMock(return_value=mock_router_instance)
launch_router(args)
# Should create router instance via from_args
router_mod.from_args.assert_called_once()
def test_launch_with_pd_config(self):
"""Test launching with valid PD configuration."""
args = RouterArgs(
pd_disaggregation=True,
prefill_urls=[("http://prefill1:8000", 9000)],
decode_urls=["http://decode1:8001"],
policy="cache_aware",
)
# Should not raise validation error
with patch("sglang_router.launch_router.Router") as router_mod:
mock_router_instance = MagicMock()
router_mod.from_args = MagicMock(return_value=mock_router_instance)
launch_router(args)
# Should create router instance via from_args
router_mod.from_args.assert_called_once()
def test_launch_with_service_discovery_config(self):
"""Test launching with valid service discovery configuration."""
args = RouterArgs(
service_discovery=True,
selector={"app": "worker"},
service_discovery_port=8080,
)
# Should not raise validation error
with patch("sglang_router.launch_router.Router") as router_mod:
mock_router_instance = MagicMock()
router_mod.from_args = MagicMock(return_value=mock_router_instance)
launch_router(args)
# Should create router instance via from_args
router_mod.from_args.assert_called_once()