adapt to sglang v0.5.2rc1 on dcu
This commit is contained in:
7
sgl-router/py_src/sglang_router/__init__.py
Normal file
7
sgl-router/py_src/sglang_router/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
# a lightweihgt wrapper on router with argument type and comments
|
||||
# no wrapper on policy type => direct export
|
||||
from sglang_router.router import Router
|
||||
from sglang_router.version import __version__
|
||||
from sglang_router_rs import PolicyType
|
||||
|
||||
__all__ = ["Router", "PolicyType", "__version__"]
|
||||
846
sgl-router/py_src/sglang_router/launch_router.py
Normal file
846
sgl-router/py_src/sglang_router/launch_router.py
Normal file
@@ -0,0 +1,846 @@
|
||||
import argparse
|
||||
import dataclasses
|
||||
import logging
|
||||
import sys
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from sglang_router import Router
|
||||
from sglang_router_rs import PolicyType
|
||||
|
||||
|
||||
def setup_logger():
|
||||
logger = logging.getLogger("router")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
formatter = logging.Formatter(
|
||||
"[Router (Python)] %(asctime)s - %(levelname)s - %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
|
||||
handler = logging.StreamHandler()
|
||||
handler.setFormatter(formatter)
|
||||
logger.addHandler(handler)
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class RouterArgs:
|
||||
# Worker configuration
|
||||
worker_urls: List[str] = dataclasses.field(default_factory=list)
|
||||
host: str = "127.0.0.1"
|
||||
port: int = 30000
|
||||
|
||||
# PD-specific configuration
|
||||
pd_disaggregation: bool = False # Enable PD disaggregated mode
|
||||
prefill_urls: List[tuple] = dataclasses.field(
|
||||
default_factory=list
|
||||
) # List of (url, bootstrap_port)
|
||||
decode_urls: List[str] = dataclasses.field(default_factory=list)
|
||||
|
||||
# Routing policy
|
||||
policy: str = "cache_aware"
|
||||
prefill_policy: Optional[str] = None # Specific policy for prefill nodes in PD mode
|
||||
decode_policy: Optional[str] = None # Specific policy for decode nodes in PD mode
|
||||
worker_startup_timeout_secs: int = 600
|
||||
worker_startup_check_interval: int = 30
|
||||
cache_threshold: float = 0.3
|
||||
balance_abs_threshold: int = 64
|
||||
balance_rel_threshold: float = 1.5
|
||||
eviction_interval: int = 120
|
||||
max_tree_size: int = 2**26
|
||||
max_payload_size: int = 512 * 1024 * 1024 # 512MB default for large batches
|
||||
dp_aware: bool = False
|
||||
api_key: Optional[str] = None
|
||||
log_dir: Optional[str] = None
|
||||
log_level: Optional[str] = None
|
||||
# Service discovery configuration
|
||||
service_discovery: bool = False
|
||||
selector: Dict[str, str] = dataclasses.field(default_factory=dict)
|
||||
service_discovery_port: int = 80
|
||||
service_discovery_namespace: Optional[str] = None
|
||||
# PD service discovery configuration
|
||||
prefill_selector: Dict[str, str] = dataclasses.field(default_factory=dict)
|
||||
decode_selector: Dict[str, str] = dataclasses.field(default_factory=dict)
|
||||
bootstrap_port_annotation: str = "sglang.ai/bootstrap-port"
|
||||
# Prometheus configuration
|
||||
prometheus_port: Optional[int] = None
|
||||
prometheus_host: Optional[str] = None
|
||||
# Request ID headers configuration
|
||||
request_id_headers: Optional[List[str]] = None
|
||||
# Request timeout in seconds
|
||||
request_timeout_secs: int = 1800
|
||||
# Max concurrent requests for rate limiting
|
||||
max_concurrent_requests: int = 256
|
||||
# Queue size for pending requests when max concurrent limit reached
|
||||
queue_size: int = 100
|
||||
# Maximum time (in seconds) a request can wait in queue before timing out
|
||||
queue_timeout_secs: int = 60
|
||||
# Token bucket refill rate (tokens per second). If not set, defaults to max_concurrent_requests
|
||||
rate_limit_tokens_per_second: Optional[int] = None
|
||||
# CORS allowed origins
|
||||
cors_allowed_origins: List[str] = dataclasses.field(default_factory=list)
|
||||
# Retry configuration
|
||||
retry_max_retries: int = 5
|
||||
retry_initial_backoff_ms: int = 50
|
||||
retry_max_backoff_ms: int = 30_000
|
||||
retry_backoff_multiplier: float = 1.5
|
||||
retry_jitter_factor: float = 0.2
|
||||
disable_retries: bool = False
|
||||
# Health check configuration
|
||||
health_failure_threshold: int = 3
|
||||
health_success_threshold: int = 2
|
||||
health_check_timeout_secs: int = 5
|
||||
health_check_interval_secs: int = 60
|
||||
health_check_endpoint: str = "/health"
|
||||
# Circuit breaker configuration
|
||||
cb_failure_threshold: int = 10
|
||||
cb_success_threshold: int = 3
|
||||
cb_timeout_duration_secs: int = 60
|
||||
cb_window_duration_secs: int = 120
|
||||
disable_circuit_breaker: bool = False
|
||||
# Tokenizer configuration
|
||||
model_path: Optional[str] = None
|
||||
tokenizer_path: Optional[str] = None
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(
|
||||
parser: argparse.ArgumentParser,
|
||||
use_router_prefix: bool = False,
|
||||
exclude_host_port: bool = False,
|
||||
):
|
||||
"""
|
||||
Add router-specific arguments to an argument parser.
|
||||
|
||||
Args:
|
||||
parser: The argument parser to add arguments to
|
||||
use_router_prefix: If True, prefix all arguments with 'router-' to avoid conflicts
|
||||
exclude_host_port: If True, don't add host and port arguments (used when inheriting from server)
|
||||
"""
|
||||
prefix = "router-" if use_router_prefix else ""
|
||||
|
||||
# Worker configuration
|
||||
if not exclude_host_port:
|
||||
parser.add_argument(
|
||||
"--host",
|
||||
type=str,
|
||||
default=RouterArgs.host,
|
||||
help="Host address to bind the router server",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=RouterArgs.port,
|
||||
help="Port number to bind the router server",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--worker-urls",
|
||||
type=str,
|
||||
nargs="*",
|
||||
default=[],
|
||||
help="List of worker URLs (e.g., http://worker1:8000 http://worker2:8000)",
|
||||
)
|
||||
|
||||
# Routing policy configuration
|
||||
parser.add_argument(
|
||||
f"--{prefix}policy",
|
||||
type=str,
|
||||
default=RouterArgs.policy,
|
||||
choices=["random", "round_robin", "cache_aware", "power_of_two"],
|
||||
help="Load balancing policy to use. In PD mode, this is used for both prefill and decode unless overridden",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}prefill-policy",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=["random", "round_robin", "cache_aware", "power_of_two"],
|
||||
help="Specific policy for prefill nodes in PD mode. If not specified, uses the main policy",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}decode-policy",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=["random", "round_robin", "cache_aware", "power_of_two"],
|
||||
help="Specific policy for decode nodes in PD mode. If not specified, uses the main policy",
|
||||
)
|
||||
|
||||
# PD-specific arguments
|
||||
parser.add_argument(
|
||||
f"--{prefix}pd-disaggregation",
|
||||
action="store_true",
|
||||
help="Enable PD (Prefill-Decode) disaggregated mode",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}prefill",
|
||||
nargs="+",
|
||||
action="append",
|
||||
help="Prefill server URL and optional bootstrap port. Can be specified multiple times. "
|
||||
"Format: --prefill URL [BOOTSTRAP_PORT]. "
|
||||
"BOOTSTRAP_PORT can be a port number, 'none', or omitted (defaults to none).",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}decode",
|
||||
nargs=1,
|
||||
action="append",
|
||||
metavar=("URL",),
|
||||
help="Decode server URL. Can be specified multiple times.",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}worker-startup-timeout-secs",
|
||||
type=int,
|
||||
default=RouterArgs.worker_startup_timeout_secs,
|
||||
help="Timeout in seconds for worker startup",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}worker-startup-check-interval",
|
||||
type=int,
|
||||
default=RouterArgs.worker_startup_check_interval,
|
||||
help="Interval in seconds between checks for worker startup",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}cache-threshold",
|
||||
type=float,
|
||||
default=RouterArgs.cache_threshold,
|
||||
help="Cache threshold (0.0-1.0) for cache-aware routing",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}balance-abs-threshold",
|
||||
type=int,
|
||||
default=RouterArgs.balance_abs_threshold,
|
||||
help="Load balancing is triggered when (max_load - min_load) > abs_threshold AND max_load > min_load * rel_threshold. Otherwise, use cache aware",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}balance-rel-threshold",
|
||||
type=float,
|
||||
default=RouterArgs.balance_rel_threshold,
|
||||
help="Load balancing is triggered when (max_load - min_load) > abs_threshold AND max_load > min_load * rel_threshold. Otherwise, use cache aware",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}eviction-interval",
|
||||
type=int,
|
||||
default=RouterArgs.eviction_interval,
|
||||
help="Interval in seconds between cache eviction operations",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}max-tree-size",
|
||||
type=int,
|
||||
default=RouterArgs.max_tree_size,
|
||||
help="Maximum size of the approximation tree for cache-aware routing",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}max-payload-size",
|
||||
type=int,
|
||||
default=RouterArgs.max_payload_size,
|
||||
help="Maximum payload size in bytes",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}dp-aware",
|
||||
action="store_true",
|
||||
help="Enable data parallelism aware schedule",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}api-key",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The api key used for the authorization with the worker. Useful when the dp aware scheduling strategy is enaled.",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}log-dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Directory to store log files. If not specified, logs are only output to console.",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}log-level",
|
||||
type=str,
|
||||
default="info",
|
||||
choices=["debug", "info", "warning", "error", "critical"],
|
||||
help="Set the logging level. If not specified, defaults to INFO.",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}service-discovery",
|
||||
action="store_true",
|
||||
help="Enable Kubernetes service discovery",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}selector",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="Label selector for Kubernetes service discovery (format: key1=value1 key2=value2)",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}service-discovery-port",
|
||||
type=int,
|
||||
default=RouterArgs.service_discovery_port,
|
||||
help="Port to use for discovered worker pods",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}service-discovery-namespace",
|
||||
type=str,
|
||||
help="Kubernetes namespace to watch for pods. If not provided, watches all namespaces (requires cluster-wide permissions)",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}prefill-selector",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="Label selector for prefill server pods in PD mode (format: key1=value1 key2=value2)",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}decode-selector",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="Label selector for decode server pods in PD mode (format: key1=value1 key2=value2)",
|
||||
)
|
||||
# Prometheus configuration
|
||||
parser.add_argument(
|
||||
f"--{prefix}prometheus-port",
|
||||
type=int,
|
||||
default=29000,
|
||||
help="Port to expose Prometheus metrics. If not specified, Prometheus metrics are disabled",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}prometheus-host",
|
||||
type=str,
|
||||
default="127.0.0.1",
|
||||
help="Host address to bind the Prometheus metrics server",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}request-id-headers",
|
||||
type=str,
|
||||
nargs="*",
|
||||
help="Custom HTTP headers to check for request IDs (e.g., x-request-id x-trace-id). If not specified, uses common defaults.",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}request-timeout-secs",
|
||||
type=int,
|
||||
default=RouterArgs.request_timeout_secs,
|
||||
help="Request timeout in seconds",
|
||||
)
|
||||
# Retry configuration
|
||||
parser.add_argument(
|
||||
f"--{prefix}retry-max-retries",
|
||||
type=int,
|
||||
default=RouterArgs.retry_max_retries,
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}retry-initial-backoff-ms",
|
||||
type=int,
|
||||
default=RouterArgs.retry_initial_backoff_ms,
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}retry-max-backoff-ms",
|
||||
type=int,
|
||||
default=RouterArgs.retry_max_backoff_ms,
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}retry-backoff-multiplier",
|
||||
type=float,
|
||||
default=RouterArgs.retry_backoff_multiplier,
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}retry-jitter-factor",
|
||||
type=float,
|
||||
default=RouterArgs.retry_jitter_factor,
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}disable-retries",
|
||||
action="store_true",
|
||||
help="Disable retries (equivalent to setting retry_max_retries=1)",
|
||||
)
|
||||
# Circuit breaker configuration
|
||||
parser.add_argument(
|
||||
f"--{prefix}cb-failure-threshold",
|
||||
type=int,
|
||||
default=RouterArgs.cb_failure_threshold,
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}cb-success-threshold",
|
||||
type=int,
|
||||
default=RouterArgs.cb_success_threshold,
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}cb-timeout-duration-secs",
|
||||
type=int,
|
||||
default=RouterArgs.cb_timeout_duration_secs,
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}cb-window-duration-secs",
|
||||
type=int,
|
||||
default=RouterArgs.cb_window_duration_secs,
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}disable-circuit-breaker",
|
||||
action="store_true",
|
||||
help="Disable circuit breaker (equivalent to setting cb_failure_threshold to u32::MAX)",
|
||||
)
|
||||
# Health check configuration
|
||||
parser.add_argument(
|
||||
f"--{prefix}health-failure-threshold",
|
||||
type=int,
|
||||
default=RouterArgs.health_failure_threshold,
|
||||
help="Number of consecutive health check failures before marking worker unhealthy",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}health-success-threshold",
|
||||
type=int,
|
||||
default=RouterArgs.health_success_threshold,
|
||||
help="Number of consecutive health check successes before marking worker healthy",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}health-check-timeout-secs",
|
||||
type=int,
|
||||
default=RouterArgs.health_check_timeout_secs,
|
||||
help="Timeout in seconds for health check requests",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}health-check-interval-secs",
|
||||
type=int,
|
||||
default=RouterArgs.health_check_interval_secs,
|
||||
help="Interval in seconds between runtime health checks",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}health-check-endpoint",
|
||||
type=str,
|
||||
default=RouterArgs.health_check_endpoint,
|
||||
help="Health check endpoint path",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}max-concurrent-requests",
|
||||
type=int,
|
||||
default=RouterArgs.max_concurrent_requests,
|
||||
help="Maximum number of concurrent requests allowed (for rate limiting)",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}queue-size",
|
||||
type=int,
|
||||
default=RouterArgs.queue_size,
|
||||
help="Queue size for pending requests when max concurrent limit reached (0 = no queue, return 429 immediately)",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}queue-timeout-secs",
|
||||
type=int,
|
||||
default=RouterArgs.queue_timeout_secs,
|
||||
help="Maximum time (in seconds) a request can wait in queue before timing out",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}rate-limit-tokens-per-second",
|
||||
type=int,
|
||||
default=RouterArgs.rate_limit_tokens_per_second,
|
||||
help="Token bucket refill rate (tokens per second). If not set, defaults to max_concurrent_requests",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}cors-allowed-origins",
|
||||
type=str,
|
||||
nargs="*",
|
||||
default=[],
|
||||
help="CORS allowed origins (e.g., http://localhost:3000 https://example.com)",
|
||||
)
|
||||
# Tokenizer configuration
|
||||
parser.add_argument(
|
||||
f"--{prefix}model-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Model path for loading tokenizer (HuggingFace model ID or local path)",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}tokenizer-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Explicit tokenizer path (overrides model_path tokenizer if provided)",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(
|
||||
cls, args: argparse.Namespace, use_router_prefix: bool = False
|
||||
) -> "RouterArgs":
|
||||
"""
|
||||
Create RouterArgs instance from parsed command line arguments.
|
||||
|
||||
Args:
|
||||
args: Parsed command line arguments
|
||||
use_router_prefix: If True, look for arguments with 'router-' prefix
|
||||
"""
|
||||
prefix = "router_" if use_router_prefix else ""
|
||||
worker_urls = getattr(args, "worker_urls", [])
|
||||
|
||||
# Parse PD URLs
|
||||
prefill_urls = cls._parse_prefill_urls(getattr(args, f"{prefix}prefill", None))
|
||||
decode_urls = cls._parse_decode_urls(getattr(args, f"{prefix}decode", None))
|
||||
|
||||
return cls(
|
||||
worker_urls=worker_urls,
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
pd_disaggregation=getattr(args, f"{prefix}pd_disaggregation", False),
|
||||
prefill_urls=prefill_urls,
|
||||
decode_urls=decode_urls,
|
||||
policy=getattr(args, f"{prefix}policy"),
|
||||
prefill_policy=getattr(args, f"{prefix}prefill_policy", None),
|
||||
decode_policy=getattr(args, f"{prefix}decode_policy", None),
|
||||
worker_startup_timeout_secs=getattr(
|
||||
args, f"{prefix}worker_startup_timeout_secs"
|
||||
),
|
||||
worker_startup_check_interval=getattr(
|
||||
args, f"{prefix}worker_startup_check_interval"
|
||||
),
|
||||
cache_threshold=getattr(args, f"{prefix}cache_threshold"),
|
||||
balance_abs_threshold=getattr(args, f"{prefix}balance_abs_threshold"),
|
||||
balance_rel_threshold=getattr(args, f"{prefix}balance_rel_threshold"),
|
||||
eviction_interval=getattr(args, f"{prefix}eviction_interval"),
|
||||
max_tree_size=getattr(args, f"{prefix}max_tree_size"),
|
||||
max_payload_size=getattr(args, f"{prefix}max_payload_size"),
|
||||
dp_aware=getattr(args, f"{prefix}dp_aware", False),
|
||||
api_key=getattr(args, f"{prefix}api_key", None),
|
||||
log_dir=getattr(args, f"{prefix}log_dir", None),
|
||||
log_level=getattr(args, f"{prefix}log_level", None),
|
||||
service_discovery=getattr(args, f"{prefix}service_discovery", False),
|
||||
selector=cls._parse_selector(getattr(args, f"{prefix}selector", None)),
|
||||
service_discovery_port=getattr(args, f"{prefix}service_discovery_port"),
|
||||
service_discovery_namespace=getattr(
|
||||
args, f"{prefix}service_discovery_namespace", None
|
||||
),
|
||||
prefill_selector=cls._parse_selector(
|
||||
getattr(args, f"{prefix}prefill_selector", None)
|
||||
),
|
||||
decode_selector=cls._parse_selector(
|
||||
getattr(args, f"{prefix}decode_selector", None)
|
||||
),
|
||||
bootstrap_port_annotation="sglang.ai/bootstrap-port", # Mooncake-specific annotation
|
||||
prometheus_port=getattr(args, f"{prefix}prometheus_port", None),
|
||||
prometheus_host=getattr(args, f"{prefix}prometheus_host", None),
|
||||
request_id_headers=getattr(args, f"{prefix}request_id_headers", None),
|
||||
request_timeout_secs=getattr(
|
||||
args, f"{prefix}request_timeout_secs", RouterArgs.request_timeout_secs
|
||||
),
|
||||
max_concurrent_requests=getattr(
|
||||
args,
|
||||
f"{prefix}max_concurrent_requests",
|
||||
RouterArgs.max_concurrent_requests,
|
||||
),
|
||||
queue_size=getattr(
|
||||
args,
|
||||
f"{prefix}queue_size",
|
||||
RouterArgs.queue_size,
|
||||
),
|
||||
queue_timeout_secs=getattr(
|
||||
args,
|
||||
f"{prefix}queue_timeout_secs",
|
||||
RouterArgs.queue_timeout_secs,
|
||||
),
|
||||
rate_limit_tokens_per_second=getattr(
|
||||
args,
|
||||
f"{prefix}rate_limit_tokens_per_second",
|
||||
RouterArgs.rate_limit_tokens_per_second,
|
||||
),
|
||||
cors_allowed_origins=getattr(args, f"{prefix}cors_allowed_origins", []),
|
||||
retry_max_retries=getattr(args, f"{prefix}retry_max_retries"),
|
||||
retry_initial_backoff_ms=getattr(args, f"{prefix}retry_initial_backoff_ms"),
|
||||
retry_max_backoff_ms=getattr(args, f"{prefix}retry_max_backoff_ms"),
|
||||
retry_backoff_multiplier=getattr(args, f"{prefix}retry_backoff_multiplier"),
|
||||
retry_jitter_factor=getattr(args, f"{prefix}retry_jitter_factor"),
|
||||
cb_failure_threshold=getattr(args, f"{prefix}cb_failure_threshold"),
|
||||
cb_success_threshold=getattr(args, f"{prefix}cb_success_threshold"),
|
||||
cb_timeout_duration_secs=getattr(args, f"{prefix}cb_timeout_duration_secs"),
|
||||
cb_window_duration_secs=getattr(args, f"{prefix}cb_window_duration_secs"),
|
||||
disable_retries=getattr(args, f"{prefix}disable_retries", False),
|
||||
disable_circuit_breaker=getattr(
|
||||
args, f"{prefix}disable_circuit_breaker", False
|
||||
),
|
||||
health_failure_threshold=getattr(
|
||||
args,
|
||||
f"{prefix}health_failure_threshold",
|
||||
RouterArgs.health_failure_threshold,
|
||||
),
|
||||
health_success_threshold=getattr(
|
||||
args,
|
||||
f"{prefix}health_success_threshold",
|
||||
RouterArgs.health_success_threshold,
|
||||
),
|
||||
health_check_timeout_secs=getattr(
|
||||
args,
|
||||
f"{prefix}health_check_timeout_secs",
|
||||
RouterArgs.health_check_timeout_secs,
|
||||
),
|
||||
health_check_interval_secs=getattr(
|
||||
args,
|
||||
f"{prefix}health_check_interval_secs",
|
||||
RouterArgs.health_check_interval_secs,
|
||||
),
|
||||
health_check_endpoint=getattr(
|
||||
args, f"{prefix}health_check_endpoint", RouterArgs.health_check_endpoint
|
||||
),
|
||||
model_path=getattr(args, f"{prefix}model_path", None),
|
||||
tokenizer_path=getattr(args, f"{prefix}tokenizer_path", None),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _parse_selector(selector_list):
|
||||
if not selector_list:
|
||||
return {}
|
||||
|
||||
selector = {}
|
||||
for item in selector_list:
|
||||
if "=" in item:
|
||||
key, value = item.split("=", 1)
|
||||
selector[key] = value
|
||||
return selector
|
||||
|
||||
@staticmethod
|
||||
def _parse_prefill_urls(prefill_list):
|
||||
"""Parse prefill URLs from --prefill arguments.
|
||||
|
||||
Format: --prefill URL [BOOTSTRAP_PORT]
|
||||
Example:
|
||||
--prefill http://prefill1:8080 9000 # With bootstrap port
|
||||
--prefill http://prefill2:8080 none # Explicitly no bootstrap port
|
||||
--prefill http://prefill3:8080 # Defaults to no bootstrap port
|
||||
"""
|
||||
if not prefill_list:
|
||||
return []
|
||||
|
||||
prefill_urls = []
|
||||
for prefill_args in prefill_list:
|
||||
|
||||
url = prefill_args[0]
|
||||
|
||||
# Handle optional bootstrap port
|
||||
if len(prefill_args) >= 2:
|
||||
bootstrap_port_str = prefill_args[1]
|
||||
# Handle 'none' as None
|
||||
if bootstrap_port_str.lower() == "none":
|
||||
bootstrap_port = None
|
||||
else:
|
||||
try:
|
||||
bootstrap_port = int(bootstrap_port_str)
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
f"Invalid bootstrap port: {bootstrap_port_str}. Must be a number or 'none'"
|
||||
)
|
||||
else:
|
||||
# No bootstrap port specified, default to None
|
||||
bootstrap_port = None
|
||||
|
||||
prefill_urls.append((url, bootstrap_port))
|
||||
|
||||
return prefill_urls
|
||||
|
||||
@staticmethod
|
||||
def _parse_decode_urls(decode_list):
|
||||
"""Parse decode URLs from --decode arguments.
|
||||
|
||||
Format: --decode URL
|
||||
Example: --decode http://decode1:8081 --decode http://decode2:8081
|
||||
"""
|
||||
if not decode_list:
|
||||
return []
|
||||
|
||||
# decode_list is a list of single-element lists due to nargs=1
|
||||
return [url[0] for url in decode_list]
|
||||
|
||||
|
||||
def policy_from_str(policy_str: str) -> PolicyType:
|
||||
"""Convert policy string to PolicyType enum."""
|
||||
policy_map = {
|
||||
"random": PolicyType.Random,
|
||||
"round_robin": PolicyType.RoundRobin,
|
||||
"cache_aware": PolicyType.CacheAware,
|
||||
"power_of_two": PolicyType.PowerOfTwo,
|
||||
}
|
||||
return policy_map[policy_str]
|
||||
|
||||
|
||||
def launch_router(args: argparse.Namespace) -> Optional[Router]:
|
||||
"""
|
||||
Launch the SGLang router with the configuration from parsed arguments.
|
||||
|
||||
Args:
|
||||
args: Namespace object containing router configuration
|
||||
Can be either raw argparse.Namespace or converted RouterArgs
|
||||
|
||||
Returns:
|
||||
Router instance if successful, None if failed
|
||||
"""
|
||||
logger = logging.getLogger("router")
|
||||
try:
|
||||
# Convert to RouterArgs if needed
|
||||
if not isinstance(args, RouterArgs):
|
||||
router_args = RouterArgs.from_cli_args(args)
|
||||
else:
|
||||
router_args = args
|
||||
|
||||
# Validate configuration based on mode
|
||||
if router_args.pd_disaggregation:
|
||||
# Validate PD configuration - skip URL requirements if using service discovery
|
||||
if not router_args.service_discovery:
|
||||
if not router_args.prefill_urls:
|
||||
raise ValueError("PD disaggregation mode requires --prefill")
|
||||
if not router_args.decode_urls:
|
||||
raise ValueError("PD disaggregation mode requires --decode")
|
||||
|
||||
# Warn about policy usage in PD mode
|
||||
if (
|
||||
router_args.prefill_policy
|
||||
and router_args.decode_policy
|
||||
and router_args.policy
|
||||
):
|
||||
logger.warning(
|
||||
"Both --prefill-policy and --decode-policy are specified. "
|
||||
"The main --policy flag will be ignored for PD mode."
|
||||
)
|
||||
elif (
|
||||
router_args.prefill_policy
|
||||
and not router_args.decode_policy
|
||||
and router_args.policy
|
||||
):
|
||||
logger.info(
|
||||
f"Using --prefill-policy '{router_args.prefill_policy}' for prefill nodes "
|
||||
f"and --policy '{router_args.policy}' for decode nodes."
|
||||
)
|
||||
elif (
|
||||
router_args.decode_policy
|
||||
and not router_args.prefill_policy
|
||||
and router_args.policy
|
||||
):
|
||||
logger.info(
|
||||
f"Using --policy '{router_args.policy}' for prefill nodes "
|
||||
f"and --decode-policy '{router_args.decode_policy}' for decode nodes."
|
||||
)
|
||||
|
||||
# Create router with unified constructor
|
||||
router = Router(
|
||||
worker_urls=(
|
||||
[]
|
||||
if router_args.service_discovery or router_args.pd_disaggregation
|
||||
else router_args.worker_urls
|
||||
),
|
||||
host=router_args.host,
|
||||
port=router_args.port,
|
||||
policy=policy_from_str(router_args.policy),
|
||||
worker_startup_timeout_secs=router_args.worker_startup_timeout_secs,
|
||||
worker_startup_check_interval=router_args.worker_startup_check_interval,
|
||||
cache_threshold=router_args.cache_threshold,
|
||||
balance_abs_threshold=router_args.balance_abs_threshold,
|
||||
balance_rel_threshold=router_args.balance_rel_threshold,
|
||||
eviction_interval_secs=router_args.eviction_interval,
|
||||
max_tree_size=router_args.max_tree_size,
|
||||
max_payload_size=router_args.max_payload_size,
|
||||
dp_aware=router_args.dp_aware,
|
||||
api_key=router_args.api_key,
|
||||
log_dir=router_args.log_dir,
|
||||
log_level=router_args.log_level,
|
||||
service_discovery=router_args.service_discovery,
|
||||
selector=router_args.selector,
|
||||
service_discovery_port=router_args.service_discovery_port,
|
||||
service_discovery_namespace=router_args.service_discovery_namespace,
|
||||
prefill_selector=router_args.prefill_selector,
|
||||
decode_selector=router_args.decode_selector,
|
||||
prometheus_port=router_args.prometheus_port,
|
||||
prometheus_host=router_args.prometheus_host,
|
||||
request_timeout_secs=router_args.request_timeout_secs,
|
||||
pd_disaggregation=router_args.pd_disaggregation,
|
||||
prefill_urls=(
|
||||
router_args.prefill_urls if router_args.pd_disaggregation else None
|
||||
),
|
||||
decode_urls=(
|
||||
router_args.decode_urls if router_args.pd_disaggregation else None
|
||||
),
|
||||
prefill_policy=(
|
||||
policy_from_str(router_args.prefill_policy)
|
||||
if router_args.prefill_policy
|
||||
else None
|
||||
),
|
||||
decode_policy=(
|
||||
policy_from_str(router_args.decode_policy)
|
||||
if router_args.decode_policy
|
||||
else None
|
||||
),
|
||||
request_id_headers=router_args.request_id_headers,
|
||||
max_concurrent_requests=router_args.max_concurrent_requests,
|
||||
queue_size=router_args.queue_size,
|
||||
queue_timeout_secs=router_args.queue_timeout_secs,
|
||||
rate_limit_tokens_per_second=router_args.rate_limit_tokens_per_second,
|
||||
cors_allowed_origins=router_args.cors_allowed_origins,
|
||||
retry_max_retries=router_args.retry_max_retries,
|
||||
retry_initial_backoff_ms=router_args.retry_initial_backoff_ms,
|
||||
retry_max_backoff_ms=router_args.retry_max_backoff_ms,
|
||||
retry_backoff_multiplier=router_args.retry_backoff_multiplier,
|
||||
retry_jitter_factor=router_args.retry_jitter_factor,
|
||||
cb_failure_threshold=router_args.cb_failure_threshold,
|
||||
cb_success_threshold=router_args.cb_success_threshold,
|
||||
cb_timeout_duration_secs=router_args.cb_timeout_duration_secs,
|
||||
cb_window_duration_secs=router_args.cb_window_duration_secs,
|
||||
disable_retries=router_args.disable_retries,
|
||||
disable_circuit_breaker=router_args.disable_circuit_breaker,
|
||||
health_failure_threshold=router_args.health_failure_threshold,
|
||||
health_success_threshold=router_args.health_success_threshold,
|
||||
health_check_timeout_secs=router_args.health_check_timeout_secs,
|
||||
health_check_interval_secs=router_args.health_check_interval_secs,
|
||||
health_check_endpoint=router_args.health_check_endpoint,
|
||||
model_path=router_args.model_path,
|
||||
tokenizer_path=router_args.tokenizer_path,
|
||||
)
|
||||
|
||||
router.start()
|
||||
return router
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error starting router: {e}")
|
||||
raise e
|
||||
|
||||
|
||||
class CustomHelpFormatter(
|
||||
argparse.RawDescriptionHelpFormatter, argparse.ArgumentDefaultsHelpFormatter
|
||||
):
|
||||
"""Custom formatter that preserves both description formatting and shows defaults"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def parse_router_args(args: List[str]) -> RouterArgs:
|
||||
"""Parse command line arguments and return RouterArgs instance."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="""SGLang Router - High-performance request distribution across worker nodes
|
||||
|
||||
Usage:
|
||||
This launcher enables starting a router with individual worker instances. It is useful for
|
||||
multi-node setups or when you want to start workers and router separately.
|
||||
|
||||
Examples:
|
||||
# Regular mode
|
||||
python -m sglang_router.launch_router --worker-urls http://worker1:8000 http://worker2:8000
|
||||
|
||||
# PD disaggregated mode with same policy for both
|
||||
python -m sglang_router.launch_router --pd-disaggregation \\
|
||||
--prefill http://prefill1:8000 9000 --prefill http://prefill2:8000 \\
|
||||
--decode http://decode1:8001 --decode http://decode2:8001 \\
|
||||
--policy cache_aware
|
||||
|
||||
# PD mode with optional bootstrap ports
|
||||
python -m sglang_router.launch_router --pd-disaggregation \\
|
||||
--prefill http://prefill1:8000 9000 \\ # With bootstrap port
|
||||
--prefill http://prefill2:8000 none \\ # Explicitly no bootstrap port
|
||||
--prefill http://prefill3:8000 \\ # Defaults to no bootstrap port
|
||||
--decode http://decode1:8001 --decode http://decode2:8001
|
||||
|
||||
# PD mode with different policies for prefill and decode
|
||||
python -m sglang_router.launch_router --pd-disaggregation \\
|
||||
--prefill http://prefill1:8000 --prefill http://prefill2:8000 \\
|
||||
--decode http://decode1:8001 --decode http://decode2:8001 \\
|
||||
--prefill-policy cache_aware --decode-policy power_of_two
|
||||
|
||||
""",
|
||||
formatter_class=CustomHelpFormatter,
|
||||
)
|
||||
|
||||
RouterArgs.add_cli_args(parser, use_router_prefix=False)
|
||||
return RouterArgs.from_cli_args(parser.parse_args(args), use_router_prefix=False)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
router_args = parse_router_args(sys.argv[1:])
|
||||
launch_router(router_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
203
sgl-router/py_src/sglang_router/launch_server.py
Normal file
203
sgl-router/py_src/sglang_router/launch_server.py
Normal file
@@ -0,0 +1,203 @@
|
||||
import argparse
|
||||
import copy
|
||||
import logging
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import random
|
||||
import signal
|
||||
import sys
|
||||
import time
|
||||
from typing import List
|
||||
|
||||
import requests
|
||||
from setproctitle import setproctitle
|
||||
from sglang_router.launch_router import RouterArgs, launch_router
|
||||
|
||||
from sglang.srt.entrypoints.http_server import launch_server
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import is_port_available
|
||||
|
||||
|
||||
def setup_logger():
|
||||
logger = logging.getLogger("router")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
formatter = logging.Formatter(
|
||||
"[Router (Python)] %(asctime)s - %(levelname)s - %(message)s - %(filename)s:%(lineno)d",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
|
||||
handler = logging.StreamHandler()
|
||||
handler.setFormatter(formatter)
|
||||
logger.addHandler(handler)
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
# Create new process group
|
||||
def run_server(server_args, dp_rank):
|
||||
"""
|
||||
Note:
|
||||
|
||||
1. Without os.setpgrp(), all processes share the same PGID. When you press Ctrl+C, the terminal sends SIGINT to all processes in the group simultaneously.
|
||||
This can cause leaf processes to terminate first, which messes up the cleaning order and produces orphaned processes.
|
||||
|
||||
Terminal (PGID=100)
|
||||
└── Main Python Process (PGID=100)
|
||||
└── Server Process 1 (PGID=100)
|
||||
└── Scheduler 1
|
||||
└── Detokenizer 1
|
||||
└── Server Process 2 (PGID=100)
|
||||
└── Scheduler 2
|
||||
└── Detokenizer 2
|
||||
|
||||
2. With os.setpgrp(), the main Python process and its children are in a separate group. Now:
|
||||
|
||||
Terminal (PGID=100)
|
||||
└── Main Python Process (PGID=200)
|
||||
└── Server Process 1 (PGID=300)
|
||||
└── Scheduler 1
|
||||
└── Detokenizer 1
|
||||
└── Server Process 2 (PGID=400)
|
||||
└── Scheduler 2
|
||||
└── Detokenizer 2
|
||||
"""
|
||||
# create new process group
|
||||
os.setpgrp()
|
||||
|
||||
setproctitle("sglang::server")
|
||||
# Set SGLANG_DP_RANK environment variable
|
||||
os.environ["SGLANG_DP_RANK"] = str(dp_rank)
|
||||
|
||||
launch_server(server_args)
|
||||
|
||||
|
||||
def launch_server_process(
|
||||
server_args: ServerArgs, worker_port: int, dp_id: int
|
||||
) -> mp.Process:
|
||||
"""Launch a single server process with the given args and port."""
|
||||
server_args = copy.deepcopy(server_args)
|
||||
server_args.port = worker_port
|
||||
server_args.base_gpu_id = dp_id * server_args.tp_size
|
||||
server_args.dp_size = 1
|
||||
|
||||
proc = mp.Process(target=run_server, args=(server_args, dp_id))
|
||||
proc.start()
|
||||
return proc
|
||||
|
||||
|
||||
def wait_for_server_health(host: str, port: int, timeout: int = 300) -> bool:
|
||||
"""Wait for server to be healthy by checking /health endpoint."""
|
||||
start_time = time.perf_counter()
|
||||
url = f"http://{host}:{port}/health"
|
||||
|
||||
while time.perf_counter() - start_time < timeout:
|
||||
try:
|
||||
response = requests.get(url, timeout=5)
|
||||
if response.status_code == 200:
|
||||
return True
|
||||
except requests.exceptions.RequestException:
|
||||
pass
|
||||
time.sleep(1)
|
||||
return False
|
||||
|
||||
|
||||
def find_available_ports(base_port: int, count: int) -> List[int]:
|
||||
"""Find consecutive available ports starting from base_port."""
|
||||
available_ports = []
|
||||
current_port = base_port
|
||||
|
||||
while len(available_ports) < count:
|
||||
if is_port_available(current_port):
|
||||
available_ports.append(current_port)
|
||||
current_port += random.randint(100, 1000)
|
||||
|
||||
return available_ports
|
||||
|
||||
|
||||
def cleanup_processes(processes: List[mp.Process]):
|
||||
for process in processes:
|
||||
logger.info(f"Terminating process group {process.pid}")
|
||||
try:
|
||||
os.killpg(process.pid, signal.SIGTERM)
|
||||
except ProcessLookupError:
|
||||
# Process group may already be terminated
|
||||
pass
|
||||
|
||||
# Wait for processes to terminate
|
||||
for process in processes:
|
||||
process.join(timeout=5)
|
||||
if process.is_alive():
|
||||
logger.warning(
|
||||
f"Process {process.pid} did not terminate gracefully, forcing kill"
|
||||
)
|
||||
try:
|
||||
os.killpg(process.pid, signal.SIGKILL)
|
||||
except ProcessLookupError:
|
||||
pass
|
||||
|
||||
logger.info("All process groups terminated")
|
||||
|
||||
|
||||
def main():
|
||||
# CUDA runtime isn't fork-safe, which can lead to subtle bugs or crashes
|
||||
mp.set_start_method("spawn")
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Launch SGLang router and server processes"
|
||||
)
|
||||
|
||||
ServerArgs.add_cli_args(parser)
|
||||
RouterArgs.add_cli_args(parser, use_router_prefix=True, exclude_host_port=True)
|
||||
parser.add_argument(
|
||||
"--router-dp-worker-base-port",
|
||||
type=int,
|
||||
default=31000,
|
||||
help="Base port number for data parallel workers",
|
||||
)
|
||||
# No extra retry/CB flags here; RouterArgs.add_cli_args already defines them with router- prefix
|
||||
|
||||
args = parser.parse_args()
|
||||
server_args = ServerArgs.from_cli_args(args)
|
||||
router_args = RouterArgs.from_cli_args(args, use_router_prefix=True)
|
||||
|
||||
# Find available ports for workers
|
||||
worker_ports = find_available_ports(
|
||||
args.router_dp_worker_base_port, server_args.dp_size
|
||||
)
|
||||
|
||||
# Start server processes
|
||||
server_processes = []
|
||||
|
||||
for i, worker_port in enumerate(worker_ports):
|
||||
logger.info(f"Launching DP server process {i} on port {worker_port}")
|
||||
proc = launch_server_process(server_args, worker_port, i)
|
||||
server_processes.append(proc)
|
||||
|
||||
signal.signal(signal.SIGINT, lambda sig, frame: cleanup_processes(server_processes))
|
||||
signal.signal(
|
||||
signal.SIGTERM, lambda sig, frame: cleanup_processes(server_processes)
|
||||
)
|
||||
signal.signal(
|
||||
signal.SIGQUIT, lambda sig, frame: cleanup_processes(server_processes)
|
||||
)
|
||||
|
||||
# Update router args with worker URLs
|
||||
router_args.worker_urls = [
|
||||
f"http://{server_args.host}:{port}" for port in worker_ports
|
||||
]
|
||||
|
||||
# Start the router
|
||||
try:
|
||||
launch_router(router_args)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start router: {e}")
|
||||
cleanup_processes(server_processes)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
211
sgl-router/py_src/sglang_router/router.py
Normal file
211
sgl-router/py_src/sglang_router/router.py
Normal file
@@ -0,0 +1,211 @@
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from sglang_router_rs import PolicyType
|
||||
from sglang_router_rs import Router as _Router
|
||||
|
||||
|
||||
class Router:
|
||||
"""
|
||||
A high-performance router for distributing requests across worker nodes.
|
||||
|
||||
Args:
|
||||
worker_urls: List of URLs for worker nodes that will handle requests. Each URL should include
|
||||
the protocol, host, and port (e.g., ['http://worker1:8000', 'http://worker2:8000'])
|
||||
policy: Load balancing policy to use. Options:
|
||||
- PolicyType.Random: Randomly select workers
|
||||
- PolicyType.RoundRobin: Distribute requests in round-robin fashion
|
||||
- PolicyType.CacheAware: Distribute requests based on cache state and load balance
|
||||
- PolicyType.PowerOfTwo: Select best of two random workers based on load (PD mode only)
|
||||
host: Host address to bind the router server. Default: '127.0.0.1'
|
||||
port: Port number to bind the router server. Default: 3001
|
||||
worker_startup_timeout_secs: Timeout in seconds for worker startup. Default: 300
|
||||
worker_startup_check_interval: Interval in seconds between checks for worker initialization. Default: 10
|
||||
cache_threshold: Cache threshold (0.0-1.0) for cache-aware routing. Routes to cached worker
|
||||
if the match rate exceeds threshold, otherwise routes to the worker with the smallest
|
||||
tree. Default: 0.5
|
||||
balance_abs_threshold: Load balancing is triggered when (max_load - min_load) > abs_threshold
|
||||
AND max_load > min_load * rel_threshold. Otherwise, use cache aware. Default: 32
|
||||
balance_rel_threshold: Load balancing is triggered when (max_load - min_load) > abs_threshold
|
||||
AND max_load > min_load * rel_threshold. Otherwise, use cache aware. Default: 1.0001
|
||||
eviction_interval_secs: Interval in seconds between cache eviction operations in cache-aware
|
||||
routing. Default: 60
|
||||
max_payload_size: Maximum payload size in bytes. Default: 256MB
|
||||
max_tree_size: Maximum size of the approximation tree for cache-aware routing. Default: 2^24
|
||||
dp_aware: Enable data parallelism aware schedule. Default: False
|
||||
api_key: The api key used for the authorization with the worker.
|
||||
Useful when the dp aware scheduling strategy is enabled.
|
||||
Default: None
|
||||
log_dir: Directory to store log files. If None, logs are only output to console. Default: None
|
||||
log_level: Logging level. Options: 'debug', 'info', 'warning', 'error', 'critical'.
|
||||
service_discovery: Enable Kubernetes service discovery. When enabled, the router will
|
||||
automatically discover worker pods based on the selector. Default: False
|
||||
selector: Dictionary mapping of label keys to values for Kubernetes pod selection.
|
||||
Example: {"app": "sglang-worker"}. Default: {}
|
||||
service_discovery_port: Port to use for service discovery. The router will generate
|
||||
worker URLs using this port. Default: 80
|
||||
service_discovery_namespace: Kubernetes namespace to watch for pods. If not provided,
|
||||
watches pods across all namespaces (requires cluster-wide permissions). Default: None
|
||||
prefill_selector: Dictionary mapping of label keys to values for Kubernetes pod selection
|
||||
for prefill servers (PD mode only). Default: {}
|
||||
decode_selector: Dictionary mapping of label keys to values for Kubernetes pod selection
|
||||
for decode servers (PD mode only). Default: {}
|
||||
prometheus_port: Port to expose Prometheus metrics. Default: None
|
||||
prometheus_host: Host address to bind the Prometheus metrics server. Default: None
|
||||
pd_disaggregation: Enable PD (Prefill-Decode) disaggregated mode. Default: False
|
||||
prefill_urls: List of (url, bootstrap_port) tuples for prefill servers (PD mode only)
|
||||
decode_urls: List of URLs for decode servers (PD mode only)
|
||||
prefill_policy: Specific load balancing policy for prefill nodes (PD mode only).
|
||||
If not specified, uses the main policy. Default: None
|
||||
decode_policy: Specific load balancing policy for decode nodes (PD mode only).
|
||||
If not specified, uses the main policy. Default: None
|
||||
request_id_headers: List of HTTP headers to check for request IDs. If not specified,
|
||||
uses common defaults: ['x-request-id', 'x-correlation-id', 'x-trace-id', 'request-id'].
|
||||
Example: ['x-my-request-id', 'x-custom-trace-id']. Default: None
|
||||
bootstrap_port_annotation: Kubernetes annotation name for bootstrap port (PD mode).
|
||||
Default: 'sglang.ai/bootstrap-port'
|
||||
request_timeout_secs: Request timeout in seconds. Default: 600
|
||||
max_concurrent_requests: Maximum number of concurrent requests allowed for rate limiting. Default: 256
|
||||
queue_size: Queue size for pending requests when max concurrent limit reached (0 = no queue, return 429 immediately). Default: 100
|
||||
queue_timeout_secs: Maximum time (in seconds) a request can wait in queue before timing out. Default: 60
|
||||
rate_limit_tokens_per_second: Token bucket refill rate (tokens per second). If not set, defaults to max_concurrent_requests. Default: None
|
||||
cors_allowed_origins: List of allowed origins for CORS. Empty list allows all origins. Default: []
|
||||
health_failure_threshold: Number of consecutive health check failures before marking worker unhealthy. Default: 3
|
||||
health_success_threshold: Number of consecutive health check successes before marking worker healthy. Default: 2
|
||||
health_check_timeout_secs: Timeout in seconds for health check requests. Default: 5
|
||||
health_check_interval_secs: Interval in seconds between runtime health checks. Default: 60
|
||||
health_check_endpoint: Health check endpoint path. Default: '/health'
|
||||
model_path: Model path for loading tokenizer (HuggingFace model ID or local path). Default: None
|
||||
tokenizer_path: Explicit tokenizer path (overrides model_path tokenizer if provided). Default: None
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
worker_urls: List[str],
|
||||
policy: PolicyType = PolicyType.RoundRobin,
|
||||
host: str = "127.0.0.1",
|
||||
port: int = 3001,
|
||||
worker_startup_timeout_secs: int = 600,
|
||||
worker_startup_check_interval: int = 30,
|
||||
cache_threshold: float = 0.3,
|
||||
balance_abs_threshold: int = 64,
|
||||
balance_rel_threshold: float = 1.5,
|
||||
eviction_interval_secs: int = 120,
|
||||
max_tree_size: int = 2**26,
|
||||
max_payload_size: int = 512 * 1024 * 1024, # 512MB
|
||||
dp_aware: bool = False,
|
||||
api_key: Optional[str] = None,
|
||||
log_dir: Optional[str] = None,
|
||||
log_level: Optional[str] = None,
|
||||
service_discovery: bool = False,
|
||||
selector: Dict[str, str] = None,
|
||||
service_discovery_port: int = 80,
|
||||
service_discovery_namespace: Optional[str] = None,
|
||||
prefill_selector: Dict[str, str] = None,
|
||||
decode_selector: Dict[str, str] = None,
|
||||
bootstrap_port_annotation: str = "sglang.ai/bootstrap-port",
|
||||
prometheus_port: Optional[int] = None,
|
||||
prometheus_host: Optional[str] = None,
|
||||
request_timeout_secs: int = 1800,
|
||||
request_id_headers: Optional[List[str]] = None,
|
||||
pd_disaggregation: bool = False,
|
||||
prefill_urls: Optional[List[tuple]] = None,
|
||||
decode_urls: Optional[List[str]] = None,
|
||||
prefill_policy: Optional[PolicyType] = None,
|
||||
decode_policy: Optional[PolicyType] = None,
|
||||
max_concurrent_requests: int = 256,
|
||||
queue_size: int = 100,
|
||||
queue_timeout_secs: int = 60,
|
||||
rate_limit_tokens_per_second: Optional[int] = None,
|
||||
cors_allowed_origins: List[str] = None,
|
||||
retry_max_retries: int = 5,
|
||||
retry_initial_backoff_ms: int = 50,
|
||||
retry_max_backoff_ms: int = 30_000,
|
||||
retry_backoff_multiplier: float = 1.5,
|
||||
retry_jitter_factor: float = 0.2,
|
||||
cb_failure_threshold: int = 10,
|
||||
cb_success_threshold: int = 3,
|
||||
cb_timeout_duration_secs: int = 60,
|
||||
cb_window_duration_secs: int = 120,
|
||||
disable_retries: bool = False,
|
||||
disable_circuit_breaker: bool = False,
|
||||
health_failure_threshold: int = 3,
|
||||
health_success_threshold: int = 2,
|
||||
health_check_timeout_secs: int = 5,
|
||||
health_check_interval_secs: int = 60,
|
||||
health_check_endpoint: str = "/health",
|
||||
model_path: Optional[str] = None,
|
||||
tokenizer_path: Optional[str] = None,
|
||||
):
|
||||
if selector is None:
|
||||
selector = {}
|
||||
if prefill_selector is None:
|
||||
prefill_selector = {}
|
||||
if decode_selector is None:
|
||||
decode_selector = {}
|
||||
if cors_allowed_origins is None:
|
||||
cors_allowed_origins = []
|
||||
|
||||
self._router = _Router(
|
||||
worker_urls=worker_urls,
|
||||
policy=policy,
|
||||
host=host,
|
||||
port=port,
|
||||
worker_startup_timeout_secs=worker_startup_timeout_secs,
|
||||
worker_startup_check_interval=worker_startup_check_interval,
|
||||
cache_threshold=cache_threshold,
|
||||
balance_abs_threshold=balance_abs_threshold,
|
||||
balance_rel_threshold=balance_rel_threshold,
|
||||
eviction_interval_secs=eviction_interval_secs,
|
||||
max_tree_size=max_tree_size,
|
||||
max_payload_size=max_payload_size,
|
||||
dp_aware=dp_aware,
|
||||
api_key=api_key,
|
||||
log_dir=log_dir,
|
||||
log_level=log_level,
|
||||
service_discovery=service_discovery,
|
||||
selector=selector,
|
||||
service_discovery_port=service_discovery_port,
|
||||
service_discovery_namespace=service_discovery_namespace,
|
||||
prefill_selector=prefill_selector,
|
||||
decode_selector=decode_selector,
|
||||
bootstrap_port_annotation=bootstrap_port_annotation,
|
||||
prometheus_port=prometheus_port,
|
||||
prometheus_host=prometheus_host,
|
||||
request_timeout_secs=request_timeout_secs,
|
||||
request_id_headers=request_id_headers,
|
||||
pd_disaggregation=pd_disaggregation,
|
||||
prefill_urls=prefill_urls,
|
||||
decode_urls=decode_urls,
|
||||
prefill_policy=prefill_policy,
|
||||
decode_policy=decode_policy,
|
||||
max_concurrent_requests=max_concurrent_requests,
|
||||
queue_size=queue_size,
|
||||
queue_timeout_secs=queue_timeout_secs,
|
||||
rate_limit_tokens_per_second=rate_limit_tokens_per_second,
|
||||
cors_allowed_origins=cors_allowed_origins,
|
||||
retry_max_retries=retry_max_retries,
|
||||
retry_initial_backoff_ms=retry_initial_backoff_ms,
|
||||
retry_max_backoff_ms=retry_max_backoff_ms,
|
||||
retry_backoff_multiplier=retry_backoff_multiplier,
|
||||
retry_jitter_factor=retry_jitter_factor,
|
||||
cb_failure_threshold=cb_failure_threshold,
|
||||
cb_success_threshold=cb_success_threshold,
|
||||
cb_timeout_duration_secs=cb_timeout_duration_secs,
|
||||
cb_window_duration_secs=cb_window_duration_secs,
|
||||
disable_retries=disable_retries,
|
||||
disable_circuit_breaker=disable_circuit_breaker,
|
||||
health_failure_threshold=health_failure_threshold,
|
||||
health_success_threshold=health_success_threshold,
|
||||
health_check_timeout_secs=health_check_timeout_secs,
|
||||
health_check_interval_secs=health_check_interval_secs,
|
||||
health_check_endpoint=health_check_endpoint,
|
||||
model_path=model_path,
|
||||
tokenizer_path=tokenizer_path,
|
||||
)
|
||||
|
||||
def start(self) -> None:
|
||||
"""Start the router server.
|
||||
|
||||
This method blocks until the server is shut down.
|
||||
"""
|
||||
self._router.start()
|
||||
1
sgl-router/py_src/sglang_router/version.py
Normal file
1
sgl-router/py_src/sglang_router/version.py
Normal file
@@ -0,0 +1 @@
|
||||
__version__ = "0.1.9"
|
||||
Reference in New Issue
Block a user