[misc] Add PD service discovery support in router (#7361)
This commit is contained in:
@@ -45,7 +45,7 @@ class TestLaunchRouter(unittest.TestCase):
|
||||
prometheus_port=None,
|
||||
prometheus_host=None,
|
||||
# PD-specific attributes
|
||||
pd_disaggregated=False,
|
||||
pd_disaggregation=False,
|
||||
prefill=None,
|
||||
decode=None,
|
||||
# Keep worker_urls for regular mode
|
||||
@@ -119,7 +119,7 @@ class TestLaunchRouter(unittest.TestCase):
|
||||
# Test RouterArgs parsing for PD mode
|
||||
# Simulate the parsed args structure from argparse with action="append"
|
||||
args = self.create_router_args(
|
||||
pd_disaggregated=True,
|
||||
pd_disaggregation=True,
|
||||
policy="power_of_two", # PowerOfTwo is only valid in PD mode
|
||||
prefill=[
|
||||
["http://prefill1:8080", "9000"],
|
||||
@@ -133,7 +133,7 @@ class TestLaunchRouter(unittest.TestCase):
|
||||
)
|
||||
|
||||
router_args = RouterArgs.from_cli_args(args)
|
||||
self.assertTrue(router_args.pd_disaggregated)
|
||||
self.assertTrue(router_args.pd_disaggregation)
|
||||
self.assertEqual(router_args.policy, "power_of_two")
|
||||
self.assertEqual(len(router_args.prefill_urls), 2)
|
||||
self.assertEqual(len(router_args.decode_urls), 2)
|
||||
@@ -147,7 +147,7 @@ class TestLaunchRouter(unittest.TestCase):
|
||||
# Test Router creation in PD mode
|
||||
router = Router(
|
||||
worker_urls=[], # Empty for PD mode
|
||||
pd_disaggregated=True,
|
||||
pd_disaggregation=True,
|
||||
prefill_urls=[
|
||||
("http://prefill1:8080", 9000),
|
||||
("http://prefill2:8080", None),
|
||||
@@ -165,7 +165,7 @@ class TestLaunchRouter(unittest.TestCase):
|
||||
|
||||
# Test 1: PowerOfTwo is only valid in PD mode
|
||||
args = self.create_router_args(
|
||||
pd_disaggregated=False,
|
||||
pd_disaggregation=False,
|
||||
policy="power_of_two",
|
||||
worker_urls=["http://localhost:8000"],
|
||||
)
|
||||
@@ -180,7 +180,7 @@ class TestLaunchRouter(unittest.TestCase):
|
||||
|
||||
# Test 2: RoundRobin is not valid in PD mode
|
||||
args = self.create_router_args(
|
||||
pd_disaggregated=True,
|
||||
pd_disaggregation=True,
|
||||
policy="round_robin",
|
||||
prefill=[["http://prefill1:8080", "9000"]],
|
||||
decode=[["http://decode1:8081"]],
|
||||
@@ -198,7 +198,7 @@ class TestLaunchRouter(unittest.TestCase):
|
||||
# Test 3: Valid combinations should not raise errors
|
||||
# Regular mode with RoundRobin
|
||||
args = self.create_router_args(
|
||||
pd_disaggregated=False,
|
||||
pd_disaggregation=False,
|
||||
policy="round_robin",
|
||||
worker_urls=["http://localhost:8000"],
|
||||
)
|
||||
@@ -206,7 +206,7 @@ class TestLaunchRouter(unittest.TestCase):
|
||||
|
||||
# PD mode with PowerOfTwo
|
||||
args = self.create_router_args(
|
||||
pd_disaggregated=True,
|
||||
pd_disaggregation=True,
|
||||
policy="power_of_two",
|
||||
prefill=[["http://prefill1:8080", "9000"]],
|
||||
decode=[["http://decode1:8081"]],
|
||||
@@ -214,6 +214,79 @@ class TestLaunchRouter(unittest.TestCase):
|
||||
)
|
||||
# This should not raise (though it may fail to connect)
|
||||
|
||||
def test_pd_service_discovery_args_parsing(self):
|
||||
"""Test PD service discovery CLI argument parsing."""
|
||||
import argparse
|
||||
|
||||
from sglang_router.launch_router import RouterArgs
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
RouterArgs.add_cli_args(parser)
|
||||
|
||||
args = parser.parse_args(
|
||||
[
|
||||
"--pd-disaggregation",
|
||||
"--service-discovery",
|
||||
"--prefill-selector",
|
||||
"app=sglang",
|
||||
"component=prefill",
|
||||
"--decode-selector",
|
||||
"app=sglang",
|
||||
"component=decode",
|
||||
"--service-discovery-port",
|
||||
"8000",
|
||||
"--service-discovery-namespace",
|
||||
"production",
|
||||
"--policy",
|
||||
"cache_aware",
|
||||
]
|
||||
)
|
||||
|
||||
router_args = RouterArgs.from_cli_args(args)
|
||||
|
||||
self.assertTrue(router_args.pd_disaggregation)
|
||||
self.assertTrue(router_args.service_discovery)
|
||||
self.assertEqual(
|
||||
router_args.prefill_selector, {"app": "sglang", "component": "prefill"}
|
||||
)
|
||||
self.assertEqual(
|
||||
router_args.decode_selector, {"app": "sglang", "component": "decode"}
|
||||
)
|
||||
self.assertEqual(router_args.service_discovery_port, 8000)
|
||||
self.assertEqual(router_args.service_discovery_namespace, "production")
|
||||
|
||||
def test_regular_service_discovery_args_parsing(self):
|
||||
"""Test regular mode service discovery CLI argument parsing."""
|
||||
import argparse
|
||||
|
||||
from sglang_router.launch_router import RouterArgs
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
RouterArgs.add_cli_args(parser)
|
||||
|
||||
args = parser.parse_args(
|
||||
[
|
||||
"--service-discovery",
|
||||
"--selector",
|
||||
"app=sglang-worker",
|
||||
"environment=staging",
|
||||
"--service-discovery-port",
|
||||
"8000",
|
||||
"--policy",
|
||||
"round_robin",
|
||||
]
|
||||
)
|
||||
|
||||
router_args = RouterArgs.from_cli_args(args)
|
||||
|
||||
self.assertFalse(router_args.pd_disaggregation)
|
||||
self.assertTrue(router_args.service_discovery)
|
||||
self.assertEqual(
|
||||
router_args.selector, {"app": "sglang-worker", "environment": "staging"}
|
||||
)
|
||||
self.assertEqual(router_args.prefill_selector, {})
|
||||
self.assertEqual(router_args.decode_selector, {})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user