[router] add different policies for p node and d node (#8395)
This commit is contained in:
@@ -40,6 +40,8 @@ class RouterArgs:
|
||||
|
||||
# Routing policy
|
||||
policy: str = "cache_aware"
|
||||
prefill_policy: Optional[str] = None # Specific policy for prefill nodes in PD mode
|
||||
decode_policy: Optional[str] = None # Specific policy for decode nodes in PD mode
|
||||
worker_startup_timeout_secs: int = 300
|
||||
worker_startup_check_interval: int = 10
|
||||
cache_threshold: float = 0.5
|
||||
@@ -108,7 +110,21 @@ class RouterArgs:
|
||||
type=str,
|
||||
default=RouterArgs.policy,
|
||||
choices=["random", "round_robin", "cache_aware", "power_of_two"],
|
||||
help="Load balancing policy to use. Note: power_of_two is only available in PD disaggregated mode",
|
||||
help="Load balancing policy to use. In PD mode, this is used for both prefill and decode unless overridden",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}prefill-policy",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=["random", "round_robin", "cache_aware", "power_of_two"],
|
||||
help="Specific policy for prefill nodes in PD mode. If not specified, uses the main policy",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}decode-policy",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=["random", "round_robin", "cache_aware", "power_of_two"],
|
||||
help="Specific policy for decode nodes in PD mode. If not specified, uses the main policy",
|
||||
)
|
||||
|
||||
# PD-specific arguments
|
||||
@@ -266,6 +282,8 @@ class RouterArgs:
|
||||
prefill_urls=prefill_urls,
|
||||
decode_urls=decode_urls,
|
||||
policy=getattr(args, f"{prefix}policy"),
|
||||
prefill_policy=getattr(args, f"{prefix}prefill_policy", None),
|
||||
decode_policy=getattr(args, f"{prefix}decode_policy", None),
|
||||
worker_startup_timeout_secs=getattr(
|
||||
args, f"{prefix}worker_startup_timeout_secs"
|
||||
),
|
||||
@@ -389,6 +407,35 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
|
||||
if not router_args.decode_urls:
|
||||
raise ValueError("PD disaggregation mode requires --decode")
|
||||
|
||||
# Warn about policy usage in PD mode
|
||||
if (
|
||||
router_args.prefill_policy
|
||||
and router_args.decode_policy
|
||||
and router_args.policy
|
||||
):
|
||||
logger.warning(
|
||||
"Both --prefill-policy and --decode-policy are specified. "
|
||||
"The main --policy flag will be ignored for PD mode."
|
||||
)
|
||||
elif (
|
||||
router_args.prefill_policy
|
||||
and not router_args.decode_policy
|
||||
and router_args.policy
|
||||
):
|
||||
logger.info(
|
||||
f"Using --prefill-policy '{router_args.prefill_policy}' for prefill nodes "
|
||||
f"and --policy '{router_args.policy}' for decode nodes."
|
||||
)
|
||||
elif (
|
||||
router_args.decode_policy
|
||||
and not router_args.prefill_policy
|
||||
and router_args.policy
|
||||
):
|
||||
logger.info(
|
||||
f"Using --policy '{router_args.policy}' for prefill nodes "
|
||||
f"and --decode-policy '{router_args.decode_policy}' for decode nodes."
|
||||
)
|
||||
|
||||
# Create router with unified constructor
|
||||
router = Router(
|
||||
worker_urls=(
|
||||
@@ -424,6 +471,16 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
|
||||
decode_urls=(
|
||||
router_args.decode_urls if router_args.pd_disaggregation else None
|
||||
),
|
||||
prefill_policy=(
|
||||
policy_from_str(router_args.prefill_policy)
|
||||
if router_args.prefill_policy
|
||||
else None
|
||||
),
|
||||
decode_policy=(
|
||||
policy_from_str(router_args.decode_policy)
|
||||
if router_args.decode_policy
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
router.start()
|
||||
@@ -455,12 +512,18 @@ Examples:
|
||||
# Regular mode
|
||||
python -m sglang_router.launch_router --worker-urls http://worker1:8000 http://worker2:8000
|
||||
|
||||
# PD disaggregated mode
|
||||
# PD disaggregated mode with same policy for both
|
||||
python -m sglang_router.launch_router --pd-disaggregation \\
|
||||
--prefill http://prefill1:8000 9000 --prefill http://prefill2:8000 none \\
|
||||
--decode http://decode1:8001 --decode http://decode2:8001 \\
|
||||
--policy cache_aware
|
||||
|
||||
# PD mode with different policies for prefill and decode
|
||||
python -m sglang_router.launch_router --pd-disaggregation \\
|
||||
--prefill http://prefill1:8000 9000 --prefill http://prefill2:8000 none \\
|
||||
--decode http://decode1:8001 --decode http://decode2:8001 \\
|
||||
--prefill-policy cache_aware --decode-policy power_of_two
|
||||
|
||||
""",
|
||||
formatter_class=CustomHelpFormatter,
|
||||
)
|
||||
|
||||
@@ -50,6 +50,10 @@ class Router:
|
||||
pd_disaggregation: Enable PD (Prefill-Decode) disaggregated mode. Default: False
|
||||
prefill_urls: List of (url, bootstrap_port) tuples for prefill servers (PD mode only)
|
||||
decode_urls: List of URLs for decode servers (PD mode only)
|
||||
prefill_policy: Specific load balancing policy for prefill nodes (PD mode only).
|
||||
If not specified, uses the main policy. Default: None
|
||||
decode_policy: Specific load balancing policy for decode nodes (PD mode only).
|
||||
If not specified, uses the main policy. Default: None
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -79,6 +83,8 @@ class Router:
|
||||
pd_disaggregation: bool = False,
|
||||
prefill_urls: Optional[List[tuple]] = None,
|
||||
decode_urls: Optional[List[str]] = None,
|
||||
prefill_policy: Optional[PolicyType] = None,
|
||||
decode_policy: Optional[PolicyType] = None,
|
||||
):
|
||||
if selector is None:
|
||||
selector = {}
|
||||
@@ -113,6 +119,8 @@ class Router:
|
||||
pd_disaggregation=pd_disaggregation,
|
||||
prefill_urls=prefill_urls,
|
||||
decode_urls=decode_urls,
|
||||
prefill_policy=prefill_policy,
|
||||
decode_policy=decode_policy,
|
||||
)
|
||||
|
||||
def start(self) -> None:
|
||||
|
||||
Reference in New Issue
Block a user