[misc] Add PD service discovery support in router (#7361)

This commit is contained in:
Simo Lin
2025-06-22 17:54:14 -07:00
committed by GitHub
parent bd4f581896
commit 30f2a44a96
11 changed files with 1362 additions and 120 deletions

View File

@@ -45,7 +45,7 @@ class TestLaunchRouter(unittest.TestCase):
prometheus_port=None,
prometheus_host=None,
# PD-specific attributes
pd_disaggregated=False,
pd_disaggregation=False,
prefill=None,
decode=None,
# Keep worker_urls for regular mode
@@ -119,7 +119,7 @@ class TestLaunchRouter(unittest.TestCase):
# Test RouterArgs parsing for PD mode
# Simulate the parsed args structure from argparse with action="append"
args = self.create_router_args(
pd_disaggregated=True,
pd_disaggregation=True,
policy="power_of_two", # PowerOfTwo is only valid in PD mode
prefill=[
["http://prefill1:8080", "9000"],
@@ -133,7 +133,7 @@ class TestLaunchRouter(unittest.TestCase):
)
router_args = RouterArgs.from_cli_args(args)
self.assertTrue(router_args.pd_disaggregated)
self.assertTrue(router_args.pd_disaggregation)
self.assertEqual(router_args.policy, "power_of_two")
self.assertEqual(len(router_args.prefill_urls), 2)
self.assertEqual(len(router_args.decode_urls), 2)
@@ -147,7 +147,7 @@ class TestLaunchRouter(unittest.TestCase):
# Test Router creation in PD mode
router = Router(
worker_urls=[], # Empty for PD mode
pd_disaggregated=True,
pd_disaggregation=True,
prefill_urls=[
("http://prefill1:8080", 9000),
("http://prefill2:8080", None),
@@ -165,7 +165,7 @@ class TestLaunchRouter(unittest.TestCase):
# Test 1: PowerOfTwo is only valid in PD mode
args = self.create_router_args(
pd_disaggregated=False,
pd_disaggregation=False,
policy="power_of_two",
worker_urls=["http://localhost:8000"],
)
@@ -180,7 +180,7 @@ class TestLaunchRouter(unittest.TestCase):
# Test 2: RoundRobin is not valid in PD mode
args = self.create_router_args(
pd_disaggregated=True,
pd_disaggregation=True,
policy="round_robin",
prefill=[["http://prefill1:8080", "9000"]],
decode=[["http://decode1:8081"]],
@@ -198,7 +198,7 @@ class TestLaunchRouter(unittest.TestCase):
# Test 3: Valid combinations should not raise errors
# Regular mode with RoundRobin
args = self.create_router_args(
pd_disaggregated=False,
pd_disaggregation=False,
policy="round_robin",
worker_urls=["http://localhost:8000"],
)
@@ -206,7 +206,7 @@ class TestLaunchRouter(unittest.TestCase):
# PD mode with PowerOfTwo
args = self.create_router_args(
pd_disaggregated=True,
pd_disaggregation=True,
policy="power_of_two",
prefill=[["http://prefill1:8080", "9000"]],
decode=[["http://decode1:8081"]],
@@ -214,6 +214,79 @@ class TestLaunchRouter(unittest.TestCase):
)
# This should not raise (though it may fail to connect)
def test_pd_service_discovery_args_parsing(self):
"""Test PD service discovery CLI argument parsing."""
import argparse
from sglang_router.launch_router import RouterArgs
parser = argparse.ArgumentParser()
RouterArgs.add_cli_args(parser)
args = parser.parse_args(
[
"--pd-disaggregation",
"--service-discovery",
"--prefill-selector",
"app=sglang",
"component=prefill",
"--decode-selector",
"app=sglang",
"component=decode",
"--service-discovery-port",
"8000",
"--service-discovery-namespace",
"production",
"--policy",
"cache_aware",
]
)
router_args = RouterArgs.from_cli_args(args)
self.assertTrue(router_args.pd_disaggregation)
self.assertTrue(router_args.service_discovery)
self.assertEqual(
router_args.prefill_selector, {"app": "sglang", "component": "prefill"}
)
self.assertEqual(
router_args.decode_selector, {"app": "sglang", "component": "decode"}
)
self.assertEqual(router_args.service_discovery_port, 8000)
self.assertEqual(router_args.service_discovery_namespace, "production")
def test_regular_service_discovery_args_parsing(self):
"""Test regular mode service discovery CLI argument parsing."""
import argparse
from sglang_router.launch_router import RouterArgs
parser = argparse.ArgumentParser()
RouterArgs.add_cli_args(parser)
args = parser.parse_args(
[
"--service-discovery",
"--selector",
"app=sglang-worker",
"environment=staging",
"--service-discovery-port",
"8000",
"--policy",
"round_robin",
]
)
router_args = RouterArgs.from_cli_args(args)
self.assertFalse(router_args.pd_disaggregation)
self.assertTrue(router_args.service_discovery)
self.assertEqual(
router_args.selector, {"app": "sglang-worker", "environment": "staging"}
)
self.assertEqual(router_args.prefill_selector, {})
self.assertEqual(router_args.decode_selector, {})
if __name__ == "__main__":
unittest.main()