[misc] Add PD service discovery support in router (#7361)

This commit is contained in:
Simo Lin
2025-06-22 17:54:14 -07:00
committed by GitHub
parent bd4f581896
commit 30f2a44a96
11 changed files with 1362 additions and 120 deletions

View File

@@ -32,7 +32,7 @@ class RouterArgs:
port: int = 30000
# PD-specific configuration
pd_disaggregated: bool = False # Enable PD disaggregated mode
pd_disaggregation: bool = False # Enable PD disaggregated mode
prefill_urls: List[tuple] = dataclasses.field(
default_factory=list
) # List of (url, bootstrap_port)
@@ -55,6 +55,10 @@ class RouterArgs:
selector: Dict[str, str] = dataclasses.field(default_factory=dict)
service_discovery_port: int = 80
service_discovery_namespace: Optional[str] = None
# PD service discovery configuration
prefill_selector: Dict[str, str] = dataclasses.field(default_factory=dict)
decode_selector: Dict[str, str] = dataclasses.field(default_factory=dict)
bootstrap_port_annotation: str = "sglang.ai/bootstrap-port"
# Prometheus configuration
prometheus_port: Optional[int] = None
prometheus_host: Optional[str] = None
@@ -108,7 +112,7 @@ class RouterArgs:
# PD-specific arguments
parser.add_argument(
f"--{prefix}pd-disaggregated",
f"--{prefix}pd-disaggregation",
action="store_true",
help="Enable PD (Prefill-Decode) disaggregated mode",
)
@@ -207,6 +211,18 @@ class RouterArgs:
type=str,
help="Kubernetes namespace to watch for pods. If not provided, watches all namespaces (requires cluster-wide permissions)",
)
parser.add_argument(
f"--{prefix}prefill-selector",
type=str,
nargs="+",
help="Label selector for prefill server pods in PD mode (format: key1=value1 key2=value2)",
)
parser.add_argument(
f"--{prefix}decode-selector",
type=str,
nargs="+",
help="Label selector for decode server pods in PD mode (format: key1=value1 key2=value2)",
)
# Prometheus configuration
parser.add_argument(
f"--{prefix}prometheus-port",
@@ -243,7 +259,7 @@ class RouterArgs:
worker_urls=worker_urls,
host=args.host,
port=args.port,
pd_disaggregated=getattr(args, f"{prefix}pd_disaggregated", False),
pd_disaggregation=getattr(args, f"{prefix}pd_disaggregation", False),
prefill_urls=prefill_urls,
decode_urls=decode_urls,
policy=getattr(args, f"{prefix}policy"),
@@ -267,6 +283,13 @@ class RouterArgs:
service_discovery_namespace=getattr(
args, f"{prefix}service_discovery_namespace", None
),
prefill_selector=cls._parse_selector(
getattr(args, f"{prefix}prefill_selector", None)
),
decode_selector=cls._parse_selector(
getattr(args, f"{prefix}decode_selector", None)
),
bootstrap_port_annotation="sglang.ai/bootstrap-port", # Mooncake-specific annotation
prometheus_port=getattr(args, f"{prefix}prometheus_port", None),
prometheus_host=getattr(args, f"{prefix}prometheus_host", None),
)
@@ -355,17 +378,20 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
router_args = args
# Validate configuration based on mode
if router_args.pd_disaggregated:
# Validate PD configuration
if not router_args.prefill_urls:
raise ValueError("PD disaggregated mode requires --prefill")
if not router_args.decode_urls:
raise ValueError("PD disaggregated mode requires --decode")
if router_args.pd_disaggregation:
# Validate PD configuration - skip URL requirements if using service discovery
if not router_args.service_discovery:
if not router_args.prefill_urls:
raise ValueError("PD disaggregation mode requires --prefill")
if not router_args.decode_urls:
raise ValueError("PD disaggregation mode requires --decode")
# Create router with unified constructor
router = Router(
worker_urls=(
router_args.worker_urls if not router_args.pd_disaggregated else []
[]
if router_args.service_discovery or router_args.pd_disaggregation
else router_args.worker_urls
),
host=router_args.host,
port=router_args.port,
@@ -384,14 +410,16 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
selector=router_args.selector,
service_discovery_port=router_args.service_discovery_port,
service_discovery_namespace=router_args.service_discovery_namespace,
prefill_selector=router_args.prefill_selector,
decode_selector=router_args.decode_selector,
prometheus_port=router_args.prometheus_port,
prometheus_host=router_args.prometheus_host,
pd_disaggregated=router_args.pd_disaggregated,
pd_disaggregation=router_args.pd_disaggregation,
prefill_urls=(
router_args.prefill_urls if router_args.pd_disaggregated else None
router_args.prefill_urls if router_args.pd_disaggregation else None
),
decode_urls=(
router_args.decode_urls if router_args.pd_disaggregated else None
router_args.decode_urls if router_args.pd_disaggregation else None
),
)
@@ -425,7 +453,7 @@ Examples:
python -m sglang_router.launch_router --worker-urls http://worker1:8000 http://worker2:8000
# PD disaggregated mode
python -m sglang_router.launch_router --pd-disaggregated \\
python -m sglang_router.launch_router --pd-disaggregation \\
--prefill http://prefill1:8000 9000 --prefill http://prefill2:8000 none \\
--decode http://decode1:8001 --decode http://decode2:8001 \\
--policy cache_aware

View File

@@ -41,9 +41,13 @@ class Router:
worker URLs using this port. Default: 80
service_discovery_namespace: Kubernetes namespace to watch for pods. If not provided,
watches pods across all namespaces (requires cluster-wide permissions). Default: None
prefill_selector: Dictionary mapping of label keys to values for Kubernetes pod selection
for prefill servers (PD mode only). Default: {}
decode_selector: Dictionary mapping of label keys to values for Kubernetes pod selection
for decode servers (PD mode only). Default: {}
prometheus_port: Port to expose Prometheus metrics. Default: None
prometheus_host: Host address to bind the Prometheus metrics server. Default: None
pd_disaggregated: Enable PD (Prefill-Decode) disaggregated mode. Default: False
pd_disaggregation: Enable PD (Prefill-Decode) disaggregated mode. Default: False
prefill_urls: List of (url, bootstrap_port) tuples for prefill servers (PD mode only)
decode_urls: List of URLs for decode servers (PD mode only)
"""
@@ -68,14 +72,20 @@ class Router:
selector: Dict[str, str] = None,
service_discovery_port: int = 80,
service_discovery_namespace: Optional[str] = None,
prefill_selector: Dict[str, str] = None,
decode_selector: Dict[str, str] = None,
prometheus_port: Optional[int] = None,
prometheus_host: Optional[str] = None,
pd_disaggregated: bool = False,
pd_disaggregation: bool = False,
prefill_urls: Optional[List[tuple]] = None,
decode_urls: Optional[List[str]] = None,
):
if selector is None:
selector = {}
if prefill_selector is None:
prefill_selector = {}
if decode_selector is None:
decode_selector = {}
self._router = _Router(
worker_urls=worker_urls,
@@ -96,9 +106,11 @@ class Router:
selector=selector,
service_discovery_port=service_discovery_port,
service_discovery_namespace=service_discovery_namespace,
prefill_selector=prefill_selector,
decode_selector=decode_selector,
prometheus_port=prometheus_port,
prometheus_host=prometheus_host,
pd_disaggregated=pd_disaggregated,
pd_disaggregation=pd_disaggregation,
prefill_urls=prefill_urls,
decode_urls=decode_urls,
)