Merge PDLB (Prefill-Decode Load Balancer) into SGLang Router (#7096)
This commit is contained in:
@@ -31,6 +31,13 @@ class RouterArgs:
|
||||
host: str = "127.0.0.1"
|
||||
port: int = 30000
|
||||
|
||||
# PD-specific configuration
|
||||
pd_disaggregated: bool = False # Enable PD disaggregated mode
|
||||
prefill_urls: List[tuple] = dataclasses.field(
|
||||
default_factory=list
|
||||
) # List of (url, bootstrap_port)
|
||||
decode_urls: List[str] = dataclasses.field(default_factory=list)
|
||||
|
||||
# Routing policy
|
||||
policy: str = "cache_aware"
|
||||
worker_startup_timeout_secs: int = 300
|
||||
@@ -40,7 +47,7 @@ class RouterArgs:
|
||||
balance_rel_threshold: float = 1.0001
|
||||
eviction_interval: int = 60
|
||||
max_tree_size: int = 2**24
|
||||
max_payload_size: int = 4 * 1024 * 1024 # 4MB
|
||||
max_payload_size: int = 256 * 1024 * 1024 # 256MB default for large batches
|
||||
verbose: bool = False
|
||||
log_dir: Optional[str] = None
|
||||
# Service discovery configuration
|
||||
@@ -95,8 +102,29 @@ class RouterArgs:
|
||||
f"--{prefix}policy",
|
||||
type=str,
|
||||
default=RouterArgs.policy,
|
||||
choices=["random", "round_robin", "cache_aware"],
|
||||
help="Load balancing policy to use",
|
||||
choices=["random", "round_robin", "cache_aware", "power_of_two"],
|
||||
help="Load balancing policy to use. Note: power_of_two is only available in PD disaggregated mode",
|
||||
)
|
||||
|
||||
# PD-specific arguments
|
||||
parser.add_argument(
|
||||
f"--{prefix}pd-disaggregated",
|
||||
action="store_true",
|
||||
help="Enable PD (Prefill-Decode) disaggregated mode",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}prefill",
|
||||
nargs=2,
|
||||
action="append",
|
||||
metavar=("URL", "BOOTSTRAP_PORT"),
|
||||
help="Prefill server URL and bootstrap port. Can be specified multiple times. BOOTSTRAP_PORT can be 'none' for no bootstrap port.",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}decode",
|
||||
nargs=1,
|
||||
action="append",
|
||||
metavar=("URL",),
|
||||
help="Decode server URL. Can be specified multiple times.",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}worker-startup-timeout-secs",
|
||||
@@ -205,11 +233,19 @@ class RouterArgs:
|
||||
use_router_prefix: If True, look for arguments with 'router-' prefix
|
||||
"""
|
||||
prefix = "router_" if use_router_prefix else ""
|
||||
worker_urls = args.worker_urls if args.worker_urls is not None else []
|
||||
worker_urls = getattr(args, "worker_urls", [])
|
||||
|
||||
# Parse PD URLs
|
||||
prefill_urls = cls._parse_prefill_urls(getattr(args, f"{prefix}prefill", None))
|
||||
decode_urls = cls._parse_decode_urls(getattr(args, f"{prefix}decode", None))
|
||||
|
||||
return cls(
|
||||
worker_urls=worker_urls,
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
pd_disaggregated=getattr(args, f"{prefix}pd_disaggregated", False),
|
||||
prefill_urls=prefill_urls,
|
||||
decode_urls=decode_urls,
|
||||
policy=getattr(args, f"{prefix}policy"),
|
||||
worker_startup_timeout_secs=getattr(
|
||||
args, f"{prefix}worker_startup_timeout_secs"
|
||||
@@ -247,6 +283,46 @@ class RouterArgs:
|
||||
selector[key] = value
|
||||
return selector
|
||||
|
||||
@staticmethod
|
||||
def _parse_prefill_urls(prefill_list):
|
||||
"""Parse prefill URLs from --prefill arguments.
|
||||
|
||||
Format: --prefill URL BOOTSTRAP_PORT
|
||||
Example: --prefill http://prefill1:8080 9000 --prefill http://prefill2:8080 none
|
||||
"""
|
||||
if not prefill_list:
|
||||
return []
|
||||
|
||||
prefill_urls = []
|
||||
for url, bootstrap_port_str in prefill_list:
|
||||
# Handle 'none' as None
|
||||
if bootstrap_port_str.lower() == "none":
|
||||
bootstrap_port = None
|
||||
else:
|
||||
try:
|
||||
bootstrap_port = int(bootstrap_port_str)
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
f"Invalid bootstrap port: {bootstrap_port_str}. Must be a number or 'none'"
|
||||
)
|
||||
|
||||
prefill_urls.append((url, bootstrap_port))
|
||||
|
||||
return prefill_urls
|
||||
|
||||
@staticmethod
|
||||
def _parse_decode_urls(decode_list):
|
||||
"""Parse decode URLs from --decode arguments.
|
||||
|
||||
Format: --decode URL
|
||||
Example: --decode http://decode1:8081 --decode http://decode2:8081
|
||||
"""
|
||||
if not decode_list:
|
||||
return []
|
||||
|
||||
# decode_list is a list of single-element lists due to nargs=1
|
||||
return [url[0] for url in decode_list]
|
||||
|
||||
|
||||
def policy_from_str(policy_str: str) -> PolicyType:
|
||||
"""Convert policy string to PolicyType enum."""
|
||||
@@ -254,6 +330,7 @@ def policy_from_str(policy_str: str) -> PolicyType:
|
||||
"random": PolicyType.Random,
|
||||
"round_robin": PolicyType.RoundRobin,
|
||||
"cache_aware": PolicyType.CacheAware,
|
||||
"power_of_two": PolicyType.PowerOfTwo,
|
||||
}
|
||||
return policy_map[policy_str]
|
||||
|
||||
@@ -277,8 +354,19 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
|
||||
else:
|
||||
router_args = args
|
||||
|
||||
# Validate configuration based on mode
|
||||
if router_args.pd_disaggregated:
|
||||
# Validate PD configuration
|
||||
if not router_args.prefill_urls:
|
||||
raise ValueError("PD disaggregated mode requires --prefill")
|
||||
if not router_args.decode_urls:
|
||||
raise ValueError("PD disaggregated mode requires --decode")
|
||||
|
||||
# Create router with unified constructor
|
||||
router = Router(
|
||||
worker_urls=router_args.worker_urls,
|
||||
worker_urls=(
|
||||
router_args.worker_urls if not router_args.pd_disaggregated else []
|
||||
),
|
||||
host=router_args.host,
|
||||
port=router_args.port,
|
||||
policy=policy_from_str(router_args.policy),
|
||||
@@ -298,6 +386,13 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
|
||||
service_discovery_namespace=router_args.service_discovery_namespace,
|
||||
prometheus_port=router_args.prometheus_port,
|
||||
prometheus_host=router_args.prometheus_host,
|
||||
pd_disaggregated=router_args.pd_disaggregated,
|
||||
prefill_urls=(
|
||||
router_args.prefill_urls if router_args.pd_disaggregated else None
|
||||
),
|
||||
decode_urls=(
|
||||
router_args.decode_urls if router_args.pd_disaggregated else None
|
||||
),
|
||||
)
|
||||
|
||||
router.start()
|
||||
@@ -326,8 +421,14 @@ This launcher enables starting a router with individual worker instances. It is
|
||||
multi-node setups or when you want to start workers and router separately.
|
||||
|
||||
Examples:
|
||||
# Regular mode
|
||||
python -m sglang_router.launch_router --worker-urls http://worker1:8000 http://worker2:8000
|
||||
python -m sglang_router.launch_router --worker-urls http://worker1:8000 http://worker2:8000 --cache-threshold 0.7 --balance-abs-threshold 64 --balance-rel-threshold 1.2
|
||||
|
||||
# PD disaggregated mode
|
||||
python -m sglang_router.launch_router --pd-disaggregated \\
|
||||
--prefill http://prefill1:8000 9000 --prefill http://prefill2:8000 none \\
|
||||
--decode http://decode1:8001 --decode http://decode2:8001 \\
|
||||
--policy cache_aware
|
||||
|
||||
""",
|
||||
formatter_class=CustomHelpFormatter,
|
||||
|
||||
Reference in New Issue
Block a user