From a730ce8162145180b4d4cae8d6fe28cdb58dab69 Mon Sep 17 00:00:00 2001 From: Rui Chen Date: Wed, 30 Jul 2025 20:58:48 +0800 Subject: [PATCH] [feature] [sgl-router] Add a dp-aware routing strategy (#6869) --- docs/router/router.md | 8 + .../py_src/sglang_router/launch_router.py | 17 + sgl-router/py_src/sglang_router/router.py | 8 + sgl-router/py_test/run_suite.py | 2 +- sgl-router/py_test/test_launch_router.py | 47 +++ sgl-router/py_test/test_launch_server.py | 279 ++++++++++++++++ sgl-router/src/config/types.rs | 14 + sgl-router/src/config/validation.rs | 8 + sgl-router/src/core/error.rs | 5 + sgl-router/src/core/worker.rs | 28 +- sgl-router/src/lib.rs | 10 + sgl-router/src/routers/factory.rs | 2 + sgl-router/src/routers/router.rs | 298 +++++++++++++++++- sgl-router/src/service_discovery.rs | 2 +- sgl-router/tests/api_endpoints_test.rs | 4 + sgl-router/tests/common/mod.rs | 4 + sgl-router/tests/request_formats_test.rs | 2 + sgl-router/tests/streaming_tests.rs | 2 + sgl-router/tests/test_pd_routing.rs | 2 + 19 files changed, 726 insertions(+), 16 deletions(-) diff --git a/docs/router/router.md b/docs/router/router.md index 8267007e1..7339144fa 100644 --- a/docs/router/router.md +++ b/docs/router/router.md @@ -141,6 +141,14 @@ Process: For unbalanced systems, this strategy tracks pending request counts per worker and routes new requests to the least busy worker. This helps maintain optimal load distribution across workers. +***Data-Parallelism Aware Routing*** + +An additional DP-aware routing strategy can be enabled on top of the sgl-router’s hybrid cache-aware load-balancing strategy by setting the `--dp-aware` flag when starting the router. + +When this flag is enabled, the router attempts to contact the workers to retrieve the `dp_size` of each one and registers the new workers at the DP-rank level. In this mode, the router applies the cache-aware routing strategy in a more fine-grained manner, with assistance from the DP controller on the SRT side. + +By default (when the flag is not set), the SRT’s DP controller distributes incoming requests across DP ranks in a round-robin fashion. + ## Configuration Parameters 1. `cache_threshold`: (float, 0.0 to 1.0, default: 0.5) diff --git a/sgl-router/py_src/sglang_router/launch_router.py b/sgl-router/py_src/sglang_router/launch_router.py index 9337c4eaa..13fada0f5 100644 --- a/sgl-router/py_src/sglang_router/launch_router.py +++ b/sgl-router/py_src/sglang_router/launch_router.py @@ -50,6 +50,8 @@ class RouterArgs: eviction_interval: int = 60 max_tree_size: int = 2**24 max_payload_size: int = 256 * 1024 * 1024 # 256MB default for large batches + dp_aware: bool = False + api_key: Optional[str] = None log_dir: Optional[str] = None log_level: Optional[str] = None # Service discovery configuration @@ -197,6 +199,17 @@ class RouterArgs: default=RouterArgs.max_payload_size, help="Maximum payload size in bytes", ) + parser.add_argument( + f"--{prefix}dp-aware", + action="store_true", + help="Enable data parallelism aware schedule", + ) + parser.add_argument( + f"--{prefix}api-key", + type=str, + default=None, + help="The api key used for the authorization with the worker. Useful when the dp aware scheduling strategy is enaled.", + ) parser.add_argument( f"--{prefix}log-dir", type=str, @@ -304,6 +317,8 @@ class RouterArgs: eviction_interval=getattr(args, f"{prefix}eviction_interval"), max_tree_size=getattr(args, f"{prefix}max_tree_size"), max_payload_size=getattr(args, f"{prefix}max_payload_size"), + dp_aware=getattr(args, f"{prefix}dp_aware", False), + api_key=getattr(args, f"{prefix}api_key", None), log_dir=getattr(args, f"{prefix}log_dir", None), log_level=getattr(args, f"{prefix}log_level", None), service_discovery=getattr(args, f"{prefix}service_discovery", False), @@ -463,6 +478,8 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]: eviction_interval_secs=router_args.eviction_interval, max_tree_size=router_args.max_tree_size, max_payload_size=router_args.max_payload_size, + dp_aware=router_args.dp_aware, + api_key=router_args.api_key, log_dir=router_args.log_dir, log_level=router_args.log_level, service_discovery=router_args.service_discovery, diff --git a/sgl-router/py_src/sglang_router/router.py b/sgl-router/py_src/sglang_router/router.py index 7b85f7767..7bde7f022 100644 --- a/sgl-router/py_src/sglang_router/router.py +++ b/sgl-router/py_src/sglang_router/router.py @@ -31,6 +31,10 @@ class Router: routing. Default: 60 max_payload_size: Maximum payload size in bytes. Default: 256MB max_tree_size: Maximum size of the approximation tree for cache-aware routing. Default: 2^24 + dp_aware: Enable data parallelism aware schedule. Default: False + api_key: The api key used for the authorization with the worker. + Useful when the dp aware scheduling strategy is enabled. + Default: None log_dir: Directory to store log files. If None, logs are only output to console. Default: None log_level: Logging level. Options: 'debug', 'info', 'warning', 'error', 'critical'. service_discovery: Enable Kubernetes service discovery. When enabled, the router will @@ -73,6 +77,8 @@ class Router: eviction_interval_secs: int = 60, max_tree_size: int = 2**24, max_payload_size: int = 256 * 1024 * 1024, # 256MB + dp_aware: bool = False, + api_key: Optional[str] = None, log_dir: Optional[str] = None, log_level: Optional[str] = None, service_discovery: bool = False, @@ -110,6 +116,8 @@ class Router: eviction_interval_secs=eviction_interval_secs, max_tree_size=max_tree_size, max_payload_size=max_payload_size, + dp_aware=dp_aware, + api_key=api_key, log_dir=log_dir, log_level=log_level, service_discovery=service_discovery, diff --git a/sgl-router/py_test/run_suite.py b/sgl-router/py_test/run_suite.py index e1434b0e8..ac7f9c140 100644 --- a/sgl-router/py_test/run_suite.py +++ b/sgl-router/py_test/run_suite.py @@ -8,7 +8,7 @@ if __name__ == "__main__": arg_parser.add_argument( "--timeout-per-file", type=int, - default=1000, + default=2000, help="The time limit for running one file in seconds.", ) args = arg_parser.parse_args() diff --git a/sgl-router/py_test/test_launch_router.py b/sgl-router/py_test/test_launch_router.py index 90d8aa664..a014efac6 100644 --- a/sgl-router/py_test/test_launch_router.py +++ b/sgl-router/py_test/test_launch_router.py @@ -43,6 +43,7 @@ class TestLaunchRouter(unittest.TestCase): selector=None, service_discovery_port=80, service_discovery_namespace=None, + dp_aware=False, prometheus_port=None, prometheus_host=None, # PD-specific attributes @@ -111,6 +112,52 @@ class TestLaunchRouter(unittest.TestCase): ) self.run_router_process(args) + def test_launch_router_common_with_dp_aware(self): + args = self.create_router_args( + worker_urls=["http://localhost:8000"], + dp_aware=True, + ) + self.run_router_process(args) + + def test_launch_router_with_empty_worker_urls_with_dp_aware(self): + args = self.create_router_args( + worker_urls=[], + dp_aware=True, + ) + self.run_router_process(args) + + def test_launch_router_common_with_dp_aware_service_discovery(self): + # Test launch router with bot srevice_discovery and dp_aware enabled + # Should fail since service_discovery and dp_aware is conflict + args = self.create_router_args( + worker_urls=["http://localhost:8000"], + dp_aware=True, + service_discovery=True, + selector=["app=test-worker"], + ) + + def run_router(): + try: + from sglang_router.launch_router import launch_router + + router = launch_router(args) + if router is None: + return 1 + return 0 + except Exception as e: + print(e) + return 1 + + process = multiprocessing.Process(target=run_router) + try: + process.start() + # Wait 3 seconds + time.sleep(3) + # Should fail since service_discovery and dp_aware is conflict + self.assertFalse(process.is_alive()) + finally: + terminate_process(process) + def test_launch_router_pd_mode_basic(self): """Test basic PD router functionality without actually starting servers.""" # This test just verifies the PD router can be created and configured diff --git a/sgl-router/py_test/test_launch_server.py b/sgl-router/py_test/test_launch_server.py index bfba8a765..d361e8d66 100644 --- a/sgl-router/py_test/test_launch_server.py +++ b/sgl-router/py_test/test_launch_server.py @@ -30,6 +30,7 @@ def popen_launch_router( service_discovery_namespace: str = None, prometheus_port: int = None, prometheus_host: str = None, + dp_aware: bool = False, ): """ Launch the router server process. @@ -49,6 +50,7 @@ def popen_launch_router( service_discovery_namespace: Kubernetes namespace to watch for pods. If None, watches all namespaces. prometheus_port: Port to expose Prometheus metrics. If None, Prometheus metrics are disabled. prometheus_host: Host address to bind the Prometheus metrics server. + dp_aware: Enable data parallelism aware routing strategy. """ _, host, port = base_url.split(":") host = host[2:] @@ -69,10 +71,12 @@ def popen_launch_router( "5", "--router-policy", policy, + "--allow-auto-truncate", ] if api_key is not None: command.extend(["--api-key", api_key]) + command.extend(["--router-api-key", api_key]) if max_payload_size is not None: command.extend(["--router-max-payload-size", str(max_payload_size)]) @@ -100,6 +104,9 @@ def popen_launch_router( if log_dir is not None: command.extend(["--log-dir", log_dir]) + if dp_aware: + command.append("--router-dp-aware") + process = subprocess.Popen(command, stdout=None, stderr=None) start_time = time.perf_counter() @@ -127,6 +134,7 @@ def popen_launch_server( model: str, base_url: str, timeout: float, + api_key: str = None, ): _, host, port = base_url.split(":") host = host[2:] @@ -145,6 +153,9 @@ def popen_launch_server( "1", ] + if api_key is not None: + command.extend(["--api-key", api_key]) + process = subprocess.Popen(command, stdout=None, stderr=None) # intentionally don't wait and defer the job to the router health check @@ -426,6 +437,274 @@ class TestLaunchServer(unittest.TestCase): response.status_code, 200, "Request with correct api key should succeed" ) + def test_6_mmlu_with_dp_aware(self): + print("Running test_6_mmlu_with_dp_aware...") + # DP size = 2 + self.process = popen_launch_router( + self.model, + self.base_url, + dp_size=2, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + policy="cache_aware", + dp_aware=True, + ) + + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mmlu", + num_examples=64, + num_threads=32, + temperature=0.1, + ) + + metrics = run_eval(args) + score = metrics["score"] + THRESHOLD = 0.65 + passed = score >= THRESHOLD + msg = f"dp aware MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})" + self.assertGreaterEqual(score, THRESHOLD, msg) + + def test_7_add_and_remove_worker_with_dp_aware(self): + print("Running test_7_add_and_remove_worker_with_dp_aware...") + + # Set dp_size = 1 + self.process = popen_launch_router( + self.model, + self.base_url, + dp_size=1, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + policy="round_robin", # make sure every worker processes requests + dp_aware=True, # dp aware strategy should work well with RR + ) + + # 1. Start a worker + port = find_available_port() + worker_url = f"http://127.0.0.1:{port}" + worker_process = popen_launch_server( + self.model, worker_url, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH + ) + self.other_process.append(worker_process) + + # 2. Use the /add_worker API to add it to the router + # It will be used by router after it is healthy + with requests.Session() as session: + response = session.post(f"{self.base_url}/add_worker?url={worker_url}") + print(f"status code: {response.status_code}, response: {response.text}") + self.assertEqual(response.status_code, 200) + + # 3. Run mmlu + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mmlu", + num_examples=64, + num_threads=32, + temperature=0.1, + ) + metrics = run_eval(args) + score = metrics["score"] + THRESHOLD = 0.65 + passed = score >= THRESHOLD + msg = f"MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})" + self.assertGreaterEqual(score, THRESHOLD, msg) + + # 4. Use the /remove_worker API to remove it from the router + with requests.Session() as session: + response = session.post(f"{self.base_url}/remove_worker?url={worker_url}") + print(f"status code: {response.status_code}, response: {response.text}") + self.assertEqual(response.status_code, 200) + + # 5. Run mmlu again + metrics = run_eval(args) + score = metrics["score"] + THRESHOLD = 0.65 + passed = score >= THRESHOLD + msg = f"MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})" + self.assertGreaterEqual(score, THRESHOLD, msg) + + # 6. Start another worker with api_key set + terminate_and_wait(worker_process) # terminate the old worker process + port = find_available_port() + worker_url = f"http://127.0.0.1:{port}" + worker_process = popen_launch_server( + self.model, + worker_url, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + api_key="correct_api_key", + ) + self.other_process.append(worker_process) + + # 7. Use the /add_worker API to add it to the router + # Should fail since the router would contact the worker's + # /get_server_info endpoint for the dp_size info, but it + # has no knowledge of the api key + with requests.Session() as session: + response = session.post(f"{self.base_url}/add_worker?url={worker_url}") + print(f"status code: {response.status_code}, response: {response.text}") + self.assertNotEqual(response.status_code, 200) + + def test_8_lazy_fault_tolerance_with_dp_aware(self): + print("Running test_8_lazy_fault_tolerance_with_dp_aware...") + + # Set dp_size = 1 + self.process = popen_launch_router( + self.model, + self.base_url, + dp_size=1, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + policy="round_robin", + dp_aware=True, + ) + + # 1. Start a worker + port = find_available_port() + worker_url = f"http://127.0.0.1:{port}" + worker_process = popen_launch_server( + self.model, worker_url, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH + ) + self.other_process.append(worker_process) + + # 2. Use the /add_worker API to add it to the router + # It will be used by router after it is healthy + with requests.Session() as session: + response = session.post(f"{self.base_url}/add_worker?url={worker_url}") + print(f"status code: {response.status_code}, response: {response.text}") + self.assertEqual(response.status_code, 200) + + # Start a thread to kill the worker after 10 seconds to mimic + # abrupt worker failure + def kill_worker(): + time.sleep(10) + kill_process_tree(worker_process.pid) + print("Worker process killed") + + import threading + + kill_thread = threading.Thread(target=kill_worker) + kill_thread.daemon = True + kill_thread.start() + + # 3. Run mmlu + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mmlu", + num_examples=256, + num_threads=32, + temperature=0.1, + ) + metrics = run_eval(args) + score = metrics["score"] + THRESHOLD = 0.65 + passed = score >= THRESHOLD + msg = f"MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})" + self.assertGreaterEqual(score, THRESHOLD, msg) + + def test_9_payload_size_with_dp_aware(self): + print("Running test_9_payload_size_with_dp_aware...") + + # Start the router with 1MB limit + self.process = popen_launch_router( + self.model, + self.base_url, + dp_size=1, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + policy="round_robin", + max_payload_size=1 * 1024 * 1024, # 1MB limit + dp_aware=True, + ) + + # Test case 1: Payload just under 1MB should succeed + payload_0_5_mb = { + "text": "x" * int(0.5 * 1024 * 1024), # 0.5MB of text + "temperature": 0.0, + } + + with requests.Session() as session: + response = session.post( + f"{self.base_url}/generate", + json=payload_0_5_mb, + headers={"Content-Type": "application/json"}, + ) + self.assertEqual( + response.status_code, + 200, + f"0.5MB payload should succeed but got status {response.status_code}", + ) + + # Test case 2: Payload over 1MB should fail + payload_1_plus_mb = { + "text": "x" * int((1.2 * 1024 * 1024)), # 1.2MB of text + "temperature": 0.0, + } + + with requests.Session() as session: + response = session.post( + f"{self.base_url}/generate", + json=payload_1_plus_mb, + headers={"Content-Type": "application/json"}, + ) + self.assertEqual( + response.status_code, + 413, # Payload Too Large + f"1.2MB payload should fail with 413 but got status {response.status_code}", + ) + + def test_10_api_key_with_dp_aware(self): + print("Running test_10_api_key_with_dp_aware...") + + self.process = popen_launch_router( + self.model, + self.base_url, + dp_size=1, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + policy="round_robin", + api_key="correct_api_key", + dp_aware=True, + ) + + # Test case 1: request without api key should fail + with requests.Session() as session: + response = session.post( + f"{self.base_url}/generate", + json={"text": "Kanye west is, ", "temperature": 0}, + ) + print(f"status code: {response.status_code}, response: {response.text}") + self.assertEqual( + response.status_code, + 401, + f"Request without api key should fail with 401 but got status {response.status_code}", + ) + + # Test case 2: request with invalid api key should fail + with requests.Session() as session: + response = requests.post( + f"{self.base_url}/generate", + json={"text": "Kanye west is, ", "temperature": 0}, + headers={"Authorization": "Bearer 123"}, + ) + print(f"status code: {response.status_code}, response: {response.text}") + self.assertEqual( + response.status_code, + 401, + f"Request without api key should fail with 401 but got status {response.status_code}", + ) + + # Test case 3: request with correct api key should succeed + with requests.Session() as session: + response = session.post( + f"{self.base_url}/generate", + json={"text": "Kanye west is ", "temperature": 0}, + headers={"Authorization": "Bearer correct_api_key"}, + ) + print(f"status code: {response.status_code}, response: {response.text}") + self.assertEqual( + response.status_code, + 200, + f"Request with correct api key should succeed but got status {response.status_code}", + ) + if __name__ == "__main__": unittest.main() diff --git a/sgl-router/src/config/types.rs b/sgl-router/src/config/types.rs index 537e2a119..67358caaa 100644 --- a/sgl-router/src/config/types.rs +++ b/sgl-router/src/config/types.rs @@ -21,6 +21,10 @@ pub struct RouterConfig { pub worker_startup_timeout_secs: u64, /// Worker health check interval in seconds pub worker_startup_check_interval_secs: u64, + /// Enable data parallelism aware schedule + pub dp_aware: bool, + /// The api key used for the authorization with the worker + pub api_key: Option, /// Service discovery configuration (optional) pub discovery: Option, /// Metrics configuration (optional) @@ -205,6 +209,8 @@ impl Default for RouterConfig { request_timeout_secs: 600, worker_startup_timeout_secs: 300, worker_startup_check_interval_secs: 10, + dp_aware: false, + api_key: None, discovery: None, metrics: None, log_dir: None, @@ -311,6 +317,8 @@ mod tests { request_timeout_secs: 30, worker_startup_timeout_secs: 60, worker_startup_check_interval_secs: 5, + dp_aware: false, + api_key: None, discovery: Some(DiscoveryConfig::default()), metrics: Some(MetricsConfig::default()), log_dir: Some("/var/log".to_string()), @@ -727,6 +735,8 @@ mod tests { request_timeout_secs: 120, worker_startup_timeout_secs: 60, worker_startup_check_interval_secs: 5, + dp_aware: false, + api_key: None, discovery: Some(DiscoveryConfig { enabled: true, namespace: Some("sglang".to_string()), @@ -774,6 +784,8 @@ mod tests { request_timeout_secs: 300, worker_startup_timeout_secs: 180, worker_startup_check_interval_secs: 15, + dp_aware: false, + api_key: None, discovery: Some(DiscoveryConfig { enabled: true, namespace: None, @@ -812,6 +824,8 @@ mod tests { request_timeout_secs: 900, worker_startup_timeout_secs: 600, worker_startup_check_interval_secs: 20, + dp_aware: false, + api_key: None, discovery: Some(DiscoveryConfig { enabled: true, namespace: Some("production".to_string()), diff --git a/sgl-router/src/config/validation.rs b/sgl-router/src/config/validation.rs index 1e78a0f10..65eaef95f 100644 --- a/sgl-router/src/config/validation.rs +++ b/sgl-router/src/config/validation.rs @@ -313,6 +313,14 @@ impl ConfigValidator { } } + // Service discovery is conflict with dp_aware routing for now + // since it's not fully supported yet + if has_service_discovery && config.dp_aware { + return Err(ConfigError::IncompatibleConfig { + reason: "DP-aware routing is not compatible with service discovery".to_string(), + }); + } + Ok(()) } diff --git a/sgl-router/src/core/error.rs b/sgl-router/src/core/error.rs index 4d50ccee0..b89ba8032 100644 --- a/sgl-router/src/core/error.rs +++ b/sgl-router/src/core/error.rs @@ -17,6 +17,8 @@ pub enum WorkerError { NetworkError { url: String, error: String }, /// Worker is at capacity WorkerAtCapacity { url: String }, + /// Invalid URL format + InvalidUrl { url: String }, } impl fmt::Display for WorkerError { @@ -37,6 +39,9 @@ impl fmt::Display for WorkerError { WorkerError::WorkerAtCapacity { url } => { write!(f, "Worker at capacity: {}", url) } + WorkerError::InvalidUrl { url } => { + write!(f, "Invalid URL format: {}", url) + } } } } diff --git a/sgl-router/src/core/worker.rs b/sgl-router/src/core/worker.rs index fc91b1f5e..58db15991 100644 --- a/sgl-router/src/core/worker.rs +++ b/sgl-router/src/core/worker.rs @@ -162,6 +162,27 @@ impl BasicWorker { self.metadata.health_config = config; self } + + pub fn normalised_url(&self) -> WorkerResult<&str> { + if self.url().contains("@") { + // Need to extract the URL from "http://host:port@dp_rank" + let parts: Vec<&str> = self.url().split('@').collect(); + if parts.len() != 2 { + return Err(WorkerError::InvalidUrl { + url: self.url().to_string(), + }); + } + // Ensure the second part (the dp_rank) can be parsed as an integer + match parts[1].parse::() { + Ok(_) => Ok(parts[0]), + Err(_) => Err(WorkerError::InvalidUrl { + url: self.url().to_string(), + }), + } + } else { + Ok(self.url()) + } + } } #[async_trait] @@ -186,7 +207,8 @@ impl Worker for BasicWorker { use std::time::Duration; // Perform actual HTTP health check - let health_url = format!("{}{}", self.url(), self.metadata.health_config.endpoint); + let url = self.normalised_url()?; + let health_url = format!("{}{}", url, self.metadata.health_config.endpoint); let timeout = Duration::from_secs(self.metadata.health_config.timeout_secs); // Use the shared client with a custom timeout for this request @@ -203,7 +225,7 @@ impl Worker for BasicWorker { } else { self.set_healthy(false); Err(WorkerError::HealthCheckFailed { - url: self.url().to_string(), + url: url.to_string(), reason: format!("Health check returned status: {}", response.status()), }) } @@ -211,7 +233,7 @@ impl Worker for BasicWorker { Err(e) => { self.set_healthy(false); Err(WorkerError::HealthCheckFailed { - url: self.url().to_string(), + url: url.to_string(), reason: format!("Health check request failed: {}", e), }) } diff --git a/sgl-router/src/lib.rs b/sgl-router/src/lib.rs index ede058f87..6bec3d418 100644 --- a/sgl-router/src/lib.rs +++ b/sgl-router/src/lib.rs @@ -37,6 +37,8 @@ struct Router { eviction_interval_secs: u64, max_tree_size: usize, max_payload_size: usize, + dp_aware: bool, + api_key: Option, log_dir: Option, log_level: Option, service_discovery: bool, @@ -136,6 +138,8 @@ impl Router { request_timeout_secs: self.request_timeout_secs, worker_startup_timeout_secs: self.worker_startup_timeout_secs, worker_startup_check_interval_secs: self.worker_startup_check_interval, + dp_aware: self.dp_aware, + api_key: self.api_key.clone(), discovery, metrics, log_dir: self.log_dir.clone(), @@ -161,6 +165,8 @@ impl Router { eviction_interval_secs = 60, max_tree_size = 2usize.pow(24), max_payload_size = 256 * 1024 * 1024, // 256MB default for large batches + dp_aware = false, + api_key = None, log_dir = None, log_level = None, service_discovery = false, @@ -193,6 +199,8 @@ impl Router { eviction_interval_secs: u64, max_tree_size: usize, max_payload_size: usize, + dp_aware: bool, + api_key: Option, log_dir: Option, log_level: Option, service_discovery: bool, @@ -225,6 +233,8 @@ impl Router { eviction_interval_secs, max_tree_size, max_payload_size, + dp_aware, + api_key, log_dir, log_level, service_discovery, diff --git a/sgl-router/src/routers/factory.rs b/sgl-router/src/routers/factory.rs index edf063440..b97974367 100644 --- a/sgl-router/src/routers/factory.rs +++ b/sgl-router/src/routers/factory.rs @@ -45,6 +45,8 @@ impl RouterFactory { policy, router_config.worker_startup_timeout_secs, router_config.worker_startup_check_interval_secs, + router_config.dp_aware, + router_config.api_key.clone(), )?; Ok(Box::new(router)) diff --git a/sgl-router/src/routers/router.rs b/sgl-router/src/routers/router.rs index b065afafe..294fa4919 100644 --- a/sgl-router/src/routers/router.rs +++ b/sgl-router/src/routers/router.rs @@ -30,6 +30,8 @@ pub struct Router { policy: Arc, timeout_secs: u64, interval_secs: u64, + dp_aware: bool, + api_key: Option, _worker_loads: Arc>>, _load_monitor_handle: Option>>, _health_checker: Option, @@ -42,6 +44,8 @@ impl Router { policy: Arc, timeout_secs: u64, interval_secs: u64, + dp_aware: bool, + api_key: Option, ) -> Result { // Update active workers gauge RouterMetrics::set_active_workers(worker_urls.len()); @@ -51,6 +55,14 @@ impl Router { Self::wait_for_healthy_workers(&worker_urls, timeout_secs, interval_secs)?; } + let worker_urls = if dp_aware { + // worker address now in the format of "http://host:port@dp_rank" + Self::get_dp_aware_workers(&worker_urls, &api_key) + .map_err(|e| format!("Failed to get dp-aware workers: {}", e))? + } else { + worker_urls + }; + // Create Worker trait objects from URLs let workers: Vec> = worker_urls .iter() @@ -89,6 +101,8 @@ impl Router { policy, timeout_secs, interval_secs, + dp_aware, + api_key, _worker_loads: worker_loads, _load_monitor_handle: load_monitor_handle, _health_checker: Some(health_checker), @@ -160,6 +174,62 @@ impl Router { } } + fn get_worker_dp_size(worker_url: &str, api_key: &Option) -> Result { + let sync_client = reqwest::blocking::Client::new(); + let mut req_builder = sync_client.get(&format!("{}/get_server_info", worker_url)); + if let Some(key) = api_key { + req_builder = req_builder.bearer_auth(key); + } + + match req_builder.send() { + Ok(res) => { + if res.status().is_success() { + let server_info = res + .text() + .map_err(|e| format!("failed to read text from response: {}", e))?; + + let server_info: serde_json::Value = serde_json::from_str(&server_info) + .map_err(|e| format!("failed to decode JSON: {}", e))?; + + let dp_size = server_info + .get("dp_size") + .and_then(|v| v.as_u64()) + .ok_or_else(|| String::from("dp_size not found or not an u64"))?; + + Ok(if dp_size > usize::MAX as u64 { + return Err(format!("dp_size is too large: {}", dp_size)); + } else { + dp_size as usize + }) + } else { + Err(format!("unexpected status code: {}", res.status())) + } + } + Err(e) => Err(format!("error response: {}", e)), + } + } + + // Given a list of workers, return a list of workers with dp_rank as suffix + fn get_dp_aware_workers( + worker_urls: &[String], + api_key: &Option, + ) -> Result, String> { + let mut dp_aware_workers: Vec = Vec::new(); + + for url in worker_urls { + match Self::get_worker_dp_size(url, api_key) { + Ok(dp_size) => { + for i in 0..dp_size { + dp_aware_workers.push(format!("{}@{}", url, i)); + } + } + Err(e) => return Err(format!("Failed to get DP size for {}: {}", url, e)), + } + } + + Ok(dp_aware_workers) + } + fn select_first_worker(&self) -> Result { let workers_guard = self.workers.read().unwrap(); if workers_guard.is_empty() { @@ -178,6 +248,21 @@ impl Router { ) -> HttpResponse { let request_id = get_request_id(req); let start = Instant::now(); + + let worker_url = if self.dp_aware { + // Need to extract the URL from "http://host:port@dp_rank" + let (worker_url_prefix, _dp_rank) = match Self::extract_dp_rank(worker_url) { + Ok(tup) => tup, + Err(e) => { + error!("Failed to extract dp_rank: {}", e); + return HttpResponse::InternalServerError().finish(); + } + }; + worker_url_prefix + } else { + worker_url + }; + let mut request_builder = client.get(format!("{}{}", worker_url, route)); // Copy all headers from original request except for /health because it does not need authorization @@ -292,7 +377,7 @@ impl Router { worker_url = %worker_url, "Removing failed worker" ); - self.remove_worker(&worker_url); + self.remove_failed_worker(&worker_url); break; } } @@ -392,7 +477,7 @@ impl Router { request_id = %request_id, "Removing failed worker after typed request failures worker_url={}", worker_url ); - self.remove_worker(&worker_url); + self.remove_failed_worker(&worker_url); break; } } @@ -415,6 +500,23 @@ impl Router { } } + // TODO (rui): Better accommodate to the Worker abstraction + fn extract_dp_rank(worker_url: &str) -> Result<(&str, usize), String> { + let parts: Vec<&str> = worker_url.split('@').collect(); + if parts.len() != 2 { + return Err(format!("invalid worker_url format: {}", worker_url)); + } + + // Parse the second part (dp_rank) into an integer + match parts[1].parse::() { + Ok(dp_rank) => Ok((parts[0], dp_rank)), + Err(_) => Err(format!( + "failed to parse dp_rank from worker_url: {}", + worker_url + )), + } + } + // Send typed request directly without conversion async fn send_typed_request( &self, @@ -429,9 +531,47 @@ impl Router { let request_id = get_request_id(req); let start = Instant::now(); - let mut request_builder = client - .post(format!("{}{}", worker_url, route)) - .json(typed_req); // Use json() directly with typed request + let mut request_builder = if self.dp_aware { + let (worker_url_prefix, dp_rank) = match Self::extract_dp_rank(worker_url) { + Ok(tup) => tup, + Err(e) => { + error!("Failed to extract dp_rank: {}", e); + return HttpResponse::InternalServerError().finish(); + } + }; + + // Parse the request body + let mut json_val = match serde_json::to_value(typed_req) { + Ok(j) => j, + Err(e) => { + return HttpResponse::BadRequest() + .body(format!("Convert into serde_json::Value failed: {}", e)); + } + }; + + // Insert the data_parallel_rank field + if let Some(map) = json_val.as_object_mut() { + map.insert( + String::from("data_parallel_rank"), + serde_json::json!(dp_rank), + ); + debug!( + "Modified request body: {}", + serde_json::to_string(&json_val).unwrap_or(String::from("ERR")) + ); + } else { + return HttpResponse::BadRequest() + .body("Failed to insert the data_parallel_rank field into the request body"); + } + + client + .post(format!("{}{}", worker_url_prefix, route)) + .json(&json_val) + } else { + client + .post(format!("{}{}", worker_url, route)) + .json(typed_req) // Use json() directly with typed request + }; // Copy all headers from original request for (name, value) in copy_request_headers(req) { @@ -560,12 +700,35 @@ impl Router { Ok(res) => { if res.status().is_success() { let mut workers_guard = self.workers.write().unwrap(); - if workers_guard.iter().any(|w| w.url() == worker_url) { - return Err(format!("Worker {} already exists", worker_url)); + if self.dp_aware { + // Need to contact the worker to extract the dp_size, + // and add them as multiple workers + let url_vec = vec![String::from(worker_url)]; + let dp_url_vec = Self::get_dp_aware_workers(&url_vec, &self.api_key) + .map_err(|e| format!("Failed to get dp-aware workers: {}", e))?; + let mut worker_added: bool = false; + for dp_url in &dp_url_vec { + if workers_guard.iter().any(|w| w.url() == dp_url) { + warn!("Worker {} already exists", dp_url); + continue; + } + info!("Added worker: {}", dp_url); + let new_worker = WorkerFactory::create_regular(dp_url.to_string()); + workers_guard.push(new_worker); + worker_added = true; + } + if !worker_added { + return Err(format!("No worker added for {}", worker_url)); + } + } else { + if workers_guard.iter().any(|w| w.url() == worker_url) { + return Err(format!("Worker {} already exists", worker_url)); + } + info!("Added worker: {}", worker_url); + let new_worker = WorkerFactory::create_regular(worker_url.to_string()); + workers_guard.push(new_worker); } - info!("Added worker: {}", worker_url); - let new_worker = WorkerFactory::create_regular(worker_url.to_string()); - workers_guard.push(new_worker); + RouterMetrics::set_active_workers(workers_guard.len()); // If cache aware policy, initialize the worker in the tree @@ -612,11 +775,81 @@ impl Router { } } + /// Remove all the worker(s) that match the URL prefix pub fn remove_worker(&self, worker_url: &str) { + if self.dp_aware { + // remove dp-aware workers in a prefix-matching fashion + // without contacting the remote worker + let mut candidate_workers: Vec = Vec::new(); + let mut removed_workers: Vec = Vec::new(); + let worker_url_prefix = format!("{}@", worker_url); + + { + // find the candidate workers to be removed + let workers_guard = self.workers.read().unwrap(); + for w in workers_guard.iter() { + if w.url().starts_with(&worker_url_prefix) { + candidate_workers.push(w.url().to_string()); + } + } + } + + { + // do the removing on the worker_urls + let mut workers_guard = self.workers.write().unwrap(); + for dp_url in candidate_workers.iter() { + if let Some(index) = workers_guard.iter().position(|w| w.url() == dp_url) { + workers_guard.remove(index); + info!("Removed worker: {}", dp_url); + removed_workers.push(dp_url.to_string()); + } else { + warn!("Worker {} not found, skipping removal", dp_url); + continue; + } + } + RouterMetrics::set_active_workers(workers_guard.len()); + } + + // If cache aware policy, remove the workers from the tree + if let Some(cache_aware) = self + .policy + .as_any() + .downcast_ref::() + { + for dp_url in removed_workers.iter() { + cache_aware.remove_worker(dp_url); + info!("Removed worker from tree: {}", dp_url); + } + } + } else { + let mut workers_guard = self.workers.write().unwrap(); + if let Some(index) = workers_guard.iter().position(|w| w.url() == worker_url) { + workers_guard.remove(index); + info!("Removed worker: {}", worker_url); + RouterMetrics::set_active_workers(workers_guard.len()); + } else { + warn!("Worker {} not found, skipping removal", worker_url); + return; + } + + // If cache aware policy, remove the workers from the tree + if let Some(cache_aware) = self + .policy + .as_any() + .downcast_ref::() + { + cache_aware.remove_worker(worker_url); + info!("Removed worker from tree: {}", worker_url); + } + } + } + + /// Remove a specific failed worker; for internal usage + fn remove_failed_worker(&self, worker_url: &str) { let mut workers_guard = self.workers.write().unwrap(); if let Some(index) = workers_guard.iter().position(|w| w.url() == worker_url) { workers_guard.remove(index); - info!("Removed worker: {}", worker_url); + info!("Removed failed worker: {}", worker_url); RouterMetrics::set_active_workers(workers_guard.len()); } else { warn!("Worker {} not found, skipping removal", worker_url); @@ -634,6 +867,20 @@ impl Router { } async fn get_worker_load(&self, client: &reqwest::Client, worker_url: &str) -> Option { + let worker_url = if self.dp_aware { + // Need to extract the URL from "http://host:port@dp_rank" + let (worker_url_prefix, _dp_rank) = match Self::extract_dp_rank(worker_url) { + Ok(tup) => tup, + Err(e) => { + error!("Failed to extract dp_rank: {}", e); + return None; + } + }; + worker_url_prefix + } else { + worker_url + }; + match client.get(&format!("{}/get_load", worker_url)).send().await { Ok(res) if res.status().is_success() => match res.bytes().await { Ok(bytes) => match serde_json::from_slice::(&bytes) { @@ -710,6 +957,20 @@ impl Router { // Static version of get_worker_load for use in monitoring task async fn get_worker_load_static(client: &reqwest::Client, worker_url: &str) -> Option { + let worker_url = if worker_url.contains("@") { + // Need to extract the URL from "http://host:port@dp_rank" + let (worker_url_prefix, _dp_rank) = match Self::extract_dp_rank(worker_url) { + Ok(tup) => tup, + Err(e) => { + debug!("Failed to extract dp_rank: {}", e); + return None; + } + }; + worker_url_prefix + } else { + worker_url + }; + match client.get(&format!("{}/get_load", worker_url)).send().await { Ok(res) if res.status().is_success() => match res.bytes().await { Ok(bytes) => match serde_json::from_slice::(&bytes) { @@ -862,6 +1123,19 @@ impl RouterTrait for Router { // Send requests to all workers concurrently without headers let mut tasks = Vec::new(); for worker_url in &worker_urls { + let worker_url = if self.dp_aware { + // Need to extract the URL from "http://host:port@dp_rank" + let (worker_url_prefix, _dp_rank) = match Self::extract_dp_rank(worker_url) { + Ok(tup) => tup, + Err(e) => { + error!("Failed to extract dp_rank: {}", e); + return HttpResponse::InternalServerError().finish(); + } + }; + worker_url_prefix + } else { + worker_url + }; let request_builder = client.post(format!("{}/flush_cache", worker_url)); tasks.push(request_builder.send()); } @@ -948,6 +1222,8 @@ mod tests { policy: Arc::new(RandomPolicy::new()), timeout_secs: 5, interval_secs: 1, + dp_aware: false, + api_key: None, _worker_loads: Arc::new(rx), _load_monitor_handle: None, _health_checker: None, diff --git a/sgl-router/src/service_discovery.rs b/sgl-router/src/service_discovery.rs index fae09896d..717370d14 100644 --- a/sgl-router/src/service_discovery.rs +++ b/sgl-router/src/service_discovery.rs @@ -581,7 +581,7 @@ mod tests { use crate::routers::router::Router; let policy = PolicyFactory::create_from_config(&PolicyConfig::Random); - let router = Router::new(vec![], policy, 5, 1).unwrap(); + let router = Router::new(vec![], policy, 5, 1, false, None).unwrap(); Arc::new(router) as Arc } diff --git a/sgl-router/tests/api_endpoints_test.rs b/sgl-router/tests/api_endpoints_test.rs index bf86d776b..c38843b77 100644 --- a/sgl-router/tests/api_endpoints_test.rs +++ b/sgl-router/tests/api_endpoints_test.rs @@ -31,6 +31,8 @@ impl TestContext { request_timeout_secs: 600, worker_startup_timeout_secs: 1, worker_startup_check_interval_secs: 1, + dp_aware: false, + api_key: None, discovery: None, metrics: None, log_dir: None, @@ -950,6 +952,8 @@ mod error_tests { request_timeout_secs: 600, worker_startup_timeout_secs: 1, worker_startup_check_interval_secs: 1, + dp_aware: false, + api_key: None, discovery: None, metrics: None, log_dir: None, diff --git a/sgl-router/tests/common/mod.rs b/sgl-router/tests/common/mod.rs index 62c99a46b..47aafae32 100644 --- a/sgl-router/tests/common/mod.rs +++ b/sgl-router/tests/common/mod.rs @@ -16,6 +16,8 @@ pub fn create_test_config(worker_urls: Vec) -> RouterConfig { request_timeout_secs: 600, worker_startup_timeout_secs: 300, worker_startup_check_interval_secs: 10, + dp_aware: false, + api_key: None, discovery: None, metrics: None, log_dir: None, @@ -37,6 +39,8 @@ pub fn create_test_config_no_workers() -> RouterConfig { request_timeout_secs: 600, worker_startup_timeout_secs: 0, // No wait worker_startup_check_interval_secs: 10, + dp_aware: false, + api_key: None, discovery: None, metrics: None, log_dir: None, diff --git a/sgl-router/tests/request_formats_test.rs b/sgl-router/tests/request_formats_test.rs index d265d1030..b6bc6ac4a 100644 --- a/sgl-router/tests/request_formats_test.rs +++ b/sgl-router/tests/request_formats_test.rs @@ -42,6 +42,8 @@ impl RequestTestContext { request_timeout_secs: 600, worker_startup_timeout_secs: 1, worker_startup_check_interval_secs: 1, + dp_aware: false, + api_key: None, discovery: None, metrics: None, log_dir: None, diff --git a/sgl-router/tests/streaming_tests.rs b/sgl-router/tests/streaming_tests.rs index ada8b7e45..3fce7b835 100644 --- a/sgl-router/tests/streaming_tests.rs +++ b/sgl-router/tests/streaming_tests.rs @@ -46,6 +46,8 @@ impl StreamingTestContext { request_timeout_secs: 600, worker_startup_timeout_secs: 1, worker_startup_check_interval_secs: 1, + dp_aware: false, + api_key: None, discovery: None, metrics: None, log_dir: None, diff --git a/sgl-router/tests/test_pd_routing.rs b/sgl-router/tests/test_pd_routing.rs index a6cb8d02d..8bf0c2ee2 100644 --- a/sgl-router/tests/test_pd_routing.rs +++ b/sgl-router/tests/test_pd_routing.rs @@ -169,6 +169,8 @@ mod test_pd_routing { request_timeout_secs: 60, worker_startup_timeout_secs: 10, worker_startup_check_interval_secs: 1, + dp_aware: false, + api_key: None, discovery: None, metrics: None, log_dir: None,