Merge PDLB (Prefill-Decode Load Balancer) into SGLang Router (#7096)

This commit is contained in:
Simo Lin
2025-06-18 11:28:15 -07:00
committed by GitHub
parent 712bf9ec9b
commit 09ae5b20f3
13 changed files with 4045 additions and 187 deletions

View File

@@ -31,6 +31,13 @@ class RouterArgs:
host: str = "127.0.0.1"
port: int = 30000
# PD-specific configuration
pd_disaggregated: bool = False # Enable PD disaggregated mode
prefill_urls: List[tuple] = dataclasses.field(
default_factory=list
) # List of (url, bootstrap_port)
decode_urls: List[str] = dataclasses.field(default_factory=list)
# Routing policy
policy: str = "cache_aware"
worker_startup_timeout_secs: int = 300
@@ -40,7 +47,7 @@ class RouterArgs:
balance_rel_threshold: float = 1.0001
eviction_interval: int = 60
max_tree_size: int = 2**24
max_payload_size: int = 4 * 1024 * 1024 # 4MB
max_payload_size: int = 256 * 1024 * 1024 # 256MB default for large batches
verbose: bool = False
log_dir: Optional[str] = None
# Service discovery configuration
@@ -95,8 +102,29 @@ class RouterArgs:
f"--{prefix}policy",
type=str,
default=RouterArgs.policy,
choices=["random", "round_robin", "cache_aware"],
help="Load balancing policy to use",
choices=["random", "round_robin", "cache_aware", "power_of_two"],
help="Load balancing policy to use. Note: power_of_two is only available in PD disaggregated mode",
)
# PD-specific arguments
parser.add_argument(
f"--{prefix}pd-disaggregated",
action="store_true",
help="Enable PD (Prefill-Decode) disaggregated mode",
)
parser.add_argument(
f"--{prefix}prefill",
nargs=2,
action="append",
metavar=("URL", "BOOTSTRAP_PORT"),
help="Prefill server URL and bootstrap port. Can be specified multiple times. BOOTSTRAP_PORT can be 'none' for no bootstrap port.",
)
parser.add_argument(
f"--{prefix}decode",
nargs=1,
action="append",
metavar=("URL",),
help="Decode server URL. Can be specified multiple times.",
)
parser.add_argument(
f"--{prefix}worker-startup-timeout-secs",
@@ -205,11 +233,19 @@ class RouterArgs:
use_router_prefix: If True, look for arguments with 'router-' prefix
"""
prefix = "router_" if use_router_prefix else ""
worker_urls = args.worker_urls if args.worker_urls is not None else []
worker_urls = getattr(args, "worker_urls", [])
# Parse PD URLs
prefill_urls = cls._parse_prefill_urls(getattr(args, f"{prefix}prefill", None))
decode_urls = cls._parse_decode_urls(getattr(args, f"{prefix}decode", None))
return cls(
worker_urls=worker_urls,
host=args.host,
port=args.port,
pd_disaggregated=getattr(args, f"{prefix}pd_disaggregated", False),
prefill_urls=prefill_urls,
decode_urls=decode_urls,
policy=getattr(args, f"{prefix}policy"),
worker_startup_timeout_secs=getattr(
args, f"{prefix}worker_startup_timeout_secs"
@@ -247,6 +283,46 @@ class RouterArgs:
selector[key] = value
return selector
@staticmethod
def _parse_prefill_urls(prefill_list):
"""Parse prefill URLs from --prefill arguments.
Format: --prefill URL BOOTSTRAP_PORT
Example: --prefill http://prefill1:8080 9000 --prefill http://prefill2:8080 none
"""
if not prefill_list:
return []
prefill_urls = []
for url, bootstrap_port_str in prefill_list:
# Handle 'none' as None
if bootstrap_port_str.lower() == "none":
bootstrap_port = None
else:
try:
bootstrap_port = int(bootstrap_port_str)
except ValueError:
raise ValueError(
f"Invalid bootstrap port: {bootstrap_port_str}. Must be a number or 'none'"
)
prefill_urls.append((url, bootstrap_port))
return prefill_urls
@staticmethod
def _parse_decode_urls(decode_list):
"""Parse decode URLs from --decode arguments.
Format: --decode URL
Example: --decode http://decode1:8081 --decode http://decode2:8081
"""
if not decode_list:
return []
# decode_list is a list of single-element lists due to nargs=1
return [url[0] for url in decode_list]
def policy_from_str(policy_str: str) -> PolicyType:
"""Convert policy string to PolicyType enum."""
@@ -254,6 +330,7 @@ def policy_from_str(policy_str: str) -> PolicyType:
"random": PolicyType.Random,
"round_robin": PolicyType.RoundRobin,
"cache_aware": PolicyType.CacheAware,
"power_of_two": PolicyType.PowerOfTwo,
}
return policy_map[policy_str]
@@ -277,8 +354,19 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
else:
router_args = args
# Validate configuration based on mode
if router_args.pd_disaggregated:
# Validate PD configuration
if not router_args.prefill_urls:
raise ValueError("PD disaggregated mode requires --prefill")
if not router_args.decode_urls:
raise ValueError("PD disaggregated mode requires --decode")
# Create router with unified constructor
router = Router(
worker_urls=router_args.worker_urls,
worker_urls=(
router_args.worker_urls if not router_args.pd_disaggregated else []
),
host=router_args.host,
port=router_args.port,
policy=policy_from_str(router_args.policy),
@@ -298,6 +386,13 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
service_discovery_namespace=router_args.service_discovery_namespace,
prometheus_port=router_args.prometheus_port,
prometheus_host=router_args.prometheus_host,
pd_disaggregated=router_args.pd_disaggregated,
prefill_urls=(
router_args.prefill_urls if router_args.pd_disaggregated else None
),
decode_urls=(
router_args.decode_urls if router_args.pd_disaggregated else None
),
)
router.start()
@@ -326,8 +421,14 @@ This launcher enables starting a router with individual worker instances. It is
multi-node setups or when you want to start workers and router separately.
Examples:
# Regular mode
python -m sglang_router.launch_router --worker-urls http://worker1:8000 http://worker2:8000
python -m sglang_router.launch_router --worker-urls http://worker1:8000 http://worker2:8000 --cache-threshold 0.7 --balance-abs-threshold 64 --balance-rel-threshold 1.2
# PD disaggregated mode
python -m sglang_router.launch_router --pd-disaggregated \\
--prefill http://prefill1:8000 9000 --prefill http://prefill2:8000 none \\
--decode http://decode1:8001 --decode http://decode2:8001 \\
--policy cache_aware
""",
formatter_class=CustomHelpFormatter,