Simplify Router arguments passing and build it in docker image (#9964)
This commit is contained in:
@@ -1,7 +1,3 @@
|
||||
# a lightweihgt wrapper on router with argument type and comments
|
||||
# no wrapper on policy type => direct export
|
||||
from sglang_router.router import Router
|
||||
from sglang_router.version import __version__
|
||||
from sglang_router_rs import PolicyType
|
||||
|
||||
__all__ = ["Router", "PolicyType", "__version__"]
|
||||
__all__ = ["__version__"]
|
||||
|
||||
@@ -1,654 +1,22 @@
|
||||
import argparse
|
||||
import dataclasses
|
||||
import logging
|
||||
import sys
|
||||
from typing import Dict, List, Optional
|
||||
from typing import List, Optional
|
||||
|
||||
from sglang_router import Router
|
||||
from sglang_router_rs import PolicyType
|
||||
import setproctitle
|
||||
from sglang_router.mini_lb import MiniLoadBalancer
|
||||
from sglang_router.router_args import RouterArgs
|
||||
|
||||
logger = logging.getLogger("router")
|
||||
|
||||
def setup_logger():
|
||||
logger = logging.getLogger("router")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
formatter = logging.Formatter(
|
||||
"[Router (Python)] %(asctime)s - %(levelname)s - %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
try:
|
||||
from sglang_router.router import Router
|
||||
except ImportError:
|
||||
Router = None
|
||||
logger.warning(
|
||||
"Rust Router is not installed, only python MiniLB (debugging only) is available"
|
||||
)
|
||||
|
||||
handler = logging.StreamHandler()
|
||||
handler.setFormatter(formatter)
|
||||
logger.addHandler(handler)
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class RouterArgs:
|
||||
# Worker configuration
|
||||
worker_urls: List[str] = dataclasses.field(default_factory=list)
|
||||
host: str = "127.0.0.1"
|
||||
port: int = 30000
|
||||
|
||||
# PD-specific configuration
|
||||
pd_disaggregation: bool = False # Enable PD disaggregated mode
|
||||
prefill_urls: List[tuple] = dataclasses.field(
|
||||
default_factory=list
|
||||
) # List of (url, bootstrap_port)
|
||||
decode_urls: List[str] = dataclasses.field(default_factory=list)
|
||||
|
||||
# Routing policy
|
||||
policy: str = "cache_aware"
|
||||
prefill_policy: Optional[str] = None # Specific policy for prefill nodes in PD mode
|
||||
decode_policy: Optional[str] = None # Specific policy for decode nodes in PD mode
|
||||
worker_startup_timeout_secs: int = 600
|
||||
worker_startup_check_interval: int = 30
|
||||
cache_threshold: float = 0.3
|
||||
balance_abs_threshold: int = 64
|
||||
balance_rel_threshold: float = 1.5
|
||||
eviction_interval: int = 120
|
||||
max_tree_size: int = 2**26
|
||||
max_payload_size: int = 512 * 1024 * 1024 # 512MB default for large batches
|
||||
dp_aware: bool = False
|
||||
api_key: Optional[str] = None
|
||||
log_dir: Optional[str] = None
|
||||
log_level: Optional[str] = None
|
||||
# Service discovery configuration
|
||||
service_discovery: bool = False
|
||||
selector: Dict[str, str] = dataclasses.field(default_factory=dict)
|
||||
service_discovery_port: int = 80
|
||||
service_discovery_namespace: Optional[str] = None
|
||||
# PD service discovery configuration
|
||||
prefill_selector: Dict[str, str] = dataclasses.field(default_factory=dict)
|
||||
decode_selector: Dict[str, str] = dataclasses.field(default_factory=dict)
|
||||
bootstrap_port_annotation: str = "sglang.ai/bootstrap-port"
|
||||
# Prometheus configuration
|
||||
prometheus_port: Optional[int] = None
|
||||
prometheus_host: Optional[str] = None
|
||||
# Request ID headers configuration
|
||||
request_id_headers: Optional[List[str]] = None
|
||||
# Request timeout in seconds
|
||||
request_timeout_secs: int = 1800
|
||||
# Max concurrent requests for rate limiting
|
||||
max_concurrent_requests: int = 256
|
||||
# Queue size for pending requests when max concurrent limit reached
|
||||
queue_size: int = 100
|
||||
# Maximum time (in seconds) a request can wait in queue before timing out
|
||||
queue_timeout_secs: int = 60
|
||||
# Token bucket refill rate (tokens per second). If not set, defaults to max_concurrent_requests
|
||||
rate_limit_tokens_per_second: Optional[int] = None
|
||||
# CORS allowed origins
|
||||
cors_allowed_origins: List[str] = dataclasses.field(default_factory=list)
|
||||
# Retry configuration
|
||||
retry_max_retries: int = 5
|
||||
retry_initial_backoff_ms: int = 50
|
||||
retry_max_backoff_ms: int = 30_000
|
||||
retry_backoff_multiplier: float = 1.5
|
||||
retry_jitter_factor: float = 0.2
|
||||
disable_retries: bool = False
|
||||
# Health check configuration
|
||||
health_failure_threshold: int = 3
|
||||
health_success_threshold: int = 2
|
||||
health_check_timeout_secs: int = 5
|
||||
health_check_interval_secs: int = 60
|
||||
health_check_endpoint: str = "/health"
|
||||
# Circuit breaker configuration
|
||||
cb_failure_threshold: int = 10
|
||||
cb_success_threshold: int = 3
|
||||
cb_timeout_duration_secs: int = 60
|
||||
cb_window_duration_secs: int = 120
|
||||
disable_circuit_breaker: bool = False
|
||||
# Tokenizer configuration
|
||||
model_path: Optional[str] = None
|
||||
tokenizer_path: Optional[str] = None
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(
|
||||
parser: argparse.ArgumentParser,
|
||||
use_router_prefix: bool = False,
|
||||
exclude_host_port: bool = False,
|
||||
):
|
||||
"""
|
||||
Add router-specific arguments to an argument parser.
|
||||
|
||||
Args:
|
||||
parser: The argument parser to add arguments to
|
||||
use_router_prefix: If True, prefix all arguments with 'router-' to avoid conflicts
|
||||
exclude_host_port: If True, don't add host and port arguments (used when inheriting from server)
|
||||
"""
|
||||
prefix = "router-" if use_router_prefix else ""
|
||||
|
||||
# Worker configuration
|
||||
if not exclude_host_port:
|
||||
parser.add_argument(
|
||||
"--host",
|
||||
type=str,
|
||||
default=RouterArgs.host,
|
||||
help="Host address to bind the router server",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=RouterArgs.port,
|
||||
help="Port number to bind the router server",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--worker-urls",
|
||||
type=str,
|
||||
nargs="*",
|
||||
default=[],
|
||||
help="List of worker URLs (e.g., http://worker1:8000 http://worker2:8000)",
|
||||
)
|
||||
|
||||
# Routing policy configuration
|
||||
parser.add_argument(
|
||||
f"--{prefix}policy",
|
||||
type=str,
|
||||
default=RouterArgs.policy,
|
||||
choices=["random", "round_robin", "cache_aware", "power_of_two"],
|
||||
help="Load balancing policy to use. In PD mode, this is used for both prefill and decode unless overridden",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}prefill-policy",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=["random", "round_robin", "cache_aware", "power_of_two"],
|
||||
help="Specific policy for prefill nodes in PD mode. If not specified, uses the main policy",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}decode-policy",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=["random", "round_robin", "cache_aware", "power_of_two"],
|
||||
help="Specific policy for decode nodes in PD mode. If not specified, uses the main policy",
|
||||
)
|
||||
|
||||
# PD-specific arguments
|
||||
parser.add_argument(
|
||||
f"--{prefix}pd-disaggregation",
|
||||
action="store_true",
|
||||
help="Enable PD (Prefill-Decode) disaggregated mode",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}prefill",
|
||||
nargs="+",
|
||||
action="append",
|
||||
help="Prefill server URL and optional bootstrap port. Can be specified multiple times. "
|
||||
"Format: --prefill URL [BOOTSTRAP_PORT]. "
|
||||
"BOOTSTRAP_PORT can be a port number, 'none', or omitted (defaults to none).",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}decode",
|
||||
nargs=1,
|
||||
action="append",
|
||||
metavar=("URL",),
|
||||
help="Decode server URL. Can be specified multiple times.",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}worker-startup-timeout-secs",
|
||||
type=int,
|
||||
default=RouterArgs.worker_startup_timeout_secs,
|
||||
help="Timeout in seconds for worker startup",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}worker-startup-check-interval",
|
||||
type=int,
|
||||
default=RouterArgs.worker_startup_check_interval,
|
||||
help="Interval in seconds between checks for worker startup",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}cache-threshold",
|
||||
type=float,
|
||||
default=RouterArgs.cache_threshold,
|
||||
help="Cache threshold (0.0-1.0) for cache-aware routing",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}balance-abs-threshold",
|
||||
type=int,
|
||||
default=RouterArgs.balance_abs_threshold,
|
||||
help="Load balancing is triggered when (max_load - min_load) > abs_threshold AND max_load > min_load * rel_threshold. Otherwise, use cache aware",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}balance-rel-threshold",
|
||||
type=float,
|
||||
default=RouterArgs.balance_rel_threshold,
|
||||
help="Load balancing is triggered when (max_load - min_load) > abs_threshold AND max_load > min_load * rel_threshold. Otherwise, use cache aware",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}eviction-interval",
|
||||
type=int,
|
||||
default=RouterArgs.eviction_interval,
|
||||
help="Interval in seconds between cache eviction operations",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}max-tree-size",
|
||||
type=int,
|
||||
default=RouterArgs.max_tree_size,
|
||||
help="Maximum size of the approximation tree for cache-aware routing",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}max-payload-size",
|
||||
type=int,
|
||||
default=RouterArgs.max_payload_size,
|
||||
help="Maximum payload size in bytes",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}dp-aware",
|
||||
action="store_true",
|
||||
help="Enable data parallelism aware schedule",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}api-key",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The api key used for the authorization with the worker. Useful when the dp aware scheduling strategy is enaled.",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}log-dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Directory to store log files. If not specified, logs are only output to console.",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}log-level",
|
||||
type=str,
|
||||
default="info",
|
||||
choices=["debug", "info", "warning", "error", "critical"],
|
||||
help="Set the logging level. If not specified, defaults to INFO.",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}service-discovery",
|
||||
action="store_true",
|
||||
help="Enable Kubernetes service discovery",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}selector",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="Label selector for Kubernetes service discovery (format: key1=value1 key2=value2)",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}service-discovery-port",
|
||||
type=int,
|
||||
default=RouterArgs.service_discovery_port,
|
||||
help="Port to use for discovered worker pods",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}service-discovery-namespace",
|
||||
type=str,
|
||||
help="Kubernetes namespace to watch for pods. If not provided, watches all namespaces (requires cluster-wide permissions)",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}prefill-selector",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="Label selector for prefill server pods in PD mode (format: key1=value1 key2=value2)",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}decode-selector",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="Label selector for decode server pods in PD mode (format: key1=value1 key2=value2)",
|
||||
)
|
||||
# Prometheus configuration
|
||||
parser.add_argument(
|
||||
f"--{prefix}prometheus-port",
|
||||
type=int,
|
||||
default=29000,
|
||||
help="Port to expose Prometheus metrics. If not specified, Prometheus metrics are disabled",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}prometheus-host",
|
||||
type=str,
|
||||
default="127.0.0.1",
|
||||
help="Host address to bind the Prometheus metrics server",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}request-id-headers",
|
||||
type=str,
|
||||
nargs="*",
|
||||
help="Custom HTTP headers to check for request IDs (e.g., x-request-id x-trace-id). If not specified, uses common defaults.",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}request-timeout-secs",
|
||||
type=int,
|
||||
default=RouterArgs.request_timeout_secs,
|
||||
help="Request timeout in seconds",
|
||||
)
|
||||
# Retry configuration
|
||||
parser.add_argument(
|
||||
f"--{prefix}retry-max-retries",
|
||||
type=int,
|
||||
default=RouterArgs.retry_max_retries,
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}retry-initial-backoff-ms",
|
||||
type=int,
|
||||
default=RouterArgs.retry_initial_backoff_ms,
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}retry-max-backoff-ms",
|
||||
type=int,
|
||||
default=RouterArgs.retry_max_backoff_ms,
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}retry-backoff-multiplier",
|
||||
type=float,
|
||||
default=RouterArgs.retry_backoff_multiplier,
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}retry-jitter-factor",
|
||||
type=float,
|
||||
default=RouterArgs.retry_jitter_factor,
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}disable-retries",
|
||||
action="store_true",
|
||||
help="Disable retries (equivalent to setting retry_max_retries=1)",
|
||||
)
|
||||
# Circuit breaker configuration
|
||||
parser.add_argument(
|
||||
f"--{prefix}cb-failure-threshold",
|
||||
type=int,
|
||||
default=RouterArgs.cb_failure_threshold,
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}cb-success-threshold",
|
||||
type=int,
|
||||
default=RouterArgs.cb_success_threshold,
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}cb-timeout-duration-secs",
|
||||
type=int,
|
||||
default=RouterArgs.cb_timeout_duration_secs,
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}cb-window-duration-secs",
|
||||
type=int,
|
||||
default=RouterArgs.cb_window_duration_secs,
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}disable-circuit-breaker",
|
||||
action="store_true",
|
||||
help="Disable circuit breaker (equivalent to setting cb_failure_threshold to u32::MAX)",
|
||||
)
|
||||
# Health check configuration
|
||||
parser.add_argument(
|
||||
f"--{prefix}health-failure-threshold",
|
||||
type=int,
|
||||
default=RouterArgs.health_failure_threshold,
|
||||
help="Number of consecutive health check failures before marking worker unhealthy",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}health-success-threshold",
|
||||
type=int,
|
||||
default=RouterArgs.health_success_threshold,
|
||||
help="Number of consecutive health check successes before marking worker healthy",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}health-check-timeout-secs",
|
||||
type=int,
|
||||
default=RouterArgs.health_check_timeout_secs,
|
||||
help="Timeout in seconds for health check requests",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}health-check-interval-secs",
|
||||
type=int,
|
||||
default=RouterArgs.health_check_interval_secs,
|
||||
help="Interval in seconds between runtime health checks",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}health-check-endpoint",
|
||||
type=str,
|
||||
default=RouterArgs.health_check_endpoint,
|
||||
help="Health check endpoint path",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}max-concurrent-requests",
|
||||
type=int,
|
||||
default=RouterArgs.max_concurrent_requests,
|
||||
help="Maximum number of concurrent requests allowed (for rate limiting)",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}queue-size",
|
||||
type=int,
|
||||
default=RouterArgs.queue_size,
|
||||
help="Queue size for pending requests when max concurrent limit reached (0 = no queue, return 429 immediately)",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}queue-timeout-secs",
|
||||
type=int,
|
||||
default=RouterArgs.queue_timeout_secs,
|
||||
help="Maximum time (in seconds) a request can wait in queue before timing out",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}rate-limit-tokens-per-second",
|
||||
type=int,
|
||||
default=RouterArgs.rate_limit_tokens_per_second,
|
||||
help="Token bucket refill rate (tokens per second). If not set, defaults to max_concurrent_requests",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}cors-allowed-origins",
|
||||
type=str,
|
||||
nargs="*",
|
||||
default=[],
|
||||
help="CORS allowed origins (e.g., http://localhost:3000 https://example.com)",
|
||||
)
|
||||
# Tokenizer configuration
|
||||
parser.add_argument(
|
||||
f"--{prefix}model-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Model path for loading tokenizer (HuggingFace model ID or local path)",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}tokenizer-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Explicit tokenizer path (overrides model_path tokenizer if provided)",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(
|
||||
cls, args: argparse.Namespace, use_router_prefix: bool = False
|
||||
) -> "RouterArgs":
|
||||
"""
|
||||
Create RouterArgs instance from parsed command line arguments.
|
||||
|
||||
Args:
|
||||
args: Parsed command line arguments
|
||||
use_router_prefix: If True, look for arguments with 'router-' prefix
|
||||
"""
|
||||
prefix = "router_" if use_router_prefix else ""
|
||||
worker_urls = getattr(args, "worker_urls", [])
|
||||
|
||||
# Parse PD URLs
|
||||
prefill_urls = cls._parse_prefill_urls(getattr(args, f"{prefix}prefill", None))
|
||||
decode_urls = cls._parse_decode_urls(getattr(args, f"{prefix}decode", None))
|
||||
|
||||
return cls(
|
||||
worker_urls=worker_urls,
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
pd_disaggregation=getattr(args, f"{prefix}pd_disaggregation", False),
|
||||
prefill_urls=prefill_urls,
|
||||
decode_urls=decode_urls,
|
||||
policy=getattr(args, f"{prefix}policy"),
|
||||
prefill_policy=getattr(args, f"{prefix}prefill_policy", None),
|
||||
decode_policy=getattr(args, f"{prefix}decode_policy", None),
|
||||
worker_startup_timeout_secs=getattr(
|
||||
args, f"{prefix}worker_startup_timeout_secs"
|
||||
),
|
||||
worker_startup_check_interval=getattr(
|
||||
args, f"{prefix}worker_startup_check_interval"
|
||||
),
|
||||
cache_threshold=getattr(args, f"{prefix}cache_threshold"),
|
||||
balance_abs_threshold=getattr(args, f"{prefix}balance_abs_threshold"),
|
||||
balance_rel_threshold=getattr(args, f"{prefix}balance_rel_threshold"),
|
||||
eviction_interval=getattr(args, f"{prefix}eviction_interval"),
|
||||
max_tree_size=getattr(args, f"{prefix}max_tree_size"),
|
||||
max_payload_size=getattr(args, f"{prefix}max_payload_size"),
|
||||
dp_aware=getattr(args, f"{prefix}dp_aware", False),
|
||||
api_key=getattr(args, f"{prefix}api_key", None),
|
||||
log_dir=getattr(args, f"{prefix}log_dir", None),
|
||||
log_level=getattr(args, f"{prefix}log_level", None),
|
||||
service_discovery=getattr(args, f"{prefix}service_discovery", False),
|
||||
selector=cls._parse_selector(getattr(args, f"{prefix}selector", None)),
|
||||
service_discovery_port=getattr(args, f"{prefix}service_discovery_port"),
|
||||
service_discovery_namespace=getattr(
|
||||
args, f"{prefix}service_discovery_namespace", None
|
||||
),
|
||||
prefill_selector=cls._parse_selector(
|
||||
getattr(args, f"{prefix}prefill_selector", None)
|
||||
),
|
||||
decode_selector=cls._parse_selector(
|
||||
getattr(args, f"{prefix}decode_selector", None)
|
||||
),
|
||||
bootstrap_port_annotation="sglang.ai/bootstrap-port", # Mooncake-specific annotation
|
||||
prometheus_port=getattr(args, f"{prefix}prometheus_port", None),
|
||||
prometheus_host=getattr(args, f"{prefix}prometheus_host", None),
|
||||
request_id_headers=getattr(args, f"{prefix}request_id_headers", None),
|
||||
request_timeout_secs=getattr(
|
||||
args, f"{prefix}request_timeout_secs", RouterArgs.request_timeout_secs
|
||||
),
|
||||
max_concurrent_requests=getattr(
|
||||
args,
|
||||
f"{prefix}max_concurrent_requests",
|
||||
RouterArgs.max_concurrent_requests,
|
||||
),
|
||||
queue_size=getattr(
|
||||
args,
|
||||
f"{prefix}queue_size",
|
||||
RouterArgs.queue_size,
|
||||
),
|
||||
queue_timeout_secs=getattr(
|
||||
args,
|
||||
f"{prefix}queue_timeout_secs",
|
||||
RouterArgs.queue_timeout_secs,
|
||||
),
|
||||
rate_limit_tokens_per_second=getattr(
|
||||
args,
|
||||
f"{prefix}rate_limit_tokens_per_second",
|
||||
RouterArgs.rate_limit_tokens_per_second,
|
||||
),
|
||||
cors_allowed_origins=getattr(args, f"{prefix}cors_allowed_origins", []),
|
||||
retry_max_retries=getattr(args, f"{prefix}retry_max_retries"),
|
||||
retry_initial_backoff_ms=getattr(args, f"{prefix}retry_initial_backoff_ms"),
|
||||
retry_max_backoff_ms=getattr(args, f"{prefix}retry_max_backoff_ms"),
|
||||
retry_backoff_multiplier=getattr(args, f"{prefix}retry_backoff_multiplier"),
|
||||
retry_jitter_factor=getattr(args, f"{prefix}retry_jitter_factor"),
|
||||
cb_failure_threshold=getattr(args, f"{prefix}cb_failure_threshold"),
|
||||
cb_success_threshold=getattr(args, f"{prefix}cb_success_threshold"),
|
||||
cb_timeout_duration_secs=getattr(args, f"{prefix}cb_timeout_duration_secs"),
|
||||
cb_window_duration_secs=getattr(args, f"{prefix}cb_window_duration_secs"),
|
||||
disable_retries=getattr(args, f"{prefix}disable_retries", False),
|
||||
disable_circuit_breaker=getattr(
|
||||
args, f"{prefix}disable_circuit_breaker", False
|
||||
),
|
||||
health_failure_threshold=getattr(
|
||||
args,
|
||||
f"{prefix}health_failure_threshold",
|
||||
RouterArgs.health_failure_threshold,
|
||||
),
|
||||
health_success_threshold=getattr(
|
||||
args,
|
||||
f"{prefix}health_success_threshold",
|
||||
RouterArgs.health_success_threshold,
|
||||
),
|
||||
health_check_timeout_secs=getattr(
|
||||
args,
|
||||
f"{prefix}health_check_timeout_secs",
|
||||
RouterArgs.health_check_timeout_secs,
|
||||
),
|
||||
health_check_interval_secs=getattr(
|
||||
args,
|
||||
f"{prefix}health_check_interval_secs",
|
||||
RouterArgs.health_check_interval_secs,
|
||||
),
|
||||
health_check_endpoint=getattr(
|
||||
args, f"{prefix}health_check_endpoint", RouterArgs.health_check_endpoint
|
||||
),
|
||||
model_path=getattr(args, f"{prefix}model_path", None),
|
||||
tokenizer_path=getattr(args, f"{prefix}tokenizer_path", None),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _parse_selector(selector_list):
|
||||
if not selector_list:
|
||||
return {}
|
||||
|
||||
selector = {}
|
||||
for item in selector_list:
|
||||
if "=" in item:
|
||||
key, value = item.split("=", 1)
|
||||
selector[key] = value
|
||||
return selector
|
||||
|
||||
@staticmethod
|
||||
def _parse_prefill_urls(prefill_list):
|
||||
"""Parse prefill URLs from --prefill arguments.
|
||||
|
||||
Format: --prefill URL [BOOTSTRAP_PORT]
|
||||
Example:
|
||||
--prefill http://prefill1:8080 9000 # With bootstrap port
|
||||
--prefill http://prefill2:8080 none # Explicitly no bootstrap port
|
||||
--prefill http://prefill3:8080 # Defaults to no bootstrap port
|
||||
"""
|
||||
if not prefill_list:
|
||||
return []
|
||||
|
||||
prefill_urls = []
|
||||
for prefill_args in prefill_list:
|
||||
|
||||
url = prefill_args[0]
|
||||
|
||||
# Handle optional bootstrap port
|
||||
if len(prefill_args) >= 2:
|
||||
bootstrap_port_str = prefill_args[1]
|
||||
# Handle 'none' as None
|
||||
if bootstrap_port_str.lower() == "none":
|
||||
bootstrap_port = None
|
||||
else:
|
||||
try:
|
||||
bootstrap_port = int(bootstrap_port_str)
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
f"Invalid bootstrap port: {bootstrap_port_str}. Must be a number or 'none'"
|
||||
)
|
||||
else:
|
||||
# No bootstrap port specified, default to None
|
||||
bootstrap_port = None
|
||||
|
||||
prefill_urls.append((url, bootstrap_port))
|
||||
|
||||
return prefill_urls
|
||||
|
||||
@staticmethod
|
||||
def _parse_decode_urls(decode_list):
|
||||
"""Parse decode URLs from --decode arguments.
|
||||
|
||||
Format: --decode URL
|
||||
Example: --decode http://decode1:8081 --decode http://decode2:8081
|
||||
"""
|
||||
if not decode_list:
|
||||
return []
|
||||
|
||||
# decode_list is a list of single-element lists due to nargs=1
|
||||
return [url[0] for url in decode_list]
|
||||
|
||||
|
||||
def policy_from_str(policy_str: str) -> PolicyType:
|
||||
"""Convert policy string to PolicyType enum."""
|
||||
policy_map = {
|
||||
"random": PolicyType.Random,
|
||||
"round_robin": PolicyType.RoundRobin,
|
||||
"cache_aware": PolicyType.CacheAware,
|
||||
"power_of_two": PolicyType.PowerOfTwo,
|
||||
}
|
||||
return policy_map[policy_str]
|
||||
|
||||
|
||||
def launch_router(args: argparse.Namespace) -> Optional[Router]:
|
||||
"""
|
||||
@@ -661,7 +29,7 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
|
||||
Returns:
|
||||
Router instance if successful, None if failed
|
||||
"""
|
||||
logger = logging.getLogger("router")
|
||||
setproctitle.setproctitle("sglang::router")
|
||||
try:
|
||||
# Convert to RouterArgs if needed
|
||||
if not isinstance(args, RouterArgs):
|
||||
@@ -669,120 +37,15 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
|
||||
else:
|
||||
router_args = args
|
||||
|
||||
# Validate configuration based on mode
|
||||
if router_args.pd_disaggregation:
|
||||
# Validate PD configuration - skip URL requirements if using service discovery
|
||||
if not router_args.service_discovery:
|
||||
if not router_args.prefill_urls:
|
||||
raise ValueError("PD disaggregation mode requires --prefill")
|
||||
if not router_args.decode_urls:
|
||||
raise ValueError("PD disaggregation mode requires --decode")
|
||||
|
||||
# Warn about policy usage in PD mode
|
||||
if (
|
||||
router_args.prefill_policy
|
||||
and router_args.decode_policy
|
||||
and router_args.policy
|
||||
):
|
||||
logger.warning(
|
||||
"Both --prefill-policy and --decode-policy are specified. "
|
||||
"The main --policy flag will be ignored for PD mode."
|
||||
)
|
||||
elif (
|
||||
router_args.prefill_policy
|
||||
and not router_args.decode_policy
|
||||
and router_args.policy
|
||||
):
|
||||
logger.info(
|
||||
f"Using --prefill-policy '{router_args.prefill_policy}' for prefill nodes "
|
||||
f"and --policy '{router_args.policy}' for decode nodes."
|
||||
)
|
||||
elif (
|
||||
router_args.decode_policy
|
||||
and not router_args.prefill_policy
|
||||
and router_args.policy
|
||||
):
|
||||
logger.info(
|
||||
f"Using --policy '{router_args.policy}' for prefill nodes "
|
||||
f"and --decode-policy '{router_args.decode_policy}' for decode nodes."
|
||||
)
|
||||
|
||||
# Create router with unified constructor
|
||||
router = Router(
|
||||
worker_urls=(
|
||||
[]
|
||||
if router_args.service_discovery or router_args.pd_disaggregation
|
||||
else router_args.worker_urls
|
||||
),
|
||||
host=router_args.host,
|
||||
port=router_args.port,
|
||||
policy=policy_from_str(router_args.policy),
|
||||
worker_startup_timeout_secs=router_args.worker_startup_timeout_secs,
|
||||
worker_startup_check_interval=router_args.worker_startup_check_interval,
|
||||
cache_threshold=router_args.cache_threshold,
|
||||
balance_abs_threshold=router_args.balance_abs_threshold,
|
||||
balance_rel_threshold=router_args.balance_rel_threshold,
|
||||
eviction_interval_secs=router_args.eviction_interval,
|
||||
max_tree_size=router_args.max_tree_size,
|
||||
max_payload_size=router_args.max_payload_size,
|
||||
dp_aware=router_args.dp_aware,
|
||||
api_key=router_args.api_key,
|
||||
log_dir=router_args.log_dir,
|
||||
log_level=router_args.log_level,
|
||||
service_discovery=router_args.service_discovery,
|
||||
selector=router_args.selector,
|
||||
service_discovery_port=router_args.service_discovery_port,
|
||||
service_discovery_namespace=router_args.service_discovery_namespace,
|
||||
prefill_selector=router_args.prefill_selector,
|
||||
decode_selector=router_args.decode_selector,
|
||||
prometheus_port=router_args.prometheus_port,
|
||||
prometheus_host=router_args.prometheus_host,
|
||||
request_timeout_secs=router_args.request_timeout_secs,
|
||||
pd_disaggregation=router_args.pd_disaggregation,
|
||||
prefill_urls=(
|
||||
router_args.prefill_urls if router_args.pd_disaggregation else None
|
||||
),
|
||||
decode_urls=(
|
||||
router_args.decode_urls if router_args.pd_disaggregation else None
|
||||
),
|
||||
prefill_policy=(
|
||||
policy_from_str(router_args.prefill_policy)
|
||||
if router_args.prefill_policy
|
||||
else None
|
||||
),
|
||||
decode_policy=(
|
||||
policy_from_str(router_args.decode_policy)
|
||||
if router_args.decode_policy
|
||||
else None
|
||||
),
|
||||
request_id_headers=router_args.request_id_headers,
|
||||
max_concurrent_requests=router_args.max_concurrent_requests,
|
||||
queue_size=router_args.queue_size,
|
||||
queue_timeout_secs=router_args.queue_timeout_secs,
|
||||
rate_limit_tokens_per_second=router_args.rate_limit_tokens_per_second,
|
||||
cors_allowed_origins=router_args.cors_allowed_origins,
|
||||
retry_max_retries=router_args.retry_max_retries,
|
||||
retry_initial_backoff_ms=router_args.retry_initial_backoff_ms,
|
||||
retry_max_backoff_ms=router_args.retry_max_backoff_ms,
|
||||
retry_backoff_multiplier=router_args.retry_backoff_multiplier,
|
||||
retry_jitter_factor=router_args.retry_jitter_factor,
|
||||
cb_failure_threshold=router_args.cb_failure_threshold,
|
||||
cb_success_threshold=router_args.cb_success_threshold,
|
||||
cb_timeout_duration_secs=router_args.cb_timeout_duration_secs,
|
||||
cb_window_duration_secs=router_args.cb_window_duration_secs,
|
||||
disable_retries=router_args.disable_retries,
|
||||
disable_circuit_breaker=router_args.disable_circuit_breaker,
|
||||
health_failure_threshold=router_args.health_failure_threshold,
|
||||
health_success_threshold=router_args.health_success_threshold,
|
||||
health_check_timeout_secs=router_args.health_check_timeout_secs,
|
||||
health_check_interval_secs=router_args.health_check_interval_secs,
|
||||
health_check_endpoint=router_args.health_check_endpoint,
|
||||
model_path=router_args.model_path,
|
||||
tokenizer_path=router_args.tokenizer_path,
|
||||
)
|
||||
|
||||
router.start()
|
||||
return router
|
||||
if router_args.mini_lb:
|
||||
mini_lb = MiniLoadBalancer(router_args)
|
||||
mini_lb.start()
|
||||
else:
|
||||
if Router is None:
|
||||
raise RuntimeError("Rust Router is not installed")
|
||||
router_args._validate_router_args()
|
||||
router = Router.from_args(router_args)
|
||||
router.start()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error starting router: {e}")
|
||||
|
||||
395
sgl-router/py_src/sglang_router/mini_lb.py
Normal file
395
sgl-router/py_src/sglang_router/mini_lb.py
Normal file
@@ -0,0 +1,395 @@
|
||||
"""
|
||||
Minimal HTTP load balancer for prefill and decode servers for testing.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import ipaddress
|
||||
import logging
|
||||
import random
|
||||
import urllib
|
||||
from http import HTTPStatus
|
||||
from itertools import chain
|
||||
from typing import Optional
|
||||
|
||||
import aiohttp
|
||||
import orjson
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
|
||||
from sglang_router.router_args import RouterArgs
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
AIOHTTP_STREAM_READ_CHUNK_SIZE = (
|
||||
1024 * 64
|
||||
) # 64KB, to prevent aiohttp's "Chunk too big" error
|
||||
|
||||
|
||||
def maybe_wrap_ipv6_address(address: str) -> str:
|
||||
try:
|
||||
ipaddress.IPv6Address(address)
|
||||
return f"[{address}]"
|
||||
except ValueError:
|
||||
return address
|
||||
|
||||
|
||||
class MiniLoadBalancer:
|
||||
def __init__(
|
||||
self,
|
||||
router_args: RouterArgs,
|
||||
):
|
||||
self._validate_router_args(router_args)
|
||||
|
||||
self.host = router_args.host
|
||||
self.port = router_args.port
|
||||
self.timeout = router_args.request_timeout_secs
|
||||
self.prefill_urls = [url[0] for url in router_args.prefill_urls]
|
||||
self.prefill_bootstrap_ports = [url[1] for url in router_args.prefill_urls]
|
||||
self.decode_urls = router_args.decode_urls
|
||||
|
||||
def _validate_router_args(self, router_args: RouterArgs):
|
||||
logger.warning(
|
||||
"\x1b[33mMiniLB is only for debugging purposes, it only supports random policy!\033[0m"
|
||||
)
|
||||
|
||||
# NOTE: too many arguments unsupported, just validate some important ones
|
||||
if router_args.policy != "random":
|
||||
logger.warning("[MiniLB] Overriding policy to random")
|
||||
router_args.policy = "random"
|
||||
|
||||
if not router_args.pd_disaggregation:
|
||||
raise ValueError("MiniLB only supports PD disaggregation mode")
|
||||
|
||||
if len(router_args.prefill_urls) == 0 or len(router_args.decode_urls) == 0:
|
||||
raise ValueError(
|
||||
"MiniLB requires at least one prefill and one decode server"
|
||||
)
|
||||
|
||||
def start(self):
|
||||
global lb
|
||||
lb = self
|
||||
uvicorn.run(app, host=self.host, port=self.port)
|
||||
|
||||
def select_pair(self):
|
||||
assert len(self.prefill_urls) > 0, "No prefill servers available"
|
||||
assert len(self.decode_urls) > 0, "No decode servers available"
|
||||
pidx = random.randint(0, len(self.prefill_urls) - 1)
|
||||
didx = random.randint(0, len(self.decode_urls) - 1)
|
||||
return (
|
||||
self.prefill_urls[pidx],
|
||||
self.prefill_bootstrap_ports[pidx],
|
||||
self.decode_urls[didx],
|
||||
)
|
||||
|
||||
async def generate(
|
||||
self, modified_request, prefill_server, decode_server, endpoint
|
||||
) -> ORJSONResponse:
|
||||
assert endpoint[0] != "/", f"Endpoint should not start with '/': {endpoint}"
|
||||
|
||||
async with aiohttp.ClientSession(
|
||||
timeout=aiohttp.ClientTimeout(
|
||||
total=self.timeout
|
||||
) # Add timeout for request reliability
|
||||
) as session:
|
||||
tasks = [
|
||||
session.post(f"{prefill_server}/{endpoint}", json=modified_request),
|
||||
session.post(f"{decode_server}/{endpoint}", json=modified_request),
|
||||
]
|
||||
|
||||
# Wait for both responses to complete. Prefill should end first.
|
||||
prefill_response, decode_response = await asyncio.gather(*tasks)
|
||||
|
||||
if "return_logprob" in modified_request:
|
||||
|
||||
prefill_json = await prefill_response.json()
|
||||
ret_json = await decode_response.json()
|
||||
|
||||
# merge `meta_info.input_token_logprobs` from prefill to decode
|
||||
if "meta_info" in ret_json:
|
||||
if "input_token_logprobs" in ret_json["meta_info"]:
|
||||
ret_json["meta_info"]["input_token_logprobs"] = (
|
||||
prefill_json["meta_info"]["input_token_logprobs"]
|
||||
+ ret_json["meta_info"]["input_token_logprobs"]
|
||||
)
|
||||
else:
|
||||
ret_json = await decode_response.json()
|
||||
|
||||
return ORJSONResponse(
|
||||
content=ret_json,
|
||||
status_code=decode_response.status,
|
||||
)
|
||||
|
||||
async def generate_stream(
|
||||
self, modified_request, prefill_server, decode_server, endpoint="generate"
|
||||
):
|
||||
assert endpoint[0] != "/", f"Endpoint should not start with '/': {endpoint}"
|
||||
|
||||
async def stream_results():
|
||||
async with aiohttp.ClientSession(
|
||||
timeout=aiohttp.ClientTimeout(
|
||||
total=self.timeout
|
||||
) # Add timeout for request reliability
|
||||
) as session:
|
||||
# Create the tasks for both prefill and decode requests
|
||||
tasks = [
|
||||
session.post(f"{prefill_server}/{endpoint}", json=modified_request),
|
||||
session.post(f"{decode_server}/{endpoint}", json=modified_request),
|
||||
]
|
||||
# Wait for both responses to complete. Since this is streaming, they return immediately.
|
||||
prefill_response, decode_response = await asyncio.gather(*tasks)
|
||||
|
||||
if modified_request.get("return_logprob", False):
|
||||
prefill_chunks = []
|
||||
async for chunk in prefill_response.content:
|
||||
prefill_chunks.append(chunk)
|
||||
|
||||
first_prefill_chunk = (
|
||||
prefill_chunks[0].decode("utf-8")[5:].strip("\n")
|
||||
)
|
||||
first_prefill_chunk_json = orjson.loads(first_prefill_chunk)
|
||||
|
||||
async for chunk in decode_response.content:
|
||||
# Note: This is inefficient
|
||||
# merge prefill input_token_logprobs, output_token_logprobs to decode
|
||||
decoded_chunk = chunk.decode("utf-8")
|
||||
if (
|
||||
decoded_chunk
|
||||
and decoded_chunk.startswith("data:")
|
||||
and "[DONE]" not in decoded_chunk
|
||||
):
|
||||
ret_json = orjson.loads(decoded_chunk[5:].strip("\n"))
|
||||
ret_json["meta_info"]["input_token_logprobs"] = (
|
||||
first_prefill_chunk_json["meta_info"][
|
||||
"input_token_logprobs"
|
||||
]
|
||||
+ ret_json["meta_info"]["input_token_logprobs"]
|
||||
)
|
||||
|
||||
yield b"data: " + orjson.dumps(ret_json) + b"\n\n"
|
||||
else:
|
||||
yield chunk
|
||||
else:
|
||||
async for chunk in decode_response.content.iter_chunked(
|
||||
AIOHTTP_STREAM_READ_CHUNK_SIZE
|
||||
):
|
||||
yield chunk
|
||||
|
||||
return StreamingResponse(
|
||||
stream_results(),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
lb: Optional[MiniLoadBalancer] = None
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@app.get("/health_generate")
|
||||
async def health_generate():
|
||||
async with aiohttp.ClientSession() as session:
|
||||
# Create the tasks
|
||||
tasks = []
|
||||
for server in chain(lb.prefill_urls, lb.decode_urls):
|
||||
tasks.append(session.get(f"{server}/health_generate"))
|
||||
for i, response in enumerate(asyncio.as_completed(tasks)):
|
||||
await response
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@app.post("/flush_cache")
|
||||
async def flush_cache():
|
||||
async with aiohttp.ClientSession() as session:
|
||||
# Create the tasks
|
||||
tasks = []
|
||||
for server in chain(lb.prefill_urls, lb.decode_urls):
|
||||
tasks.append(session.post(f"{server}/flush_cache"))
|
||||
for i, response in enumerate(asyncio.as_completed(tasks)):
|
||||
await response
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@app.get("/get_server_info")
|
||||
async def get_server_info():
|
||||
prefill_infos = []
|
||||
decode_infos = []
|
||||
all_internal_states = []
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
for server in lb.prefill_urls:
|
||||
server_info = await session.get(f"{server}/get_server_info")
|
||||
prefill_infos.append(await server_info.json())
|
||||
for server in lb.decode_urls:
|
||||
server_info = await session.get(f"{server}/get_server_info")
|
||||
info_json = await server_info.json()
|
||||
decode_infos.append(info_json)
|
||||
# Extract internal_states from decode servers
|
||||
if "internal_states" in info_json:
|
||||
all_internal_states.extend(info_json["internal_states"])
|
||||
|
||||
# Return format expected by bench_one_batch_server.py
|
||||
if all_internal_states:
|
||||
return {
|
||||
"internal_states": all_internal_states,
|
||||
"prefill": prefill_infos,
|
||||
"decode": decode_infos,
|
||||
}
|
||||
else:
|
||||
# Fallback with dummy data if no internal states found
|
||||
return {
|
||||
"internal_states": [
|
||||
{
|
||||
"last_gen_throughput": 0.0,
|
||||
"avg_spec_accept_length": None,
|
||||
}
|
||||
],
|
||||
"prefill": prefill_infos,
|
||||
"decode": decode_infos,
|
||||
}
|
||||
|
||||
|
||||
@app.get("/get_model_info")
|
||||
async def get_model_info():
|
||||
if not lb or not lb.prefill_urls:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.SERVICE_UNAVAILABLE,
|
||||
detail="There is no server registered",
|
||||
)
|
||||
|
||||
target_server_url = lb.prefill_urls[0]
|
||||
endpoint_url = f"{target_server_url}/get_model_info"
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
try:
|
||||
async with session.get(endpoint_url) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.BAD_GATEWAY,
|
||||
detail=(
|
||||
f"Failed to get model info from {target_server_url}"
|
||||
f"Status: {response.status}, Response: {error_text}"
|
||||
),
|
||||
)
|
||||
|
||||
model_info_json = await response.json()
|
||||
return ORJSONResponse(content=model_info_json)
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.SERVICE_UNAVAILABLE,
|
||||
detail=f"Failed to get model info from backend",
|
||||
)
|
||||
|
||||
|
||||
@app.post("/generate")
|
||||
async def handle_generate_request(request_data: dict):
|
||||
prefill_server, bootstrap_port, decode_server = lb.select_pair()
|
||||
|
||||
# Parse and transform prefill_server for bootstrap data
|
||||
parsed_url = urllib.parse.urlparse(prefill_server)
|
||||
hostname = maybe_wrap_ipv6_address(parsed_url.hostname)
|
||||
modified_request = request_data.copy()
|
||||
|
||||
batch_size = _get_request_batch_size(modified_request)
|
||||
if batch_size is not None:
|
||||
modified_request.update(
|
||||
{
|
||||
"bootstrap_host": [hostname] * batch_size,
|
||||
"bootstrap_port": [bootstrap_port] * batch_size,
|
||||
"bootstrap_room": [
|
||||
_generate_bootstrap_room() for _ in range(batch_size)
|
||||
],
|
||||
}
|
||||
)
|
||||
else:
|
||||
modified_request.update(
|
||||
{
|
||||
"bootstrap_host": hostname,
|
||||
"bootstrap_port": bootstrap_port,
|
||||
"bootstrap_room": _generate_bootstrap_room(),
|
||||
}
|
||||
)
|
||||
|
||||
if request_data.get("stream", False):
|
||||
return await lb.generate_stream(
|
||||
modified_request, prefill_server, decode_server, "generate"
|
||||
)
|
||||
else:
|
||||
return await lb.generate(
|
||||
modified_request, prefill_server, decode_server, "generate"
|
||||
)
|
||||
|
||||
|
||||
async def _forward_to_backend(request_data: dict, endpoint_name: str):
|
||||
prefill_server, bootstrap_port, decode_server = lb.select_pair()
|
||||
|
||||
# Parse and transform prefill_server for bootstrap data
|
||||
parsed_url = urllib.parse.urlparse(prefill_server)
|
||||
hostname = maybe_wrap_ipv6_address(parsed_url.hostname)
|
||||
modified_request = request_data.copy()
|
||||
modified_request.update(
|
||||
{
|
||||
"bootstrap_host": hostname,
|
||||
"bootstrap_port": bootstrap_port,
|
||||
"bootstrap_room": _generate_bootstrap_room(),
|
||||
}
|
||||
)
|
||||
|
||||
if request_data.get("stream", False):
|
||||
return await lb.generate_stream(
|
||||
modified_request,
|
||||
prefill_server,
|
||||
decode_server,
|
||||
endpoint=endpoint_name,
|
||||
)
|
||||
else:
|
||||
return await lb.generate(
|
||||
modified_request,
|
||||
prefill_server,
|
||||
decode_server,
|
||||
endpoint=endpoint_name,
|
||||
)
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
async def handle_chat_completion_request(request_data: dict):
|
||||
return await _forward_to_backend(request_data, "v1/chat/completions")
|
||||
|
||||
|
||||
@app.post("/v1/completions")
|
||||
async def handle_completion_request(request_data: dict):
|
||||
return await _forward_to_backend(request_data, "v1/completions")
|
||||
|
||||
|
||||
def _generate_bootstrap_room():
|
||||
return random.randint(0, 2**63 - 1)
|
||||
|
||||
|
||||
# We may utilize `GenerateReqInput`'s logic later
|
||||
def _get_request_batch_size(request):
|
||||
if (text := request.get("text")) is not None:
|
||||
return None if isinstance(text, str) else len(text)
|
||||
if (input_ids := request.get("input_ids")) is not None:
|
||||
return None if isinstance(input_ids[0], int) else len(input_ids)
|
||||
return None
|
||||
|
||||
|
||||
@app.get("/v1/models")
|
||||
async def get_models():
|
||||
prefill_server = lb.prefill_urls[0] # Get the first prefill server
|
||||
async with aiohttp.ClientSession() as session:
|
||||
try:
|
||||
response = await session.get(f"{prefill_server}/v1/models")
|
||||
if response.status != 200:
|
||||
raise HTTPException(
|
||||
status_code=response.status,
|
||||
detail=f"Prefill server error: Status {response.status}",
|
||||
)
|
||||
return ORJSONResponse(content=await response.json())
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
@@ -1,9 +1,23 @@
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Optional
|
||||
|
||||
from sglang_router.router_args import RouterArgs
|
||||
from sglang_router_rs import PolicyType
|
||||
from sglang_router_rs import Router as _Router
|
||||
|
||||
|
||||
def policy_from_str(policy_str: Optional[str]) -> PolicyType:
|
||||
"""Convert policy string to PolicyType enum."""
|
||||
if policy_str is None:
|
||||
return None
|
||||
policy_map = {
|
||||
"random": PolicyType.Random,
|
||||
"round_robin": PolicyType.RoundRobin,
|
||||
"cache_aware": PolicyType.CacheAware,
|
||||
"power_of_two": PolicyType.PowerOfTwo,
|
||||
}
|
||||
return policy_map[policy_str]
|
||||
|
||||
|
||||
class Router:
|
||||
"""
|
||||
A high-performance router for distributing requests across worker nodes.
|
||||
@@ -78,130 +92,34 @@ class Router:
|
||||
tokenizer_path: Explicit tokenizer path (overrides model_path tokenizer if provided). Default: None
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
worker_urls: List[str],
|
||||
policy: PolicyType = PolicyType.RoundRobin,
|
||||
host: str = "127.0.0.1",
|
||||
port: int = 3001,
|
||||
worker_startup_timeout_secs: int = 600,
|
||||
worker_startup_check_interval: int = 30,
|
||||
cache_threshold: float = 0.3,
|
||||
balance_abs_threshold: int = 64,
|
||||
balance_rel_threshold: float = 1.5,
|
||||
eviction_interval_secs: int = 120,
|
||||
max_tree_size: int = 2**26,
|
||||
max_payload_size: int = 512 * 1024 * 1024, # 512MB
|
||||
dp_aware: bool = False,
|
||||
api_key: Optional[str] = None,
|
||||
log_dir: Optional[str] = None,
|
||||
log_level: Optional[str] = None,
|
||||
service_discovery: bool = False,
|
||||
selector: Dict[str, str] = None,
|
||||
service_discovery_port: int = 80,
|
||||
service_discovery_namespace: Optional[str] = None,
|
||||
prefill_selector: Dict[str, str] = None,
|
||||
decode_selector: Dict[str, str] = None,
|
||||
bootstrap_port_annotation: str = "sglang.ai/bootstrap-port",
|
||||
prometheus_port: Optional[int] = None,
|
||||
prometheus_host: Optional[str] = None,
|
||||
request_timeout_secs: int = 1800,
|
||||
request_id_headers: Optional[List[str]] = None,
|
||||
pd_disaggregation: bool = False,
|
||||
prefill_urls: Optional[List[tuple]] = None,
|
||||
decode_urls: Optional[List[str]] = None,
|
||||
prefill_policy: Optional[PolicyType] = None,
|
||||
decode_policy: Optional[PolicyType] = None,
|
||||
max_concurrent_requests: int = 256,
|
||||
queue_size: int = 100,
|
||||
queue_timeout_secs: int = 60,
|
||||
rate_limit_tokens_per_second: Optional[int] = None,
|
||||
cors_allowed_origins: List[str] = None,
|
||||
retry_max_retries: int = 5,
|
||||
retry_initial_backoff_ms: int = 50,
|
||||
retry_max_backoff_ms: int = 30_000,
|
||||
retry_backoff_multiplier: float = 1.5,
|
||||
retry_jitter_factor: float = 0.2,
|
||||
cb_failure_threshold: int = 10,
|
||||
cb_success_threshold: int = 3,
|
||||
cb_timeout_duration_secs: int = 60,
|
||||
cb_window_duration_secs: int = 120,
|
||||
disable_retries: bool = False,
|
||||
disable_circuit_breaker: bool = False,
|
||||
health_failure_threshold: int = 3,
|
||||
health_success_threshold: int = 2,
|
||||
health_check_timeout_secs: int = 5,
|
||||
health_check_interval_secs: int = 60,
|
||||
health_check_endpoint: str = "/health",
|
||||
model_path: Optional[str] = None,
|
||||
tokenizer_path: Optional[str] = None,
|
||||
):
|
||||
if selector is None:
|
||||
selector = {}
|
||||
if prefill_selector is None:
|
||||
prefill_selector = {}
|
||||
if decode_selector is None:
|
||||
decode_selector = {}
|
||||
if cors_allowed_origins is None:
|
||||
cors_allowed_origins = []
|
||||
def __init__(self, router: _Router):
|
||||
self._router = router
|
||||
|
||||
self._router = _Router(
|
||||
worker_urls=worker_urls,
|
||||
policy=policy,
|
||||
host=host,
|
||||
port=port,
|
||||
worker_startup_timeout_secs=worker_startup_timeout_secs,
|
||||
worker_startup_check_interval=worker_startup_check_interval,
|
||||
cache_threshold=cache_threshold,
|
||||
balance_abs_threshold=balance_abs_threshold,
|
||||
balance_rel_threshold=balance_rel_threshold,
|
||||
eviction_interval_secs=eviction_interval_secs,
|
||||
max_tree_size=max_tree_size,
|
||||
max_payload_size=max_payload_size,
|
||||
dp_aware=dp_aware,
|
||||
api_key=api_key,
|
||||
log_dir=log_dir,
|
||||
log_level=log_level,
|
||||
service_discovery=service_discovery,
|
||||
selector=selector,
|
||||
service_discovery_port=service_discovery_port,
|
||||
service_discovery_namespace=service_discovery_namespace,
|
||||
prefill_selector=prefill_selector,
|
||||
decode_selector=decode_selector,
|
||||
bootstrap_port_annotation=bootstrap_port_annotation,
|
||||
prometheus_port=prometheus_port,
|
||||
prometheus_host=prometheus_host,
|
||||
request_timeout_secs=request_timeout_secs,
|
||||
request_id_headers=request_id_headers,
|
||||
pd_disaggregation=pd_disaggregation,
|
||||
prefill_urls=prefill_urls,
|
||||
decode_urls=decode_urls,
|
||||
prefill_policy=prefill_policy,
|
||||
decode_policy=decode_policy,
|
||||
max_concurrent_requests=max_concurrent_requests,
|
||||
queue_size=queue_size,
|
||||
queue_timeout_secs=queue_timeout_secs,
|
||||
rate_limit_tokens_per_second=rate_limit_tokens_per_second,
|
||||
cors_allowed_origins=cors_allowed_origins,
|
||||
retry_max_retries=retry_max_retries,
|
||||
retry_initial_backoff_ms=retry_initial_backoff_ms,
|
||||
retry_max_backoff_ms=retry_max_backoff_ms,
|
||||
retry_backoff_multiplier=retry_backoff_multiplier,
|
||||
retry_jitter_factor=retry_jitter_factor,
|
||||
cb_failure_threshold=cb_failure_threshold,
|
||||
cb_success_threshold=cb_success_threshold,
|
||||
cb_timeout_duration_secs=cb_timeout_duration_secs,
|
||||
cb_window_duration_secs=cb_window_duration_secs,
|
||||
disable_retries=disable_retries,
|
||||
disable_circuit_breaker=disable_circuit_breaker,
|
||||
health_failure_threshold=health_failure_threshold,
|
||||
health_success_threshold=health_success_threshold,
|
||||
health_check_timeout_secs=health_check_timeout_secs,
|
||||
health_check_interval_secs=health_check_interval_secs,
|
||||
health_check_endpoint=health_check_endpoint,
|
||||
model_path=model_path,
|
||||
tokenizer_path=tokenizer_path,
|
||||
@staticmethod
|
||||
def from_args(args: RouterArgs) -> "Router":
|
||||
"""Create a router from a RouterArgs instance."""
|
||||
|
||||
args_dict = vars(args)
|
||||
# Convert RouterArgs to _Router parameters
|
||||
args_dict["worker_urls"] = (
|
||||
[]
|
||||
if args_dict["service_discovery"] or args_dict["pd_disaggregation"]
|
||||
else args_dict["worker_urls"]
|
||||
)
|
||||
args_dict["policy"] = policy_from_str(args_dict["policy"])
|
||||
args_dict["prefill_urls"] = (
|
||||
args_dict["prefill_urls"] if args_dict["pd_disaggregation"] else None
|
||||
)
|
||||
args_dict["decode_urls"] = (
|
||||
args_dict["decode_urls"] if args_dict["pd_disaggregation"] else None
|
||||
)
|
||||
args_dict["prefill_policy"] = policy_from_str(args_dict["prefill_policy"])
|
||||
args_dict["decode_policy"] = policy_from_str(args_dict["decode_policy"])
|
||||
|
||||
# remoge mini_lb parameter
|
||||
args_dict.pop("mini_lb")
|
||||
|
||||
return Router(_Router(**args_dict))
|
||||
|
||||
def start(self) -> None:
|
||||
"""Start the router server.
|
||||
|
||||
577
sgl-router/py_src/sglang_router/router_args.py
Normal file
577
sgl-router/py_src/sglang_router/router_args.py
Normal file
@@ -0,0 +1,577 @@
|
||||
import argparse
|
||||
import dataclasses
|
||||
import logging
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class RouterArgs:
|
||||
# Worker configuration
|
||||
worker_urls: List[str] = dataclasses.field(default_factory=list)
|
||||
host: str = "127.0.0.1"
|
||||
port: int = 30000
|
||||
|
||||
# PD-specific configuration
|
||||
mini_lb: bool = False
|
||||
pd_disaggregation: bool = False # Enable PD disaggregated mode
|
||||
prefill_urls: List[tuple] = dataclasses.field(
|
||||
default_factory=list
|
||||
) # List of (url, bootstrap_port)
|
||||
decode_urls: List[str] = dataclasses.field(default_factory=list)
|
||||
|
||||
# Routing policy
|
||||
policy: str = "cache_aware"
|
||||
prefill_policy: Optional[str] = None # Specific policy for prefill nodes in PD mode
|
||||
decode_policy: Optional[str] = None # Specific policy for decode nodes in PD mode
|
||||
worker_startup_timeout_secs: int = 600
|
||||
worker_startup_check_interval: int = 30
|
||||
cache_threshold: float = 0.3
|
||||
balance_abs_threshold: int = 64
|
||||
balance_rel_threshold: float = 1.5
|
||||
eviction_interval_secs: int = 120
|
||||
max_tree_size: int = 2**26
|
||||
max_payload_size: int = 512 * 1024 * 1024 # 512MB default for large batches
|
||||
dp_aware: bool = False
|
||||
api_key: Optional[str] = None
|
||||
log_dir: Optional[str] = None
|
||||
log_level: Optional[str] = None
|
||||
# Service discovery configuration
|
||||
service_discovery: bool = False
|
||||
selector: Dict[str, str] = dataclasses.field(default_factory=dict)
|
||||
service_discovery_port: int = 80
|
||||
service_discovery_namespace: Optional[str] = None
|
||||
# PD service discovery configuration
|
||||
prefill_selector: Dict[str, str] = dataclasses.field(default_factory=dict)
|
||||
decode_selector: Dict[str, str] = dataclasses.field(default_factory=dict)
|
||||
bootstrap_port_annotation: str = "sglang.ai/bootstrap-port"
|
||||
# Prometheus configuration
|
||||
prometheus_port: Optional[int] = None
|
||||
prometheus_host: Optional[str] = None
|
||||
# Request ID headers configuration
|
||||
request_id_headers: Optional[List[str]] = None
|
||||
# Request timeout in seconds
|
||||
request_timeout_secs: int = 1800
|
||||
# Max concurrent requests for rate limiting
|
||||
max_concurrent_requests: int = 256
|
||||
# Queue size for pending requests when max concurrent limit reached
|
||||
queue_size: int = 100
|
||||
# Maximum time (in seconds) a request can wait in queue before timing out
|
||||
queue_timeout_secs: int = 60
|
||||
# Token bucket refill rate (tokens per second). If not set, defaults to max_concurrent_requests
|
||||
rate_limit_tokens_per_second: Optional[int] = None
|
||||
# CORS allowed origins
|
||||
cors_allowed_origins: List[str] = dataclasses.field(default_factory=list)
|
||||
# Retry configuration
|
||||
retry_max_retries: int = 5
|
||||
retry_initial_backoff_ms: int = 50
|
||||
retry_max_backoff_ms: int = 30_000
|
||||
retry_backoff_multiplier: float = 1.5
|
||||
retry_jitter_factor: float = 0.2
|
||||
disable_retries: bool = False
|
||||
# Health check configuration
|
||||
health_failure_threshold: int = 3
|
||||
health_success_threshold: int = 2
|
||||
health_check_timeout_secs: int = 5
|
||||
health_check_interval_secs: int = 60
|
||||
health_check_endpoint: str = "/health"
|
||||
# Circuit breaker configuration
|
||||
cb_failure_threshold: int = 10
|
||||
cb_success_threshold: int = 3
|
||||
cb_timeout_duration_secs: int = 60
|
||||
cb_window_duration_secs: int = 120
|
||||
disable_circuit_breaker: bool = False
|
||||
# Tokenizer configuration
|
||||
model_path: Optional[str] = None
|
||||
tokenizer_path: Optional[str] = None
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(
|
||||
parser: argparse.ArgumentParser,
|
||||
use_router_prefix: bool = False,
|
||||
exclude_host_port: bool = False,
|
||||
):
|
||||
"""
|
||||
Add router-specific arguments to an argument parser.
|
||||
|
||||
Args:
|
||||
parser: The argument parser to add arguments to
|
||||
use_router_prefix: If True, prefix all arguments with 'router-' to avoid conflicts
|
||||
exclude_host_port: If True, don't add host and port arguments (used when inheriting from server)
|
||||
"""
|
||||
prefix = "router-" if use_router_prefix else ""
|
||||
|
||||
# Worker configuration
|
||||
if not exclude_host_port:
|
||||
parser.add_argument(
|
||||
"--host",
|
||||
type=str,
|
||||
default=RouterArgs.host,
|
||||
help="Host address to bind the router server",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=RouterArgs.port,
|
||||
help="Port number to bind the router server",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--worker-urls",
|
||||
type=str,
|
||||
nargs="*",
|
||||
default=[],
|
||||
help="List of worker URLs (e.g., http://worker1:8000 http://worker2:8000)",
|
||||
)
|
||||
|
||||
# Routing policy configuration
|
||||
parser.add_argument(
|
||||
f"--{prefix}policy",
|
||||
type=str,
|
||||
default=RouterArgs.policy,
|
||||
choices=["random", "round_robin", "cache_aware", "power_of_two"],
|
||||
help="Load balancing policy to use. In PD mode, this is used for both prefill and decode unless overridden",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}prefill-policy",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=["random", "round_robin", "cache_aware", "power_of_two"],
|
||||
help="Specific policy for prefill nodes in PD mode. If not specified, uses the main policy",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}decode-policy",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=["random", "round_robin", "cache_aware", "power_of_two"],
|
||||
help="Specific policy for decode nodes in PD mode. If not specified, uses the main policy",
|
||||
)
|
||||
|
||||
# PD-specific arguments
|
||||
parser.add_argument(
|
||||
f"--{prefix}mini-lb",
|
||||
action="store_true",
|
||||
help="Enable MiniLB",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}pd-disaggregation",
|
||||
action="store_true",
|
||||
help="Enable PD (Prefill-Decode) disaggregated mode",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}prefill",
|
||||
nargs="+",
|
||||
action="append",
|
||||
help="Prefill server URL and optional bootstrap port. Can be specified multiple times. "
|
||||
"Format: --prefill URL [BOOTSTRAP_PORT]. "
|
||||
"BOOTSTRAP_PORT can be a port number, 'none', or omitted (defaults to none).",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}decode",
|
||||
nargs=1,
|
||||
action="append",
|
||||
metavar=("URL",),
|
||||
help="Decode server URL. Can be specified multiple times.",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}worker-startup-timeout-secs",
|
||||
type=int,
|
||||
default=RouterArgs.worker_startup_timeout_secs,
|
||||
help="Timeout in seconds for worker startup",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}worker-startup-check-interval",
|
||||
type=int,
|
||||
default=RouterArgs.worker_startup_check_interval,
|
||||
help="Interval in seconds between checks for worker startup",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}cache-threshold",
|
||||
type=float,
|
||||
default=RouterArgs.cache_threshold,
|
||||
help="Cache threshold (0.0-1.0) for cache-aware routing",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}balance-abs-threshold",
|
||||
type=int,
|
||||
default=RouterArgs.balance_abs_threshold,
|
||||
help="Load balancing is triggered when (max_load - min_load) > abs_threshold AND max_load > min_load * rel_threshold. Otherwise, use cache aware",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}balance-rel-threshold",
|
||||
type=float,
|
||||
default=RouterArgs.balance_rel_threshold,
|
||||
help="Load balancing is triggered when (max_load - min_load) > abs_threshold AND max_load > min_load * rel_threshold. Otherwise, use cache aware",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}eviction-interval-secs",
|
||||
type=int,
|
||||
default=RouterArgs.eviction_interval_secs,
|
||||
help="Interval in seconds between cache eviction operations",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}max-tree-size",
|
||||
type=int,
|
||||
default=RouterArgs.max_tree_size,
|
||||
help="Maximum size of the approximation tree for cache-aware routing",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}max-payload-size",
|
||||
type=int,
|
||||
default=RouterArgs.max_payload_size,
|
||||
help="Maximum payload size in bytes",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}dp-aware",
|
||||
action="store_true",
|
||||
help="Enable data parallelism aware schedule",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}api-key",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The api key used for the authorization with the worker. Useful when the dp aware scheduling strategy is enaled.",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}log-dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Directory to store log files. If not specified, logs are only output to console.",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}log-level",
|
||||
type=str,
|
||||
default="info",
|
||||
choices=["debug", "info", "warning", "error", "critical"],
|
||||
help="Set the logging level. If not specified, defaults to INFO.",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}service-discovery",
|
||||
action="store_true",
|
||||
help="Enable Kubernetes service discovery",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}selector",
|
||||
type=str,
|
||||
nargs="+",
|
||||
default={},
|
||||
help="Label selector for Kubernetes service discovery (format: key1=value1 key2=value2)",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}service-discovery-port",
|
||||
type=int,
|
||||
default=RouterArgs.service_discovery_port,
|
||||
help="Port to use for discovered worker pods",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}service-discovery-namespace",
|
||||
type=str,
|
||||
help="Kubernetes namespace to watch for pods. If not provided, watches all namespaces (requires cluster-wide permissions)",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}prefill-selector",
|
||||
type=str,
|
||||
nargs="+",
|
||||
default={},
|
||||
help="Label selector for prefill server pods in PD mode (format: key1=value1 key2=value2)",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}decode-selector",
|
||||
type=str,
|
||||
nargs="+",
|
||||
default={},
|
||||
help="Label selector for decode server pods in PD mode (format: key1=value1 key2=value2)",
|
||||
)
|
||||
# Prometheus configuration
|
||||
parser.add_argument(
|
||||
f"--{prefix}prometheus-port",
|
||||
type=int,
|
||||
default=29000,
|
||||
help="Port to expose Prometheus metrics. If not specified, Prometheus metrics are disabled",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}prometheus-host",
|
||||
type=str,
|
||||
default="127.0.0.1",
|
||||
help="Host address to bind the Prometheus metrics server",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}request-id-headers",
|
||||
type=str,
|
||||
nargs="*",
|
||||
help="Custom HTTP headers to check for request IDs (e.g., x-request-id x-trace-id). If not specified, uses common defaults.",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}request-timeout-secs",
|
||||
type=int,
|
||||
default=RouterArgs.request_timeout_secs,
|
||||
help="Request timeout in seconds",
|
||||
)
|
||||
# Retry configuration
|
||||
parser.add_argument(
|
||||
f"--{prefix}retry-max-retries",
|
||||
type=int,
|
||||
default=RouterArgs.retry_max_retries,
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}retry-initial-backoff-ms",
|
||||
type=int,
|
||||
default=RouterArgs.retry_initial_backoff_ms,
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}retry-max-backoff-ms",
|
||||
type=int,
|
||||
default=RouterArgs.retry_max_backoff_ms,
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}retry-backoff-multiplier",
|
||||
type=float,
|
||||
default=RouterArgs.retry_backoff_multiplier,
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}retry-jitter-factor",
|
||||
type=float,
|
||||
default=RouterArgs.retry_jitter_factor,
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}disable-retries",
|
||||
action="store_true",
|
||||
help="Disable retries (equivalent to setting retry_max_retries=1)",
|
||||
)
|
||||
# Circuit breaker configuration
|
||||
parser.add_argument(
|
||||
f"--{prefix}cb-failure-threshold",
|
||||
type=int,
|
||||
default=RouterArgs.cb_failure_threshold,
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}cb-success-threshold",
|
||||
type=int,
|
||||
default=RouterArgs.cb_success_threshold,
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}cb-timeout-duration-secs",
|
||||
type=int,
|
||||
default=RouterArgs.cb_timeout_duration_secs,
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}cb-window-duration-secs",
|
||||
type=int,
|
||||
default=RouterArgs.cb_window_duration_secs,
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}disable-circuit-breaker",
|
||||
action="store_true",
|
||||
help="Disable circuit breaker (equivalent to setting cb_failure_threshold to u32::MAX)",
|
||||
)
|
||||
# Health check configuration
|
||||
parser.add_argument(
|
||||
f"--{prefix}health-failure-threshold",
|
||||
type=int,
|
||||
default=RouterArgs.health_failure_threshold,
|
||||
help="Number of consecutive health check failures before marking worker unhealthy",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}health-success-threshold",
|
||||
type=int,
|
||||
default=RouterArgs.health_success_threshold,
|
||||
help="Number of consecutive health check successes before marking worker healthy",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}health-check-timeout-secs",
|
||||
type=int,
|
||||
default=RouterArgs.health_check_timeout_secs,
|
||||
help="Timeout in seconds for health check requests",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}health-check-interval-secs",
|
||||
type=int,
|
||||
default=RouterArgs.health_check_interval_secs,
|
||||
help="Interval in seconds between runtime health checks",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}health-check-endpoint",
|
||||
type=str,
|
||||
default=RouterArgs.health_check_endpoint,
|
||||
help="Health check endpoint path",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}max-concurrent-requests",
|
||||
type=int,
|
||||
default=RouterArgs.max_concurrent_requests,
|
||||
help="Maximum number of concurrent requests allowed (for rate limiting)",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}queue-size",
|
||||
type=int,
|
||||
default=RouterArgs.queue_size,
|
||||
help="Queue size for pending requests when max concurrent limit reached (0 = no queue, return 429 immediately)",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}queue-timeout-secs",
|
||||
type=int,
|
||||
default=RouterArgs.queue_timeout_secs,
|
||||
help="Maximum time (in seconds) a request can wait in queue before timing out",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}rate-limit-tokens-per-second",
|
||||
type=int,
|
||||
default=RouterArgs.rate_limit_tokens_per_second,
|
||||
help="Token bucket refill rate (tokens per second). If not set, defaults to max_concurrent_requests",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}cors-allowed-origins",
|
||||
type=str,
|
||||
nargs="*",
|
||||
default=[],
|
||||
help="CORS allowed origins (e.g., http://localhost:3000 https://example.com)",
|
||||
)
|
||||
# Tokenizer configuration
|
||||
parser.add_argument(
|
||||
f"--{prefix}model-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Model path for loading tokenizer (HuggingFace model ID or local path)",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}tokenizer-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Explicit tokenizer path (overrides model_path tokenizer if provided)",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(
|
||||
cls, args: argparse.Namespace, use_router_prefix: bool = False
|
||||
) -> "RouterArgs":
|
||||
"""
|
||||
Create RouterArgs instance from parsed command line arguments.
|
||||
|
||||
Args:
|
||||
args: Parsed command line arguments
|
||||
use_router_prefix: If True, look for arguments with 'router-' prefix
|
||||
"""
|
||||
prefix = "router_" if use_router_prefix else ""
|
||||
cli_args_dict = vars(args)
|
||||
args_dict = {}
|
||||
|
||||
for attr in dataclasses.fields(cls):
|
||||
# Auto strip prefix from args
|
||||
if f"{prefix}{attr.name}" in cli_args_dict:
|
||||
args_dict[attr.name] = cli_args_dict[f"{prefix}{attr.name}"]
|
||||
elif attr.name in cli_args_dict:
|
||||
args_dict[attr.name] = cli_args_dict[attr.name]
|
||||
|
||||
# parse special arguments and remove "--prefill" and "--decode" from cli_args_dict
|
||||
args_dict["prefill_urls"] = cls._parse_prefill_urls(
|
||||
cli_args_dict.get(f"{prefix}prefill", None)
|
||||
)
|
||||
args_dict["decode_urls"] = cls._parse_decode_urls(
|
||||
cli_args_dict.get(f"{prefix}decode", None)
|
||||
)
|
||||
args_dict["selector"] = cls._parse_selector(
|
||||
cli_args_dict.get(f"{prefix}selector", None)
|
||||
)
|
||||
args_dict["prefill_selector"] = cls._parse_selector(
|
||||
cli_args_dict.get(f"{prefix}prefill_selector", None)
|
||||
)
|
||||
args_dict["decode_selector"] = cls._parse_selector(
|
||||
cli_args_dict.get(f"{prefix}decode_selector", None)
|
||||
)
|
||||
|
||||
# Mooncake-specific annotation
|
||||
args_dict["bootstrap_port_annotation"] = "sglang.ai/bootstrap-port"
|
||||
|
||||
return cls(**args_dict)
|
||||
|
||||
def _validate_router_args(self):
|
||||
# Validate configuration based on mode
|
||||
if self.pd_disaggregation:
|
||||
# Validate PD configuration - skip URL requirements if using service discovery
|
||||
if not self.service_discovery:
|
||||
if not self.prefill_urls:
|
||||
raise ValueError("PD disaggregation mode requires --prefill")
|
||||
if not self.decode_urls:
|
||||
raise ValueError("PD disaggregation mode requires --decode")
|
||||
|
||||
# Warn about policy usage in PD mode
|
||||
if self.prefill_policy and self.decode_policy and self.policy:
|
||||
logger.warning(
|
||||
"Both --prefill-policy and --decode-policy are specified. "
|
||||
"The main --policy flag will be ignored for PD mode."
|
||||
)
|
||||
elif self.prefill_policy and not self.decode_policy and self.policy:
|
||||
logger.info(
|
||||
f"Using --prefill-policy '{self.prefill_policy}' for prefill nodes "
|
||||
f"and --policy '{self.policy}' for decode nodes."
|
||||
)
|
||||
elif self.decode_policy and not self.prefill_policy and self.policy:
|
||||
logger.info(
|
||||
f"Using --policy '{self.policy}' for prefill nodes "
|
||||
f"and --decode-policy '{self.decode_policy}' for decode nodes."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _parse_selector(selector_list):
|
||||
if not selector_list:
|
||||
return {}
|
||||
|
||||
selector = {}
|
||||
for item in selector_list:
|
||||
if "=" in item:
|
||||
key, value = item.split("=", 1)
|
||||
selector[key] = value
|
||||
return selector
|
||||
|
||||
@staticmethod
|
||||
def _parse_prefill_urls(prefill_list):
|
||||
"""Parse prefill URLs from --prefill arguments.
|
||||
|
||||
Format: --prefill URL [BOOTSTRAP_PORT]
|
||||
Example:
|
||||
--prefill http://prefill1:8080 9000 # With bootstrap port
|
||||
--prefill http://prefill2:8080 none # Explicitly no bootstrap port
|
||||
--prefill http://prefill3:8080 # Defaults to no bootstrap port
|
||||
"""
|
||||
if not prefill_list:
|
||||
return []
|
||||
|
||||
prefill_urls = []
|
||||
for prefill_args in prefill_list:
|
||||
|
||||
url = prefill_args[0]
|
||||
|
||||
# Handle optional bootstrap port
|
||||
if len(prefill_args) >= 2:
|
||||
bootstrap_port_str = prefill_args[1]
|
||||
# Handle 'none' as None
|
||||
if bootstrap_port_str.lower() == "none":
|
||||
bootstrap_port = None
|
||||
else:
|
||||
try:
|
||||
bootstrap_port = int(bootstrap_port_str)
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
f"Invalid bootstrap port: {bootstrap_port_str}. Must be a number or 'none'"
|
||||
)
|
||||
else:
|
||||
# No bootstrap port specified, default to None
|
||||
bootstrap_port = None
|
||||
|
||||
prefill_urls.append((url, bootstrap_port))
|
||||
|
||||
return prefill_urls
|
||||
|
||||
@staticmethod
|
||||
def _parse_decode_urls(decode_list):
|
||||
"""Parse decode URLs from --decode arguments.
|
||||
|
||||
Format: --decode URL
|
||||
Example: --decode http://decode1:8081 --decode http://decode2:8081
|
||||
"""
|
||||
if not decode_list:
|
||||
return []
|
||||
|
||||
# decode_list is a list of single-element lists due to nargs=1
|
||||
return [url[0] for url in decode_list]
|
||||
@@ -33,7 +33,7 @@ class TestLaunchRouter(unittest.TestCase):
|
||||
cache_threshold=0.5,
|
||||
balance_abs_threshold=32,
|
||||
balance_rel_threshold=1.0001,
|
||||
eviction_interval=60,
|
||||
eviction_interval_secs=60,
|
||||
max_tree_size=2**24,
|
||||
max_payload_size=256 * 1024 * 1024, # 256MB
|
||||
verbose=False,
|
||||
@@ -176,9 +176,8 @@ class TestLaunchRouter(unittest.TestCase):
|
||||
"""Test basic PD router functionality without actually starting servers."""
|
||||
# This test just verifies the PD router can be created and configured
|
||||
# without actually starting it (which would require real prefill/decode servers)
|
||||
from sglang_router import Router
|
||||
from sglang_router.launch_router import RouterArgs
|
||||
from sglang_router_rs import PolicyType
|
||||
from sglang_router.router import PolicyType, Router
|
||||
|
||||
# Test RouterArgs parsing for PD mode
|
||||
# Simulate the parsed args structure from argparse with action="append"
|
||||
@@ -209,18 +208,7 @@ class TestLaunchRouter(unittest.TestCase):
|
||||
self.assertEqual(router_args.decode_urls[1], "http://decode2:8081")
|
||||
|
||||
# Test Router creation in PD mode
|
||||
router = Router(
|
||||
worker_urls=[], # Empty for PD mode
|
||||
pd_disaggregation=True,
|
||||
prefill_urls=[
|
||||
("http://prefill1:8080", 9000),
|
||||
("http://prefill2:8080", None),
|
||||
],
|
||||
decode_urls=["http://decode1:8081", "http://decode2:8081"],
|
||||
policy=PolicyType.CacheAware,
|
||||
host="127.0.0.1",
|
||||
port=3001,
|
||||
)
|
||||
router = Router.from_args(router_args)
|
||||
self.assertIsNotNone(router)
|
||||
|
||||
def test_policy_validation(self):
|
||||
|
||||
@@ -77,7 +77,7 @@ def popen_launch_router(
|
||||
port,
|
||||
"--dp",
|
||||
str(dp_size),
|
||||
"--router-eviction-interval",
|
||||
"--router-eviction-interval-secs",
|
||||
"5",
|
||||
"--router-policy",
|
||||
policy,
|
||||
|
||||
@@ -28,8 +28,3 @@ find = { where = ["py_src"] }
|
||||
# workaround for https://github.com/pypa/twine/issues/1216
|
||||
[tool.setuptools]
|
||||
license-files = []
|
||||
|
||||
[[tool.setuptools-rust.ext-modules]]
|
||||
target = "sglang_router_rs"
|
||||
path = "Cargo.toml"
|
||||
binding = "PyO3"
|
||||
|
||||
21
sgl-router/setup.py
Normal file
21
sgl-router/setup.py
Normal file
@@ -0,0 +1,21 @@
|
||||
import os
|
||||
|
||||
from setuptools import setup
|
||||
from setuptools_rust import Binding, RustExtension
|
||||
|
||||
no_rust = os.environ.get("SGLANG_ROUTER_BUILD_NO_RUST") == "1"
|
||||
|
||||
rust_extensions = []
|
||||
if not no_rust:
|
||||
rust_extensions.append(
|
||||
RustExtension(
|
||||
target="sglang_router_rs",
|
||||
path="Cargo.toml",
|
||||
binding=Binding.PyO3,
|
||||
)
|
||||
)
|
||||
|
||||
setup(
|
||||
rust_extensions=rust_extensions,
|
||||
zip_safe=False,
|
||||
)
|
||||
Reference in New Issue
Block a user