[misc] Add PD service discovery support in router (#7361)
This commit is contained in:
@@ -32,7 +32,7 @@ class RouterArgs:
|
||||
port: int = 30000
|
||||
|
||||
# PD-specific configuration
|
||||
pd_disaggregated: bool = False # Enable PD disaggregated mode
|
||||
pd_disaggregation: bool = False # Enable PD disaggregated mode
|
||||
prefill_urls: List[tuple] = dataclasses.field(
|
||||
default_factory=list
|
||||
) # List of (url, bootstrap_port)
|
||||
@@ -55,6 +55,10 @@ class RouterArgs:
|
||||
selector: Dict[str, str] = dataclasses.field(default_factory=dict)
|
||||
service_discovery_port: int = 80
|
||||
service_discovery_namespace: Optional[str] = None
|
||||
# PD service discovery configuration
|
||||
prefill_selector: Dict[str, str] = dataclasses.field(default_factory=dict)
|
||||
decode_selector: Dict[str, str] = dataclasses.field(default_factory=dict)
|
||||
bootstrap_port_annotation: str = "sglang.ai/bootstrap-port"
|
||||
# Prometheus configuration
|
||||
prometheus_port: Optional[int] = None
|
||||
prometheus_host: Optional[str] = None
|
||||
@@ -108,7 +112,7 @@ class RouterArgs:
|
||||
|
||||
# PD-specific arguments
|
||||
parser.add_argument(
|
||||
f"--{prefix}pd-disaggregated",
|
||||
f"--{prefix}pd-disaggregation",
|
||||
action="store_true",
|
||||
help="Enable PD (Prefill-Decode) disaggregated mode",
|
||||
)
|
||||
@@ -207,6 +211,18 @@ class RouterArgs:
|
||||
type=str,
|
||||
help="Kubernetes namespace to watch for pods. If not provided, watches all namespaces (requires cluster-wide permissions)",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}prefill-selector",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="Label selector for prefill server pods in PD mode (format: key1=value1 key2=value2)",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}decode-selector",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="Label selector for decode server pods in PD mode (format: key1=value1 key2=value2)",
|
||||
)
|
||||
# Prometheus configuration
|
||||
parser.add_argument(
|
||||
f"--{prefix}prometheus-port",
|
||||
@@ -243,7 +259,7 @@ class RouterArgs:
|
||||
worker_urls=worker_urls,
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
pd_disaggregated=getattr(args, f"{prefix}pd_disaggregated", False),
|
||||
pd_disaggregation=getattr(args, f"{prefix}pd_disaggregation", False),
|
||||
prefill_urls=prefill_urls,
|
||||
decode_urls=decode_urls,
|
||||
policy=getattr(args, f"{prefix}policy"),
|
||||
@@ -267,6 +283,13 @@ class RouterArgs:
|
||||
service_discovery_namespace=getattr(
|
||||
args, f"{prefix}service_discovery_namespace", None
|
||||
),
|
||||
prefill_selector=cls._parse_selector(
|
||||
getattr(args, f"{prefix}prefill_selector", None)
|
||||
),
|
||||
decode_selector=cls._parse_selector(
|
||||
getattr(args, f"{prefix}decode_selector", None)
|
||||
),
|
||||
bootstrap_port_annotation="sglang.ai/bootstrap-port", # Mooncake-specific annotation
|
||||
prometheus_port=getattr(args, f"{prefix}prometheus_port", None),
|
||||
prometheus_host=getattr(args, f"{prefix}prometheus_host", None),
|
||||
)
|
||||
@@ -355,17 +378,20 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
|
||||
router_args = args
|
||||
|
||||
# Validate configuration based on mode
|
||||
if router_args.pd_disaggregated:
|
||||
# Validate PD configuration
|
||||
if not router_args.prefill_urls:
|
||||
raise ValueError("PD disaggregated mode requires --prefill")
|
||||
if not router_args.decode_urls:
|
||||
raise ValueError("PD disaggregated mode requires --decode")
|
||||
if router_args.pd_disaggregation:
|
||||
# Validate PD configuration - skip URL requirements if using service discovery
|
||||
if not router_args.service_discovery:
|
||||
if not router_args.prefill_urls:
|
||||
raise ValueError("PD disaggregation mode requires --prefill")
|
||||
if not router_args.decode_urls:
|
||||
raise ValueError("PD disaggregation mode requires --decode")
|
||||
|
||||
# Create router with unified constructor
|
||||
router = Router(
|
||||
worker_urls=(
|
||||
router_args.worker_urls if not router_args.pd_disaggregated else []
|
||||
[]
|
||||
if router_args.service_discovery or router_args.pd_disaggregation
|
||||
else router_args.worker_urls
|
||||
),
|
||||
host=router_args.host,
|
||||
port=router_args.port,
|
||||
@@ -384,14 +410,16 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
|
||||
selector=router_args.selector,
|
||||
service_discovery_port=router_args.service_discovery_port,
|
||||
service_discovery_namespace=router_args.service_discovery_namespace,
|
||||
prefill_selector=router_args.prefill_selector,
|
||||
decode_selector=router_args.decode_selector,
|
||||
prometheus_port=router_args.prometheus_port,
|
||||
prometheus_host=router_args.prometheus_host,
|
||||
pd_disaggregated=router_args.pd_disaggregated,
|
||||
pd_disaggregation=router_args.pd_disaggregation,
|
||||
prefill_urls=(
|
||||
router_args.prefill_urls if router_args.pd_disaggregated else None
|
||||
router_args.prefill_urls if router_args.pd_disaggregation else None
|
||||
),
|
||||
decode_urls=(
|
||||
router_args.decode_urls if router_args.pd_disaggregated else None
|
||||
router_args.decode_urls if router_args.pd_disaggregation else None
|
||||
),
|
||||
)
|
||||
|
||||
@@ -425,7 +453,7 @@ Examples:
|
||||
python -m sglang_router.launch_router --worker-urls http://worker1:8000 http://worker2:8000
|
||||
|
||||
# PD disaggregated mode
|
||||
python -m sglang_router.launch_router --pd-disaggregated \\
|
||||
python -m sglang_router.launch_router --pd-disaggregation \\
|
||||
--prefill http://prefill1:8000 9000 --prefill http://prefill2:8000 none \\
|
||||
--decode http://decode1:8001 --decode http://decode2:8001 \\
|
||||
--policy cache_aware
|
||||
|
||||
Reference in New Issue
Block a user