[router] add different policies for p node and d node (#8395)

This commit is contained in:
Simo Lin
2025-07-27 00:39:20 -07:00
committed by GitHub
parent 0bcc195f4e
commit 2ab97023e3
10 changed files with 536 additions and 81 deletions

View File

@@ -40,6 +40,8 @@ class RouterArgs:
# Routing policy
policy: str = "cache_aware"
prefill_policy: Optional[str] = None # Specific policy for prefill nodes in PD mode
decode_policy: Optional[str] = None # Specific policy for decode nodes in PD mode
worker_startup_timeout_secs: int = 300
worker_startup_check_interval: int = 10
cache_threshold: float = 0.5
@@ -108,7 +110,21 @@ class RouterArgs:
type=str,
default=RouterArgs.policy,
choices=["random", "round_robin", "cache_aware", "power_of_two"],
help="Load balancing policy to use. Note: power_of_two is only available in PD disaggregated mode",
help="Load balancing policy to use. In PD mode, this is used for both prefill and decode unless overridden",
)
parser.add_argument(
f"--{prefix}prefill-policy",
type=str,
default=None,
choices=["random", "round_robin", "cache_aware", "power_of_two"],
help="Specific policy for prefill nodes in PD mode. If not specified, uses the main policy",
)
parser.add_argument(
f"--{prefix}decode-policy",
type=str,
default=None,
choices=["random", "round_robin", "cache_aware", "power_of_two"],
help="Specific policy for decode nodes in PD mode. If not specified, uses the main policy",
)
# PD-specific arguments
@@ -266,6 +282,8 @@ class RouterArgs:
prefill_urls=prefill_urls,
decode_urls=decode_urls,
policy=getattr(args, f"{prefix}policy"),
prefill_policy=getattr(args, f"{prefix}prefill_policy", None),
decode_policy=getattr(args, f"{prefix}decode_policy", None),
worker_startup_timeout_secs=getattr(
args, f"{prefix}worker_startup_timeout_secs"
),
@@ -389,6 +407,35 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
if not router_args.decode_urls:
raise ValueError("PD disaggregation mode requires --decode")
# Warn about policy usage in PD mode
if (
router_args.prefill_policy
and router_args.decode_policy
and router_args.policy
):
logger.warning(
"Both --prefill-policy and --decode-policy are specified. "
"The main --policy flag will be ignored for PD mode."
)
elif (
router_args.prefill_policy
and not router_args.decode_policy
and router_args.policy
):
logger.info(
f"Using --prefill-policy '{router_args.prefill_policy}' for prefill nodes "
f"and --policy '{router_args.policy}' for decode nodes."
)
elif (
router_args.decode_policy
and not router_args.prefill_policy
and router_args.policy
):
logger.info(
f"Using --policy '{router_args.policy}' for prefill nodes "
f"and --decode-policy '{router_args.decode_policy}' for decode nodes."
)
# Create router with unified constructor
router = Router(
worker_urls=(
@@ -424,6 +471,16 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
decode_urls=(
router_args.decode_urls if router_args.pd_disaggregation else None
),
prefill_policy=(
policy_from_str(router_args.prefill_policy)
if router_args.prefill_policy
else None
),
decode_policy=(
policy_from_str(router_args.decode_policy)
if router_args.decode_policy
else None
),
)
router.start()
@@ -455,12 +512,18 @@ Examples:
# Regular mode
python -m sglang_router.launch_router --worker-urls http://worker1:8000 http://worker2:8000
# PD disaggregated mode
# PD disaggregated mode with same policy for both
python -m sglang_router.launch_router --pd-disaggregation \\
--prefill http://prefill1:8000 9000 --prefill http://prefill2:8000 none \\
--decode http://decode1:8001 --decode http://decode2:8001 \\
--policy cache_aware
# PD mode with different policies for prefill and decode
python -m sglang_router.launch_router --pd-disaggregation \\
--prefill http://prefill1:8000 9000 --prefill http://prefill2:8000 none \\
--decode http://decode1:8001 --decode http://decode2:8001 \\
--prefill-policy cache_aware --decode-policy power_of_two
""",
formatter_class=CustomHelpFormatter,
)