182 lines
9.4 KiB
Python
182 lines
9.4 KiB
Python
from typing import Dict, List, Optional
|
|
|
|
from sglang_router_rs import PolicyType
|
|
from sglang_router_rs import Router as _Router
|
|
|
|
|
|
class Router:
|
|
"""
|
|
A high-performance router for distributing requests across worker nodes.
|
|
|
|
Args:
|
|
worker_urls: List of URLs for worker nodes that will handle requests. Each URL should include
|
|
the protocol, host, and port (e.g., ['http://worker1:8000', 'http://worker2:8000'])
|
|
policy: Load balancing policy to use. Options:
|
|
- PolicyType.Random: Randomly select workers
|
|
- PolicyType.RoundRobin: Distribute requests in round-robin fashion
|
|
- PolicyType.CacheAware: Distribute requests based on cache state and load balance
|
|
- PolicyType.PowerOfTwo: Select best of two random workers based on load (PD mode only)
|
|
host: Host address to bind the router server. Default: '127.0.0.1'
|
|
port: Port number to bind the router server. Default: 3001
|
|
worker_startup_timeout_secs: Timeout in seconds for worker startup. Default: 300
|
|
worker_startup_check_interval: Interval in seconds between checks for worker initialization. Default: 10
|
|
cache_threshold: Cache threshold (0.0-1.0) for cache-aware routing. Routes to cached worker
|
|
if the match rate exceeds threshold, otherwise routes to the worker with the smallest
|
|
tree. Default: 0.5
|
|
balance_abs_threshold: Load balancing is triggered when (max_load - min_load) > abs_threshold
|
|
AND max_load > min_load * rel_threshold. Otherwise, use cache aware. Default: 32
|
|
balance_rel_threshold: Load balancing is triggered when (max_load - min_load) > abs_threshold
|
|
AND max_load > min_load * rel_threshold. Otherwise, use cache aware. Default: 1.0001
|
|
eviction_interval_secs: Interval in seconds between cache eviction operations in cache-aware
|
|
routing. Default: 60
|
|
max_payload_size: Maximum payload size in bytes. Default: 256MB
|
|
max_tree_size: Maximum size of the approximation tree for cache-aware routing. Default: 2^24
|
|
dp_aware: Enable data parallelism aware schedule. Default: False
|
|
api_key: The api key used for the authorization with the worker.
|
|
Useful when the dp aware scheduling strategy is enabled.
|
|
Default: None
|
|
log_dir: Directory to store log files. If None, logs are only output to console. Default: None
|
|
log_level: Logging level. Options: 'debug', 'info', 'warning', 'error', 'critical'.
|
|
service_discovery: Enable Kubernetes service discovery. When enabled, the router will
|
|
automatically discover worker pods based on the selector. Default: False
|
|
selector: Dictionary mapping of label keys to values for Kubernetes pod selection.
|
|
Example: {"app": "sglang-worker"}. Default: {}
|
|
service_discovery_port: Port to use for service discovery. The router will generate
|
|
worker URLs using this port. Default: 80
|
|
service_discovery_namespace: Kubernetes namespace to watch for pods. If not provided,
|
|
watches pods across all namespaces (requires cluster-wide permissions). Default: None
|
|
prefill_selector: Dictionary mapping of label keys to values for Kubernetes pod selection
|
|
for prefill servers (PD mode only). Default: {}
|
|
decode_selector: Dictionary mapping of label keys to values for Kubernetes pod selection
|
|
for decode servers (PD mode only). Default: {}
|
|
prometheus_port: Port to expose Prometheus metrics. Default: None
|
|
prometheus_host: Host address to bind the Prometheus metrics server. Default: None
|
|
pd_disaggregation: Enable PD (Prefill-Decode) disaggregated mode. Default: False
|
|
prefill_urls: List of (url, bootstrap_port) tuples for prefill servers (PD mode only)
|
|
decode_urls: List of URLs for decode servers (PD mode only)
|
|
prefill_policy: Specific load balancing policy for prefill nodes (PD mode only).
|
|
If not specified, uses the main policy. Default: None
|
|
decode_policy: Specific load balancing policy for decode nodes (PD mode only).
|
|
If not specified, uses the main policy. Default: None
|
|
request_id_headers: List of HTTP headers to check for request IDs. If not specified,
|
|
uses common defaults: ['x-request-id', 'x-correlation-id', 'x-trace-id', 'request-id'].
|
|
Example: ['x-my-request-id', 'x-custom-trace-id']. Default: None
|
|
bootstrap_port_annotation: Kubernetes annotation name for bootstrap port (PD mode).
|
|
Default: 'sglang.ai/bootstrap-port'
|
|
request_timeout_secs: Request timeout in seconds. Default: 600
|
|
max_concurrent_requests: Maximum number of concurrent requests allowed for rate limiting. Default: 64
|
|
cors_allowed_origins: List of allowed origins for CORS. Empty list allows all origins. Default: []
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
worker_urls: List[str],
|
|
policy: PolicyType = PolicyType.RoundRobin,
|
|
host: str = "127.0.0.1",
|
|
port: int = 3001,
|
|
worker_startup_timeout_secs: int = 300,
|
|
worker_startup_check_interval: int = 10,
|
|
cache_threshold: float = 0.50,
|
|
balance_abs_threshold: int = 32,
|
|
balance_rel_threshold: float = 1.0001,
|
|
eviction_interval_secs: int = 60,
|
|
max_tree_size: int = 2**24,
|
|
max_payload_size: int = 256 * 1024 * 1024, # 256MB
|
|
dp_aware: bool = False,
|
|
api_key: Optional[str] = None,
|
|
log_dir: Optional[str] = None,
|
|
log_level: Optional[str] = None,
|
|
service_discovery: bool = False,
|
|
selector: Dict[str, str] = None,
|
|
service_discovery_port: int = 80,
|
|
service_discovery_namespace: Optional[str] = None,
|
|
prefill_selector: Dict[str, str] = None,
|
|
decode_selector: Dict[str, str] = None,
|
|
bootstrap_port_annotation: str = "sglang.ai/bootstrap-port",
|
|
prometheus_port: Optional[int] = None,
|
|
prometheus_host: Optional[str] = None,
|
|
request_timeout_secs: int = 600,
|
|
request_id_headers: Optional[List[str]] = None,
|
|
pd_disaggregation: bool = False,
|
|
prefill_urls: Optional[List[tuple]] = None,
|
|
decode_urls: Optional[List[str]] = None,
|
|
prefill_policy: Optional[PolicyType] = None,
|
|
decode_policy: Optional[PolicyType] = None,
|
|
max_concurrent_requests: int = 64,
|
|
cors_allowed_origins: List[str] = None,
|
|
retry_max_retries: int = 3,
|
|
retry_initial_backoff_ms: int = 100,
|
|
retry_max_backoff_ms: int = 10_000,
|
|
retry_backoff_multiplier: float = 2.0,
|
|
retry_jitter_factor: float = 0.1,
|
|
cb_failure_threshold: int = 5,
|
|
cb_success_threshold: int = 2,
|
|
cb_timeout_duration_secs: int = 30,
|
|
cb_window_duration_secs: int = 60,
|
|
disable_retries: bool = False,
|
|
disable_circuit_breaker: bool = False,
|
|
):
|
|
if selector is None:
|
|
selector = {}
|
|
if prefill_selector is None:
|
|
prefill_selector = {}
|
|
if decode_selector is None:
|
|
decode_selector = {}
|
|
if cors_allowed_origins is None:
|
|
cors_allowed_origins = []
|
|
|
|
self._router = _Router(
|
|
worker_urls=worker_urls,
|
|
policy=policy,
|
|
host=host,
|
|
port=port,
|
|
worker_startup_timeout_secs=worker_startup_timeout_secs,
|
|
worker_startup_check_interval=worker_startup_check_interval,
|
|
cache_threshold=cache_threshold,
|
|
balance_abs_threshold=balance_abs_threshold,
|
|
balance_rel_threshold=balance_rel_threshold,
|
|
eviction_interval_secs=eviction_interval_secs,
|
|
max_tree_size=max_tree_size,
|
|
max_payload_size=max_payload_size,
|
|
dp_aware=dp_aware,
|
|
api_key=api_key,
|
|
log_dir=log_dir,
|
|
log_level=log_level,
|
|
service_discovery=service_discovery,
|
|
selector=selector,
|
|
service_discovery_port=service_discovery_port,
|
|
service_discovery_namespace=service_discovery_namespace,
|
|
prefill_selector=prefill_selector,
|
|
decode_selector=decode_selector,
|
|
bootstrap_port_annotation=bootstrap_port_annotation,
|
|
prometheus_port=prometheus_port,
|
|
prometheus_host=prometheus_host,
|
|
request_timeout_secs=request_timeout_secs,
|
|
request_id_headers=request_id_headers,
|
|
pd_disaggregation=pd_disaggregation,
|
|
prefill_urls=prefill_urls,
|
|
decode_urls=decode_urls,
|
|
prefill_policy=prefill_policy,
|
|
decode_policy=decode_policy,
|
|
max_concurrent_requests=max_concurrent_requests,
|
|
cors_allowed_origins=cors_allowed_origins,
|
|
retry_max_retries=retry_max_retries,
|
|
retry_initial_backoff_ms=retry_initial_backoff_ms,
|
|
retry_max_backoff_ms=retry_max_backoff_ms,
|
|
retry_backoff_multiplier=retry_backoff_multiplier,
|
|
retry_jitter_factor=retry_jitter_factor,
|
|
cb_failure_threshold=cb_failure_threshold,
|
|
cb_success_threshold=cb_success_threshold,
|
|
cb_timeout_duration_secs=cb_timeout_duration_secs,
|
|
cb_window_duration_secs=cb_window_duration_secs,
|
|
disable_retries=disable_retries,
|
|
disable_circuit_breaker=disable_circuit_breaker,
|
|
)
|
|
|
|
def start(self) -> None:
|
|
"""Start the router server.
|
|
|
|
This method blocks until the server is shut down.
|
|
"""
|
|
self._router.start()
|