Merge PDLB (Prefill-Decode Load Balancer) into SGLang Router (#7096)

This commit is contained in:
Simo Lin
2025-06-18 11:28:15 -07:00
committed by GitHub
parent 712bf9ec9b
commit 09ae5b20f3
13 changed files with 4045 additions and 187 deletions

View File

@@ -31,6 +31,13 @@ class RouterArgs:
host: str = "127.0.0.1"
port: int = 30000
# PD-specific configuration
pd_disaggregated: bool = False # Enable PD disaggregated mode
prefill_urls: List[tuple] = dataclasses.field(
default_factory=list
) # List of (url, bootstrap_port)
decode_urls: List[str] = dataclasses.field(default_factory=list)
# Routing policy
policy: str = "cache_aware"
worker_startup_timeout_secs: int = 300
@@ -40,7 +47,7 @@ class RouterArgs:
balance_rel_threshold: float = 1.0001
eviction_interval: int = 60
max_tree_size: int = 2**24
max_payload_size: int = 4 * 1024 * 1024 # 4MB
max_payload_size: int = 256 * 1024 * 1024 # 256MB default for large batches
verbose: bool = False
log_dir: Optional[str] = None
# Service discovery configuration
@@ -95,8 +102,29 @@ class RouterArgs:
f"--{prefix}policy",
type=str,
default=RouterArgs.policy,
choices=["random", "round_robin", "cache_aware"],
help="Load balancing policy to use",
choices=["random", "round_robin", "cache_aware", "power_of_two"],
help="Load balancing policy to use. Note: power_of_two is only available in PD disaggregated mode",
)
# PD-specific arguments
parser.add_argument(
f"--{prefix}pd-disaggregated",
action="store_true",
help="Enable PD (Prefill-Decode) disaggregated mode",
)
parser.add_argument(
f"--{prefix}prefill",
nargs=2,
action="append",
metavar=("URL", "BOOTSTRAP_PORT"),
help="Prefill server URL and bootstrap port. Can be specified multiple times. BOOTSTRAP_PORT can be 'none' for no bootstrap port.",
)
parser.add_argument(
f"--{prefix}decode",
nargs=1,
action="append",
metavar=("URL",),
help="Decode server URL. Can be specified multiple times.",
)
parser.add_argument(
f"--{prefix}worker-startup-timeout-secs",
@@ -205,11 +233,19 @@ class RouterArgs:
use_router_prefix: If True, look for arguments with 'router-' prefix
"""
prefix = "router_" if use_router_prefix else ""
worker_urls = args.worker_urls if args.worker_urls is not None else []
worker_urls = getattr(args, "worker_urls", [])
# Parse PD URLs
prefill_urls = cls._parse_prefill_urls(getattr(args, f"{prefix}prefill", None))
decode_urls = cls._parse_decode_urls(getattr(args, f"{prefix}decode", None))
return cls(
worker_urls=worker_urls,
host=args.host,
port=args.port,
pd_disaggregated=getattr(args, f"{prefix}pd_disaggregated", False),
prefill_urls=prefill_urls,
decode_urls=decode_urls,
policy=getattr(args, f"{prefix}policy"),
worker_startup_timeout_secs=getattr(
args, f"{prefix}worker_startup_timeout_secs"
@@ -247,6 +283,46 @@ class RouterArgs:
selector[key] = value
return selector
@staticmethod
def _parse_prefill_urls(prefill_list):
"""Parse prefill URLs from --prefill arguments.
Format: --prefill URL BOOTSTRAP_PORT
Example: --prefill http://prefill1:8080 9000 --prefill http://prefill2:8080 none
"""
if not prefill_list:
return []
prefill_urls = []
for url, bootstrap_port_str in prefill_list:
# Handle 'none' as None
if bootstrap_port_str.lower() == "none":
bootstrap_port = None
else:
try:
bootstrap_port = int(bootstrap_port_str)
except ValueError:
raise ValueError(
f"Invalid bootstrap port: {bootstrap_port_str}. Must be a number or 'none'"
)
prefill_urls.append((url, bootstrap_port))
return prefill_urls
@staticmethod
def _parse_decode_urls(decode_list):
"""Parse decode URLs from --decode arguments.
Format: --decode URL
Example: --decode http://decode1:8081 --decode http://decode2:8081
"""
if not decode_list:
return []
# decode_list is a list of single-element lists due to nargs=1
return [url[0] for url in decode_list]
def policy_from_str(policy_str: str) -> PolicyType:
"""Convert policy string to PolicyType enum."""
@@ -254,6 +330,7 @@ def policy_from_str(policy_str: str) -> PolicyType:
"random": PolicyType.Random,
"round_robin": PolicyType.RoundRobin,
"cache_aware": PolicyType.CacheAware,
"power_of_two": PolicyType.PowerOfTwo,
}
return policy_map[policy_str]
@@ -277,8 +354,19 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
else:
router_args = args
# Validate configuration based on mode
if router_args.pd_disaggregated:
# Validate PD configuration
if not router_args.prefill_urls:
raise ValueError("PD disaggregated mode requires --prefill")
if not router_args.decode_urls:
raise ValueError("PD disaggregated mode requires --decode")
# Create router with unified constructor
router = Router(
worker_urls=router_args.worker_urls,
worker_urls=(
router_args.worker_urls if not router_args.pd_disaggregated else []
),
host=router_args.host,
port=router_args.port,
policy=policy_from_str(router_args.policy),
@@ -298,6 +386,13 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
service_discovery_namespace=router_args.service_discovery_namespace,
prometheus_port=router_args.prometheus_port,
prometheus_host=router_args.prometheus_host,
pd_disaggregated=router_args.pd_disaggregated,
prefill_urls=(
router_args.prefill_urls if router_args.pd_disaggregated else None
),
decode_urls=(
router_args.decode_urls if router_args.pd_disaggregated else None
),
)
router.start()
@@ -326,8 +421,14 @@ This launcher enables starting a router with individual worker instances. It is
multi-node setups or when you want to start workers and router separately.
Examples:
# Regular mode
python -m sglang_router.launch_router --worker-urls http://worker1:8000 http://worker2:8000
python -m sglang_router.launch_router --worker-urls http://worker1:8000 http://worker2:8000 --cache-threshold 0.7 --balance-abs-threshold 64 --balance-rel-threshold 1.2
# PD disaggregated mode
python -m sglang_router.launch_router --pd-disaggregated \\
--prefill http://prefill1:8000 9000 --prefill http://prefill2:8000 none \\
--decode http://decode1:8001 --decode http://decode2:8001 \\
--policy cache_aware
""",
formatter_class=CustomHelpFormatter,

View File

@@ -15,6 +15,7 @@ class Router:
- PolicyType.Random: Randomly select workers
- PolicyType.RoundRobin: Distribute requests in round-robin fashion
- PolicyType.CacheAware: Distribute requests based on cache state and load balance
- PolicyType.PowerOfTwo: Select best of two random workers based on load (PD mode only)
host: Host address to bind the router server. Default: '127.0.0.1'
port: Port number to bind the router server. Default: 3001
worker_startup_timeout_secs: Timeout in seconds for worker startup. Default: 300
@@ -28,7 +29,7 @@ class Router:
AND max_load > min_load * rel_threshold. Otherwise, use cache aware. Default: 1.0001
eviction_interval_secs: Interval in seconds between cache eviction operations in cache-aware
routing. Default: 60
max_payload_size: Maximum payload size in bytes. Default: 4MB
max_payload_size: Maximum payload size in bytes. Default: 256MB
max_tree_size: Maximum size of the approximation tree for cache-aware routing. Default: 2^24
verbose: Enable verbose logging. Default: False
log_dir: Directory to store log files. If None, logs are only output to console. Default: None
@@ -42,6 +43,9 @@ class Router:
watches pods across all namespaces (requires cluster-wide permissions). Default: None
prometheus_port: Port to expose Prometheus metrics. Default: None
prometheus_host: Host address to bind the Prometheus metrics server. Default: None
pd_disaggregated: Enable PD (Prefill-Decode) disaggregated mode. Default: False
prefill_urls: List of (url, bootstrap_port) tuples for prefill servers (PD mode only)
decode_urls: List of URLs for decode servers (PD mode only)
"""
def __init__(
@@ -57,7 +61,7 @@ class Router:
balance_rel_threshold: float = 1.0001,
eviction_interval_secs: int = 60,
max_tree_size: int = 2**24,
max_payload_size: int = 4 * 1024 * 1024, # 4MB
max_payload_size: int = 256 * 1024 * 1024, # 256MB
verbose: bool = False,
log_dir: Optional[str] = None,
service_discovery: bool = False,
@@ -66,6 +70,9 @@ class Router:
service_discovery_namespace: Optional[str] = None,
prometheus_port: Optional[int] = None,
prometheus_host: Optional[str] = None,
pd_disaggregated: bool = False,
prefill_urls: Optional[List[tuple]] = None,
decode_urls: Optional[List[str]] = None,
):
if selector is None:
selector = {}
@@ -91,6 +98,9 @@ class Router:
service_discovery_namespace=service_discovery_namespace,
prometheus_port=prometheus_port,
prometheus_host=prometheus_host,
pd_disaggregated=pd_disaggregated,
prefill_urls=prefill_urls,
decode_urls=decode_urls,
)
def start(self) -> None: