Simplify Router arguments passing and build it in docker image (#9964)

This commit is contained in:
Liangsheng Yin
2025-09-05 12:13:55 +08:00
committed by GitHub
parent 0e9387a95d
commit 6e95f5e5bd
24 changed files with 1157 additions and 1587 deletions

View File

@@ -1,7 +1,3 @@
# a lightweihgt wrapper on router with argument type and comments
# no wrapper on policy type => direct export
from sglang_router.router import Router
from sglang_router.version import __version__
from sglang_router_rs import PolicyType
__all__ = ["Router", "PolicyType", "__version__"]
__all__ = ["__version__"]

View File

@@ -1,654 +1,22 @@
import argparse
import dataclasses
import logging
import sys
from typing import Dict, List, Optional
from typing import List, Optional
from sglang_router import Router
from sglang_router_rs import PolicyType
import setproctitle
from sglang_router.mini_lb import MiniLoadBalancer
from sglang_router.router_args import RouterArgs
logger = logging.getLogger("router")
def setup_logger():
logger = logging.getLogger("router")
logger.setLevel(logging.INFO)
formatter = logging.Formatter(
"[Router (Python)] %(asctime)s - %(levelname)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
try:
from sglang_router.router import Router
except ImportError:
Router = None
logger.warning(
"Rust Router is not installed, only python MiniLB (debugging only) is available"
)
handler = logging.StreamHandler()
handler.setFormatter(formatter)
logger.addHandler(handler)
return logger
@dataclasses.dataclass
class RouterArgs:
# Worker configuration
worker_urls: List[str] = dataclasses.field(default_factory=list)
host: str = "127.0.0.1"
port: int = 30000
# PD-specific configuration
pd_disaggregation: bool = False # Enable PD disaggregated mode
prefill_urls: List[tuple] = dataclasses.field(
default_factory=list
) # List of (url, bootstrap_port)
decode_urls: List[str] = dataclasses.field(default_factory=list)
# Routing policy
policy: str = "cache_aware"
prefill_policy: Optional[str] = None # Specific policy for prefill nodes in PD mode
decode_policy: Optional[str] = None # Specific policy for decode nodes in PD mode
worker_startup_timeout_secs: int = 600
worker_startup_check_interval: int = 30
cache_threshold: float = 0.3
balance_abs_threshold: int = 64
balance_rel_threshold: float = 1.5
eviction_interval: int = 120
max_tree_size: int = 2**26
max_payload_size: int = 512 * 1024 * 1024 # 512MB default for large batches
dp_aware: bool = False
api_key: Optional[str] = None
log_dir: Optional[str] = None
log_level: Optional[str] = None
# Service discovery configuration
service_discovery: bool = False
selector: Dict[str, str] = dataclasses.field(default_factory=dict)
service_discovery_port: int = 80
service_discovery_namespace: Optional[str] = None
# PD service discovery configuration
prefill_selector: Dict[str, str] = dataclasses.field(default_factory=dict)
decode_selector: Dict[str, str] = dataclasses.field(default_factory=dict)
bootstrap_port_annotation: str = "sglang.ai/bootstrap-port"
# Prometheus configuration
prometheus_port: Optional[int] = None
prometheus_host: Optional[str] = None
# Request ID headers configuration
request_id_headers: Optional[List[str]] = None
# Request timeout in seconds
request_timeout_secs: int = 1800
# Max concurrent requests for rate limiting
max_concurrent_requests: int = 256
# Queue size for pending requests when max concurrent limit reached
queue_size: int = 100
# Maximum time (in seconds) a request can wait in queue before timing out
queue_timeout_secs: int = 60
# Token bucket refill rate (tokens per second). If not set, defaults to max_concurrent_requests
rate_limit_tokens_per_second: Optional[int] = None
# CORS allowed origins
cors_allowed_origins: List[str] = dataclasses.field(default_factory=list)
# Retry configuration
retry_max_retries: int = 5
retry_initial_backoff_ms: int = 50
retry_max_backoff_ms: int = 30_000
retry_backoff_multiplier: float = 1.5
retry_jitter_factor: float = 0.2
disable_retries: bool = False
# Health check configuration
health_failure_threshold: int = 3
health_success_threshold: int = 2
health_check_timeout_secs: int = 5
health_check_interval_secs: int = 60
health_check_endpoint: str = "/health"
# Circuit breaker configuration
cb_failure_threshold: int = 10
cb_success_threshold: int = 3
cb_timeout_duration_secs: int = 60
cb_window_duration_secs: int = 120
disable_circuit_breaker: bool = False
# Tokenizer configuration
model_path: Optional[str] = None
tokenizer_path: Optional[str] = None
@staticmethod
def add_cli_args(
parser: argparse.ArgumentParser,
use_router_prefix: bool = False,
exclude_host_port: bool = False,
):
"""
Add router-specific arguments to an argument parser.
Args:
parser: The argument parser to add arguments to
use_router_prefix: If True, prefix all arguments with 'router-' to avoid conflicts
exclude_host_port: If True, don't add host and port arguments (used when inheriting from server)
"""
prefix = "router-" if use_router_prefix else ""
# Worker configuration
if not exclude_host_port:
parser.add_argument(
"--host",
type=str,
default=RouterArgs.host,
help="Host address to bind the router server",
)
parser.add_argument(
"--port",
type=int,
default=RouterArgs.port,
help="Port number to bind the router server",
)
parser.add_argument(
"--worker-urls",
type=str,
nargs="*",
default=[],
help="List of worker URLs (e.g., http://worker1:8000 http://worker2:8000)",
)
# Routing policy configuration
parser.add_argument(
f"--{prefix}policy",
type=str,
default=RouterArgs.policy,
choices=["random", "round_robin", "cache_aware", "power_of_two"],
help="Load balancing policy to use. In PD mode, this is used for both prefill and decode unless overridden",
)
parser.add_argument(
f"--{prefix}prefill-policy",
type=str,
default=None,
choices=["random", "round_robin", "cache_aware", "power_of_two"],
help="Specific policy for prefill nodes in PD mode. If not specified, uses the main policy",
)
parser.add_argument(
f"--{prefix}decode-policy",
type=str,
default=None,
choices=["random", "round_robin", "cache_aware", "power_of_two"],
help="Specific policy for decode nodes in PD mode. If not specified, uses the main policy",
)
# PD-specific arguments
parser.add_argument(
f"--{prefix}pd-disaggregation",
action="store_true",
help="Enable PD (Prefill-Decode) disaggregated mode",
)
parser.add_argument(
f"--{prefix}prefill",
nargs="+",
action="append",
help="Prefill server URL and optional bootstrap port. Can be specified multiple times. "
"Format: --prefill URL [BOOTSTRAP_PORT]. "
"BOOTSTRAP_PORT can be a port number, 'none', or omitted (defaults to none).",
)
parser.add_argument(
f"--{prefix}decode",
nargs=1,
action="append",
metavar=("URL",),
help="Decode server URL. Can be specified multiple times.",
)
parser.add_argument(
f"--{prefix}worker-startup-timeout-secs",
type=int,
default=RouterArgs.worker_startup_timeout_secs,
help="Timeout in seconds for worker startup",
)
parser.add_argument(
f"--{prefix}worker-startup-check-interval",
type=int,
default=RouterArgs.worker_startup_check_interval,
help="Interval in seconds between checks for worker startup",
)
parser.add_argument(
f"--{prefix}cache-threshold",
type=float,
default=RouterArgs.cache_threshold,
help="Cache threshold (0.0-1.0) for cache-aware routing",
)
parser.add_argument(
f"--{prefix}balance-abs-threshold",
type=int,
default=RouterArgs.balance_abs_threshold,
help="Load balancing is triggered when (max_load - min_load) > abs_threshold AND max_load > min_load * rel_threshold. Otherwise, use cache aware",
)
parser.add_argument(
f"--{prefix}balance-rel-threshold",
type=float,
default=RouterArgs.balance_rel_threshold,
help="Load balancing is triggered when (max_load - min_load) > abs_threshold AND max_load > min_load * rel_threshold. Otherwise, use cache aware",
)
parser.add_argument(
f"--{prefix}eviction-interval",
type=int,
default=RouterArgs.eviction_interval,
help="Interval in seconds between cache eviction operations",
)
parser.add_argument(
f"--{prefix}max-tree-size",
type=int,
default=RouterArgs.max_tree_size,
help="Maximum size of the approximation tree for cache-aware routing",
)
parser.add_argument(
f"--{prefix}max-payload-size",
type=int,
default=RouterArgs.max_payload_size,
help="Maximum payload size in bytes",
)
parser.add_argument(
f"--{prefix}dp-aware",
action="store_true",
help="Enable data parallelism aware schedule",
)
parser.add_argument(
f"--{prefix}api-key",
type=str,
default=None,
help="The api key used for the authorization with the worker. Useful when the dp aware scheduling strategy is enaled.",
)
parser.add_argument(
f"--{prefix}log-dir",
type=str,
default=None,
help="Directory to store log files. If not specified, logs are only output to console.",
)
parser.add_argument(
f"--{prefix}log-level",
type=str,
default="info",
choices=["debug", "info", "warning", "error", "critical"],
help="Set the logging level. If not specified, defaults to INFO.",
)
parser.add_argument(
f"--{prefix}service-discovery",
action="store_true",
help="Enable Kubernetes service discovery",
)
parser.add_argument(
f"--{prefix}selector",
type=str,
nargs="+",
help="Label selector for Kubernetes service discovery (format: key1=value1 key2=value2)",
)
parser.add_argument(
f"--{prefix}service-discovery-port",
type=int,
default=RouterArgs.service_discovery_port,
help="Port to use for discovered worker pods",
)
parser.add_argument(
f"--{prefix}service-discovery-namespace",
type=str,
help="Kubernetes namespace to watch for pods. If not provided, watches all namespaces (requires cluster-wide permissions)",
)
parser.add_argument(
f"--{prefix}prefill-selector",
type=str,
nargs="+",
help="Label selector for prefill server pods in PD mode (format: key1=value1 key2=value2)",
)
parser.add_argument(
f"--{prefix}decode-selector",
type=str,
nargs="+",
help="Label selector for decode server pods in PD mode (format: key1=value1 key2=value2)",
)
# Prometheus configuration
parser.add_argument(
f"--{prefix}prometheus-port",
type=int,
default=29000,
help="Port to expose Prometheus metrics. If not specified, Prometheus metrics are disabled",
)
parser.add_argument(
f"--{prefix}prometheus-host",
type=str,
default="127.0.0.1",
help="Host address to bind the Prometheus metrics server",
)
parser.add_argument(
f"--{prefix}request-id-headers",
type=str,
nargs="*",
help="Custom HTTP headers to check for request IDs (e.g., x-request-id x-trace-id). If not specified, uses common defaults.",
)
parser.add_argument(
f"--{prefix}request-timeout-secs",
type=int,
default=RouterArgs.request_timeout_secs,
help="Request timeout in seconds",
)
# Retry configuration
parser.add_argument(
f"--{prefix}retry-max-retries",
type=int,
default=RouterArgs.retry_max_retries,
)
parser.add_argument(
f"--{prefix}retry-initial-backoff-ms",
type=int,
default=RouterArgs.retry_initial_backoff_ms,
)
parser.add_argument(
f"--{prefix}retry-max-backoff-ms",
type=int,
default=RouterArgs.retry_max_backoff_ms,
)
parser.add_argument(
f"--{prefix}retry-backoff-multiplier",
type=float,
default=RouterArgs.retry_backoff_multiplier,
)
parser.add_argument(
f"--{prefix}retry-jitter-factor",
type=float,
default=RouterArgs.retry_jitter_factor,
)
parser.add_argument(
f"--{prefix}disable-retries",
action="store_true",
help="Disable retries (equivalent to setting retry_max_retries=1)",
)
# Circuit breaker configuration
parser.add_argument(
f"--{prefix}cb-failure-threshold",
type=int,
default=RouterArgs.cb_failure_threshold,
)
parser.add_argument(
f"--{prefix}cb-success-threshold",
type=int,
default=RouterArgs.cb_success_threshold,
)
parser.add_argument(
f"--{prefix}cb-timeout-duration-secs",
type=int,
default=RouterArgs.cb_timeout_duration_secs,
)
parser.add_argument(
f"--{prefix}cb-window-duration-secs",
type=int,
default=RouterArgs.cb_window_duration_secs,
)
parser.add_argument(
f"--{prefix}disable-circuit-breaker",
action="store_true",
help="Disable circuit breaker (equivalent to setting cb_failure_threshold to u32::MAX)",
)
# Health check configuration
parser.add_argument(
f"--{prefix}health-failure-threshold",
type=int,
default=RouterArgs.health_failure_threshold,
help="Number of consecutive health check failures before marking worker unhealthy",
)
parser.add_argument(
f"--{prefix}health-success-threshold",
type=int,
default=RouterArgs.health_success_threshold,
help="Number of consecutive health check successes before marking worker healthy",
)
parser.add_argument(
f"--{prefix}health-check-timeout-secs",
type=int,
default=RouterArgs.health_check_timeout_secs,
help="Timeout in seconds for health check requests",
)
parser.add_argument(
f"--{prefix}health-check-interval-secs",
type=int,
default=RouterArgs.health_check_interval_secs,
help="Interval in seconds between runtime health checks",
)
parser.add_argument(
f"--{prefix}health-check-endpoint",
type=str,
default=RouterArgs.health_check_endpoint,
help="Health check endpoint path",
)
parser.add_argument(
f"--{prefix}max-concurrent-requests",
type=int,
default=RouterArgs.max_concurrent_requests,
help="Maximum number of concurrent requests allowed (for rate limiting)",
)
parser.add_argument(
f"--{prefix}queue-size",
type=int,
default=RouterArgs.queue_size,
help="Queue size for pending requests when max concurrent limit reached (0 = no queue, return 429 immediately)",
)
parser.add_argument(
f"--{prefix}queue-timeout-secs",
type=int,
default=RouterArgs.queue_timeout_secs,
help="Maximum time (in seconds) a request can wait in queue before timing out",
)
parser.add_argument(
f"--{prefix}rate-limit-tokens-per-second",
type=int,
default=RouterArgs.rate_limit_tokens_per_second,
help="Token bucket refill rate (tokens per second). If not set, defaults to max_concurrent_requests",
)
parser.add_argument(
f"--{prefix}cors-allowed-origins",
type=str,
nargs="*",
default=[],
help="CORS allowed origins (e.g., http://localhost:3000 https://example.com)",
)
# Tokenizer configuration
parser.add_argument(
f"--{prefix}model-path",
type=str,
default=None,
help="Model path for loading tokenizer (HuggingFace model ID or local path)",
)
parser.add_argument(
f"--{prefix}tokenizer-path",
type=str,
default=None,
help="Explicit tokenizer path (overrides model_path tokenizer if provided)",
)
@classmethod
def from_cli_args(
cls, args: argparse.Namespace, use_router_prefix: bool = False
) -> "RouterArgs":
"""
Create RouterArgs instance from parsed command line arguments.
Args:
args: Parsed command line arguments
use_router_prefix: If True, look for arguments with 'router-' prefix
"""
prefix = "router_" if use_router_prefix else ""
worker_urls = getattr(args, "worker_urls", [])
# Parse PD URLs
prefill_urls = cls._parse_prefill_urls(getattr(args, f"{prefix}prefill", None))
decode_urls = cls._parse_decode_urls(getattr(args, f"{prefix}decode", None))
return cls(
worker_urls=worker_urls,
host=args.host,
port=args.port,
pd_disaggregation=getattr(args, f"{prefix}pd_disaggregation", False),
prefill_urls=prefill_urls,
decode_urls=decode_urls,
policy=getattr(args, f"{prefix}policy"),
prefill_policy=getattr(args, f"{prefix}prefill_policy", None),
decode_policy=getattr(args, f"{prefix}decode_policy", None),
worker_startup_timeout_secs=getattr(
args, f"{prefix}worker_startup_timeout_secs"
),
worker_startup_check_interval=getattr(
args, f"{prefix}worker_startup_check_interval"
),
cache_threshold=getattr(args, f"{prefix}cache_threshold"),
balance_abs_threshold=getattr(args, f"{prefix}balance_abs_threshold"),
balance_rel_threshold=getattr(args, f"{prefix}balance_rel_threshold"),
eviction_interval=getattr(args, f"{prefix}eviction_interval"),
max_tree_size=getattr(args, f"{prefix}max_tree_size"),
max_payload_size=getattr(args, f"{prefix}max_payload_size"),
dp_aware=getattr(args, f"{prefix}dp_aware", False),
api_key=getattr(args, f"{prefix}api_key", None),
log_dir=getattr(args, f"{prefix}log_dir", None),
log_level=getattr(args, f"{prefix}log_level", None),
service_discovery=getattr(args, f"{prefix}service_discovery", False),
selector=cls._parse_selector(getattr(args, f"{prefix}selector", None)),
service_discovery_port=getattr(args, f"{prefix}service_discovery_port"),
service_discovery_namespace=getattr(
args, f"{prefix}service_discovery_namespace", None
),
prefill_selector=cls._parse_selector(
getattr(args, f"{prefix}prefill_selector", None)
),
decode_selector=cls._parse_selector(
getattr(args, f"{prefix}decode_selector", None)
),
bootstrap_port_annotation="sglang.ai/bootstrap-port", # Mooncake-specific annotation
prometheus_port=getattr(args, f"{prefix}prometheus_port", None),
prometheus_host=getattr(args, f"{prefix}prometheus_host", None),
request_id_headers=getattr(args, f"{prefix}request_id_headers", None),
request_timeout_secs=getattr(
args, f"{prefix}request_timeout_secs", RouterArgs.request_timeout_secs
),
max_concurrent_requests=getattr(
args,
f"{prefix}max_concurrent_requests",
RouterArgs.max_concurrent_requests,
),
queue_size=getattr(
args,
f"{prefix}queue_size",
RouterArgs.queue_size,
),
queue_timeout_secs=getattr(
args,
f"{prefix}queue_timeout_secs",
RouterArgs.queue_timeout_secs,
),
rate_limit_tokens_per_second=getattr(
args,
f"{prefix}rate_limit_tokens_per_second",
RouterArgs.rate_limit_tokens_per_second,
),
cors_allowed_origins=getattr(args, f"{prefix}cors_allowed_origins", []),
retry_max_retries=getattr(args, f"{prefix}retry_max_retries"),
retry_initial_backoff_ms=getattr(args, f"{prefix}retry_initial_backoff_ms"),
retry_max_backoff_ms=getattr(args, f"{prefix}retry_max_backoff_ms"),
retry_backoff_multiplier=getattr(args, f"{prefix}retry_backoff_multiplier"),
retry_jitter_factor=getattr(args, f"{prefix}retry_jitter_factor"),
cb_failure_threshold=getattr(args, f"{prefix}cb_failure_threshold"),
cb_success_threshold=getattr(args, f"{prefix}cb_success_threshold"),
cb_timeout_duration_secs=getattr(args, f"{prefix}cb_timeout_duration_secs"),
cb_window_duration_secs=getattr(args, f"{prefix}cb_window_duration_secs"),
disable_retries=getattr(args, f"{prefix}disable_retries", False),
disable_circuit_breaker=getattr(
args, f"{prefix}disable_circuit_breaker", False
),
health_failure_threshold=getattr(
args,
f"{prefix}health_failure_threshold",
RouterArgs.health_failure_threshold,
),
health_success_threshold=getattr(
args,
f"{prefix}health_success_threshold",
RouterArgs.health_success_threshold,
),
health_check_timeout_secs=getattr(
args,
f"{prefix}health_check_timeout_secs",
RouterArgs.health_check_timeout_secs,
),
health_check_interval_secs=getattr(
args,
f"{prefix}health_check_interval_secs",
RouterArgs.health_check_interval_secs,
),
health_check_endpoint=getattr(
args, f"{prefix}health_check_endpoint", RouterArgs.health_check_endpoint
),
model_path=getattr(args, f"{prefix}model_path", None),
tokenizer_path=getattr(args, f"{prefix}tokenizer_path", None),
)
@staticmethod
def _parse_selector(selector_list):
if not selector_list:
return {}
selector = {}
for item in selector_list:
if "=" in item:
key, value = item.split("=", 1)
selector[key] = value
return selector
@staticmethod
def _parse_prefill_urls(prefill_list):
"""Parse prefill URLs from --prefill arguments.
Format: --prefill URL [BOOTSTRAP_PORT]
Example:
--prefill http://prefill1:8080 9000 # With bootstrap port
--prefill http://prefill2:8080 none # Explicitly no bootstrap port
--prefill http://prefill3:8080 # Defaults to no bootstrap port
"""
if not prefill_list:
return []
prefill_urls = []
for prefill_args in prefill_list:
url = prefill_args[0]
# Handle optional bootstrap port
if len(prefill_args) >= 2:
bootstrap_port_str = prefill_args[1]
# Handle 'none' as None
if bootstrap_port_str.lower() == "none":
bootstrap_port = None
else:
try:
bootstrap_port = int(bootstrap_port_str)
except ValueError:
raise ValueError(
f"Invalid bootstrap port: {bootstrap_port_str}. Must be a number or 'none'"
)
else:
# No bootstrap port specified, default to None
bootstrap_port = None
prefill_urls.append((url, bootstrap_port))
return prefill_urls
@staticmethod
def _parse_decode_urls(decode_list):
"""Parse decode URLs from --decode arguments.
Format: --decode URL
Example: --decode http://decode1:8081 --decode http://decode2:8081
"""
if not decode_list:
return []
# decode_list is a list of single-element lists due to nargs=1
return [url[0] for url in decode_list]
def policy_from_str(policy_str: str) -> PolicyType:
"""Convert policy string to PolicyType enum."""
policy_map = {
"random": PolicyType.Random,
"round_robin": PolicyType.RoundRobin,
"cache_aware": PolicyType.CacheAware,
"power_of_two": PolicyType.PowerOfTwo,
}
return policy_map[policy_str]
def launch_router(args: argparse.Namespace) -> Optional[Router]:
"""
@@ -661,7 +29,7 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
Returns:
Router instance if successful, None if failed
"""
logger = logging.getLogger("router")
setproctitle.setproctitle("sglang::router")
try:
# Convert to RouterArgs if needed
if not isinstance(args, RouterArgs):
@@ -669,120 +37,15 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
else:
router_args = args
# Validate configuration based on mode
if router_args.pd_disaggregation:
# Validate PD configuration - skip URL requirements if using service discovery
if not router_args.service_discovery:
if not router_args.prefill_urls:
raise ValueError("PD disaggregation mode requires --prefill")
if not router_args.decode_urls:
raise ValueError("PD disaggregation mode requires --decode")
# Warn about policy usage in PD mode
if (
router_args.prefill_policy
and router_args.decode_policy
and router_args.policy
):
logger.warning(
"Both --prefill-policy and --decode-policy are specified. "
"The main --policy flag will be ignored for PD mode."
)
elif (
router_args.prefill_policy
and not router_args.decode_policy
and router_args.policy
):
logger.info(
f"Using --prefill-policy '{router_args.prefill_policy}' for prefill nodes "
f"and --policy '{router_args.policy}' for decode nodes."
)
elif (
router_args.decode_policy
and not router_args.prefill_policy
and router_args.policy
):
logger.info(
f"Using --policy '{router_args.policy}' for prefill nodes "
f"and --decode-policy '{router_args.decode_policy}' for decode nodes."
)
# Create router with unified constructor
router = Router(
worker_urls=(
[]
if router_args.service_discovery or router_args.pd_disaggregation
else router_args.worker_urls
),
host=router_args.host,
port=router_args.port,
policy=policy_from_str(router_args.policy),
worker_startup_timeout_secs=router_args.worker_startup_timeout_secs,
worker_startup_check_interval=router_args.worker_startup_check_interval,
cache_threshold=router_args.cache_threshold,
balance_abs_threshold=router_args.balance_abs_threshold,
balance_rel_threshold=router_args.balance_rel_threshold,
eviction_interval_secs=router_args.eviction_interval,
max_tree_size=router_args.max_tree_size,
max_payload_size=router_args.max_payload_size,
dp_aware=router_args.dp_aware,
api_key=router_args.api_key,
log_dir=router_args.log_dir,
log_level=router_args.log_level,
service_discovery=router_args.service_discovery,
selector=router_args.selector,
service_discovery_port=router_args.service_discovery_port,
service_discovery_namespace=router_args.service_discovery_namespace,
prefill_selector=router_args.prefill_selector,
decode_selector=router_args.decode_selector,
prometheus_port=router_args.prometheus_port,
prometheus_host=router_args.prometheus_host,
request_timeout_secs=router_args.request_timeout_secs,
pd_disaggregation=router_args.pd_disaggregation,
prefill_urls=(
router_args.prefill_urls if router_args.pd_disaggregation else None
),
decode_urls=(
router_args.decode_urls if router_args.pd_disaggregation else None
),
prefill_policy=(
policy_from_str(router_args.prefill_policy)
if router_args.prefill_policy
else None
),
decode_policy=(
policy_from_str(router_args.decode_policy)
if router_args.decode_policy
else None
),
request_id_headers=router_args.request_id_headers,
max_concurrent_requests=router_args.max_concurrent_requests,
queue_size=router_args.queue_size,
queue_timeout_secs=router_args.queue_timeout_secs,
rate_limit_tokens_per_second=router_args.rate_limit_tokens_per_second,
cors_allowed_origins=router_args.cors_allowed_origins,
retry_max_retries=router_args.retry_max_retries,
retry_initial_backoff_ms=router_args.retry_initial_backoff_ms,
retry_max_backoff_ms=router_args.retry_max_backoff_ms,
retry_backoff_multiplier=router_args.retry_backoff_multiplier,
retry_jitter_factor=router_args.retry_jitter_factor,
cb_failure_threshold=router_args.cb_failure_threshold,
cb_success_threshold=router_args.cb_success_threshold,
cb_timeout_duration_secs=router_args.cb_timeout_duration_secs,
cb_window_duration_secs=router_args.cb_window_duration_secs,
disable_retries=router_args.disable_retries,
disable_circuit_breaker=router_args.disable_circuit_breaker,
health_failure_threshold=router_args.health_failure_threshold,
health_success_threshold=router_args.health_success_threshold,
health_check_timeout_secs=router_args.health_check_timeout_secs,
health_check_interval_secs=router_args.health_check_interval_secs,
health_check_endpoint=router_args.health_check_endpoint,
model_path=router_args.model_path,
tokenizer_path=router_args.tokenizer_path,
)
router.start()
return router
if router_args.mini_lb:
mini_lb = MiniLoadBalancer(router_args)
mini_lb.start()
else:
if Router is None:
raise RuntimeError("Rust Router is not installed")
router_args._validate_router_args()
router = Router.from_args(router_args)
router.start()
except Exception as e:
logger.error(f"Error starting router: {e}")

View File

@@ -0,0 +1,395 @@
"""
Minimal HTTP load balancer for prefill and decode servers for testing.
"""
import asyncio
import ipaddress
import logging
import random
import urllib
from http import HTTPStatus
from itertools import chain
from typing import Optional
import aiohttp
import orjson
import uvicorn
from fastapi import FastAPI, HTTPException
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
from sglang_router.router_args import RouterArgs
logger = logging.getLogger(__name__)
AIOHTTP_STREAM_READ_CHUNK_SIZE = (
1024 * 64
) # 64KB, to prevent aiohttp's "Chunk too big" error
def maybe_wrap_ipv6_address(address: str) -> str:
try:
ipaddress.IPv6Address(address)
return f"[{address}]"
except ValueError:
return address
class MiniLoadBalancer:
def __init__(
self,
router_args: RouterArgs,
):
self._validate_router_args(router_args)
self.host = router_args.host
self.port = router_args.port
self.timeout = router_args.request_timeout_secs
self.prefill_urls = [url[0] for url in router_args.prefill_urls]
self.prefill_bootstrap_ports = [url[1] for url in router_args.prefill_urls]
self.decode_urls = router_args.decode_urls
def _validate_router_args(self, router_args: RouterArgs):
logger.warning(
"\x1b[33mMiniLB is only for debugging purposes, it only supports random policy!\033[0m"
)
# NOTE: too many arguments unsupported, just validate some important ones
if router_args.policy != "random":
logger.warning("[MiniLB] Overriding policy to random")
router_args.policy = "random"
if not router_args.pd_disaggregation:
raise ValueError("MiniLB only supports PD disaggregation mode")
if len(router_args.prefill_urls) == 0 or len(router_args.decode_urls) == 0:
raise ValueError(
"MiniLB requires at least one prefill and one decode server"
)
def start(self):
global lb
lb = self
uvicorn.run(app, host=self.host, port=self.port)
def select_pair(self):
assert len(self.prefill_urls) > 0, "No prefill servers available"
assert len(self.decode_urls) > 0, "No decode servers available"
pidx = random.randint(0, len(self.prefill_urls) - 1)
didx = random.randint(0, len(self.decode_urls) - 1)
return (
self.prefill_urls[pidx],
self.prefill_bootstrap_ports[pidx],
self.decode_urls[didx],
)
async def generate(
self, modified_request, prefill_server, decode_server, endpoint
) -> ORJSONResponse:
assert endpoint[0] != "/", f"Endpoint should not start with '/': {endpoint}"
async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(
total=self.timeout
) # Add timeout for request reliability
) as session:
tasks = [
session.post(f"{prefill_server}/{endpoint}", json=modified_request),
session.post(f"{decode_server}/{endpoint}", json=modified_request),
]
# Wait for both responses to complete. Prefill should end first.
prefill_response, decode_response = await asyncio.gather(*tasks)
if "return_logprob" in modified_request:
prefill_json = await prefill_response.json()
ret_json = await decode_response.json()
# merge `meta_info.input_token_logprobs` from prefill to decode
if "meta_info" in ret_json:
if "input_token_logprobs" in ret_json["meta_info"]:
ret_json["meta_info"]["input_token_logprobs"] = (
prefill_json["meta_info"]["input_token_logprobs"]
+ ret_json["meta_info"]["input_token_logprobs"]
)
else:
ret_json = await decode_response.json()
return ORJSONResponse(
content=ret_json,
status_code=decode_response.status,
)
async def generate_stream(
self, modified_request, prefill_server, decode_server, endpoint="generate"
):
assert endpoint[0] != "/", f"Endpoint should not start with '/': {endpoint}"
async def stream_results():
async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(
total=self.timeout
) # Add timeout for request reliability
) as session:
# Create the tasks for both prefill and decode requests
tasks = [
session.post(f"{prefill_server}/{endpoint}", json=modified_request),
session.post(f"{decode_server}/{endpoint}", json=modified_request),
]
# Wait for both responses to complete. Since this is streaming, they return immediately.
prefill_response, decode_response = await asyncio.gather(*tasks)
if modified_request.get("return_logprob", False):
prefill_chunks = []
async for chunk in prefill_response.content:
prefill_chunks.append(chunk)
first_prefill_chunk = (
prefill_chunks[0].decode("utf-8")[5:].strip("\n")
)
first_prefill_chunk_json = orjson.loads(first_prefill_chunk)
async for chunk in decode_response.content:
# Note: This is inefficient
# merge prefill input_token_logprobs, output_token_logprobs to decode
decoded_chunk = chunk.decode("utf-8")
if (
decoded_chunk
and decoded_chunk.startswith("data:")
and "[DONE]" not in decoded_chunk
):
ret_json = orjson.loads(decoded_chunk[5:].strip("\n"))
ret_json["meta_info"]["input_token_logprobs"] = (
first_prefill_chunk_json["meta_info"][
"input_token_logprobs"
]
+ ret_json["meta_info"]["input_token_logprobs"]
)
yield b"data: " + orjson.dumps(ret_json) + b"\n\n"
else:
yield chunk
else:
async for chunk in decode_response.content.iter_chunked(
AIOHTTP_STREAM_READ_CHUNK_SIZE
):
yield chunk
return StreamingResponse(
stream_results(),
media_type="text/event-stream",
)
app = FastAPI()
lb: Optional[MiniLoadBalancer] = None
@app.get("/health")
async def health_check():
return Response(status_code=200)
@app.get("/health_generate")
async def health_generate():
async with aiohttp.ClientSession() as session:
# Create the tasks
tasks = []
for server in chain(lb.prefill_urls, lb.decode_urls):
tasks.append(session.get(f"{server}/health_generate"))
for i, response in enumerate(asyncio.as_completed(tasks)):
await response
return Response(status_code=200)
@app.post("/flush_cache")
async def flush_cache():
async with aiohttp.ClientSession() as session:
# Create the tasks
tasks = []
for server in chain(lb.prefill_urls, lb.decode_urls):
tasks.append(session.post(f"{server}/flush_cache"))
for i, response in enumerate(asyncio.as_completed(tasks)):
await response
return Response(status_code=200)
@app.get("/get_server_info")
async def get_server_info():
prefill_infos = []
decode_infos = []
all_internal_states = []
async with aiohttp.ClientSession() as session:
for server in lb.prefill_urls:
server_info = await session.get(f"{server}/get_server_info")
prefill_infos.append(await server_info.json())
for server in lb.decode_urls:
server_info = await session.get(f"{server}/get_server_info")
info_json = await server_info.json()
decode_infos.append(info_json)
# Extract internal_states from decode servers
if "internal_states" in info_json:
all_internal_states.extend(info_json["internal_states"])
# Return format expected by bench_one_batch_server.py
if all_internal_states:
return {
"internal_states": all_internal_states,
"prefill": prefill_infos,
"decode": decode_infos,
}
else:
# Fallback with dummy data if no internal states found
return {
"internal_states": [
{
"last_gen_throughput": 0.0,
"avg_spec_accept_length": None,
}
],
"prefill": prefill_infos,
"decode": decode_infos,
}
@app.get("/get_model_info")
async def get_model_info():
if not lb or not lb.prefill_urls:
raise HTTPException(
status_code=HTTPStatus.SERVICE_UNAVAILABLE,
detail="There is no server registered",
)
target_server_url = lb.prefill_urls[0]
endpoint_url = f"{target_server_url}/get_model_info"
async with aiohttp.ClientSession() as session:
try:
async with session.get(endpoint_url) as response:
if response.status != 200:
error_text = await response.text()
raise HTTPException(
status_code=HTTPStatus.BAD_GATEWAY,
detail=(
f"Failed to get model info from {target_server_url}"
f"Status: {response.status}, Response: {error_text}"
),
)
model_info_json = await response.json()
return ORJSONResponse(content=model_info_json)
except aiohttp.ClientError as e:
raise HTTPException(
status_code=HTTPStatus.SERVICE_UNAVAILABLE,
detail=f"Failed to get model info from backend",
)
@app.post("/generate")
async def handle_generate_request(request_data: dict):
prefill_server, bootstrap_port, decode_server = lb.select_pair()
# Parse and transform prefill_server for bootstrap data
parsed_url = urllib.parse.urlparse(prefill_server)
hostname = maybe_wrap_ipv6_address(parsed_url.hostname)
modified_request = request_data.copy()
batch_size = _get_request_batch_size(modified_request)
if batch_size is not None:
modified_request.update(
{
"bootstrap_host": [hostname] * batch_size,
"bootstrap_port": [bootstrap_port] * batch_size,
"bootstrap_room": [
_generate_bootstrap_room() for _ in range(batch_size)
],
}
)
else:
modified_request.update(
{
"bootstrap_host": hostname,
"bootstrap_port": bootstrap_port,
"bootstrap_room": _generate_bootstrap_room(),
}
)
if request_data.get("stream", False):
return await lb.generate_stream(
modified_request, prefill_server, decode_server, "generate"
)
else:
return await lb.generate(
modified_request, prefill_server, decode_server, "generate"
)
async def _forward_to_backend(request_data: dict, endpoint_name: str):
prefill_server, bootstrap_port, decode_server = lb.select_pair()
# Parse and transform prefill_server for bootstrap data
parsed_url = urllib.parse.urlparse(prefill_server)
hostname = maybe_wrap_ipv6_address(parsed_url.hostname)
modified_request = request_data.copy()
modified_request.update(
{
"bootstrap_host": hostname,
"bootstrap_port": bootstrap_port,
"bootstrap_room": _generate_bootstrap_room(),
}
)
if request_data.get("stream", False):
return await lb.generate_stream(
modified_request,
prefill_server,
decode_server,
endpoint=endpoint_name,
)
else:
return await lb.generate(
modified_request,
prefill_server,
decode_server,
endpoint=endpoint_name,
)
@app.post("/v1/chat/completions")
async def handle_chat_completion_request(request_data: dict):
return await _forward_to_backend(request_data, "v1/chat/completions")
@app.post("/v1/completions")
async def handle_completion_request(request_data: dict):
return await _forward_to_backend(request_data, "v1/completions")
def _generate_bootstrap_room():
return random.randint(0, 2**63 - 1)
# We may utilize `GenerateReqInput`'s logic later
def _get_request_batch_size(request):
if (text := request.get("text")) is not None:
return None if isinstance(text, str) else len(text)
if (input_ids := request.get("input_ids")) is not None:
return None if isinstance(input_ids[0], int) else len(input_ids)
return None
@app.get("/v1/models")
async def get_models():
prefill_server = lb.prefill_urls[0] # Get the first prefill server
async with aiohttp.ClientSession() as session:
try:
response = await session.get(f"{prefill_server}/v1/models")
if response.status != 200:
raise HTTPException(
status_code=response.status,
detail=f"Prefill server error: Status {response.status}",
)
return ORJSONResponse(content=await response.json())
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

