From 9d68bdb240a5ba2713bd130f96e457f29fc7fba1 Mon Sep 17 00:00:00 2001 From: Simo Lin Date: Mon, 11 Aug 2025 21:37:36 -0700 Subject: [PATCH] [router] Add Rust Binary Entrypoint for SGLang Router (#9089) --- sgl-router/Cargo.toml | 5 + sgl-router/README.md | 33 +- sgl-router/src/main.rs | 490 +++++++++++++++++++++++ sgl-router/src/routers/factory.rs | 33 +- sgl-router/src/routers/pd_router.rs | 6 +- sgl-router/src/routers/router.rs | 95 +++-- sgl-router/src/server.rs | 2 +- sgl-router/src/service_discovery.rs | 21 +- sgl-router/tests/api_endpoints_test.rs | 13 +- sgl-router/tests/request_formats_test.rs | 6 +- sgl-router/tests/streaming_tests.rs | 6 +- sgl-router/tests/test_pd_routing.rs | 6 +- 12 files changed, 638 insertions(+), 78 deletions(-) create mode 100644 sgl-router/src/main.rs diff --git a/sgl-router/Cargo.toml b/sgl-router/Cargo.toml index 1b85576c0..b187e0970 100644 --- a/sgl-router/Cargo.toml +++ b/sgl-router/Cargo.toml @@ -9,7 +9,12 @@ name = "sglang_router_rs" # Python/C binding + Rust library: Use ["cdylib", "rlib"] crate-type = ["cdylib", "rlib"] +[[bin]] +name = "sglang-router" +path = "src/main.rs" + [dependencies] +clap = { version = "4", features = ["derive"] } axum = { version = "0.8.4", features = ["macros", "ws", "tracing"] } tower = { version = "0.5", features = ["full"] } tower-http = { version = "0.6", features = ["trace", "compression-gzip", "cors", "timeout", "limit", "request-id", "util"] } diff --git a/sgl-router/README.md b/sgl-router/README.md index 6e9eba39a..42d1bb314 100644 --- a/sgl-router/README.md +++ b/sgl-router/README.md @@ -56,7 +56,21 @@ pip install -e . cargo build ``` -#### Launch Router with Worker URLs in regular mode +#### Using the Rust Binary Directly (Alternative to Python) +```bash +# Build the Rust binary +cargo build --release + +# Launch router with worker URLs in regular mode +./target/release/sglang-router \ + --worker-urls http://worker1:8000 http://worker2:8000 + +# Or use cargo run +cargo run --release -- \ + --worker-urls http://worker1:8000 http://worker2:8000 +``` + +#### Launch Router with Python (Original Method) ```bash # Launch router with worker URLs python -m sglang_router.launch_router \ @@ -68,7 +82,22 @@ python -m sglang_router.launch_router \ # Note that the prefill and decode URLs must be provided in the following format: # http://: for decode nodes # http://: bootstrap-port for prefill nodes, where bootstrap-port is optional -# Launch router with worker URLs + +# Using Rust binary directly +./target/release/sglang-router \ + --pd-disaggregation \ + --policy cache_aware \ + --prefill http://127.0.0.1:30001 9001 \ + --prefill http://127.0.0.2:30002 9002 \ + --prefill http://127.0.0.3:30003 9003 \ + --prefill http://127.0.0.4:30004 9004 \ + --decode http://127.0.0.5:30005 \ + --decode http://127.0.0.6:30006 \ + --decode http://127.0.0.7:30007 \ + --host 0.0.0.0 \ + --port 8080 + +# Or using Python launcher python -m sglang_router.launch_router \ --pd-disaggregation \ --policy cache_aware \ diff --git a/sgl-router/src/main.rs b/sgl-router/src/main.rs new file mode 100644 index 000000000..180545942 --- /dev/null +++ b/sgl-router/src/main.rs @@ -0,0 +1,490 @@ +use clap::{ArgAction, Parser}; +use sglang_router_rs::config::{ + CircuitBreakerConfig, ConfigError, ConfigResult, DiscoveryConfig, MetricsConfig, PolicyConfig, + RetryConfig, RouterConfig, RoutingMode, +}; +use sglang_router_rs::metrics::PrometheusConfig; +use sglang_router_rs::server::{self, ServerConfig}; +use sglang_router_rs::service_discovery::ServiceDiscoveryConfig; +use std::collections::HashMap; + +// Helper function to parse prefill arguments from command line +fn parse_prefill_args() -> Vec<(String, Option)> { + let args: Vec = std::env::args().collect(); + let mut prefill_entries = Vec::new(); + let mut i = 0; + + while i < args.len() { + if args[i] == "--prefill" && i + 1 < args.len() { + let url = args[i + 1].clone(); + let bootstrap_port = if i + 2 < args.len() && !args[i + 2].starts_with("--") { + // Check if next arg is a port number + if let Ok(port) = args[i + 2].parse::() { + i += 1; // Skip the port argument + Some(port) + } else if args[i + 2].to_lowercase() == "none" { + i += 1; // Skip the "none" argument + None + } else { + None + } + } else { + None + }; + prefill_entries.push((url, bootstrap_port)); + i += 2; // Skip --prefill and URL + } else { + i += 1; + } + } + + prefill_entries +} + +#[derive(Parser, Debug)] +#[command(name = "sglang-router")] +#[command(about = "SGLang Router - High-performance request distribution across worker nodes")] +#[command(long_about = r#" +SGLang Router - High-performance request distribution across worker nodes + +Usage: +This launcher enables starting a router with individual worker instances. It is useful for +multi-node setups or when you want to start workers and router separately. + +Examples: + # Regular mode + sglang-router --worker-urls http://worker1:8000 http://worker2:8000 + + # PD disaggregated mode with same policy for both + sglang-router --pd-disaggregation \ + --prefill http://127.0.0.1:30001 9001 \ + --prefill http://127.0.0.2:30002 9002 \ + --decode http://127.0.0.3:30003 \ + --decode http://127.0.0.4:30004 \ + --policy cache_aware + + # PD mode with different policies for prefill and decode + sglang-router --pd-disaggregation \ + --prefill http://127.0.0.1:30001 9001 \ + --prefill http://127.0.0.2:30002 \ + --decode http://127.0.0.3:30003 \ + --decode http://127.0.0.4:30004 \ + --prefill-policy cache_aware --decode-policy power_of_two +"#)] +struct CliArgs { + /// Host address to bind the router server + #[arg(long, default_value = "127.0.0.1")] + host: String, + + /// Port number to bind the router server + #[arg(long, default_value_t = 30000)] + port: u16, + + /// List of worker URLs (e.g., http://worker1:8000 http://worker2:8000) + #[arg(long, num_args = 0..)] + worker_urls: Vec, + + /// Load balancing policy to use + #[arg(long, default_value = "cache_aware", value_parser = ["random", "round_robin", "cache_aware", "power_of_two"])] + policy: String, + + /// Enable PD (Prefill-Decode) disaggregated mode + #[arg(long, default_value_t = false)] + pd_disaggregation: bool, + + /// Decode server URL (can be specified multiple times) + #[arg(long, action = ArgAction::Append)] + decode: Vec, + + /// Specific policy for prefill nodes in PD mode + #[arg(long, value_parser = ["random", "round_robin", "cache_aware", "power_of_two"])] + prefill_policy: Option, + + /// Specific policy for decode nodes in PD mode + #[arg(long, value_parser = ["random", "round_robin", "cache_aware", "power_of_two"])] + decode_policy: Option, + + /// Timeout in seconds for worker startup + #[arg(long, default_value_t = 300)] + worker_startup_timeout_secs: u64, + + /// Interval in seconds between checks for worker startup + #[arg(long, default_value_t = 10)] + worker_startup_check_interval: u64, + + /// Cache threshold (0.0-1.0) for cache-aware routing + #[arg(long, default_value_t = 0.5)] + cache_threshold: f32, + + /// Absolute threshold for load balancing + #[arg(long, default_value_t = 32)] + balance_abs_threshold: usize, + + /// Relative threshold for load balancing + #[arg(long, default_value_t = 1.0001)] + balance_rel_threshold: f32, + + /// Interval in seconds between cache eviction operations + #[arg(long, default_value_t = 60)] + eviction_interval: u64, + + /// Maximum size of the approximation tree for cache-aware routing + #[arg(long, default_value_t = 16777216)] // 2^24 + max_tree_size: usize, + + /// Maximum payload size in bytes + #[arg(long, default_value_t = 268435456)] // 256MB + max_payload_size: usize, + + /// Enable data parallelism aware schedule + #[arg(long, default_value_t = false)] + dp_aware: bool, + + /// API key for worker authorization + #[arg(long)] + api_key: Option, + + /// Directory to store log files + #[arg(long)] + log_dir: Option, + + /// Set the logging level + #[arg(long, default_value = "info", value_parser = ["debug", "info", "warn", "error"])] + log_level: String, + + /// Enable Kubernetes service discovery + #[arg(long, default_value_t = false)] + service_discovery: bool, + + /// Label selector for Kubernetes service discovery (format: key1=value1 key2=value2) + #[arg(long, num_args = 0..)] + selector: Vec, + + /// Port to use for discovered worker pods + #[arg(long, default_value_t = 80)] + service_discovery_port: u16, + + /// Kubernetes namespace to watch for pods + #[arg(long)] + service_discovery_namespace: Option, + + /// Label selector for prefill server pods in PD mode + #[arg(long, num_args = 0..)] + prefill_selector: Vec, + + /// Label selector for decode server pods in PD mode + #[arg(long, num_args = 0..)] + decode_selector: Vec, + + /// Port to expose Prometheus metrics + #[arg(long, default_value_t = 29000)] + prometheus_port: u16, + + /// Host address to bind the Prometheus metrics server + #[arg(long, default_value = "127.0.0.1")] + prometheus_host: String, + + /// Custom HTTP headers to check for request IDs + #[arg(long, num_args = 0..)] + request_id_headers: Vec, + + /// Request timeout in seconds + #[arg(long, default_value_t = 600)] + request_timeout_secs: u64, + + /// Maximum number of concurrent requests allowed + #[arg(long, default_value_t = 64)] + max_concurrent_requests: usize, + + /// CORS allowed origins + #[arg(long, num_args = 0..)] + cors_allowed_origins: Vec, + + // Retry configuration + /// Maximum number of retries + #[arg(long, default_value_t = 3)] + retry_max_retries: u32, + + /// Initial backoff in milliseconds for retries + #[arg(long, default_value_t = 100)] + retry_initial_backoff_ms: u64, + + /// Maximum backoff in milliseconds for retries + #[arg(long, default_value_t = 10000)] + retry_max_backoff_ms: u64, + + /// Backoff multiplier for exponential backoff + #[arg(long, default_value_t = 2.0)] + retry_backoff_multiplier: f32, + + /// Jitter factor for retry backoff + #[arg(long, default_value_t = 0.1)] + retry_jitter_factor: f32, + + /// Disable retries + #[arg(long, default_value_t = false)] + disable_retries: bool, + + // Circuit breaker configuration + /// Number of failures before circuit breaker opens + #[arg(long, default_value_t = 5)] + cb_failure_threshold: u32, + + /// Number of successes before circuit breaker closes + #[arg(long, default_value_t = 2)] + cb_success_threshold: u32, + + /// Timeout duration in seconds for circuit breaker + #[arg(long, default_value_t = 30)] + cb_timeout_duration_secs: u64, + + /// Window duration in seconds for circuit breaker + #[arg(long, default_value_t = 60)] + cb_window_duration_secs: u64, + + /// Disable circuit breaker + #[arg(long, default_value_t = false)] + disable_circuit_breaker: bool, +} + +impl CliArgs { + /// Parse selector strings into HashMap + fn parse_selector(selector_list: &[String]) -> HashMap { + let mut map = HashMap::new(); + for item in selector_list { + if let Some(eq_pos) = item.find('=') { + let key = item[..eq_pos].to_string(); + let value = item[eq_pos + 1..].to_string(); + map.insert(key, value); + } + } + map + } + + /// Convert policy string to PolicyConfig + fn parse_policy(&self, policy_str: &str) -> PolicyConfig { + match policy_str { + "random" => PolicyConfig::Random, + "round_robin" => PolicyConfig::RoundRobin, + "cache_aware" => PolicyConfig::CacheAware { + cache_threshold: self.cache_threshold, + balance_abs_threshold: self.balance_abs_threshold, + balance_rel_threshold: self.balance_rel_threshold, + eviction_interval_secs: self.eviction_interval, + max_tree_size: self.max_tree_size, + }, + "power_of_two" => PolicyConfig::PowerOfTwo { + load_check_interval_secs: 5, // Default value + }, + _ => PolicyConfig::RoundRobin, // Fallback + } + } + + /// Convert CLI arguments to RouterConfig + fn to_router_config( + &self, + prefill_urls: Vec<(String, Option)>, + ) -> ConfigResult { + // Determine routing mode + let mode = if self.pd_disaggregation { + let decode_urls = self.decode.clone(); + + // Validate PD configuration if not using service discovery + if !self.service_discovery && (prefill_urls.is_empty() || decode_urls.is_empty()) { + return Err(ConfigError::ValidationFailed { + reason: "PD disaggregation mode requires --prefill and --decode URLs when not using service discovery".to_string(), + }); + } + + RoutingMode::PrefillDecode { + prefill_urls, + decode_urls, + prefill_policy: self.prefill_policy.as_ref().map(|p| self.parse_policy(p)), + decode_policy: self.decode_policy.as_ref().map(|p| self.parse_policy(p)), + } + } else { + // Regular mode + if !self.service_discovery && self.worker_urls.is_empty() { + return Err(ConfigError::ValidationFailed { + reason: "Regular mode requires --worker-urls when not using service discovery" + .to_string(), + }); + } + RoutingMode::Regular { + worker_urls: self.worker_urls.clone(), + } + }; + + // Main policy + let policy = self.parse_policy(&self.policy); + + // Service discovery configuration + let discovery = if self.service_discovery { + Some(DiscoveryConfig { + enabled: true, + namespace: self.service_discovery_namespace.clone(), + port: self.service_discovery_port, + check_interval_secs: 60, + selector: Self::parse_selector(&self.selector), + prefill_selector: Self::parse_selector(&self.prefill_selector), + decode_selector: Self::parse_selector(&self.decode_selector), + bootstrap_port_annotation: "sglang.ai/bootstrap-port".to_string(), + }) + } else { + None + }; + + // Metrics configuration + let metrics = Some(MetricsConfig { + port: self.prometheus_port, + host: self.prometheus_host.clone(), + }); + + // Build RouterConfig + Ok(RouterConfig { + mode, + policy, + host: self.host.clone(), + port: self.port, + max_payload_size: self.max_payload_size, + request_timeout_secs: self.request_timeout_secs, + worker_startup_timeout_secs: self.worker_startup_timeout_secs, + worker_startup_check_interval_secs: self.worker_startup_check_interval, + dp_aware: self.dp_aware, + api_key: self.api_key.clone(), + discovery, + metrics, + log_dir: self.log_dir.clone(), + log_level: Some(self.log_level.clone()), + request_id_headers: if self.request_id_headers.is_empty() { + None + } else { + Some(self.request_id_headers.clone()) + }, + max_concurrent_requests: self.max_concurrent_requests, + cors_allowed_origins: self.cors_allowed_origins.clone(), + retry: RetryConfig { + max_retries: self.retry_max_retries, + initial_backoff_ms: self.retry_initial_backoff_ms, + max_backoff_ms: self.retry_max_backoff_ms, + backoff_multiplier: self.retry_backoff_multiplier, + jitter_factor: self.retry_jitter_factor, + }, + circuit_breaker: CircuitBreakerConfig { + failure_threshold: self.cb_failure_threshold, + success_threshold: self.cb_success_threshold, + timeout_duration_secs: self.cb_timeout_duration_secs, + window_duration_secs: self.cb_window_duration_secs, + }, + disable_retries: self.disable_retries, + disable_circuit_breaker: self.disable_circuit_breaker, + }) + } + + /// Create ServerConfig from CLI args and RouterConfig + fn to_server_config(&self, router_config: RouterConfig) -> ServerConfig { + // Create service discovery config if enabled + let service_discovery_config = if self.service_discovery { + Some(ServiceDiscoveryConfig { + enabled: true, + selector: Self::parse_selector(&self.selector), + check_interval: std::time::Duration::from_secs(60), + port: self.service_discovery_port, + namespace: self.service_discovery_namespace.clone(), + pd_mode: self.pd_disaggregation, + prefill_selector: Self::parse_selector(&self.prefill_selector), + decode_selector: Self::parse_selector(&self.decode_selector), + bootstrap_port_annotation: "sglang.ai/bootstrap-port".to_string(), + }) + } else { + None + }; + + // Create Prometheus config + let prometheus_config = Some(PrometheusConfig { + port: self.prometheus_port, + host: self.prometheus_host.clone(), + }); + + ServerConfig { + host: self.host.clone(), + port: self.port, + router_config, + max_payload_size: self.max_payload_size, + log_dir: self.log_dir.clone(), + log_level: Some(self.log_level.clone()), + service_discovery_config, + prometheus_config, + request_timeout_secs: self.request_timeout_secs, + request_id_headers: if self.request_id_headers.is_empty() { + None + } else { + Some(self.request_id_headers.clone()) + }, + } + } +} + +fn main() -> Result<(), Box> { + // Parse prefill arguments manually before clap parsing + let prefill_urls = parse_prefill_args(); + + // Filter out prefill arguments and their values before passing to clap + let mut filtered_args: Vec = Vec::new(); + let raw_args: Vec = std::env::args().collect(); + let mut i = 0; + + while i < raw_args.len() { + if raw_args[i] == "--prefill" && i + 1 < raw_args.len() { + // Skip --prefill and its URL + i += 2; + // Also skip bootstrap port if present + if i < raw_args.len() && !raw_args[i].starts_with("--") { + if raw_args[i].parse::().is_ok() || raw_args[i].to_lowercase() == "none" { + i += 1; + } + } + } else { + filtered_args.push(raw_args[i].clone()); + i += 1; + } + } + + // Parse CLI arguments with clap using filtered args + let cli_args = CliArgs::parse_from(filtered_args); + + // Print startup info + println!("SGLang Router starting..."); + println!("Host: {}:{}", cli_args.host, cli_args.port); + println!( + "Mode: {}", + if cli_args.pd_disaggregation { + "PD Disaggregated" + } else { + "Regular" + } + ); + println!("Policy: {}", cli_args.policy); + + if cli_args.pd_disaggregation && !prefill_urls.is_empty() { + println!("Prefill nodes: {:?}", prefill_urls); + println!("Decode nodes: {:?}", cli_args.decode); + } + + // Convert to RouterConfig + let router_config = cli_args.to_router_config(prefill_urls)?; + + // Validate configuration + router_config.validate()?; + + // Create ServerConfig + let server_config = cli_args.to_server_config(router_config); + + // Create a new runtime for the server (like Python binding does) + let runtime = tokio::runtime::Runtime::new()?; + + // Block on the async startup function + runtime.block_on(async move { server::startup(server_config).await })?; + + Ok(()) +} diff --git a/sgl-router/src/routers/factory.rs b/sgl-router/src/routers/factory.rs index 357007278..78dbb932e 100644 --- a/sgl-router/src/routers/factory.rs +++ b/sgl-router/src/routers/factory.rs @@ -11,29 +11,32 @@ pub struct RouterFactory; impl RouterFactory { /// Create a router instance from application context - pub fn create_router(ctx: &Arc) -> Result, String> { + pub async fn create_router(ctx: &Arc) -> Result, String> { match &ctx.router_config.mode { RoutingMode::Regular { worker_urls } => { - Self::create_regular_router(worker_urls, &ctx.router_config.policy, ctx) + Self::create_regular_router(worker_urls, &ctx.router_config.policy, ctx).await } RoutingMode::PrefillDecode { prefill_urls, decode_urls, prefill_policy, decode_policy, - } => Self::create_pd_router( - prefill_urls, - decode_urls, - prefill_policy.as_ref(), - decode_policy.as_ref(), - &ctx.router_config.policy, - ctx, - ), + } => { + Self::create_pd_router( + prefill_urls, + decode_urls, + prefill_policy.as_ref(), + decode_policy.as_ref(), + &ctx.router_config.policy, + ctx, + ) + .await + } } } /// Create a regular router with injected policy - fn create_regular_router( + async fn create_regular_router( worker_urls: &[String], policy_config: &PolicyConfig, ctx: &Arc, @@ -52,13 +55,14 @@ impl RouterFactory { ctx.router_config.api_key.clone(), ctx.router_config.retry.clone(), ctx.router_config.circuit_breaker.clone(), - )?; + ) + .await?; Ok(Box::new(router)) } /// Create a PD router with injected policy - fn create_pd_router( + async fn create_pd_router( prefill_urls: &[(String, Option)], decode_urls: &[String], prefill_policy_config: Option<&PolicyConfig>, @@ -83,7 +87,8 @@ impl RouterFactory { ctx.router_config.worker_startup_check_interval_secs, ctx.router_config.retry.clone(), ctx.router_config.circuit_breaker.clone(), - )?; + ) + .await?; Ok(Box::new(router)) } diff --git a/sgl-router/src/routers/pd_router.rs b/sgl-router/src/routers/pd_router.rs index b0347e59f..729bca0e7 100644 --- a/sgl-router/src/routers/pd_router.rs +++ b/sgl-router/src/routers/pd_router.rs @@ -67,6 +67,7 @@ impl PDRouter { self.timeout_secs, self.interval_secs, ) + .await .map_err(|_| PDRouterError::HealthCheckFailed { url: url.to_string(), }) @@ -349,7 +350,7 @@ impl PDRouter { Ok(format!("Successfully removed decode server: {}", url)) } - pub fn new( + pub async fn new( prefill_urls: Vec<(String, Option)>, decode_urls: Vec, prefill_policy: Arc, @@ -392,7 +393,8 @@ impl PDRouter { &all_urls, timeout_secs, interval_secs, - )?; + ) + .await?; } // Initialize cache-aware policies with workers diff --git a/sgl-router/src/routers/router.rs b/sgl-router/src/routers/router.rs index d6ecb0960..023607ac1 100644 --- a/sgl-router/src/routers/router.rs +++ b/sgl-router/src/routers/router.rs @@ -17,7 +17,6 @@ use futures_util::StreamExt; use reqwest::Client; use std::collections::HashMap; use std::sync::{Arc, RwLock}; -use std::thread; use std::time::{Duration, Instant}; use tokio_stream::wrappers::UnboundedReceiverStream; use tracing::{debug, error, info, warn}; @@ -52,7 +51,7 @@ pub struct Router { impl Router { /// Create a new router with injected policy and client - pub fn new( + pub async fn new( worker_urls: Vec, policy: Arc, client: Client, @@ -68,7 +67,7 @@ impl Router { // Wait for workers to be healthy (skip if empty - for service discovery mode) if !worker_urls.is_empty() { - Self::wait_for_healthy_workers(&worker_urls, timeout_secs, interval_secs)?; + Self::wait_for_healthy_workers(&worker_urls, timeout_secs, interval_secs).await?; } let worker_urls = if dp_aware { @@ -156,7 +155,7 @@ impl Router { .collect() } - pub fn wait_for_healthy_workers( + pub async fn wait_for_healthy_workers( worker_urls: &[String], timeout_secs: u64, interval_secs: u64, @@ -167,9 +166,24 @@ impl Router { ); } + // Perform health check asynchronously + Self::wait_for_healthy_workers_async(worker_urls, timeout_secs, interval_secs).await + } + + async fn wait_for_healthy_workers_async( + worker_urls: &[String], + timeout_secs: u64, + interval_secs: u64, + ) -> Result<(), String> { + info!( + "Waiting for {} workers to become healthy (timeout: {}s)", + worker_urls.len(), + timeout_secs + ); + let start_time = std::time::Instant::now(); - let sync_client = reqwest::blocking::Client::builder() - .timeout(Duration::from_secs(timeout_secs)) + let client = reqwest::Client::builder() + .timeout(Duration::from_secs(2)) .build() .map_err(|e| format!("Failed to create HTTP client: {}", e))?; @@ -185,20 +199,48 @@ impl Router { )); } + // Perform all health checks concurrently + let mut health_checks = Vec::new(); + for url in worker_urls { + let client_clone = client.clone(); + let url_clone = url.clone(); + + let check_health = tokio::spawn(async move { + let health_url = format!("{}/health", url_clone); + match client_clone.get(&health_url).send().await { + Ok(res) => { + if res.status().is_success() { + None + } else { + Some((url_clone, format!("status: {}", res.status()))) + } + } + Err(_) => Some((url_clone, "not ready".to_string())), + } + }); + + health_checks.push(check_health); + } + + // Wait for all health checks to complete + let results = futures::future::join_all(health_checks).await; + let mut all_healthy = true; let mut unhealthy_workers = Vec::new(); - for url in worker_urls { - match sync_client.get(&format!("{}/health", url)).send() { - Ok(res) => { - if !res.status().is_success() { - all_healthy = false; - unhealthy_workers.push((url, format!("status: {}", res.status()))); - } + for result in results { + match result { + Ok(None) => { + // Worker is healthy } - Err(_) => { + Ok(Some((url, reason))) => { all_healthy = false; - unhealthy_workers.push((url, "not ready".to_string())); + unhealthy_workers.push((url, reason)); + } + Err(e) => { + all_healthy = false; + unhealthy_workers + .push(("unknown".to_string(), format!("task error: {}", e))); } } } @@ -208,11 +250,12 @@ impl Router { return Ok(()); } else { debug!( - "Waiting for {} workers to become healthy ({} unhealthy)", + "Waiting for {} workers to become healthy ({} unhealthy: {:?})", worker_urls.len(), - unhealthy_workers.len() + unhealthy_workers.len(), + unhealthy_workers ); - thread::sleep(Duration::from_secs(interval_secs)); + tokio::time::sleep(Duration::from_secs(interval_secs)).await; } } } @@ -1246,19 +1289,19 @@ mod tests { assert_eq!(result.unwrap(), "http://worker1:8080"); } - #[test] - fn test_wait_for_healthy_workers_empty_list() { - // Empty list will timeout as there are no workers to check - let result = Router::wait_for_healthy_workers(&[], 1, 1); + #[tokio::test] + async fn test_wait_for_healthy_workers_empty_list() { + // Empty list will return error immediately + let result = Router::wait_for_healthy_workers(&[], 1, 1).await; assert!(result.is_err()); - assert!(result.unwrap_err().contains("Timeout")); + assert!(result.unwrap_err().contains("no workers provided")); } - #[test] - fn test_wait_for_healthy_workers_invalid_urls() { + #[tokio::test] + async fn test_wait_for_healthy_workers_invalid_urls() { // This test will timeout quickly since the URLs are invalid let result = - Router::wait_for_healthy_workers(&["http://nonexistent:8080".to_string()], 1, 1); + Router::wait_for_healthy_workers(&["http://nonexistent:8080".to_string()], 1, 1).await; assert!(result.is_err()); assert!(result.unwrap_err().contains("Timeout")); } diff --git a/sgl-router/src/server.rs b/sgl-router/src/server.rs index 1ca668374..9746e5845 100644 --- a/sgl-router/src/server.rs +++ b/sgl-router/src/server.rs @@ -285,7 +285,7 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box Arc { + async fn create_test_router() -> Arc { use crate::config::PolicyConfig; use crate::policies::PolicyFactory; use crate::routers::router::Router; @@ -593,6 +593,7 @@ mod tests { crate::config::types::RetryConfig::default(), crate::config::types::CircuitBreakerConfig::default(), ) + .await .unwrap(); Arc::new(router) as Arc } @@ -896,7 +897,7 @@ mod tests { #[tokio::test] async fn test_handle_pod_event_add_unhealthy_pod() { - let router = create_test_router(); + let router = create_test_router().await; let tracked_pods = Arc::new(Mutex::new(HashSet::new())); let pod_info = PodInfo { name: "pod1".into(), @@ -925,7 +926,7 @@ mod tests { #[tokio::test] async fn test_handle_pod_deletion_non_existing_pod() { - let router = create_test_router(); + let router = create_test_router().await; let tracked_pods = Arc::new(Mutex::new(HashSet::new())); let pod_info = PodInfo { name: "pod1".into(), @@ -952,7 +953,7 @@ mod tests { #[tokio::test] async fn test_handle_pd_pod_event_prefill_pod() { - let router = create_test_router(); + let router = create_test_router().await; let tracked_pods = Arc::new(Mutex::new(HashSet::new())); let pod_info = PodInfo { name: "prefill-pod".into(), @@ -981,7 +982,7 @@ mod tests { #[tokio::test] async fn test_handle_pd_pod_event_decode_pod() { - let router = create_test_router(); + let router = create_test_router().await; let tracked_pods = Arc::new(Mutex::new(HashSet::new())); let pod_info = PodInfo { name: "decode-pod".into(), @@ -1008,7 +1009,7 @@ mod tests { #[tokio::test] async fn test_handle_pd_pod_deletion_tracked_pod() { - let router = create_test_router(); + let router = create_test_router().await; let tracked_pods = Arc::new(Mutex::new(HashSet::new())); let pod_info = PodInfo { name: "test-pod".into(), @@ -1042,7 +1043,7 @@ mod tests { #[tokio::test] async fn test_handle_pd_pod_deletion_untracked_pod() { - let router = create_test_router(); + let router = create_test_router().await; let tracked_pods = Arc::new(Mutex::new(HashSet::new())); let pod_info = PodInfo { name: "untracked-pod".into(), @@ -1071,7 +1072,7 @@ mod tests { #[tokio::test] async fn test_unified_handler_regular_mode() { - let router = create_test_router(); + let router = create_test_router().await; let tracked_pods = Arc::new(Mutex::new(HashSet::new())); let pod_info = PodInfo { name: "regular-pod".into(), @@ -1099,7 +1100,7 @@ mod tests { #[tokio::test] async fn test_unified_handler_pd_mode_with_prefill() { - let router = create_test_router(); + let router = create_test_router().await; let tracked_pods = Arc::new(Mutex::new(HashSet::new())); let pod_info = PodInfo { name: "prefill-pod".into(), @@ -1127,7 +1128,7 @@ mod tests { #[tokio::test] async fn test_unified_handler_deletion_with_pd_mode() { - let router = create_test_router(); + let router = create_test_router().await; let tracked_pods = Arc::new(Mutex::new(HashSet::new())); let pod_info = PodInfo { name: "decode-pod".into(), diff --git a/sgl-router/tests/api_endpoints_test.rs b/sgl-router/tests/api_endpoints_test.rs index 71411d759..1cadead56 100644 --- a/sgl-router/tests/api_endpoints_test.rs +++ b/sgl-router/tests/api_endpoints_test.rs @@ -92,12 +92,8 @@ impl TestContext { // Create app context let app_context = common::create_test_context(config.clone()); - // Create router using sync factory in a blocking context - let router = - tokio::task::spawn_blocking(move || RouterFactory::create_router(&app_context)) - .await - .unwrap() - .unwrap(); + // Create router + let router = RouterFactory::create_router(&app_context).await.unwrap(); let router = Arc::from(router); // Wait for router to discover workers @@ -1451,10 +1447,7 @@ mod pd_mode_tests { let app_context = common::create_test_context(config); // Create router - this might fail due to health check issues - let router_result = - tokio::task::spawn_blocking(move || RouterFactory::create_router(&app_context)) - .await - .unwrap(); + let router_result = RouterFactory::create_router(&app_context).await; // Clean up workers prefill_worker.stop().await; diff --git a/sgl-router/tests/request_formats_test.rs b/sgl-router/tests/request_formats_test.rs index 251179e49..be04a9103 100644 --- a/sgl-router/tests/request_formats_test.rs +++ b/sgl-router/tests/request_formats_test.rs @@ -60,11 +60,7 @@ impl TestContext { config.mode = RoutingMode::Regular { worker_urls }; let app_context = common::create_test_context(config); - let router = - tokio::task::spawn_blocking(move || RouterFactory::create_router(&app_context)) - .await - .unwrap() - .unwrap(); + let router = RouterFactory::create_router(&app_context).await.unwrap(); let router = Arc::from(router); if !workers.is_empty() { diff --git a/sgl-router/tests/streaming_tests.rs b/sgl-router/tests/streaming_tests.rs index 66a15ae08..bb0090daa 100644 --- a/sgl-router/tests/streaming_tests.rs +++ b/sgl-router/tests/streaming_tests.rs @@ -61,11 +61,7 @@ impl TestContext { config.mode = RoutingMode::Regular { worker_urls }; let app_context = common::create_test_context(config); - let router = - tokio::task::spawn_blocking(move || RouterFactory::create_router(&app_context)) - .await - .unwrap() - .unwrap(); + let router = RouterFactory::create_router(&app_context).await.unwrap(); let router = Arc::from(router); if !workers.is_empty() { diff --git a/sgl-router/tests/test_pd_routing.rs b/sgl-router/tests/test_pd_routing.rs index 95860bca6..72b1697cf 100644 --- a/sgl-router/tests/test_pd_routing.rs +++ b/sgl-router/tests/test_pd_routing.rs @@ -109,8 +109,8 @@ mod test_pd_routing { } } - #[test] - fn test_pd_router_configuration() { + #[tokio::test] + async fn test_pd_router_configuration() { // Test PD router configuration with various policies // In the new structure, RoutingMode and PolicyConfig are separate let test_cases = vec![ @@ -190,7 +190,7 @@ mod test_pd_routing { let app_context = sglang_router_rs::server::AppContext::new(config, reqwest::Client::new(), 64); let app_context = std::sync::Arc::new(app_context); - let result = RouterFactory::create_router(&app_context); + let result = RouterFactory::create_router(&app_context).await; assert!(result.is_err()); let error_msg = result.unwrap_err(); // Error should be about health/timeout, not configuration