Merge PDLB (Prefill-Decode Load Balancer) into SGLang Router (#7096)
This commit is contained in:
@@ -35,13 +35,21 @@ class TestLaunchRouter(unittest.TestCase):
|
||||
balance_rel_threshold=1.0001,
|
||||
eviction_interval=60,
|
||||
max_tree_size=2**24,
|
||||
max_payload_size=4 * 1024 * 1024, # 4MB
|
||||
max_payload_size=256 * 1024 * 1024, # 256MB
|
||||
verbose=False,
|
||||
log_dir=None,
|
||||
service_discovery=False,
|
||||
selector=None,
|
||||
service_discovery_port=80,
|
||||
service_discovery_namespace=None,
|
||||
prometheus_port=None,
|
||||
prometheus_host=None,
|
||||
# PD-specific attributes
|
||||
pd_disaggregated=False,
|
||||
prefill=None,
|
||||
decode=None,
|
||||
# Keep worker_urls for regular mode
|
||||
worker_urls=[],
|
||||
)
|
||||
|
||||
def create_router_args(self, **kwargs):
|
||||
@@ -81,7 +89,7 @@ class TestLaunchRouter(unittest.TestCase):
|
||||
|
||||
def test_launch_router_with_empty_worker_urls(self):
|
||||
args = self.create_router_args(worker_urls=[])
|
||||
self.run_router_process(args)
|
||||
self.run_router_process(args) # Expected error
|
||||
|
||||
def test_launch_router_with_service_discovery(self):
|
||||
# Test router startup with service discovery enabled but no selectors
|
||||
@@ -100,6 +108,112 @@ class TestLaunchRouter(unittest.TestCase):
|
||||
)
|
||||
self.run_router_process(args)
|
||||
|
||||
def test_launch_router_pd_mode_basic(self):
|
||||
"""Test basic PD router functionality without actually starting servers."""
|
||||
# This test just verifies the PD router can be created and configured
|
||||
# without actually starting it (which would require real prefill/decode servers)
|
||||
from sglang_router import Router
|
||||
from sglang_router.launch_router import RouterArgs
|
||||
from sglang_router_rs import PolicyType
|
||||
|
||||
# Test RouterArgs parsing for PD mode
|
||||
# Simulate the parsed args structure from argparse with action="append"
|
||||
args = self.create_router_args(
|
||||
pd_disaggregated=True,
|
||||
policy="power_of_two", # PowerOfTwo is only valid in PD mode
|
||||
prefill=[
|
||||
["http://prefill1:8080", "9000"],
|
||||
["http://prefill2:8080", "none"],
|
||||
],
|
||||
decode=[
|
||||
["http://decode1:8081"],
|
||||
["http://decode2:8081"],
|
||||
],
|
||||
worker_urls=[], # Empty for PD mode
|
||||
)
|
||||
|
||||
router_args = RouterArgs.from_cli_args(args)
|
||||
self.assertTrue(router_args.pd_disaggregated)
|
||||
self.assertEqual(router_args.policy, "power_of_two")
|
||||
self.assertEqual(len(router_args.prefill_urls), 2)
|
||||
self.assertEqual(len(router_args.decode_urls), 2)
|
||||
|
||||
# Verify the parsed URLs and bootstrap ports
|
||||
self.assertEqual(router_args.prefill_urls[0], ("http://prefill1:8080", 9000))
|
||||
self.assertEqual(router_args.prefill_urls[1], ("http://prefill2:8080", None))
|
||||
self.assertEqual(router_args.decode_urls[0], "http://decode1:8081")
|
||||
self.assertEqual(router_args.decode_urls[1], "http://decode2:8081")
|
||||
|
||||
# Test Router creation in PD mode
|
||||
router = Router(
|
||||
worker_urls=[], # Empty for PD mode
|
||||
pd_disaggregated=True,
|
||||
prefill_urls=[
|
||||
("http://prefill1:8080", 9000),
|
||||
("http://prefill2:8080", None),
|
||||
],
|
||||
decode_urls=["http://decode1:8081", "http://decode2:8081"],
|
||||
policy=PolicyType.CacheAware,
|
||||
host="127.0.0.1",
|
||||
port=3001,
|
||||
)
|
||||
self.assertIsNotNone(router)
|
||||
|
||||
def test_policy_validation(self):
|
||||
"""Test that policy validation works correctly for PD and regular modes."""
|
||||
from sglang_router.launch_router import RouterArgs, launch_router
|
||||
|
||||
# Test 1: PowerOfTwo is only valid in PD mode
|
||||
args = self.create_router_args(
|
||||
pd_disaggregated=False,
|
||||
policy="power_of_two",
|
||||
worker_urls=["http://localhost:8000"],
|
||||
)
|
||||
|
||||
# Should raise error
|
||||
with self.assertRaises(ValueError) as cm:
|
||||
launch_router(args)
|
||||
self.assertIn(
|
||||
"PowerOfTwo policy is only supported in PD disaggregated mode",
|
||||
str(cm.exception),
|
||||
)
|
||||
|
||||
# Test 2: RoundRobin is not valid in PD mode
|
||||
args = self.create_router_args(
|
||||
pd_disaggregated=True,
|
||||
policy="round_robin",
|
||||
prefill=[["http://prefill1:8080", "9000"]],
|
||||
decode=[["http://decode1:8081"]],
|
||||
worker_urls=[],
|
||||
)
|
||||
|
||||
# Should raise error
|
||||
with self.assertRaises(ValueError) as cm:
|
||||
launch_router(args)
|
||||
self.assertIn(
|
||||
"RoundRobin policy is not supported in PD disaggregated mode",
|
||||
str(cm.exception),
|
||||
)
|
||||
|
||||
# Test 3: Valid combinations should not raise errors
|
||||
# Regular mode with RoundRobin
|
||||
args = self.create_router_args(
|
||||
pd_disaggregated=False,
|
||||
policy="round_robin",
|
||||
worker_urls=["http://localhost:8000"],
|
||||
)
|
||||
# This should not raise (though it may fail to connect)
|
||||
|
||||
# PD mode with PowerOfTwo
|
||||
args = self.create_router_args(
|
||||
pd_disaggregated=True,
|
||||
policy="power_of_two",
|
||||
prefill=[["http://prefill1:8080", "9000"]],
|
||||
decode=[["http://decode1:8081"]],
|
||||
worker_urls=[],
|
||||
)
|
||||
# This should not raise (though it may fail to connect)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user