From 00eb5eb7213c7b75984ef44eae80bda2228d4cfd Mon Sep 17 00:00:00 2001 From: Simo Lin Date: Fri, 19 Sep 2025 15:37:57 -0400 Subject: [PATCH] [router] refactor router and worker management 2/n (#10666) --- sgl-router/src/routers/factory.rs | 20 +- sgl-router/src/routers/http/pd_router.rs | 94 ++--- sgl-router/src/routers/http/router.rs | 79 +--- sgl-router/src/routers/mod.rs | 2 + sgl-router/src/routers/worker_initializer.rs | 361 +++++++++++++++++++ sgl-router/src/server.rs | 28 +- sgl-router/src/service_discovery.rs | 2 +- sgl-router/tests/api_endpoints_test.rs | 8 + sgl-router/tests/request_formats_test.rs | 15 +- sgl-router/tests/streaming_tests.rs | 15 +- sgl-router/tests/test_pd_routing.rs | 16 +- 11 files changed, 483 insertions(+), 157 deletions(-) create mode 100644 sgl-router/src/routers/worker_initializer.rs diff --git a/sgl-router/src/routers/factory.rs b/sgl-router/src/routers/factory.rs index 9fec8be13..a1a2cfd64 100644 --- a/sgl-router/src/routers/factory.rs +++ b/sgl-router/src/routers/factory.rs @@ -47,18 +47,17 @@ impl RouterFactory { ConnectionMode::Http => { // Route to HTTP implementation based on routing mode match &ctx.router_config.mode { - RoutingMode::Regular { worker_urls } => { - Self::create_regular_router(worker_urls, ctx).await + RoutingMode::Regular { .. } => { + // Workers already initialized in registry + Self::create_regular_router(ctx).await } RoutingMode::PrefillDecode { - prefill_urls, - decode_urls, prefill_policy, decode_policy, + .. } => { + // Workers already initialized in registry Self::create_pd_router( - prefill_urls, - decode_urls, prefill_policy.as_ref(), decode_policy.as_ref(), &ctx.router_config.policy, @@ -76,19 +75,17 @@ impl RouterFactory { /// Create a regular router pub async fn create_regular_router( - worker_urls: &[String], ctx: &Arc, ) -> Result, String> { // Create regular router with context - let router = Router::new(worker_urls.to_vec(), ctx).await?; + // Workers should already be initialized in the registry + let router = Router::new(ctx).await?; Ok(Box::new(router)) } /// Create a PD router with injected policy pub async fn create_pd_router( - prefill_urls: &[(String, Option)], - decode_urls: &[String], prefill_policy_config: Option<&PolicyConfig>, decode_policy_config: Option<&PolicyConfig>, main_policy_config: &PolicyConfig, @@ -105,7 +102,8 @@ impl RouterFactory { ctx.policy_registry.set_decode_policy(decode_policy); // Create PD router with context (policies are in PolicyRegistry) - let router = PDRouter::new(prefill_urls.to_vec(), decode_urls.to_vec(), ctx).await?; + // Workers should already be initialized in the registry + let router = PDRouter::new(ctx).await?; Ok(Box::new(router)) } diff --git a/sgl-router/src/routers/http/pd_router.rs b/sgl-router/src/routers/http/pd_router.rs index df2d2e987..b0002f62e 100644 --- a/sgl-router/src/routers/http/pd_router.rs +++ b/sgl-router/src/routers/http/pd_router.rs @@ -3,7 +3,7 @@ use super::pd_types::{api_path, PDRouterError}; use crate::config::types::RetryConfig; use crate::core::{ - is_retryable_status, BasicWorkerBuilder, CircuitBreakerConfig, HealthConfig, RetryExecutor, + is_retryable_status, BasicWorkerBuilder, CircuitBreakerConfig, ConnectionMode, RetryExecutor, Worker, WorkerLoadGuard, WorkerRegistry, WorkerType, }; use crate::metrics::RouterMetrics; @@ -371,12 +371,30 @@ impl PDRouter { } } - #[allow(clippy::too_many_arguments)] - pub async fn new( - prefill_urls: Vec<(String, Option)>, - decode_urls: Vec, - ctx: &Arc, - ) -> Result { + pub async fn new(ctx: &Arc) -> Result { + let prefill_workers = ctx.worker_registry.get_workers_filtered( + None, // any model + Some(WorkerType::Prefill { + bootstrap_port: None, + }), + Some(ConnectionMode::Http), + false, // include all workers + ); + + let decode_workers = ctx.worker_registry.get_workers_filtered( + None, // any model + Some(WorkerType::Decode), + Some(ConnectionMode::Http), + false, // include all workers + ); + + // Get all worker URLs for monitoring + let all_urls: Vec = prefill_workers + .iter() + .chain(decode_workers.iter()) + .map(|w| w.url().to_string()) + .collect(); + // Convert config CircuitBreakerConfig to core CircuitBreakerConfig let circuit_breaker_config = ctx.router_config.effective_circuit_breaker_config(); let core_cb_config = CircuitBreakerConfig { @@ -386,60 +404,6 @@ impl PDRouter { window_duration: Duration::from_secs(circuit_breaker_config.window_duration_secs), }; - // Register prefill workers in the registry - for (url, port) in prefill_urls { - let worker = BasicWorkerBuilder::new(url) - .worker_type(WorkerType::Prefill { - bootstrap_port: port, - }) - .circuit_breaker_config(core_cb_config.clone()) - .health_config(HealthConfig { - timeout_secs: ctx.router_config.health_check.timeout_secs, - check_interval_secs: ctx.router_config.health_check.check_interval_secs, - endpoint: ctx.router_config.health_check.endpoint.clone(), - failure_threshold: ctx.router_config.health_check.failure_threshold, - success_threshold: ctx.router_config.health_check.success_threshold, - }) - .build(); - ctx.worker_registry.register(Arc::new(worker)); - } - - // Register decode workers in the registry - for url in decode_urls { - let worker = BasicWorkerBuilder::new(url) - .worker_type(WorkerType::Decode) - .circuit_breaker_config(core_cb_config.clone()) - .health_config(HealthConfig { - timeout_secs: ctx.router_config.health_check.timeout_secs, - check_interval_secs: ctx.router_config.health_check.check_interval_secs, - endpoint: ctx.router_config.health_check.endpoint.clone(), - failure_threshold: ctx.router_config.health_check.failure_threshold, - success_threshold: ctx.router_config.health_check.success_threshold, - }) - .build(); - ctx.worker_registry.register(Arc::new(worker)); - } - - // Get all workers from registry for health check - let all_workers = ctx.worker_registry.get_all(); - let all_urls: Vec = all_workers - .iter() - .map(|worker| worker.url().to_string()) - .collect(); - if !all_urls.is_empty() { - crate::routers::http::router::Router::wait_for_healthy_workers( - &all_urls, - ctx.router_config.worker_startup_timeout_secs, - ctx.router_config.worker_startup_check_interval_secs, - ) - .await?; - } - - // Initialize cache-aware policies with workers from registry - // Note: We need to get workers by type and convert to Box for CacheAwarePolicy - // This is a temporary workaround until CacheAwarePolicy is updated to work with Arc - // TODO: Update CacheAwarePolicy to accept Arc instead of Box - // Set up background load monitoring for power-of-two selection let (tx, rx) = tokio::sync::watch::channel(HashMap::new()); let worker_loads = Arc::new(rx); @@ -471,11 +435,8 @@ impl PDRouter { None }; - // Note: Health checking is now handled centrally by RouterManager - // Individual routers no longer need to manage health checkers - // Build a dedicated prefill client for fire-and-forget semantics - let prefill_client = reqwest::Client::builder() + let prefill_client = Client::builder() .pool_max_idle_per_host(0) .http1_only() .connect_timeout(Duration::from_millis(300)) @@ -489,6 +450,7 @@ impl PDRouter { // Spawn a coordinator with limited concurrent drain tasks // This prevents unbounded task spawning under extreme load + // TODO reevaluate a simpler approach (e.g. do we really need to deal with fire and forget) tokio::spawn(async move { info!("Prefill drain coordinator started"); @@ -513,7 +475,7 @@ impl PDRouter { // Drain the response body efficiently // Use streaming to avoid loading entire body into memory - let start = std::time::Instant::now(); + let start = Instant::now(); let mut stream = response.bytes_stream(); let mut bytes_drained = 0; diff --git a/sgl-router/src/routers/http/router.rs b/sgl-router/src/routers/http/router.rs index 91ad1d948..e5305cb56 100644 --- a/sgl-router/src/routers/http/router.rs +++ b/sgl-router/src/routers/http/router.rs @@ -1,6 +1,6 @@ use crate::config::types::RetryConfig; use crate::core::{ - is_retryable_status, BasicWorkerBuilder, CircuitBreakerConfig, HealthConfig, RetryExecutor, + is_retryable_status, BasicWorkerBuilder, CircuitBreakerConfig, ConnectionMode, RetryExecutor, Worker, WorkerRegistry, WorkerType, }; use crate::metrics::RouterMetrics; @@ -47,31 +47,19 @@ pub struct Router { impl Router { /// Create a new router with injected policy and client - #[allow(clippy::too_many_arguments)] - pub async fn new( - worker_urls: Vec, - ctx: &Arc, - ) -> Result { + pub async fn new(ctx: &Arc) -> Result { + let workers = ctx.worker_registry.get_workers_filtered( + None, // any model + Some(WorkerType::Regular), + Some(ConnectionMode::Http), + false, // include all workers + ); + // Update active workers gauge - RouterMetrics::set_active_workers(worker_urls.len()); + RouterMetrics::set_active_workers(workers.len()); - // Wait for workers to be healthy (skip if empty - for service discovery mode) - if !worker_urls.is_empty() { - Self::wait_for_healthy_workers( - &worker_urls, - ctx.router_config.worker_startup_timeout_secs, - ctx.router_config.worker_startup_check_interval_secs, - ) - .await?; - } - - let worker_urls = if ctx.router_config.dp_aware { - // worker address now in the format of "http://host:port@dp_rank" - Self::get_dp_aware_workers(&worker_urls, &ctx.router_config.api_key) - .map_err(|e| format!("Failed to get dp-aware workers: {}", e))? - } else { - worker_urls - }; + // Get worker URLs for monitoring + let worker_urls: Vec = workers.iter().map(|w| w.url().to_string()).collect(); // Convert config CircuitBreakerConfig to core CircuitBreakerConfig let circuit_breaker_config = ctx.router_config.effective_circuit_breaker_config(); @@ -82,40 +70,14 @@ impl Router { window_duration: Duration::from_secs(circuit_breaker_config.window_duration_secs), }; - // Register workers in the registry - // In IGW mode, we need to fetch model info from workers - for url in &worker_urls { - // TODO: In IGW mode, fetch model_id from worker's /get_model_info endpoint - // For now, create worker without model_id - let worker = BasicWorkerBuilder::new(url.clone()) - .worker_type(WorkerType::Regular) - .circuit_breaker_config(core_cb_config.clone()) - .health_config(HealthConfig { - timeout_secs: ctx.router_config.health_check.timeout_secs, - check_interval_secs: ctx.router_config.health_check.check_interval_secs, - endpoint: ctx.router_config.health_check.endpoint.clone(), - failure_threshold: ctx.router_config.health_check.failure_threshold, - success_threshold: ctx.router_config.health_check.success_threshold, - }) - .build(); - - let worker_arc = Arc::new(worker); - ctx.worker_registry.register(worker_arc.clone()); - - // Notify PolicyRegistry about the new worker - let model_id = worker_arc.model_id(); - let policy = ctx.policy_registry.on_worker_added(model_id, None); - - // If this is a cache-aware policy and it's the first worker for this model, - // initialize it with the worker - if policy.name() == "cache_aware" { - if let Some(cache_aware) = policy - .as_any() - .downcast_ref::() - { - let worker_dyn: Arc = worker_arc.clone(); - cache_aware.init_workers(std::slice::from_ref(&worker_dyn)); - } + // Initialize cache-aware policy with workers if needed + let default_policy = ctx.policy_registry.get_default_policy(); + if default_policy.name() == "cache_aware" { + if let Some(cache_aware) = default_policy + .as_any() + .downcast_ref::() + { + cache_aware.init_workers(&workers); } } @@ -124,7 +86,6 @@ impl Router { let worker_loads = Arc::new(rx); // Check if default policy is power_of_two for load monitoring - let default_policy = ctx.policy_registry.get_default_policy(); let load_monitor_handle = if default_policy.name() == "power_of_two" { let monitor_urls = worker_urls.clone(); let monitor_interval = ctx.router_config.worker_startup_check_interval_secs; diff --git a/sgl-router/src/routers/mod.rs b/sgl-router/src/routers/mod.rs index ea64c12e1..76ac33847 100644 --- a/sgl-router/src/routers/mod.rs +++ b/sgl-router/src/routers/mod.rs @@ -19,8 +19,10 @@ pub mod grpc; pub mod header_utils; pub mod http; pub mod router_manager; +pub mod worker_initializer; pub use factory::RouterFactory; +pub use worker_initializer::WorkerInitializer; // Re-export HTTP routers for convenience (keeps routers::openai_router path working) pub use http::{openai_router, pd_router, pd_types, router}; diff --git a/sgl-router/src/routers/worker_initializer.rs b/sgl-router/src/routers/worker_initializer.rs new file mode 100644 index 000000000..e71fba1c8 --- /dev/null +++ b/sgl-router/src/routers/worker_initializer.rs @@ -0,0 +1,361 @@ +// Worker Initialization Module +// Separates worker lifecycle management from router construction + +use crate::config::types::{ConnectionMode as ConfigConnectionMode, RouterConfig, RoutingMode}; +use crate::core::{ + BasicWorkerBuilder, CircuitBreakerConfig, ConnectionMode, HealthConfig, WorkerRegistry, + WorkerType, +}; +use std::sync::Arc; +use std::time::Duration; +use tracing::{info, warn}; + +/// WorkerInitializer handles the creation and registration of workers +/// based on routing configuration, separating this concern from router constructors +pub struct WorkerInitializer; + +impl WorkerInitializer { + /// Initialize workers based on configuration and register them in the WorkerRegistry + pub async fn initialize_workers( + config: &RouterConfig, + worker_registry: &Arc, + ) -> Result<(), String> { + info!("Initializing workers for routing mode: {:?}", config.mode); + + match &config.mode { + RoutingMode::Regular { worker_urls } => { + Self::create_regular_workers( + worker_urls, + &config.connection_mode, + config, + worker_registry, + ) + .await?; + } + RoutingMode::PrefillDecode { + prefill_urls, + decode_urls, + .. + } => { + Self::create_prefill_workers( + prefill_urls, + &config.connection_mode, + config, + worker_registry, + ) + .await?; + Self::create_decode_workers( + decode_urls, + &config.connection_mode, + config, + worker_registry, + ) + .await?; + } + RoutingMode::OpenAI { .. } => { + info!("OpenAI routing mode - no local workers to initialize"); + } + } + + // Wait for workers to be healthy if any were registered + if worker_registry.stats().total_workers > 0 { + Self::wait_for_healthy_workers( + worker_registry, + config.worker_startup_timeout_secs, + config.worker_startup_check_interval_secs, + ) + .await?; + } + + Ok(()) + } + + /// Create regular workers for standard routing mode + async fn create_regular_workers( + urls: &[String], + config_connection_mode: &ConfigConnectionMode, + config: &RouterConfig, + registry: &Arc, + ) -> Result<(), String> { + info!("Creating {} regular workers", urls.len()); + + // Convert config connection mode to core connection mode + let connection_mode = Self::convert_connection_mode(config_connection_mode, urls.first()); + + // Convert circuit breaker config + let circuit_breaker_config = config.effective_circuit_breaker_config(); + let core_cb_config = CircuitBreakerConfig { + failure_threshold: circuit_breaker_config.failure_threshold, + success_threshold: circuit_breaker_config.success_threshold, + timeout_duration: Duration::from_secs(circuit_breaker_config.timeout_duration_secs), + window_duration: Duration::from_secs(circuit_breaker_config.window_duration_secs), + }; + + // Convert health check config + let health_config = HealthConfig { + timeout_secs: config.health_check.timeout_secs, + check_interval_secs: config.health_check.check_interval_secs, + endpoint: config.health_check.endpoint.clone(), + failure_threshold: config.health_check.failure_threshold, + success_threshold: config.health_check.success_threshold, + }; + + for url in urls { + // TODO: Add DP-aware support when we have dp_rank/dp_size info + let worker = BasicWorkerBuilder::new(url.clone()) + .worker_type(WorkerType::Regular) + .connection_mode(connection_mode.clone()) + .circuit_breaker_config(core_cb_config.clone()) + .health_config(health_config.clone()) + .build(); + + let worker_id = registry.register(Arc::new(worker)); + info!("Registered regular worker {} with ID {:?}", url, worker_id); + } + + Ok(()) + } + + /// Create prefill workers for disaggregated routing mode + async fn create_prefill_workers( + prefill_entries: &[(String, Option)], + config_connection_mode: &ConfigConnectionMode, + config: &RouterConfig, + registry: &Arc, + ) -> Result<(), String> { + info!("Creating {} prefill workers", prefill_entries.len()); + + // Convert config connection mode to core connection mode + let connection_mode = Self::convert_connection_mode( + config_connection_mode, + prefill_entries.first().map(|(url, _)| url), + ); + + // Convert circuit breaker config + let circuit_breaker_config = config.effective_circuit_breaker_config(); + let core_cb_config = CircuitBreakerConfig { + failure_threshold: circuit_breaker_config.failure_threshold, + success_threshold: circuit_breaker_config.success_threshold, + timeout_duration: Duration::from_secs(circuit_breaker_config.timeout_duration_secs), + window_duration: Duration::from_secs(circuit_breaker_config.window_duration_secs), + }; + + // Convert health check config + let health_config = HealthConfig { + timeout_secs: config.health_check.timeout_secs, + check_interval_secs: config.health_check.check_interval_secs, + endpoint: config.health_check.endpoint.clone(), + failure_threshold: config.health_check.failure_threshold, + success_threshold: config.health_check.success_threshold, + }; + + for (url, bootstrap_port) in prefill_entries { + // TODO: Add DP-aware support when we have dp_rank/dp_size info + let worker = BasicWorkerBuilder::new(url.clone()) + .worker_type(WorkerType::Prefill { + bootstrap_port: *bootstrap_port, + }) + .connection_mode(connection_mode.clone()) + .circuit_breaker_config(core_cb_config.clone()) + .health_config(health_config.clone()) + .build(); + + let worker_id = registry.register(Arc::new(worker)); + info!("Registered prefill worker {} with ID {:?}", url, worker_id); + } + + Ok(()) + } + + /// Create decode workers for disaggregated routing mode + async fn create_decode_workers( + urls: &[String], + config_connection_mode: &ConfigConnectionMode, + config: &RouterConfig, + registry: &Arc, + ) -> Result<(), String> { + info!("Creating {} decode workers", urls.len()); + + // Convert config connection mode to core connection mode + let connection_mode = Self::convert_connection_mode(config_connection_mode, urls.first()); + + // Convert circuit breaker config + let circuit_breaker_config = config.effective_circuit_breaker_config(); + let core_cb_config = CircuitBreakerConfig { + failure_threshold: circuit_breaker_config.failure_threshold, + success_threshold: circuit_breaker_config.success_threshold, + timeout_duration: Duration::from_secs(circuit_breaker_config.timeout_duration_secs), + window_duration: Duration::from_secs(circuit_breaker_config.window_duration_secs), + }; + + // Convert health check config + let health_config = HealthConfig { + timeout_secs: config.health_check.timeout_secs, + check_interval_secs: config.health_check.check_interval_secs, + endpoint: config.health_check.endpoint.clone(), + failure_threshold: config.health_check.failure_threshold, + success_threshold: config.health_check.success_threshold, + }; + + for url in urls { + // TODO: Add DP-aware support when we have dp_rank/dp_size info + let worker = BasicWorkerBuilder::new(url.clone()) + .worker_type(WorkerType::Decode) + .connection_mode(connection_mode.clone()) + .circuit_breaker_config(core_cb_config.clone()) + .health_config(health_config.clone()) + .build(); + + let worker_id = registry.register(Arc::new(worker)); + info!("Registered decode worker {} with ID {:?}", url, worker_id); + } + + Ok(()) + } + + /// Convert config connection mode to core connection mode + fn convert_connection_mode( + config_mode: &ConfigConnectionMode, + _sample_url: Option<&String>, + ) -> ConnectionMode { + match config_mode { + ConfigConnectionMode::Http => ConnectionMode::Http, + ConfigConnectionMode::Grpc => ConnectionMode::Grpc { port: None }, + } + } + + /// Wait for workers to become healthy + async fn wait_for_healthy_workers( + registry: &Arc, + timeout_secs: u64, + check_interval_secs: u64, + ) -> Result<(), String> { + let timeout = Duration::from_secs(timeout_secs); + let check_interval = Duration::from_secs(check_interval_secs); + let start_time = std::time::Instant::now(); + + info!( + "Waiting for workers to become healthy (timeout: {}s)", + timeout_secs + ); + + loop { + let stats = registry.stats(); + + if stats.healthy_workers > 0 { + info!( + "Workers healthy: {}/{} workers are ready", + stats.healthy_workers, stats.total_workers + ); + + // If we have at least one healthy worker, we can proceed + // This allows partial degradation rather than total failure + return Ok(()); + } + + if start_time.elapsed() > timeout { + let error_msg = format!( + "Timeout waiting for workers to become healthy after {}s. Total workers: {}, Healthy: {}", + timeout_secs, stats.total_workers, stats.healthy_workers + ); + warn!("{}", error_msg); + + // If we have workers but none are healthy, it's still a failure + if stats.total_workers > 0 { + return Err(error_msg); + } else { + // No workers at all might be OK for some configurations + warn!("No workers registered, proceeding anyway"); + return Ok(()); + } + } + + tokio::time::sleep(check_interval).await; + } + } + + /// Initialize workers for gRPC connections specifically + /// This is used when gRPC clients are pre-connected + pub async fn initialize_grpc_workers( + worker_urls: &[String], + worker_type: WorkerType, + config: &RouterConfig, + registry: &Arc, + grpc_clients: &mut std::collections::HashMap, + ) -> Result<(), String> { + info!( + "Creating {} gRPC workers of type {:?}", + worker_urls.len(), + worker_type + ); + + // Convert circuit breaker config + let circuit_breaker_config = config.effective_circuit_breaker_config(); + let core_cb_config = CircuitBreakerConfig { + failure_threshold: circuit_breaker_config.failure_threshold, + success_threshold: circuit_breaker_config.success_threshold, + timeout_duration: Duration::from_secs(circuit_breaker_config.timeout_duration_secs), + window_duration: Duration::from_secs(circuit_breaker_config.window_duration_secs), + }; + + // Convert health check config + let health_config = HealthConfig { + timeout_secs: config.health_check.timeout_secs, + check_interval_secs: config.health_check.check_interval_secs, + endpoint: config.health_check.endpoint.clone(), + failure_threshold: config.health_check.failure_threshold, + success_threshold: config.health_check.success_threshold, + }; + + for url in worker_urls { + if let Some(client) = grpc_clients.remove(url) { + let worker = BasicWorkerBuilder::new(url.clone()) + .worker_type(worker_type.clone()) + .connection_mode(ConnectionMode::Grpc { port: None }) + .circuit_breaker_config(core_cb_config.clone()) + .health_config(health_config.clone()) + .grpc_client(client) + .build(); + + let worker_id = registry.register(Arc::new(worker)); + info!("Registered gRPC worker {} with ID {:?}", url, worker_id); + } else { + warn!("No gRPC client available for worker {}, skipping", url); + } + } + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_convert_connection_mode() { + // HTTP mode + assert!(matches!( + WorkerInitializer::convert_connection_mode( + &ConfigConnectionMode::Http, + Some(&"http://localhost:8080".to_string()) + ), + ConnectionMode::Http + )); + + // gRPC mode + assert!(matches!( + WorkerInitializer::convert_connection_mode( + &ConfigConnectionMode::Grpc, + Some(&"grpc://localhost:50051".to_string()) + ), + ConnectionMode::Grpc { .. } + )); + + // No URL provided + assert!(matches!( + WorkerInitializer::convert_connection_mode(&ConfigConnectionMode::Http, None), + ConnectionMode::Http + )); + } +} diff --git a/sgl-router/src/server.rs b/sgl-router/src/server.rs index 215e2b54c..f04e08183 100644 --- a/sgl-router/src/server.rs +++ b/sgl-router/src/server.rs @@ -14,6 +14,7 @@ use crate::{ worker_spec::{WorkerApiResponse, WorkerConfigRequest, WorkerErrorResponse}, }, reasoning_parser::ParserFactory, + routers::WorkerInitializer, routers::{ router_manager::{RouterId, RouterManager}, RouterFactory, RouterTrait, @@ -594,6 +595,22 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box, Option>) = if config.router_config.enable_igw { @@ -608,12 +625,7 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box { info!("Created HTTP Regular router"); router_manager.register_router( @@ -628,8 +640,6 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box Result<(), Box } diff --git a/sgl-router/tests/api_endpoints_test.rs b/sgl-router/tests/api_endpoints_test.rs index b715dad55..2c3d10a23 100644 --- a/sgl-router/tests/api_endpoints_test.rs +++ b/sgl-router/tests/api_endpoints_test.rs @@ -101,6 +101,14 @@ impl TestContext { // Create app context let app_context = common::create_test_context(config.clone()); + // Initialize workers in the registry before creating router + if !worker_urls.is_empty() { + use sglang_router_rs::routers::WorkerInitializer; + WorkerInitializer::initialize_workers(&config, &app_context.worker_registry) + .await + .expect("Failed to initialize workers"); + } + // Create router let router = RouterFactory::create_router(&app_context).await.unwrap(); let router = Arc::from(router); diff --git a/sgl-router/tests/request_formats_test.rs b/sgl-router/tests/request_formats_test.rs index 2ec7f0039..60b95852b 100644 --- a/sgl-router/tests/request_formats_test.rs +++ b/sgl-router/tests/request_formats_test.rs @@ -39,9 +39,20 @@ impl TestContext { tokio::time::sleep(tokio::time::Duration::from_millis(200)).await; } - config.mode = RoutingMode::Regular { worker_urls }; + config.mode = RoutingMode::Regular { + worker_urls: worker_urls.clone(), + }; + + let app_context = common::create_test_context(config.clone()); + + // Initialize workers in the registry before creating router + if !worker_urls.is_empty() { + use sglang_router_rs::routers::WorkerInitializer; + WorkerInitializer::initialize_workers(&config, &app_context.worker_registry) + .await + .expect("Failed to initialize workers"); + } - let app_context = common::create_test_context(config); let router = RouterFactory::create_router(&app_context).await.unwrap(); let router = Arc::from(router); diff --git a/sgl-router/tests/streaming_tests.rs b/sgl-router/tests/streaming_tests.rs index b998625c1..4cf1eff1d 100644 --- a/sgl-router/tests/streaming_tests.rs +++ b/sgl-router/tests/streaming_tests.rs @@ -40,9 +40,20 @@ impl TestContext { tokio::time::sleep(tokio::time::Duration::from_millis(200)).await; } - config.mode = RoutingMode::Regular { worker_urls }; + config.mode = RoutingMode::Regular { + worker_urls: worker_urls.clone(), + }; + + let app_context = common::create_test_context(config.clone()); + + // Initialize workers in the registry before creating router + if !worker_urls.is_empty() { + use sglang_router_rs::routers::WorkerInitializer; + WorkerInitializer::initialize_workers(&config, &app_context.worker_registry) + .await + .expect("Failed to initialize workers"); + } - let app_context = common::create_test_context(config); let router = RouterFactory::create_router(&app_context).await.unwrap(); let router = Arc::from(router); diff --git a/sgl-router/tests/test_pd_routing.rs b/sgl-router/tests/test_pd_routing.rs index 5b0f9dd96..fa52492ce 100644 --- a/sgl-router/tests/test_pd_routing.rs +++ b/sgl-router/tests/test_pd_routing.rs @@ -207,19 +207,21 @@ mod test_pd_routing { history_backend: sglang_router_rs::config::HistoryBackend::Memory, }; - // Router creation will fail due to health checks, but config should be valid let app_context = sglang_router_rs::server::AppContext::new(config, reqwest::Client::new(), 64, None) .expect("Failed to create AppContext"); let app_context = std::sync::Arc::new(app_context); let result = RouterFactory::create_router(&app_context).await; - assert!(result.is_err()); - let error_msg = result.unwrap_err(); - // Error should be about health/timeout, not configuration assert!( - error_msg.contains("healthy") || error_msg.contains("timeout"), - "Unexpected error: {}", - error_msg + result.is_ok(), + "Router creation should succeed with empty worker" + ); + + // Verify that no workers are registered since we didn't initialize them + let stats = app_context.worker_registry.stats(); + assert_eq!( + stats.total_workers, 0, + "No workers should be registered without initialization" ); } }