View File

@@ -1,9 +1,23 @@
from typing import Dict, List, Optional
from typing import Optional
from sglang_router.router_args import RouterArgs
from sglang_router_rs import PolicyType
from sglang_router_rs import Router as _Router
def policy_from_str(policy_str: Optional[str]) -> PolicyType:
"""Convert policy string to PolicyType enum."""
if policy_str is None:
return None
policy_map = {
"random": PolicyType.Random,
"round_robin": PolicyType.RoundRobin,
"cache_aware": PolicyType.CacheAware,
"power_of_two": PolicyType.PowerOfTwo,
}
return policy_map[policy_str]
class Router:
"""
A high-performance router for distributing requests across worker nodes.
@@ -78,130 +92,34 @@ class Router:
tokenizer_path: Explicit tokenizer path (overrides model_path tokenizer if provided). Default: None
"""
def __init__(
self,
worker_urls: List[str],
policy: PolicyType = PolicyType.RoundRobin,
host: str = "127.0.0.1",
port: int = 3001,
worker_startup_timeout_secs: int = 600,
worker_startup_check_interval: int = 30,
cache_threshold: float = 0.3,
balance_abs_threshold: int = 64,
balance_rel_threshold: float = 1.5,
eviction_interval_secs: int = 120,
max_tree_size: int = 2**26,
max_payload_size: int = 512 * 1024 * 1024, # 512MB
dp_aware: bool = False,
api_key: Optional[str] = None,
log_dir: Optional[str] = None,
log_level: Optional[str] = None,
service_discovery: bool = False,
selector: Dict[str, str] = None,
service_discovery_port: int = 80,
service_discovery_namespace: Optional[str] = None,
prefill_selector: Dict[str, str] = None,
decode_selector: Dict[str, str] = None,
bootstrap_port_annotation: str = "sglang.ai/bootstrap-port",
prometheus_port: Optional[int] = None,
prometheus_host: Optional[str] = None,
request_timeout_secs: int = 1800,
request_id_headers: Optional[List[str]] = None,
pd_disaggregation: bool = False,
prefill_urls: Optional[List[tuple]] = None,
decode_urls: Optional[List[str]] = None,
prefill_policy: Optional[PolicyType] = None,
decode_policy: Optional[PolicyType] = None,
max_concurrent_requests: int = 256,
queue_size: int = 100,
queue_timeout_secs: int = 60,
rate_limit_tokens_per_second: Optional[int] = None,
cors_allowed_origins: List[str] = None,
retry_max_retries: int = 5,
retry_initial_backoff_ms: int = 50,
retry_max_backoff_ms: int = 30_000,
retry_backoff_multiplier: float = 1.5,
retry_jitter_factor: float = 0.2,
cb_failure_threshold: int = 10,
cb_success_threshold: int = 3,
cb_timeout_duration_secs: int = 60,
cb_window_duration_secs: int = 120,
disable_retries: bool = False,
disable_circuit_breaker: bool = False,
health_failure_threshold: int = 3,
health_success_threshold: int = 2,
health_check_timeout_secs: int = 5,
health_check_interval_secs: int = 60,
health_check_endpoint: str = "/health",
model_path: Optional[str] = None,
tokenizer_path: Optional[str] = None,
):
if selector is None:
selector = {}
if prefill_selector is None:
prefill_selector = {}
if decode_selector is None:
decode_selector = {}
if cors_allowed_origins is None:
cors_allowed_origins = []
def __init__(self, router: _Router):
self._router = router
self._router = _Router(
worker_urls=worker_urls,
policy=policy,
host=host,
port=port,
worker_startup_timeout_secs=worker_startup_timeout_secs,
worker_startup_check_interval=worker_startup_check_interval,
cache_threshold=cache_threshold,
balance_abs_threshold=balance_abs_threshold,
balance_rel_threshold=balance_rel_threshold,
eviction_interval_secs=eviction_interval_secs,
max_tree_size=max_tree_size,
max_payload_size=max_payload_size,
dp_aware=dp_aware,
api_key=api_key,
log_dir=log_dir,
log_level=log_level,
service_discovery=service_discovery,
selector=selector,
service_discovery_port=service_discovery_port,
service_discovery_namespace=service_discovery_namespace,
prefill_selector=prefill_selector,
decode_selector=decode_selector,
bootstrap_port_annotation=bootstrap_port_annotation,
prometheus_port=prometheus_port,
prometheus_host=prometheus_host,
request_timeout_secs=request_timeout_secs,
request_id_headers=request_id_headers,
pd_disaggregation=pd_disaggregation,
prefill_urls=prefill_urls,
decode_urls=decode_urls,
prefill_policy=prefill_policy,
decode_policy=decode_policy,
max_concurrent_requests=max_concurrent_requests,
queue_size=queue_size,
queue_timeout_secs=queue_timeout_secs,
rate_limit_tokens_per_second=rate_limit_tokens_per_second,
cors_allowed_origins=cors_allowed_origins,
retry_max_retries=retry_max_retries,
retry_initial_backoff_ms=retry_initial_backoff_ms,
retry_max_backoff_ms=retry_max_backoff_ms,
retry_backoff_multiplier=retry_backoff_multiplier,
retry_jitter_factor=retry_jitter_factor,
cb_failure_threshold=cb_failure_threshold,
cb_success_threshold=cb_success_threshold,
cb_timeout_duration_secs=cb_timeout_duration_secs,
cb_window_duration_secs=cb_window_duration_secs,
disable_retries=disable_retries,
disable_circuit_breaker=disable_circuit_breaker,
health_failure_threshold=health_failure_threshold,
health_success_threshold=health_success_threshold,
health_check_timeout_secs=health_check_timeout_secs,
health_check_interval_secs=health_check_interval_secs,
health_check_endpoint=health_check_endpoint,
model_path=model_path,
tokenizer_path=tokenizer_path,
@staticmethod
def from_args(args: RouterArgs) -> "Router":
"""Create a router from a RouterArgs instance."""
args_dict = vars(args)
# Convert RouterArgs to _Router parameters
args_dict["worker_urls"] = (
[]
if args_dict["service_discovery"] or args_dict["pd_disaggregation"]
else args_dict["worker_urls"]
)
args_dict["policy"] = policy_from_str(args_dict["policy"])
args_dict["prefill_urls"] = (
args_dict["prefill_urls"] if args_dict["pd_disaggregation"] else None
)
args_dict["decode_urls"] = (
args_dict["decode_urls"] if args_dict["pd_disaggregation"] else None
)
args_dict["prefill_policy"] = policy_from_str(args_dict["prefill_policy"])
args_dict["decode_policy"] = policy_from_str(args_dict["decode_policy"])
# remoge mini_lb parameter
args_dict.pop("mini_lb")
return Router(_Router(**args_dict))
def start(self) -> None:
"""Start the router server.

View File

@@ -0,0 +1,577 @@
import argparse
import dataclasses
import logging
from typing import Dict, List, Optional
logger = logging.getLogger(__name__)
@dataclasses.dataclass
class RouterArgs:
# Worker configuration
worker_urls: List[str] = dataclasses.field(default_factory=list)
host: str = "127.0.0.1"
port: int = 30000
# PD-specific configuration
mini_lb: bool = False
pd_disaggregation: bool = False # Enable PD disaggregated mode
prefill_urls: List[tuple] = dataclasses.field(
default_factory=list
) # List of (url, bootstrap_port)
decode_urls: List[str] = dataclasses.field(default_factory=list)
# Routing policy
policy: str = "cache_aware"
prefill_policy: Optional[str] = None # Specific policy for prefill nodes in PD mode
decode_policy: Optional[str] = None # Specific policy for decode nodes in PD mode
worker_startup_timeout_secs: int = 600
worker_startup_check_interval: int = 30
cache_threshold: float = 0.3
balance_abs_threshold: int = 64
balance_rel_threshold: float = 1.5
eviction_interval_secs: int = 120
max_tree_size: int = 2**26
max_payload_size: int = 512 * 1024 * 1024 # 512MB default for large batches
dp_aware: bool = False
api_key: Optional[str] = None
log_dir: Optional[str] = None
log_level: Optional[str] = None
# Service discovery configuration
service_discovery: bool = False
selector: Dict[str, str] = dataclasses.field(default_factory=dict)
service_discovery_port: int = 80
service_discovery_namespace: Optional[str] = None
# PD service discovery configuration
prefill_selector: Dict[str, str] = dataclasses.field(default_factory=dict)
decode_selector: Dict[str, str] = dataclasses.field(default_factory=dict)
bootstrap_port_annotation: str = "sglang.ai/bootstrap-port"
# Prometheus configuration
prometheus_port: Optional[int] = None
prometheus_host: Optional[str] = None
# Request ID headers configuration
request_id_headers: Optional[List[str]] = None
# Request timeout in seconds
request_timeout_secs: int = 1800
# Max concurrent requests for rate limiting
max_concurrent_requests: int = 256
# Queue size for pending requests when max concurrent limit reached
queue_size: int = 100
# Maximum time (in seconds) a request can wait in queue before timing out
queue_timeout_secs: int = 60
# Token bucket refill rate (tokens per second). If not set, defaults to max_concurrent_requests
rate_limit_tokens_per_second: Optional[int] = None
# CORS allowed origins
cors_allowed_origins: List[str] = dataclasses.field(default_factory=list)
# Retry configuration
retry_max_retries: int = 5
retry_initial_backoff_ms: int = 50
retry_max_backoff_ms: int = 30_000
retry_backoff_multiplier: float = 1.5
retry_jitter_factor: float = 0.2
disable_retries: bool = False
# Health check configuration
health_failure_threshold: int = 3
health_success_threshold: int = 2
health_check_timeout_secs: int = 5
health_check_interval_secs: int = 60
health_check_endpoint: str = "/health"
# Circuit breaker configuration
cb_failure_threshold: int = 10
cb_success_threshold: int = 3
cb_timeout_duration_secs: int = 60
cb_window_duration_secs: int = 120
disable_circuit_breaker: bool = False
# Tokenizer configuration
model_path: Optional[str] = None
tokenizer_path: Optional[str] = None
@staticmethod
def add_cli_args(
parser: argparse.ArgumentParser,
use_router_prefix: bool = False,
exclude_host_port: bool = False,
):
"""
Add router-specific arguments to an argument parser.
Args:
parser: The argument parser to add arguments to
use_router_prefix: If True, prefix all arguments with 'router-' to avoid conflicts
exclude_host_port: If True, don't add host and port arguments (used when inheriting from server)
"""
prefix = "router-" if use_router_prefix else ""
# Worker configuration
if not exclude_host_port:
parser.add_argument(
"--host",
type=str,
default=RouterArgs.host,
help="Host address to bind the router server",
)
parser.add_argument(
"--port",
type=int,
default=RouterArgs.port,
help="Port number to bind the router server",
)
parser.add_argument(
"--worker-urls",
type=str,
nargs="*",
default=[],
help="List of worker URLs (e.g., http://worker1:8000 http://worker2:8000)",
)
# Routing policy configuration
parser.add_argument(
f"--{prefix}policy",
type=str,
default=RouterArgs.policy,
choices=["random", "round_robin", "cache_aware", "power_of_two"],
help="Load balancing policy to use. In PD mode, this is used for both prefill and decode unless overridden",
)
parser.add_argument(
f"--{prefix}prefill-policy",
type=str,
default=None,
choices=["random", "round_robin", "cache_aware", "power_of_two"],
help="Specific policy for prefill nodes in PD mode. If not specified, uses the main policy",
)
parser.add_argument(
f"--{prefix}decode-policy",
type=str,
default=None,
choices=["random", "round_robin", "cache_aware", "power_of_two"],
help="Specific policy for decode nodes in PD mode. If not specified, uses the main policy",
)
# PD-specific arguments
parser.add_argument(
f"--{prefix}mini-lb",
action="store_true",
help="Enable MiniLB",
)
parser.add_argument(
f"--{prefix}pd-disaggregation",
action="store_true",
help="Enable PD (Prefill-Decode) disaggregated mode",
)
parser.add_argument(
f"--{prefix}prefill",
nargs="+",
action="append",
help="Prefill server URL and optional bootstrap port. Can be specified multiple times. "
"Format: --prefill URL [BOOTSTRAP_PORT]. "
"BOOTSTRAP_PORT can be a port number, 'none', or omitted (defaults to none).",
)
parser.add_argument(
f"--{prefix}decode",
nargs=1,
action="append",
metavar=("URL",),
help="Decode server URL. Can be specified multiple times.",
)
parser.add_argument(
f"--{prefix}worker-startup-timeout-secs",
type=int,
default=RouterArgs.worker_startup_timeout_secs,
help="Timeout in seconds for worker startup",
)
parser.add_argument(
f"--{prefix}worker-startup-check-interval",
type=int,
default=RouterArgs.worker_startup_check_interval,
help="Interval in seconds between checks for worker startup",
)
parser.add_argument(
f"--{prefix}cache-threshold",
type=float,
default=RouterArgs.cache_threshold,
help="Cache threshold (0.0-1.0) for cache-aware routing",
)
parser.add_argument(
f"--{prefix}balance-abs-threshold",
type=int,
default=RouterArgs.balance_abs_threshold,
help="Load balancing is triggered when (max_load - min_load) > abs_threshold AND max_load > min_load * rel_threshold. Otherwise, use cache aware",
)
parser.add_argument(
f"--{prefix}balance-rel-threshold",
type=float,
default=RouterArgs.balance_rel_threshold,
help="Load balancing is triggered when (max_load - min_load) > abs_threshold AND max_load > min_load * rel_threshold. Otherwise, use cache aware",
)
parser.add_argument(
f"--{prefix}eviction-interval-secs",
type=int,
default=RouterArgs.eviction_interval_secs,
help="Interval in seconds between cache eviction operations",
)
parser.add_argument(
f"--{prefix}max-tree-size",
type=int,
default=RouterArgs.max_tree_size,
help="Maximum size of the approximation tree for cache-aware routing",
)
parser.add_argument(
f"--{prefix}max-payload-size",
type=int,
default=RouterArgs.max_payload_size,
help="Maximum payload size in bytes",
)
parser.add_argument(
f"--{prefix}dp-aware",
action="store_true",
help="Enable data parallelism aware schedule",
)
parser.add_argument(
f"--{prefix}api-key",
type=str,
default=None,
help="The api key used for the authorization with the worker. Useful when the dp aware scheduling strategy is enaled.",
)
parser.add_argument(
f"--{prefix}log-dir",
type=str,
default=None,
help="Directory to store log files. If not specified, logs are only output to console.",
)
parser.add_argument(
f"--{prefix}log-level",
type=str,
default="info",
choices=["debug", "info", "warning", "error", "critical"],
help="Set the logging level. If not specified, defaults to INFO.",
)
parser.add_argument(
f"--{prefix}service-discovery",
action="store_true",
help="Enable Kubernetes service discovery",
)
parser.add_argument(
f"--{prefix}selector",
type=str,
nargs="+",
default={},
help="Label selector for Kubernetes service discovery (format: key1=value1 key2=value2)",
)
parser.add_argument(
f"--{prefix}service-discovery-port",
type=int,
default=RouterArgs.service_discovery_port,
help="Port to use for discovered worker pods",
)
parser.add_argument(
f"--{prefix}service-discovery-namespace",
type=str,
help="Kubernetes namespace to watch for pods. If not provided, watches all namespaces (requires cluster-wide permissions)",
)
parser.add_argument(
f"--{prefix}prefill-selector",
type=str,
nargs="+",
default={},
help="Label selector for prefill server pods in PD mode (format: key1=value1 key2=value2)",
)
parser.add_argument(
f"--{prefix}decode-selector",
type=str,
nargs="+",
default={},
help="Label selector for decode server pods in PD mode (format: key1=value1 key2=value2)",
)
# Prometheus configuration
parser.add_argument(
f"--{prefix}prometheus-port",
type=int,
default=29000,
help="Port to expose Prometheus metrics. If not specified, Prometheus metrics are disabled",
)
parser.add_argument(
f"--{prefix}prometheus-host",
type=str,
default="127.0.0.1",
help="Host address to bind the Prometheus metrics server",
)
parser.add_argument(
f"--{prefix}request-id-headers",
type=str,
nargs="*",
help="Custom HTTP headers to check for request IDs (e.g., x-request-id x-trace-id). If not specified, uses common defaults.",
)
parser.add_argument(
f"--{prefix}request-timeout-secs",
type=int,
default=RouterArgs.request_timeout_secs,
help="Request timeout in seconds",
)
# Retry configuration
parser.add_argument(
f"--{prefix}retry-max-retries",
type=int,
default=RouterArgs.retry_max_retries,
)
parser.add_argument(
f"--{prefix}retry-initial-backoff-ms",
type=int,
default=RouterArgs.retry_initial_backoff_ms,
)
parser.add_argument(
f"--{prefix}retry-max-backoff-ms",
type=int,
default=RouterArgs.retry_max_backoff_ms,
)
parser.add_argument(
f"--{prefix}retry-backoff-multiplier",
type=float,
default=RouterArgs.retry_backoff_multiplier,
)
parser.add_argument(
f"--{prefix}retry-jitter-factor",
type=float,
default=RouterArgs.retry_jitter_factor,
)
parser.add_argument(
f"--{prefix}disable-retries",
action="store_true",
help="Disable retries (equivalent to setting retry_max_retries=1)",
)
# Circuit breaker configuration
parser.add_argument(
f"--{prefix}cb-failure-threshold",
type=int,
default=RouterArgs.cb_failure_threshold,
)
parser.add_argument(
f"--{prefix}cb-success-threshold",
type=int,
default=RouterArgs.cb_success_threshold,
)
parser.add_argument(
f"--{prefix}cb-timeout-duration-secs",
type=int,
default=RouterArgs.cb_timeout_duration_secs,
)
parser.add_argument(
f"--{prefix}cb-window-duration-secs",
type=int,
default=RouterArgs.cb_window_duration_secs,
)
parser.add_argument(
f"--{prefix}disable-circuit-breaker",
action="store_true",
help="Disable circuit breaker (equivalent to setting cb_failure_threshold to u32::MAX)",
)
# Health check configuration
parser.add_argument(
f"--{prefix}health-failure-threshold",
type=int,
default=RouterArgs.health_failure_threshold,
help="Number of consecutive health check failures before marking worker unhealthy",
)
parser.add_argument(
f"--{prefix}health-success-threshold",
type=int,
default=RouterArgs.health_success_threshold,
help="Number of consecutive health check successes before marking worker healthy",
)
parser.add_argument(
f"--{prefix}health-check-timeout-secs",
type=int,
default=RouterArgs.health_check_timeout_secs,
help="Timeout in seconds for health check requests",
)
parser.add_argument(
f"--{prefix}health-check-interval-secs",
type=int,
default=RouterArgs.health_check_interval_secs,
help="Interval in seconds between runtime health checks",
)
parser.add_argument(
f"--{prefix}health-check-endpoint",
type=str,
default=RouterArgs.health_check_endpoint,
help="Health check endpoint path",
)
parser.add_argument(
f"--{prefix}max-concurrent-requests",
type=int,
default=RouterArgs.max_concurrent_requests,
help="Maximum number of concurrent requests allowed (for rate limiting)",
)
parser.add_argument(
f"--{prefix}queue-size",
type=int,
default=RouterArgs.queue_size,
help="Queue size for pending requests when max concurrent limit reached (0 = no queue, return 429 immediately)",
)
parser.add_argument(
f"--{prefix}queue-timeout-secs",
type=int,
default=RouterArgs.queue_timeout_secs,
help="Maximum time (in seconds) a request can wait in queue before timing out",
)
parser.add_argument(
f"--{prefix}rate-limit-tokens-per-second",
type=int,
default=RouterArgs.rate_limit_tokens_per_second,
help="Token bucket refill rate (tokens per second). If not set, defaults to max_concurrent_requests",
)
parser.add_argument(
f"--{prefix}cors-allowed-origins",
type=str,
nargs="*",
default=[],
help="CORS allowed origins (e.g., http://localhost:3000 https://example.com)",
)
# Tokenizer configuration
parser.add_argument(
f"--{prefix}model-path",
type=str,
default=None,
help="Model path for loading tokenizer (HuggingFace model ID or local path)",
)
parser.add_argument(
f"--{prefix}tokenizer-path",
type=str,
default=None,
help="Explicit tokenizer path (overrides model_path tokenizer if provided)",
)
@classmethod
def from_cli_args(
cls, args: argparse.Namespace, use_router_prefix: bool = False
) -> "RouterArgs":
"""
Create RouterArgs instance from parsed command line arguments.
Args:
args: Parsed command line arguments
use_router_prefix: If True, look for arguments with 'router-' prefix
"""
prefix = "router_" if use_router_prefix else ""
cli_args_dict = vars(args)
args_dict = {}
for attr in dataclasses.fields(cls):
# Auto strip prefix from args
if f"{prefix}{attr.name}" in cli_args_dict:
args_dict[attr.name] = cli_args_dict[f"{prefix}{attr.name}"]
elif attr.name in cli_args_dict:
args_dict[attr.name] = cli_args_dict[attr.name]
# parse special arguments and remove "--prefill" and "--decode" from cli_args_dict
args_dict["prefill_urls"] = cls._parse_prefill_urls(
cli_args_dict.get(f"{prefix}prefill", None)
)
args_dict["decode_urls"] = cls._parse_decode_urls(
cli_args_dict.get(f"{prefix}decode", None)
)
args_dict["selector"] = cls._parse_selector(
cli_args_dict.get(f"{prefix}selector", None)
)
args_dict["prefill_selector"] = cls._parse_selector(
cli_args_dict.get(f"{prefix}prefill_selector", None)
)
args_dict["decode_selector"] = cls._parse_selector(
cli_args_dict.get(f"{prefix}decode_selector", None)
)
# Mooncake-specific annotation
args_dict["bootstrap_port_annotation"] = "sglang.ai/bootstrap-port"
return cls(**args_dict)
def _validate_router_args(self):
# Validate configuration based on mode
if self.pd_disaggregation:
# Validate PD configuration - skip URL requirements if using service discovery
if not self.service_discovery:
if not self.prefill_urls:
raise ValueError("PD disaggregation mode requires --prefill")
if not self.decode_urls:
raise ValueError("PD disaggregation mode requires --decode")
# Warn about policy usage in PD mode
if self.prefill_policy and self.decode_policy and self.policy:
logger.warning(
"Both --prefill-policy and --decode-policy are specified. "
"The main --policy flag will be ignored for PD mode."
)
elif self.prefill_policy and not self.decode_policy and self.policy:
logger.info(
f"Using --prefill-policy '{self.prefill_policy}' for prefill nodes "
f"and --policy '{self.policy}' for decode nodes."
)
elif self.decode_policy and not self.prefill_policy and self.policy:
logger.info(
f"Using --policy '{self.policy}' for prefill nodes "
f"and --decode-policy '{self.decode_policy}' for decode nodes."
)
@staticmethod
def _parse_selector(selector_list):
if not selector_list:
return {}
selector = {}
for item in selector_list:
if "=" in item:
key, value = item.split("=", 1)
selector[key] = value
return selector
@staticmethod
def _parse_prefill_urls(prefill_list):
"""Parse prefill URLs from --prefill arguments.
Format: --prefill URL [BOOTSTRAP_PORT]
Example:
--prefill http://prefill1:8080 9000 # With bootstrap port
--prefill http://prefill2:8080 none # Explicitly no bootstrap port
--prefill http://prefill3:8080 # Defaults to no bootstrap port
"""
if not prefill_list:
return []
prefill_urls = []
for prefill_args in prefill_list:
url = prefill_args[0]
# Handle optional bootstrap port
if len(prefill_args) >= 2:
bootstrap_port_str = prefill_args[1]
# Handle 'none' as None
if bootstrap_port_str.lower() == "none":
bootstrap_port = None
else:
try:
bootstrap_port = int(bootstrap_port_str)
except ValueError:
raise ValueError(
f"Invalid bootstrap port: {bootstrap_port_str}. Must be a number or 'none'"
)
else:
# No bootstrap port specified, default to None
bootstrap_port = None
prefill_urls.append((url, bootstrap_port))
return prefill_urls
@staticmethod
def _parse_decode_urls(decode_list):
"""Parse decode URLs from --decode arguments.
Format: --decode URL
Example: --decode http://decode1:8081 --decode http://decode2:8081
"""
if not decode_list:
return []
# decode_list is a list of single-element lists due to nargs=1
return [url[0] for url in decode_list]

View File

@@ -33,7 +33,7 @@ class TestLaunchRouter(unittest.TestCase):
cache_threshold=0.5,
balance_abs_threshold=32,
balance_rel_threshold=1.0001,
eviction_interval=60,
eviction_interval_secs=60,
max_tree_size=2**24,
max_payload_size=256 * 1024 * 1024, # 256MB
verbose=False,
@@ -176,9 +176,8 @@ class TestLaunchRouter(unittest.TestCase):
"""Test basic PD router functionality without actually starting servers."""
# This test just verifies the PD router can be created and configured
# without actually starting it (which would require real prefill/decode servers)
from sglang_router import Router
from sglang_router.launch_router import RouterArgs
from sglang_router_rs import PolicyType
from sglang_router.router import PolicyType, Router
# Test RouterArgs parsing for PD mode
# Simulate the parsed args structure from argparse with action="append"
@@ -209,18 +208,7 @@ class TestLaunchRouter(unittest.TestCase):
self.assertEqual(router_args.decode_urls[1], "http://decode2:8081")
# Test Router creation in PD mode
router = Router(
worker_urls=[], # Empty for PD mode
pd_disaggregation=True,
prefill_urls=[
("http://prefill1:8080", 9000),
("http://prefill2:8080", None),
],
decode_urls=["http://decode1:8081", "http://decode2:8081"],
policy=PolicyType.CacheAware,
host="127.0.0.1",
port=3001,
)
router = Router.from_args(router_args)
self.assertIsNotNone(router)
def test_policy_validation(self):

View File

@@ -77,7 +77,7 @@ def popen_launch_router(
port,
"--dp",
str(dp_size),
"--router-eviction-interval",
"--router-eviction-interval-secs",
"5",
"--router-policy",
policy,

View File

@@ -28,8 +28,3 @@ find = { where = ["py_src"] }
# workaround for https://github.com/pypa/twine/issues/1216
[tool.setuptools]
license-files = []
[[tool.setuptools-rust.ext-modules]]
target = "sglang_router_rs"
path = "Cargo.toml"
binding = "PyO3"

21
sgl-router/setup.py Normal file
View File

@@ -0,0 +1,21 @@
import os
from setuptools import setup
from setuptools_rust import Binding, RustExtension
no_rust = os.environ.get("SGLANG_ROUTER_BUILD_NO_RUST") == "1"
rust_extensions = []
if not no_rust:
rust_extensions.append(
RustExtension(
target="sglang_router_rs",
path="Cargo.toml",
binding=Binding.PyO3,
)
)
setup(
rust_extensions=rust_extensions,
zip_safe=False,
)